feat: allow uploading recording as part of node transition

This commit is contained in:
Abhishek Kumar 2026-04-10 11:54:00 +05:30
parent bb5f56bfb7
commit 65c76ca7ff
36 changed files with 2255 additions and 201 deletions

View file

@ -0,0 +1,70 @@
"""unique recording id per org and workflow
Revision ID: 67a5cf3e09d0
Revises: e7254d2c6c18
Create Date: 2026-04-09 17:03:38.302041
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "67a5cf3e09d0"
down_revision: Union[str, None] = "e7254d2c6c18"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Widen column from 16 to 64 chars for descriptive names
op.alter_column(
"workflow_recordings",
"recording_id",
existing_type=sa.VARCHAR(length=16),
type_=sa.String(length=64),
existing_nullable=False,
)
# Drop the old globally-unique index
op.drop_index(
op.f("ix_workflow_recordings_recording_id"), table_name="workflow_recordings"
)
# Re-create as non-unique index for lookups
op.create_index(
"ix_workflow_recordings_recording_id",
"workflow_recordings",
["recording_id"],
unique=False,
)
# Add composite unique constraint (recording_id, organization_id, workflow_id)
op.create_unique_constraint(
"uq_workflow_recordings_recording_id_org_wf",
"workflow_recordings",
["recording_id", "organization_id", "workflow_id"],
)
def downgrade() -> None:
op.drop_constraint(
"uq_workflow_recordings_recording_id_org_wf",
"workflow_recordings",
type_="unique",
)
op.drop_index(
"ix_workflow_recordings_recording_id", table_name="workflow_recordings"
)
op.create_index(
op.f("ix_workflow_recordings_recording_id"),
"workflow_recordings",
["recording_id"],
unique=True,
)
op.alter_column(
"workflow_recordings",
"recording_id",
existing_type=sa.String(length=64),
type_=sa.VARCHAR(length=16),
existing_nullable=False,
)

View file

@ -1015,8 +1015,8 @@ class WorkflowRecordingModel(Base):
id = Column(Integer, primary_key=True, index=True)
# Short globally unique ID (e.g. "xbhfha3k") used in prompts
recording_id = Column(String(16), unique=True, nullable=False, index=True)
# Descriptive ID used in prompts (unique per organization)
recording_id = Column(String(64), nullable=False, index=True)
# Scoping
workflow_id = Column(
@ -1062,6 +1062,12 @@ class WorkflowRecordingModel(Base):
# Indexes
__table_args__ = (
UniqueConstraint(
"recording_id",
"organization_id",
"workflow_id",
name="uq_workflow_recordings_recording_id_org_wf",
),
Index("ix_workflow_recordings_workflow_id", "workflow_id"),
Index("ix_workflow_recordings_org_id", "organization_id"),
Index("ix_workflow_recordings_recording_id", "recording_id"),

View file

@ -77,19 +77,19 @@ class WorkflowRecordingClient(BaseDBClient):
)
return recording
async def get_recordings_for_workflow(
async def get_recordings(
self,
workflow_id: int,
organization_id: int,
workflow_id: Optional[int] = None,
tts_provider: Optional[str] = None,
tts_model: Optional[str] = None,
tts_voice_id: Optional[str] = None,
) -> List[WorkflowRecordingModel]:
"""Get recordings for a workflow, optionally filtered by TTS config.
"""Get recordings for an organization, optionally filtered by workflow and TTS config.
Args:
workflow_id: ID of the workflow
organization_id: ID of the organization
workflow_id: Optional workflow ID filter
tts_provider: Optional TTS provider filter
tts_model: Optional TTS model filter
tts_voice_id: Optional TTS voice ID filter
@ -99,11 +99,12 @@ class WorkflowRecordingClient(BaseDBClient):
"""
async with self.async_session() as session:
query = select(WorkflowRecordingModel).where(
WorkflowRecordingModel.workflow_id == workflow_id,
WorkflowRecordingModel.organization_id == organization_id,
WorkflowRecordingModel.is_active == True,
)
if workflow_id is not None:
query = query.where(WorkflowRecordingModel.workflow_id == workflow_id)
if tts_provider:
query = query.where(WorkflowRecordingModel.tts_provider == tts_provider)
if tts_model:
@ -120,12 +121,14 @@ class WorkflowRecordingClient(BaseDBClient):
self,
recording_id: str,
organization_id: int,
workflow_id: int,
) -> Optional[WorkflowRecordingModel]:
"""Get a recording by its short ID.
"""Get a recording by its string recording_id (unique per org + workflow).
Args:
recording_id: The short unique recording ID
recording_id: The descriptive recording ID
organization_id: ID of the organization
workflow_id: ID of the workflow
Returns:
WorkflowRecordingModel if found, None otherwise
@ -134,6 +137,31 @@ class WorkflowRecordingClient(BaseDBClient):
query = select(WorkflowRecordingModel).where(
WorkflowRecordingModel.recording_id == recording_id,
WorkflowRecordingModel.organization_id == organization_id,
WorkflowRecordingModel.workflow_id == workflow_id,
WorkflowRecordingModel.is_active == True,
)
result = await session.execute(query)
return result.scalar_one_or_none()
async def get_recording_by_id(
self,
id: int,
organization_id: int,
) -> Optional[WorkflowRecordingModel]:
"""Get a recording by its integer primary key.
Args:
id: The primary key ID
organization_id: ID of the organization
Returns:
WorkflowRecordingModel if found, None otherwise
"""
async with self.async_session() as session:
query = select(WorkflowRecordingModel).where(
WorkflowRecordingModel.id == id,
WorkflowRecordingModel.organization_id == organization_id,
WorkflowRecordingModel.is_active == True,
)
@ -167,11 +195,15 @@ class WorkflowRecordingClient(BaseDBClient):
result = await session.execute(query)
return result.scalar_one() > 0
async def check_recording_id_exists(self, recording_id: str) -> bool:
"""Check if a recording ID already exists globally.
async def check_recording_id_exists(
self, recording_id: str, organization_id: int, workflow_id: int
) -> bool:
"""Check if a recording ID already exists within an organization and workflow.
Args:
recording_id: The short recording ID to check
recording_id: The recording ID to check
organization_id: ID of the organization
workflow_id: ID of the workflow
Returns:
True if exists, False otherwise
@ -179,10 +211,52 @@ class WorkflowRecordingClient(BaseDBClient):
async with self.async_session() as session:
query = select(WorkflowRecordingModel.id).where(
WorkflowRecordingModel.recording_id == recording_id,
WorkflowRecordingModel.organization_id == organization_id,
WorkflowRecordingModel.workflow_id == workflow_id,
WorkflowRecordingModel.is_active == True,
)
result = await session.execute(query)
return result.scalar_one_or_none() is not None
async def update_recording_id(
self,
id: int,
new_recording_id: str,
organization_id: int,
) -> Optional[WorkflowRecordingModel]:
"""Update the recording_id of a recording.
Args:
id: Primary key ID of the recording
new_recording_id: New recording ID
organization_id: ID of the organization
Returns:
Updated WorkflowRecordingModel if found, None otherwise
"""
async with self.async_session() as session:
query = select(WorkflowRecordingModel).where(
WorkflowRecordingModel.id == id,
WorkflowRecordingModel.organization_id == organization_id,
WorkflowRecordingModel.is_active == True,
)
result = await session.execute(query)
recording = result.scalar_one_or_none()
if not recording:
return None
old_id = recording.recording_id
recording.recording_id = new_recording_id
await session.commit()
await session.refresh(recording)
logger.info(
f"Updated recording ID {old_id} -> {new_recording_id}, "
f"org {organization_id}"
)
return recording
async def delete_recording(
self,
recording_id: str,

View file

@ -178,6 +178,11 @@ async def initiate_call(
workflow_run_id = request.workflow_run_id
if not workflow_run_id:
# Fetch workflow to merge template context variables (e.g. caller_number,
# called_number set in workflow settings for testing pre-call data fetch)
workflow = await db_client.get_workflow_by_id(request.workflow_id)
template_vars = (workflow.template_context_variables or {}) if workflow else {}
numeric_suffix = int(str(uuid.uuid4()).replace("-", "")[:8], 16) % 100000000
workflow_run_name = f"WR-TEL-OUT-{numeric_suffix:08d}"
workflow_run = await db_client.create_workflow_run(
@ -187,6 +192,7 @@ async def initiate_call(
user_id=user.id,
call_type=CallType.OUTBOUND,
initial_context={
**template_vars,
"phone_number": phone_number,
"called_number": phone_number,
"provider": provider.PROVIDER_NAME,

View file

@ -16,6 +16,7 @@ from api.schemas.workflow_recording import (
BatchRecordingUploadResponseSchema,
RecordingListResponseSchema,
RecordingResponseSchema,
RecordingUpdateRequestSchema,
RecordingUploadResponseSchema,
)
from api.services.auth.depends import get_user
@ -25,11 +26,13 @@ from api.services.storage import storage_fs
router = APIRouter(prefix="/workflow-recordings", tags=["workflow-recordings"])
async def _generate_unique_recording_id() -> str:
"""Generate a globally unique short recording ID."""
async def _generate_unique_recording_id(organization_id: int, workflow_id: int) -> str:
"""Generate a unique short recording ID within an organization and workflow."""
for _ in range(10):
rid = generate_short_id(8)
exists = await db_client.check_recording_id_exists(rid)
exists = await db_client.check_recording_id_exists(
rid, organization_id, workflow_id
)
if not exists:
return rid
raise HTTPException(
@ -69,7 +72,9 @@ async def get_upload_urls(
try:
items = []
for fd in request.files:
recording_id = await _generate_unique_recording_id()
recording_id = await _generate_unique_recording_id(
user.selected_organization_id, request.workflow_id
)
storage_key = (
f"recordings/{user.selected_organization_id}"
@ -163,10 +168,12 @@ async def create_recordings(
@router.get(
"/",
response_model=RecordingListResponseSchema,
summary="List recordings for a workflow",
summary="List recordings",
)
async def list_recordings(
workflow_id: Annotated[int, Query(description="Workflow ID")],
workflow_id: Annotated[
Optional[int], Query(description="Filter by workflow ID")
] = None,
tts_provider: Annotated[
Optional[str], Query(description="Filter by TTS provider")
] = None,
@ -178,11 +185,11 @@ async def list_recordings(
] = None,
user=Depends(get_user),
):
"""List recordings for a workflow, optionally filtered by TTS configuration."""
"""List recordings for the organization, optionally filtered by workflow and TTS configuration."""
try:
recordings = await db_client.get_recordings_for_workflow(
workflow_id=workflow_id,
recordings = await db_client.get_recordings(
organization_id=user.selected_organization_id,
workflow_id=workflow_id,
tts_provider=tts_provider,
tts_model=tts_model,
tts_voice_id=tts_voice_id,
@ -233,6 +240,62 @@ async def delete_recording(
) from exc
@router.patch(
"/{id}",
response_model=RecordingResponseSchema,
summary="Update a recording's Recording ID",
)
async def update_recording(
id: int,
request: RecordingUpdateRequestSchema,
user=Depends(get_user),
):
"""Update the recording_id (descriptive name) of a recording."""
try:
new_id = request.recording_id.strip()
if not new_id:
raise HTTPException(status_code=400, detail="Recording ID cannot be empty")
# Look up by integer PK — globally unique, no ambiguity
existing = await db_client.get_recording_by_id(
id, user.selected_organization_id
)
if not existing:
raise HTTPException(status_code=404, detail="Recording not found")
if new_id == existing.recording_id:
return _build_response(existing)
# Check if the new ID is already taken within this org + workflow
exists = await db_client.check_recording_id_exists(
new_id, user.selected_organization_id, existing.workflow_id
)
if exists:
raise HTTPException(
status_code=409,
detail=f"Recording ID '{new_id}' is already in use in this workflow",
)
recording = await db_client.update_recording_id(
id=id,
new_recording_id=new_id,
organization_id=user.selected_organization_id,
)
if not recording:
raise HTTPException(status_code=404, detail="Recording not found")
return _build_response(recording)
except HTTPException:
raise
except Exception as exc:
logger.error(f"Error updating recording: {exc}")
raise HTTPException(
status_code=500, detail="Failed to update recording"
) from exc
@router.post(
"/transcribe",
summary="Transcribe an audio file",

View file

@ -98,6 +98,17 @@ class BatchRecordingCreateResponseSchema(BaseModel):
)
class RecordingUpdateRequestSchema(BaseModel):
"""Request schema for updating a recording's ID."""
recording_id: str = Field(
...,
min_length=1,
max_length=64,
description="New descriptive recording ID",
)
class RecordingListResponseSchema(BaseModel):
"""Response schema for list of recordings."""

View file

@ -200,7 +200,6 @@ class CampaignCallDispatcher:
# Merge context variables (queued_run context already includes retry info if applicable)
initial_context = {
**workflow.template_context_variables,
**queued_run.context_variables,
"campaign_id": campaign.id,
"provider": provider.PROVIDER_NAME,

View file

@ -11,12 +11,17 @@ from api.services.pipecat.in_memory_buffers import (
InMemoryLogsBuffer,
)
from api.services.pipecat.pipeline_metrics_aggregator import PipelineMetricsAggregator
from api.services.pipecat.recording_playback import queue_recording_audio
from api.services.pipecat.tracing_config import get_trace_url
from api.services.workflow.pipecat_engine import PipecatEngine
from api.tasks.arq import enqueue_job
from api.tasks.function_names import FunctionNames
from api.utils.hold_audio import play_hold_audio_loop
from pipecat.frames.frames import Frame, LLMContextFrame, TTSSpeakFrame
from pipecat.frames.frames import (
Frame,
LLMContextFrame,
TTSSpeakFrame,
)
from pipecat.pipeline.task import PipelineTask
from pipecat.processors.audio.audio_buffer_processor import AudioBufferProcessor
from pipecat.utils.enums import EndTaskReason
@ -32,6 +37,7 @@ def register_event_handlers(
pipeline_metrics_aggregator: PipelineMetricsAggregator,
audio_config=AudioConfig,
pre_call_fetch_task: asyncio.Task | None = None,
fetch_recording_audio=None,
):
"""Register all event handlers for transport and task events.
@ -112,12 +118,31 @@ def register_event_handlers(
# so that render_template() has the complete _call_context_vars.
await engine.set_node(engine.workflow.start_node_id)
greeting = engine.get_start_greeting()
if greeting:
logger.debug(
"Both pipeline_started and client_connected received - playing greeting via TTS"
)
await task.queue_frame(TTSSpeakFrame(greeting))
greeting_info = engine.get_start_greeting()
if greeting_info:
greeting_type, greeting_value = greeting_info
if (
greeting_type == "audio"
and greeting_value
and fetch_recording_audio
):
logger.debug(f"Playing audio greeting recording: {greeting_value}")
audio_data = await fetch_recording_audio(greeting_value)
if audio_data:
await queue_recording_audio(
audio_data,
sample_rate=audio_config.pipeline_sample_rate or 16000,
queue_frame=task.queue_frame,
)
else:
logger.warning(
f"Failed to fetch audio greeting {greeting_value}, "
"falling back to LLM generation"
)
await engine.llm.queue_frame(LLMContextFrame(engine.context))
else:
logger.debug("Playing text greeting via TTS")
await task.queue_frame(TTSSpeakFrame(greeting_value))
else:
logger.debug(
"Both pipeline_started and client_connected received - triggering initial LLM generation"

View file

@ -27,9 +27,13 @@ from .audio_file_cache import (
# ---------------------------------------------------------------------------
def _cache_path(recording_id: str, sample_rate: int) -> str:
def _cache_path(
organization_id: int, workflow_id: int, recording_id: str, sample_rate: int
) -> str:
"""Return the on-disk path for a cached PCM file."""
return os.path.join(CACHE_DIR, f"{recording_id}_{sample_rate}.pcm")
return os.path.join(
CACHE_DIR, f"{organization_id}_{workflow_id}_{recording_id}_{sample_rate}.pcm"
)
# ---------------------------------------------------------------------------
@ -39,18 +43,20 @@ def _cache_path(recording_id: str, sample_rate: int) -> str:
def create_recording_audio_fetcher(
organization_id: int,
workflow_id: int,
pipeline_sample_rate: int,
) -> Callable[[str], Awaitable[Optional[bytes]]]:
"""Create an async callback that returns raw PCM bytes for a recording_id.
The returned callable:
1. Checks the filesystem cache (keyed by ``recording_id`` + sample rate).
1. Checks the filesystem cache (keyed by org/workflow/recording + sample rate).
2. On miss, looks up the recording in the DB, downloads the audio file
from S3/MinIO, converts it to 16-bit mono PCM at *pipeline_sample_rate*,
trims leading/trailing silence, caches the result on disk, and returns it.
Args:
organization_id: Organization owning the recordings.
workflow_id: Workflow the recordings belong to.
pipeline_sample_rate: Target PCM sample rate for the pipeline.
Returns:
@ -68,7 +74,9 @@ def create_recording_audio_fetcher(
return _storage_cache[backend]
async def fetch(recording_id: str) -> Optional[bytes]:
cached = _cache_path(recording_id, pipeline_sample_rate)
cached = _cache_path(
organization_id, workflow_id, recording_id, pipeline_sample_rate
)
# 1. Serve from filesystem cache
if os.path.exists(cached):
@ -77,7 +85,7 @@ def create_recording_audio_fetcher(
# 2. DB lookup
recording = await db_client.get_recording_by_recording_id(
recording_id, organization_id
recording_id, organization_id, workflow_id
)
if not recording:
logger.warning(f"Recording {recording_id} not found in database")
@ -112,8 +120,8 @@ async def warm_recording_cache(
from api.services.storage import get_storage_for_backend
try:
recordings = await db_client.get_recordings_for_workflow(
workflow_id, organization_id
recordings = await db_client.get_recordings(
organization_id=organization_id, workflow_id=workflow_id
)
if not recordings:
return
@ -122,7 +130,11 @@ async def warm_recording_cache(
uncached = [
r
for r in recordings
if not os.path.exists(_cache_path(r.recording_id, pipeline_sample_rate))
if not os.path.exists(
_cache_path(
organization_id, workflow_id, r.recording_id, pipeline_sample_rate
)
)
]
if not uncached:
logger.debug(f"Recording cache already warm for workflow {workflow_id}")
@ -187,7 +199,12 @@ async def _download_and_convert(
pcm_data = _trim_silence(pcm_data, sample_rate)
# Write to disk cache
cached = _cache_path(recording.recording_id, sample_rate)
cached = _cache_path(
recording.organization_id,
recording.workflow_id,
recording.recording_id,
sample_rate,
)
write_cache_file(cached, pcm_data)
return pcm_data

View file

@ -0,0 +1,41 @@
"""Shared helper for pushing pre-recorded audio frames into a pipeline."""
import uuid
from typing import Awaitable, Callable
from pipecat.frames.frames import (
Frame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
async def queue_recording_audio(
audio_data: bytes,
*,
sample_rate: int,
queue_frame: Callable[[Frame], Awaitable[None]],
) -> None:
"""Push TTSStarted → TTSAudioRaw → TTSStopped frames.
This is the canonical way to play pre-recorded PCM audio through the
pipeline outside of the RecordingRouterProcessor (which uses its own
``push_frame`` path).
Args:
audio_data: Raw 16-bit mono PCM bytes.
sample_rate: Pipeline sample rate (e.g. 16000).
queue_frame: Typically ``task.queue_frame``.
"""
context_id = str(uuid.uuid4())
await queue_frame(TTSStartedFrame(context_id=context_id))
await queue_frame(
TTSAudioRawFrame(
audio=audio_data,
sample_rate=sample_rate,
num_channels=1,
context_id=context_id,
)
)
await queue_frame(TTSStoppedFrame(context_id=context_id))

View file

@ -828,6 +828,15 @@ async def _run_pipeline(
voicemail_detector = None
recording_router = None
# Create recording audio fetcher (used by recording router, audio greetings,
# and audio transition speech)
fetch_audio = create_recording_audio_fetcher(
organization_id=workflow.organization_id,
workflow_id=workflow_id,
pipeline_sample_rate=audio_config.pipeline_sample_rate,
)
engine.set_fetch_recording_audio(fetch_audio)
if not is_realtime:
# Create voicemail detector if enabled in workflow configurations
voicemail_config = (workflow.workflow_configurations or {}).get(
@ -868,10 +877,6 @@ async def _run_pipeline(
# Create recording router if workflow has active recordings
if has_recordings:
fetch_audio = create_recording_audio_fetcher(
organization_id=workflow.organization_id,
pipeline_sample_rate=audio_config.pipeline_sample_rate,
)
recording_router = RecordingRouterProcessor(
audio_sample_rate=audio_config.pipeline_sample_rate,
fetch_recording_audio=fetch_audio,
@ -973,6 +978,7 @@ async def _run_pipeline(
pipeline_metrics_aggregator=pipeline_metrics_aggregator,
audio_config=audio_config,
pre_call_fetch_task=pre_call_fetch_task,
fetch_recording_audio=fetch_audio,
)
register_audio_data_handler(audio_buffer, workflow_run_id, in_memory_audio_buffer)

View file

@ -54,6 +54,8 @@ class NodeDataDTO(BaseModel):
extraction_variables: Optional[list[ExtractionVariableDTO]] = None
add_global_prompt: bool = True
greeting: Optional[str] = None
greeting_type: Optional[str] = None # 'text' or 'audio'
greeting_recording_id: Optional[str] = None
wait_for_user_response: bool = False
wait_for_user_response_timeout: Optional[float] = None
detect_voicemail: bool = False
@ -102,6 +104,8 @@ class EdgeDataDTO(BaseModel):
label: str = Field(..., min_length=1)
condition: str = Field(..., min_length=1)
transition_speech: Optional[str] = None
transition_speech_type: Optional[str] = None # 'text' or 'audio'
transition_speech_recording_id: Optional[str] = None
class RFEdgeDTO(BaseModel):

View file

@ -1,14 +1,12 @@
"""Service for duplicating workflows including recordings."""
import copy
import json
import posixpath
import uuid
from loguru import logger
from api.db import db_client
from api.db.workflow_recording_client import generate_short_id
from api.enums import StorageBackend
from api.services.storage import get_storage_for_backend, storage_fs
@ -41,16 +39,6 @@ def _regenerate_trigger_uuids(workflow_definition: dict) -> dict:
return updated_definition
async def _generate_unique_recording_id() -> str:
"""Generate a globally unique short recording ID."""
for _ in range(10):
rid = generate_short_id(8)
exists = await db_client.check_recording_id_exists(rid)
if not exists:
return rid
raise RuntimeError("Failed to generate unique recording ID")
async def duplicate_workflow(
workflow_id: int,
organization_id: int,
@ -130,29 +118,15 @@ async def duplicate_workflow(
organization_id=organization_id,
)
# 6. Copy recordings with new IDs and storage paths scoped to new workflow
recording_id_map = await _duplicate_recordings(
# 6. Copy recordings (recording_ids are preserved since they're scoped per workflow)
await _duplicate_recordings(
source_workflow_id=workflow_id,
new_workflow_id=new_workflow.id,
organization_id=organization_id,
user_id=user_id,
)
# 7. Replace old recording IDs with new ones in the workflow definition
if recording_id_map:
workflow_definition = _replace_recording_ids(
workflow_definition, recording_id_map
)
new_workflow = await db_client.update_workflow(
workflow_id=new_workflow.id,
name=None,
workflow_definition=workflow_definition,
template_context_variables=None,
workflow_configurations=None,
organization_id=organization_id,
)
# 8. Sync triggers for the new workflow
# 7. Sync triggers for the new workflow
if workflow_definition:
trigger_paths = _extract_trigger_paths(workflow_definition)
if trigger_paths:
@ -170,34 +144,28 @@ async def _duplicate_recordings(
new_workflow_id: int,
organization_id: int,
user_id: int,
) -> dict[str, str]:
) -> None:
"""Duplicate all recordings for a workflow.
Copies each recording file to a new storage path scoped under the new
workflow ID, and creates new DB records pointing to the copied files.
Returns:
Mapping of old_recording_id -> new_recording_id
workflow ID. Recording IDs are preserved since they are unique per
(org, workflow).
"""
recordings = await db_client.get_recordings_for_workflow(
recordings = await db_client.get_recordings(
workflow_id=source_workflow_id,
organization_id=organization_id,
)
if not recordings:
return {}
recording_id_map: dict[str, str] = {}
return
for rec in recordings:
try:
new_recording_id = await _generate_unique_recording_id()
# Build new storage key: recordings/{org_id}/{new_workflow_id}/{new_recording_id}/{filename}
# Build new storage key: recordings/{org_id}/{new_workflow_id}/{recording_id}/{filename}
filename = posixpath.basename(rec.storage_key)
new_storage_key = (
f"recordings/{organization_id}"
f"/{new_workflow_id}/{new_recording_id}"
f"/{new_workflow_id}/{rec.recording_id}"
f"/{filename}"
)
@ -211,7 +179,7 @@ async def _duplicate_recordings(
continue
await db_client.create_recording(
recording_id=new_recording_id,
recording_id=rec.recording_id,
workflow_id=new_workflow_id,
organization_id=organization_id,
tts_provider=rec.tts_provider,
@ -224,34 +192,12 @@ async def _duplicate_recordings(
metadata=copy.deepcopy(rec.recording_metadata),
)
recording_id_map[rec.recording_id] = new_recording_id
logger.info(
f"Duplicated recording {rec.recording_id} -> {new_recording_id}"
)
logger.info(f"Duplicated recording {rec.recording_id}")
except Exception as e:
logger.error(f"Error duplicating recording {rec.recording_id}: {e}")
continue
return recording_id_map
def _replace_recording_ids(
workflow_definition: dict,
recording_id_map: dict[str, str],
) -> dict:
"""Replace old recording IDs with new ones throughout the workflow definition.
Uses JSON serialization to do a thorough find-and-replace across all
nested fields (node prompts, data, etc.).
"""
definition_str = json.dumps(workflow_definition)
for old_id, new_id in recording_id_map.items():
definition_str = definition_str.replace(old_id, new_id)
return json.loads(definition_str)
async def _copy_storage_object(
source_key: str, dest_key: str, storage_backend: str

View file

@ -1,5 +1,6 @@
from typing import TYPE_CHECKING, Awaitable, Callable, Optional, Union
from api.services.pipecat.recording_playback import queue_recording_audio
from api.services.workflow.disposition_mapper import (
apply_disposition_mapping,
get_organization_id_from_workflow_run,
@ -114,6 +115,9 @@ class PipecatEngine:
# Audio configuration (set via set_audio_config from _run_pipeline)
self._audio_config = None
# Recording audio fetcher (set via set_fetch_recording_audio from _run_pipeline)
self._fetch_recording_audio = None
# True when the workflow has active recordings; enables recording
# response mode instructions on all nodes for in-context learning.
self._has_recordings: bool = has_recordings
@ -191,6 +195,8 @@ class PipecatEngine:
name: str,
transition_to_node: str,
transition_speech: Optional[str] = None,
transition_speech_type: Optional[str] = None,
transition_speech_recording_id: Optional[str] = None,
):
async def transition_func(function_call_params: FunctionCallParams) -> None:
"""Inner function that handles the node change tool calls"""
@ -204,8 +210,33 @@ class PipecatEngine:
# Perform variable extraction before transitioning to new node
await self._perform_variable_extraction_if_needed(self._current_node)
# Queue transition speech before switching nodes
if transition_speech:
# Queue transition speech/audio before switching nodes
speech_type = transition_speech_type or "text"
if (
speech_type == "audio"
and transition_speech_recording_id
and self._fetch_recording_audio
):
logger.info(
f"Playing transition audio: {transition_speech_recording_id}"
)
self._queued_speech_mute_state = "waiting"
audio_data = await self._fetch_recording_audio(
transition_speech_recording_id
)
if audio_data:
await queue_recording_audio(
audio_data,
sample_rate=self._audio_config.pipeline_sample_rate
if self._audio_config
else 16000,
queue_frame=self.task.queue_frame,
)
else:
logger.warning(
f"Failed to fetch transition audio {transition_speech_recording_id}"
)
elif transition_speech:
logger.info(f"Playing transition speech: {transition_speech}")
self._queued_speech_mute_state = "waiting"
await self.task.queue_frame(
@ -259,6 +290,8 @@ class PipecatEngine:
name: str,
transition_to_node: str,
transition_speech: Optional[str] = None,
transition_speech_type: Optional[str] = None,
transition_speech_recording_id: Optional[str] = None,
):
logger.debug(
f"Registering function {name} to transition to node {transition_to_node} with LLM"
@ -266,7 +299,11 @@ class PipecatEngine:
# Create transition function
transition_func = await self._create_transition_func(
name, transition_to_node, transition_speech
name,
transition_to_node,
transition_speech,
transition_speech_type,
transition_speech_recording_id,
)
# Register function with LLM
@ -442,6 +479,8 @@ class PipecatEngine:
outgoing_edge.get_function_name(),
outgoing_edge.target,
outgoing_edge.transition_speech,
outgoing_edge.data.transition_speech_type,
outgoing_edge.data.transition_speech_recording_id,
)
# Register custom tool handlers for this node
@ -533,11 +572,27 @@ class PipecatEngine:
# Setup LLM Context with Prompts and Functions
await self._setup_llm_context(node)
def get_start_greeting(self) -> Optional[str]:
"""Return the rendered greeting for the start node, or None if not configured."""
def get_start_greeting(self) -> Optional[tuple[str, Optional[str]]]:
"""Return the greeting info for the start node, or None if not configured.
Returns:
A tuple of (greeting_type, value) where:
- ("text", rendered_text) for text greetings spoken via TTS
- ("audio", recording_id) for pre-recorded audio greetings
Or None if no greeting is configured.
"""
start_node = self.workflow.nodes.get(self.workflow.start_node_id)
if start_node and start_node.greeting:
return self._format_prompt(start_node.greeting)
if not start_node:
return None
greeting_type = start_node.greeting_type or "text"
if greeting_type == "audio" and start_node.greeting_recording_id:
return ("audio", start_node.greeting_recording_id)
if start_node.greeting:
return ("text", self._format_prompt(start_node.greeting))
return None
async def _handle_end_node(self, node: Node) -> None:
@ -698,6 +753,10 @@ class PipecatEngine:
"""Set the audio configuration for the pipeline."""
self._audio_config = audio_config
def set_fetch_recording_audio(self, fetch_fn) -> None:
"""Set the recording audio fetcher callback."""
self._fetch_recording_audio = fetch_fn
def set_mute_pipeline(self, mute: bool) -> None:
"""Set the pipeline mute state.

View file

@ -16,6 +16,7 @@ from loguru import logger
from api.db import db_client
from api.enums import ToolCategory, WorkflowRunMode
from api.services.pipecat.recording_playback import queue_recording_audio
from api.services.telephony.call_transfer_manager import get_call_transfer_manager
from api.services.telephony.factory import get_telephony_provider
from api.services.telephony.transfer_event_protocol import TransferContext
@ -77,6 +78,42 @@ class CustomToolManager:
self._engine = engine
self._organization_id: Optional[int] = None
async def _play_config_message(
self, config: dict, *, append_to_context: bool = False
) -> bool:
"""Play a message from tool config — text or pre-recorded audio.
Returns True if a message was queued, False otherwise.
"""
message_type = config.get("messageType", "none")
if message_type == "audio":
recording_id = config.get("audioRecordingId", "")
if recording_id and self._engine._fetch_recording_audio:
audio_data = await self._engine._fetch_recording_audio(recording_id)
if audio_data:
await queue_recording_audio(
audio_data,
sample_rate=self._engine._audio_config.pipeline_sample_rate
if self._engine._audio_config
else 16000,
queue_frame=self._engine.task.queue_frame,
)
return True
else:
logger.warning(f"Failed to fetch recording {recording_id}")
return False
if message_type == "custom":
custom_message = config.get("customMessage", "")
if custom_message:
await self._engine.task.queue_frame(
TTSSpeakFrame(custom_message, append_to_context=append_to_context)
)
return True
return False
async def get_organization_id(self) -> Optional[int]:
"""Get and cache the organization ID from workflow run."""
if self._organization_id is None:
@ -250,9 +287,29 @@ class CustomToolManager:
try:
# Queue custom message before executing the API call
# Queue custom message (text or audio) before executing the API call
config = tool.definition.get("config", {}) if tool.definition else {}
custom_msg_type = config.get("customMessageType", "text")
custom_message = config.get("customMessage", "")
if custom_message:
if custom_msg_type == "audio":
recording_id = config.get("customMessageRecordingId", "")
if recording_id and self._engine._fetch_recording_audio:
logger.info(
f"Playing audio message before HTTP tool: {recording_id}"
)
self._engine._queued_speech_mute_state = "waiting"
audio_data = await self._engine._fetch_recording_audio(
recording_id
)
if audio_data:
await queue_recording_audio(
audio_data,
sample_rate=self._engine._audio_config.pipeline_sample_rate
if self._engine._audio_config
else 16000,
queue_frame=self._engine.task.queue_frame,
)
elif custom_message:
logger.info(
f"Playing custom message before HTTP tool: {custom_message}"
)
@ -299,8 +356,6 @@ class CustomToolManager:
try:
# Get the end call configuration
config = tool.definition.get("config", {})
message_type = config.get("messageType", "none")
custom_message = config.get("customMessage", "")
# Handle end call reason if enabled
end_call_reason_enabled = config.get("endCallReason", False)
@ -322,10 +377,8 @@ class CustomToolManager:
properties=properties,
)
if message_type == "custom" and custom_message:
# Queue the custom message to be spoken
logger.info(f"Playing custom goodbye message: {custom_message}")
await self._engine.task.queue_frame(TTSSpeakFrame(custom_message))
played = await self._play_config_message(config)
if played:
# End the call after the message (not immediately)
await self._engine.end_call_with_reason(
EndTaskReason.END_CALL_TOOL_REASON.value,
@ -370,8 +423,6 @@ class CustomToolManager:
# Get the transfer call configuration
config = tool.definition.get("config", {})
destination = config.get("destination", "")
message_type = config.get("messageType", "none")
custom_message = config.get("customMessage", "")
timeout_seconds = config.get(
"timeout", 30
) # Default 30 seconds if not configured
@ -443,10 +494,9 @@ class CustomToolManager:
)
return
if message_type == "custom" and custom_message:
logger.info(f"Playing pre-transfer message: {custom_message}")
played = await self._play_config_message(config)
if played:
self._engine._queued_speech_mute_state = "waiting"
await self._engine.task.queue_frame(TTSSpeakFrame(custom_message))
# Get organization ID for provider configuration
organization_id = await self.get_organization_id()

View file

@ -77,6 +77,8 @@ class Node:
self.extraction_variables = data.extraction_variables
self.add_global_prompt = data.add_global_prompt
self.greeting = data.greeting
self.greeting_type = data.greeting_type
self.greeting_recording_id = data.greeting_recording_id
self.detect_voicemail = data.detect_voicemail
self.delayed_start = data.delayed_start
self.delayed_start_duration = data.delayed_start_duration

View file

@ -0,0 +1,587 @@
"""Tests for text and audio playback in greetings, transitions, and tool messages.
Verifies that:
- Text mode produces TTSSpeakFrame
- Audio mode produces TTSStartedFrame -> TTSAudioRawFrame -> TTSStoppedFrame
- Covers: start node greetings, edge transition speech, tool config messages
"""
import asyncio
from typing import Any, Dict, List
from unittest.mock import AsyncMock, Mock, patch
import pytest
from api.services.workflow.dto import (
EdgeDataDTO,
NodeDataDTO,
NodeType,
Position,
ReactFlowDTO,
RFEdgeDTO,
RFNodeDTO,
)
from api.services.workflow.pipecat_engine import PipecatEngine
from api.services.workflow.pipecat_engine_custom_tools import CustomToolManager
from api.services.workflow.workflow import WorkflowGraph
from pipecat.frames.frames import (
Frame,
LLMContextFrame,
TTSAudioRawFrame,
TTSSpeakFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_response_universal import (
LLMAssistantAggregatorParams,
LLMContextAggregatorPair,
)
from pipecat.tests import MockLLMService, MockTTSService
from pipecat.tests.mock_transport import MockTransport
from pipecat.transports.base_transport import TransportParams
# ─── Constants ──────────────────────────────────────────────────
START_PROMPT = "Start Call System Prompt"
END_PROMPT = "End Call System Prompt"
TEXT_GREETING = "Hello, welcome to our service!"
TEXT_TRANSITION = "Thank you for calling, goodbye!"
AUDIO_GREETING_ID = "rec-greeting-001"
AUDIO_TRANSITION_ID = "rec-transition-001"
FAKE_PCM_AUDIO = b"\x00\x01" * 1000 # Fake 16-bit mono PCM data
# ─── Fixtures ───────────────────────────────────────────────────
@pytest.fixture
def text_workflow() -> WorkflowGraph:
"""Start->End workflow with text greeting and text transition speech."""
dto = ReactFlowDTO(
nodes=[
RFNodeDTO(
id="start",
type=NodeType.startNode,
position=Position(x=0, y=0),
data=NodeDataDTO(
name="Start Call",
prompt=START_PROMPT,
is_start=True,
allow_interrupt=False,
add_global_prompt=False,
greeting=TEXT_GREETING,
greeting_type="text",
extraction_enabled=False,
),
),
RFNodeDTO(
id="end",
type=NodeType.endNode,
position=Position(x=0, y=200),
data=NodeDataDTO(
name="End Call",
prompt=END_PROMPT,
is_end=True,
allow_interrupt=False,
add_global_prompt=False,
extraction_enabled=False,
),
),
],
edges=[
RFEdgeDTO(
id="start-end",
source="start",
target="end",
data=EdgeDataDTO(
label="End Call",
condition="When the user says end the call",
transition_speech=TEXT_TRANSITION,
transition_speech_type="text",
),
),
],
)
return WorkflowGraph(dto)
@pytest.fixture
def audio_workflow() -> WorkflowGraph:
"""Start->End workflow with audio greeting and audio transition speech."""
dto = ReactFlowDTO(
nodes=[
RFNodeDTO(
id="start",
type=NodeType.startNode,
position=Position(x=0, y=0),
data=NodeDataDTO(
name="Start Call",
prompt=START_PROMPT,
is_start=True,
allow_interrupt=False,
add_global_prompt=False,
greeting_type="audio",
greeting_recording_id=AUDIO_GREETING_ID,
extraction_enabled=False,
),
),
RFNodeDTO(
id="end",
type=NodeType.endNode,
position=Position(x=0, y=200),
data=NodeDataDTO(
name="End Call",
prompt=END_PROMPT,
is_end=True,
allow_interrupt=False,
add_global_prompt=False,
extraction_enabled=False,
),
),
],
edges=[
RFEdgeDTO(
id="start-end",
source="start",
target="end",
data=EdgeDataDTO(
label="End Call",
condition="When the user says end the call",
transition_speech_type="audio",
transition_speech_recording_id=AUDIO_TRANSITION_ID,
),
),
],
)
return WorkflowGraph(dto)
# ─── Pipeline Helper ────────────────────────────────────────────
async def run_pipeline_and_capture_frames(
workflow: WorkflowGraph,
functions: List[Dict[str, Any]],
fetch_recording_audio=None,
num_text_steps: int = 1,
) -> tuple[MockLLMService, LLMContext, list[Frame]]:
"""Run a pipeline with mock tool calls and capture frames queued via task.queue_frame.
Returns:
Tuple of (llm, context, list of captured frames).
"""
first_step_chunks = MockLLMService.create_multiple_function_call_chunks(functions)
mock_steps = MockLLMService.create_multi_step_responses(
first_step_chunks, num_text_steps=num_text_steps, step_prefix="Response"
)
llm = MockLLMService(mock_steps=mock_steps, chunk_delay=0.001)
tts = MockTTSService(mock_audio_duration_ms=40, frame_delay=0)
mock_transport = MockTransport(
params=TransportParams(
audio_in_enabled=True,
audio_out_enabled=True,
audio_in_sample_rate=16000,
audio_out_sample_rate=16000,
),
)
context = LLMContext()
assistant_params = LLMAssistantAggregatorParams(expect_stripped_words=True)
context_aggregator = LLMContextAggregatorPair(
context, assistant_params=assistant_params
)
engine = PipecatEngine(
llm=llm,
context=context,
workflow=workflow,
call_context_vars={"customer_name": "Test User"},
workflow_run_id=1,
)
if fetch_recording_audio:
engine.set_fetch_recording_audio(fetch_recording_audio)
pipeline = Pipeline(
[llm, tts, mock_transport.output(), context_aggregator.assistant()]
)
task = PipelineTask(pipeline, params=PipelineParams(), enable_rtvi=False)
engine.set_task(task)
# Spy on task.queue_frame to capture all frames queued by the engine
queued_frames: list[Frame] = []
original_queue_frame = task.queue_frame
async def capturing_queue_frame(frame):
queued_frames.append(frame)
await original_queue_frame(frame)
task.queue_frame = capturing_queue_frame
with (
patch(
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run",
new_callable=AsyncMock,
return_value=1,
),
patch(
"api.services.workflow.pipecat_engine.apply_disposition_mapping",
new_callable=AsyncMock,
return_value="completed",
),
):
runner = PipelineRunner()
async def run():
await runner.run(task)
async def initialize():
await asyncio.sleep(0.01)
await engine.initialize()
await engine.set_node(engine.workflow.start_node_id)
await engine.llm.queue_frame(LLMContextFrame(engine.context))
await asyncio.gather(run(), initialize())
return llm, context, queued_frames
# ─── Tests: Start Greeting ──────────────────────────────────────
class TestStartGreeting:
"""Unit tests for PipecatEngine.get_start_greeting()."""
def test_text_greeting_returns_text_tuple(self, text_workflow: WorkflowGraph):
"""Text greeting config should return ('text', rendered_text)."""
engine = PipecatEngine(
workflow=text_workflow,
call_context_vars={},
workflow_run_id=1,
)
result = engine.get_start_greeting()
assert result == ("text", TEXT_GREETING)
def test_audio_greeting_returns_audio_tuple(self, audio_workflow: WorkflowGraph):
"""Audio greeting config should return ('audio', recording_id)."""
engine = PipecatEngine(
workflow=audio_workflow,
call_context_vars={},
workflow_run_id=1,
)
result = engine.get_start_greeting()
assert result == ("audio", AUDIO_GREETING_ID)
def test_no_greeting_returns_none(self):
"""No greeting configured should return None."""
dto = ReactFlowDTO(
nodes=[
RFNodeDTO(
id="start",
type=NodeType.startNode,
position=Position(x=0, y=0),
data=NodeDataDTO(
name="Start",
prompt="Prompt",
is_start=True,
add_global_prompt=False,
extraction_enabled=False,
),
),
RFNodeDTO(
id="end",
type=NodeType.endNode,
position=Position(x=0, y=200),
data=NodeDataDTO(
name="End",
prompt="End",
is_end=True,
add_global_prompt=False,
extraction_enabled=False,
),
),
],
edges=[
RFEdgeDTO(
id="e",
source="start",
target="end",
data=EdgeDataDTO(label="End", condition="End"),
),
],
)
engine = PipecatEngine(
workflow=WorkflowGraph(dto),
call_context_vars={},
workflow_run_id=1,
)
assert engine.get_start_greeting() is None
def test_text_greeting_renders_template_variables(self):
"""Text greeting with {{variable}} placeholders should be rendered."""
dto = ReactFlowDTO(
nodes=[
RFNodeDTO(
id="start",
type=NodeType.startNode,
position=Position(x=0, y=0),
data=NodeDataDTO(
name="Start",
prompt="Prompt",
is_start=True,
add_global_prompt=False,
greeting="Hello {{customer_name}}!",
greeting_type="text",
extraction_enabled=False,
),
),
RFNodeDTO(
id="end",
type=NodeType.endNode,
position=Position(x=0, y=200),
data=NodeDataDTO(
name="End",
prompt="End",
is_end=True,
add_global_prompt=False,
extraction_enabled=False,
),
),
],
edges=[
RFEdgeDTO(
id="e",
source="start",
target="end",
data=EdgeDataDTO(label="End", condition="End"),
),
],
)
engine = PipecatEngine(
workflow=WorkflowGraph(dto),
call_context_vars={"customer_name": "Alice"},
workflow_run_id=1,
)
result = engine.get_start_greeting()
assert result == ("text", "Hello Alice!")
# ─── Tests: Transition Speech (Pipeline) ────────────────────────
class TestTransitionSpeech:
"""Pipeline tests for edge transition speech (text and audio)."""
@pytest.mark.asyncio
async def test_text_transition_queues_tts_speak_frame(
self, text_workflow: WorkflowGraph
):
"""Text transition speech should queue a TTSSpeakFrame with the message."""
functions = [
{
"name": "end_call",
"arguments": {},
"tool_call_id": "call_transition",
},
]
llm, context, queued_frames = await run_pipeline_and_capture_frames(
workflow=text_workflow,
functions=functions,
num_text_steps=2,
)
# Pipeline completes: 1st gen on StartNode, 2nd gen on EndNode
assert llm.get_current_step() == 2
# Verify TTSSpeakFrame was queued with the transition speech text
tts_speak_frames = [f for f in queued_frames if isinstance(f, TTSSpeakFrame)]
transition_frames = [f for f in tts_speak_frames if f.text == TEXT_TRANSITION]
assert len(transition_frames) == 1, (
f"Expected one TTSSpeakFrame with text '{TEXT_TRANSITION}', "
f"got: {[f.text for f in tts_speak_frames]}"
)
# No raw audio frames should be queued for text transition
audio_raw = [f for f in queued_frames if isinstance(f, TTSAudioRawFrame)]
assert len(audio_raw) == 0
@pytest.mark.asyncio
async def test_audio_transition_queues_audio_frames(
self, audio_workflow: WorkflowGraph
):
"""Audio transition speech should queue TTSStarted + TTSAudioRaw + TTSStopped."""
functions = [
{
"name": "end_call",
"arguments": {},
"tool_call_id": "call_transition",
},
]
mock_fetch = AsyncMock(return_value=FAKE_PCM_AUDIO)
llm, context, queued_frames = await run_pipeline_and_capture_frames(
workflow=audio_workflow,
functions=functions,
fetch_recording_audio=mock_fetch,
num_text_steps=2,
)
# Pipeline completes
assert llm.get_current_step() == 2
# Verify fetch was called with the correct recording ID
mock_fetch.assert_called_once_with(AUDIO_TRANSITION_ID)
# Verify the three-frame audio sequence was queued
started = [f for f in queued_frames if isinstance(f, TTSStartedFrame)]
audio = [f for f in queued_frames if isinstance(f, TTSAudioRawFrame)]
stopped = [f for f in queued_frames if isinstance(f, TTSStoppedFrame)]
assert len(started) >= 1, (
f"Expected TTSStartedFrame. "
f"Frame types: {[type(f).__name__ for f in queued_frames]}"
)
assert len(audio) >= 1, "Expected TTSAudioRawFrame"
assert len(stopped) >= 1, "Expected TTSStoppedFrame"
# Verify audio content
assert audio[0].audio == FAKE_PCM_AUDIO
assert audio[0].sample_rate == 16000
assert audio[0].num_channels == 1
# Verify context_id consistency across the three frames
ctx_id = started[0].context_id
assert ctx_id is not None
assert audio[0].context_id == ctx_id
assert stopped[0].context_id == ctx_id
# No TTSSpeakFrame should be queued for audio transition
speak = [f for f in queued_frames if isinstance(f, TTSSpeakFrame)]
assert len(speak) == 0
# ─── Tests: Tool Config Messages ────────────────────────────────
class TestPlayConfigMessage:
"""Unit tests for CustomToolManager._play_config_message."""
@pytest.fixture
def mock_engine(self):
"""Create a mock engine with frame capture on task.queue_frame."""
engine = Mock()
engine._workflow_run_id = 1
engine._call_context_vars = {}
engine._fetch_recording_audio = None
engine._audio_config = None
engine.task = Mock()
engine.llm = Mock()
# Capture frames queued via task.queue_frame
engine._queued_frames = []
async def mock_queue_frame(frame):
engine._queued_frames.append(frame)
engine.task.queue_frame = mock_queue_frame
return engine
@pytest.mark.asyncio
async def test_custom_text_queues_tts_speak_frame(self, mock_engine):
"""messageType='custom' queues TTSSpeakFrame with the message text."""
manager = CustomToolManager(mock_engine)
config = {"messageType": "custom", "customMessage": "Ending your call now."}
result = await manager._play_config_message(config)
assert result is True
frames = mock_engine._queued_frames
assert len(frames) == 1
assert isinstance(frames[0], TTSSpeakFrame)
assert frames[0].text == "Ending your call now."
@pytest.mark.asyncio
async def test_audio_queues_started_raw_stopped_frames(self, mock_engine):
"""messageType='audio' queues TTSStarted + TTSAudioRaw + TTSStopped."""
mock_fetch = AsyncMock(return_value=FAKE_PCM_AUDIO)
mock_engine._fetch_recording_audio = mock_fetch
manager = CustomToolManager(mock_engine)
config = {"messageType": "audio", "audioRecordingId": "rec-end-001"}
result = await manager._play_config_message(config)
assert result is True
mock_fetch.assert_called_once_with("rec-end-001")
frames = mock_engine._queued_frames
assert len(frames) == 3
assert isinstance(frames[0], TTSStartedFrame)
assert isinstance(frames[1], TTSAudioRawFrame)
assert isinstance(frames[2], TTSStoppedFrame)
# Verify audio content
assert frames[1].audio == FAKE_PCM_AUDIO
assert frames[1].sample_rate == 16000
assert frames[1].num_channels == 1
# Context IDs should match across all three frames
ctx_id = frames[0].context_id
assert ctx_id is not None
assert frames[1].context_id == ctx_id
assert frames[2].context_id == ctx_id
@pytest.mark.asyncio
async def test_none_message_type_returns_false(self, mock_engine):
"""messageType='none' returns False without queuing frames."""
manager = CustomToolManager(mock_engine)
result = await manager._play_config_message({"messageType": "none"})
assert result is False
assert len(mock_engine._queued_frames) == 0
@pytest.mark.asyncio
async def test_audio_without_fetch_callback_returns_false(self, mock_engine):
"""Audio without fetch_recording_audio callback returns False."""
mock_engine._fetch_recording_audio = None
manager = CustomToolManager(mock_engine)
config = {"messageType": "audio", "audioRecordingId": "rec-123"}
result = await manager._play_config_message(config)
assert result is False
assert len(mock_engine._queued_frames) == 0
@pytest.mark.asyncio
async def test_audio_with_failed_fetch_returns_false(self, mock_engine):
"""Audio with fetch returning None returns False."""
mock_fetch = AsyncMock(return_value=None)
mock_engine._fetch_recording_audio = mock_fetch
manager = CustomToolManager(mock_engine)
config = {"messageType": "audio", "audioRecordingId": "rec-123"}
result = await manager._play_config_message(config)
assert result is False
mock_fetch.assert_called_once_with("rec-123")
assert len(mock_engine._queued_frames) == 0
@pytest.mark.asyncio
async def test_custom_empty_message_returns_false(self, mock_engine):
"""messageType='custom' with empty message returns False."""
manager = CustomToolManager(mock_engine)
config = {"messageType": "custom", "customMessage": ""}
result = await manager._play_config_message(config)
assert result is False
assert len(mock_engine._queued_frames) == 0