feat: add xai grok as realtime model

This commit is contained in:
Abhishek Kumar 2026-05-22 18:04:59 +05:30
parent 291264de7b
commit 9135c2da13
14 changed files with 776 additions and 36 deletions

View file

@ -0,0 +1,253 @@
"""Dograh subclass of pipecat's Grok Realtime LLM service.
Layers Dograh engine integration quirks onto upstream-pristine
:class:`GrokRealtimeLLMService`. Grok already supports runtime session updates,
so this wrapper stays close to the OpenAI realtime shim.
Adds:
- **User-mute audio gating** via ``UserMuteStarted/StoppedFrame``.
- **TTSSpeakFrame as initial-response trigger** so the engine's greeting
flow kicks off the bot's first response.
- **One-off LLMMessagesAppendFrame handling** for ephemeral realtime prompts
like user-idle checks, without mutating Dograh's local ``LLMContext``.
- **Function-call deferral** until the bot finishes speaking, to avoid racing
tool execution with the active audio turn.
- **finalized=True on TranscriptionFrame** for parity with Dograh's other
realtime providers.
"""
import json
from typing import Any
from loguru import logger
from pipecat.frames.frames import (
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
Frame,
LLMFullResponseStartFrame,
LLMMessagesAppendFrame,
TranscriptionFrame,
TTSSpeakFrame,
UserMuteStartedFrame,
UserMuteStoppedFrame,
)
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.llm_service import FunctionCallFromLLM
from pipecat.services.xai.realtime import events
from pipecat.services.xai.realtime.llm import GrokRealtimeLLMService
from pipecat.utils.time import time_now_iso8601
class DograhGrokRealtimeLLMService(GrokRealtimeLLMService):
"""Grok Realtime with Dograh engine integration quirks."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._user_is_muted: bool = False
self._handled_initial_context: bool = False
self._bot_is_speaking: bool = False
self._deferred_function_calls: list[FunctionCallFromLLM] = []
async def process_frame(self, frame: Frame, direction: FrameDirection):
if isinstance(frame, UserMuteStartedFrame):
self._user_is_muted = True
await self.push_frame(frame, direction)
return
if isinstance(frame, UserMuteStoppedFrame):
self._user_is_muted = False
await self.push_frame(frame, direction)
return
if isinstance(frame, TTSSpeakFrame):
if not self._handled_initial_context:
await self._handle_context(self._context)
else:
logger.warning(
f"{self}: TTSSpeakFrame after initial context already "
"handled — Grok Realtime owns audio generation, ignoring"
)
return
if isinstance(frame, LLMMessagesAppendFrame):
await self._handle_messages_append(frame)
return
if isinstance(frame, BotStartedSpeakingFrame):
self._bot_is_speaking = True
elif isinstance(frame, BotStoppedSpeakingFrame):
self._bot_is_speaking = False
await self._run_pending_function_calls()
await super().process_frame(frame, direction)
async def _handle_messages_append(self, frame: LLMMessagesAppendFrame):
"""Consume a one-off append frame without mutating the local LLMContext."""
if self._disconnecting:
return
if not self._api_session_ready:
if frame.run_llm:
logger.debug(
f"{self}: LLMMessagesAppendFrame received before session ready; "
"deferring response until the session is initialized"
)
self._run_llm_when_api_session_ready = True
return
appended_any = False
for message in frame.messages:
item = self._message_to_conversation_item(message)
if item is None:
continue
evt = events.ConversationItemCreateEvent(item=item)
self._messages_added_manually[evt.item.id] = True
await self.send_client_event(evt)
appended_any = True
if frame.run_llm and appended_any:
await self._send_manual_response_create()
async def _handle_context(self, context: LLMContext):
if not self._handled_initial_context:
if context is None:
logger.warning(
f"{self}: received initial context trigger before context was set"
)
return
self._handled_initial_context = True
self._context = context
await self._create_response()
else:
self._context = context
await self._process_completed_function_calls(send_new_results=True)
async def _send_user_audio(self, frame):
if self._user_is_muted:
return
await super()._send_user_audio(frame)
def _message_to_conversation_item(
self, message: dict[str, Any]
) -> events.ConversationItem | None:
if not isinstance(message, dict):
logger.warning(
f"{self}: skipping unsupported appended message payload {message!r}"
)
return None
role = message.get("role")
if role not in {"user", "system", "developer"}:
logger.warning(
f"{self}: skipping unsupported appended message role {role!r}"
)
return None
text = self._extract_text_content(message.get("content"))
if not text:
logger.warning(
f"{self}: skipping appended message with unsupported content {message!r}"
)
return None
item_role = "system" if role in {"system", "developer"} else "user"
return events.ConversationItem(
type="message",
role=item_role,
content=[events.ItemContent(type="input_text", text=text)],
)
@staticmethod
def _extract_text_content(content: Any) -> str | None:
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for part in content:
if not isinstance(part, dict):
return None
if part.get("type") != "text":
return None
text = part.get("text")
if not isinstance(text, str):
return None
parts.append(text)
return "\n".join(parts) if parts else None
return None
async def _send_manual_response_create(self):
"""Trigger inference after manually appending conversation items."""
await self.push_frame(LLMFullResponseStartFrame())
await self.start_processing_metrics()
await self.start_ttfb_metrics()
await self.send_client_event(
events.ResponseCreateEvent(
response=events.ResponseProperties(modalities=["text", "audio"])
)
)
async def _run_pending_function_calls(self):
if not self._deferred_function_calls:
return
function_calls = self._deferred_function_calls
self._deferred_function_calls = []
logger.debug(
f"{self}: executing {len(function_calls)} deferred function call(s) "
"after bot turn ended"
)
await self.run_function_calls(function_calls)
async def _handle_evt_function_call_arguments_done(self, evt):
"""Process or defer tool calls until the bot finishes speaking."""
try:
args = json.loads(evt.arguments)
function_call_item = self._pending_function_calls.get(evt.call_id)
if function_call_item:
del self._pending_function_calls[evt.call_id]
function_name = getattr(evt, "name", None) or function_call_item.name
function_calls = [
FunctionCallFromLLM(
context=self._context,
tool_call_id=evt.call_id,
function_name=function_name,
arguments=args,
)
]
if self._bot_is_speaking:
self._deferred_function_calls.extend(function_calls)
logger.debug(
f"{self}: deferring function call {function_name} "
"until bot stops speaking"
)
else:
await self.run_function_calls(function_calls)
logger.debug(f"Processed function call: {function_name}")
else:
logger.warning(
f"No tracked function call found for call_id: {evt.call_id}"
)
logger.warning(
f"Available pending calls: {list(self._pending_function_calls.keys())}"
)
except Exception as e:
logger.error(f"Failed to process function call arguments: {e}")
async def _handle_evt_input_audio_transcription_completed(self, evt):
await self._call_event_handler(
"on_conversation_item_updated", evt.item_id, None
)
transcript = evt.transcript.strip() if evt.transcript else ""
if not transcript:
return
await self.broadcast_frame(
TranscriptionFrame,
text=transcript,
user_id="",
timestamp=time_now_iso8601(),
result=evt,
finalized=True,
)

View file

@ -120,6 +120,16 @@ def _create_realtime_user_turn_config(provider: str):
),
None,
)
if provider == ServiceProviders.GROK_REALTIME.value:
# Grok Voice Agent emits server-side speech-start/stop and
# interruption signals, so local VAD should stay out of the way.
return (
UserTurnStrategies(
start=[ExternalUserTurnStartStrategy()],
stop=[ExternalUserTurnStopStrategy()],
),
None,
)
return (
UserTurnStrategies(

View file

@ -30,7 +30,13 @@ from pipecat.services.gladia.stt import GladiaSTTService, GladiaSTTSettings
from pipecat.services.google.llm import GoogleLLMService, GoogleLLMSettings
from pipecat.services.google.stt import GoogleSTTService, GoogleSTTSettings
from pipecat.services.google.tts import GoogleTTSService, GoogleTTSSettings
from pipecat.services.google.vertex.llm import (
GoogleVertexLLMService,
GoogleVertexLLMSettings,
)
from pipecat.services.groq.llm import GroqLLMService, GroqLLMSettings
from pipecat.services.minimax.llm import MiniMaxLLMService
from pipecat.services.minimax.tts import MiniMaxTTSSettings
from pipecat.services.openai.base_llm import OpenAILLMSettings
from pipecat.services.openai.llm import OpenAILLMService
from pipecat.services.openai.stt import (
@ -40,8 +46,6 @@ from pipecat.services.openai.stt import (
from pipecat.services.openai.tts import OpenAITTSService, OpenAITTSSettings
from pipecat.services.openrouter.llm import OpenRouterLLMService, OpenRouterLLMSettings
from pipecat.services.rime.tts import RimeTTSService, RimeTTSSettings
from pipecat.services.minimax.llm import MiniMaxLLMService
from pipecat.services.minimax.tts import MiniMaxHttpTTSService, MiniMaxTTSSettings
from pipecat.services.sarvam.stt import SarvamSTTService, SarvamSTTSettings
from pipecat.services.sarvam.tts import SarvamTTSService, SarvamTTSSettings
from pipecat.services.speaches.llm import SpeachesLLMService, SpeachesLLMSettings
@ -482,13 +486,16 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
def create_llm_service_from_provider(
provider: str,
model: str,
api_key: str,
api_key: str | None,
*,
base_url: str | None = None,
endpoint: str | None = None,
aws_access_key: str | None = None,
aws_secret_key: str | None = None,
aws_region: str | None = None,
project_id: str | None = None,
location: str | None = None,
credentials: str | None = None,
temperature: float | None = None,
):
"""Create an LLM service from explicit provider/model/api_key.
@ -528,6 +535,13 @@ def create_llm_service_from_provider(
api_key=api_key,
settings=GoogleLLMSettings(model=model, temperature=0.1),
)
elif provider == ServiceProviders.GOOGLE_VERTEX.value:
return GoogleVertexLLMService(
credentials=credentials,
project_id=project_id,
location=location or "us-east4",
settings=GoogleVertexLLMSettings(model=model, temperature=0.1),
)
elif provider == ServiceProviders.AZURE.value:
return AzureLLMService(
api_key=api_key,
@ -611,6 +625,21 @@ def create_realtime_llm_service(user_config, audio_config: "AudioConfig"):
),
),
)
elif provider == ServiceProviders.GROK_REALTIME.value:
from api.services.pipecat.realtime.grok_realtime import (
DograhGrokRealtimeLLMService,
)
from pipecat.services.xai.realtime.events import SessionProperties
return DograhGrokRealtimeLLMService(
api_key=api_key,
settings=DograhGrokRealtimeLLMService.Settings(
model=model,
session_properties=SessionProperties(
voice=voice or "Ara",
),
),
)
elif provider == ServiceProviders.GOOGLE_REALTIME.value:
from api.services.pipecat.realtime.gemini_live import (
DograhGeminiLiveLLMService,
@ -672,6 +701,10 @@ def create_llm_service(user_config):
kwargs["aws_access_key"] = user_config.llm.aws_access_key
kwargs["aws_secret_key"] = user_config.llm.aws_secret_key
kwargs["aws_region"] = user_config.llm.aws_region
elif provider == ServiceProviders.GOOGLE_VERTEX.value:
kwargs["project_id"] = user_config.llm.project_id
kwargs["location"] = user_config.llm.location
kwargs["credentials"] = user_config.llm.credentials
elif provider == ServiceProviders.MINIMAX.value:
kwargs["base_url"] = user_config.llm.base_url
kwargs["temperature"] = user_config.llm.temperature