diff --git a/README.md b/README.md index cd62044..369767c 100644 --- a/README.md +++ b/README.md @@ -132,6 +132,11 @@ Visit [https://www.dograh.com](https://www.dograh.com/) for our managed cloud of You can go to [https://docs.dograh.com](https://docs.dograh.com/) for our documentation. +## ๐Ÿ“ฆ SDKs + +- **Python SDK** โ€” [pypi.org/project/dograh-sdk](https://pypi.org/project/dograh-sdk/) +- **Node SDK** โ€” [npmjs.com/package/@dograh/sdk](https://www.npmjs.com/package/@dograh/sdk) + ## ๐ŸคCommunity & Support > ๐Ÿ‘‹ **Coming from the Better Stack video?** Drop your use case in our [pinned GitHub Discussion](https://github.com/orgs/dograh-hq/discussions/291) โ€” we read every reply and the founders personally onboard early adopters. diff --git a/api/routes/telephony.py b/api/routes/telephony.py index 7935ec9..86bbbc0 100644 --- a/api/routes/telephony.py +++ b/api/routes/telephony.py @@ -25,7 +25,7 @@ from api.enums import CallType, WorkflowRunState from api.errors.telephony_errors import TelephonyError from api.sdk_expose import sdk_expose from api.services.auth.depends import get_user -from api.services.quota_service import check_dograh_quota, check_dograh_quota_by_user_id +from api.services.quota_service import check_dograh_quota_by_user_id from api.services.telephony.call_transfer_manager import get_call_transfer_manager from api.services.telephony.factory import ( get_all_telephony_providers, @@ -60,6 +60,15 @@ class InitiateCallRequest(BaseModel): from_phone_number_id: int | None = None +def _get_execution_user_id(workflow) -> int: + if workflow.user_id is None: + raise HTTPException( + status_code=409, + detail="Workflow has no execution owner", + ) + return workflow.user_id + + @router.post( "/initiate-call", **sdk_expose( @@ -107,15 +116,6 @@ async def initiate_call( detail="telephony_not_configured", ) - # Check Dograh quota before initiating the call (apply per-workflow - # model_overrides so the keys we will actually use are the ones checked). - quota_result = await check_dograh_quota(user, workflow_id=request.workflow_id) - if not quota_result.has_quota: - raise HTTPException(status_code=402, detail=quota_result.error_message) - - # Determine the workflow run mode based on provider type - workflow_run_mode = provider.PROVIDER_NAME - phone_number = request.phone_number or user_configuration.test_phone_number if not phone_number: @@ -125,25 +125,38 @@ async def initiate_call( "configuration", ) + workflow = await db_client.get_workflow( + request.workflow_id, organization_id=user.selected_organization_id + ) + if not workflow: + raise HTTPException(status_code=404, detail="Workflow not found") + execution_user_id = _get_execution_user_id(workflow) + + # Check Dograh quota before initiating the call (apply per-workflow + # model_overrides so the keys we will actually use are the ones checked). + quota_result = await check_dograh_quota_by_user_id( + execution_user_id, workflow_id=workflow.id + ) + if not quota_result.has_quota: + raise HTTPException(status_code=402, detail=quota_result.error_message) + + # Determine the workflow run mode based on provider type + workflow_run_mode = provider.PROVIDER_NAME + workflow_run_id = request.workflow_run_id if not workflow_run_id: - # Fetch workflow to merge template context variables (e.g. caller_number, - # called_number set in workflow settings for testing pre-call data fetch) - workflow = await db_client.get_workflow( - request.workflow_id, organization_id=user.selected_organization_id - ) - if not workflow: - raise HTTPException(status_code=404, detail="Workflow not found") + # Merge template context variables (e.g. caller_number, called_number + # set in workflow settings for testing pre-call data fetch). template_vars = workflow.template_context_variables or {} numeric_suffix = int(str(uuid.uuid4()).replace("-", "")[:8], 16) % 100000000 workflow_run_name = f"WR-TEL-OUT-{numeric_suffix:08d}" workflow_run = await db_client.create_workflow_run( workflow_run_name, - request.workflow_id, + workflow.id, workflow_run_mode, - user_id=user.id, + user_id=execution_user_id, call_type=CallType.OUTBOUND, initial_context={ **template_vars, @@ -157,9 +170,16 @@ async def initiate_call( ) workflow_run_id = workflow_run.id else: - workflow_run = await db_client.get_workflow_run(workflow_run_id, user.id) + workflow_run = await db_client.get_workflow_run( + workflow_run_id, organization_id=user.selected_organization_id + ) if not workflow_run: raise HTTPException(status_code=400, detail="Workflow run not found") + if workflow_run.workflow_id != workflow.id: + raise HTTPException( + status_code=400, + detail="workflow_run_workflow_mismatch", + ) workflow_run_name = workflow_run.name # Construct webhook URL based on provider type @@ -169,13 +189,13 @@ async def initiate_call( webhook_url = ( f"{backend_endpoint}/api/v1/telephony/{webhook_endpoint}" - f"?workflow_id={request.workflow_id}" - f"&user_id={user.id}" + f"?workflow_id={workflow.id}" + f"&user_id={execution_user_id}" f"&workflow_run_id={workflow_run_id}" f"&organization_id={user.selected_organization_id}" ) - keywords = {"workflow_id": request.workflow_id, "user_id": user.id} + keywords = {"workflow_id": workflow.id, "user_id": execution_user_id} # Resolve optional caller-ID. The config has already been validated against # the user's organization, so filtering by config_id is sufficient for @@ -293,6 +313,7 @@ async def _detect_provider(webhook_data: dict, headers: dict): async def _validate_inbound_request( workflow_id: int, + webhook_url: str, provider_class, normalized_data, webhook_data: dict, @@ -364,8 +385,6 @@ async def _validate_inbound_request( # Verify webhook signature using the matched config's credentials. The # provider extracts its own signature/timestamp/nonce headers from the # dict, so this dispatcher stays generic. - backend_endpoint, _ = await get_backend_endpoints() - webhook_url = f"{backend_endpoint}/api/v1/telephony/inbound/{workflow_id}" provider_instance = await get_telephony_provider_by_id( telephony_configuration_id, organization_id ) @@ -701,13 +720,11 @@ async def handle_inbound_run(request: Request): user_id = workflow.user_id # 3. Verify webhook signature against the matched config's credentials. - backend_endpoint, wss_backend_endpoint = await get_backend_endpoints() - webhook_url = f"{backend_endpoint}/api/v1/telephony/inbound/run" provider_instance = await get_telephony_provider_by_id( telephony_configuration_id, config.organization_id ) signature_valid = await provider_instance.verify_inbound_signature( - webhook_url, webhook_data, headers, raw_body + str(request.url), webhook_data, headers, raw_body ) if not signature_valid: logger.warning( @@ -740,6 +757,7 @@ async def handle_inbound_run(request: Request): from_phone_number_id=phone_row.id, ) + backend_endpoint, wss_backend_endpoint = await get_backend_endpoints() websocket_url = ( f"{wss_backend_endpoint}/api/v1/telephony/ws/" f"{workflow_id}/{user_id}/{workflow_run_id}" @@ -840,6 +858,7 @@ async def handle_inbound_telephony( provider_instance, ) = await _validate_inbound_request( workflow_id, + str(request.url), provider_class, normalized_data, webhook_data, diff --git a/api/services/campaign/campaign_call_dispatcher.py b/api/services/campaign/campaign_call_dispatcher.py index ece8cac..e00ddb6 100644 --- a/api/services/campaign/campaign_call_dispatcher.py +++ b/api/services/campaign/campaign_call_dispatcher.py @@ -296,7 +296,6 @@ class CampaignCallDispatcher: f"?workflow_id={campaign.workflow_id}" f"&user_id={campaign.created_by}" f"&workflow_run_id={workflow_run.id}" - f"&campaign_id={campaign.id}" f"&organization_id={campaign.organization_id}" ) diff --git a/api/services/configuration/check_validity.py b/api/services/configuration/check_validity.py index d5a724c..721884b 100644 --- a/api/services/configuration/check_validity.py +++ b/api/services/configuration/check_validity.py @@ -50,6 +50,7 @@ class UserConfigurationValidator: ServiceProviders.GOOGLE_VERTEX.value: self._check_google_vertex_llm_api_key, ServiceProviders.OPENAI_REALTIME.value: self._check_openai_api_key, ServiceProviders.GROK_REALTIME.value: self._check_grok_realtime_api_key, + ServiceProviders.ULTRAVOX_REALTIME.value: self._check_ultravox_realtime_api_key, ServiceProviders.GOOGLE_REALTIME.value: self._check_google_api_key, ServiceProviders.GOOGLE_VERTEX_REALTIME.value: self._check_google_vertex_realtime_api_key, ServiceProviders.ASSEMBLYAI.value: self._check_assemblyai_api_key, @@ -255,6 +256,9 @@ class UserConfigurationValidator: def _check_grok_realtime_api_key(self, model: str, api_key: str) -> bool: return True + def _check_ultravox_realtime_api_key(self, model: str, api_key: str) -> bool: + return True + def _check_speechmatics_api_key(self, model: str, api_key: str) -> bool: return True diff --git a/api/services/configuration/masking.py b/api/services/configuration/masking.py index 3b904c6..f1ed1f6 100644 --- a/api/services/configuration/masking.py +++ b/api/services/configuration/masking.py @@ -38,7 +38,11 @@ def check_for_masked_keys(config: "UserConfiguration") -> None: for secret_field in SERVICE_SECRET_FIELDS: if not hasattr(service, secret_field): continue - if contains_masked_key(getattr(service, secret_field, None)): + if secret_field == "api_key" and hasattr(service, "get_all_api_keys"): + secret_value = service.get_all_api_keys() + else: + secret_value = getattr(service, secret_field, None) + if contains_masked_key(secret_value): raise ValueError( f"The {field} {secret_field} appears to be masked. " "Please provide the actual value, not the masked value." diff --git a/api/services/configuration/registry.py b/api/services/configuration/registry.py index 9ba16c3..e60db18 100644 --- a/api/services/configuration/registry.py +++ b/api/services/configuration/registry.py @@ -62,6 +62,7 @@ class ServiceProviders(str, Enum): GOOGLE_VERTEX = "google_vertex" OPENAI_REALTIME = "openai_realtime" GROK_REALTIME = "grok_realtime" + ULTRAVOX_REALTIME = "ultravox_realtime" GOOGLE_REALTIME = "google_realtime" GOOGLE_VERTEX_REALTIME = "google_vertex_realtime" @@ -85,6 +86,7 @@ class BaseServiceConfiguration(BaseModel): ServiceProviders.GOOGLE_VERTEX, ServiceProviders.OPENAI_REALTIME, ServiceProviders.GROK_REALTIME, + ServiceProviders.ULTRAVOX_REALTIME, ServiceProviders.GOOGLE_REALTIME, ServiceProviders.GOOGLE_VERTEX_REALTIME, # ServiceProviders.SARVAM, @@ -214,6 +216,7 @@ AWS_BEDROCK_PROVIDER_MODEL_CONFIG = provider_model_config("AWS Bedrock") GOOGLE_VERTEX_PROVIDER_MODEL_CONFIG = provider_model_config("Google Vertex") OPENAI_REALTIME_PROVIDER_MODEL_CONFIG = provider_model_config("OpenAI Realtime") GROK_REALTIME_PROVIDER_MODEL_CONFIG = provider_model_config("Grok Realtime") +ULTRAVOX_REALTIME_PROVIDER_MODEL_CONFIG = provider_model_config("Ultravox Realtime") GOOGLE_REALTIME_PROVIDER_MODEL_CONFIG = provider_model_config("Google Realtime") GOOGLE_VERTEX_REALTIME_PROVIDER_MODEL_CONFIG = provider_model_config( "Google Vertex Realtime" @@ -504,6 +507,7 @@ class OpenAIRealtimeLLMConfiguration(BaseLLMConfiguration): GROK_REALTIME_MODELS = ["grok-voice-think-fast-1.0"] GROK_REALTIME_VOICES = ["Ara", "Rex", "Sal", "Eve", "Leo"] +ULTRAVOX_REALTIME_MODELS = ["ultravox-v0.7", "fixie-ai/ultravox"] @register_service(ServiceType.REALTIME) @@ -528,6 +532,26 @@ class GrokRealtimeLLMConfiguration(BaseLLMConfiguration): ) +@register_service(ServiceType.REALTIME) +class UltravoxRealtimeLLMConfiguration(BaseLLMConfiguration): + model_config = ULTRAVOX_REALTIME_PROVIDER_MODEL_CONFIG + provider: Literal[ServiceProviders.ULTRAVOX_REALTIME] = ( + ServiceProviders.ULTRAVOX_REALTIME + ) + model: str = Field( + default="ultravox-v0.7", + description="Ultravox realtime voice-agent model.", + json_schema_extra={ + "examples": ULTRAVOX_REALTIME_MODELS, + "allow_custom_input": True, + }, + ) + voice: str = Field( + default="Mark", + description="Ultravox voice name or voice ID.", + ) + + @register_service(ServiceType.REALTIME) class GoogleRealtimeLLMConfiguration(BaseLLMConfiguration): model_config = GOOGLE_REALTIME_PROVIDER_MODEL_CONFIG @@ -615,6 +639,7 @@ class GoogleVertexRealtimeLLMConfiguration(BaseLLMConfiguration): REALTIME_PROVIDERS = { ServiceProviders.OPENAI_REALTIME.value, ServiceProviders.GROK_REALTIME.value, + ServiceProviders.ULTRAVOX_REALTIME.value, ServiceProviders.GOOGLE_REALTIME.value, ServiceProviders.GOOGLE_VERTEX_REALTIME.value, } @@ -640,6 +665,7 @@ RealtimeConfig = Annotated[ Union[ OpenAIRealtimeLLMConfiguration, GrokRealtimeLLMConfiguration, + UltravoxRealtimeLLMConfiguration, GoogleRealtimeLLMConfiguration, GoogleVertexRealtimeLLMConfiguration, ], diff --git a/api/services/pipecat/realtime/ultravox_realtime.py b/api/services/pipecat/realtime/ultravox_realtime.py new file mode 100644 index 0000000..a666bc9 --- /dev/null +++ b/api/services/pipecat/realtime/ultravox_realtime.py @@ -0,0 +1,653 @@ +"""Dograh subclass of pipecat's Ultravox realtime LLM service. + +Ultravox is audio-native and realtime, but prompt and tool configuration is +bound to call creation. Dograh therefore cannot lean on in-session updates or +Gemini-style session resumption handles. This wrapper adapts Ultravox to the +Dograh engine contract by: + +- deferring the first call creation until the engine queues the initial node + opening via ``TTSSpeakFrame`` or ``LLMContextFrame`` +- marking the call for recreation when ``system_instruction`` changes across + node transitions, then rebuilding it on the follow-up ``LLMContextFrame`` + so the transition tool result is present in ``initialMessages`` +- reconstructing Ultravox ``initialMessages`` from Dograh context when the + call must be recreated after a node transition +- appending a transient resumptive user nudge to recreated ``initialMessages`` + after tool-result transitions, without mutating Dograh's stored context +- handling Dograh-only frames such as user mute and idle append prompts +- tagging user transcripts with ``finalized=True`` for downstream parity +""" + +import hashlib +import json +from typing import Any + +from loguru import logger +from pydantic import Field +from websockets.exceptions import ConnectionClosed + +from pipecat.frames.frames import ( + Frame, + LLMMessagesAppendFrame, + TranscriptionFrame, + TTSSpeakFrame, + UserMuteStartedFrame, + UserMuteStoppedFrame, +) +from pipecat.processors.aggregators import async_tool_messages +from pipecat.processors.aggregators.llm_context import ( + LLMContext, + LLMSpecificMessage, + is_given, +) +from pipecat.processors.frame_processor import FrameDirection +from pipecat.services.llm_service import LLMService +from pipecat.services.settings import _NotGiven, assert_given +from pipecat.services.ultravox.llm import ( + OneShotInputParams, + UltravoxRealtimeLLMService, + websocket_client, +) +from pipecat.utils.time import time_now_iso8601 + + +class DograhUltravoxOneShotInputParams(OneShotInputParams): + """Dograh-friendly OneShot params with string voice support.""" + + voice: str | None = Field(default=None) + + +_ULTRAVOX_MAX_TOOL_TIMEOUT_SECS = 40.0 +_RESUMPTION_USER_MESSAGE = ( + "IMPORTANT: We are resuming an existing conversation. You are given previous turns ONLY for your reference. " + "Do not use that to frame your response. Follow your ORIGINAL INSTRUCTIONS ONLY." +) + + +class DograhUltravoxRealtimeLLMService(UltravoxRealtimeLLMService): + """Ultravox realtime with Dograh engine integration quirks.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._context: LLMContext | None = None + self._selected_tools = None + self._user_is_muted: bool = False + self._call_system_instruction: str | None = None + self._reconnect_required: bool = False + self._call_started: bool = False + self._has_connected_once: bool = False + self._pending_reconnect_system_instruction: str | None = None + self._pending_initial_messages: list[dict[str, Any]] | None = None + self._pending_user_text_messages: list[str] = [] + + async def start(self, frame): + # Dograh defers call creation until the engine queues the node opening. + await LLMService.start(self, frame) + + async def process_frame(self, frame: Frame, direction: FrameDirection): + if isinstance(frame, UserMuteStartedFrame): + self._user_is_muted = True + await self.push_frame(frame, direction) + return + if isinstance(frame, UserMuteStoppedFrame): + self._user_is_muted = False + await self.push_frame(frame, direction) + return + if isinstance(frame, TTSSpeakFrame): + if not self._socket: + await self._connect_call( + system_instruction=self._current_system_instruction(), + greeting_text=frame.text, + initial_messages=None, + agent_speaks_first=True, + ) + else: + logger.warning( + f"{self}: TTSSpeakFrame received after the Ultravox call was " + "already created; ignoring because Ultravox owns speech output" + ) + return + if isinstance(frame, LLMMessagesAppendFrame): + await self._handle_messages_append(frame) + return + await super().process_frame(frame, direction) + + async def _update_settings(self, delta: UltravoxRealtimeLLMService.Settings): + changed = await super(UltravoxRealtimeLLMService, self)._update_settings(delta) + if "output_medium" in changed: + await self._update_output_medium(assert_given(self._settings.output_medium)) + if "system_instruction" in changed and self._has_connected_once: + # Mirror Gemini's "settings change means reconnect" intent, but + # defer the actual new-call creation until the subsequent + # LLMContextFrame arrives with the transition tool result. Ultravox + # cannot accept that historical tool result over a formal + # post-connect tool-response channel the way Gemini can. + self._reconnect_required = True + handled = {"output_medium", "system_instruction"} + self._warn_unhandled_updated_settings(changed.keys() - handled) + return changed + + async def _disconnect(self, preserve_completed_tool_calls: bool = True): + self._disconnecting = True + await self.stop_all_metrics() + if self._socket: + await self._socket.close() + self._socket = None + if self._receive_task: + await self.cancel_task(self._receive_task, timeout=1.0) + self._receive_task = None + if not preserve_completed_tool_calls: + self._completed_tool_calls = set() + self._call_started = False + self._started_placeholder_sent = set() + self._disconnecting = False + + async def _send_user_audio(self, frame): + if self._user_is_muted: + return + await super()._send_user_audio(frame) + + async def _handle_context(self, context: LLMContext): + self._context = context + system_instruction = self._current_system_instruction() + + if self._socket and not self._reconnect_required: + await super()._handle_context(context) + return + + initial_messages, history_tool_call_ids = self._build_initial_messages(context) + if history_tool_call_ids: + self._completed_tool_calls.update(history_tool_call_ids) + + if self._bot_responding: + self._pending_reconnect_system_instruction = system_instruction + self._pending_initial_messages = initial_messages + return + + await self._reconnect_with_context( + system_instruction=system_instruction, + initial_messages=initial_messages, + ) + + async def _handle_response_end(self): + await super()._handle_response_end() + if self._pending_reconnect_system_instruction is None: + return + + system_instruction = self._pending_reconnect_system_instruction + initial_messages = self._pending_initial_messages + self._pending_reconnect_system_instruction = None + self._pending_initial_messages = None + await self._reconnect_with_context( + system_instruction=system_instruction, + initial_messages=initial_messages, + ) + + async def _handle_messages_append(self, frame: LLMMessagesAppendFrame): + texts = [ + text + for text in ( + self._extract_text_content(message.get("content")) + for message in frame.messages + if isinstance(message, dict) + ) + if text + ] + if not texts: + return + + if not self._socket: + self._pending_user_text_messages.extend(texts) + await self._connect_call( + system_instruction=self._current_system_instruction(), + greeting_text=None, + initial_messages=None, + agent_speaks_first=False, + ) + return + + if not self._call_started: + self._pending_user_text_messages.extend(texts) + logger.debug( + f"{self}: queueing {len(texts)} user text message(s) until call_started" + ) + return + + for text in texts: + await self._send_user_text(text) + + async def _handle_user_transcript(self, text: str): + transcript = text.strip() if text else "" + if not transcript: + return + await self.broadcast_frame( + TranscriptionFrame, + user_id=self._last_user_id or "", + timestamp=time_now_iso8601(), + result=text, + text=transcript, + finalized=True, + ) + + async def _connect_call( + self, + *, + system_instruction: str | None, + greeting_text: str | None, + initial_messages: list[dict[str, Any]] | None, + agent_speaks_first: bool, + ): + params = self._build_one_shot_params( + greeting_text=greeting_text, + initial_messages=initial_messages, + agent_speaks_first=agent_speaks_first, + ) + self._params = params + self._selected_tools = self._current_tools_schema(self._context) + tool_names = ( + [tool.name for tool in self._selected_tools.standard_tools] + if self._selected_tools + else [] + ) + prompt = params.system_prompt or "" + prompt_hash = hashlib.sha256(prompt.encode("utf-8")).hexdigest()[:12] + + try: + logger.info( + f"{self}: creating Ultravox call " + f"(agent_speaks_first={agent_speaks_first}, " + f"voice={params.voice!r}, " + f"tools={tool_names}, " + f"system_prompt_len={len(prompt)}, " + f"system_prompt_sha256={prompt_hash})" + ) + join_url = await self._start_one_shot_call(params) + logger.info(f"Joining Ultravox Realtime call via URL: {join_url}") + self._socket = await websocket_client.connect(join_url) + self._receive_task = self.create_task(self._receive_messages()) + self._call_system_instruction = system_instruction + self._call_started = False + self._has_connected_once = True + except Exception as e: + logger.error( + f"{self}: Ultravox call creation/join failed " + f"for tools={tool_names}: {e}" + ) + await self.push_error(f"Failed to connect to Ultravox: {e}", e, fatal=True) + + async def _receive_messages(self): + """Receive messages from the Ultravox Realtime WebSocket. + + Upstream handles exceptions raised while processing individual messages, + but websocket close exceptions are raised by the async iterator itself. + During user hangup / pipeline teardown that close is expected, so treat + normal websocket shutdown as a debug condition rather than a pipeline + error. + """ + if not self._socket: + return + + try: + async for message in self._socket: + try: + if isinstance(message, bytes): + await self._handle_audio(message) + continue + + data = json.loads(message) + match data.get("type"): + case "call_started": + self._call_started = True + logger.debug( + f"{self}: Ultravox call_started received for callId=" + f"{data.get('callId')}" + ) + await self._flush_pending_user_text_messages() + case "state": + if self._bot_responding and data.get("state") != "speaking": + await self._handle_response_end() + case "client_tool_invocation": + await self._handle_tool_invocation( + data.get("toolName"), + data.get("invocationId"), + data.get("parameters"), + ) + case "transcript": + match data.get("role"): + case "user": + if not data.get("final"): + logger.warning( + "Unexpected non-final user transcript from Ultravox Realtime; ignoring." + ) + else: + await self._handle_user_transcript( + data.get("text") + ) + case "agent": + await self._handle_agent_transcript( + data.get("medium"), + data.get("text"), + data.get("delta"), + data.get("final", False), + ) + case _: + logger.debug( + f"Received transcript with unknown role from Ultravox Realtime: {data}" + ) + case _: + logger.debug(f"Received unhandled Ultravox message: {data}") + except Exception as e: + if self._disconnecting or not self._socket: + return + await self.push_error( + "Ultravox websocket receive error", e, fatal=True + ) + except ConnectionClosed as e: + if ( + self._disconnecting + or not self._socket + or self._is_benign_websocket_close(e) + ): + logger.debug(f"{self}: Ultravox websocket closed: {e}") + return + await self.push_error("Ultravox websocket receive error", e, fatal=True) + + async def _flush_pending_user_text_messages(self): + if ( + not self._socket + or not self._call_started + or not self._pending_user_text_messages + ): + return + + pending_texts = self._pending_user_text_messages + self._pending_user_text_messages = [] + for pending_text in pending_texts: + await self._send_user_text(pending_text) + + async def _reconnect_with_context( + self, + *, + system_instruction: str | None, + initial_messages: list[dict[str, Any]] | None, + ): + call_initial_messages = self._initial_messages_for_call(initial_messages) + logger.debug( + f"{self}: reconnecting Ultravox call with initialMessages=" + f"{json.dumps(call_initial_messages, ensure_ascii=True, default=str)}" + ) + if self._socket: + await self._disconnect(preserve_completed_tool_calls=True) + + await self._connect_call( + system_instruction=system_instruction, + greeting_text=None, + initial_messages=initial_messages, + agent_speaks_first=self._should_agent_speak_first(initial_messages), + ) + self._reconnect_required = False + + def _build_one_shot_params( + self, + *, + greeting_text: str | None, + initial_messages: list[dict[str, Any]] | None, + agent_speaks_first: bool, + ) -> DograhUltravoxOneShotInputParams: + current_params = self._params + extra = { + key: value + for key, value in current_params.extra.items() + if key not in {"firstSpeakerSettings", "initialMessages"} + } + + if greeting_text is not None: + extra["firstSpeakerSettings"] = {"agent": {"text": greeting_text}} + elif agent_speaks_first: + extra["firstSpeakerSettings"] = {"agent": {}} + else: + extra["firstSpeakerSettings"] = {"user": {}} + call_initial_messages = self._initial_messages_for_call(initial_messages) + if call_initial_messages: + extra["initialMessages"] = call_initial_messages + + output_medium = self._settings.output_medium + if isinstance(output_medium, _NotGiven): + output_medium = current_params.output_medium + + return DograhUltravoxOneShotInputParams( + api_key=current_params.api_key, + system_prompt=self._current_system_instruction(), + temperature=current_params.temperature, + model=assert_given(self._settings.model), + voice=current_params.voice, + metadata=current_params.metadata, + output_medium=output_medium, + max_duration=current_params.max_duration, + extra=extra, + ) + + def _current_tools_schema(self, context: LLMContext | None): + if context is None or not is_given(context.tools): + return None + return context.tools + + def _to_selected_tools(self, tool: Any) -> list[dict[str, Any]]: + selected_tools = super()._to_selected_tools(tool) + for selected_tool in selected_tools: + temporary_tool = selected_tool.get("temporaryTool") + if not isinstance(temporary_tool, dict): + continue + + tool_name = temporary_tool.get("modelToolName") + if not isinstance(tool_name, str): + continue + + timeout = self._ultravox_timeout_for_tool(tool_name) + if timeout is not None: + temporary_tool["timeout"] = timeout + return selected_tools + + def _current_system_instruction(self) -> str | None: + system_instruction = self._settings.system_instruction + if isinstance(system_instruction, _NotGiven): + return None + return system_instruction + + def _ultravox_timeout_for_tool(self, function_name: str) -> str | None: + item = self._functions.get(function_name) or self._functions.get(None) + if item is None or item.timeout_secs is None or item.timeout_secs <= 0: + return None + + timeout_secs = min(float(item.timeout_secs), _ULTRAVOX_MAX_TOOL_TIMEOUT_SECS) + return f"{timeout_secs:g}s" + + def _initial_messages_for_call( + self, initial_messages: list[dict[str, Any]] | None + ) -> list[dict[str, Any]] | None: + if not initial_messages: + return None + if not self._should_add_resumption_user_message(initial_messages): + return initial_messages + + return [ + *initial_messages, + { + "role": "MESSAGE_ROLE_USER", + "text": _RESUMPTION_USER_MESSAGE, + }, + ] + + def _build_initial_messages( + self, context: LLMContext + ) -> tuple[list[dict[str, Any]] | None, set[str]]: + initial_messages: list[dict[str, Any]] = [] + tool_call_id_to_name: dict[str, str] = {} + completed_tool_call_ids: set[str] = set() + + for message in context.get_messages(): + if isinstance(message, LLMSpecificMessage): + continue + + async_payload = async_tool_messages.parse_message(message) + if async_payload is not None: + if async_payload.kind == "intermediate": + logger.error( + f"{self}: Ultravox does not support streamed async tool results; " + f"dropping intermediate result from initialMessages for " + f"tool_call_id={async_payload.tool_call_id}." + ) + continue + if async_payload.kind == "final": + initial_message = self._build_ultravox_message( + role="MESSAGE_ROLE_TOOL_RESULT", + text=async_payload.result or "", + invocation_id=async_payload.tool_call_id, + tool_name=tool_call_id_to_name.get(async_payload.tool_call_id), + ) + if initial_message is not None: + initial_messages.append(initial_message) + completed_tool_call_ids.add(async_payload.tool_call_id) + continue + + role = message.get("role") + if role == "user": + initial_message = self._build_ultravox_message( + role="MESSAGE_ROLE_USER", + text=self._extract_text_content(message.get("content")), + ) + if initial_message is not None: + initial_messages.append(initial_message) + elif role == "assistant": + text = self._extract_text_content(message.get("content")) + initial_message = self._build_ultravox_message( + role="MESSAGE_ROLE_AGENT", + text=text, + ) + if initial_message is not None: + initial_messages.append(initial_message) + + tool_calls = message.get("tool_calls") + if isinstance(tool_calls, list): + for tool_call in tool_calls: + if not isinstance(tool_call, dict): + continue + tool_id = tool_call.get("id") + function = tool_call.get("function") + tool_name = ( + function.get("name") if isinstance(function, dict) else None + ) + if isinstance(tool_id, str) and isinstance(tool_name, str): + tool_call_id_to_name[tool_id] = tool_name + initial_message = self._build_ultravox_message( + role="MESSAGE_ROLE_TOOL_CALL", + text="", + invocation_id=tool_id, + tool_name=tool_name, + ) + if initial_message is not None: + initial_messages.append(initial_message) + elif ( + role == "tool" + and message.get("content") != "IN_PROGRESS" + and message.get("content") != "CANCELLED" + ): + tool_call_id = message.get("tool_call_id") + initial_message = self._build_ultravox_message( + role="MESSAGE_ROLE_TOOL_RESULT", + text=self._stringify_tool_result(message.get("content")), + invocation_id=tool_call_id + if isinstance(tool_call_id, str) + else None, + tool_name=( + tool_call_id_to_name.get(tool_call_id) + if isinstance(tool_call_id, str) + else None + ), + ) + if initial_message is not None: + initial_messages.append(initial_message) + if isinstance(tool_call_id, str): + completed_tool_call_ids.add(tool_call_id) + + return (initial_messages or None), completed_tool_call_ids + + @staticmethod + def _build_ultravox_message( + *, + role: str, + text: str | None, + invocation_id: str | None = None, + tool_name: str | None = None, + ) -> dict[str, Any] | None: + if text is None: + return None + + message: dict[str, Any] = { + "role": role, + "text": text, + } + if invocation_id is not None: + message["invocationId"] = invocation_id + if tool_name is not None: + message["toolName"] = tool_name + return message + + @staticmethod + def _should_agent_speak_first( + initial_messages: list[dict[str, Any]] | None, + ) -> bool: + if not initial_messages: + return True + return initial_messages[-1].get("role") in { + "MESSAGE_ROLE_USER", + "MESSAGE_ROLE_TOOL_RESULT", + } + + @staticmethod + def _should_add_resumption_user_message( + initial_messages: list[dict[str, Any]] | None, + ) -> bool: + if not initial_messages: + return False + return initial_messages[-1].get("role") == "MESSAGE_ROLE_TOOL_RESULT" + + @staticmethod + def _is_benign_websocket_close(exc: ConnectionClosed) -> bool: + return any( + close is not None and close.code in {1000, 1001} + for close in (exc.sent, exc.rcvd) + ) + + @staticmethod + def _extract_text_content(content: Any) -> str | None: + if isinstance(content, str): + return content + if isinstance(content, list): + parts: list[str] = [] + for part in content: + if not isinstance(part, dict): + return None + if part.get("type") != "text": + return None + text = part.get("text") + if not isinstance(text, str): + return None + parts.append(text) + return "\n".join(parts) if parts else None + return None + + @staticmethod + def _stringify_tool_result(content: Any) -> str: + if isinstance(content, str): + return content + if isinstance(content, list): + parts: list[str] = [] + for part in content: + if isinstance(part, dict): + text = part.get("text") + if isinstance(text, str): + parts.append(text) + if parts: + return "".join(parts) + return json.dumps(content, ensure_ascii=True, default=str) diff --git a/api/services/pipecat/service_factory.py b/api/services/pipecat/service_factory.py index 5ef61cd..ad5c357 100644 --- a/api/services/pipecat/service_factory.py +++ b/api/services/pipecat/service_factory.py @@ -640,6 +640,24 @@ def create_realtime_llm_service(user_config, audio_config: "AudioConfig"): ), ), ) + elif provider == ServiceProviders.ULTRAVOX_REALTIME.value: + from api.services.pipecat.realtime.ultravox_realtime import ( + DograhUltravoxOneShotInputParams, + DograhUltravoxRealtimeLLMService, + ) + + return DograhUltravoxRealtimeLLMService( + params=DograhUltravoxOneShotInputParams( + api_key=api_key, + model=model, + voice=voice, + output_medium="voice", + ), + settings=DograhUltravoxRealtimeLLMService.Settings( + model=model, + output_medium="voice", + ), + ) elif provider == ServiceProviders.GOOGLE_REALTIME.value: from api.services.pipecat.realtime.gemini_live import ( DograhGeminiLiveLLMService, diff --git a/api/services/telephony/providers/plivo/routes.py b/api/services/telephony/providers/plivo/routes.py index be1ecd7..ff64b37 100644 --- a/api/services/telephony/providers/plivo/routes.py +++ b/api/services/telephony/providers/plivo/routes.py @@ -5,9 +5,8 @@ provider registry โ€” see ProviderSpec.router. """ import json -from typing import Optional -from fastapi import APIRouter, Header, Request +from fastapi import APIRouter, Request from loguru import logger from pipecat.utils.run_context import set_current_run_id from starlette.responses import HTMLResponse @@ -18,7 +17,6 @@ from api.services.telephony.status_processor import ( StatusCallbackRequest, _process_status_update, ) -from api.utils.common import get_backend_endpoints router = APIRouter() @@ -26,9 +24,6 @@ router = APIRouter() async def _handle_plivo_status_callback( workflow_run_id: int, request: Request, - x_plivo_signature_v3: Optional[str], - x_plivo_signature_ma_v3: Optional[str], - x_plivo_signature_v3_nonce: Optional[str], ): set_current_run_id(workflow_run_id) @@ -52,19 +47,14 @@ async def _handle_plivo_status_callback( workflow_run, workflow.organization_id ) - signature = x_plivo_signature_v3 or x_plivo_signature_ma_v3 - if signature: - backend_endpoint, _ = await get_backend_endpoints() - callback_kind = request.url.path.split("/")[-2] - full_url = f"{backend_endpoint}/api/v1/telephony/plivo/{callback_kind}/{workflow_run_id}" - is_valid = await provider.verify_inbound_signature( - full_url, - callback_data, - dict(request.headers), - ) - if not is_valid: - logger.warning(f"[run {workflow_run_id}] Invalid Plivo webhook signature") - return {"status": "error", "reason": "invalid_signature"} + is_valid = await provider.verify_inbound_signature( + str(request.url), + callback_data, + dict(request.headers), + ) + if not is_valid: + logger.warning(f"[run {workflow_run_id}] Invalid Plivo webhook signature") + return {"status": "error", "reason": "invalid_signature"} parsed_data = provider.parse_status_callback(callback_data) status_update = StatusCallbackRequest( @@ -88,9 +78,6 @@ async def handle_plivo_xml_webhook( workflow_run_id: int, organization_id: int, request: Request, - x_plivo_signature_v3: Optional[str] = Header(None), - x_plivo_signature_ma_v3: Optional[str] = Header(None), - x_plivo_signature_v3_nonce: Optional[str] = Header(None), ): """ Handle initial webhook from Plivo when an outbound call is answered. @@ -103,26 +90,16 @@ async def handle_plivo_xml_webhook( form_data = await request.form() callback_data = dict(form_data) - signature = x_plivo_signature_v3 or x_plivo_signature_ma_v3 - if signature: - backend_endpoint, _ = await get_backend_endpoints() - full_url = ( - f"{backend_endpoint}/api/v1/telephony/plivo-xml" - f"?workflow_id={workflow_id}" - f"&user_id={user_id}" - f"&workflow_run_id={workflow_run_id}" - f"&organization_id={organization_id}" + is_valid = await provider.verify_inbound_signature( + str(request.url), callback_data, dict(request.headers) + ) + if not is_valid: + logger.warning( + f"[run {workflow_run_id}] Invalid Plivo signature on answer webhook" ) - is_valid = await provider.verify_inbound_signature( - full_url, callback_data, dict(request.headers) + return provider.generate_error_response( + "invalid_signature", "Invalid webhook signature." ) - if not is_valid: - logger.warning( - f"[run {workflow_run_id}] Invalid Plivo signature on answer webhook" - ) - return provider.generate_error_response( - "invalid_signature", "Invalid webhook signature." - ) call_id = callback_data.get("CallUUID") or callback_data.get("RequestUUID") if call_id: @@ -142,33 +119,15 @@ async def handle_plivo_xml_webhook( async def handle_plivo_hangup_callback( workflow_run_id: int, request: Request, - x_plivo_signature_v3: Optional[str] = Header(None), - x_plivo_signature_ma_v3: Optional[str] = Header(None), - x_plivo_signature_v3_nonce: Optional[str] = Header(None), ): """Handle Plivo hangup callbacks.""" - return await _handle_plivo_status_callback( - workflow_run_id, - request, - x_plivo_signature_v3, - x_plivo_signature_ma_v3, - x_plivo_signature_v3_nonce, - ) + return await _handle_plivo_status_callback(workflow_run_id, request) @router.post("/plivo/ring-callback/{workflow_run_id}") async def handle_plivo_ring_callback( workflow_run_id: int, request: Request, - x_plivo_signature_v3: Optional[str] = Header(None), - x_plivo_signature_ma_v3: Optional[str] = Header(None), - x_plivo_signature_v3_nonce: Optional[str] = Header(None), ): """Handle Plivo ring callbacks.""" - return await _handle_plivo_status_callback( - workflow_run_id, - request, - x_plivo_signature_v3, - x_plivo_signature_ma_v3, - x_plivo_signature_v3_nonce, - ) + return await _handle_plivo_status_callback(workflow_run_id, request) diff --git a/api/services/telephony/providers/twilio/routes.py b/api/services/telephony/providers/twilio/routes.py index e8ac939..c779617 100644 --- a/api/services/telephony/providers/twilio/routes.py +++ b/api/services/telephony/providers/twilio/routes.py @@ -5,9 +5,8 @@ provider registry โ€” see ProviderSpec.router. """ import json -from typing import Optional -from fastapi import APIRouter, Header, Request +from fastapi import APIRouter, HTTPException, Request from loguru import logger from pipecat.utils.run_context import set_current_run_id from starlette.responses import HTMLResponse @@ -18,14 +17,17 @@ from api.services.telephony.status_processor import ( StatusCallbackRequest, _process_status_update, ) -from api.utils.common import get_backend_endpoints router = APIRouter() @router.post("/twiml", include_in_schema=False) async def handle_twiml_webhook( - workflow_id: int, user_id: int, workflow_run_id: int, organization_id: int + workflow_id: int, + user_id: int, + workflow_run_id: int, + organization_id: int, + request: Request, ): """ Handle initial webhook from telephony provider. @@ -34,6 +36,18 @@ async def handle_twiml_webhook( workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id) provider = await get_telephony_provider_for_run(workflow_run, organization_id) + callback_data = dict(await request.form()) + + is_valid = await provider.verify_inbound_signature( + str(request.url), + callback_data, + dict(request.headers), + ) + if not is_valid: + logger.warning( + f"[run {workflow_run_id}] Invalid Twilio signature on answer webhook" + ) + raise HTTPException(status_code=401, detail="Invalid webhook signature") response_content = await provider.get_webhook_response( workflow_id, user_id, workflow_run_id @@ -46,7 +60,6 @@ async def handle_twiml_webhook( async def handle_twilio_status_callback( workflow_run_id: int, request: Request, - x_webhook_signature: Optional[str] = Header(None), ): """Handle Twilio-specific status callbacks.""" set_current_run_id(workflow_run_id) @@ -75,19 +88,14 @@ async def handle_twilio_status_callback( workflow_run, workflow.organization_id ) - if x_webhook_signature: - backend_endpoint, _ = await get_backend_endpoints() - full_url = f"{backend_endpoint}/api/v1/telephony/twilio/status-callback/{workflow_run_id}" - - is_valid = await provider.verify_webhook_signature( - full_url, callback_data, x_webhook_signature - ) - - if not is_valid: - logger.warning( - f"Invalid webhook signature for workflow run {workflow_run_id}" - ) - return {"status": "error", "reason": "invalid_signature"} + is_valid = await provider.verify_inbound_signature( + str(request.url), + callback_data, + dict(request.headers), + ) + if not is_valid: + logger.warning(f"Invalid webhook signature for workflow run {workflow_run_id}") + raise HTTPException(status_code=401, detail="Invalid webhook signature") # Parse the callback data into generic format parsed_data = provider.parse_status_callback(callback_data) diff --git a/api/services/telephony/providers/vobiz/routes.py b/api/services/telephony/providers/vobiz/routes.py index 4fffe5b..3c13e4b 100644 --- a/api/services/telephony/providers/vobiz/routes.py +++ b/api/services/telephony/providers/vobiz/routes.py @@ -81,9 +81,9 @@ async def handle_vobiz_hangup_callback( f"[run {workflow_run_id}] Vobiz hangup callback - Headers: {json.dumps(all_headers)}" ) - # Parse the callback data (Vobiz sends form data or JSON) - form_data = await request.form() - callback_data = dict(form_data) + # Parse the callback data from the raw body so signed webhooks can verify + # the exact bytes Vobiz sent without draining the request stream first. + callback_data, raw_body = await parse_webhook_request(request) # TODO: Remove this debug logging after Vobiz team clarifies webhook authentication logger.info( @@ -114,10 +114,6 @@ async def handle_vobiz_hangup_callback( workflow_run, workflow.organization_id ) - # Get raw body for signature verification - raw_body = await request.body() - webhook_body = raw_body.decode("utf-8") - # Verify signature backend_endpoint, _ = await get_backend_endpoints() webhook_url = f"{backend_endpoint}/api/v1/telephony/vobiz/hangup-callback/{workflow_run_id}" @@ -127,7 +123,7 @@ async def handle_vobiz_hangup_callback( callback_data, x_vobiz_signature, x_vobiz_timestamp, - webhook_body, + raw_body, ) if not is_valid: @@ -206,9 +202,9 @@ async def handle_vobiz_ring_callback( f"[run {workflow_run_id}] Vobiz ring callback - Headers: {json.dumps(all_headers)}" ) - # Parse the callback data - form_data = await request.form() - callback_data = dict(form_data) + # Parse the callback data from the raw body so signed webhooks can verify + # the exact bytes Vobiz sent without draining the request stream first. + callback_data, raw_body = await parse_webhook_request(request) # TODO: Remove this debug logging after Vobiz team clarifies webhook authentication logger.info( @@ -240,10 +236,6 @@ async def handle_vobiz_ring_callback( workflow_run, workflow.organization_id ) - # Get raw body for signature verification - raw_body = await request.body() - webhook_body = raw_body.decode("utf-8") - # Verify signature backend_endpoint, _ = await get_backend_endpoints() webhook_url = ( @@ -255,7 +247,7 @@ async def handle_vobiz_ring_callback( callback_data, x_vobiz_signature, x_vobiz_timestamp, - webhook_body, + raw_body, ) if not is_valid: @@ -311,9 +303,10 @@ async def handle_vobiz_hangup_callback_by_workflow( ) try: - callback_data, _ = await parse_webhook_request(request) + callback_data, raw_body = await parse_webhook_request(request) except ValueError: callback_data = {} + raw_body = "" call_uuid = callback_data.get("CallUUID") or callback_data.get("call_uuid") logger.info( @@ -356,8 +349,6 @@ async def handle_vobiz_hangup_callback_by_workflow( ) if x_vobiz_signature: - raw_body = await request.body() - webhook_body = raw_body.decode("utf-8") backend_endpoint, _ = await get_backend_endpoints() webhook_url = f"{backend_endpoint}/api/v1/telephony/vobiz/hangup-callback/workflow/{workflow_id}" @@ -366,7 +357,7 @@ async def handle_vobiz_hangup_callback_by_workflow( callback_data, x_vobiz_signature, x_vobiz_timestamp, - webhook_body, + raw_body, ) if not is_valid: diff --git a/api/services/workflow/pipecat_engine_custom_tools.py b/api/services/workflow/pipecat_engine_custom_tools.py index 3d070b7..25298d7 100644 --- a/api/services/workflow/pipecat_engine_custom_tools.py +++ b/api/services/workflow/pipecat_engine_custom_tools.py @@ -297,6 +297,10 @@ class CustomToolManager: timeout_secs = 120.0 handler = self._create_transfer_call_handler(tool, function_name) else: + timeout_ms = ((tool.definition or {}).get("config", {}) or {}).get( + "timeout_ms", 5000 + ) + timeout_secs = float(timeout_ms) / 1000 handler = self._create_http_tool_handler(tool, function_name) return handler, timeout_secs diff --git a/api/tests/telephony/plivo/test_routes.py b/api/tests/telephony/plivo/test_routes.py new file mode 100644 index 0000000..e3a2b06 --- /dev/null +++ b/api/tests/telephony/plivo/test_routes.py @@ -0,0 +1,185 @@ +import base64 +import hashlib +import hmac +from types import SimpleNamespace +from unittest.mock import AsyncMock, patch +from urllib.parse import urlencode + +import pytest +from starlette.requests import Request + +from api.services.telephony.providers.plivo.provider import PlivoProvider +from api.services.telephony.providers.plivo.routes import ( + handle_plivo_hangup_callback, + handle_plivo_xml_webhook, +) + + +def _provider() -> PlivoProvider: + return PlivoProvider( + { + "auth_id": "MA123", + "auth_token": "plivo-auth-token", + "from_numbers": ["+15551230002"], + } + ) + + +def _request( + *, + path: str, + query: dict[str, str | int], + form_data: dict[str, str], + headers: dict[str, str] | None = None, +) -> Request: + body = urlencode(form_data).encode("utf-8") + query_string = urlencode(query).encode("utf-8") + request_headers = [ + (b"content-type", b"application/x-www-form-urlencoded"), + *[ + (name.lower().encode("ascii"), value.encode("ascii")) + for name, value in (headers or {}).items() + ], + ] + + async def receive(): + return { + "type": "http.request", + "body": body, + "more_body": False, + } + + return Request( + { + "type": "http", + "method": "POST", + "scheme": "https", + "server": ("example.test", 443), + "path": path, + "query_string": query_string, + "headers": request_headers, + }, + receive, + ) + + +def _signature( + provider: PlivoProvider, + *, + path: str, + query: dict[str, str | int], + form_data: dict[str, str], + nonce: str, +) -> str: + url = f"https://example.test{path}" + if query: + url = f"{url}?{urlencode(query)}" + payload = f"{provider._construct_post_url(url, form_data)}.{nonce}" + return base64.b64encode( + hmac.new( + provider.auth_token.encode("utf-8"), + payload.encode("utf-8"), + hashlib.sha256, + ).digest() + ).decode("utf-8") + + +@pytest.mark.asyncio +async def test_plivo_xml_route_accepts_valid_signature_with_extra_query_param(): + provider = _provider() + query = { + "workflow_id": 7, + "user_id": 8, + "workflow_run_id": 123, + "campaign_id": 42, + "organization_id": 11, + } + form_data = { + "CallUUID": "call-123", + "Direction": "outbound", + "From": "15551230001", + "To": "15551230002", + } + nonce = "nonce-123" + request = _request( + path="/api/v1/telephony/plivo-xml", + query=query, + form_data=form_data, + headers={ + "x-plivo-signature-v3": _signature( + provider, + path="/api/v1/telephony/plivo-xml", + query=query, + form_data=form_data, + nonce=nonce, + ), + "x-plivo-signature-v3-nonce": nonce, + }, + ) + + with ( + patch("api.services.telephony.providers.plivo.routes.db_client") as db_client, + patch( + "api.services.telephony.providers.plivo.routes.get_telephony_provider_for_run", + new_callable=AsyncMock, + return_value=provider, + ), + patch.object( + provider, + "get_webhook_response", + new_callable=AsyncMock, + return_value="", + ) as get_webhook_response, + ): + db_client.get_workflow_run_by_id = AsyncMock( + return_value=SimpleNamespace(gathered_context={}, workflow_id=7) + ) + db_client.update_workflow_run = AsyncMock() + + response = await handle_plivo_xml_webhook( + workflow_id=7, + user_id=8, + workflow_run_id=123, + organization_id=11, + request=request, + ) + + assert response.body == b"" + get_webhook_response.assert_awaited_once_with(7, 8, 123) + db_client.update_workflow_run.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_plivo_status_callback_rejects_missing_signature(): + provider = _provider() + request = _request( + path="/api/v1/telephony/plivo/hangup-callback/123", + query={}, + form_data={"CallUUID": "call-123", "Event": "hangup"}, + ) + + with ( + patch("api.services.telephony.providers.plivo.routes.db_client") as db_client, + patch( + "api.services.telephony.providers.plivo.routes.get_telephony_provider_for_run", + new_callable=AsyncMock, + return_value=provider, + ), + patch( + "api.services.telephony.providers.plivo.routes._process_status_update", + new_callable=AsyncMock, + ) as process_status, + ): + db_client.get_workflow_run_by_id = AsyncMock( + return_value=SimpleNamespace(workflow_id=7) + ) + db_client.get_workflow_by_id = AsyncMock( + return_value=SimpleNamespace(organization_id=11) + ) + + result = await handle_plivo_hangup_callback( + workflow_run_id=123, request=request + ) + + assert result == {"status": "error", "reason": "invalid_signature"} + process_status.assert_not_awaited() diff --git a/api/tests/telephony/twilio/test_routes.py b/api/tests/telephony/twilio/test_routes.py new file mode 100644 index 0000000..6748d94 --- /dev/null +++ b/api/tests/telephony/twilio/test_routes.py @@ -0,0 +1,253 @@ +from types import SimpleNamespace +from unittest.mock import AsyncMock, patch +from urllib.parse import urlencode + +import pytest +from fastapi import HTTPException +from starlette.requests import Request +from twilio.request_validator import RequestValidator + +from api.services.telephony.providers.twilio.provider import TwilioProvider +from api.services.telephony.providers.twilio.routes import ( + handle_twilio_status_callback, + handle_twiml_webhook, +) + + +def _provider() -> TwilioProvider: + return TwilioProvider( + { + "account_sid": "AC123", + "auth_token": "twilio-auth-token", + "from_numbers": ["+15551230002"], + } + ) + + +def _request( + *, + path: str, + query: dict[str, str | int], + form_data: dict[str, str], + headers: dict[str, str] | None = None, +) -> Request: + body = urlencode(form_data).encode("utf-8") + query_string = urlencode(query).encode("utf-8") + request_headers = [ + (b"content-type", b"application/x-www-form-urlencoded"), + *[ + (name.lower().encode("ascii"), value.encode("ascii")) + for name, value in (headers or {}).items() + ], + ] + + async def receive(): + return { + "type": "http.request", + "body": body, + "more_body": False, + } + + return Request( + { + "type": "http", + "method": "POST", + "scheme": "https", + "server": ("example.test", 443), + "path": path, + "query_string": query_string, + "headers": request_headers, + }, + receive, + ) + + +def _signature( + provider: TwilioProvider, + *, + path: str, + query: dict[str, str | int], + form_data: dict[str, str], +) -> str: + url = f"https://example.test{path}" + if query: + url = f"{url}?{urlencode(query)}" + validator = RequestValidator(provider.auth_token) + return validator.compute_signature(url, form_data) + + +@pytest.mark.asyncio +async def test_twiml_route_accepts_valid_signature_with_extra_query_param(): + provider = _provider() + query = { + "workflow_id": 7, + "user_id": 8, + "workflow_run_id": 123, + "campaign_id": 42, + "organization_id": 11, + } + form_data = {"CallSid": "CA123", "CallStatus": "in-progress"} + request = _request( + path="/api/v1/telephony/twiml", + query=query, + form_data=form_data, + headers={ + "x-twilio-signature": _signature( + provider, + path="/api/v1/telephony/twiml", + query=query, + form_data=form_data, + ) + }, + ) + + with ( + patch("api.services.telephony.providers.twilio.routes.db_client") as db_client, + patch( + "api.services.telephony.providers.twilio.routes.get_telephony_provider_for_run", + new_callable=AsyncMock, + return_value=provider, + ), + patch.object( + provider, + "get_webhook_response", + new_callable=AsyncMock, + return_value="", + ) as get_webhook_response, + ): + db_client.get_workflow_run_by_id = AsyncMock( + return_value=SimpleNamespace(id=123) + ) + + response = await handle_twiml_webhook( + workflow_id=7, + user_id=8, + workflow_run_id=123, + organization_id=11, + request=request, + ) + + assert response.body == b"" + get_webhook_response.assert_awaited_once_with(7, 8, 123) + + +@pytest.mark.asyncio +async def test_twiml_route_rejects_missing_signature(): + provider = _provider() + request = _request( + path="/api/v1/telephony/twiml", + query={ + "workflow_id": 7, + "user_id": 8, + "workflow_run_id": 123, + "organization_id": 11, + }, + form_data={"CallSid": "CA123"}, + ) + + with ( + patch("api.services.telephony.providers.twilio.routes.db_client") as db_client, + patch( + "api.services.telephony.providers.twilio.routes.get_telephony_provider_for_run", + new_callable=AsyncMock, + return_value=provider, + ), + ): + db_client.get_workflow_run_by_id = AsyncMock( + return_value=SimpleNamespace(id=123) + ) + + with pytest.raises(HTTPException) as exc_info: + await handle_twiml_webhook( + workflow_id=7, + user_id=8, + workflow_run_id=123, + organization_id=11, + request=request, + ) + + assert exc_info.value.status_code == 401 + assert exc_info.value.detail == "Invalid webhook signature" + + +@pytest.mark.asyncio +async def test_twilio_status_callback_rejects_legacy_header_name(): + provider = _provider() + form_data = {"CallSid": "CA123", "CallStatus": "completed"} + request = _request( + path="/api/v1/telephony/twilio/status-callback/123", + query={}, + form_data=form_data, + headers={"x-webhook-signature": "not-a-twilio-signature"}, + ) + + with ( + patch("api.services.telephony.providers.twilio.routes.db_client") as db_client, + patch( + "api.services.telephony.providers.twilio.routes.get_telephony_provider_for_run", + new_callable=AsyncMock, + return_value=provider, + ), + patch( + "api.services.telephony.providers.twilio.routes._process_status_update", + new_callable=AsyncMock, + ) as process_status, + ): + db_client.get_workflow_run_by_id = AsyncMock( + return_value=SimpleNamespace(workflow_id=7) + ) + db_client.get_workflow_by_id = AsyncMock( + return_value=SimpleNamespace(organization_id=11) + ) + + with pytest.raises(HTTPException) as exc_info: + await handle_twilio_status_callback(workflow_run_id=123, request=request) + + assert exc_info.value.status_code == 401 + assert exc_info.value.detail == "Invalid webhook signature" + process_status.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_twilio_status_callback_accepts_valid_signature(): + provider = _provider() + form_data = {"CallSid": "CA123", "CallStatus": "completed"} + request = _request( + path="/api/v1/telephony/twilio/status-callback/123", + query={}, + form_data=form_data, + headers={ + "x-twilio-signature": _signature( + provider, + path="/api/v1/telephony/twilio/status-callback/123", + query={}, + form_data=form_data, + ) + }, + ) + + with ( + patch("api.services.telephony.providers.twilio.routes.db_client") as db_client, + patch( + "api.services.telephony.providers.twilio.routes.get_telephony_provider_for_run", + new_callable=AsyncMock, + return_value=provider, + ), + patch( + "api.services.telephony.providers.twilio.routes._process_status_update", + new_callable=AsyncMock, + ) as process_status, + ): + db_client.get_workflow_run_by_id = AsyncMock( + return_value=SimpleNamespace(workflow_id=7) + ) + db_client.get_workflow_by_id = AsyncMock( + return_value=SimpleNamespace(organization_id=11) + ) + + result = await handle_twilio_status_callback( + workflow_run_id=123, request=request + ) + + assert result == {"status": "success"} + process_status.assert_awaited_once() diff --git a/api/tests/telephony/vobiz/test_routes.py b/api/tests/telephony/vobiz/test_routes.py new file mode 100644 index 0000000..f726eee --- /dev/null +++ b/api/tests/telephony/vobiz/test_routes.py @@ -0,0 +1,178 @@ +import hashlib +import hmac +from datetime import UTC, datetime +from types import SimpleNamespace +from unittest.mock import AsyncMock, patch +from urllib.parse import urlencode + +import pytest +from starlette.requests import Request + +from api.services.telephony.providers.vobiz.provider import VobizProvider +from api.services.telephony.providers.vobiz.routes import ( + handle_vobiz_hangup_callback, + handle_vobiz_ring_callback, +) + + +def _provider() -> VobizProvider: + return VobizProvider( + { + "auth_id": "MA123", + "auth_token": "vobiz-auth-token", + "from_numbers": ["+15551230002"], + } + ) + + +def _request( + *, + path: str, + form_data: dict[str, str], + headers: dict[str, str] | None = None, +) -> Request: + body = urlencode(form_data).encode("utf-8") + request_headers = [ + (b"content-type", b"application/x-www-form-urlencoded"), + *[ + (name.lower().encode("ascii"), value.encode("ascii")) + for name, value in (headers or {}).items() + ], + ] + + async def receive(): + return { + "type": "http.request", + "body": body, + "more_body": False, + } + + return Request( + { + "type": "http", + "method": "POST", + "scheme": "https", + "server": ("example.test", 443), + "path": path, + "query_string": b"", + "headers": request_headers, + }, + receive, + ) + + +def _signed_headers( + provider: VobizProvider, *, form_data: dict[str, str] +) -> dict[str, str]: + timestamp = str(int(datetime.now(UTC).timestamp())) + body = urlencode(form_data) + signature = hmac.new( + provider.auth_token.encode("utf-8"), + f"{timestamp}.{body}".encode("utf-8"), + hashlib.sha256, + ).hexdigest() + return { + "x-vobiz-signature": signature, + "x-vobiz-timestamp": timestamp, + } + + +@pytest.mark.asyncio +async def test_vobiz_hangup_callback_accepts_signed_form_body(): + provider = _provider() + form_data = { + "CallUUID": "call-123", + "CallStatus": "completed", + "From": "15551230001", + "To": "15551230002", + "Direction": "outbound", + "Duration": "12", + } + headers = _signed_headers(provider, form_data=form_data) + request = _request( + path="/api/v1/telephony/vobiz/hangup-callback/123", + form_data=form_data, + headers=headers, + ) + + with ( + patch("api.services.telephony.providers.vobiz.routes.db_client") as db_client, + patch( + "api.services.telephony.providers.vobiz.routes.get_telephony_provider_for_run", + new_callable=AsyncMock, + return_value=provider, + ), + patch( + "api.services.telephony.providers.vobiz.routes.get_backend_endpoints", + new_callable=AsyncMock, + return_value=("https://example.test", "wss://example.test"), + ), + patch( + "api.services.telephony.providers.vobiz.routes._process_status_update", + new_callable=AsyncMock, + ) as process_status, + ): + db_client.get_workflow_run_by_id = AsyncMock( + return_value=SimpleNamespace(workflow_id=7) + ) + db_client.get_workflow_by_id = AsyncMock( + return_value=SimpleNamespace(organization_id=11) + ) + + result = await handle_vobiz_hangup_callback( + workflow_run_id=123, + request=request, + x_vobiz_signature=headers["x-vobiz-signature"], + x_vobiz_timestamp=headers["x-vobiz-timestamp"], + ) + + assert result == {"status": "success"} + process_status.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_vobiz_ring_callback_accepts_signed_form_body(): + provider = _provider() + form_data = { + "CallUUID": "call-123", + "CallStatus": "ringing", + "From": "15551230001", + "To": "15551230002", + } + headers = _signed_headers(provider, form_data=form_data) + request = _request( + path="/api/v1/telephony/vobiz/ring-callback/123", + form_data=form_data, + headers=headers, + ) + + workflow_run = SimpleNamespace(workflow_id=7, logs={}) + + with ( + patch("api.services.telephony.providers.vobiz.routes.db_client") as db_client, + patch( + "api.services.telephony.providers.vobiz.routes.get_telephony_provider_for_run", + new_callable=AsyncMock, + return_value=provider, + ), + patch( + "api.services.telephony.providers.vobiz.routes.get_backend_endpoints", + new_callable=AsyncMock, + return_value=("https://example.test", "wss://example.test"), + ), + ): + db_client.get_workflow_run_by_id = AsyncMock(return_value=workflow_run) + db_client.get_workflow_by_id = AsyncMock( + return_value=SimpleNamespace(organization_id=11) + ) + db_client.update_workflow_run = AsyncMock() + + result = await handle_vobiz_ring_callback( + workflow_run_id=123, + request=request, + x_vobiz_signature=headers["x-vobiz-signature"], + x_vobiz_timestamp=headers["x-vobiz-timestamp"], + ) + + assert result == {"status": "success"} + db_client.update_workflow_run.assert_awaited_once() diff --git a/api/tests/test_custom_tools.py b/api/tests/test_custom_tools.py index 4638e5f..703ae76 100644 --- a/api/tests/test_custom_tools.py +++ b/api/tests/test_custom_tools.py @@ -935,9 +935,11 @@ class TestCustomToolManagerUnit: # Create a mock engine with a mock LLM mock_llm = Mock() registered_handlers = {} + registered_kwargs = {} def capture_register(name, handler, **kwargs): registered_handlers[name] = handler + registered_kwargs[name] = kwargs mock_llm.register_function = capture_register @@ -986,6 +988,7 @@ class TestCustomToolManagerUnit: # Verify handler was registered assert "api_call" in registered_handlers + assert registered_kwargs["api_call"]["timeout_secs"] == pytest.approx(5) # Now test that the handler works handler = registered_handlers["api_call"] diff --git a/api/tests/test_from_number_pool_isolation.py b/api/tests/test_from_number_pool_isolation.py index c22241f..3c65d10 100644 --- a/api/tests/test_from_number_pool_isolation.py +++ b/api/tests/test_from_number_pool_isolation.py @@ -313,6 +313,13 @@ class TestDispatcherThreadsTelephonyConfig: f"kwargs={store_kwargs}" ) + assert provider.initiate_call.await_count == 1 + webhook_url = provider.initiate_call.await_args.kwargs["webhook_url"] + assert "campaign_id=" not in webhook_url, ( + "campaign outbound answer_url should not include campaign_id; " + f"got {webhook_url}" + ) + @pytest.mark.asyncio async def test_release_call_slot_uses_stored_telephony_config(self): """When a call completes, release_call_slot must release the from_number diff --git a/api/tests/test_google_vertex_llm_service_factory.py b/api/tests/test_google_vertex_llm_service_factory.py index ec8c4ce..966d657 100644 --- a/api/tests/test_google_vertex_llm_service_factory.py +++ b/api/tests/test_google_vertex_llm_service_factory.py @@ -19,7 +19,7 @@ class TestGoogleVertexLLMConfiguration: config = GoogleVertexLLMConfiguration(project_id="demo-project") assert config.provider == ServiceProviders.GOOGLE_VERTEX assert config.model == "gemini-2.5-flash" - assert config.location == "us-east4" + assert config.location == "global" assert config.credentials is None assert config.api_key is None diff --git a/api/tests/test_resolve_effective_config.py b/api/tests/test_resolve_effective_config.py index fb7bbd7..c747387 100644 --- a/api/tests/test_resolve_effective_config.py +++ b/api/tests/test_resolve_effective_config.py @@ -17,6 +17,7 @@ from api.services.configuration.registry import ( GoogleVertexLLMConfiguration, GrokRealtimeLLMConfiguration, OpenAILLMService, + UltravoxRealtimeLLMConfiguration, ) from api.services.configuration.resolve import resolve_effective_config @@ -261,6 +262,22 @@ class TestRealtimeOverride: assert result.realtime.provider == "grok_realtime" assert result.realtime.voice == "Sal" + def test_switch_realtime_provider_to_ultravox(self, global_config_realtime): + result = resolve_effective_config( + global_config_realtime, + { + "realtime": { + "provider": "ultravox_realtime", + "api_key": "ultra-key", + "model": "ultravox-v0.7", + "voice": "Mark", + } + }, + ) + assert isinstance(result.realtime, UltravoxRealtimeLLMConfiguration) + assert result.realtime.provider == "ultravox_realtime" + assert result.realtime.voice == "Mark" + def test_override_is_realtime_only_without_realtime_section(self, global_config): """Override is_realtime=True but provide no realtime config. Should set the flag; realtime section stays None from global.""" diff --git a/api/tests/test_telephony_routes.py b/api/tests/test_telephony_routes.py new file mode 100644 index 0000000..49c2f8d --- /dev/null +++ b/api/tests/test_telephony_routes.py @@ -0,0 +1,158 @@ +from types import SimpleNamespace +from unittest.mock import AsyncMock, Mock, patch + +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from api.routes.telephony import router +from api.services.auth.depends import get_user + + +def _make_test_app() -> FastAPI: + app = FastAPI() + app.include_router(router) + app.dependency_overrides[get_user] = lambda: SimpleNamespace( + id=7, + selected_organization_id=11, + ) + return app + + +def _workflow(*, workflow_id: int = 33, user_id: int = 99): + return SimpleNamespace( + id=workflow_id, + user_id=user_id, + organization_id=11, + template_context_variables={"template_key": "template-value"}, + ) + + +def _provider(): + return SimpleNamespace( + PROVIDER_NAME="twilio", + WEBHOOK_ENDPOINT="twilio/voice", + validate_config=Mock(return_value=True), + initiate_call=AsyncMock( + return_value=SimpleNamespace( + caller_number="+15550001111", + provider_metadata={"call_id": "call-123"}, + ) + ), + ) + + +def test_initiate_call_executes_as_workflow_owner_for_shared_org_workflow(): + app = _make_test_app() + client = TestClient(app) + + workflow = _workflow() + provider = _provider() + quota_mock = AsyncMock( + return_value=SimpleNamespace(has_quota=True, error_message="") + ) + + with ( + patch("api.routes.telephony.db_client") as mock_db, + patch( + "api.routes.telephony.check_dograh_quota_by_user_id", + new=quota_mock, + ), + patch( + "api.routes.telephony.get_default_telephony_provider", + new=AsyncMock(return_value=provider), + ), + patch( + "api.routes.telephony.get_backend_endpoints", + new=AsyncMock(return_value=("https://api.example.com", "wss://ignored")), + ), + ): + mock_db.get_user_configurations = AsyncMock( + return_value=SimpleNamespace(test_phone_number=None) + ) + mock_db.get_default_telephony_configuration = AsyncMock( + return_value=SimpleNamespace(id=55) + ) + mock_db.get_workflow = AsyncMock(return_value=workflow) + mock_db.create_workflow_run = AsyncMock( + return_value=SimpleNamespace( + id=501, + name="WR-TEL-OUT-00000001", + initial_context={"template_key": "template-value"}, + ) + ) + mock_db.update_workflow_run = AsyncMock() + + response = client.post( + "/telephony/initiate-call", + json={"workflow_id": workflow.id, "phone_number": "+15551234567"}, + ) + + assert response.status_code == 200 + quota_mock.assert_awaited_once_with(workflow.user_id, workflow_id=workflow.id) + mock_db.get_workflow.assert_awaited_once_with(workflow.id, organization_id=11) + + create_call = mock_db.create_workflow_run.await_args + create_args = create_call.args + create_kwargs = create_call.kwargs + assert create_args[1] == workflow.id + assert create_kwargs["user_id"] == workflow.user_id + assert create_kwargs["organization_id"] == workflow.organization_id + assert create_kwargs["initial_context"]["template_key"] == "template-value" + + initiate_kwargs = provider.initiate_call.await_args.kwargs + assert initiate_kwargs["workflow_id"] == workflow.id + assert initiate_kwargs["user_id"] == workflow.user_id + assert "user_id=99" in initiate_kwargs["webhook_url"] + + +def test_initiate_call_rejects_existing_run_for_different_workflow(): + app = _make_test_app() + client = TestClient(app) + + workflow = _workflow() + provider = _provider() + quota_mock = AsyncMock( + return_value=SimpleNamespace(has_quota=True, error_message="") + ) + + with ( + patch("api.routes.telephony.db_client") as mock_db, + patch( + "api.routes.telephony.check_dograh_quota_by_user_id", + new=quota_mock, + ), + patch( + "api.routes.telephony.get_default_telephony_provider", + new=AsyncMock(return_value=provider), + ), + ): + mock_db.get_user_configurations = AsyncMock( + return_value=SimpleNamespace(test_phone_number=None) + ) + mock_db.get_default_telephony_configuration = AsyncMock( + return_value=SimpleNamespace(id=55) + ) + mock_db.get_workflow = AsyncMock(return_value=workflow) + mock_db.get_workflow_run = AsyncMock( + return_value=SimpleNamespace( + id=501, + workflow_id=44, + name="WR-TEL-OUT-00000044", + initial_context={}, + ) + ) + + response = client.post( + "/telephony/initiate-call", + json={ + "workflow_id": workflow.id, + "workflow_run_id": 501, + "phone_number": "+15551234567", + }, + ) + + assert response.status_code == 400 + assert response.json()["detail"] == "workflow_run_workflow_mismatch" + mock_db.get_workflow_run.assert_awaited_once_with(501, organization_id=11) + assert not mock_db.create_workflow_run.called + assert provider.initiate_call.await_count == 0 diff --git a/api/tests/test_ultravox_realtime_wrapper.py b/api/tests/test_ultravox_realtime_wrapper.py new file mode 100644 index 0000000..1034b8d --- /dev/null +++ b/api/tests/test_ultravox_realtime_wrapper.py @@ -0,0 +1,459 @@ +from types import SimpleNamespace +from unittest.mock import AsyncMock, call + +import pytest +from pipecat.adapters.schemas.function_schema import FunctionSchema +from pipecat.adapters.schemas.tools_schema import ToolsSchema +from pipecat.frames.frames import LLMMessagesAppendFrame, TTSSpeakFrame +from pipecat.processors.aggregators.llm_context import LLMContext +from pipecat.processors.frame_processor import FrameDirection +from websockets.exceptions import ConnectionClosedError +from websockets.frames import Close + +from api.schemas.user_configuration import UserConfiguration +from api.services.configuration.registry import UltravoxRealtimeLLMConfiguration +from api.services.pipecat.realtime.ultravox_realtime import ( + _RESUMPTION_USER_MESSAGE, + DograhUltravoxOneShotInputParams, + DograhUltravoxRealtimeLLMService, +) +from api.services.pipecat.service_factory import create_realtime_llm_service + + +class _ClosingSocket: + def __init__(self, exc): + self._exc = exc + + def __aiter__(self): + return self + + async def __anext__(self): + raise self._exc + + +class _MessageSocket: + def __init__(self, messages): + self._messages = iter(messages) + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return next(self._messages) + except StopIteration: + raise StopAsyncIteration + + +def _make_service() -> DograhUltravoxRealtimeLLMService: + service = DograhUltravoxRealtimeLLMService( + params=DograhUltravoxOneShotInputParams( + api_key="test-key", + model="ultravox-v0.7", + output_medium="voice", + ), + settings=DograhUltravoxRealtimeLLMService.Settings( + model="ultravox-v0.7", + output_medium="voice", + ), + ) + service.stop_all_metrics = AsyncMock() + service.cancel_task = AsyncMock() + service.push_error = AsyncMock() + return service + + +def _tool_schema() -> ToolsSchema: + return ToolsSchema( + standard_tools=[ + FunctionSchema( + name="transition_to_next_node", + description="Move to the next workflow node", + properties={"reason": {"type": "string"}}, + required=[], + ) + ] + ) + + +@pytest.mark.asyncio +async def test_tts_greeting_triggers_initial_connect(): + service = _make_service() + service._connect_call = AsyncMock() + + await service.process_frame( + TTSSpeakFrame("Hello there", append_to_context=True), + FrameDirection.DOWNSTREAM, + ) + + service._connect_call.assert_awaited_once() + assert service._connect_call.await_args.kwargs["greeting_text"] == "Hello there" + assert service._connect_call.await_args.kwargs["agent_speaks_first"] is True + + +@pytest.mark.asyncio +async def test_initial_context_connects_without_replay(): + service = _make_service() + service._connect_call = AsyncMock() + context = LLMContext() + + await service._handle_context(context) + + service._connect_call.assert_awaited_once() + assert service._connect_call.await_args.kwargs["initial_messages"] is None + assert service._connect_call.await_args.kwargs["agent_speaks_first"] is True + + +@pytest.mark.asyncio +async def test_system_instruction_update_marks_reconnect_required(): + service = _make_service() + service._has_connected_once = True + + changed = await service._update_settings( + DograhUltravoxRealtimeLLMService.Settings(system_instruction="new instruction") + ) + + assert "system_instruction" in changed + assert service._reconnect_required is True + + +@pytest.mark.asyncio +async def test_system_instruction_change_reconnects_with_full_initial_messages(): + service = _make_service() + service._socket = object() + service._has_connected_once = True + service._call_system_instruction = "old instruction" + service._reconnect_required = True + service._settings.system_instruction = "new instruction" + service._reconnect_with_context = AsyncMock() + + context = LLMContext( + messages=[ + {"role": "user", "content": "I want to hear the pricing."}, + { + "role": "assistant", + "content": "Let me check that for you.", + "tool_calls": [ + { + "id": "call-transition", + "type": "function", + "function": { + "name": "transition_to_next_node", + "arguments": '{"reason":"pricing requested"}', + }, + } + ], + }, + { + "role": "tool", + "tool_call_id": "call-transition", + "content": '{"status":"done"}', + }, + ], + tools=_tool_schema(), + ) + + await service._handle_context(context) + + service._reconnect_with_context.assert_awaited_once() + initial_messages = service._reconnect_with_context.await_args.kwargs[ + "initial_messages" + ] + assert initial_messages == [ + { + "role": "MESSAGE_ROLE_USER", + "text": "I want to hear the pricing.", + }, + { + "role": "MESSAGE_ROLE_AGENT", + "text": "Let me check that for you.", + }, + { + "role": "MESSAGE_ROLE_TOOL_CALL", + "text": "", + "invocationId": "call-transition", + "toolName": "transition_to_next_node", + }, + { + "role": "MESSAGE_ROLE_TOOL_RESULT", + "text": '{"status":"done"}', + "invocationId": "call-transition", + "toolName": "transition_to_next_node", + }, + ] + assert "call-transition" in service._completed_tool_calls + + +@pytest.mark.asyncio +async def test_tool_context_update_does_not_reconnect_when_system_instruction_is_unchanged(): + service = _make_service() + service._socket = object() + service._call_system_instruction = "same instruction" + service._settings.system_instruction = "same instruction" + service._reconnect_with_context = AsyncMock() + service._send_tool_result = AsyncMock() + + context = LLMContext( + messages=[ + { + "role": "tool", + "tool_call_id": "call-transition", + "content": '{"status":"done"}', + }, + ], + tools=_tool_schema(), + ) + + await service._handle_context(context) + + service._reconnect_with_context.assert_not_awaited() + service._send_tool_result.assert_awaited_once_with( + "call-transition", + '{"status":"done"}', + ) + + +@pytest.mark.asyncio +async def test_messages_append_frame_sends_user_text(): + service = _make_service() + service._socket = object() + service._call_started = True + service._send_user_text = AsyncMock() + + await service._handle_messages_append( + LLMMessagesAppendFrame( + [{"role": "user", "content": "Are you still there?"}], + run_llm=True, + ) + ) + + service._send_user_text.assert_awaited_once_with("Are you still there?") + + +@pytest.mark.asyncio +async def test_messages_append_frame_queues_user_text_until_call_started(): + service = _make_service() + service._socket = object() + service._call_started = False + service._send_user_text = AsyncMock() + + await service._handle_messages_append( + LLMMessagesAppendFrame( + [{"role": "user", "content": "Are you still there?"}], + run_llm=True, + ) + ) + + assert service._pending_user_text_messages == ["Are you still there?"] + service._send_user_text.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_call_started_flushes_pending_user_text_messages(): + service = _make_service() + service._pending_user_text_messages = [ + "First queued message", + "Second queued message", + ] + service._send_user_text = AsyncMock() + service._socket = _MessageSocket(['{"type":"call_started","callId":"call-123"}']) + + await service._receive_messages() + + assert service._call_started is True + assert service._pending_user_text_messages == [] + assert service._send_user_text.await_args_list == [ + call("First queued message"), + call("Second queued message"), + ] + + +@pytest.mark.asyncio +async def test_completed_input_transcription_is_broadcast_as_finalized(): + service = _make_service() + service.broadcast_frame = AsyncMock() + service._last_user_id = "caller-1" + + await service._handle_user_transcript("Hello there") + + service.broadcast_frame.assert_awaited_once() + assert service.broadcast_frame.await_args.args[0].__name__ == "TranscriptionFrame" + assert service.broadcast_frame.await_args.kwargs["text"] == "Hello there" + assert service.broadcast_frame.await_args.kwargs["finalized"] is True + + +def test_build_one_shot_params_uses_explicit_greeting_text(): + service = _make_service() + + params = service._build_one_shot_params( + greeting_text="Welcome to Dograh", + initial_messages=None, + agent_speaks_first=True, + ) + + assert params.extra["firstSpeakerSettings"] == { + "agent": {"text": "Welcome to Dograh"} + } + + +def test_build_one_shot_params_includes_initial_messages(): + service = _make_service() + service._settings.system_instruction = "Base instruction" + + params = service._build_one_shot_params( + greeting_text=None, + initial_messages=[ + {"role": "MESSAGE_ROLE_USER", "text": "User asked a question."}, + {"role": "MESSAGE_ROLE_TOOL_RESULT", "text": '{"status":"done"}'}, + ], + agent_speaks_first=True, + ) + + assert params.extra["initialMessages"] == [ + {"role": "MESSAGE_ROLE_USER", "text": "User asked a question."}, + {"role": "MESSAGE_ROLE_TOOL_RESULT", "text": '{"status":"done"}'}, + {"role": "MESSAGE_ROLE_USER", "text": _RESUMPTION_USER_MESSAGE}, + ] + assert params.system_prompt == "Base instruction" + + +def test_build_one_shot_params_without_tool_result_does_not_add_resumption_user_message(): + service = _make_service() + service._settings.system_instruction = "Base instruction" + + params = service._build_one_shot_params( + greeting_text=None, + initial_messages=[ + {"role": "MESSAGE_ROLE_USER", "text": "User asked a question."}, + {"role": "MESSAGE_ROLE_AGENT", "text": "Assistant replied."}, + ], + agent_speaks_first=False, + ) + + assert params.system_prompt == "Base instruction" + + +def test_should_agent_speak_first_when_history_ends_with_tool_result(): + service = _make_service() + + assert ( + service._should_agent_speak_first( + [ + {"role": "MESSAGE_ROLE_USER", "text": "Hello"}, + {"role": "MESSAGE_ROLE_TOOL_RESULT", "text": '{"status":"done"}'}, + ] + ) + is True + ) + + +def test_should_not_force_agent_speaks_first_when_history_ends_with_agent(): + service = _make_service() + + assert ( + service._should_agent_speak_first( + [{"role": "MESSAGE_ROLE_AGENT", "text": "How else can I help?"}] + ) + is False + ) + + +def test_should_add_resumption_user_message_only_when_history_ends_with_tool_result(): + service = _make_service() + + assert ( + service._should_add_resumption_user_message( + [{"role": "MESSAGE_ROLE_TOOL_RESULT", "text": '{"status":"done"}'}] + ) + is True + ) + assert ( + service._should_add_resumption_user_message( + [{"role": "MESSAGE_ROLE_AGENT", "text": "Assistant replied."}] + ) + is False + ) + + +def test_to_selected_tools_includes_registered_timeout(): + service = _make_service() + service.register_function( + "transition_to_next_node", + AsyncMock(), + timeout_secs=5.5, + ) + + selected_tools = service._to_selected_tools(_tool_schema()) + + assert selected_tools == [ + { + "temporaryTool": { + "modelToolName": "transition_to_next_node", + "description": "Move to the next workflow node", + "dynamicParameters": [ + { + "name": "reason", + "location": "PARAMETER_LOCATION_BODY", + "schema": {"type": "string"}, + "required": False, + } + ], + "client": {}, + "timeout": "5.5s", + } + } + ] + + +@pytest.mark.asyncio +async def test_receive_messages_ignores_benign_websocket_close(): + service = _make_service() + service._socket = _ClosingSocket( + ConnectionClosedError(None, Close(1000, "OK"), None) + ) + + await service._receive_messages() + + service.push_error.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_receive_messages_reports_unexpected_websocket_close(): + service = _make_service() + service._socket = _ClosingSocket( + ConnectionClosedError(None, Close(1011, "internal error"), None) + ) + + await service._receive_messages() + + service.push_error.assert_awaited_once() + + +def test_factory_creates_dograh_ultravox_realtime_service(): + user_config = UserConfiguration( + is_realtime=True, + realtime=UltravoxRealtimeLLMConfiguration( + provider="ultravox_realtime", + api_key="ultra-key", + model="ultravox-v0.7", + voice="Mark", + ), + ) + + service = create_realtime_llm_service( + user_config, + audio_config=SimpleNamespace(), + ) + + assert isinstance(service, DograhUltravoxRealtimeLLMService) + assert service._params.voice == "Mark" + + +def test_ultravox_realtime_configuration_defaults_to_mark_voice(): + config = UltravoxRealtimeLLMConfiguration( + provider="ultravox_realtime", + api_key="ultra-key", + model="ultravox-v0.7", + ) + + assert config.voice == "Mark" diff --git a/sdk/python/pyproject.toml b/sdk/python/pyproject.toml index 95c857f..bdfcdb6 100644 --- a/sdk/python/pyproject.toml +++ b/sdk/python/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dograh-sdk" -version = "0.1.5" +version = "0.1.6" description = "Typed builder for Dograh voice-AI workflows" readme = "README.md" requires-python = ">=3.10" diff --git a/sdk/python/src/dograh_sdk/_generated_models.py b/sdk/python/src/dograh_sdk/_generated_models.py index 40fbe44..692ab8d 100644 --- a/sdk/python/src/dograh_sdk/_generated_models.py +++ b/sdk/python/src/dograh_sdk/_generated_models.py @@ -1,6 +1,6 @@ # generated by datamodel-codegen: -# filename: dograh-openapi-XXXXXX.json.SafScGt2nh -# timestamp: 2026-05-22T09:06:50+00:00 +# filename: dograh-openapi-XXXXXX.json.ahnZ2z2E21 +# timestamp: 2026-05-23T03:32:29+00:00 from __future__ import annotations diff --git a/sdk/typescript/package.json b/sdk/typescript/package.json index 04972b4..6bc8b26 100644 --- a/sdk/typescript/package.json +++ b/sdk/typescript/package.json @@ -1,6 +1,6 @@ { "name": "@dograh/sdk", - "version": "0.1.5", + "version": "0.1.6", "description": "Typed builder for Dograh voice-AI workflows", "license": "BSD-2-Clause", "author": "Zansat Technologies Private Limited",