mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-07-01 08:59:46 +02:00
feat: add xai grok as realtime model
This commit is contained in:
parent
291264de7b
commit
9135c2da13
14 changed files with 776 additions and 36 deletions
253
api/services/pipecat/realtime/grok_realtime.py
Normal file
253
api/services/pipecat/realtime/grok_realtime.py
Normal file
|
|
@ -0,0 +1,253 @@
|
|||
"""Dograh subclass of pipecat's Grok Realtime LLM service.
|
||||
|
||||
Layers Dograh engine integration quirks onto upstream-pristine
|
||||
:class:`GrokRealtimeLLMService`. Grok already supports runtime session updates,
|
||||
so this wrapper stays close to the OpenAI realtime shim.
|
||||
|
||||
Adds:
|
||||
|
||||
- **User-mute audio gating** via ``UserMuteStarted/StoppedFrame``.
|
||||
- **TTSSpeakFrame as initial-response trigger** so the engine's greeting
|
||||
flow kicks off the bot's first response.
|
||||
- **One-off LLMMessagesAppendFrame handling** for ephemeral realtime prompts
|
||||
like user-idle checks, without mutating Dograh's local ``LLMContext``.
|
||||
- **Function-call deferral** until the bot finishes speaking, to avoid racing
|
||||
tool execution with the active audio turn.
|
||||
- **finalized=True on TranscriptionFrame** for parity with Dograh's other
|
||||
realtime providers.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
Frame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesAppendFrame,
|
||||
TranscriptionFrame,
|
||||
TTSSpeakFrame,
|
||||
UserMuteStartedFrame,
|
||||
UserMuteStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.llm_service import FunctionCallFromLLM
|
||||
from pipecat.services.xai.realtime import events
|
||||
from pipecat.services.xai.realtime.llm import GrokRealtimeLLMService
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
|
||||
class DograhGrokRealtimeLLMService(GrokRealtimeLLMService):
|
||||
"""Grok Realtime with Dograh engine integration quirks."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._user_is_muted: bool = False
|
||||
self._handled_initial_context: bool = False
|
||||
self._bot_is_speaking: bool = False
|
||||
self._deferred_function_calls: list[FunctionCallFromLLM] = []
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
if isinstance(frame, UserMuteStartedFrame):
|
||||
self._user_is_muted = True
|
||||
await self.push_frame(frame, direction)
|
||||
return
|
||||
if isinstance(frame, UserMuteStoppedFrame):
|
||||
self._user_is_muted = False
|
||||
await self.push_frame(frame, direction)
|
||||
return
|
||||
if isinstance(frame, TTSSpeakFrame):
|
||||
if not self._handled_initial_context:
|
||||
await self._handle_context(self._context)
|
||||
else:
|
||||
logger.warning(
|
||||
f"{self}: TTSSpeakFrame after initial context already "
|
||||
"handled — Grok Realtime owns audio generation, ignoring"
|
||||
)
|
||||
return
|
||||
if isinstance(frame, LLMMessagesAppendFrame):
|
||||
await self._handle_messages_append(frame)
|
||||
return
|
||||
if isinstance(frame, BotStartedSpeakingFrame):
|
||||
self._bot_is_speaking = True
|
||||
elif isinstance(frame, BotStoppedSpeakingFrame):
|
||||
self._bot_is_speaking = False
|
||||
await self._run_pending_function_calls()
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
async def _handle_messages_append(self, frame: LLMMessagesAppendFrame):
|
||||
"""Consume a one-off append frame without mutating the local LLMContext."""
|
||||
if self._disconnecting:
|
||||
return
|
||||
|
||||
if not self._api_session_ready:
|
||||
if frame.run_llm:
|
||||
logger.debug(
|
||||
f"{self}: LLMMessagesAppendFrame received before session ready; "
|
||||
"deferring response until the session is initialized"
|
||||
)
|
||||
self._run_llm_when_api_session_ready = True
|
||||
return
|
||||
|
||||
appended_any = False
|
||||
for message in frame.messages:
|
||||
item = self._message_to_conversation_item(message)
|
||||
if item is None:
|
||||
continue
|
||||
evt = events.ConversationItemCreateEvent(item=item)
|
||||
self._messages_added_manually[evt.item.id] = True
|
||||
await self.send_client_event(evt)
|
||||
appended_any = True
|
||||
|
||||
if frame.run_llm and appended_any:
|
||||
await self._send_manual_response_create()
|
||||
|
||||
async def _handle_context(self, context: LLMContext):
|
||||
if not self._handled_initial_context:
|
||||
if context is None:
|
||||
logger.warning(
|
||||
f"{self}: received initial context trigger before context was set"
|
||||
)
|
||||
return
|
||||
self._handled_initial_context = True
|
||||
self._context = context
|
||||
await self._create_response()
|
||||
else:
|
||||
self._context = context
|
||||
await self._process_completed_function_calls(send_new_results=True)
|
||||
|
||||
async def _send_user_audio(self, frame):
|
||||
if self._user_is_muted:
|
||||
return
|
||||
await super()._send_user_audio(frame)
|
||||
|
||||
def _message_to_conversation_item(
|
||||
self, message: dict[str, Any]
|
||||
) -> events.ConversationItem | None:
|
||||
if not isinstance(message, dict):
|
||||
logger.warning(
|
||||
f"{self}: skipping unsupported appended message payload {message!r}"
|
||||
)
|
||||
return None
|
||||
|
||||
role = message.get("role")
|
||||
if role not in {"user", "system", "developer"}:
|
||||
logger.warning(
|
||||
f"{self}: skipping unsupported appended message role {role!r}"
|
||||
)
|
||||
return None
|
||||
|
||||
text = self._extract_text_content(message.get("content"))
|
||||
if not text:
|
||||
logger.warning(
|
||||
f"{self}: skipping appended message with unsupported content {message!r}"
|
||||
)
|
||||
return None
|
||||
|
||||
item_role = "system" if role in {"system", "developer"} else "user"
|
||||
return events.ConversationItem(
|
||||
type="message",
|
||||
role=item_role,
|
||||
content=[events.ItemContent(type="input_text", text=text)],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_text_content(content: Any) -> str | None:
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for part in content:
|
||||
if not isinstance(part, dict):
|
||||
return None
|
||||
if part.get("type") != "text":
|
||||
return None
|
||||
text = part.get("text")
|
||||
if not isinstance(text, str):
|
||||
return None
|
||||
parts.append(text)
|
||||
return "\n".join(parts) if parts else None
|
||||
return None
|
||||
|
||||
async def _send_manual_response_create(self):
|
||||
"""Trigger inference after manually appending conversation items."""
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
await self.start_processing_metrics()
|
||||
await self.start_ttfb_metrics()
|
||||
await self.send_client_event(
|
||||
events.ResponseCreateEvent(
|
||||
response=events.ResponseProperties(modalities=["text", "audio"])
|
||||
)
|
||||
)
|
||||
|
||||
async def _run_pending_function_calls(self):
|
||||
if not self._deferred_function_calls:
|
||||
return
|
||||
function_calls = self._deferred_function_calls
|
||||
self._deferred_function_calls = []
|
||||
logger.debug(
|
||||
f"{self}: executing {len(function_calls)} deferred function call(s) "
|
||||
"after bot turn ended"
|
||||
)
|
||||
await self.run_function_calls(function_calls)
|
||||
|
||||
async def _handle_evt_function_call_arguments_done(self, evt):
|
||||
"""Process or defer tool calls until the bot finishes speaking."""
|
||||
try:
|
||||
args = json.loads(evt.arguments)
|
||||
|
||||
function_call_item = self._pending_function_calls.get(evt.call_id)
|
||||
if function_call_item:
|
||||
del self._pending_function_calls[evt.call_id]
|
||||
|
||||
function_name = getattr(evt, "name", None) or function_call_item.name
|
||||
function_calls = [
|
||||
FunctionCallFromLLM(
|
||||
context=self._context,
|
||||
tool_call_id=evt.call_id,
|
||||
function_name=function_name,
|
||||
arguments=args,
|
||||
)
|
||||
]
|
||||
|
||||
if self._bot_is_speaking:
|
||||
self._deferred_function_calls.extend(function_calls)
|
||||
logger.debug(
|
||||
f"{self}: deferring function call {function_name} "
|
||||
"until bot stops speaking"
|
||||
)
|
||||
else:
|
||||
await self.run_function_calls(function_calls)
|
||||
logger.debug(f"Processed function call: {function_name}")
|
||||
else:
|
||||
logger.warning(
|
||||
f"No tracked function call found for call_id: {evt.call_id}"
|
||||
)
|
||||
logger.warning(
|
||||
f"Available pending calls: {list(self._pending_function_calls.keys())}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process function call arguments: {e}")
|
||||
|
||||
async def _handle_evt_input_audio_transcription_completed(self, evt):
|
||||
await self._call_event_handler(
|
||||
"on_conversation_item_updated", evt.item_id, None
|
||||
)
|
||||
|
||||
transcript = evt.transcript.strip() if evt.transcript else ""
|
||||
if not transcript:
|
||||
return
|
||||
|
||||
await self.broadcast_frame(
|
||||
TranscriptionFrame,
|
||||
text=transcript,
|
||||
user_id="",
|
||||
timestamp=time_now_iso8601(),
|
||||
result=evt,
|
||||
finalized=True,
|
||||
)
|
||||
|
|
@ -120,6 +120,16 @@ def _create_realtime_user_turn_config(provider: str):
|
|||
),
|
||||
None,
|
||||
)
|
||||
if provider == ServiceProviders.GROK_REALTIME.value:
|
||||
# Grok Voice Agent emits server-side speech-start/stop and
|
||||
# interruption signals, so local VAD should stay out of the way.
|
||||
return (
|
||||
UserTurnStrategies(
|
||||
start=[ExternalUserTurnStartStrategy()],
|
||||
stop=[ExternalUserTurnStopStrategy()],
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
return (
|
||||
UserTurnStrategies(
|
||||
|
|
|
|||
|
|
@ -30,7 +30,13 @@ from pipecat.services.gladia.stt import GladiaSTTService, GladiaSTTSettings
|
|||
from pipecat.services.google.llm import GoogleLLMService, GoogleLLMSettings
|
||||
from pipecat.services.google.stt import GoogleSTTService, GoogleSTTSettings
|
||||
from pipecat.services.google.tts import GoogleTTSService, GoogleTTSSettings
|
||||
from pipecat.services.google.vertex.llm import (
|
||||
GoogleVertexLLMService,
|
||||
GoogleVertexLLMSettings,
|
||||
)
|
||||
from pipecat.services.groq.llm import GroqLLMService, GroqLLMSettings
|
||||
from pipecat.services.minimax.llm import MiniMaxLLMService
|
||||
from pipecat.services.minimax.tts import MiniMaxTTSSettings
|
||||
from pipecat.services.openai.base_llm import OpenAILLMSettings
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.services.openai.stt import (
|
||||
|
|
@ -40,8 +46,6 @@ from pipecat.services.openai.stt import (
|
|||
from pipecat.services.openai.tts import OpenAITTSService, OpenAITTSSettings
|
||||
from pipecat.services.openrouter.llm import OpenRouterLLMService, OpenRouterLLMSettings
|
||||
from pipecat.services.rime.tts import RimeTTSService, RimeTTSSettings
|
||||
from pipecat.services.minimax.llm import MiniMaxLLMService
|
||||
from pipecat.services.minimax.tts import MiniMaxHttpTTSService, MiniMaxTTSSettings
|
||||
from pipecat.services.sarvam.stt import SarvamSTTService, SarvamSTTSettings
|
||||
from pipecat.services.sarvam.tts import SarvamTTSService, SarvamTTSSettings
|
||||
from pipecat.services.speaches.llm import SpeachesLLMService, SpeachesLLMSettings
|
||||
|
|
@ -482,13 +486,16 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
|||
def create_llm_service_from_provider(
|
||||
provider: str,
|
||||
model: str,
|
||||
api_key: str,
|
||||
api_key: str | None,
|
||||
*,
|
||||
base_url: str | None = None,
|
||||
endpoint: str | None = None,
|
||||
aws_access_key: str | None = None,
|
||||
aws_secret_key: str | None = None,
|
||||
aws_region: str | None = None,
|
||||
project_id: str | None = None,
|
||||
location: str | None = None,
|
||||
credentials: str | None = None,
|
||||
temperature: float | None = None,
|
||||
):
|
||||
"""Create an LLM service from explicit provider/model/api_key.
|
||||
|
|
@ -528,6 +535,13 @@ def create_llm_service_from_provider(
|
|||
api_key=api_key,
|
||||
settings=GoogleLLMSettings(model=model, temperature=0.1),
|
||||
)
|
||||
elif provider == ServiceProviders.GOOGLE_VERTEX.value:
|
||||
return GoogleVertexLLMService(
|
||||
credentials=credentials,
|
||||
project_id=project_id,
|
||||
location=location or "us-east4",
|
||||
settings=GoogleVertexLLMSettings(model=model, temperature=0.1),
|
||||
)
|
||||
elif provider == ServiceProviders.AZURE.value:
|
||||
return AzureLLMService(
|
||||
api_key=api_key,
|
||||
|
|
@ -611,6 +625,21 @@ def create_realtime_llm_service(user_config, audio_config: "AudioConfig"):
|
|||
),
|
||||
),
|
||||
)
|
||||
elif provider == ServiceProviders.GROK_REALTIME.value:
|
||||
from api.services.pipecat.realtime.grok_realtime import (
|
||||
DograhGrokRealtimeLLMService,
|
||||
)
|
||||
from pipecat.services.xai.realtime.events import SessionProperties
|
||||
|
||||
return DograhGrokRealtimeLLMService(
|
||||
api_key=api_key,
|
||||
settings=DograhGrokRealtimeLLMService.Settings(
|
||||
model=model,
|
||||
session_properties=SessionProperties(
|
||||
voice=voice or "Ara",
|
||||
),
|
||||
),
|
||||
)
|
||||
elif provider == ServiceProviders.GOOGLE_REALTIME.value:
|
||||
from api.services.pipecat.realtime.gemini_live import (
|
||||
DograhGeminiLiveLLMService,
|
||||
|
|
@ -672,6 +701,10 @@ def create_llm_service(user_config):
|
|||
kwargs["aws_access_key"] = user_config.llm.aws_access_key
|
||||
kwargs["aws_secret_key"] = user_config.llm.aws_secret_key
|
||||
kwargs["aws_region"] = user_config.llm.aws_region
|
||||
elif provider == ServiceProviders.GOOGLE_VERTEX.value:
|
||||
kwargs["project_id"] = user_config.llm.project_id
|
||||
kwargs["location"] = user_config.llm.location
|
||||
kwargs["credentials"] = user_config.llm.credentials
|
||||
elif provider == ServiceProviders.MINIMAX.value:
|
||||
kwargs["base_url"] = user_config.llm.base_url
|
||||
kwargs["temperature"] = user_config.llm.temperature
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue