feat: add xai grok as realtime model

This commit is contained in:
Abhishek Kumar 2026-05-22 18:04:59 +05:30
parent 291264de7b
commit 9135c2da13
14 changed files with 776 additions and 36 deletions

View file

@ -47,7 +47,9 @@ class UserConfigurationValidator:
ServiceProviders.CAMB.value: self._check_camb_api_key,
ServiceProviders.AWS_BEDROCK.value: self._check_aws_bedrock_api_key,
ServiceProviders.SPEACHES.value: self._check_speaches_api_key,
ServiceProviders.GOOGLE_VERTEX.value: self._check_google_vertex_llm_api_key,
ServiceProviders.OPENAI_REALTIME.value: self._check_openai_api_key,
ServiceProviders.GROK_REALTIME.value: self._check_grok_realtime_api_key,
ServiceProviders.GOOGLE_REALTIME.value: self._check_google_api_key,
ServiceProviders.GOOGLE_VERTEX_REALTIME.value: self._check_google_vertex_realtime_api_key,
ServiceProviders.ASSEMBLYAI.value: self._check_assemblyai_api_key,
@ -134,6 +136,20 @@ class UserConfigurationValidator:
return [{"model": service_name, "message": str(e)}]
return []
# Vertex LLM uses service-account credentials (or ADC) instead of api_key
if provider == ServiceProviders.GOOGLE_VERTEX.value:
try:
if not self._check_google_vertex_llm_api_key(provider, service_config):
return [
{
"model": service_name,
"message": f"Invalid {provider} configuration",
}
]
except ValueError as e:
return [{"model": service_name, "message": str(e)}]
return []
# AWS Bedrock uses AWS credentials instead of api_key
if provider == ServiceProviders.AWS_BEDROCK.value:
try:
@ -236,6 +252,9 @@ class UserConfigurationValidator:
def _check_openrouter_api_key(self, model: str, api_key: str) -> bool:
return True
def _check_grok_realtime_api_key(self, model: str, api_key: str) -> bool:
return True
def _check_speechmatics_api_key(self, model: str, api_key: str) -> bool:
return True
@ -254,6 +273,13 @@ class UserConfigurationValidator:
raise ValueError("location is required for Google Vertex Realtime")
return True
def _check_google_vertex_llm_api_key(self, model: str, service_config) -> bool:
if not getattr(service_config, "project_id", None):
raise ValueError("project_id is required for Google Vertex")
if not getattr(service_config, "location", None):
raise ValueError("location is required for Google Vertex")
return True
def _check_aws_bedrock_api_key(self, model: str, service_config) -> bool:
if not service_config.aws_access_key or not service_config.aws_secret_key:
raise ValueError("AWS access key and secret key are required for Bedrock")

View file

@ -18,27 +18,31 @@ from api.services.integrations import get_node_secret_fields
VISIBLE_CHARS = 4 # number of trailing characters to reveal
MASK_CHAR = "*"
MASK_MARKER = "***" # substring that indicates a masked key
SERVICE_SECRET_FIELDS = ("api_key", "credentials", "aws_access_key", "aws_secret_key")
def contains_masked_key(api_key: str | list[str] | None) -> bool:
"""Return True if *api_key* looks like a masked placeholder."""
if api_key is None:
def contains_masked_key(value: str | list[str] | None) -> bool:
"""Return True if *value* looks like a masked placeholder."""
if value is None:
return False
keys = api_key if isinstance(api_key, list) else [api_key]
keys = value if isinstance(value, list) else [value]
return any(MASK_MARKER in k for k in keys)
def check_for_masked_keys(config: "UserConfiguration") -> None:
"""Raise ValueError if any service in *config* still has a masked API key."""
"""Raise ValueError if any service in *config* still has a masked secret."""
for field in ("llm", "tts", "stt", "embeddings", "realtime"):
service = getattr(config, field, None)
if service is None:
continue
if contains_masked_key(service.get_all_api_keys()):
raise ValueError(
f"The {field} api_key appears to be masked. "
"Please provide the actual API key, not the masked value."
)
for secret_field in SERVICE_SECRET_FIELDS:
if not hasattr(service, secret_field):
continue
if contains_masked_key(getattr(service, secret_field, None)):
raise ValueError(
f"The {field} {secret_field} appears to be masked. "
"Please provide the actual value, not the masked value."
)
def mask_key(real_key: str, visible: int = VISIBLE_CHARS) -> str:
@ -105,12 +109,14 @@ def _mask_service(service_cfg: Optional[ServiceConfig]) -> Optional[Dict[str, An
# Work on a dict copy so we don't mutate original models
data = service_cfg.model_dump()
if "api_key" in data and data["api_key"]:
raw = data["api_key"]
for secret_field in SERVICE_SECRET_FIELDS:
if secret_field not in data or not data[secret_field]:
continue
raw = data[secret_field]
if isinstance(raw, list):
data["api_key"] = [mask_key(k) for k in raw]
data[secret_field] = [mask_key(k) for k in raw]
else:
data["api_key"] = mask_key(raw)
data[secret_field] = mask_key(raw)
return data

View file

@ -7,7 +7,10 @@ stored, while honouring masked API keys.
from typing import Dict
from api.schemas.user_configuration import UserConfiguration
from api.services.configuration.masking import resolve_masked_api_keys
from api.services.configuration.masking import (
SERVICE_SECRET_FIELDS,
resolve_masked_api_keys,
)
SERVICE_FIELDS = ("llm", "tts", "stt", "embeddings", "realtime")
@ -45,18 +48,16 @@ def merge_user_configurations(
and incoming_cfg.get("provider") != old_cfg.get("provider")
)
incoming_api_key = incoming_cfg.get("api_key")
if not provider_changed:
# conditional preservation of api_key
if incoming_api_key is not None:
if old_cfg and "api_key" in old_cfg:
incoming_cfg["api_key"] = resolve_masked_api_keys(
incoming_api_key, old_cfg["api_key"]
)
else:
if "api_key" in old_cfg:
incoming_cfg["api_key"] = old_cfg["api_key"]
for secret_field in SERVICE_SECRET_FIELDS:
incoming_secret = incoming_cfg.get(secret_field)
if incoming_secret is not None:
if old_cfg and secret_field in old_cfg:
incoming_cfg[secret_field] = resolve_masked_api_keys(
incoming_secret, old_cfg[secret_field]
)
elif secret_field in old_cfg:
incoming_cfg[secret_field] = old_cfg[secret_field]
merged[service_name] = incoming_cfg

View file

@ -3,6 +3,14 @@ GOOGLE_MODELS = (
"gemini-2.0-flash-lite",
"gemini-2.5-flash",
"gemini-2.5-flash-lite",
"gemini-3.5-flash",
"gemini-3.5-flash-lite",
)
GOOGLE_VERTEX_MODELS = (
"gemini-2.5-flash",
"gemini-2.5-flash-lite",
"gemini-3.1-flash-lite",
"gemini-3.5-flash",
)
GOOGLE_REALTIME_MODELS = ("gemini-3.1-flash-live-preview",)

View file

@ -28,6 +28,7 @@ from api.services.configuration.options import (
SARVAM_V3_VOICES,
SPEECHMATICS_STT_LANGUAGES,
)
from api.services.configuration.options.google import GOOGLE_VERTEX_MODELS
class ServiceType(Enum):
@ -58,7 +59,9 @@ class ServiceProviders(str, Enum):
GLADIA = "gladia"
RIME = "rime"
MINIMAX = "minimax"
GOOGLE_VERTEX = "google_vertex"
OPENAI_REALTIME = "openai_realtime"
GROK_REALTIME = "grok_realtime"
GOOGLE_REALTIME = "google_realtime"
GOOGLE_VERTEX_REALTIME = "google_vertex_realtime"
@ -79,7 +82,9 @@ class BaseServiceConfiguration(BaseModel):
ServiceProviders.GLADIA,
ServiceProviders.RIME,
ServiceProviders.MINIMAX,
ServiceProviders.GOOGLE_VERTEX,
ServiceProviders.OPENAI_REALTIME,
ServiceProviders.GROK_REALTIME,
ServiceProviders.GOOGLE_REALTIME,
ServiceProviders.GOOGLE_VERTEX_REALTIME,
# ServiceProviders.SARVAM,
@ -206,7 +211,9 @@ OPENROUTER_PROVIDER_MODEL_CONFIG = provider_model_config("Open Router")
AZURE_OPENAI_PROVIDER_MODEL_CONFIG = provider_model_config("Azure OpenAI")
DOGRAH_PROVIDER_MODEL_CONFIG = provider_model_config("Dograh")
AWS_BEDROCK_PROVIDER_MODEL_CONFIG = provider_model_config("AWS Bedrock")
GOOGLE_VERTEX_PROVIDER_MODEL_CONFIG = provider_model_config("Google Vertex")
OPENAI_REALTIME_PROVIDER_MODEL_CONFIG = provider_model_config("OpenAI Realtime")
GROK_REALTIME_PROVIDER_MODEL_CONFIG = provider_model_config("Grok Realtime")
GOOGLE_REALTIME_PROVIDER_MODEL_CONFIG = provider_model_config("Google Realtime")
GOOGLE_VERTEX_REALTIME_PROVIDER_MODEL_CONFIG = provider_model_config(
"Google Vertex Realtime"
@ -239,6 +246,7 @@ OPENAI_MODELS = [
"gpt-5-nano",
"gpt-3.5-turbo",
]
GROQ_MODELS = [
"llama-3.3-70b-versatile",
"deepseek-r1-distill-llama-70b",
@ -292,6 +300,40 @@ class GoogleLLMService(BaseLLMConfiguration):
)
@register_llm
class GoogleVertexLLMConfiguration(BaseLLMConfiguration):
model_config = GOOGLE_VERTEX_PROVIDER_MODEL_CONFIG
provider: Literal[ServiceProviders.GOOGLE_VERTEX] = ServiceProviders.GOOGLE_VERTEX
model: str = Field(
default="gemini-2.5-flash",
description="Gemini model on Vertex AI.",
json_schema_extra={
"examples": GOOGLE_VERTEX_MODELS,
"allow_custom_input": True,
},
)
project_id: str = Field(description="Google Cloud project ID for Vertex AI.")
location: str = Field(
default="global",
description="GCP region for the Vertex AI endpoint (e.g. 'global').",
)
credentials: str | None = Field(
default=None,
description=(
"Paste the entire service-account JSON file contents. If omitted, "
"falls back to Application Default Credentials (ADC)."
),
json_schema_extra={"multiline": True},
)
api_key: str | list[str] | None = Field(
default=None,
description=(
"Not used for Vertex AI — authentication is via the service account "
"in `credentials` (or ADC). Leave blank."
),
)
@register_llm
class GroqLLMService(BaseLLMConfiguration):
model_config = GROQ_PROVIDER_MODEL_CONFIG
@ -460,6 +502,32 @@ class OpenAIRealtimeLLMConfiguration(BaseLLMConfiguration):
)
GROK_REALTIME_MODELS = ["grok-voice-think-fast-1.0"]
GROK_REALTIME_VOICES = ["Ara", "Rex", "Sal", "Eve", "Leo"]
@register_service(ServiceType.REALTIME)
class GrokRealtimeLLMConfiguration(BaseLLMConfiguration):
model_config = GROK_REALTIME_PROVIDER_MODEL_CONFIG
provider: Literal[ServiceProviders.GROK_REALTIME] = ServiceProviders.GROK_REALTIME
model: str = Field(
default="grok-voice-think-fast-1.0",
description="Grok realtime voice-agent model.",
json_schema_extra={
"examples": GROK_REALTIME_MODELS,
"allow_custom_input": True,
},
)
voice: str = Field(
default="Ara",
description="Voice the model speaks in.",
json_schema_extra={
"examples": GROK_REALTIME_VOICES,
"allow_custom_input": True,
},
)
@register_service(ServiceType.REALTIME)
class GoogleRealtimeLLMConfiguration(BaseLLMConfiguration):
model_config = GOOGLE_REALTIME_PROVIDER_MODEL_CONFIG
@ -524,8 +592,8 @@ class GoogleVertexRealtimeLLMConfiguration(BaseLLMConfiguration):
)
project_id: str = Field(description="Google Cloud project ID for Vertex AI.")
location: str = Field(
default="us-east4",
description="GCP region for the Vertex AI endpoint (e.g. 'us-east4').",
default="global",
description="GCP region for the Vertex AI endpoint (e.g. 'global').",
)
credentials: str | None = Field(
default=None,
@ -546,6 +614,7 @@ class GoogleVertexRealtimeLLMConfiguration(BaseLLMConfiguration):
REALTIME_PROVIDERS = {
ServiceProviders.OPENAI_REALTIME.value,
ServiceProviders.GROK_REALTIME.value,
ServiceProviders.GOOGLE_REALTIME.value,
ServiceProviders.GOOGLE_VERTEX_REALTIME.value,
}
@ -554,6 +623,7 @@ REALTIME_PROVIDERS = {
LLMConfig = Annotated[
Union[
OpenAILLMService,
GoogleVertexLLMConfiguration,
GroqLLMService,
OpenRouterLLMConfiguration,
GoogleLLMService,
@ -569,6 +639,7 @@ LLMConfig = Annotated[
RealtimeConfig = Annotated[
Union[
OpenAIRealtimeLLMConfiguration,
GrokRealtimeLLMConfiguration,
GoogleRealtimeLLMConfiguration,
GoogleVertexRealtimeLLMConfiguration,
],

View file

@ -0,0 +1,253 @@
"""Dograh subclass of pipecat's Grok Realtime LLM service.
Layers Dograh engine integration quirks onto upstream-pristine
:class:`GrokRealtimeLLMService`. Grok already supports runtime session updates,
so this wrapper stays close to the OpenAI realtime shim.
Adds:
- **User-mute audio gating** via ``UserMuteStarted/StoppedFrame``.
- **TTSSpeakFrame as initial-response trigger** so the engine's greeting
flow kicks off the bot's first response.
- **One-off LLMMessagesAppendFrame handling** for ephemeral realtime prompts
like user-idle checks, without mutating Dograh's local ``LLMContext``.
- **Function-call deferral** until the bot finishes speaking, to avoid racing
tool execution with the active audio turn.
- **finalized=True on TranscriptionFrame** for parity with Dograh's other
realtime providers.
"""
import json
from typing import Any
from loguru import logger
from pipecat.frames.frames import (
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
Frame,
LLMFullResponseStartFrame,
LLMMessagesAppendFrame,
TranscriptionFrame,
TTSSpeakFrame,
UserMuteStartedFrame,
UserMuteStoppedFrame,
)
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.llm_service import FunctionCallFromLLM
from pipecat.services.xai.realtime import events
from pipecat.services.xai.realtime.llm import GrokRealtimeLLMService
from pipecat.utils.time import time_now_iso8601
class DograhGrokRealtimeLLMService(GrokRealtimeLLMService):
"""Grok Realtime with Dograh engine integration quirks."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._user_is_muted: bool = False
self._handled_initial_context: bool = False
self._bot_is_speaking: bool = False
self._deferred_function_calls: list[FunctionCallFromLLM] = []
async def process_frame(self, frame: Frame, direction: FrameDirection):
if isinstance(frame, UserMuteStartedFrame):
self._user_is_muted = True
await self.push_frame(frame, direction)
return
if isinstance(frame, UserMuteStoppedFrame):
self._user_is_muted = False
await self.push_frame(frame, direction)
return
if isinstance(frame, TTSSpeakFrame):
if not self._handled_initial_context:
await self._handle_context(self._context)
else:
logger.warning(
f"{self}: TTSSpeakFrame after initial context already "
"handled — Grok Realtime owns audio generation, ignoring"
)
return
if isinstance(frame, LLMMessagesAppendFrame):
await self._handle_messages_append(frame)
return
if isinstance(frame, BotStartedSpeakingFrame):
self._bot_is_speaking = True
elif isinstance(frame, BotStoppedSpeakingFrame):
self._bot_is_speaking = False
await self._run_pending_function_calls()
await super().process_frame(frame, direction)
async def _handle_messages_append(self, frame: LLMMessagesAppendFrame):
"""Consume a one-off append frame without mutating the local LLMContext."""
if self._disconnecting:
return
if not self._api_session_ready:
if frame.run_llm:
logger.debug(
f"{self}: LLMMessagesAppendFrame received before session ready; "
"deferring response until the session is initialized"
)
self._run_llm_when_api_session_ready = True
return
appended_any = False
for message in frame.messages:
item = self._message_to_conversation_item(message)
if item is None:
continue
evt = events.ConversationItemCreateEvent(item=item)
self._messages_added_manually[evt.item.id] = True
await self.send_client_event(evt)
appended_any = True
if frame.run_llm and appended_any:
await self._send_manual_response_create()
async def _handle_context(self, context: LLMContext):
if not self._handled_initial_context:
if context is None:
logger.warning(
f"{self}: received initial context trigger before context was set"
)
return
self._handled_initial_context = True
self._context = context
await self._create_response()
else:
self._context = context
await self._process_completed_function_calls(send_new_results=True)
async def _send_user_audio(self, frame):
if self._user_is_muted:
return
await super()._send_user_audio(frame)
def _message_to_conversation_item(
self, message: dict[str, Any]
) -> events.ConversationItem | None:
if not isinstance(message, dict):
logger.warning(
f"{self}: skipping unsupported appended message payload {message!r}"
)
return None
role = message.get("role")
if role not in {"user", "system", "developer"}:
logger.warning(
f"{self}: skipping unsupported appended message role {role!r}"
)
return None
text = self._extract_text_content(message.get("content"))
if not text:
logger.warning(
f"{self}: skipping appended message with unsupported content {message!r}"
)
return None
item_role = "system" if role in {"system", "developer"} else "user"
return events.ConversationItem(
type="message",
role=item_role,
content=[events.ItemContent(type="input_text", text=text)],
)
@staticmethod
def _extract_text_content(content: Any) -> str | None:
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for part in content:
if not isinstance(part, dict):
return None
if part.get("type") != "text":
return None
text = part.get("text")
if not isinstance(text, str):
return None
parts.append(text)
return "\n".join(parts) if parts else None
return None
async def _send_manual_response_create(self):
"""Trigger inference after manually appending conversation items."""
await self.push_frame(LLMFullResponseStartFrame())
await self.start_processing_metrics()
await self.start_ttfb_metrics()
await self.send_client_event(
events.ResponseCreateEvent(
response=events.ResponseProperties(modalities=["text", "audio"])
)
)
async def _run_pending_function_calls(self):
if not self._deferred_function_calls:
return
function_calls = self._deferred_function_calls
self._deferred_function_calls = []
logger.debug(
f"{self}: executing {len(function_calls)} deferred function call(s) "
"after bot turn ended"
)
await self.run_function_calls(function_calls)
async def _handle_evt_function_call_arguments_done(self, evt):
"""Process or defer tool calls until the bot finishes speaking."""
try:
args = json.loads(evt.arguments)
function_call_item = self._pending_function_calls.get(evt.call_id)
if function_call_item:
del self._pending_function_calls[evt.call_id]
function_name = getattr(evt, "name", None) or function_call_item.name
function_calls = [
FunctionCallFromLLM(
context=self._context,
tool_call_id=evt.call_id,
function_name=function_name,
arguments=args,
)
]
if self._bot_is_speaking:
self._deferred_function_calls.extend(function_calls)
logger.debug(
f"{self}: deferring function call {function_name} "
"until bot stops speaking"
)
else:
await self.run_function_calls(function_calls)
logger.debug(f"Processed function call: {function_name}")
else:
logger.warning(
f"No tracked function call found for call_id: {evt.call_id}"
)
logger.warning(
f"Available pending calls: {list(self._pending_function_calls.keys())}"
)
except Exception as e:
logger.error(f"Failed to process function call arguments: {e}")
async def _handle_evt_input_audio_transcription_completed(self, evt):
await self._call_event_handler(
"on_conversation_item_updated", evt.item_id, None
)
transcript = evt.transcript.strip() if evt.transcript else ""
if not transcript:
return
await self.broadcast_frame(
TranscriptionFrame,
text=transcript,
user_id="",
timestamp=time_now_iso8601(),
result=evt,
finalized=True,
)

View file

@ -120,6 +120,16 @@ def _create_realtime_user_turn_config(provider: str):
),
None,
)
if provider == ServiceProviders.GROK_REALTIME.value:
# Grok Voice Agent emits server-side speech-start/stop and
# interruption signals, so local VAD should stay out of the way.
return (
UserTurnStrategies(
start=[ExternalUserTurnStartStrategy()],
stop=[ExternalUserTurnStopStrategy()],
),
None,
)
return (
UserTurnStrategies(

View file

@ -30,7 +30,13 @@ from pipecat.services.gladia.stt import GladiaSTTService, GladiaSTTSettings
from pipecat.services.google.llm import GoogleLLMService, GoogleLLMSettings
from pipecat.services.google.stt import GoogleSTTService, GoogleSTTSettings
from pipecat.services.google.tts import GoogleTTSService, GoogleTTSSettings
from pipecat.services.google.vertex.llm import (
GoogleVertexLLMService,
GoogleVertexLLMSettings,
)
from pipecat.services.groq.llm import GroqLLMService, GroqLLMSettings
from pipecat.services.minimax.llm import MiniMaxLLMService
from pipecat.services.minimax.tts import MiniMaxTTSSettings
from pipecat.services.openai.base_llm import OpenAILLMSettings
from pipecat.services.openai.llm import OpenAILLMService
from pipecat.services.openai.stt import (
@ -40,8 +46,6 @@ from pipecat.services.openai.stt import (
from pipecat.services.openai.tts import OpenAITTSService, OpenAITTSSettings
from pipecat.services.openrouter.llm import OpenRouterLLMService, OpenRouterLLMSettings
from pipecat.services.rime.tts import RimeTTSService, RimeTTSSettings
from pipecat.services.minimax.llm import MiniMaxLLMService
from pipecat.services.minimax.tts import MiniMaxHttpTTSService, MiniMaxTTSSettings
from pipecat.services.sarvam.stt import SarvamSTTService, SarvamSTTSettings
from pipecat.services.sarvam.tts import SarvamTTSService, SarvamTTSSettings
from pipecat.services.speaches.llm import SpeachesLLMService, SpeachesLLMSettings
@ -482,13 +486,16 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
def create_llm_service_from_provider(
provider: str,
model: str,
api_key: str,
api_key: str | None,
*,
base_url: str | None = None,
endpoint: str | None = None,
aws_access_key: str | None = None,
aws_secret_key: str | None = None,
aws_region: str | None = None,
project_id: str | None = None,
location: str | None = None,
credentials: str | None = None,
temperature: float | None = None,
):
"""Create an LLM service from explicit provider/model/api_key.
@ -528,6 +535,13 @@ def create_llm_service_from_provider(
api_key=api_key,
settings=GoogleLLMSettings(model=model, temperature=0.1),
)
elif provider == ServiceProviders.GOOGLE_VERTEX.value:
return GoogleVertexLLMService(
credentials=credentials,
project_id=project_id,
location=location or "us-east4",
settings=GoogleVertexLLMSettings(model=model, temperature=0.1),
)
elif provider == ServiceProviders.AZURE.value:
return AzureLLMService(
api_key=api_key,
@ -611,6 +625,21 @@ def create_realtime_llm_service(user_config, audio_config: "AudioConfig"):
),
),
)
elif provider == ServiceProviders.GROK_REALTIME.value:
from api.services.pipecat.realtime.grok_realtime import (
DograhGrokRealtimeLLMService,
)
from pipecat.services.xai.realtime.events import SessionProperties
return DograhGrokRealtimeLLMService(
api_key=api_key,
settings=DograhGrokRealtimeLLMService.Settings(
model=model,
session_properties=SessionProperties(
voice=voice or "Ara",
),
),
)
elif provider == ServiceProviders.GOOGLE_REALTIME.value:
from api.services.pipecat.realtime.gemini_live import (
DograhGeminiLiveLLMService,
@ -672,6 +701,10 @@ def create_llm_service(user_config):
kwargs["aws_access_key"] = user_config.llm.aws_access_key
kwargs["aws_secret_key"] = user_config.llm.aws_secret_key
kwargs["aws_region"] = user_config.llm.aws_region
elif provider == ServiceProviders.GOOGLE_VERTEX.value:
kwargs["project_id"] = user_config.llm.project_id
kwargs["location"] = user_config.llm.location
kwargs["credentials"] = user_config.llm.credentials
elif provider == ServiceProviders.MINIMAX.value:
kwargs["base_url"] = user_config.llm.base_url
kwargs["temperature"] = user_config.llm.temperature

View file

@ -0,0 +1,103 @@
from types import SimpleNamespace
from unittest.mock import patch
from api.services.configuration.check_validity import UserConfigurationValidator
from api.services.configuration.registry import (
REGISTRY,
GoogleVertexLLMConfiguration,
ServiceProviders,
ServiceType,
)
from api.services.pipecat.service_factory import (
create_llm_service,
create_llm_service_from_provider,
)
class TestGoogleVertexLLMConfiguration:
def test_defaults(self):
config = GoogleVertexLLMConfiguration(project_id="demo-project")
assert config.provider == ServiceProviders.GOOGLE_VERTEX
assert config.model == "gemini-2.5-flash"
assert config.location == "us-east4"
assert config.credentials is None
assert config.api_key is None
def test_registered_in_llm_registry(self):
assert ServiceProviders.GOOGLE_VERTEX in REGISTRY[ServiceType.LLM]
assert (
REGISTRY[ServiceType.LLM][ServiceProviders.GOOGLE_VERTEX]
is GoogleVertexLLMConfiguration
)
class TestGoogleVertexLLMServiceFactory:
def test_create_llm_service_from_provider_uses_vertex_service(self):
with patch(
"api.services.pipecat.service_factory.GoogleVertexLLMService"
) as mock_service:
create_llm_service_from_provider(
provider=ServiceProviders.GOOGLE_VERTEX.value,
model="gemini-2.5-pro",
api_key=None,
project_id="demo-project",
location="us-central1",
credentials='{"type":"service_account"}',
)
kwargs = mock_service.call_args.kwargs
assert kwargs["project_id"] == "demo-project"
assert kwargs["location"] == "us-central1"
assert kwargs["credentials"] == '{"type":"service_account"}'
assert kwargs["settings"].model == "gemini-2.5-pro"
assert kwargs["settings"].temperature == 0.1
def test_create_llm_service_extracts_vertex_credentials(self):
user_config = SimpleNamespace(
llm=SimpleNamespace(
provider=ServiceProviders.GOOGLE_VERTEX.value,
api_key=None,
model="gemini-2.5-flash",
project_id="demo-project",
location="us-east4",
credentials='{"type":"service_account"}',
)
)
with patch(
"api.services.pipecat.service_factory.GoogleVertexLLMService"
) as mock_service:
create_llm_service(user_config)
kwargs = mock_service.call_args.kwargs
assert kwargs["project_id"] == "demo-project"
assert kwargs["location"] == "us-east4"
assert kwargs["credentials"] == '{"type":"service_account"}'
class TestGoogleVertexLLMValidation:
def test_validator_accepts_vertex_llm_without_api_key(self):
validator = UserConfigurationValidator()
config = GoogleVertexLLMConfiguration(
project_id="demo-project",
location="us-east4",
credentials='{"type":"service_account"}',
)
assert validator._validate_service(config, "llm") == []
def test_validator_requires_project_id(self):
validator = UserConfigurationValidator()
config = SimpleNamespace(
provider=ServiceProviders.GOOGLE_VERTEX.value,
project_id=None,
location="us-east4",
credentials='{"type":"service_account"}',
api_key=None,
)
result = validator._validate_service(config, "llm")
assert result == [
{"model": "llm", "message": "project_id is required for Google Vertex"}
]

View file

@ -0,0 +1,138 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from pipecat.frames.frames import LLMMessagesAppendFrame, TTSSpeakFrame
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.xai.realtime import events
from api.schemas.user_configuration import UserConfiguration
from api.services.configuration.registry import GrokRealtimeLLMConfiguration
from api.services.pipecat.realtime.grok_realtime import (
DograhGrokRealtimeLLMService,
)
from api.services.pipecat.service_factory import create_realtime_llm_service
def _make_service() -> DograhGrokRealtimeLLMService:
service = DograhGrokRealtimeLLMService(api_key="test-key")
service._create_response = AsyncMock()
service._process_completed_function_calls = AsyncMock()
return service
@pytest.mark.asyncio
async def test_initial_context_triggers_response_when_context_was_prepopulated():
service = _make_service()
context = LLMContext()
service._context = context
await service._handle_context(context)
assert service._handled_initial_context is True
assert service._context is context
service._create_response.assert_awaited_once()
service._process_completed_function_calls.assert_not_awaited()
@pytest.mark.asyncio
async def test_tts_greeting_uses_initial_context_handler():
service = _make_service()
service._context = LLMContext()
service._handle_context = AsyncMock()
await service.process_frame(
TTSSpeakFrame("hello", append_to_context=True),
FrameDirection.DOWNSTREAM,
)
service._handle_context.assert_awaited_once_with(service._context)
@pytest.mark.asyncio
async def test_messages_append_frame_sends_conversation_item():
service = _make_service()
service._api_session_ready = True
service.send_client_event = AsyncMock()
service._send_manual_response_create = AsyncMock()
await service._handle_messages_append(
LLMMessagesAppendFrame(
[{"role": "user", "content": "Are you still there?"}],
run_llm=True,
)
)
service.send_client_event.assert_awaited_once()
event = service.send_client_event.await_args.args[0]
assert isinstance(event, events.ConversationItemCreateEvent)
assert event.item.role == "user"
assert event.item.type == "message"
assert event.item.content == [
events.ItemContent(type="input_text", text="Are you still there?")
]
service._send_manual_response_create.assert_awaited_once()
@pytest.mark.asyncio
async def test_function_call_is_deferred_until_bot_stops_speaking():
service = _make_service()
service._context = LLMContext()
service.run_function_calls = AsyncMock()
service._bot_is_speaking = True
service._pending_function_calls["call-1"] = SimpleNamespace(name="customer_support")
await service._handle_evt_function_call_arguments_done(
SimpleNamespace(
call_id="call-1",
name="customer_support",
arguments='{"department":"sales"}',
)
)
service.run_function_calls.assert_not_awaited()
assert len(service._deferred_function_calls) == 1
await service._run_pending_function_calls()
service.run_function_calls.assert_awaited_once()
assert service._deferred_function_calls == []
@pytest.mark.asyncio
async def test_completed_input_transcription_is_broadcast_as_finalized():
service = _make_service()
service._call_event_handler = AsyncMock()
service.broadcast_frame = AsyncMock()
evt = SimpleNamespace(item_id="item-1", transcript="Hello there")
await service._handle_evt_input_audio_transcription_completed(evt)
service._call_event_handler.assert_awaited_once_with(
"on_conversation_item_updated", "item-1", None
)
service.broadcast_frame.assert_awaited_once()
assert service.broadcast_frame.await_args.args[0].__name__ == "TranscriptionFrame"
assert service.broadcast_frame.await_args.kwargs["text"] == "Hello there"
assert service.broadcast_frame.await_args.kwargs["finalized"] is True
def test_factory_creates_dograh_grok_realtime_service():
user_config = UserConfiguration(
is_realtime=True,
realtime=GrokRealtimeLLMConfiguration(
provider="grok_realtime",
api_key="xai-key",
model="grok-voice-think-fast-1.0",
voice="Sal",
),
)
service = create_realtime_llm_service(
user_config,
audio_config=SimpleNamespace(),
)
assert isinstance(service, DograhGrokRealtimeLLMService)

View file

@ -9,6 +9,7 @@ from api.services.auth.depends import get_user
from api.services.configuration.masking import mask_key
from api.services.configuration.registry import (
GoogleLLMService,
GoogleVertexLLMConfiguration,
OpenAILLMService,
)
@ -168,3 +169,44 @@ class TestMaskedKeyRejection:
# Merge resolves the masked key back to the real one,
# so check_for_masked_keys should NOT raise.
assert response.status_code == 200
def test_allows_same_provider_with_masked_vertex_credentials(self):
"""Same provider with masked credentials should succeed."""
app = _make_test_app()
client = TestClient(app)
real_credentials = '{"type":"service_account","project_id":"demo-project"}'
masked_credentials = mask_key(real_credentials)
existing = UserConfiguration(
llm=GoogleVertexLLMConfiguration(
provider="google_vertex",
api_key=None,
model="gemini-2.5-flash",
project_id="demo-project",
location="us-east4",
credentials=real_credentials,
)
)
with (
patch("api.routes.user.db_client") as mock_db,
patch("api.routes.user.UserConfigurationValidator") as mock_validator,
):
mock_db.get_user_configurations = AsyncMock(return_value=existing)
mock_db.update_user_configuration = AsyncMock(return_value=existing)
mock_validator.return_value.validate = AsyncMock()
response = client.put(
"/user/configurations/user",
json={
"llm": {
"provider": "google_vertex",
"model": "gemini-2.5-flash",
"project_id": "demo-project",
"location": "us-east4",
"credentials": masked_credentials,
}
},
)
assert response.status_code == 200

View file

@ -109,11 +109,12 @@ class TestMiniMaxTTSServiceFactory:
)
audio_config = SimpleNamespace(transport_in_sample_rate=16000)
with patch(
"api.services.pipecat.service_factory.aiohttp.ClientSession"
), patch(
"api.services.pipecat.service_factory.MiniMaxOwnedSessionTTSService"
) as mock_service:
with (
patch("api.services.pipecat.service_factory.aiohttp.ClientSession"),
patch(
"api.services.pipecat.service_factory.MiniMaxOwnedSessionTTSService"
) as mock_service,
):
create_tts_service(user_config, audio_config)
assert mock_service.call_count == 1

View file

@ -14,6 +14,8 @@ from api.services.configuration.registry import (
DeepgramSTTConfiguration,
ElevenlabsTTSConfiguration,
GoogleRealtimeLLMConfiguration,
GoogleVertexLLMConfiguration,
GrokRealtimeLLMConfiguration,
OpenAILLMService,
)
from api.services.configuration.resolve import resolve_effective_config
@ -164,6 +166,23 @@ class TestProviderChange:
assert result.tts.provider == "elevenlabs"
assert result.stt.provider == "deepgram"
def test_override_llm_to_google_vertex(self, global_config):
result = resolve_effective_config(
global_config,
{
"llm": {
"provider": "google_vertex",
"model": "gemini-2.5-flash",
"project_id": "demo-project",
"location": "us-east4",
"credentials": '{"type":"service_account"}',
}
},
)
assert isinstance(result.llm, GoogleVertexLLMConfiguration)
assert result.llm.provider == "google_vertex"
assert result.llm.project_id == "demo-project"
# ---------------------------------------------------------------------------
# API key inheritance
@ -226,6 +245,22 @@ class TestRealtimeOverride:
assert result.realtime.provider == "google_realtime" # inherited
assert result.realtime.api_key == "goog-global-rt" # inherited
def test_switch_realtime_provider_to_grok(self, global_config_realtime):
result = resolve_effective_config(
global_config_realtime,
{
"realtime": {
"provider": "grok_realtime",
"api_key": "xai-key",
"model": "grok-voice-think-fast-1.0",
"voice": "Sal",
}
},
)
assert isinstance(result.realtime, GrokRealtimeLLMConfiguration)
assert result.realtime.provider == "grok_realtime"
assert result.realtime.voice == "Sal"
def test_override_is_realtime_only_without_realtime_section(self, global_config):
"""Override is_realtime=True but provide no realtime config.
Should set the flag; realtime section stays None from global."""

View file

@ -51,6 +51,19 @@ def test_openai_realtime_uses_provider_turn_frames_without_local_vad():
assert isinstance(strategies.stop[0], ExternalUserTurnStopStrategy)
def test_grok_realtime_uses_provider_turn_frames_without_local_vad():
strategies, vad_analyzer = _create_realtime_user_turn_config(
ServiceProviders.GROK_REALTIME.value
)
assert vad_analyzer is None
assert len(strategies.start) == 1
assert isinstance(strategies.start[0], ExternalUserTurnStartStrategy)
assert strategies.start[0]._enable_interruptions is False
assert len(strategies.stop) == 1
assert isinstance(strategies.stop[0], ExternalUserTurnStopStrategy)
def test_unknown_realtime_providers_keep_local_vad():
strategies, vad_analyzer = _create_realtime_user_turn_config("other_realtime")