mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
feat: add gemini live and speaches integration (#220)
* feat: add speaches models * feat: add gemini realtime and speaches integration - Add gemini realtime support - Add speaches support for locally hosted LLMs * chore: bump pipecat * feat: add language option * fix: add skip aggregator types to tts settings * fix: make API key optional for realtime
This commit is contained in:
parent
e0c3d6c3bf
commit
87e72d5f6f
19 changed files with 743 additions and 270 deletions
|
|
@ -67,6 +67,10 @@ def setup_logging():
|
|||
# Handler might already be removed
|
||||
pass
|
||||
|
||||
# Set default extra values on the shared core so ALL logger references
|
||||
# (including ones imported before this runs) have run_id available.
|
||||
loguru.logger.configure(extra={"run_id": None})
|
||||
|
||||
# Patch loguru to inject run_id
|
||||
patched = loguru.logger.patch(inject_run_id)
|
||||
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ class DefaultConfigurationsResponse(TypedDict):
|
|||
tts: dict[str, dict]
|
||||
stt: dict[str, dict]
|
||||
embeddings: dict[str, dict]
|
||||
realtime: dict[str, dict]
|
||||
default_providers: dict[str, str]
|
||||
|
||||
|
||||
|
|
@ -55,6 +56,10 @@ async def get_default_configurations() -> DefaultConfigurationsResponse:
|
|||
provider: model_cls.model_json_schema()
|
||||
for provider, model_cls in REGISTRY[ServiceType.EMBEDDINGS].items()
|
||||
},
|
||||
"realtime": {
|
||||
provider: model_cls.model_json_schema()
|
||||
for provider, model_cls in REGISTRY[ServiceType.REALTIME].items()
|
||||
},
|
||||
"default_providers": DEFAULT_SERVICE_PROVIDERS,
|
||||
}
|
||||
return configurations
|
||||
|
|
@ -75,6 +80,8 @@ class UserConfigurationRequestResponseSchema(BaseModel):
|
|||
tts: dict[str, Union[str, float, list[str], None]] | None = None
|
||||
stt: dict[str, Union[str, float, list[str], None]] | None = None
|
||||
embeddings: dict[str, Union[str, float, list[str], None]] | None = None
|
||||
realtime: dict[str, Union[str, float, list[str], None]] | None = None
|
||||
is_realtime: bool | None = None
|
||||
test_phone_number: str | None = None
|
||||
timezone: str | None = None
|
||||
organization_pricing: dict[str, Union[float, str, bool]] | None = None
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from api.services.configuration.registry import (
|
||||
EmbeddingsConfig,
|
||||
LLMConfig,
|
||||
RealtimeConfig,
|
||||
STTConfig,
|
||||
TTSConfig,
|
||||
)
|
||||
|
|
@ -15,6 +16,18 @@ class UserConfiguration(BaseModel):
|
|||
stt: STTConfig | None = None
|
||||
tts: TTSConfig | None = None
|
||||
embeddings: EmbeddingsConfig | None = None
|
||||
realtime: RealtimeConfig | None = None
|
||||
is_realtime: bool = False
|
||||
test_phone_number: str | None = None
|
||||
timezone: str | None = None
|
||||
last_validated_at: datetime | None = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def strip_incomplete_realtime_when_disabled(cls, data):
|
||||
"""Skip realtime validation when is_realtime is False and api_key is missing."""
|
||||
if isinstance(data, dict) and not data.get("is_realtime", False):
|
||||
realtime = data.get("realtime")
|
||||
if isinstance(realtime, dict) and not realtime.get("api_key"):
|
||||
data.pop("realtime", None)
|
||||
return data
|
||||
|
|
|
|||
|
|
@ -46,7 +46,9 @@ class UserConfigurationValidator:
|
|||
ServiceProviders.SPEECHMATICS.value: self._check_speechmatics_api_key,
|
||||
ServiceProviders.CAMB.value: self._check_camb_api_key,
|
||||
ServiceProviders.AWS_BEDROCK.value: self._check_aws_bedrock_api_key,
|
||||
ServiceProviders.SELF_HOSTED.value: self._check_self_hosted_api_key,
|
||||
ServiceProviders.SPEACHES.value: self._check_speaches_api_key,
|
||||
ServiceProviders.OPENAI_REALTIME.value: self._check_openai_api_key,
|
||||
ServiceProviders.GOOGLE_REALTIME.value: self._check_google_api_key,
|
||||
}
|
||||
|
||||
async def validate(
|
||||
|
|
@ -70,6 +72,13 @@ class UserConfigurationValidator:
|
|||
configuration.embeddings, "embeddings", required=False
|
||||
)
|
||||
)
|
||||
# Realtime is optional - only validate if is_realtime is enabled
|
||||
if configuration.is_realtime:
|
||||
status_list.extend(
|
||||
self._validate_service(
|
||||
configuration.realtime, "realtime", required=True
|
||||
)
|
||||
)
|
||||
|
||||
if status_list:
|
||||
raise ValueError(status_list)
|
||||
|
|
@ -90,10 +99,10 @@ class UserConfigurationValidator:
|
|||
|
||||
provider = service_config.provider
|
||||
|
||||
# Self-hosted doesn't require an API key
|
||||
if provider == ServiceProviders.SELF_HOSTED.value:
|
||||
# Speaches doesn't require an API key
|
||||
if provider == ServiceProviders.SPEACHES.value:
|
||||
try:
|
||||
if not self._check_self_hosted_api_key(provider, service_config):
|
||||
if not self._check_speaches_api_key(provider, service_config):
|
||||
return [
|
||||
{
|
||||
"model": service_name,
|
||||
|
|
@ -199,9 +208,9 @@ class UserConfigurationValidator:
|
|||
def _check_camb_api_key(self, model: str, api_key: str) -> bool:
|
||||
return True
|
||||
|
||||
def _check_self_hosted_api_key(self, model: str, service_config) -> bool:
|
||||
def _check_speaches_api_key(self, model: str, service_config) -> bool:
|
||||
if not getattr(service_config, "base_url", None):
|
||||
raise ValueError("base_url is required for self-hosted LLM")
|
||||
raise ValueError("base_url is required for Speaches services")
|
||||
return True
|
||||
|
||||
def _check_aws_bedrock_api_key(self, model: str, service_config) -> bool:
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ def contains_masked_key(api_key: str | list[str] | None) -> bool:
|
|||
|
||||
def check_for_masked_keys(config: "UserConfiguration") -> None:
|
||||
"""Raise ValueError if any service in *config* still has a masked API key."""
|
||||
for field in ("llm", "tts", "stt", "embeddings"):
|
||||
for field in ("llm", "tts", "stt", "embeddings", "realtime"):
|
||||
service = getattr(config, field, None)
|
||||
if service is None:
|
||||
continue
|
||||
|
|
@ -121,6 +121,8 @@ def mask_user_config(config: UserConfiguration) -> Dict[str, Any]:
|
|||
"tts": _mask_service(config.tts),
|
||||
"stt": _mask_service(config.stt),
|
||||
"embeddings": _mask_service(config.embeddings),
|
||||
"realtime": _mask_service(config.realtime),
|
||||
"is_realtime": config.is_realtime,
|
||||
"test_phone_number": config.test_phone_number,
|
||||
"timezone": config.timezone,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from typing import Dict
|
|||
from api.schemas.user_configuration import UserConfiguration
|
||||
from api.services.configuration.masking import resolve_masked_api_keys
|
||||
|
||||
SERVICE_FIELDS = ("llm", "tts", "stt", "embeddings")
|
||||
SERVICE_FIELDS = ("llm", "tts", "stt", "embeddings", "realtime")
|
||||
|
||||
|
||||
def merge_user_configurations(
|
||||
|
|
@ -64,6 +64,9 @@ def merge_user_configurations(
|
|||
_merge_service_block(service)
|
||||
|
||||
# other simple fields
|
||||
if "is_realtime" in incoming_partial:
|
||||
merged["is_realtime"] = incoming_partial["is_realtime"]
|
||||
|
||||
if "test_phone_number" in incoming_partial:
|
||||
merged["test_phone_number"] = incoming_partial["test_phone_number"]
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ class ServiceType(Enum):
|
|||
TTS = auto()
|
||||
STT = auto()
|
||||
EMBEDDINGS = auto()
|
||||
REALTIME = auto()
|
||||
|
||||
|
||||
class ServiceProviders(str, Enum):
|
||||
|
|
@ -27,7 +28,9 @@ class ServiceProviders(str, Enum):
|
|||
SPEECHMATICS = "speechmatics"
|
||||
CAMB = "camb"
|
||||
AWS_BEDROCK = "aws_bedrock"
|
||||
SELF_HOSTED = "self_hosted"
|
||||
SPEACHES = "speaches"
|
||||
OPENAI_REALTIME = "openai_realtime"
|
||||
GOOGLE_REALTIME = "google_realtime"
|
||||
|
||||
|
||||
class BaseServiceConfiguration(BaseModel):
|
||||
|
|
@ -41,7 +44,9 @@ class BaseServiceConfiguration(BaseModel):
|
|||
ServiceProviders.AZURE,
|
||||
ServiceProviders.DOGRAH,
|
||||
ServiceProviders.AWS_BEDROCK,
|
||||
ServiceProviders.SELF_HOSTED,
|
||||
ServiceProviders.SPEACHES,
|
||||
ServiceProviders.OPENAI_REALTIME,
|
||||
ServiceProviders.GOOGLE_REALTIME,
|
||||
# ServiceProviders.SARVAM,
|
||||
]
|
||||
api_key: str | list[str]
|
||||
|
|
@ -97,6 +102,7 @@ REGISTRY: Dict[ServiceType, Dict[str, Type[BaseServiceConfiguration]]] = {
|
|||
ServiceType.TTS: {},
|
||||
ServiceType.STT: {},
|
||||
ServiceType.EMBEDDINGS: {},
|
||||
ServiceType.REALTIME: {},
|
||||
}
|
||||
|
||||
T = TypeVar("T", bound=BaseServiceConfiguration)
|
||||
|
|
@ -191,14 +197,18 @@ AWS_BEDROCK_MODELS = [
|
|||
@register_llm
|
||||
class OpenAILLMService(BaseLLMConfiguration):
|
||||
provider: Literal[ServiceProviders.OPENAI] = ServiceProviders.OPENAI
|
||||
model: str = Field(default="gpt-4.1", json_schema_extra={"examples": OPENAI_MODELS})
|
||||
model: str = Field(
|
||||
default="gpt-4.1",
|
||||
json_schema_extra={"examples": OPENAI_MODELS, "allow_custom_input": True},
|
||||
)
|
||||
|
||||
|
||||
@register_llm
|
||||
class GoogleLLMService(BaseLLMConfiguration):
|
||||
provider: Literal[ServiceProviders.GOOGLE] = ServiceProviders.GOOGLE
|
||||
model: str = Field(
|
||||
default="gemini-2.0-flash", json_schema_extra={"examples": GOOGLE_MODELS}
|
||||
default="gemini-2.0-flash",
|
||||
json_schema_extra={"examples": GOOGLE_MODELS, "allow_custom_input": True},
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -206,7 +216,8 @@ class GoogleLLMService(BaseLLMConfiguration):
|
|||
class GroqLLMService(BaseLLMConfiguration):
|
||||
provider: Literal[ServiceProviders.GROQ] = ServiceProviders.GROQ
|
||||
model: str = Field(
|
||||
default="llama-3.3-70b-versatile", json_schema_extra={"examples": GROQ_MODELS}
|
||||
default="llama-3.3-70b-versatile",
|
||||
json_schema_extra={"examples": GROQ_MODELS, "allow_custom_input": True},
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -214,7 +225,8 @@ class GroqLLMService(BaseLLMConfiguration):
|
|||
class OpenRouterLLMConfiguration(BaseLLMConfiguration):
|
||||
provider: Literal[ServiceProviders.OPENROUTER] = ServiceProviders.OPENROUTER
|
||||
model: str = Field(
|
||||
default="openai/gpt-4.1", json_schema_extra={"examples": OPENROUTER_MODELS}
|
||||
default="openai/gpt-4.1",
|
||||
json_schema_extra={"examples": OPENROUTER_MODELS, "allow_custom_input": True},
|
||||
)
|
||||
|
||||
base_url: str = Field(default="https://openrouter.ai/api/v1")
|
||||
|
|
@ -224,7 +236,8 @@ class OpenRouterLLMConfiguration(BaseLLMConfiguration):
|
|||
class AzureLLMService(BaseLLMConfiguration):
|
||||
provider: Literal[ServiceProviders.AZURE] = ServiceProviders.AZURE
|
||||
model: str = Field(
|
||||
default="gpt-4.1-mini", json_schema_extra={"examples": AZURE_MODELS}
|
||||
default="gpt-4.1-mini",
|
||||
json_schema_extra={"examples": AZURE_MODELS, "allow_custom_input": True},
|
||||
)
|
||||
|
||||
endpoint: str
|
||||
|
|
@ -234,7 +247,8 @@ class AzureLLMService(BaseLLMConfiguration):
|
|||
class DograhLLMService(BaseLLMConfiguration):
|
||||
provider: Literal[ServiceProviders.DOGRAH] = ServiceProviders.DOGRAH
|
||||
model: str = Field(
|
||||
default="default", json_schema_extra={"examples": DOGRAH_LLM_MODELS}
|
||||
default="default",
|
||||
json_schema_extra={"examples": DOGRAH_LLM_MODELS, "allow_custom_input": True},
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -243,7 +257,7 @@ class AWSBedrockLLMConfiguration(BaseLLMConfiguration):
|
|||
provider: Literal[ServiceProviders.AWS_BEDROCK] = ServiceProviders.AWS_BEDROCK
|
||||
model: str = Field(
|
||||
default="us.amazon.nova-pro-v1:0",
|
||||
json_schema_extra={"examples": AWS_BEDROCK_MODELS},
|
||||
json_schema_extra={"examples": AWS_BEDROCK_MODELS, "allow_custom_input": True},
|
||||
)
|
||||
aws_access_key: str = Field(default="")
|
||||
aws_secret_key: str = Field(default="")
|
||||
|
|
@ -251,14 +265,18 @@ class AWSBedrockLLMConfiguration(BaseLLMConfiguration):
|
|||
api_key: str | list[str] | None = Field(default=None)
|
||||
|
||||
|
||||
SELF_HOSTED_LLM_MODELS = ["llama3", "mistral", "phi3", "qwen2", "gemma2", "deepseek-r1"]
|
||||
SPEACHES_LLM_MODELS = ["llama3", "mistral", "phi3", "qwen2", "gemma2", "deepseek-r1"]
|
||||
|
||||
|
||||
@register_llm
|
||||
class SelfHostedLLMConfiguration(BaseLLMConfiguration):
|
||||
provider: Literal[ServiceProviders.SELF_HOSTED] = ServiceProviders.SELF_HOSTED
|
||||
class SpeachesLLMConfiguration(BaseLLMConfiguration):
|
||||
provider: Literal[ServiceProviders.SPEACHES] = ServiceProviders.SPEACHES
|
||||
model: str = Field(
|
||||
default="llama3", json_schema_extra={"examples": SELF_HOSTED_LLM_MODELS}
|
||||
default="llama3",
|
||||
json_schema_extra={
|
||||
"examples": SPEACHES_LLM_MODELS,
|
||||
"allow_custom_input": True,
|
||||
},
|
||||
)
|
||||
base_url: str = Field(
|
||||
default="http://localhost:11434/v1",
|
||||
|
|
@ -267,6 +285,78 @@ class SelfHostedLLMConfiguration(BaseLLMConfiguration):
|
|||
api_key: str | list[str] | None = Field(default=None)
|
||||
|
||||
|
||||
OPENAI_REALTIME_MODELS = ["gpt-4o-realtime-preview", "gpt-4o-mini-realtime-preview"]
|
||||
OPENAI_REALTIME_VOICES = [
|
||||
"alloy",
|
||||
"ash",
|
||||
"ballad",
|
||||
"coral",
|
||||
"echo",
|
||||
"sage",
|
||||
"shimmer",
|
||||
"verse",
|
||||
]
|
||||
|
||||
|
||||
# @register_service(ServiceType.REALTIME)
|
||||
# class OpenAIRealtimeLLMConfiguration(BaseLLMConfiguration):
|
||||
# provider: Literal[ServiceProviders.OPENAI_REALTIME] = (
|
||||
# ServiceProviders.OPENAI_REALTIME
|
||||
# )
|
||||
# model: str = Field(
|
||||
# default="gpt-4o-realtime-preview",
|
||||
# json_schema_extra={
|
||||
# "examples": OPENAI_REALTIME_MODELS,
|
||||
# "allow_custom_input": True,
|
||||
# },
|
||||
# )
|
||||
# voice: str = Field(
|
||||
# default="alloy",
|
||||
# json_schema_extra={"examples": OPENAI_REALTIME_VOICES},
|
||||
# )
|
||||
|
||||
|
||||
GOOGLE_REALTIME_MODELS = ["gemini-3.1-flash-live-preview"]
|
||||
GOOGLE_REALTIME_VOICES = ["Puck", "Charon", "Kore", "Fenrir", "Aoede"]
|
||||
GOOGLE_REALTIME_LANGUAGES = [
|
||||
"en"
|
||||
]
|
||||
|
||||
|
||||
@register_service(ServiceType.REALTIME)
|
||||
class GoogleRealtimeLLMConfiguration(BaseLLMConfiguration):
|
||||
provider: Literal[ServiceProviders.GOOGLE_REALTIME] = (
|
||||
ServiceProviders.GOOGLE_REALTIME
|
||||
)
|
||||
model: str = Field(
|
||||
default="gemini-3.1-flash-live-preview",
|
||||
json_schema_extra={
|
||||
"examples": GOOGLE_REALTIME_MODELS,
|
||||
"allow_custom_input": True,
|
||||
},
|
||||
)
|
||||
voice: str = Field(
|
||||
default="Puck",
|
||||
json_schema_extra={
|
||||
"examples": GOOGLE_REALTIME_VOICES,
|
||||
"allow_custom_input": True,
|
||||
},
|
||||
)
|
||||
language: str = Field(
|
||||
default="en",
|
||||
json_schema_extra={
|
||||
"examples": GOOGLE_REALTIME_LANGUAGES,
|
||||
"allow_custom_input": True,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
REALTIME_PROVIDERS = {
|
||||
ServiceProviders.OPENAI_REALTIME.value,
|
||||
ServiceProviders.GOOGLE_REALTIME.value,
|
||||
}
|
||||
|
||||
|
||||
LLMConfig = Annotated[
|
||||
Union[
|
||||
OpenAILLMService,
|
||||
|
|
@ -276,7 +366,15 @@ LLMConfig = Annotated[
|
|||
AzureLLMService,
|
||||
DograhLLMService,
|
||||
AWSBedrockLLMConfiguration,
|
||||
SelfHostedLLMConfiguration,
|
||||
SpeachesLLMConfiguration,
|
||||
],
|
||||
Field(discriminator="provider"),
|
||||
]
|
||||
|
||||
RealtimeConfig = Annotated[
|
||||
Union[
|
||||
# OpenAIRealtimeLLMConfiguration,
|
||||
GoogleRealtimeLLMConfiguration,
|
||||
],
|
||||
Field(discriminator="provider"),
|
||||
]
|
||||
|
|
@ -462,6 +560,34 @@ class CambTTSConfiguration(BaseTTSConfiguration):
|
|||
language: str = Field(default="en-us", description="BCP-47 language code")
|
||||
|
||||
|
||||
SPEACHES_TTS_MODELS = ["hexgrad/Kokoro-82M"]
|
||||
|
||||
|
||||
@register_tts
|
||||
class SpeachesTTSConfiguration(BaseTTSConfiguration):
|
||||
provider: Literal[ServiceProviders.SPEACHES] = ServiceProviders.SPEACHES
|
||||
model: str = Field(
|
||||
default="kokoro",
|
||||
json_schema_extra={
|
||||
"examples": SPEACHES_TTS_MODELS,
|
||||
"allow_custom_input": True,
|
||||
},
|
||||
)
|
||||
voice: str = Field(
|
||||
default="af_heart",
|
||||
json_schema_extra={"allow_custom_input": True},
|
||||
description="Voice ID for the TTS engine",
|
||||
)
|
||||
base_url: str = Field(
|
||||
default="http://localhost:8000/v1",
|
||||
description="OpenAI-compatible TTS endpoint (Kokoro-FastAPI, etc.)",
|
||||
)
|
||||
speed: float = Field(
|
||||
default=1.0, ge=0.25, le=4.0, description="Speech speed (0.25 to 4.0)"
|
||||
)
|
||||
api_key: str | list[str] | None = Field(default=None)
|
||||
|
||||
|
||||
TTSConfig = Annotated[
|
||||
Union[
|
||||
DeepgramTTSConfiguration,
|
||||
|
|
@ -471,6 +597,7 @@ TTSConfig = Annotated[
|
|||
DograhTTSService,
|
||||
SarvamTTSConfiguration,
|
||||
CambTTSConfiguration,
|
||||
SpeachesTTSConfiguration,
|
||||
],
|
||||
Field(discriminator="provider"),
|
||||
]
|
||||
|
|
@ -674,6 +801,37 @@ class SpeechmaticsSTTConfiguration(BaseSTTConfiguration):
|
|||
)
|
||||
|
||||
|
||||
SPEACHES_STT_MODELS = [
|
||||
"Systran/faster-distil-whisper-small.en",
|
||||
"Systran/faster-whisper-large-v3",
|
||||
]
|
||||
SPEACHES_STT_LANGUAGES = ["en", "ar", "nl", "fr", "de", "hi", "it", "pt", "es"]
|
||||
|
||||
|
||||
@register_stt
|
||||
class SpeachesSTTConfiguration(BaseSTTConfiguration):
|
||||
provider: Literal[ServiceProviders.SPEACHES] = ServiceProviders.SPEACHES
|
||||
model: str = Field(
|
||||
default="Systran/faster-distil-whisper-small.en",
|
||||
json_schema_extra={
|
||||
"examples": SPEACHES_STT_MODELS,
|
||||
"allow_custom_input": True,
|
||||
},
|
||||
)
|
||||
language: str = Field(
|
||||
default="en",
|
||||
json_schema_extra={
|
||||
"examples": SPEACHES_STT_LANGUAGES,
|
||||
"allow_custom_input": True,
|
||||
},
|
||||
)
|
||||
base_url: str = Field(
|
||||
default="http://localhost:8000/v1",
|
||||
description="OpenAI-compatible STT endpoint (Speaches, etc.)",
|
||||
)
|
||||
api_key: str | list[str] | None = Field(default=None)
|
||||
|
||||
|
||||
STTConfig = Annotated[
|
||||
Union[
|
||||
DeepgramSTTConfiguration,
|
||||
|
|
@ -682,6 +840,7 @@ STTConfig = Annotated[
|
|||
DograhSTTService,
|
||||
SpeechmaticsSTTConfiguration,
|
||||
SarvamSTTConfiguration,
|
||||
SpeachesSTTConfiguration,
|
||||
],
|
||||
Field(discriminator="provider"),
|
||||
]
|
||||
|
|
@ -720,6 +879,6 @@ EmbeddingsConfig = Annotated[
|
|||
]
|
||||
|
||||
ServiceConfig = Annotated[
|
||||
Union[LLMConfig, TTSConfig, STTConfig, EmbeddingsConfig],
|
||||
Union[LLMConfig, RealtimeConfig, TTSConfig, STTConfig, EmbeddingsConfig],
|
||||
Field(discriminator="provider"),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -97,6 +97,34 @@ def build_pipeline(
|
|||
return Pipeline(processors)
|
||||
|
||||
|
||||
def build_realtime_pipeline(
|
||||
transport,
|
||||
realtime_llm,
|
||||
audio_buffer,
|
||||
user_context_aggregator,
|
||||
assistant_context_aggregator,
|
||||
pipeline_engine_callback_processor,
|
||||
pipeline_metrics_aggregator,
|
||||
):
|
||||
"""Build a pipeline for realtime (speech-to-speech) LLM services.
|
||||
|
||||
Realtime services (e.g. OpenAI Realtime, Gemini Live) handle STT+LLM+TTS
|
||||
internally, so no separate STT or TTS processors are needed.
|
||||
"""
|
||||
processors = [
|
||||
transport.input(),
|
||||
user_context_aggregator,
|
||||
realtime_llm,
|
||||
pipeline_engine_callback_processor,
|
||||
transport.output(),
|
||||
audio_buffer,
|
||||
assistant_context_aggregator,
|
||||
pipeline_metrics_aggregator,
|
||||
]
|
||||
|
||||
return Pipeline(processors)
|
||||
|
||||
|
||||
def create_pipeline_task(pipeline, workflow_run_id, audio_config: AudioConfig = None):
|
||||
"""Create a pipeline task with appropriate parameters"""
|
||||
# Set up pipeline params with audio configuration if provided
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from api.services.pipecat.event_handlers import (
|
|||
from api.services.pipecat.in_memory_buffers import InMemoryLogsBuffer
|
||||
from api.services.pipecat.pipeline_builder import (
|
||||
build_pipeline,
|
||||
build_realtime_pipeline,
|
||||
create_pipeline_components,
|
||||
create_pipeline_task,
|
||||
)
|
||||
|
|
@ -35,6 +36,7 @@ from api.services.pipecat.recording_router_processor import RecordingRouterProce
|
|||
from api.services.pipecat.service_factory import (
|
||||
create_llm_service,
|
||||
create_llm_service_from_provider,
|
||||
create_realtime_llm_service,
|
||||
create_stt_service,
|
||||
create_tts_service,
|
||||
)
|
||||
|
|
@ -603,10 +605,18 @@ async def _run_pipeline(
|
|||
term.strip() for term in dictionary.split(",") if term.strip()
|
||||
]
|
||||
|
||||
# Detect realtime mode (speech-to-speech services like OpenAI Realtime, Gemini Live)
|
||||
is_realtime = user_config.is_realtime and user_config.realtime is not None
|
||||
|
||||
# Create services based on user configuration
|
||||
stt = create_stt_service(user_config, audio_config, keyterms=keyterms)
|
||||
tts = create_tts_service(user_config, audio_config)
|
||||
llm = create_llm_service(user_config)
|
||||
if is_realtime:
|
||||
llm = create_realtime_llm_service(user_config, audio_config)
|
||||
stt = None
|
||||
tts = None
|
||||
else:
|
||||
stt = create_stt_service(user_config, audio_config, keyterms=keyterms)
|
||||
tts = create_tts_service(user_config, audio_config)
|
||||
llm = create_llm_service(user_config)
|
||||
|
||||
workflow_graph = WorkflowGraph(
|
||||
ReactFlowDTO.model_validate(workflow.workflow_definition_with_fallback)
|
||||
|
|
@ -694,46 +704,66 @@ async def _run_pipeline(
|
|||
)
|
||||
|
||||
# Configure turn strategies based on STT provider, model, and workflow configuration
|
||||
# Deepgram Flux uses external turn detection (VAD + External start/stop)
|
||||
# Other models use configurable turn detection strategy
|
||||
is_deepgram_flux = (
|
||||
user_config.stt.provider == ServiceProviders.DEEPGRAM.value
|
||||
and user_config.stt.model == "flux-general-en"
|
||||
)
|
||||
if is_realtime:
|
||||
# Realtime services have server-side VAD/turn detection.
|
||||
# For stop strategy, lets rely on SmartTurnAnalyzer which is
|
||||
# enabled by default
|
||||
user_turn_strategies = UserTurnStrategies(
|
||||
start=[VADUserTurnStartStrategy()], stop=[]
|
||||
)
|
||||
|
||||
if is_deepgram_flux:
|
||||
user_turn_strategies = UserTurnStrategies(
|
||||
start=[
|
||||
VADUserTurnStartStrategy(),
|
||||
ExternalUserTurnStartStrategy(enable_interruptions=True),
|
||||
],
|
||||
stop=[ExternalUserTurnStopStrategy()],
|
||||
)
|
||||
elif turn_stop_strategy == "turn_analyzer":
|
||||
# Smart Turn Analyzer: best for longer responses with natural pauses
|
||||
smart_turn_params = SmartTurnParams(stop_secs=smart_turn_stop_secs)
|
||||
user_turn_strategies = UserTurnStrategies(
|
||||
start=[VADUserTurnStartStrategy(), TranscriptionUserTurnStartStrategy()],
|
||||
stop=[
|
||||
TurnAnalyzerUserTurnStopStrategy(
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=smart_turn_params)
|
||||
)
|
||||
],
|
||||
)
|
||||
# Lets not start the pipeline as muted for Realtime
|
||||
# - CallbackUserMuteStrategy: mutes based on engine's _mute_pipeline state
|
||||
user_mute_strategies = [
|
||||
FunctionCallUserMuteStrategy(),
|
||||
CallbackUserMuteStrategy(should_mute_callback=engine.should_mute_user),
|
||||
]
|
||||
else:
|
||||
# Transcription-based (default): best for short 1-2 word responses
|
||||
user_turn_strategies = UserTurnStrategies(
|
||||
start=[VADUserTurnStartStrategy(), TranscriptionUserTurnStartStrategy()],
|
||||
stop=[SpeechTimeoutUserTurnStopStrategy()],
|
||||
# Deepgram Flux uses external turn detection (VAD + External start/stop)
|
||||
# Other models use configurable turn detection strategy
|
||||
is_deepgram_flux = (
|
||||
user_config.stt.provider == ServiceProviders.DEEPGRAM.value
|
||||
and user_config.stt.model == "flux-general-en"
|
||||
)
|
||||
|
||||
# Create user mute strategies
|
||||
# - CallbackUserMuteStrategy: mutes based on engine's _mute_pipeline state
|
||||
user_mute_strategies = [
|
||||
MuteUntilFirstBotCompleteUserMuteStrategy(),
|
||||
FunctionCallUserMuteStrategy(),
|
||||
CallbackUserMuteStrategy(should_mute_callback=engine.should_mute_user),
|
||||
]
|
||||
if is_deepgram_flux:
|
||||
user_turn_strategies = UserTurnStrategies(
|
||||
start=[
|
||||
VADUserTurnStartStrategy(),
|
||||
ExternalUserTurnStartStrategy(enable_interruptions=True),
|
||||
],
|
||||
stop=[ExternalUserTurnStopStrategy()],
|
||||
)
|
||||
elif turn_stop_strategy == "turn_analyzer":
|
||||
# Smart Turn Analyzer: best for longer responses with natural pauses
|
||||
smart_turn_params = SmartTurnParams(stop_secs=smart_turn_stop_secs)
|
||||
user_turn_strategies = UserTurnStrategies(
|
||||
start=[
|
||||
VADUserTurnStartStrategy(),
|
||||
TranscriptionUserTurnStartStrategy(),
|
||||
],
|
||||
stop=[
|
||||
TurnAnalyzerUserTurnStopStrategy(
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=smart_turn_params)
|
||||
)
|
||||
],
|
||||
)
|
||||
else:
|
||||
# Transcription-based (default): best for short 1-2 word responses
|
||||
user_turn_strategies = UserTurnStrategies(
|
||||
start=[
|
||||
VADUserTurnStartStrategy(),
|
||||
TranscriptionUserTurnStartStrategy(),
|
||||
],
|
||||
stop=[SpeechTimeoutUserTurnStopStrategy()],
|
||||
)
|
||||
|
||||
# - CallbackUserMuteStrategy: mutes based on engine's _mute_pipeline state
|
||||
user_mute_strategies = [
|
||||
MuteUntilFirstBotCompleteUserMuteStrategy(),
|
||||
FunctionCallUserMuteStrategy(),
|
||||
CallbackUserMuteStrategy(should_mute_callback=engine.should_mute_user),
|
||||
]
|
||||
|
||||
user_params = LLMUserAggregatorParams(
|
||||
user_turn_strategies=user_turn_strategies,
|
||||
|
|
@ -769,77 +799,93 @@ async def _run_pipeline(
|
|||
async def on_user_turn_started(aggregator, strategy):
|
||||
user_idle_handler.reset()
|
||||
|
||||
# Create voicemail detector if enabled in workflow configurations
|
||||
# Voicemail detection and recording router are not supported in realtime mode
|
||||
voicemail_detector = None
|
||||
voicemail_config = (workflow.workflow_configurations or {}).get(
|
||||
"voicemail_detection", {}
|
||||
)
|
||||
if voicemail_config.get("enabled", False):
|
||||
logger.info(f"Voicemail detection enabled for workflow run {workflow_run_id}")
|
||||
# Create a separate LLM instance for the voicemail sub-pipeline
|
||||
# (can't share with main pipeline as it would mess up frame linking)
|
||||
if voicemail_config.get("use_workflow_llm", True):
|
||||
voicemail_llm = create_llm_service(user_config)
|
||||
else:
|
||||
voicemail_llm = create_llm_service_from_provider(
|
||||
provider=voicemail_config.get("provider", "openai"),
|
||||
model=voicemail_config.get("model", "gpt-4.1"),
|
||||
api_key=voicemail_config.get("api_key", ""),
|
||||
)
|
||||
|
||||
long_speech_timeout = voicemail_config.get("long_speech_timeout", 8.0)
|
||||
custom_system_prompt = voicemail_config.get("system_prompt") or None
|
||||
|
||||
voicemail_detector = VoicemailDetector(
|
||||
llm=voicemail_llm,
|
||||
long_speech_timeout=long_speech_timeout,
|
||||
custom_system_prompt=custom_system_prompt,
|
||||
)
|
||||
|
||||
# Register event handler to end task when voicemail is detected
|
||||
@voicemail_detector.event_handler("on_voicemail_detected")
|
||||
async def _on_voicemail_detected(_processor):
|
||||
logger.info(f"Voicemail detected for workflow run {workflow_run_id}")
|
||||
await engine.end_call_with_reason(
|
||||
reason=EndTaskReason.VOICEMAIL_DETECTED.value,
|
||||
abort_immediately=True,
|
||||
)
|
||||
|
||||
# Create recording router if workflow has active recordings
|
||||
recording_router = None
|
||||
if has_recordings:
|
||||
fetch_audio = create_recording_audio_fetcher(
|
||||
organization_id=workflow.organization_id,
|
||||
pipeline_sample_rate=audio_config.pipeline_sample_rate,
|
||||
|
||||
if not is_realtime:
|
||||
# Create voicemail detector if enabled in workflow configurations
|
||||
voicemail_config = (workflow.workflow_configurations or {}).get(
|
||||
"voicemail_detection", {}
|
||||
)
|
||||
recording_router = RecordingRouterProcessor(
|
||||
audio_sample_rate=audio_config.pipeline_sample_rate,
|
||||
fetch_recording_audio=fetch_audio,
|
||||
)
|
||||
# Warm the recording cache in the background so audio is ready
|
||||
# before the first playback request.
|
||||
asyncio.create_task(
|
||||
warm_recording_cache(
|
||||
workflow_id=workflow_id,
|
||||
if voicemail_config.get("enabled", False):
|
||||
logger.info(
|
||||
f"Voicemail detection enabled for workflow run {workflow_run_id}"
|
||||
)
|
||||
# Create a separate LLM instance for the voicemail sub-pipeline
|
||||
# (can't share with main pipeline as it would mess up frame linking)
|
||||
if voicemail_config.get("use_workflow_llm", True):
|
||||
voicemail_llm = create_llm_service(user_config)
|
||||
else:
|
||||
voicemail_llm = create_llm_service_from_provider(
|
||||
provider=voicemail_config.get("provider", "openai"),
|
||||
model=voicemail_config.get("model", "gpt-4.1"),
|
||||
api_key=voicemail_config.get("api_key", ""),
|
||||
)
|
||||
|
||||
long_speech_timeout = voicemail_config.get("long_speech_timeout", 8.0)
|
||||
custom_system_prompt = voicemail_config.get("system_prompt") or None
|
||||
|
||||
voicemail_detector = VoicemailDetector(
|
||||
llm=voicemail_llm,
|
||||
long_speech_timeout=long_speech_timeout,
|
||||
custom_system_prompt=custom_system_prompt,
|
||||
)
|
||||
|
||||
# Register event handler to end task when voicemail is detected
|
||||
@voicemail_detector.event_handler("on_voicemail_detected")
|
||||
async def _on_voicemail_detected(_processor):
|
||||
logger.info(f"Voicemail detected for workflow run {workflow_run_id}")
|
||||
await engine.end_call_with_reason(
|
||||
reason=EndTaskReason.VOICEMAIL_DETECTED.value,
|
||||
abort_immediately=True,
|
||||
)
|
||||
|
||||
# Create recording router if workflow has active recordings
|
||||
if has_recordings:
|
||||
fetch_audio = create_recording_audio_fetcher(
|
||||
organization_id=workflow.organization_id,
|
||||
pipeline_sample_rate=audio_config.pipeline_sample_rate,
|
||||
)
|
||||
)
|
||||
recording_router = RecordingRouterProcessor(
|
||||
audio_sample_rate=audio_config.pipeline_sample_rate,
|
||||
fetch_recording_audio=fetch_audio,
|
||||
)
|
||||
# Warm the recording cache in the background so audio is ready
|
||||
# before the first playback request.
|
||||
asyncio.create_task(
|
||||
warm_recording_cache(
|
||||
workflow_id=workflow_id,
|
||||
organization_id=workflow.organization_id,
|
||||
pipeline_sample_rate=audio_config.pipeline_sample_rate,
|
||||
)
|
||||
)
|
||||
|
||||
# Build the pipeline with the STT mute filter and context controller
|
||||
pipeline = build_pipeline(
|
||||
transport,
|
||||
stt,
|
||||
audio_buffer,
|
||||
llm,
|
||||
tts,
|
||||
user_context_aggregator,
|
||||
assistant_context_aggregator,
|
||||
pipeline_engine_callback_processor,
|
||||
pipeline_metrics_aggregator,
|
||||
voicemail_detector=voicemail_detector,
|
||||
recording_router=recording_router,
|
||||
)
|
||||
# Build the pipeline
|
||||
if is_realtime:
|
||||
pipeline = build_realtime_pipeline(
|
||||
transport,
|
||||
llm,
|
||||
audio_buffer,
|
||||
user_context_aggregator,
|
||||
assistant_context_aggregator,
|
||||
pipeline_engine_callback_processor,
|
||||
pipeline_metrics_aggregator,
|
||||
)
|
||||
else:
|
||||
pipeline = build_pipeline(
|
||||
transport,
|
||||
stt,
|
||||
audio_buffer,
|
||||
llm,
|
||||
tts,
|
||||
user_context_aggregator,
|
||||
assistant_context_aggregator,
|
||||
pipeline_engine_callback_processor,
|
||||
pipeline_metrics_aggregator,
|
||||
voicemail_detector=voicemail_detector,
|
||||
recording_router=recording_router,
|
||||
)
|
||||
|
||||
# Create pipeline task with audio configuration
|
||||
task = create_pipeline_task(pipeline, workflow_run_id, audio_config)
|
||||
|
|
@ -847,7 +893,8 @@ async def _run_pipeline(
|
|||
# Now set the task on the engine
|
||||
engine.set_task(task)
|
||||
|
||||
# Initialize the engine to set the initial context
|
||||
# Initialize the engine to set the initial context with
|
||||
# System Prompt and Tools
|
||||
await engine.initialize()
|
||||
|
||||
# Add real-time feedback observer (always logs to buffer, streams to WS if available)
|
||||
|
|
|
|||
|
|
@ -27,11 +27,17 @@ from pipecat.services.google.llm import GoogleLLMService, GoogleLLMSettings
|
|||
from pipecat.services.groq.llm import GroqLLMService, GroqLLMSettings
|
||||
from pipecat.services.openai.base_llm import OpenAILLMSettings
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.services.openai.stt import OpenAISTTService, OpenAISTTSettings
|
||||
from pipecat.services.openai.stt import (
|
||||
OpenAISTTService,
|
||||
OpenAISTTSettings,
|
||||
)
|
||||
from pipecat.services.openai.tts import OpenAITTSService, OpenAITTSSettings
|
||||
from pipecat.services.openrouter.llm import OpenRouterLLMService, OpenRouterLLMSettings
|
||||
from pipecat.services.sarvam.stt import SarvamSTTService, SarvamSTTSettings
|
||||
from pipecat.services.sarvam.tts import SarvamTTSService, SarvamTTSSettings
|
||||
from pipecat.services.speaches.llm import SpeachesLLMService, SpeachesLLMSettings
|
||||
from pipecat.services.speaches.stt import SpeachesSTTService, SpeachesSTTSettings
|
||||
from pipecat.services.speaches.tts import SpeachesTTSService, SpeachesTTSSettings
|
||||
from pipecat.services.speechmatics.stt import (
|
||||
SpeechmaticsSTTService,
|
||||
SpeechmaticsSTTSettings,
|
||||
|
|
@ -58,7 +64,6 @@ def create_stt_service(
|
|||
if user_config.stt.provider == ServiceProviders.DEEPGRAM.value:
|
||||
# Check if using Flux model (English-only, no language selection)
|
||||
if user_config.stt.model == "flux-general-en":
|
||||
logger.debug("Using DeepGram Flux Model")
|
||||
return DeepgramFluxSTTService(
|
||||
api_key=user_config.stt.api_key,
|
||||
settings=DeepgramFluxSTTSettings(
|
||||
|
|
@ -137,6 +142,20 @@ def create_stt_service(
|
|||
),
|
||||
sample_rate=audio_config.transport_in_sample_rate,
|
||||
)
|
||||
elif user_config.stt.provider == ServiceProviders.SPEACHES.value:
|
||||
base_url = user_config.stt.base_url.replace("http://", "ws://").replace(
|
||||
"https://", "wss://"
|
||||
)
|
||||
language = getattr(user_config.stt, "language", None) or "multi"
|
||||
return SpeachesSTTService(
|
||||
base_url=base_url,
|
||||
api_key=user_config.stt.api_key or "none",
|
||||
settings=SpeachesSTTSettings(
|
||||
model=user_config.stt.model,
|
||||
language=language,
|
||||
),
|
||||
sample_rate=audio_config.transport_in_sample_rate,
|
||||
)
|
||||
elif user_config.stt.provider == ServiceProviders.SPEECHMATICS.value:
|
||||
from pipecat.services.speechmatics.stt import (
|
||||
AdditionalVocabEntry,
|
||||
|
|
@ -186,6 +205,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
|||
api_key=user_config.tts.api_key,
|
||||
settings=DeepgramTTSSettings(voice=user_config.tts.voice),
|
||||
text_filters=[xml_function_tag_filter],
|
||||
skip_aggregator_types=["recording_router"],
|
||||
silence_time_s=1.0,
|
||||
)
|
||||
elif user_config.tts.provider == ServiceProviders.OPENAI.value:
|
||||
|
|
@ -193,6 +213,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
|||
api_key=user_config.tts.api_key,
|
||||
settings=OpenAITTSSettings(model=user_config.tts.model),
|
||||
text_filters=[xml_function_tag_filter],
|
||||
skip_aggregator_types=["recording_router"],
|
||||
silence_time_s=1.0,
|
||||
)
|
||||
elif user_config.tts.provider == ServiceProviders.ELEVENLABS.value:
|
||||
|
|
@ -212,6 +233,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
|||
similarity_boost=0.75,
|
||||
),
|
||||
text_filters=[xml_function_tag_filter],
|
||||
skip_aggregator_types=["recording_router"],
|
||||
silence_time_s=1.0,
|
||||
)
|
||||
elif user_config.tts.provider == ServiceProviders.CARTESIA.value:
|
||||
|
|
@ -231,6 +253,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
|||
),
|
||||
),
|
||||
text_filters=[xml_function_tag_filter],
|
||||
skip_aggregator_types=["recording_router"],
|
||||
silence_time_s=1.0,
|
||||
)
|
||||
elif user_config.tts.provider == ServiceProviders.DOGRAH.value:
|
||||
|
|
@ -245,6 +268,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
|||
speed=user_config.tts.speed,
|
||||
),
|
||||
text_filters=[xml_function_tag_filter],
|
||||
skip_aggregator_types=["recording_router"],
|
||||
silence_time_s=1.0,
|
||||
)
|
||||
elif user_config.tts.provider == ServiceProviders.CAMB.value:
|
||||
|
|
@ -257,10 +281,24 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
|||
voice_id=voice_id,
|
||||
model=user_config.tts.model,
|
||||
text_filters=[xml_function_tag_filter],
|
||||
skip_aggregator_types=["recording_router"],
|
||||
)
|
||||
# Set language directly as BCP-47 code (bypasses Language enum conversion)
|
||||
tts._settings.language = language
|
||||
return tts
|
||||
elif user_config.tts.provider == ServiceProviders.SPEACHES.value:
|
||||
return SpeachesTTSService(
|
||||
base_url=user_config.tts.base_url,
|
||||
api_key=user_config.tts.api_key or "none",
|
||||
settings=SpeachesTTSSettings(
|
||||
model=user_config.tts.model,
|
||||
voice=user_config.tts.voice,
|
||||
speed=user_config.tts.speed,
|
||||
),
|
||||
text_filters=[xml_function_tag_filter],
|
||||
skip_aggregator_types=["recording_router"],
|
||||
silence_time_s=1.0,
|
||||
)
|
||||
elif user_config.tts.provider == ServiceProviders.SARVAM.value:
|
||||
# Map Sarvam language code to pipecat Language enum for TTS
|
||||
language_mapping = {
|
||||
|
|
@ -288,6 +326,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
|||
language=pipecat_language,
|
||||
),
|
||||
text_filters=[xml_function_tag_filter],
|
||||
skip_aggregator_types=["recording_router"],
|
||||
silence_time_s=1.0,
|
||||
)
|
||||
else:
|
||||
|
|
@ -363,16 +402,80 @@ def create_llm_service_from_provider(
|
|||
aws_region=aws_region,
|
||||
settings=AWSBedrockLLMSettings(model=model),
|
||||
)
|
||||
elif provider == ServiceProviders.SELF_HOSTED.value:
|
||||
return OpenAILLMService(
|
||||
elif provider == ServiceProviders.SPEACHES.value:
|
||||
return SpeachesLLMService(
|
||||
base_url=base_url or "http://localhost:11434/v1",
|
||||
api_key=api_key or "none",
|
||||
settings=OpenAILLMSettings(model=model),
|
||||
settings=SpeachesLLMSettings(model=model),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid LLM provider {provider}")
|
||||
|
||||
|
||||
def create_realtime_llm_service(user_config, audio_config: "AudioConfig"):
|
||||
"""Create a realtime (speech-to-speech) LLM service that handles STT+LLM+TTS.
|
||||
|
||||
These services bypass separate STT/TTS and handle audio directly via
|
||||
a bidirectional WebSocket connection. Reads from user_config.realtime.
|
||||
"""
|
||||
realtime_config = user_config.realtime
|
||||
provider = realtime_config.provider
|
||||
model = realtime_config.model
|
||||
api_key = realtime_config.api_key
|
||||
voice = getattr(realtime_config, "voice", None)
|
||||
language = getattr(realtime_config, "language", None)
|
||||
|
||||
logger.info(
|
||||
f"Creating realtime LLM service: provider={provider}, model={model}, voice={voice}, language={language}"
|
||||
)
|
||||
|
||||
if provider == ServiceProviders.OPENAI_REALTIME.value:
|
||||
from pipecat.services.openai.realtime.events import (
|
||||
AudioConfiguration,
|
||||
AudioInput,
|
||||
AudioOutput,
|
||||
InputAudioTranscription,
|
||||
SessionProperties,
|
||||
)
|
||||
from pipecat.services.openai.realtime.llm import OpenAIRealtimeLLMService
|
||||
|
||||
return OpenAIRealtimeLLMService(
|
||||
api_key=api_key,
|
||||
settings=OpenAIRealtimeLLMService.Settings(
|
||||
model=model,
|
||||
session_properties=SessionProperties(
|
||||
audio=AudioConfiguration(
|
||||
input=AudioInput(
|
||||
transcription=InputAudioTranscription(),
|
||||
),
|
||||
output=AudioOutput(
|
||||
voice=voice or "alloy",
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
elif provider == ServiceProviders.GOOGLE_REALTIME.value:
|
||||
from pipecat.services.google.gemini_live.llm import GeminiLiveLLMService
|
||||
|
||||
# Gemini Live enables input/output audio transcription by default
|
||||
# in its _connect() method — no need to configure it explicitly.
|
||||
settings_kwargs = {
|
||||
"model": model,
|
||||
"voice": voice or "Puck",
|
||||
}
|
||||
if language:
|
||||
settings_kwargs["language"] = language
|
||||
return GeminiLiveLLMService(
|
||||
api_key=api_key,
|
||||
settings=GeminiLiveLLMService.Settings(**settings_kwargs),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid realtime LLM provider {provider}"
|
||||
)
|
||||
|
||||
|
||||
def create_llm_service(user_config):
|
||||
"""Create and return appropriate LLM service based on user configuration."""
|
||||
provider = user_config.llm.provider
|
||||
|
|
@ -384,7 +487,7 @@ def create_llm_service(user_config):
|
|||
kwargs["base_url"] = user_config.llm.base_url
|
||||
elif provider == ServiceProviders.AZURE.value:
|
||||
kwargs["endpoint"] = user_config.llm.endpoint
|
||||
elif provider == ServiceProviders.SELF_HOSTED.value:
|
||||
elif provider == ServiceProviders.SPEACHES.value:
|
||||
kwargs["base_url"] = user_config.llm.base_url
|
||||
elif provider == ServiceProviders.AWS_BEDROCK.value:
|
||||
kwargs["aws_access_key"] = user_config.llm.aws_access_key
|
||||
|
|
|
|||
|
|
@ -210,12 +210,17 @@ class PipecatEngine:
|
|||
async def _update_llm_context(self, system_prompt: str, functions: list[dict]):
|
||||
"""Update LLM settings with the composed system prompt and tool list."""
|
||||
|
||||
await self.llm._update_settings(LLMSettings(system_instruction=system_prompt))
|
||||
|
||||
if functions:
|
||||
tools_schema = ToolsSchema(standard_tools=functions)
|
||||
self.context.set_tools(tools_schema)
|
||||
|
||||
await self.llm._update_settings(LLMSettings(system_instruction=system_prompt))
|
||||
|
||||
# For Gemini Live, set context on the LLM before _update_settings so that
|
||||
# _connect (triggered by reconnect) can read tools from it.
|
||||
if hasattr(self.llm, "_context") and not self.llm._context and self.context:
|
||||
self.llm._context = self.context
|
||||
|
||||
def _format_prompt(self, prompt: str) -> str:
|
||||
"""Delegate prompt formatting to the shared workflow.utils implementation."""
|
||||
|
||||
|
|
|
|||
|
|
@ -215,13 +215,17 @@ class VariableExtractionManager:
|
|||
with tracer.start_as_current_span(
|
||||
"llm-variable-extraction", context=parent_ctx
|
||||
) as span:
|
||||
tracing_messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
*extraction_messages,
|
||||
]
|
||||
add_llm_span_attributes(
|
||||
span,
|
||||
service_name=self._engine.llm.__class__.__name__,
|
||||
model=model_name,
|
||||
operation_name="llm-variable-extraction",
|
||||
messages=extraction_messages,
|
||||
output=llm_response,
|
||||
messages=tracing_messages,
|
||||
output=json.dumps({"content": llm_response}),
|
||||
stream=False,
|
||||
parameters={},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -59,7 +59,14 @@ async def _generate_conversation_summary(
|
|||
)
|
||||
|
||||
span_name = f"conversation-summary-before-{node_name}"
|
||||
add_qa_span_to_trace(parent_ctx, model, messages, summary, span_name)
|
||||
add_qa_span_to_trace(
|
||||
parent_ctx,
|
||||
model,
|
||||
messages,
|
||||
summary,
|
||||
span_name,
|
||||
CONVERSATION_SUMMARY_SYSTEM_PROMPT,
|
||||
)
|
||||
|
||||
return summary
|
||||
except Exception as e:
|
||||
|
|
@ -189,7 +196,9 @@ async def run_per_node_qa_analysis(
|
|||
|
||||
# Trace
|
||||
span_name = f"qa-node-{node_name}"
|
||||
add_qa_span_to_trace(parent_ctx, model, messages, raw_response, span_name)
|
||||
add_qa_span_to_trace(
|
||||
parent_ctx, model, messages, raw_response, span_name, system_content
|
||||
)
|
||||
|
||||
# Parse response
|
||||
node_result: dict[str, Any] = {
|
||||
|
|
@ -299,7 +308,9 @@ async def _run_whole_call_qa_analysis(
|
|||
|
||||
# Langfuse tracing
|
||||
parent_ctx = setup_langfuse_parent_context(workflow_run)
|
||||
add_qa_span_to_trace(parent_ctx, model, messages, raw_response, "qa-analysis")
|
||||
add_qa_span_to_trace(
|
||||
parent_ctx, model, messages, raw_response, "qa-analysis", system_content
|
||||
)
|
||||
|
||||
return {
|
||||
"node_results": {"whole_call": node_result},
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
"""LLM configuration resolution and token usage accumulation."""
|
||||
|
||||
import random
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.models import WorkflowRunModel
|
||||
|
||||
|
|
@ -57,6 +59,8 @@ async def resolve_user_llm_config(
|
|||
|
||||
provider = llm_config.get("provider", "openai")
|
||||
api_key = llm_config.get("api_key", "")
|
||||
if isinstance(api_key, list):
|
||||
api_key = random.choice(api_key)
|
||||
model = llm_config.get("model", "gpt-4.1")
|
||||
|
||||
kwargs = {}
|
||||
|
|
|
|||
|
|
@ -166,7 +166,9 @@ async def ensure_node_summaries(
|
|||
continue
|
||||
|
||||
# Create a Langfuse trace for this summary generation
|
||||
trace_url = create_node_summary_trace(model, messages, summary_text, node_name)
|
||||
trace_url = create_node_summary_trace(
|
||||
model, messages, summary_text, node_name, NODE_SUMMARY_SYSTEM_PROMPT
|
||||
)
|
||||
|
||||
entry: dict[str, Any] = {"summary": summary_text}
|
||||
if trace_url:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
"""Langfuse / OpenTelemetry tracing helpers for QA analysis."""
|
||||
|
||||
import json
|
||||
import re
|
||||
|
||||
from loguru import logger
|
||||
|
|
@ -70,6 +71,7 @@ def add_qa_span_to_trace(
|
|||
messages: list[dict],
|
||||
output: str,
|
||||
span_name: str,
|
||||
system_prompt: str = "",
|
||||
) -> None:
|
||||
"""Create a child span under the conversation trace."""
|
||||
if parent_ctx is None:
|
||||
|
|
@ -84,13 +86,21 @@ def add_qa_span_to_trace(
|
|||
span_name,
|
||||
context=parent_ctx,
|
||||
) as span:
|
||||
tracing_messages = (
|
||||
[
|
||||
{"role": "system", "content": system_prompt},
|
||||
*messages,
|
||||
]
|
||||
if system_prompt
|
||||
else messages
|
||||
)
|
||||
add_llm_span_attributes(
|
||||
span,
|
||||
service_name="OpenAILLMService",
|
||||
model=model,
|
||||
operation_name=span_name,
|
||||
messages=messages,
|
||||
output=output,
|
||||
messages=tracing_messages,
|
||||
output=json.dumps({"content": output}),
|
||||
stream=False,
|
||||
parameters={"temperature": 0},
|
||||
)
|
||||
|
|
@ -103,6 +113,7 @@ def create_node_summary_trace(
|
|||
messages: list[dict],
|
||||
output: str,
|
||||
node_name: str,
|
||||
system_prompt: str = "",
|
||||
) -> str | None:
|
||||
"""Create a standalone Langfuse trace for a node summary generation.
|
||||
|
||||
|
|
@ -125,13 +136,21 @@ def create_node_summary_trace(
|
|||
f"node-summary-{node_name}",
|
||||
context=Context(),
|
||||
) as span:
|
||||
tracing_messages = (
|
||||
[
|
||||
{"role": "system", "content": system_prompt},
|
||||
*messages,
|
||||
]
|
||||
if system_prompt
|
||||
else messages
|
||||
)
|
||||
add_llm_span_attributes(
|
||||
span,
|
||||
service_name="OpenAILLMService",
|
||||
model=model,
|
||||
operation_name=f"node-summary-{node_name}",
|
||||
messages=messages,
|
||||
output=output,
|
||||
messages=tracing_messages,
|
||||
output=json.dumps({"content": output}),
|
||||
stream=False,
|
||||
parameters={"temperature": 0},
|
||||
)
|
||||
|
|
|
|||
2
pipecat
2
pipecat
|
|
@ -1 +1 @@
|
|||
Subproject commit 2e2171e2a64ec87b3964fbc2440b5291489912a8
|
||||
Subproject commit a2dc39c0d706e420121d045183554f378fe9d841
|
||||
|
|
@ -448,6 +448,11 @@ export type DefaultConfigurationsResponse = {
|
|||
[key: string]: unknown;
|
||||
};
|
||||
};
|
||||
realtime: {
|
||||
[key: string]: {
|
||||
[key: string]: unknown;
|
||||
};
|
||||
};
|
||||
default_providers: {
|
||||
[key: string]: string;
|
||||
};
|
||||
|
|
@ -1329,6 +1334,10 @@ export type UserConfigurationRequestResponseSchema = {
|
|||
embeddings?: {
|
||||
[key: string]: string | number | Array<string> | null;
|
||||
} | null;
|
||||
realtime?: {
|
||||
[key: string]: string | number | Array<string> | null;
|
||||
} | null;
|
||||
is_realtime?: boolean | null;
|
||||
test_phone_number?: string | null;
|
||||
timezone?: string | null;
|
||||
organization_pricing?: {
|
||||
|
|
|
|||
|
|
@ -11,12 +11,13 @@ import { Checkbox } from "@/components/ui/checkbox";
|
|||
import { Input } from "@/components/ui/input";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@/components/ui/select";
|
||||
import { Switch } from "@/components/ui/switch";
|
||||
import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
|
||||
import { VoiceSelector } from "@/components/VoiceSelector";
|
||||
import { LANGUAGE_DISPLAY_NAMES } from "@/constants/languages";
|
||||
import { useUserConfig } from "@/context/UserConfigContext";
|
||||
|
||||
type ServiceSegment = "llm" | "tts" | "stt" | "embeddings";
|
||||
type ServiceSegment = "llm" | "tts" | "stt" | "embeddings" | "realtime";
|
||||
|
||||
interface SchemaProperty {
|
||||
type?: string;
|
||||
|
|
@ -24,6 +25,7 @@ interface SchemaProperty {
|
|||
enum?: string[];
|
||||
examples?: string[];
|
||||
model_options?: Record<string, string[]>;
|
||||
allow_custom_input?: boolean;
|
||||
$ref?: string;
|
||||
description?: string;
|
||||
format?: string;
|
||||
|
|
@ -40,13 +42,18 @@ interface FormValues {
|
|||
[key: string]: string | number | boolean;
|
||||
}
|
||||
|
||||
const TAB_CONFIG: { key: ServiceSegment; label: string }[] = [
|
||||
const STANDARD_TABS: { key: ServiceSegment; label: string }[] = [
|
||||
{ key: "llm", label: "LLM" },
|
||||
{ key: "tts", label: "Voice" },
|
||||
{ key: "stt", label: "Transcriber" },
|
||||
{ key: "embeddings", label: "Embedding" },
|
||||
];
|
||||
|
||||
const REALTIME_TABS: { key: ServiceSegment; label: string }[] = [
|
||||
{ key: "realtime", label: "Realtime Model" },
|
||||
{ key: "embeddings", label: "Embedding" },
|
||||
];
|
||||
|
||||
// Display names for Sarvam voices
|
||||
const VOICE_DISPLAY_NAMES: Record<string, string> = {
|
||||
"anushka": "Anushka (Female)",
|
||||
|
|
@ -61,27 +68,30 @@ const VOICE_DISPLAY_NAMES: Record<string, string> = {
|
|||
export default function ServiceConfiguration() {
|
||||
const [apiError, setApiError] = useState<string | null>(null);
|
||||
const [isSaving, setIsSaving] = useState(false);
|
||||
const [isRealtime, setIsRealtime] = useState(false);
|
||||
const { userConfig, saveUserConfig } = useUserConfig();
|
||||
const [schemas, setSchemas] = useState<Record<ServiceSegment, Record<string, ProviderSchema>>>({
|
||||
llm: {},
|
||||
tts: {},
|
||||
stt: {},
|
||||
embeddings: {}
|
||||
embeddings: {},
|
||||
realtime: {},
|
||||
});
|
||||
const [serviceProviders, setServiceProviders] = useState<Record<ServiceSegment, string>>({
|
||||
llm: "",
|
||||
tts: "",
|
||||
stt: "",
|
||||
embeddings: ""
|
||||
embeddings: "",
|
||||
realtime: "",
|
||||
});
|
||||
const [apiKeys, setApiKeys] = useState<Record<ServiceSegment, string[]>>({
|
||||
llm: [""],
|
||||
tts: [""],
|
||||
stt: [""],
|
||||
embeddings: [""],
|
||||
realtime: [""],
|
||||
});
|
||||
const [isManualModelInput, setIsManualModelInput] = useState(false);
|
||||
const [hasCheckedManualMode, setHasCheckedManualMode] = useState(false);
|
||||
const [isCustomInput, setIsCustomInput] = useState<Record<string, boolean>>({});
|
||||
|
||||
const {
|
||||
register,
|
||||
|
|
@ -97,12 +107,20 @@ export default function ServiceConfiguration() {
|
|||
const fetchConfigurations = async () => {
|
||||
const response = await getDefaultConfigurationsApiV1UserConfigurationsDefaultsGet();
|
||||
if (response.data) {
|
||||
const data = response.data as Record<string, unknown>;
|
||||
setSchemas({
|
||||
llm: response.data.llm as Record<string, ProviderSchema>,
|
||||
tts: response.data.tts as Record<string, ProviderSchema>,
|
||||
stt: response.data.stt as Record<string, ProviderSchema>,
|
||||
embeddings: response.data.embeddings as Record<string, ProviderSchema>
|
||||
embeddings: response.data.embeddings as Record<string, ProviderSchema>,
|
||||
realtime: (data.realtime || {}) as Record<string, ProviderSchema>,
|
||||
});
|
||||
|
||||
// Restore realtime toggle from saved config
|
||||
const configData = userConfig as Record<string, unknown> | null;
|
||||
if (configData?.is_realtime) {
|
||||
setIsRealtime(true);
|
||||
}
|
||||
} else {
|
||||
console.error("Failed to fetch configurations");
|
||||
return;
|
||||
|
|
@ -113,23 +131,41 @@ export default function ServiceConfiguration() {
|
|||
llm: response.data.default_providers.llm,
|
||||
tts: response.data.default_providers.tts,
|
||||
stt: response.data.default_providers.stt,
|
||||
embeddings: response.data.default_providers.embeddings
|
||||
embeddings: response.data.default_providers.embeddings,
|
||||
realtime: "",
|
||||
};
|
||||
|
||||
// Set default realtime provider from schema keys
|
||||
const data = response.data as Record<string, unknown>;
|
||||
const realtimeSchemas = (data.realtime || {}) as Record<string, ProviderSchema>;
|
||||
const realtimeProviderKeys = Object.keys(realtimeSchemas);
|
||||
if (realtimeProviderKeys.length > 0) {
|
||||
selectedProviders.realtime = realtimeProviderKeys[0];
|
||||
}
|
||||
|
||||
const loadedApiKeys: Record<ServiceSegment, string[]> = {
|
||||
llm: [""],
|
||||
tts: [""],
|
||||
stt: [""],
|
||||
embeddings: [""],
|
||||
realtime: [""],
|
||||
};
|
||||
|
||||
const setServicePropertyValues = (service: ServiceSegment) => {
|
||||
if (userConfig?.[service]?.provider) {
|
||||
Object.entries(userConfig?.[service]).forEach(([field, value]) => {
|
||||
// For realtime, read from userConfig.realtime; for others, read from userConfig[service]
|
||||
const configSource = service === "realtime"
|
||||
? (userConfig as Record<string, unknown> | null)?.realtime as Record<string, unknown> | undefined
|
||||
: userConfig?.[service as "llm" | "tts" | "stt" | "embeddings"];
|
||||
|
||||
const schemaSource = service === "realtime"
|
||||
? realtimeSchemas
|
||||
: response.data[service as "llm" | "tts" | "stt" | "embeddings"] as Record<string, ProviderSchema> | undefined;
|
||||
|
||||
if (configSource?.provider) {
|
||||
Object.entries(configSource).forEach(([field, value]) => {
|
||||
if (field === "api_key") {
|
||||
// Handle api_key separately — it can be string or string[]
|
||||
if (Array.isArray(value)) {
|
||||
loadedApiKeys[service] = value.length > 0 ? value : [""];
|
||||
loadedApiKeys[service] = (value as string[]).length > 0 ? value as string[] : [""];
|
||||
} else {
|
||||
loadedApiKeys[service] = value ? [value as string] : [""];
|
||||
}
|
||||
|
|
@ -137,9 +173,9 @@ export default function ServiceConfiguration() {
|
|||
defaultValues[`${service}_${field}`] = value as string | number | boolean;
|
||||
}
|
||||
});
|
||||
selectedProviders[service] = userConfig?.[service]?.provider as string;
|
||||
// Fill in schema defaults for fields not present in userConfig
|
||||
const properties = response.data[service]?.[selectedProviders[service]]?.properties as Record<string, SchemaProperty>;
|
||||
selectedProviders[service] = configSource.provider as string;
|
||||
// Fill in schema defaults for fields not present in config
|
||||
const properties = schemaSource?.[selectedProviders[service]]?.properties as Record<string, SchemaProperty>;
|
||||
if (properties) {
|
||||
Object.entries(properties).forEach(([field, schema]) => {
|
||||
const key = `${service}_${field}`;
|
||||
|
|
@ -149,7 +185,7 @@ export default function ServiceConfiguration() {
|
|||
});
|
||||
}
|
||||
} else {
|
||||
const properties = response.data[service]?.[selectedProviders[service]]?.properties as Record<string, SchemaProperty>;
|
||||
const properties = schemaSource?.[selectedProviders[service]]?.properties as Record<string, SchemaProperty>;
|
||||
if (properties) {
|
||||
Object.entries(properties).forEach(([field, schema]) => {
|
||||
if (field !== "provider" && schema.default !== undefined) {
|
||||
|
|
@ -164,6 +200,33 @@ export default function ServiceConfiguration() {
|
|||
setServicePropertyValues("tts");
|
||||
setServicePropertyValues("stt");
|
||||
setServicePropertyValues("embeddings");
|
||||
setServicePropertyValues("realtime");
|
||||
|
||||
// Detect saved values that are not in suggested options (custom value)
|
||||
const detectedCustomInput: Record<string, boolean> = {};
|
||||
const allSchemas = { ...response.data, realtime: realtimeSchemas } as unknown as Record<string, Record<string, ProviderSchema>>;
|
||||
(["llm", "tts", "stt", "embeddings", "realtime"] as ServiceSegment[]).forEach(service => {
|
||||
const provider = selectedProviders[service];
|
||||
const providerSchema = allSchemas[service]?.[provider];
|
||||
if (!providerSchema) return;
|
||||
|
||||
const configSource = service === "realtime"
|
||||
? (userConfig as Record<string, unknown> | null)?.realtime as Record<string, unknown> | undefined
|
||||
: userConfig?.[service as "llm" | "tts" | "stt" | "embeddings"];
|
||||
|
||||
Object.entries(providerSchema.properties).forEach(([field, schema]) => {
|
||||
const actualSchema = (schema as SchemaProperty).$ref && providerSchema.$defs
|
||||
? providerSchema.$defs[(schema as SchemaProperty).$ref!.split('/').pop() || '']
|
||||
: schema as SchemaProperty;
|
||||
|
||||
if (!actualSchema?.allow_custom_input || !actualSchema?.examples) return;
|
||||
|
||||
const savedValue = configSource?.[field] as string | undefined;
|
||||
if (savedValue && !actualSchema.examples.includes(savedValue)) {
|
||||
detectedCustomInput[`${service}_${field}`] = true;
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// IMPORTANT: Reset form values BEFORE changing providers
|
||||
// Otherwise, Radix Select sees old values that don't match new provider's enum
|
||||
|
|
@ -171,33 +234,11 @@ export default function ServiceConfiguration() {
|
|||
reset(defaultValues);
|
||||
setApiKeys(loadedApiKeys);
|
||||
setServiceProviders(selectedProviders);
|
||||
setIsCustomInput(detectedCustomInput);
|
||||
};
|
||||
fetchConfigurations();
|
||||
}, [reset, userConfig]);
|
||||
|
||||
// Check if the saved LLM model is not in the suggested options (custom model)
|
||||
useEffect(() => {
|
||||
if (hasCheckedManualMode) return;
|
||||
|
||||
const currentProvider = serviceProviders.llm;
|
||||
const providerSchema = schemas?.llm?.[currentProvider];
|
||||
if (!providerSchema) return;
|
||||
|
||||
const modelSchema = providerSchema.properties.model;
|
||||
const actualModelSchema = modelSchema?.$ref && providerSchema.$defs
|
||||
? providerSchema.$defs[modelSchema.$ref.split('/').pop() || '']
|
||||
: modelSchema;
|
||||
|
||||
if (actualModelSchema?.examples && userConfig?.llm?.model) {
|
||||
const savedModel = userConfig.llm.model as string;
|
||||
const isInOptions = actualModelSchema.examples.includes(savedModel);
|
||||
if (!isInOptions) {
|
||||
setIsManualModelInput(true);
|
||||
}
|
||||
setHasCheckedManualMode(true);
|
||||
}
|
||||
}, [schemas, serviceProviders.llm, userConfig?.llm?.model, hasCheckedManualMode]);
|
||||
|
||||
// Reset voice when TTS model changes if the provider has model-dependent voice options
|
||||
const ttsModel = watch("tts_model");
|
||||
useEffect(() => {
|
||||
|
|
@ -256,10 +297,14 @@ export default function ServiceConfiguration() {
|
|||
setServiceProviders(prev => ({ ...prev, [service]: providerName }));
|
||||
setApiKeys(prev => ({ ...prev, [service]: [""] }));
|
||||
|
||||
// Reset manual model input when LLM provider changes
|
||||
if (service === "llm") {
|
||||
setIsManualModelInput(false);
|
||||
}
|
||||
// Reset custom input toggles when provider changes
|
||||
setIsCustomInput(prev => {
|
||||
const next = { ...prev };
|
||||
Object.keys(next).forEach(key => {
|
||||
if (key.startsWith(`${service}_`)) delete next[key];
|
||||
});
|
||||
return next;
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -271,55 +316,42 @@ export default function ServiceConfiguration() {
|
|||
const getServiceApiKeys = (service: ServiceSegment): string[] =>
|
||||
apiKeys[service].map(k => k.trim()).filter(k => k.length > 0);
|
||||
|
||||
const userConfig: Record<ServiceSegment, Record<string, string | number | string[]>> = {
|
||||
llm: {
|
||||
provider: serviceProviders.llm,
|
||||
...(getServiceApiKeys("llm").length > 0 && { api_key: getServiceApiKeys("llm") }),
|
||||
model: data.llm_model as string
|
||||
},
|
||||
tts: {
|
||||
provider: serviceProviders.tts,
|
||||
...(getServiceApiKeys("tts").length > 0 && { api_key: getServiceApiKeys("tts") }),
|
||||
},
|
||||
stt: {
|
||||
provider: serviceProviders.stt,
|
||||
...(getServiceApiKeys("stt").length > 0 && { api_key: getServiceApiKeys("stt") }),
|
||||
},
|
||||
embeddings: {
|
||||
provider: serviceProviders.embeddings,
|
||||
...(getServiceApiKeys("embeddings").length > 0 && { api_key: getServiceApiKeys("embeddings") }),
|
||||
model: data.embeddings_model as string
|
||||
// Build service configs from form data
|
||||
const buildServiceConfig = (service: ServiceSegment) => {
|
||||
const config: Record<string, string | number | string[]> = {
|
||||
provider: serviceProviders[service],
|
||||
};
|
||||
const keys = getServiceApiKeys(service);
|
||||
if (keys.length > 0) {
|
||||
config.api_key = keys;
|
||||
}
|
||||
// Add all form fields for this service
|
||||
Object.entries(data).forEach(([property, value]) => {
|
||||
if (!property.startsWith(`${service}_`)) return;
|
||||
const field = property.slice(service.length + 1);
|
||||
if (field === "api_key" || field === "provider") return;
|
||||
config[field] = value as string | number;
|
||||
});
|
||||
return config;
|
||||
};
|
||||
|
||||
// Add any extra properties in the payload
|
||||
Object.entries(data).forEach(([property, value]) => {
|
||||
const parts = property.split('_');
|
||||
const service = parts[0] as ServiceSegment;
|
||||
const field = parts.slice(1).join('_');
|
||||
|
||||
if (field === "api_key") return; // handled via apiKeys state
|
||||
if (userConfig[service] && !(field in userConfig[service])) {
|
||||
(userConfig[service] as Record<string, string>)[field] = value as string;
|
||||
}
|
||||
});
|
||||
|
||||
// Build save config - only include embeddings if api_key is provided
|
||||
const saveConfig: {
|
||||
llm: Record<string, string | number | string[]>;
|
||||
tts: Record<string, string | number | string[]>;
|
||||
stt: Record<string, string | number | string[]>;
|
||||
embeddings?: Record<string, string | number | string[]>;
|
||||
} = {
|
||||
llm: userConfig.llm,
|
||||
tts: userConfig.tts,
|
||||
stt: userConfig.stt
|
||||
// Always save all configs so switching modes preserves everything
|
||||
const saveConfig: Record<string, unknown> = {
|
||||
llm: buildServiceConfig("llm"),
|
||||
tts: buildServiceConfig("tts"),
|
||||
stt: buildServiceConfig("stt"),
|
||||
is_realtime: isRealtime,
|
||||
};
|
||||
|
||||
// Save realtime config if provider is set
|
||||
if (serviceProviders.realtime) {
|
||||
saveConfig.realtime = buildServiceConfig("realtime");
|
||||
}
|
||||
|
||||
// Only include embeddings if user has configured it (has api_key)
|
||||
const embeddingsKeys = getServiceApiKeys("embeddings");
|
||||
if (embeddingsKeys.length > 0) {
|
||||
saveConfig.embeddings = userConfig.embeddings;
|
||||
saveConfig.embeddings = buildServiceConfig("embeddings");
|
||||
}
|
||||
|
||||
try {
|
||||
|
|
@ -459,15 +491,13 @@ export default function ServiceConfiguration() {
|
|||
? providerSchema.$defs[schema.$ref.split('/').pop() || '']
|
||||
: schema;
|
||||
|
||||
// Use VoiceSelector for voice field in TTS service (except Sarvam which uses predefined options)
|
||||
if (service === "tts" && field === "voice") {
|
||||
const currentProvider = serviceProviders.tts;
|
||||
// Sarvam uses predefined voice options, not VoiceSelector
|
||||
// VoiceSelector for TTS voice fields without predefined options or manual input flag
|
||||
if (service === "tts" && field === "voice" && !actualSchema?.allow_custom_input) {
|
||||
const hasVoiceOptions = actualSchema?.enum || actualSchema?.examples;
|
||||
if (currentProvider !== "sarvam" && !hasVoiceOptions) {
|
||||
if (!hasVoiceOptions) {
|
||||
return (
|
||||
<VoiceSelector
|
||||
provider={currentProvider}
|
||||
provider={serviceProviders.tts}
|
||||
value={watch(`${service}_${field}`) as string || ""}
|
||||
onChange={(voiceId) => {
|
||||
setValue(`${service}_${field}`, voiceId, { shouldDirty: true });
|
||||
|
|
@ -477,39 +507,36 @@ export default function ServiceConfiguration() {
|
|||
}
|
||||
}
|
||||
|
||||
// Handle LLM model field with manual input toggle (uses examples from schema)
|
||||
if (service === "llm" && field === "model" && actualSchema?.examples) {
|
||||
const currentValue = watch(`${service}_${field}`) as string || "";
|
||||
const modelOptions = actualSchema.examples;
|
||||
// Generic allow_custom_input handler for any field (model, voice with options, etc.)
|
||||
if (actualSchema?.allow_custom_input && actualSchema?.examples) {
|
||||
const fieldKey = `${service}_${field}`;
|
||||
const currentValue = watch(fieldKey) as string || "";
|
||||
const options = actualSchema.examples;
|
||||
|
||||
if (isManualModelInput) {
|
||||
if (isCustomInput[fieldKey]) {
|
||||
return (
|
||||
<div className="space-y-2">
|
||||
<Input
|
||||
type="text"
|
||||
placeholder="Enter model name"
|
||||
placeholder={`Enter ${field}`}
|
||||
value={currentValue}
|
||||
onChange={(e) => {
|
||||
setValue(`${service}_${field}`, e.target.value, { shouldDirty: true });
|
||||
setValue(fieldKey, e.target.value, { shouldDirty: true });
|
||||
}}
|
||||
/>
|
||||
<div className="flex items-center space-x-2">
|
||||
<Checkbox
|
||||
id="manual-model-input"
|
||||
checked={isManualModelInput}
|
||||
id={`custom-input-${fieldKey}`}
|
||||
checked={true}
|
||||
onCheckedChange={(checked) => {
|
||||
setIsManualModelInput(checked as boolean);
|
||||
if (!checked && modelOptions.length > 0) {
|
||||
// Reset to first option when switching back
|
||||
setValue(`${service}_${field}`, modelOptions[0], { shouldDirty: true });
|
||||
setIsCustomInput(prev => ({ ...prev, [fieldKey]: checked as boolean }));
|
||||
if (!checked && options.length > 0) {
|
||||
setValue(fieldKey, options[0], { shouldDirty: true });
|
||||
}
|
||||
}}
|
||||
/>
|
||||
<Label
|
||||
htmlFor="manual-model-input"
|
||||
className="text-sm font-normal cursor-pointer"
|
||||
>
|
||||
Add Model Manually
|
||||
<Label htmlFor={`custom-input-${fieldKey}`} className="text-sm font-normal cursor-pointer">
|
||||
Enter Custom Value
|
||||
</Label>
|
||||
</div>
|
||||
</div>
|
||||
|
|
@ -522,14 +549,14 @@ export default function ServiceConfiguration() {
|
|||
value={currentValue}
|
||||
onValueChange={(value) => {
|
||||
if (!value) return;
|
||||
setValue(`${service}_${field}`, value, { shouldDirty: true });
|
||||
setValue(fieldKey, value, { shouldDirty: true });
|
||||
}}
|
||||
>
|
||||
<SelectTrigger className="w-full">
|
||||
<SelectValue placeholder="Select model" />
|
||||
<SelectValue placeholder={`Select ${field}`} />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
{modelOptions.map((value: string) => (
|
||||
{options.map((value: string) => (
|
||||
<SelectItem key={value} value={value}>
|
||||
{value}
|
||||
</SelectItem>
|
||||
|
|
@ -538,17 +565,14 @@ export default function ServiceConfiguration() {
|
|||
</Select>
|
||||
<div className="flex items-center space-x-2">
|
||||
<Checkbox
|
||||
id="manual-model-input-dropdown"
|
||||
checked={isManualModelInput}
|
||||
id={`custom-input-${fieldKey}-dropdown`}
|
||||
checked={false}
|
||||
onCheckedChange={(checked) => {
|
||||
setIsManualModelInput(checked as boolean);
|
||||
setIsCustomInput(prev => ({ ...prev, [fieldKey]: checked as boolean }));
|
||||
}}
|
||||
/>
|
||||
<Label
|
||||
htmlFor="manual-model-input-dropdown"
|
||||
className="text-sm font-normal cursor-pointer"
|
||||
>
|
||||
Add Model Manually
|
||||
<Label htmlFor={`custom-input-${fieldKey}-dropdown`} className="text-sm font-normal cursor-pointer">
|
||||
Enter Custom Value
|
||||
</Label>
|
||||
</div>
|
||||
</div>
|
||||
|
|
@ -616,6 +640,9 @@ export default function ServiceConfiguration() {
|
|||
);
|
||||
};
|
||||
|
||||
const visibleTabs = isRealtime ? REALTIME_TABS : STANDARD_TABS;
|
||||
const defaultTab = isRealtime ? "realtime" : "llm";
|
||||
|
||||
return (
|
||||
<div className="w-full max-w-2xl mx-auto">
|
||||
<div className="mb-6">
|
||||
|
|
@ -626,18 +653,35 @@ export default function ServiceConfiguration() {
|
|||
</div>
|
||||
|
||||
<form onSubmit={handleSubmit(onSubmit)}>
|
||||
{/* Realtime toggle */}
|
||||
<div className="flex items-center justify-between mb-4 p-4 border rounded-lg">
|
||||
<div>
|
||||
<Label htmlFor="realtime-toggle" className="text-sm font-medium">
|
||||
Realtime Mode
|
||||
</Label>
|
||||
<p className="text-xs text-muted-foreground mt-0.5">
|
||||
Uses a single speech-to-speech model (no separate STT/TTS)
|
||||
</p>
|
||||
</div>
|
||||
<Switch
|
||||
id="realtime-toggle"
|
||||
checked={isRealtime}
|
||||
onCheckedChange={setIsRealtime}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<Card>
|
||||
<CardContent className="pt-6">
|
||||
<Tabs defaultValue="llm" className="w-full">
|
||||
<TabsList className="grid w-full grid-cols-4 mb-6">
|
||||
{TAB_CONFIG.map(({ key, label }) => (
|
||||
<Tabs key={defaultTab} defaultValue={defaultTab} className="w-full">
|
||||
<TabsList className="grid w-full mb-6" style={{ gridTemplateColumns: `repeat(${visibleTabs.length}, 1fr)` }}>
|
||||
{visibleTabs.map(({ key, label }) => (
|
||||
<TabsTrigger key={key} value={key}>
|
||||
{label}
|
||||
</TabsTrigger>
|
||||
))}
|
||||
</TabsList>
|
||||
|
||||
{TAB_CONFIG.map(({ key }) => (
|
||||
{visibleTabs.map(({ key }) => (
|
||||
<TabsContent key={key} value={key} className="mt-0">
|
||||
{renderServiceFields(key)}
|
||||
</TabsContent>
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue