mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-16 08:25:18 +02:00
add call transfer skeleton
This commit is contained in:
parent
e8005042e2
commit
c990af2a16
8 changed files with 450 additions and 25 deletions
|
|
@ -113,6 +113,7 @@ class MockToolModel:
|
|||
name: str
|
||||
description: str
|
||||
definition: Dict[str, Any]
|
||||
category: str = "http_api"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -144,6 +145,25 @@ def mock_user_config():
|
|||
return MockUserConfig()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def transfer_call_tool():
|
||||
"""Create a mock transfer call tool for testing."""
|
||||
return MockToolModel(
|
||||
tool_uuid="transfer-uuid-001",
|
||||
name="Transfer to Support",
|
||||
description="Transfer the call to a support representative",
|
||||
category="transfer_call",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "transfer_call",
|
||||
"config": {
|
||||
"transferNumber": "+15551234567",
|
||||
"transferMessage": "Please hold while I transfer you to a support representative.",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_tools():
|
||||
"""Create sample mock tools for testing."""
|
||||
|
|
|
|||
|
|
@ -5,14 +5,15 @@ using PipecatEngine's actual function registration and execution logic.
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Callable, Coroutine, Dict, List, Optional
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.transfer_event_protocol import send_transfer_signal
|
||||
from api.services.workflow.workflow import WorkflowGraph
|
||||
from api.tests.conftest import END_CALL_SYSTEM_PROMPT
|
||||
from api.tests.conftest import END_CALL_SYSTEM_PROMPT, MockToolModel
|
||||
from pipecat.frames.frames import LLMContextFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
|
|
@ -32,7 +33,11 @@ async def run_pipeline_with_tool_calls(
|
|||
functions: List[Dict[str, Any]],
|
||||
text: str | None = None,
|
||||
num_text_steps: int = 1,
|
||||
) -> tuple[MockLLMService, LLMContext]:
|
||||
mock_tools: Optional[List[MockToolModel]] = None,
|
||||
on_engine_ready: Optional[
|
||||
Callable[[PipecatEngine], Coroutine[Any, Any, None]]
|
||||
] = None,
|
||||
) -> tuple[MockLLMService, LLMContext, PipecatEngine]:
|
||||
"""Run a pipeline with mock tool calls and return the LLM for assertions.
|
||||
|
||||
Args:
|
||||
|
|
@ -40,9 +45,12 @@ async def run_pipeline_with_tool_calls(
|
|||
functions: List of function call definitions with name, arguments, and tool_call_id.
|
||||
text: Text to add to the first step (streamed before the tool calls).
|
||||
num_text_steps: Number of text response steps after the tool calls.
|
||||
mock_tools: Optional list of mock tools to be returned by db_client.get_tools_by_uuids.
|
||||
on_engine_ready: Optional async callback called after engine is initialized.
|
||||
Useful for sending signals or performing actions during pipeline execution.
|
||||
|
||||
Returns:
|
||||
The MockLLMService instance for making assertions.
|
||||
Tuple of (MockLLMService, LLMContext, PipecatEngine) for making assertions.
|
||||
"""
|
||||
# Create first step chunks
|
||||
if text:
|
||||
|
|
@ -118,25 +126,43 @@ async def run_pipeline_with_tool_calls(
|
|||
return_value=1,
|
||||
):
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine.apply_disposition_mapping",
|
||||
"api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run",
|
||||
new_callable=AsyncMock,
|
||||
return_value="completed",
|
||||
return_value=1,
|
||||
):
|
||||
runner = PipelineRunner()
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine.apply_disposition_mapping",
|
||||
new_callable=AsyncMock,
|
||||
return_value="completed",
|
||||
):
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine_custom_tools.db_client.get_tools_by_uuids",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tools or [],
|
||||
):
|
||||
runner = PipelineRunner()
|
||||
|
||||
async def run_pipeline():
|
||||
await runner.run(task)
|
||||
async def run_pipeline():
|
||||
await runner.run(task)
|
||||
|
||||
async def initialize_engine():
|
||||
# Small delay to let runner start
|
||||
await asyncio.sleep(0.01)
|
||||
await engine.initialize()
|
||||
await engine.llm.queue_frame(LLMContextFrame(engine.context))
|
||||
async def initialize_engine():
|
||||
# Small delay to let runner start
|
||||
await asyncio.sleep(0.01)
|
||||
await engine.initialize()
|
||||
await engine.llm.queue_frame(LLMContextFrame(engine.context))
|
||||
|
||||
# Run both concurrently
|
||||
await asyncio.gather(run_pipeline(), initialize_engine())
|
||||
async def run_callback():
|
||||
if on_engine_ready:
|
||||
# Wait for engine to process tool calls
|
||||
await asyncio.sleep(0.1)
|
||||
await on_engine_ready(engine)
|
||||
|
||||
return llm, context
|
||||
# Run all concurrently
|
||||
await asyncio.gather(
|
||||
run_pipeline(), initialize_engine(), run_callback()
|
||||
)
|
||||
|
||||
return llm, context, engine
|
||||
|
||||
|
||||
class TestPipecatEngineToolCalls:
|
||||
|
|
@ -172,7 +198,7 @@ class TestPipecatEngineToolCalls:
|
|||
},
|
||||
]
|
||||
|
||||
llm, context = await run_pipeline_with_tool_calls(
|
||||
llm, context, _ = await run_pipeline_with_tool_calls(
|
||||
workflow=simple_workflow,
|
||||
functions=functions,
|
||||
num_text_steps=2,
|
||||
|
|
@ -218,7 +244,7 @@ class TestPipecatEngineToolCalls:
|
|||
},
|
||||
]
|
||||
|
||||
llm, context = await run_pipeline_with_tool_calls(
|
||||
llm, context, _ = await run_pipeline_with_tool_calls(
|
||||
workflow=simple_workflow,
|
||||
functions=functions,
|
||||
num_text_steps=2,
|
||||
|
|
@ -265,7 +291,7 @@ class TestPipecatEngineToolCalls:
|
|||
},
|
||||
]
|
||||
|
||||
llm, context = await run_pipeline_with_tool_calls(
|
||||
llm, context, _ = await run_pipeline_with_tool_calls(
|
||||
workflow=simple_workflow,
|
||||
functions=functions,
|
||||
text="Hello There!",
|
||||
|
|
@ -302,7 +328,7 @@ class TestPipecatEngineToolCalls:
|
|||
},
|
||||
]
|
||||
|
||||
llm, context = await run_pipeline_with_tool_calls(
|
||||
llm, context, _ = await run_pipeline_with_tool_calls(
|
||||
workflow=simple_workflow,
|
||||
functions=functions,
|
||||
num_text_steps=1,
|
||||
|
|
@ -316,3 +342,54 @@ class TestPipecatEngineToolCalls:
|
|||
|
||||
# Assert that the context was updated with END_CALL_SYSTEM_PROMPT
|
||||
assert context.messages[0]["content"] == END_CALL_SYSTEM_PROMPT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transfer_call_tool_execution(
|
||||
self, simple_workflow: WorkflowGraph, transfer_call_tool: MockToolModel
|
||||
):
|
||||
"""Test transfer call tool execution through PipecatEngine.
|
||||
|
||||
This test verifies that when the LLM calls the transfer_to_support tool:
|
||||
1. The transfer call handler is invoked
|
||||
2. The handler waits for a transfer signal via Redis pub/sub
|
||||
3. When the signal is sent, the handler proceeds
|
||||
4. The gathered_context is updated with transfer_requested=True
|
||||
5. The gathered_context contains the transfer_number
|
||||
"""
|
||||
# Add the transfer tool to the start node at runtime
|
||||
simple_workflow.nodes["start"].tool_uuids = [transfer_call_tool.tool_uuid]
|
||||
simple_workflow.nodes["start"].extraction_enabled = False
|
||||
|
||||
# The function name is derived from the tool name (snake_case)
|
||||
functions = [
|
||||
{
|
||||
"name": "transfer_to_support",
|
||||
"arguments": {},
|
||||
"tool_call_id": "call_transfer",
|
||||
},
|
||||
]
|
||||
|
||||
# Callback to send transfer signal while handler is waiting
|
||||
async def send_signal(engine: PipecatEngine):
|
||||
# Send the transfer signal to unblock the waiting handler
|
||||
await send_transfer_signal(
|
||||
workflow_run_id=engine._workflow_run_id,
|
||||
)
|
||||
|
||||
_, _, engine = await run_pipeline_with_tool_calls(
|
||||
workflow=simple_workflow,
|
||||
functions=functions,
|
||||
num_text_steps=1,
|
||||
mock_tools=[transfer_call_tool],
|
||||
on_engine_ready=send_signal,
|
||||
)
|
||||
|
||||
# Verify the gathered context was updated with transfer information
|
||||
gathered_context = await engine.get_gathered_context()
|
||||
|
||||
assert gathered_context.get("transfer_requested") is True, (
|
||||
"transfer_requested should be True in gathered_context"
|
||||
)
|
||||
assert gathered_context.get("transfer_number") == "+15551234567", (
|
||||
"transfer_number should match the configured number"
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue