mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-19 08:28:10 +02:00
add call transfer skeleton
This commit is contained in:
parent
e8005042e2
commit
c990af2a16
8 changed files with 450 additions and 25 deletions
|
|
@ -20,6 +20,10 @@ from api.services.workflow.tools.custom_tool import (
|
|||
execute_http_tool,
|
||||
tool_to_function_schema,
|
||||
)
|
||||
from api.services.workflow.transfer_event_protocol import (
|
||||
TransferEventType,
|
||||
wait_for_transfer_signal,
|
||||
)
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.frames.frames import FunctionCallResultProperties, TTSSpeakFrame
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
|
|
@ -139,6 +143,9 @@ class CustomToolManager:
|
|||
if tool.category == ToolCategory.END_CALL.value:
|
||||
return self._create_end_call_handler(tool, function_name)
|
||||
|
||||
if tool.category == ToolCategory.TRANSFER_CALL.value:
|
||||
return self._create_transfer_call_handler(tool, function_name)
|
||||
|
||||
return self._create_http_tool_handler(tool, function_name)
|
||||
|
||||
def _create_http_tool_handler(self, tool: Any, function_name: str):
|
||||
|
|
@ -230,3 +237,100 @@ class CustomToolManager:
|
|||
)
|
||||
|
||||
return end_call_handler
|
||||
|
||||
def _create_transfer_call_handler(self, tool: Any, function_name: str):
|
||||
"""Create a handler function for a transfer call tool.
|
||||
|
||||
Args:
|
||||
tool: The ToolModel instance
|
||||
function_name: The function name used by the LLM
|
||||
|
||||
Returns:
|
||||
Async handler function for the transfer call tool
|
||||
"""
|
||||
|
||||
async def transfer_call_handler(
|
||||
function_call_params: FunctionCallParams,
|
||||
) -> None:
|
||||
logger.info(f"Transfer Call Tool EXECUTED: {function_name}")
|
||||
|
||||
try:
|
||||
# Get the transfer call configuration
|
||||
config = tool.definition.get("config", {})
|
||||
transfer_number = config.get("transferNumber", "")
|
||||
transfer_message = config.get("transferMessage", "")
|
||||
|
||||
if not transfer_number:
|
||||
logger.error("Transfer number not configured")
|
||||
await function_call_params.result_callback(
|
||||
{"status": "error", "error": "Transfer number not configured"}
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(f"Initiating transfer to: {transfer_number}")
|
||||
|
||||
# Play transfer message if configured
|
||||
if transfer_message:
|
||||
logger.info(f"Playing transfer message: {transfer_message}")
|
||||
await self._engine.task.queue_frame(TTSSpeakFrame(transfer_message))
|
||||
|
||||
# Store transfer intent in gathered context
|
||||
self._engine._gathered_context["transfer_requested"] = True
|
||||
self._engine._gathered_context["transfer_number"] = transfer_number
|
||||
|
||||
# Wait for external signal to proceed with transfer (30s timeout)
|
||||
workflow_run_id = self._engine._workflow_run_id
|
||||
logger.info(
|
||||
f"Waiting for transfer signal for workflow_run_id: {workflow_run_id}"
|
||||
)
|
||||
|
||||
transfer_event = await wait_for_transfer_signal(
|
||||
workflow_run_id=workflow_run_id,
|
||||
timeout_seconds=30.0,
|
||||
)
|
||||
|
||||
if transfer_event is None:
|
||||
# Timeout - transfer failed
|
||||
logger.warning("Transfer signal timed out")
|
||||
self._engine._gathered_context["transfer_status"] = "timed_out"
|
||||
await function_call_params.result_callback(
|
||||
{"status": "error", "error": "Transfer signal timed out"}
|
||||
)
|
||||
return
|
||||
|
||||
if transfer_event.type == TransferEventType.TRANSFER_CANCEL.value:
|
||||
# Cancelled - transfer failed
|
||||
logger.info("Transfer was cancelled")
|
||||
self._engine._gathered_context["transfer_status"] = "cancelled"
|
||||
await function_call_params.result_callback(
|
||||
{"status": "error", "error": "Transfer was cancelled"}
|
||||
)
|
||||
return
|
||||
|
||||
# Success - proceed with transfer
|
||||
logger.info("Transfer signal received - proceeding with transfer")
|
||||
self._engine._gathered_context["transfer_status"] = "success"
|
||||
|
||||
# Lets send result callback so that timeout task is cancelled. Lets not
|
||||
# run llm
|
||||
await function_call_params.result_callback(
|
||||
{"status": "error", "error": "Transfer was cancelled"},
|
||||
properties=FunctionCallResultProperties(run_llm=False),
|
||||
)
|
||||
|
||||
# Terminate the call after the call is added to the conference
|
||||
await self._engine.end_call_with_reason(
|
||||
EndTaskReason.CALL_TRANSFERRED.value,
|
||||
abort_immediately=True,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Transfer call tool '{function_name}' execution failed: {e}"
|
||||
)
|
||||
await function_call_params.result_callback(
|
||||
{"status": "error", "error": str(e)},
|
||||
properties=properties,
|
||||
)
|
||||
|
||||
return transfer_call_handler
|
||||
|
|
|
|||
127
api/services/workflow/transfer_event_protocol.py
Normal file
127
api/services/workflow/transfer_event_protocol.py
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
"""Transfer call event protocol for Redis-based coordination.
|
||||
|
||||
Simple protocol for awaiting transfer completion signal from external trigger.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from dataclasses import asdict, dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
from loguru import logger
|
||||
|
||||
from api.constants import REDIS_URL
|
||||
|
||||
|
||||
class TransferEventType(str, Enum):
|
||||
"""Types of transfer events."""
|
||||
|
||||
TRANSFER_PROCEED = "transfer_proceed"
|
||||
TRANSFER_CANCEL = "transfer_cancel"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransferEvent:
|
||||
"""Event sent to signal transfer status."""
|
||||
|
||||
type: str
|
||||
workflow_run_id: int
|
||||
message: Optional[str] = None
|
||||
|
||||
def to_json(self) -> str:
|
||||
return json.dumps(asdict(self))
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, data: str) -> "TransferEvent":
|
||||
return cls(**json.loads(data))
|
||||
|
||||
|
||||
class TransferRedisChannels:
|
||||
"""Redis channel naming for transfer events."""
|
||||
|
||||
@staticmethod
|
||||
def transfer_await(workflow_run_id: int) -> str:
|
||||
"""Channel for awaiting transfer completion."""
|
||||
return f"transfer:await:{workflow_run_id}"
|
||||
|
||||
|
||||
async def wait_for_transfer_signal(
|
||||
workflow_run_id: int,
|
||||
timeout_seconds: float = 30.0,
|
||||
) -> Optional[TransferEvent]:
|
||||
"""Wait for a transfer signal on Redis pub/sub.
|
||||
|
||||
Args:
|
||||
workflow_run_id: The workflow run ID to wait for
|
||||
timeout_seconds: How long to wait before timing out
|
||||
|
||||
Returns:
|
||||
TransferEvent if received, None if timed out
|
||||
"""
|
||||
channel = TransferRedisChannels.transfer_await(workflow_run_id)
|
||||
redis_client = await aioredis.from_url(REDIS_URL, decode_responses=True)
|
||||
pubsub = redis_client.pubsub()
|
||||
|
||||
try:
|
||||
await pubsub.subscribe(channel)
|
||||
logger.info(f"Waiting for transfer signal on channel: {channel}")
|
||||
|
||||
async def listen_for_event() -> Optional[TransferEvent]:
|
||||
async for message in pubsub.listen():
|
||||
if message["type"] == "message":
|
||||
event = TransferEvent.from_json(message["data"])
|
||||
logger.info(f"Received transfer event: {event.type}")
|
||||
return event
|
||||
# pubsub.listen() ended (connection closed) - return None
|
||||
return None
|
||||
|
||||
# Wait with timeout
|
||||
event = await asyncio.wait_for(listen_for_event(), timeout=timeout_seconds)
|
||||
return event
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Transfer signal timed out after {timeout_seconds}s")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error waiting for transfer signal: {e}")
|
||||
return None
|
||||
finally:
|
||||
await pubsub.unsubscribe(channel)
|
||||
await pubsub.aclose()
|
||||
await redis_client.aclose()
|
||||
|
||||
|
||||
async def send_transfer_signal(
|
||||
workflow_run_id: int,
|
||||
event_type: TransferEventType = TransferEventType.TRANSFER_PROCEED,
|
||||
message: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Send a transfer signal to unblock a waiting handler.
|
||||
|
||||
Args:
|
||||
workflow_run_id: The workflow run ID to signal
|
||||
event_type: Type of signal (proceed or cancel)
|
||||
message: Optional message
|
||||
|
||||
Returns:
|
||||
True if signal was sent successfully
|
||||
"""
|
||||
channel = TransferRedisChannels.transfer_await(workflow_run_id)
|
||||
redis_client = await aioredis.from_url(REDIS_URL, decode_responses=True)
|
||||
|
||||
try:
|
||||
event = TransferEvent(
|
||||
type=event_type.value,
|
||||
workflow_run_id=workflow_run_id,
|
||||
message=message,
|
||||
)
|
||||
await redis_client.publish(channel, event.to_json())
|
||||
logger.info(f"Sent transfer signal to channel: {channel}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending transfer signal: {e}")
|
||||
return False
|
||||
finally:
|
||||
await redis_client.aclose()
|
||||
Loading…
Add table
Add a link
Reference in a new issue