feat: add AWS Bedrock support

This commit is contained in:
Abhishek Kumar 2026-03-19 15:06:59 +05:30
parent 1604e306ec
commit fe84f086ba
30 changed files with 546 additions and 195 deletions

View file

@ -206,7 +206,7 @@ class WorkflowClient(BaseDBClient):
async def update_workflow(
self,
workflow_id: int,
name: str,
name: str | None,
workflow_definition: dict | None,
template_context_variables: dict | None,
workflow_configurations: dict | None,
@ -249,7 +249,8 @@ class WorkflowClient(BaseDBClient):
if not workflow:
raise ValueError(f"Workflow with ID {workflow_id} not found")
workflow.name = name
if name is not None:
workflow.name = name
if template_context_variables is not None:
workflow.template_context_variables = template_context_variables

View file

@ -108,9 +108,7 @@ async def get_mps_credits(user: UserModel = Depends(get_user)):
)
else:
if not user.selected_organization_id:
raise HTTPException(
status_code=400, detail="No organization selected"
)
raise HTTPException(status_code=400, detail="No organization selected")
usage = await mps_service_key_client.get_usage_by_organization(
user.selected_organization_id
)

View file

@ -71,10 +71,10 @@ async def get_auth_user(
class UserConfigurationRequestResponseSchema(BaseModel):
llm: dict[str, Union[str, float, list[str]]] | None = None
tts: dict[str, Union[str, float, list[str]]] | None = None
stt: dict[str, Union[str, float, list[str]]] | None = None
embeddings: dict[str, Union[str, float, list[str]]] | None = None
llm: dict[str, Union[str, float, list[str], None]] | None = None
tts: dict[str, Union[str, float, list[str], None]] | None = None
stt: dict[str, Union[str, float, list[str], None]] | None = None
embeddings: dict[str, Union[str, float, list[str], None]] | None = None
test_phone_number: str | None = None
timezone: str | None = None
organization_pricing: dict[str, Union[float, str, bool]] | None = None

View file

@ -138,7 +138,7 @@ class DuplicateTemplateRequest(BaseModel):
class UpdateWorkflowRequest(BaseModel):
name: str
name: str | None = None
workflow_definition: dict | None = None
template_context_variables: dict | None = None
workflow_configurations: dict | None = None

View file

@ -38,6 +38,7 @@ class UserConfigurationValidator:
ServiceProviders.DOGRAH.value: self._check_dograh_api_key,
ServiceProviders.SARVAM.value: self._check_sarvam_api_key,
ServiceProviders.SPEECHMATICS.value: self._check_speechmatics_api_key,
ServiceProviders.AWS_BEDROCK.value: self._check_aws_bedrock_api_key,
}
async def validate(self, configuration: UserConfiguration) -> APIKeyStatusResponse:
@ -71,6 +72,21 @@ class UserConfigurationValidator:
return [] # Optional service not configured is OK
provider = service_config.provider
# AWS Bedrock uses AWS credentials instead of api_key
if provider == ServiceProviders.AWS_BEDROCK.value:
try:
if not self._check_aws_bedrock_api_key(provider, service_config):
return [
{
"model": service_name,
"message": f"Invalid {provider} credentials",
}
]
except ValueError as e:
return [{"model": service_name, "message": str(e)}]
return []
api_key = service_config.api_key
try:
@ -143,3 +159,8 @@ class UserConfigurationValidator:
def _check_speechmatics_api_key(self, model: str, api_key: str) -> bool:
return True
def _check_aws_bedrock_api_key(self, model: str, service_config) -> bool:
if not service_config.aws_access_key or not service_config.aws_secret_key:
raise ValueError("AWS access key and secret key are required for Bedrock")
return True

View file

@ -25,6 +25,7 @@ class ServiceProviders(str, Enum):
DOGRAH = "dograh"
SARVAM = "sarvam"
SPEECHMATICS = "speechmatics"
AWS_BEDROCK = "aws_bedrock"
class BaseServiceConfiguration(BaseModel):
@ -37,6 +38,7 @@ class BaseServiceConfiguration(BaseModel):
ServiceProviders.GOOGLE,
ServiceProviders.AZURE,
ServiceProviders.DOGRAH,
ServiceProviders.AWS_BEDROCK,
# ServiceProviders.SARVAM,
]
api_key: str | list[str]
@ -44,6 +46,8 @@ class BaseServiceConfiguration(BaseModel):
@field_validator("api_key")
@classmethod
def validate_api_key(cls, v):
if v is None:
return v
if isinstance(v, list) and len(v) == 0:
raise ValueError("api_key list must not be empty")
return v
@ -51,6 +55,8 @@ class BaseServiceConfiguration(BaseModel):
def __getattribute__(self, name: str):
if name == "api_key":
value = super().__getattribute__(name)
if value is None:
return value
if isinstance(value, list):
return random.choice(value)
return value
@ -59,6 +65,8 @@ class BaseServiceConfiguration(BaseModel):
def get_all_api_keys(self) -> list[str]:
"""Get all API keys as a list (bypasses random selection)."""
value = super().__getattribute__("api_key")
if value is None:
return []
if isinstance(value, list):
return list(value)
return [value]
@ -167,6 +175,14 @@ OPENROUTER_MODELS = [
]
AZURE_MODELS = ["gpt-4.1-mini"]
DOGRAH_LLM_MODELS = ["default", "accurate", "fast", "lite", "zen"]
AWS_BEDROCK_MODELS = [
"us.amazon.nova-pro-v1:0",
"us.amazon.nova-lite-v1:0",
"us.amazon.nova-micro-v1:0",
"us.anthropic.claude-sonnet-4-20250514-v1:0",
"us.anthropic.claude-3-5-sonnet-20241022-v2:0",
"us.anthropic.claude-haiku-4-5-20251001-v1:0",
]
@register_llm
@ -219,6 +235,19 @@ class DograhLLMService(BaseLLMConfiguration):
)
@register_llm
class AWSBedrockLLMConfiguration(BaseLLMConfiguration):
provider: Literal[ServiceProviders.AWS_BEDROCK] = ServiceProviders.AWS_BEDROCK
model: str = Field(
default="us.amazon.nova-pro-v1:0",
json_schema_extra={"examples": AWS_BEDROCK_MODELS},
)
aws_access_key: str = Field(default="")
aws_secret_key: str = Field(default="")
aws_region: str = Field(default="us-east-1")
api_key: str | list[str] | None = Field(default=None)
LLMConfig = Annotated[
Union[
OpenAILLMService,
@ -227,6 +256,7 @@ LLMConfig = Annotated[
GoogleLLMService,
AzureLLMService,
DograhLLMService,
AWSBedrockLLMConfiguration,
],
Field(discriminator="provider"),
]

View file

@ -12,7 +12,7 @@ from api.services.pipecat.pipeline_metrics_aggregator import PipelineMetricsAggr
from api.services.workflow.pipecat_engine import PipecatEngine
from api.tasks.arq import enqueue_job
from api.tasks.function_names import FunctionNames
from pipecat.frames.frames import Frame, LLMContextFrame
from pipecat.frames.frames import Frame, LLMContextFrame, TTSSpeakFrame
from pipecat.pipeline.task import PipelineTask
from pipecat.processors.audio.audio_buffer_processor import AudioBufferProcessor
from pipecat.utils.enums import EndTaskReason
@ -47,32 +47,44 @@ def register_event_handlers(
sample_rate=sample_rate,
num_channels=num_channels,
)
# Track both events to ensure LLM is only triggered after both occur
# Track both events to ensure the initial response is only triggered after both occur
ready_state = {
"pipeline_started": False,
"client_connected": False,
"llm_triggered": False,
"initial_response_triggered": False,
}
async def maybe_trigger_llm():
"""Trigger LLM only after both pipeline_started and client_connected events."""
async def maybe_trigger_initial_response():
"""Start the conversation after both pipeline_started and client_connected events.
If the start node has a greeting configured, play it directly via TTS.
Otherwise, trigger an LLM generation for the opening message.
"""
if (
ready_state["pipeline_started"]
and ready_state["client_connected"]
and not ready_state["llm_triggered"]
and not ready_state["initial_response_triggered"]
):
ready_state["llm_triggered"] = True
logger.debug(
"Both pipeline_started and client_connected received - triggering initial LLM generation"
)
await engine.llm.queue_frame(LLMContextFrame(engine.context))
ready_state["initial_response_triggered"] = True
greeting = engine.get_start_greeting()
if greeting:
logger.debug(
"Both pipeline_started and client_connected received - playing greeting via TTS"
)
await task.queue_frame(TTSSpeakFrame(greeting))
else:
logger.debug(
"Both pipeline_started and client_connected received - triggering initial LLM generation"
)
await engine.llm.queue_frame(LLMContextFrame(engine.context))
@transport.event_handler("on_client_connected")
async def on_client_connected(_transport, _participant):
logger.debug("In on_client_connected callback handler")
await audio_buffer.start_recording()
ready_state["client_connected"] = True
await maybe_trigger_llm()
await maybe_trigger_initial_response()
@transport.event_handler("on_client_disconnected")
async def on_client_disconnected(_transport, _participant):
@ -93,7 +105,7 @@ def register_event_handlers(
async def on_pipeline_started(_task: PipelineTask, _frame: Frame):
logger.debug("In on_pipeline_started callback handler")
ready_state["pipeline_started"] = True
await maybe_trigger_llm()
await maybe_trigger_initial_response()
@task.event_handler("on_pipeline_error")
async def on_pipeline_error(_task: PipelineTask, frame: Frame):

View file

@ -74,9 +74,16 @@ def build_pipeline(
if recording_router:
post_llm.append(recording_router)
processors.append(user_context_aggregator)
# Insert LLM gate before the main LLM when voicemail detection is enabled.
# This prevents the main LLM from being triggered until classification
# determines whether a human or voicemail answered the call.
if voicemail_detector:
processors.append(voicemail_detector.llm_gate())
processors.extend(
[
user_context_aggregator,
llm, # LLM
*post_llm,
tts, # TTS

View file

@ -41,6 +41,7 @@ from pipecat.frames.frames import (
MetricsFrame,
StopFrame,
TranscriptionFrame,
TTSSpeakFrame,
)
from pipecat.metrics.metrics import TTFBMetricsData
from pipecat.observers.base_observer import BaseObserver, FramePushed
@ -205,6 +206,17 @@ class RealtimeFeedbackObserver(BaseObserver):
},
}
)
# Handle TTSSpeakFrame (e.g. greeting) - send immediately via WS only
# Final turn text is persisted via on_assistant_turn_stopped to avoid duplication
elif isinstance(frame, TTSSpeakFrame):
await self._send_ws(
{
"type": RealtimeFeedbackType.BOT_TEXT.value,
"payload": {
"text": frame.text,
},
}
)
# Handle bot TTS text - respect pts timing, WebSocket only
# Complete turn text is persisted via register_turn_handlers
elif isinstance(frame, LLMTextFrame):

View file

@ -173,7 +173,9 @@ async def _download_and_convert(
Returns the processed PCM bytes, or None on failure.
"""
ext = _ext_from_key(recording.storage_key)
fd, tmp_path = tempfile.mkstemp(suffix=ext, prefix=f"dograh_dl_{recording.recording_id}_")
fd, tmp_path = tempfile.mkstemp(
suffix=ext, prefix=f"dograh_dl_{recording.recording_id}_"
)
os.close(fd)
try:
storage = get_storage_fn(recording.storage_backend)

View file

@ -34,6 +34,7 @@ from api.services.pipecat.recording_audio_cache import (
from api.services.pipecat.recording_router_processor import RecordingRouterProcessor
from api.services.pipecat.service_factory import (
create_llm_service,
create_llm_service_from_provider,
create_stt_service,
create_tts_service,
)
@ -669,18 +670,31 @@ async def _run_pipeline(
async def on_user_turn_started(aggregator, strategy):
user_idle_handler.reset()
# Create voicemail detector if enabled in the workflow's start node
# Create voicemail detector if enabled in workflow configurations
voicemail_detector = None
start_node = workflow_graph.nodes.get(workflow_graph.start_node_id)
if start_node and start_node.detect_voicemail:
voicemail_config = (workflow.workflow_configurations or {}).get(
"voicemail_detection", {}
)
if voicemail_config.get("enabled", False):
logger.info(f"Voicemail detection enabled for workflow run {workflow_run_id}")
# Create a separate LLM instance for the voicemail sub-pipeline
# (can't share with main pipeline as it would mess up frame linking)
voicemail_llm = create_llm_service(user_config)
if voicemail_config.get("use_workflow_llm", True):
voicemail_llm = create_llm_service(user_config)
else:
voicemail_llm = create_llm_service_from_provider(
provider=voicemail_config.get("provider", "openai"),
model=voicemail_config.get("model", "gpt-4.1"),
api_key=voicemail_config.get("api_key", ""),
)
long_speech_timeout = voicemail_config.get("long_speech_timeout", 8.0)
custom_system_prompt = voicemail_config.get("system_prompt") or None
voicemail_detector = VoicemailDetector(
llm=voicemail_llm,
voicemail_response_delay=1.0,
long_speech_timeout=8.0,
long_speech_timeout=long_speech_timeout,
custom_system_prompt=custom_system_prompt,
)
# Register event handler to end task when voicemail is detected

View file

@ -5,6 +5,7 @@ from loguru import logger
from api.constants import MPS_API_URL
from api.services.configuration.registry import ServiceProviders
from pipecat.services.aws.llm import AWSBedrockLLMService, AWSBedrockLLMSettings
from pipecat.services.azure.llm import AzureLLMService, AzureLLMSettings
from pipecat.services.cartesia.stt import CartesiaSTTService
from pipecat.services.cartesia.tts import CartesiaTTSService, CartesiaTTSSettings
@ -268,56 +269,91 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
)
def create_llm_service(user_config):
"""Create and return appropriate LLM service based on user configuration"""
model = user_config.llm.model
logger.info(
f"Creating LLM service: provider={user_config.llm.provider}, model={model}"
)
if user_config.llm.provider == ServiceProviders.OPENAI.value:
def create_llm_service_from_provider(
provider: str,
model: str,
api_key: str,
*,
base_url: str | None = None,
endpoint: str | None = None,
aws_access_key: str | None = None,
aws_secret_key: str | None = None,
aws_region: str | None = None,
):
"""Create an LLM service from explicit provider/model/api_key.
Also used by create_llm_service which extracts these from user_config.
"""
logger.info(f"Creating LLM service: provider={provider}, model={model}")
if provider == ServiceProviders.OPENAI.value:
if "gpt-5" in model:
return OpenAILLMService(
api_key=user_config.llm.api_key,
api_key=api_key,
settings=OpenAILLMSettings(
model=model,
extra={"reasoning_effort": "minimal", "verbosity": "low"},
),
)
else:
return OpenAILLMService(
api_key=user_config.llm.api_key,
settings=OpenAILLMSettings(model=model, temperature=0.1),
)
elif user_config.llm.provider == ServiceProviders.GROQ.value:
print(
f"Creating Groq LLM service with API key: {user_config.llm.api_key} and model: {model}"
return OpenAILLMService(
api_key=api_key,
settings=OpenAILLMSettings(model=model, temperature=0.1),
)
elif provider == ServiceProviders.GROQ.value:
return GroqLLMService(
api_key=user_config.llm.api_key,
api_key=api_key,
settings=GroqLLMSettings(model=model, temperature=0.1),
)
elif user_config.llm.provider == ServiceProviders.OPENROUTER.value:
elif provider == ServiceProviders.OPENROUTER.value:
kwargs = {}
if base_url:
kwargs["base_url"] = base_url
return OpenRouterLLMService(
api_key=user_config.llm.api_key,
base_url=user_config.llm.base_url,
api_key=api_key,
settings=OpenRouterLLMSettings(model=model, temperature=0.1),
**kwargs,
)
elif user_config.llm.provider == ServiceProviders.GOOGLE.value:
elif provider == ServiceProviders.GOOGLE.value:
return GoogleLLMService(
api_key=user_config.llm.api_key,
api_key=api_key,
settings=GoogleLLMSettings(model=model, temperature=0.1),
)
elif user_config.llm.provider == ServiceProviders.AZURE.value:
elif provider == ServiceProviders.AZURE.value:
return AzureLLMService(
api_key=user_config.llm.api_key,
endpoint=user_config.llm.endpoint,
api_key=api_key,
endpoint=endpoint,
settings=AzureLLMSettings(model=model, temperature=0.1),
)
elif user_config.llm.provider == ServiceProviders.DOGRAH.value:
elif provider == ServiceProviders.DOGRAH.value:
return DograhLLMService(
base_url=f"{MPS_API_URL}/api/v1/llm",
api_key=user_config.llm.api_key,
api_key=api_key,
settings=OpenAILLMSettings(model=model),
)
elif provider == ServiceProviders.AWS_BEDROCK.value:
return AWSBedrockLLMService(
aws_access_key=aws_access_key,
aws_secret_key=aws_secret_key,
aws_region=aws_region,
settings=AWSBedrockLLMSettings(model=model),
)
else:
raise HTTPException(status_code=400, detail="Invalid LLM provider")
raise HTTPException(status_code=400, detail=f"Invalid LLM provider {provider}")
def create_llm_service(user_config):
"""Create and return appropriate LLM service based on user configuration."""
provider = user_config.llm.provider
model = user_config.llm.model
api_key = user_config.llm.api_key
kwargs = {}
if provider == ServiceProviders.OPENROUTER.value:
kwargs["base_url"] = user_config.llm.base_url
elif provider == ServiceProviders.AZURE.value:
kwargs["endpoint"] = user_config.llm.endpoint
elif provider == ServiceProviders.AWS_BEDROCK.value:
kwargs["aws_access_key"] = user_config.llm.aws_access_key
kwargs["aws_secret_key"] = user_config.llm.aws_secret_key
kwargs["aws_region"] = user_config.llm.aws_region
return create_llm_service_from_provider(provider, model, api_key, **kwargs)

View file

@ -53,6 +53,7 @@ class NodeDataDTO(BaseModel):
extraction_prompt: Optional[str] = None
extraction_variables: Optional[list[ExtractionVariableDTO]] = None
add_global_prompt: bool = True
greeting: Optional[str] = None
wait_for_user_response: bool = False
wait_for_user_response_timeout: Optional[float] = None
detect_voicemail: bool = False

View file

@ -554,6 +554,13 @@ class PipecatEngine:
# Setup LLM Context with Prompts and Functions
await self._setup_llm_context(node)
def get_start_greeting(self) -> Optional[str]:
"""Return the rendered greeting for the start node, or None if not configured."""
start_node = self.workflow.nodes.get(self.workflow.start_node_id)
if start_node and start_node.greeting:
return self._format_prompt(start_node.greeting)
return None
async def _handle_end_node(self, node: Node) -> None:
"""Handle end node execution."""
if node.is_static:

View file

@ -4,19 +4,16 @@ import json
from typing import Any
from loguru import logger
from openai import AsyncOpenAI
from api.db.models import WorkflowRunModel
from api.services.gen_ai.json_parser import parse_llm_json
from api.services.pipecat.service_factory import create_llm_service_from_provider
from api.services.workflow.qa.conversation import (
build_conversation_structure,
format_transcript,
split_events_by_node,
)
from api.services.workflow.qa.llm_config import (
accumulate_token_usage,
resolve_llm_config,
)
from api.services.workflow.qa.llm_config import resolve_llm_config
from api.services.workflow.qa.metrics import compute_call_metrics
from api.services.workflow.qa.node_summary import (
CONVERSATION_SUMMARY_SYSTEM_PROMPT,
@ -28,15 +25,22 @@ from api.services.workflow.qa.tracing import (
setup_langfuse_parent_context,
)
from api.utils.template_renderer import render_template
from pipecat.processors.aggregators.llm_context import LLMContext
async def _run_llm_inference(llm, messages: list[dict]) -> str | None:
"""Run a one-shot LLM inference using the pipecat service."""
context = LLMContext()
context.set_messages(messages)
return await llm.run_inference(context)
async def _generate_conversation_summary(
client: AsyncOpenAI,
llm,
model: str,
transcript: str,
parent_ctx,
node_name: str,
total_token_usage: dict,
) -> str:
"""Generate a summary of the conversation so far (before the current node).
@ -48,13 +52,7 @@ async def _generate_conversation_summary(
]
try:
response = await client.chat.completions.create(
model=model,
messages=messages,
temperature=0,
)
summary = response.choices[0].message.content or ""
accumulate_token_usage(total_token_usage, response)
summary = await _run_llm_inference(llm, messages) or ""
span_name = f"conversation-summary-before-{node_name}"
add_qa_span_to_trace(parent_ctx, model, messages, summary, span_name)
@ -82,7 +80,7 @@ async def run_per_node_qa_analysis(
Falls back to whole-call QA if events lack node_id.
Returns:
Dict with node_results, token_usage, model
Dict with node_results, model
"""
logs = workflow_run.logs or {}
rtf_events = logs.get("realtime_feedback_events", [])
@ -107,7 +105,9 @@ async def run_per_node_qa_analysis(
return {"error": "no_system_prompt", "node_results": {}}
# Resolve LLM config
model, api_key, base_url = await resolve_llm_config(qa_node_data, workflow_run)
provider, model, api_key, service_kwargs = await resolve_llm_config(
qa_node_data, workflow_run
)
if not api_key:
logger.warning(
f"No LLM API key configured for QA analysis on run {workflow_run_id}"
@ -122,13 +122,9 @@ async def run_per_node_qa_analysis(
# Set up Langfuse tracing
parent_ctx = setup_langfuse_parent_context(workflow_run)
# Build LLM client
client_kwargs: dict[str, Any] = {"api_key": api_key}
if base_url:
client_kwargs["base_url"] = base_url
client = AsyncOpenAI(**client_kwargs)
# Build LLM service
llm = create_llm_service_from_provider(provider, model, api_key, **service_kwargs)
total_token_usage: dict[str, int] = {}
node_results: dict[str, Any] = {}
prior_conversation: list[dict] = [] # Running accumulation of all prior nodes
@ -150,12 +146,11 @@ async def run_per_node_qa_analysis(
if idx > 0 and prior_conversation:
prior_transcript = format_transcript(prior_conversation)
previous_conversation_summary = await _generate_conversation_summary(
client,
llm,
model,
prior_transcript,
parent_ctx,
node_name,
total_token_usage,
)
# Substitute placeholders in the user's system prompt
@ -174,14 +169,7 @@ async def run_per_node_qa_analysis(
# Call QA LLM
try:
response = await client.chat.completions.create(
model=model,
messages=messages,
temperature=0,
extra_body={"stream": False},
)
raw_response = response.choices[0].message.content
accumulate_token_usage(total_token_usage, response)
raw_response = await _run_llm_inference(llm, messages)
except Exception as e:
logger.error(
f"QA LLM call failed for node '{node_name}' on run {workflow_run_id}: {e}"
@ -221,13 +209,10 @@ async def run_per_node_qa_analysis(
# Append this node's conversation to running total
prior_conversation.extend(node_conversation)
result: dict[str, Any] = {
return {
"node_results": node_results,
"model": model,
}
if total_token_usage:
result["token_usage"] = total_token_usage
return result
async def _run_whole_call_qa_analysis(
@ -262,7 +247,9 @@ async def _run_whole_call_qa_analysis(
logger.warning("No system prompt defined for QA Node")
return {"error": "no_system_prompt", "node_results": {}}
model, api_key, base_url = await resolve_llm_config(qa_node_data, workflow_run)
provider, model, api_key, service_kwargs = await resolve_llm_config(
qa_node_data, workflow_run
)
if not api_key:
logger.warning(
@ -284,27 +271,14 @@ async def _run_whole_call_qa_analysis(
]
# Call LLM
client_kwargs: dict[str, Any] = {"api_key": api_key}
if base_url:
client_kwargs["base_url"] = base_url
client = AsyncOpenAI(**client_kwargs)
llm = create_llm_service_from_provider(provider, model, api_key, **service_kwargs)
try:
response = await client.chat.completions.create(
model=model,
messages=messages,
temperature=0,
)
raw_response = response.choices[0].message.content
raw_response = await _run_llm_inference(llm, messages)
except Exception as e:
logger.error(f"QA LLM call failed for run {workflow_run_id}: {e}")
return {"error": str(e), "node_results": {}}
# Extract token usage
token_usage: dict[str, int] = {}
accumulate_token_usage(token_usage, response)
# Parse response
node_result: dict[str, Any] = {
"node_name": "whole_call",
@ -325,10 +299,7 @@ async def _run_whole_call_qa_analysis(
parent_ctx = setup_langfuse_parent_context(workflow_run)
add_qa_span_to_trace(parent_ctx, model, messages, raw_response, "qa-analysis")
result: dict[str, Any] = {
return {
"node_results": {"whole_call": node_result},
"model": model,
}
if token_usage:
result["token_usage"] = token_usage
return result

View file

@ -1,63 +1,50 @@
"""LLM configuration resolution and token usage accumulation."""
from api.constants import MPS_API_URL
from api.db import db_client
from api.db.models import WorkflowRunModel
def _provider_base_url(provider: str | None, endpoint: str = "") -> str | None:
"""Return the base URL for a given LLM provider."""
if provider == "openrouter":
return "https://openrouter.ai/api/v1"
if provider == "groq":
return "https://api.groq.com/openai/v1"
if provider == "google":
return "https://generativelanguage.googleapis.com/v1beta/openai/"
if provider == "azure":
return endpoint or None
if provider == "dograh":
return f"{MPS_API_URL}/api/v1/llm"
return None
async def resolve_llm_config(
qa_node_data: dict, workflow_run: WorkflowRunModel
) -> tuple[str, str, str | None]:
"""Resolve the LLM model, API key, and base URL for QA analysis.
) -> tuple[str, str, str, dict]:
"""Resolve the LLM provider, model, API key, and extra kwargs for QA analysis.
If the QA node has its own LLM configuration (qa_use_workflow_llm=False),
use those settings directly. Otherwise, fall back to the user's configured LLM.
Returns:
(model, api_key, base_url) tuple
(provider, model, api_key, service_kwargs) tuple service_kwargs can be
passed directly to create_llm_service_from_provider as keyword arguments.
"""
if not qa_node_data.get("qa_use_workflow_llm", True):
provider = qa_node_data.get("qa_provider", "openai")
kwargs = {}
if provider == "azure":
kwargs["endpoint"] = qa_node_data.get("qa_endpoint", "")
return (
provider,
qa_node_data.get("qa_model"),
qa_node_data.get("qa_api_key"),
_provider_base_url(
qa_node_data.get("qa_provider"),
qa_node_data.get("qa_endpoint", ""),
),
kwargs,
)
# Fall back to user's configured LLM
model, api_key, base_url = await resolve_user_llm_config(workflow_run)
provider, model, api_key, kwargs = await resolve_user_llm_config(workflow_run)
qa_model = qa_node_data.get("qa_model", "default")
if qa_model and qa_model != "default":
model = qa_model
return model, api_key, base_url
return provider, model, api_key, kwargs
async def resolve_user_llm_config(
workflow_run: WorkflowRunModel,
) -> tuple[str, str, str | None]:
) -> tuple[str, str, str, dict]:
"""Resolve the user's configured LLM (from UserConfiguration).
Returns:
(model, api_key, base_url) tuple
(provider, model, api_key, service_kwargs) tuple
"""
user_id = None
if workflow_run.workflow and workflow_run.workflow.user:
@ -71,11 +58,14 @@ async def resolve_user_llm_config(
provider = llm_config.get("provider", "openai")
api_key = llm_config.get("api_key", "")
model = llm_config.get("model", "gpt-4.1")
base_url = _provider_base_url(provider, llm_config.get("endpoint", ""))
if provider == "openrouter" and llm_config.get("base_url"):
base_url = llm_config["base_url"]
return model, api_key, base_url
kwargs = {}
if provider == "azure":
kwargs["endpoint"] = llm_config.get("endpoint", "")
elif provider == "openrouter" and llm_config.get("base_url"):
kwargs["base_url"] = llm_config["base_url"]
return provider, model, api_key, kwargs
def accumulate_token_usage(total: dict, response) -> None:

View file

@ -3,13 +3,14 @@
from typing import Any
from loguru import logger
from openai import AsyncOpenAI
from api.db import db_client
from api.db.models import WorkflowRunModel
from api.services.pipecat.service_factory import create_llm_service_from_provider
from api.services.workflow.dto import NodeType
from api.services.workflow.qa.llm_config import resolve_llm_config
from api.services.workflow.qa.tracing import create_node_summary_trace
from pipecat.processors.aggregators.llm_context import LLMContext
NODE_SUMMARY_SYSTEM_PROMPT = (
"You are analyzing a voice AI agent script. This is only a part of a larger script. "
@ -67,15 +68,14 @@ async def ensure_node_summaries(
if not nodes_needing_summary:
return existing_summaries
model, api_key, base_url = await resolve_llm_config(qa_node_data, workflow_run)
provider, model, api_key, service_kwargs = await resolve_llm_config(
qa_node_data, workflow_run
)
if not api_key:
logger.warning("No API key for node summary generation, skipping")
return existing_summaries
client_kwargs: dict[str, Any] = {"api_key": api_key}
if base_url:
client_kwargs["base_url"] = base_url
client = AsyncOpenAI(**client_kwargs)
llm = create_llm_service_from_provider(provider, model, api_key, **service_kwargs)
updated_summaries = dict(existing_summaries)
@ -153,12 +153,9 @@ async def ensure_node_summaries(
]
try:
response = await client.chat.completions.create(
model=model,
messages=messages,
temperature=0,
)
summary_text = response.choices[0].message.content or ""
context = LLMContext()
context.set_messages(messages)
summary_text = await llm.run_inference(context) or ""
except Exception as e:
logger.warning(f"Failed to generate summary for node {node_id}: {e}")
updated_summaries[node_id] = {"summary": ""}

View file

@ -45,6 +45,7 @@ class Node:
self.extraction_prompt = data.extraction_prompt
self.extraction_variables = data.extraction_variables
self.add_global_prompt = data.add_global_prompt
self.greeting = data.greeting
self.detect_voicemail = data.detect_voicemail
self.delayed_start = data.delayed_start
self.delayed_start_duration = data.delayed_start_duration

View file

@ -139,7 +139,6 @@ class TestVoicemailDetectorWithUserAggregator:
# Create voicemail detector with the classification LLM
voicemail_detector = VoicemailDetector(
llm=voicemail_llm,
voicemail_response_delay=0,
)
# Set up frame counter to track UserStoppedSpeakingFrame in voicemail detector's user aggregator

View file

@ -18,11 +18,11 @@ def generate_transcript_text(events: List[dict]) -> str:
event_type == RealtimeFeedbackType.USER_TRANSCRIPTION.value
and payload.get("final") is True
):
timestamp = payload.get("timestamp", "")
timestamp = payload.get("timestamp") or event.get("timestamp", "")
prefix = f"[{timestamp}] " if timestamp else ""
lines.append(f"{prefix}user: {payload.get('text', '')}\n")
elif event_type == RealtimeFeedbackType.BOT_TEXT.value:
timestamp = payload.get("timestamp", "")
timestamp = payload.get("timestamp") or event.get("timestamp", "")
prefix = f"[{timestamp}] " if timestamp else ""
lines.append(f"{prefix}assistant: {payload.get('text', '')}\n")