mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-19 08:28:10 +02:00
feat: add xai grok as realtime model
This commit is contained in:
parent
291264de7b
commit
9135c2da13
14 changed files with 776 additions and 36 deletions
103
api/tests/test_google_vertex_llm_service_factory.py
Normal file
103
api/tests/test_google_vertex_llm_service_factory.py
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from api.services.configuration.check_validity import UserConfigurationValidator
|
||||
from api.services.configuration.registry import (
|
||||
REGISTRY,
|
||||
GoogleVertexLLMConfiguration,
|
||||
ServiceProviders,
|
||||
ServiceType,
|
||||
)
|
||||
from api.services.pipecat.service_factory import (
|
||||
create_llm_service,
|
||||
create_llm_service_from_provider,
|
||||
)
|
||||
|
||||
|
||||
class TestGoogleVertexLLMConfiguration:
|
||||
def test_defaults(self):
|
||||
config = GoogleVertexLLMConfiguration(project_id="demo-project")
|
||||
assert config.provider == ServiceProviders.GOOGLE_VERTEX
|
||||
assert config.model == "gemini-2.5-flash"
|
||||
assert config.location == "us-east4"
|
||||
assert config.credentials is None
|
||||
assert config.api_key is None
|
||||
|
||||
def test_registered_in_llm_registry(self):
|
||||
assert ServiceProviders.GOOGLE_VERTEX in REGISTRY[ServiceType.LLM]
|
||||
assert (
|
||||
REGISTRY[ServiceType.LLM][ServiceProviders.GOOGLE_VERTEX]
|
||||
is GoogleVertexLLMConfiguration
|
||||
)
|
||||
|
||||
|
||||
class TestGoogleVertexLLMServiceFactory:
|
||||
def test_create_llm_service_from_provider_uses_vertex_service(self):
|
||||
with patch(
|
||||
"api.services.pipecat.service_factory.GoogleVertexLLMService"
|
||||
) as mock_service:
|
||||
create_llm_service_from_provider(
|
||||
provider=ServiceProviders.GOOGLE_VERTEX.value,
|
||||
model="gemini-2.5-pro",
|
||||
api_key=None,
|
||||
project_id="demo-project",
|
||||
location="us-central1",
|
||||
credentials='{"type":"service_account"}',
|
||||
)
|
||||
|
||||
kwargs = mock_service.call_args.kwargs
|
||||
assert kwargs["project_id"] == "demo-project"
|
||||
assert kwargs["location"] == "us-central1"
|
||||
assert kwargs["credentials"] == '{"type":"service_account"}'
|
||||
assert kwargs["settings"].model == "gemini-2.5-pro"
|
||||
assert kwargs["settings"].temperature == 0.1
|
||||
|
||||
def test_create_llm_service_extracts_vertex_credentials(self):
|
||||
user_config = SimpleNamespace(
|
||||
llm=SimpleNamespace(
|
||||
provider=ServiceProviders.GOOGLE_VERTEX.value,
|
||||
api_key=None,
|
||||
model="gemini-2.5-flash",
|
||||
project_id="demo-project",
|
||||
location="us-east4",
|
||||
credentials='{"type":"service_account"}',
|
||||
)
|
||||
)
|
||||
|
||||
with patch(
|
||||
"api.services.pipecat.service_factory.GoogleVertexLLMService"
|
||||
) as mock_service:
|
||||
create_llm_service(user_config)
|
||||
|
||||
kwargs = mock_service.call_args.kwargs
|
||||
assert kwargs["project_id"] == "demo-project"
|
||||
assert kwargs["location"] == "us-east4"
|
||||
assert kwargs["credentials"] == '{"type":"service_account"}'
|
||||
|
||||
|
||||
class TestGoogleVertexLLMValidation:
|
||||
def test_validator_accepts_vertex_llm_without_api_key(self):
|
||||
validator = UserConfigurationValidator()
|
||||
config = GoogleVertexLLMConfiguration(
|
||||
project_id="demo-project",
|
||||
location="us-east4",
|
||||
credentials='{"type":"service_account"}',
|
||||
)
|
||||
|
||||
assert validator._validate_service(config, "llm") == []
|
||||
|
||||
def test_validator_requires_project_id(self):
|
||||
validator = UserConfigurationValidator()
|
||||
config = SimpleNamespace(
|
||||
provider=ServiceProviders.GOOGLE_VERTEX.value,
|
||||
project_id=None,
|
||||
location="us-east4",
|
||||
credentials='{"type":"service_account"}',
|
||||
api_key=None,
|
||||
)
|
||||
|
||||
result = validator._validate_service(config, "llm")
|
||||
|
||||
assert result == [
|
||||
{"model": "llm", "message": "project_id is required for Google Vertex"}
|
||||
]
|
||||
138
api/tests/test_grok_realtime_wrapper.py
Normal file
138
api/tests/test_grok_realtime_wrapper.py
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from pipecat.frames.frames import LLMMessagesAppendFrame, TTSSpeakFrame
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.xai.realtime import events
|
||||
|
||||
from api.schemas.user_configuration import UserConfiguration
|
||||
from api.services.configuration.registry import GrokRealtimeLLMConfiguration
|
||||
from api.services.pipecat.realtime.grok_realtime import (
|
||||
DograhGrokRealtimeLLMService,
|
||||
)
|
||||
from api.services.pipecat.service_factory import create_realtime_llm_service
|
||||
|
||||
|
||||
def _make_service() -> DograhGrokRealtimeLLMService:
|
||||
service = DograhGrokRealtimeLLMService(api_key="test-key")
|
||||
service._create_response = AsyncMock()
|
||||
service._process_completed_function_calls = AsyncMock()
|
||||
return service
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initial_context_triggers_response_when_context_was_prepopulated():
|
||||
service = _make_service()
|
||||
context = LLMContext()
|
||||
service._context = context
|
||||
|
||||
await service._handle_context(context)
|
||||
|
||||
assert service._handled_initial_context is True
|
||||
assert service._context is context
|
||||
service._create_response.assert_awaited_once()
|
||||
service._process_completed_function_calls.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tts_greeting_uses_initial_context_handler():
|
||||
service = _make_service()
|
||||
service._context = LLMContext()
|
||||
service._handle_context = AsyncMock()
|
||||
|
||||
await service.process_frame(
|
||||
TTSSpeakFrame("hello", append_to_context=True),
|
||||
FrameDirection.DOWNSTREAM,
|
||||
)
|
||||
|
||||
service._handle_context.assert_awaited_once_with(service._context)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_messages_append_frame_sends_conversation_item():
|
||||
service = _make_service()
|
||||
service._api_session_ready = True
|
||||
service.send_client_event = AsyncMock()
|
||||
service._send_manual_response_create = AsyncMock()
|
||||
|
||||
await service._handle_messages_append(
|
||||
LLMMessagesAppendFrame(
|
||||
[{"role": "user", "content": "Are you still there?"}],
|
||||
run_llm=True,
|
||||
)
|
||||
)
|
||||
|
||||
service.send_client_event.assert_awaited_once()
|
||||
event = service.send_client_event.await_args.args[0]
|
||||
assert isinstance(event, events.ConversationItemCreateEvent)
|
||||
assert event.item.role == "user"
|
||||
assert event.item.type == "message"
|
||||
assert event.item.content == [
|
||||
events.ItemContent(type="input_text", text="Are you still there?")
|
||||
]
|
||||
service._send_manual_response_create.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_function_call_is_deferred_until_bot_stops_speaking():
|
||||
service = _make_service()
|
||||
service._context = LLMContext()
|
||||
service.run_function_calls = AsyncMock()
|
||||
service._bot_is_speaking = True
|
||||
service._pending_function_calls["call-1"] = SimpleNamespace(name="customer_support")
|
||||
|
||||
await service._handle_evt_function_call_arguments_done(
|
||||
SimpleNamespace(
|
||||
call_id="call-1",
|
||||
name="customer_support",
|
||||
arguments='{"department":"sales"}',
|
||||
)
|
||||
)
|
||||
|
||||
service.run_function_calls.assert_not_awaited()
|
||||
assert len(service._deferred_function_calls) == 1
|
||||
|
||||
await service._run_pending_function_calls()
|
||||
|
||||
service.run_function_calls.assert_awaited_once()
|
||||
assert service._deferred_function_calls == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completed_input_transcription_is_broadcast_as_finalized():
|
||||
service = _make_service()
|
||||
service._call_event_handler = AsyncMock()
|
||||
service.broadcast_frame = AsyncMock()
|
||||
|
||||
evt = SimpleNamespace(item_id="item-1", transcript="Hello there")
|
||||
|
||||
await service._handle_evt_input_audio_transcription_completed(evt)
|
||||
|
||||
service._call_event_handler.assert_awaited_once_with(
|
||||
"on_conversation_item_updated", "item-1", None
|
||||
)
|
||||
service.broadcast_frame.assert_awaited_once()
|
||||
assert service.broadcast_frame.await_args.args[0].__name__ == "TranscriptionFrame"
|
||||
assert service.broadcast_frame.await_args.kwargs["text"] == "Hello there"
|
||||
assert service.broadcast_frame.await_args.kwargs["finalized"] is True
|
||||
|
||||
|
||||
def test_factory_creates_dograh_grok_realtime_service():
|
||||
user_config = UserConfiguration(
|
||||
is_realtime=True,
|
||||
realtime=GrokRealtimeLLMConfiguration(
|
||||
provider="grok_realtime",
|
||||
api_key="xai-key",
|
||||
model="grok-voice-think-fast-1.0",
|
||||
voice="Sal",
|
||||
),
|
||||
)
|
||||
|
||||
service = create_realtime_llm_service(
|
||||
user_config,
|
||||
audio_config=SimpleNamespace(),
|
||||
)
|
||||
|
||||
assert isinstance(service, DograhGrokRealtimeLLMService)
|
||||
|
|
@ -9,6 +9,7 @@ from api.services.auth.depends import get_user
|
|||
from api.services.configuration.masking import mask_key
|
||||
from api.services.configuration.registry import (
|
||||
GoogleLLMService,
|
||||
GoogleVertexLLMConfiguration,
|
||||
OpenAILLMService,
|
||||
)
|
||||
|
||||
|
|
@ -168,3 +169,44 @@ class TestMaskedKeyRejection:
|
|||
# Merge resolves the masked key back to the real one,
|
||||
# so check_for_masked_keys should NOT raise.
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_allows_same_provider_with_masked_vertex_credentials(self):
|
||||
"""Same provider with masked credentials should succeed."""
|
||||
app = _make_test_app()
|
||||
client = TestClient(app)
|
||||
|
||||
real_credentials = '{"type":"service_account","project_id":"demo-project"}'
|
||||
masked_credentials = mask_key(real_credentials)
|
||||
existing = UserConfiguration(
|
||||
llm=GoogleVertexLLMConfiguration(
|
||||
provider="google_vertex",
|
||||
api_key=None,
|
||||
model="gemini-2.5-flash",
|
||||
project_id="demo-project",
|
||||
location="us-east4",
|
||||
credentials=real_credentials,
|
||||
)
|
||||
)
|
||||
|
||||
with (
|
||||
patch("api.routes.user.db_client") as mock_db,
|
||||
patch("api.routes.user.UserConfigurationValidator") as mock_validator,
|
||||
):
|
||||
mock_db.get_user_configurations = AsyncMock(return_value=existing)
|
||||
mock_db.update_user_configuration = AsyncMock(return_value=existing)
|
||||
mock_validator.return_value.validate = AsyncMock()
|
||||
|
||||
response = client.put(
|
||||
"/user/configurations/user",
|
||||
json={
|
||||
"llm": {
|
||||
"provider": "google_vertex",
|
||||
"model": "gemini-2.5-flash",
|
||||
"project_id": "demo-project",
|
||||
"location": "us-east4",
|
||||
"credentials": masked_credentials,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
|
|
|||
|
|
@ -109,11 +109,12 @@ class TestMiniMaxTTSServiceFactory:
|
|||
)
|
||||
audio_config = SimpleNamespace(transport_in_sample_rate=16000)
|
||||
|
||||
with patch(
|
||||
"api.services.pipecat.service_factory.aiohttp.ClientSession"
|
||||
), patch(
|
||||
"api.services.pipecat.service_factory.MiniMaxOwnedSessionTTSService"
|
||||
) as mock_service:
|
||||
with (
|
||||
patch("api.services.pipecat.service_factory.aiohttp.ClientSession"),
|
||||
patch(
|
||||
"api.services.pipecat.service_factory.MiniMaxOwnedSessionTTSService"
|
||||
) as mock_service,
|
||||
):
|
||||
create_tts_service(user_config, audio_config)
|
||||
|
||||
assert mock_service.call_count == 1
|
||||
|
|
|
|||
|
|
@ -14,6 +14,8 @@ from api.services.configuration.registry import (
|
|||
DeepgramSTTConfiguration,
|
||||
ElevenlabsTTSConfiguration,
|
||||
GoogleRealtimeLLMConfiguration,
|
||||
GoogleVertexLLMConfiguration,
|
||||
GrokRealtimeLLMConfiguration,
|
||||
OpenAILLMService,
|
||||
)
|
||||
from api.services.configuration.resolve import resolve_effective_config
|
||||
|
|
@ -164,6 +166,23 @@ class TestProviderChange:
|
|||
assert result.tts.provider == "elevenlabs"
|
||||
assert result.stt.provider == "deepgram"
|
||||
|
||||
def test_override_llm_to_google_vertex(self, global_config):
|
||||
result = resolve_effective_config(
|
||||
global_config,
|
||||
{
|
||||
"llm": {
|
||||
"provider": "google_vertex",
|
||||
"model": "gemini-2.5-flash",
|
||||
"project_id": "demo-project",
|
||||
"location": "us-east4",
|
||||
"credentials": '{"type":"service_account"}',
|
||||
}
|
||||
},
|
||||
)
|
||||
assert isinstance(result.llm, GoogleVertexLLMConfiguration)
|
||||
assert result.llm.provider == "google_vertex"
|
||||
assert result.llm.project_id == "demo-project"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# API key inheritance
|
||||
|
|
@ -226,6 +245,22 @@ class TestRealtimeOverride:
|
|||
assert result.realtime.provider == "google_realtime" # inherited
|
||||
assert result.realtime.api_key == "goog-global-rt" # inherited
|
||||
|
||||
def test_switch_realtime_provider_to_grok(self, global_config_realtime):
|
||||
result = resolve_effective_config(
|
||||
global_config_realtime,
|
||||
{
|
||||
"realtime": {
|
||||
"provider": "grok_realtime",
|
||||
"api_key": "xai-key",
|
||||
"model": "grok-voice-think-fast-1.0",
|
||||
"voice": "Sal",
|
||||
}
|
||||
},
|
||||
)
|
||||
assert isinstance(result.realtime, GrokRealtimeLLMConfiguration)
|
||||
assert result.realtime.provider == "grok_realtime"
|
||||
assert result.realtime.voice == "Sal"
|
||||
|
||||
def test_override_is_realtime_only_without_realtime_section(self, global_config):
|
||||
"""Override is_realtime=True but provide no realtime config.
|
||||
Should set the flag; realtime section stays None from global."""
|
||||
|
|
|
|||
|
|
@ -51,6 +51,19 @@ def test_openai_realtime_uses_provider_turn_frames_without_local_vad():
|
|||
assert isinstance(strategies.stop[0], ExternalUserTurnStopStrategy)
|
||||
|
||||
|
||||
def test_grok_realtime_uses_provider_turn_frames_without_local_vad():
|
||||
strategies, vad_analyzer = _create_realtime_user_turn_config(
|
||||
ServiceProviders.GROK_REALTIME.value
|
||||
)
|
||||
|
||||
assert vad_analyzer is None
|
||||
assert len(strategies.start) == 1
|
||||
assert isinstance(strategies.start[0], ExternalUserTurnStartStrategy)
|
||||
assert strategies.start[0]._enable_interruptions is False
|
||||
assert len(strategies.stop) == 1
|
||||
assert isinstance(strategies.stop[0], ExternalUserTurnStopStrategy)
|
||||
|
||||
|
||||
def test_unknown_realtime_providers_keep_local_vad():
|
||||
strategies, vad_analyzer = _create_realtime_user_turn_config("other_realtime")
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue