mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-16 08:25:18 +02:00
feat: add qa node in workflow builder (#172)
* feat: add qa node in workflow builder * feat: add qa analysis token usage in usage_info * fix: mask the API key in QA node * feat: add advanced configuration in QA node
This commit is contained in:
parent
f1f4830012
commit
a836825b83
30 changed files with 1619 additions and 265 deletions
|
|
@ -15,8 +15,6 @@ setup_logging()
|
|||
from arq import create_pool
|
||||
from arq.connections import ArqRedis, RedisSettings
|
||||
|
||||
from api.tasks.workflow_run_cost import calculate_workflow_run_cost
|
||||
|
||||
parsed_url = urlparse(REDIS_URL)
|
||||
|
||||
# Check if we're using TLS (rediss://)
|
||||
|
|
@ -55,7 +53,6 @@ from api.tasks.s3_upload import (
|
|||
|
||||
class WorkerSettings:
|
||||
functions = [
|
||||
calculate_workflow_run_cost,
|
||||
run_integrations_post_workflow_run,
|
||||
upload_voicemail_audio_to_s3,
|
||||
process_workflow_completion,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
class FunctionNames:
|
||||
CALCULATE_WORKFLOW_RUN_COST = "calculate_workflow_run_cost"
|
||||
RUN_INTEGRATIONS_POST_WORKFLOW_RUN = "run_integrations_post_workflow_run"
|
||||
PROCESS_WORKFLOW_COMPLETION = "process_workflow_completion"
|
||||
UPLOAD_VOICEMAIL_AUDIO_TO_S3 = "upload_voicemail_audio_to_s3"
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
"""Execute webhook integrations after workflow run completion."""
|
||||
"""Execute integrations (QA analysis, webhooks) after workflow run completion."""
|
||||
|
||||
import random
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import httpx
|
||||
|
|
@ -8,22 +9,141 @@ from loguru import logger
|
|||
from api.constants import BACKEND_API_ENDPOINT
|
||||
from api.db import db_client
|
||||
from api.db.models import WorkflowRunModel
|
||||
from api.services.qa_analysis import run_qa_analysis
|
||||
from api.utils.credential_auth import build_auth_header
|
||||
from api.utils.template_renderer import render_template
|
||||
from pipecat.utils.enums import EndTaskReason
|
||||
from pipecat.utils.run_context import set_current_run_id
|
||||
|
||||
|
||||
def _should_skip_qa(
|
||||
node_data: dict,
|
||||
workflow_run: WorkflowRunModel,
|
||||
) -> str | None:
|
||||
"""Check whether QA analysis should be skipped for this call.
|
||||
|
||||
Returns a reason string if the call should be skipped, or None if it should proceed.
|
||||
"""
|
||||
# Check minimum call duration
|
||||
min_duration = node_data.get("qa_min_call_duration", 15)
|
||||
usage_info = workflow_run.usage_info or {}
|
||||
call_duration = usage_info.get("call_duration_seconds")
|
||||
if call_duration is not None and call_duration < min_duration:
|
||||
return f"call duration ({call_duration:.1f}s) below minimum ({min_duration}s)"
|
||||
|
||||
# Check voicemail calls
|
||||
qa_voicemail_calls = node_data.get("qa_voicemail_calls", False)
|
||||
if not qa_voicemail_calls:
|
||||
gathered_context = workflow_run.gathered_context or {}
|
||||
call_disposition = gathered_context.get("call_disposition", "")
|
||||
if call_disposition == EndTaskReason.VOICEMAIL_DETECTED.value:
|
||||
return "voicemail call and QA voicemail calls is disabled"
|
||||
|
||||
# Check sample rate
|
||||
sample_rate = node_data.get("qa_sample_rate", 100)
|
||||
if sample_rate < 100:
|
||||
roll = random.randint(1, 100)
|
||||
if roll > sample_rate:
|
||||
return f"excluded by sampling ({sample_rate}% sample rate, rolled {roll})"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def _run_qa_nodes(
|
||||
qa_nodes: list[dict],
|
||||
workflow_run: WorkflowRunModel,
|
||||
workflow_run_id: int,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run QA analysis for each enabled QA node and aggregate results.
|
||||
|
||||
Returns:
|
||||
Dict keyed by node ID with QA analysis results.
|
||||
"""
|
||||
results: Dict[str, Any] = {}
|
||||
|
||||
for node in qa_nodes:
|
||||
node_data = node.get("data", {})
|
||||
node_id = node.get("id", "unknown")
|
||||
node_name = node_data.get("name", "QA Analysis")
|
||||
|
||||
if not node_data.get("qa_enabled", True):
|
||||
logger.debug(f"QA node '{node_name}' is disabled, skipping")
|
||||
continue
|
||||
|
||||
skip_reason = _should_skip_qa(node_data, workflow_run)
|
||||
if skip_reason:
|
||||
logger.info(f"Skipping QA node '{node_name}' (#{node_id}): {skip_reason}")
|
||||
results[f"qa_{node_id}"] = {"skipped": True, "reason": skip_reason}
|
||||
continue
|
||||
|
||||
try:
|
||||
logger.info(f"Running QA analysis for node '{node_name}' (#{node_id})")
|
||||
result = await run_qa_analysis(node_data, workflow_run, workflow_run_id)
|
||||
results[f"qa_{node_id}"] = result
|
||||
logger.info(
|
||||
f"QA analysis complete for '{node_name}': "
|
||||
f"score={result.get('score')}, tags={len(result.get('tags', []))}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"QA analysis failed for node '{node_name}': {e}")
|
||||
results[f"qa_{node_id}"] = {"error": str(e)}
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def _update_usage_info_with_qa_tokens(
|
||||
workflow_run_id: int,
|
||||
workflow_run: WorkflowRunModel,
|
||||
qa_results: Dict[str, Any],
|
||||
) -> None:
|
||||
"""Add QA analysis LLM token usage to the workflow run's usage_info."""
|
||||
try:
|
||||
usage_info = dict(workflow_run.usage_info or {})
|
||||
llm_usage = dict(usage_info.get("llm", {}))
|
||||
|
||||
for _node_key, result in qa_results.items():
|
||||
token_usage = result.get("token_usage")
|
||||
model = result.get("model")
|
||||
if not token_usage or not model:
|
||||
continue
|
||||
|
||||
key = f"QAAnalysis|||{model}"
|
||||
if key in llm_usage:
|
||||
# Aggregate if multiple QA nodes use the same model
|
||||
existing = llm_usage[key]
|
||||
for field in (
|
||||
"prompt_tokens",
|
||||
"completion_tokens",
|
||||
"total_tokens",
|
||||
"cache_read_input_tokens",
|
||||
):
|
||||
existing[field] = (existing.get(field) or 0) + (
|
||||
token_usage.get(field) or 0
|
||||
)
|
||||
else:
|
||||
llm_usage[key] = token_usage
|
||||
|
||||
usage_info["llm"] = llm_usage
|
||||
await db_client.update_workflow_run(
|
||||
run_id=workflow_run_id, usage_info=usage_info
|
||||
)
|
||||
logger.info(f"Updated usage_info with QA token usage for run {workflow_run_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update usage_info with QA tokens: {e}")
|
||||
|
||||
|
||||
async def run_integrations_post_workflow_run(_ctx, workflow_run_id: int):
|
||||
"""
|
||||
Run webhook integrations after a workflow run completes.
|
||||
Run integrations after a workflow run completes.
|
||||
|
||||
This function:
|
||||
1. Gets the workflow run and its contexts
|
||||
2. Extracts webhook nodes from workflow definition
|
||||
3. Executes each enabled webhook node
|
||||
2. Runs QA analysis nodes (if any)
|
||||
3. Stores QA results in annotations
|
||||
4. Executes webhook nodes with QA results available in render context
|
||||
"""
|
||||
set_current_run_id(workflow_run_id)
|
||||
logger.info("Running webhook integrations for workflow run")
|
||||
logger.info("Running integrations for workflow run")
|
||||
|
||||
try:
|
||||
# Step 1: Get workflow run with full context
|
||||
|
|
@ -36,39 +156,61 @@ async def run_integrations_post_workflow_run(_ctx, workflow_run_id: int):
|
|||
return
|
||||
|
||||
if not organization_id:
|
||||
logger.warning("No organization found, skipping webhooks")
|
||||
logger.warning("No organization found, skipping integrations")
|
||||
return
|
||||
|
||||
# Step 2: Get workflow definition
|
||||
workflow_definition = workflow_run.workflow.workflow_definition_with_fallback
|
||||
if not workflow_definition:
|
||||
logger.debug("No workflow definition, skipping webhooks")
|
||||
logger.debug("No workflow definition, skipping integrations")
|
||||
return
|
||||
|
||||
# Step 3: Extract webhook nodes
|
||||
# Step 3: Extract integration nodes
|
||||
nodes = workflow_definition.get("nodes", [])
|
||||
qa_nodes = [n for n in nodes if n.get("type") == "qa"]
|
||||
webhook_nodes = [n for n in nodes if n.get("type") == "webhook"]
|
||||
|
||||
# Step 4: Generate public access token if webhooks exist or campaign_id is set
|
||||
has_campaign = workflow_run.campaign_id is not None
|
||||
if not webhook_nodes and not has_campaign:
|
||||
logger.debug("No webhook nodes and no campaign, skipping")
|
||||
if not webhook_nodes and not qa_nodes and not has_campaign:
|
||||
logger.debug("No integration nodes and no campaign, skipping")
|
||||
return
|
||||
|
||||
public_token = None
|
||||
if webhook_nodes or has_campaign:
|
||||
public_token = await db_client.ensure_public_access_token(workflow_run_id)
|
||||
|
||||
# Step 5: Run QA analysis before webhooks
|
||||
if qa_nodes:
|
||||
logger.info(f"Found {len(qa_nodes)} QA nodes to execute")
|
||||
qa_results = await _run_qa_nodes(qa_nodes, workflow_run, workflow_run_id)
|
||||
|
||||
if qa_results:
|
||||
await db_client.update_workflow_run(
|
||||
workflow_run_id, annotations=qa_results
|
||||
)
|
||||
|
||||
# Add QA token usage to workflow run's usage_info
|
||||
await _update_usage_info_with_qa_tokens(
|
||||
workflow_run_id, workflow_run, qa_results
|
||||
)
|
||||
|
||||
# Re-fetch workflow_run to get updated annotations
|
||||
workflow_run, _ = await db_client.get_workflow_run_with_context(
|
||||
workflow_run_id
|
||||
)
|
||||
|
||||
# Step 6: Execute webhooks
|
||||
if not webhook_nodes:
|
||||
logger.debug("No webhook nodes in workflow")
|
||||
return
|
||||
|
||||
logger.info(f"Found {len(webhook_nodes)} webhook nodes to execute")
|
||||
|
||||
# Step 5: Build render context
|
||||
# Step 7: Build render context (includes annotations from QA)
|
||||
render_context = _build_render_context(workflow_run, public_token)
|
||||
|
||||
# Step 6: Execute each webhook node
|
||||
# Step 8: Execute each webhook node
|
||||
for node in webhook_nodes:
|
||||
webhook_data = node.get("data", {})
|
||||
try:
|
||||
|
|
@ -84,7 +226,7 @@ async def run_integrations_post_workflow_run(_ctx, workflow_run_id: int):
|
|||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error running webhook integrations: {e}", exc_info=True)
|
||||
logger.error(f"Error running integrations: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
|
|
@ -110,6 +252,8 @@ def _build_render_context(
|
|||
"initial_context": workflow_run.initial_context or {},
|
||||
"gathered_context": workflow_run.gathered_context or {},
|
||||
"cost_info": workflow_run.usage_info or {},
|
||||
# Annotations (includes QA results)
|
||||
"annotations": workflow_run.annotations or {},
|
||||
}
|
||||
|
||||
# Add public download URLs if token is available
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from typing import Optional
|
|||
from loguru import logger
|
||||
|
||||
from api.db import db_client
|
||||
from api.services.pricing.workflow_run_cost import calculate_workflow_run_cost
|
||||
from api.services.storage import get_current_storage_backend, storage_fs
|
||||
from api.tasks.run_integrations import run_integrations_post_workflow_run
|
||||
from pipecat.utils.run_context import set_current_run_id
|
||||
|
|
@ -162,10 +163,16 @@ async def process_workflow_completion(
|
|||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up temp transcript file: {e}")
|
||||
|
||||
# Step 3: Run webhook integrations (after uploads are complete)
|
||||
# Step 3: Run integrations including QA analysis (after uploads are complete)
|
||||
try:
|
||||
await run_integrations_post_workflow_run(_ctx, workflow_run_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error running integrations for workflow {workflow_run_id}: {e}")
|
||||
|
||||
# Step 4: Calculate cost after integrations (so QA token usage is included)
|
||||
try:
|
||||
await calculate_workflow_run_cost(workflow_run_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating cost for workflow {workflow_run_id}: {e}")
|
||||
|
||||
logger.info(f"Completed workflow completion processing for run {workflow_run_id}")
|
||||
|
|
|
|||
|
|
@ -1,157 +0,0 @@
|
|||
from loguru import logger
|
||||
|
||||
from api.db import db_client
|
||||
from api.enums import WorkflowRunMode
|
||||
from api.services.pricing.cost_calculator import cost_calculator
|
||||
from api.services.telephony.factory import get_telephony_provider
|
||||
from pipecat.utils.run_context import set_current_run_id
|
||||
|
||||
|
||||
async def _fetch_telephony_cost(workflow_run) -> dict | None:
|
||||
"""Fetch telephony call cost. Returns a dict with cost_usd and provider_name, or None."""
|
||||
if (
|
||||
workflow_run.mode
|
||||
not in [WorkflowRunMode.TWILIO.value, WorkflowRunMode.VONAGE.value]
|
||||
or not workflow_run.cost_info
|
||||
):
|
||||
return None
|
||||
|
||||
call_id = workflow_run.cost_info.get("call_id")
|
||||
if not call_id:
|
||||
logger.warning(f"call_id not found in cost_info")
|
||||
return None
|
||||
|
||||
provider_name = workflow_run.mode.lower() if workflow_run.mode else ""
|
||||
|
||||
workflow = await db_client.get_workflow_by_id(workflow_run.workflow_id)
|
||||
if not workflow:
|
||||
logger.warning("Workflow not found for workflow run")
|
||||
raise Exception("Workflow not found")
|
||||
|
||||
provider = await get_telephony_provider(workflow.organization_id)
|
||||
call_cost_info = await provider.get_call_cost(call_id)
|
||||
|
||||
if call_cost_info.get("status") == "error":
|
||||
logger.error(
|
||||
f"Failed to fetch {provider_name} call cost: {call_cost_info.get('error')}"
|
||||
)
|
||||
return None
|
||||
|
||||
cost_usd = call_cost_info.get("cost_usd", 0.0)
|
||||
logger.info(
|
||||
f"{provider_name.title()} call cost: ${cost_usd:.6f} USD for call {call_id}"
|
||||
)
|
||||
return {"cost_usd": cost_usd, "provider_name": provider_name}
|
||||
|
||||
|
||||
async def _update_organization_usage(
|
||||
org, dograh_tokens: float, duration_seconds: float, charge_usd: float | None
|
||||
) -> None:
|
||||
"""Update organization usage after a workflow run."""
|
||||
org_id = org.id
|
||||
await db_client.update_usage_after_run(
|
||||
org_id, dograh_tokens, duration_seconds, charge_usd
|
||||
)
|
||||
if charge_usd is not None:
|
||||
logger.info(
|
||||
f"Updated organization usage with ${charge_usd:.2f} USD ({dograh_tokens} Dograh Tokens) and {duration_seconds}s duration for org {org_id}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Updated organization usage with {dograh_tokens} Dograh Tokens and {duration_seconds}s duration for org {org_id}"
|
||||
)
|
||||
|
||||
|
||||
async def calculate_workflow_run_cost(ctx, workflow_run_id: int):
|
||||
# Set the run_id in context variable for consistent logging format
|
||||
set_current_run_id(workflow_run_id)
|
||||
logger.debug("Calculating cost for workflow run")
|
||||
|
||||
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
|
||||
if not workflow_run:
|
||||
logger.warning("Workflow run not found")
|
||||
return
|
||||
|
||||
workflow_usage_info = workflow_run.usage_info
|
||||
if not workflow_usage_info:
|
||||
logger.warning("No usage info available for workflow run")
|
||||
return
|
||||
|
||||
try:
|
||||
# Calculate cost breakdown
|
||||
cost_breakdown = cost_calculator.calculate_total_cost(workflow_usage_info)
|
||||
|
||||
# Fetch telephony call cost
|
||||
try:
|
||||
telephony_cost = await _fetch_telephony_cost(workflow_run)
|
||||
if telephony_cost:
|
||||
telephony_cost_usd = telephony_cost["cost_usd"]
|
||||
provider_name = telephony_cost["provider_name"]
|
||||
cost_breakdown["telephony_call"] = telephony_cost_usd
|
||||
cost_breakdown[f"{provider_name}_call"] = telephony_cost_usd
|
||||
cost_breakdown["total"] = (
|
||||
float(cost_breakdown["total"]) + telephony_cost_usd
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch telephony call cost: {e}")
|
||||
# Don't fail the whole cost calculation if telephony API fails
|
||||
|
||||
# Store cost information back to the workflow run
|
||||
# We'll add the cost breakdown to the workflow run
|
||||
# Convert USD to Dograh Tokens (1 cent = 1 token)
|
||||
dograh_tokens = round(float(cost_breakdown["total"]) * 100, 2)
|
||||
|
||||
# Get organization to check if it has USD pricing
|
||||
org = None
|
||||
charge_usd = None
|
||||
if (
|
||||
workflow_run.workflow
|
||||
and workflow_run.workflow.user
|
||||
and workflow_run.workflow.user.selected_organization_id
|
||||
):
|
||||
org = await db_client.get_organization_by_id(
|
||||
workflow_run.workflow.user.selected_organization_id
|
||||
)
|
||||
|
||||
# Calculate USD cost if organization has pricing configured
|
||||
if org and org.price_per_second_usd:
|
||||
duration_seconds = workflow_usage_info.get("call_duration_seconds", 0)
|
||||
charge_usd = duration_seconds * org.price_per_second_usd
|
||||
|
||||
cost_info = {
|
||||
**workflow_run.cost_info,
|
||||
"cost_breakdown": cost_breakdown,
|
||||
"total_cost_usd": float(cost_breakdown["total"]),
|
||||
"dograh_token_usage": dograh_tokens,
|
||||
"calculated_at": workflow_run.created_at.isoformat(),
|
||||
"call_duration_seconds": workflow_usage_info["call_duration_seconds"],
|
||||
}
|
||||
|
||||
# Add USD cost if available
|
||||
if charge_usd is not None:
|
||||
cost_info["charge_usd"] = charge_usd
|
||||
cost_info["price_per_second_usd"] = org.price_per_second_usd
|
||||
|
||||
# Update workflow run with cost information
|
||||
await db_client.update_workflow_run(run_id=workflow_run_id, cost_info=cost_info)
|
||||
|
||||
# Update organization usage if applicable
|
||||
if org:
|
||||
try:
|
||||
duration_seconds = workflow_usage_info.get("call_duration_seconds", 0)
|
||||
await _update_organization_usage(
|
||||
org, dograh_tokens, duration_seconds, charge_usd
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to update organization usage for org {org.id}: {e}"
|
||||
)
|
||||
# Don't fail the whole task if usage update fails
|
||||
|
||||
logger.info(
|
||||
f"Calculated cost for workflow run: ${cost_breakdown['total']:.6f} USD ({dograh_tokens} Dograh Tokens)"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating cost for workflow run: {e}")
|
||||
raise
|
||||
Loading…
Add table
Add a link
Reference in a new issue