mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
feat: add gemini realtime and speaches integration
- Add gemini realtime support - Add speaches support for locally hosted LLMs
This commit is contained in:
parent
2eaaabd936
commit
ee2028eb2d
19 changed files with 531 additions and 185 deletions
|
|
@ -67,6 +67,10 @@ def setup_logging():
|
|||
# Handler might already be removed
|
||||
pass
|
||||
|
||||
# Set default extra values on the shared core so ALL logger references
|
||||
# (including ones imported before this runs) have run_id available.
|
||||
loguru.logger.configure(extra={"run_id": None})
|
||||
|
||||
# Patch loguru to inject run_id
|
||||
patched = loguru.logger.patch(inject_run_id)
|
||||
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ class DefaultConfigurationsResponse(TypedDict):
|
|||
tts: dict[str, dict]
|
||||
stt: dict[str, dict]
|
||||
embeddings: dict[str, dict]
|
||||
realtime: dict[str, dict]
|
||||
default_providers: dict[str, str]
|
||||
|
||||
|
||||
|
|
@ -55,6 +56,10 @@ async def get_default_configurations() -> DefaultConfigurationsResponse:
|
|||
provider: model_cls.model_json_schema()
|
||||
for provider, model_cls in REGISTRY[ServiceType.EMBEDDINGS].items()
|
||||
},
|
||||
"realtime": {
|
||||
provider: model_cls.model_json_schema()
|
||||
for provider, model_cls in REGISTRY[ServiceType.REALTIME].items()
|
||||
},
|
||||
"default_providers": DEFAULT_SERVICE_PROVIDERS,
|
||||
}
|
||||
return configurations
|
||||
|
|
@ -75,6 +80,8 @@ class UserConfigurationRequestResponseSchema(BaseModel):
|
|||
tts: dict[str, Union[str, float, list[str], None]] | None = None
|
||||
stt: dict[str, Union[str, float, list[str], None]] | None = None
|
||||
embeddings: dict[str, Union[str, float, list[str], None]] | None = None
|
||||
realtime: dict[str, Union[str, float, list[str], None]] | None = None
|
||||
is_realtime: bool | None = None
|
||||
test_phone_number: str | None = None
|
||||
timezone: str | None = None
|
||||
organization_pricing: dict[str, Union[float, str, bool]] | None = None
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from pydantic import BaseModel
|
|||
from api.services.configuration.registry import (
|
||||
EmbeddingsConfig,
|
||||
LLMConfig,
|
||||
RealtimeConfig,
|
||||
STTConfig,
|
||||
TTSConfig,
|
||||
)
|
||||
|
|
@ -15,6 +16,8 @@ class UserConfiguration(BaseModel):
|
|||
stt: STTConfig | None = None
|
||||
tts: TTSConfig | None = None
|
||||
embeddings: EmbeddingsConfig | None = None
|
||||
realtime: RealtimeConfig | None = None
|
||||
is_realtime: bool = False
|
||||
test_phone_number: str | None = None
|
||||
timezone: str | None = None
|
||||
last_validated_at: datetime | None = None
|
||||
|
|
|
|||
|
|
@ -47,6 +47,8 @@ class UserConfigurationValidator:
|
|||
ServiceProviders.CAMB.value: self._check_camb_api_key,
|
||||
ServiceProviders.AWS_BEDROCK.value: self._check_aws_bedrock_api_key,
|
||||
ServiceProviders.SPEACHES.value: self._check_speaches_api_key,
|
||||
ServiceProviders.OPENAI_REALTIME.value: self._check_openai_api_key,
|
||||
ServiceProviders.GOOGLE_REALTIME.value: self._check_google_api_key,
|
||||
}
|
||||
|
||||
async def validate(
|
||||
|
|
@ -70,6 +72,10 @@ class UserConfigurationValidator:
|
|||
configuration.embeddings, "embeddings", required=False
|
||||
)
|
||||
)
|
||||
# Realtime is optional - only validate if configured
|
||||
status_list.extend(
|
||||
self._validate_service(configuration.realtime, "realtime", required=False)
|
||||
)
|
||||
|
||||
if status_list:
|
||||
raise ValueError(status_list)
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ def contains_masked_key(api_key: str | list[str] | None) -> bool:
|
|||
|
||||
def check_for_masked_keys(config: "UserConfiguration") -> None:
|
||||
"""Raise ValueError if any service in *config* still has a masked API key."""
|
||||
for field in ("llm", "tts", "stt", "embeddings"):
|
||||
for field in ("llm", "tts", "stt", "embeddings", "realtime"):
|
||||
service = getattr(config, field, None)
|
||||
if service is None:
|
||||
continue
|
||||
|
|
@ -121,6 +121,8 @@ def mask_user_config(config: UserConfiguration) -> Dict[str, Any]:
|
|||
"tts": _mask_service(config.tts),
|
||||
"stt": _mask_service(config.stt),
|
||||
"embeddings": _mask_service(config.embeddings),
|
||||
"realtime": _mask_service(config.realtime),
|
||||
"is_realtime": config.is_realtime,
|
||||
"test_phone_number": config.test_phone_number,
|
||||
"timezone": config.timezone,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from typing import Dict
|
|||
from api.schemas.user_configuration import UserConfiguration
|
||||
from api.services.configuration.masking import resolve_masked_api_keys
|
||||
|
||||
SERVICE_FIELDS = ("llm", "tts", "stt", "embeddings")
|
||||
SERVICE_FIELDS = ("llm", "tts", "stt", "embeddings", "realtime")
|
||||
|
||||
|
||||
def merge_user_configurations(
|
||||
|
|
@ -64,6 +64,9 @@ def merge_user_configurations(
|
|||
_merge_service_block(service)
|
||||
|
||||
# other simple fields
|
||||
if "is_realtime" in incoming_partial:
|
||||
merged["is_realtime"] = incoming_partial["is_realtime"]
|
||||
|
||||
if "test_phone_number" in incoming_partial:
|
||||
merged["test_phone_number"] = incoming_partial["test_phone_number"]
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ class ServiceType(Enum):
|
|||
TTS = auto()
|
||||
STT = auto()
|
||||
EMBEDDINGS = auto()
|
||||
REALTIME = auto()
|
||||
|
||||
|
||||
class ServiceProviders(str, Enum):
|
||||
|
|
@ -28,6 +29,8 @@ class ServiceProviders(str, Enum):
|
|||
CAMB = "camb"
|
||||
AWS_BEDROCK = "aws_bedrock"
|
||||
SPEACHES = "speaches"
|
||||
OPENAI_REALTIME = "openai_realtime"
|
||||
GOOGLE_REALTIME = "google_realtime"
|
||||
|
||||
|
||||
class BaseServiceConfiguration(BaseModel):
|
||||
|
|
@ -42,6 +45,8 @@ class BaseServiceConfiguration(BaseModel):
|
|||
ServiceProviders.DOGRAH,
|
||||
ServiceProviders.AWS_BEDROCK,
|
||||
ServiceProviders.SPEACHES,
|
||||
ServiceProviders.OPENAI_REALTIME,
|
||||
ServiceProviders.GOOGLE_REALTIME,
|
||||
# ServiceProviders.SARVAM,
|
||||
]
|
||||
api_key: str | list[str]
|
||||
|
|
@ -97,6 +102,7 @@ REGISTRY: Dict[ServiceType, Dict[str, Type[BaseServiceConfiguration]]] = {
|
|||
ServiceType.TTS: {},
|
||||
ServiceType.STT: {},
|
||||
ServiceType.EMBEDDINGS: {},
|
||||
ServiceType.REALTIME: {},
|
||||
}
|
||||
|
||||
T = TypeVar("T", bound=BaseServiceConfiguration)
|
||||
|
|
@ -279,6 +285,68 @@ class SpeachesLLMConfiguration(BaseLLMConfiguration):
|
|||
api_key: str | list[str] | None = Field(default=None)
|
||||
|
||||
|
||||
OPENAI_REALTIME_MODELS = ["gpt-4o-realtime-preview", "gpt-4o-mini-realtime-preview"]
|
||||
OPENAI_REALTIME_VOICES = [
|
||||
"alloy",
|
||||
"ash",
|
||||
"ballad",
|
||||
"coral",
|
||||
"echo",
|
||||
"sage",
|
||||
"shimmer",
|
||||
"verse",
|
||||
]
|
||||
|
||||
|
||||
# @register_service(ServiceType.REALTIME)
|
||||
# class OpenAIRealtimeLLMConfiguration(BaseLLMConfiguration):
|
||||
# provider: Literal[ServiceProviders.OPENAI_REALTIME] = (
|
||||
# ServiceProviders.OPENAI_REALTIME
|
||||
# )
|
||||
# model: str = Field(
|
||||
# default="gpt-4o-realtime-preview",
|
||||
# json_schema_extra={
|
||||
# "examples": OPENAI_REALTIME_MODELS,
|
||||
# "allow_custom_input": True,
|
||||
# },
|
||||
# )
|
||||
# voice: str = Field(
|
||||
# default="alloy",
|
||||
# json_schema_extra={"examples": OPENAI_REALTIME_VOICES},
|
||||
# )
|
||||
|
||||
|
||||
GOOGLE_REALTIME_MODELS = ["gemini-3.1-flash-live-preview"]
|
||||
GOOGLE_REALTIME_VOICES = ["Puck", "Charon", "Kore", "Fenrir", "Aoede"]
|
||||
|
||||
|
||||
@register_service(ServiceType.REALTIME)
|
||||
class GoogleRealtimeLLMConfiguration(BaseLLMConfiguration):
|
||||
provider: Literal[ServiceProviders.GOOGLE_REALTIME] = (
|
||||
ServiceProviders.GOOGLE_REALTIME
|
||||
)
|
||||
model: str = Field(
|
||||
default="gemini-2.0-flash-live-001",
|
||||
json_schema_extra={
|
||||
"examples": GOOGLE_REALTIME_MODELS,
|
||||
"allow_custom_input": True,
|
||||
},
|
||||
)
|
||||
voice: str = Field(
|
||||
default="Puck",
|
||||
json_schema_extra={
|
||||
"examples": GOOGLE_REALTIME_VOICES,
|
||||
"allow_custom_input": True,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
REALTIME_PROVIDERS = {
|
||||
ServiceProviders.OPENAI_REALTIME.value,
|
||||
ServiceProviders.GOOGLE_REALTIME.value,
|
||||
}
|
||||
|
||||
|
||||
LLMConfig = Annotated[
|
||||
Union[
|
||||
OpenAILLMService,
|
||||
|
|
@ -293,6 +361,14 @@ LLMConfig = Annotated[
|
|||
Field(discriminator="provider"),
|
||||
]
|
||||
|
||||
RealtimeConfig = Annotated[
|
||||
Union[
|
||||
# OpenAIRealtimeLLMConfiguration,
|
||||
GoogleRealtimeLLMConfiguration,
|
||||
],
|
||||
Field(discriminator="provider"),
|
||||
]
|
||||
|
||||
###################################################### TTS ########################################################################
|
||||
|
||||
|
||||
|
|
@ -719,6 +795,7 @@ SPEACHES_STT_MODELS = [
|
|||
"Systran/faster-distil-whisper-small.en",
|
||||
"Systran/faster-whisper-large-v3",
|
||||
]
|
||||
SPEACHES_STT_LANGUAGES = ["en", "ar", "nl", "fr", "de", "hi", "it", "pt", "es"]
|
||||
|
||||
|
||||
@register_stt
|
||||
|
|
@ -731,6 +808,13 @@ class SpeachesSTTConfiguration(BaseSTTConfiguration):
|
|||
"allow_custom_input": True,
|
||||
},
|
||||
)
|
||||
language: str = Field(
|
||||
default="en",
|
||||
json_schema_extra={
|
||||
"examples": SPEACHES_STT_LANGUAGES,
|
||||
"allow_custom_input": True,
|
||||
},
|
||||
)
|
||||
base_url: str = Field(
|
||||
default="http://localhost:8000/v1",
|
||||
description="OpenAI-compatible STT endpoint (Speaches, etc.)",
|
||||
|
|
@ -785,6 +869,6 @@ EmbeddingsConfig = Annotated[
|
|||
]
|
||||
|
||||
ServiceConfig = Annotated[
|
||||
Union[LLMConfig, TTSConfig, STTConfig, EmbeddingsConfig],
|
||||
Union[LLMConfig, RealtimeConfig, TTSConfig, STTConfig, EmbeddingsConfig],
|
||||
Field(discriminator="provider"),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -97,6 +97,34 @@ def build_pipeline(
|
|||
return Pipeline(processors)
|
||||
|
||||
|
||||
def build_realtime_pipeline(
|
||||
transport,
|
||||
realtime_llm,
|
||||
audio_buffer,
|
||||
user_context_aggregator,
|
||||
assistant_context_aggregator,
|
||||
pipeline_engine_callback_processor,
|
||||
pipeline_metrics_aggregator,
|
||||
):
|
||||
"""Build a pipeline for realtime (speech-to-speech) LLM services.
|
||||
|
||||
Realtime services (e.g. OpenAI Realtime, Gemini Live) handle STT+LLM+TTS
|
||||
internally, so no separate STT or TTS processors are needed.
|
||||
"""
|
||||
processors = [
|
||||
transport.input(),
|
||||
user_context_aggregator,
|
||||
realtime_llm,
|
||||
pipeline_engine_callback_processor,
|
||||
transport.output(),
|
||||
audio_buffer,
|
||||
assistant_context_aggregator,
|
||||
pipeline_metrics_aggregator,
|
||||
]
|
||||
|
||||
return Pipeline(processors)
|
||||
|
||||
|
||||
def create_pipeline_task(pipeline, workflow_run_id, audio_config: AudioConfig = None):
|
||||
"""Create a pipeline task with appropriate parameters"""
|
||||
# Set up pipeline params with audio configuration if provided
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from api.services.pipecat.event_handlers import (
|
|||
from api.services.pipecat.in_memory_buffers import InMemoryLogsBuffer
|
||||
from api.services.pipecat.pipeline_builder import (
|
||||
build_pipeline,
|
||||
build_realtime_pipeline,
|
||||
create_pipeline_components,
|
||||
create_pipeline_task,
|
||||
)
|
||||
|
|
@ -35,6 +36,7 @@ from api.services.pipecat.recording_router_processor import RecordingRouterProce
|
|||
from api.services.pipecat.service_factory import (
|
||||
create_llm_service,
|
||||
create_llm_service_from_provider,
|
||||
create_realtime_llm_service,
|
||||
create_stt_service,
|
||||
create_tts_service,
|
||||
)
|
||||
|
|
@ -603,10 +605,18 @@ async def _run_pipeline(
|
|||
term.strip() for term in dictionary.split(",") if term.strip()
|
||||
]
|
||||
|
||||
# Detect realtime mode (speech-to-speech services like OpenAI Realtime, Gemini Live)
|
||||
is_realtime = user_config.is_realtime and user_config.realtime is not None
|
||||
|
||||
# Create services based on user configuration
|
||||
stt = create_stt_service(user_config, audio_config, keyterms=keyterms)
|
||||
tts = create_tts_service(user_config, audio_config)
|
||||
llm = create_llm_service(user_config)
|
||||
if is_realtime:
|
||||
llm = create_realtime_llm_service(user_config, audio_config)
|
||||
stt = None
|
||||
tts = None
|
||||
else:
|
||||
stt = create_stt_service(user_config, audio_config, keyterms=keyterms)
|
||||
tts = create_tts_service(user_config, audio_config)
|
||||
llm = create_llm_service(user_config)
|
||||
|
||||
workflow_graph = WorkflowGraph(
|
||||
ReactFlowDTO.model_validate(workflow.workflow_definition_with_fallback)
|
||||
|
|
@ -694,46 +704,66 @@ async def _run_pipeline(
|
|||
)
|
||||
|
||||
# Configure turn strategies based on STT provider, model, and workflow configuration
|
||||
# Deepgram Flux uses external turn detection (VAD + External start/stop)
|
||||
# Other models use configurable turn detection strategy
|
||||
is_deepgram_flux = (
|
||||
user_config.stt.provider == ServiceProviders.DEEPGRAM.value
|
||||
and user_config.stt.model == "flux-general-en"
|
||||
)
|
||||
if is_realtime:
|
||||
# Realtime services have server-side VAD/turn detection.
|
||||
# For stop strategy, lets rely on SmartTurnAnalyzer which is
|
||||
# enabled by default
|
||||
user_turn_strategies = UserTurnStrategies(
|
||||
start=[VADUserTurnStartStrategy()], stop=[]
|
||||
)
|
||||
|
||||
if is_deepgram_flux:
|
||||
user_turn_strategies = UserTurnStrategies(
|
||||
start=[
|
||||
VADUserTurnStartStrategy(),
|
||||
ExternalUserTurnStartStrategy(enable_interruptions=True),
|
||||
],
|
||||
stop=[ExternalUserTurnStopStrategy()],
|
||||
)
|
||||
elif turn_stop_strategy == "turn_analyzer":
|
||||
# Smart Turn Analyzer: best for longer responses with natural pauses
|
||||
smart_turn_params = SmartTurnParams(stop_secs=smart_turn_stop_secs)
|
||||
user_turn_strategies = UserTurnStrategies(
|
||||
start=[VADUserTurnStartStrategy(), TranscriptionUserTurnStartStrategy()],
|
||||
stop=[
|
||||
TurnAnalyzerUserTurnStopStrategy(
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=smart_turn_params)
|
||||
)
|
||||
],
|
||||
)
|
||||
# Lets not start the pipeline as muted for Realtime
|
||||
# - CallbackUserMuteStrategy: mutes based on engine's _mute_pipeline state
|
||||
user_mute_strategies = [
|
||||
FunctionCallUserMuteStrategy(),
|
||||
CallbackUserMuteStrategy(should_mute_callback=engine.should_mute_user),
|
||||
]
|
||||
else:
|
||||
# Transcription-based (default): best for short 1-2 word responses
|
||||
user_turn_strategies = UserTurnStrategies(
|
||||
start=[VADUserTurnStartStrategy(), TranscriptionUserTurnStartStrategy()],
|
||||
stop=[SpeechTimeoutUserTurnStopStrategy()],
|
||||
# Deepgram Flux uses external turn detection (VAD + External start/stop)
|
||||
# Other models use configurable turn detection strategy
|
||||
is_deepgram_flux = (
|
||||
user_config.stt.provider == ServiceProviders.DEEPGRAM.value
|
||||
and user_config.stt.model == "flux-general-en"
|
||||
)
|
||||
|
||||
# Create user mute strategies
|
||||
# - CallbackUserMuteStrategy: mutes based on engine's _mute_pipeline state
|
||||
user_mute_strategies = [
|
||||
MuteUntilFirstBotCompleteUserMuteStrategy(),
|
||||
FunctionCallUserMuteStrategy(),
|
||||
CallbackUserMuteStrategy(should_mute_callback=engine.should_mute_user),
|
||||
]
|
||||
if is_deepgram_flux:
|
||||
user_turn_strategies = UserTurnStrategies(
|
||||
start=[
|
||||
VADUserTurnStartStrategy(),
|
||||
ExternalUserTurnStartStrategy(enable_interruptions=True),
|
||||
],
|
||||
stop=[ExternalUserTurnStopStrategy()],
|
||||
)
|
||||
elif turn_stop_strategy == "turn_analyzer":
|
||||
# Smart Turn Analyzer: best for longer responses with natural pauses
|
||||
smart_turn_params = SmartTurnParams(stop_secs=smart_turn_stop_secs)
|
||||
user_turn_strategies = UserTurnStrategies(
|
||||
start=[
|
||||
VADUserTurnStartStrategy(),
|
||||
TranscriptionUserTurnStartStrategy(),
|
||||
],
|
||||
stop=[
|
||||
TurnAnalyzerUserTurnStopStrategy(
|
||||
turn_analyzer=LocalSmartTurnAnalyzerV3(params=smart_turn_params)
|
||||
)
|
||||
],
|
||||
)
|
||||
else:
|
||||
# Transcription-based (default): best for short 1-2 word responses
|
||||
user_turn_strategies = UserTurnStrategies(
|
||||
start=[
|
||||
VADUserTurnStartStrategy(),
|
||||
TranscriptionUserTurnStartStrategy(),
|
||||
],
|
||||
stop=[SpeechTimeoutUserTurnStopStrategy()],
|
||||
)
|
||||
|
||||
# - CallbackUserMuteStrategy: mutes based on engine's _mute_pipeline state
|
||||
user_mute_strategies = [
|
||||
MuteUntilFirstBotCompleteUserMuteStrategy(),
|
||||
FunctionCallUserMuteStrategy(),
|
||||
CallbackUserMuteStrategy(should_mute_callback=engine.should_mute_user),
|
||||
]
|
||||
|
||||
user_params = LLMUserAggregatorParams(
|
||||
user_turn_strategies=user_turn_strategies,
|
||||
|
|
@ -769,77 +799,93 @@ async def _run_pipeline(
|
|||
async def on_user_turn_started(aggregator, strategy):
|
||||
user_idle_handler.reset()
|
||||
|
||||
# Create voicemail detector if enabled in workflow configurations
|
||||
# Voicemail detection and recording router are not supported in realtime mode
|
||||
voicemail_detector = None
|
||||
voicemail_config = (workflow.workflow_configurations or {}).get(
|
||||
"voicemail_detection", {}
|
||||
)
|
||||
if voicemail_config.get("enabled", False):
|
||||
logger.info(f"Voicemail detection enabled for workflow run {workflow_run_id}")
|
||||
# Create a separate LLM instance for the voicemail sub-pipeline
|
||||
# (can't share with main pipeline as it would mess up frame linking)
|
||||
if voicemail_config.get("use_workflow_llm", True):
|
||||
voicemail_llm = create_llm_service(user_config)
|
||||
else:
|
||||
voicemail_llm = create_llm_service_from_provider(
|
||||
provider=voicemail_config.get("provider", "openai"),
|
||||
model=voicemail_config.get("model", "gpt-4.1"),
|
||||
api_key=voicemail_config.get("api_key", ""),
|
||||
)
|
||||
|
||||
long_speech_timeout = voicemail_config.get("long_speech_timeout", 8.0)
|
||||
custom_system_prompt = voicemail_config.get("system_prompt") or None
|
||||
|
||||
voicemail_detector = VoicemailDetector(
|
||||
llm=voicemail_llm,
|
||||
long_speech_timeout=long_speech_timeout,
|
||||
custom_system_prompt=custom_system_prompt,
|
||||
)
|
||||
|
||||
# Register event handler to end task when voicemail is detected
|
||||
@voicemail_detector.event_handler("on_voicemail_detected")
|
||||
async def _on_voicemail_detected(_processor):
|
||||
logger.info(f"Voicemail detected for workflow run {workflow_run_id}")
|
||||
await engine.end_call_with_reason(
|
||||
reason=EndTaskReason.VOICEMAIL_DETECTED.value,
|
||||
abort_immediately=True,
|
||||
)
|
||||
|
||||
# Create recording router if workflow has active recordings
|
||||
recording_router = None
|
||||
if has_recordings:
|
||||
fetch_audio = create_recording_audio_fetcher(
|
||||
organization_id=workflow.organization_id,
|
||||
pipeline_sample_rate=audio_config.pipeline_sample_rate,
|
||||
|
||||
if not is_realtime:
|
||||
# Create voicemail detector if enabled in workflow configurations
|
||||
voicemail_config = (workflow.workflow_configurations or {}).get(
|
||||
"voicemail_detection", {}
|
||||
)
|
||||
recording_router = RecordingRouterProcessor(
|
||||
audio_sample_rate=audio_config.pipeline_sample_rate,
|
||||
fetch_recording_audio=fetch_audio,
|
||||
)
|
||||
# Warm the recording cache in the background so audio is ready
|
||||
# before the first playback request.
|
||||
asyncio.create_task(
|
||||
warm_recording_cache(
|
||||
workflow_id=workflow_id,
|
||||
if voicemail_config.get("enabled", False):
|
||||
logger.info(
|
||||
f"Voicemail detection enabled for workflow run {workflow_run_id}"
|
||||
)
|
||||
# Create a separate LLM instance for the voicemail sub-pipeline
|
||||
# (can't share with main pipeline as it would mess up frame linking)
|
||||
if voicemail_config.get("use_workflow_llm", True):
|
||||
voicemail_llm = create_llm_service(user_config)
|
||||
else:
|
||||
voicemail_llm = create_llm_service_from_provider(
|
||||
provider=voicemail_config.get("provider", "openai"),
|
||||
model=voicemail_config.get("model", "gpt-4.1"),
|
||||
api_key=voicemail_config.get("api_key", ""),
|
||||
)
|
||||
|
||||
long_speech_timeout = voicemail_config.get("long_speech_timeout", 8.0)
|
||||
custom_system_prompt = voicemail_config.get("system_prompt") or None
|
||||
|
||||
voicemail_detector = VoicemailDetector(
|
||||
llm=voicemail_llm,
|
||||
long_speech_timeout=long_speech_timeout,
|
||||
custom_system_prompt=custom_system_prompt,
|
||||
)
|
||||
|
||||
# Register event handler to end task when voicemail is detected
|
||||
@voicemail_detector.event_handler("on_voicemail_detected")
|
||||
async def _on_voicemail_detected(_processor):
|
||||
logger.info(f"Voicemail detected for workflow run {workflow_run_id}")
|
||||
await engine.end_call_with_reason(
|
||||
reason=EndTaskReason.VOICEMAIL_DETECTED.value,
|
||||
abort_immediately=True,
|
||||
)
|
||||
|
||||
# Create recording router if workflow has active recordings
|
||||
if has_recordings:
|
||||
fetch_audio = create_recording_audio_fetcher(
|
||||
organization_id=workflow.organization_id,
|
||||
pipeline_sample_rate=audio_config.pipeline_sample_rate,
|
||||
)
|
||||
)
|
||||
recording_router = RecordingRouterProcessor(
|
||||
audio_sample_rate=audio_config.pipeline_sample_rate,
|
||||
fetch_recording_audio=fetch_audio,
|
||||
)
|
||||
# Warm the recording cache in the background so audio is ready
|
||||
# before the first playback request.
|
||||
asyncio.create_task(
|
||||
warm_recording_cache(
|
||||
workflow_id=workflow_id,
|
||||
organization_id=workflow.organization_id,
|
||||
pipeline_sample_rate=audio_config.pipeline_sample_rate,
|
||||
)
|
||||
)
|
||||
|
||||
# Build the pipeline with the STT mute filter and context controller
|
||||
pipeline = build_pipeline(
|
||||
transport,
|
||||
stt,
|
||||
audio_buffer,
|
||||
llm,
|
||||
tts,
|
||||
user_context_aggregator,
|
||||
assistant_context_aggregator,
|
||||
pipeline_engine_callback_processor,
|
||||
pipeline_metrics_aggregator,
|
||||
voicemail_detector=voicemail_detector,
|
||||
recording_router=recording_router,
|
||||
)
|
||||
# Build the pipeline
|
||||
if is_realtime:
|
||||
pipeline = build_realtime_pipeline(
|
||||
transport,
|
||||
llm,
|
||||
audio_buffer,
|
||||
user_context_aggregator,
|
||||
assistant_context_aggregator,
|
||||
pipeline_engine_callback_processor,
|
||||
pipeline_metrics_aggregator,
|
||||
)
|
||||
else:
|
||||
pipeline = build_pipeline(
|
||||
transport,
|
||||
stt,
|
||||
audio_buffer,
|
||||
llm,
|
||||
tts,
|
||||
user_context_aggregator,
|
||||
assistant_context_aggregator,
|
||||
pipeline_engine_callback_processor,
|
||||
pipeline_metrics_aggregator,
|
||||
voicemail_detector=voicemail_detector,
|
||||
recording_router=recording_router,
|
||||
)
|
||||
|
||||
# Create pipeline task with audio configuration
|
||||
task = create_pipeline_task(pipeline, workflow_run_id, audio_config)
|
||||
|
|
@ -847,7 +893,8 @@ async def _run_pipeline(
|
|||
# Now set the task on the engine
|
||||
engine.set_task(task)
|
||||
|
||||
# Initialize the engine to set the initial context
|
||||
# Initialize the engine to set the initial context with
|
||||
# System Prompt and Tools
|
||||
await engine.initialize()
|
||||
|
||||
# Add real-time feedback observer (always logs to buffer, streams to WS if available)
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ from pipecat.services.openai.tts import OpenAITTSService, OpenAITTSSettings
|
|||
from pipecat.services.openrouter.llm import OpenRouterLLMService, OpenRouterLLMSettings
|
||||
from pipecat.services.sarvam.stt import SarvamSTTService, SarvamSTTSettings
|
||||
from pipecat.services.sarvam.tts import SarvamTTSService, SarvamTTSSettings
|
||||
from pipecat.services.speaches.llm import SpeachesLLMService, SpeachesLLMSettings
|
||||
from pipecat.services.speaches.stt import SpeachesSTTService, SpeachesSTTSettings
|
||||
from pipecat.services.speaches.tts import SpeachesTTSService, SpeachesTTSSettings
|
||||
from pipecat.services.speechmatics.stt import (
|
||||
|
|
@ -63,7 +64,6 @@ def create_stt_service(
|
|||
if user_config.stt.provider == ServiceProviders.DEEPGRAM.value:
|
||||
# Check if using Flux model (English-only, no language selection)
|
||||
if user_config.stt.model == "flux-general-en":
|
||||
logger.debug("Using DeepGram Flux Model")
|
||||
return DeepgramFluxSTTService(
|
||||
api_key=user_config.stt.api_key,
|
||||
settings=DeepgramFluxSTTSettings(
|
||||
|
|
@ -395,15 +395,75 @@ def create_llm_service_from_provider(
|
|||
settings=AWSBedrockLLMSettings(model=model),
|
||||
)
|
||||
elif provider == ServiceProviders.SPEACHES.value:
|
||||
return OpenAILLMService(
|
||||
return SpeachesLLMService(
|
||||
base_url=base_url or "http://localhost:11434/v1",
|
||||
api_key=api_key or "none",
|
||||
settings=OpenAILLMSettings(model=model),
|
||||
settings=SpeachesLLMSettings(model=model),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid LLM provider {provider}")
|
||||
|
||||
|
||||
def create_realtime_llm_service(user_config, audio_config: "AudioConfig"):
|
||||
"""Create a realtime (speech-to-speech) LLM service that handles STT+LLM+TTS.
|
||||
|
||||
These services bypass separate STT/TTS and handle audio directly via
|
||||
a bidirectional WebSocket connection. Reads from user_config.realtime.
|
||||
"""
|
||||
realtime_config = user_config.realtime
|
||||
provider = realtime_config.provider
|
||||
model = realtime_config.model
|
||||
api_key = realtime_config.api_key
|
||||
voice = getattr(realtime_config, "voice", None)
|
||||
|
||||
logger.info(
|
||||
f"Creating realtime LLM service: provider={provider}, model={model}, voice={voice}"
|
||||
)
|
||||
|
||||
if provider == ServiceProviders.OPENAI_REALTIME.value:
|
||||
from pipecat.services.openai.realtime.events import (
|
||||
AudioConfiguration,
|
||||
AudioInput,
|
||||
AudioOutput,
|
||||
InputAudioTranscription,
|
||||
SessionProperties,
|
||||
)
|
||||
from pipecat.services.openai.realtime.llm import OpenAIRealtimeLLMService
|
||||
|
||||
return OpenAIRealtimeLLMService(
|
||||
api_key=api_key,
|
||||
settings=OpenAIRealtimeLLMService.Settings(
|
||||
model=model,
|
||||
session_properties=SessionProperties(
|
||||
audio=AudioConfiguration(
|
||||
input=AudioInput(
|
||||
transcription=InputAudioTranscription(),
|
||||
),
|
||||
output=AudioOutput(
|
||||
voice=voice or "alloy",
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
elif provider == ServiceProviders.GOOGLE_REALTIME.value:
|
||||
from pipecat.services.google.gemini_live.llm import GeminiLiveLLMService
|
||||
|
||||
# Gemini Live enables input/output audio transcription by default
|
||||
# in its _connect() method — no need to configure it explicitly.
|
||||
return GeminiLiveLLMService(
|
||||
api_key=api_key,
|
||||
settings=GeminiLiveLLMService.Settings(
|
||||
model=model,
|
||||
voice=voice or "Puck", # vad=GeminiVADParams(disabled=True)
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid realtime LLM provider {provider}"
|
||||
)
|
||||
|
||||
|
||||
def create_llm_service(user_config):
|
||||
"""Create and return appropriate LLM service based on user configuration."""
|
||||
provider = user_config.llm.provider
|
||||
|
|
|
|||
|
|
@ -210,12 +210,17 @@ class PipecatEngine:
|
|||
async def _update_llm_context(self, system_prompt: str, functions: list[dict]):
|
||||
"""Update LLM settings with the composed system prompt and tool list."""
|
||||
|
||||
await self.llm._update_settings(LLMSettings(system_instruction=system_prompt))
|
||||
|
||||
if functions:
|
||||
tools_schema = ToolsSchema(standard_tools=functions)
|
||||
self.context.set_tools(tools_schema)
|
||||
|
||||
await self.llm._update_settings(LLMSettings(system_instruction=system_prompt))
|
||||
|
||||
# For Gemini Live, set context on the LLM before _update_settings so that
|
||||
# _connect (triggered by reconnect) can read tools from it.
|
||||
if hasattr(self.llm, "_context") and not self.llm._context and self.context:
|
||||
self.llm._context = self.context
|
||||
|
||||
def _format_prompt(self, prompt: str) -> str:
|
||||
"""Delegate prompt formatting to the shared workflow.utils implementation."""
|
||||
|
||||
|
|
|
|||
|
|
@ -215,13 +215,17 @@ class VariableExtractionManager:
|
|||
with tracer.start_as_current_span(
|
||||
"llm-variable-extraction", context=parent_ctx
|
||||
) as span:
|
||||
tracing_messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
*extraction_messages,
|
||||
]
|
||||
add_llm_span_attributes(
|
||||
span,
|
||||
service_name=self._engine.llm.__class__.__name__,
|
||||
model=model_name,
|
||||
operation_name="llm-variable-extraction",
|
||||
messages=extraction_messages,
|
||||
output=llm_response,
|
||||
messages=tracing_messages,
|
||||
output=json.dumps({"content": llm_response}),
|
||||
stream=False,
|
||||
parameters={},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -59,7 +59,14 @@ async def _generate_conversation_summary(
|
|||
)
|
||||
|
||||
span_name = f"conversation-summary-before-{node_name}"
|
||||
add_qa_span_to_trace(parent_ctx, model, messages, summary, span_name)
|
||||
add_qa_span_to_trace(
|
||||
parent_ctx,
|
||||
model,
|
||||
messages,
|
||||
summary,
|
||||
span_name,
|
||||
CONVERSATION_SUMMARY_SYSTEM_PROMPT,
|
||||
)
|
||||
|
||||
return summary
|
||||
except Exception as e:
|
||||
|
|
@ -189,7 +196,9 @@ async def run_per_node_qa_analysis(
|
|||
|
||||
# Trace
|
||||
span_name = f"qa-node-{node_name}"
|
||||
add_qa_span_to_trace(parent_ctx, model, messages, raw_response, span_name)
|
||||
add_qa_span_to_trace(
|
||||
parent_ctx, model, messages, raw_response, span_name, system_content
|
||||
)
|
||||
|
||||
# Parse response
|
||||
node_result: dict[str, Any] = {
|
||||
|
|
@ -299,7 +308,9 @@ async def _run_whole_call_qa_analysis(
|
|||
|
||||
# Langfuse tracing
|
||||
parent_ctx = setup_langfuse_parent_context(workflow_run)
|
||||
add_qa_span_to_trace(parent_ctx, model, messages, raw_response, "qa-analysis")
|
||||
add_qa_span_to_trace(
|
||||
parent_ctx, model, messages, raw_response, "qa-analysis", system_content
|
||||
)
|
||||
|
||||
return {
|
||||
"node_results": {"whole_call": node_result},
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
"""LLM configuration resolution and token usage accumulation."""
|
||||
|
||||
import random
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.models import WorkflowRunModel
|
||||
|
||||
|
|
@ -57,6 +59,8 @@ async def resolve_user_llm_config(
|
|||
|
||||
provider = llm_config.get("provider", "openai")
|
||||
api_key = llm_config.get("api_key", "")
|
||||
if isinstance(api_key, list):
|
||||
api_key = random.choice(api_key)
|
||||
model = llm_config.get("model", "gpt-4.1")
|
||||
|
||||
kwargs = {}
|
||||
|
|
|
|||
|
|
@ -166,7 +166,9 @@ async def ensure_node_summaries(
|
|||
continue
|
||||
|
||||
# Create a Langfuse trace for this summary generation
|
||||
trace_url = create_node_summary_trace(model, messages, summary_text, node_name)
|
||||
trace_url = create_node_summary_trace(
|
||||
model, messages, summary_text, node_name, NODE_SUMMARY_SYSTEM_PROMPT
|
||||
)
|
||||
|
||||
entry: dict[str, Any] = {"summary": summary_text}
|
||||
if trace_url:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
"""Langfuse / OpenTelemetry tracing helpers for QA analysis."""
|
||||
|
||||
import json
|
||||
import re
|
||||
|
||||
from loguru import logger
|
||||
|
|
@ -70,6 +71,7 @@ def add_qa_span_to_trace(
|
|||
messages: list[dict],
|
||||
output: str,
|
||||
span_name: str,
|
||||
system_prompt: str = "",
|
||||
) -> None:
|
||||
"""Create a child span under the conversation trace."""
|
||||
if parent_ctx is None:
|
||||
|
|
@ -84,13 +86,21 @@ def add_qa_span_to_trace(
|
|||
span_name,
|
||||
context=parent_ctx,
|
||||
) as span:
|
||||
tracing_messages = (
|
||||
[
|
||||
{"role": "system", "content": system_prompt},
|
||||
*messages,
|
||||
]
|
||||
if system_prompt
|
||||
else messages
|
||||
)
|
||||
add_llm_span_attributes(
|
||||
span,
|
||||
service_name="OpenAILLMService",
|
||||
model=model,
|
||||
operation_name=span_name,
|
||||
messages=messages,
|
||||
output=output,
|
||||
messages=tracing_messages,
|
||||
output=json.dumps({"content": output}),
|
||||
stream=False,
|
||||
parameters={"temperature": 0},
|
||||
)
|
||||
|
|
@ -103,6 +113,7 @@ def create_node_summary_trace(
|
|||
messages: list[dict],
|
||||
output: str,
|
||||
node_name: str,
|
||||
system_prompt: str = "",
|
||||
) -> str | None:
|
||||
"""Create a standalone Langfuse trace for a node summary generation.
|
||||
|
||||
|
|
@ -125,13 +136,21 @@ def create_node_summary_trace(
|
|||
f"node-summary-{node_name}",
|
||||
context=Context(),
|
||||
) as span:
|
||||
tracing_messages = (
|
||||
[
|
||||
{"role": "system", "content": system_prompt},
|
||||
*messages,
|
||||
]
|
||||
if system_prompt
|
||||
else messages
|
||||
)
|
||||
add_llm_span_attributes(
|
||||
span,
|
||||
service_name="OpenAILLMService",
|
||||
model=model,
|
||||
operation_name=f"node-summary-{node_name}",
|
||||
messages=messages,
|
||||
output=output,
|
||||
messages=tracing_messages,
|
||||
output=json.dumps({"content": output}),
|
||||
stream=False,
|
||||
parameters={"temperature": 0},
|
||||
)
|
||||
|
|
|
|||
2
pipecat
2
pipecat
|
|
@ -1 +1 @@
|
|||
Subproject commit 2e2171e2a64ec87b3964fbc2440b5291489912a8
|
||||
Subproject commit 6954990a908b4beab2c7298e23e2dede9d20acdb
|
||||
|
|
@ -448,6 +448,11 @@ export type DefaultConfigurationsResponse = {
|
|||
[key: string]: unknown;
|
||||
};
|
||||
};
|
||||
realtime: {
|
||||
[key: string]: {
|
||||
[key: string]: unknown;
|
||||
};
|
||||
};
|
||||
default_providers: {
|
||||
[key: string]: string;
|
||||
};
|
||||
|
|
@ -1329,6 +1334,10 @@ export type UserConfigurationRequestResponseSchema = {
|
|||
embeddings?: {
|
||||
[key: string]: string | number | Array<string> | null;
|
||||
} | null;
|
||||
realtime?: {
|
||||
[key: string]: string | number | Array<string> | null;
|
||||
} | null;
|
||||
is_realtime?: boolean | null;
|
||||
test_phone_number?: string | null;
|
||||
timezone?: string | null;
|
||||
organization_pricing?: {
|
||||
|
|
|
|||
|
|
@ -11,12 +11,13 @@ import { Checkbox } from "@/components/ui/checkbox";
|
|||
import { Input } from "@/components/ui/input";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@/components/ui/select";
|
||||
import { Switch } from "@/components/ui/switch";
|
||||
import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
|
||||
import { VoiceSelector } from "@/components/VoiceSelector";
|
||||
import { LANGUAGE_DISPLAY_NAMES } from "@/constants/languages";
|
||||
import { useUserConfig } from "@/context/UserConfigContext";
|
||||
|
||||
type ServiceSegment = "llm" | "tts" | "stt" | "embeddings";
|
||||
type ServiceSegment = "llm" | "tts" | "stt" | "embeddings" | "realtime";
|
||||
|
||||
interface SchemaProperty {
|
||||
type?: string;
|
||||
|
|
@ -41,13 +42,18 @@ interface FormValues {
|
|||
[key: string]: string | number | boolean;
|
||||
}
|
||||
|
||||
const TAB_CONFIG: { key: ServiceSegment; label: string }[] = [
|
||||
const STANDARD_TABS: { key: ServiceSegment; label: string }[] = [
|
||||
{ key: "llm", label: "LLM" },
|
||||
{ key: "tts", label: "Voice" },
|
||||
{ key: "stt", label: "Transcriber" },
|
||||
{ key: "embeddings", label: "Embedding" },
|
||||
];
|
||||
|
||||
const REALTIME_TABS: { key: ServiceSegment; label: string }[] = [
|
||||
{ key: "realtime", label: "Realtime Model" },
|
||||
{ key: "embeddings", label: "Embedding" },
|
||||
];
|
||||
|
||||
// Display names for Sarvam voices
|
||||
const VOICE_DISPLAY_NAMES: Record<string, string> = {
|
||||
"anushka": "Anushka (Female)",
|
||||
|
|
@ -62,24 +68,28 @@ const VOICE_DISPLAY_NAMES: Record<string, string> = {
|
|||
export default function ServiceConfiguration() {
|
||||
const [apiError, setApiError] = useState<string | null>(null);
|
||||
const [isSaving, setIsSaving] = useState(false);
|
||||
const [isRealtime, setIsRealtime] = useState(false);
|
||||
const { userConfig, saveUserConfig } = useUserConfig();
|
||||
const [schemas, setSchemas] = useState<Record<ServiceSegment, Record<string, ProviderSchema>>>({
|
||||
llm: {},
|
||||
tts: {},
|
||||
stt: {},
|
||||
embeddings: {}
|
||||
embeddings: {},
|
||||
realtime: {},
|
||||
});
|
||||
const [serviceProviders, setServiceProviders] = useState<Record<ServiceSegment, string>>({
|
||||
llm: "",
|
||||
tts: "",
|
||||
stt: "",
|
||||
embeddings: ""
|
||||
embeddings: "",
|
||||
realtime: "",
|
||||
});
|
||||
const [apiKeys, setApiKeys] = useState<Record<ServiceSegment, string[]>>({
|
||||
llm: [""],
|
||||
tts: [""],
|
||||
stt: [""],
|
||||
embeddings: [""],
|
||||
realtime: [""],
|
||||
});
|
||||
const [isCustomInput, setIsCustomInput] = useState<Record<string, boolean>>({});
|
||||
|
||||
|
|
@ -97,12 +107,20 @@ export default function ServiceConfiguration() {
|
|||
const fetchConfigurations = async () => {
|
||||
const response = await getDefaultConfigurationsApiV1UserConfigurationsDefaultsGet();
|
||||
if (response.data) {
|
||||
const data = response.data as Record<string, unknown>;
|
||||
setSchemas({
|
||||
llm: response.data.llm as Record<string, ProviderSchema>,
|
||||
tts: response.data.tts as Record<string, ProviderSchema>,
|
||||
stt: response.data.stt as Record<string, ProviderSchema>,
|
||||
embeddings: response.data.embeddings as Record<string, ProviderSchema>
|
||||
embeddings: response.data.embeddings as Record<string, ProviderSchema>,
|
||||
realtime: (data.realtime || {}) as Record<string, ProviderSchema>,
|
||||
});
|
||||
|
||||
// Restore realtime toggle from saved config
|
||||
const configData = userConfig as Record<string, unknown> | null;
|
||||
if (configData?.is_realtime) {
|
||||
setIsRealtime(true);
|
||||
}
|
||||
} else {
|
||||
console.error("Failed to fetch configurations");
|
||||
return;
|
||||
|
|
@ -113,23 +131,41 @@ export default function ServiceConfiguration() {
|
|||
llm: response.data.default_providers.llm,
|
||||
tts: response.data.default_providers.tts,
|
||||
stt: response.data.default_providers.stt,
|
||||
embeddings: response.data.default_providers.embeddings
|
||||
embeddings: response.data.default_providers.embeddings,
|
||||
realtime: "",
|
||||
};
|
||||
|
||||
// Set default realtime provider from schema keys
|
||||
const data = response.data as Record<string, unknown>;
|
||||
const realtimeSchemas = (data.realtime || {}) as Record<string, ProviderSchema>;
|
||||
const realtimeProviderKeys = Object.keys(realtimeSchemas);
|
||||
if (realtimeProviderKeys.length > 0) {
|
||||
selectedProviders.realtime = realtimeProviderKeys[0];
|
||||
}
|
||||
|
||||
const loadedApiKeys: Record<ServiceSegment, string[]> = {
|
||||
llm: [""],
|
||||
tts: [""],
|
||||
stt: [""],
|
||||
embeddings: [""],
|
||||
realtime: [""],
|
||||
};
|
||||
|
||||
const setServicePropertyValues = (service: ServiceSegment) => {
|
||||
if (userConfig?.[service]?.provider) {
|
||||
Object.entries(userConfig?.[service]).forEach(([field, value]) => {
|
||||
// For realtime, read from userConfig.realtime; for others, read from userConfig[service]
|
||||
const configSource = service === "realtime"
|
||||
? (userConfig as Record<string, unknown> | null)?.realtime as Record<string, unknown> | undefined
|
||||
: userConfig?.[service as "llm" | "tts" | "stt" | "embeddings"];
|
||||
|
||||
const schemaSource = service === "realtime"
|
||||
? realtimeSchemas
|
||||
: response.data[service as "llm" | "tts" | "stt" | "embeddings"] as Record<string, ProviderSchema> | undefined;
|
||||
|
||||
if (configSource?.provider) {
|
||||
Object.entries(configSource).forEach(([field, value]) => {
|
||||
if (field === "api_key") {
|
||||
// Handle api_key separately — it can be string or string[]
|
||||
if (Array.isArray(value)) {
|
||||
loadedApiKeys[service] = value.length > 0 ? value : [""];
|
||||
loadedApiKeys[service] = (value as string[]).length > 0 ? value as string[] : [""];
|
||||
} else {
|
||||
loadedApiKeys[service] = value ? [value as string] : [""];
|
||||
}
|
||||
|
|
@ -137,9 +173,9 @@ export default function ServiceConfiguration() {
|
|||
defaultValues[`${service}_${field}`] = value as string | number | boolean;
|
||||
}
|
||||
});
|
||||
selectedProviders[service] = userConfig?.[service]?.provider as string;
|
||||
// Fill in schema defaults for fields not present in userConfig
|
||||
const properties = response.data[service]?.[selectedProviders[service]]?.properties as Record<string, SchemaProperty>;
|
||||
selectedProviders[service] = configSource.provider as string;
|
||||
// Fill in schema defaults for fields not present in config
|
||||
const properties = schemaSource?.[selectedProviders[service]]?.properties as Record<string, SchemaProperty>;
|
||||
if (properties) {
|
||||
Object.entries(properties).forEach(([field, schema]) => {
|
||||
const key = `${service}_${field}`;
|
||||
|
|
@ -149,7 +185,7 @@ export default function ServiceConfiguration() {
|
|||
});
|
||||
}
|
||||
} else {
|
||||
const properties = response.data[service]?.[selectedProviders[service]]?.properties as Record<string, SchemaProperty>;
|
||||
const properties = schemaSource?.[selectedProviders[service]]?.properties as Record<string, SchemaProperty>;
|
||||
if (properties) {
|
||||
Object.entries(properties).forEach(([field, schema]) => {
|
||||
if (field !== "provider" && schema.default !== undefined) {
|
||||
|
|
@ -164,15 +200,20 @@ export default function ServiceConfiguration() {
|
|||
setServicePropertyValues("tts");
|
||||
setServicePropertyValues("stt");
|
||||
setServicePropertyValues("embeddings");
|
||||
setServicePropertyValues("realtime");
|
||||
|
||||
// Detect saved values that are not in suggested options (custom value)
|
||||
const detectedCustomInput: Record<string, boolean> = {};
|
||||
const allSchemas = response.data as Record<string, Record<string, ProviderSchema>>;
|
||||
(["llm", "tts", "stt", "embeddings"] as ServiceSegment[]).forEach(service => {
|
||||
const allSchemas = { ...response.data, realtime: realtimeSchemas } as unknown as Record<string, Record<string, ProviderSchema>>;
|
||||
(["llm", "tts", "stt", "embeddings", "realtime"] as ServiceSegment[]).forEach(service => {
|
||||
const provider = selectedProviders[service];
|
||||
const providerSchema = allSchemas[service]?.[provider];
|
||||
if (!providerSchema) return;
|
||||
|
||||
const configSource = service === "realtime"
|
||||
? (userConfig as Record<string, unknown> | null)?.realtime as Record<string, unknown> | undefined
|
||||
: userConfig?.[service as "llm" | "tts" | "stt" | "embeddings"];
|
||||
|
||||
Object.entries(providerSchema.properties).forEach(([field, schema]) => {
|
||||
const actualSchema = (schema as SchemaProperty).$ref && providerSchema.$defs
|
||||
? providerSchema.$defs[(schema as SchemaProperty).$ref!.split('/').pop() || '']
|
||||
|
|
@ -180,7 +221,7 @@ export default function ServiceConfiguration() {
|
|||
|
||||
if (!actualSchema?.allow_custom_input || !actualSchema?.examples) return;
|
||||
|
||||
const savedValue = userConfig?.[service]?.[field] as string | undefined;
|
||||
const savedValue = configSource?.[field] as string | undefined;
|
||||
if (savedValue && !actualSchema.examples.includes(savedValue)) {
|
||||
detectedCustomInput[`${service}_${field}`] = true;
|
||||
}
|
||||
|
|
@ -275,55 +316,42 @@ export default function ServiceConfiguration() {
|
|||
const getServiceApiKeys = (service: ServiceSegment): string[] =>
|
||||
apiKeys[service].map(k => k.trim()).filter(k => k.length > 0);
|
||||
|
||||
const userConfig: Record<ServiceSegment, Record<string, string | number | string[]>> = {
|
||||
llm: {
|
||||
provider: serviceProviders.llm,
|
||||
...(getServiceApiKeys("llm").length > 0 && { api_key: getServiceApiKeys("llm") }),
|
||||
model: data.llm_model as string
|
||||
},
|
||||
tts: {
|
||||
provider: serviceProviders.tts,
|
||||
...(getServiceApiKeys("tts").length > 0 && { api_key: getServiceApiKeys("tts") }),
|
||||
},
|
||||
stt: {
|
||||
provider: serviceProviders.stt,
|
||||
...(getServiceApiKeys("stt").length > 0 && { api_key: getServiceApiKeys("stt") }),
|
||||
},
|
||||
embeddings: {
|
||||
provider: serviceProviders.embeddings,
|
||||
...(getServiceApiKeys("embeddings").length > 0 && { api_key: getServiceApiKeys("embeddings") }),
|
||||
model: data.embeddings_model as string
|
||||
// Build service configs from form data
|
||||
const buildServiceConfig = (service: ServiceSegment) => {
|
||||
const config: Record<string, string | number | string[]> = {
|
||||
provider: serviceProviders[service],
|
||||
};
|
||||
const keys = getServiceApiKeys(service);
|
||||
if (keys.length > 0) {
|
||||
config.api_key = keys;
|
||||
}
|
||||
// Add all form fields for this service
|
||||
Object.entries(data).forEach(([property, value]) => {
|
||||
if (!property.startsWith(`${service}_`)) return;
|
||||
const field = property.slice(service.length + 1);
|
||||
if (field === "api_key" || field === "provider") return;
|
||||
config[field] = value as string | number;
|
||||
});
|
||||
return config;
|
||||
};
|
||||
|
||||
// Add any extra properties in the payload
|
||||
Object.entries(data).forEach(([property, value]) => {
|
||||
const parts = property.split('_');
|
||||
const service = parts[0] as ServiceSegment;
|
||||
const field = parts.slice(1).join('_');
|
||||
|
||||
if (field === "api_key") return; // handled via apiKeys state
|
||||
if (userConfig[service] && !(field in userConfig[service])) {
|
||||
(userConfig[service] as Record<string, string>)[field] = value as string;
|
||||
}
|
||||
});
|
||||
|
||||
// Build save config - only include embeddings if api_key is provided
|
||||
const saveConfig: {
|
||||
llm: Record<string, string | number | string[]>;
|
||||
tts: Record<string, string | number | string[]>;
|
||||
stt: Record<string, string | number | string[]>;
|
||||
embeddings?: Record<string, string | number | string[]>;
|
||||
} = {
|
||||
llm: userConfig.llm,
|
||||
tts: userConfig.tts,
|
||||
stt: userConfig.stt
|
||||
// Always save all configs so switching modes preserves everything
|
||||
const saveConfig: Record<string, unknown> = {
|
||||
llm: buildServiceConfig("llm"),
|
||||
tts: buildServiceConfig("tts"),
|
||||
stt: buildServiceConfig("stt"),
|
||||
is_realtime: isRealtime,
|
||||
};
|
||||
|
||||
// Save realtime config if provider is set
|
||||
if (serviceProviders.realtime) {
|
||||
saveConfig.realtime = buildServiceConfig("realtime");
|
||||
}
|
||||
|
||||
// Only include embeddings if user has configured it (has api_key)
|
||||
const embeddingsKeys = getServiceApiKeys("embeddings");
|
||||
if (embeddingsKeys.length > 0) {
|
||||
saveConfig.embeddings = userConfig.embeddings;
|
||||
saveConfig.embeddings = buildServiceConfig("embeddings");
|
||||
}
|
||||
|
||||
try {
|
||||
|
|
@ -612,6 +640,9 @@ export default function ServiceConfiguration() {
|
|||
);
|
||||
};
|
||||
|
||||
const visibleTabs = isRealtime ? REALTIME_TABS : STANDARD_TABS;
|
||||
const defaultTab = isRealtime ? "realtime" : "llm";
|
||||
|
||||
return (
|
||||
<div className="w-full max-w-2xl mx-auto">
|
||||
<div className="mb-6">
|
||||
|
|
@ -622,18 +653,35 @@ export default function ServiceConfiguration() {
|
|||
</div>
|
||||
|
||||
<form onSubmit={handleSubmit(onSubmit)}>
|
||||
{/* Realtime toggle */}
|
||||
<div className="flex items-center justify-between mb-4 p-4 border rounded-lg">
|
||||
<div>
|
||||
<Label htmlFor="realtime-toggle" className="text-sm font-medium">
|
||||
Realtime Mode
|
||||
</Label>
|
||||
<p className="text-xs text-muted-foreground mt-0.5">
|
||||
Uses a single speech-to-speech model (no separate STT/TTS)
|
||||
</p>
|
||||
</div>
|
||||
<Switch
|
||||
id="realtime-toggle"
|
||||
checked={isRealtime}
|
||||
onCheckedChange={setIsRealtime}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<Card>
|
||||
<CardContent className="pt-6">
|
||||
<Tabs defaultValue="llm" className="w-full">
|
||||
<TabsList className="grid w-full grid-cols-4 mb-6">
|
||||
{TAB_CONFIG.map(({ key, label }) => (
|
||||
<Tabs key={defaultTab} defaultValue={defaultTab} className="w-full">
|
||||
<TabsList className="grid w-full mb-6" style={{ gridTemplateColumns: `repeat(${visibleTabs.length}, 1fr)` }}>
|
||||
{visibleTabs.map(({ key, label }) => (
|
||||
<TabsTrigger key={key} value={key}>
|
||||
{label}
|
||||
</TabsTrigger>
|
||||
))}
|
||||
</TabsList>
|
||||
|
||||
{TAB_CONFIG.map(({ key }) => (
|
||||
{visibleTabs.map(({ key }) => (
|
||||
<TabsContent key={key} value={key} className="mt-0">
|
||||
{renderServiceFields(key)}
|
||||
</TabsContent>
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue