diff --git a/api/services/configuration/check_validity.py b/api/services/configuration/check_validity.py index 289e643..d5a724c 100644 --- a/api/services/configuration/check_validity.py +++ b/api/services/configuration/check_validity.py @@ -47,7 +47,9 @@ class UserConfigurationValidator: ServiceProviders.CAMB.value: self._check_camb_api_key, ServiceProviders.AWS_BEDROCK.value: self._check_aws_bedrock_api_key, ServiceProviders.SPEACHES.value: self._check_speaches_api_key, + ServiceProviders.GOOGLE_VERTEX.value: self._check_google_vertex_llm_api_key, ServiceProviders.OPENAI_REALTIME.value: self._check_openai_api_key, + ServiceProviders.GROK_REALTIME.value: self._check_grok_realtime_api_key, ServiceProviders.GOOGLE_REALTIME.value: self._check_google_api_key, ServiceProviders.GOOGLE_VERTEX_REALTIME.value: self._check_google_vertex_realtime_api_key, ServiceProviders.ASSEMBLYAI.value: self._check_assemblyai_api_key, @@ -134,6 +136,20 @@ class UserConfigurationValidator: return [{"model": service_name, "message": str(e)}] return [] + # Vertex LLM uses service-account credentials (or ADC) instead of api_key + if provider == ServiceProviders.GOOGLE_VERTEX.value: + try: + if not self._check_google_vertex_llm_api_key(provider, service_config): + return [ + { + "model": service_name, + "message": f"Invalid {provider} configuration", + } + ] + except ValueError as e: + return [{"model": service_name, "message": str(e)}] + return [] + # AWS Bedrock uses AWS credentials instead of api_key if provider == ServiceProviders.AWS_BEDROCK.value: try: @@ -236,6 +252,9 @@ class UserConfigurationValidator: def _check_openrouter_api_key(self, model: str, api_key: str) -> bool: return True + def _check_grok_realtime_api_key(self, model: str, api_key: str) -> bool: + return True + def _check_speechmatics_api_key(self, model: str, api_key: str) -> bool: return True @@ -254,6 +273,13 @@ class UserConfigurationValidator: raise ValueError("location is required for Google Vertex Realtime") return True + def _check_google_vertex_llm_api_key(self, model: str, service_config) -> bool: + if not getattr(service_config, "project_id", None): + raise ValueError("project_id is required for Google Vertex") + if not getattr(service_config, "location", None): + raise ValueError("location is required for Google Vertex") + return True + def _check_aws_bedrock_api_key(self, model: str, service_config) -> bool: if not service_config.aws_access_key or not service_config.aws_secret_key: raise ValueError("AWS access key and secret key are required for Bedrock") diff --git a/api/services/configuration/masking.py b/api/services/configuration/masking.py index 3aee31c..3b904c6 100644 --- a/api/services/configuration/masking.py +++ b/api/services/configuration/masking.py @@ -18,27 +18,31 @@ from api.services.integrations import get_node_secret_fields VISIBLE_CHARS = 4 # number of trailing characters to reveal MASK_CHAR = "*" MASK_MARKER = "***" # substring that indicates a masked key +SERVICE_SECRET_FIELDS = ("api_key", "credentials", "aws_access_key", "aws_secret_key") -def contains_masked_key(api_key: str | list[str] | None) -> bool: - """Return True if *api_key* looks like a masked placeholder.""" - if api_key is None: +def contains_masked_key(value: str | list[str] | None) -> bool: + """Return True if *value* looks like a masked placeholder.""" + if value is None: return False - keys = api_key if isinstance(api_key, list) else [api_key] + keys = value if isinstance(value, list) else [value] return any(MASK_MARKER in k for k in keys) def check_for_masked_keys(config: "UserConfiguration") -> None: - """Raise ValueError if any service in *config* still has a masked API key.""" + """Raise ValueError if any service in *config* still has a masked secret.""" for field in ("llm", "tts", "stt", "embeddings", "realtime"): service = getattr(config, field, None) if service is None: continue - if contains_masked_key(service.get_all_api_keys()): - raise ValueError( - f"The {field} api_key appears to be masked. " - "Please provide the actual API key, not the masked value." - ) + for secret_field in SERVICE_SECRET_FIELDS: + if not hasattr(service, secret_field): + continue + if contains_masked_key(getattr(service, secret_field, None)): + raise ValueError( + f"The {field} {secret_field} appears to be masked. " + "Please provide the actual value, not the masked value." + ) def mask_key(real_key: str, visible: int = VISIBLE_CHARS) -> str: @@ -105,12 +109,14 @@ def _mask_service(service_cfg: Optional[ServiceConfig]) -> Optional[Dict[str, An # Work on a dict copy so we don't mutate original models data = service_cfg.model_dump() - if "api_key" in data and data["api_key"]: - raw = data["api_key"] + for secret_field in SERVICE_SECRET_FIELDS: + if secret_field not in data or not data[secret_field]: + continue + raw = data[secret_field] if isinstance(raw, list): - data["api_key"] = [mask_key(k) for k in raw] + data[secret_field] = [mask_key(k) for k in raw] else: - data["api_key"] = mask_key(raw) + data[secret_field] = mask_key(raw) return data diff --git a/api/services/configuration/merge.py b/api/services/configuration/merge.py index 992637f..937060d 100644 --- a/api/services/configuration/merge.py +++ b/api/services/configuration/merge.py @@ -7,7 +7,10 @@ stored, while honouring masked API keys. from typing import Dict from api.schemas.user_configuration import UserConfiguration -from api.services.configuration.masking import resolve_masked_api_keys +from api.services.configuration.masking import ( + SERVICE_SECRET_FIELDS, + resolve_masked_api_keys, +) SERVICE_FIELDS = ("llm", "tts", "stt", "embeddings", "realtime") @@ -45,18 +48,16 @@ def merge_user_configurations( and incoming_cfg.get("provider") != old_cfg.get("provider") ) - incoming_api_key = incoming_cfg.get("api_key") - if not provider_changed: - # conditional preservation of api_key - if incoming_api_key is not None: - if old_cfg and "api_key" in old_cfg: - incoming_cfg["api_key"] = resolve_masked_api_keys( - incoming_api_key, old_cfg["api_key"] - ) - else: - if "api_key" in old_cfg: - incoming_cfg["api_key"] = old_cfg["api_key"] + for secret_field in SERVICE_SECRET_FIELDS: + incoming_secret = incoming_cfg.get(secret_field) + if incoming_secret is not None: + if old_cfg and secret_field in old_cfg: + incoming_cfg[secret_field] = resolve_masked_api_keys( + incoming_secret, old_cfg[secret_field] + ) + elif secret_field in old_cfg: + incoming_cfg[secret_field] = old_cfg[secret_field] merged[service_name] = incoming_cfg diff --git a/api/services/configuration/options/google.py b/api/services/configuration/options/google.py index 68a02e8..8852f11 100644 --- a/api/services/configuration/options/google.py +++ b/api/services/configuration/options/google.py @@ -3,6 +3,14 @@ GOOGLE_MODELS = ( "gemini-2.0-flash-lite", "gemini-2.5-flash", "gemini-2.5-flash-lite", + "gemini-3.5-flash", + "gemini-3.5-flash-lite", +) +GOOGLE_VERTEX_MODELS = ( + "gemini-2.5-flash", + "gemini-2.5-flash-lite", + "gemini-3.1-flash-lite", + "gemini-3.5-flash", ) GOOGLE_REALTIME_MODELS = ("gemini-3.1-flash-live-preview",) diff --git a/api/services/configuration/registry.py b/api/services/configuration/registry.py index da9b601..9ba16c3 100644 --- a/api/services/configuration/registry.py +++ b/api/services/configuration/registry.py @@ -28,6 +28,7 @@ from api.services.configuration.options import ( SARVAM_V3_VOICES, SPEECHMATICS_STT_LANGUAGES, ) +from api.services.configuration.options.google import GOOGLE_VERTEX_MODELS class ServiceType(Enum): @@ -58,7 +59,9 @@ class ServiceProviders(str, Enum): GLADIA = "gladia" RIME = "rime" MINIMAX = "minimax" + GOOGLE_VERTEX = "google_vertex" OPENAI_REALTIME = "openai_realtime" + GROK_REALTIME = "grok_realtime" GOOGLE_REALTIME = "google_realtime" GOOGLE_VERTEX_REALTIME = "google_vertex_realtime" @@ -79,7 +82,9 @@ class BaseServiceConfiguration(BaseModel): ServiceProviders.GLADIA, ServiceProviders.RIME, ServiceProviders.MINIMAX, + ServiceProviders.GOOGLE_VERTEX, ServiceProviders.OPENAI_REALTIME, + ServiceProviders.GROK_REALTIME, ServiceProviders.GOOGLE_REALTIME, ServiceProviders.GOOGLE_VERTEX_REALTIME, # ServiceProviders.SARVAM, @@ -206,7 +211,9 @@ OPENROUTER_PROVIDER_MODEL_CONFIG = provider_model_config("Open Router") AZURE_OPENAI_PROVIDER_MODEL_CONFIG = provider_model_config("Azure OpenAI") DOGRAH_PROVIDER_MODEL_CONFIG = provider_model_config("Dograh") AWS_BEDROCK_PROVIDER_MODEL_CONFIG = provider_model_config("AWS Bedrock") +GOOGLE_VERTEX_PROVIDER_MODEL_CONFIG = provider_model_config("Google Vertex") OPENAI_REALTIME_PROVIDER_MODEL_CONFIG = provider_model_config("OpenAI Realtime") +GROK_REALTIME_PROVIDER_MODEL_CONFIG = provider_model_config("Grok Realtime") GOOGLE_REALTIME_PROVIDER_MODEL_CONFIG = provider_model_config("Google Realtime") GOOGLE_VERTEX_REALTIME_PROVIDER_MODEL_CONFIG = provider_model_config( "Google Vertex Realtime" @@ -239,6 +246,7 @@ OPENAI_MODELS = [ "gpt-5-nano", "gpt-3.5-turbo", ] + GROQ_MODELS = [ "llama-3.3-70b-versatile", "deepseek-r1-distill-llama-70b", @@ -292,6 +300,40 @@ class GoogleLLMService(BaseLLMConfiguration): ) +@register_llm +class GoogleVertexLLMConfiguration(BaseLLMConfiguration): + model_config = GOOGLE_VERTEX_PROVIDER_MODEL_CONFIG + provider: Literal[ServiceProviders.GOOGLE_VERTEX] = ServiceProviders.GOOGLE_VERTEX + model: str = Field( + default="gemini-2.5-flash", + description="Gemini model on Vertex AI.", + json_schema_extra={ + "examples": GOOGLE_VERTEX_MODELS, + "allow_custom_input": True, + }, + ) + project_id: str = Field(description="Google Cloud project ID for Vertex AI.") + location: str = Field( + default="global", + description="GCP region for the Vertex AI endpoint (e.g. 'global').", + ) + credentials: str | None = Field( + default=None, + description=( + "Paste the entire service-account JSON file contents. If omitted, " + "falls back to Application Default Credentials (ADC)." + ), + json_schema_extra={"multiline": True}, + ) + api_key: str | list[str] | None = Field( + default=None, + description=( + "Not used for Vertex AI — authentication is via the service account " + "in `credentials` (or ADC). Leave blank." + ), + ) + + @register_llm class GroqLLMService(BaseLLMConfiguration): model_config = GROQ_PROVIDER_MODEL_CONFIG @@ -460,6 +502,32 @@ class OpenAIRealtimeLLMConfiguration(BaseLLMConfiguration): ) +GROK_REALTIME_MODELS = ["grok-voice-think-fast-1.0"] +GROK_REALTIME_VOICES = ["Ara", "Rex", "Sal", "Eve", "Leo"] + + +@register_service(ServiceType.REALTIME) +class GrokRealtimeLLMConfiguration(BaseLLMConfiguration): + model_config = GROK_REALTIME_PROVIDER_MODEL_CONFIG + provider: Literal[ServiceProviders.GROK_REALTIME] = ServiceProviders.GROK_REALTIME + model: str = Field( + default="grok-voice-think-fast-1.0", + description="Grok realtime voice-agent model.", + json_schema_extra={ + "examples": GROK_REALTIME_MODELS, + "allow_custom_input": True, + }, + ) + voice: str = Field( + default="Ara", + description="Voice the model speaks in.", + json_schema_extra={ + "examples": GROK_REALTIME_VOICES, + "allow_custom_input": True, + }, + ) + + @register_service(ServiceType.REALTIME) class GoogleRealtimeLLMConfiguration(BaseLLMConfiguration): model_config = GOOGLE_REALTIME_PROVIDER_MODEL_CONFIG @@ -524,8 +592,8 @@ class GoogleVertexRealtimeLLMConfiguration(BaseLLMConfiguration): ) project_id: str = Field(description="Google Cloud project ID for Vertex AI.") location: str = Field( - default="us-east4", - description="GCP region for the Vertex AI endpoint (e.g. 'us-east4').", + default="global", + description="GCP region for the Vertex AI endpoint (e.g. 'global').", ) credentials: str | None = Field( default=None, @@ -546,6 +614,7 @@ class GoogleVertexRealtimeLLMConfiguration(BaseLLMConfiguration): REALTIME_PROVIDERS = { ServiceProviders.OPENAI_REALTIME.value, + ServiceProviders.GROK_REALTIME.value, ServiceProviders.GOOGLE_REALTIME.value, ServiceProviders.GOOGLE_VERTEX_REALTIME.value, } @@ -554,6 +623,7 @@ REALTIME_PROVIDERS = { LLMConfig = Annotated[ Union[ OpenAILLMService, + GoogleVertexLLMConfiguration, GroqLLMService, OpenRouterLLMConfiguration, GoogleLLMService, @@ -569,6 +639,7 @@ LLMConfig = Annotated[ RealtimeConfig = Annotated[ Union[ OpenAIRealtimeLLMConfiguration, + GrokRealtimeLLMConfiguration, GoogleRealtimeLLMConfiguration, GoogleVertexRealtimeLLMConfiguration, ], diff --git a/api/services/pipecat/realtime/grok_realtime.py b/api/services/pipecat/realtime/grok_realtime.py new file mode 100644 index 0000000..84037c4 --- /dev/null +++ b/api/services/pipecat/realtime/grok_realtime.py @@ -0,0 +1,253 @@ +"""Dograh subclass of pipecat's Grok Realtime LLM service. + +Layers Dograh engine integration quirks onto upstream-pristine +:class:`GrokRealtimeLLMService`. Grok already supports runtime session updates, +so this wrapper stays close to the OpenAI realtime shim. + +Adds: + +- **User-mute audio gating** via ``UserMuteStarted/StoppedFrame``. +- **TTSSpeakFrame as initial-response trigger** so the engine's greeting + flow kicks off the bot's first response. +- **One-off LLMMessagesAppendFrame handling** for ephemeral realtime prompts + like user-idle checks, without mutating Dograh's local ``LLMContext``. +- **Function-call deferral** until the bot finishes speaking, to avoid racing + tool execution with the active audio turn. +- **finalized=True on TranscriptionFrame** for parity with Dograh's other + realtime providers. +""" + +import json +from typing import Any + +from loguru import logger + +from pipecat.frames.frames import ( + BotStartedSpeakingFrame, + BotStoppedSpeakingFrame, + Frame, + LLMFullResponseStartFrame, + LLMMessagesAppendFrame, + TranscriptionFrame, + TTSSpeakFrame, + UserMuteStartedFrame, + UserMuteStoppedFrame, +) +from pipecat.processors.aggregators.llm_context import LLMContext +from pipecat.processors.frame_processor import FrameDirection +from pipecat.services.llm_service import FunctionCallFromLLM +from pipecat.services.xai.realtime import events +from pipecat.services.xai.realtime.llm import GrokRealtimeLLMService +from pipecat.utils.time import time_now_iso8601 + + +class DograhGrokRealtimeLLMService(GrokRealtimeLLMService): + """Grok Realtime with Dograh engine integration quirks.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._user_is_muted: bool = False + self._handled_initial_context: bool = False + self._bot_is_speaking: bool = False + self._deferred_function_calls: list[FunctionCallFromLLM] = [] + + async def process_frame(self, frame: Frame, direction: FrameDirection): + if isinstance(frame, UserMuteStartedFrame): + self._user_is_muted = True + await self.push_frame(frame, direction) + return + if isinstance(frame, UserMuteStoppedFrame): + self._user_is_muted = False + await self.push_frame(frame, direction) + return + if isinstance(frame, TTSSpeakFrame): + if not self._handled_initial_context: + await self._handle_context(self._context) + else: + logger.warning( + f"{self}: TTSSpeakFrame after initial context already " + "handled — Grok Realtime owns audio generation, ignoring" + ) + return + if isinstance(frame, LLMMessagesAppendFrame): + await self._handle_messages_append(frame) + return + if isinstance(frame, BotStartedSpeakingFrame): + self._bot_is_speaking = True + elif isinstance(frame, BotStoppedSpeakingFrame): + self._bot_is_speaking = False + await self._run_pending_function_calls() + await super().process_frame(frame, direction) + + async def _handle_messages_append(self, frame: LLMMessagesAppendFrame): + """Consume a one-off append frame without mutating the local LLMContext.""" + if self._disconnecting: + return + + if not self._api_session_ready: + if frame.run_llm: + logger.debug( + f"{self}: LLMMessagesAppendFrame received before session ready; " + "deferring response until the session is initialized" + ) + self._run_llm_when_api_session_ready = True + return + + appended_any = False + for message in frame.messages: + item = self._message_to_conversation_item(message) + if item is None: + continue + evt = events.ConversationItemCreateEvent(item=item) + self._messages_added_manually[evt.item.id] = True + await self.send_client_event(evt) + appended_any = True + + if frame.run_llm and appended_any: + await self._send_manual_response_create() + + async def _handle_context(self, context: LLMContext): + if not self._handled_initial_context: + if context is None: + logger.warning( + f"{self}: received initial context trigger before context was set" + ) + return + self._handled_initial_context = True + self._context = context + await self._create_response() + else: + self._context = context + await self._process_completed_function_calls(send_new_results=True) + + async def _send_user_audio(self, frame): + if self._user_is_muted: + return + await super()._send_user_audio(frame) + + def _message_to_conversation_item( + self, message: dict[str, Any] + ) -> events.ConversationItem | None: + if not isinstance(message, dict): + logger.warning( + f"{self}: skipping unsupported appended message payload {message!r}" + ) + return None + + role = message.get("role") + if role not in {"user", "system", "developer"}: + logger.warning( + f"{self}: skipping unsupported appended message role {role!r}" + ) + return None + + text = self._extract_text_content(message.get("content")) + if not text: + logger.warning( + f"{self}: skipping appended message with unsupported content {message!r}" + ) + return None + + item_role = "system" if role in {"system", "developer"} else "user" + return events.ConversationItem( + type="message", + role=item_role, + content=[events.ItemContent(type="input_text", text=text)], + ) + + @staticmethod + def _extract_text_content(content: Any) -> str | None: + if isinstance(content, str): + return content + if isinstance(content, list): + parts: list[str] = [] + for part in content: + if not isinstance(part, dict): + return None + if part.get("type") != "text": + return None + text = part.get("text") + if not isinstance(text, str): + return None + parts.append(text) + return "\n".join(parts) if parts else None + return None + + async def _send_manual_response_create(self): + """Trigger inference after manually appending conversation items.""" + await self.push_frame(LLMFullResponseStartFrame()) + await self.start_processing_metrics() + await self.start_ttfb_metrics() + await self.send_client_event( + events.ResponseCreateEvent( + response=events.ResponseProperties(modalities=["text", "audio"]) + ) + ) + + async def _run_pending_function_calls(self): + if not self._deferred_function_calls: + return + function_calls = self._deferred_function_calls + self._deferred_function_calls = [] + logger.debug( + f"{self}: executing {len(function_calls)} deferred function call(s) " + "after bot turn ended" + ) + await self.run_function_calls(function_calls) + + async def _handle_evt_function_call_arguments_done(self, evt): + """Process or defer tool calls until the bot finishes speaking.""" + try: + args = json.loads(evt.arguments) + + function_call_item = self._pending_function_calls.get(evt.call_id) + if function_call_item: + del self._pending_function_calls[evt.call_id] + + function_name = getattr(evt, "name", None) or function_call_item.name + function_calls = [ + FunctionCallFromLLM( + context=self._context, + tool_call_id=evt.call_id, + function_name=function_name, + arguments=args, + ) + ] + + if self._bot_is_speaking: + self._deferred_function_calls.extend(function_calls) + logger.debug( + f"{self}: deferring function call {function_name} " + "until bot stops speaking" + ) + else: + await self.run_function_calls(function_calls) + logger.debug(f"Processed function call: {function_name}") + else: + logger.warning( + f"No tracked function call found for call_id: {evt.call_id}" + ) + logger.warning( + f"Available pending calls: {list(self._pending_function_calls.keys())}" + ) + + except Exception as e: + logger.error(f"Failed to process function call arguments: {e}") + + async def _handle_evt_input_audio_transcription_completed(self, evt): + await self._call_event_handler( + "on_conversation_item_updated", evt.item_id, None + ) + + transcript = evt.transcript.strip() if evt.transcript else "" + if not transcript: + return + + await self.broadcast_frame( + TranscriptionFrame, + text=transcript, + user_id="", + timestamp=time_now_iso8601(), + result=evt, + finalized=True, + ) diff --git a/api/services/pipecat/run_pipeline.py b/api/services/pipecat/run_pipeline.py index fba97a7..6cae498 100644 --- a/api/services/pipecat/run_pipeline.py +++ b/api/services/pipecat/run_pipeline.py @@ -120,6 +120,16 @@ def _create_realtime_user_turn_config(provider: str): ), None, ) + if provider == ServiceProviders.GROK_REALTIME.value: + # Grok Voice Agent emits server-side speech-start/stop and + # interruption signals, so local VAD should stay out of the way. + return ( + UserTurnStrategies( + start=[ExternalUserTurnStartStrategy()], + stop=[ExternalUserTurnStopStrategy()], + ), + None, + ) return ( UserTurnStrategies( diff --git a/api/services/pipecat/service_factory.py b/api/services/pipecat/service_factory.py index 400fbfd..5ef61cd 100644 --- a/api/services/pipecat/service_factory.py +++ b/api/services/pipecat/service_factory.py @@ -30,7 +30,13 @@ from pipecat.services.gladia.stt import GladiaSTTService, GladiaSTTSettings from pipecat.services.google.llm import GoogleLLMService, GoogleLLMSettings from pipecat.services.google.stt import GoogleSTTService, GoogleSTTSettings from pipecat.services.google.tts import GoogleTTSService, GoogleTTSSettings +from pipecat.services.google.vertex.llm import ( + GoogleVertexLLMService, + GoogleVertexLLMSettings, +) from pipecat.services.groq.llm import GroqLLMService, GroqLLMSettings +from pipecat.services.minimax.llm import MiniMaxLLMService +from pipecat.services.minimax.tts import MiniMaxTTSSettings from pipecat.services.openai.base_llm import OpenAILLMSettings from pipecat.services.openai.llm import OpenAILLMService from pipecat.services.openai.stt import ( @@ -40,8 +46,6 @@ from pipecat.services.openai.stt import ( from pipecat.services.openai.tts import OpenAITTSService, OpenAITTSSettings from pipecat.services.openrouter.llm import OpenRouterLLMService, OpenRouterLLMSettings from pipecat.services.rime.tts import RimeTTSService, RimeTTSSettings -from pipecat.services.minimax.llm import MiniMaxLLMService -from pipecat.services.minimax.tts import MiniMaxHttpTTSService, MiniMaxTTSSettings from pipecat.services.sarvam.stt import SarvamSTTService, SarvamSTTSettings from pipecat.services.sarvam.tts import SarvamTTSService, SarvamTTSSettings from pipecat.services.speaches.llm import SpeachesLLMService, SpeachesLLMSettings @@ -482,13 +486,16 @@ def create_tts_service(user_config, audio_config: "AudioConfig"): def create_llm_service_from_provider( provider: str, model: str, - api_key: str, + api_key: str | None, *, base_url: str | None = None, endpoint: str | None = None, aws_access_key: str | None = None, aws_secret_key: str | None = None, aws_region: str | None = None, + project_id: str | None = None, + location: str | None = None, + credentials: str | None = None, temperature: float | None = None, ): """Create an LLM service from explicit provider/model/api_key. @@ -528,6 +535,13 @@ def create_llm_service_from_provider( api_key=api_key, settings=GoogleLLMSettings(model=model, temperature=0.1), ) + elif provider == ServiceProviders.GOOGLE_VERTEX.value: + return GoogleVertexLLMService( + credentials=credentials, + project_id=project_id, + location=location or "us-east4", + settings=GoogleVertexLLMSettings(model=model, temperature=0.1), + ) elif provider == ServiceProviders.AZURE.value: return AzureLLMService( api_key=api_key, @@ -611,6 +625,21 @@ def create_realtime_llm_service(user_config, audio_config: "AudioConfig"): ), ), ) + elif provider == ServiceProviders.GROK_REALTIME.value: + from api.services.pipecat.realtime.grok_realtime import ( + DograhGrokRealtimeLLMService, + ) + from pipecat.services.xai.realtime.events import SessionProperties + + return DograhGrokRealtimeLLMService( + api_key=api_key, + settings=DograhGrokRealtimeLLMService.Settings( + model=model, + session_properties=SessionProperties( + voice=voice or "Ara", + ), + ), + ) elif provider == ServiceProviders.GOOGLE_REALTIME.value: from api.services.pipecat.realtime.gemini_live import ( DograhGeminiLiveLLMService, @@ -672,6 +701,10 @@ def create_llm_service(user_config): kwargs["aws_access_key"] = user_config.llm.aws_access_key kwargs["aws_secret_key"] = user_config.llm.aws_secret_key kwargs["aws_region"] = user_config.llm.aws_region + elif provider == ServiceProviders.GOOGLE_VERTEX.value: + kwargs["project_id"] = user_config.llm.project_id + kwargs["location"] = user_config.llm.location + kwargs["credentials"] = user_config.llm.credentials elif provider == ServiceProviders.MINIMAX.value: kwargs["base_url"] = user_config.llm.base_url kwargs["temperature"] = user_config.llm.temperature diff --git a/api/tests/test_google_vertex_llm_service_factory.py b/api/tests/test_google_vertex_llm_service_factory.py new file mode 100644 index 0000000..ec8c4ce --- /dev/null +++ b/api/tests/test_google_vertex_llm_service_factory.py @@ -0,0 +1,103 @@ +from types import SimpleNamespace +from unittest.mock import patch + +from api.services.configuration.check_validity import UserConfigurationValidator +from api.services.configuration.registry import ( + REGISTRY, + GoogleVertexLLMConfiguration, + ServiceProviders, + ServiceType, +) +from api.services.pipecat.service_factory import ( + create_llm_service, + create_llm_service_from_provider, +) + + +class TestGoogleVertexLLMConfiguration: + def test_defaults(self): + config = GoogleVertexLLMConfiguration(project_id="demo-project") + assert config.provider == ServiceProviders.GOOGLE_VERTEX + assert config.model == "gemini-2.5-flash" + assert config.location == "us-east4" + assert config.credentials is None + assert config.api_key is None + + def test_registered_in_llm_registry(self): + assert ServiceProviders.GOOGLE_VERTEX in REGISTRY[ServiceType.LLM] + assert ( + REGISTRY[ServiceType.LLM][ServiceProviders.GOOGLE_VERTEX] + is GoogleVertexLLMConfiguration + ) + + +class TestGoogleVertexLLMServiceFactory: + def test_create_llm_service_from_provider_uses_vertex_service(self): + with patch( + "api.services.pipecat.service_factory.GoogleVertexLLMService" + ) as mock_service: + create_llm_service_from_provider( + provider=ServiceProviders.GOOGLE_VERTEX.value, + model="gemini-2.5-pro", + api_key=None, + project_id="demo-project", + location="us-central1", + credentials='{"type":"service_account"}', + ) + + kwargs = mock_service.call_args.kwargs + assert kwargs["project_id"] == "demo-project" + assert kwargs["location"] == "us-central1" + assert kwargs["credentials"] == '{"type":"service_account"}' + assert kwargs["settings"].model == "gemini-2.5-pro" + assert kwargs["settings"].temperature == 0.1 + + def test_create_llm_service_extracts_vertex_credentials(self): + user_config = SimpleNamespace( + llm=SimpleNamespace( + provider=ServiceProviders.GOOGLE_VERTEX.value, + api_key=None, + model="gemini-2.5-flash", + project_id="demo-project", + location="us-east4", + credentials='{"type":"service_account"}', + ) + ) + + with patch( + "api.services.pipecat.service_factory.GoogleVertexLLMService" + ) as mock_service: + create_llm_service(user_config) + + kwargs = mock_service.call_args.kwargs + assert kwargs["project_id"] == "demo-project" + assert kwargs["location"] == "us-east4" + assert kwargs["credentials"] == '{"type":"service_account"}' + + +class TestGoogleVertexLLMValidation: + def test_validator_accepts_vertex_llm_without_api_key(self): + validator = UserConfigurationValidator() + config = GoogleVertexLLMConfiguration( + project_id="demo-project", + location="us-east4", + credentials='{"type":"service_account"}', + ) + + assert validator._validate_service(config, "llm") == [] + + def test_validator_requires_project_id(self): + validator = UserConfigurationValidator() + config = SimpleNamespace( + provider=ServiceProviders.GOOGLE_VERTEX.value, + project_id=None, + location="us-east4", + credentials='{"type":"service_account"}', + api_key=None, + ) + + result = validator._validate_service(config, "llm") + + assert result == [ + {"model": "llm", "message": "project_id is required for Google Vertex"} + ] diff --git a/api/tests/test_grok_realtime_wrapper.py b/api/tests/test_grok_realtime_wrapper.py new file mode 100644 index 0000000..f3cfa1a --- /dev/null +++ b/api/tests/test_grok_realtime_wrapper.py @@ -0,0 +1,138 @@ +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import pytest +from pipecat.frames.frames import LLMMessagesAppendFrame, TTSSpeakFrame +from pipecat.processors.aggregators.llm_context import LLMContext +from pipecat.processors.frame_processor import FrameDirection +from pipecat.services.xai.realtime import events + +from api.schemas.user_configuration import UserConfiguration +from api.services.configuration.registry import GrokRealtimeLLMConfiguration +from api.services.pipecat.realtime.grok_realtime import ( + DograhGrokRealtimeLLMService, +) +from api.services.pipecat.service_factory import create_realtime_llm_service + + +def _make_service() -> DograhGrokRealtimeLLMService: + service = DograhGrokRealtimeLLMService(api_key="test-key") + service._create_response = AsyncMock() + service._process_completed_function_calls = AsyncMock() + return service + + +@pytest.mark.asyncio +async def test_initial_context_triggers_response_when_context_was_prepopulated(): + service = _make_service() + context = LLMContext() + service._context = context + + await service._handle_context(context) + + assert service._handled_initial_context is True + assert service._context is context + service._create_response.assert_awaited_once() + service._process_completed_function_calls.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_tts_greeting_uses_initial_context_handler(): + service = _make_service() + service._context = LLMContext() + service._handle_context = AsyncMock() + + await service.process_frame( + TTSSpeakFrame("hello", append_to_context=True), + FrameDirection.DOWNSTREAM, + ) + + service._handle_context.assert_awaited_once_with(service._context) + + +@pytest.mark.asyncio +async def test_messages_append_frame_sends_conversation_item(): + service = _make_service() + service._api_session_ready = True + service.send_client_event = AsyncMock() + service._send_manual_response_create = AsyncMock() + + await service._handle_messages_append( + LLMMessagesAppendFrame( + [{"role": "user", "content": "Are you still there?"}], + run_llm=True, + ) + ) + + service.send_client_event.assert_awaited_once() + event = service.send_client_event.await_args.args[0] + assert isinstance(event, events.ConversationItemCreateEvent) + assert event.item.role == "user" + assert event.item.type == "message" + assert event.item.content == [ + events.ItemContent(type="input_text", text="Are you still there?") + ] + service._send_manual_response_create.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_function_call_is_deferred_until_bot_stops_speaking(): + service = _make_service() + service._context = LLMContext() + service.run_function_calls = AsyncMock() + service._bot_is_speaking = True + service._pending_function_calls["call-1"] = SimpleNamespace(name="customer_support") + + await service._handle_evt_function_call_arguments_done( + SimpleNamespace( + call_id="call-1", + name="customer_support", + arguments='{"department":"sales"}', + ) + ) + + service.run_function_calls.assert_not_awaited() + assert len(service._deferred_function_calls) == 1 + + await service._run_pending_function_calls() + + service.run_function_calls.assert_awaited_once() + assert service._deferred_function_calls == [] + + +@pytest.mark.asyncio +async def test_completed_input_transcription_is_broadcast_as_finalized(): + service = _make_service() + service._call_event_handler = AsyncMock() + service.broadcast_frame = AsyncMock() + + evt = SimpleNamespace(item_id="item-1", transcript="Hello there") + + await service._handle_evt_input_audio_transcription_completed(evt) + + service._call_event_handler.assert_awaited_once_with( + "on_conversation_item_updated", "item-1", None + ) + service.broadcast_frame.assert_awaited_once() + assert service.broadcast_frame.await_args.args[0].__name__ == "TranscriptionFrame" + assert service.broadcast_frame.await_args.kwargs["text"] == "Hello there" + assert service.broadcast_frame.await_args.kwargs["finalized"] is True + + +def test_factory_creates_dograh_grok_realtime_service(): + user_config = UserConfiguration( + is_realtime=True, + realtime=GrokRealtimeLLMConfiguration( + provider="grok_realtime", + api_key="xai-key", + model="grok-voice-think-fast-1.0", + voice="Sal", + ), + ) + + service = create_realtime_llm_service( + user_config, + audio_config=SimpleNamespace(), + ) + + assert isinstance(service, DograhGrokRealtimeLLMService) diff --git a/api/tests/test_masked_key_rejection.py b/api/tests/test_masked_key_rejection.py index bdb6c6a..c6fdb51 100644 --- a/api/tests/test_masked_key_rejection.py +++ b/api/tests/test_masked_key_rejection.py @@ -9,6 +9,7 @@ from api.services.auth.depends import get_user from api.services.configuration.masking import mask_key from api.services.configuration.registry import ( GoogleLLMService, + GoogleVertexLLMConfiguration, OpenAILLMService, ) @@ -168,3 +169,44 @@ class TestMaskedKeyRejection: # Merge resolves the masked key back to the real one, # so check_for_masked_keys should NOT raise. assert response.status_code == 200 + + def test_allows_same_provider_with_masked_vertex_credentials(self): + """Same provider with masked credentials should succeed.""" + app = _make_test_app() + client = TestClient(app) + + real_credentials = '{"type":"service_account","project_id":"demo-project"}' + masked_credentials = mask_key(real_credentials) + existing = UserConfiguration( + llm=GoogleVertexLLMConfiguration( + provider="google_vertex", + api_key=None, + model="gemini-2.5-flash", + project_id="demo-project", + location="us-east4", + credentials=real_credentials, + ) + ) + + with ( + patch("api.routes.user.db_client") as mock_db, + patch("api.routes.user.UserConfigurationValidator") as mock_validator, + ): + mock_db.get_user_configurations = AsyncMock(return_value=existing) + mock_db.update_user_configuration = AsyncMock(return_value=existing) + mock_validator.return_value.validate = AsyncMock() + + response = client.put( + "/user/configurations/user", + json={ + "llm": { + "provider": "google_vertex", + "model": "gemini-2.5-flash", + "project_id": "demo-project", + "location": "us-east4", + "credentials": masked_credentials, + } + }, + ) + + assert response.status_code == 200 diff --git a/api/tests/test_minimax_service_factory.py b/api/tests/test_minimax_service_factory.py index ecf4676..207e39f 100644 --- a/api/tests/test_minimax_service_factory.py +++ b/api/tests/test_minimax_service_factory.py @@ -109,11 +109,12 @@ class TestMiniMaxTTSServiceFactory: ) audio_config = SimpleNamespace(transport_in_sample_rate=16000) - with patch( - "api.services.pipecat.service_factory.aiohttp.ClientSession" - ), patch( - "api.services.pipecat.service_factory.MiniMaxOwnedSessionTTSService" - ) as mock_service: + with ( + patch("api.services.pipecat.service_factory.aiohttp.ClientSession"), + patch( + "api.services.pipecat.service_factory.MiniMaxOwnedSessionTTSService" + ) as mock_service, + ): create_tts_service(user_config, audio_config) assert mock_service.call_count == 1 diff --git a/api/tests/test_resolve_effective_config.py b/api/tests/test_resolve_effective_config.py index 5d37058..fb7bbd7 100644 --- a/api/tests/test_resolve_effective_config.py +++ b/api/tests/test_resolve_effective_config.py @@ -14,6 +14,8 @@ from api.services.configuration.registry import ( DeepgramSTTConfiguration, ElevenlabsTTSConfiguration, GoogleRealtimeLLMConfiguration, + GoogleVertexLLMConfiguration, + GrokRealtimeLLMConfiguration, OpenAILLMService, ) from api.services.configuration.resolve import resolve_effective_config @@ -164,6 +166,23 @@ class TestProviderChange: assert result.tts.provider == "elevenlabs" assert result.stt.provider == "deepgram" + def test_override_llm_to_google_vertex(self, global_config): + result = resolve_effective_config( + global_config, + { + "llm": { + "provider": "google_vertex", + "model": "gemini-2.5-flash", + "project_id": "demo-project", + "location": "us-east4", + "credentials": '{"type":"service_account"}', + } + }, + ) + assert isinstance(result.llm, GoogleVertexLLMConfiguration) + assert result.llm.provider == "google_vertex" + assert result.llm.project_id == "demo-project" + # --------------------------------------------------------------------------- # API key inheritance @@ -226,6 +245,22 @@ class TestRealtimeOverride: assert result.realtime.provider == "google_realtime" # inherited assert result.realtime.api_key == "goog-global-rt" # inherited + def test_switch_realtime_provider_to_grok(self, global_config_realtime): + result = resolve_effective_config( + global_config_realtime, + { + "realtime": { + "provider": "grok_realtime", + "api_key": "xai-key", + "model": "grok-voice-think-fast-1.0", + "voice": "Sal", + } + }, + ) + assert isinstance(result.realtime, GrokRealtimeLLMConfiguration) + assert result.realtime.provider == "grok_realtime" + assert result.realtime.voice == "Sal" + def test_override_is_realtime_only_without_realtime_section(self, global_config): """Override is_realtime=True but provide no realtime config. Should set the flag; realtime section stays None from global.""" diff --git a/api/tests/test_run_pipeline_realtime_turn_config.py b/api/tests/test_run_pipeline_realtime_turn_config.py index 8f09e1b..0ec07bd 100644 --- a/api/tests/test_run_pipeline_realtime_turn_config.py +++ b/api/tests/test_run_pipeline_realtime_turn_config.py @@ -51,6 +51,19 @@ def test_openai_realtime_uses_provider_turn_frames_without_local_vad(): assert isinstance(strategies.stop[0], ExternalUserTurnStopStrategy) +def test_grok_realtime_uses_provider_turn_frames_without_local_vad(): + strategies, vad_analyzer = _create_realtime_user_turn_config( + ServiceProviders.GROK_REALTIME.value + ) + + assert vad_analyzer is None + assert len(strategies.start) == 1 + assert isinstance(strategies.start[0], ExternalUserTurnStartStrategy) + assert strategies.start[0]._enable_interruptions is False + assert len(strategies.stop) == 1 + assert isinstance(strategies.stop[0], ExternalUserTurnStopStrategy) + + def test_unknown_realtime_providers_keep_local_vad(): strategies, vad_analyzer = _create_realtime_user_turn_config("other_realtime")