feat: add hybrid text + recording functionality in agents (#191)

* feat: add recording feature in agents

* chore: pin pipecat version

* feat: show usage in UI

* chore: update pipecat
This commit is contained in:
Abhishek 2026-03-16 15:04:08 +05:30 committed by GitHub
parent f075bcb623
commit 494c60d774
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
43 changed files with 2865 additions and 397 deletions

View file

@ -1,4 +1,4 @@
from typing import Dict, Optional, TypedDict
from typing import Optional, TypedDict
import openai
from deepgram import DeepgramClient
@ -12,6 +12,7 @@ from api.schemas.user_configuration import (
UserConfiguration,
)
from api.services.configuration.registry import ServiceConfig, ServiceProviders
from api.services.mps_service_key_client import mps_service_key_client
class APIKeyStatus(TypedDict):
@ -25,7 +26,6 @@ class APIKeyStatusResponse(TypedDict):
class UserConfigurationValidator:
def __init__(self):
self._provider_api_key_validity_status: Dict[str, bool] = {}
self._validator_map = {
ServiceProviders.OPENAI.value: self._check_openai_api_key,
ServiceProviders.DEEPGRAM.value: self._check_deepgram_api_key,
@ -73,8 +73,13 @@ class UserConfigurationValidator:
provider = service_config.provider
api_key = service_config.api_key
if not self._check_api_key(provider, api_key):
return [{"model": service_name, "message": f"Invalid {provider} API key"}]
try:
if not self._check_api_key(provider, api_key):
return [
{"model": service_name, "message": f"Invalid {provider} API key"}
]
except ValueError as e:
return [{"model": service_name, "message": str(e)}]
return []
@ -87,40 +92,28 @@ class UserConfigurationValidator:
return validator(provider, api_key)
def _check_openai_api_key(self, model: str, api_key: str) -> bool:
if model in self._provider_api_key_validity_status:
return self._provider_api_key_validity_status[model]
client = openai.OpenAI(api_key=api_key)
try:
client.models.list()
self._provider_api_key_validity_status[model] = True
return True
except openai.AuthenticationError:
self._provider_api_key_validity_status[model] = False
return self._provider_api_key_validity_status[model]
return False
def _check_deepgram_api_key(self, model: str, api_key: str) -> bool:
if model in self._provider_api_key_validity_status:
return self._provider_api_key_validity_status[model]
try:
deepgram = DeepgramClient(api_key=api_key)
deepgram.manage.v1.projects.list()
self._provider_api_key_validity_status[model] = True
return True
except Exception:
self._provider_api_key_validity_status[model] = False
return self._provider_api_key_validity_status[model]
return False
def _check_groq_api_key(self, model: str, api_key: str) -> bool:
if model in self._provider_api_key_validity_status:
return self._provider_api_key_validity_status[model]
client = Groq(api_key=api_key)
try:
client.models.list()
self._provider_api_key_validity_status[model] = True
return True
except Exception:
self._provider_api_key_validity_status[model] = False
return self._provider_api_key_validity_status[model]
return False
def _validate_elevenlabs_api_key(self, model: str, api_key: str) -> bool:
return True
@ -135,7 +128,12 @@ class UserConfigurationValidator:
return True
def _check_dograh_api_key(self, model: str, api_key: str) -> bool:
return True
if api_key.startswith("dgr"):
raise ValueError(
"You provided a Dograh API key (dgr...) instead of a service key. "
"Please use a service key (mps...)."
)
return mps_service_key_client.validate_service_key(api_key)
def _check_sarvam_api_key(self, model: str, api_key: str) -> bool:
return True

View file

@ -285,6 +285,90 @@ class MPSServiceKeyClient:
response=response,
)
async def get_usage_by_created_by(self, created_by: str) -> dict:
"""
Get aggregated usage for all service keys created by a user (OSS mode).
Args:
created_by: The user's provider ID
Returns:
Dictionary containing total_credits_used and remaining_credits
"""
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.post(
f"{self.base_url}/api/v1/service-keys/usage/created-by",
json={"created_by": created_by},
headers=self._get_headers(created_by=created_by),
)
if response.status_code == 200:
data = response.json()
return {
"total_credits_used": data.get("total_credits_used", 0.0),
"remaining_credits": data.get("remaining_credits", 0.0),
}
else:
logger.error(
f"Failed to get usage by created_by: {response.status_code} - {response.text}"
)
raise httpx.HTTPStatusError(
f"Failed to get usage by created_by: {response.text}",
request=response.request,
response=response,
)
async def get_usage_by_organization(self, organization_id: int) -> dict:
"""
Get aggregated usage for all service keys belonging to an organization (hosted mode).
Args:
organization_id: The organization's ID
Returns:
Dictionary containing total_credits_used and remaining_credits
"""
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.post(
f"{self.base_url}/api/v1/service-keys/usage/organization",
json={"organization_id": organization_id},
headers=self._get_headers(organization_id=organization_id),
)
if response.status_code == 200:
data = response.json()
return {
"total_credits_used": data.get("total_credits_used", 0.0),
"remaining_credits": data.get("remaining_credits", 0.0),
}
else:
logger.error(
f"Failed to get usage by organization: {response.status_code} - {response.text}"
)
raise httpx.HTTPStatusError(
f"Failed to get usage by organization: {response.text}",
request=response.request,
response=response,
)
def validate_service_key(self, service_key: str) -> bool:
"""
Synchronously validate a Dograh service key by checking usage via MPS.
Returns True if the key is valid, False otherwise.
"""
try:
with httpx.Client(timeout=self.timeout) as client:
response = client.post(
f"{self.base_url}/api/v1/service-keys/usage",
json={"service_key": service_key},
headers=self._get_headers(),
)
return response.status_code == 200
except Exception:
logger.warning("Failed to validate Dograh service key via MPS")
return False
async def get_voices(
self,
provider: str,

View file

@ -39,6 +39,7 @@ def build_pipeline(
pipeline_engine_callback_processor,
pipeline_metrics_aggregator,
voicemail_detector=None,
recording_router=None,
):
"""Build the main pipeline with all components.
@ -47,6 +48,9 @@ def build_pipeline(
voicemail_detector: Optional native pipecat VoicemailDetector. When provided,
inserts voicemail detection after STT. Note: We don't use the TTS gate
to avoid blocking TTS frames during classification.
recording_router: Optional RecordingRouterProcessor. When provided,
inserts between callback processor and TTS to route between
pre-recorded audio playback and dynamic TTS.
"""
# Build processors list with optional voicemail detection
processors = [
@ -66,11 +70,15 @@ def build_pipeline(
processors.append(voicemail_detector.detector())
# Continue with the rest of the pipeline
post_llm = [pipeline_engine_callback_processor]
if recording_router:
post_llm.append(recording_router)
processors.extend(
[
user_context_aggregator,
llm, # LLM
pipeline_engine_callback_processor,
*post_llm,
tts, # TTS
transport.output(), # Transport bot output
audio_buffer, # AudioBufferProcessor - records both input and output audio

View file

@ -37,10 +37,10 @@ from pipecat.frames.frames import (
FunctionCallResultFrame,
InterimTranscriptionFrame,
InterruptionFrame,
LLMTextFrame,
MetricsFrame,
StopFrame,
TranscriptionFrame,
TTSTextFrame,
)
from pipecat.metrics.metrics import TTFBMetricsData
from pipecat.observers.base_observer import BaseObserver, FramePushed
@ -207,7 +207,7 @@ class RealtimeFeedbackObserver(BaseObserver):
)
# Handle bot TTS text - respect pts timing, WebSocket only
# Complete turn text is persisted via register_turn_handlers
elif isinstance(frame, TTSTextFrame):
elif isinstance(frame, LLMTextFrame):
message = {
"type": RealtimeFeedbackType.BOT_TEXT.value,
"payload": {

View file

@ -0,0 +1,338 @@
"""Filesystem-backed cache and audio fetcher for workflow recordings.
Downloads recording files from object storage on first access, converts them
to raw 16-bit mono PCM at the pipeline sample rate via ffmpeg, trims
leading/trailing silence, and caches the processed bytes on disk so
subsequent plays (even from other workers) are instantaneous.
"""
import asyncio
import os
import shutil
import tempfile
from typing import Awaitable, Callable, Optional
import numpy as np
from loguru import logger
from api.constants import APP_ROOT_DIR
from pipecat.audio.utils import SPEAKING_THRESHOLD
# ---------------------------------------------------------------------------
# Filesystem cache directory
# ---------------------------------------------------------------------------
_CACHE_DIR = os.path.join(os.path.dirname(APP_ROOT_DIR), "dograh_pcm_cache")
os.makedirs(_CACHE_DIR, exist_ok=True)
def _cache_path(recording_id: str, sample_rate: int) -> str:
"""Return the on-disk path for a cached PCM file."""
return os.path.join(_CACHE_DIR, f"{recording_id}_{sample_rate}.pcm")
# ---------------------------------------------------------------------------
# Public factory
# ---------------------------------------------------------------------------
def create_recording_audio_fetcher(
organization_id: int,
pipeline_sample_rate: int,
) -> Callable[[str], Awaitable[Optional[bytes]]]:
"""Create an async callback that returns raw PCM bytes for a recording_id.
The returned callable:
1. Checks the filesystem cache (keyed by ``recording_id`` + sample rate).
2. On miss, looks up the recording in the DB, downloads the audio file
from S3/MinIO, converts it to 16-bit mono PCM at *pipeline_sample_rate*,
trims leading/trailing silence, caches the result on disk, and returns it.
Args:
organization_id: Organization owning the recordings.
pipeline_sample_rate: Target PCM sample rate for the pipeline.
Returns:
``async (recording_id: str) -> Optional[bytes]``
"""
from api.db import db_client
from api.services.storage import get_storage_for_backend
# Resolve storage instances once per backend at creation time, not per fetch.
_storage_cache: dict[str, object] = {}
def _get_storage(backend: str):
if backend not in _storage_cache:
_storage_cache[backend] = get_storage_for_backend(backend)
return _storage_cache[backend]
async def fetch(recording_id: str) -> Optional[bytes]:
cached = _cache_path(recording_id, pipeline_sample_rate)
# 1. Serve from filesystem cache
if os.path.exists(cached):
logger.debug(f"Recording {recording_id} served from disk cache")
return _read_file(cached)
# 2. DB lookup
recording = await db_client.get_recording_by_recording_id(
recording_id, organization_id
)
if not recording:
logger.warning(f"Recording {recording_id} not found in database")
return None
# 3. Download, convert, trim, and cache
pcm_data = await _download_and_convert(
recording, pipeline_sample_rate, _get_storage
)
return pcm_data
return fetch
# ---------------------------------------------------------------------------
# Cache warming
# ---------------------------------------------------------------------------
async def warm_recording_cache(
workflow_id: int,
organization_id: int,
pipeline_sample_rate: int,
) -> None:
"""Pre-fetch all active recordings for a workflow into the disk cache.
Launched as a background ``asyncio.Task`` at pipeline startup so that
recordings are ready before the first playback request. Errors are logged
but never propagated a cache miss falls back to the on-demand fetch path.
"""
from api.db import db_client
from api.services.storage import get_storage_for_backend
try:
recordings = await db_client.get_recordings_for_workflow(
workflow_id, organization_id
)
if not recordings:
return
# Skip if every recording is already cached on disk
uncached = [
r
for r in recordings
if not os.path.exists(_cache_path(r.recording_id, pipeline_sample_rate))
]
if not uncached:
logger.debug(f"Recording cache already warm for workflow {workflow_id}")
return
logger.info(
f"Warming recording cache: {len(uncached)}/{len(recordings)} "
f"recording(s) for workflow {workflow_id}"
)
# Resolve storage instances once per backend, not per recording
storage_by_backend: dict[str, object] = {}
def _get_storage(backend: str):
if backend not in storage_by_backend:
storage_by_backend[backend] = get_storage_for_backend(backend)
return storage_by_backend[backend]
for recording in uncached:
try:
pcm_data = await _download_and_convert(
recording, pipeline_sample_rate, _get_storage
)
if pcm_data:
logger.debug(
f"Cache warm: loaded {recording.recording_id} "
f"({len(pcm_data)} bytes)"
)
except Exception:
logger.exception(
f"Cache warm: error processing {recording.recording_id}"
)
logger.info(f"Recording cache warm complete for workflow {workflow_id}")
except Exception:
logger.exception("Recording cache warm failed")
# ---------------------------------------------------------------------------
# Shared download → convert → trim → cache-to-disk helper
# ---------------------------------------------------------------------------
async def _download_and_convert(
recording, sample_rate: int, get_storage_fn
) -> Optional[bytes]:
"""Download a recording from storage, convert to PCM, trim, and cache to disk.
Returns the processed PCM bytes, or None on failure.
"""
ext = _ext_from_key(recording.storage_key)
fd, tmp_path = tempfile.mkstemp(suffix=ext, prefix=f"dograh_dl_{recording.recording_id}_")
os.close(fd)
try:
storage = get_storage_fn(recording.storage_backend)
success = await storage.adownload_file(recording.storage_key, tmp_path)
if not success:
logger.error(f"Failed to download recording {recording.recording_id}")
return None
pcm_data = await _audio_file_to_pcm(tmp_path, sample_rate)
if pcm_data is None:
return None
pcm_data = _trim_silence(pcm_data, sample_rate)
# Write to disk cache atomically (write to tmp then rename)
cached = _cache_path(recording.recording_id, sample_rate)
fd, tmp_cache = tempfile.mkstemp(dir=_CACHE_DIR, suffix=".pcm.tmp")
os.close(fd)
_write_file(tmp_cache, pcm_data)
os.replace(tmp_cache, cached)
return pcm_data
except Exception:
logger.exception(f"Error fetching recording {recording.recording_id}")
return None
finally:
if os.path.exists(tmp_path):
try:
os.unlink(tmp_path)
except OSError:
pass
# ---------------------------------------------------------------------------
# File I/O helpers (run via asyncio.to_thread)
# ---------------------------------------------------------------------------
def _read_file(path: str) -> bytes:
with open(path, "rb") as f:
return f.read()
def _write_file(path: str, data: bytes) -> None:
with open(path, "wb") as f:
f.write(data)
# ---------------------------------------------------------------------------
# Audio conversion
# ---------------------------------------------------------------------------
async def _audio_file_to_pcm(
file_path: str, target_sample_rate: int
) -> Optional[bytes]:
"""Convert an audio file to raw 16-bit mono PCM bytes via ffmpeg."""
ffmpeg = shutil.which("ffmpeg")
if not ffmpeg:
logger.error("ffmpeg not found on PATH — cannot decode recording")
return None
cmd = [
ffmpeg,
"-i",
file_path,
"-f",
"s16le", # raw 16-bit signed little-endian PCM
"-acodec",
"pcm_s16le",
"-ac",
"1", # mono
"-ar",
str(target_sample_rate),
"-loglevel",
"error",
"pipe:1", # output to stdout
]
try:
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await proc.communicate()
if proc.returncode != 0:
logger.error(f"ffmpeg failed (rc={proc.returncode}): {stderr.decode()}")
return None
if not stdout:
logger.error("ffmpeg produced no output")
return None
return stdout
except Exception:
logger.exception("ffmpeg subprocess error")
return None
# ---------------------------------------------------------------------------
# Silence trimming
# ---------------------------------------------------------------------------
def _trim_silence(pcm_data: bytes, sample_rate: int) -> bytes:
"""Trim leading and trailing silence from raw 16-bit mono PCM bytes.
Uses 10ms frames and the same amplitude threshold as pipecat's
``is_silence`` to detect speech boundaries.
"""
data = np.frombuffer(pcm_data, dtype=np.int16)
frame_size = int(sample_rate * 0.01) # 10ms frames
num_frames = len(data) // frame_size
if num_frames == 0:
return pcm_data
# Find first non-silent frame
first_speech = None
for i in range(num_frames):
frame = data[i * frame_size : (i + 1) * frame_size]
if np.abs(frame).max() > SPEAKING_THRESHOLD:
first_speech = i
break
if first_speech is None:
# Entire clip is silence — return as-is to avoid empty audio
return pcm_data
# Find last non-silent frame
last_speech = first_speech
for i in range(num_frames - 1, first_speech - 1, -1):
frame = data[i * frame_size : (i + 1) * frame_size]
if np.abs(frame).max() > SPEAKING_THRESHOLD:
last_speech = i
break
start = first_speech * frame_size
end = (last_speech + 1) * frame_size
trimmed = data[start:end]
trimmed_duration = len(trimmed) / sample_rate
original_duration = len(data) / sample_rate
if original_duration - trimmed_duration > 0.05:
logger.debug(
f"Trimmed silence: {original_duration:.2f}s → {trimmed_duration:.2f}s"
)
return trimmed.tobytes()
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _ext_from_key(storage_key: str) -> str:
"""Extract file extension from a storage key, defaulting to .wav."""
_, ext = os.path.splitext(storage_key)
return ext if ext else ".wav"

View file

@ -0,0 +1,254 @@
"""Recording router processor for routing LLM output between TTS and pre-recorded audio.
Sits between the LLM (after pipeline_engine_callbacks_processor) and TTS in the
pipeline. Detects response mode markers ( for TTS, for recording) and routes
accordingly:
- (TTS): Strips the marker, passes remaining text downstream to TTS.
- (Recording): Suppresses TTS, fetches cached audio, pushes
OutputAudioRawFrame downstream.
Pattern modelled after ``pipecat.turns.user_turn_completion_mixin`` buffer
streaming LLM text tokens until the mode marker is detected, then act.
"""
import uuid
from typing import Awaitable, Callable, Optional
from loguru import logger
from api.services.workflow.pipecat_engine_context_composer import (
RECORDING_MARKER,
TTS_MARKER,
)
from pipecat.frames.frames import (
Frame,
InterruptionFrame,
LLMFullResponseEndFrame,
LLMTextFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
class RecordingRouterProcessor(FrameProcessor):
"""Routes LLM responses between TTS and pre-recorded audio playback.
When the LLM prefixes its response with:
- ```` text flows to TTS as normal speech.
- ```` text is suppressed (skip_tts), and the referenced recording is
fetched (with local disk cache) and streamed as ``OutputAudioRawFrame``.
If no marker is detected by the end of the response, text is passed through
to TTS as a graceful degradation.
Args:
audio_sample_rate: Pipeline sample rate for OutputAudioRawFrame.
fetch_recording_audio: Async callback that takes a recording_id and
returns raw 16-bit mono PCM bytes, or None on failure.
"""
def __init__(
self,
*,
audio_sample_rate: int,
fetch_recording_audio: Callable[[str], Awaitable[Optional[bytes]]],
**kwargs,
):
super().__init__(**kwargs)
self._audio_sample_rate = audio_sample_rate
self._fetch_recording_audio = fetch_recording_audio
# Per-response state
self._frame_buffer: list[tuple[LLMTextFrame, FrameDirection]] = []
self._mode: Optional[str] = None # None = detecting, "tts", "recording"
self._recording_id_buffer = ""
# ------------------------------------------------------------------
# Frame dispatch
# ------------------------------------------------------------------
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, InterruptionFrame):
self._reset()
await self.push_frame(frame, direction)
elif isinstance(frame, LLMTextFrame):
await self._handle_llm_text(frame, direction)
elif isinstance(frame, LLMFullResponseEndFrame):
await self._handle_response_end(frame, direction)
else:
await self.push_frame(frame, direction)
# ------------------------------------------------------------------
# LLMTextFrame handling
# ------------------------------------------------------------------
async def _handle_llm_text(self, frame: LLMTextFrame, direction: FrameDirection):
# Pass through frames already marked skip_tts (e.g. turn completion ✓)
if frame.skip_tts:
await self.push_frame(frame, direction)
return
# --- TTS mode established: pass text through normally ---
if self._mode == "tts":
await self.push_frame(frame, direction)
return
# --- Recording mode: buffer recording_id, suppress TTS ---
if self._mode == "recording":
self._recording_id_buffer += frame.text
frame.skip_tts = True
await self.push_frame(frame, direction)
return
# --- Detection mode: buffer until marker found ---
self._frame_buffer.append((frame, direction))
buffered_text = self._buffered_text()
# Check for recording marker (●)
if RECORDING_MARKER in buffered_text:
self._mode = "recording"
marker_end = buffered_text.index(RECORDING_MARKER) + len(RECORDING_MARKER)
# Push buffered frames with skip_tts, extract recording_id from post-marker text
cumulative = 0
for buf_frame, buf_dir in self._frame_buffer:
buf_frame.skip_tts = True
frame_start = cumulative
cumulative += len(buf_frame.text)
await self.push_frame(buf_frame, buf_dir)
# Capture any recording_id text after the marker
if cumulative > marker_end:
offset = max(marker_end - frame_start, 0)
remaining = buf_frame.text[offset:]
if not self._recording_id_buffer and remaining.startswith(" "):
remaining = remaining[1:]
self._recording_id_buffer += remaining
self._frame_buffer = []
return
# Check for TTS marker (▸)
if TTS_MARKER in buffered_text:
self._mode = "tts"
marker_end = buffered_text.index(TTS_MARKER) + len(TTS_MARKER)
# Push buffered frames — skip_tts for marker portion, normal for the rest
cumulative = 0
for buf_frame, buf_dir in self._frame_buffer:
frame_start = cumulative
cumulative += len(buf_frame.text)
if cumulative <= marker_end:
# Entirely within marker portion — suppress TTS
buf_frame.skip_tts = True
await self.push_frame(buf_frame, buf_dir)
elif frame_start >= marker_end:
# Entirely after marker — normal TTS speech
if frame_start == marker_end and buf_frame.text.startswith(" "):
buf_frame.text = buf_frame.text[1:]
if buf_frame.text:
await self.push_frame(buf_frame, buf_dir)
else:
# Frame spans the marker boundary — split
offset = marker_end - frame_start
original_text = buf_frame.text
buf_frame.text = original_text[:offset]
buf_frame.skip_tts = True
await self.push_frame(buf_frame, buf_dir)
tts_text = original_text[offset:]
if tts_text.startswith(" "):
tts_text = tts_text[1:]
if tts_text:
await self.push_frame(LLMTextFrame(tts_text), buf_dir)
self._frame_buffer = []
return
# Neither marker found yet — keep buffering (should arrive very soon)
# ------------------------------------------------------------------
# End-of-response handling
# ------------------------------------------------------------------
async def _handle_response_end(
self, frame: LLMFullResponseEndFrame, direction: FrameDirection
):
if self._mode == "recording":
recording_id = self._recording_id_buffer.strip()
if recording_id:
await self._play_recording(recording_id)
else:
logger.warning(
"RecordingRouterProcessor: recording mode but empty recording_id"
)
elif self._mode is None and self._frame_buffer:
# Graceful degradation: no marker detected, pass text to TTS as-is
logger.warning(
"RecordingRouterProcessor: no response mode marker found, "
"passing text to TTS as-is"
)
for buf_frame, buf_dir in self._frame_buffer:
await self.push_frame(buf_frame, buf_dir)
self._reset()
await self.push_frame(frame, direction)
# ------------------------------------------------------------------
# Audio playback
# ------------------------------------------------------------------
async def _play_recording(self, recording_id: str):
"""Fetch recording audio and push TTSStarted → TTSAudioRaw → TTSStopped.
The transport handles chunking automatically. The Started/Stopped
frames ensure downstream processors (transport, audio buffer, observers)
treat this as a proper TTS utterance.
"""
logger.info(f"Playing pre-recorded audio: {recording_id}")
audio_data = await self._fetch_recording_audio(recording_id)
if not audio_data:
logger.warning(
f"Failed to fetch recording {recording_id}, no audio will play"
)
return
context_id = str(uuid.uuid4())
await self.push_frame(TTSStartedFrame(context_id=context_id))
await self.push_frame(
TTSAudioRawFrame(
audio=audio_data,
sample_rate=self._audio_sample_rate,
num_channels=1,
context_id=context_id,
)
)
await self.push_frame(TTSStoppedFrame(context_id=context_id))
duration_secs = len(audio_data) / (self._audio_sample_rate * 2)
logger.debug(
f"Finished pushing recording {recording_id} "
f"({len(audio_data)} bytes, {duration_secs:.1f}s)"
)
# ------------------------------------------------------------------
# State management
# ------------------------------------------------------------------
def _buffered_text(self) -> str:
"""Return concatenated text from the frame buffer."""
return "".join(f.text for f, _ in self._frame_buffer)
def _reset(self):
"""Reset per-response state."""
self._frame_buffer = []
self._mode = None
self._recording_id_buffer = ""

View file

@ -27,6 +27,11 @@ from api.services.pipecat.realtime_feedback_observer import (
RealtimeFeedbackObserver,
register_turn_log_handlers,
)
from api.services.pipecat.recording_audio_cache import (
create_recording_audio_fetcher,
warm_recording_cache,
)
from api.services.pipecat.recording_router_processor import RecordingRouterProcessor
from api.services.pipecat.service_factory import (
create_llm_service,
create_stt_service,
@ -558,6 +563,12 @@ async def _run_pipeline(
embeddings_model = user_config.embeddings.model
embeddings_base_url = getattr(user_config.embeddings, "base_url", None)
# Check if the workflow has any active recordings so the engine can
# include recording response mode instructions in all node prompts.
has_recordings = await db_client.has_active_recordings(
workflow_id, workflow.organization_id
)
engine = PipecatEngine(
llm=llm,
workflow=workflow_graph,
@ -567,6 +578,7 @@ async def _run_pipeline(
embeddings_api_key=embeddings_api_key,
embeddings_model=embeddings_model,
embeddings_base_url=embeddings_base_url,
has_recordings=has_recordings,
)
# Create pipeline components
@ -680,6 +692,27 @@ async def _run_pipeline(
abort_immediately=True,
)
# Create recording router if workflow has active recordings
recording_router = None
if has_recordings:
fetch_audio = create_recording_audio_fetcher(
organization_id=workflow.organization_id,
pipeline_sample_rate=audio_config.pipeline_sample_rate,
)
recording_router = RecordingRouterProcessor(
audio_sample_rate=audio_config.pipeline_sample_rate,
fetch_recording_audio=fetch_audio,
)
# Warm the recording cache in the background so audio is ready
# before the first playback request.
asyncio.create_task(
warm_recording_cache(
workflow_id=workflow_id,
organization_id=workflow.organization_id,
pipeline_sample_rate=audio_config.pipeline_sample_rate,
)
)
# Build the pipeline with the STT mute filter and context controller
pipeline = build_pipeline(
transport,
@ -692,6 +725,7 @@ async def _run_pipeline(
pipeline_engine_callback_processor,
pipeline_metrics_aggregator,
voicemail_detector=voicemail_detector,
recording_router=recording_router,
)
# Create pipeline task with audio configuration

View file

@ -5,25 +5,32 @@ from loguru import logger
from api.constants import MPS_API_URL
from api.services.configuration.registry import ServiceProviders
from pipecat.services.azure.llm import AzureLLMService
from pipecat.services.azure.llm import AzureLLMService, AzureLLMSettings
from pipecat.services.cartesia.stt import CartesiaSTTService
from pipecat.services.cartesia.tts import CartesiaTTSService
from pipecat.services.deepgram.flux.stt import DeepgramFluxSTTService
from pipecat.services.deepgram.stt import DeepgramSTTService, LiveOptions
from pipecat.services.deepgram.tts import DeepgramTTSService
from pipecat.services.cartesia.tts import CartesiaTTSService, CartesiaTTSSettings
from pipecat.services.deepgram.flux.stt import (
DeepgramFluxSTTService,
DeepgramFluxSTTSettings,
)
from pipecat.services.deepgram.stt import DeepgramSTTService, DeepgramSTTSettings
from pipecat.services.deepgram.tts import DeepgramTTSService, DeepgramTTSSettings
from pipecat.services.dograh.llm import DograhLLMService
from pipecat.services.dograh.stt import DograhSTTService
from pipecat.services.dograh.tts import DograhTTSService
from pipecat.services.elevenlabs.tts import ElevenLabsTTSService
from pipecat.services.google.llm import GoogleLLMService
from pipecat.services.groq.llm import GroqLLMService
from pipecat.services.dograh.stt import DograhSTTService, DograhSTTSettings
from pipecat.services.dograh.tts import DograhTTSService, DograhTTSSettings
from pipecat.services.elevenlabs.tts import ElevenLabsTTSService, ElevenLabsTTSSettings
from pipecat.services.google.llm import GoogleLLMService, GoogleLLMSettings
from pipecat.services.groq.llm import GroqLLMService, GroqLLMSettings
from pipecat.services.openai.base_llm import OpenAILLMSettings
from pipecat.services.openai.llm import OpenAILLMService
from pipecat.services.openai.stt import OpenAISTTService
from pipecat.services.openai.tts import OpenAITTSService
from pipecat.services.openrouter.llm import OpenRouterLLMService
from pipecat.services.sarvam.stt import SarvamSTTService
from pipecat.services.sarvam.tts import SarvamTTSService
from pipecat.services.speechmatics.stt import SpeechmaticsSTTService
from pipecat.services.openai.stt import OpenAISTTService, OpenAISTTSettings
from pipecat.services.openai.tts import OpenAITTSService, OpenAITTSSettings
from pipecat.services.openrouter.llm import OpenRouterLLMService, OpenRouterLLMSettings
from pipecat.services.sarvam.stt import SarvamSTTService, SarvamSTTSettings
from pipecat.services.sarvam.tts import SarvamTTSService, SarvamTTSSettings
from pipecat.services.speechmatics.stt import (
SpeechmaticsSTTService,
SpeechmaticsSTTSettings,
)
from pipecat.transcriptions.language import Language
from pipecat.utils.text.xml_function_tag_filter import XMLFunctionTagFilter
@ -49,8 +56,8 @@ def create_stt_service(
logger.debug("Using DeepGram Flux Model")
return DeepgramFluxSTTService(
api_key=user_config.stt.api_key,
model=user_config.stt.model,
params=DeepgramFluxSTTService.InputParams(
settings=DeepgramFluxSTTSettings(
model=user_config.stt.model,
eot_timeout_ms=3000,
eot_threshold=0.7,
eager_eot_threshold=0.5,
@ -63,23 +70,23 @@ def create_stt_service(
# Other models than flux
# Use language from user config, defaulting to "multi" for multilingual support
language = getattr(user_config.stt, "language", None) or "multi"
live_options = LiveOptions(
language=language,
profanity_filter=False,
endpointing=100,
model=user_config.stt.model,
keyterm=keyterms or [],
)
logger.debug(f"Using DeepGram Model - {user_config.stt.model}")
return DeepgramSTTService(
live_options=live_options,
api_key=user_config.stt.api_key,
settings=DeepgramSTTSettings(
language=language,
profanity_filter=False,
endpointing=100,
model=user_config.stt.model,
keyterm=keyterms or [],
),
should_interrupt=False, # Let UserAggregator take care of sending InterruptionFrame
sample_rate=audio_config.transport_in_sample_rate,
)
elif user_config.stt.provider == ServiceProviders.OPENAI.value:
return OpenAISTTService(
api_key=user_config.stt.api_key, model=user_config.stt.model
api_key=user_config.stt.api_key,
settings=OpenAISTTSettings(model=user_config.stt.model),
)
elif user_config.stt.provider == ServiceProviders.CARTESIA.value:
return CartesiaSTTService(
@ -92,8 +99,10 @@ def create_stt_service(
return DograhSTTService(
base_url=base_url,
api_key=user_config.stt.api_key,
model=user_config.stt.model,
language=language,
settings=DograhSTTSettings(
model=user_config.stt.model,
language=language,
),
keyterms=keyterms,
sample_rate=audio_config.transport_in_sample_rate,
)
@ -117,8 +126,10 @@ def create_stt_service(
pipecat_language = language_mapping.get(language, Language.HI_IN)
return SarvamSTTService(
api_key=user_config.stt.api_key,
model=user_config.stt.model,
params=SarvamSTTService.InputParams(language=pipecat_language),
settings=SarvamSTTSettings(
model=user_config.stt.model,
language=pipecat_language,
),
sample_rate=audio_config.transport_in_sample_rate,
)
elif user_config.stt.provider == ServiceProviders.SPEECHMATICS.value:
@ -140,7 +151,7 @@ def create_stt_service(
additional_vocab = [AdditionalVocabEntry(content=term) for term in keyterms]
return SpeechmaticsSTTService(
api_key=user_config.stt.api_key,
params=SpeechmaticsSTTService.InputParams(
settings=SpeechmaticsSTTSettings(
language=language,
operating_point=operating_point,
additional_vocab=additional_vocab,
@ -168,14 +179,16 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
if user_config.tts.provider == ServiceProviders.DEEPGRAM.value:
return DeepgramTTSService(
api_key=user_config.tts.api_key,
voice=user_config.tts.voice,
settings=DeepgramTTSSettings(voice=user_config.tts.voice),
text_filters=[xml_function_tag_filter],
silence_time_s=1.0,
)
elif user_config.tts.provider == ServiceProviders.OPENAI.value:
return OpenAITTSService(
api_key=user_config.tts.api_key,
model=user_config.tts.model,
settings=OpenAITTSSettings(model=user_config.tts.model),
text_filters=[xml_function_tag_filter],
silence_time_s=1.0,
)
elif user_config.tts.provider == ServiceProviders.ELEVENLABS.value:
# Backward compatible with older configuration "Name - voice_id"
@ -186,19 +199,25 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
return ElevenLabsTTSService(
reconnect_on_error=False,
api_key=user_config.tts.api_key,
voice_id=voice_id,
model=user_config.tts.model,
params=ElevenLabsTTSService.InputParams(
stability=0.8, speed=user_config.tts.speed, similarity_boost=0.75
settings=ElevenLabsTTSSettings(
voice=voice_id,
model=user_config.tts.model,
stability=0.8,
speed=user_config.tts.speed,
similarity_boost=0.75,
),
text_filters=[xml_function_tag_filter],
silence_time_s=1.0,
)
elif user_config.tts.provider == ServiceProviders.CARTESIA.value:
return CartesiaTTSService(
api_key=user_config.tts.api_key,
voice_id=user_config.tts.voice,
model=user_config.tts.model,
settings=CartesiaTTSSettings(
voice=user_config.tts.voice,
model=user_config.tts.model,
),
text_filters=[xml_function_tag_filter],
silence_time_s=1.0,
)
elif user_config.tts.provider == ServiceProviders.DOGRAH.value:
# Convert HTTP URL to WebSocket URL for TTS
@ -206,10 +225,13 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
return DograhTTSService(
base_url=base_url,
api_key=user_config.tts.api_key,
model=user_config.tts.model,
voice=user_config.tts.voice,
params=DograhTTSService.InputParams(speed=user_config.tts.speed),
settings=DograhTTSSettings(
model=user_config.tts.model,
voice=user_config.tts.voice,
speed=user_config.tts.speed,
),
text_filters=[xml_function_tag_filter],
silence_time_s=1.0,
)
elif user_config.tts.provider == ServiceProviders.SARVAM.value:
# Map Sarvam language code to pipecat Language enum for TTS
@ -232,10 +254,13 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
voice = getattr(user_config.tts, "voice", None) or "anushka"
return SarvamTTSService(
api_key=user_config.tts.api_key,
model=user_config.tts.model,
voice_id=voice,
params=SarvamTTSService.InputParams(language=pipecat_language),
settings=SarvamTTSSettings(
model=user_config.tts.model,
voice=voice,
language=pipecat_language,
),
text_filters=[xml_function_tag_filter],
silence_time_s=1.0,
)
else:
raise HTTPException(
@ -253,16 +278,15 @@ def create_llm_service(user_config):
if "gpt-5" in model:
return OpenAILLMService(
api_key=user_config.llm.api_key,
model=model,
params=OpenAILLMService.InputParams(
reasoning_effort="minimal", verbosity="low"
settings=OpenAILLMSettings(
model=model,
extra={"reasoning_effort": "minimal", "verbosity": "low"},
),
)
else:
return OpenAILLMService(
api_key=user_config.llm.api_key,
model=model,
params=OpenAILLMService.InputParams(temperature=0.1),
settings=OpenAILLMSettings(model=model, temperature=0.1),
)
elif user_config.llm.provider == ServiceProviders.GROQ.value:
print(
@ -270,36 +294,30 @@ def create_llm_service(user_config):
)
return GroqLLMService(
api_key=user_config.llm.api_key,
model=model,
params=OpenAILLMService.InputParams(temperature=0.1),
settings=GroqLLMSettings(model=model, temperature=0.1),
)
elif user_config.llm.provider == ServiceProviders.OPENROUTER.value:
return OpenRouterLLMService(
api_key=user_config.llm.api_key,
model=model,
base_url=user_config.llm.base_url,
params=OpenAILLMService.InputParams(temperature=0.1),
settings=OpenRouterLLMSettings(model=model, temperature=0.1),
)
elif user_config.llm.provider == ServiceProviders.GOOGLE.value:
# Use the correct InputParams class for Google to avoid propagating OpenAI-specific
# NOT_GIVEN sentinels that break Pydantic validation in GoogleLLMService.
return GoogleLLMService(
api_key=user_config.llm.api_key,
model=model,
params=GoogleLLMService.InputParams(temperature=0.1),
settings=GoogleLLMSettings(model=model, temperature=0.1),
)
elif user_config.llm.provider == ServiceProviders.AZURE.value:
return AzureLLMService(
api_key=user_config.llm.api_key,
endpoint=user_config.llm.endpoint,
model=model, # Azure uses deployment name as model
params=AzureLLMService.InputParams(temperature=0.1),
settings=AzureLLMSettings(model=model, temperature=0.1),
)
elif user_config.llm.provider == ServiceProviders.DOGRAH.value:
return DograhLLMService(
base_url=f"{MPS_API_URL}/api/v1/llm",
api_key=user_config.llm.api_key,
model=model,
settings=OpenAILLMSettings(model=model),
)
else:
raise HTTPException(status_code=400, detail="Invalid LLM provider")

View file

@ -5,6 +5,7 @@ from api.services.workflow.disposition_mapper import (
get_organization_id_from_workflow_run,
)
from api.services.workflow.workflow import Node, WorkflowGraph
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.frames.frames import (
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
@ -16,6 +17,7 @@ from pipecat.frames.frames import (
from pipecat.pipeline.task import PipelineTask
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.services.llm_service import FunctionCallParams
from pipecat.services.settings import LLMSettings
from pipecat.utils.enums import EndTaskReason
if TYPE_CHECKING:
@ -31,18 +33,19 @@ import asyncio
from loguru import logger
from api.services.workflow import pipecat_engine_callbacks as engine_callbacks
from api.services.workflow.pipecat_engine_custom_tools import CustomToolManager
from api.services.workflow.pipecat_engine_utils import (
from api.services.workflow.pipecat_engine_context_composer import (
compose_functions_for_node,
compose_system_prompt_for_node,
)
from api.services.workflow.pipecat_engine_custom_tools import (
CustomToolManager,
get_function_schema,
render_template,
update_llm_context,
)
from api.services.workflow.pipecat_engine_variable_extractor import (
VariableExtractionManager,
)
from api.services.workflow.tools.calculator import get_calculator_tools, safe_calculator
from api.services.workflow.tools.knowledge_base import (
get_knowledge_base_tool,
retrieve_from_knowledge_base,
)
from api.services.workflow.tools.timezone import (
@ -50,6 +53,7 @@ from api.services.workflow.tools.timezone import (
get_current_time,
get_time_tools,
)
from api.utils.template_renderer import render_template
class PipecatEngine:
@ -68,6 +72,7 @@ class PipecatEngine:
embeddings_api_key: Optional[str] = None,
embeddings_model: Optional[str] = None,
embeddings_base_url: Optional[str] = None,
has_recordings: bool = False,
):
self.task = task
self.llm = llm
@ -113,6 +118,10 @@ class PipecatEngine:
# Audio configuration (set via set_audio_config from _run_pipeline)
self._audio_config = None
# True when the workflow has active recordings; enables recording
# response mode instructions on all nodes for in-context learning.
self._has_recordings: bool = has_recordings
async def _get_organization_id(self) -> Optional[int]:
"""Get and cache the organization ID from workflow run."""
if self._custom_tool_manager:
@ -194,15 +203,14 @@ class PipecatEngine:
logger.error(f"Error initializing {self.__class__.__name__}: {e}")
raise
def _get_function_schema(self, function_name: str, description: str):
"""Thin wrapper around utils.get_function_schema for backwards compatibility."""
async def _update_llm_context(self, system_prompt: str, functions: list[dict]):
"""Update LLM settings with the composed system prompt and tool list."""
return get_function_schema(function_name, description)
await self.llm._update_settings(LLMSettings(system_instruction=system_prompt))
async def _update_llm_context(self, system_message: dict, functions: list[dict]):
"""Delegate context update to the shared workflow.utils implementation."""
update_llm_context(self.context, system_message, functions)
if functions:
tools_schema = ToolsSchema(standard_tools=functions)
self.context.set_tools(tools_schema)
def _format_prompt(self, prompt: str) -> str:
"""Delegate prompt formatting to the shared workflow.utils implementation."""
@ -473,12 +481,19 @@ class PipecatEngine:
if node.document_uuids:
await self._register_knowledge_base_function(node.document_uuids)
# Set up system message and functions
(
system_message,
functions,
) = await self._compose_system_message_functions_for_node(node)
await self._update_llm_context(system_message, functions)
# Compose prompt and functions via the context composer module
system_prompt = compose_system_prompt_for_node(
node=node,
workflow=self.workflow,
format_prompt=self._format_prompt,
has_recordings=self._has_recordings,
)
functions = await compose_functions_for_node(
node=node,
builtin_function_schemas=self.builtin_function_schemas,
custom_tool_manager=self._custom_tool_manager,
)
await self._update_llm_context(system_prompt, functions)
async def set_node(self, node_id: str):
"""
@ -610,62 +625,6 @@ class PipecatEngine:
)
await self.task.queue_frame(frame_to_push)
async def _compose_system_message_functions_for_node(
self, node: "Node"
) -> tuple[list[dict], list[dict]]:
"""Generate the system messages and function schemas for the given node.
This performs the same formatting logic used when entering a node but
does **not** register the functions with the LLM; callers are
responsible for that.
"""
global_prompt = ""
if self.workflow.global_node_id and node.add_global_prompt:
global_node = self.workflow.nodes[self.workflow.global_node_id]
global_prompt = self._format_prompt(global_node.prompt)
functions: list[dict] = []
# Add built-in function schemas (calculator and timezone tools)
functions.extend(self.builtin_function_schemas)
# Add knowledge base retrieval tool if node has documents
if node.document_uuids:
kb_tool_def = get_knowledge_base_tool(node.document_uuids)
kb_schema = get_function_schema(
kb_tool_def["function"]["name"],
kb_tool_def["function"]["description"],
properties=kb_tool_def["function"]["parameters"].get("properties", {}),
required=kb_tool_def["function"]["parameters"].get("required", []),
)
functions.append(kb_schema)
# Add custom tools from node.tool_uuids
if node.tool_uuids and self._custom_tool_manager:
custom_tool_schemas = await self._custom_tool_manager.get_tool_schemas(
node.tool_uuids
)
functions.extend(custom_tool_schemas)
# Transition functions (schema only; registration handled elsewhere)
for outgoing_edge in node.out_edges:
function_schema = self._get_function_schema(
outgoing_edge.get_function_name(), outgoing_edge.condition
)
functions.append(function_schema)
formatted_node_prompt = self._format_prompt(node.prompt)
system_message = {
"role": "system",
"content": "\n\n".join(
p for p in (global_prompt, formatted_node_prompt) if p
),
}
return system_message, functions
async def should_mute_user(self, frame: "Frame") -> bool:
"""
Callback for CallbackUserMuteStrategy to determine if the user should be muted.

View file

@ -0,0 +1,138 @@
"""System prompt and function schema composition for PipecatEngine nodes.
Extracts prompt and function composition logic from PipecatEngine into
reusable functions. Defines recording response mode markers and instructions.
"""
from typing import TYPE_CHECKING, Callable, Optional
if TYPE_CHECKING:
from api.services.workflow.pipecat_engine_custom_tools import CustomToolManager
from api.services.workflow.workflow import Node, WorkflowGraph
from api.services.workflow.pipecat_engine_custom_tools import get_function_schema
from api.services.workflow.tools.knowledge_base import get_knowledge_base_tool
# ---------------------------------------------------------------------------
# Recording response mode markers
# ---------------------------------------------------------------------------
RECORDING_MARKER = "" # Play pre-recorded audio
TTS_MARKER = "" # Generate dynamic TTS text
# ---------------------------------------------------------------------------
# Recording response mode system prompt instructions
# ---------------------------------------------------------------------------
RECORDING_RESPONSE_MODE_INSTRUCTIONS = """\
RESPONSE MODE INSTRUCTIONS - MANDATORY FORMAT:
Every response you generate MUST begin with a response mode indicator.
You have two modes for responding:
1. DYNAMIC SPEECH (): Generate text that will be converted to speech by TTS.
Format: `` followed by a space and your full spoken response.
Example: Hello! How can I help you today?
2. PRE-RECORDED AUDIO (): Play a pre-recorded audio message.
Format: `` followed by a space and ONLY the recording_id. Nothing else.
Example: rec_greeting_01
RULES:
- Your response MUST start with either `` or `` as the very first character.
- For `` (dynamic speech): Follow with a space and your full response text.
- For `` (pre-recorded audio): Follow with a space and ONLY the recording_id. No other text.
- Use `` when a pre-recorded message matches the situation well.
- Use `` when you need to generate a dynamic, contextual response.
- NEVER mix modes in a single response. Choose one."""
def compose_system_prompt_for_node(
*,
node: "Node",
workflow: "WorkflowGraph",
format_prompt: Callable[[str], str],
has_recordings: bool,
) -> str:
"""Compose the full system prompt text for a workflow node.
Combines the global prompt, node-specific prompt, and (when recordings
are enabled anywhere in the workflow) the recording response mode
instructions into a single string.
Args:
node: The workflow node to compose the prompt for.
workflow: The full workflow graph (needed for global node prompt).
format_prompt: Callable to render template variables in prompts.
has_recordings: Whether any node in the workflow uses recordings.
Returns:
The composed system prompt text.
"""
global_prompt = ""
if workflow.global_node_id and node.add_global_prompt:
global_node = workflow.nodes[workflow.global_node_id]
global_prompt = format_prompt(global_node.prompt)
formatted_node_prompt = format_prompt(node.prompt)
parts = [p for p in (global_prompt, formatted_node_prompt) if p]
if has_recordings:
parts.append(RECORDING_RESPONSE_MODE_INSTRUCTIONS)
# TODO: Append per-node available recordings list here once
# Node.recording_ids is populated. The list should include
# recording_id and a short description so the LLM can choose.
return "\n\n".join(parts)
async def compose_functions_for_node(
*,
node: "Node",
builtin_function_schemas: list[dict],
custom_tool_manager: Optional["CustomToolManager"],
) -> list[dict]:
"""Compose the function/tool schemas for a workflow node.
Gathers built-in tools, knowledge-base tools, custom tools,
and transition function schemas into a single list.
Args:
node: The workflow node to compose functions for.
builtin_function_schemas: Pre-computed schemas for built-in tools.
custom_tool_manager: Manager for user-defined custom tools (may be None).
Returns:
A list of function schemas to register with the LLM.
"""
functions: list[dict] = []
# Built-in tools (calculator, timezone)
functions.extend(builtin_function_schemas)
# Knowledge base retrieval tool
if node.document_uuids:
kb_tool_def = get_knowledge_base_tool(node.document_uuids)
kb_schema = get_function_schema(
kb_tool_def["function"]["name"],
kb_tool_def["function"]["description"],
properties=kb_tool_def["function"]["parameters"].get("properties", {}),
required=kb_tool_def["function"]["parameters"].get("required", []),
)
functions.append(kb_schema)
# Custom tools
if node.tool_uuids and custom_tool_manager:
custom_tool_schemas = await custom_tool_manager.get_tool_schemas(
node.tool_uuids
)
functions.extend(custom_tool_schemas)
# Transition function schemas
for outgoing_edge in node.out_edges:
function_schema = get_function_schema(
outgoing_edge.get_function_name(), outgoing_edge.condition
)
functions.append(function_schema)
return functions

View file

@ -10,7 +10,7 @@ import asyncio
import re
import time
import uuid
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from loguru import logger
@ -23,7 +23,6 @@ from api.services.telephony.transfer_event_protocol import TransferContext
from api.services.workflow.disposition_mapper import (
get_organization_id_from_workflow_run,
)
from api.services.workflow.pipecat_engine_utils import get_function_schema
from api.services.workflow.tools.custom_tool import (
execute_http_tool,
tool_to_function_schema,
@ -42,6 +41,29 @@ if TYPE_CHECKING:
from api.services.workflow.pipecat_engine import PipecatEngine
def get_function_schema(
function_name: str,
description: str,
*,
properties: Dict[str, Any] | None = None,
required: List[str] | None = None,
) -> FunctionSchema:
"""Create a FunctionSchema definition that can later be transformed into
the provider-specific format (OpenAI, Gemini, etc.).
The helper keeps the public signature backward-compatible callers that
only pass ``function_name`` and ``description`` continue to work and will
define a parameter-less function.
"""
return FunctionSchema(
name=function_name,
description=description,
properties=properties or {},
required=required or [],
)
class CustomToolManager:
"""Manager for custom tool registration and execution.

View file

@ -1,68 +0,0 @@
from __future__ import annotations
from typing import Any, Dict, List
from api.utils.template_renderer import render_template
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.processors.aggregators.llm_context import LLMContext
__all__ = [
"get_function_schema",
"update_llm_context",
"render_template",
]
def get_function_schema(
function_name: str,
description: str,
*,
properties: Dict[str, Any] | None = None,
required: List[str] | None = None,
) -> FunctionSchema:
"""Create a FunctionSchema definition that can later be transformed into
the provider-specific format (OpenAI, Gemini, etc.).
The helper keeps the public signature backward-compatible callers that
only pass ``function_name`` and ``description`` continue to work and will
define a parameter-less function.
"""
return FunctionSchema(
name=function_name,
description=description,
properties=properties or {},
required=required or [],
)
def update_llm_context(
context: LLMContext,
system_message: Dict[str, Any],
functions: List[FunctionSchema],
) -> None:
"""Update *context* with an up-to-date system message and tool list.
This helper removes any previous system messages before inserting the new
*system_message* at the top of the conversation history and then instructs
the LLM which *functions* (a.k.a. tools) are currently available.
"""
# Wrap the provided function schemas in a ToolsSchema so that the adapter
# associated with the current LLM service can convert them to the correct
# provider-specific representation when required.
tools_schema = ToolsSchema(standard_tools=functions)
previous_interactions = context.messages
# Replace the first message if it's a system message, otherwise prepend.
# Keep any system messages that appear in the middle of the conversation.
if previous_interactions and previous_interactions[0]["role"] == "system":
messages = [system_message] + previous_interactions[1:]
else:
messages = [system_message] + previous_interactions
context.set_messages(messages)
if functions:
context.set_tools(tools_schema)