mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
feat: allow uploading recording as part of node transition
This commit is contained in:
parent
bb5f56bfb7
commit
65c76ca7ff
36 changed files with 2255 additions and 201 deletions
|
|
@ -54,6 +54,8 @@ class NodeDataDTO(BaseModel):
|
|||
extraction_variables: Optional[list[ExtractionVariableDTO]] = None
|
||||
add_global_prompt: bool = True
|
||||
greeting: Optional[str] = None
|
||||
greeting_type: Optional[str] = None # 'text' or 'audio'
|
||||
greeting_recording_id: Optional[str] = None
|
||||
wait_for_user_response: bool = False
|
||||
wait_for_user_response_timeout: Optional[float] = None
|
||||
detect_voicemail: bool = False
|
||||
|
|
@ -102,6 +104,8 @@ class EdgeDataDTO(BaseModel):
|
|||
label: str = Field(..., min_length=1)
|
||||
condition: str = Field(..., min_length=1)
|
||||
transition_speech: Optional[str] = None
|
||||
transition_speech_type: Optional[str] = None # 'text' or 'audio'
|
||||
transition_speech_recording_id: Optional[str] = None
|
||||
|
||||
|
||||
class RFEdgeDTO(BaseModel):
|
||||
|
|
|
|||
|
|
@ -1,14 +1,12 @@
|
|||
"""Service for duplicating workflows including recordings."""
|
||||
|
||||
import copy
|
||||
import json
|
||||
import posixpath
|
||||
import uuid
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.workflow_recording_client import generate_short_id
|
||||
from api.enums import StorageBackend
|
||||
from api.services.storage import get_storage_for_backend, storage_fs
|
||||
|
||||
|
|
@ -41,16 +39,6 @@ def _regenerate_trigger_uuids(workflow_definition: dict) -> dict:
|
|||
return updated_definition
|
||||
|
||||
|
||||
async def _generate_unique_recording_id() -> str:
|
||||
"""Generate a globally unique short recording ID."""
|
||||
for _ in range(10):
|
||||
rid = generate_short_id(8)
|
||||
exists = await db_client.check_recording_id_exists(rid)
|
||||
if not exists:
|
||||
return rid
|
||||
raise RuntimeError("Failed to generate unique recording ID")
|
||||
|
||||
|
||||
async def duplicate_workflow(
|
||||
workflow_id: int,
|
||||
organization_id: int,
|
||||
|
|
@ -130,29 +118,15 @@ async def duplicate_workflow(
|
|||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
# 6. Copy recordings with new IDs and storage paths scoped to new workflow
|
||||
recording_id_map = await _duplicate_recordings(
|
||||
# 6. Copy recordings (recording_ids are preserved since they're scoped per workflow)
|
||||
await _duplicate_recordings(
|
||||
source_workflow_id=workflow_id,
|
||||
new_workflow_id=new_workflow.id,
|
||||
organization_id=organization_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# 7. Replace old recording IDs with new ones in the workflow definition
|
||||
if recording_id_map:
|
||||
workflow_definition = _replace_recording_ids(
|
||||
workflow_definition, recording_id_map
|
||||
)
|
||||
new_workflow = await db_client.update_workflow(
|
||||
workflow_id=new_workflow.id,
|
||||
name=None,
|
||||
workflow_definition=workflow_definition,
|
||||
template_context_variables=None,
|
||||
workflow_configurations=None,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
# 8. Sync triggers for the new workflow
|
||||
# 7. Sync triggers for the new workflow
|
||||
if workflow_definition:
|
||||
trigger_paths = _extract_trigger_paths(workflow_definition)
|
||||
if trigger_paths:
|
||||
|
|
@ -170,34 +144,28 @@ async def _duplicate_recordings(
|
|||
new_workflow_id: int,
|
||||
organization_id: int,
|
||||
user_id: int,
|
||||
) -> dict[str, str]:
|
||||
) -> None:
|
||||
"""Duplicate all recordings for a workflow.
|
||||
|
||||
Copies each recording file to a new storage path scoped under the new
|
||||
workflow ID, and creates new DB records pointing to the copied files.
|
||||
|
||||
Returns:
|
||||
Mapping of old_recording_id -> new_recording_id
|
||||
workflow ID. Recording IDs are preserved since they are unique per
|
||||
(org, workflow).
|
||||
"""
|
||||
recordings = await db_client.get_recordings_for_workflow(
|
||||
recordings = await db_client.get_recordings(
|
||||
workflow_id=source_workflow_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
if not recordings:
|
||||
return {}
|
||||
|
||||
recording_id_map: dict[str, str] = {}
|
||||
return
|
||||
|
||||
for rec in recordings:
|
||||
try:
|
||||
new_recording_id = await _generate_unique_recording_id()
|
||||
|
||||
# Build new storage key: recordings/{org_id}/{new_workflow_id}/{new_recording_id}/{filename}
|
||||
# Build new storage key: recordings/{org_id}/{new_workflow_id}/{recording_id}/{filename}
|
||||
filename = posixpath.basename(rec.storage_key)
|
||||
new_storage_key = (
|
||||
f"recordings/{organization_id}"
|
||||
f"/{new_workflow_id}/{new_recording_id}"
|
||||
f"/{new_workflow_id}/{rec.recording_id}"
|
||||
f"/{filename}"
|
||||
)
|
||||
|
||||
|
|
@ -211,7 +179,7 @@ async def _duplicate_recordings(
|
|||
continue
|
||||
|
||||
await db_client.create_recording(
|
||||
recording_id=new_recording_id,
|
||||
recording_id=rec.recording_id,
|
||||
workflow_id=new_workflow_id,
|
||||
organization_id=organization_id,
|
||||
tts_provider=rec.tts_provider,
|
||||
|
|
@ -224,34 +192,12 @@ async def _duplicate_recordings(
|
|||
metadata=copy.deepcopy(rec.recording_metadata),
|
||||
)
|
||||
|
||||
recording_id_map[rec.recording_id] = new_recording_id
|
||||
logger.info(
|
||||
f"Duplicated recording {rec.recording_id} -> {new_recording_id}"
|
||||
)
|
||||
logger.info(f"Duplicated recording {rec.recording_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error duplicating recording {rec.recording_id}: {e}")
|
||||
continue
|
||||
|
||||
return recording_id_map
|
||||
|
||||
|
||||
def _replace_recording_ids(
|
||||
workflow_definition: dict,
|
||||
recording_id_map: dict[str, str],
|
||||
) -> dict:
|
||||
"""Replace old recording IDs with new ones throughout the workflow definition.
|
||||
|
||||
Uses JSON serialization to do a thorough find-and-replace across all
|
||||
nested fields (node prompts, data, etc.).
|
||||
"""
|
||||
definition_str = json.dumps(workflow_definition)
|
||||
|
||||
for old_id, new_id in recording_id_map.items():
|
||||
definition_str = definition_str.replace(old_id, new_id)
|
||||
|
||||
return json.loads(definition_str)
|
||||
|
||||
|
||||
async def _copy_storage_object(
|
||||
source_key: str, dest_key: str, storage_backend: str
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from typing import TYPE_CHECKING, Awaitable, Callable, Optional, Union
|
||||
|
||||
from api.services.pipecat.recording_playback import queue_recording_audio
|
||||
from api.services.workflow.disposition_mapper import (
|
||||
apply_disposition_mapping,
|
||||
get_organization_id_from_workflow_run,
|
||||
|
|
@ -114,6 +115,9 @@ class PipecatEngine:
|
|||
# Audio configuration (set via set_audio_config from _run_pipeline)
|
||||
self._audio_config = None
|
||||
|
||||
# Recording audio fetcher (set via set_fetch_recording_audio from _run_pipeline)
|
||||
self._fetch_recording_audio = None
|
||||
|
||||
# True when the workflow has active recordings; enables recording
|
||||
# response mode instructions on all nodes for in-context learning.
|
||||
self._has_recordings: bool = has_recordings
|
||||
|
|
@ -191,6 +195,8 @@ class PipecatEngine:
|
|||
name: str,
|
||||
transition_to_node: str,
|
||||
transition_speech: Optional[str] = None,
|
||||
transition_speech_type: Optional[str] = None,
|
||||
transition_speech_recording_id: Optional[str] = None,
|
||||
):
|
||||
async def transition_func(function_call_params: FunctionCallParams) -> None:
|
||||
"""Inner function that handles the node change tool calls"""
|
||||
|
|
@ -204,8 +210,33 @@ class PipecatEngine:
|
|||
# Perform variable extraction before transitioning to new node
|
||||
await self._perform_variable_extraction_if_needed(self._current_node)
|
||||
|
||||
# Queue transition speech before switching nodes
|
||||
if transition_speech:
|
||||
# Queue transition speech/audio before switching nodes
|
||||
speech_type = transition_speech_type or "text"
|
||||
if (
|
||||
speech_type == "audio"
|
||||
and transition_speech_recording_id
|
||||
and self._fetch_recording_audio
|
||||
):
|
||||
logger.info(
|
||||
f"Playing transition audio: {transition_speech_recording_id}"
|
||||
)
|
||||
self._queued_speech_mute_state = "waiting"
|
||||
audio_data = await self._fetch_recording_audio(
|
||||
transition_speech_recording_id
|
||||
)
|
||||
if audio_data:
|
||||
await queue_recording_audio(
|
||||
audio_data,
|
||||
sample_rate=self._audio_config.pipeline_sample_rate
|
||||
if self._audio_config
|
||||
else 16000,
|
||||
queue_frame=self.task.queue_frame,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to fetch transition audio {transition_speech_recording_id}"
|
||||
)
|
||||
elif transition_speech:
|
||||
logger.info(f"Playing transition speech: {transition_speech}")
|
||||
self._queued_speech_mute_state = "waiting"
|
||||
await self.task.queue_frame(
|
||||
|
|
@ -259,6 +290,8 @@ class PipecatEngine:
|
|||
name: str,
|
||||
transition_to_node: str,
|
||||
transition_speech: Optional[str] = None,
|
||||
transition_speech_type: Optional[str] = None,
|
||||
transition_speech_recording_id: Optional[str] = None,
|
||||
):
|
||||
logger.debug(
|
||||
f"Registering function {name} to transition to node {transition_to_node} with LLM"
|
||||
|
|
@ -266,7 +299,11 @@ class PipecatEngine:
|
|||
|
||||
# Create transition function
|
||||
transition_func = await self._create_transition_func(
|
||||
name, transition_to_node, transition_speech
|
||||
name,
|
||||
transition_to_node,
|
||||
transition_speech,
|
||||
transition_speech_type,
|
||||
transition_speech_recording_id,
|
||||
)
|
||||
|
||||
# Register function with LLM
|
||||
|
|
@ -442,6 +479,8 @@ class PipecatEngine:
|
|||
outgoing_edge.get_function_name(),
|
||||
outgoing_edge.target,
|
||||
outgoing_edge.transition_speech,
|
||||
outgoing_edge.data.transition_speech_type,
|
||||
outgoing_edge.data.transition_speech_recording_id,
|
||||
)
|
||||
|
||||
# Register custom tool handlers for this node
|
||||
|
|
@ -533,11 +572,27 @@ class PipecatEngine:
|
|||
# Setup LLM Context with Prompts and Functions
|
||||
await self._setup_llm_context(node)
|
||||
|
||||
def get_start_greeting(self) -> Optional[str]:
|
||||
"""Return the rendered greeting for the start node, or None if not configured."""
|
||||
def get_start_greeting(self) -> Optional[tuple[str, Optional[str]]]:
|
||||
"""Return the greeting info for the start node, or None if not configured.
|
||||
|
||||
Returns:
|
||||
A tuple of (greeting_type, value) where:
|
||||
- ("text", rendered_text) for text greetings spoken via TTS
|
||||
- ("audio", recording_id) for pre-recorded audio greetings
|
||||
Or None if no greeting is configured.
|
||||
"""
|
||||
start_node = self.workflow.nodes.get(self.workflow.start_node_id)
|
||||
if start_node and start_node.greeting:
|
||||
return self._format_prompt(start_node.greeting)
|
||||
if not start_node:
|
||||
return None
|
||||
|
||||
greeting_type = start_node.greeting_type or "text"
|
||||
|
||||
if greeting_type == "audio" and start_node.greeting_recording_id:
|
||||
return ("audio", start_node.greeting_recording_id)
|
||||
|
||||
if start_node.greeting:
|
||||
return ("text", self._format_prompt(start_node.greeting))
|
||||
|
||||
return None
|
||||
|
||||
async def _handle_end_node(self, node: Node) -> None:
|
||||
|
|
@ -698,6 +753,10 @@ class PipecatEngine:
|
|||
"""Set the audio configuration for the pipeline."""
|
||||
self._audio_config = audio_config
|
||||
|
||||
def set_fetch_recording_audio(self, fetch_fn) -> None:
|
||||
"""Set the recording audio fetcher callback."""
|
||||
self._fetch_recording_audio = fetch_fn
|
||||
|
||||
def set_mute_pipeline(self, mute: bool) -> None:
|
||||
"""Set the pipeline mute state.
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from loguru import logger
|
|||
|
||||
from api.db import db_client
|
||||
from api.enums import ToolCategory, WorkflowRunMode
|
||||
from api.services.pipecat.recording_playback import queue_recording_audio
|
||||
from api.services.telephony.call_transfer_manager import get_call_transfer_manager
|
||||
from api.services.telephony.factory import get_telephony_provider
|
||||
from api.services.telephony.transfer_event_protocol import TransferContext
|
||||
|
|
@ -77,6 +78,42 @@ class CustomToolManager:
|
|||
self._engine = engine
|
||||
self._organization_id: Optional[int] = None
|
||||
|
||||
async def _play_config_message(
|
||||
self, config: dict, *, append_to_context: bool = False
|
||||
) -> bool:
|
||||
"""Play a message from tool config — text or pre-recorded audio.
|
||||
|
||||
Returns True if a message was queued, False otherwise.
|
||||
"""
|
||||
message_type = config.get("messageType", "none")
|
||||
|
||||
if message_type == "audio":
|
||||
recording_id = config.get("audioRecordingId", "")
|
||||
if recording_id and self._engine._fetch_recording_audio:
|
||||
audio_data = await self._engine._fetch_recording_audio(recording_id)
|
||||
if audio_data:
|
||||
await queue_recording_audio(
|
||||
audio_data,
|
||||
sample_rate=self._engine._audio_config.pipeline_sample_rate
|
||||
if self._engine._audio_config
|
||||
else 16000,
|
||||
queue_frame=self._engine.task.queue_frame,
|
||||
)
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"Failed to fetch recording {recording_id}")
|
||||
return False
|
||||
|
||||
if message_type == "custom":
|
||||
custom_message = config.get("customMessage", "")
|
||||
if custom_message:
|
||||
await self._engine.task.queue_frame(
|
||||
TTSSpeakFrame(custom_message, append_to_context=append_to_context)
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def get_organization_id(self) -> Optional[int]:
|
||||
"""Get and cache the organization ID from workflow run."""
|
||||
if self._organization_id is None:
|
||||
|
|
@ -250,9 +287,29 @@ class CustomToolManager:
|
|||
|
||||
try:
|
||||
# Queue custom message before executing the API call
|
||||
# Queue custom message (text or audio) before executing the API call
|
||||
config = tool.definition.get("config", {}) if tool.definition else {}
|
||||
custom_msg_type = config.get("customMessageType", "text")
|
||||
custom_message = config.get("customMessage", "")
|
||||
if custom_message:
|
||||
if custom_msg_type == "audio":
|
||||
recording_id = config.get("customMessageRecordingId", "")
|
||||
if recording_id and self._engine._fetch_recording_audio:
|
||||
logger.info(
|
||||
f"Playing audio message before HTTP tool: {recording_id}"
|
||||
)
|
||||
self._engine._queued_speech_mute_state = "waiting"
|
||||
audio_data = await self._engine._fetch_recording_audio(
|
||||
recording_id
|
||||
)
|
||||
if audio_data:
|
||||
await queue_recording_audio(
|
||||
audio_data,
|
||||
sample_rate=self._engine._audio_config.pipeline_sample_rate
|
||||
if self._engine._audio_config
|
||||
else 16000,
|
||||
queue_frame=self._engine.task.queue_frame,
|
||||
)
|
||||
elif custom_message:
|
||||
logger.info(
|
||||
f"Playing custom message before HTTP tool: {custom_message}"
|
||||
)
|
||||
|
|
@ -299,8 +356,6 @@ class CustomToolManager:
|
|||
try:
|
||||
# Get the end call configuration
|
||||
config = tool.definition.get("config", {})
|
||||
message_type = config.get("messageType", "none")
|
||||
custom_message = config.get("customMessage", "")
|
||||
|
||||
# Handle end call reason if enabled
|
||||
end_call_reason_enabled = config.get("endCallReason", False)
|
||||
|
|
@ -322,10 +377,8 @@ class CustomToolManager:
|
|||
properties=properties,
|
||||
)
|
||||
|
||||
if message_type == "custom" and custom_message:
|
||||
# Queue the custom message to be spoken
|
||||
logger.info(f"Playing custom goodbye message: {custom_message}")
|
||||
await self._engine.task.queue_frame(TTSSpeakFrame(custom_message))
|
||||
played = await self._play_config_message(config)
|
||||
if played:
|
||||
# End the call after the message (not immediately)
|
||||
await self._engine.end_call_with_reason(
|
||||
EndTaskReason.END_CALL_TOOL_REASON.value,
|
||||
|
|
@ -370,8 +423,6 @@ class CustomToolManager:
|
|||
# Get the transfer call configuration
|
||||
config = tool.definition.get("config", {})
|
||||
destination = config.get("destination", "")
|
||||
message_type = config.get("messageType", "none")
|
||||
custom_message = config.get("customMessage", "")
|
||||
timeout_seconds = config.get(
|
||||
"timeout", 30
|
||||
) # Default 30 seconds if not configured
|
||||
|
|
@ -443,10 +494,9 @@ class CustomToolManager:
|
|||
)
|
||||
return
|
||||
|
||||
if message_type == "custom" and custom_message:
|
||||
logger.info(f"Playing pre-transfer message: {custom_message}")
|
||||
played = await self._play_config_message(config)
|
||||
if played:
|
||||
self._engine._queued_speech_mute_state = "waiting"
|
||||
await self._engine.task.queue_frame(TTSSpeakFrame(custom_message))
|
||||
|
||||
# Get organization ID for provider configuration
|
||||
organization_id = await self.get_organization_id()
|
||||
|
|
|
|||
|
|
@ -77,6 +77,8 @@ class Node:
|
|||
self.extraction_variables = data.extraction_variables
|
||||
self.add_global_prompt = data.add_global_prompt
|
||||
self.greeting = data.greeting
|
||||
self.greeting_type = data.greeting_type
|
||||
self.greeting_recording_id = data.greeting_recording_id
|
||||
self.detect_voicemail = data.detect_voicemail
|
||||
self.delayed_start = data.delayed_start
|
||||
self.delayed_start_duration = data.delayed_start_duration
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue