mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-22 08:38:13 +02:00
feat: add AWS Bedrock support
This commit is contained in:
parent
1604e306ec
commit
fe84f086ba
30 changed files with 546 additions and 195 deletions
|
|
@ -12,7 +12,7 @@ from api.services.pipecat.pipeline_metrics_aggregator import PipelineMetricsAggr
|
|||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.tasks.arq import enqueue_job
|
||||
from api.tasks.function_names import FunctionNames
|
||||
from pipecat.frames.frames import Frame, LLMContextFrame
|
||||
from pipecat.frames.frames import Frame, LLMContextFrame, TTSSpeakFrame
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.processors.audio.audio_buffer_processor import AudioBufferProcessor
|
||||
from pipecat.utils.enums import EndTaskReason
|
||||
|
|
@ -47,32 +47,44 @@ def register_event_handlers(
|
|||
sample_rate=sample_rate,
|
||||
num_channels=num_channels,
|
||||
)
|
||||
# Track both events to ensure LLM is only triggered after both occur
|
||||
# Track both events to ensure the initial response is only triggered after both occur
|
||||
ready_state = {
|
||||
"pipeline_started": False,
|
||||
"client_connected": False,
|
||||
"llm_triggered": False,
|
||||
"initial_response_triggered": False,
|
||||
}
|
||||
|
||||
async def maybe_trigger_llm():
|
||||
"""Trigger LLM only after both pipeline_started and client_connected events."""
|
||||
async def maybe_trigger_initial_response():
|
||||
"""Start the conversation after both pipeline_started and client_connected events.
|
||||
|
||||
If the start node has a greeting configured, play it directly via TTS.
|
||||
Otherwise, trigger an LLM generation for the opening message.
|
||||
"""
|
||||
if (
|
||||
ready_state["pipeline_started"]
|
||||
and ready_state["client_connected"]
|
||||
and not ready_state["llm_triggered"]
|
||||
and not ready_state["initial_response_triggered"]
|
||||
):
|
||||
ready_state["llm_triggered"] = True
|
||||
logger.debug(
|
||||
"Both pipeline_started and client_connected received - triggering initial LLM generation"
|
||||
)
|
||||
await engine.llm.queue_frame(LLMContextFrame(engine.context))
|
||||
ready_state["initial_response_triggered"] = True
|
||||
|
||||
greeting = engine.get_start_greeting()
|
||||
if greeting:
|
||||
logger.debug(
|
||||
"Both pipeline_started and client_connected received - playing greeting via TTS"
|
||||
)
|
||||
await task.queue_frame(TTSSpeakFrame(greeting))
|
||||
else:
|
||||
logger.debug(
|
||||
"Both pipeline_started and client_connected received - triggering initial LLM generation"
|
||||
)
|
||||
await engine.llm.queue_frame(LLMContextFrame(engine.context))
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(_transport, _participant):
|
||||
logger.debug("In on_client_connected callback handler")
|
||||
await audio_buffer.start_recording()
|
||||
ready_state["client_connected"] = True
|
||||
await maybe_trigger_llm()
|
||||
await maybe_trigger_initial_response()
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(_transport, _participant):
|
||||
|
|
@ -93,7 +105,7 @@ def register_event_handlers(
|
|||
async def on_pipeline_started(_task: PipelineTask, _frame: Frame):
|
||||
logger.debug("In on_pipeline_started callback handler")
|
||||
ready_state["pipeline_started"] = True
|
||||
await maybe_trigger_llm()
|
||||
await maybe_trigger_initial_response()
|
||||
|
||||
@task.event_handler("on_pipeline_error")
|
||||
async def on_pipeline_error(_task: PipelineTask, frame: Frame):
|
||||
|
|
|
|||
|
|
@ -74,9 +74,16 @@ def build_pipeline(
|
|||
if recording_router:
|
||||
post_llm.append(recording_router)
|
||||
|
||||
processors.append(user_context_aggregator)
|
||||
|
||||
# Insert LLM gate before the main LLM when voicemail detection is enabled.
|
||||
# This prevents the main LLM from being triggered until classification
|
||||
# determines whether a human or voicemail answered the call.
|
||||
if voicemail_detector:
|
||||
processors.append(voicemail_detector.llm_gate())
|
||||
|
||||
processors.extend(
|
||||
[
|
||||
user_context_aggregator,
|
||||
llm, # LLM
|
||||
*post_llm,
|
||||
tts, # TTS
|
||||
|
|
|
|||
|
|
@ -41,6 +41,7 @@ from pipecat.frames.frames import (
|
|||
MetricsFrame,
|
||||
StopFrame,
|
||||
TranscriptionFrame,
|
||||
TTSSpeakFrame,
|
||||
)
|
||||
from pipecat.metrics.metrics import TTFBMetricsData
|
||||
from pipecat.observers.base_observer import BaseObserver, FramePushed
|
||||
|
|
@ -205,6 +206,17 @@ class RealtimeFeedbackObserver(BaseObserver):
|
|||
},
|
||||
}
|
||||
)
|
||||
# Handle TTSSpeakFrame (e.g. greeting) - send immediately via WS only
|
||||
# Final turn text is persisted via on_assistant_turn_stopped to avoid duplication
|
||||
elif isinstance(frame, TTSSpeakFrame):
|
||||
await self._send_ws(
|
||||
{
|
||||
"type": RealtimeFeedbackType.BOT_TEXT.value,
|
||||
"payload": {
|
||||
"text": frame.text,
|
||||
},
|
||||
}
|
||||
)
|
||||
# Handle bot TTS text - respect pts timing, WebSocket only
|
||||
# Complete turn text is persisted via register_turn_handlers
|
||||
elif isinstance(frame, LLMTextFrame):
|
||||
|
|
|
|||
|
|
@ -173,7 +173,9 @@ async def _download_and_convert(
|
|||
Returns the processed PCM bytes, or None on failure.
|
||||
"""
|
||||
ext = _ext_from_key(recording.storage_key)
|
||||
fd, tmp_path = tempfile.mkstemp(suffix=ext, prefix=f"dograh_dl_{recording.recording_id}_")
|
||||
fd, tmp_path = tempfile.mkstemp(
|
||||
suffix=ext, prefix=f"dograh_dl_{recording.recording_id}_"
|
||||
)
|
||||
os.close(fd)
|
||||
try:
|
||||
storage = get_storage_fn(recording.storage_backend)
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ from api.services.pipecat.recording_audio_cache import (
|
|||
from api.services.pipecat.recording_router_processor import RecordingRouterProcessor
|
||||
from api.services.pipecat.service_factory import (
|
||||
create_llm_service,
|
||||
create_llm_service_from_provider,
|
||||
create_stt_service,
|
||||
create_tts_service,
|
||||
)
|
||||
|
|
@ -669,18 +670,31 @@ async def _run_pipeline(
|
|||
async def on_user_turn_started(aggregator, strategy):
|
||||
user_idle_handler.reset()
|
||||
|
||||
# Create voicemail detector if enabled in the workflow's start node
|
||||
# Create voicemail detector if enabled in workflow configurations
|
||||
voicemail_detector = None
|
||||
start_node = workflow_graph.nodes.get(workflow_graph.start_node_id)
|
||||
if start_node and start_node.detect_voicemail:
|
||||
voicemail_config = (workflow.workflow_configurations or {}).get(
|
||||
"voicemail_detection", {}
|
||||
)
|
||||
if voicemail_config.get("enabled", False):
|
||||
logger.info(f"Voicemail detection enabled for workflow run {workflow_run_id}")
|
||||
# Create a separate LLM instance for the voicemail sub-pipeline
|
||||
# (can't share with main pipeline as it would mess up frame linking)
|
||||
voicemail_llm = create_llm_service(user_config)
|
||||
if voicemail_config.get("use_workflow_llm", True):
|
||||
voicemail_llm = create_llm_service(user_config)
|
||||
else:
|
||||
voicemail_llm = create_llm_service_from_provider(
|
||||
provider=voicemail_config.get("provider", "openai"),
|
||||
model=voicemail_config.get("model", "gpt-4.1"),
|
||||
api_key=voicemail_config.get("api_key", ""),
|
||||
)
|
||||
|
||||
long_speech_timeout = voicemail_config.get("long_speech_timeout", 8.0)
|
||||
custom_system_prompt = voicemail_config.get("system_prompt") or None
|
||||
|
||||
voicemail_detector = VoicemailDetector(
|
||||
llm=voicemail_llm,
|
||||
voicemail_response_delay=1.0,
|
||||
long_speech_timeout=8.0,
|
||||
long_speech_timeout=long_speech_timeout,
|
||||
custom_system_prompt=custom_system_prompt,
|
||||
)
|
||||
|
||||
# Register event handler to end task when voicemail is detected
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from loguru import logger
|
|||
|
||||
from api.constants import MPS_API_URL
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from pipecat.services.aws.llm import AWSBedrockLLMService, AWSBedrockLLMSettings
|
||||
from pipecat.services.azure.llm import AzureLLMService, AzureLLMSettings
|
||||
from pipecat.services.cartesia.stt import CartesiaSTTService
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService, CartesiaTTSSettings
|
||||
|
|
@ -268,56 +269,91 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
|||
)
|
||||
|
||||
|
||||
def create_llm_service(user_config):
|
||||
"""Create and return appropriate LLM service based on user configuration"""
|
||||
model = user_config.llm.model
|
||||
logger.info(
|
||||
f"Creating LLM service: provider={user_config.llm.provider}, model={model}"
|
||||
)
|
||||
if user_config.llm.provider == ServiceProviders.OPENAI.value:
|
||||
def create_llm_service_from_provider(
|
||||
provider: str,
|
||||
model: str,
|
||||
api_key: str,
|
||||
*,
|
||||
base_url: str | None = None,
|
||||
endpoint: str | None = None,
|
||||
aws_access_key: str | None = None,
|
||||
aws_secret_key: str | None = None,
|
||||
aws_region: str | None = None,
|
||||
):
|
||||
"""Create an LLM service from explicit provider/model/api_key.
|
||||
|
||||
Also used by create_llm_service which extracts these from user_config.
|
||||
"""
|
||||
logger.info(f"Creating LLM service: provider={provider}, model={model}")
|
||||
if provider == ServiceProviders.OPENAI.value:
|
||||
if "gpt-5" in model:
|
||||
return OpenAILLMService(
|
||||
api_key=user_config.llm.api_key,
|
||||
api_key=api_key,
|
||||
settings=OpenAILLMSettings(
|
||||
model=model,
|
||||
extra={"reasoning_effort": "minimal", "verbosity": "low"},
|
||||
),
|
||||
)
|
||||
else:
|
||||
return OpenAILLMService(
|
||||
api_key=user_config.llm.api_key,
|
||||
settings=OpenAILLMSettings(model=model, temperature=0.1),
|
||||
)
|
||||
elif user_config.llm.provider == ServiceProviders.GROQ.value:
|
||||
print(
|
||||
f"Creating Groq LLM service with API key: {user_config.llm.api_key} and model: {model}"
|
||||
return OpenAILLMService(
|
||||
api_key=api_key,
|
||||
settings=OpenAILLMSettings(model=model, temperature=0.1),
|
||||
)
|
||||
elif provider == ServiceProviders.GROQ.value:
|
||||
return GroqLLMService(
|
||||
api_key=user_config.llm.api_key,
|
||||
api_key=api_key,
|
||||
settings=GroqLLMSettings(model=model, temperature=0.1),
|
||||
)
|
||||
elif user_config.llm.provider == ServiceProviders.OPENROUTER.value:
|
||||
elif provider == ServiceProviders.OPENROUTER.value:
|
||||
kwargs = {}
|
||||
if base_url:
|
||||
kwargs["base_url"] = base_url
|
||||
return OpenRouterLLMService(
|
||||
api_key=user_config.llm.api_key,
|
||||
base_url=user_config.llm.base_url,
|
||||
api_key=api_key,
|
||||
settings=OpenRouterLLMSettings(model=model, temperature=0.1),
|
||||
**kwargs,
|
||||
)
|
||||
elif user_config.llm.provider == ServiceProviders.GOOGLE.value:
|
||||
elif provider == ServiceProviders.GOOGLE.value:
|
||||
return GoogleLLMService(
|
||||
api_key=user_config.llm.api_key,
|
||||
api_key=api_key,
|
||||
settings=GoogleLLMSettings(model=model, temperature=0.1),
|
||||
)
|
||||
elif user_config.llm.provider == ServiceProviders.AZURE.value:
|
||||
elif provider == ServiceProviders.AZURE.value:
|
||||
return AzureLLMService(
|
||||
api_key=user_config.llm.api_key,
|
||||
endpoint=user_config.llm.endpoint,
|
||||
api_key=api_key,
|
||||
endpoint=endpoint,
|
||||
settings=AzureLLMSettings(model=model, temperature=0.1),
|
||||
)
|
||||
elif user_config.llm.provider == ServiceProviders.DOGRAH.value:
|
||||
elif provider == ServiceProviders.DOGRAH.value:
|
||||
return DograhLLMService(
|
||||
base_url=f"{MPS_API_URL}/api/v1/llm",
|
||||
api_key=user_config.llm.api_key,
|
||||
api_key=api_key,
|
||||
settings=OpenAILLMSettings(model=model),
|
||||
)
|
||||
elif provider == ServiceProviders.AWS_BEDROCK.value:
|
||||
return AWSBedrockLLMService(
|
||||
aws_access_key=aws_access_key,
|
||||
aws_secret_key=aws_secret_key,
|
||||
aws_region=aws_region,
|
||||
settings=AWSBedrockLLMSettings(model=model),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Invalid LLM provider")
|
||||
raise HTTPException(status_code=400, detail=f"Invalid LLM provider {provider}")
|
||||
|
||||
|
||||
def create_llm_service(user_config):
|
||||
"""Create and return appropriate LLM service based on user configuration."""
|
||||
provider = user_config.llm.provider
|
||||
model = user_config.llm.model
|
||||
api_key = user_config.llm.api_key
|
||||
|
||||
kwargs = {}
|
||||
if provider == ServiceProviders.OPENROUTER.value:
|
||||
kwargs["base_url"] = user_config.llm.base_url
|
||||
elif provider == ServiceProviders.AZURE.value:
|
||||
kwargs["endpoint"] = user_config.llm.endpoint
|
||||
elif provider == ServiceProviders.AWS_BEDROCK.value:
|
||||
kwargs["aws_access_key"] = user_config.llm.aws_access_key
|
||||
kwargs["aws_secret_key"] = user_config.llm.aws_secret_key
|
||||
kwargs["aws_region"] = user_config.llm.aws_region
|
||||
|
||||
return create_llm_service_from_provider(provider, model, api_key, **kwargs)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue