mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-13 08:15:21 +02:00
feat: allow recordings in tool transitions
This commit is contained in:
parent
3a272d3a44
commit
ffe9a99401
38 changed files with 1555 additions and 692 deletions
|
|
@ -1,4 +1,4 @@
|
|||
"""Service for duplicating workflows including recordings."""
|
||||
"""Service for duplicating workflows."""
|
||||
|
||||
import copy
|
||||
import posixpath
|
||||
|
|
@ -44,7 +44,9 @@ async def duplicate_workflow(
|
|||
organization_id: int,
|
||||
user_id: int,
|
||||
):
|
||||
"""Duplicate a workflow including its definition, config, recordings, and triggers.
|
||||
"""Duplicate a workflow including its definition, config, and triggers.
|
||||
|
||||
Recordings are org-scoped and shared, so they are not duplicated.
|
||||
|
||||
Args:
|
||||
workflow_id: The source workflow ID to duplicate
|
||||
|
|
@ -118,15 +120,7 @@ async def duplicate_workflow(
|
|||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
# 6. Copy recordings (recording_ids are preserved since they're scoped per workflow)
|
||||
await _duplicate_recordings(
|
||||
source_workflow_id=workflow_id,
|
||||
new_workflow_id=new_workflow.id,
|
||||
organization_id=organization_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# 7. Sync triggers for the new workflow
|
||||
# 6. Sync triggers for the new workflow
|
||||
if workflow_definition:
|
||||
trigger_paths = _extract_trigger_paths(workflow_definition)
|
||||
if trigger_paths:
|
||||
|
|
@ -139,66 +133,6 @@ async def duplicate_workflow(
|
|||
return new_workflow
|
||||
|
||||
|
||||
async def _duplicate_recordings(
|
||||
source_workflow_id: int,
|
||||
new_workflow_id: int,
|
||||
organization_id: int,
|
||||
user_id: int,
|
||||
) -> None:
|
||||
"""Duplicate all recordings for a workflow.
|
||||
|
||||
Copies each recording file to a new storage path scoped under the new
|
||||
workflow ID. Recording IDs are preserved since they are unique per
|
||||
(org, workflow).
|
||||
"""
|
||||
recordings = await db_client.get_recordings(
|
||||
workflow_id=source_workflow_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
if not recordings:
|
||||
return
|
||||
|
||||
for rec in recordings:
|
||||
try:
|
||||
# Build new storage key: recordings/{org_id}/{new_workflow_id}/{recording_id}/{filename}
|
||||
filename = posixpath.basename(rec.storage_key)
|
||||
new_storage_key = (
|
||||
f"recordings/{organization_id}"
|
||||
f"/{new_workflow_id}/{rec.recording_id}"
|
||||
f"/{filename}"
|
||||
)
|
||||
|
||||
copied = await _copy_storage_object(
|
||||
rec.storage_key, new_storage_key, rec.storage_backend
|
||||
)
|
||||
if not copied:
|
||||
logger.warning(
|
||||
f"Failed to copy recording file {rec.recording_id}, skipping"
|
||||
)
|
||||
continue
|
||||
|
||||
await db_client.create_recording(
|
||||
recording_id=rec.recording_id,
|
||||
workflow_id=new_workflow_id,
|
||||
organization_id=organization_id,
|
||||
tts_provider=rec.tts_provider,
|
||||
tts_model=rec.tts_model,
|
||||
tts_voice_id=rec.tts_voice_id,
|
||||
transcript=rec.transcript,
|
||||
storage_key=new_storage_key,
|
||||
storage_backend=rec.storage_backend,
|
||||
created_by=user_id,
|
||||
metadata=copy.deepcopy(rec.recording_metadata),
|
||||
)
|
||||
|
||||
logger.info(f"Duplicated recording {rec.recording_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error duplicating recording {rec.recording_id}: {e}")
|
||||
continue
|
||||
|
||||
|
||||
async def _copy_storage_object(
|
||||
source_key: str, dest_key: str, storage_backend: str
|
||||
) -> bool:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from typing import TYPE_CHECKING, Awaitable, Callable, Optional, Union
|
||||
|
||||
from api.services.pipecat.recording_playback import queue_recording_audio
|
||||
from api.services.pipecat.audio_playback import play_audio
|
||||
from api.services.workflow.disposition_mapper import (
|
||||
apply_disposition_mapping,
|
||||
get_organization_id_from_workflow_run,
|
||||
|
|
@ -115,6 +115,10 @@ class PipecatEngine:
|
|||
# Audio configuration (set via set_audio_config from _run_pipeline)
|
||||
self._audio_config = None
|
||||
|
||||
# Transport output processor for injecting audio directly into the
|
||||
# output, bypassing STT (set via set_transport_output from _run_pipeline)
|
||||
self._transport_output = None
|
||||
|
||||
# Recording audio fetcher (set via set_fetch_recording_audio from _run_pipeline)
|
||||
self._fetch_recording_audio = None
|
||||
|
||||
|
|
@ -221,16 +225,17 @@ class PipecatEngine:
|
|||
f"Playing transition audio: {transition_speech_recording_id}"
|
||||
)
|
||||
self._queued_speech_mute_state = "waiting"
|
||||
audio_data = await self._fetch_recording_audio(
|
||||
transition_speech_recording_id
|
||||
result = await self._fetch_recording_audio(
|
||||
recording_pk=int(transition_speech_recording_id)
|
||||
)
|
||||
if audio_data:
|
||||
await queue_recording_audio(
|
||||
audio_data,
|
||||
if result:
|
||||
await play_audio(
|
||||
result.audio,
|
||||
sample_rate=self._audio_config.pipeline_sample_rate
|
||||
if self._audio_config
|
||||
else 16000,
|
||||
queue_frame=self.task.queue_frame,
|
||||
queue_frame=self._transport_output.queue_frame,
|
||||
transcript=result.transcript,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
|
|
@ -753,6 +758,14 @@ class PipecatEngine:
|
|||
"""Set the audio configuration for the pipeline."""
|
||||
self._audio_config = audio_config
|
||||
|
||||
def set_transport_output(self, transport_output) -> None:
|
||||
"""Set the transport output processor for direct audio playback.
|
||||
|
||||
Audio queued here bypasses STT and the rest of the pipeline,
|
||||
going straight to the caller.
|
||||
"""
|
||||
self._transport_output = transport_output
|
||||
|
||||
def set_fetch_recording_audio(self, fetch_fn) -> None:
|
||||
"""Set the recording audio fetcher callback."""
|
||||
self._fetch_recording_audio = fetch_fn
|
||||
|
|
|
|||
|
|
@ -168,7 +168,6 @@ def create_aggregation_correction_callback(engine: "PipecatEngine"):
|
|||
reference = engine._current_llm_generation_reference_text
|
||||
|
||||
if not reference:
|
||||
logger.warning("No reference text available for aggregation correction")
|
||||
return corrupted
|
||||
|
||||
# Apply the correction algorithm
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from loguru import logger
|
|||
|
||||
from api.db import db_client
|
||||
from api.enums import ToolCategory, WorkflowRunMode
|
||||
from api.services.pipecat.recording_playback import queue_recording_audio
|
||||
from api.services.pipecat.audio_playback import play_audio, play_audio_loop
|
||||
from api.services.telephony.call_transfer_manager import get_call_transfer_manager
|
||||
from api.services.telephony.factory import get_telephony_provider
|
||||
from api.services.telephony.transfer_event_protocol import TransferContext
|
||||
|
|
@ -28,7 +28,6 @@ from api.services.workflow.tools.custom_tool import (
|
|||
execute_http_tool,
|
||||
tool_to_function_schema,
|
||||
)
|
||||
from api.utils.hold_audio import play_hold_audio_loop
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.frames.frames import (
|
||||
FunctionCallResultProperties,
|
||||
|
|
@ -88,20 +87,23 @@ class CustomToolManager:
|
|||
message_type = config.get("messageType", "none")
|
||||
|
||||
if message_type == "audio":
|
||||
recording_id = config.get("audioRecordingId", "")
|
||||
if recording_id and self._engine._fetch_recording_audio:
|
||||
audio_data = await self._engine._fetch_recording_audio(recording_id)
|
||||
if audio_data:
|
||||
await queue_recording_audio(
|
||||
audio_data,
|
||||
recording_pk = config.get("audioRecordingId")
|
||||
if recording_pk and self._engine._fetch_recording_audio:
|
||||
result = await self._engine._fetch_recording_audio(
|
||||
recording_pk=int(recording_pk)
|
||||
)
|
||||
if result:
|
||||
await play_audio(
|
||||
result.audio,
|
||||
sample_rate=self._engine._audio_config.pipeline_sample_rate
|
||||
if self._engine._audio_config
|
||||
else 16000,
|
||||
queue_frame=self._engine.task.queue_frame,
|
||||
queue_frame=self._engine._transport_output.queue_frame,
|
||||
transcript=result.transcript,
|
||||
)
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"Failed to fetch recording {recording_id}")
|
||||
logger.warning(f"Failed to fetch recording pk={recording_pk}")
|
||||
return False
|
||||
|
||||
if message_type == "custom":
|
||||
|
|
@ -292,22 +294,23 @@ class CustomToolManager:
|
|||
custom_msg_type = config.get("customMessageType", "text")
|
||||
custom_message = config.get("customMessage", "")
|
||||
if custom_msg_type == "audio":
|
||||
recording_id = config.get("customMessageRecordingId", "")
|
||||
if recording_id and self._engine._fetch_recording_audio:
|
||||
recording_pk = config.get("customMessageRecordingId")
|
||||
if recording_pk and self._engine._fetch_recording_audio:
|
||||
logger.info(
|
||||
f"Playing audio message before HTTP tool: {recording_id}"
|
||||
f"Playing audio message before HTTP tool: pk={recording_pk}"
|
||||
)
|
||||
self._engine._queued_speech_mute_state = "waiting"
|
||||
audio_data = await self._engine._fetch_recording_audio(
|
||||
recording_id
|
||||
result = await self._engine._fetch_recording_audio(
|
||||
recording_pk=int(recording_pk)
|
||||
)
|
||||
if audio_data:
|
||||
await queue_recording_audio(
|
||||
audio_data,
|
||||
if result:
|
||||
await play_audio(
|
||||
result.audio,
|
||||
sample_rate=self._engine._audio_config.pipeline_sample_rate
|
||||
if self._engine._audio_config
|
||||
else 16000,
|
||||
queue_frame=self._engine.task.queue_frame,
|
||||
queue_frame=self._engine._transport_output.queue_frame,
|
||||
transcript=result.transcript,
|
||||
)
|
||||
elif custom_message:
|
||||
logger.info(
|
||||
|
|
@ -587,10 +590,10 @@ class CustomToolManager:
|
|||
|
||||
# Start hold music as background task
|
||||
hold_music_task = asyncio.create_task(
|
||||
play_hold_audio_loop(
|
||||
self._engine.task,
|
||||
hold_music_stop_event,
|
||||
sample_rate,
|
||||
play_audio_loop(
|
||||
stop_event=hold_music_stop_event,
|
||||
sample_rate=sample_rate,
|
||||
queue_frame=self._engine._transport_output.queue_frame,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue