mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-22 08:38:13 +02:00
Initial Commit 🚀 🚀
This commit is contained in:
commit
4f2a629340
444 changed files with 76863 additions and 0 deletions
1
api/tasks/__init__.py
Normal file
1
api/tasks/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
|
||||
119
api/tasks/arq.py
Normal file
119
api/tasks/arq.py
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
"""ARQ worker configuration - setup logging before importing tasks"""
|
||||
|
||||
import ssl
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from api.constants import REDIS_URL
|
||||
|
||||
# Setup logging - this is now idempotent and safe to call multiple times
|
||||
from api.logging_config import setup_logging
|
||||
from api.tasks.function_names import FunctionNames
|
||||
|
||||
logging_queue_listener = setup_logging()
|
||||
|
||||
# Now import ARQ and task dependencies
|
||||
from arq import create_pool
|
||||
from arq.connections import ArqRedis, RedisSettings
|
||||
|
||||
from api.tasks.workflow_run_cost import calculate_workflow_run_cost
|
||||
|
||||
parsed_url = urlparse(REDIS_URL)
|
||||
|
||||
# Check if we're using TLS (rediss://)
|
||||
use_ssl = parsed_url.scheme == "rediss"
|
||||
|
||||
# Create SSL context if using rediss://
|
||||
ssl_context = None
|
||||
if use_ssl:
|
||||
ssl_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
|
||||
ssl_context.check_hostname = False
|
||||
ssl_context.verify_mode = ssl.CERT_NONE
|
||||
|
||||
REDIS_SETTINGS = RedisSettings(
|
||||
host=parsed_url.hostname or "localhost",
|
||||
port=parsed_url.port or 6379,
|
||||
password=parsed_url.password,
|
||||
conn_timeout=10,
|
||||
ssl=use_ssl,
|
||||
ssl_ca_certs=None if not use_ssl else None,
|
||||
ssl_certfile=None,
|
||||
ssl_keyfile=None,
|
||||
ssl_check_hostname=False if use_ssl else None,
|
||||
)
|
||||
|
||||
from api.tasks.campaign_tasks import (
|
||||
monitor_campaign_progress,
|
||||
process_campaign_batch,
|
||||
sync_campaign_source,
|
||||
)
|
||||
from api.tasks.run_integrations import run_integrations_post_workflow_run
|
||||
from api.tasks.s3_upload import (
|
||||
upload_audio_to_s3,
|
||||
upload_transcript_to_s3,
|
||||
upload_voicemail_audio_to_s3,
|
||||
)
|
||||
|
||||
|
||||
class WorkerSettings:
|
||||
functions = [
|
||||
calculate_workflow_run_cost,
|
||||
run_integrations_post_workflow_run,
|
||||
upload_audio_to_s3,
|
||||
upload_transcript_to_s3,
|
||||
upload_voicemail_audio_to_s3,
|
||||
sync_campaign_source,
|
||||
process_campaign_batch,
|
||||
monitor_campaign_progress,
|
||||
]
|
||||
cron_jobs = []
|
||||
redis_settings = REDIS_SETTINGS
|
||||
max_jobs = 10
|
||||
|
||||
|
||||
LOG_CONFIG = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
# --- Handlers ---
|
||||
"handlers": {
|
||||
"console": { # everything goes to stdout
|
||||
"class": "logging.StreamHandler",
|
||||
"stream": "ext://sys.stdout",
|
||||
"level": "WARNING", # only WARNING and above
|
||||
"formatter": "simple",
|
||||
},
|
||||
},
|
||||
# --- Formatters (optional) ---
|
||||
"formatters": {
|
||||
"simple": {
|
||||
"format": "%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
||||
},
|
||||
},
|
||||
# --- Root logger ---
|
||||
"root": {
|
||||
"handlers": ["console"],
|
||||
"level": "WARNING",
|
||||
},
|
||||
# --- Optionally silence Arq itself explicitly ---
|
||||
"loggers": {
|
||||
"arq": { # arq.* loggers
|
||||
"level": "WARNING",
|
||||
"handlers": ["console"],
|
||||
"propagate": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
_redis_pool: ArqRedis | None = None
|
||||
|
||||
|
||||
async def get_arq_redis() -> ArqRedis:
|
||||
global _redis_pool
|
||||
if _redis_pool is None:
|
||||
_redis_pool = await create_pool(REDIS_SETTINGS)
|
||||
return _redis_pool
|
||||
|
||||
|
||||
async def enqueue_job(function_name: FunctionNames, *args):
|
||||
redis = await get_arq_redis()
|
||||
await redis.enqueue_job(function_name, *args)
|
||||
199
api/tasks/campaign_tasks.py
Normal file
199
api/tasks/campaign_tasks.py
Normal file
|
|
@ -0,0 +1,199 @@
|
|||
from datetime import UTC, datetime
|
||||
from typing import Dict
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.db import db_client
|
||||
from api.enums import RedisChannel
|
||||
from api.services.campaign.call_dispatcher import campaign_call_dispatcher
|
||||
from api.services.campaign.campaign_event_protocol import BatchFailedEvent
|
||||
from api.services.campaign.campaign_event_publisher import (
|
||||
get_campaign_event_publisher,
|
||||
)
|
||||
from api.services.campaign.source_sync import get_sync_service
|
||||
|
||||
|
||||
async def sync_campaign_source(ctx: Dict, campaign_id: int) -> None:
|
||||
"""
|
||||
Phase 1: Syncs data from configured source to queued_runs table
|
||||
- Campaign state should already be 'syncing'
|
||||
- Determines source type from campaign configuration
|
||||
- Fetches data via appropriate sync service (Google Sheets, HubSpot, etc.)
|
||||
- Creates queued_run entries with unique source_uuid
|
||||
- Updates campaign total_rows
|
||||
- Transitions campaign state to 'running' on success
|
||||
- Enqueues process_campaign_batch tasks
|
||||
"""
|
||||
logger.info(f"Starting source sync for campaign {campaign_id}")
|
||||
|
||||
try:
|
||||
# Get campaign
|
||||
campaign = await db_client.get_campaign_by_id(campaign_id)
|
||||
if not campaign:
|
||||
raise ValueError(f"Campaign {campaign_id} not found")
|
||||
|
||||
# Get appropriate sync service
|
||||
sync_service = get_sync_service(campaign.source_type)
|
||||
|
||||
# Sync source data
|
||||
rows_synced = await sync_service.sync_source_data(campaign_id)
|
||||
|
||||
if rows_synced == 0:
|
||||
# No data to process
|
||||
await db_client.update_campaign(
|
||||
campaign_id=campaign_id,
|
||||
state="completed",
|
||||
completed_at=datetime.now(UTC),
|
||||
source_sync_status="completed",
|
||||
source_last_synced_at=datetime.now(UTC),
|
||||
)
|
||||
logger.info(f"Campaign {campaign_id} completed with no data to process")
|
||||
return
|
||||
|
||||
# Update campaign state to running
|
||||
await db_client.update_campaign(
|
||||
campaign_id=campaign_id,
|
||||
state="running",
|
||||
source_sync_status="completed",
|
||||
source_last_synced_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
# Publish sync completed event - orchestrator will schedule first batch
|
||||
publisher = await get_campaign_event_publisher()
|
||||
await publisher.publish_sync_completed(
|
||||
campaign_id=campaign_id,
|
||||
total_rows=rows_synced,
|
||||
source_type=campaign.source_type,
|
||||
source_id=campaign.source_id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Campaign {campaign_id} source sync completed, {rows_synced} rows synced"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error syncing campaign {campaign_id} source: {e}")
|
||||
|
||||
# Update campaign with error
|
||||
await db_client.update_campaign(
|
||||
campaign_id=campaign_id,
|
||||
state="failed",
|
||||
source_sync_status="failed",
|
||||
source_sync_error=str(e),
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
async def process_campaign_batch(
|
||||
ctx: Dict, campaign_id: int, batch_size: int = 10
|
||||
) -> None:
|
||||
"""
|
||||
Phase 2: Processes a batch of queued runs
|
||||
- Fetches next batch of 'queued' runs (including due retries)
|
||||
- Creates workflow runs with context variables
|
||||
- Initiates Twilio calls with rate limiting
|
||||
- Updates queued_run state to 'processed'
|
||||
- Updates campaign.processed_rows counter
|
||||
- Publishes batch_completed event for orchestrator
|
||||
"""
|
||||
logger.info(f"Processing batch for campaign {campaign_id}, batch_size={batch_size}")
|
||||
|
||||
failed_count = 0
|
||||
try:
|
||||
# Process the batch
|
||||
processed_count = await campaign_call_dispatcher.process_batch(
|
||||
campaign_id=campaign_id, batch_size=batch_size
|
||||
)
|
||||
|
||||
# Publish batch completed event - orchestrator will handle next batch scheduling
|
||||
publisher = await get_campaign_event_publisher()
|
||||
await publisher.publish_batch_completed(
|
||||
campaign_id=campaign_id,
|
||||
processed_count=processed_count,
|
||||
failed_count=failed_count,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Campaign {campaign_id} batch completed: processed={processed_count}, "
|
||||
f"failed={failed_count}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing batch for campaign {campaign_id}: {e}")
|
||||
|
||||
# Publish batch failed event
|
||||
publisher = await get_campaign_event_publisher()
|
||||
event = BatchFailedEvent(
|
||||
campaign_id=campaign_id,
|
||||
error=str(e),
|
||||
processed_count=0,
|
||||
)
|
||||
await publisher.redis.publish(
|
||||
RedisChannel.CAMPAIGN_EVENTS.value, event.to_json()
|
||||
)
|
||||
|
||||
# Update campaign state to failed
|
||||
await db_client.update_campaign(campaign_id=campaign_id, state="failed")
|
||||
raise
|
||||
|
||||
|
||||
async def monitor_campaign_progress(ctx: Dict, campaign_id: int) -> None:
|
||||
"""
|
||||
Phase 3: Monitors campaign completion
|
||||
- Checks if all queued runs are in 'processed' state
|
||||
- Queries workflow_runs for final call statistics
|
||||
- Updates campaign state to 'completed'
|
||||
- Calculates total calls made, successful, failed
|
||||
- Triggers post-campaign integrations
|
||||
"""
|
||||
logger.info(f"Monitoring progress for campaign {campaign_id}")
|
||||
|
||||
try:
|
||||
# Get campaign
|
||||
campaign = await db_client.get_campaign_by_id(campaign_id)
|
||||
if not campaign:
|
||||
raise ValueError(f"Campaign {campaign_id} not found")
|
||||
|
||||
# Check if all runs are processed
|
||||
pending_runs = await db_client.count_queued_runs(
|
||||
campaign_id=campaign_id, state="queued"
|
||||
)
|
||||
|
||||
if pending_runs > 0:
|
||||
logger.info(f"Campaign {campaign_id} still has {pending_runs} pending runs")
|
||||
return
|
||||
|
||||
# All runs processed, mark campaign as completed
|
||||
await db_client.update_campaign(
|
||||
campaign_id=campaign_id, state="completed", completed_at=datetime.now(UTC)
|
||||
)
|
||||
|
||||
# Calculate statistics
|
||||
workflow_runs = await db_client.get_workflow_runs_by_campaign(campaign_id)
|
||||
|
||||
total_calls = len(workflow_runs)
|
||||
successful_calls = 0
|
||||
failed_calls = 0
|
||||
|
||||
for run in workflow_runs:
|
||||
callbacks = run.logs.get("twilio_status_callbacks", [])
|
||||
if callbacks:
|
||||
final_status = callbacks[-1].get("status", "").lower()
|
||||
if final_status == "completed":
|
||||
successful_calls += 1
|
||||
elif final_status in ["failed", "busy", "no-answer"]:
|
||||
failed_calls += 1
|
||||
|
||||
logger.info(
|
||||
f"Campaign {campaign_id} completed: "
|
||||
f"Total calls: {total_calls}, "
|
||||
f"Successful: {successful_calls}, "
|
||||
f"Failed: {failed_calls}"
|
||||
)
|
||||
|
||||
# TODO: Trigger post-campaign integrations if configured
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error monitoring campaign {campaign_id}: {e}")
|
||||
raise
|
||||
9
api/tasks/function_names.py
Normal file
9
api/tasks/function_names.py
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
class FunctionNames:
|
||||
CALCULATE_WORKFLOW_RUN_COST = "calculate_workflow_run_cost"
|
||||
RUN_INTEGRATIONS_POST_WORKFLOW_RUN = "run_integrations_post_workflow_run"
|
||||
UPLOAD_AUDIO_TO_S3 = "upload_audio_to_s3"
|
||||
UPLOAD_TRANSCRIPT_TO_S3 = "upload_transcript_to_s3"
|
||||
UPLOAD_VOICEMAIL_AUDIO_TO_S3 = "upload_voicemail_audio_to_s3"
|
||||
SYNC_CAMPAIGN_SOURCE = "sync_campaign_source"
|
||||
PROCESS_CAMPAIGN_BATCH = "process_campaign_batch"
|
||||
MONITOR_CAMPAIGN_PROGRESS = "monitor_campaign_progress"
|
||||
227
api/tasks/run_integrations.py
Normal file
227
api/tasks/run_integrations.py
Normal file
|
|
@ -0,0 +1,227 @@
|
|||
import os
|
||||
|
||||
import aiohttp
|
||||
import httpx
|
||||
from loguru import logger
|
||||
from pipecat.utils.context import set_current_run_id
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.models import IntegrationModel
|
||||
from api.enums import OrganizationConfigurationKey, WorkflowRunMode
|
||||
from api.utils.template_renderer import render_template
|
||||
|
||||
|
||||
async def run_integrations_post_workflow_run(ctx, workflow_run_id: int):
|
||||
"""
|
||||
Run integrations after a workflow run completes.
|
||||
|
||||
This function:
|
||||
1. Gets the workflow run and its gathered_context
|
||||
2. Determines the organization_id through the workflow -> user -> organization chain
|
||||
3. Fetches all active integrations for that organization
|
||||
4. For Slack integrations, sends the gathered_context to the webhook URL
|
||||
|
||||
Args:
|
||||
workflow_run_id: The ID of the completed workflow run
|
||||
"""
|
||||
# Set the workflow_run_id in context variable for consistent logging format
|
||||
set_current_run_id(workflow_run_id)
|
||||
logger.info("Running integrations for workflow run")
|
||||
|
||||
try:
|
||||
# Step 1: Get workflow run details with gathered_context using DB client
|
||||
workflow_run, organization_id = await db_client.get_workflow_run_with_context(
|
||||
workflow_run_id
|
||||
)
|
||||
|
||||
if not workflow_run:
|
||||
logger.error("Workflow run not found")
|
||||
return
|
||||
|
||||
if not workflow_run.workflow:
|
||||
logger.error("Workflow not found for workflow run")
|
||||
return
|
||||
|
||||
if not workflow_run.workflow.user:
|
||||
logger.error("User not found for workflow run")
|
||||
return
|
||||
|
||||
gathered_context = workflow_run.gathered_context
|
||||
initial_context = workflow_run.initial_context
|
||||
|
||||
if not gathered_context:
|
||||
logger.info("No gathered context for workflow run, skipping integrations")
|
||||
return
|
||||
|
||||
# Check if workflow run mode is stasis and sync with vendor
|
||||
if workflow_run.mode == WorkflowRunMode.STASIS.value:
|
||||
await _sync_vendor_data(initial_context, gathered_context)
|
||||
|
||||
# Step 2: Check if organization_id is available
|
||||
if not organization_id:
|
||||
logger.warning(
|
||||
f"No organization found for workflow run, skipping integrations"
|
||||
)
|
||||
return
|
||||
|
||||
logger.debug(f"Found organization_id {organization_id} for workflow run")
|
||||
|
||||
# Step 3: Get all active integrations for the organization using DB client
|
||||
integrations = await db_client.get_active_integrations_by_organization(
|
||||
organization_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Found {len(integrations)} active integrations for organization {organization_id}"
|
||||
)
|
||||
|
||||
# Step 4: Process each integration
|
||||
for integration in integrations:
|
||||
await _process_integration(integration, gathered_context)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error running integrations for workflow run: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def _sync_vendor_data(initial_context: dict, gathered_context: dict):
|
||||
"""
|
||||
Sync data with external vendor for stasis mode workflow runs.
|
||||
|
||||
Args:
|
||||
initial_context: The initial context containing lead_id
|
||||
gathered_context: The gathered context containing mapped_call_disposition
|
||||
"""
|
||||
if not os.getenv("ARI_DATA_SYNCING_URI"):
|
||||
logger.info("ARI_DATA_SYNCING_URI not configured, skipping vendor sync")
|
||||
return
|
||||
|
||||
try:
|
||||
lead_id = initial_context.get("lead_id")
|
||||
status = gathered_context.get("mapped_call_disposition")
|
||||
|
||||
if lead_id and status:
|
||||
ari_data_uri = os.getenv("ARI_DATA_SYNCING_URI")
|
||||
# Add URL params to the base URL
|
||||
sync_url = f"{ari_data_uri}&lead_id={lead_id}&status={status}"
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(sync_url, timeout=10.0)
|
||||
response.raise_for_status()
|
||||
logger.info(
|
||||
f"Successfully synced data for lead_id: {lead_id} with status: {status}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Missing lead_id or status for syncing - lead_id: {lead_id}, status: {status}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to sync data to ARI_DATA_SYNCING_URI: {e}")
|
||||
|
||||
|
||||
async def _process_integration(
|
||||
integration: IntegrationModel,
|
||||
gathered_context: dict,
|
||||
):
|
||||
"""
|
||||
Process a single integration.
|
||||
|
||||
Args:
|
||||
integration: The integration model
|
||||
gathered_context: The gathered context from the workflow run
|
||||
workflow_run_name: Name of the workflow run
|
||||
run_id: The workflow run ID for logging context
|
||||
"""
|
||||
logger.info(
|
||||
f"Processing integration {integration.id} (provider: {integration.provider})"
|
||||
)
|
||||
|
||||
try:
|
||||
if integration.provider.lower() == "slack":
|
||||
await _process_slack_integration(integration, gathered_context)
|
||||
else:
|
||||
logger.info(
|
||||
f"Integration provider '{integration.provider}' not supported yet"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing integration {integration.id}: {str(e)}")
|
||||
|
||||
|
||||
async def _process_slack_integration(
|
||||
integration: IntegrationModel, gathered_context: dict
|
||||
):
|
||||
"""
|
||||
Process a Slack integration by sending gathered_context to the webhook.
|
||||
|
||||
Args:
|
||||
integration: The Slack integration model
|
||||
gathered_context: The gathered context from the workflow run
|
||||
workflow_run_name: Name of the workflow run
|
||||
run_id: The workflow run ID for logging context
|
||||
"""
|
||||
logger.info(f"Processing Slack integration {integration.id}")
|
||||
|
||||
# TODO: Generalise this, currently tailored to Kapil's use case
|
||||
if gathered_context.get("mapped_call_disposition") != "XFER":
|
||||
logger.debug(
|
||||
f"Not sending message on slack since not XFER: {gathered_context.get('mapped_call_disposition')}"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
# Extract webhook URL from connection_details
|
||||
connection_details = integration.connection_details
|
||||
|
||||
if not connection_details:
|
||||
logger.error(
|
||||
f"No connection details found for Slack integration {integration.id}"
|
||||
)
|
||||
return
|
||||
|
||||
# Navigate to incoming_webhook.url in the connection_details
|
||||
webhook_url = connection_details.get("connection_config", {}).get(
|
||||
"incoming_webhook.url"
|
||||
)
|
||||
if not webhook_url:
|
||||
logger.error(
|
||||
f"No incoming_webhook found in connection details for integration {integration.id}"
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(f"Found Slack webhook URL for integration {integration.id}")
|
||||
|
||||
# Get message template configuration
|
||||
# Get organization_id from the integration model
|
||||
organization_id = integration.organisation_id
|
||||
message_templates = await db_client.get_configuration_value(
|
||||
organization_id,
|
||||
OrganizationConfigurationKey.DISPOSITION_MESSAGE_TEMPLATE.value,
|
||||
default={},
|
||||
)
|
||||
|
||||
# Check if there's a custom template for Slack
|
||||
slack_template = message_templates.get("slack", {})
|
||||
rendered_text = render_template(slack_template, gathered_context)
|
||||
|
||||
slack_message = {"text": rendered_text}
|
||||
|
||||
# Send to Slack webhook
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
webhook_url,
|
||||
json=slack_message,
|
||||
headers={"Content-Type": "application/json"},
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
logger.info(
|
||||
f"Successfully sent message to Slack for integration {integration.id}"
|
||||
)
|
||||
else:
|
||||
error_text = await response.text()
|
||||
logger.error(
|
||||
f"Failed to send Slack message for integration {integration.id}: {response.status} - {error_text}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing Slack integration {integration.id}: {str(e)}")
|
||||
172
api/tasks/s3_upload.py
Normal file
172
api/tasks/s3_upload.py
Normal file
|
|
@ -0,0 +1,172 @@
|
|||
import os
|
||||
|
||||
from loguru import logger
|
||||
from pipecat.utils.context import set_current_run_id
|
||||
|
||||
from api.db import db_client
|
||||
from api.services.storage import get_current_storage_backend, storage_fs
|
||||
|
||||
|
||||
async def upload_audio_to_s3(ctx, workflow_run_id: int, temp_file_path: str):
|
||||
"""Upload audio file from temp path to S3."""
|
||||
run_id = str(workflow_run_id)
|
||||
set_current_run_id(run_id)
|
||||
|
||||
logger.info(f"Starting audio upload to S3 from {temp_file_path}")
|
||||
|
||||
try:
|
||||
# Verify temp file exists
|
||||
if not os.path.exists(temp_file_path):
|
||||
logger.error(f"Temp audio file not found: {temp_file_path}")
|
||||
raise FileNotFoundError(f"Temp audio file not found: {temp_file_path}")
|
||||
|
||||
file_size = os.path.getsize(temp_file_path)
|
||||
logger.debug(f"Audio file size: {file_size} bytes")
|
||||
|
||||
recording_url = f"recordings/{workflow_run_id}.wav"
|
||||
storage_backend = get_current_storage_backend()
|
||||
|
||||
logger.info(
|
||||
f"UPLOAD: Using {storage_backend.name} (value: {storage_backend.value}) for audio upload - workflow_run_id: {workflow_run_id}"
|
||||
)
|
||||
|
||||
await storage_fs.aupload_file(temp_file_path, recording_url)
|
||||
|
||||
# Update DB with recording URL and storage backend
|
||||
await db_client.update_workflow_run(
|
||||
run_id=workflow_run_id,
|
||||
recording_url=recording_url,
|
||||
storage_backend=storage_backend.value,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Successfully uploaded audio to {storage_backend.name}: {recording_url} (stored backend: {storage_backend.name})"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error uploading audio to S3 for workflow {workflow_run_id}: {e}")
|
||||
raise
|
||||
finally:
|
||||
# Clean up temp file
|
||||
if os.path.exists(temp_file_path):
|
||||
try:
|
||||
os.remove(temp_file_path)
|
||||
logger.debug(f"Cleaned up temp audio file: {temp_file_path}")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to clean up temp audio file {temp_file_path}: {e}"
|
||||
)
|
||||
|
||||
|
||||
async def upload_transcript_to_s3(ctx, workflow_run_id: int, temp_file_path: str):
|
||||
"""Upload transcript file from temp path to S3."""
|
||||
run_id = str(workflow_run_id)
|
||||
set_current_run_id(run_id)
|
||||
|
||||
logger.info(f"Starting transcript upload to S3 from {temp_file_path}")
|
||||
|
||||
try:
|
||||
# Verify temp file exists
|
||||
if not os.path.exists(temp_file_path):
|
||||
logger.error(f"Temp transcript file not found: {temp_file_path}")
|
||||
raise FileNotFoundError(f"Temp transcript file not found: {temp_file_path}")
|
||||
|
||||
file_size = os.path.getsize(temp_file_path)
|
||||
logger.debug(f"Transcript file size: {file_size} bytes")
|
||||
|
||||
transcript_url = f"transcripts/{workflow_run_id}.txt"
|
||||
storage_backend = get_current_storage_backend()
|
||||
|
||||
logger.info(
|
||||
f"UPLOAD: Using {storage_backend.name} (value: {storage_backend.value}) for transcript upload - workflow_run_id: {workflow_run_id}"
|
||||
)
|
||||
|
||||
await storage_fs.aupload_file(temp_file_path, transcript_url)
|
||||
|
||||
# Update DB with transcript URL and storage backend
|
||||
await db_client.update_workflow_run(
|
||||
run_id=workflow_run_id,
|
||||
transcript_url=transcript_url,
|
||||
storage_backend=storage_backend.value,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Successfully uploaded transcript to {storage_backend.name}: {transcript_url} (stored backend: {storage_backend.name})"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error uploading transcript to S3 for workflow {workflow_run_id}: {e}"
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
# Clean up temp file
|
||||
if os.path.exists(temp_file_path):
|
||||
try:
|
||||
os.remove(temp_file_path)
|
||||
logger.debug(f"Cleaned up temp transcript file: {temp_file_path}")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to clean up temp transcript file {temp_file_path}: {e}"
|
||||
)
|
||||
|
||||
|
||||
async def upload_voicemail_audio_to_s3(
|
||||
ctx,
|
||||
workflow_run_id: int,
|
||||
temp_file_path: str,
|
||||
s3_key: str,
|
||||
):
|
||||
"""Upload voicemail detection audio from temp file to S3.
|
||||
|
||||
This function is similar to upload_audio_to_s3 but handles voicemail-specific
|
||||
paths and doesn't update the workflow run's recording_url field.
|
||||
|
||||
Args:
|
||||
ctx: ARQ context
|
||||
workflow_run_id: The workflow run ID
|
||||
temp_file_path: Path to the temporary WAV file
|
||||
s3_key: The S3 key where the file should be uploaded
|
||||
"""
|
||||
run_id = str(workflow_run_id)
|
||||
set_current_run_id(run_id)
|
||||
|
||||
logger.info(f"Starting voicemail audio upload to S3 from {temp_file_path}")
|
||||
|
||||
try:
|
||||
# Verify temp file exists
|
||||
if not os.path.exists(temp_file_path):
|
||||
logger.error(f"Temp voicemail audio file not found: {temp_file_path}")
|
||||
raise FileNotFoundError(
|
||||
f"Temp voicemail audio file not found: {temp_file_path}"
|
||||
)
|
||||
|
||||
file_size = os.path.getsize(temp_file_path)
|
||||
logger.debug(f"Voicemail audio file size: {file_size} bytes")
|
||||
|
||||
# Upload to S3
|
||||
upload_ok = await storage_fs.aupload_file(temp_file_path, s3_key)
|
||||
|
||||
if upload_ok:
|
||||
logger.info(f"Successfully uploaded voicemail audio to S3: {s3_key}")
|
||||
else:
|
||||
logger.error(
|
||||
f"Failed to upload voicemail audio to S3 for workflow {workflow_run_id}"
|
||||
)
|
||||
raise Exception(f"S3 upload failed for {s3_key}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error uploading voicemail audio to S3 for workflow {workflow_run_id}: {e}"
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
# Clean up temp file (same pattern as upload_audio_to_s3)
|
||||
if os.path.exists(temp_file_path):
|
||||
try:
|
||||
os.remove(temp_file_path)
|
||||
logger.debug(f"Cleaned up temp voicemail audio file: {temp_file_path}")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to clean up temp voicemail audio file {temp_file_path}: {e}"
|
||||
)
|
||||
123
api/tasks/workflow_run_cost.py
Normal file
123
api/tasks/workflow_run_cost.py
Normal file
|
|
@ -0,0 +1,123 @@
|
|||
from loguru import logger
|
||||
from pipecat.utils.context import set_current_run_id
|
||||
|
||||
from api.db import db_client
|
||||
from api.enums import WorkflowRunMode
|
||||
from api.services.pricing.cost_calculator import cost_calculator
|
||||
from api.services.telephony.twilio import TwilioService
|
||||
|
||||
|
||||
async def calculate_workflow_run_cost(ctx, workflow_run_id: int):
|
||||
# Set the run_id in context variable for consistent logging format
|
||||
set_current_run_id(workflow_run_id)
|
||||
logger.debug("Calculating cost for workflow run")
|
||||
|
||||
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
|
||||
if not workflow_run:
|
||||
logger.warning("Workflow run not found")
|
||||
return
|
||||
|
||||
workflow_usage_info = workflow_run.usage_info
|
||||
if not workflow_usage_info:
|
||||
logger.warning("No usage info available for workflow run")
|
||||
return
|
||||
|
||||
try:
|
||||
# Calculate cost breakdown
|
||||
cost_breakdown = cost_calculator.calculate_total_cost(workflow_usage_info)
|
||||
|
||||
# If this is a Twilio call, fetch the Twilio call cost
|
||||
twilio_cost_usd = 0.0
|
||||
if workflow_run.mode == WorkflowRunMode.TWILIO.value and workflow_run.cost_info:
|
||||
twilio_call_sid = workflow_run.cost_info.get("twilio_call_sid")
|
||||
if twilio_call_sid:
|
||||
try:
|
||||
twilio_service = TwilioService()
|
||||
call_info = await twilio_service.get_call(twilio_call_sid)
|
||||
# Twilio returns price as a string with negative value (e.g., "-0.0085")
|
||||
if call_info.get("price"):
|
||||
twilio_cost_usd = abs(float(call_info["price"]))
|
||||
cost_breakdown["twilio_call"] = twilio_cost_usd
|
||||
# Add Twilio cost to the total
|
||||
cost_breakdown["total"] = (
|
||||
float(cost_breakdown["total"]) + twilio_cost_usd
|
||||
)
|
||||
logger.info(
|
||||
f"Twilio call cost: ${twilio_cost_usd:.6f} USD for call {twilio_call_sid}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch Twilio call cost: {e}")
|
||||
# Don't fail the whole cost calculation if Twilio API fails
|
||||
|
||||
# Store cost information back to the workflow run
|
||||
# We'll add the cost breakdown to the workflow run
|
||||
# Convert USD to Dograh Tokens (1 cent = 1 token)
|
||||
dograh_tokens = round(float(cost_breakdown["total"]) * 100, 2)
|
||||
|
||||
# Get organization to check if it has USD pricing
|
||||
org = None
|
||||
charge_usd = None
|
||||
if (
|
||||
workflow_run.workflow
|
||||
and workflow_run.workflow.user
|
||||
and workflow_run.workflow.user.selected_organization_id
|
||||
):
|
||||
org = await db_client.get_organization_by_id(
|
||||
workflow_run.workflow.user.selected_organization_id
|
||||
)
|
||||
|
||||
# Calculate USD cost if organization has pricing configured
|
||||
if org and org.price_per_second_usd:
|
||||
duration_seconds = workflow_usage_info.get("call_duration_seconds", 0)
|
||||
charge_usd = duration_seconds * org.price_per_second_usd
|
||||
|
||||
cost_info = {
|
||||
"cost_breakdown": cost_breakdown,
|
||||
"total_cost_usd": float(cost_breakdown["total"]),
|
||||
"dograh_token_usage": dograh_tokens,
|
||||
"calculated_at": workflow_run.created_at.isoformat(),
|
||||
"call_duration_seconds": workflow_usage_info["call_duration_seconds"],
|
||||
}
|
||||
|
||||
# Add USD cost if available
|
||||
if charge_usd is not None:
|
||||
cost_info["charge_usd"] = charge_usd
|
||||
cost_info["price_per_second_usd"] = org.price_per_second_usd
|
||||
|
||||
# Preserve the twilio_call_sid if it exists
|
||||
if workflow_run.cost_info and "twilio_call_sid" in workflow_run.cost_info:
|
||||
cost_info["twilio_call_sid"] = workflow_run.cost_info["twilio_call_sid"]
|
||||
|
||||
# Update workflow run with cost information
|
||||
await db_client.update_workflow_run(run_id=workflow_run_id, cost_info=cost_info)
|
||||
|
||||
# Update organization usage if applicable
|
||||
if org:
|
||||
org_id = org.id
|
||||
try:
|
||||
duration_seconds = workflow_usage_info.get("call_duration_seconds", 0)
|
||||
# Pass USD amount if organization has pricing
|
||||
await db_client.update_usage_after_run(
|
||||
org_id, dograh_tokens, duration_seconds, charge_usd
|
||||
)
|
||||
if charge_usd is not None:
|
||||
logger.info(
|
||||
f"Updated organization usage with ${charge_usd:.2f} USD ({dograh_tokens} Dograh Tokens) and {duration_seconds}s duration for org {org_id}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Updated organization usage with {dograh_tokens} Dograh Tokens and {duration_seconds}s duration for org {org_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to update organization usage for org {org_id}: {e}"
|
||||
)
|
||||
# Don't fail the whole task if usage update fails
|
||||
|
||||
logger.info(
|
||||
f"Calculated cost for workflow run: ${cost_breakdown['total']:.6f} USD ({dograh_tokens} Dograh Tokens)"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating cost for workflow run: {e}")
|
||||
raise
|
||||
Loading…
Add table
Add a link
Reference in a new issue