mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-13 08:15:21 +02:00
add call transfer skeleton
This commit is contained in:
parent
e8005042e2
commit
c990af2a16
8 changed files with 450 additions and 25 deletions
|
|
@ -122,7 +122,8 @@ class ToolCategory(Enum):
|
|||
|
||||
HTTP_API = "http_api" # Custom HTTP API calls (implemented)
|
||||
END_CALL = "end_call" # End call tool
|
||||
NATIVE = "native" # Built-in integrations (future: call_transfer, dtmf_input)
|
||||
TRANSFER_CALL = "transfer_call" # Transfer call to another number
|
||||
NATIVE = "native" # Built-in integrations (future: dtmf_input)
|
||||
INTEGRATION = "integration" # Third-party integrations (future: Google Calendar, Salesforce, etc.)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -30,6 +30,10 @@ from api.services.telephony.factory import (
|
|||
get_all_telephony_providers,
|
||||
get_telephony_provider,
|
||||
)
|
||||
from api.services.workflow.transfer_event_protocol import (
|
||||
TransferEventType,
|
||||
send_transfer_signal,
|
||||
)
|
||||
from api.utils.common import get_backend_endpoints
|
||||
from api.utils.telephony_helper import (
|
||||
generic_hangup_response,
|
||||
|
|
@ -1480,3 +1484,52 @@ async def handle_cloudonix_cdr(request: Request):
|
|||
)
|
||||
|
||||
return {"status": "success"}
|
||||
|
||||
|
||||
class TransferSignalRequest(BaseModel):
|
||||
"""Request to send a transfer signal."""
|
||||
|
||||
action: str = "proceed" # "proceed" or "cancel"
|
||||
message: Optional[str] = None
|
||||
|
||||
|
||||
@router.post("/transfer-signal/{workflow_run_id}")
|
||||
async def send_transfer_signal_endpoint(
|
||||
workflow_run_id: int,
|
||||
request: TransferSignalRequest,
|
||||
):
|
||||
"""Send a transfer signal to unblock a waiting transfer call handler.
|
||||
|
||||
This is a POC endpoint to test the transfer call flow.
|
||||
Call this endpoint to signal that the transfer is ready to proceed.
|
||||
|
||||
Args:
|
||||
workflow_run_id: The workflow run ID waiting for the signal
|
||||
request: The signal action (proceed or cancel) and optional message
|
||||
"""
|
||||
logger.info(
|
||||
f"[run {workflow_run_id}] Received transfer signal request: action={request.action}"
|
||||
)
|
||||
|
||||
event_type = (
|
||||
TransferEventType.TRANSFER_PROCEED
|
||||
if request.action == "proceed"
|
||||
else TransferEventType.TRANSFER_CANCEL
|
||||
)
|
||||
|
||||
success = await send_transfer_signal(
|
||||
workflow_run_id=workflow_run_id,
|
||||
event_type=event_type,
|
||||
message=request.message,
|
||||
)
|
||||
|
||||
if success:
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Transfer signal sent: {request.action}",
|
||||
}
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to send transfer signal",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -20,6 +20,10 @@ from api.services.workflow.tools.custom_tool import (
|
|||
execute_http_tool,
|
||||
tool_to_function_schema,
|
||||
)
|
||||
from api.services.workflow.transfer_event_protocol import (
|
||||
TransferEventType,
|
||||
wait_for_transfer_signal,
|
||||
)
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.frames.frames import FunctionCallResultProperties, TTSSpeakFrame
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
|
|
@ -139,6 +143,9 @@ class CustomToolManager:
|
|||
if tool.category == ToolCategory.END_CALL.value:
|
||||
return self._create_end_call_handler(tool, function_name)
|
||||
|
||||
if tool.category == ToolCategory.TRANSFER_CALL.value:
|
||||
return self._create_transfer_call_handler(tool, function_name)
|
||||
|
||||
return self._create_http_tool_handler(tool, function_name)
|
||||
|
||||
def _create_http_tool_handler(self, tool: Any, function_name: str):
|
||||
|
|
@ -230,3 +237,100 @@ class CustomToolManager:
|
|||
)
|
||||
|
||||
return end_call_handler
|
||||
|
||||
def _create_transfer_call_handler(self, tool: Any, function_name: str):
|
||||
"""Create a handler function for a transfer call tool.
|
||||
|
||||
Args:
|
||||
tool: The ToolModel instance
|
||||
function_name: The function name used by the LLM
|
||||
|
||||
Returns:
|
||||
Async handler function for the transfer call tool
|
||||
"""
|
||||
|
||||
async def transfer_call_handler(
|
||||
function_call_params: FunctionCallParams,
|
||||
) -> None:
|
||||
logger.info(f"Transfer Call Tool EXECUTED: {function_name}")
|
||||
|
||||
try:
|
||||
# Get the transfer call configuration
|
||||
config = tool.definition.get("config", {})
|
||||
transfer_number = config.get("transferNumber", "")
|
||||
transfer_message = config.get("transferMessage", "")
|
||||
|
||||
if not transfer_number:
|
||||
logger.error("Transfer number not configured")
|
||||
await function_call_params.result_callback(
|
||||
{"status": "error", "error": "Transfer number not configured"}
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(f"Initiating transfer to: {transfer_number}")
|
||||
|
||||
# Play transfer message if configured
|
||||
if transfer_message:
|
||||
logger.info(f"Playing transfer message: {transfer_message}")
|
||||
await self._engine.task.queue_frame(TTSSpeakFrame(transfer_message))
|
||||
|
||||
# Store transfer intent in gathered context
|
||||
self._engine._gathered_context["transfer_requested"] = True
|
||||
self._engine._gathered_context["transfer_number"] = transfer_number
|
||||
|
||||
# Wait for external signal to proceed with transfer (30s timeout)
|
||||
workflow_run_id = self._engine._workflow_run_id
|
||||
logger.info(
|
||||
f"Waiting for transfer signal for workflow_run_id: {workflow_run_id}"
|
||||
)
|
||||
|
||||
transfer_event = await wait_for_transfer_signal(
|
||||
workflow_run_id=workflow_run_id,
|
||||
timeout_seconds=30.0,
|
||||
)
|
||||
|
||||
if transfer_event is None:
|
||||
# Timeout - transfer failed
|
||||
logger.warning("Transfer signal timed out")
|
||||
self._engine._gathered_context["transfer_status"] = "timed_out"
|
||||
await function_call_params.result_callback(
|
||||
{"status": "error", "error": "Transfer signal timed out"}
|
||||
)
|
||||
return
|
||||
|
||||
if transfer_event.type == TransferEventType.TRANSFER_CANCEL.value:
|
||||
# Cancelled - transfer failed
|
||||
logger.info("Transfer was cancelled")
|
||||
self._engine._gathered_context["transfer_status"] = "cancelled"
|
||||
await function_call_params.result_callback(
|
||||
{"status": "error", "error": "Transfer was cancelled"}
|
||||
)
|
||||
return
|
||||
|
||||
# Success - proceed with transfer
|
||||
logger.info("Transfer signal received - proceeding with transfer")
|
||||
self._engine._gathered_context["transfer_status"] = "success"
|
||||
|
||||
# Lets send result callback so that timeout task is cancelled. Lets not
|
||||
# run llm
|
||||
await function_call_params.result_callback(
|
||||
{"status": "error", "error": "Transfer was cancelled"},
|
||||
properties=FunctionCallResultProperties(run_llm=False),
|
||||
)
|
||||
|
||||
# Terminate the call after the call is added to the conference
|
||||
await self._engine.end_call_with_reason(
|
||||
EndTaskReason.CALL_TRANSFERRED.value,
|
||||
abort_immediately=True,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Transfer call tool '{function_name}' execution failed: {e}"
|
||||
)
|
||||
await function_call_params.result_callback(
|
||||
{"status": "error", "error": str(e)},
|
||||
properties=properties,
|
||||
)
|
||||
|
||||
return transfer_call_handler
|
||||
|
|
|
|||
127
api/services/workflow/transfer_event_protocol.py
Normal file
127
api/services/workflow/transfer_event_protocol.py
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
"""Transfer call event protocol for Redis-based coordination.
|
||||
|
||||
Simple protocol for awaiting transfer completion signal from external trigger.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from dataclasses import asdict, dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
from loguru import logger
|
||||
|
||||
from api.constants import REDIS_URL
|
||||
|
||||
|
||||
class TransferEventType(str, Enum):
|
||||
"""Types of transfer events."""
|
||||
|
||||
TRANSFER_PROCEED = "transfer_proceed"
|
||||
TRANSFER_CANCEL = "transfer_cancel"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransferEvent:
|
||||
"""Event sent to signal transfer status."""
|
||||
|
||||
type: str
|
||||
workflow_run_id: int
|
||||
message: Optional[str] = None
|
||||
|
||||
def to_json(self) -> str:
|
||||
return json.dumps(asdict(self))
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, data: str) -> "TransferEvent":
|
||||
return cls(**json.loads(data))
|
||||
|
||||
|
||||
class TransferRedisChannels:
|
||||
"""Redis channel naming for transfer events."""
|
||||
|
||||
@staticmethod
|
||||
def transfer_await(workflow_run_id: int) -> str:
|
||||
"""Channel for awaiting transfer completion."""
|
||||
return f"transfer:await:{workflow_run_id}"
|
||||
|
||||
|
||||
async def wait_for_transfer_signal(
|
||||
workflow_run_id: int,
|
||||
timeout_seconds: float = 30.0,
|
||||
) -> Optional[TransferEvent]:
|
||||
"""Wait for a transfer signal on Redis pub/sub.
|
||||
|
||||
Args:
|
||||
workflow_run_id: The workflow run ID to wait for
|
||||
timeout_seconds: How long to wait before timing out
|
||||
|
||||
Returns:
|
||||
TransferEvent if received, None if timed out
|
||||
"""
|
||||
channel = TransferRedisChannels.transfer_await(workflow_run_id)
|
||||
redis_client = await aioredis.from_url(REDIS_URL, decode_responses=True)
|
||||
pubsub = redis_client.pubsub()
|
||||
|
||||
try:
|
||||
await pubsub.subscribe(channel)
|
||||
logger.info(f"Waiting for transfer signal on channel: {channel}")
|
||||
|
||||
async def listen_for_event() -> Optional[TransferEvent]:
|
||||
async for message in pubsub.listen():
|
||||
if message["type"] == "message":
|
||||
event = TransferEvent.from_json(message["data"])
|
||||
logger.info(f"Received transfer event: {event.type}")
|
||||
return event
|
||||
# pubsub.listen() ended (connection closed) - return None
|
||||
return None
|
||||
|
||||
# Wait with timeout
|
||||
event = await asyncio.wait_for(listen_for_event(), timeout=timeout_seconds)
|
||||
return event
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Transfer signal timed out after {timeout_seconds}s")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error waiting for transfer signal: {e}")
|
||||
return None
|
||||
finally:
|
||||
await pubsub.unsubscribe(channel)
|
||||
await pubsub.aclose()
|
||||
await redis_client.aclose()
|
||||
|
||||
|
||||
async def send_transfer_signal(
|
||||
workflow_run_id: int,
|
||||
event_type: TransferEventType = TransferEventType.TRANSFER_PROCEED,
|
||||
message: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Send a transfer signal to unblock a waiting handler.
|
||||
|
||||
Args:
|
||||
workflow_run_id: The workflow run ID to signal
|
||||
event_type: Type of signal (proceed or cancel)
|
||||
message: Optional message
|
||||
|
||||
Returns:
|
||||
True if signal was sent successfully
|
||||
"""
|
||||
channel = TransferRedisChannels.transfer_await(workflow_run_id)
|
||||
redis_client = await aioredis.from_url(REDIS_URL, decode_responses=True)
|
||||
|
||||
try:
|
||||
event = TransferEvent(
|
||||
type=event_type.value,
|
||||
workflow_run_id=workflow_run_id,
|
||||
message=message,
|
||||
)
|
||||
await redis_client.publish(channel, event.to_json())
|
||||
logger.info(f"Sent transfer signal to channel: {channel}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending transfer signal: {e}")
|
||||
return False
|
||||
finally:
|
||||
await redis_client.aclose()
|
||||
|
|
@ -113,6 +113,7 @@ class MockToolModel:
|
|||
name: str
|
||||
description: str
|
||||
definition: Dict[str, Any]
|
||||
category: str = "http_api"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -144,6 +145,25 @@ def mock_user_config():
|
|||
return MockUserConfig()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def transfer_call_tool():
|
||||
"""Create a mock transfer call tool for testing."""
|
||||
return MockToolModel(
|
||||
tool_uuid="transfer-uuid-001",
|
||||
name="Transfer to Support",
|
||||
description="Transfer the call to a support representative",
|
||||
category="transfer_call",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "transfer_call",
|
||||
"config": {
|
||||
"transferNumber": "+15551234567",
|
||||
"transferMessage": "Please hold while I transfer you to a support representative.",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_tools():
|
||||
"""Create sample mock tools for testing."""
|
||||
|
|
|
|||
|
|
@ -5,14 +5,15 @@ using PipecatEngine's actual function registration and execution logic.
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Callable, Coroutine, Dict, List, Optional
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.transfer_event_protocol import send_transfer_signal
|
||||
from api.services.workflow.workflow import WorkflowGraph
|
||||
from api.tests.conftest import END_CALL_SYSTEM_PROMPT
|
||||
from api.tests.conftest import END_CALL_SYSTEM_PROMPT, MockToolModel
|
||||
from pipecat.frames.frames import LLMContextFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
|
|
@ -32,7 +33,11 @@ async def run_pipeline_with_tool_calls(
|
|||
functions: List[Dict[str, Any]],
|
||||
text: str | None = None,
|
||||
num_text_steps: int = 1,
|
||||
) -> tuple[MockLLMService, LLMContext]:
|
||||
mock_tools: Optional[List[MockToolModel]] = None,
|
||||
on_engine_ready: Optional[
|
||||
Callable[[PipecatEngine], Coroutine[Any, Any, None]]
|
||||
] = None,
|
||||
) -> tuple[MockLLMService, LLMContext, PipecatEngine]:
|
||||
"""Run a pipeline with mock tool calls and return the LLM for assertions.
|
||||
|
||||
Args:
|
||||
|
|
@ -40,9 +45,12 @@ async def run_pipeline_with_tool_calls(
|
|||
functions: List of function call definitions with name, arguments, and tool_call_id.
|
||||
text: Text to add to the first step (streamed before the tool calls).
|
||||
num_text_steps: Number of text response steps after the tool calls.
|
||||
mock_tools: Optional list of mock tools to be returned by db_client.get_tools_by_uuids.
|
||||
on_engine_ready: Optional async callback called after engine is initialized.
|
||||
Useful for sending signals or performing actions during pipeline execution.
|
||||
|
||||
Returns:
|
||||
The MockLLMService instance for making assertions.
|
||||
Tuple of (MockLLMService, LLMContext, PipecatEngine) for making assertions.
|
||||
"""
|
||||
# Create first step chunks
|
||||
if text:
|
||||
|
|
@ -118,25 +126,43 @@ async def run_pipeline_with_tool_calls(
|
|||
return_value=1,
|
||||
):
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine.apply_disposition_mapping",
|
||||
"api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run",
|
||||
new_callable=AsyncMock,
|
||||
return_value="completed",
|
||||
return_value=1,
|
||||
):
|
||||
runner = PipelineRunner()
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine.apply_disposition_mapping",
|
||||
new_callable=AsyncMock,
|
||||
return_value="completed",
|
||||
):
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine_custom_tools.db_client.get_tools_by_uuids",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tools or [],
|
||||
):
|
||||
runner = PipelineRunner()
|
||||
|
||||
async def run_pipeline():
|
||||
await runner.run(task)
|
||||
async def run_pipeline():
|
||||
await runner.run(task)
|
||||
|
||||
async def initialize_engine():
|
||||
# Small delay to let runner start
|
||||
await asyncio.sleep(0.01)
|
||||
await engine.initialize()
|
||||
await engine.llm.queue_frame(LLMContextFrame(engine.context))
|
||||
async def initialize_engine():
|
||||
# Small delay to let runner start
|
||||
await asyncio.sleep(0.01)
|
||||
await engine.initialize()
|
||||
await engine.llm.queue_frame(LLMContextFrame(engine.context))
|
||||
|
||||
# Run both concurrently
|
||||
await asyncio.gather(run_pipeline(), initialize_engine())
|
||||
async def run_callback():
|
||||
if on_engine_ready:
|
||||
# Wait for engine to process tool calls
|
||||
await asyncio.sleep(0.1)
|
||||
await on_engine_ready(engine)
|
||||
|
||||
return llm, context
|
||||
# Run all concurrently
|
||||
await asyncio.gather(
|
||||
run_pipeline(), initialize_engine(), run_callback()
|
||||
)
|
||||
|
||||
return llm, context, engine
|
||||
|
||||
|
||||
class TestPipecatEngineToolCalls:
|
||||
|
|
@ -172,7 +198,7 @@ class TestPipecatEngineToolCalls:
|
|||
},
|
||||
]
|
||||
|
||||
llm, context = await run_pipeline_with_tool_calls(
|
||||
llm, context, _ = await run_pipeline_with_tool_calls(
|
||||
workflow=simple_workflow,
|
||||
functions=functions,
|
||||
num_text_steps=2,
|
||||
|
|
@ -218,7 +244,7 @@ class TestPipecatEngineToolCalls:
|
|||
},
|
||||
]
|
||||
|
||||
llm, context = await run_pipeline_with_tool_calls(
|
||||
llm, context, _ = await run_pipeline_with_tool_calls(
|
||||
workflow=simple_workflow,
|
||||
functions=functions,
|
||||
num_text_steps=2,
|
||||
|
|
@ -265,7 +291,7 @@ class TestPipecatEngineToolCalls:
|
|||
},
|
||||
]
|
||||
|
||||
llm, context = await run_pipeline_with_tool_calls(
|
||||
llm, context, _ = await run_pipeline_with_tool_calls(
|
||||
workflow=simple_workflow,
|
||||
functions=functions,
|
||||
text="Hello There!",
|
||||
|
|
@ -302,7 +328,7 @@ class TestPipecatEngineToolCalls:
|
|||
},
|
||||
]
|
||||
|
||||
llm, context = await run_pipeline_with_tool_calls(
|
||||
llm, context, _ = await run_pipeline_with_tool_calls(
|
||||
workflow=simple_workflow,
|
||||
functions=functions,
|
||||
num_text_steps=1,
|
||||
|
|
@ -316,3 +342,54 @@ class TestPipecatEngineToolCalls:
|
|||
|
||||
# Assert that the context was updated with END_CALL_SYSTEM_PROMPT
|
||||
assert context.messages[0]["content"] == END_CALL_SYSTEM_PROMPT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transfer_call_tool_execution(
|
||||
self, simple_workflow: WorkflowGraph, transfer_call_tool: MockToolModel
|
||||
):
|
||||
"""Test transfer call tool execution through PipecatEngine.
|
||||
|
||||
This test verifies that when the LLM calls the transfer_to_support tool:
|
||||
1. The transfer call handler is invoked
|
||||
2. The handler waits for a transfer signal via Redis pub/sub
|
||||
3. When the signal is sent, the handler proceeds
|
||||
4. The gathered_context is updated with transfer_requested=True
|
||||
5. The gathered_context contains the transfer_number
|
||||
"""
|
||||
# Add the transfer tool to the start node at runtime
|
||||
simple_workflow.nodes["start"].tool_uuids = [transfer_call_tool.tool_uuid]
|
||||
simple_workflow.nodes["start"].extraction_enabled = False
|
||||
|
||||
# The function name is derived from the tool name (snake_case)
|
||||
functions = [
|
||||
{
|
||||
"name": "transfer_to_support",
|
||||
"arguments": {},
|
||||
"tool_call_id": "call_transfer",
|
||||
},
|
||||
]
|
||||
|
||||
# Callback to send transfer signal while handler is waiting
|
||||
async def send_signal(engine: PipecatEngine):
|
||||
# Send the transfer signal to unblock the waiting handler
|
||||
await send_transfer_signal(
|
||||
workflow_run_id=engine._workflow_run_id,
|
||||
)
|
||||
|
||||
_, _, engine = await run_pipeline_with_tool_calls(
|
||||
workflow=simple_workflow,
|
||||
functions=functions,
|
||||
num_text_steps=1,
|
||||
mock_tools=[transfer_call_tool],
|
||||
on_engine_ready=send_signal,
|
||||
)
|
||||
|
||||
# Verify the gathered context was updated with transfer information
|
||||
gathered_context = await engine.get_gathered_context()
|
||||
|
||||
assert gathered_context.get("transfer_requested") is True, (
|
||||
"transfer_requested should be True in gathered_context"
|
||||
)
|
||||
assert gathered_context.get("transfer_number") == "+15551234567", (
|
||||
"transfer_number should match the configured number"
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue