feat: add xai grok as realtime model

This commit is contained in:
Abhishek Kumar 2026-05-22 18:04:59 +05:30
parent 291264de7b
commit 9135c2da13
14 changed files with 776 additions and 36 deletions

View file

@ -0,0 +1,103 @@
from types import SimpleNamespace
from unittest.mock import patch
from api.services.configuration.check_validity import UserConfigurationValidator
from api.services.configuration.registry import (
REGISTRY,
GoogleVertexLLMConfiguration,
ServiceProviders,
ServiceType,
)
from api.services.pipecat.service_factory import (
create_llm_service,
create_llm_service_from_provider,
)
class TestGoogleVertexLLMConfiguration:
def test_defaults(self):
config = GoogleVertexLLMConfiguration(project_id="demo-project")
assert config.provider == ServiceProviders.GOOGLE_VERTEX
assert config.model == "gemini-2.5-flash"
assert config.location == "us-east4"
assert config.credentials is None
assert config.api_key is None
def test_registered_in_llm_registry(self):
assert ServiceProviders.GOOGLE_VERTEX in REGISTRY[ServiceType.LLM]
assert (
REGISTRY[ServiceType.LLM][ServiceProviders.GOOGLE_VERTEX]
is GoogleVertexLLMConfiguration
)
class TestGoogleVertexLLMServiceFactory:
def test_create_llm_service_from_provider_uses_vertex_service(self):
with patch(
"api.services.pipecat.service_factory.GoogleVertexLLMService"
) as mock_service:
create_llm_service_from_provider(
provider=ServiceProviders.GOOGLE_VERTEX.value,
model="gemini-2.5-pro",
api_key=None,
project_id="demo-project",
location="us-central1",
credentials='{"type":"service_account"}',
)
kwargs = mock_service.call_args.kwargs
assert kwargs["project_id"] == "demo-project"
assert kwargs["location"] == "us-central1"
assert kwargs["credentials"] == '{"type":"service_account"}'
assert kwargs["settings"].model == "gemini-2.5-pro"
assert kwargs["settings"].temperature == 0.1
def test_create_llm_service_extracts_vertex_credentials(self):
user_config = SimpleNamespace(
llm=SimpleNamespace(
provider=ServiceProviders.GOOGLE_VERTEX.value,
api_key=None,
model="gemini-2.5-flash",
project_id="demo-project",
location="us-east4",
credentials='{"type":"service_account"}',
)
)
with patch(
"api.services.pipecat.service_factory.GoogleVertexLLMService"
) as mock_service:
create_llm_service(user_config)
kwargs = mock_service.call_args.kwargs
assert kwargs["project_id"] == "demo-project"
assert kwargs["location"] == "us-east4"
assert kwargs["credentials"] == '{"type":"service_account"}'
class TestGoogleVertexLLMValidation:
def test_validator_accepts_vertex_llm_without_api_key(self):
validator = UserConfigurationValidator()
config = GoogleVertexLLMConfiguration(
project_id="demo-project",
location="us-east4",
credentials='{"type":"service_account"}',
)
assert validator._validate_service(config, "llm") == []
def test_validator_requires_project_id(self):
validator = UserConfigurationValidator()
config = SimpleNamespace(
provider=ServiceProviders.GOOGLE_VERTEX.value,
project_id=None,
location="us-east4",
credentials='{"type":"service_account"}',
api_key=None,
)
result = validator._validate_service(config, "llm")
assert result == [
{"model": "llm", "message": "project_id is required for Google Vertex"}
]

View file

@ -0,0 +1,138 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from pipecat.frames.frames import LLMMessagesAppendFrame, TTSSpeakFrame
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.xai.realtime import events
from api.schemas.user_configuration import UserConfiguration
from api.services.configuration.registry import GrokRealtimeLLMConfiguration
from api.services.pipecat.realtime.grok_realtime import (
DograhGrokRealtimeLLMService,
)
from api.services.pipecat.service_factory import create_realtime_llm_service
def _make_service() -> DograhGrokRealtimeLLMService:
service = DograhGrokRealtimeLLMService(api_key="test-key")
service._create_response = AsyncMock()
service._process_completed_function_calls = AsyncMock()
return service
@pytest.mark.asyncio
async def test_initial_context_triggers_response_when_context_was_prepopulated():
service = _make_service()
context = LLMContext()
service._context = context
await service._handle_context(context)
assert service._handled_initial_context is True
assert service._context is context
service._create_response.assert_awaited_once()
service._process_completed_function_calls.assert_not_awaited()
@pytest.mark.asyncio
async def test_tts_greeting_uses_initial_context_handler():
service = _make_service()
service._context = LLMContext()
service._handle_context = AsyncMock()
await service.process_frame(
TTSSpeakFrame("hello", append_to_context=True),
FrameDirection.DOWNSTREAM,
)
service._handle_context.assert_awaited_once_with(service._context)
@pytest.mark.asyncio
async def test_messages_append_frame_sends_conversation_item():
service = _make_service()
service._api_session_ready = True
service.send_client_event = AsyncMock()
service._send_manual_response_create = AsyncMock()
await service._handle_messages_append(
LLMMessagesAppendFrame(
[{"role": "user", "content": "Are you still there?"}],
run_llm=True,
)
)
service.send_client_event.assert_awaited_once()
event = service.send_client_event.await_args.args[0]
assert isinstance(event, events.ConversationItemCreateEvent)
assert event.item.role == "user"
assert event.item.type == "message"
assert event.item.content == [
events.ItemContent(type="input_text", text="Are you still there?")
]
service._send_manual_response_create.assert_awaited_once()
@pytest.mark.asyncio
async def test_function_call_is_deferred_until_bot_stops_speaking():
service = _make_service()
service._context = LLMContext()
service.run_function_calls = AsyncMock()
service._bot_is_speaking = True
service._pending_function_calls["call-1"] = SimpleNamespace(name="customer_support")
await service._handle_evt_function_call_arguments_done(
SimpleNamespace(
call_id="call-1",
name="customer_support",
arguments='{"department":"sales"}',
)
)
service.run_function_calls.assert_not_awaited()
assert len(service._deferred_function_calls) == 1
await service._run_pending_function_calls()
service.run_function_calls.assert_awaited_once()
assert service._deferred_function_calls == []
@pytest.mark.asyncio
async def test_completed_input_transcription_is_broadcast_as_finalized():
service = _make_service()
service._call_event_handler = AsyncMock()
service.broadcast_frame = AsyncMock()
evt = SimpleNamespace(item_id="item-1", transcript="Hello there")
await service._handle_evt_input_audio_transcription_completed(evt)
service._call_event_handler.assert_awaited_once_with(
"on_conversation_item_updated", "item-1", None
)
service.broadcast_frame.assert_awaited_once()
assert service.broadcast_frame.await_args.args[0].__name__ == "TranscriptionFrame"
assert service.broadcast_frame.await_args.kwargs["text"] == "Hello there"
assert service.broadcast_frame.await_args.kwargs["finalized"] is True
def test_factory_creates_dograh_grok_realtime_service():
user_config = UserConfiguration(
is_realtime=True,
realtime=GrokRealtimeLLMConfiguration(
provider="grok_realtime",
api_key="xai-key",
model="grok-voice-think-fast-1.0",
voice="Sal",
),
)
service = create_realtime_llm_service(
user_config,
audio_config=SimpleNamespace(),
)
assert isinstance(service, DograhGrokRealtimeLLMService)

View file

@ -9,6 +9,7 @@ from api.services.auth.depends import get_user
from api.services.configuration.masking import mask_key
from api.services.configuration.registry import (
GoogleLLMService,
GoogleVertexLLMConfiguration,
OpenAILLMService,
)
@ -168,3 +169,44 @@ class TestMaskedKeyRejection:
# Merge resolves the masked key back to the real one,
# so check_for_masked_keys should NOT raise.
assert response.status_code == 200
def test_allows_same_provider_with_masked_vertex_credentials(self):
"""Same provider with masked credentials should succeed."""
app = _make_test_app()
client = TestClient(app)
real_credentials = '{"type":"service_account","project_id":"demo-project"}'
masked_credentials = mask_key(real_credentials)
existing = UserConfiguration(
llm=GoogleVertexLLMConfiguration(
provider="google_vertex",
api_key=None,
model="gemini-2.5-flash",
project_id="demo-project",
location="us-east4",
credentials=real_credentials,
)
)
with (
patch("api.routes.user.db_client") as mock_db,
patch("api.routes.user.UserConfigurationValidator") as mock_validator,
):
mock_db.get_user_configurations = AsyncMock(return_value=existing)
mock_db.update_user_configuration = AsyncMock(return_value=existing)
mock_validator.return_value.validate = AsyncMock()
response = client.put(
"/user/configurations/user",
json={
"llm": {
"provider": "google_vertex",
"model": "gemini-2.5-flash",
"project_id": "demo-project",
"location": "us-east4",
"credentials": masked_credentials,
}
},
)
assert response.status_code == 200

View file

@ -109,11 +109,12 @@ class TestMiniMaxTTSServiceFactory:
)
audio_config = SimpleNamespace(transport_in_sample_rate=16000)
with patch(
"api.services.pipecat.service_factory.aiohttp.ClientSession"
), patch(
"api.services.pipecat.service_factory.MiniMaxOwnedSessionTTSService"
) as mock_service:
with (
patch("api.services.pipecat.service_factory.aiohttp.ClientSession"),
patch(
"api.services.pipecat.service_factory.MiniMaxOwnedSessionTTSService"
) as mock_service,
):
create_tts_service(user_config, audio_config)
assert mock_service.call_count == 1

View file

@ -14,6 +14,8 @@ from api.services.configuration.registry import (
DeepgramSTTConfiguration,
ElevenlabsTTSConfiguration,
GoogleRealtimeLLMConfiguration,
GoogleVertexLLMConfiguration,
GrokRealtimeLLMConfiguration,
OpenAILLMService,
)
from api.services.configuration.resolve import resolve_effective_config
@ -164,6 +166,23 @@ class TestProviderChange:
assert result.tts.provider == "elevenlabs"
assert result.stt.provider == "deepgram"
def test_override_llm_to_google_vertex(self, global_config):
result = resolve_effective_config(
global_config,
{
"llm": {
"provider": "google_vertex",
"model": "gemini-2.5-flash",
"project_id": "demo-project",
"location": "us-east4",
"credentials": '{"type":"service_account"}',
}
},
)
assert isinstance(result.llm, GoogleVertexLLMConfiguration)
assert result.llm.provider == "google_vertex"
assert result.llm.project_id == "demo-project"
# ---------------------------------------------------------------------------
# API key inheritance
@ -226,6 +245,22 @@ class TestRealtimeOverride:
assert result.realtime.provider == "google_realtime" # inherited
assert result.realtime.api_key == "goog-global-rt" # inherited
def test_switch_realtime_provider_to_grok(self, global_config_realtime):
result = resolve_effective_config(
global_config_realtime,
{
"realtime": {
"provider": "grok_realtime",
"api_key": "xai-key",
"model": "grok-voice-think-fast-1.0",
"voice": "Sal",
}
},
)
assert isinstance(result.realtime, GrokRealtimeLLMConfiguration)
assert result.realtime.provider == "grok_realtime"
assert result.realtime.voice == "Sal"
def test_override_is_realtime_only_without_realtime_section(self, global_config):
"""Override is_realtime=True but provide no realtime config.
Should set the flag; realtime section stays None from global."""

View file

@ -51,6 +51,19 @@ def test_openai_realtime_uses_provider_turn_frames_without_local_vad():
assert isinstance(strategies.stop[0], ExternalUserTurnStopStrategy)
def test_grok_realtime_uses_provider_turn_frames_without_local_vad():
strategies, vad_analyzer = _create_realtime_user_turn_config(
ServiceProviders.GROK_REALTIME.value
)
assert vad_analyzer is None
assert len(strategies.start) == 1
assert isinstance(strategies.start[0], ExternalUserTurnStartStrategy)
assert strategies.start[0]._enable_interruptions is False
assert len(strategies.stop) == 1
assert isinstance(strategies.stop[0], ExternalUserTurnStopStrategy)
def test_unknown_realtime_providers_keep_local_vad():
strategies, vad_analyzer = _create_realtime_user_turn_config("other_realtime")