feat: add support for self hosted llm models

This commit is contained in:
Abhishek Kumar 2026-03-24 17:50:45 +05:30
parent 31e075d114
commit ac0731a374
17 changed files with 179 additions and 48 deletions

View file

@ -40,6 +40,7 @@ class UserConfigurationValidator:
ServiceProviders.SPEECHMATICS.value: self._check_speechmatics_api_key,
ServiceProviders.CAMB.value: self._check_camb_api_key,
ServiceProviders.AWS_BEDROCK.value: self._check_aws_bedrock_api_key,
ServiceProviders.SELF_HOSTED.value: self._check_self_hosted_api_key,
}
async def validate(self, configuration: UserConfiguration) -> APIKeyStatusResponse:
@ -74,6 +75,20 @@ class UserConfigurationValidator:
provider = service_config.provider
# Self-hosted doesn't require an API key
if provider == ServiceProviders.SELF_HOSTED.value:
try:
if not self._check_self_hosted_api_key(provider, service_config):
return [
{
"model": service_name,
"message": f"Invalid {provider} configuration",
}
]
except ValueError as e:
return [{"model": service_name, "message": str(e)}]
return []
# AWS Bedrock uses AWS credentials instead of api_key
if provider == ServiceProviders.AWS_BEDROCK.value:
try:
@ -163,7 +178,12 @@ class UserConfigurationValidator:
def _check_camb_api_key(self, model: str, api_key: str) -> bool:
return True
def _check_self_hosted_api_key(self, model: str, service_config) -> bool:
if not getattr(service_config, "base_url", None):
raise ValueError("base_url is required for self-hosted LLM")
return True
def _check_aws_bedrock_api_key(self, model: str, service_config) -> bool:
if not service_config.aws_access_key or not service_config.aws_secret_key:
raise ValueError("AWS access key and secret key are required for Bedrock")

View file

@ -27,6 +27,7 @@ class ServiceProviders(str, Enum):
SPEECHMATICS = "speechmatics"
CAMB = "camb"
AWS_BEDROCK = "aws_bedrock"
SELF_HOSTED = "self_hosted"
class BaseServiceConfiguration(BaseModel):
@ -40,6 +41,7 @@ class BaseServiceConfiguration(BaseModel):
ServiceProviders.AZURE,
ServiceProviders.DOGRAH,
ServiceProviders.AWS_BEDROCK,
ServiceProviders.SELF_HOSTED,
# ServiceProviders.SARVAM,
]
api_key: str | list[str]
@ -249,6 +251,22 @@ class AWSBedrockLLMConfiguration(BaseLLMConfiguration):
api_key: str | list[str] | None = Field(default=None)
SELF_HOSTED_LLM_MODELS = ["llama3", "mistral", "phi3", "qwen2", "gemma2", "deepseek-r1"]
@register_llm
class SelfHostedLLMConfiguration(BaseLLMConfiguration):
provider: Literal[ServiceProviders.SELF_HOSTED] = ServiceProviders.SELF_HOSTED
model: str = Field(
default="llama3", json_schema_extra={"examples": SELF_HOSTED_LLM_MODELS}
)
base_url: str = Field(
default="http://localhost:11434/v1",
description="OpenAI-compatible endpoint (Ollama, vLLM, etc.)",
)
api_key: str | list[str] | None = Field(default=None)
LLMConfig = Annotated[
Union[
OpenAILLMService,
@ -258,6 +276,7 @@ LLMConfig = Annotated[
AzureLLMService,
DograhLLMService,
AWSBedrockLLMConfiguration,
SelfHostedLLMConfiguration,
],
Field(discriminator="provider"),
]
@ -334,6 +353,12 @@ class CartesiaTTSConfiguration(BaseTTSConfiguration):
)
voice: str = Field(default="3faa81ae-d3d8-4ab1-9e44-e50e46d33c30")
speed: float = Field(default=1.0, ge=0.6, le=1.5, description="Speed of the voice")
volume: float = Field(
default=1.0,
ge=0.5,
le=2.0,
description="Volume multiplier for generated speech",
)
SARVAM_TTS_MODELS = ["bulbul:v2", "bulbul:v3"]

View file

@ -66,6 +66,7 @@ class RecordingRouterProcessor(FrameProcessor):
self._frame_buffer: list[tuple[LLMTextFrame, FrameDirection]] = []
self._mode: Optional[str] = None # None = detecting, "tts", "recording"
self._recording_id_buffer = ""
self._recording_playback_started = False
# ------------------------------------------------------------------
# Frame dispatch
@ -99,9 +100,15 @@ class RecordingRouterProcessor(FrameProcessor):
await self.push_frame(frame, direction)
return
# --- Recording mode: accumulate recording_id silently ---
# --- Recording mode: accumulate text and start playback ASAP ---
if self._mode == "recording":
self._recording_id_buffer += frame.text
if not self._recording_playback_started:
buf = self._recording_id_buffer.lstrip()
if " " in buf:
recording_id = buf.split()[0]
self._recording_playback_started = True
await self._play_recording(recording_id)
return
# --- Detection mode: buffer until marker found ---
@ -178,16 +185,21 @@ class RecordingRouterProcessor(FrameProcessor):
self, frame: LLMFullResponseEndFrame, direction: FrameDirection
):
if self._mode == "recording":
recording_id = self._recording_id_buffer.strip()
if recording_id:
# Push accumulated text as TTSTextFrame for UI feedback via observer
full_text = self._recording_id_buffer.strip()
if full_text:
recording_id = full_text.split()[0]
# Push full text (marker + id + transcript) for assistant context
await self.push_frame(
TTSTextFrame(
text=RECORDING_MARKER + self._recording_id_buffer,
aggregated_by="recording_router",
)
)
await self._play_recording(recording_id)
# Fallback: if response ended before a space arrived (no transcript)
if not self._recording_playback_started:
await self._play_recording(recording_id)
else:
logger.warning(
"RecordingRouterProcessor: recording mode but empty recording_id"
@ -256,3 +268,4 @@ class RecordingRouterProcessor(FrameProcessor):
self._frame_buffer = []
self._mode = None
self._recording_id_buffer = ""
self._recording_playback_started = False

View file

@ -8,7 +8,11 @@ from api.services.configuration.registry import ServiceProviders
from pipecat.services.aws.llm import AWSBedrockLLMService, AWSBedrockLLMSettings
from pipecat.services.azure.llm import AzureLLMService, AzureLLMSettings
from pipecat.services.cartesia.stt import CartesiaSTTService
from pipecat.services.cartesia.tts import CartesiaTTSService, CartesiaTTSSettings, GenerationConfig
from pipecat.services.cartesia.tts import (
CartesiaTTSService,
CartesiaTTSSettings,
GenerationConfig,
)
from pipecat.services.deepgram.flux.stt import (
DeepgramFluxSTTService,
DeepgramFluxSTTSettings,
@ -212,13 +216,19 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
)
elif user_config.tts.provider == ServiceProviders.CARTESIA.value:
speed = getattr(user_config.tts, "speed", None)
generation_config = GenerationConfig(speed=speed) if speed and speed != 1.0 else None
generation_config = (
GenerationConfig(speed=speed) if speed and speed != 1.0 else None
)
return CartesiaTTSService(
api_key=user_config.tts.api_key,
settings=CartesiaTTSSettings(
voice=user_config.tts.voice,
model=user_config.tts.model,
**({"generation_config": generation_config} if generation_config else {}),
**(
{"generation_config": generation_config}
if generation_config
else {}
),
),
text_filters=[xml_function_tag_filter],
silence_time_s=1.0,
@ -353,6 +363,12 @@ def create_llm_service_from_provider(
aws_region=aws_region,
settings=AWSBedrockLLMSettings(model=model),
)
elif provider == ServiceProviders.SELF_HOSTED.value:
return OpenAILLMService(
base_url=base_url or "http://localhost:11434/v1",
api_key=api_key or "none",
settings=OpenAILLMSettings(model=model),
)
else:
raise HTTPException(status_code=400, detail=f"Invalid LLM provider {provider}")
@ -368,6 +384,8 @@ def create_llm_service(user_config):
kwargs["base_url"] = user_config.llm.base_url
elif provider == ServiceProviders.AZURE.value:
kwargs["endpoint"] = user_config.llm.endpoint
elif provider == ServiceProviders.SELF_HOSTED.value:
kwargs["base_url"] = user_config.llm.base_url
elif provider == ServiceProviders.AWS_BEDROCK.value:
kwargs["aws_access_key"] = user_config.llm.aws_access_key
kwargs["aws_secret_key"] = user_config.llm.aws_secret_key

View file

@ -437,9 +437,7 @@ class PipecatEngine:
async def _do_extraction():
try:
logger.debug(
f"Starting variable extraction for node: {node.name}"
)
logger.debug(f"Starting variable extraction for node: {node.name}")
extracted_data = (
await self._variable_extraction_manager._perform_extraction(
extraction_variables, parent_context, extraction_prompt
@ -454,7 +452,9 @@ class PipecatEngine:
f"Variable extraction completed for node: {node.name}. Extracted: {extracted_data}"
)
except Exception as e:
logger.error(f"Error during variable extraction for node {node.name}: {str(e)}")
logger.error(
f"Error during variable extraction for node {node.name}: {str(e)}"
)
if run_in_background:
logger.debug(
@ -497,9 +497,7 @@ class PipecatEngine:
logger.error(
f"Pending extraction task '{task_name}' failed: {result}"
)
logger.debug(
f"All pending extraction tasks completed in {elapsed:.2f}s"
)
logger.debug(f"All pending extraction tasks completed in {elapsed:.2f}s")
except asyncio.TimeoutError:
incomplete = [
t.get_name() for t in self._pending_extraction_tasks if not t.done()

View file

@ -34,13 +34,13 @@ You have two modes for responding:
Example: Hello! How can I help you today?
2. PRE-RECORDED AUDIO (): Play a pre-recorded audio message.
Format: `` followed by a space and ONLY the recording_id. Nothing else.
Example: rec_greeting_01
Format: `` followed by a space followed by recording_id followed by provided transcript. Nothing else.
Example: rec_greeting_01 [ Provided Transcript ]
RULES:
- Your response MUST start with either `` or `` as the very first character.
- For `` (dynamic speech): Follow with a space and your full response text.
- For `` (pre-recorded audio): Follow with a space and ONLY the recording_id. No other text.
- For `` (pre-recorded audio): Follow with a space and the recording_id and the provided transcript. No other text.
- Use `` when a pre-recorded message matches the situation well.
- Use `` when you need to generate a dynamic, contextual response.
- NEVER mix modes in a single response. Choose one."""

View file

@ -28,7 +28,9 @@ from api.utils.template_renderer import render_template
from pipecat.processors.aggregators.llm_context import LLMContext
async def _run_llm_inference(llm, messages: list[dict], system_prompt: str) -> str | None:
async def _run_llm_inference(
llm, messages: list[dict], system_prompt: str
) -> str | None:
"""Run a one-shot LLM inference using the pipecat service."""
context = LLMContext()
context.set_messages(messages)
@ -51,7 +53,10 @@ async def _generate_conversation_summary(
]
try:
summary = await _run_llm_inference(llm, messages, CONVERSATION_SUMMARY_SYSTEM_PROMPT) or ""
summary = (
await _run_llm_inference(llm, messages, CONVERSATION_SUMMARY_SYSTEM_PROMPT)
or ""
)
span_name = f"conversation-summary-before-{node_name}"
add_qa_span_to_trace(parent_ctx, model, messages, summary, span_name)

View file

@ -154,7 +154,12 @@ async def ensure_node_summaries(
try:
context = LLMContext()
context.set_messages(messages)
summary_text = await llm.run_inference(context, system_instruction=NODE_SUMMARY_SYSTEM_PROMPT) or ""
summary_text = (
await llm.run_inference(
context, system_instruction=NODE_SUMMARY_SYSTEM_PROMPT
)
or ""
)
except Exception as e:
logger.warning(f"Failed to generate summary for node {node_id}: {e}")
updated_summaries[node_id] = {"summary": ""}