mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
feat: add xai grok as realtime model
This commit is contained in:
parent
291264de7b
commit
9135c2da13
14 changed files with 776 additions and 36 deletions
|
|
@ -47,7 +47,9 @@ class UserConfigurationValidator:
|
|||
ServiceProviders.CAMB.value: self._check_camb_api_key,
|
||||
ServiceProviders.AWS_BEDROCK.value: self._check_aws_bedrock_api_key,
|
||||
ServiceProviders.SPEACHES.value: self._check_speaches_api_key,
|
||||
ServiceProviders.GOOGLE_VERTEX.value: self._check_google_vertex_llm_api_key,
|
||||
ServiceProviders.OPENAI_REALTIME.value: self._check_openai_api_key,
|
||||
ServiceProviders.GROK_REALTIME.value: self._check_grok_realtime_api_key,
|
||||
ServiceProviders.GOOGLE_REALTIME.value: self._check_google_api_key,
|
||||
ServiceProviders.GOOGLE_VERTEX_REALTIME.value: self._check_google_vertex_realtime_api_key,
|
||||
ServiceProviders.ASSEMBLYAI.value: self._check_assemblyai_api_key,
|
||||
|
|
@ -134,6 +136,20 @@ class UserConfigurationValidator:
|
|||
return [{"model": service_name, "message": str(e)}]
|
||||
return []
|
||||
|
||||
# Vertex LLM uses service-account credentials (or ADC) instead of api_key
|
||||
if provider == ServiceProviders.GOOGLE_VERTEX.value:
|
||||
try:
|
||||
if not self._check_google_vertex_llm_api_key(provider, service_config):
|
||||
return [
|
||||
{
|
||||
"model": service_name,
|
||||
"message": f"Invalid {provider} configuration",
|
||||
}
|
||||
]
|
||||
except ValueError as e:
|
||||
return [{"model": service_name, "message": str(e)}]
|
||||
return []
|
||||
|
||||
# AWS Bedrock uses AWS credentials instead of api_key
|
||||
if provider == ServiceProviders.AWS_BEDROCK.value:
|
||||
try:
|
||||
|
|
@ -236,6 +252,9 @@ class UserConfigurationValidator:
|
|||
def _check_openrouter_api_key(self, model: str, api_key: str) -> bool:
|
||||
return True
|
||||
|
||||
def _check_grok_realtime_api_key(self, model: str, api_key: str) -> bool:
|
||||
return True
|
||||
|
||||
def _check_speechmatics_api_key(self, model: str, api_key: str) -> bool:
|
||||
return True
|
||||
|
||||
|
|
@ -254,6 +273,13 @@ class UserConfigurationValidator:
|
|||
raise ValueError("location is required for Google Vertex Realtime")
|
||||
return True
|
||||
|
||||
def _check_google_vertex_llm_api_key(self, model: str, service_config) -> bool:
|
||||
if not getattr(service_config, "project_id", None):
|
||||
raise ValueError("project_id is required for Google Vertex")
|
||||
if not getattr(service_config, "location", None):
|
||||
raise ValueError("location is required for Google Vertex")
|
||||
return True
|
||||
|
||||
def _check_aws_bedrock_api_key(self, model: str, service_config) -> bool:
|
||||
if not service_config.aws_access_key or not service_config.aws_secret_key:
|
||||
raise ValueError("AWS access key and secret key are required for Bedrock")
|
||||
|
|
|
|||
|
|
@ -18,27 +18,31 @@ from api.services.integrations import get_node_secret_fields
|
|||
VISIBLE_CHARS = 4 # number of trailing characters to reveal
|
||||
MASK_CHAR = "*"
|
||||
MASK_MARKER = "***" # substring that indicates a masked key
|
||||
SERVICE_SECRET_FIELDS = ("api_key", "credentials", "aws_access_key", "aws_secret_key")
|
||||
|
||||
|
||||
def contains_masked_key(api_key: str | list[str] | None) -> bool:
|
||||
"""Return True if *api_key* looks like a masked placeholder."""
|
||||
if api_key is None:
|
||||
def contains_masked_key(value: str | list[str] | None) -> bool:
|
||||
"""Return True if *value* looks like a masked placeholder."""
|
||||
if value is None:
|
||||
return False
|
||||
keys = api_key if isinstance(api_key, list) else [api_key]
|
||||
keys = value if isinstance(value, list) else [value]
|
||||
return any(MASK_MARKER in k for k in keys)
|
||||
|
||||
|
||||
def check_for_masked_keys(config: "UserConfiguration") -> None:
|
||||
"""Raise ValueError if any service in *config* still has a masked API key."""
|
||||
"""Raise ValueError if any service in *config* still has a masked secret."""
|
||||
for field in ("llm", "tts", "stt", "embeddings", "realtime"):
|
||||
service = getattr(config, field, None)
|
||||
if service is None:
|
||||
continue
|
||||
if contains_masked_key(service.get_all_api_keys()):
|
||||
raise ValueError(
|
||||
f"The {field} api_key appears to be masked. "
|
||||
"Please provide the actual API key, not the masked value."
|
||||
)
|
||||
for secret_field in SERVICE_SECRET_FIELDS:
|
||||
if not hasattr(service, secret_field):
|
||||
continue
|
||||
if contains_masked_key(getattr(service, secret_field, None)):
|
||||
raise ValueError(
|
||||
f"The {field} {secret_field} appears to be masked. "
|
||||
"Please provide the actual value, not the masked value."
|
||||
)
|
||||
|
||||
|
||||
def mask_key(real_key: str, visible: int = VISIBLE_CHARS) -> str:
|
||||
|
|
@ -105,12 +109,14 @@ def _mask_service(service_cfg: Optional[ServiceConfig]) -> Optional[Dict[str, An
|
|||
|
||||
# Work on a dict copy so we don't mutate original models
|
||||
data = service_cfg.model_dump()
|
||||
if "api_key" in data and data["api_key"]:
|
||||
raw = data["api_key"]
|
||||
for secret_field in SERVICE_SECRET_FIELDS:
|
||||
if secret_field not in data or not data[secret_field]:
|
||||
continue
|
||||
raw = data[secret_field]
|
||||
if isinstance(raw, list):
|
||||
data["api_key"] = [mask_key(k) for k in raw]
|
||||
data[secret_field] = [mask_key(k) for k in raw]
|
||||
else:
|
||||
data["api_key"] = mask_key(raw)
|
||||
data[secret_field] = mask_key(raw)
|
||||
return data
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,10 @@ stored, while honouring masked API keys.
|
|||
from typing import Dict
|
||||
|
||||
from api.schemas.user_configuration import UserConfiguration
|
||||
from api.services.configuration.masking import resolve_masked_api_keys
|
||||
from api.services.configuration.masking import (
|
||||
SERVICE_SECRET_FIELDS,
|
||||
resolve_masked_api_keys,
|
||||
)
|
||||
|
||||
SERVICE_FIELDS = ("llm", "tts", "stt", "embeddings", "realtime")
|
||||
|
||||
|
|
@ -45,18 +48,16 @@ def merge_user_configurations(
|
|||
and incoming_cfg.get("provider") != old_cfg.get("provider")
|
||||
)
|
||||
|
||||
incoming_api_key = incoming_cfg.get("api_key")
|
||||
|
||||
if not provider_changed:
|
||||
# conditional preservation of api_key
|
||||
if incoming_api_key is not None:
|
||||
if old_cfg and "api_key" in old_cfg:
|
||||
incoming_cfg["api_key"] = resolve_masked_api_keys(
|
||||
incoming_api_key, old_cfg["api_key"]
|
||||
)
|
||||
else:
|
||||
if "api_key" in old_cfg:
|
||||
incoming_cfg["api_key"] = old_cfg["api_key"]
|
||||
for secret_field in SERVICE_SECRET_FIELDS:
|
||||
incoming_secret = incoming_cfg.get(secret_field)
|
||||
if incoming_secret is not None:
|
||||
if old_cfg and secret_field in old_cfg:
|
||||
incoming_cfg[secret_field] = resolve_masked_api_keys(
|
||||
incoming_secret, old_cfg[secret_field]
|
||||
)
|
||||
elif secret_field in old_cfg:
|
||||
incoming_cfg[secret_field] = old_cfg[secret_field]
|
||||
|
||||
merged[service_name] = incoming_cfg
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,14 @@ GOOGLE_MODELS = (
|
|||
"gemini-2.0-flash-lite",
|
||||
"gemini-2.5-flash",
|
||||
"gemini-2.5-flash-lite",
|
||||
"gemini-3.5-flash",
|
||||
"gemini-3.5-flash-lite",
|
||||
)
|
||||
GOOGLE_VERTEX_MODELS = (
|
||||
"gemini-2.5-flash",
|
||||
"gemini-2.5-flash-lite",
|
||||
"gemini-3.1-flash-lite",
|
||||
"gemini-3.5-flash",
|
||||
)
|
||||
|
||||
GOOGLE_REALTIME_MODELS = ("gemini-3.1-flash-live-preview",)
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ from api.services.configuration.options import (
|
|||
SARVAM_V3_VOICES,
|
||||
SPEECHMATICS_STT_LANGUAGES,
|
||||
)
|
||||
from api.services.configuration.options.google import GOOGLE_VERTEX_MODELS
|
||||
|
||||
|
||||
class ServiceType(Enum):
|
||||
|
|
@ -58,7 +59,9 @@ class ServiceProviders(str, Enum):
|
|||
GLADIA = "gladia"
|
||||
RIME = "rime"
|
||||
MINIMAX = "minimax"
|
||||
GOOGLE_VERTEX = "google_vertex"
|
||||
OPENAI_REALTIME = "openai_realtime"
|
||||
GROK_REALTIME = "grok_realtime"
|
||||
GOOGLE_REALTIME = "google_realtime"
|
||||
GOOGLE_VERTEX_REALTIME = "google_vertex_realtime"
|
||||
|
||||
|
|
@ -79,7 +82,9 @@ class BaseServiceConfiguration(BaseModel):
|
|||
ServiceProviders.GLADIA,
|
||||
ServiceProviders.RIME,
|
||||
ServiceProviders.MINIMAX,
|
||||
ServiceProviders.GOOGLE_VERTEX,
|
||||
ServiceProviders.OPENAI_REALTIME,
|
||||
ServiceProviders.GROK_REALTIME,
|
||||
ServiceProviders.GOOGLE_REALTIME,
|
||||
ServiceProviders.GOOGLE_VERTEX_REALTIME,
|
||||
# ServiceProviders.SARVAM,
|
||||
|
|
@ -206,7 +211,9 @@ OPENROUTER_PROVIDER_MODEL_CONFIG = provider_model_config("Open Router")
|
|||
AZURE_OPENAI_PROVIDER_MODEL_CONFIG = provider_model_config("Azure OpenAI")
|
||||
DOGRAH_PROVIDER_MODEL_CONFIG = provider_model_config("Dograh")
|
||||
AWS_BEDROCK_PROVIDER_MODEL_CONFIG = provider_model_config("AWS Bedrock")
|
||||
GOOGLE_VERTEX_PROVIDER_MODEL_CONFIG = provider_model_config("Google Vertex")
|
||||
OPENAI_REALTIME_PROVIDER_MODEL_CONFIG = provider_model_config("OpenAI Realtime")
|
||||
GROK_REALTIME_PROVIDER_MODEL_CONFIG = provider_model_config("Grok Realtime")
|
||||
GOOGLE_REALTIME_PROVIDER_MODEL_CONFIG = provider_model_config("Google Realtime")
|
||||
GOOGLE_VERTEX_REALTIME_PROVIDER_MODEL_CONFIG = provider_model_config(
|
||||
"Google Vertex Realtime"
|
||||
|
|
@ -239,6 +246,7 @@ OPENAI_MODELS = [
|
|||
"gpt-5-nano",
|
||||
"gpt-3.5-turbo",
|
||||
]
|
||||
|
||||
GROQ_MODELS = [
|
||||
"llama-3.3-70b-versatile",
|
||||
"deepseek-r1-distill-llama-70b",
|
||||
|
|
@ -292,6 +300,40 @@ class GoogleLLMService(BaseLLMConfiguration):
|
|||
)
|
||||
|
||||
|
||||
@register_llm
|
||||
class GoogleVertexLLMConfiguration(BaseLLMConfiguration):
|
||||
model_config = GOOGLE_VERTEX_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.GOOGLE_VERTEX] = ServiceProviders.GOOGLE_VERTEX
|
||||
model: str = Field(
|
||||
default="gemini-2.5-flash",
|
||||
description="Gemini model on Vertex AI.",
|
||||
json_schema_extra={
|
||||
"examples": GOOGLE_VERTEX_MODELS,
|
||||
"allow_custom_input": True,
|
||||
},
|
||||
)
|
||||
project_id: str = Field(description="Google Cloud project ID for Vertex AI.")
|
||||
location: str = Field(
|
||||
default="global",
|
||||
description="GCP region for the Vertex AI endpoint (e.g. 'global').",
|
||||
)
|
||||
credentials: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Paste the entire service-account JSON file contents. If omitted, "
|
||||
"falls back to Application Default Credentials (ADC)."
|
||||
),
|
||||
json_schema_extra={"multiline": True},
|
||||
)
|
||||
api_key: str | list[str] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Not used for Vertex AI — authentication is via the service account "
|
||||
"in `credentials` (or ADC). Leave blank."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@register_llm
|
||||
class GroqLLMService(BaseLLMConfiguration):
|
||||
model_config = GROQ_PROVIDER_MODEL_CONFIG
|
||||
|
|
@ -460,6 +502,32 @@ class OpenAIRealtimeLLMConfiguration(BaseLLMConfiguration):
|
|||
)
|
||||
|
||||
|
||||
GROK_REALTIME_MODELS = ["grok-voice-think-fast-1.0"]
|
||||
GROK_REALTIME_VOICES = ["Ara", "Rex", "Sal", "Eve", "Leo"]
|
||||
|
||||
|
||||
@register_service(ServiceType.REALTIME)
|
||||
class GrokRealtimeLLMConfiguration(BaseLLMConfiguration):
|
||||
model_config = GROK_REALTIME_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.GROK_REALTIME] = ServiceProviders.GROK_REALTIME
|
||||
model: str = Field(
|
||||
default="grok-voice-think-fast-1.0",
|
||||
description="Grok realtime voice-agent model.",
|
||||
json_schema_extra={
|
||||
"examples": GROK_REALTIME_MODELS,
|
||||
"allow_custom_input": True,
|
||||
},
|
||||
)
|
||||
voice: str = Field(
|
||||
default="Ara",
|
||||
description="Voice the model speaks in.",
|
||||
json_schema_extra={
|
||||
"examples": GROK_REALTIME_VOICES,
|
||||
"allow_custom_input": True,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@register_service(ServiceType.REALTIME)
|
||||
class GoogleRealtimeLLMConfiguration(BaseLLMConfiguration):
|
||||
model_config = GOOGLE_REALTIME_PROVIDER_MODEL_CONFIG
|
||||
|
|
@ -524,8 +592,8 @@ class GoogleVertexRealtimeLLMConfiguration(BaseLLMConfiguration):
|
|||
)
|
||||
project_id: str = Field(description="Google Cloud project ID for Vertex AI.")
|
||||
location: str = Field(
|
||||
default="us-east4",
|
||||
description="GCP region for the Vertex AI endpoint (e.g. 'us-east4').",
|
||||
default="global",
|
||||
description="GCP region for the Vertex AI endpoint (e.g. 'global').",
|
||||
)
|
||||
credentials: str | None = Field(
|
||||
default=None,
|
||||
|
|
@ -546,6 +614,7 @@ class GoogleVertexRealtimeLLMConfiguration(BaseLLMConfiguration):
|
|||
|
||||
REALTIME_PROVIDERS = {
|
||||
ServiceProviders.OPENAI_REALTIME.value,
|
||||
ServiceProviders.GROK_REALTIME.value,
|
||||
ServiceProviders.GOOGLE_REALTIME.value,
|
||||
ServiceProviders.GOOGLE_VERTEX_REALTIME.value,
|
||||
}
|
||||
|
|
@ -554,6 +623,7 @@ REALTIME_PROVIDERS = {
|
|||
LLMConfig = Annotated[
|
||||
Union[
|
||||
OpenAILLMService,
|
||||
GoogleVertexLLMConfiguration,
|
||||
GroqLLMService,
|
||||
OpenRouterLLMConfiguration,
|
||||
GoogleLLMService,
|
||||
|
|
@ -569,6 +639,7 @@ LLMConfig = Annotated[
|
|||
RealtimeConfig = Annotated[
|
||||
Union[
|
||||
OpenAIRealtimeLLMConfiguration,
|
||||
GrokRealtimeLLMConfiguration,
|
||||
GoogleRealtimeLLMConfiguration,
|
||||
GoogleVertexRealtimeLLMConfiguration,
|
||||
],
|
||||
|
|
|
|||
253
api/services/pipecat/realtime/grok_realtime.py
Normal file
253
api/services/pipecat/realtime/grok_realtime.py
Normal file
|
|
@ -0,0 +1,253 @@
|
|||
"""Dograh subclass of pipecat's Grok Realtime LLM service.
|
||||
|
||||
Layers Dograh engine integration quirks onto upstream-pristine
|
||||
:class:`GrokRealtimeLLMService`. Grok already supports runtime session updates,
|
||||
so this wrapper stays close to the OpenAI realtime shim.
|
||||
|
||||
Adds:
|
||||
|
||||
- **User-mute audio gating** via ``UserMuteStarted/StoppedFrame``.
|
||||
- **TTSSpeakFrame as initial-response trigger** so the engine's greeting
|
||||
flow kicks off the bot's first response.
|
||||
- **One-off LLMMessagesAppendFrame handling** for ephemeral realtime prompts
|
||||
like user-idle checks, without mutating Dograh's local ``LLMContext``.
|
||||
- **Function-call deferral** until the bot finishes speaking, to avoid racing
|
||||
tool execution with the active audio turn.
|
||||
- **finalized=True on TranscriptionFrame** for parity with Dograh's other
|
||||
realtime providers.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
Frame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesAppendFrame,
|
||||
TranscriptionFrame,
|
||||
TTSSpeakFrame,
|
||||
UserMuteStartedFrame,
|
||||
UserMuteStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.llm_service import FunctionCallFromLLM
|
||||
from pipecat.services.xai.realtime import events
|
||||
from pipecat.services.xai.realtime.llm import GrokRealtimeLLMService
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
|
||||
class DograhGrokRealtimeLLMService(GrokRealtimeLLMService):
|
||||
"""Grok Realtime with Dograh engine integration quirks."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._user_is_muted: bool = False
|
||||
self._handled_initial_context: bool = False
|
||||
self._bot_is_speaking: bool = False
|
||||
self._deferred_function_calls: list[FunctionCallFromLLM] = []
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
if isinstance(frame, UserMuteStartedFrame):
|
||||
self._user_is_muted = True
|
||||
await self.push_frame(frame, direction)
|
||||
return
|
||||
if isinstance(frame, UserMuteStoppedFrame):
|
||||
self._user_is_muted = False
|
||||
await self.push_frame(frame, direction)
|
||||
return
|
||||
if isinstance(frame, TTSSpeakFrame):
|
||||
if not self._handled_initial_context:
|
||||
await self._handle_context(self._context)
|
||||
else:
|
||||
logger.warning(
|
||||
f"{self}: TTSSpeakFrame after initial context already "
|
||||
"handled — Grok Realtime owns audio generation, ignoring"
|
||||
)
|
||||
return
|
||||
if isinstance(frame, LLMMessagesAppendFrame):
|
||||
await self._handle_messages_append(frame)
|
||||
return
|
||||
if isinstance(frame, BotStartedSpeakingFrame):
|
||||
self._bot_is_speaking = True
|
||||
elif isinstance(frame, BotStoppedSpeakingFrame):
|
||||
self._bot_is_speaking = False
|
||||
await self._run_pending_function_calls()
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
async def _handle_messages_append(self, frame: LLMMessagesAppendFrame):
|
||||
"""Consume a one-off append frame without mutating the local LLMContext."""
|
||||
if self._disconnecting:
|
||||
return
|
||||
|
||||
if not self._api_session_ready:
|
||||
if frame.run_llm:
|
||||
logger.debug(
|
||||
f"{self}: LLMMessagesAppendFrame received before session ready; "
|
||||
"deferring response until the session is initialized"
|
||||
)
|
||||
self._run_llm_when_api_session_ready = True
|
||||
return
|
||||
|
||||
appended_any = False
|
||||
for message in frame.messages:
|
||||
item = self._message_to_conversation_item(message)
|
||||
if item is None:
|
||||
continue
|
||||
evt = events.ConversationItemCreateEvent(item=item)
|
||||
self._messages_added_manually[evt.item.id] = True
|
||||
await self.send_client_event(evt)
|
||||
appended_any = True
|
||||
|
||||
if frame.run_llm and appended_any:
|
||||
await self._send_manual_response_create()
|
||||
|
||||
async def _handle_context(self, context: LLMContext):
|
||||
if not self._handled_initial_context:
|
||||
if context is None:
|
||||
logger.warning(
|
||||
f"{self}: received initial context trigger before context was set"
|
||||
)
|
||||
return
|
||||
self._handled_initial_context = True
|
||||
self._context = context
|
||||
await self._create_response()
|
||||
else:
|
||||
self._context = context
|
||||
await self._process_completed_function_calls(send_new_results=True)
|
||||
|
||||
async def _send_user_audio(self, frame):
|
||||
if self._user_is_muted:
|
||||
return
|
||||
await super()._send_user_audio(frame)
|
||||
|
||||
def _message_to_conversation_item(
|
||||
self, message: dict[str, Any]
|
||||
) -> events.ConversationItem | None:
|
||||
if not isinstance(message, dict):
|
||||
logger.warning(
|
||||
f"{self}: skipping unsupported appended message payload {message!r}"
|
||||
)
|
||||
return None
|
||||
|
||||
role = message.get("role")
|
||||
if role not in {"user", "system", "developer"}:
|
||||
logger.warning(
|
||||
f"{self}: skipping unsupported appended message role {role!r}"
|
||||
)
|
||||
return None
|
||||
|
||||
text = self._extract_text_content(message.get("content"))
|
||||
if not text:
|
||||
logger.warning(
|
||||
f"{self}: skipping appended message with unsupported content {message!r}"
|
||||
)
|
||||
return None
|
||||
|
||||
item_role = "system" if role in {"system", "developer"} else "user"
|
||||
return events.ConversationItem(
|
||||
type="message",
|
||||
role=item_role,
|
||||
content=[events.ItemContent(type="input_text", text=text)],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_text_content(content: Any) -> str | None:
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for part in content:
|
||||
if not isinstance(part, dict):
|
||||
return None
|
||||
if part.get("type") != "text":
|
||||
return None
|
||||
text = part.get("text")
|
||||
if not isinstance(text, str):
|
||||
return None
|
||||
parts.append(text)
|
||||
return "\n".join(parts) if parts else None
|
||||
return None
|
||||
|
||||
async def _send_manual_response_create(self):
|
||||
"""Trigger inference after manually appending conversation items."""
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
await self.start_processing_metrics()
|
||||
await self.start_ttfb_metrics()
|
||||
await self.send_client_event(
|
||||
events.ResponseCreateEvent(
|
||||
response=events.ResponseProperties(modalities=["text", "audio"])
|
||||
)
|
||||
)
|
||||
|
||||
async def _run_pending_function_calls(self):
|
||||
if not self._deferred_function_calls:
|
||||
return
|
||||
function_calls = self._deferred_function_calls
|
||||
self._deferred_function_calls = []
|
||||
logger.debug(
|
||||
f"{self}: executing {len(function_calls)} deferred function call(s) "
|
||||
"after bot turn ended"
|
||||
)
|
||||
await self.run_function_calls(function_calls)
|
||||
|
||||
async def _handle_evt_function_call_arguments_done(self, evt):
|
||||
"""Process or defer tool calls until the bot finishes speaking."""
|
||||
try:
|
||||
args = json.loads(evt.arguments)
|
||||
|
||||
function_call_item = self._pending_function_calls.get(evt.call_id)
|
||||
if function_call_item:
|
||||
del self._pending_function_calls[evt.call_id]
|
||||
|
||||
function_name = getattr(evt, "name", None) or function_call_item.name
|
||||
function_calls = [
|
||||
FunctionCallFromLLM(
|
||||
context=self._context,
|
||||
tool_call_id=evt.call_id,
|
||||
function_name=function_name,
|
||||
arguments=args,
|
||||
)
|
||||
]
|
||||
|
||||
if self._bot_is_speaking:
|
||||
self._deferred_function_calls.extend(function_calls)
|
||||
logger.debug(
|
||||
f"{self}: deferring function call {function_name} "
|
||||
"until bot stops speaking"
|
||||
)
|
||||
else:
|
||||
await self.run_function_calls(function_calls)
|
||||
logger.debug(f"Processed function call: {function_name}")
|
||||
else:
|
||||
logger.warning(
|
||||
f"No tracked function call found for call_id: {evt.call_id}"
|
||||
)
|
||||
logger.warning(
|
||||
f"Available pending calls: {list(self._pending_function_calls.keys())}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process function call arguments: {e}")
|
||||
|
||||
async def _handle_evt_input_audio_transcription_completed(self, evt):
|
||||
await self._call_event_handler(
|
||||
"on_conversation_item_updated", evt.item_id, None
|
||||
)
|
||||
|
||||
transcript = evt.transcript.strip() if evt.transcript else ""
|
||||
if not transcript:
|
||||
return
|
||||
|
||||
await self.broadcast_frame(
|
||||
TranscriptionFrame,
|
||||
text=transcript,
|
||||
user_id="",
|
||||
timestamp=time_now_iso8601(),
|
||||
result=evt,
|
||||
finalized=True,
|
||||
)
|
||||
|
|
@ -120,6 +120,16 @@ def _create_realtime_user_turn_config(provider: str):
|
|||
),
|
||||
None,
|
||||
)
|
||||
if provider == ServiceProviders.GROK_REALTIME.value:
|
||||
# Grok Voice Agent emits server-side speech-start/stop and
|
||||
# interruption signals, so local VAD should stay out of the way.
|
||||
return (
|
||||
UserTurnStrategies(
|
||||
start=[ExternalUserTurnStartStrategy()],
|
||||
stop=[ExternalUserTurnStopStrategy()],
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
return (
|
||||
UserTurnStrategies(
|
||||
|
|
|
|||
|
|
@ -30,7 +30,13 @@ from pipecat.services.gladia.stt import GladiaSTTService, GladiaSTTSettings
|
|||
from pipecat.services.google.llm import GoogleLLMService, GoogleLLMSettings
|
||||
from pipecat.services.google.stt import GoogleSTTService, GoogleSTTSettings
|
||||
from pipecat.services.google.tts import GoogleTTSService, GoogleTTSSettings
|
||||
from pipecat.services.google.vertex.llm import (
|
||||
GoogleVertexLLMService,
|
||||
GoogleVertexLLMSettings,
|
||||
)
|
||||
from pipecat.services.groq.llm import GroqLLMService, GroqLLMSettings
|
||||
from pipecat.services.minimax.llm import MiniMaxLLMService
|
||||
from pipecat.services.minimax.tts import MiniMaxTTSSettings
|
||||
from pipecat.services.openai.base_llm import OpenAILLMSettings
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.services.openai.stt import (
|
||||
|
|
@ -40,8 +46,6 @@ from pipecat.services.openai.stt import (
|
|||
from pipecat.services.openai.tts import OpenAITTSService, OpenAITTSSettings
|
||||
from pipecat.services.openrouter.llm import OpenRouterLLMService, OpenRouterLLMSettings
|
||||
from pipecat.services.rime.tts import RimeTTSService, RimeTTSSettings
|
||||
from pipecat.services.minimax.llm import MiniMaxLLMService
|
||||
from pipecat.services.minimax.tts import MiniMaxHttpTTSService, MiniMaxTTSSettings
|
||||
from pipecat.services.sarvam.stt import SarvamSTTService, SarvamSTTSettings
|
||||
from pipecat.services.sarvam.tts import SarvamTTSService, SarvamTTSSettings
|
||||
from pipecat.services.speaches.llm import SpeachesLLMService, SpeachesLLMSettings
|
||||
|
|
@ -482,13 +486,16 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
|||
def create_llm_service_from_provider(
|
||||
provider: str,
|
||||
model: str,
|
||||
api_key: str,
|
||||
api_key: str | None,
|
||||
*,
|
||||
base_url: str | None = None,
|
||||
endpoint: str | None = None,
|
||||
aws_access_key: str | None = None,
|
||||
aws_secret_key: str | None = None,
|
||||
aws_region: str | None = None,
|
||||
project_id: str | None = None,
|
||||
location: str | None = None,
|
||||
credentials: str | None = None,
|
||||
temperature: float | None = None,
|
||||
):
|
||||
"""Create an LLM service from explicit provider/model/api_key.
|
||||
|
|
@ -528,6 +535,13 @@ def create_llm_service_from_provider(
|
|||
api_key=api_key,
|
||||
settings=GoogleLLMSettings(model=model, temperature=0.1),
|
||||
)
|
||||
elif provider == ServiceProviders.GOOGLE_VERTEX.value:
|
||||
return GoogleVertexLLMService(
|
||||
credentials=credentials,
|
||||
project_id=project_id,
|
||||
location=location or "us-east4",
|
||||
settings=GoogleVertexLLMSettings(model=model, temperature=0.1),
|
||||
)
|
||||
elif provider == ServiceProviders.AZURE.value:
|
||||
return AzureLLMService(
|
||||
api_key=api_key,
|
||||
|
|
@ -611,6 +625,21 @@ def create_realtime_llm_service(user_config, audio_config: "AudioConfig"):
|
|||
),
|
||||
),
|
||||
)
|
||||
elif provider == ServiceProviders.GROK_REALTIME.value:
|
||||
from api.services.pipecat.realtime.grok_realtime import (
|
||||
DograhGrokRealtimeLLMService,
|
||||
)
|
||||
from pipecat.services.xai.realtime.events import SessionProperties
|
||||
|
||||
return DograhGrokRealtimeLLMService(
|
||||
api_key=api_key,
|
||||
settings=DograhGrokRealtimeLLMService.Settings(
|
||||
model=model,
|
||||
session_properties=SessionProperties(
|
||||
voice=voice or "Ara",
|
||||
),
|
||||
),
|
||||
)
|
||||
elif provider == ServiceProviders.GOOGLE_REALTIME.value:
|
||||
from api.services.pipecat.realtime.gemini_live import (
|
||||
DograhGeminiLiveLLMService,
|
||||
|
|
@ -672,6 +701,10 @@ def create_llm_service(user_config):
|
|||
kwargs["aws_access_key"] = user_config.llm.aws_access_key
|
||||
kwargs["aws_secret_key"] = user_config.llm.aws_secret_key
|
||||
kwargs["aws_region"] = user_config.llm.aws_region
|
||||
elif provider == ServiceProviders.GOOGLE_VERTEX.value:
|
||||
kwargs["project_id"] = user_config.llm.project_id
|
||||
kwargs["location"] = user_config.llm.location
|
||||
kwargs["credentials"] = user_config.llm.credentials
|
||||
elif provider == ServiceProviders.MINIMAX.value:
|
||||
kwargs["base_url"] = user_config.llm.base_url
|
||||
kwargs["temperature"] = user_config.llm.temperature
|
||||
|
|
|
|||
103
api/tests/test_google_vertex_llm_service_factory.py
Normal file
103
api/tests/test_google_vertex_llm_service_factory.py
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from api.services.configuration.check_validity import UserConfigurationValidator
|
||||
from api.services.configuration.registry import (
|
||||
REGISTRY,
|
||||
GoogleVertexLLMConfiguration,
|
||||
ServiceProviders,
|
||||
ServiceType,
|
||||
)
|
||||
from api.services.pipecat.service_factory import (
|
||||
create_llm_service,
|
||||
create_llm_service_from_provider,
|
||||
)
|
||||
|
||||
|
||||
class TestGoogleVertexLLMConfiguration:
|
||||
def test_defaults(self):
|
||||
config = GoogleVertexLLMConfiguration(project_id="demo-project")
|
||||
assert config.provider == ServiceProviders.GOOGLE_VERTEX
|
||||
assert config.model == "gemini-2.5-flash"
|
||||
assert config.location == "us-east4"
|
||||
assert config.credentials is None
|
||||
assert config.api_key is None
|
||||
|
||||
def test_registered_in_llm_registry(self):
|
||||
assert ServiceProviders.GOOGLE_VERTEX in REGISTRY[ServiceType.LLM]
|
||||
assert (
|
||||
REGISTRY[ServiceType.LLM][ServiceProviders.GOOGLE_VERTEX]
|
||||
is GoogleVertexLLMConfiguration
|
||||
)
|
||||
|
||||
|
||||
class TestGoogleVertexLLMServiceFactory:
|
||||
def test_create_llm_service_from_provider_uses_vertex_service(self):
|
||||
with patch(
|
||||
"api.services.pipecat.service_factory.GoogleVertexLLMService"
|
||||
) as mock_service:
|
||||
create_llm_service_from_provider(
|
||||
provider=ServiceProviders.GOOGLE_VERTEX.value,
|
||||
model="gemini-2.5-pro",
|
||||
api_key=None,
|
||||
project_id="demo-project",
|
||||
location="us-central1",
|
||||
credentials='{"type":"service_account"}',
|
||||
)
|
||||
|
||||
kwargs = mock_service.call_args.kwargs
|
||||
assert kwargs["project_id"] == "demo-project"
|
||||
assert kwargs["location"] == "us-central1"
|
||||
assert kwargs["credentials"] == '{"type":"service_account"}'
|
||||
assert kwargs["settings"].model == "gemini-2.5-pro"
|
||||
assert kwargs["settings"].temperature == 0.1
|
||||
|
||||
def test_create_llm_service_extracts_vertex_credentials(self):
|
||||
user_config = SimpleNamespace(
|
||||
llm=SimpleNamespace(
|
||||
provider=ServiceProviders.GOOGLE_VERTEX.value,
|
||||
api_key=None,
|
||||
model="gemini-2.5-flash",
|
||||
project_id="demo-project",
|
||||
location="us-east4",
|
||||
credentials='{"type":"service_account"}',
|
||||
)
|
||||
)
|
||||
|
||||
with patch(
|
||||
"api.services.pipecat.service_factory.GoogleVertexLLMService"
|
||||
) as mock_service:
|
||||
create_llm_service(user_config)
|
||||
|
||||
kwargs = mock_service.call_args.kwargs
|
||||
assert kwargs["project_id"] == "demo-project"
|
||||
assert kwargs["location"] == "us-east4"
|
||||
assert kwargs["credentials"] == '{"type":"service_account"}'
|
||||
|
||||
|
||||
class TestGoogleVertexLLMValidation:
|
||||
def test_validator_accepts_vertex_llm_without_api_key(self):
|
||||
validator = UserConfigurationValidator()
|
||||
config = GoogleVertexLLMConfiguration(
|
||||
project_id="demo-project",
|
||||
location="us-east4",
|
||||
credentials='{"type":"service_account"}',
|
||||
)
|
||||
|
||||
assert validator._validate_service(config, "llm") == []
|
||||
|
||||
def test_validator_requires_project_id(self):
|
||||
validator = UserConfigurationValidator()
|
||||
config = SimpleNamespace(
|
||||
provider=ServiceProviders.GOOGLE_VERTEX.value,
|
||||
project_id=None,
|
||||
location="us-east4",
|
||||
credentials='{"type":"service_account"}',
|
||||
api_key=None,
|
||||
)
|
||||
|
||||
result = validator._validate_service(config, "llm")
|
||||
|
||||
assert result == [
|
||||
{"model": "llm", "message": "project_id is required for Google Vertex"}
|
||||
]
|
||||
138
api/tests/test_grok_realtime_wrapper.py
Normal file
138
api/tests/test_grok_realtime_wrapper.py
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from pipecat.frames.frames import LLMMessagesAppendFrame, TTSSpeakFrame
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.xai.realtime import events
|
||||
|
||||
from api.schemas.user_configuration import UserConfiguration
|
||||
from api.services.configuration.registry import GrokRealtimeLLMConfiguration
|
||||
from api.services.pipecat.realtime.grok_realtime import (
|
||||
DograhGrokRealtimeLLMService,
|
||||
)
|
||||
from api.services.pipecat.service_factory import create_realtime_llm_service
|
||||
|
||||
|
||||
def _make_service() -> DograhGrokRealtimeLLMService:
|
||||
service = DograhGrokRealtimeLLMService(api_key="test-key")
|
||||
service._create_response = AsyncMock()
|
||||
service._process_completed_function_calls = AsyncMock()
|
||||
return service
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initial_context_triggers_response_when_context_was_prepopulated():
|
||||
service = _make_service()
|
||||
context = LLMContext()
|
||||
service._context = context
|
||||
|
||||
await service._handle_context(context)
|
||||
|
||||
assert service._handled_initial_context is True
|
||||
assert service._context is context
|
||||
service._create_response.assert_awaited_once()
|
||||
service._process_completed_function_calls.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tts_greeting_uses_initial_context_handler():
|
||||
service = _make_service()
|
||||
service._context = LLMContext()
|
||||
service._handle_context = AsyncMock()
|
||||
|
||||
await service.process_frame(
|
||||
TTSSpeakFrame("hello", append_to_context=True),
|
||||
FrameDirection.DOWNSTREAM,
|
||||
)
|
||||
|
||||
service._handle_context.assert_awaited_once_with(service._context)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_messages_append_frame_sends_conversation_item():
|
||||
service = _make_service()
|
||||
service._api_session_ready = True
|
||||
service.send_client_event = AsyncMock()
|
||||
service._send_manual_response_create = AsyncMock()
|
||||
|
||||
await service._handle_messages_append(
|
||||
LLMMessagesAppendFrame(
|
||||
[{"role": "user", "content": "Are you still there?"}],
|
||||
run_llm=True,
|
||||
)
|
||||
)
|
||||
|
||||
service.send_client_event.assert_awaited_once()
|
||||
event = service.send_client_event.await_args.args[0]
|
||||
assert isinstance(event, events.ConversationItemCreateEvent)
|
||||
assert event.item.role == "user"
|
||||
assert event.item.type == "message"
|
||||
assert event.item.content == [
|
||||
events.ItemContent(type="input_text", text="Are you still there?")
|
||||
]
|
||||
service._send_manual_response_create.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_function_call_is_deferred_until_bot_stops_speaking():
|
||||
service = _make_service()
|
||||
service._context = LLMContext()
|
||||
service.run_function_calls = AsyncMock()
|
||||
service._bot_is_speaking = True
|
||||
service._pending_function_calls["call-1"] = SimpleNamespace(name="customer_support")
|
||||
|
||||
await service._handle_evt_function_call_arguments_done(
|
||||
SimpleNamespace(
|
||||
call_id="call-1",
|
||||
name="customer_support",
|
||||
arguments='{"department":"sales"}',
|
||||
)
|
||||
)
|
||||
|
||||
service.run_function_calls.assert_not_awaited()
|
||||
assert len(service._deferred_function_calls) == 1
|
||||
|
||||
await service._run_pending_function_calls()
|
||||
|
||||
service.run_function_calls.assert_awaited_once()
|
||||
assert service._deferred_function_calls == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completed_input_transcription_is_broadcast_as_finalized():
|
||||
service = _make_service()
|
||||
service._call_event_handler = AsyncMock()
|
||||
service.broadcast_frame = AsyncMock()
|
||||
|
||||
evt = SimpleNamespace(item_id="item-1", transcript="Hello there")
|
||||
|
||||
await service._handle_evt_input_audio_transcription_completed(evt)
|
||||
|
||||
service._call_event_handler.assert_awaited_once_with(
|
||||
"on_conversation_item_updated", "item-1", None
|
||||
)
|
||||
service.broadcast_frame.assert_awaited_once()
|
||||
assert service.broadcast_frame.await_args.args[0].__name__ == "TranscriptionFrame"
|
||||
assert service.broadcast_frame.await_args.kwargs["text"] == "Hello there"
|
||||
assert service.broadcast_frame.await_args.kwargs["finalized"] is True
|
||||
|
||||
|
||||
def test_factory_creates_dograh_grok_realtime_service():
|
||||
user_config = UserConfiguration(
|
||||
is_realtime=True,
|
||||
realtime=GrokRealtimeLLMConfiguration(
|
||||
provider="grok_realtime",
|
||||
api_key="xai-key",
|
||||
model="grok-voice-think-fast-1.0",
|
||||
voice="Sal",
|
||||
),
|
||||
)
|
||||
|
||||
service = create_realtime_llm_service(
|
||||
user_config,
|
||||
audio_config=SimpleNamespace(),
|
||||
)
|
||||
|
||||
assert isinstance(service, DograhGrokRealtimeLLMService)
|
||||
|
|
@ -9,6 +9,7 @@ from api.services.auth.depends import get_user
|
|||
from api.services.configuration.masking import mask_key
|
||||
from api.services.configuration.registry import (
|
||||
GoogleLLMService,
|
||||
GoogleVertexLLMConfiguration,
|
||||
OpenAILLMService,
|
||||
)
|
||||
|
||||
|
|
@ -168,3 +169,44 @@ class TestMaskedKeyRejection:
|
|||
# Merge resolves the masked key back to the real one,
|
||||
# so check_for_masked_keys should NOT raise.
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_allows_same_provider_with_masked_vertex_credentials(self):
|
||||
"""Same provider with masked credentials should succeed."""
|
||||
app = _make_test_app()
|
||||
client = TestClient(app)
|
||||
|
||||
real_credentials = '{"type":"service_account","project_id":"demo-project"}'
|
||||
masked_credentials = mask_key(real_credentials)
|
||||
existing = UserConfiguration(
|
||||
llm=GoogleVertexLLMConfiguration(
|
||||
provider="google_vertex",
|
||||
api_key=None,
|
||||
model="gemini-2.5-flash",
|
||||
project_id="demo-project",
|
||||
location="us-east4",
|
||||
credentials=real_credentials,
|
||||
)
|
||||
)
|
||||
|
||||
with (
|
||||
patch("api.routes.user.db_client") as mock_db,
|
||||
patch("api.routes.user.UserConfigurationValidator") as mock_validator,
|
||||
):
|
||||
mock_db.get_user_configurations = AsyncMock(return_value=existing)
|
||||
mock_db.update_user_configuration = AsyncMock(return_value=existing)
|
||||
mock_validator.return_value.validate = AsyncMock()
|
||||
|
||||
response = client.put(
|
||||
"/user/configurations/user",
|
||||
json={
|
||||
"llm": {
|
||||
"provider": "google_vertex",
|
||||
"model": "gemini-2.5-flash",
|
||||
"project_id": "demo-project",
|
||||
"location": "us-east4",
|
||||
"credentials": masked_credentials,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
|
|
|||
|
|
@ -109,11 +109,12 @@ class TestMiniMaxTTSServiceFactory:
|
|||
)
|
||||
audio_config = SimpleNamespace(transport_in_sample_rate=16000)
|
||||
|
||||
with patch(
|
||||
"api.services.pipecat.service_factory.aiohttp.ClientSession"
|
||||
), patch(
|
||||
"api.services.pipecat.service_factory.MiniMaxOwnedSessionTTSService"
|
||||
) as mock_service:
|
||||
with (
|
||||
patch("api.services.pipecat.service_factory.aiohttp.ClientSession"),
|
||||
patch(
|
||||
"api.services.pipecat.service_factory.MiniMaxOwnedSessionTTSService"
|
||||
) as mock_service,
|
||||
):
|
||||
create_tts_service(user_config, audio_config)
|
||||
|
||||
assert mock_service.call_count == 1
|
||||
|
|
|
|||
|
|
@ -14,6 +14,8 @@ from api.services.configuration.registry import (
|
|||
DeepgramSTTConfiguration,
|
||||
ElevenlabsTTSConfiguration,
|
||||
GoogleRealtimeLLMConfiguration,
|
||||
GoogleVertexLLMConfiguration,
|
||||
GrokRealtimeLLMConfiguration,
|
||||
OpenAILLMService,
|
||||
)
|
||||
from api.services.configuration.resolve import resolve_effective_config
|
||||
|
|
@ -164,6 +166,23 @@ class TestProviderChange:
|
|||
assert result.tts.provider == "elevenlabs"
|
||||
assert result.stt.provider == "deepgram"
|
||||
|
||||
def test_override_llm_to_google_vertex(self, global_config):
|
||||
result = resolve_effective_config(
|
||||
global_config,
|
||||
{
|
||||
"llm": {
|
||||
"provider": "google_vertex",
|
||||
"model": "gemini-2.5-flash",
|
||||
"project_id": "demo-project",
|
||||
"location": "us-east4",
|
||||
"credentials": '{"type":"service_account"}',
|
||||
}
|
||||
},
|
||||
)
|
||||
assert isinstance(result.llm, GoogleVertexLLMConfiguration)
|
||||
assert result.llm.provider == "google_vertex"
|
||||
assert result.llm.project_id == "demo-project"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# API key inheritance
|
||||
|
|
@ -226,6 +245,22 @@ class TestRealtimeOverride:
|
|||
assert result.realtime.provider == "google_realtime" # inherited
|
||||
assert result.realtime.api_key == "goog-global-rt" # inherited
|
||||
|
||||
def test_switch_realtime_provider_to_grok(self, global_config_realtime):
|
||||
result = resolve_effective_config(
|
||||
global_config_realtime,
|
||||
{
|
||||
"realtime": {
|
||||
"provider": "grok_realtime",
|
||||
"api_key": "xai-key",
|
||||
"model": "grok-voice-think-fast-1.0",
|
||||
"voice": "Sal",
|
||||
}
|
||||
},
|
||||
)
|
||||
assert isinstance(result.realtime, GrokRealtimeLLMConfiguration)
|
||||
assert result.realtime.provider == "grok_realtime"
|
||||
assert result.realtime.voice == "Sal"
|
||||
|
||||
def test_override_is_realtime_only_without_realtime_section(self, global_config):
|
||||
"""Override is_realtime=True but provide no realtime config.
|
||||
Should set the flag; realtime section stays None from global."""
|
||||
|
|
|
|||
|
|
@ -51,6 +51,19 @@ def test_openai_realtime_uses_provider_turn_frames_without_local_vad():
|
|||
assert isinstance(strategies.stop[0], ExternalUserTurnStopStrategy)
|
||||
|
||||
|
||||
def test_grok_realtime_uses_provider_turn_frames_without_local_vad():
|
||||
strategies, vad_analyzer = _create_realtime_user_turn_config(
|
||||
ServiceProviders.GROK_REALTIME.value
|
||||
)
|
||||
|
||||
assert vad_analyzer is None
|
||||
assert len(strategies.start) == 1
|
||||
assert isinstance(strategies.start[0], ExternalUserTurnStartStrategy)
|
||||
assert strategies.start[0]._enable_interruptions is False
|
||||
assert len(strategies.stop) == 1
|
||||
assert isinstance(strategies.stop[0], ExternalUserTurnStopStrategy)
|
||||
|
||||
|
||||
def test_unknown_realtime_providers_keep_local_vad():
|
||||
strategies, vad_analyzer = _create_realtime_user_turn_config("other_realtime")
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue