mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-16 08:25:18 +02:00
feat: add qa node in workflow builder (#172)
* feat: add qa node in workflow builder * feat: add qa analysis token usage in usage_info * fix: mask the API key in QA node * feat: add advanced configuration in QA node
This commit is contained in:
parent
f1f4830012
commit
a836825b83
30 changed files with 1619 additions and 265 deletions
|
|
@ -68,3 +68,65 @@ def mask_user_config(config: UserConfiguration) -> Dict[str, Any]:
|
|||
"test_phone_number": config.test_phone_number,
|
||||
"timezone": config.timezone,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Workflow definition helpers – mask / merge QA-node API keys
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_QA_API_KEY_FIELD = "qa_api_key"
|
||||
|
||||
|
||||
def mask_workflow_definition(workflow_definition: Optional[Dict]) -> Optional[Dict]:
|
||||
"""Return a *shallow copy* of *workflow_definition* with QA-node API keys masked."""
|
||||
if not workflow_definition:
|
||||
return workflow_definition
|
||||
|
||||
import copy
|
||||
|
||||
masked = copy.deepcopy(workflow_definition)
|
||||
for node in masked.get("nodes", []):
|
||||
if node.get("type") != "qa":
|
||||
continue
|
||||
data = node.get("data", {})
|
||||
raw_key = data.get(_QA_API_KEY_FIELD)
|
||||
if raw_key:
|
||||
data[_QA_API_KEY_FIELD] = mask_key(raw_key)
|
||||
return masked
|
||||
|
||||
|
||||
def merge_workflow_api_keys(
|
||||
incoming_definition: Optional[Dict], existing_definition: Optional[Dict]
|
||||
) -> Optional[Dict]:
|
||||
"""Preserve real QA-node API keys when the incoming value is a masked placeholder.
|
||||
|
||||
For each QA node in *incoming_definition*, if its ``qa_api_key`` equals
|
||||
the masked form of the corresponding node in *existing_definition*, the
|
||||
real key is restored so it is never lost.
|
||||
"""
|
||||
if not incoming_definition or not existing_definition:
|
||||
return incoming_definition
|
||||
|
||||
# Build lookup: node-id → data for existing QA nodes
|
||||
existing_qa: Dict[str, Dict] = {}
|
||||
for node in existing_definition.get("nodes", []):
|
||||
if node.get("type") == "qa":
|
||||
existing_qa[node["id"]] = node.get("data", {})
|
||||
|
||||
for node in incoming_definition.get("nodes", []):
|
||||
if node.get("type") != "qa":
|
||||
continue
|
||||
data = node.get("data", {})
|
||||
incoming_key = data.get(_QA_API_KEY_FIELD)
|
||||
if not incoming_key:
|
||||
continue
|
||||
|
||||
old_data = existing_qa.get(node["id"])
|
||||
if not old_data:
|
||||
continue
|
||||
|
||||
old_key = old_data.get(_QA_API_KEY_FIELD, "")
|
||||
if old_key and is_mask_of(incoming_key, old_key):
|
||||
data[_QA_API_KEY_FIELD] = old_key
|
||||
|
||||
return incoming_definition
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import asyncio
|
|||
from typing import Dict, Set
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.utils import mix_audio
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
|
|
|
|||
|
|
@ -216,9 +216,8 @@ def register_event_handlers(
|
|||
except Exception as e:
|
||||
logger.error(f"Error preparing buffers for S3 upload: {e}", exc_info=True)
|
||||
|
||||
await enqueue_job(FunctionNames.CALCULATE_WORKFLOW_RUN_COST, workflow_run_id)
|
||||
|
||||
# Combined task: uploads artifacts then runs integrations sequentially
|
||||
# Combined task: uploads artifacts, runs integrations (including QA),
|
||||
# then calculates cost (so QA token usage is captured in usage_info)
|
||||
await enqueue_job(
|
||||
FunctionNames.PROCESS_WORKFLOW_COMPLETION,
|
||||
workflow_run_id,
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ from api.services.pipecat.service_factory import (
|
|||
create_stt_service,
|
||||
create_tts_service,
|
||||
)
|
||||
from api.services.pipecat.tracing_config import setup_pipeline_tracing
|
||||
from api.services.pipecat.tracing_config import setup_tracing_exporter
|
||||
from api.services.pipecat.transport_setup import (
|
||||
create_ari_transport,
|
||||
create_cloudonix_transport,
|
||||
|
|
@ -80,7 +80,7 @@ from pipecat.utils.run_context import set_current_run_id
|
|||
from pipecat.utils.tracing.context_registry import ContextProviderRegistry
|
||||
|
||||
# Setup tracing if enabled
|
||||
setup_pipeline_tracing()
|
||||
setup_tracing_exporter()
|
||||
|
||||
|
||||
async def run_pipeline_twilio(
|
||||
|
|
|
|||
|
|
@ -4,9 +4,16 @@ import os
|
|||
from loguru import logger
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
||||
|
||||
from api.constants import ENABLE_TRACING
|
||||
from api.constants import (
|
||||
ENABLE_TRACING,
|
||||
LANGFUSE_HOST,
|
||||
LANGFUSE_PUBLIC_KEY,
|
||||
LANGFUSE_SECRET_KEY,
|
||||
)
|
||||
from pipecat.utils.tracing.setup import setup_tracing
|
||||
|
||||
_tracing_initialized = False
|
||||
|
||||
|
||||
def is_tracing_enabled():
|
||||
"""Check if tracing should be enabled based on ENABLE_TRACING flag."""
|
||||
|
|
@ -15,28 +22,31 @@ def is_tracing_enabled():
|
|||
return ENABLE_TRACING
|
||||
|
||||
|
||||
def setup_pipeline_tracing():
|
||||
"""Setup tracing for the pipeline if enabled"""
|
||||
if is_tracing_enabled():
|
||||
# Only set up Langfuse if credentials are provided
|
||||
langfuse_host = os.environ.get("LANGFUSE_HOST")
|
||||
langfuse_public_key = os.environ.get("LANGFUSE_PUBLIC_KEY")
|
||||
langfuse_secret_key = os.environ.get("LANGFUSE_SECRET_KEY")
|
||||
def setup_tracing_exporter():
|
||||
"""Setup the OTEL tracing exporter for Langfuse if enabled.
|
||||
|
||||
if not all([langfuse_host, langfuse_public_key, langfuse_secret_key]):
|
||||
Idempotent — safe to call from both the pipeline process and the ARQ worker.
|
||||
"""
|
||||
global _tracing_initialized
|
||||
if _tracing_initialized:
|
||||
return
|
||||
|
||||
if is_tracing_enabled():
|
||||
if not all([LANGFUSE_HOST, LANGFUSE_PUBLIC_KEY, LANGFUSE_SECRET_KEY]):
|
||||
logger.warning(
|
||||
"Warning: ENABLE_TRACING is true but Langfuse credentials are not configured. Tracing disabled."
|
||||
)
|
||||
return
|
||||
|
||||
LANGFUSE_AUTH = base64.b64encode(
|
||||
f"{langfuse_public_key}:{langfuse_secret_key}".encode()
|
||||
langfuse_auth = base64.b64encode(
|
||||
f"{LANGFUSE_PUBLIC_KEY}:{LANGFUSE_SECRET_KEY}".encode()
|
||||
).decode()
|
||||
|
||||
os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"] = f"{langfuse_host}/api/public/otel"
|
||||
os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"] = f"{LANGFUSE_HOST}/api/public/otel"
|
||||
os.environ["OTEL_EXPORTER_OTLP_HEADERS"] = (
|
||||
f"Authorization=Basic {LANGFUSE_AUTH}"
|
||||
f"Authorization=Basic {langfuse_auth}"
|
||||
)
|
||||
|
||||
otlp_exporter = OTLPSpanExporter()
|
||||
setup_tracing(service_name="dograh-pipeline", exporter=otlp_exporter)
|
||||
_tracing_initialized = True
|
||||
|
|
|
|||
153
api/services/pricing/workflow_run_cost.py
Normal file
153
api/services/pricing/workflow_run_cost.py
Normal file
|
|
@ -0,0 +1,153 @@
|
|||
from loguru import logger
|
||||
|
||||
from api.db import db_client
|
||||
from api.enums import WorkflowRunMode
|
||||
from api.services.pricing.cost_calculator import cost_calculator
|
||||
from api.services.telephony.factory import get_telephony_provider
|
||||
|
||||
|
||||
async def _fetch_telephony_cost(workflow_run) -> dict | None:
|
||||
"""Fetch telephony call cost. Returns a dict with cost_usd and provider_name, or None."""
|
||||
if (
|
||||
workflow_run.mode
|
||||
not in [WorkflowRunMode.TWILIO.value, WorkflowRunMode.VONAGE.value]
|
||||
or not workflow_run.cost_info
|
||||
):
|
||||
return None
|
||||
|
||||
call_id = workflow_run.cost_info.get("call_id")
|
||||
if not call_id:
|
||||
logger.warning(f"call_id not found in cost_info")
|
||||
return None
|
||||
|
||||
provider_name = workflow_run.mode.lower() if workflow_run.mode else ""
|
||||
|
||||
workflow = await db_client.get_workflow_by_id(workflow_run.workflow_id)
|
||||
if not workflow:
|
||||
logger.warning("Workflow not found for workflow run")
|
||||
raise Exception("Workflow not found")
|
||||
|
||||
provider = await get_telephony_provider(workflow.organization_id)
|
||||
call_cost_info = await provider.get_call_cost(call_id)
|
||||
|
||||
if call_cost_info.get("status") == "error":
|
||||
logger.error(
|
||||
f"Failed to fetch {provider_name} call cost: {call_cost_info.get('error')}"
|
||||
)
|
||||
return None
|
||||
|
||||
cost_usd = call_cost_info.get("cost_usd", 0.0)
|
||||
logger.info(
|
||||
f"{provider_name.title()} call cost: ${cost_usd:.6f} USD for call {call_id}"
|
||||
)
|
||||
return {"cost_usd": cost_usd, "provider_name": provider_name}
|
||||
|
||||
|
||||
async def _update_organization_usage(
|
||||
org, dograh_tokens: float, duration_seconds: float, charge_usd: float | None
|
||||
) -> None:
|
||||
"""Update organization usage after a workflow run."""
|
||||
org_id = org.id
|
||||
await db_client.update_usage_after_run(
|
||||
org_id, dograh_tokens, duration_seconds, charge_usd
|
||||
)
|
||||
if charge_usd is not None:
|
||||
logger.info(
|
||||
f"Updated organization usage with ${charge_usd:.2f} USD ({dograh_tokens} Dograh Tokens) and {duration_seconds}s duration for org {org_id}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Updated organization usage with {dograh_tokens} Dograh Tokens and {duration_seconds}s duration for org {org_id}"
|
||||
)
|
||||
|
||||
|
||||
async def calculate_workflow_run_cost(workflow_run_id: int):
|
||||
logger.debug("Calculating cost for workflow run")
|
||||
|
||||
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
|
||||
if not workflow_run:
|
||||
logger.warning("Workflow run not found")
|
||||
return
|
||||
|
||||
workflow_usage_info = workflow_run.usage_info
|
||||
if not workflow_usage_info:
|
||||
logger.warning("No usage info available for workflow run")
|
||||
return
|
||||
|
||||
try:
|
||||
# Calculate cost breakdown
|
||||
cost_breakdown = cost_calculator.calculate_total_cost(workflow_usage_info)
|
||||
|
||||
# Fetch telephony call cost
|
||||
try:
|
||||
telephony_cost = await _fetch_telephony_cost(workflow_run)
|
||||
if telephony_cost:
|
||||
telephony_cost_usd = telephony_cost["cost_usd"]
|
||||
provider_name = telephony_cost["provider_name"]
|
||||
cost_breakdown["telephony_call"] = telephony_cost_usd
|
||||
cost_breakdown[f"{provider_name}_call"] = telephony_cost_usd
|
||||
cost_breakdown["total"] = (
|
||||
float(cost_breakdown["total"]) + telephony_cost_usd
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch telephony call cost: {e}")
|
||||
# Don't fail the whole cost calculation if telephony API fails
|
||||
|
||||
# Store cost information back to the workflow run
|
||||
# Convert USD to Dograh Tokens (1 cent = 1 token)
|
||||
dograh_tokens = round(float(cost_breakdown["total"]) * 100, 2)
|
||||
|
||||
# Get organization to check if it has USD pricing
|
||||
org = None
|
||||
charge_usd = None
|
||||
if (
|
||||
workflow_run.workflow
|
||||
and workflow_run.workflow.user
|
||||
and workflow_run.workflow.user.selected_organization_id
|
||||
):
|
||||
org = await db_client.get_organization_by_id(
|
||||
workflow_run.workflow.user.selected_organization_id
|
||||
)
|
||||
|
||||
# Calculate USD cost if organization has pricing configured
|
||||
if org and org.price_per_second_usd:
|
||||
duration_seconds = workflow_usage_info.get("call_duration_seconds", 0)
|
||||
charge_usd = duration_seconds * org.price_per_second_usd
|
||||
|
||||
cost_info = {
|
||||
**workflow_run.cost_info,
|
||||
"cost_breakdown": cost_breakdown,
|
||||
"total_cost_usd": float(cost_breakdown["total"]),
|
||||
"dograh_token_usage": dograh_tokens,
|
||||
"calculated_at": workflow_run.created_at.isoformat(),
|
||||
"call_duration_seconds": workflow_usage_info["call_duration_seconds"],
|
||||
}
|
||||
|
||||
# Add USD cost if available
|
||||
if charge_usd is not None:
|
||||
cost_info["charge_usd"] = charge_usd
|
||||
cost_info["price_per_second_usd"] = org.price_per_second_usd
|
||||
|
||||
# Update workflow run with cost information
|
||||
await db_client.update_workflow_run(run_id=workflow_run_id, cost_info=cost_info)
|
||||
|
||||
# Update organization usage if applicable
|
||||
if org:
|
||||
try:
|
||||
duration_seconds = workflow_usage_info.get("call_duration_seconds", 0)
|
||||
await _update_organization_usage(
|
||||
org, dograh_tokens, duration_seconds, charge_usd
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to update organization usage for org {org.id}: {e}"
|
||||
)
|
||||
# Don't fail the whole task if usage update fails
|
||||
|
||||
logger.info(
|
||||
f"Calculated cost for workflow run: ${cost_breakdown['total']:.6f} USD ({dograh_tokens} Dograh Tokens)"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating cost for workflow run: {e}")
|
||||
raise
|
||||
360
api/services/qa_analysis.py
Normal file
360
api/services/qa_analysis.py
Normal file
|
|
@ -0,0 +1,360 @@
|
|||
"""QA analysis service for post-call quality assessment.
|
||||
|
||||
Runs LLM-based analysis on call transcripts, traces under the same
|
||||
Langfuse trace as the conversation, and returns structured results.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.models import WorkflowRunModel
|
||||
from api.services.gen_ai.json_parser import parse_llm_json
|
||||
from pipecat.utils.enums import RealtimeFeedbackType
|
||||
|
||||
|
||||
def build_conversation_structure(logs: list[dict]) -> list[dict]:
|
||||
"""Transform raw call logs into a conversation structure for LLM QA analysis."""
|
||||
if not logs:
|
||||
return []
|
||||
|
||||
start_time = datetime.fromisoformat(logs[0]["timestamp"])
|
||||
|
||||
conversation = []
|
||||
for event in logs:
|
||||
if event["type"] == RealtimeFeedbackType.BOT_TEXT.value:
|
||||
speaker = "assistant"
|
||||
utterance_text = event["payload"]["text"]
|
||||
event_time = datetime.fromisoformat(event["payload"]["timestamp"])
|
||||
elif event["type"] == RealtimeFeedbackType.USER_TRANSCRIPTION.value and event[
|
||||
"payload"
|
||||
].get("final", False):
|
||||
speaker = "user"
|
||||
utterance_text = event["payload"]["text"]
|
||||
event_time = datetime.fromisoformat(event["payload"]["timestamp"])
|
||||
else:
|
||||
continue
|
||||
|
||||
time_from_start = (event_time - start_time).total_seconds()
|
||||
|
||||
conversation.append(
|
||||
{
|
||||
"time_from_start_seconds": round(time_from_start, 2),
|
||||
"speaker": speaker,
|
||||
"text": utterance_text,
|
||||
"node_name": event.get("node_name", ""),
|
||||
"turn": event.get("turn", 0),
|
||||
}
|
||||
)
|
||||
|
||||
return conversation
|
||||
|
||||
|
||||
def format_transcript(conversation: list[dict]) -> str:
|
||||
"""Format conversation structure into a readable transcript string for the LLM."""
|
||||
lines = []
|
||||
for entry in conversation:
|
||||
lines.append(
|
||||
f"[{entry['time_from_start_seconds']:.1f}s] "
|
||||
f"{entry['speaker']}: {entry['text']}"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def compute_call_metrics(
|
||||
logs: list[dict], call_duration_seconds: float | None = None
|
||||
) -> dict:
|
||||
"""Pre-compute quantitative metrics from raw call logs."""
|
||||
latencies = []
|
||||
ttfb_values = []
|
||||
|
||||
for event in logs:
|
||||
if event["type"] == RealtimeFeedbackType.LATENCY_MEASURED.value:
|
||||
latencies.append(event["payload"]["latency_seconds"])
|
||||
elif event["type"] == RealtimeFeedbackType.TTFB_METRIC.value:
|
||||
ttfb_values.append(event["payload"]["ttfb_seconds"])
|
||||
|
||||
turns = set()
|
||||
for event in logs:
|
||||
if event["type"] in (
|
||||
RealtimeFeedbackType.USER_TRANSCRIPTION.value,
|
||||
RealtimeFeedbackType.BOT_TEXT.value,
|
||||
):
|
||||
turns.add(event.get("turn", 0))
|
||||
|
||||
return {
|
||||
"call_duration_seconds": call_duration_seconds,
|
||||
"num_turns": len(turns),
|
||||
"avg_latency_seconds": (
|
||||
round(sum(latencies) / len(latencies), 2) if latencies else None
|
||||
),
|
||||
"avg_ttfb_seconds": (
|
||||
round(sum(ttfb_values) / len(ttfb_values), 2) if ttfb_values else None
|
||||
),
|
||||
"max_latency_seconds": round(max(latencies), 2) if latencies else None,
|
||||
}
|
||||
|
||||
|
||||
def _extract_trace_id(gathered_context: dict) -> str | None:
|
||||
"""Extract Langfuse trace_id from gathered_context trace_url.
|
||||
|
||||
URL format: https://langfuse.dograh.com/project/<project_id>/traces/<trace_id>
|
||||
"""
|
||||
trace_url = gathered_context.get("trace_url")
|
||||
if not trace_url:
|
||||
return None
|
||||
try:
|
||||
match = re.search(r"/traces/([a-fA-F0-9]+)$", trace_url)
|
||||
if match:
|
||||
return match.group(1)
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _provider_base_url(provider: str | None, endpoint: str = "") -> str | None:
|
||||
"""Return the base URL for a given LLM provider."""
|
||||
if provider == "openrouter":
|
||||
return "https://openrouter.ai/api/v1"
|
||||
if provider == "groq":
|
||||
return "https://api.groq.com/openai/v1"
|
||||
if provider == "google":
|
||||
return "https://generativelanguage.googleapis.com/v1beta/openai/"
|
||||
if provider == "azure":
|
||||
return endpoint or None
|
||||
return None
|
||||
|
||||
|
||||
async def _resolve_llm_config(
|
||||
qa_node_data: dict, workflow_run: WorkflowRunModel
|
||||
) -> tuple[str, str, str | None]:
|
||||
"""Resolve the LLM model, API key, and base URL for QA analysis.
|
||||
|
||||
If the QA node has its own LLM configuration (qa_use_workflow_llm=False),
|
||||
use those settings directly. Otherwise, fall back to the user's configured LLM.
|
||||
|
||||
Returns:
|
||||
(model, api_key, base_url) tuple
|
||||
"""
|
||||
if not qa_node_data.get("qa_use_workflow_llm", True):
|
||||
return (
|
||||
qa_node_data.get("qa_model"),
|
||||
qa_node_data.get("qa_api_key"),
|
||||
_provider_base_url(
|
||||
qa_node_data.get("qa_provider"),
|
||||
qa_node_data.get("qa_endpoint", ""),
|
||||
),
|
||||
)
|
||||
|
||||
# Fall back to user's configured LLM
|
||||
user_id = None
|
||||
if workflow_run.workflow and workflow_run.workflow.user:
|
||||
user_id = workflow_run.workflow.user.id
|
||||
|
||||
llm_config: dict = {}
|
||||
if user_id:
|
||||
user_configuration = await db_client.get_user_configurations(user_id)
|
||||
llm_config = user_configuration.model_dump(exclude_none=True).get("llm", {})
|
||||
|
||||
provider = llm_config.get("provider", "openai")
|
||||
api_key = llm_config.get("api_key", "")
|
||||
|
||||
qa_model = qa_node_data.get("qa_model", "default")
|
||||
if qa_model and qa_model != "default":
|
||||
model = qa_model
|
||||
else:
|
||||
model = llm_config.get("model", "gpt-4.1")
|
||||
|
||||
base_url = _provider_base_url(provider, llm_config.get("endpoint", ""))
|
||||
# For openrouter, prefer user-configured base_url if set
|
||||
if provider == "openrouter" and llm_config.get("base_url"):
|
||||
base_url = llm_config["base_url"]
|
||||
|
||||
return model, api_key, base_url
|
||||
|
||||
|
||||
async def run_qa_analysis(
|
||||
qa_node_data: dict[str, Any],
|
||||
workflow_run: WorkflowRunModel,
|
||||
workflow_run_id: int,
|
||||
) -> dict[str, Any]:
|
||||
"""Run QA analysis on a completed workflow run.
|
||||
|
||||
Args:
|
||||
qa_node_data: The QA node's data dict from workflow definition
|
||||
workflow_run: The workflow run model with logs and context
|
||||
workflow_run_id: The workflow run ID
|
||||
|
||||
Returns:
|
||||
Dict with tags, summary, score, raw_response
|
||||
"""
|
||||
# Extract transcript from logs
|
||||
logs = workflow_run.logs or {}
|
||||
rtf_events = logs.get("realtime_feedback_events", [])
|
||||
if not rtf_events:
|
||||
logger.warning(f"No realtime_feedback_events for run {workflow_run_id}")
|
||||
return {"error": "no_transcript", "tags": [], "summary": "", "score": None}
|
||||
|
||||
conversation = build_conversation_structure(rtf_events)
|
||||
transcript = format_transcript(conversation)
|
||||
if not transcript:
|
||||
logger.warning(f"Empty transcript for run {workflow_run_id}")
|
||||
return {"error": "empty_transcript", "tags": [], "summary": "", "score": None}
|
||||
|
||||
# Compute call metrics
|
||||
usage_info = workflow_run.usage_info or {}
|
||||
call_duration = usage_info.get("call_duration_seconds")
|
||||
metrics = compute_call_metrics(rtf_events, call_duration)
|
||||
|
||||
# Resolve LLM config
|
||||
system_prompt = qa_node_data.get("qa_system_prompt", "")
|
||||
if not system_prompt:
|
||||
logger.warning("No system prompt defined for QA Node")
|
||||
return {"error": "no_system_prompt", "tags": [], "summary": "", "score": None}
|
||||
|
||||
model, api_key, base_url = await _resolve_llm_config(qa_node_data, workflow_run)
|
||||
|
||||
if not api_key:
|
||||
logger.warning(
|
||||
f"No LLM API key configured for QA analysis on run {workflow_run_id}"
|
||||
)
|
||||
return {"error": "no_api_key", "tags": [], "summary": "", "score": None}
|
||||
|
||||
# Build messages
|
||||
system_content = system_prompt.replace("{metrics}", json.dumps(metrics, indent=2))
|
||||
messages = [
|
||||
{"role": "system", "content": system_content},
|
||||
{"role": "user", "content": f"## Transcript\n{transcript}"},
|
||||
]
|
||||
|
||||
# Call LLM
|
||||
client_kwargs: dict[str, Any] = {"api_key": api_key}
|
||||
if base_url:
|
||||
client_kwargs["base_url"] = base_url
|
||||
|
||||
client = AsyncOpenAI(**client_kwargs)
|
||||
|
||||
try:
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
)
|
||||
raw_response = response.choices[0].message.content
|
||||
except Exception as e:
|
||||
logger.error(f"QA LLM call failed for run {workflow_run_id}: {e}")
|
||||
return {"error": str(e), "tags": [], "summary": "", "score": None}
|
||||
|
||||
# Extract token usage from LLM response
|
||||
token_usage = None
|
||||
if response.usage:
|
||||
token_usage = {
|
||||
"prompt_tokens": response.usage.prompt_tokens or 0,
|
||||
"completion_tokens": response.usage.completion_tokens or 0,
|
||||
"total_tokens": response.usage.total_tokens or 0,
|
||||
"cache_read_input_tokens": getattr(
|
||||
response.usage, "cache_read_input_tokens", 0
|
||||
)
|
||||
or 0,
|
||||
"cache_creation_input_tokens": getattr(
|
||||
response.usage, "cache_creation_input_tokens", None
|
||||
),
|
||||
}
|
||||
|
||||
# Parse response
|
||||
result: dict[str, Any] = {"raw_response": raw_response, "model": model}
|
||||
if token_usage:
|
||||
result["token_usage"] = token_usage
|
||||
try:
|
||||
parsed = parse_llm_json(raw_response)
|
||||
result["tags"] = parsed.get("tags", [])
|
||||
result["summary"] = parsed.get("summary", "")
|
||||
result["score"] = parsed.get("call_quality_score")
|
||||
result["overall_sentiment"] = parsed.get("overall_sentiment")
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
result["tags"] = []
|
||||
result["summary"] = ""
|
||||
result["score"] = None
|
||||
|
||||
# Langfuse tracing — attach QA generation to the conversation trace
|
||||
_add_qa_span_to_conversation_trace(
|
||||
workflow_run, model, messages, raw_response, result
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _add_qa_span_to_conversation_trace(
|
||||
workflow_run: WorkflowRunModel,
|
||||
model: str,
|
||||
messages: list[dict],
|
||||
raw_response: str,
|
||||
result: dict,
|
||||
):
|
||||
"""Attach the QA generation to the existing Langfuse conversation trace.
|
||||
|
||||
Uses OpenTelemetry directly to create a child span under the existing trace,
|
||||
matching the same attribute format used by the pipecat pipeline (gen_ai.*).
|
||||
"""
|
||||
try:
|
||||
from opentelemetry import trace as otel_trace
|
||||
from opentelemetry.trace import (
|
||||
NonRecordingSpan,
|
||||
SpanContext,
|
||||
TraceFlags,
|
||||
set_span_in_context,
|
||||
)
|
||||
|
||||
from api.services.pipecat.tracing_config import (
|
||||
is_tracing_enabled,
|
||||
setup_tracing_exporter,
|
||||
)
|
||||
from pipecat.utils.tracing.service_attributes import add_llm_span_attributes
|
||||
|
||||
if not is_tracing_enabled():
|
||||
return
|
||||
|
||||
# Ensure the OTEL exporter is initialized (idempotent — no-op if
|
||||
# already called in the pipeline process, required in the ARQ worker).
|
||||
setup_tracing_exporter()
|
||||
|
||||
gathered_context = workflow_run.gathered_context or {}
|
||||
trace_id = _extract_trace_id(gathered_context)
|
||||
if not trace_id:
|
||||
logger.debug("No trace_id found, skipping Langfuse QA trace")
|
||||
return
|
||||
|
||||
tracer = otel_trace.get_tracer("pipecat")
|
||||
|
||||
# Create a remote parent context from the existing trace ID
|
||||
parent_span_ctx = SpanContext(
|
||||
trace_id=int(trace_id, 16),
|
||||
span_id=0x1, # dummy parent span id
|
||||
is_remote=True,
|
||||
trace_flags=TraceFlags(0x01),
|
||||
)
|
||||
parent_ctx = set_span_in_context(NonRecordingSpan(parent_span_ctx))
|
||||
|
||||
# Create a child span under the existing trace
|
||||
with tracer.start_as_current_span(
|
||||
"qa-analysis",
|
||||
context=parent_ctx,
|
||||
) as span:
|
||||
add_llm_span_attributes(
|
||||
span,
|
||||
service_name="OpenAILLMService",
|
||||
model=model,
|
||||
operation_name="qa-analysis",
|
||||
messages=messages,
|
||||
output=raw_response,
|
||||
stream=False,
|
||||
parameters={"temperature": 0},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to trace QA to Langfuse: {e}")
|
||||
|
|
@ -11,6 +11,7 @@ class NodeType(str, Enum):
|
|||
globalNode = "globalNode"
|
||||
trigger = "trigger"
|
||||
webhook = "webhook"
|
||||
qa = "qa"
|
||||
|
||||
|
||||
class Position(BaseModel):
|
||||
|
|
@ -68,6 +69,13 @@ class NodeDataDTO(BaseModel):
|
|||
custom_headers: Optional[list[CustomHeaderDTO]] = None
|
||||
payload_template: Optional[dict] = None
|
||||
retry_config: Optional[RetryConfigDTO] = None
|
||||
# QA node specific fields
|
||||
qa_enabled: bool = True
|
||||
qa_system_prompt: Optional[str] = None
|
||||
qa_model: Optional[str] = None
|
||||
qa_min_call_duration: int = 15
|
||||
qa_voicemail_calls: bool = False
|
||||
qa_sample_rate: int = 100
|
||||
|
||||
|
||||
class RFNodeDTO(BaseModel):
|
||||
|
|
@ -78,8 +86,8 @@ class RFNodeDTO(BaseModel):
|
|||
|
||||
@model_validator(mode="after")
|
||||
def _validate_prompt_required(self):
|
||||
"""Require prompt for all node types except trigger and webhook."""
|
||||
if self.type not in (NodeType.trigger, NodeType.webhook):
|
||||
"""Require prompt for all node types except trigger, webhook, and qa."""
|
||||
if self.type not in (NodeType.trigger, NodeType.webhook, NodeType.qa):
|
||||
if not self.data.prompt or len(self.data.prompt.strip()) == 0:
|
||||
raise ValueError("Prompt is required for non-trigger nodes")
|
||||
return self
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue