mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
feat: add ultravox realtime and fix signature issue in telephony
- Add UltraVox realtime - Fix signature issue on telephony
This commit is contained in:
parent
9135c2da13
commit
ea0cac63cd
24 changed files with 2082 additions and 133 deletions
|
|
@ -132,6 +132,11 @@ Visit [https://www.dograh.com](https://www.dograh.com/) for our managed cloud of
|
|||
|
||||
You can go to [https://docs.dograh.com](https://docs.dograh.com/) for our documentation.
|
||||
|
||||
## 📦 SDKs
|
||||
|
||||
- **Python SDK** — [pypi.org/project/dograh-sdk](https://pypi.org/project/dograh-sdk/)
|
||||
- **Node SDK** — [npmjs.com/package/@dograh/sdk](https://www.npmjs.com/package/@dograh/sdk)
|
||||
|
||||
## 🤝Community & Support
|
||||
|
||||
> 👋 **Coming from the Better Stack video?** Drop your use case in our [pinned GitHub Discussion](https://github.com/orgs/dograh-hq/discussions/291) — we read every reply and the founders personally onboard early adopters.
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ from api.enums import CallType, WorkflowRunState
|
|||
from api.errors.telephony_errors import TelephonyError
|
||||
from api.sdk_expose import sdk_expose
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.quota_service import check_dograh_quota, check_dograh_quota_by_user_id
|
||||
from api.services.quota_service import check_dograh_quota_by_user_id
|
||||
from api.services.telephony.call_transfer_manager import get_call_transfer_manager
|
||||
from api.services.telephony.factory import (
|
||||
get_all_telephony_providers,
|
||||
|
|
@ -60,6 +60,15 @@ class InitiateCallRequest(BaseModel):
|
|||
from_phone_number_id: int | None = None
|
||||
|
||||
|
||||
def _get_execution_user_id(workflow) -> int:
|
||||
if workflow.user_id is None:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Workflow has no execution owner",
|
||||
)
|
||||
return workflow.user_id
|
||||
|
||||
|
||||
@router.post(
|
||||
"/initiate-call",
|
||||
**sdk_expose(
|
||||
|
|
@ -107,15 +116,6 @@ async def initiate_call(
|
|||
detail="telephony_not_configured",
|
||||
)
|
||||
|
||||
# Check Dograh quota before initiating the call (apply per-workflow
|
||||
# model_overrides so the keys we will actually use are the ones checked).
|
||||
quota_result = await check_dograh_quota(user, workflow_id=request.workflow_id)
|
||||
if not quota_result.has_quota:
|
||||
raise HTTPException(status_code=402, detail=quota_result.error_message)
|
||||
|
||||
# Determine the workflow run mode based on provider type
|
||||
workflow_run_mode = provider.PROVIDER_NAME
|
||||
|
||||
phone_number = request.phone_number or user_configuration.test_phone_number
|
||||
|
||||
if not phone_number:
|
||||
|
|
@ -125,25 +125,38 @@ async def initiate_call(
|
|||
"configuration",
|
||||
)
|
||||
|
||||
workflow = await db_client.get_workflow(
|
||||
request.workflow_id, organization_id=user.selected_organization_id
|
||||
)
|
||||
if not workflow:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
execution_user_id = _get_execution_user_id(workflow)
|
||||
|
||||
# Check Dograh quota before initiating the call (apply per-workflow
|
||||
# model_overrides so the keys we will actually use are the ones checked).
|
||||
quota_result = await check_dograh_quota_by_user_id(
|
||||
execution_user_id, workflow_id=workflow.id
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
raise HTTPException(status_code=402, detail=quota_result.error_message)
|
||||
|
||||
# Determine the workflow run mode based on provider type
|
||||
workflow_run_mode = provider.PROVIDER_NAME
|
||||
|
||||
workflow_run_id = request.workflow_run_id
|
||||
|
||||
if not workflow_run_id:
|
||||
# Fetch workflow to merge template context variables (e.g. caller_number,
|
||||
# called_number set in workflow settings for testing pre-call data fetch)
|
||||
workflow = await db_client.get_workflow(
|
||||
request.workflow_id, organization_id=user.selected_organization_id
|
||||
)
|
||||
if not workflow:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
# Merge template context variables (e.g. caller_number, called_number
|
||||
# set in workflow settings for testing pre-call data fetch).
|
||||
template_vars = workflow.template_context_variables or {}
|
||||
|
||||
numeric_suffix = int(str(uuid.uuid4()).replace("-", "")[:8], 16) % 100000000
|
||||
workflow_run_name = f"WR-TEL-OUT-{numeric_suffix:08d}"
|
||||
workflow_run = await db_client.create_workflow_run(
|
||||
workflow_run_name,
|
||||
request.workflow_id,
|
||||
workflow.id,
|
||||
workflow_run_mode,
|
||||
user_id=user.id,
|
||||
user_id=execution_user_id,
|
||||
call_type=CallType.OUTBOUND,
|
||||
initial_context={
|
||||
**template_vars,
|
||||
|
|
@ -157,9 +170,16 @@ async def initiate_call(
|
|||
)
|
||||
workflow_run_id = workflow_run.id
|
||||
else:
|
||||
workflow_run = await db_client.get_workflow_run(workflow_run_id, user.id)
|
||||
workflow_run = await db_client.get_workflow_run(
|
||||
workflow_run_id, organization_id=user.selected_organization_id
|
||||
)
|
||||
if not workflow_run:
|
||||
raise HTTPException(status_code=400, detail="Workflow run not found")
|
||||
if workflow_run.workflow_id != workflow.id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="workflow_run_workflow_mismatch",
|
||||
)
|
||||
workflow_run_name = workflow_run.name
|
||||
|
||||
# Construct webhook URL based on provider type
|
||||
|
|
@ -169,13 +189,13 @@ async def initiate_call(
|
|||
|
||||
webhook_url = (
|
||||
f"{backend_endpoint}/api/v1/telephony/{webhook_endpoint}"
|
||||
f"?workflow_id={request.workflow_id}"
|
||||
f"&user_id={user.id}"
|
||||
f"?workflow_id={workflow.id}"
|
||||
f"&user_id={execution_user_id}"
|
||||
f"&workflow_run_id={workflow_run_id}"
|
||||
f"&organization_id={user.selected_organization_id}"
|
||||
)
|
||||
|
||||
keywords = {"workflow_id": request.workflow_id, "user_id": user.id}
|
||||
keywords = {"workflow_id": workflow.id, "user_id": execution_user_id}
|
||||
|
||||
# Resolve optional caller-ID. The config has already been validated against
|
||||
# the user's organization, so filtering by config_id is sufficient for
|
||||
|
|
@ -293,6 +313,7 @@ async def _detect_provider(webhook_data: dict, headers: dict):
|
|||
|
||||
async def _validate_inbound_request(
|
||||
workflow_id: int,
|
||||
webhook_url: str,
|
||||
provider_class,
|
||||
normalized_data,
|
||||
webhook_data: dict,
|
||||
|
|
@ -364,8 +385,6 @@ async def _validate_inbound_request(
|
|||
# Verify webhook signature using the matched config's credentials. The
|
||||
# provider extracts its own signature/timestamp/nonce headers from the
|
||||
# dict, so this dispatcher stays generic.
|
||||
backend_endpoint, _ = await get_backend_endpoints()
|
||||
webhook_url = f"{backend_endpoint}/api/v1/telephony/inbound/{workflow_id}"
|
||||
provider_instance = await get_telephony_provider_by_id(
|
||||
telephony_configuration_id, organization_id
|
||||
)
|
||||
|
|
@ -701,13 +720,11 @@ async def handle_inbound_run(request: Request):
|
|||
user_id = workflow.user_id
|
||||
|
||||
# 3. Verify webhook signature against the matched config's credentials.
|
||||
backend_endpoint, wss_backend_endpoint = await get_backend_endpoints()
|
||||
webhook_url = f"{backend_endpoint}/api/v1/telephony/inbound/run"
|
||||
provider_instance = await get_telephony_provider_by_id(
|
||||
telephony_configuration_id, config.organization_id
|
||||
)
|
||||
signature_valid = await provider_instance.verify_inbound_signature(
|
||||
webhook_url, webhook_data, headers, raw_body
|
||||
str(request.url), webhook_data, headers, raw_body
|
||||
)
|
||||
if not signature_valid:
|
||||
logger.warning(
|
||||
|
|
@ -840,6 +857,7 @@ async def handle_inbound_telephony(
|
|||
provider_instance,
|
||||
) = await _validate_inbound_request(
|
||||
workflow_id,
|
||||
str(request.url),
|
||||
provider_class,
|
||||
normalized_data,
|
||||
webhook_data,
|
||||
|
|
|
|||
|
|
@ -296,7 +296,6 @@ class CampaignCallDispatcher:
|
|||
f"?workflow_id={campaign.workflow_id}"
|
||||
f"&user_id={campaign.created_by}"
|
||||
f"&workflow_run_id={workflow_run.id}"
|
||||
f"&campaign_id={campaign.id}"
|
||||
f"&organization_id={campaign.organization_id}"
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -50,6 +50,7 @@ class UserConfigurationValidator:
|
|||
ServiceProviders.GOOGLE_VERTEX.value: self._check_google_vertex_llm_api_key,
|
||||
ServiceProviders.OPENAI_REALTIME.value: self._check_openai_api_key,
|
||||
ServiceProviders.GROK_REALTIME.value: self._check_grok_realtime_api_key,
|
||||
ServiceProviders.ULTRAVOX_REALTIME.value: self._check_ultravox_realtime_api_key,
|
||||
ServiceProviders.GOOGLE_REALTIME.value: self._check_google_api_key,
|
||||
ServiceProviders.GOOGLE_VERTEX_REALTIME.value: self._check_google_vertex_realtime_api_key,
|
||||
ServiceProviders.ASSEMBLYAI.value: self._check_assemblyai_api_key,
|
||||
|
|
@ -255,6 +256,9 @@ class UserConfigurationValidator:
|
|||
def _check_grok_realtime_api_key(self, model: str, api_key: str) -> bool:
|
||||
return True
|
||||
|
||||
def _check_ultravox_realtime_api_key(self, model: str, api_key: str) -> bool:
|
||||
return True
|
||||
|
||||
def _check_speechmatics_api_key(self, model: str, api_key: str) -> bool:
|
||||
return True
|
||||
|
||||
|
|
|
|||
|
|
@ -38,7 +38,11 @@ def check_for_masked_keys(config: "UserConfiguration") -> None:
|
|||
for secret_field in SERVICE_SECRET_FIELDS:
|
||||
if not hasattr(service, secret_field):
|
||||
continue
|
||||
if contains_masked_key(getattr(service, secret_field, None)):
|
||||
if secret_field == "api_key" and hasattr(service, "get_all_api_keys"):
|
||||
secret_value = service.get_all_api_keys()
|
||||
else:
|
||||
secret_value = getattr(service, secret_field, None)
|
||||
if contains_masked_key(secret_value):
|
||||
raise ValueError(
|
||||
f"The {field} {secret_field} appears to be masked. "
|
||||
"Please provide the actual value, not the masked value."
|
||||
|
|
|
|||
|
|
@ -62,6 +62,7 @@ class ServiceProviders(str, Enum):
|
|||
GOOGLE_VERTEX = "google_vertex"
|
||||
OPENAI_REALTIME = "openai_realtime"
|
||||
GROK_REALTIME = "grok_realtime"
|
||||
ULTRAVOX_REALTIME = "ultravox_realtime"
|
||||
GOOGLE_REALTIME = "google_realtime"
|
||||
GOOGLE_VERTEX_REALTIME = "google_vertex_realtime"
|
||||
|
||||
|
|
@ -85,6 +86,7 @@ class BaseServiceConfiguration(BaseModel):
|
|||
ServiceProviders.GOOGLE_VERTEX,
|
||||
ServiceProviders.OPENAI_REALTIME,
|
||||
ServiceProviders.GROK_REALTIME,
|
||||
ServiceProviders.ULTRAVOX_REALTIME,
|
||||
ServiceProviders.GOOGLE_REALTIME,
|
||||
ServiceProviders.GOOGLE_VERTEX_REALTIME,
|
||||
# ServiceProviders.SARVAM,
|
||||
|
|
@ -214,6 +216,7 @@ AWS_BEDROCK_PROVIDER_MODEL_CONFIG = provider_model_config("AWS Bedrock")
|
|||
GOOGLE_VERTEX_PROVIDER_MODEL_CONFIG = provider_model_config("Google Vertex")
|
||||
OPENAI_REALTIME_PROVIDER_MODEL_CONFIG = provider_model_config("OpenAI Realtime")
|
||||
GROK_REALTIME_PROVIDER_MODEL_CONFIG = provider_model_config("Grok Realtime")
|
||||
ULTRAVOX_REALTIME_PROVIDER_MODEL_CONFIG = provider_model_config("Ultravox Realtime")
|
||||
GOOGLE_REALTIME_PROVIDER_MODEL_CONFIG = provider_model_config("Google Realtime")
|
||||
GOOGLE_VERTEX_REALTIME_PROVIDER_MODEL_CONFIG = provider_model_config(
|
||||
"Google Vertex Realtime"
|
||||
|
|
@ -504,6 +507,7 @@ class OpenAIRealtimeLLMConfiguration(BaseLLMConfiguration):
|
|||
|
||||
GROK_REALTIME_MODELS = ["grok-voice-think-fast-1.0"]
|
||||
GROK_REALTIME_VOICES = ["Ara", "Rex", "Sal", "Eve", "Leo"]
|
||||
ULTRAVOX_REALTIME_MODELS = ["ultravox-v0.7", "fixie-ai/ultravox"]
|
||||
|
||||
|
||||
@register_service(ServiceType.REALTIME)
|
||||
|
|
@ -528,6 +532,26 @@ class GrokRealtimeLLMConfiguration(BaseLLMConfiguration):
|
|||
)
|
||||
|
||||
|
||||
@register_service(ServiceType.REALTIME)
|
||||
class UltravoxRealtimeLLMConfiguration(BaseLLMConfiguration):
|
||||
model_config = ULTRAVOX_REALTIME_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.ULTRAVOX_REALTIME] = (
|
||||
ServiceProviders.ULTRAVOX_REALTIME
|
||||
)
|
||||
model: str = Field(
|
||||
default="ultravox-v0.7",
|
||||
description="Ultravox realtime voice-agent model.",
|
||||
json_schema_extra={
|
||||
"examples": ULTRAVOX_REALTIME_MODELS,
|
||||
"allow_custom_input": True,
|
||||
},
|
||||
)
|
||||
voice: str = Field(
|
||||
default="Mark",
|
||||
description="Ultravox voice name or voice ID.",
|
||||
)
|
||||
|
||||
|
||||
@register_service(ServiceType.REALTIME)
|
||||
class GoogleRealtimeLLMConfiguration(BaseLLMConfiguration):
|
||||
model_config = GOOGLE_REALTIME_PROVIDER_MODEL_CONFIG
|
||||
|
|
@ -615,6 +639,7 @@ class GoogleVertexRealtimeLLMConfiguration(BaseLLMConfiguration):
|
|||
REALTIME_PROVIDERS = {
|
||||
ServiceProviders.OPENAI_REALTIME.value,
|
||||
ServiceProviders.GROK_REALTIME.value,
|
||||
ServiceProviders.ULTRAVOX_REALTIME.value,
|
||||
ServiceProviders.GOOGLE_REALTIME.value,
|
||||
ServiceProviders.GOOGLE_VERTEX_REALTIME.value,
|
||||
}
|
||||
|
|
@ -640,6 +665,7 @@ RealtimeConfig = Annotated[
|
|||
Union[
|
||||
OpenAIRealtimeLLMConfiguration,
|
||||
GrokRealtimeLLMConfiguration,
|
||||
UltravoxRealtimeLLMConfiguration,
|
||||
GoogleRealtimeLLMConfiguration,
|
||||
GoogleVertexRealtimeLLMConfiguration,
|
||||
],
|
||||
|
|
|
|||
653
api/services/pipecat/realtime/ultravox_realtime.py
Normal file
653
api/services/pipecat/realtime/ultravox_realtime.py
Normal file
|
|
@ -0,0 +1,653 @@
|
|||
"""Dograh subclass of pipecat's Ultravox realtime LLM service.
|
||||
|
||||
Ultravox is audio-native and realtime, but prompt and tool configuration is
|
||||
bound to call creation. Dograh therefore cannot lean on in-session updates or
|
||||
Gemini-style session resumption handles. This wrapper adapts Ultravox to the
|
||||
Dograh engine contract by:
|
||||
|
||||
- deferring the first call creation until the engine queues the initial node
|
||||
opening via ``TTSSpeakFrame`` or ``LLMContextFrame``
|
||||
- marking the call for recreation when ``system_instruction`` changes across
|
||||
node transitions, then rebuilding it on the follow-up ``LLMContextFrame``
|
||||
so the transition tool result is present in ``initialMessages``
|
||||
- reconstructing Ultravox ``initialMessages`` from Dograh context when the
|
||||
call must be recreated after a node transition
|
||||
- appending a transient resumptive user nudge to recreated ``initialMessages``
|
||||
after tool-result transitions, without mutating Dograh's stored context
|
||||
- handling Dograh-only frames such as user mute and idle append prompts
|
||||
- tagging user transcripts with ``finalized=True`` for downstream parity
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import Field
|
||||
from websockets.exceptions import ConnectionClosed
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
LLMMessagesAppendFrame,
|
||||
TranscriptionFrame,
|
||||
TTSSpeakFrame,
|
||||
UserMuteStartedFrame,
|
||||
UserMuteStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators import async_tool_messages
|
||||
from pipecat.processors.aggregators.llm_context import (
|
||||
LLMContext,
|
||||
LLMSpecificMessage,
|
||||
is_given,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.llm_service import LLMService
|
||||
from pipecat.services.settings import _NotGiven, assert_given
|
||||
from pipecat.services.ultravox.llm import (
|
||||
OneShotInputParams,
|
||||
UltravoxRealtimeLLMService,
|
||||
websocket_client,
|
||||
)
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
|
||||
class DograhUltravoxOneShotInputParams(OneShotInputParams):
|
||||
"""Dograh-friendly OneShot params with string voice support."""
|
||||
|
||||
voice: str | None = Field(default=None)
|
||||
|
||||
|
||||
_ULTRAVOX_MAX_TOOL_TIMEOUT_SECS = 40.0
|
||||
_RESUMPTION_USER_MESSAGE = (
|
||||
"IMPORTANT: We are resuming an existing conversation. You are given previous turns ONLY for your reference. "
|
||||
"Do not use that to frame your response. Follow your ORIGINAL INSTRUCTIONS ONLY."
|
||||
)
|
||||
|
||||
|
||||
class DograhUltravoxRealtimeLLMService(UltravoxRealtimeLLMService):
|
||||
"""Ultravox realtime with Dograh engine integration quirks."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._context: LLMContext | None = None
|
||||
self._selected_tools = None
|
||||
self._user_is_muted: bool = False
|
||||
self._call_system_instruction: str | None = None
|
||||
self._reconnect_required: bool = False
|
||||
self._call_started: bool = False
|
||||
self._has_connected_once: bool = False
|
||||
self._pending_reconnect_system_instruction: str | None = None
|
||||
self._pending_initial_messages: list[dict[str, Any]] | None = None
|
||||
self._pending_user_text_messages: list[str] = []
|
||||
|
||||
async def start(self, frame):
|
||||
# Dograh defers call creation until the engine queues the node opening.
|
||||
await LLMService.start(self, frame)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
if isinstance(frame, UserMuteStartedFrame):
|
||||
self._user_is_muted = True
|
||||
await self.push_frame(frame, direction)
|
||||
return
|
||||
if isinstance(frame, UserMuteStoppedFrame):
|
||||
self._user_is_muted = False
|
||||
await self.push_frame(frame, direction)
|
||||
return
|
||||
if isinstance(frame, TTSSpeakFrame):
|
||||
if not self._socket:
|
||||
await self._connect_call(
|
||||
system_instruction=self._current_system_instruction(),
|
||||
greeting_text=frame.text,
|
||||
initial_messages=None,
|
||||
agent_speaks_first=True,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"{self}: TTSSpeakFrame received after the Ultravox call was "
|
||||
"already created; ignoring because Ultravox owns speech output"
|
||||
)
|
||||
return
|
||||
if isinstance(frame, LLMMessagesAppendFrame):
|
||||
await self._handle_messages_append(frame)
|
||||
return
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
async def _update_settings(self, delta: UltravoxRealtimeLLMService.Settings):
|
||||
changed = await super(UltravoxRealtimeLLMService, self)._update_settings(delta)
|
||||
if "output_medium" in changed:
|
||||
await self._update_output_medium(assert_given(self._settings.output_medium))
|
||||
if "system_instruction" in changed and self._has_connected_once:
|
||||
# Mirror Gemini's "settings change means reconnect" intent, but
|
||||
# defer the actual new-call creation until the subsequent
|
||||
# LLMContextFrame arrives with the transition tool result. Ultravox
|
||||
# cannot accept that historical tool result over a formal
|
||||
# post-connect tool-response channel the way Gemini can.
|
||||
self._reconnect_required = True
|
||||
handled = {"output_medium", "system_instruction"}
|
||||
self._warn_unhandled_updated_settings(changed.keys() - handled)
|
||||
return changed
|
||||
|
||||
async def _disconnect(self, preserve_completed_tool_calls: bool = True):
|
||||
self._disconnecting = True
|
||||
await self.stop_all_metrics()
|
||||
if self._socket:
|
||||
await self._socket.close()
|
||||
self._socket = None
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task, timeout=1.0)
|
||||
self._receive_task = None
|
||||
if not preserve_completed_tool_calls:
|
||||
self._completed_tool_calls = set()
|
||||
self._call_started = False
|
||||
self._started_placeholder_sent = set()
|
||||
self._disconnecting = False
|
||||
|
||||
async def _send_user_audio(self, frame):
|
||||
if self._user_is_muted:
|
||||
return
|
||||
await super()._send_user_audio(frame)
|
||||
|
||||
async def _handle_context(self, context: LLMContext):
|
||||
self._context = context
|
||||
system_instruction = self._current_system_instruction()
|
||||
|
||||
if self._socket and not self._reconnect_required:
|
||||
await super()._handle_context(context)
|
||||
return
|
||||
|
||||
initial_messages, history_tool_call_ids = self._build_initial_messages(context)
|
||||
if history_tool_call_ids:
|
||||
self._completed_tool_calls.update(history_tool_call_ids)
|
||||
|
||||
if self._bot_responding:
|
||||
self._pending_reconnect_system_instruction = system_instruction
|
||||
self._pending_initial_messages = initial_messages
|
||||
return
|
||||
|
||||
await self._reconnect_with_context(
|
||||
system_instruction=system_instruction,
|
||||
initial_messages=initial_messages,
|
||||
)
|
||||
|
||||
async def _handle_response_end(self):
|
||||
await super()._handle_response_end()
|
||||
if self._pending_reconnect_system_instruction is None:
|
||||
return
|
||||
|
||||
system_instruction = self._pending_reconnect_system_instruction
|
||||
initial_messages = self._pending_initial_messages
|
||||
self._pending_reconnect_system_instruction = None
|
||||
self._pending_initial_messages = None
|
||||
await self._reconnect_with_context(
|
||||
system_instruction=system_instruction,
|
||||
initial_messages=initial_messages,
|
||||
)
|
||||
|
||||
async def _handle_messages_append(self, frame: LLMMessagesAppendFrame):
|
||||
texts = [
|
||||
text
|
||||
for text in (
|
||||
self._extract_text_content(message.get("content"))
|
||||
for message in frame.messages
|
||||
if isinstance(message, dict)
|
||||
)
|
||||
if text
|
||||
]
|
||||
if not texts:
|
||||
return
|
||||
|
||||
if not self._socket:
|
||||
self._pending_user_text_messages.extend(texts)
|
||||
await self._connect_call(
|
||||
system_instruction=self._current_system_instruction(),
|
||||
greeting_text=None,
|
||||
initial_messages=None,
|
||||
agent_speaks_first=False,
|
||||
)
|
||||
return
|
||||
|
||||
if not self._call_started:
|
||||
self._pending_user_text_messages.extend(texts)
|
||||
logger.debug(
|
||||
f"{self}: queueing {len(texts)} user text message(s) until call_started"
|
||||
)
|
||||
return
|
||||
|
||||
for text in texts:
|
||||
await self._send_user_text(text)
|
||||
|
||||
async def _handle_user_transcript(self, text: str):
|
||||
transcript = text.strip() if text else ""
|
||||
if not transcript:
|
||||
return
|
||||
await self.broadcast_frame(
|
||||
TranscriptionFrame,
|
||||
user_id=self._last_user_id or "",
|
||||
timestamp=time_now_iso8601(),
|
||||
result=text,
|
||||
text=transcript,
|
||||
finalized=True,
|
||||
)
|
||||
|
||||
async def _connect_call(
|
||||
self,
|
||||
*,
|
||||
system_instruction: str | None,
|
||||
greeting_text: str | None,
|
||||
initial_messages: list[dict[str, Any]] | None,
|
||||
agent_speaks_first: bool,
|
||||
):
|
||||
params = self._build_one_shot_params(
|
||||
greeting_text=greeting_text,
|
||||
initial_messages=initial_messages,
|
||||
agent_speaks_first=agent_speaks_first,
|
||||
)
|
||||
self._params = params
|
||||
self._selected_tools = self._current_tools_schema(self._context)
|
||||
tool_names = (
|
||||
[tool.name for tool in self._selected_tools.standard_tools]
|
||||
if self._selected_tools
|
||||
else []
|
||||
)
|
||||
prompt = params.system_prompt or ""
|
||||
prompt_hash = hashlib.sha256(prompt.encode("utf-8")).hexdigest()[:12]
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
f"{self}: creating Ultravox call "
|
||||
f"(agent_speaks_first={agent_speaks_first}, "
|
||||
f"voice={params.voice!r}, "
|
||||
f"tools={tool_names}, "
|
||||
f"system_prompt_len={len(prompt)}, "
|
||||
f"system_prompt_sha256={prompt_hash})"
|
||||
)
|
||||
join_url = await self._start_one_shot_call(params)
|
||||
logger.info(f"Joining Ultravox Realtime call via URL: {join_url}")
|
||||
self._socket = await websocket_client.connect(join_url)
|
||||
self._receive_task = self.create_task(self._receive_messages())
|
||||
self._call_system_instruction = system_instruction
|
||||
self._call_started = False
|
||||
self._has_connected_once = True
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"{self}: Ultravox call creation/join failed "
|
||||
f"for tools={tool_names}: {e}"
|
||||
)
|
||||
await self.push_error(f"Failed to connect to Ultravox: {e}", e, fatal=True)
|
||||
|
||||
async def _receive_messages(self):
|
||||
"""Receive messages from the Ultravox Realtime WebSocket.
|
||||
|
||||
Upstream handles exceptions raised while processing individual messages,
|
||||
but websocket close exceptions are raised by the async iterator itself.
|
||||
During user hangup / pipeline teardown that close is expected, so treat
|
||||
normal websocket shutdown as a debug condition rather than a pipeline
|
||||
error.
|
||||
"""
|
||||
if not self._socket:
|
||||
return
|
||||
|
||||
try:
|
||||
async for message in self._socket:
|
||||
try:
|
||||
if isinstance(message, bytes):
|
||||
await self._handle_audio(message)
|
||||
continue
|
||||
|
||||
data = json.loads(message)
|
||||
match data.get("type"):
|
||||
case "call_started":
|
||||
self._call_started = True
|
||||
logger.debug(
|
||||
f"{self}: Ultravox call_started received for callId="
|
||||
f"{data.get('callId')}"
|
||||
)
|
||||
await self._flush_pending_user_text_messages()
|
||||
case "state":
|
||||
if self._bot_responding and data.get("state") != "speaking":
|
||||
await self._handle_response_end()
|
||||
case "client_tool_invocation":
|
||||
await self._handle_tool_invocation(
|
||||
data.get("toolName"),
|
||||
data.get("invocationId"),
|
||||
data.get("parameters"),
|
||||
)
|
||||
case "transcript":
|
||||
match data.get("role"):
|
||||
case "user":
|
||||
if not data.get("final"):
|
||||
logger.warning(
|
||||
"Unexpected non-final user transcript from Ultravox Realtime; ignoring."
|
||||
)
|
||||
else:
|
||||
await self._handle_user_transcript(
|
||||
data.get("text")
|
||||
)
|
||||
case "agent":
|
||||
await self._handle_agent_transcript(
|
||||
data.get("medium"),
|
||||
data.get("text"),
|
||||
data.get("delta"),
|
||||
data.get("final", False),
|
||||
)
|
||||
case _:
|
||||
logger.debug(
|
||||
f"Received transcript with unknown role from Ultravox Realtime: {data}"
|
||||
)
|
||||
case _:
|
||||
logger.debug(f"Received unhandled Ultravox message: {data}")
|
||||
except Exception as e:
|
||||
if self._disconnecting or not self._socket:
|
||||
return
|
||||
await self.push_error(
|
||||
"Ultravox websocket receive error", e, fatal=True
|
||||
)
|
||||
except ConnectionClosed as e:
|
||||
if (
|
||||
self._disconnecting
|
||||
or not self._socket
|
||||
or self._is_benign_websocket_close(e)
|
||||
):
|
||||
logger.debug(f"{self}: Ultravox websocket closed: {e}")
|
||||
return
|
||||
await self.push_error("Ultravox websocket receive error", e, fatal=True)
|
||||
|
||||
async def _flush_pending_user_text_messages(self):
|
||||
if (
|
||||
not self._socket
|
||||
or not self._call_started
|
||||
or not self._pending_user_text_messages
|
||||
):
|
||||
return
|
||||
|
||||
pending_texts = self._pending_user_text_messages
|
||||
self._pending_user_text_messages = []
|
||||
for pending_text in pending_texts:
|
||||
await self._send_user_text(pending_text)
|
||||
|
||||
async def _reconnect_with_context(
|
||||
self,
|
||||
*,
|
||||
system_instruction: str | None,
|
||||
initial_messages: list[dict[str, Any]] | None,
|
||||
):
|
||||
call_initial_messages = self._initial_messages_for_call(initial_messages)
|
||||
logger.debug(
|
||||
f"{self}: reconnecting Ultravox call with initialMessages="
|
||||
f"{json.dumps(call_initial_messages, ensure_ascii=True, default=str)}"
|
||||
)
|
||||
if self._socket:
|
||||
await self._disconnect(preserve_completed_tool_calls=True)
|
||||
|
||||
await self._connect_call(
|
||||
system_instruction=system_instruction,
|
||||
greeting_text=None,
|
||||
initial_messages=initial_messages,
|
||||
agent_speaks_first=self._should_agent_speak_first(initial_messages),
|
||||
)
|
||||
self._reconnect_required = False
|
||||
|
||||
def _build_one_shot_params(
|
||||
self,
|
||||
*,
|
||||
greeting_text: str | None,
|
||||
initial_messages: list[dict[str, Any]] | None,
|
||||
agent_speaks_first: bool,
|
||||
) -> DograhUltravoxOneShotInputParams:
|
||||
current_params = self._params
|
||||
extra = {
|
||||
key: value
|
||||
for key, value in current_params.extra.items()
|
||||
if key not in {"firstSpeakerSettings", "initialMessages"}
|
||||
}
|
||||
|
||||
if greeting_text is not None:
|
||||
extra["firstSpeakerSettings"] = {"agent": {"text": greeting_text}}
|
||||
elif agent_speaks_first:
|
||||
extra["firstSpeakerSettings"] = {"agent": {}}
|
||||
else:
|
||||
extra["firstSpeakerSettings"] = {"user": {}}
|
||||
call_initial_messages = self._initial_messages_for_call(initial_messages)
|
||||
if call_initial_messages:
|
||||
extra["initialMessages"] = call_initial_messages
|
||||
|
||||
output_medium = self._settings.output_medium
|
||||
if isinstance(output_medium, _NotGiven):
|
||||
output_medium = current_params.output_medium
|
||||
|
||||
return DograhUltravoxOneShotInputParams(
|
||||
api_key=current_params.api_key,
|
||||
system_prompt=self._current_system_instruction(),
|
||||
temperature=current_params.temperature,
|
||||
model=assert_given(self._settings.model),
|
||||
voice=current_params.voice,
|
||||
metadata=current_params.metadata,
|
||||
output_medium=output_medium,
|
||||
max_duration=current_params.max_duration,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
def _current_tools_schema(self, context: LLMContext | None):
|
||||
if context is None or not is_given(context.tools):
|
||||
return None
|
||||
return context.tools
|
||||
|
||||
def _to_selected_tools(self, tool: Any) -> list[dict[str, Any]]:
|
||||
selected_tools = super()._to_selected_tools(tool)
|
||||
for selected_tool in selected_tools:
|
||||
temporary_tool = selected_tool.get("temporaryTool")
|
||||
if not isinstance(temporary_tool, dict):
|
||||
continue
|
||||
|
||||
tool_name = temporary_tool.get("modelToolName")
|
||||
if not isinstance(tool_name, str):
|
||||
continue
|
||||
|
||||
timeout = self._ultravox_timeout_for_tool(tool_name)
|
||||
if timeout is not None:
|
||||
temporary_tool["timeout"] = timeout
|
||||
return selected_tools
|
||||
|
||||
def _current_system_instruction(self) -> str | None:
|
||||
system_instruction = self._settings.system_instruction
|
||||
if isinstance(system_instruction, _NotGiven):
|
||||
return None
|
||||
return system_instruction
|
||||
|
||||
def _ultravox_timeout_for_tool(self, function_name: str) -> str | None:
|
||||
item = self._functions.get(function_name) or self._functions.get(None)
|
||||
if item is None or item.timeout_secs is None or item.timeout_secs <= 0:
|
||||
return None
|
||||
|
||||
timeout_secs = min(float(item.timeout_secs), _ULTRAVOX_MAX_TOOL_TIMEOUT_SECS)
|
||||
return f"{timeout_secs:g}s"
|
||||
|
||||
def _initial_messages_for_call(
|
||||
self, initial_messages: list[dict[str, Any]] | None
|
||||
) -> list[dict[str, Any]] | None:
|
||||
if not initial_messages:
|
||||
return None
|
||||
if not self._should_add_resumption_user_message(initial_messages):
|
||||
return initial_messages
|
||||
|
||||
return [
|
||||
*initial_messages,
|
||||
{
|
||||
"role": "MESSAGE_ROLE_USER",
|
||||
"text": _RESUMPTION_USER_MESSAGE,
|
||||
},
|
||||
]
|
||||
|
||||
def _build_initial_messages(
|
||||
self, context: LLMContext
|
||||
) -> tuple[list[dict[str, Any]] | None, set[str]]:
|
||||
initial_messages: list[dict[str, Any]] = []
|
||||
tool_call_id_to_name: dict[str, str] = {}
|
||||
completed_tool_call_ids: set[str] = set()
|
||||
|
||||
for message in context.get_messages():
|
||||
if isinstance(message, LLMSpecificMessage):
|
||||
continue
|
||||
|
||||
async_payload = async_tool_messages.parse_message(message)
|
||||
if async_payload is not None:
|
||||
if async_payload.kind == "intermediate":
|
||||
logger.error(
|
||||
f"{self}: Ultravox does not support streamed async tool results; "
|
||||
f"dropping intermediate result from initialMessages for "
|
||||
f"tool_call_id={async_payload.tool_call_id}."
|
||||
)
|
||||
continue
|
||||
if async_payload.kind == "final":
|
||||
initial_message = self._build_ultravox_message(
|
||||
role="MESSAGE_ROLE_TOOL_RESULT",
|
||||
text=async_payload.result or "",
|
||||
invocation_id=async_payload.tool_call_id,
|
||||
tool_name=tool_call_id_to_name.get(async_payload.tool_call_id),
|
||||
)
|
||||
if initial_message is not None:
|
||||
initial_messages.append(initial_message)
|
||||
completed_tool_call_ids.add(async_payload.tool_call_id)
|
||||
continue
|
||||
|
||||
role = message.get("role")
|
||||
if role == "user":
|
||||
initial_message = self._build_ultravox_message(
|
||||
role="MESSAGE_ROLE_USER",
|
||||
text=self._extract_text_content(message.get("content")),
|
||||
)
|
||||
if initial_message is not None:
|
||||
initial_messages.append(initial_message)
|
||||
elif role == "assistant":
|
||||
text = self._extract_text_content(message.get("content"))
|
||||
initial_message = self._build_ultravox_message(
|
||||
role="MESSAGE_ROLE_AGENT",
|
||||
text=text,
|
||||
)
|
||||
if initial_message is not None:
|
||||
initial_messages.append(initial_message)
|
||||
|
||||
tool_calls = message.get("tool_calls")
|
||||
if isinstance(tool_calls, list):
|
||||
for tool_call in tool_calls:
|
||||
if not isinstance(tool_call, dict):
|
||||
continue
|
||||
tool_id = tool_call.get("id")
|
||||
function = tool_call.get("function")
|
||||
tool_name = (
|
||||
function.get("name") if isinstance(function, dict) else None
|
||||
)
|
||||
if isinstance(tool_id, str) and isinstance(tool_name, str):
|
||||
tool_call_id_to_name[tool_id] = tool_name
|
||||
initial_message = self._build_ultravox_message(
|
||||
role="MESSAGE_ROLE_TOOL_CALL",
|
||||
text="",
|
||||
invocation_id=tool_id,
|
||||
tool_name=tool_name,
|
||||
)
|
||||
if initial_message is not None:
|
||||
initial_messages.append(initial_message)
|
||||
elif (
|
||||
role == "tool"
|
||||
and message.get("content") != "IN_PROGRESS"
|
||||
and message.get("content") != "CANCELLED"
|
||||
):
|
||||
tool_call_id = message.get("tool_call_id")
|
||||
initial_message = self._build_ultravox_message(
|
||||
role="MESSAGE_ROLE_TOOL_RESULT",
|
||||
text=self._stringify_tool_result(message.get("content")),
|
||||
invocation_id=tool_call_id
|
||||
if isinstance(tool_call_id, str)
|
||||
else None,
|
||||
tool_name=(
|
||||
tool_call_id_to_name.get(tool_call_id)
|
||||
if isinstance(tool_call_id, str)
|
||||
else None
|
||||
),
|
||||
)
|
||||
if initial_message is not None:
|
||||
initial_messages.append(initial_message)
|
||||
if isinstance(tool_call_id, str):
|
||||
completed_tool_call_ids.add(tool_call_id)
|
||||
|
||||
return (initial_messages or None), completed_tool_call_ids
|
||||
|
||||
@staticmethod
|
||||
def _build_ultravox_message(
|
||||
*,
|
||||
role: str,
|
||||
text: str | None,
|
||||
invocation_id: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
if text is None:
|
||||
return None
|
||||
|
||||
message: dict[str, Any] = {
|
||||
"role": role,
|
||||
"text": text,
|
||||
}
|
||||
if invocation_id is not None:
|
||||
message["invocationId"] = invocation_id
|
||||
if tool_name is not None:
|
||||
message["toolName"] = tool_name
|
||||
return message
|
||||
|
||||
@staticmethod
|
||||
def _should_agent_speak_first(
|
||||
initial_messages: list[dict[str, Any]] | None,
|
||||
) -> bool:
|
||||
if not initial_messages:
|
||||
return True
|
||||
return initial_messages[-1].get("role") in {
|
||||
"MESSAGE_ROLE_USER",
|
||||
"MESSAGE_ROLE_TOOL_RESULT",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _should_add_resumption_user_message(
|
||||
initial_messages: list[dict[str, Any]] | None,
|
||||
) -> bool:
|
||||
if not initial_messages:
|
||||
return False
|
||||
return initial_messages[-1].get("role") == "MESSAGE_ROLE_TOOL_RESULT"
|
||||
|
||||
@staticmethod
|
||||
def _is_benign_websocket_close(exc: ConnectionClosed) -> bool:
|
||||
return any(
|
||||
close is not None and close.code in {1000, 1001}
|
||||
for close in (exc.sent, exc.rcvd)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_text_content(content: Any) -> str | None:
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for part in content:
|
||||
if not isinstance(part, dict):
|
||||
return None
|
||||
if part.get("type") != "text":
|
||||
return None
|
||||
text = part.get("text")
|
||||
if not isinstance(text, str):
|
||||
return None
|
||||
parts.append(text)
|
||||
return "\n".join(parts) if parts else None
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _stringify_tool_result(content: Any) -> str:
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for part in content:
|
||||
if isinstance(part, dict):
|
||||
text = part.get("text")
|
||||
if isinstance(text, str):
|
||||
parts.append(text)
|
||||
if parts:
|
||||
return "".join(parts)
|
||||
return json.dumps(content, ensure_ascii=True, default=str)
|
||||
|
|
@ -640,6 +640,24 @@ def create_realtime_llm_service(user_config, audio_config: "AudioConfig"):
|
|||
),
|
||||
),
|
||||
)
|
||||
elif provider == ServiceProviders.ULTRAVOX_REALTIME.value:
|
||||
from api.services.pipecat.realtime.ultravox_realtime import (
|
||||
DograhUltravoxOneShotInputParams,
|
||||
DograhUltravoxRealtimeLLMService,
|
||||
)
|
||||
|
||||
return DograhUltravoxRealtimeLLMService(
|
||||
params=DograhUltravoxOneShotInputParams(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
voice=voice,
|
||||
output_medium="voice",
|
||||
),
|
||||
settings=DograhUltravoxRealtimeLLMService.Settings(
|
||||
model=model,
|
||||
output_medium="voice",
|
||||
),
|
||||
)
|
||||
elif provider == ServiceProviders.GOOGLE_REALTIME.value:
|
||||
from api.services.pipecat.realtime.gemini_live import (
|
||||
DograhGeminiLiveLLMService,
|
||||
|
|
|
|||
|
|
@ -5,9 +5,8 @@ provider registry — see ProviderSpec.router.
|
|||
"""
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Header, Request
|
||||
from fastapi import APIRouter, Request
|
||||
from loguru import logger
|
||||
from pipecat.utils.run_context import set_current_run_id
|
||||
from starlette.responses import HTMLResponse
|
||||
|
|
@ -18,7 +17,6 @@ from api.services.telephony.status_processor import (
|
|||
StatusCallbackRequest,
|
||||
_process_status_update,
|
||||
)
|
||||
from api.utils.common import get_backend_endpoints
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
|
@ -26,9 +24,6 @@ router = APIRouter()
|
|||
async def _handle_plivo_status_callback(
|
||||
workflow_run_id: int,
|
||||
request: Request,
|
||||
x_plivo_signature_v3: Optional[str],
|
||||
x_plivo_signature_ma_v3: Optional[str],
|
||||
x_plivo_signature_v3_nonce: Optional[str],
|
||||
):
|
||||
set_current_run_id(workflow_run_id)
|
||||
|
||||
|
|
@ -52,19 +47,14 @@ async def _handle_plivo_status_callback(
|
|||
workflow_run, workflow.organization_id
|
||||
)
|
||||
|
||||
signature = x_plivo_signature_v3 or x_plivo_signature_ma_v3
|
||||
if signature:
|
||||
backend_endpoint, _ = await get_backend_endpoints()
|
||||
callback_kind = request.url.path.split("/")[-2]
|
||||
full_url = f"{backend_endpoint}/api/v1/telephony/plivo/{callback_kind}/{workflow_run_id}"
|
||||
is_valid = await provider.verify_inbound_signature(
|
||||
full_url,
|
||||
callback_data,
|
||||
dict(request.headers),
|
||||
)
|
||||
if not is_valid:
|
||||
logger.warning(f"[run {workflow_run_id}] Invalid Plivo webhook signature")
|
||||
return {"status": "error", "reason": "invalid_signature"}
|
||||
is_valid = await provider.verify_inbound_signature(
|
||||
str(request.url),
|
||||
callback_data,
|
||||
dict(request.headers),
|
||||
)
|
||||
if not is_valid:
|
||||
logger.warning(f"[run {workflow_run_id}] Invalid Plivo webhook signature")
|
||||
return {"status": "error", "reason": "invalid_signature"}
|
||||
|
||||
parsed_data = provider.parse_status_callback(callback_data)
|
||||
status_update = StatusCallbackRequest(
|
||||
|
|
@ -88,9 +78,6 @@ async def handle_plivo_xml_webhook(
|
|||
workflow_run_id: int,
|
||||
organization_id: int,
|
||||
request: Request,
|
||||
x_plivo_signature_v3: Optional[str] = Header(None),
|
||||
x_plivo_signature_ma_v3: Optional[str] = Header(None),
|
||||
x_plivo_signature_v3_nonce: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
Handle initial webhook from Plivo when an outbound call is answered.
|
||||
|
|
@ -103,26 +90,16 @@ async def handle_plivo_xml_webhook(
|
|||
form_data = await request.form()
|
||||
callback_data = dict(form_data)
|
||||
|
||||
signature = x_plivo_signature_v3 or x_plivo_signature_ma_v3
|
||||
if signature:
|
||||
backend_endpoint, _ = await get_backend_endpoints()
|
||||
full_url = (
|
||||
f"{backend_endpoint}/api/v1/telephony/plivo-xml"
|
||||
f"?workflow_id={workflow_id}"
|
||||
f"&user_id={user_id}"
|
||||
f"&workflow_run_id={workflow_run_id}"
|
||||
f"&organization_id={organization_id}"
|
||||
is_valid = await provider.verify_inbound_signature(
|
||||
str(request.url), callback_data, dict(request.headers)
|
||||
)
|
||||
if not is_valid:
|
||||
logger.warning(
|
||||
f"[run {workflow_run_id}] Invalid Plivo signature on answer webhook"
|
||||
)
|
||||
is_valid = await provider.verify_inbound_signature(
|
||||
full_url, callback_data, dict(request.headers)
|
||||
return provider.generate_error_response(
|
||||
"invalid_signature", "Invalid webhook signature."
|
||||
)
|
||||
if not is_valid:
|
||||
logger.warning(
|
||||
f"[run {workflow_run_id}] Invalid Plivo signature on answer webhook"
|
||||
)
|
||||
return provider.generate_error_response(
|
||||
"invalid_signature", "Invalid webhook signature."
|
||||
)
|
||||
|
||||
call_id = callback_data.get("CallUUID") or callback_data.get("RequestUUID")
|
||||
if call_id:
|
||||
|
|
@ -142,33 +119,15 @@ async def handle_plivo_xml_webhook(
|
|||
async def handle_plivo_hangup_callback(
|
||||
workflow_run_id: int,
|
||||
request: Request,
|
||||
x_plivo_signature_v3: Optional[str] = Header(None),
|
||||
x_plivo_signature_ma_v3: Optional[str] = Header(None),
|
||||
x_plivo_signature_v3_nonce: Optional[str] = Header(None),
|
||||
):
|
||||
"""Handle Plivo hangup callbacks."""
|
||||
return await _handle_plivo_status_callback(
|
||||
workflow_run_id,
|
||||
request,
|
||||
x_plivo_signature_v3,
|
||||
x_plivo_signature_ma_v3,
|
||||
x_plivo_signature_v3_nonce,
|
||||
)
|
||||
return await _handle_plivo_status_callback(workflow_run_id, request)
|
||||
|
||||
|
||||
@router.post("/plivo/ring-callback/{workflow_run_id}")
|
||||
async def handle_plivo_ring_callback(
|
||||
workflow_run_id: int,
|
||||
request: Request,
|
||||
x_plivo_signature_v3: Optional[str] = Header(None),
|
||||
x_plivo_signature_ma_v3: Optional[str] = Header(None),
|
||||
x_plivo_signature_v3_nonce: Optional[str] = Header(None),
|
||||
):
|
||||
"""Handle Plivo ring callbacks."""
|
||||
return await _handle_plivo_status_callback(
|
||||
workflow_run_id,
|
||||
request,
|
||||
x_plivo_signature_v3,
|
||||
x_plivo_signature_ma_v3,
|
||||
x_plivo_signature_v3_nonce,
|
||||
)
|
||||
return await _handle_plivo_status_callback(workflow_run_id, request)
|
||||
|
|
|
|||
|
|
@ -5,9 +5,8 @@ provider registry — see ProviderSpec.router.
|
|||
"""
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Header, Request
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from loguru import logger
|
||||
from pipecat.utils.run_context import set_current_run_id
|
||||
from starlette.responses import HTMLResponse
|
||||
|
|
@ -18,14 +17,17 @@ from api.services.telephony.status_processor import (
|
|||
StatusCallbackRequest,
|
||||
_process_status_update,
|
||||
)
|
||||
from api.utils.common import get_backend_endpoints
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/twiml", include_in_schema=False)
|
||||
async def handle_twiml_webhook(
|
||||
workflow_id: int, user_id: int, workflow_run_id: int, organization_id: int
|
||||
workflow_id: int,
|
||||
user_id: int,
|
||||
workflow_run_id: int,
|
||||
organization_id: int,
|
||||
request: Request,
|
||||
):
|
||||
"""
|
||||
Handle initial webhook from telephony provider.
|
||||
|
|
@ -34,6 +36,18 @@ async def handle_twiml_webhook(
|
|||
|
||||
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
|
||||
provider = await get_telephony_provider_for_run(workflow_run, organization_id)
|
||||
callback_data = dict(await request.form())
|
||||
|
||||
is_valid = await provider.verify_inbound_signature(
|
||||
str(request.url),
|
||||
callback_data,
|
||||
dict(request.headers),
|
||||
)
|
||||
if not is_valid:
|
||||
logger.warning(
|
||||
f"[run {workflow_run_id}] Invalid Twilio signature on answer webhook"
|
||||
)
|
||||
raise HTTPException(status_code=401, detail="Invalid webhook signature")
|
||||
|
||||
response_content = await provider.get_webhook_response(
|
||||
workflow_id, user_id, workflow_run_id
|
||||
|
|
@ -46,7 +60,6 @@ async def handle_twiml_webhook(
|
|||
async def handle_twilio_status_callback(
|
||||
workflow_run_id: int,
|
||||
request: Request,
|
||||
x_webhook_signature: Optional[str] = Header(None),
|
||||
):
|
||||
"""Handle Twilio-specific status callbacks."""
|
||||
set_current_run_id(workflow_run_id)
|
||||
|
|
@ -75,19 +88,14 @@ async def handle_twilio_status_callback(
|
|||
workflow_run, workflow.organization_id
|
||||
)
|
||||
|
||||
if x_webhook_signature:
|
||||
backend_endpoint, _ = await get_backend_endpoints()
|
||||
full_url = f"{backend_endpoint}/api/v1/telephony/twilio/status-callback/{workflow_run_id}"
|
||||
|
||||
is_valid = await provider.verify_webhook_signature(
|
||||
full_url, callback_data, x_webhook_signature
|
||||
)
|
||||
|
||||
if not is_valid:
|
||||
logger.warning(
|
||||
f"Invalid webhook signature for workflow run {workflow_run_id}"
|
||||
)
|
||||
return {"status": "error", "reason": "invalid_signature"}
|
||||
is_valid = await provider.verify_inbound_signature(
|
||||
str(request.url),
|
||||
callback_data,
|
||||
dict(request.headers),
|
||||
)
|
||||
if not is_valid:
|
||||
logger.warning(f"Invalid webhook signature for workflow run {workflow_run_id}")
|
||||
raise HTTPException(status_code=401, detail="Invalid webhook signature")
|
||||
|
||||
# Parse the callback data into generic format
|
||||
parsed_data = provider.parse_status_callback(callback_data)
|
||||
|
|
|
|||
|
|
@ -81,9 +81,9 @@ async def handle_vobiz_hangup_callback(
|
|||
f"[run {workflow_run_id}] Vobiz hangup callback - Headers: {json.dumps(all_headers)}"
|
||||
)
|
||||
|
||||
# Parse the callback data (Vobiz sends form data or JSON)
|
||||
form_data = await request.form()
|
||||
callback_data = dict(form_data)
|
||||
# Parse the callback data from the raw body so signed webhooks can verify
|
||||
# the exact bytes Vobiz sent without draining the request stream first.
|
||||
callback_data, raw_body = await parse_webhook_request(request)
|
||||
|
||||
# TODO: Remove this debug logging after Vobiz team clarifies webhook authentication
|
||||
logger.info(
|
||||
|
|
@ -114,10 +114,6 @@ async def handle_vobiz_hangup_callback(
|
|||
workflow_run, workflow.organization_id
|
||||
)
|
||||
|
||||
# Get raw body for signature verification
|
||||
raw_body = await request.body()
|
||||
webhook_body = raw_body.decode("utf-8")
|
||||
|
||||
# Verify signature
|
||||
backend_endpoint, _ = await get_backend_endpoints()
|
||||
webhook_url = f"{backend_endpoint}/api/v1/telephony/vobiz/hangup-callback/{workflow_run_id}"
|
||||
|
|
@ -127,7 +123,7 @@ async def handle_vobiz_hangup_callback(
|
|||
callback_data,
|
||||
x_vobiz_signature,
|
||||
x_vobiz_timestamp,
|
||||
webhook_body,
|
||||
raw_body,
|
||||
)
|
||||
|
||||
if not is_valid:
|
||||
|
|
@ -206,9 +202,9 @@ async def handle_vobiz_ring_callback(
|
|||
f"[run {workflow_run_id}] Vobiz ring callback - Headers: {json.dumps(all_headers)}"
|
||||
)
|
||||
|
||||
# Parse the callback data
|
||||
form_data = await request.form()
|
||||
callback_data = dict(form_data)
|
||||
# Parse the callback data from the raw body so signed webhooks can verify
|
||||
# the exact bytes Vobiz sent without draining the request stream first.
|
||||
callback_data, raw_body = await parse_webhook_request(request)
|
||||
|
||||
# TODO: Remove this debug logging after Vobiz team clarifies webhook authentication
|
||||
logger.info(
|
||||
|
|
@ -240,10 +236,6 @@ async def handle_vobiz_ring_callback(
|
|||
workflow_run, workflow.organization_id
|
||||
)
|
||||
|
||||
# Get raw body for signature verification
|
||||
raw_body = await request.body()
|
||||
webhook_body = raw_body.decode("utf-8")
|
||||
|
||||
# Verify signature
|
||||
backend_endpoint, _ = await get_backend_endpoints()
|
||||
webhook_url = (
|
||||
|
|
@ -255,7 +247,7 @@ async def handle_vobiz_ring_callback(
|
|||
callback_data,
|
||||
x_vobiz_signature,
|
||||
x_vobiz_timestamp,
|
||||
webhook_body,
|
||||
raw_body,
|
||||
)
|
||||
|
||||
if not is_valid:
|
||||
|
|
@ -311,9 +303,10 @@ async def handle_vobiz_hangup_callback_by_workflow(
|
|||
)
|
||||
|
||||
try:
|
||||
callback_data, _ = await parse_webhook_request(request)
|
||||
callback_data, raw_body = await parse_webhook_request(request)
|
||||
except ValueError:
|
||||
callback_data = {}
|
||||
raw_body = ""
|
||||
|
||||
call_uuid = callback_data.get("CallUUID") or callback_data.get("call_uuid")
|
||||
logger.info(
|
||||
|
|
@ -356,8 +349,6 @@ async def handle_vobiz_hangup_callback_by_workflow(
|
|||
)
|
||||
|
||||
if x_vobiz_signature:
|
||||
raw_body = await request.body()
|
||||
webhook_body = raw_body.decode("utf-8")
|
||||
backend_endpoint, _ = await get_backend_endpoints()
|
||||
webhook_url = f"{backend_endpoint}/api/v1/telephony/vobiz/hangup-callback/workflow/{workflow_id}"
|
||||
|
||||
|
|
@ -366,7 +357,7 @@ async def handle_vobiz_hangup_callback_by_workflow(
|
|||
callback_data,
|
||||
x_vobiz_signature,
|
||||
x_vobiz_timestamp,
|
||||
webhook_body,
|
||||
raw_body,
|
||||
)
|
||||
|
||||
if not is_valid:
|
||||
|
|
|
|||
|
|
@ -297,6 +297,10 @@ class CustomToolManager:
|
|||
timeout_secs = 120.0
|
||||
handler = self._create_transfer_call_handler(tool, function_name)
|
||||
else:
|
||||
timeout_ms = ((tool.definition or {}).get("config", {}) or {}).get(
|
||||
"timeout_ms", 5000
|
||||
)
|
||||
timeout_secs = float(timeout_ms) / 1000
|
||||
handler = self._create_http_tool_handler(tool, function_name)
|
||||
|
||||
return handler, timeout_secs
|
||||
|
|
|
|||
185
api/tests/telephony/plivo/test_routes.py
Normal file
185
api/tests/telephony/plivo/test_routes.py
Normal file
|
|
@ -0,0 +1,185 @@
|
|||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import pytest
|
||||
from starlette.requests import Request
|
||||
|
||||
from api.services.telephony.providers.plivo.provider import PlivoProvider
|
||||
from api.services.telephony.providers.plivo.routes import (
|
||||
handle_plivo_hangup_callback,
|
||||
handle_plivo_xml_webhook,
|
||||
)
|
||||
|
||||
|
||||
def _provider() -> PlivoProvider:
|
||||
return PlivoProvider(
|
||||
{
|
||||
"auth_id": "MA123",
|
||||
"auth_token": "plivo-auth-token",
|
||||
"from_numbers": ["+15551230002"],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _request(
|
||||
*,
|
||||
path: str,
|
||||
query: dict[str, str | int],
|
||||
form_data: dict[str, str],
|
||||
headers: dict[str, str] | None = None,
|
||||
) -> Request:
|
||||
body = urlencode(form_data).encode("utf-8")
|
||||
query_string = urlencode(query).encode("utf-8")
|
||||
request_headers = [
|
||||
(b"content-type", b"application/x-www-form-urlencoded"),
|
||||
*[
|
||||
(name.lower().encode("ascii"), value.encode("ascii"))
|
||||
for name, value in (headers or {}).items()
|
||||
],
|
||||
]
|
||||
|
||||
async def receive():
|
||||
return {
|
||||
"type": "http.request",
|
||||
"body": body,
|
||||
"more_body": False,
|
||||
}
|
||||
|
||||
return Request(
|
||||
{
|
||||
"type": "http",
|
||||
"method": "POST",
|
||||
"scheme": "https",
|
||||
"server": ("example.test", 443),
|
||||
"path": path,
|
||||
"query_string": query_string,
|
||||
"headers": request_headers,
|
||||
},
|
||||
receive,
|
||||
)
|
||||
|
||||
|
||||
def _signature(
|
||||
provider: PlivoProvider,
|
||||
*,
|
||||
path: str,
|
||||
query: dict[str, str | int],
|
||||
form_data: dict[str, str],
|
||||
nonce: str,
|
||||
) -> str:
|
||||
url = f"https://example.test{path}"
|
||||
if query:
|
||||
url = f"{url}?{urlencode(query)}"
|
||||
payload = f"{provider._construct_post_url(url, form_data)}.{nonce}"
|
||||
return base64.b64encode(
|
||||
hmac.new(
|
||||
provider.auth_token.encode("utf-8"),
|
||||
payload.encode("utf-8"),
|
||||
hashlib.sha256,
|
||||
).digest()
|
||||
).decode("utf-8")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plivo_xml_route_accepts_valid_signature_with_extra_query_param():
|
||||
provider = _provider()
|
||||
query = {
|
||||
"workflow_id": 7,
|
||||
"user_id": 8,
|
||||
"workflow_run_id": 123,
|
||||
"campaign_id": 42,
|
||||
"organization_id": 11,
|
||||
}
|
||||
form_data = {
|
||||
"CallUUID": "call-123",
|
||||
"Direction": "outbound",
|
||||
"From": "15551230001",
|
||||
"To": "15551230002",
|
||||
}
|
||||
nonce = "nonce-123"
|
||||
request = _request(
|
||||
path="/api/v1/telephony/plivo-xml",
|
||||
query=query,
|
||||
form_data=form_data,
|
||||
headers={
|
||||
"x-plivo-signature-v3": _signature(
|
||||
provider,
|
||||
path="/api/v1/telephony/plivo-xml",
|
||||
query=query,
|
||||
form_data=form_data,
|
||||
nonce=nonce,
|
||||
),
|
||||
"x-plivo-signature-v3-nonce": nonce,
|
||||
},
|
||||
)
|
||||
|
||||
with (
|
||||
patch("api.services.telephony.providers.plivo.routes.db_client") as db_client,
|
||||
patch(
|
||||
"api.services.telephony.providers.plivo.routes.get_telephony_provider_for_run",
|
||||
new_callable=AsyncMock,
|
||||
return_value=provider,
|
||||
),
|
||||
patch.object(
|
||||
provider,
|
||||
"get_webhook_response",
|
||||
new_callable=AsyncMock,
|
||||
return_value="<Response/>",
|
||||
) as get_webhook_response,
|
||||
):
|
||||
db_client.get_workflow_run_by_id = AsyncMock(
|
||||
return_value=SimpleNamespace(gathered_context={}, workflow_id=7)
|
||||
)
|
||||
db_client.update_workflow_run = AsyncMock()
|
||||
|
||||
response = await handle_plivo_xml_webhook(
|
||||
workflow_id=7,
|
||||
user_id=8,
|
||||
workflow_run_id=123,
|
||||
organization_id=11,
|
||||
request=request,
|
||||
)
|
||||
|
||||
assert response.body == b"<Response/>"
|
||||
get_webhook_response.assert_awaited_once_with(7, 8, 123)
|
||||
db_client.update_workflow_run.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plivo_status_callback_rejects_missing_signature():
|
||||
provider = _provider()
|
||||
request = _request(
|
||||
path="/api/v1/telephony/plivo/hangup-callback/123",
|
||||
query={},
|
||||
form_data={"CallUUID": "call-123", "Event": "hangup"},
|
||||
)
|
||||
|
||||
with (
|
||||
patch("api.services.telephony.providers.plivo.routes.db_client") as db_client,
|
||||
patch(
|
||||
"api.services.telephony.providers.plivo.routes.get_telephony_provider_for_run",
|
||||
new_callable=AsyncMock,
|
||||
return_value=provider,
|
||||
),
|
||||
patch(
|
||||
"api.services.telephony.providers.plivo.routes._process_status_update",
|
||||
new_callable=AsyncMock,
|
||||
) as process_status,
|
||||
):
|
||||
db_client.get_workflow_run_by_id = AsyncMock(
|
||||
return_value=SimpleNamespace(workflow_id=7)
|
||||
)
|
||||
db_client.get_workflow_by_id = AsyncMock(
|
||||
return_value=SimpleNamespace(organization_id=11)
|
||||
)
|
||||
|
||||
result = await handle_plivo_hangup_callback(
|
||||
workflow_run_id=123, request=request
|
||||
)
|
||||
|
||||
assert result == {"status": "error", "reason": "invalid_signature"}
|
||||
process_status.assert_not_awaited()
|
||||
253
api/tests/telephony/twilio/test_routes.py
Normal file
253
api/tests/telephony/twilio/test_routes.py
Normal file
|
|
@ -0,0 +1,253 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from starlette.requests import Request
|
||||
from twilio.request_validator import RequestValidator
|
||||
|
||||
from api.services.telephony.providers.twilio.provider import TwilioProvider
|
||||
from api.services.telephony.providers.twilio.routes import (
|
||||
handle_twilio_status_callback,
|
||||
handle_twiml_webhook,
|
||||
)
|
||||
|
||||
|
||||
def _provider() -> TwilioProvider:
|
||||
return TwilioProvider(
|
||||
{
|
||||
"account_sid": "AC123",
|
||||
"auth_token": "twilio-auth-token",
|
||||
"from_numbers": ["+15551230002"],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _request(
|
||||
*,
|
||||
path: str,
|
||||
query: dict[str, str | int],
|
||||
form_data: dict[str, str],
|
||||
headers: dict[str, str] | None = None,
|
||||
) -> Request:
|
||||
body = urlencode(form_data).encode("utf-8")
|
||||
query_string = urlencode(query).encode("utf-8")
|
||||
request_headers = [
|
||||
(b"content-type", b"application/x-www-form-urlencoded"),
|
||||
*[
|
||||
(name.lower().encode("ascii"), value.encode("ascii"))
|
||||
for name, value in (headers or {}).items()
|
||||
],
|
||||
]
|
||||
|
||||
async def receive():
|
||||
return {
|
||||
"type": "http.request",
|
||||
"body": body,
|
||||
"more_body": False,
|
||||
}
|
||||
|
||||
return Request(
|
||||
{
|
||||
"type": "http",
|
||||
"method": "POST",
|
||||
"scheme": "https",
|
||||
"server": ("example.test", 443),
|
||||
"path": path,
|
||||
"query_string": query_string,
|
||||
"headers": request_headers,
|
||||
},
|
||||
receive,
|
||||
)
|
||||
|
||||
|
||||
def _signature(
|
||||
provider: TwilioProvider,
|
||||
*,
|
||||
path: str,
|
||||
query: dict[str, str | int],
|
||||
form_data: dict[str, str],
|
||||
) -> str:
|
||||
url = f"https://example.test{path}"
|
||||
if query:
|
||||
url = f"{url}?{urlencode(query)}"
|
||||
validator = RequestValidator(provider.auth_token)
|
||||
return validator.compute_signature(url, form_data)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_twiml_route_accepts_valid_signature_with_extra_query_param():
|
||||
provider = _provider()
|
||||
query = {
|
||||
"workflow_id": 7,
|
||||
"user_id": 8,
|
||||
"workflow_run_id": 123,
|
||||
"campaign_id": 42,
|
||||
"organization_id": 11,
|
||||
}
|
||||
form_data = {"CallSid": "CA123", "CallStatus": "in-progress"}
|
||||
request = _request(
|
||||
path="/api/v1/telephony/twiml",
|
||||
query=query,
|
||||
form_data=form_data,
|
||||
headers={
|
||||
"x-twilio-signature": _signature(
|
||||
provider,
|
||||
path="/api/v1/telephony/twiml",
|
||||
query=query,
|
||||
form_data=form_data,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
with (
|
||||
patch("api.services.telephony.providers.twilio.routes.db_client") as db_client,
|
||||
patch(
|
||||
"api.services.telephony.providers.twilio.routes.get_telephony_provider_for_run",
|
||||
new_callable=AsyncMock,
|
||||
return_value=provider,
|
||||
),
|
||||
patch.object(
|
||||
provider,
|
||||
"get_webhook_response",
|
||||
new_callable=AsyncMock,
|
||||
return_value="<Response/>",
|
||||
) as get_webhook_response,
|
||||
):
|
||||
db_client.get_workflow_run_by_id = AsyncMock(
|
||||
return_value=SimpleNamespace(id=123)
|
||||
)
|
||||
|
||||
response = await handle_twiml_webhook(
|
||||
workflow_id=7,
|
||||
user_id=8,
|
||||
workflow_run_id=123,
|
||||
organization_id=11,
|
||||
request=request,
|
||||
)
|
||||
|
||||
assert response.body == b"<Response/>"
|
||||
get_webhook_response.assert_awaited_once_with(7, 8, 123)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_twiml_route_rejects_missing_signature():
|
||||
provider = _provider()
|
||||
request = _request(
|
||||
path="/api/v1/telephony/twiml",
|
||||
query={
|
||||
"workflow_id": 7,
|
||||
"user_id": 8,
|
||||
"workflow_run_id": 123,
|
||||
"organization_id": 11,
|
||||
},
|
||||
form_data={"CallSid": "CA123"},
|
||||
)
|
||||
|
||||
with (
|
||||
patch("api.services.telephony.providers.twilio.routes.db_client") as db_client,
|
||||
patch(
|
||||
"api.services.telephony.providers.twilio.routes.get_telephony_provider_for_run",
|
||||
new_callable=AsyncMock,
|
||||
return_value=provider,
|
||||
),
|
||||
):
|
||||
db_client.get_workflow_run_by_id = AsyncMock(
|
||||
return_value=SimpleNamespace(id=123)
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await handle_twiml_webhook(
|
||||
workflow_id=7,
|
||||
user_id=8,
|
||||
workflow_run_id=123,
|
||||
organization_id=11,
|
||||
request=request,
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert exc_info.value.detail == "Invalid webhook signature"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_twilio_status_callback_rejects_legacy_header_name():
|
||||
provider = _provider()
|
||||
form_data = {"CallSid": "CA123", "CallStatus": "completed"}
|
||||
request = _request(
|
||||
path="/api/v1/telephony/twilio/status-callback/123",
|
||||
query={},
|
||||
form_data=form_data,
|
||||
headers={"x-webhook-signature": "not-a-twilio-signature"},
|
||||
)
|
||||
|
||||
with (
|
||||
patch("api.services.telephony.providers.twilio.routes.db_client") as db_client,
|
||||
patch(
|
||||
"api.services.telephony.providers.twilio.routes.get_telephony_provider_for_run",
|
||||
new_callable=AsyncMock,
|
||||
return_value=provider,
|
||||
),
|
||||
patch(
|
||||
"api.services.telephony.providers.twilio.routes._process_status_update",
|
||||
new_callable=AsyncMock,
|
||||
) as process_status,
|
||||
):
|
||||
db_client.get_workflow_run_by_id = AsyncMock(
|
||||
return_value=SimpleNamespace(workflow_id=7)
|
||||
)
|
||||
db_client.get_workflow_by_id = AsyncMock(
|
||||
return_value=SimpleNamespace(organization_id=11)
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await handle_twilio_status_callback(workflow_run_id=123, request=request)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert exc_info.value.detail == "Invalid webhook signature"
|
||||
process_status.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_twilio_status_callback_accepts_valid_signature():
|
||||
provider = _provider()
|
||||
form_data = {"CallSid": "CA123", "CallStatus": "completed"}
|
||||
request = _request(
|
||||
path="/api/v1/telephony/twilio/status-callback/123",
|
||||
query={},
|
||||
form_data=form_data,
|
||||
headers={
|
||||
"x-twilio-signature": _signature(
|
||||
provider,
|
||||
path="/api/v1/telephony/twilio/status-callback/123",
|
||||
query={},
|
||||
form_data=form_data,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
with (
|
||||
patch("api.services.telephony.providers.twilio.routes.db_client") as db_client,
|
||||
patch(
|
||||
"api.services.telephony.providers.twilio.routes.get_telephony_provider_for_run",
|
||||
new_callable=AsyncMock,
|
||||
return_value=provider,
|
||||
),
|
||||
patch(
|
||||
"api.services.telephony.providers.twilio.routes._process_status_update",
|
||||
new_callable=AsyncMock,
|
||||
) as process_status,
|
||||
):
|
||||
db_client.get_workflow_run_by_id = AsyncMock(
|
||||
return_value=SimpleNamespace(workflow_id=7)
|
||||
)
|
||||
db_client.get_workflow_by_id = AsyncMock(
|
||||
return_value=SimpleNamespace(organization_id=11)
|
||||
)
|
||||
|
||||
result = await handle_twilio_status_callback(
|
||||
workflow_run_id=123, request=request
|
||||
)
|
||||
|
||||
assert result == {"status": "success"}
|
||||
process_status.assert_awaited_once()
|
||||
178
api/tests/telephony/vobiz/test_routes.py
Normal file
178
api/tests/telephony/vobiz/test_routes.py
Normal file
|
|
@ -0,0 +1,178 @@
|
|||
import hashlib
|
||||
import hmac
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import pytest
|
||||
from starlette.requests import Request
|
||||
|
||||
from api.services.telephony.providers.vobiz.provider import VobizProvider
|
||||
from api.services.telephony.providers.vobiz.routes import (
|
||||
handle_vobiz_hangup_callback,
|
||||
handle_vobiz_ring_callback,
|
||||
)
|
||||
|
||||
|
||||
def _provider() -> VobizProvider:
|
||||
return VobizProvider(
|
||||
{
|
||||
"auth_id": "MA123",
|
||||
"auth_token": "vobiz-auth-token",
|
||||
"from_numbers": ["+15551230002"],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _request(
|
||||
*,
|
||||
path: str,
|
||||
form_data: dict[str, str],
|
||||
headers: dict[str, str] | None = None,
|
||||
) -> Request:
|
||||
body = urlencode(form_data).encode("utf-8")
|
||||
request_headers = [
|
||||
(b"content-type", b"application/x-www-form-urlencoded"),
|
||||
*[
|
||||
(name.lower().encode("ascii"), value.encode("ascii"))
|
||||
for name, value in (headers or {}).items()
|
||||
],
|
||||
]
|
||||
|
||||
async def receive():
|
||||
return {
|
||||
"type": "http.request",
|
||||
"body": body,
|
||||
"more_body": False,
|
||||
}
|
||||
|
||||
return Request(
|
||||
{
|
||||
"type": "http",
|
||||
"method": "POST",
|
||||
"scheme": "https",
|
||||
"server": ("example.test", 443),
|
||||
"path": path,
|
||||
"query_string": b"",
|
||||
"headers": request_headers,
|
||||
},
|
||||
receive,
|
||||
)
|
||||
|
||||
|
||||
def _signed_headers(
|
||||
provider: VobizProvider, *, form_data: dict[str, str]
|
||||
) -> dict[str, str]:
|
||||
timestamp = str(int(datetime.now(UTC).timestamp()))
|
||||
body = urlencode(form_data)
|
||||
signature = hmac.new(
|
||||
provider.auth_token.encode("utf-8"),
|
||||
f"{timestamp}.{body}".encode("utf-8"),
|
||||
hashlib.sha256,
|
||||
).hexdigest()
|
||||
return {
|
||||
"x-vobiz-signature": signature,
|
||||
"x-vobiz-timestamp": timestamp,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vobiz_hangup_callback_accepts_signed_form_body():
|
||||
provider = _provider()
|
||||
form_data = {
|
||||
"CallUUID": "call-123",
|
||||
"CallStatus": "completed",
|
||||
"From": "15551230001",
|
||||
"To": "15551230002",
|
||||
"Direction": "outbound",
|
||||
"Duration": "12",
|
||||
}
|
||||
headers = _signed_headers(provider, form_data=form_data)
|
||||
request = _request(
|
||||
path="/api/v1/telephony/vobiz/hangup-callback/123",
|
||||
form_data=form_data,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
with (
|
||||
patch("api.services.telephony.providers.vobiz.routes.db_client") as db_client,
|
||||
patch(
|
||||
"api.services.telephony.providers.vobiz.routes.get_telephony_provider_for_run",
|
||||
new_callable=AsyncMock,
|
||||
return_value=provider,
|
||||
),
|
||||
patch(
|
||||
"api.services.telephony.providers.vobiz.routes.get_backend_endpoints",
|
||||
new_callable=AsyncMock,
|
||||
return_value=("https://example.test", "wss://example.test"),
|
||||
),
|
||||
patch(
|
||||
"api.services.telephony.providers.vobiz.routes._process_status_update",
|
||||
new_callable=AsyncMock,
|
||||
) as process_status,
|
||||
):
|
||||
db_client.get_workflow_run_by_id = AsyncMock(
|
||||
return_value=SimpleNamespace(workflow_id=7)
|
||||
)
|
||||
db_client.get_workflow_by_id = AsyncMock(
|
||||
return_value=SimpleNamespace(organization_id=11)
|
||||
)
|
||||
|
||||
result = await handle_vobiz_hangup_callback(
|
||||
workflow_run_id=123,
|
||||
request=request,
|
||||
x_vobiz_signature=headers["x-vobiz-signature"],
|
||||
x_vobiz_timestamp=headers["x-vobiz-timestamp"],
|
||||
)
|
||||
|
||||
assert result == {"status": "success"}
|
||||
process_status.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vobiz_ring_callback_accepts_signed_form_body():
|
||||
provider = _provider()
|
||||
form_data = {
|
||||
"CallUUID": "call-123",
|
||||
"CallStatus": "ringing",
|
||||
"From": "15551230001",
|
||||
"To": "15551230002",
|
||||
}
|
||||
headers = _signed_headers(provider, form_data=form_data)
|
||||
request = _request(
|
||||
path="/api/v1/telephony/vobiz/ring-callback/123",
|
||||
form_data=form_data,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
workflow_run = SimpleNamespace(workflow_id=7, logs={})
|
||||
|
||||
with (
|
||||
patch("api.services.telephony.providers.vobiz.routes.db_client") as db_client,
|
||||
patch(
|
||||
"api.services.telephony.providers.vobiz.routes.get_telephony_provider_for_run",
|
||||
new_callable=AsyncMock,
|
||||
return_value=provider,
|
||||
),
|
||||
patch(
|
||||
"api.services.telephony.providers.vobiz.routes.get_backend_endpoints",
|
||||
new_callable=AsyncMock,
|
||||
return_value=("https://example.test", "wss://example.test"),
|
||||
),
|
||||
):
|
||||
db_client.get_workflow_run_by_id = AsyncMock(return_value=workflow_run)
|
||||
db_client.get_workflow_by_id = AsyncMock(
|
||||
return_value=SimpleNamespace(organization_id=11)
|
||||
)
|
||||
db_client.update_workflow_run = AsyncMock()
|
||||
|
||||
result = await handle_vobiz_ring_callback(
|
||||
workflow_run_id=123,
|
||||
request=request,
|
||||
x_vobiz_signature=headers["x-vobiz-signature"],
|
||||
x_vobiz_timestamp=headers["x-vobiz-timestamp"],
|
||||
)
|
||||
|
||||
assert result == {"status": "success"}
|
||||
db_client.update_workflow_run.assert_awaited_once()
|
||||
|
|
@ -935,9 +935,11 @@ class TestCustomToolManagerUnit:
|
|||
# Create a mock engine with a mock LLM
|
||||
mock_llm = Mock()
|
||||
registered_handlers = {}
|
||||
registered_kwargs = {}
|
||||
|
||||
def capture_register(name, handler, **kwargs):
|
||||
registered_handlers[name] = handler
|
||||
registered_kwargs[name] = kwargs
|
||||
|
||||
mock_llm.register_function = capture_register
|
||||
|
||||
|
|
@ -986,6 +988,7 @@ class TestCustomToolManagerUnit:
|
|||
|
||||
# Verify handler was registered
|
||||
assert "api_call" in registered_handlers
|
||||
assert registered_kwargs["api_call"]["timeout_secs"] == pytest.approx(5)
|
||||
|
||||
# Now test that the handler works
|
||||
handler = registered_handlers["api_call"]
|
||||
|
|
|
|||
|
|
@ -313,6 +313,13 @@ class TestDispatcherThreadsTelephonyConfig:
|
|||
f"kwargs={store_kwargs}"
|
||||
)
|
||||
|
||||
assert provider.initiate_call.await_count == 1
|
||||
webhook_url = provider.initiate_call.await_args.kwargs["webhook_url"]
|
||||
assert "campaign_id=" not in webhook_url, (
|
||||
"campaign outbound answer_url should not include campaign_id; "
|
||||
f"got {webhook_url}"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_release_call_slot_uses_stored_telephony_config(self):
|
||||
"""When a call completes, release_call_slot must release the from_number
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ class TestGoogleVertexLLMConfiguration:
|
|||
config = GoogleVertexLLMConfiguration(project_id="demo-project")
|
||||
assert config.provider == ServiceProviders.GOOGLE_VERTEX
|
||||
assert config.model == "gemini-2.5-flash"
|
||||
assert config.location == "us-east4"
|
||||
assert config.location == "global"
|
||||
assert config.credentials is None
|
||||
assert config.api_key is None
|
||||
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from api.services.configuration.registry import (
|
|||
GoogleVertexLLMConfiguration,
|
||||
GrokRealtimeLLMConfiguration,
|
||||
OpenAILLMService,
|
||||
UltravoxRealtimeLLMConfiguration,
|
||||
)
|
||||
from api.services.configuration.resolve import resolve_effective_config
|
||||
|
||||
|
|
@ -261,6 +262,22 @@ class TestRealtimeOverride:
|
|||
assert result.realtime.provider == "grok_realtime"
|
||||
assert result.realtime.voice == "Sal"
|
||||
|
||||
def test_switch_realtime_provider_to_ultravox(self, global_config_realtime):
|
||||
result = resolve_effective_config(
|
||||
global_config_realtime,
|
||||
{
|
||||
"realtime": {
|
||||
"provider": "ultravox_realtime",
|
||||
"api_key": "ultra-key",
|
||||
"model": "ultravox-v0.7",
|
||||
"voice": "Mark",
|
||||
}
|
||||
},
|
||||
)
|
||||
assert isinstance(result.realtime, UltravoxRealtimeLLMConfiguration)
|
||||
assert result.realtime.provider == "ultravox_realtime"
|
||||
assert result.realtime.voice == "Mark"
|
||||
|
||||
def test_override_is_realtime_only_without_realtime_section(self, global_config):
|
||||
"""Override is_realtime=True but provide no realtime config.
|
||||
Should set the flag; realtime section stays None from global."""
|
||||
|
|
|
|||
158
api/tests/test_telephony_routes.py
Normal file
158
api/tests/test_telephony_routes.py
Normal file
|
|
@ -0,0 +1,158 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from api.routes.telephony import router
|
||||
from api.services.auth.depends import get_user
|
||||
|
||||
|
||||
def _make_test_app() -> FastAPI:
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
app.dependency_overrides[get_user] = lambda: SimpleNamespace(
|
||||
id=7,
|
||||
selected_organization_id=11,
|
||||
)
|
||||
return app
|
||||
|
||||
|
||||
def _workflow(*, workflow_id: int = 33, user_id: int = 99):
|
||||
return SimpleNamespace(
|
||||
id=workflow_id,
|
||||
user_id=user_id,
|
||||
organization_id=11,
|
||||
template_context_variables={"template_key": "template-value"},
|
||||
)
|
||||
|
||||
|
||||
def _provider():
|
||||
return SimpleNamespace(
|
||||
PROVIDER_NAME="twilio",
|
||||
WEBHOOK_ENDPOINT="twilio/voice",
|
||||
validate_config=Mock(return_value=True),
|
||||
initiate_call=AsyncMock(
|
||||
return_value=SimpleNamespace(
|
||||
caller_number="+15550001111",
|
||||
provider_metadata={"call_id": "call-123"},
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_initiate_call_executes_as_workflow_owner_for_shared_org_workflow():
|
||||
app = _make_test_app()
|
||||
client = TestClient(app)
|
||||
|
||||
workflow = _workflow()
|
||||
provider = _provider()
|
||||
quota_mock = AsyncMock(
|
||||
return_value=SimpleNamespace(has_quota=True, error_message="")
|
||||
)
|
||||
|
||||
with (
|
||||
patch("api.routes.telephony.db_client") as mock_db,
|
||||
patch(
|
||||
"api.routes.telephony.check_dograh_quota_by_user_id",
|
||||
new=quota_mock,
|
||||
),
|
||||
patch(
|
||||
"api.routes.telephony.get_default_telephony_provider",
|
||||
new=AsyncMock(return_value=provider),
|
||||
),
|
||||
patch(
|
||||
"api.routes.telephony.get_backend_endpoints",
|
||||
new=AsyncMock(return_value=("https://api.example.com", "wss://ignored")),
|
||||
),
|
||||
):
|
||||
mock_db.get_user_configurations = AsyncMock(
|
||||
return_value=SimpleNamespace(test_phone_number=None)
|
||||
)
|
||||
mock_db.get_default_telephony_configuration = AsyncMock(
|
||||
return_value=SimpleNamespace(id=55)
|
||||
)
|
||||
mock_db.get_workflow = AsyncMock(return_value=workflow)
|
||||
mock_db.create_workflow_run = AsyncMock(
|
||||
return_value=SimpleNamespace(
|
||||
id=501,
|
||||
name="WR-TEL-OUT-00000001",
|
||||
initial_context={"template_key": "template-value"},
|
||||
)
|
||||
)
|
||||
mock_db.update_workflow_run = AsyncMock()
|
||||
|
||||
response = client.post(
|
||||
"/telephony/initiate-call",
|
||||
json={"workflow_id": workflow.id, "phone_number": "+15551234567"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
quota_mock.assert_awaited_once_with(workflow.user_id, workflow_id=workflow.id)
|
||||
mock_db.get_workflow.assert_awaited_once_with(workflow.id, organization_id=11)
|
||||
|
||||
create_call = mock_db.create_workflow_run.await_args
|
||||
create_args = create_call.args
|
||||
create_kwargs = create_call.kwargs
|
||||
assert create_args[1] == workflow.id
|
||||
assert create_kwargs["user_id"] == workflow.user_id
|
||||
assert create_kwargs["organization_id"] == workflow.organization_id
|
||||
assert create_kwargs["initial_context"]["template_key"] == "template-value"
|
||||
|
||||
initiate_kwargs = provider.initiate_call.await_args.kwargs
|
||||
assert initiate_kwargs["workflow_id"] == workflow.id
|
||||
assert initiate_kwargs["user_id"] == workflow.user_id
|
||||
assert "user_id=99" in initiate_kwargs["webhook_url"]
|
||||
|
||||
|
||||
def test_initiate_call_rejects_existing_run_for_different_workflow():
|
||||
app = _make_test_app()
|
||||
client = TestClient(app)
|
||||
|
||||
workflow = _workflow()
|
||||
provider = _provider()
|
||||
quota_mock = AsyncMock(
|
||||
return_value=SimpleNamespace(has_quota=True, error_message="")
|
||||
)
|
||||
|
||||
with (
|
||||
patch("api.routes.telephony.db_client") as mock_db,
|
||||
patch(
|
||||
"api.routes.telephony.check_dograh_quota_by_user_id",
|
||||
new=quota_mock,
|
||||
),
|
||||
patch(
|
||||
"api.routes.telephony.get_default_telephony_provider",
|
||||
new=AsyncMock(return_value=provider),
|
||||
),
|
||||
):
|
||||
mock_db.get_user_configurations = AsyncMock(
|
||||
return_value=SimpleNamespace(test_phone_number=None)
|
||||
)
|
||||
mock_db.get_default_telephony_configuration = AsyncMock(
|
||||
return_value=SimpleNamespace(id=55)
|
||||
)
|
||||
mock_db.get_workflow = AsyncMock(return_value=workflow)
|
||||
mock_db.get_workflow_run = AsyncMock(
|
||||
return_value=SimpleNamespace(
|
||||
id=501,
|
||||
workflow_id=44,
|
||||
name="WR-TEL-OUT-00000044",
|
||||
initial_context={},
|
||||
)
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/telephony/initiate-call",
|
||||
json={
|
||||
"workflow_id": workflow.id,
|
||||
"workflow_run_id": 501,
|
||||
"phone_number": "+15551234567",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.json()["detail"] == "workflow_run_workflow_mismatch"
|
||||
mock_db.get_workflow_run.assert_awaited_once_with(501, organization_id=11)
|
||||
assert not mock_db.create_workflow_run.called
|
||||
assert provider.initiate_call.await_count == 0
|
||||
459
api/tests/test_ultravox_realtime_wrapper.py
Normal file
459
api/tests/test_ultravox_realtime_wrapper.py
Normal file
|
|
@ -0,0 +1,459 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, call
|
||||
|
||||
import pytest
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.frames.frames import LLMMessagesAppendFrame, TTSSpeakFrame
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from websockets.exceptions import ConnectionClosedError
|
||||
from websockets.frames import Close
|
||||
|
||||
from api.schemas.user_configuration import UserConfiguration
|
||||
from api.services.configuration.registry import UltravoxRealtimeLLMConfiguration
|
||||
from api.services.pipecat.realtime.ultravox_realtime import (
|
||||
_RESUMPTION_USER_MESSAGE,
|
||||
DograhUltravoxOneShotInputParams,
|
||||
DograhUltravoxRealtimeLLMService,
|
||||
)
|
||||
from api.services.pipecat.service_factory import create_realtime_llm_service
|
||||
|
||||
|
||||
class _ClosingSocket:
|
||||
def __init__(self, exc):
|
||||
self._exc = exc
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
raise self._exc
|
||||
|
||||
|
||||
class _MessageSocket:
|
||||
def __init__(self, messages):
|
||||
self._messages = iter(messages)
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
try:
|
||||
return next(self._messages)
|
||||
except StopIteration:
|
||||
raise StopAsyncIteration
|
||||
|
||||
|
||||
def _make_service() -> DograhUltravoxRealtimeLLMService:
|
||||
service = DograhUltravoxRealtimeLLMService(
|
||||
params=DograhUltravoxOneShotInputParams(
|
||||
api_key="test-key",
|
||||
model="ultravox-v0.7",
|
||||
output_medium="voice",
|
||||
),
|
||||
settings=DograhUltravoxRealtimeLLMService.Settings(
|
||||
model="ultravox-v0.7",
|
||||
output_medium="voice",
|
||||
),
|
||||
)
|
||||
service.stop_all_metrics = AsyncMock()
|
||||
service.cancel_task = AsyncMock()
|
||||
service.push_error = AsyncMock()
|
||||
return service
|
||||
|
||||
|
||||
def _tool_schema() -> ToolsSchema:
|
||||
return ToolsSchema(
|
||||
standard_tools=[
|
||||
FunctionSchema(
|
||||
name="transition_to_next_node",
|
||||
description="Move to the next workflow node",
|
||||
properties={"reason": {"type": "string"}},
|
||||
required=[],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tts_greeting_triggers_initial_connect():
|
||||
service = _make_service()
|
||||
service._connect_call = AsyncMock()
|
||||
|
||||
await service.process_frame(
|
||||
TTSSpeakFrame("Hello there", append_to_context=True),
|
||||
FrameDirection.DOWNSTREAM,
|
||||
)
|
||||
|
||||
service._connect_call.assert_awaited_once()
|
||||
assert service._connect_call.await_args.kwargs["greeting_text"] == "Hello there"
|
||||
assert service._connect_call.await_args.kwargs["agent_speaks_first"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initial_context_connects_without_replay():
|
||||
service = _make_service()
|
||||
service._connect_call = AsyncMock()
|
||||
context = LLMContext()
|
||||
|
||||
await service._handle_context(context)
|
||||
|
||||
service._connect_call.assert_awaited_once()
|
||||
assert service._connect_call.await_args.kwargs["initial_messages"] is None
|
||||
assert service._connect_call.await_args.kwargs["agent_speaks_first"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_system_instruction_update_marks_reconnect_required():
|
||||
service = _make_service()
|
||||
service._has_connected_once = True
|
||||
|
||||
changed = await service._update_settings(
|
||||
DograhUltravoxRealtimeLLMService.Settings(system_instruction="new instruction")
|
||||
)
|
||||
|
||||
assert "system_instruction" in changed
|
||||
assert service._reconnect_required is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_system_instruction_change_reconnects_with_full_initial_messages():
|
||||
service = _make_service()
|
||||
service._socket = object()
|
||||
service._has_connected_once = True
|
||||
service._call_system_instruction = "old instruction"
|
||||
service._reconnect_required = True
|
||||
service._settings.system_instruction = "new instruction"
|
||||
service._reconnect_with_context = AsyncMock()
|
||||
|
||||
context = LLMContext(
|
||||
messages=[
|
||||
{"role": "user", "content": "I want to hear the pricing."},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Let me check that for you.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call-transition",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "transition_to_next_node",
|
||||
"arguments": '{"reason":"pricing requested"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call-transition",
|
||||
"content": '{"status":"done"}',
|
||||
},
|
||||
],
|
||||
tools=_tool_schema(),
|
||||
)
|
||||
|
||||
await service._handle_context(context)
|
||||
|
||||
service._reconnect_with_context.assert_awaited_once()
|
||||
initial_messages = service._reconnect_with_context.await_args.kwargs[
|
||||
"initial_messages"
|
||||
]
|
||||
assert initial_messages == [
|
||||
{
|
||||
"role": "MESSAGE_ROLE_USER",
|
||||
"text": "I want to hear the pricing.",
|
||||
},
|
||||
{
|
||||
"role": "MESSAGE_ROLE_AGENT",
|
||||
"text": "Let me check that for you.",
|
||||
},
|
||||
{
|
||||
"role": "MESSAGE_ROLE_TOOL_CALL",
|
||||
"text": "",
|
||||
"invocationId": "call-transition",
|
||||
"toolName": "transition_to_next_node",
|
||||
},
|
||||
{
|
||||
"role": "MESSAGE_ROLE_TOOL_RESULT",
|
||||
"text": '{"status":"done"}',
|
||||
"invocationId": "call-transition",
|
||||
"toolName": "transition_to_next_node",
|
||||
},
|
||||
]
|
||||
assert "call-transition" in service._completed_tool_calls
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_context_update_does_not_reconnect_when_system_instruction_is_unchanged():
|
||||
service = _make_service()
|
||||
service._socket = object()
|
||||
service._call_system_instruction = "same instruction"
|
||||
service._settings.system_instruction = "same instruction"
|
||||
service._reconnect_with_context = AsyncMock()
|
||||
service._send_tool_result = AsyncMock()
|
||||
|
||||
context = LLMContext(
|
||||
messages=[
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call-transition",
|
||||
"content": '{"status":"done"}',
|
||||
},
|
||||
],
|
||||
tools=_tool_schema(),
|
||||
)
|
||||
|
||||
await service._handle_context(context)
|
||||
|
||||
service._reconnect_with_context.assert_not_awaited()
|
||||
service._send_tool_result.assert_awaited_once_with(
|
||||
"call-transition",
|
||||
'{"status":"done"}',
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_messages_append_frame_sends_user_text():
|
||||
service = _make_service()
|
||||
service._socket = object()
|
||||
service._call_started = True
|
||||
service._send_user_text = AsyncMock()
|
||||
|
||||
await service._handle_messages_append(
|
||||
LLMMessagesAppendFrame(
|
||||
[{"role": "user", "content": "Are you still there?"}],
|
||||
run_llm=True,
|
||||
)
|
||||
)
|
||||
|
||||
service._send_user_text.assert_awaited_once_with("Are you still there?")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_messages_append_frame_queues_user_text_until_call_started():
|
||||
service = _make_service()
|
||||
service._socket = object()
|
||||
service._call_started = False
|
||||
service._send_user_text = AsyncMock()
|
||||
|
||||
await service._handle_messages_append(
|
||||
LLMMessagesAppendFrame(
|
||||
[{"role": "user", "content": "Are you still there?"}],
|
||||
run_llm=True,
|
||||
)
|
||||
)
|
||||
|
||||
assert service._pending_user_text_messages == ["Are you still there?"]
|
||||
service._send_user_text.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_started_flushes_pending_user_text_messages():
|
||||
service = _make_service()
|
||||
service._pending_user_text_messages = [
|
||||
"First queued message",
|
||||
"Second queued message",
|
||||
]
|
||||
service._send_user_text = AsyncMock()
|
||||
service._socket = _MessageSocket(['{"type":"call_started","callId":"call-123"}'])
|
||||
|
||||
await service._receive_messages()
|
||||
|
||||
assert service._call_started is True
|
||||
assert service._pending_user_text_messages == []
|
||||
assert service._send_user_text.await_args_list == [
|
||||
call("First queued message"),
|
||||
call("Second queued message"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completed_input_transcription_is_broadcast_as_finalized():
|
||||
service = _make_service()
|
||||
service.broadcast_frame = AsyncMock()
|
||||
service._last_user_id = "caller-1"
|
||||
|
||||
await service._handle_user_transcript("Hello there")
|
||||
|
||||
service.broadcast_frame.assert_awaited_once()
|
||||
assert service.broadcast_frame.await_args.args[0].__name__ == "TranscriptionFrame"
|
||||
assert service.broadcast_frame.await_args.kwargs["text"] == "Hello there"
|
||||
assert service.broadcast_frame.await_args.kwargs["finalized"] is True
|
||||
|
||||
|
||||
def test_build_one_shot_params_uses_explicit_greeting_text():
|
||||
service = _make_service()
|
||||
|
||||
params = service._build_one_shot_params(
|
||||
greeting_text="Welcome to Dograh",
|
||||
initial_messages=None,
|
||||
agent_speaks_first=True,
|
||||
)
|
||||
|
||||
assert params.extra["firstSpeakerSettings"] == {
|
||||
"agent": {"text": "Welcome to Dograh"}
|
||||
}
|
||||
|
||||
|
||||
def test_build_one_shot_params_includes_initial_messages():
|
||||
service = _make_service()
|
||||
service._settings.system_instruction = "Base instruction"
|
||||
|
||||
params = service._build_one_shot_params(
|
||||
greeting_text=None,
|
||||
initial_messages=[
|
||||
{"role": "MESSAGE_ROLE_USER", "text": "User asked a question."},
|
||||
{"role": "MESSAGE_ROLE_TOOL_RESULT", "text": '{"status":"done"}'},
|
||||
],
|
||||
agent_speaks_first=True,
|
||||
)
|
||||
|
||||
assert params.extra["initialMessages"] == [
|
||||
{"role": "MESSAGE_ROLE_USER", "text": "User asked a question."},
|
||||
{"role": "MESSAGE_ROLE_TOOL_RESULT", "text": '{"status":"done"}'},
|
||||
{"role": "MESSAGE_ROLE_USER", "text": _RESUMPTION_USER_MESSAGE},
|
||||
]
|
||||
assert params.system_prompt == "Base instruction"
|
||||
|
||||
|
||||
def test_build_one_shot_params_without_tool_result_does_not_add_resumption_user_message():
|
||||
service = _make_service()
|
||||
service._settings.system_instruction = "Base instruction"
|
||||
|
||||
params = service._build_one_shot_params(
|
||||
greeting_text=None,
|
||||
initial_messages=[
|
||||
{"role": "MESSAGE_ROLE_USER", "text": "User asked a question."},
|
||||
{"role": "MESSAGE_ROLE_AGENT", "text": "Assistant replied."},
|
||||
],
|
||||
agent_speaks_first=False,
|
||||
)
|
||||
|
||||
assert params.system_prompt == "Base instruction"
|
||||
|
||||
|
||||
def test_should_agent_speak_first_when_history_ends_with_tool_result():
|
||||
service = _make_service()
|
||||
|
||||
assert (
|
||||
service._should_agent_speak_first(
|
||||
[
|
||||
{"role": "MESSAGE_ROLE_USER", "text": "Hello"},
|
||||
{"role": "MESSAGE_ROLE_TOOL_RESULT", "text": '{"status":"done"}'},
|
||||
]
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
|
||||
def test_should_not_force_agent_speaks_first_when_history_ends_with_agent():
|
||||
service = _make_service()
|
||||
|
||||
assert (
|
||||
service._should_agent_speak_first(
|
||||
[{"role": "MESSAGE_ROLE_AGENT", "text": "How else can I help?"}]
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_should_add_resumption_user_message_only_when_history_ends_with_tool_result():
|
||||
service = _make_service()
|
||||
|
||||
assert (
|
||||
service._should_add_resumption_user_message(
|
||||
[{"role": "MESSAGE_ROLE_TOOL_RESULT", "text": '{"status":"done"}'}]
|
||||
)
|
||||
is True
|
||||
)
|
||||
assert (
|
||||
service._should_add_resumption_user_message(
|
||||
[{"role": "MESSAGE_ROLE_AGENT", "text": "Assistant replied."}]
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_to_selected_tools_includes_registered_timeout():
|
||||
service = _make_service()
|
||||
service.register_function(
|
||||
"transition_to_next_node",
|
||||
AsyncMock(),
|
||||
timeout_secs=5.5,
|
||||
)
|
||||
|
||||
selected_tools = service._to_selected_tools(_tool_schema())
|
||||
|
||||
assert selected_tools == [
|
||||
{
|
||||
"temporaryTool": {
|
||||
"modelToolName": "transition_to_next_node",
|
||||
"description": "Move to the next workflow node",
|
||||
"dynamicParameters": [
|
||||
{
|
||||
"name": "reason",
|
||||
"location": "PARAMETER_LOCATION_BODY",
|
||||
"schema": {"type": "string"},
|
||||
"required": False,
|
||||
}
|
||||
],
|
||||
"client": {},
|
||||
"timeout": "5.5s",
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_messages_ignores_benign_websocket_close():
|
||||
service = _make_service()
|
||||
service._socket = _ClosingSocket(
|
||||
ConnectionClosedError(None, Close(1000, "OK"), None)
|
||||
)
|
||||
|
||||
await service._receive_messages()
|
||||
|
||||
service.push_error.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_messages_reports_unexpected_websocket_close():
|
||||
service = _make_service()
|
||||
service._socket = _ClosingSocket(
|
||||
ConnectionClosedError(None, Close(1011, "internal error"), None)
|
||||
)
|
||||
|
||||
await service._receive_messages()
|
||||
|
||||
service.push_error.assert_awaited_once()
|
||||
|
||||
|
||||
def test_factory_creates_dograh_ultravox_realtime_service():
|
||||
user_config = UserConfiguration(
|
||||
is_realtime=True,
|
||||
realtime=UltravoxRealtimeLLMConfiguration(
|
||||
provider="ultravox_realtime",
|
||||
api_key="ultra-key",
|
||||
model="ultravox-v0.7",
|
||||
voice="Mark",
|
||||
),
|
||||
)
|
||||
|
||||
service = create_realtime_llm_service(
|
||||
user_config,
|
||||
audio_config=SimpleNamespace(),
|
||||
)
|
||||
|
||||
assert isinstance(service, DograhUltravoxRealtimeLLMService)
|
||||
assert service._params.voice == "Mark"
|
||||
|
||||
|
||||
def test_ultravox_realtime_configuration_defaults_to_mark_voice():
|
||||
config = UltravoxRealtimeLLMConfiguration(
|
||||
provider="ultravox_realtime",
|
||||
api_key="ultra-key",
|
||||
model="ultravox-v0.7",
|
||||
)
|
||||
|
||||
assert config.voice == "Mark"
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
[project]
|
||||
name = "dograh-sdk"
|
||||
version = "0.1.5"
|
||||
version = "0.1.6"
|
||||
description = "Typed builder for Dograh voice-AI workflows"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# generated by datamodel-codegen:
|
||||
# filename: dograh-openapi-XXXXXX.json.SafScGt2nh
|
||||
# timestamp: 2026-05-22T09:06:50+00:00
|
||||
# filename: dograh-openapi-XXXXXX.json.ahnZ2z2E21
|
||||
# timestamp: 2026-05-23T03:32:29+00:00
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
{
|
||||
"name": "@dograh/sdk",
|
||||
"version": "0.1.5",
|
||||
"version": "0.1.6",
|
||||
"description": "Typed builder for Dograh voice-AI workflows",
|
||||
"license": "BSD-2-Clause",
|
||||
"author": "Zansat Technologies Private Limited",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue