mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
feat: add qa node in workflow builder (#172)
* feat: add qa node in workflow builder * feat: add qa analysis token usage in usage_info * fix: mask the API key in QA node * feat: add advanced configuration in QA node
This commit is contained in:
parent
f1f4830012
commit
a836825b83
30 changed files with 1619 additions and 265 deletions
|
|
@ -12,15 +12,16 @@ References:
|
|||
"""
|
||||
|
||||
import os
|
||||
|
||||
# Load environment variables before importing anything else
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
# Load environment variables before importing anything else
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load .env.test from api directory for test configuration
|
||||
env_path = Path(__file__).parent / ".env.test"
|
||||
# Load .env.test before importing api.constants (which reads DATABASE_URL at import time)
|
||||
env_path = Path(__file__).resolve().parent / ".env.test"
|
||||
load_dotenv(env_path)
|
||||
|
||||
import logging
|
||||
|
|
@ -29,6 +30,8 @@ import sys
|
|||
import loguru
|
||||
import pytest
|
||||
|
||||
from api.constants import APP_ROOT_DIR # noqa: E402
|
||||
|
||||
|
||||
def setup_test_logging():
|
||||
"""Configure logging for tests using LOG_LEVEL from .env.test"""
|
||||
|
|
@ -191,7 +194,7 @@ async def run_migrations(database_url: str):
|
|||
from alembic.config import Config
|
||||
|
||||
# Get alembic.ini path
|
||||
alembic_ini_path = Path(__file__).parent / "alembic.ini"
|
||||
alembic_ini_path = APP_ROOT_DIR / "alembic.ini"
|
||||
|
||||
# Create alembic config
|
||||
alembic_cfg = Config(str(alembic_ini_path))
|
||||
|
|
|
|||
|
|
@ -15,7 +15,11 @@ VOICEMAIL_RECORDING_DURATION = 5.0
|
|||
|
||||
# Configuration constants
|
||||
ENABLE_TRACING = os.getenv("ENABLE_TRACING", "false").lower() == "true"
|
||||
ENABLE_RNNOISE = os.getenv("ENABLE_RNNOISE", "false").lower() == "true"
|
||||
|
||||
# Langfuse Configuration
|
||||
LANGFUSE_HOST = os.getenv("LANGFUSE_HOST")
|
||||
LANGFUSE_PUBLIC_KEY = os.getenv("LANGFUSE_PUBLIC_KEY")
|
||||
LANGFUSE_SECRET_KEY = os.getenv("LANGFUSE_SECRET_KEY")
|
||||
|
||||
# URLs for deployment
|
||||
BACKEND_API_ENDPOINT = os.getenv("BACKEND_API_ENDPOINT", "http://localhost:8000")
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from sqlalchemy import func
|
||||
|
|
@ -180,10 +179,6 @@ class WorkflowRunClient(BaseDBClient):
|
|||
"cost_info": run.cost_info,
|
||||
"initial_context": run.initial_context,
|
||||
"gathered_context": run.gathered_context,
|
||||
"admin_comment": (run.annotations or {}).get("admin_comment"),
|
||||
"admin_comment_ts": (run.annotations or {}).get(
|
||||
"admin_comment_ts"
|
||||
),
|
||||
"created_at": run.created_at,
|
||||
}
|
||||
)
|
||||
|
|
@ -321,6 +316,7 @@ class WorkflowRunClient(BaseDBClient):
|
|||
gathered_context: dict | None = None,
|
||||
logs: dict | None = None,
|
||||
state: str | None = None,
|
||||
annotations: dict | None = None,
|
||||
) -> WorkflowRunModel:
|
||||
async with self.async_session() as session:
|
||||
# Use SELECT FOR UPDATE to lock the row during the update
|
||||
|
|
@ -353,6 +349,8 @@ class WorkflowRunClient(BaseDBClient):
|
|||
if logs:
|
||||
# Lets merge the incoming logs key with existing ones
|
||||
run.logs = {**run.logs, **logs}
|
||||
if annotations:
|
||||
run.annotations = {**run.annotations, **annotations}
|
||||
if is_completed:
|
||||
run.is_completed = is_completed
|
||||
if state:
|
||||
|
|
@ -365,39 +363,6 @@ class WorkflowRunClient(BaseDBClient):
|
|||
await session.refresh(run)
|
||||
return run
|
||||
|
||||
async def update_admin_comment(
|
||||
self, run_id: int, admin_comment: str
|
||||
) -> WorkflowRunModel:
|
||||
"""Update (or create) the admin comment inside the ``annotations`` JSON column.
|
||||
|
||||
The comment is stored under the key ``admin_comment`` so we do not
|
||||
overwrite any other existing annotations that may be present.
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(WorkflowRunModel).where(WorkflowRunModel.id == run_id)
|
||||
)
|
||||
run = result.scalars().first()
|
||||
if run is None:
|
||||
raise ValueError(f"Workflow run with ID {run_id} not found")
|
||||
|
||||
# Ensure we never mutate a shared dict between instances
|
||||
current_annotations = dict(run.annotations or {})
|
||||
current_annotations["admin_comment"] = admin_comment
|
||||
|
||||
current_annotations["admin_comment_ts"] = datetime.now(
|
||||
timezone.utc
|
||||
).isoformat()
|
||||
run.annotations = current_annotations
|
||||
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
await session.refresh(run)
|
||||
return run
|
||||
|
||||
async def get_workflow_run_with_context(
|
||||
self, workflow_run_id: int
|
||||
) -> Tuple[Optional[WorkflowRunModel], Optional[int]]:
|
||||
|
|
|
|||
|
|
@ -45,8 +45,6 @@ class SuperuserWorkflowRunResponse(BaseModel):
|
|||
cost_info: Optional[dict]
|
||||
initial_context: Optional[dict]
|
||||
gathered_context: Optional[dict]
|
||||
admin_comment: Optional[str]
|
||||
admin_comment_ts: Optional[datetime]
|
||||
created_at: datetime
|
||||
|
||||
|
||||
|
|
@ -151,49 +149,3 @@ async def get_workflow_runs(
|
|||
limit=limit,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
|
||||
|
||||
# ------------------ Admin Comment ------------------
|
||||
|
||||
|
||||
class AdminCommentRequest(BaseModel):
|
||||
admin_comment: str
|
||||
|
||||
|
||||
class AdminCommentResponse(BaseModel):
|
||||
success: bool
|
||||
admin_comment: str
|
||||
admin_comment_ts: datetime
|
||||
|
||||
|
||||
# ------------------ Routes ------------------
|
||||
|
||||
|
||||
@router.post("/workflow-runs/{run_id}/comment", response_model=AdminCommentResponse)
|
||||
async def set_admin_comment(
|
||||
run_id: int,
|
||||
request: AdminCommentRequest,
|
||||
user: UserModel = Depends(get_superuser),
|
||||
):
|
||||
"""Add or update an *admin-only* comment for a workflow run.
|
||||
|
||||
The comment is stored inside the ``annotations`` JSON column under the
|
||||
``admin_comment`` key so that it does not interfere with any other
|
||||
annotations recorded by the system.
|
||||
"""
|
||||
|
||||
await db_client.update_admin_comment(
|
||||
run_id=run_id, admin_comment=request.admin_comment
|
||||
)
|
||||
|
||||
# Fetch the updated run to get the timestamp from annotations
|
||||
updated_run = await db_client.get_workflow_run_by_id(run_id)
|
||||
admin_comment_ts = None
|
||||
if updated_run and updated_run.annotations:
|
||||
admin_comment_ts = updated_run.annotations.get("admin_comment_ts")
|
||||
|
||||
return AdminCommentResponse(
|
||||
success=True,
|
||||
admin_comment=request.admin_comment,
|
||||
admin_comment_ts=admin_comment_ts,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -15,6 +15,10 @@ from api.db.workflow_template_client import WorkflowTemplateClient
|
|||
from api.enums import CallType
|
||||
from api.schemas.workflow import WorkflowRunResponseSchema
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.configuration.masking import (
|
||||
mask_workflow_definition,
|
||||
merge_workflow_api_keys,
|
||||
)
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
from api.services.workflow.dto import ReactFlowDTO
|
||||
from api.services.workflow.errors import ItemKind, WorkflowError
|
||||
|
|
@ -273,7 +277,9 @@ async def create_workflow(
|
|||
"name": workflow.name,
|
||||
"status": workflow.status,
|
||||
"created_at": workflow.created_at,
|
||||
"workflow_definition": workflow.workflow_definition_with_fallback,
|
||||
"workflow_definition": mask_workflow_definition(
|
||||
workflow.workflow_definition_with_fallback
|
||||
),
|
||||
"current_definition_id": workflow.current_definition_id,
|
||||
"template_context_variables": workflow.template_context_variables,
|
||||
"call_disposition_codes": workflow.call_disposition_codes,
|
||||
|
|
@ -351,7 +357,9 @@ async def create_workflow_from_template(
|
|||
"name": workflow.name,
|
||||
"status": workflow.status,
|
||||
"created_at": workflow.created_at,
|
||||
"workflow_definition": workflow.workflow_definition_with_fallback,
|
||||
"workflow_definition": mask_workflow_definition(
|
||||
workflow.workflow_definition_with_fallback
|
||||
),
|
||||
"current_definition_id": workflow.current_definition_id,
|
||||
"template_context_variables": workflow.template_context_variables,
|
||||
"call_disposition_codes": workflow.call_disposition_codes,
|
||||
|
|
@ -462,7 +470,9 @@ async def get_workflow(
|
|||
"name": workflow.name,
|
||||
"status": workflow.status,
|
||||
"created_at": workflow.created_at,
|
||||
"workflow_definition": workflow.workflow_definition_with_fallback,
|
||||
"workflow_definition": mask_workflow_definition(
|
||||
workflow.workflow_definition_with_fallback
|
||||
),
|
||||
"current_definition_id": workflow.current_definition_id,
|
||||
"template_context_variables": workflow.template_context_variables,
|
||||
"call_disposition_codes": workflow.call_disposition_codes,
|
||||
|
|
@ -512,7 +522,9 @@ async def update_workflow_status(
|
|||
"name": workflow.name,
|
||||
"status": workflow.status,
|
||||
"created_at": workflow.created_at,
|
||||
"workflow_definition": workflow.workflow_definition_with_fallback,
|
||||
"workflow_definition": mask_workflow_definition(
|
||||
workflow.workflow_definition_with_fallback
|
||||
),
|
||||
"current_definition_id": workflow.current_definition_id,
|
||||
"template_context_variables": workflow.template_context_variables,
|
||||
"call_disposition_codes": workflow.call_disposition_codes,
|
||||
|
|
@ -545,18 +557,30 @@ async def update_workflow(
|
|||
HTTPException: If the workflow is not found or if there's a database error
|
||||
"""
|
||||
try:
|
||||
# Restore real API keys where the incoming definition has masked placeholders
|
||||
workflow_definition = request.workflow_definition
|
||||
if workflow_definition:
|
||||
existing_workflow = await db_client.get_workflow(
|
||||
workflow_id, organization_id=user.selected_organization_id
|
||||
)
|
||||
if existing_workflow:
|
||||
workflow_definition = merge_workflow_api_keys(
|
||||
workflow_definition,
|
||||
existing_workflow.workflow_definition_with_fallback,
|
||||
)
|
||||
|
||||
workflow = await db_client.update_workflow(
|
||||
workflow_id=workflow_id,
|
||||
name=request.name,
|
||||
workflow_definition=request.workflow_definition,
|
||||
workflow_definition=workflow_definition,
|
||||
template_context_variables=request.template_context_variables,
|
||||
workflow_configurations=request.workflow_configurations,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
|
||||
# Sync agent triggers if workflow definition was updated
|
||||
if request.workflow_definition:
|
||||
trigger_paths = extract_trigger_paths(request.workflow_definition)
|
||||
if workflow_definition:
|
||||
trigger_paths = extract_trigger_paths(workflow_definition)
|
||||
await db_client.sync_triggers_for_workflow(
|
||||
workflow_id=workflow.id,
|
||||
organization_id=user.selected_organization_id,
|
||||
|
|
@ -568,7 +592,9 @@ async def update_workflow(
|
|||
"name": workflow.name,
|
||||
"status": workflow.status,
|
||||
"created_at": workflow.created_at,
|
||||
"workflow_definition": workflow.workflow_definition_with_fallback,
|
||||
"workflow_definition": mask_workflow_definition(
|
||||
workflow.workflow_definition_with_fallback
|
||||
),
|
||||
"current_definition_id": workflow.current_definition_id,
|
||||
"template_context_variables": workflow.template_context_variables,
|
||||
"call_disposition_codes": workflow.call_disposition_codes,
|
||||
|
|
@ -798,7 +824,9 @@ async def duplicate_workflow_template(
|
|||
"name": workflow.name,
|
||||
"status": workflow.status,
|
||||
"created_at": workflow.created_at,
|
||||
"workflow_definition": workflow.workflow_definition_with_fallback,
|
||||
"workflow_definition": mask_workflow_definition(
|
||||
workflow.workflow_definition_with_fallback
|
||||
),
|
||||
"current_definition_id": workflow.current_definition_id,
|
||||
"template_context_variables": workflow.template_context_variables,
|
||||
"call_disposition_codes": workflow.call_disposition_codes,
|
||||
|
|
|
|||
|
|
@ -68,3 +68,65 @@ def mask_user_config(config: UserConfiguration) -> Dict[str, Any]:
|
|||
"test_phone_number": config.test_phone_number,
|
||||
"timezone": config.timezone,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Workflow definition helpers – mask / merge QA-node API keys
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_QA_API_KEY_FIELD = "qa_api_key"
|
||||
|
||||
|
||||
def mask_workflow_definition(workflow_definition: Optional[Dict]) -> Optional[Dict]:
|
||||
"""Return a *shallow copy* of *workflow_definition* with QA-node API keys masked."""
|
||||
if not workflow_definition:
|
||||
return workflow_definition
|
||||
|
||||
import copy
|
||||
|
||||
masked = copy.deepcopy(workflow_definition)
|
||||
for node in masked.get("nodes", []):
|
||||
if node.get("type") != "qa":
|
||||
continue
|
||||
data = node.get("data", {})
|
||||
raw_key = data.get(_QA_API_KEY_FIELD)
|
||||
if raw_key:
|
||||
data[_QA_API_KEY_FIELD] = mask_key(raw_key)
|
||||
return masked
|
||||
|
||||
|
||||
def merge_workflow_api_keys(
|
||||
incoming_definition: Optional[Dict], existing_definition: Optional[Dict]
|
||||
) -> Optional[Dict]:
|
||||
"""Preserve real QA-node API keys when the incoming value is a masked placeholder.
|
||||
|
||||
For each QA node in *incoming_definition*, if its ``qa_api_key`` equals
|
||||
the masked form of the corresponding node in *existing_definition*, the
|
||||
real key is restored so it is never lost.
|
||||
"""
|
||||
if not incoming_definition or not existing_definition:
|
||||
return incoming_definition
|
||||
|
||||
# Build lookup: node-id → data for existing QA nodes
|
||||
existing_qa: Dict[str, Dict] = {}
|
||||
for node in existing_definition.get("nodes", []):
|
||||
if node.get("type") == "qa":
|
||||
existing_qa[node["id"]] = node.get("data", {})
|
||||
|
||||
for node in incoming_definition.get("nodes", []):
|
||||
if node.get("type") != "qa":
|
||||
continue
|
||||
data = node.get("data", {})
|
||||
incoming_key = data.get(_QA_API_KEY_FIELD)
|
||||
if not incoming_key:
|
||||
continue
|
||||
|
||||
old_data = existing_qa.get(node["id"])
|
||||
if not old_data:
|
||||
continue
|
||||
|
||||
old_key = old_data.get(_QA_API_KEY_FIELD, "")
|
||||
if old_key and is_mask_of(incoming_key, old_key):
|
||||
data[_QA_API_KEY_FIELD] = old_key
|
||||
|
||||
return incoming_definition
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import asyncio
|
|||
from typing import Dict, Set
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.utils import mix_audio
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
|
|
|
|||
|
|
@ -216,9 +216,8 @@ def register_event_handlers(
|
|||
except Exception as e:
|
||||
logger.error(f"Error preparing buffers for S3 upload: {e}", exc_info=True)
|
||||
|
||||
await enqueue_job(FunctionNames.CALCULATE_WORKFLOW_RUN_COST, workflow_run_id)
|
||||
|
||||
# Combined task: uploads artifacts then runs integrations sequentially
|
||||
# Combined task: uploads artifacts, runs integrations (including QA),
|
||||
# then calculates cost (so QA token usage is captured in usage_info)
|
||||
await enqueue_job(
|
||||
FunctionNames.PROCESS_WORKFLOW_COMPLETION,
|
||||
workflow_run_id,
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ from api.services.pipecat.service_factory import (
|
|||
create_stt_service,
|
||||
create_tts_service,
|
||||
)
|
||||
from api.services.pipecat.tracing_config import setup_pipeline_tracing
|
||||
from api.services.pipecat.tracing_config import setup_tracing_exporter
|
||||
from api.services.pipecat.transport_setup import (
|
||||
create_ari_transport,
|
||||
create_cloudonix_transport,
|
||||
|
|
@ -80,7 +80,7 @@ from pipecat.utils.run_context import set_current_run_id
|
|||
from pipecat.utils.tracing.context_registry import ContextProviderRegistry
|
||||
|
||||
# Setup tracing if enabled
|
||||
setup_pipeline_tracing()
|
||||
setup_tracing_exporter()
|
||||
|
||||
|
||||
async def run_pipeline_twilio(
|
||||
|
|
|
|||
|
|
@ -4,9 +4,16 @@ import os
|
|||
from loguru import logger
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
||||
|
||||
from api.constants import ENABLE_TRACING
|
||||
from api.constants import (
|
||||
ENABLE_TRACING,
|
||||
LANGFUSE_HOST,
|
||||
LANGFUSE_PUBLIC_KEY,
|
||||
LANGFUSE_SECRET_KEY,
|
||||
)
|
||||
from pipecat.utils.tracing.setup import setup_tracing
|
||||
|
||||
_tracing_initialized = False
|
||||
|
||||
|
||||
def is_tracing_enabled():
|
||||
"""Check if tracing should be enabled based on ENABLE_TRACING flag."""
|
||||
|
|
@ -15,28 +22,31 @@ def is_tracing_enabled():
|
|||
return ENABLE_TRACING
|
||||
|
||||
|
||||
def setup_pipeline_tracing():
|
||||
"""Setup tracing for the pipeline if enabled"""
|
||||
if is_tracing_enabled():
|
||||
# Only set up Langfuse if credentials are provided
|
||||
langfuse_host = os.environ.get("LANGFUSE_HOST")
|
||||
langfuse_public_key = os.environ.get("LANGFUSE_PUBLIC_KEY")
|
||||
langfuse_secret_key = os.environ.get("LANGFUSE_SECRET_KEY")
|
||||
def setup_tracing_exporter():
|
||||
"""Setup the OTEL tracing exporter for Langfuse if enabled.
|
||||
|
||||
if not all([langfuse_host, langfuse_public_key, langfuse_secret_key]):
|
||||
Idempotent — safe to call from both the pipeline process and the ARQ worker.
|
||||
"""
|
||||
global _tracing_initialized
|
||||
if _tracing_initialized:
|
||||
return
|
||||
|
||||
if is_tracing_enabled():
|
||||
if not all([LANGFUSE_HOST, LANGFUSE_PUBLIC_KEY, LANGFUSE_SECRET_KEY]):
|
||||
logger.warning(
|
||||
"Warning: ENABLE_TRACING is true but Langfuse credentials are not configured. Tracing disabled."
|
||||
)
|
||||
return
|
||||
|
||||
LANGFUSE_AUTH = base64.b64encode(
|
||||
f"{langfuse_public_key}:{langfuse_secret_key}".encode()
|
||||
langfuse_auth = base64.b64encode(
|
||||
f"{LANGFUSE_PUBLIC_KEY}:{LANGFUSE_SECRET_KEY}".encode()
|
||||
).decode()
|
||||
|
||||
os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"] = f"{langfuse_host}/api/public/otel"
|
||||
os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"] = f"{LANGFUSE_HOST}/api/public/otel"
|
||||
os.environ["OTEL_EXPORTER_OTLP_HEADERS"] = (
|
||||
f"Authorization=Basic {LANGFUSE_AUTH}"
|
||||
f"Authorization=Basic {langfuse_auth}"
|
||||
)
|
||||
|
||||
otlp_exporter = OTLPSpanExporter()
|
||||
setup_tracing(service_name="dograh-pipeline", exporter=otlp_exporter)
|
||||
_tracing_initialized = True
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ from api.db import db_client
|
|||
from api.enums import WorkflowRunMode
|
||||
from api.services.pricing.cost_calculator import cost_calculator
|
||||
from api.services.telephony.factory import get_telephony_provider
|
||||
from pipecat.utils.run_context import set_current_run_id
|
||||
|
||||
|
||||
async def _fetch_telephony_cost(workflow_run) -> dict | None:
|
||||
|
|
@ -62,9 +61,7 @@ async def _update_organization_usage(
|
|||
)
|
||||
|
||||
|
||||
async def calculate_workflow_run_cost(ctx, workflow_run_id: int):
|
||||
# Set the run_id in context variable for consistent logging format
|
||||
set_current_run_id(workflow_run_id)
|
||||
async def calculate_workflow_run_cost(workflow_run_id: int):
|
||||
logger.debug("Calculating cost for workflow run")
|
||||
|
||||
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
|
||||
|
|
@ -97,7 +94,6 @@ async def calculate_workflow_run_cost(ctx, workflow_run_id: int):
|
|||
# Don't fail the whole cost calculation if telephony API fails
|
||||
|
||||
# Store cost information back to the workflow run
|
||||
# We'll add the cost breakdown to the workflow run
|
||||
# Convert USD to Dograh Tokens (1 cent = 1 token)
|
||||
dograh_tokens = round(float(cost_breakdown["total"]) * 100, 2)
|
||||
|
||||
360
api/services/qa_analysis.py
Normal file
360
api/services/qa_analysis.py
Normal file
|
|
@ -0,0 +1,360 @@
|
|||
"""QA analysis service for post-call quality assessment.
|
||||
|
||||
Runs LLM-based analysis on call transcripts, traces under the same
|
||||
Langfuse trace as the conversation, and returns structured results.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.models import WorkflowRunModel
|
||||
from api.services.gen_ai.json_parser import parse_llm_json
|
||||
from pipecat.utils.enums import RealtimeFeedbackType
|
||||
|
||||
|
||||
def build_conversation_structure(logs: list[dict]) -> list[dict]:
|
||||
"""Transform raw call logs into a conversation structure for LLM QA analysis."""
|
||||
if not logs:
|
||||
return []
|
||||
|
||||
start_time = datetime.fromisoformat(logs[0]["timestamp"])
|
||||
|
||||
conversation = []
|
||||
for event in logs:
|
||||
if event["type"] == RealtimeFeedbackType.BOT_TEXT.value:
|
||||
speaker = "assistant"
|
||||
utterance_text = event["payload"]["text"]
|
||||
event_time = datetime.fromisoformat(event["payload"]["timestamp"])
|
||||
elif event["type"] == RealtimeFeedbackType.USER_TRANSCRIPTION.value and event[
|
||||
"payload"
|
||||
].get("final", False):
|
||||
speaker = "user"
|
||||
utterance_text = event["payload"]["text"]
|
||||
event_time = datetime.fromisoformat(event["payload"]["timestamp"])
|
||||
else:
|
||||
continue
|
||||
|
||||
time_from_start = (event_time - start_time).total_seconds()
|
||||
|
||||
conversation.append(
|
||||
{
|
||||
"time_from_start_seconds": round(time_from_start, 2),
|
||||
"speaker": speaker,
|
||||
"text": utterance_text,
|
||||
"node_name": event.get("node_name", ""),
|
||||
"turn": event.get("turn", 0),
|
||||
}
|
||||
)
|
||||
|
||||
return conversation
|
||||
|
||||
|
||||
def format_transcript(conversation: list[dict]) -> str:
|
||||
"""Format conversation structure into a readable transcript string for the LLM."""
|
||||
lines = []
|
||||
for entry in conversation:
|
||||
lines.append(
|
||||
f"[{entry['time_from_start_seconds']:.1f}s] "
|
||||
f"{entry['speaker']}: {entry['text']}"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def compute_call_metrics(
|
||||
logs: list[dict], call_duration_seconds: float | None = None
|
||||
) -> dict:
|
||||
"""Pre-compute quantitative metrics from raw call logs."""
|
||||
latencies = []
|
||||
ttfb_values = []
|
||||
|
||||
for event in logs:
|
||||
if event["type"] == RealtimeFeedbackType.LATENCY_MEASURED.value:
|
||||
latencies.append(event["payload"]["latency_seconds"])
|
||||
elif event["type"] == RealtimeFeedbackType.TTFB_METRIC.value:
|
||||
ttfb_values.append(event["payload"]["ttfb_seconds"])
|
||||
|
||||
turns = set()
|
||||
for event in logs:
|
||||
if event["type"] in (
|
||||
RealtimeFeedbackType.USER_TRANSCRIPTION.value,
|
||||
RealtimeFeedbackType.BOT_TEXT.value,
|
||||
):
|
||||
turns.add(event.get("turn", 0))
|
||||
|
||||
return {
|
||||
"call_duration_seconds": call_duration_seconds,
|
||||
"num_turns": len(turns),
|
||||
"avg_latency_seconds": (
|
||||
round(sum(latencies) / len(latencies), 2) if latencies else None
|
||||
),
|
||||
"avg_ttfb_seconds": (
|
||||
round(sum(ttfb_values) / len(ttfb_values), 2) if ttfb_values else None
|
||||
),
|
||||
"max_latency_seconds": round(max(latencies), 2) if latencies else None,
|
||||
}
|
||||
|
||||
|
||||
def _extract_trace_id(gathered_context: dict) -> str | None:
|
||||
"""Extract Langfuse trace_id from gathered_context trace_url.
|
||||
|
||||
URL format: https://langfuse.dograh.com/project/<project_id>/traces/<trace_id>
|
||||
"""
|
||||
trace_url = gathered_context.get("trace_url")
|
||||
if not trace_url:
|
||||
return None
|
||||
try:
|
||||
match = re.search(r"/traces/([a-fA-F0-9]+)$", trace_url)
|
||||
if match:
|
||||
return match.group(1)
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _provider_base_url(provider: str | None, endpoint: str = "") -> str | None:
|
||||
"""Return the base URL for a given LLM provider."""
|
||||
if provider == "openrouter":
|
||||
return "https://openrouter.ai/api/v1"
|
||||
if provider == "groq":
|
||||
return "https://api.groq.com/openai/v1"
|
||||
if provider == "google":
|
||||
return "https://generativelanguage.googleapis.com/v1beta/openai/"
|
||||
if provider == "azure":
|
||||
return endpoint or None
|
||||
return None
|
||||
|
||||
|
||||
async def _resolve_llm_config(
|
||||
qa_node_data: dict, workflow_run: WorkflowRunModel
|
||||
) -> tuple[str, str, str | None]:
|
||||
"""Resolve the LLM model, API key, and base URL for QA analysis.
|
||||
|
||||
If the QA node has its own LLM configuration (qa_use_workflow_llm=False),
|
||||
use those settings directly. Otherwise, fall back to the user's configured LLM.
|
||||
|
||||
Returns:
|
||||
(model, api_key, base_url) tuple
|
||||
"""
|
||||
if not qa_node_data.get("qa_use_workflow_llm", True):
|
||||
return (
|
||||
qa_node_data.get("qa_model"),
|
||||
qa_node_data.get("qa_api_key"),
|
||||
_provider_base_url(
|
||||
qa_node_data.get("qa_provider"),
|
||||
qa_node_data.get("qa_endpoint", ""),
|
||||
),
|
||||
)
|
||||
|
||||
# Fall back to user's configured LLM
|
||||
user_id = None
|
||||
if workflow_run.workflow and workflow_run.workflow.user:
|
||||
user_id = workflow_run.workflow.user.id
|
||||
|
||||
llm_config: dict = {}
|
||||
if user_id:
|
||||
user_configuration = await db_client.get_user_configurations(user_id)
|
||||
llm_config = user_configuration.model_dump(exclude_none=True).get("llm", {})
|
||||
|
||||
provider = llm_config.get("provider", "openai")
|
||||
api_key = llm_config.get("api_key", "")
|
||||
|
||||
qa_model = qa_node_data.get("qa_model", "default")
|
||||
if qa_model and qa_model != "default":
|
||||
model = qa_model
|
||||
else:
|
||||
model = llm_config.get("model", "gpt-4.1")
|
||||
|
||||
base_url = _provider_base_url(provider, llm_config.get("endpoint", ""))
|
||||
# For openrouter, prefer user-configured base_url if set
|
||||
if provider == "openrouter" and llm_config.get("base_url"):
|
||||
base_url = llm_config["base_url"]
|
||||
|
||||
return model, api_key, base_url
|
||||
|
||||
|
||||
async def run_qa_analysis(
|
||||
qa_node_data: dict[str, Any],
|
||||
workflow_run: WorkflowRunModel,
|
||||
workflow_run_id: int,
|
||||
) -> dict[str, Any]:
|
||||
"""Run QA analysis on a completed workflow run.
|
||||
|
||||
Args:
|
||||
qa_node_data: The QA node's data dict from workflow definition
|
||||
workflow_run: The workflow run model with logs and context
|
||||
workflow_run_id: The workflow run ID
|
||||
|
||||
Returns:
|
||||
Dict with tags, summary, score, raw_response
|
||||
"""
|
||||
# Extract transcript from logs
|
||||
logs = workflow_run.logs or {}
|
||||
rtf_events = logs.get("realtime_feedback_events", [])
|
||||
if not rtf_events:
|
||||
logger.warning(f"No realtime_feedback_events for run {workflow_run_id}")
|
||||
return {"error": "no_transcript", "tags": [], "summary": "", "score": None}
|
||||
|
||||
conversation = build_conversation_structure(rtf_events)
|
||||
transcript = format_transcript(conversation)
|
||||
if not transcript:
|
||||
logger.warning(f"Empty transcript for run {workflow_run_id}")
|
||||
return {"error": "empty_transcript", "tags": [], "summary": "", "score": None}
|
||||
|
||||
# Compute call metrics
|
||||
usage_info = workflow_run.usage_info or {}
|
||||
call_duration = usage_info.get("call_duration_seconds")
|
||||
metrics = compute_call_metrics(rtf_events, call_duration)
|
||||
|
||||
# Resolve LLM config
|
||||
system_prompt = qa_node_data.get("qa_system_prompt", "")
|
||||
if not system_prompt:
|
||||
logger.warning("No system prompt defined for QA Node")
|
||||
return {"error": "no_system_prompt", "tags": [], "summary": "", "score": None}
|
||||
|
||||
model, api_key, base_url = await _resolve_llm_config(qa_node_data, workflow_run)
|
||||
|
||||
if not api_key:
|
||||
logger.warning(
|
||||
f"No LLM API key configured for QA analysis on run {workflow_run_id}"
|
||||
)
|
||||
return {"error": "no_api_key", "tags": [], "summary": "", "score": None}
|
||||
|
||||
# Build messages
|
||||
system_content = system_prompt.replace("{metrics}", json.dumps(metrics, indent=2))
|
||||
messages = [
|
||||
{"role": "system", "content": system_content},
|
||||
{"role": "user", "content": f"## Transcript\n{transcript}"},
|
||||
]
|
||||
|
||||
# Call LLM
|
||||
client_kwargs: dict[str, Any] = {"api_key": api_key}
|
||||
if base_url:
|
||||
client_kwargs["base_url"] = base_url
|
||||
|
||||
client = AsyncOpenAI(**client_kwargs)
|
||||
|
||||
try:
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
)
|
||||
raw_response = response.choices[0].message.content
|
||||
except Exception as e:
|
||||
logger.error(f"QA LLM call failed for run {workflow_run_id}: {e}")
|
||||
return {"error": str(e), "tags": [], "summary": "", "score": None}
|
||||
|
||||
# Extract token usage from LLM response
|
||||
token_usage = None
|
||||
if response.usage:
|
||||
token_usage = {
|
||||
"prompt_tokens": response.usage.prompt_tokens or 0,
|
||||
"completion_tokens": response.usage.completion_tokens or 0,
|
||||
"total_tokens": response.usage.total_tokens or 0,
|
||||
"cache_read_input_tokens": getattr(
|
||||
response.usage, "cache_read_input_tokens", 0
|
||||
)
|
||||
or 0,
|
||||
"cache_creation_input_tokens": getattr(
|
||||
response.usage, "cache_creation_input_tokens", None
|
||||
),
|
||||
}
|
||||
|
||||
# Parse response
|
||||
result: dict[str, Any] = {"raw_response": raw_response, "model": model}
|
||||
if token_usage:
|
||||
result["token_usage"] = token_usage
|
||||
try:
|
||||
parsed = parse_llm_json(raw_response)
|
||||
result["tags"] = parsed.get("tags", [])
|
||||
result["summary"] = parsed.get("summary", "")
|
||||
result["score"] = parsed.get("call_quality_score")
|
||||
result["overall_sentiment"] = parsed.get("overall_sentiment")
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
result["tags"] = []
|
||||
result["summary"] = ""
|
||||
result["score"] = None
|
||||
|
||||
# Langfuse tracing — attach QA generation to the conversation trace
|
||||
_add_qa_span_to_conversation_trace(
|
||||
workflow_run, model, messages, raw_response, result
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _add_qa_span_to_conversation_trace(
|
||||
workflow_run: WorkflowRunModel,
|
||||
model: str,
|
||||
messages: list[dict],
|
||||
raw_response: str,
|
||||
result: dict,
|
||||
):
|
||||
"""Attach the QA generation to the existing Langfuse conversation trace.
|
||||
|
||||
Uses OpenTelemetry directly to create a child span under the existing trace,
|
||||
matching the same attribute format used by the pipecat pipeline (gen_ai.*).
|
||||
"""
|
||||
try:
|
||||
from opentelemetry import trace as otel_trace
|
||||
from opentelemetry.trace import (
|
||||
NonRecordingSpan,
|
||||
SpanContext,
|
||||
TraceFlags,
|
||||
set_span_in_context,
|
||||
)
|
||||
|
||||
from api.services.pipecat.tracing_config import (
|
||||
is_tracing_enabled,
|
||||
setup_tracing_exporter,
|
||||
)
|
||||
from pipecat.utils.tracing.service_attributes import add_llm_span_attributes
|
||||
|
||||
if not is_tracing_enabled():
|
||||
return
|
||||
|
||||
# Ensure the OTEL exporter is initialized (idempotent — no-op if
|
||||
# already called in the pipeline process, required in the ARQ worker).
|
||||
setup_tracing_exporter()
|
||||
|
||||
gathered_context = workflow_run.gathered_context or {}
|
||||
trace_id = _extract_trace_id(gathered_context)
|
||||
if not trace_id:
|
||||
logger.debug("No trace_id found, skipping Langfuse QA trace")
|
||||
return
|
||||
|
||||
tracer = otel_trace.get_tracer("pipecat")
|
||||
|
||||
# Create a remote parent context from the existing trace ID
|
||||
parent_span_ctx = SpanContext(
|
||||
trace_id=int(trace_id, 16),
|
||||
span_id=0x1, # dummy parent span id
|
||||
is_remote=True,
|
||||
trace_flags=TraceFlags(0x01),
|
||||
)
|
||||
parent_ctx = set_span_in_context(NonRecordingSpan(parent_span_ctx))
|
||||
|
||||
# Create a child span under the existing trace
|
||||
with tracer.start_as_current_span(
|
||||
"qa-analysis",
|
||||
context=parent_ctx,
|
||||
) as span:
|
||||
add_llm_span_attributes(
|
||||
span,
|
||||
service_name="OpenAILLMService",
|
||||
model=model,
|
||||
operation_name="qa-analysis",
|
||||
messages=messages,
|
||||
output=raw_response,
|
||||
stream=False,
|
||||
parameters={"temperature": 0},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to trace QA to Langfuse: {e}")
|
||||
|
|
@ -11,6 +11,7 @@ class NodeType(str, Enum):
|
|||
globalNode = "globalNode"
|
||||
trigger = "trigger"
|
||||
webhook = "webhook"
|
||||
qa = "qa"
|
||||
|
||||
|
||||
class Position(BaseModel):
|
||||
|
|
@ -68,6 +69,13 @@ class NodeDataDTO(BaseModel):
|
|||
custom_headers: Optional[list[CustomHeaderDTO]] = None
|
||||
payload_template: Optional[dict] = None
|
||||
retry_config: Optional[RetryConfigDTO] = None
|
||||
# QA node specific fields
|
||||
qa_enabled: bool = True
|
||||
qa_system_prompt: Optional[str] = None
|
||||
qa_model: Optional[str] = None
|
||||
qa_min_call_duration: int = 15
|
||||
qa_voicemail_calls: bool = False
|
||||
qa_sample_rate: int = 100
|
||||
|
||||
|
||||
class RFNodeDTO(BaseModel):
|
||||
|
|
@ -78,8 +86,8 @@ class RFNodeDTO(BaseModel):
|
|||
|
||||
@model_validator(mode="after")
|
||||
def _validate_prompt_required(self):
|
||||
"""Require prompt for all node types except trigger and webhook."""
|
||||
if self.type not in (NodeType.trigger, NodeType.webhook):
|
||||
"""Require prompt for all node types except trigger, webhook, and qa."""
|
||||
if self.type not in (NodeType.trigger, NodeType.webhook, NodeType.qa):
|
||||
if not self.data.prompt or len(self.data.prompt.strip()) == 0:
|
||||
raise ValueError("Prompt is required for non-trigger nodes")
|
||||
return self
|
||||
|
|
|
|||
|
|
@ -15,8 +15,6 @@ setup_logging()
|
|||
from arq import create_pool
|
||||
from arq.connections import ArqRedis, RedisSettings
|
||||
|
||||
from api.tasks.workflow_run_cost import calculate_workflow_run_cost
|
||||
|
||||
parsed_url = urlparse(REDIS_URL)
|
||||
|
||||
# Check if we're using TLS (rediss://)
|
||||
|
|
@ -55,7 +53,6 @@ from api.tasks.s3_upload import (
|
|||
|
||||
class WorkerSettings:
|
||||
functions = [
|
||||
calculate_workflow_run_cost,
|
||||
run_integrations_post_workflow_run,
|
||||
upload_voicemail_audio_to_s3,
|
||||
process_workflow_completion,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
class FunctionNames:
|
||||
CALCULATE_WORKFLOW_RUN_COST = "calculate_workflow_run_cost"
|
||||
RUN_INTEGRATIONS_POST_WORKFLOW_RUN = "run_integrations_post_workflow_run"
|
||||
PROCESS_WORKFLOW_COMPLETION = "process_workflow_completion"
|
||||
UPLOAD_VOICEMAIL_AUDIO_TO_S3 = "upload_voicemail_audio_to_s3"
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
"""Execute webhook integrations after workflow run completion."""
|
||||
"""Execute integrations (QA analysis, webhooks) after workflow run completion."""
|
||||
|
||||
import random
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import httpx
|
||||
|
|
@ -8,22 +9,141 @@ from loguru import logger
|
|||
from api.constants import BACKEND_API_ENDPOINT
|
||||
from api.db import db_client
|
||||
from api.db.models import WorkflowRunModel
|
||||
from api.services.qa_analysis import run_qa_analysis
|
||||
from api.utils.credential_auth import build_auth_header
|
||||
from api.utils.template_renderer import render_template
|
||||
from pipecat.utils.enums import EndTaskReason
|
||||
from pipecat.utils.run_context import set_current_run_id
|
||||
|
||||
|
||||
def _should_skip_qa(
|
||||
node_data: dict,
|
||||
workflow_run: WorkflowRunModel,
|
||||
) -> str | None:
|
||||
"""Check whether QA analysis should be skipped for this call.
|
||||
|
||||
Returns a reason string if the call should be skipped, or None if it should proceed.
|
||||
"""
|
||||
# Check minimum call duration
|
||||
min_duration = node_data.get("qa_min_call_duration", 15)
|
||||
usage_info = workflow_run.usage_info or {}
|
||||
call_duration = usage_info.get("call_duration_seconds")
|
||||
if call_duration is not None and call_duration < min_duration:
|
||||
return f"call duration ({call_duration:.1f}s) below minimum ({min_duration}s)"
|
||||
|
||||
# Check voicemail calls
|
||||
qa_voicemail_calls = node_data.get("qa_voicemail_calls", False)
|
||||
if not qa_voicemail_calls:
|
||||
gathered_context = workflow_run.gathered_context or {}
|
||||
call_disposition = gathered_context.get("call_disposition", "")
|
||||
if call_disposition == EndTaskReason.VOICEMAIL_DETECTED.value:
|
||||
return "voicemail call and QA voicemail calls is disabled"
|
||||
|
||||
# Check sample rate
|
||||
sample_rate = node_data.get("qa_sample_rate", 100)
|
||||
if sample_rate < 100:
|
||||
roll = random.randint(1, 100)
|
||||
if roll > sample_rate:
|
||||
return f"excluded by sampling ({sample_rate}% sample rate, rolled {roll})"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def _run_qa_nodes(
|
||||
qa_nodes: list[dict],
|
||||
workflow_run: WorkflowRunModel,
|
||||
workflow_run_id: int,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run QA analysis for each enabled QA node and aggregate results.
|
||||
|
||||
Returns:
|
||||
Dict keyed by node ID with QA analysis results.
|
||||
"""
|
||||
results: Dict[str, Any] = {}
|
||||
|
||||
for node in qa_nodes:
|
||||
node_data = node.get("data", {})
|
||||
node_id = node.get("id", "unknown")
|
||||
node_name = node_data.get("name", "QA Analysis")
|
||||
|
||||
if not node_data.get("qa_enabled", True):
|
||||
logger.debug(f"QA node '{node_name}' is disabled, skipping")
|
||||
continue
|
||||
|
||||
skip_reason = _should_skip_qa(node_data, workflow_run)
|
||||
if skip_reason:
|
||||
logger.info(f"Skipping QA node '{node_name}' (#{node_id}): {skip_reason}")
|
||||
results[f"qa_{node_id}"] = {"skipped": True, "reason": skip_reason}
|
||||
continue
|
||||
|
||||
try:
|
||||
logger.info(f"Running QA analysis for node '{node_name}' (#{node_id})")
|
||||
result = await run_qa_analysis(node_data, workflow_run, workflow_run_id)
|
||||
results[f"qa_{node_id}"] = result
|
||||
logger.info(
|
||||
f"QA analysis complete for '{node_name}': "
|
||||
f"score={result.get('score')}, tags={len(result.get('tags', []))}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"QA analysis failed for node '{node_name}': {e}")
|
||||
results[f"qa_{node_id}"] = {"error": str(e)}
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def _update_usage_info_with_qa_tokens(
|
||||
workflow_run_id: int,
|
||||
workflow_run: WorkflowRunModel,
|
||||
qa_results: Dict[str, Any],
|
||||
) -> None:
|
||||
"""Add QA analysis LLM token usage to the workflow run's usage_info."""
|
||||
try:
|
||||
usage_info = dict(workflow_run.usage_info or {})
|
||||
llm_usage = dict(usage_info.get("llm", {}))
|
||||
|
||||
for _node_key, result in qa_results.items():
|
||||
token_usage = result.get("token_usage")
|
||||
model = result.get("model")
|
||||
if not token_usage or not model:
|
||||
continue
|
||||
|
||||
key = f"QAAnalysis|||{model}"
|
||||
if key in llm_usage:
|
||||
# Aggregate if multiple QA nodes use the same model
|
||||
existing = llm_usage[key]
|
||||
for field in (
|
||||
"prompt_tokens",
|
||||
"completion_tokens",
|
||||
"total_tokens",
|
||||
"cache_read_input_tokens",
|
||||
):
|
||||
existing[field] = (existing.get(field) or 0) + (
|
||||
token_usage.get(field) or 0
|
||||
)
|
||||
else:
|
||||
llm_usage[key] = token_usage
|
||||
|
||||
usage_info["llm"] = llm_usage
|
||||
await db_client.update_workflow_run(
|
||||
run_id=workflow_run_id, usage_info=usage_info
|
||||
)
|
||||
logger.info(f"Updated usage_info with QA token usage for run {workflow_run_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update usage_info with QA tokens: {e}")
|
||||
|
||||
|
||||
async def run_integrations_post_workflow_run(_ctx, workflow_run_id: int):
|
||||
"""
|
||||
Run webhook integrations after a workflow run completes.
|
||||
Run integrations after a workflow run completes.
|
||||
|
||||
This function:
|
||||
1. Gets the workflow run and its contexts
|
||||
2. Extracts webhook nodes from workflow definition
|
||||
3. Executes each enabled webhook node
|
||||
2. Runs QA analysis nodes (if any)
|
||||
3. Stores QA results in annotations
|
||||
4. Executes webhook nodes with QA results available in render context
|
||||
"""
|
||||
set_current_run_id(workflow_run_id)
|
||||
logger.info("Running webhook integrations for workflow run")
|
||||
logger.info("Running integrations for workflow run")
|
||||
|
||||
try:
|
||||
# Step 1: Get workflow run with full context
|
||||
|
|
@ -36,39 +156,61 @@ async def run_integrations_post_workflow_run(_ctx, workflow_run_id: int):
|
|||
return
|
||||
|
||||
if not organization_id:
|
||||
logger.warning("No organization found, skipping webhooks")
|
||||
logger.warning("No organization found, skipping integrations")
|
||||
return
|
||||
|
||||
# Step 2: Get workflow definition
|
||||
workflow_definition = workflow_run.workflow.workflow_definition_with_fallback
|
||||
if not workflow_definition:
|
||||
logger.debug("No workflow definition, skipping webhooks")
|
||||
logger.debug("No workflow definition, skipping integrations")
|
||||
return
|
||||
|
||||
# Step 3: Extract webhook nodes
|
||||
# Step 3: Extract integration nodes
|
||||
nodes = workflow_definition.get("nodes", [])
|
||||
qa_nodes = [n for n in nodes if n.get("type") == "qa"]
|
||||
webhook_nodes = [n for n in nodes if n.get("type") == "webhook"]
|
||||
|
||||
# Step 4: Generate public access token if webhooks exist or campaign_id is set
|
||||
has_campaign = workflow_run.campaign_id is not None
|
||||
if not webhook_nodes and not has_campaign:
|
||||
logger.debug("No webhook nodes and no campaign, skipping")
|
||||
if not webhook_nodes and not qa_nodes and not has_campaign:
|
||||
logger.debug("No integration nodes and no campaign, skipping")
|
||||
return
|
||||
|
||||
public_token = None
|
||||
if webhook_nodes or has_campaign:
|
||||
public_token = await db_client.ensure_public_access_token(workflow_run_id)
|
||||
|
||||
# Step 5: Run QA analysis before webhooks
|
||||
if qa_nodes:
|
||||
logger.info(f"Found {len(qa_nodes)} QA nodes to execute")
|
||||
qa_results = await _run_qa_nodes(qa_nodes, workflow_run, workflow_run_id)
|
||||
|
||||
if qa_results:
|
||||
await db_client.update_workflow_run(
|
||||
workflow_run_id, annotations=qa_results
|
||||
)
|
||||
|
||||
# Add QA token usage to workflow run's usage_info
|
||||
await _update_usage_info_with_qa_tokens(
|
||||
workflow_run_id, workflow_run, qa_results
|
||||
)
|
||||
|
||||
# Re-fetch workflow_run to get updated annotations
|
||||
workflow_run, _ = await db_client.get_workflow_run_with_context(
|
||||
workflow_run_id
|
||||
)
|
||||
|
||||
# Step 6: Execute webhooks
|
||||
if not webhook_nodes:
|
||||
logger.debug("No webhook nodes in workflow")
|
||||
return
|
||||
|
||||
logger.info(f"Found {len(webhook_nodes)} webhook nodes to execute")
|
||||
|
||||
# Step 5: Build render context
|
||||
# Step 7: Build render context (includes annotations from QA)
|
||||
render_context = _build_render_context(workflow_run, public_token)
|
||||
|
||||
# Step 6: Execute each webhook node
|
||||
# Step 8: Execute each webhook node
|
||||
for node in webhook_nodes:
|
||||
webhook_data = node.get("data", {})
|
||||
try:
|
||||
|
|
@ -84,7 +226,7 @@ async def run_integrations_post_workflow_run(_ctx, workflow_run_id: int):
|
|||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error running webhook integrations: {e}", exc_info=True)
|
||||
logger.error(f"Error running integrations: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
|
|
@ -110,6 +252,8 @@ def _build_render_context(
|
|||
"initial_context": workflow_run.initial_context or {},
|
||||
"gathered_context": workflow_run.gathered_context or {},
|
||||
"cost_info": workflow_run.usage_info or {},
|
||||
# Annotations (includes QA results)
|
||||
"annotations": workflow_run.annotations or {},
|
||||
}
|
||||
|
||||
# Add public download URLs if token is available
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from typing import Optional
|
|||
from loguru import logger
|
||||
|
||||
from api.db import db_client
|
||||
from api.services.pricing.workflow_run_cost import calculate_workflow_run_cost
|
||||
from api.services.storage import get_current_storage_backend, storage_fs
|
||||
from api.tasks.run_integrations import run_integrations_post_workflow_run
|
||||
from pipecat.utils.run_context import set_current_run_id
|
||||
|
|
@ -162,10 +163,16 @@ async def process_workflow_completion(
|
|||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up temp transcript file: {e}")
|
||||
|
||||
# Step 3: Run webhook integrations (after uploads are complete)
|
||||
# Step 3: Run integrations including QA analysis (after uploads are complete)
|
||||
try:
|
||||
await run_integrations_post_workflow_run(_ctx, workflow_run_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error running integrations for workflow {workflow_run_id}: {e}")
|
||||
|
||||
# Step 4: Calculate cost after integrations (so QA token usage is included)
|
||||
try:
|
||||
await calculate_workflow_run_cost(workflow_run_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating cost for workflow {workflow_run_id}: {e}")
|
||||
|
||||
logger.info(f"Completed workflow completion processing for run {workflow_run_id}")
|
||||
|
|
|
|||
|
|
@ -1,11 +1,18 @@
|
|||
"""
|
||||
Shared mock fixtures and workflow helpers for unit tests.
|
||||
|
||||
Database setup (test DB creation, migrations, session isolation) lives in
|
||||
the root api/conftest.py. This module provides lightweight, non-DB fixtures:
|
||||
- Mock objects (engine, workflow model, workflow run, user config, tools)
|
||||
- Pre-built WorkflowGraph fixtures for various node topologies
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Optional
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
|
||||
from api.constants import DATABASE_URL
|
||||
from api.services.workflow.dto import (
|
||||
EdgeDataDTO,
|
||||
ExtractionVariableDTO,
|
||||
|
|
@ -551,22 +558,3 @@ def three_node_workflow_no_variable_extraction() -> WorkflowGraph:
|
|||
],
|
||||
)
|
||||
return WorkflowGraph(dto)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Database fixtures for integration tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
async def db_engine():
|
||||
"""Create database engine for tests."""
|
||||
engine = create_async_engine(DATABASE_URL, echo=False)
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
async def db_session_factory(db_engine):
|
||||
"""Create session factory for tests."""
|
||||
return async_sessionmaker(bind=db_engine, expire_on_commit=False)
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
|
||||
import pytest
|
||||
from sqlalchemy import delete, text
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
|
||||
from api.db.models import (
|
||||
CampaignModel,
|
||||
|
|
@ -31,6 +32,35 @@ from api.services.campaign.campaign_call_dispatcher import CampaignCallDispatche
|
|||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def db_session_factory(setup_test_database):
|
||||
"""
|
||||
Create a real session factory for campaign integration tests.
|
||||
|
||||
These tests need real database commits (not savepoints) to test
|
||||
concurrent SELECT FOR UPDATE SKIP LOCKED behavior across independent
|
||||
connections.
|
||||
|
||||
Patches db_client so CampaignCallDispatcher uses the test database.
|
||||
"""
|
||||
from api.db import db_client
|
||||
|
||||
test_url = setup_test_database
|
||||
engine = create_async_engine(test_url, echo=False)
|
||||
session_factory = async_sessionmaker(bind=engine, expire_on_commit=False)
|
||||
|
||||
original_engine = db_client.engine
|
||||
original_session = db_client.async_session
|
||||
db_client.engine = engine
|
||||
db_client.async_session = session_factory
|
||||
|
||||
yield session_factory
|
||||
|
||||
db_client.engine = original_engine
|
||||
db_client.async_session = original_session
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@dataclass
|
||||
class CampaignTestData:
|
||||
"""Container for campaign test data IDs"""
|
||||
|
|
|
|||
262
api/tests/test_workflow_qa_masking.py
Normal file
262
api/tests/test_workflow_qa_masking.py
Normal file
|
|
@ -0,0 +1,262 @@
|
|||
from api.services.configuration.masking import (
|
||||
mask_key,
|
||||
mask_workflow_definition,
|
||||
merge_workflow_api_keys,
|
||||
)
|
||||
|
||||
|
||||
def _make_workflow_def(nodes):
|
||||
"""Helper to build a minimal workflow definition dict."""
|
||||
return {"nodes": nodes, "edges": [], "viewport": {"x": 0, "y": 0, "zoom": 1}}
|
||||
|
||||
|
||||
def _qa_node(node_id="qa-1", api_key="", **extra_data):
|
||||
"""Helper to build a QA node."""
|
||||
data = {"name": "QA Analysis", "qa_enabled": True, **extra_data}
|
||||
if api_key:
|
||||
data["qa_api_key"] = api_key
|
||||
return {"id": node_id, "type": "qa", "position": {"x": 0, "y": 0}, "data": data}
|
||||
|
||||
|
||||
def _agent_node(node_id="agent-1"):
|
||||
"""Helper to build a non-QA node."""
|
||||
return {
|
||||
"id": node_id,
|
||||
"type": "agentNode",
|
||||
"position": {"x": 0, "y": 0},
|
||||
"data": {"name": "Agent", "prompt": "hello"},
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# mask_workflow_definition
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMaskWorkflowDefinition:
|
||||
def test_masks_qa_api_key(self):
|
||||
"""QA node api_key is masked, showing only last 4 chars."""
|
||||
real_key = "sk-proj-abcdefghijklmnop"
|
||||
wf = _make_workflow_def([_qa_node(api_key=real_key)])
|
||||
|
||||
masked = mask_workflow_definition(wf)
|
||||
|
||||
masked_key = masked["nodes"][0]["data"]["qa_api_key"]
|
||||
assert masked_key == mask_key(real_key)
|
||||
assert masked_key.endswith("mnop")
|
||||
assert masked_key.startswith("*")
|
||||
assert real_key not in str(masked)
|
||||
|
||||
def test_does_not_mutate_original(self):
|
||||
"""The original workflow definition is not modified."""
|
||||
real_key = "sk-proj-abcdefghijklmnop"
|
||||
wf = _make_workflow_def([_qa_node(api_key=real_key)])
|
||||
|
||||
mask_workflow_definition(wf)
|
||||
|
||||
assert wf["nodes"][0]["data"]["qa_api_key"] == real_key
|
||||
|
||||
def test_non_qa_nodes_untouched(self):
|
||||
"""Non-QA nodes are not modified."""
|
||||
wf = _make_workflow_def([_agent_node(), _qa_node(api_key="sk-secret1234")])
|
||||
|
||||
masked = mask_workflow_definition(wf)
|
||||
|
||||
assert masked["nodes"][0]["type"] == "agentNode"
|
||||
assert "qa_api_key" not in masked["nodes"][0]["data"]
|
||||
assert masked["nodes"][1]["data"]["qa_api_key"] == mask_key("sk-secret1234")
|
||||
|
||||
def test_qa_node_without_api_key(self):
|
||||
"""QA node with no api_key is left as-is."""
|
||||
wf = _make_workflow_def([_qa_node()])
|
||||
|
||||
masked = mask_workflow_definition(wf)
|
||||
|
||||
assert "qa_api_key" not in masked["nodes"][0]["data"]
|
||||
|
||||
def test_qa_node_with_empty_api_key(self):
|
||||
"""QA node with empty string api_key is left as-is."""
|
||||
node = _qa_node()
|
||||
node["data"]["qa_api_key"] = ""
|
||||
wf = _make_workflow_def([node])
|
||||
|
||||
masked = mask_workflow_definition(wf)
|
||||
|
||||
assert masked["nodes"][0]["data"]["qa_api_key"] == ""
|
||||
|
||||
def test_multiple_qa_nodes(self):
|
||||
"""All QA nodes in a definition are masked."""
|
||||
wf = _make_workflow_def(
|
||||
[
|
||||
_qa_node(node_id="qa-1", api_key="key-aaaa1111"),
|
||||
_qa_node(node_id="qa-2", api_key="key-bbbb2222"),
|
||||
]
|
||||
)
|
||||
|
||||
masked = mask_workflow_definition(wf)
|
||||
|
||||
assert masked["nodes"][0]["data"]["qa_api_key"] == mask_key("key-aaaa1111")
|
||||
assert masked["nodes"][1]["data"]["qa_api_key"] == mask_key("key-bbbb2222")
|
||||
|
||||
def test_none_definition(self):
|
||||
"""None input returns None."""
|
||||
assert mask_workflow_definition(None) is None
|
||||
|
||||
def test_empty_definition(self):
|
||||
"""Empty dict returns empty dict."""
|
||||
assert mask_workflow_definition({}) == {}
|
||||
|
||||
def test_definition_without_nodes(self):
|
||||
"""Definition with no nodes key is returned as-is."""
|
||||
wf = {"edges": [], "viewport": {"x": 0, "y": 0, "zoom": 1}}
|
||||
result = mask_workflow_definition(wf)
|
||||
assert result == wf
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# merge_workflow_api_keys
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMergeWorkflowApiKeys:
|
||||
def test_masked_key_is_restored(self):
|
||||
"""When incoming key matches the mask of the existing key, real key is preserved."""
|
||||
real_key = "sk-proj-abcdefghijklmnop"
|
||||
masked_val = mask_key(real_key)
|
||||
|
||||
existing = _make_workflow_def([_qa_node(api_key=real_key)])
|
||||
incoming = _make_workflow_def([_qa_node(api_key=masked_val)])
|
||||
|
||||
result = merge_workflow_api_keys(incoming, existing)
|
||||
|
||||
assert result["nodes"][0]["data"]["qa_api_key"] == real_key
|
||||
|
||||
def test_new_key_is_accepted(self):
|
||||
"""When user provides a brand new key, it replaces the old one."""
|
||||
old_key = "sk-proj-abcdefghijklmnop"
|
||||
new_key = "sk-proj-zyxwvutsrqponmlk"
|
||||
|
||||
existing = _make_workflow_def([_qa_node(api_key=old_key)])
|
||||
incoming = _make_workflow_def([_qa_node(api_key=new_key)])
|
||||
|
||||
result = merge_workflow_api_keys(incoming, existing)
|
||||
|
||||
assert result["nodes"][0]["data"]["qa_api_key"] == new_key
|
||||
|
||||
def test_no_existing_qa_node(self):
|
||||
"""New QA node with no prior existing node keeps incoming key."""
|
||||
new_key = "sk-brand-new-key1234"
|
||||
|
||||
existing = _make_workflow_def([_agent_node()])
|
||||
incoming = _make_workflow_def([_qa_node(api_key=new_key)])
|
||||
|
||||
result = merge_workflow_api_keys(incoming, existing)
|
||||
|
||||
assert result["nodes"][0]["data"]["qa_api_key"] == new_key
|
||||
|
||||
def test_no_incoming_api_key(self):
|
||||
"""QA node without api_key in incoming is left alone."""
|
||||
existing = _make_workflow_def([_qa_node(api_key="sk-existing-key1")])
|
||||
incoming = _make_workflow_def([_qa_node()])
|
||||
|
||||
result = merge_workflow_api_keys(incoming, existing)
|
||||
|
||||
assert "qa_api_key" not in result["nodes"][0]["data"]
|
||||
|
||||
def test_multiple_qa_nodes_matched_by_id(self):
|
||||
"""Multiple QA nodes are matched by node ID, not position."""
|
||||
key_1 = "sk-first-key-abcd1234"
|
||||
key_2 = "sk-second-key-efgh5678"
|
||||
|
||||
existing = _make_workflow_def(
|
||||
[
|
||||
_qa_node(node_id="qa-1", api_key=key_1),
|
||||
_qa_node(node_id="qa-2", api_key=key_2),
|
||||
]
|
||||
)
|
||||
incoming = _make_workflow_def(
|
||||
[
|
||||
_qa_node(node_id="qa-2", api_key=mask_key(key_2)),
|
||||
_qa_node(node_id="qa-1", api_key=mask_key(key_1)),
|
||||
]
|
||||
)
|
||||
|
||||
result = merge_workflow_api_keys(incoming, existing)
|
||||
|
||||
node_map = {n["id"]: n for n in result["nodes"]}
|
||||
assert node_map["qa-1"]["data"]["qa_api_key"] == key_1
|
||||
assert node_map["qa-2"]["data"]["qa_api_key"] == key_2
|
||||
|
||||
def test_none_incoming_returns_none(self):
|
||||
existing = _make_workflow_def([_qa_node(api_key="sk-key")])
|
||||
assert merge_workflow_api_keys(None, existing) is None
|
||||
|
||||
def test_none_existing_returns_incoming(self):
|
||||
incoming = _make_workflow_def([_qa_node(api_key="sk-key")])
|
||||
result = merge_workflow_api_keys(incoming, None)
|
||||
assert result["nodes"][0]["data"]["qa_api_key"] == "sk-key"
|
||||
|
||||
def test_non_qa_nodes_not_affected(self):
|
||||
"""Agent nodes pass through without modification."""
|
||||
existing = _make_workflow_def([_agent_node()])
|
||||
incoming = _make_workflow_def([_agent_node()])
|
||||
|
||||
result = merge_workflow_api_keys(incoming, existing)
|
||||
|
||||
assert result["nodes"][0]["type"] == "agentNode"
|
||||
|
||||
def test_existing_node_has_no_key(self):
|
||||
"""If existing QA node had no key, incoming key is kept."""
|
||||
new_key = "sk-new-key-abcd1234"
|
||||
|
||||
existing = _make_workflow_def([_qa_node()])
|
||||
incoming = _make_workflow_def([_qa_node(api_key=new_key)])
|
||||
|
||||
result = merge_workflow_api_keys(incoming, existing)
|
||||
|
||||
assert result["nodes"][0]["data"]["qa_api_key"] == new_key
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Round-trip: mask then merge
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMaskAndMergeRoundTrip:
|
||||
def test_full_round_trip_preserves_key(self):
|
||||
"""Simulates: save real key → GET masks it → PUT sends masked → merge restores."""
|
||||
real_key = "sk-proj-WZRTVpVvZEXF5s0H4y8N5n2BF6lRZhC79Zq"
|
||||
|
||||
# 1. Real key stored in DB
|
||||
stored = _make_workflow_def(
|
||||
[
|
||||
_qa_node(api_key=real_key, qa_provider="openai", qa_model="gpt-4.1"),
|
||||
]
|
||||
)
|
||||
|
||||
# 2. GET response masks it
|
||||
fetched = mask_workflow_definition(stored)
|
||||
masked_key = fetched["nodes"][0]["data"]["qa_api_key"]
|
||||
assert masked_key != real_key
|
||||
assert masked_key.endswith(real_key[-4:])
|
||||
|
||||
# 3. User saves without changing the key (sends masked value back)
|
||||
incoming = fetched # same as what was fetched
|
||||
|
||||
# 4. PUT merges — real key is restored
|
||||
merged = merge_workflow_api_keys(incoming, stored)
|
||||
assert merged["nodes"][0]["data"]["qa_api_key"] == real_key
|
||||
|
||||
def test_round_trip_with_key_change(self):
|
||||
"""User changes the key mid-round-trip — new key is accepted."""
|
||||
old_key = "sk-old-key-abcdefgh"
|
||||
new_key = "sk-new-key-zyxwvuts"
|
||||
|
||||
stored = _make_workflow_def([_qa_node(api_key=old_key)])
|
||||
fetched = mask_workflow_definition(stored)
|
||||
|
||||
# User replaces the masked key with a new one
|
||||
fetched["nodes"][0]["data"]["qa_api_key"] = new_key
|
||||
|
||||
merged = merge_workflow_api_keys(fetched, stored)
|
||||
assert merged["nodes"][0]["data"]["qa_api_key"] == new_key
|
||||
Loading…
Add table
Add a link
Reference in a new issue