From ac0731a374692c3d4e4517587cd2bbbbacac1d8a Mon Sep 17 00:00:00 2001 From: Abhishek Kumar Date: Tue, 24 Mar 2026 17:50:45 +0530 Subject: [PATCH] feat: add support for self hosted llm models --- .gitignore | 1 + api/db/workflow_recording_client.py | 2 +- api/services/configuration/check_validity.py | 22 ++++++++++++++- api/services/configuration/registry.py | 25 +++++++++++++++++ .../pipecat/recording_router_processor.py | 23 ++++++++++++---- api/services/pipecat/service_factory.py | 24 ++++++++++++++--- api/services/workflow/pipecat_engine.py | 12 ++++----- .../pipecat_engine_context_composer.py | 6 ++--- api/services/workflow/qa/analysis.py | 9 +++++-- api/services/workflow/qa/node_summary.py | 7 ++++- api/tests/test_camb_tts_integration.py | 5 ++-- ui/AGENTS.md | 6 +++++ ui/src/app/superadmin/runs/page.tsx | 18 ++++++++----- .../components/RecordingsDialog.tsx | 27 ++++++++++++++----- .../[workflowId]/utils/layoutNodes.ts | 24 ++++++++++++----- ui/src/components/ServiceConfiguration.tsx | 12 ++++++++- ui/src/components/flow/MentionTextarea.tsx | 4 ++- 17 files changed, 179 insertions(+), 48 deletions(-) diff --git a/.gitignore b/.gitignore index e92242d..0e6b619 100644 --- a/.gitignore +++ b/.gitignore @@ -16,3 +16,4 @@ venv/ .playwright-mcp coturn/ *.wav +dograh_pcm_cache/ \ No newline at end of file diff --git a/api/db/workflow_recording_client.py b/api/db/workflow_recording_client.py index c8e6604..f0c81c5 100644 --- a/api/db/workflow_recording_client.py +++ b/api/db/workflow_recording_client.py @@ -64,7 +64,7 @@ class WorkflowRecordingClient(BaseDBClient): storage_key=storage_key, storage_backend=storage_backend, created_by=created_by, - metadata=metadata or {}, + recording_metadata=metadata or {}, ) session.add(recording) diff --git a/api/services/configuration/check_validity.py b/api/services/configuration/check_validity.py index 0208db6..399ec28 100644 --- a/api/services/configuration/check_validity.py +++ b/api/services/configuration/check_validity.py @@ -40,6 +40,7 @@ class UserConfigurationValidator: ServiceProviders.SPEECHMATICS.value: self._check_speechmatics_api_key, ServiceProviders.CAMB.value: self._check_camb_api_key, ServiceProviders.AWS_BEDROCK.value: self._check_aws_bedrock_api_key, + ServiceProviders.SELF_HOSTED.value: self._check_self_hosted_api_key, } async def validate(self, configuration: UserConfiguration) -> APIKeyStatusResponse: @@ -74,6 +75,20 @@ class UserConfigurationValidator: provider = service_config.provider + # Self-hosted doesn't require an API key + if provider == ServiceProviders.SELF_HOSTED.value: + try: + if not self._check_self_hosted_api_key(provider, service_config): + return [ + { + "model": service_name, + "message": f"Invalid {provider} configuration", + } + ] + except ValueError as e: + return [{"model": service_name, "message": str(e)}] + return [] + # AWS Bedrock uses AWS credentials instead of api_key if provider == ServiceProviders.AWS_BEDROCK.value: try: @@ -163,7 +178,12 @@ class UserConfigurationValidator: def _check_camb_api_key(self, model: str, api_key: str) -> bool: return True - + + def _check_self_hosted_api_key(self, model: str, service_config) -> bool: + if not getattr(service_config, "base_url", None): + raise ValueError("base_url is required for self-hosted LLM") + return True + def _check_aws_bedrock_api_key(self, model: str, service_config) -> bool: if not service_config.aws_access_key or not service_config.aws_secret_key: raise ValueError("AWS access key and secret key are required for Bedrock") diff --git a/api/services/configuration/registry.py b/api/services/configuration/registry.py index 49288ee..6cc22a0 100644 --- a/api/services/configuration/registry.py +++ b/api/services/configuration/registry.py @@ -27,6 +27,7 @@ class ServiceProviders(str, Enum): SPEECHMATICS = "speechmatics" CAMB = "camb" AWS_BEDROCK = "aws_bedrock" + SELF_HOSTED = "self_hosted" class BaseServiceConfiguration(BaseModel): @@ -40,6 +41,7 @@ class BaseServiceConfiguration(BaseModel): ServiceProviders.AZURE, ServiceProviders.DOGRAH, ServiceProviders.AWS_BEDROCK, + ServiceProviders.SELF_HOSTED, # ServiceProviders.SARVAM, ] api_key: str | list[str] @@ -249,6 +251,22 @@ class AWSBedrockLLMConfiguration(BaseLLMConfiguration): api_key: str | list[str] | None = Field(default=None) +SELF_HOSTED_LLM_MODELS = ["llama3", "mistral", "phi3", "qwen2", "gemma2", "deepseek-r1"] + + +@register_llm +class SelfHostedLLMConfiguration(BaseLLMConfiguration): + provider: Literal[ServiceProviders.SELF_HOSTED] = ServiceProviders.SELF_HOSTED + model: str = Field( + default="llama3", json_schema_extra={"examples": SELF_HOSTED_LLM_MODELS} + ) + base_url: str = Field( + default="http://localhost:11434/v1", + description="OpenAI-compatible endpoint (Ollama, vLLM, etc.)", + ) + api_key: str | list[str] | None = Field(default=None) + + LLMConfig = Annotated[ Union[ OpenAILLMService, @@ -258,6 +276,7 @@ LLMConfig = Annotated[ AzureLLMService, DograhLLMService, AWSBedrockLLMConfiguration, + SelfHostedLLMConfiguration, ], Field(discriminator="provider"), ] @@ -334,6 +353,12 @@ class CartesiaTTSConfiguration(BaseTTSConfiguration): ) voice: str = Field(default="3faa81ae-d3d8-4ab1-9e44-e50e46d33c30") speed: float = Field(default=1.0, ge=0.6, le=1.5, description="Speed of the voice") + volume: float = Field( + default=1.0, + ge=0.5, + le=2.0, + description="Volume multiplier for generated speech", + ) SARVAM_TTS_MODELS = ["bulbul:v2", "bulbul:v3"] diff --git a/api/services/pipecat/recording_router_processor.py b/api/services/pipecat/recording_router_processor.py index be4eb5b..6c22d35 100644 --- a/api/services/pipecat/recording_router_processor.py +++ b/api/services/pipecat/recording_router_processor.py @@ -66,6 +66,7 @@ class RecordingRouterProcessor(FrameProcessor): self._frame_buffer: list[tuple[LLMTextFrame, FrameDirection]] = [] self._mode: Optional[str] = None # None = detecting, "tts", "recording" self._recording_id_buffer = "" + self._recording_playback_started = False # ------------------------------------------------------------------ # Frame dispatch @@ -99,9 +100,15 @@ class RecordingRouterProcessor(FrameProcessor): await self.push_frame(frame, direction) return - # --- Recording mode: accumulate recording_id silently --- + # --- Recording mode: accumulate text and start playback ASAP --- if self._mode == "recording": self._recording_id_buffer += frame.text + if not self._recording_playback_started: + buf = self._recording_id_buffer.lstrip() + if " " in buf: + recording_id = buf.split()[0] + self._recording_playback_started = True + await self._play_recording(recording_id) return # --- Detection mode: buffer until marker found --- @@ -178,16 +185,21 @@ class RecordingRouterProcessor(FrameProcessor): self, frame: LLMFullResponseEndFrame, direction: FrameDirection ): if self._mode == "recording": - recording_id = self._recording_id_buffer.strip() - if recording_id: - # Push accumulated text as TTSTextFrame for UI feedback via observer + full_text = self._recording_id_buffer.strip() + if full_text: + recording_id = full_text.split()[0] + + # Push full text (marker + id + transcript) for assistant context await self.push_frame( TTSTextFrame( text=RECORDING_MARKER + self._recording_id_buffer, aggregated_by="recording_router", ) ) - await self._play_recording(recording_id) + + # Fallback: if response ended before a space arrived (no transcript) + if not self._recording_playback_started: + await self._play_recording(recording_id) else: logger.warning( "RecordingRouterProcessor: recording mode but empty recording_id" @@ -256,3 +268,4 @@ class RecordingRouterProcessor(FrameProcessor): self._frame_buffer = [] self._mode = None self._recording_id_buffer = "" + self._recording_playback_started = False diff --git a/api/services/pipecat/service_factory.py b/api/services/pipecat/service_factory.py index 6d00082..33d54c3 100644 --- a/api/services/pipecat/service_factory.py +++ b/api/services/pipecat/service_factory.py @@ -8,7 +8,11 @@ from api.services.configuration.registry import ServiceProviders from pipecat.services.aws.llm import AWSBedrockLLMService, AWSBedrockLLMSettings from pipecat.services.azure.llm import AzureLLMService, AzureLLMSettings from pipecat.services.cartesia.stt import CartesiaSTTService -from pipecat.services.cartesia.tts import CartesiaTTSService, CartesiaTTSSettings, GenerationConfig +from pipecat.services.cartesia.tts import ( + CartesiaTTSService, + CartesiaTTSSettings, + GenerationConfig, +) from pipecat.services.deepgram.flux.stt import ( DeepgramFluxSTTService, DeepgramFluxSTTSettings, @@ -212,13 +216,19 @@ def create_tts_service(user_config, audio_config: "AudioConfig"): ) elif user_config.tts.provider == ServiceProviders.CARTESIA.value: speed = getattr(user_config.tts, "speed", None) - generation_config = GenerationConfig(speed=speed) if speed and speed != 1.0 else None + generation_config = ( + GenerationConfig(speed=speed) if speed and speed != 1.0 else None + ) return CartesiaTTSService( api_key=user_config.tts.api_key, settings=CartesiaTTSSettings( voice=user_config.tts.voice, model=user_config.tts.model, - **({"generation_config": generation_config} if generation_config else {}), + **( + {"generation_config": generation_config} + if generation_config + else {} + ), ), text_filters=[xml_function_tag_filter], silence_time_s=1.0, @@ -353,6 +363,12 @@ def create_llm_service_from_provider( aws_region=aws_region, settings=AWSBedrockLLMSettings(model=model), ) + elif provider == ServiceProviders.SELF_HOSTED.value: + return OpenAILLMService( + base_url=base_url or "http://localhost:11434/v1", + api_key=api_key or "none", + settings=OpenAILLMSettings(model=model), + ) else: raise HTTPException(status_code=400, detail=f"Invalid LLM provider {provider}") @@ -368,6 +384,8 @@ def create_llm_service(user_config): kwargs["base_url"] = user_config.llm.base_url elif provider == ServiceProviders.AZURE.value: kwargs["endpoint"] = user_config.llm.endpoint + elif provider == ServiceProviders.SELF_HOSTED.value: + kwargs["base_url"] = user_config.llm.base_url elif provider == ServiceProviders.AWS_BEDROCK.value: kwargs["aws_access_key"] = user_config.llm.aws_access_key kwargs["aws_secret_key"] = user_config.llm.aws_secret_key diff --git a/api/services/workflow/pipecat_engine.py b/api/services/workflow/pipecat_engine.py index e41bde4..e28bc4d 100644 --- a/api/services/workflow/pipecat_engine.py +++ b/api/services/workflow/pipecat_engine.py @@ -437,9 +437,7 @@ class PipecatEngine: async def _do_extraction(): try: - logger.debug( - f"Starting variable extraction for node: {node.name}" - ) + logger.debug(f"Starting variable extraction for node: {node.name}") extracted_data = ( await self._variable_extraction_manager._perform_extraction( extraction_variables, parent_context, extraction_prompt @@ -454,7 +452,9 @@ class PipecatEngine: f"Variable extraction completed for node: {node.name}. Extracted: {extracted_data}" ) except Exception as e: - logger.error(f"Error during variable extraction for node {node.name}: {str(e)}") + logger.error( + f"Error during variable extraction for node {node.name}: {str(e)}" + ) if run_in_background: logger.debug( @@ -497,9 +497,7 @@ class PipecatEngine: logger.error( f"Pending extraction task '{task_name}' failed: {result}" ) - logger.debug( - f"All pending extraction tasks completed in {elapsed:.2f}s" - ) + logger.debug(f"All pending extraction tasks completed in {elapsed:.2f}s") except asyncio.TimeoutError: incomplete = [ t.get_name() for t in self._pending_extraction_tasks if not t.done() diff --git a/api/services/workflow/pipecat_engine_context_composer.py b/api/services/workflow/pipecat_engine_context_composer.py index 0674e4b..96b7236 100644 --- a/api/services/workflow/pipecat_engine_context_composer.py +++ b/api/services/workflow/pipecat_engine_context_composer.py @@ -34,13 +34,13 @@ You have two modes for responding: Example: ▸ Hello! How can I help you today? 2. PRE-RECORDED AUDIO (●): Play a pre-recorded audio message. - Format: `●` followed by a space and ONLY the recording_id. Nothing else. - Example: ● rec_greeting_01 + Format: `●` followed by a space followed by recording_id followed by provided transcript. Nothing else. + Example: ● rec_greeting_01 [ Provided Transcript ] RULES: - Your response MUST start with either `▸` or `●` as the very first character. - For `▸` (dynamic speech): Follow with a space and your full response text. -- For `●` (pre-recorded audio): Follow with a space and ONLY the recording_id. No other text. +- For `●` (pre-recorded audio): Follow with a space and the recording_id and the provided transcript. No other text. - Use `●` when a pre-recorded message matches the situation well. - Use `▸` when you need to generate a dynamic, contextual response. - NEVER mix modes in a single response. Choose one.""" diff --git a/api/services/workflow/qa/analysis.py b/api/services/workflow/qa/analysis.py index 55f5ab3..bd064a1 100644 --- a/api/services/workflow/qa/analysis.py +++ b/api/services/workflow/qa/analysis.py @@ -28,7 +28,9 @@ from api.utils.template_renderer import render_template from pipecat.processors.aggregators.llm_context import LLMContext -async def _run_llm_inference(llm, messages: list[dict], system_prompt: str) -> str | None: +async def _run_llm_inference( + llm, messages: list[dict], system_prompt: str +) -> str | None: """Run a one-shot LLM inference using the pipecat service.""" context = LLMContext() context.set_messages(messages) @@ -51,7 +53,10 @@ async def _generate_conversation_summary( ] try: - summary = await _run_llm_inference(llm, messages, CONVERSATION_SUMMARY_SYSTEM_PROMPT) or "" + summary = ( + await _run_llm_inference(llm, messages, CONVERSATION_SUMMARY_SYSTEM_PROMPT) + or "" + ) span_name = f"conversation-summary-before-{node_name}" add_qa_span_to_trace(parent_ctx, model, messages, summary, span_name) diff --git a/api/services/workflow/qa/node_summary.py b/api/services/workflow/qa/node_summary.py index b02e59c..db33980 100644 --- a/api/services/workflow/qa/node_summary.py +++ b/api/services/workflow/qa/node_summary.py @@ -154,7 +154,12 @@ async def ensure_node_summaries( try: context = LLMContext() context.set_messages(messages) - summary_text = await llm.run_inference(context, system_instruction=NODE_SUMMARY_SYSTEM_PROMPT) or "" + summary_text = ( + await llm.run_inference( + context, system_instruction=NODE_SUMMARY_SYSTEM_PROMPT + ) + or "" + ) except Exception as e: logger.warning(f"Failed to generate summary for node {node_id}: {e}") updated_summaries[node_id] = {"summary": ""} diff --git a/api/tests/test_camb_tts_integration.py b/api/tests/test_camb_tts_integration.py index 3d6c4a8..8b53919 100644 --- a/api/tests/test_camb_tts_integration.py +++ b/api/tests/test_camb_tts_integration.py @@ -9,7 +9,7 @@ Covers: """ from types import SimpleNamespace -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import MagicMock, patch import pytest from pydantic import ValidationError @@ -17,13 +17,12 @@ from pydantic import ValidationError from api.services.configuration.check_validity import UserConfigurationValidator from api.services.configuration.registry import ( CAMB_TTS_MODELS, - CambTTSConfiguration, REGISTRY, + CambTTSConfiguration, ServiceProviders, ServiceType, ) - # --------------------------------------------------------------------------- # 1. CambTTSConfiguration model tests # --------------------------------------------------------------------------- diff --git a/ui/AGENTS.md b/ui/AGENTS.md index 34ab371..e821582 100644 --- a/ui/AGENTS.md +++ b/ui/AGENTS.md @@ -48,6 +48,12 @@ new api route in backend, and wish to use it in the UI, generate the client usin npm run generate-client ``` +## Conventions + +### File Uploads + +Always use a hidden `` with a visible `