feat: user defined custom tools as part of workflow execution (#94)

* feat: add custom tools functionality

* Show tools in nodes

* integrate tool calling with pipeline engine
This commit is contained in:
Abhishek 2026-01-02 13:11:02 +05:30 committed by GitHub
parent cc2d3e70d2
commit 3e55af9256
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
65 changed files with 5483 additions and 6673 deletions

View file

@ -1,138 +0,0 @@
import asyncio
import pytest
from pipecat.frames.frames import (
FunctionCallInProgressFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
StartInterruptionFrame,
TextFrame,
)
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.openai.llm import OpenAIAssistantContextAggregator
@pytest.mark.asyncio
async def test_reordering_after_completion():
context = OpenAILLMContext()
aggr = OpenAIAssistantContextAggregator(context)
# Initialize task manager properly using PipelineTask
pipeline = Pipeline([aggr])
task = PipelineTask(pipeline)
runner = PipelineRunner()
# Start the task to properly initialize the frame processor
task_coroutine = asyncio.create_task(runner.run(task))
# Give the task a moment to initialize
await asyncio.sleep(0.01)
# start new LLM response
await aggr.process_frame(LLMFullResponseStartFrame(), FrameDirection.DOWNSTREAM)
# simulate a pending function call
await aggr.process_frame(
FunctionCallInProgressFrame(
function_name="transition",
tool_call_id="1",
arguments={},
cancel_on_interruption=False,
),
FrameDirection.DOWNSTREAM,
)
# now text arrives
await aggr.process_frame(TextFrame("Hi there"), FrameDirection.DOWNSTREAM)
# end response
await aggr.process_frame(LLMFullResponseEndFrame(), FrameDirection.DOWNSTREAM)
msgs = context.get_messages()
# Assert order: assistant text first, then tool_call assistant, then tool response
assert msgs[0]["role"] == "assistant" and "tool_calls" not in msgs[0]
# Fix: content is a string, not a structured object
assert msgs[0]["content"] == "Hi there"
assert any(m.get("role") == "assistant" and m.get("tool_calls") for m in msgs[1:])
assert any(m.get("role") == "tool" for m in msgs[1:])
# Clean up the running task
await task.cancel()
task_coroutine.cancel()
try:
await task_coroutine
except asyncio.CancelledError:
pass
@pytest.mark.asyncio
async def test_interruption_removes_pending_function_calls_and_marks():
context = OpenAILLMContext()
aggr = OpenAIAssistantContextAggregator(context)
# Initialize task manager properly using PipelineTask
pipeline = Pipeline([aggr])
task = PipelineTask(pipeline)
runner = PipelineRunner()
# Start the task to properly initialize the frame processor
task_coroutine = asyncio.create_task(runner.run(task))
# Give the task a moment to initialize
await asyncio.sleep(0.01)
await aggr.process_frame(LLMFullResponseStartFrame(), FrameDirection.DOWNSTREAM)
await aggr.process_frame(
FunctionCallInProgressFrame(
function_name="transition",
tool_call_id="1",
arguments={},
cancel_on_interruption=False,
),
FrameDirection.DOWNSTREAM,
)
# Debug: Check the state before interruption
print(
f"Function calls in progress before interruption: {aggr._function_calls_in_progress}"
)
print(f"Messages before interruption: {context.get_messages()}")
# no text yet - still aggregation
await aggr.process_frame(StartInterruptionFrame(), FrameDirection.DOWNSTREAM)
msgs = context.get_messages()
# Debug: Print messages to understand what's happening
print(f"Messages after interruption: {msgs}")
print(
f"Function calls in progress after interruption: {aggr._function_calls_in_progress}"
)
# After interruption before any response is complete, context should be cleared
# This is the actual behavior - interruptions clear pending function calls
if len(msgs) == 0:
# Context was cleared due to interruption before completion
assert True
else:
# If there are messages, ensure no tool calls remain
assert not any(m.get("tool_calls") for m in msgs)
assert not any(m.get("role") == "tool" for m in msgs)
# Check if interruption marker is present
if msgs:
assert msgs[-1]["role"] == "assistant"
assert "<<interrupted_by_user>>" in msgs[-1]["content"]
# Clean up the running task
await task.cancel()
task_coroutine.cancel()
try:
await task_coroutine
except asyncio.CancelledError:
pass

View file

@ -1,120 +0,0 @@
import os
import wave
import pytest
from api.services.pipecat.audio_transcript_buffers import (
InMemoryAudioBuffer,
InMemoryTranscriptBuffer,
)
@pytest.mark.asyncio
async def test_audio_buffer_append_and_write():
"""Test that audio buffer can append data and write to temp file."""
buffer = InMemoryAudioBuffer(workflow_run_id=123, sample_rate=16000, num_channels=1)
# Create some test PCM data
test_pcm = b"\x00\x01" * 1000 # 2000 bytes
# Append data
await buffer.append(test_pcm)
await buffer.append(test_pcm)
assert buffer.size == 4000
assert not buffer.is_empty
# Write to temp file
temp_path = await buffer.write_to_temp_file()
try:
# Verify file exists and is valid WAV
assert os.path.exists(temp_path)
with wave.open(temp_path, "rb") as wf:
assert wf.getnchannels() == 1
assert wf.getsampwidth() == 2
assert wf.getframerate() == 16000
# Each frame is 2 bytes (16-bit), so 4000 bytes = 2000 frames
assert wf.getnframes() == 2000
finally:
# Clean up
if os.path.exists(temp_path):
os.remove(temp_path)
@pytest.mark.asyncio
async def test_audio_buffer_memory_limit():
"""Test that audio buffer enforces memory limit."""
buffer = InMemoryAudioBuffer(workflow_run_id=123, sample_rate=16000)
# Set a smaller limit for testing
buffer._max_size = 1000
# This should work
await buffer.append(b"\x00" * 500)
# This should fail
with pytest.raises(MemoryError):
await buffer.append(b"\x00" * 600)
@pytest.mark.asyncio
async def test_transcript_buffer_append_and_write():
"""Test that transcript buffer can append data and write to temp file."""
buffer = InMemoryTranscriptBuffer(workflow_run_id=456)
# Append some transcript lines
await buffer.append("[00:00:01] user: Hello\n")
await buffer.append("[00:00:02] assistant: Hi there!\n")
await buffer.append("[00:00:03] user: How are you?\n")
assert not buffer.is_empty
# Write to temp file
temp_path = await buffer.write_to_temp_file()
try:
# Verify file exists and has correct content
assert os.path.exists(temp_path)
with open(temp_path, "r") as f:
content = f.read()
assert "[00:00:01] user: Hello\n" in content
assert "[00:00:02] assistant: Hi there!\n" in content
assert "[00:00:03] user: How are you?\n" in content
finally:
# Clean up
if os.path.exists(temp_path):
os.remove(temp_path)
@pytest.mark.asyncio
async def test_empty_buffers():
"""Test that empty buffers are handled correctly."""
audio_buffer = InMemoryAudioBuffer(workflow_run_id=789, sample_rate=16000)
transcript_buffer = InMemoryTranscriptBuffer(workflow_run_id=789)
assert audio_buffer.is_empty
assert transcript_buffer.is_empty
# Should still be able to write empty files
audio_path = await audio_buffer.write_to_temp_file()
transcript_path = await transcript_buffer.write_to_temp_file()
try:
assert os.path.exists(audio_path)
assert os.path.exists(transcript_path)
# Empty WAV file should still have valid header
with wave.open(audio_path, "rb") as wf:
assert wf.getnframes() == 0
# Empty transcript file
with open(transcript_path, "r") as f:
assert f.read() == ""
finally:
if os.path.exists(audio_path):
os.remove(audio_path)
if os.path.exists(transcript_path):
os.remove(transcript_path)

View file

@ -1,330 +0,0 @@
"""Tests for concurrent call limiting functionality."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from api.enums import OrganizationConfigurationKey
from api.services.campaign.rate_limiter import RateLimiter
class TestConcurrentCallLimiting:
"""Test suite for concurrent call limiting."""
@pytest.mark.asyncio
async def test_acquire_concurrent_slot_success(self):
"""Test successful acquisition of concurrent slot."""
rate_limiter = RateLimiter()
# Mock Redis client
with patch.object(rate_limiter, "_get_redis") as mock_redis:
mock_client = AsyncMock()
mock_client.eval = AsyncMock(return_value="test_slot_123")
mock_redis.return_value = mock_client
# Try to acquire slot
slot_id = await rate_limiter.try_acquire_concurrent_slot(
organization_id=1, max_concurrent=20
)
assert slot_id == "test_slot_123"
mock_client.eval.assert_called_once()
@pytest.mark.asyncio
async def test_acquire_concurrent_slot_limit_reached(self):
"""Test slot acquisition when limit is reached."""
rate_limiter = RateLimiter()
# Mock Redis client
with patch.object(rate_limiter, "_get_redis") as mock_redis:
mock_client = AsyncMock()
mock_client.eval = AsyncMock(return_value=None) # Limit reached
mock_redis.return_value = mock_client
# Try to acquire slot
slot_id = await rate_limiter.try_acquire_concurrent_slot(
organization_id=1, max_concurrent=20
)
assert slot_id is None
mock_client.eval.assert_called_once()
@pytest.mark.asyncio
async def test_release_concurrent_slot(self):
"""Test releasing a concurrent slot."""
rate_limiter = RateLimiter()
# Mock Redis client
with patch.object(rate_limiter, "_get_redis") as mock_redis:
mock_client = AsyncMock()
mock_client.zrem = AsyncMock(return_value=1) # Successfully removed
mock_redis.return_value = mock_client
# Release slot
success = await rate_limiter.release_concurrent_slot(
organization_id=1, slot_id="test_slot_123"
)
assert success is True
mock_client.zrem.assert_called_once_with(
"concurrent_calls:1", "test_slot_123"
)
@pytest.mark.asyncio
async def test_get_concurrent_count(self):
"""Test getting current concurrent call count."""
rate_limiter = RateLimiter()
# Mock Redis client
with patch.object(rate_limiter, "_get_redis") as mock_redis:
mock_client = AsyncMock()
mock_client.zremrangebyscore = AsyncMock() # Cleanup stale entries
mock_client.zcard = AsyncMock(return_value=5) # 5 active calls
mock_redis.return_value = mock_client
# Get count
count = await rate_limiter.get_concurrent_count(organization_id=1)
assert count == 5
mock_client.zremrangebyscore.assert_called_once()
mock_client.zcard.assert_called_once()
@pytest.mark.asyncio
async def test_stale_entry_cleanup(self):
"""Test that stale entries are cleaned up automatically."""
rate_limiter = RateLimiter()
# Mock Redis client
with patch.object(rate_limiter, "_get_redis") as mock_redis:
mock_client = AsyncMock()
# Mock eval to simulate cleanup in Lua script
mock_client.eval = AsyncMock(return_value="new_slot_123")
mock_redis.return_value = mock_client
# Try to acquire slot (which should trigger cleanup)
slot_id = await rate_limiter.try_acquire_concurrent_slot(
organization_id=1, max_concurrent=20
)
assert slot_id == "new_slot_123"
# Verify Lua script was called with proper stale cutoff
call_args = mock_client.eval.call_args[0]
lua_script = call_args[0]
assert "ZREMRANGEBYSCORE" in lua_script # Cleanup command in script
@pytest.mark.asyncio
async def test_workflow_slot_mapping_operations(self):
"""Test storing, retrieving, and deleting workflow slot mappings."""
rate_limiter = RateLimiter()
# Mock Redis client
with patch.object(rate_limiter, "_get_redis") as mock_redis:
mock_client = AsyncMock()
mock_client.hset = AsyncMock(return_value=1)
mock_client.expire = AsyncMock(return_value=True)
mock_client.hgetall = AsyncMock(
return_value={"org_id": "1", "slot_id": "test_slot_123"}
)
mock_client.delete = AsyncMock(return_value=1)
mock_redis.return_value = mock_client
# Test storing mapping
success = await rate_limiter.store_workflow_slot_mapping(
workflow_run_id=999, organization_id=1, slot_id="test_slot_123"
)
assert success is True
mock_client.hset.assert_called_once()
mock_client.expire.assert_called_once()
# Test retrieving mapping
mapping = await rate_limiter.get_workflow_slot_mapping(workflow_run_id=999)
assert mapping == (1, "test_slot_123")
mock_client.hgetall.assert_called_once_with("workflow_slot_mapping:999")
# Test deleting mapping
deleted = await rate_limiter.delete_workflow_slot_mapping(
workflow_run_id=999
)
assert deleted is True
mock_client.delete.assert_called_once_with("workflow_slot_mapping:999")
class TestCampaignCallDispatcher:
"""Test suite for CampaignCallDispatcher with concurrent limiting."""
@pytest.mark.asyncio
async def test_dispatch_call_waits_for_slot(self):
"""Test that dispatch_call waits for available slot."""
from api.services.campaign.call_dispatcher import CampaignCallDispatcher
dispatcher = CampaignCallDispatcher()
# Mock dependencies
mock_campaign = MagicMock(
organization_id=1, workflow_id=123, id=456, created_by=789
)
mock_queued_run = MagicMock(
id=111, context_variables={"phone_number": "+1234567890"}
)
# Mock rate limiter to simulate waiting
slot_acquired = False
call_count = 0
async def mock_try_acquire(org_id, max_concurrent):
nonlocal slot_acquired, call_count
call_count += 1
if call_count > 2: # Succeed on third try
slot_acquired = True
return "test_slot_123"
return None
with patch(
"api.services.campaign.call_dispatcher.rate_limiter"
) as mock_limiter:
mock_limiter.try_acquire_concurrent_slot = AsyncMock(
side_effect=mock_try_acquire
)
mock_limiter.release_concurrent_slot = AsyncMock()
mock_limiter.store_workflow_slot_mapping = AsyncMock(return_value=True)
with patch("api.services.campaign.call_dispatcher.db_client") as mock_db:
mock_db.get_configuration = AsyncMock(return_value=None)
mock_db.get_workflow_by_id = AsyncMock(
return_value=MagicMock(template_context_variables={})
)
mock_db.create_workflow_run = AsyncMock(
return_value=MagicMock(id=999, logs={})
)
with patch.object(
dispatcher.twilio_service, "initiate_call"
) as mock_twilio:
mock_twilio.return_value = {"sid": "test_sid"}
# Dispatch call (should wait and retry)
workflow_run = await dispatcher.dispatch_call(
mock_queued_run, mock_campaign
)
assert workflow_run is not None
assert slot_acquired is True
assert call_count == 3 # Tried 3 times
assert mock_limiter.try_acquire_concurrent_slot.call_count == 3
@pytest.mark.asyncio
async def test_dispatch_call_stores_slot_mapping(self):
"""Test that dispatch_call stores slot mapping in Redis."""
from api.services.campaign.call_dispatcher import CampaignCallDispatcher
dispatcher = CampaignCallDispatcher()
# Mock dependencies
mock_campaign = MagicMock(
organization_id=1, workflow_id=123, id=456, created_by=789
)
mock_queued_run = MagicMock(
id=111, context_variables={"phone_number": "+1234567890"}
)
with patch(
"api.services.campaign.call_dispatcher.rate_limiter"
) as mock_limiter:
mock_limiter.try_acquire_concurrent_slot = AsyncMock(
return_value="test_slot_123"
)
mock_limiter.store_workflow_slot_mapping = AsyncMock(return_value=True)
with patch("api.services.campaign.call_dispatcher.db_client") as mock_db:
mock_db.get_configuration = AsyncMock(return_value=None)
mock_db.get_workflow_by_id = AsyncMock(
return_value=MagicMock(template_context_variables={})
)
mock_db.create_workflow_run = AsyncMock(
return_value=MagicMock(id=999, logs={})
)
with patch.object(
dispatcher.twilio_service, "initiate_call"
) as mock_twilio:
mock_twilio.return_value = {"sid": "test_sid"}
# Dispatch call
workflow_run = await dispatcher.dispatch_call(
mock_queued_run, mock_campaign
)
# Verify slot mapping was stored
mock_limiter.store_workflow_slot_mapping.assert_called_once_with(
999, 1, "test_slot_123"
)
@pytest.mark.asyncio
async def test_org_specific_concurrent_limit(self):
"""Test that organization-specific concurrent limit is used."""
from api.services.campaign.call_dispatcher import CampaignCallDispatcher
dispatcher = CampaignCallDispatcher()
# Mock db_client to return org-specific limit
with patch("api.services.campaign.call_dispatcher.db_client") as mock_db:
mock_config = MagicMock(value={"value": 10}) # Org limit is 10
mock_db.get_configuration = AsyncMock(return_value=mock_config)
# Get org limit
limit = await dispatcher.get_org_concurrent_limit(organization_id=1)
assert limit == 10 # Should use org-specific limit
mock_db.get_configuration.assert_called_once_with(
1, OrganizationConfigurationKey.CONCURRENT_CALL_LIMIT.value
)
@pytest.mark.asyncio
async def test_default_concurrent_limit(self):
"""Test that default limit is used when org config not found."""
from api.services.campaign.call_dispatcher import CampaignCallDispatcher
dispatcher = CampaignCallDispatcher()
# Mock db_client to return None (no config)
with patch("api.services.campaign.call_dispatcher.db_client") as mock_db:
mock_db.get_configuration = AsyncMock(return_value=None)
# Get org limit
limit = await dispatcher.get_org_concurrent_limit(organization_id=1)
assert limit == 20 # Should use default limit
@pytest.mark.asyncio
async def test_release_call_slot(self):
"""Test releasing call slot when workflow completes."""
from api.services.campaign.call_dispatcher import CampaignCallDispatcher
dispatcher = CampaignCallDispatcher()
# Mock rate limiter
with patch(
"api.services.campaign.call_dispatcher.rate_limiter"
) as mock_limiter:
# Mock getting the slot mapping from Redis
mock_limiter.get_workflow_slot_mapping = AsyncMock(
return_value=(1, "test_slot_123")
)
mock_limiter.release_concurrent_slot = AsyncMock(return_value=True)
mock_limiter.delete_workflow_slot_mapping = AsyncMock(return_value=True)
# Release slot
success = await dispatcher.release_call_slot(workflow_run_id=999)
assert success is True
mock_limiter.get_workflow_slot_mapping.assert_called_once_with(999)
mock_limiter.release_concurrent_slot.assert_called_once_with(
1, "test_slot_123"
)
mock_limiter.delete_workflow_slot_mapping.assert_called_once_with(999)
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View file

@ -1,78 +0,0 @@
import pytest
from pydantic import ValidationError
from api.schemas.user_configuration import UserConfiguration
from api.services.configuration.masking import is_mask_of, mask_key, mask_user_config
from api.services.configuration.merge import merge_user_configurations
from api.services.configuration.registry import (
OpenAILLMService,
)
REAL_KEY = "sk-1234567890abcdef"
def _build_config_with_openai(key: str) -> UserConfiguration:
return UserConfiguration(
llm=OpenAILLMService(api_key=key),
stt=None,
tts=None,
)
def test_mask_key_basic():
masked = mask_key(REAL_KEY)
# Should reveal only last 4 chars
assert masked.endswith(REAL_KEY[-4:])
assert set(masked[:-4]) == {"*"}
assert len(masked) == len(REAL_KEY)
# is_mask_of round-trip
assert is_mask_of(masked, REAL_KEY)
def test_mask_user_config_masks_api_keys():
cfg = _build_config_with_openai(REAL_KEY)
dumped = mask_user_config(cfg)
assert dumped["llm"]["api_key"].endswith(REAL_KEY[-4:])
assert dumped["llm"]["api_key"].startswith("*" * (len(REAL_KEY) - 4))
def test_merge_preserves_key_when_mask_sent():
existing = _build_config_with_openai(REAL_KEY)
incoming_partial = {
"llm": {
"provider": "openai",
"model": existing.llm.model,
"api_key": mask_key(REAL_KEY), # masked placeholder
}
}
merged = merge_user_configurations(existing, incoming_partial)
assert merged.llm.api_key == REAL_KEY # key preserved
def test_merge_replaces_key_when_new_key_provided():
existing = _build_config_with_openai(REAL_KEY)
new_key = "sk-replaced-9999"
incoming_partial = {
"llm": {
"provider": "openai",
"model": existing.llm.model,
"api_key": new_key,
}
}
merged = merge_user_configurations(existing, incoming_partial)
assert merged.llm.api_key == new_key
def test_merge_drops_old_key_when_provider_changes():
existing = _build_config_with_openai(REAL_KEY)
incoming_partial = {
"llm": {
"provider": "groq",
"model": "llama-3.3-70b-versatile",
# api_key intentionally absent should NOT inherit old key
}
}
with pytest.raises(ValidationError):
merge_user_configurations(existing, incoming_partial)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,512 @@
"""Integration tests for CustomToolManager with update_llm_context.
This module tests the full flow of:
1. CustomToolManager fetching and converting tool schemas
2. update_llm_context setting those tools on the LLM context
3. Verifying the context is properly configured for LLM generation
"""
from dataclasses import dataclass
from typing import Any, Dict
from unittest.mock import AsyncMock, Mock, patch
import pytest
from api.services.workflow.pipecat_engine_custom_tools import CustomToolManager
from api.services.workflow.pipecat_engine_utils import (
get_function_schema,
update_llm_context,
)
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.processors.aggregators.llm_context import LLMContext
@dataclass
class MockToolModel:
"""Mock tool model for testing."""
tool_uuid: str
name: str
description: str
definition: Dict[str, Any]
class TestCustomToolManagerContextIntegration:
"""Integration tests for CustomToolManager with LLMContext."""
@pytest.fixture
def mock_engine(self):
"""Create a mock PipecatEngine."""
engine = Mock()
engine._workflow_run_id = 1
engine._call_context_vars = {"customer_name": "John Doe"}
engine.llm = Mock()
engine.llm.register_function = Mock()
return engine
@pytest.fixture
def sample_tools(self):
"""Create sample mock tools for testing."""
return [
MockToolModel(
tool_uuid="weather-uuid-123",
name="Get Weather",
description="Get current weather for a location",
definition={
"schema_version": 1,
"type": "http_api",
"config": {
"method": "GET",
"url": "https://api.weather.com/current",
"parameters": [
{
"name": "location",
"type": "string",
"description": "City name (e.g., San Francisco, CA)",
"required": True,
},
{
"name": "units",
"type": "string",
"description": "Temperature units: celsius or fahrenheit",
"required": False,
},
],
},
},
),
MockToolModel(
tool_uuid="booking-uuid-456",
name="Book Appointment",
description="Book an appointment for the customer",
definition={
"schema_version": 1,
"type": "http_api",
"config": {
"method": "POST",
"url": "https://api.example.com/appointments",
"parameters": [
{
"name": "customer_name",
"type": "string",
"description": "Customer's full name",
"required": True,
},
{
"name": "date",
"type": "string",
"description": "Appointment date (YYYY-MM-DD)",
"required": True,
},
{
"name": "time",
"type": "string",
"description": "Appointment time (HH:MM)",
"required": True,
},
{
"name": "notes",
"type": "string",
"description": "Additional notes",
"required": False,
},
],
},
},
),
MockToolModel(
tool_uuid="lookup-uuid-789",
name="Customer Lookup",
description="Look up customer information by phone number",
definition={
"schema_version": 1,
"type": "http_api",
"config": {
"method": "GET",
"url": "https://api.example.com/customers/lookup",
"parameters": [
{
"name": "phone",
"type": "string",
"description": "Customer phone number",
"required": True,
},
],
},
},
),
]
@pytest.mark.asyncio
async def test_get_tool_schemas_and_update_context(self, mock_engine, sample_tools):
"""Test fetching tool schemas via CustomToolManager and updating LLM context."""
manager = CustomToolManager(mock_engine)
with patch(
"api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run"
) as mock_get_org:
mock_get_org.return_value = 1
with patch(
"api.services.workflow.pipecat_engine_custom_tools.db_client"
) as mock_db:
mock_db.get_tools_by_uuids = AsyncMock(return_value=sample_tools)
# Get tool schemas via CustomToolManager - now returns FunctionSchema objects
tool_uuids = ["weather-uuid-123", "booking-uuid-456", "lookup-uuid-789"]
schemas = await manager.get_tool_schemas(tool_uuids)
# Verify schemas were returned as FunctionSchema objects
assert len(schemas) == 3
assert all(isinstance(s, FunctionSchema) for s in schemas)
# Create context with conversation history
context = LLMContext()
context.set_messages(
[
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": "I need to check the weather and book an appointment.",
},
{
"role": "assistant",
"content": "I can help with both. Where would you like to check the weather?",
},
{"role": "user", "content": "San Francisco"},
]
)
# Update context with new system message and tools
# Now we can pass schemas directly since they're FunctionSchema objects
new_system = {
"role": "system",
"content": "You are a scheduling assistant with access to weather and booking tools.",
}
update_llm_context(context, new_system, schemas)
# Verify context was updated correctly
messages = context.messages
assert len(messages) == 4
assert (
messages[0]["content"]
== "You are a scheduling assistant with access to weather and booking tools."
)
assert messages[1]["role"] == "user"
assert messages[3]["content"] == "San Francisco"
# Verify tools were set
tools = context.tools
assert tools is not None
assert len(tools.standard_tools) == 3
# Verify tool names
tool_names = {t.name for t in tools.standard_tools}
assert tool_names == {
"get_weather",
"book_appointment",
"customer_lookup",
}
@pytest.mark.asyncio
async def test_tool_schemas_have_correct_properties(
self, mock_engine, sample_tools
):
"""Test that tool schemas from CustomToolManager have correct parameter properties."""
manager = CustomToolManager(mock_engine)
with patch(
"api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run"
) as mock_get_org:
mock_get_org.return_value = 1
with patch(
"api.services.workflow.pipecat_engine_custom_tools.db_client"
) as mock_db:
mock_db.get_tools_by_uuids = AsyncMock(return_value=sample_tools)
schemas = await manager.get_tool_schemas(
["weather-uuid-123", "booking-uuid-456"]
)
# Find the booking schema - now using FunctionSchema attributes
booking_schema = next(
s for s in schemas if s.name == "book_appointment"
)
# Verify parameter properties
assert "customer_name" in booking_schema.properties
assert "date" in booking_schema.properties
assert "time" in booking_schema.properties
assert "notes" in booking_schema.properties
# Verify types
assert booking_schema.properties["customer_name"]["type"] == "string"
assert booking_schema.properties["date"]["type"] == "string"
# Verify required
assert "customer_name" in booking_schema.required
assert "date" in booking_schema.required
assert "time" in booking_schema.required
assert "notes" not in booking_schema.required
@pytest.mark.asyncio
async def test_context_update_with_builtin_and_custom_tools(
self, mock_engine, sample_tools
):
"""Test updating context with both built-in and custom tools."""
manager = CustomToolManager(mock_engine)
with patch(
"api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run"
) as mock_get_org:
mock_get_org.return_value = 1
with patch(
"api.services.workflow.pipecat_engine_custom_tools.db_client"
) as mock_db:
mock_db.get_tools_by_uuids = AsyncMock(
return_value=[sample_tools[0]]
) # Just weather
# Get custom tool schemas - returns FunctionSchema objects
custom_schemas = await manager.get_tool_schemas(["weather-uuid-123"])
# Create built-in function schemas (like calculator, timezone)
builtin_functions = [
get_function_schema(
"safe_calculator",
"Evaluate a mathematical expression safely",
properties={
"expression": {
"type": "string",
"description": "Mathematical expression to evaluate",
}
},
required=["expression"],
),
get_function_schema(
"get_current_time",
"Get the current time in a timezone",
properties={
"timezone": {
"type": "string",
"description": "Timezone name (e.g., America/New_York)",
}
},
required=["timezone"],
),
]
# Combine built-in and custom functions - both are FunctionSchema objects
all_functions = builtin_functions + custom_schemas
# Update context
context = LLMContext()
context.set_messages([{"role": "system", "content": "Old prompt"}])
new_system = {
"role": "system",
"content": "Assistant with calculator and weather tools",
}
update_llm_context(context, new_system, all_functions)
# Verify all tools are present
tools = context.tools
assert len(tools.standard_tools) == 3
tool_names = {t.name for t in tools.standard_tools}
assert "safe_calculator" in tool_names
assert "get_current_time" in tool_names
assert "get_weather" in tool_names
@pytest.mark.asyncio
async def test_tools_cached_after_first_fetch(self, mock_engine, sample_tools):
"""Test that CustomToolManager caches tools after first fetch."""
manager = CustomToolManager(mock_engine)
with patch(
"api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run"
) as mock_get_org:
mock_get_org.return_value = 1
with patch(
"api.services.workflow.pipecat_engine_custom_tools.db_client"
) as mock_db:
mock_db.get_tools_by_uuids = AsyncMock(return_value=[sample_tools[0]])
# First fetch
await manager.get_tool_schemas(["weather-uuid-123"])
# Verify tool is cached (cache stores raw schema dict, not FunctionSchema)
cached = manager.get_cached_tool("get_weather")
assert cached is not None
tool, raw_schema = cached
assert tool.tool_uuid == "weather-uuid-123"
assert raw_schema["function"]["name"] == "get_weather"
@pytest.mark.asyncio
async def test_context_preserves_function_call_history(
self, mock_engine, sample_tools
):
"""Test that update_llm_context preserves function call messages in history."""
manager = CustomToolManager(mock_engine)
with patch(
"api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run"
) as mock_get_org:
mock_get_org.return_value = 1
with patch(
"api.services.workflow.pipecat_engine_custom_tools.db_client"
) as mock_db:
mock_db.get_tools_by_uuids = AsyncMock(return_value=[sample_tools[0]])
# Get schemas - returns FunctionSchema objects
schemas = await manager.get_tool_schemas(["weather-uuid-123"])
# Create context with function call history
context = LLMContext()
context.set_messages(
[
{"role": "system", "content": "Old system prompt"},
{"role": "user", "content": "What's the weather in NYC?"},
{
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": "call_123",
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"location": "New York, NY"}',
},
}
],
},
{
"role": "tool",
"tool_call_id": "call_123",
"content": '{"temperature": 72, "condition": "sunny"}',
},
{
"role": "assistant",
"content": "The weather in NYC is 72°F and sunny!",
},
]
)
new_system = {"role": "system", "content": "Updated weather assistant"}
update_llm_context(context, new_system, schemas)
messages = context.messages
# System + user + assistant(tool_call) + tool + assistant = 5
assert len(messages) == 5
# Verify function call messages are preserved
tool_call_msg = messages[2]
assert tool_call_msg["role"] == "assistant"
assert "tool_calls" in tool_call_msg
tool_result_msg = messages[3]
assert tool_result_msg["role"] == "tool"
assert tool_result_msg["tool_call_id"] == "call_123"
@pytest.mark.asyncio
async def test_empty_tool_list_does_not_set_tools(self, mock_engine):
"""Test that empty tool list doesn't set tools on context."""
manager = CustomToolManager(mock_engine)
with patch(
"api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run"
) as mock_get_org:
mock_get_org.return_value = 1
with patch(
"api.services.workflow.pipecat_engine_custom_tools.db_client"
) as mock_db:
mock_db.get_tools_by_uuids = AsyncMock(return_value=[])
schemas = await manager.get_tool_schemas([])
assert schemas == []
context = LLMContext()
context.set_messages([{"role": "system", "content": "Old"}])
new_system = {"role": "system", "content": "No tools available"}
update_llm_context(context, new_system, [])
# Context should have updated message but no tools set
assert context.messages[0]["content"] == "No tools available"
@pytest.mark.asyncio
async def test_numeric_and_boolean_parameter_types(self, mock_engine):
"""Test that numeric and boolean parameter types are correctly handled."""
tool_with_types = MockToolModel(
tool_uuid="order-uuid",
name="Place Order",
description="Place an order for items",
definition={
"schema_version": 1,
"type": "http_api",
"config": {
"method": "POST",
"url": "https://api.example.com/orders",
"parameters": [
{
"name": "item_id",
"type": "string",
"description": "Item identifier",
"required": True,
},
{
"name": "quantity",
"type": "number",
"description": "Number of items",
"required": True,
},
{
"name": "express_shipping",
"type": "boolean",
"description": "Use express shipping",
"required": False,
},
],
},
},
)
manager = CustomToolManager(mock_engine)
with patch(
"api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run"
) as mock_get_org:
mock_get_org.return_value = 1
with patch(
"api.services.workflow.pipecat_engine_custom_tools.db_client"
) as mock_db:
mock_db.get_tools_by_uuids = AsyncMock(return_value=[tool_with_types])
# Get schemas - returns FunctionSchema objects
schemas = await manager.get_tool_schemas(["order-uuid"])
schema = schemas[0]
# Verify types using FunctionSchema attributes
assert schema.properties["item_id"]["type"] == "string"
assert schema.properties["quantity"]["type"] == "number"
assert schema.properties["express_shipping"]["type"] == "boolean"
# Update context - pass schema directly
context = LLMContext()
context.set_messages([{"role": "system", "content": "Old"}])
update_llm_context(
context, {"role": "system", "content": "Order assistant"}, schemas
)
# Verify tool was set with correct types
tool = context.tools.standard_tools[0]
assert tool.name == "place_order"
assert tool.properties["quantity"]["type"] == "number"
assert tool.properties["express_shipping"]["type"] == "boolean"

View file

@ -1,33 +0,0 @@
import os
import uuid
import pytest
from api.db.user_client import UserClient
from api.services.configuration.registry import ServiceProviders
@pytest.mark.asyncio
async def test_default_configuration_created(db_session):
# Set env variable for openai to simulate availability of default key
os.environ["OPENAI_API_KEY"] = "sk-test-openai-key"
# Ensure deepgram env variable absent to focus test
os.environ.pop("DEEPGRAM_API_KEY", None)
# Generate a unique (random) provider user ID for each test run
test_provider_user_id = f"provider_user_{uuid.uuid4().hex}"
user_client: UserClient = db_session # db_session fixture yields the client
user_model = await user_client.get_or_create_user_by_provider_id(
test_provider_user_id
)
config = await user_client.get_user_configurations(user_model.id)
assert config.llm is not None, "LLM config should be created when env key present"
assert config.llm.provider == ServiceProviders.OPENAI
assert config.llm.api_key == "sk-test-openai-key"
# Cleanup / restore env variable side-effects
os.environ.pop("OPENAI_API_KEY", None)

View file

@ -1,122 +0,0 @@
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from api.services.workflow.disposition_mapper import (
apply_disposition_mapping,
get_organization_id_from_workflow_run,
)
@pytest.mark.asyncio
async def test_apply_disposition_mapping_with_valid_mapping():
"""Test disposition mapping with valid configuration."""
with patch("api.services.workflow.disposition_mapper.db_client") as mock_db_client:
# Mock disposition mapping configuration
mock_db_client.get_configuration_value = AsyncMock(
return_value={
"XFER": "TRANSFERRED",
"ND": "NOT_QUALIFIED",
"user_hangup": "HANGUP",
}
)
# Test mapping exists
result = await apply_disposition_mapping("XFER", 1)
assert result == "TRANSFERRED"
# Test mapping doesn't exist
result = await apply_disposition_mapping("UNKNOWN", 1)
assert result == "UNKNOWN"
# Verify db_client was called correctly
mock_db_client.get_configuration_value.assert_called_with(
1, "DISPOSITION_CODE_MAPPING", default={}
)
@pytest.mark.asyncio
async def test_apply_disposition_mapping_no_organization_id():
"""Test disposition mapping with no organization ID."""
# Should return original value
result = await apply_disposition_mapping("XFER", None)
assert result == "XFER"
@pytest.mark.asyncio
async def test_apply_disposition_mapping_empty_value():
"""Test disposition mapping with empty value."""
# Should return original empty value
result = await apply_disposition_mapping("", 1)
assert result == ""
@pytest.mark.asyncio
async def test_apply_disposition_mapping_error_handling():
"""Test disposition mapping handles errors gracefully."""
with patch("api.services.workflow.disposition_mapper.db_client") as mock_db_client:
# Mock database error
mock_db_client.get_configuration_value = AsyncMock(
side_effect=Exception("Database error")
)
# Should return original value on error
result = await apply_disposition_mapping("XFER", 1)
assert result == "XFER"
@pytest.mark.asyncio
async def test_get_organization_id_from_workflow_run():
"""Test getting organization ID from workflow run ID."""
with patch("api.services.workflow.disposition_mapper.db_client") as mock_db_client:
# Mock workflow run with organization
mock_workflow_run = MagicMock()
mock_workflow_run.workflow.user.selected_organization_id = 123
mock_db_client.get_workflow_run_by_id = AsyncMock(
return_value=mock_workflow_run
)
result = await get_organization_id_from_workflow_run(1)
assert result == 123
# Verify db_client was called correctly
mock_db_client.get_workflow_run_by_id.assert_called_once_with(1)
@pytest.mark.asyncio
async def test_get_organization_id_no_workflow_run():
"""Test getting organization ID when workflow run doesn't exist."""
with patch("api.services.workflow.disposition_mapper.db_client") as mock_db_client:
# Mock no workflow run found
mock_db_client.get_workflow_run_by_id = AsyncMock(return_value=None)
result = await get_organization_id_from_workflow_run(1)
assert result is None
@pytest.mark.asyncio
async def test_get_organization_id_no_user():
"""Test getting organization ID when workflow has no user."""
with patch("api.services.workflow.disposition_mapper.db_client") as mock_db_client:
# Mock workflow run with no user
mock_workflow_run = MagicMock()
mock_workflow_run.workflow.user = None
mock_db_client.get_workflow_run_by_id = AsyncMock(
return_value=mock_workflow_run
)
result = await get_organization_id_from_workflow_run(1)
assert result is None
@pytest.mark.asyncio
async def test_get_organization_id_error_handling():
"""Test getting organization ID handles errors gracefully."""
with patch("api.services.workflow.disposition_mapper.db_client") as mock_db_client:
# Mock database error
mock_db_client.get_workflow_run_by_id = AsyncMock(
side_effect=Exception("Database error")
)
result = await get_organization_id_from_workflow_run(1)
assert result is None

View file

@ -1,370 +0,0 @@
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from pipecat.utils.enums import EndTaskReason
from api.services.pipecat.event_handlers import register_transport_event_handlers
@pytest.fixture
def mock_dependencies():
"""Create mock dependencies for event handlers."""
# Store registered handlers
registered_handlers = {}
def mock_event_handler(event_name):
def decorator(func):
registered_handlers[event_name] = func
return func
return decorator
mock_transport = MagicMock()
mock_transport.event_handler = mock_event_handler
mock_task = MagicMock()
mock_task.cancel = AsyncMock()
mock_engine = MagicMock()
mock_engine.initialize = AsyncMock()
mock_engine.cleanup = AsyncMock()
mock_audio_buffer = MagicMock()
mock_audio_buffer.start_recording = AsyncMock()
mock_audio_buffer.stop_recording = AsyncMock()
mock_usage_metrics_aggregator = MagicMock()
mock_usage_metrics_aggregator.get_all_usage_metrics_serialized = MagicMock(
return_value={"test": "metrics"}
)
return {
"transport": mock_transport,
"workflow_run_id": 123,
"audio_buffer": mock_audio_buffer,
"task": mock_task,
"engine": mock_engine,
"usage_metrics_aggregator": mock_usage_metrics_aggregator,
"audio_synchronizer": None,
"registered_handlers": registered_handlers,
}
@pytest.mark.asyncio
async def test_transport_disconnect_reason_mapping(mock_dependencies):
"""Test that transport_disconnect_reason is mapped when no engine disconnect reason exists."""
# Register event handlers
register_transport_event_handlers(
transport=mock_dependencies["transport"],
workflow_run_id=mock_dependencies["workflow_run_id"],
audio_buffer=mock_dependencies["audio_buffer"],
task=mock_dependencies["task"],
engine=mock_dependencies["engine"],
usage_metrics_aggregator=mock_dependencies["usage_metrics_aggregator"],
audio_synchronizer=mock_dependencies["audio_synchronizer"],
)
# Get the on_client_disconnected handler
handler = mock_dependencies["registered_handlers"]["on_client_disconnected"]
# Mock engine with no call disposition
mock_dependencies["engine"].get_call_disposition.return_value = None
mock_dependencies["engine"].get_gathered_context.return_value = {
"agent_name": "Alex"
}
# Mock the disposition mapper functions
with patch(
"api.services.pipecat.event_handlers.get_organization_id_from_workflow_run",
new_callable=AsyncMock,
) as mock_get_org_id:
with patch(
"api.services.pipecat.event_handlers.apply_disposition_mapping",
new_callable=AsyncMock,
) as mock_apply_mapping:
with patch(
"api.services.pipecat.event_handlers.db_client"
) as mock_db_client:
with patch(
"api.services.pipecat.event_handlers.enqueue_job"
) as mock_enqueue:
# Mock organization ID
mock_get_org_id.return_value = 1
# Mock call duration for user_hangup logic
mock_dependencies[
"usage_metrics_aggregator"
].get_call_duration.return_value = 15
# Mock disposition mapping
async def apply_mapping_side_effect(value, org_id):
return {
"NIBP": "NOT_INTERESTED_BUSINESS_PURPOSE",
"user_qualified": "QUALIFIED",
}.get(value, value)
mock_apply_mapping.side_effect = apply_mapping_side_effect
# Mock database operations
mock_workflow_run = MagicMock()
mock_workflow_run.id = 123
mock_workflow_run.workflow_id = 1
mock_workflow_run.organization_id = 1
mock_workflow_run.gathered_context = {}
mock_db_client.get_workflow_run_by_id = AsyncMock(
return_value=mock_workflow_run
)
mock_db_client.update_workflow_run = AsyncMock()
# Call handler with transport_disconnect_reason
await handler(
mock_dependencies["transport"],
participant=None,
transport_disconnect_reason="user_hangup",
)
# Verify disposition mapping was applied with NIBP (since duration > 10)
mock_apply_mapping.assert_called_once_with("NIBP", 1)
# Verify database was updated with mapped value
mock_db_client.update_workflow_run.assert_called_once()
call_args = mock_db_client.update_workflow_run.call_args
assert (
call_args[1]["gathered_context"]["mapped_call_disposition"]
== "NOT_INTERESTED_BUSINESS_PURPOSE"
)
# Verify task was cancelled (no engine disconnect reason)
mock_dependencies["task"].cancel.assert_called_once()
@pytest.mark.asyncio
async def test_transport_disconnect_reason_user_hangup_short_call(mock_dependencies):
"""Test that user_hangup with short call duration is mapped to HU."""
# Register event handlers
register_transport_event_handlers(
transport=mock_dependencies["transport"],
workflow_run_id=mock_dependencies["workflow_run_id"],
audio_buffer=mock_dependencies["audio_buffer"],
task=mock_dependencies["task"],
engine=mock_dependencies["engine"],
usage_metrics_aggregator=mock_dependencies["usage_metrics_aggregator"],
audio_synchronizer=mock_dependencies["audio_synchronizer"],
)
# Get the on_client_disconnected handler
handler = mock_dependencies["registered_handlers"]["on_client_disconnected"]
# Mock engine with no call disposition
mock_dependencies["engine"].get_call_disposition.return_value = None
mock_dependencies["engine"].get_gathered_context.return_value = {
"agent_name": "Alex"
}
# Mock the disposition mapper functions
with patch(
"api.services.pipecat.event_handlers.get_organization_id_from_workflow_run",
new_callable=AsyncMock,
) as mock_get_org_id:
with patch(
"api.services.pipecat.event_handlers.apply_disposition_mapping",
new_callable=AsyncMock,
) as mock_apply_mapping:
with patch(
"api.services.pipecat.event_handlers.db_client"
) as mock_db_client:
with patch(
"api.services.pipecat.event_handlers.enqueue_job"
) as mock_enqueue:
# Mock organization ID
mock_get_org_id.return_value = 1
# Mock call duration for user_hangup logic (< 10 seconds)
mock_dependencies[
"usage_metrics_aggregator"
].get_call_duration.return_value = 5
# Mock disposition mapping
mock_apply_mapping.return_value = "HANGUP"
# Mock database operations
mock_workflow_run = MagicMock()
mock_workflow_run.id = 123
mock_workflow_run.workflow_id = 1
mock_workflow_run.organization_id = 1
mock_workflow_run.gathered_context = {}
mock_db_client.get_workflow_run_by_id = AsyncMock(
return_value=mock_workflow_run
)
mock_db_client.update_workflow_run = AsyncMock()
# Call handler with transport_disconnect_reason
await handler(
mock_dependencies["transport"],
participant=None,
transport_disconnect_reason="user_hangup",
)
# Verify disposition mapping was applied with HU (since duration < 10)
mock_apply_mapping.assert_called_once_with("HU", 1)
# Verify database was updated with mapped value
mock_db_client.update_workflow_run.assert_called_once()
call_args = mock_db_client.update_workflow_run.call_args
assert (
call_args[1]["gathered_context"]["mapped_call_disposition"]
== "HANGUP"
)
# Verify task was cancelled (no engine disconnect reason)
mock_dependencies["task"].cancel.assert_called_once()
@pytest.mark.asyncio
async def test_engine_disconnect_reason_takes_precedence(mock_dependencies):
"""Test that engine disconnect reason takes precedence and is not mapped."""
# Register event handlers
register_transport_event_handlers(
transport=mock_dependencies["transport"],
workflow_run_id=mock_dependencies["workflow_run_id"],
audio_buffer=mock_dependencies["audio_buffer"],
task=mock_dependencies["task"],
engine=mock_dependencies["engine"],
usage_metrics_aggregator=mock_dependencies["usage_metrics_aggregator"],
audio_synchronizer=mock_dependencies["audio_synchronizer"],
)
# Get the on_client_disconnected handler
handler = mock_dependencies["registered_handlers"]["on_client_disconnected"]
# Mock engine with call disposition
mock_dependencies["engine"].get_call_disposition.return_value = "user_qualified"
mock_dependencies["engine"].get_gathered_context.return_value = {
"agent_name": "Alex"
}
# Mock the disposition mapper functions
with patch(
"api.services.pipecat.event_handlers.get_organization_id_from_workflow_run",
new_callable=AsyncMock,
) as mock_get_org_id:
with patch(
"api.services.pipecat.event_handlers.apply_disposition_mapping",
new_callable=AsyncMock,
) as mock_apply_mapping:
with patch(
"api.services.pipecat.event_handlers.db_client"
) as mock_db_client:
with patch(
"api.services.pipecat.event_handlers.enqueue_job"
) as mock_enqueue:
# Mock organization ID
mock_get_org_id.return_value = 1
# Mock disposition mapping for engine's reason
mock_apply_mapping.return_value = "QUALIFIED"
# Mock database operations
mock_workflow_run = MagicMock()
mock_workflow_run.id = 123
mock_workflow_run.workflow_id = 1
mock_workflow_run.organization_id = 1
mock_workflow_run.gathered_context = {}
mock_db_client.get_workflow_run_by_id = AsyncMock(
return_value=mock_workflow_run
)
mock_db_client.update_workflow_run = AsyncMock()
# Call handler with transport_disconnect_reason
await handler(
mock_dependencies["transport"],
participant=None,
transport_disconnect_reason="user_hangup",
)
# Verify disposition mapping was called with engine's reason
mock_apply_mapping.assert_called_once_with("user_qualified", 1)
# Verify database was updated with mapped value
mock_db_client.update_workflow_run.assert_called_once()
call_args = mock_db_client.update_workflow_run.call_args
assert (
call_args[1]["gathered_context"]["mapped_call_disposition"]
== "QUALIFIED"
)
# Verify task was NOT cancelled (engine disconnect reason exists)
mock_dependencies["task"].cancel.assert_not_called()
@pytest.mark.asyncio
async def test_no_disconnect_reason_uses_unknown(mock_dependencies):
"""Test that when no disconnect reason is provided, UNKNOWN is used."""
# Register event handlers
register_transport_event_handlers(
transport=mock_dependencies["transport"],
workflow_run_id=mock_dependencies["workflow_run_id"],
audio_buffer=mock_dependencies["audio_buffer"],
task=mock_dependencies["task"],
engine=mock_dependencies["engine"],
usage_metrics_aggregator=mock_dependencies["usage_metrics_aggregator"],
audio_synchronizer=mock_dependencies["audio_synchronizer"],
)
# Get the on_client_disconnected handler
handler = mock_dependencies["registered_handlers"]["on_client_disconnected"]
# Mock engine with no call disposition
mock_dependencies["engine"].get_call_disposition.return_value = None
mock_dependencies["engine"].get_gathered_context.return_value = {
"agent_name": "Alex"
}
with patch(
"api.services.pipecat.event_handlers.get_organization_id_from_workflow_run"
) as mock_get_org_id:
with patch(
"api.services.pipecat.event_handlers.apply_disposition_mapping"
) as mock_apply_mapping:
with patch(
"api.services.pipecat.event_handlers.db_client"
) as mock_db_client:
with patch(
"api.services.pipecat.event_handlers.enqueue_job"
) as mock_enqueue:
# Mock organization ID
mock_get_org_id.return_value = 1
# Mock disposition mapping - should return UNKNOWN as-is
mock_apply_mapping.return_value = EndTaskReason.UNKNOWN.value
# Mock database operations
mock_workflow_run = MagicMock()
mock_workflow_run.id = 123
mock_workflow_run.workflow_id = 1
mock_workflow_run.organization_id = 1
mock_workflow_run.gathered_context = {}
mock_db_client.get_workflow_run_by_id = AsyncMock(
return_value=mock_workflow_run
)
mock_db_client.update_workflow_run = AsyncMock()
# Call handler without transport_disconnect_reason
await handler(
mock_dependencies["transport"],
participant=None,
transport_disconnect_reason=None,
)
# Verify disposition mapping was called with UNKNOWN
mock_apply_mapping.assert_called_once_with(
EndTaskReason.UNKNOWN.value, 1
)
# Verify database was updated with UNKNOWN
mock_db_client.update_workflow_run.assert_called_once()
call_args = mock_db_client.update_workflow_run.call_args
assert (
call_args[1]["gathered_context"]["mapped_call_disposition"]
== EndTaskReason.UNKNOWN.value
)

View file

@ -1,184 +0,0 @@
from unittest.mock import AsyncMock, MagicMock
import pytest
from api.services.pipecat.audio_config import AudioConfig
from api.services.pipecat.event_handlers import (
register_audio_data_handler,
register_transcript_handler,
register_transport_event_handlers,
)
@pytest.mark.asyncio
async def test_transport_handlers_with_in_memory_buffers():
"""Test that transport handlers create and return in-memory buffers."""
# Mock dependencies
transport = MagicMock()
transport.event_handler = lambda event_name: lambda func: func
audio_buffer = AsyncMock()
audio_synchronizer = AsyncMock()
task = AsyncMock()
engine = AsyncMock()
engine.get_call_disposition.return_value = None
engine.get_gathered_context.return_value = {}
usage_metrics_aggregator = AsyncMock()
usage_metrics_aggregator.get_call_duration.return_value = 30
usage_metrics_aggregator.get_all_usage_metrics_serialized.return_value = {}
# Create test audio config
audio_config = AudioConfig(
transport_in_sample_rate=16000,
transport_out_sample_rate=16000,
pipeline_sample_rate=16000,
)
# Register handlers
audio_buf, transcript_buf = register_transport_event_handlers(
transport=transport,
workflow_run_id=123,
audio_buffer=audio_buffer,
task=task,
engine=engine,
usage_metrics_aggregator=usage_metrics_aggregator,
audio_synchronizer=audio_synchronizer,
audio_config=audio_config,
)
# Verify buffers were created with correct configuration
assert audio_buf is not None
assert transcript_buf is not None
assert audio_buf._workflow_run_id == 123
assert audio_buf._sample_rate == 16000
assert audio_buf._num_channels == 1
assert transcript_buf._workflow_run_id == 123
@pytest.mark.asyncio
async def test_audio_handler_with_in_memory_buffer():
"""Test audio handler uses in-memory buffer when provided."""
# Mock audio synchronizer
audio_synchronizer = MagicMock()
handlers = {}
def mock_event_handler(event_name):
def decorator(func):
handlers[event_name] = func
return func
return decorator
audio_synchronizer.event_handler = mock_event_handler
# Mock in-memory buffer
in_memory_buffer = AsyncMock()
# Register handler with buffer
register_audio_data_handler(
audio_synchronizer, workflow_run_id=123, in_memory_buffer=in_memory_buffer
)
# Test the handler
assert "on_merged_audio" in handlers
handler = handlers["on_merged_audio"]
# Call handler with test data
test_pcm = b"test_audio_data"
await handler(None, test_pcm, 16000, 1)
# Verify buffer was used
in_memory_buffer.append.assert_called_once_with(test_pcm)
@pytest.mark.asyncio
async def test_transcript_handler_with_in_memory_buffer():
"""Test transcript handler uses in-memory buffer when provided."""
# Mock transcript processor
transcript = MagicMock()
handlers = {}
def mock_event_handler(event_name):
def decorator(func):
handlers[event_name] = func
return func
return decorator
transcript.event_handler = mock_event_handler
# Mock in-memory buffer
in_memory_buffer = AsyncMock()
# Register handler with buffer
register_transcript_handler(
transcript, workflow_run_id=456, in_memory_buffer=in_memory_buffer
)
# Create test frame
test_frame = MagicMock()
test_frame.messages = [
MagicMock(timestamp="00:00:01", role="user", content="Hello"),
MagicMock(timestamp="00:00:02", role="assistant", content="Hi there"),
]
# Test the handler
handler = handlers["on_transcript_update"]
await handler(None, test_frame)
# Verify buffer was used with correct format
expected_text = "[00:00:01] user: Hello\n[00:00:02] assistant: Hi there\n"
in_memory_buffer.append.assert_called_once_with(expected_text)
@pytest.mark.asyncio
async def test_audio_config_sample_rates():
"""Test that different audio configs result in correct sample rates."""
# Mock dependencies
transport = MagicMock()
transport.event_handler = lambda event_name: lambda func: func
audio_buffer = AsyncMock()
audio_synchronizer = AsyncMock()
task = AsyncMock()
engine = AsyncMock()
engine.get_call_disposition.return_value = None
engine.get_gathered_context.return_value = {}
usage_metrics_aggregator = AsyncMock()
usage_metrics_aggregator.get_all_usage_metrics_serialized.return_value = {}
# Test with 8kHz audio config (e.g., for Stasis/Twilio)
audio_config_8k = AudioConfig(
transport_in_sample_rate=8000,
transport_out_sample_rate=8000,
pipeline_sample_rate=8000,
)
audio_buf_8k, _ = register_transport_event_handlers(
transport=transport,
workflow_run_id=456,
audio_buffer=audio_buffer,
task=task,
engine=engine,
usage_metrics_aggregator=usage_metrics_aggregator,
audio_synchronizer=audio_synchronizer,
audio_config=audio_config_8k,
)
assert audio_buf_8k._sample_rate == 8000
# Test with no audio config (should default to 16kHz)
audio_buf_default, _ = register_transport_event_handlers(
transport=transport,
workflow_run_id=789,
audio_buffer=audio_buffer,
task=task,
engine=engine,
usage_metrics_aggregator=usage_metrics_aggregator,
audio_synchronizer=audio_synchronizer,
audio_config=None,
)
assert audio_buf_default._sample_rate == 16000

View file

@ -1,162 +0,0 @@
"""Test filter functionality."""
from unittest.mock import MagicMock
from api.db.filters import ATTRIBUTE_FIELD_MAPPING, apply_workflow_run_filters
def test_attribute_field_mapping():
"""Test that all required attributes are mapped."""
expected_attributes = [
"dateRange",
"dispositionCode",
"duration",
"status",
"tokenUsage",
"runId",
"workflowId",
"callTags",
"phoneNumber",
]
for attr in expected_attributes:
assert attr in ATTRIBUTE_FIELD_MAPPING, f"Missing mapping for {attr}"
def test_filter_with_explicit_type():
"""Test that filters work with explicit type from UI."""
# Mock query
mock_query = MagicMock()
mock_query.where = MagicMock(return_value=mock_query)
test_cases = [
# Date range filter
{
"filters": [
{
"attribute": "dateRange",
"type": "dateRange",
"value": {"from": "2024-01-01", "to": "2024-01-31"},
}
],
},
# Multi-select filter
{
"filters": [
{
"attribute": "dispositionCode",
"type": "multiSelect",
"value": {"codes": ["XFER", "HU"]},
}
],
},
# Number range filter
{
"filters": [
{
"attribute": "duration",
"type": "numberRange",
"value": {"min": 60, "max": 300},
}
],
},
# Radio/status filter
{
"filters": [
{
"attribute": "status",
"type": "radio",
"value": {"status": "completed"},
}
],
},
# Number filter
{
"filters": [
{"attribute": "runId", "type": "number", "value": {"value": 123}}
],
},
# Text filter
{
"filters": [
{
"attribute": "phoneNumber",
"type": "text",
"value": {"value": "+1234567890"},
}
],
},
# Tags filter
{
"filters": [
{
"attribute": "callTags",
"type": "tags",
"value": {"codes": ["tag1", "tag2"]},
}
],
},
]
for test_case in test_cases:
result = apply_workflow_run_filters(mock_query, test_case["filters"])
# The function should process the filter without errors
assert result is not None
def test_filter_format_with_type():
"""Test that filters work with attribute, type, and value."""
mock_query = MagicMock()
mock_query.where = MagicMock(return_value=mock_query)
# Test with various filter combinations
filters = [
{
"attribute": "dispositionCode",
"type": "multiSelect",
"value": {"codes": ["NIBP"]},
},
{
"attribute": "duration",
"type": "numberRange",
"value": {"min": 0, "max": 60},
},
{"attribute": "phoneNumber", "type": "text", "value": {"value": "555"}},
]
result = apply_workflow_run_filters(mock_query, filters)
# Should have called where() for applying filters
assert mock_query.where.called
assert result is not None
def test_unknown_attribute_ignored():
"""Test that unknown attributes are safely ignored."""
mock_query = MagicMock()
mock_query.where = MagicMock(return_value=mock_query)
filters = [
{"attribute": "unknownAttribute", "value": {"value": "test"}},
{"attribute": "dispositionCode", "value": {"codes": ["XFER"]}},
]
result = apply_workflow_run_filters(mock_query, filters)
# Should still process the valid filter
assert result is not None
def test_empty_filters():
"""Test that empty filters return the query unchanged."""
mock_query = MagicMock()
result = apply_workflow_run_filters(mock_query, None)
assert result == mock_query
result = apply_workflow_run_filters(mock_query, [])
assert result == mock_query

View file

@ -1,249 +0,0 @@
"""Tests for global prompt functionality in workflow engine."""
from unittest.mock import Mock
import pytest
from pipecat.services.openai.llm import OpenAILLMContext
from api.services.workflow.dto import (
EdgeDataDTO,
NodeDataDTO,
NodeType,
ReactFlowDTO,
RFEdgeDTO,
RFNodeDTO,
)
from api.services.workflow.pipecat_engine import PipecatEngine
from api.services.workflow.workflow import WorkflowGraph
class TestGlobalPrompt:
"""Test suite for global prompt feature."""
@pytest.fixture
def workflow_with_global_node(self):
"""Create a workflow with a global node and test nodes."""
nodes = [
RFNodeDTO(
id="global",
type=NodeType.globalNode,
position={"x": 0, "y": 0},
data=NodeDataDTO(
name="Global Node",
prompt="This is the global context: {{company_name}}",
is_static=False,
),
),
RFNodeDTO(
id="start",
type=NodeType.startNode,
position={"x": 100, "y": 100},
data=NodeDataDTO(
name="Start Call",
prompt="Welcome to our service!",
is_static=False,
is_start=True,
add_global_prompt=True, # Enable global prompt
),
),
RFNodeDTO(
id="agent1",
type=NodeType.agentNode,
position={"x": 200, "y": 200},
data=NodeDataDTO(
name="Agent 1",
prompt="How can I help you today?",
add_global_prompt=False, # Disable global prompt
),
),
RFNodeDTO(
id="agent2",
type=NodeType.agentNode,
position={"x": 300, "y": 300},
data=NodeDataDTO(
name="Agent 2",
prompt="Please provide your details.",
add_global_prompt=True, # Enable global prompt
),
),
RFNodeDTO(
id="end",
type=NodeType.endNode,
position={"x": 400, "y": 400},
data=NodeDataDTO(
name="End Call",
prompt="Thank you for calling!",
is_static=True,
is_end=True,
add_global_prompt=True, # Enable global prompt (but static)
),
),
]
edges = [
RFEdgeDTO(
id="e1",
source="start",
target="agent1",
data=EdgeDataDTO(label="Next", condition="Continue to agent"),
),
RFEdgeDTO(
id="e2",
source="agent1",
target="agent2",
data=EdgeDataDTO(label="Details", condition="Get user details"),
),
RFEdgeDTO(
id="e3",
source="agent2",
target="end",
data=EdgeDataDTO(label="Finish", condition="End the call"),
),
]
flow_dto = ReactFlowDTO(nodes=nodes, edges=edges)
return WorkflowGraph(flow_dto)
@pytest.fixture
def mock_dependencies(self):
"""Create mock dependencies for PipecatEngine initialization."""
return {
"task": Mock(),
"llm": Mock(),
"context": Mock(spec=OpenAILLMContext),
"tts": Mock(),
"transport": Mock(),
"call_context_vars": {"company_name": "Dograh Inc"},
}
@pytest.fixture
def engine(self, mock_dependencies, workflow_with_global_node):
"""Create a PipecatEngine instance with test workflow."""
mock_dependencies["workflow"] = workflow_with_global_node
return PipecatEngine(**mock_dependencies)
@pytest.mark.asyncio
async def test_global_prompt_enabled(self, engine):
"""Test that global prompt is prepended when add_global_prompt is True."""
# Test with start node (add_global_prompt=True)
start_node = engine.workflow.nodes["start"]
(
system_message,
functions,
) = await engine._compose_system_message_functions_for_node(start_node)
# Global prompt should be included
expected_content = (
"This is the global context: Dograh Inc\n\nWelcome to our service!"
)
assert system_message["content"] == expected_content
assert system_message["role"] == "system"
@pytest.mark.asyncio
async def test_global_prompt_disabled(self, engine):
"""Test that global prompt is not prepended when add_global_prompt is False."""
# Test with agent1 node (add_global_prompt=False)
agent1_node = engine.workflow.nodes["agent1"]
(
system_message,
functions,
) = await engine._compose_system_message_functions_for_node(agent1_node)
# Global prompt should NOT be included
expected_content = "How can I help you today?"
assert system_message["content"] == expected_content
assert "global context" not in system_message["content"]
@pytest.mark.asyncio
async def test_global_prompt_with_static_node(self, engine):
"""Test that static nodes don't use global prompt in engine (even if enabled)."""
# Static nodes are handled differently - they use TTSSpeakFrame directly
# This test verifies the compose_system_message behavior for completeness
end_node = engine.workflow.nodes["end"]
# Even though add_global_prompt=True, static nodes handle prompts differently
# The _compose_system_message_functions_for_node is still called for consistency
(
system_message,
functions,
) = await engine._compose_system_message_functions_for_node(end_node)
# For static nodes, the global prompt would still be composed if enabled
expected_content = (
"This is the global context: Dograh Inc\n\nThank you for calling!"
)
assert system_message["content"] == expected_content
@pytest.mark.asyncio
async def test_global_prompt_variable_substitution(self, engine):
"""Test that variables in global prompt are properly substituted."""
agent2_node = engine.workflow.nodes["agent2"]
(
system_message,
functions,
) = await engine._compose_system_message_functions_for_node(agent2_node)
# Verify variable substitution in global prompt
assert "Dograh Inc" in system_message["content"]
assert "{{company_name}}" not in system_message["content"]
# Full expected content
expected_content = (
"This is the global context: Dograh Inc\n\nPlease provide your details."
)
assert system_message["content"] == expected_content
@pytest.mark.asyncio
async def test_no_global_node_scenario(self, engine):
"""Test behavior when there's no global node in the workflow."""
# Remove global node from workflow
engine.workflow.global_node_id = None
start_node = engine.workflow.nodes["start"]
(
system_message,
functions,
) = await engine._compose_system_message_functions_for_node(start_node)
# Should only have the node's own prompt
assert system_message["content"] == "Welcome to our service!"
@pytest.mark.asyncio
async def test_empty_global_prompt(self, engine):
"""Test behavior when global prompt is empty."""
# Set global prompt to empty string
engine.workflow.nodes["global"].prompt = ""
start_node = engine.workflow.nodes["start"]
(
system_message,
functions,
) = await engine._compose_system_message_functions_for_node(start_node)
# Should only have the node's own prompt (empty global prompt is filtered out)
assert system_message["content"] == "Welcome to our service!"
def test_default_add_global_prompt_value(self):
"""Test that add_global_prompt defaults to True in NodeDataDTO."""
node_data = NodeDataDTO(name="Test", prompt="Test prompt")
assert node_data.add_global_prompt is True
@pytest.mark.asyncio
async def test_multiple_prompts_concatenation(self, engine):
"""Test proper concatenation of global and node prompts."""
# Test with agent2 node that has global prompt enabled
agent2_node = engine.workflow.nodes["agent2"]
(
system_message,
functions,
) = await engine._compose_system_message_functions_for_node(agent2_node)
# Should have global and node prompts concatenated with double newlines
# (extraction prompt is no longer included in system message)
expected_parts = [
"This is the global context: Dograh Inc",
"Please provide your details.",
]
expected_content = "\n\n".join(expected_parts)
assert system_message["content"] == expected_content

View file

@ -1,175 +0,0 @@
"""Unit tests for global prompt functionality - no DB dependencies."""
import sys
from pathlib import Path
# Add the api directory to the Python path
api_path = Path(__file__).parent.parent
sys.path.insert(0, str(api_path))
from services.workflow.dto import (
EdgeDataDTO,
NodeDataDTO,
NodeType,
ReactFlowDTO,
RFEdgeDTO,
RFNodeDTO,
)
from services.workflow.workflow import WorkflowGraph
def test_node_data_dto_default_global_prompt():
"""Test that add_global_prompt defaults to True."""
node_data = NodeDataDTO(name="Test Node", prompt="Test prompt")
assert node_data.add_global_prompt is True
print("✓ NodeDataDTO defaults add_global_prompt to True")
def test_node_data_dto_explicit_global_prompt():
"""Test explicit setting of add_global_prompt."""
# Test with False
node_data_false = NodeDataDTO(
name="Test Node", prompt="Test prompt", add_global_prompt=False
)
assert node_data_false.add_global_prompt is False
# Test with True
node_data_true = NodeDataDTO(
name="Test Node", prompt="Test prompt", add_global_prompt=True
)
assert node_data_true.add_global_prompt is True
print("✓ NodeDataDTO respects explicit add_global_prompt values")
def test_workflow_node_inherits_global_prompt_setting():
"""Test that workflow Node inherits add_global_prompt from NodeDataDTO."""
nodes = [
RFNodeDTO(
id="start",
type=NodeType.startNode,
position={"x": 0, "y": 0},
data=NodeDataDTO(
name="Start",
prompt="Start prompt",
is_start=True,
add_global_prompt=True,
),
),
RFNodeDTO(
id="node1",
type=NodeType.agentNode,
position={"x": 100, "y": 0},
data=NodeDataDTO(
name="Node with global", prompt="Test prompt", add_global_prompt=True
),
),
RFNodeDTO(
id="node2",
type=NodeType.agentNode,
position={"x": 200, "y": 0},
data=NodeDataDTO(
name="Node without global",
prompt="Test prompt",
add_global_prompt=False,
),
),
RFNodeDTO(
id="end",
type=NodeType.endNode,
position={"x": 300, "y": 0},
data=NodeDataDTO(
name="End", prompt="End prompt", is_end=True, add_global_prompt=True
),
),
]
edges = [
RFEdgeDTO(
id="e1",
source="start",
target="node1",
data=EdgeDataDTO(label="Next", condition="Continue"),
),
RFEdgeDTO(
id="e2",
source="node1",
target="node2",
data=EdgeDataDTO(label="Next", condition="Continue"),
),
RFEdgeDTO(
id="e3",
source="node2",
target="end",
data=EdgeDataDTO(label="End", condition="Finish"),
),
]
flow_dto = ReactFlowDTO(nodes=nodes, edges=edges)
workflow = WorkflowGraph(flow_dto)
assert workflow.nodes["start"].add_global_prompt is True
assert workflow.nodes["node1"].add_global_prompt is True
assert workflow.nodes["node2"].add_global_prompt is False
assert workflow.nodes["end"].add_global_prompt is True
print("✓ Workflow nodes correctly inherit add_global_prompt setting")
def test_compose_system_message_respects_global_prompt_flag():
"""Test that system message composition respects add_global_prompt flag."""
# This is a simplified version - in real tests we'd use the full engine
# But this demonstrates the logic
class MockNode:
def __init__(self, add_global_prompt, prompt):
self.add_global_prompt = add_global_prompt
self.prompt = prompt
self.out_edges = []
self.extraction_enabled = False
# Simulate the logic from _compose_system_message_functions_for_node
def compose_message(node, global_prompt):
prompts = []
# Only add global prompt if node.add_global_prompt is True
if global_prompt and node.add_global_prompt:
prompts.append(global_prompt)
prompts.append(node.prompt)
return "\n\n".join(p for p in prompts if p)
global_prompt = "This is the global context"
# Test with add_global_prompt=True
node_with_global = MockNode(add_global_prompt=True, prompt="Node prompt")
message_with = compose_message(node_with_global, global_prompt)
assert message_with == "This is the global context\n\nNode prompt"
# Test with add_global_prompt=False
node_without_global = MockNode(add_global_prompt=False, prompt="Node prompt")
message_without = compose_message(node_without_global, global_prompt)
assert message_without == "Node prompt"
print("✓ System message composition respects add_global_prompt flag")
def test_static_nodes_with_global_prompt():
"""Test static nodes can have add_global_prompt setting."""
static_node_data = NodeDataDTO(
name="Static Node", prompt="Static text", is_static=True, add_global_prompt=True
)
assert static_node_data.is_static is True
assert static_node_data.add_global_prompt is True
print("✓ Static nodes can have add_global_prompt setting")
if __name__ == "__main__":
# Run all tests
test_node_data_dto_default_global_prompt()
test_node_data_dto_explicit_global_prompt()
test_workflow_node_inherits_global_prompt_setting()
test_compose_system_message_respects_global_prompt_flag()
test_static_nodes_with_global_prompt()
print("\n✅ All unit tests passed!")

View file

@ -1,248 +0,0 @@
"""
Test cases for _leave_counter mechanism in transport clients.
This test suite verifies that the _leave_counter prevents premature disconnection
when both input and output transports are using the same client.
"""
import asyncio
from unittest.mock import AsyncMock, Mock
import pytest
from pipecat.frames.frames import EndFrame, StartFrame
from pipecat.transports.network.fastapi_websocket import (
FastAPIWebsocketCallbacks,
FastAPIWebsocketClient,
FastAPIWebsocketParams,
FastAPIWebsocketTransport,
)
from pipecat.transports.network.small_webrtc import SmallWebRTCClient
from api.services.telephony.stasis_rtp_client import StasisRTPClient
class TestLeaveCounterFastAPIWebsocket:
"""Test the _leave_counter mechanism in FastAPIWebsocketClient."""
@pytest.mark.asyncio
async def test_leave_counter_prevents_early_disconnect(self):
"""Test that disconnect only happens when both transports have disconnected."""
# Create mock websocket
mock_websocket = Mock()
mock_websocket.close = AsyncMock()
# Set client_state directly to WebSocketState.CONNECTED value
from starlette.websockets import WebSocketState
mock_websocket.client_state = WebSocketState.CONNECTED
# Create callbacks
callbacks = FastAPIWebsocketCallbacks(
on_client_connected=AsyncMock(),
on_client_disconnected=AsyncMock(),
on_session_timeout=AsyncMock(),
)
# Create client
client = FastAPIWebsocketClient(
mock_websocket, is_binary=False, callbacks=callbacks
)
# Create StartFrame
start_frame = StartFrame()
# Simulate both input and output transports calling setup
await client.setup(start_frame) # Input transport
assert client._leave_counter == 1
await client.setup(start_frame) # Output transport
assert client._leave_counter == 2
# First disconnect - should not actually disconnect
await client.disconnect()
assert client._leave_counter == 1
mock_websocket.close.assert_not_called()
callbacks.on_client_disconnected.assert_not_called()
# Second disconnect - should actually disconnect
await client.disconnect()
assert client._leave_counter == 0
mock_websocket.close.assert_called_once()
callbacks.on_client_disconnected.assert_called_once()
class TestLeaveCounterStasisRTP:
"""Test the _leave_counter mechanism in StasisRTPClient."""
@pytest.mark.asyncio
async def test_leave_counter_prevents_early_disconnect(self):
"""Test that disconnect only happens when both transports have disconnected."""
# Create mock connection
mock_connection = Mock()
mock_connection.is_connected.return_value = True
mock_connection.disconnect = AsyncMock()
mock_connection.notify_sockets_closed = AsyncMock()
# Mock event_handler as a callable that acts as a decorator
def mock_event_handler(event_name):
def decorator(func):
return func
return decorator
mock_connection.event_handler = mock_event_handler
# Create callbacks
from api.services.telephony.stasis_rtp_transport import StasisRTPCallbacks
callbacks = StasisRTPCallbacks(
on_client_connected=AsyncMock(),
on_client_disconnected=AsyncMock(),
on_client_closed=AsyncMock(),
)
# Create client
client = StasisRTPClient(mock_connection, callbacks)
# Create StartFrame
start_frame = StartFrame()
# Simulate both input and output transports calling setup
await client.setup(start_frame) # Input transport
assert client._leave_counter == 1
await client.setup(start_frame) # Output transport
assert client._leave_counter == 2
# First disconnect - should not actually disconnect
await client.disconnect()
assert client._leave_counter == 1
mock_connection.disconnect.assert_not_called()
# Second disconnect - should actually disconnect
await client.disconnect()
assert client._leave_counter == 0
mock_connection.disconnect.assert_called_once()
class TestLeaveCounterSmallWebRTC:
"""Test the _leave_counter mechanism in SmallWebRTCClient."""
@pytest.mark.asyncio
async def test_leave_counter_prevents_early_disconnect(self):
"""Test that disconnect only happens when both transports have disconnected."""
# Create mock connection
mock_connection = Mock()
mock_connection.is_connected.return_value = True
mock_connection.disconnect = AsyncMock()
mock_connection.notify_sockets_closed = AsyncMock()
# Mock event_handler as a callable that acts as a decorator
def mock_event_handler(event_name):
def decorator(func):
return func
return decorator
mock_connection.event_handler = mock_event_handler
# Create callbacks
from pipecat.transports.network.small_webrtc import SmallWebRTCCallbacks
callbacks = SmallWebRTCCallbacks(
on_app_message=AsyncMock(),
on_client_connected=AsyncMock(),
on_client_disconnected=AsyncMock(),
)
# Create client
client = SmallWebRTCClient(mock_connection, callbacks)
# Create StartFrame with required attributes
start_frame = StartFrame()
# Create mock transport params
from pipecat.transports.base_transport import TransportParams
params = TransportParams(
audio_in_channels=1, audio_in_sample_rate=16000, audio_out_sample_rate=16000
)
# Simulate both input and output transports calling setup
await client.setup(params, start_frame) # Input transport
assert client._leave_counter == 1
await client.setup(params, start_frame) # Output transport
assert client._leave_counter == 2
# First disconnect - should not actually disconnect
await client.disconnect()
assert client._leave_counter == 1
mock_connection.disconnect.assert_not_called()
# Second disconnect - should actually disconnect
await client.disconnect()
assert client._leave_counter == 0
mock_connection.disconnect.assert_called_once()
@pytest.mark.skip(reason="Complex integration test - requires additional mocking")
@pytest.mark.asyncio
async def test_transport_lifecycle_with_leave_counter():
"""Test complete transport lifecycle with proper leave counter handling."""
# Create mock websocket
mock_websocket = Mock()
mock_websocket.close = AsyncMock()
# Set client_state directly to WebSocketState.CONNECTED value
from starlette.websockets import WebSocketState
mock_websocket.client_state = WebSocketState.CONNECTED
mock_websocket.iter_bytes = Mock(return_value=iter([]))
mock_websocket.send_bytes = AsyncMock()
# Create transport
params = FastAPIWebsocketParams(audio_in_enabled=True, audio_out_enabled=True)
transport = FastAPIWebsocketTransport(mock_websocket, params)
# Get input and output transports
input_transport = transport.input()
output_transport = transport.output()
# Setup the transport with required components
from pipecat.clocks.system_clock import SystemClock
from pipecat.processors.frame_processor import FrameProcessorSetup
from pipecat.utils.asyncio.task_manager import TaskManager, TaskManagerParams
clock = SystemClock()
task_manager = TaskManager()
# Setup task manager with event loop
loop = asyncio.get_event_loop()
task_manager_params = TaskManagerParams(loop=loop)
task_manager.setup(task_manager_params)
setup = FrameProcessorSetup(clock=clock, task_manager=task_manager)
# Setup both input and output transports
await input_transport.setup(setup)
await output_transport.setup(setup)
# Start both transports
start_frame = StartFrame()
await input_transport.start(start_frame)
await output_transport.start(start_frame)
# Verify leave counter is 2
assert transport._client._leave_counter == 2
# Stop input transport
end_frame = EndFrame()
await input_transport.stop(end_frame)
# Verify websocket not closed yet
mock_websocket.close.assert_not_called()
# Stop output transport
await output_transport.stop(end_frame)
# Now websocket should be closed
mock_websocket.close.assert_called_once()

View file

@ -1,99 +0,0 @@
import unittest
from pipecat.frames.frames import (
FunctionCallInProgressFrame,
LLMFullResponseStartFrame,
)
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.services.google.llm import (
GoogleAssistantContextAggregator,
GoogleLLMContext,
)
from pipecat.services.openai.llm import OpenAIAssistantContextAggregator
class TestReorderOpenAIAssistantContextAggregator(unittest.IsolatedAsyncioTestCase):
async def test_reorder_function_messages_openai(self):
"""Ensure that after a text aggregation the function-call messages are moved
to appear immediately after the text response, maintaining chronological
order (assistant text -> function call -> tool response).
"""
context = OpenAILLMContext()
aggregator = OpenAIAssistantContextAggregator(context)
# Simulate the start of an LLM response so that the aggregator creates a
# response session ID that is later used for re-ordering.
await aggregator._handle_llm_start(LLMFullResponseStartFrame())
# Simulate the model emitting a function call which the aggregator will
# record for potential re-ordering.
await aggregator._handle_function_call_in_progress(
FunctionCallInProgressFrame(
function_name="get_weather",
tool_call_id="1",
arguments={},
)
)
# Now push the textual part of the assistant response. This should
# trigger the re-ordering so that the two function-related messages
# appear *after* this text.
await aggregator.handle_aggregation("Hello!")
messages = context.get_messages()
# We expect exactly three messages after re-ordering.
self.assertEqual(len(messages), 3)
# 1. Assistant text
self.assertEqual(messages[0]["role"], "assistant")
self.assertEqual(messages[0]["content"], "Hello!")
# 2. Assistant function-call message
self.assertEqual(messages[1]["role"], "assistant")
self.assertIn("tool_calls", messages[1])
# 3. Tool response
self.assertEqual(messages[2]["role"], "tool")
self.assertEqual(messages[2]["tool_call_id"], "1")
class TestReorderGoogleAssistantContextAggregator(unittest.IsolatedAsyncioTestCase):
async def test_reorder_function_messages_google(self):
context = GoogleLLMContext()
aggregator = GoogleAssistantContextAggregator(context)
# Start an LLM response session.
await aggregator._handle_llm_start(LLMFullResponseStartFrame())
# Emit a function call.
await aggregator._handle_function_call_in_progress(
FunctionCallInProgressFrame(
function_name="get_weather",
tool_call_id="1",
arguments={},
)
)
# Push the textual content.
await aggregator.handle_aggregation("Hello!")
messages = context.messages # Google context stores Content objects.
self.assertEqual(len(messages), 3)
# The first message should be the model text.
first_msg = messages[0].to_json_dict()
self.assertEqual(first_msg["role"], "model")
self.assertEqual(first_msg["parts"][0]["text"], "Hello!")
# The second message contains the function call (also from the model).
second_msg = messages[1].to_json_dict()
self.assertEqual(second_msg["role"], "model")
self.assertIn("function_call", second_msg["parts"][0])
# The third message is the placeholder function response.
third_msg = messages[2].to_json_dict()
self.assertEqual(third_msg["role"], "user")
self.assertIn("function_response", third_msg["parts"][0])

View file

@ -1,506 +0,0 @@
"""
Tests for LoopTalk API routes and orchestration.
This module tests the LoopTalk testing functionality including test session creation,
pipeline orchestration, and agent-to-agent communication.
"""
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import pytest_asyncio
from fastapi import status
from api.db.db_client import DBClient
from api.services.looptalk.orchestrator import LoopTalkTestOrchestrator
@pytest.fixture
def actor_workflow_definition():
"""Sample actor workflow definition for testing."""
return {
"nodes": [
{
"id": "1",
"type": "startCall",
"position": {"x": 0, "y": 0},
"data": {
"prompt": "Hello, I'm the actor agent.",
"is_static": True,
"name": "Start Call",
"is_start": True,
"allow_interrupt": False,
},
},
{
"id": "2",
"type": "agentNode",
"position": {"x": 100, "y": 0},
"data": {
"prompt": "You are an actor agent testing the adversary. Ask probing questions.",
"name": "Actor Agent",
"allow_interrupt": True,
},
},
{
"id": "3",
"type": "endCall",
"position": {"x": 200, "y": 0},
"data": {
"prompt": "Goodbye!",
"name": "End Call",
"is_end": True,
},
},
],
"edges": [
{
"id": "e1",
"source": "1",
"target": "2",
"data": {"label": "Continue", "condition": "Always"},
},
{
"id": "e2",
"source": "2",
"target": "3",
"data": {"label": "End", "condition": "Always"},
},
],
"stt": {"provider": "openai", "api_key": "test-key", "model": "whisper-1"},
"llm": {"provider": "openai", "api_key": "test-key", "model": "gpt-4o-mini"},
"tts": {
"provider": "openai",
"api_key": "test-key",
"model": "tts-1",
"voice": "nova",
},
}
@pytest.fixture
def adversary_workflow_definition():
"""Sample adversary workflow definition for testing."""
return {
"nodes": [
{
"id": "1",
"type": "startCall",
"position": {"x": 0, "y": 0},
"data": {
"prompt": "Hello, I'm the adversary agent.",
"is_static": True,
"name": "Start Call",
"is_start": True,
"allow_interrupt": False,
},
},
{
"id": "2",
"type": "agentNode",
"position": {"x": 100, "y": 0},
"data": {
"prompt": "You are an adversary agent being tested. Respond defensively.",
"name": "Adversary Agent",
"allow_interrupt": True,
},
},
{
"id": "3",
"type": "endCall",
"position": {"x": 200, "y": 0},
"data": {
"prompt": "Goodbye!",
"name": "End Call",
"is_end": True,
},
},
],
"edges": [
{
"id": "e1",
"source": "1",
"target": "2",
"data": {"label": "Continue", "condition": "Always"},
},
{
"id": "e2",
"source": "2",
"target": "3",
"data": {"label": "End", "condition": "Always"},
},
],
"stt": {"provider": "deepgram", "api_key": "test-key", "model": "nova-2"},
"llm": {
"provider": "groq",
"api_key": "test-key",
"model": "llama-3.1-70b-versatile",
},
"tts": {"provider": "deepgram", "api_key": "test-key", "voice": "nova-2"},
}
from pipecat.processors.frame_processor import FrameProcessor
class MockSTTService(FrameProcessor):
"""Mock STT service for testing."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
async def run_stt(self, audio: bytes) -> str:
return "Mock transcription"
class MockLLMService(FrameProcessor):
"""Mock LLM service for testing."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
async def run_llm(self, messages) -> str:
return "Mock LLM response"
def create_context_aggregator(self, context):
"""Mock context aggregator creation."""
return MagicMock()
class MockTTSService(FrameProcessor):
"""Mock TTS service for testing."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
async def run_tts(self, text: str) -> bytes:
return b"Mock audio data"
@pytest_asyncio.fixture
async def test_user_with_org(db_session):
"""Create a test user with an organization set up."""
user = await db_session.get_or_create_user_by_provider_id("test_looptalk_user")
org, _ = await db_session.get_or_create_organization_by_provider_id(
"test_looptalk_org"
)
user_id = user.id
org_id = org.id
await db_session.add_user_to_organization(user_id, org_id)
# Update user's selected organization
async with db_session.async_session() as session:
from sqlalchemy import update
from api.db.models import UserModel
await session.execute(
update(UserModel)
.where(UserModel.id == user_id)
.values(selected_organization_id=org_id)
)
await session.commit()
# Return fresh user object
return await db_session.get_user_by_id(user_id)
@pytest.mark.asyncio
async def test_create_test_session(
test_client_factory,
db_session,
test_user_with_org,
actor_workflow_definition,
adversary_workflow_definition,
):
"""Test creating a new LoopTalk test session."""
async with test_client_factory(test_user_with_org) as test_client:
# First create two workflows
actor_workflow_response = await test_client.post(
"/api/v1/workflow/create",
json={
"name": "Actor Workflow",
"workflow_definition": actor_workflow_definition,
},
)
assert actor_workflow_response.status_code == status.HTTP_200_OK
actor_workflow_id = actor_workflow_response.json()["id"]
adversary_workflow_response = await test_client.post(
"/api/v1/workflow/create",
json={
"name": "Adversary Workflow",
"workflow_definition": adversary_workflow_definition,
},
)
assert adversary_workflow_response.status_code == status.HTTP_200_OK
adversary_workflow_id = adversary_workflow_response.json()["id"]
# Create test session
response = await test_client.post(
"/api/v1/looptalk/test-sessions",
json={
"name": "Test Session 1",
"actor_workflow_id": actor_workflow_id,
"adversary_workflow_id": adversary_workflow_id,
"config": {"test_duration": 60},
},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["name"] == "Test Session 1"
assert data["status"] == "pending"
assert data["actor_workflow_id"] == actor_workflow_id
assert data["adversary_workflow_id"] == adversary_workflow_id
assert data["config"]["test_duration"] == 60
@pytest.mark.asyncio
async def test_list_test_sessions(test_client_factory, db_session, test_user_with_org):
"""Test listing LoopTalk test sessions."""
async with test_client_factory(test_user_with_org) as test_client:
response = await test_client.get(
"/api/v1/looptalk/test-sessions",
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert isinstance(data, list)
@pytest.mark.asyncio
async def test_looptalk_orchestrator_plumbing(
db_session: DBClient, actor_workflow_definition, adversary_workflow_definition
):
"""Test the LoopTalk orchestrator plumbing with mocked services."""
# Create test user and organization
user = await db_session.get_or_create_user_by_provider_id(
provider_id="test-user-123"
)
org, _ = await db_session.get_or_create_organization_by_provider_id(
org_provider_id="test-org-123"
)
# Get IDs before session closes
user_id = user.id
org_id = org.id
await db_session.add_user_to_organization(user_id, org_id)
# Update user's selected organization manually
async with db_session.async_session() as session:
from sqlalchemy import update
from api.db.models import UserModel
await session.execute(
update(UserModel)
.where(UserModel.id == user_id)
.values(selected_organization_id=org_id)
)
await session.commit()
actor_workflow = await db_session.create_workflow(
name="Actor Workflow",
workflow_definition=actor_workflow_definition,
user_id=user_id,
)
adversary_workflow = await db_session.create_workflow(
name="Adversary Workflow",
workflow_definition=adversary_workflow_definition,
user_id=user_id,
)
# Create test session
test_session = await db_session.create_test_session(
organization_id=org_id,
name="Test Session",
actor_workflow_id=actor_workflow.id,
adversary_workflow_id=adversary_workflow.id,
config={"test_duration": 10},
)
# Mock the service factories - patch at the actual import location in pipeline_builder
with (
patch(
"api.services.looptalk.core.pipeline_builder.create_stt_service"
) as mock_stt_factory,
patch(
"api.services.looptalk.core.pipeline_builder.create_llm_service"
) as mock_llm_factory,
patch(
"api.services.looptalk.core.pipeline_builder.create_tts_service"
) as mock_tts_factory,
patch(
"api.services.workflow.pipecat_engine.PipecatEngine"
) as mock_engine_class,
patch(
"api.services.pipecat.pipeline_builder.build_pipeline"
) as mock_build_pipeline,
patch("api.services.pipecat.pipeline_builder.PipelineTask") as mock_task_class,
):
# Configure mocks
mock_stt_factory.return_value = MockSTTService()
mock_llm_factory.return_value = MockLLMService()
mock_tts_factory.return_value = MockTTSService()
mock_engine = MagicMock()
mock_engine.initialize = AsyncMock()
mock_engine.get_callback_processor = MagicMock(return_value=MagicMock())
mock_engine_class.return_value = mock_engine
# Mock pipeline and task
mock_pipeline = MagicMock()
mock_task = MagicMock()
mock_task.run = AsyncMock()
mock_task.cancel = AsyncMock() # Make cancel async
mock_build_pipeline.return_value = mock_pipeline
mock_task_class.return_value = mock_task
# Create orchestrator
orchestrator = LoopTalkTestOrchestrator(db_client=db_session)
# Start test session (in a separate task to avoid blocking)
start_task = asyncio.create_task(
orchestrator.start_test_session(
test_session_id=test_session.id, organization_id=org_id
)
)
# Give it a moment to start
await asyncio.sleep(0.5)
# Verify the session is running through session manager
session_info = orchestrator.session_manager.get_session(test_session.id)
assert session_info is not None
assert session_info["test_session"].id == test_session.id
assert "actor_task" in session_info
assert "adversary_task" in session_info
# Verify service factories were called
assert mock_stt_factory.call_count == 2 # Once for each agent
assert mock_llm_factory.call_count == 2
assert mock_tts_factory.call_count == 2
# Verify pipelines were created with PipelineTask
assert mock_task_class.call_count == 2
# Stop the test session
await orchestrator.stop_test_session(test_session_id=test_session.id)
# Verify session was cleaned up
assert orchestrator.session_manager.get_session(test_session.id) is None
# Cancel the start task
start_task.cancel()
try:
await start_task
except asyncio.CancelledError:
pass
@pytest.mark.asyncio
async def test_load_test_creation(
test_client_factory,
db_session,
test_user_with_org,
actor_workflow_definition,
adversary_workflow_definition,
):
"""Test creating a load test with multiple sessions."""
async with test_client_factory(test_user_with_org) as test_client:
# First create two workflows
actor_workflow_response = await test_client.post(
"/api/v1/workflow/create",
json={
"name": "Actor Workflow",
"workflow_definition": actor_workflow_definition,
},
)
actor_workflow_id = actor_workflow_response.json()["id"]
adversary_workflow_response = await test_client.post(
"/api/v1/workflow/create",
json={
"name": "Adversary Workflow",
"workflow_definition": adversary_workflow_definition,
},
)
adversary_workflow_id = adversary_workflow_response.json()["id"]
# Create load test
response = await test_client.post(
"/api/v1/looptalk/load-tests",
json={
"name_prefix": "Load Test",
"actor_workflow_id": actor_workflow_id,
"adversary_workflow_id": adversary_workflow_id,
"test_count": 3,
"config": {"test_duration": 30},
},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["total"] == 3
assert "load_test_group_id" in data
assert len(data["test_session_ids"]) == 3
@pytest.mark.asyncio
async def test_invalid_workflow_ids(
test_client_factory, db_session, test_user_with_org
):
"""Test creating test session with invalid workflow IDs."""
async with test_client_factory(test_user_with_org) as test_client:
response = await test_client.post(
"/api/v1/looptalk/test-sessions",
json={
"name": "Invalid Test",
"actor_workflow_id": 99999,
"adversary_workflow_id": 99999,
"config": {},
},
)
assert response.status_code == status.HTTP_404_NOT_FOUND
assert "workflow not found" in response.json()["detail"].lower()
@pytest.mark.asyncio
async def test_transport_manager():
"""Test the internal transport manager functionality."""
from pipecat.transports import InternalTransportManager, TransportParams
manager = InternalTransportManager()
# Create transport pair
params = TransportParams(
audio_out_enabled=True,
audio_in_enabled=True,
audio_out_sample_rate=16000,
audio_in_sample_rate=16000,
)
actor_transport, adversary_transport = manager.create_transport_pair(
test_session_id="test-123", actor_params=params, adversary_params=params
)
# Verify transports are connected
assert actor_transport._output._partner == adversary_transport._input
assert adversary_transport._output._partner == actor_transport._input
# Verify transport pair is tracked
assert manager.get_active_test_count() == 1
assert manager.get_transport_pair("test-123") is not None
# Remove transport pair
manager.remove_transport_pair("test-123")
assert manager.get_active_test_count() == 0
assert manager.get_transport_pair("test-123") is None

View file

@ -1,142 +0,0 @@
### - The test gets stuck. Need to figure out a way to run the test
# import asyncio
# import unittest
# from loguru import logger
# from pipecat.frames.frames import (
# FunctionCallFromLLM,
# FunctionCallInProgressFrame,
# FunctionCallResultFrame,
# FunctionCallsStartedFrame,
# LLMFullResponseEndFrame,
# LLMFullResponseStartFrame,
# LLMTextFrame,
# )
# from pipecat.processors.aggregators.openai_llm_context import (
# OpenAILLMContext,
# OpenAILLMContextFrame,
# )
# from pipecat.processors.frame_processor import FrameDirection
# from pipecat.services.llm_service import (
# FunctionCallParams,
# FunctionCallResultProperties,
# LLMService,
# )
# from pipecat.tests.utils import run_test
# class MockLLMService(LLMService):
# """A very small mocked LLM service that, upon receiving an
# ``OpenAILLMContextFrame``, streams a text completion followed by the
# execution of the supplied tools (function calls).
# """
# def __init__(self, *, content: str, tools: list[dict[str, dict]], **kwargs):
# # Run function calls sequentially so that frame ordering is deterministic.
# super().__init__(run_in_parallel=False, **kwargs)
# self._content = content
# self._tools = tools
# async def process_frame(self, frame, direction: FrameDirection):
# await super().process_frame(frame, direction)
# if isinstance(frame, OpenAILLMContextFrame) and direction == FrameDirection.DOWNSTREAM:
# # Simulate the start of a streamed completion.
# await self.push_frame(LLMFullResponseStartFrame())
# await self.push_frame(LLMTextFrame(self._content))
# # Convert tool specs into FunctionCallFromLLM objects.
# function_calls = []
# for idx, tool in enumerate(self._tools):
# function_calls.append(
# FunctionCallFromLLM(
# function_name=tool["function_name"],
# tool_call_id=f"tool_{idx}",
# arguments=tool.get("arguments", {}),
# context=frame.context,
# )
# )
# # Ask the LLM service base class to execute the calls.
# await self.run_function_calls(function_calls)
# # Finish the streamed response.
# await self.push_frame(LLMFullResponseEndFrame())
# async def _run_function_call(self, runner_item): # type: ignore[override] narrow signature
# # Ensure run_llm=True so that downstream processors know they can
# # immediately trigger another LLM call after the result is committed.
# runner_item.run_llm = True
# await super()._run_function_call(runner_item)
# class TestMockLLMPipeline(unittest.IsolatedAsyncioTestCase):
# async def test_mock_llm_pipeline_with_tools(self):
# # ------------------------------------------------------------------
# # 1. Create mocked LLM service with completion text and tools
# # ------------------------------------------------------------------
# completion_text = "Hello from mocked LLM!"
# tools = [
# {"function_name": "tool_one", "arguments": {"a": 1}},
# {"function_name": "tool_two", "arguments": {"b": 2}},
# ]
# llm = MockLLMService(content=completion_text, tools=tools)
# # ------------------------------------------------------------------
# # 2. Register the tool functions they simply log & sleep briefly.
# # Each of them marks that it has run so that we can assert later.
# # ------------------------------------------------------------------
# executed: dict[str, bool] = {t["function_name"]: False for t in tools}
# def make_handler(name: str):
# async def _handler(params: FunctionCallParams):
# logger.debug(f"Executing {name} with args {params.arguments}")
# executed[name] = True
# await asyncio.sleep(0.01)
# await params.result_callback(
# {"status": "ok"},
# properties=FunctionCallResultProperties(run_llm=True),
# )
# return _handler
# for t in tools:
# llm.register_function(t["function_name"], make_handler(t["function_name"]))
# # ------------------------------------------------------------------
# # 3. Build the pipeline and send the initial context frame that
# # triggers the completion.
# # ------------------------------------------------------------------
# context = OpenAILLMContext()
# context.add_message({"role": "user", "content": "Hi!"})
# frames_to_send = [OpenAILLMContextFrame(context)]
# expected_down_frames = [
# LLMFullResponseStartFrame,
# LLMTextFrame,
# FunctionCallsStartedFrame,
# FunctionCallInProgressFrame,
# FunctionCallResultFrame,
# FunctionCallInProgressFrame,
# FunctionCallResultFrame,
# LLMFullResponseEndFrame,
# ]
# # Run the test pipeline.
# received_down_frames, _ = await run_test(
# llm,
# frames_to_send=frames_to_send,
# expected_down_frames=expected_down_frames,
# )
# # ------------------------------------------------------------------
# # 4. Verify that both tool functions executed and that run_llm=True
# # in all FunctionCallResultFrame instances.
# # ------------------------------------------------------------------
# self.assertTrue(all(executed.values()))
# for frame in received_down_frames:
# if isinstance(frame, FunctionCallResultFrame):
# self.assertTrue(frame.run_llm)

View file

@ -1,236 +0,0 @@
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from api.services.workflow.pipecat_engine import PipecatEngine
def create_disposition_mapping_side_effect(mapping_dict):
"""Helper to create a side effect function for disposition mapping."""
async def side_effect(value, org_id):
return mapping_dict.get(value, value)
return side_effect
@pytest.fixture
def mock_dependencies():
"""Create mock dependencies for PipecatEngine."""
mock_task = MagicMock()
mock_task.queue_frame = AsyncMock()
mock_llm = MagicMock()
mock_context = MagicMock()
mock_workflow = MagicMock()
return {
"task": mock_task,
"llm": mock_llm,
"context": mock_context,
"workflow": mock_workflow,
"call_context_vars": {},
"workflow_run_id": 123,
}
@pytest.mark.asyncio
async def test_apply_disposition_mapping_with_call_disposition(mock_dependencies):
"""Test disposition mapping when call_disposition is present."""
engine = PipecatEngine(**mock_dependencies)
# Setup gathered context
engine._gathered_context = {
"call_disposition": "XFER",
"agent_name": "Alex",
"total_debt": "$15000",
}
# Mock the disposition mapper functions
with patch(
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run"
) as mock_get_org_id:
with patch(
"api.services.workflow.pipecat_engine.apply_disposition_mapping"
) as mock_apply_mapping:
# Mock organization ID
mock_get_org_id.return_value = 1
# Mock disposition mapping
mock_apply_mapping.side_effect = create_disposition_mapping_side_effect(
{
"XFER": "TRANSFERRED",
"ND": "NOT_QUALIFIED",
}
)
# Call send_end_task_frame
await engine.send_end_task_frame(reason="user_qualified")
# Verify the frame was queued with mapped values
mock_dependencies["task"].queue_frame.assert_called_once()
frame = mock_dependencies["task"].queue_frame.call_args[0][0]
# Check metadata contains mapped values
assert frame.metadata["reason"] == "user_qualified" # No mapping for this
assert (
frame.metadata["call_transfer_context"]["disposition"] == "TRANSFERRED"
)
# Check gathered context was updated
assert engine._gathered_context["call_disposition"] == "TRANSFERRED"
@pytest.mark.asyncio
async def test_apply_disposition_mapping_with_disconnect_reason(mock_dependencies):
"""Test disposition mapping for disconnect_reason when no call_disposition exists."""
engine = PipecatEngine(**mock_dependencies)
# Setup gathered context without call_disposition
engine._gathered_context = {
"agent_name": "Alex",
}
# Mock the disposition mapper functions
with patch(
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run"
) as mock_get_org_id:
with patch(
"api.services.workflow.pipecat_engine.apply_disposition_mapping"
) as mock_apply_mapping:
# Mock organization ID
mock_get_org_id.return_value = 1
# Mock disposition mapping
mock_apply_mapping.side_effect = create_disposition_mapping_side_effect(
{
"user_qualified": "QUALIFIED",
"user_disqualified": "NOT_QUALIFIED",
"user_hangup": "HANGUP",
}
)
# Call send_end_task_frame with a mappable reason
await engine.send_end_task_frame(reason="user_qualified")
# Verify the frame was queued with mapped disposition
mock_dependencies["task"].queue_frame.assert_called_once()
frame = mock_dependencies["task"].queue_frame.call_args[0][0]
# Check metadata contains original reason
assert frame.metadata["reason"] == "user_qualified"
# Check call_transfer_context has mapped disconnect_reason as disposition
assert frame.metadata["call_transfer_context"]["disposition"] == "QUALIFIED"
# Check gathered context was updated with mapped call_disposition
assert engine._gathered_context["call_disposition"] == "QUALIFIED"
# Check internal call_disposition stores mapped value
assert engine._call_disposition == "QUALIFIED"
@pytest.mark.asyncio
async def test_call_disposition_takes_precedence(mock_dependencies):
"""Test that call_disposition is used when both call_disposition and reason could be mapped."""
engine = PipecatEngine(**mock_dependencies)
# Setup gathered context with call_disposition
engine._gathered_context = {
"call_disposition": "XFER",
"agent_name": "Alex",
}
# Mock the disposition mapper functions
with patch(
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run"
) as mock_get_org_id:
with patch(
"api.services.workflow.pipecat_engine.apply_disposition_mapping"
) as mock_apply_mapping:
# Mock organization ID
mock_get_org_id.return_value = 1
# Mock disposition mapping
mock_apply_mapping.side_effect = create_disposition_mapping_side_effect(
{
"XFER": "TRANSFERRED",
"user_qualified": "QUALIFIED",
}
)
# Call send_end_task_frame with a reason that could also be mapped
await engine.send_end_task_frame(reason="user_qualified")
# Verify the frame was queued
mock_dependencies["task"].queue_frame.assert_called_once()
frame = mock_dependencies["task"].queue_frame.call_args[0][0]
# Check that call_disposition mapping was used, not reason mapping
assert (
frame.metadata["call_transfer_context"]["disposition"] == "TRANSFERRED"
)
# Check only call_disposition was updated in gathered context
assert engine._gathered_context["call_disposition"] == "TRANSFERRED"
assert "disconnect_reason" not in engine._gathered_context
@pytest.mark.asyncio
async def test_disposition_mapping_no_organization_id(mock_dependencies):
"""Test when organization_id cannot be retrieved."""
# Set workflow_run_id to None
mock_dependencies["workflow_run_id"] = None
engine = PipecatEngine(**mock_dependencies)
engine._gathered_context = {
"call_disposition": "XFER",
}
# Call send_end_task_frame
await engine.send_end_task_frame(reason="user_qualified")
# Verify the frame was queued with original values (no mapping)
mock_dependencies["task"].queue_frame.assert_called_once()
frame = mock_dependencies["task"].queue_frame.call_args[0][0]
# Check values remain unchanged
assert frame.metadata["reason"] == "user_qualified"
assert frame.metadata["call_transfer_context"]["disposition"] == "XFER"
# Gathered context should remain unchanged
assert engine._gathered_context["call_disposition"] == "XFER"
@pytest.mark.asyncio
async def test_disposition_mapping_no_configuration(mock_dependencies):
"""Test when no disposition mapping is configured."""
engine = PipecatEngine(**mock_dependencies)
engine._gathered_context = {
"call_disposition": "XFER",
}
# Mock the disposition mapper functions
with patch(
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run"
) as mock_get_org_id:
with patch(
"api.services.workflow.pipecat_engine.apply_disposition_mapping"
) as mock_apply_mapping:
# Mock organization ID
mock_get_org_id.return_value = 1
# Mock no disposition mapping (return original value)
mock_apply_mapping.side_effect = lambda value, org_id: value
# Call send_end_task_frame
await engine.send_end_task_frame(reason="user_qualified")
# Verify the frame was queued with original values
mock_dependencies["task"].queue_frame.assert_called_once()
frame = mock_dependencies["task"].queue_frame.call_args[0][0]
# Check values remain unchanged
assert frame.metadata["reason"] == "user_qualified"
assert frame.metadata["call_transfer_context"]["disposition"] == "XFER"

View file

@ -1,206 +0,0 @@
from unittest.mock import Mock
import pytest
from api.services.workflow.pipecat_engine import PipecatEngine
from api.services.workflow.workflow import WorkflowGraph
class TestPipecatEngine:
@pytest.fixture
def mock_dependencies(self):
"""Create mock dependencies for PipecatEngine initialization."""
return {
"task": Mock(),
"llm": Mock(),
"context": Mock(),
"tts": Mock(),
"transport": Mock(),
"workflow": Mock(spec=WorkflowGraph),
"call_context_vars": {},
}
@pytest.fixture
def engine_with_context(self, mock_dependencies):
"""Create a PipecatEngine instance with test context variables."""
context_vars = {
"first_name": "John",
"last_name": "Doe",
"age": 25,
"email": "john.doe@example.com",
"empty_var": "",
"zero_var": 0,
"false_var": False,
}
mock_dependencies["call_context_vars"] = context_vars
return PipecatEngine(**mock_dependencies)
@pytest.fixture
def engine_empty_context(self, mock_dependencies):
"""Create a PipecatEngine instance with empty context variables."""
mock_dependencies["call_context_vars"] = {}
return PipecatEngine(**mock_dependencies)
def test_format_prompt_simple_variable_replacement(self, engine_with_context):
"""Test simple variable replacement without filters."""
prompt = "Hello {{ first_name }}, welcome!"
result = engine_with_context._format_prompt(prompt)
assert result == "Hello John, welcome!"
def test_format_prompt_multiple_variables(self, engine_with_context):
"""Test multiple variable replacements in a single prompt."""
prompt = "Hello {{ first_name }} {{ last_name }}, you are {{ age }} years old."
result = engine_with_context._format_prompt(prompt)
assert result == "Hello John Doe, you are 25 years old."
def test_format_prompt_with_fallback_existing_value(self, engine_with_context):
"""Test fallback filter when value exists."""
prompt = "Hello {{ first_name | fallback }}, nice to meet you!"
result = engine_with_context._format_prompt(prompt)
assert result == "Hello John, nice to meet you!"
def test_format_prompt_with_fallback_missing_value(self, engine_empty_context):
"""Test fallback filter when value is missing."""
prompt = "Hello {{ first_name | fallback }}, nice to meet you!"
result = engine_empty_context._format_prompt(prompt)
assert result == "Hello First_Name, nice to meet you!"
def test_format_prompt_with_custom_fallback_missing_value(
self, engine_empty_context
):
"""Test fallback filter with custom fallback value when variable is missing."""
prompt = "Hello {{ first_name | fallback:Guest }}, welcome!"
result = engine_empty_context._format_prompt(prompt)
assert result == "Hello Guest, welcome!"
def test_format_prompt_with_custom_fallback_existing_value(
self, engine_with_context
):
"""Test fallback filter with custom fallback value when variable exists."""
prompt = "Hello {{ first_name | fallback:Guest }}, welcome!"
result = engine_with_context._format_prompt(prompt)
assert result == "Hello John, welcome!"
def test_format_prompt_empty_string_variable(self, engine_with_context):
"""Test variable with empty string value."""
prompt = "Value: '{{ empty_var | fallback:No Value }}'"
result = engine_with_context._format_prompt(prompt)
assert result == "Value: 'No Value'"
def test_format_prompt_zero_value(self, engine_with_context):
"""Test variable with zero value (should not trigger fallback)."""
prompt = "Count: {{ zero_var | fallback:None }}"
result = engine_with_context._format_prompt(prompt)
assert result == "Count: 0"
def test_format_prompt_false_value(self, engine_with_context):
"""Test variable with False value (should not trigger fallback)."""
prompt = "Status: {{ false_var | fallback:Unknown }}"
result = engine_with_context._format_prompt(prompt)
assert result == "Status: False"
def test_format_prompt_missing_variable_no_fallback(self, engine_empty_context):
"""Test missing variable without fallback filter."""
prompt = "Hello {{ missing_var }}, welcome!"
result = engine_empty_context._format_prompt(prompt)
assert result == "Hello , welcome!"
def test_format_prompt_complex_mixed_scenario(self, engine_with_context):
"""Test complex scenario with multiple variables, some with fallbacks."""
prompt = (
"Dear {{ first_name | fallback:Customer }}, "
"your email {{ email }} is confirmed. "
"{{ missing_info | fallback:Additional information }} will be sent later. "
"You are {{ age }} years old."
)
result = engine_with_context._format_prompt(prompt)
expected = (
"Dear John, "
"your email john.doe@example.com is confirmed. "
"Additional information will be sent later. "
"You are 25 years old."
)
assert result == expected
def test_format_prompt_whitespace_handling(self, engine_with_context):
"""Test handling of whitespace in template variables."""
prompt = "Hello {{ first_name | fallback : Default }}, welcome!"
result = engine_with_context._format_prompt(prompt)
assert result == "Hello John, welcome!"
def test_format_prompt_no_variables(self, engine_with_context):
"""Test prompt with no template variables."""
prompt = "This is a regular prompt with no variables."
result = engine_with_context._format_prompt(prompt)
assert result == "This is a regular prompt with no variables."
def test_format_prompt_empty_prompt(self, engine_with_context):
"""Test empty prompt."""
prompt = ""
result = engine_with_context._format_prompt(prompt)
assert result == ""
def test_format_prompt_none_prompt(self, engine_with_context):
"""Test None prompt."""
prompt = None
result = engine_with_context._format_prompt(prompt)
assert result is None
def test_format_prompt_nested_braces(self, engine_with_context):
"""Test handling of nested or malformed braces."""
prompt = "Hello {{ first_name }}, this {is not a template} variable."
result = engine_with_context._format_prompt(prompt)
assert result == "Hello John, this {is not a template} variable."
def test_format_prompt_special_characters_in_value(self):
"""Test variables containing special characters."""
mock_deps = {
"task": Mock(),
"llm": Mock(),
"context": Mock(),
"tts": Mock(),
"transport": Mock(),
"workflow": Mock(spec=WorkflowGraph),
"call_context_vars": {
"special_name": "John & Jane's Company",
"email": "test@domain.com",
},
}
engine = PipecatEngine(**mock_deps)
prompt = "Company: {{ special_name }}, Contact: {{ email }}"
result = engine._format_prompt(prompt)
assert result == "Company: John & Jane's Company, Contact: test@domain.com"
def test_format_prompt_numeric_and_boolean_conversion(self):
"""Test conversion of different data types to strings."""
mock_deps = {
"task": Mock(),
"llm": Mock(),
"context": Mock(),
"tts": Mock(),
"transport": Mock(),
"workflow": Mock(spec=WorkflowGraph),
"call_context_vars": {
"count": 42,
"price": 99.99,
"is_active": True,
"items": ["apple", "banana"],
},
}
engine = PipecatEngine(**mock_deps)
prompt = "Count: {{ count }}, Price: ${{ price }}, Active: {{ is_active }}, Items: {{ items }}"
result = engine._format_prompt(prompt)
assert (
result
== "Count: 42, Price: $99.99, Active: True, Items: ['apple', 'banana']"
)
def test_format_prompt_case_sensitivity(self, engine_with_context):
"""Test that variable names are case sensitive."""
prompt = (
"Hello {{ First_Name | fallback }}, welcome!" # Note the capitalization
)
result = engine_with_context._format_prompt(prompt)
assert result == "Hello First_Name, welcome!" # Should use fallback

View file

@ -1,295 +0,0 @@
"""
Test scenarios for provider switching and billing integrity.
This test suite validates that the multi-provider telephony system
handles provider switches correctly without losing billing data.
"""
import asyncio
# Test scenarios to validate
async def test_scenario_1_mid_call_provider_switch():
"""
Test: What happens if provider is switched while a call is active?
Expected behavior:
- Active call continues with original provider
- Call is billed to original provider
- New calls use new provider
"""
print("Test 1: Mid-call provider switching")
# Simulate workflow run with Twilio
twilio_run = {
"id": 1,
"mode": "twilio",
"cost_info": {"twilio_call_sid": "CA123456789", "provider": "twilio"},
"is_completed": False,
}
# Provider switch happens here (in real scenario, user changes config)
# But the call continues...
# When cost calculation runs, it should:
# 1. Use the provider stored in cost_info
# 2. Fetch cost from Twilio using twilio_call_sid
# 3. Store cost with provider attribution
result = {
"test": "mid_call_switch",
"status": "PASS",
"reason": "Call continues with original provider, billing intact",
}
print(f"{result['reason']}")
return result
async def test_scenario_2_pending_cost_calculation():
"""
Test: Calls that ended but cost not yet calculated when provider switches.
Expected behavior:
- Background job should use the provider info stored in cost_info
- Cost should be fetched from correct provider
"""
print("\nTest 2: Pending cost calculation during switch")
# Workflow runs that ended but cost job hasn't run yet
pending_runs = [
{
"id": 2,
"mode": "twilio",
"cost_info": {"twilio_call_sid": "CA987654321", "provider": "twilio"},
"is_completed": True,
},
{
"id": 3,
"mode": "vonage",
"cost_info": {"vonage_call_uuid": "uuid-123", "provider": "vonage"},
"is_completed": True,
},
]
# Provider switch happens here
# Cost calculation jobs run after switch
# Each job should:
# 1. Check the provider field in cost_info
# 2. Use appropriate provider API to fetch cost
# 3. Handle gracefully if credentials changed
result = {
"test": "pending_cost_calculation",
"status": "PASS",
"reason": "Cost jobs use stored provider info correctly",
}
print(f"{result['reason']}")
return result
async def test_scenario_3_mixed_provider_history():
"""
Test: Organization has calls from both Twilio and Vonage.
Expected behavior:
- Historical costs remain intact
- Reports show correct attribution
- Total costs aggregate correctly
"""
print("\nTest 3: Mixed provider history")
historical_runs = [
{"provider": "twilio", "cost_usd": 0.15, "date": "2024-01-01"},
{"provider": "vonage", "cost_usd": 0.12, "date": "2024-01-02"},
{"provider": "twilio", "cost_usd": 0.18, "date": "2024-01-03"},
{"provider": "vonage", "cost_usd": 0.14, "date": "2024-01-04"},
]
# Calculate totals
total_cost = sum(run["cost_usd"] for run in historical_runs)
twilio_cost = sum(
run["cost_usd"] for run in historical_runs if run["provider"] == "twilio"
)
vonage_cost = sum(
run["cost_usd"] for run in historical_runs if run["provider"] == "vonage"
)
result = {
"test": "mixed_provider_history",
"status": "PASS",
"total_cost": total_cost,
"twilio_cost": twilio_cost,
"vonage_cost": vonage_cost,
"reason": f"Costs correctly aggregated: Total ${total_cost:.2f} (Twilio: ${twilio_cost:.2f}, Vonage: ${vonage_cost:.2f})",
}
print(f"{result['reason']}")
return result
async def test_scenario_4_cost_api_failure():
"""
Test: Provider API fails when fetching cost.
Expected behavior:
- Error logged but system continues
- Call record preserved
- Cost marked as 0 or unknown
"""
print("\nTest 4: Cost API failure handling")
# Simulate API failure scenarios
failure_scenarios = [
{
"provider": "twilio",
"error": "401 Unauthorized - credentials changed",
"expected": "Cost set to 0, error logged",
},
{
"provider": "vonage",
"error": "404 Not Found - call record deleted",
"expected": "Cost set to 0, error logged",
},
{
"provider": "twilio",
"error": "500 Internal Server Error",
"expected": "Cost set to 0, retry possible",
},
]
for scenario in failure_scenarios:
print(f" - {scenario['provider']}: {scenario['error']}")
print(f" Expected: {scenario['expected']}")
result = {
"test": "cost_api_failure",
"status": "PASS",
"reason": "All failure scenarios handled gracefully",
}
print(f"{result['reason']}")
return result
async def test_scenario_5_configuration_migration():
"""
Test: Database migration from single to multi-provider format.
Expected behavior:
- Old TWILIO_CONFIGURATION migrated to TELEPHONY_CONFIGURATION
- Single provider config wrapped in multi-provider structure
- Existing cost_info gets provider field added
"""
print("\nTest 5: Configuration migration")
# Old format
old_config = {
"account_sid": "AC123",
"auth_token": "token123",
"from_numbers": ["+1234567890"],
"provider": "twilio",
}
# New format after migration
new_config = {
"active_provider": "twilio",
"providers": {
"twilio": {
"account_sid": "AC123",
"auth_token": "token123",
"from_numbers": ["+1234567890"],
}
},
}
# Validate migration
assert new_config["active_provider"] == "twilio"
assert "providers" in new_config
assert new_config["providers"]["twilio"]["account_sid"] == old_config["account_sid"]
result = {
"test": "configuration_migration",
"status": "PASS",
"reason": "Configuration migrated to multi-provider format correctly",
}
print(f"{result['reason']}")
return result
async def test_scenario_6_provider_cost_discrepancy():
"""
Test: Webhook cost vs API cost discrepancy.
Expected behavior:
- Webhook cost stored immediately if available
- API cost fetched later for verification
- Both costs stored for auditing
"""
print("\nTest 6: Provider cost discrepancy handling")
# Vonage webhook provides immediate cost
webhook_cost = {"vonage_webhook_price": 0.15, "vonage_webhook_duration": 120}
# API call provides authoritative cost
api_cost = {
"cost_usd": 0.14, # Slight difference
"duration": 120,
}
# Both should be stored
final_cost_info = {
**webhook_cost,
"cost_breakdown": {"telephony_call": api_cost["cost_usd"]},
"provider": "vonage",
}
result = {
"test": "cost_discrepancy",
"status": "PASS",
"reason": "Both webhook and API costs stored for auditing",
}
print(f"{result['reason']}")
return result
async def run_all_tests():
"""Run all test scenarios."""
print("=" * 60)
print("PROVIDER SWITCHING TEST SUITE")
print("=" * 60)
tests = [
test_scenario_1_mid_call_provider_switch,
test_scenario_2_pending_cost_calculation,
test_scenario_3_mixed_provider_history,
test_scenario_4_cost_api_failure,
test_scenario_5_configuration_migration,
test_scenario_6_provider_cost_discrepancy,
]
results = []
for test in tests:
result = await test()
results.append(result)
print("\n" + "=" * 60)
print("TEST SUMMARY")
print("=" * 60)
passed = sum(1 for r in results if r["status"] == "PASS")
failed = sum(1 for r in results if r["status"] == "FAIL")
print(f"Total Tests: {len(results)}")
print(f"Passed: {passed}")
print(f"Failed: {failed}")
if failed == 0:
print("\n✅ ALL TESTS PASSED - Provider switching is working correctly!")
else:
print("\n❌ Some tests failed - Review the implementation")
return results
if __name__ == "__main__":
# Run the test suite
asyncio.run(run_all_tests())

View file

@ -1,266 +0,0 @@
"""Tests for run_integrations with new DB client methods."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from api.enums import WorkflowRunMode
from api.tasks.run_integrations import run_integrations_post_workflow_run
@pytest.fixture(autouse=True)
def mock_logger():
"""Mock the logger for all tests."""
with patch("api.tasks.run_integrations.logger") as mock_logger:
mock_logger.bind.return_value = mock_logger
yield mock_logger
@pytest.fixture
def mock_workflow_run():
"""Create a mock workflow run with all required attributes."""
workflow_run = MagicMock()
workflow_run.id = 1
workflow_run.mode = "browser"
workflow_run.gathered_context = {
"call_disposition": "XFER",
"mapped_call_disposition": "XFER", # Required for Slack integration
"call_duration": "120",
"agent_name": "TestAgent",
}
workflow_run.initial_context = {"vendor_id": "123"}
# Setup workflow and user chain
workflow_run.workflow = MagicMock()
workflow_run.workflow.user = MagicMock()
workflow_run.workflow.user.selected_organization_id = 100
return workflow_run
@pytest.fixture
def mock_integration():
"""Create a mock integration."""
integration = MagicMock()
integration.id = 1
integration.organisation_id = 100
integration.provider = "slack"
integration.is_active = True
integration.connection_details = {
"connection_config": {"incoming_webhook.url": "https://hooks.slack.com/test"}
}
return integration
@pytest.mark.asyncio
async def test_run_integrations_with_db_client_methods(
mock_workflow_run, mock_integration
):
"""Test that run_integrations uses the new DB client methods correctly."""
with patch("api.tasks.run_integrations.set_current_run_id") as mock_set_run_id:
with patch("api.tasks.run_integrations.db_client") as mock_db_client:
# Mock the new DB client methods
mock_db_client.get_workflow_run_with_context = AsyncMock(
return_value=(mock_workflow_run, 100)
)
mock_db_client.get_active_integrations_by_organization = AsyncMock(
return_value=[mock_integration]
)
mock_db_client.get_configuration_value = AsyncMock(
return_value={
"slack": {
"DISPOSITION_CODE": "Disposition: {{mapped_call_disposition}}"
}
}
)
# Mock the aiohttp session for Slack webhook
with patch(
"api.tasks.run_integrations.aiohttp.ClientSession"
) as mock_session_class:
mock_response = MagicMock()
mock_response.status = 200
mock_session = MagicMock()
mock_session.__aenter__.return_value = mock_session
mock_session.__aexit__.return_value = AsyncMock()
mock_post = MagicMock()
mock_post.__aenter__.return_value = mock_response
mock_post.__aexit__.return_value = AsyncMock()
mock_session.post.return_value = mock_post
mock_session_class.return_value = mock_session
# Call the function
await run_integrations_post_workflow_run(None, 1)
# Verify the correct DB client methods were called
mock_set_run_id.assert_called_once_with(1)
mock_db_client.get_workflow_run_with_context.assert_called_once_with(1)
mock_db_client.get_active_integrations_by_organization.assert_called_once_with(
100
)
# Verify the Slack webhook was called
mock_session.post.assert_called_once()
assert (
mock_session.post.call_args[0][0] == "https://hooks.slack.com/test"
)
@pytest.mark.asyncio
async def test_run_integrations_no_workflow_run():
"""Test handling when workflow run is not found."""
with patch("api.tasks.run_integrations.set_current_run_id"):
with patch("api.tasks.run_integrations.db_client") as mock_db_client:
# Mock workflow run not found
mock_db_client.get_workflow_run_with_context = AsyncMock(
return_value=(None, None)
)
# Call the function
await run_integrations_post_workflow_run(None, 999)
# Verify it returns early and doesn't call other DB methods
mock_db_client.get_workflow_run_with_context.assert_called_once_with(999)
mock_db_client.get_active_integrations_by_organization.assert_not_called()
@pytest.mark.asyncio
async def test_run_integrations_no_organization():
"""Test handling when user has no organization."""
mock_workflow_run = MagicMock()
mock_workflow_run.id = 1
mock_workflow_run.gathered_context = {"test": "data"}
mock_workflow_run.workflow = MagicMock()
mock_workflow_run.workflow.user = MagicMock()
with patch("api.tasks.run_integrations.set_current_run_id"):
with patch("api.tasks.run_integrations.db_client") as mock_db_client:
# Mock workflow run found but no organization
mock_db_client.get_workflow_run_with_context = AsyncMock(
return_value=(mock_workflow_run, None)
)
# Call the function
await run_integrations_post_workflow_run(None, 1)
# Verify it returns early after checking organization
mock_db_client.get_workflow_run_with_context.assert_called_once_with(1)
mock_db_client.get_active_integrations_by_organization.assert_not_called()
@pytest.mark.asyncio
async def test_run_integrations_no_gathered_context(mock_workflow_run):
"""Test handling when workflow run has no gathered context."""
mock_workflow_run.gathered_context = None
with patch("api.tasks.run_integrations.set_current_run_id"):
with patch("api.tasks.run_integrations.db_client") as mock_db_client:
# Mock workflow run with no gathered context
mock_db_client.get_workflow_run_with_context = AsyncMock(
return_value=(mock_workflow_run, 100)
)
# Call the function
await run_integrations_post_workflow_run(None, 1)
# Verify it returns early after checking gathered_context
mock_db_client.get_workflow_run_with_context.assert_called_once_with(1)
mock_db_client.get_active_integrations_by_organization.assert_not_called()
@pytest.mark.asyncio
async def test_run_integrations_stasis_mode(mock_workflow_run):
"""Test that stasis mode triggers vendor sync."""
mock_workflow_run.mode = WorkflowRunMode.STASIS.value
mock_workflow_run.initial_context = {
"vendor": "test_vendor",
"vendor_base_url": "https://api.vendor.com",
"vendor_id": "123",
}
with patch("api.tasks.run_integrations.set_current_run_id"):
with patch("api.tasks.run_integrations.db_client") as mock_db_client:
with patch("api.tasks.run_integrations._sync_vendor_data") as mock_sync:
mock_sync.return_value = None
mock_db_client.get_workflow_run_with_context = AsyncMock(
return_value=(mock_workflow_run, 100)
)
mock_db_client.get_active_integrations_by_organization = AsyncMock(
return_value=[]
)
# Call the function
await run_integrations_post_workflow_run(None, 1)
# Verify vendor sync was called
mock_sync.assert_called_once_with(
mock_workflow_run.initial_context,
mock_workflow_run.gathered_context,
)
@pytest.mark.asyncio
async def test_run_integrations_multiple_integrations(mock_workflow_run):
"""Test processing multiple integrations."""
# Create multiple mock integrations
slack_integration = MagicMock()
slack_integration.provider = "slack"
slack_integration.connection_details = {
"connection_config": {"incoming_webhook.url": "https://hooks.slack.com/test1"}
}
slack_integration2 = MagicMock()
slack_integration2.provider = "slack"
slack_integration2.connection_details = {
"connection_config": {"incoming_webhook.url": "https://hooks.slack.com/test2"}
}
with patch("api.tasks.run_integrations.set_current_run_id"):
with patch("api.tasks.run_integrations.db_client") as mock_db_client:
mock_db_client.get_workflow_run_with_context = AsyncMock(
return_value=(mock_workflow_run, 100)
)
mock_db_client.get_active_integrations_by_organization = AsyncMock(
return_value=[slack_integration, slack_integration2]
)
mock_db_client.get_configuration_value = AsyncMock(
return_value={"slack": {"DISPOSITION_CODE": "Test message"}}
)
with patch(
"api.tasks.run_integrations.aiohttp.ClientSession"
) as mock_session_class:
mock_response = MagicMock()
mock_response.status = 200
mock_session = MagicMock()
mock_session.__aenter__.return_value = mock_session
mock_session.__aexit__.return_value = AsyncMock()
mock_post = MagicMock()
mock_post.__aenter__.return_value = mock_response
mock_post.__aexit__.return_value = AsyncMock()
mock_session.post.return_value = mock_post
mock_session_class.return_value = mock_session
# Call the function
await run_integrations_post_workflow_run(None, 1)
# Verify both integrations were processed
assert mock_session.post.call_count == 2
# Check that both webhooks were called
call_urls = [call[0][0] for call in mock_session.post.call_args_list]
assert "https://hooks.slack.com/test1" in call_urls
assert "https://hooks.slack.com/test2" in call_urls

View file

@ -1,330 +0,0 @@
"""Tests for webhook execution in run_integrations.py."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from api.tasks.run_integrations import (
_build_auth_header,
_build_render_context,
_execute_webhook_node,
)
@pytest.fixture(autouse=True)
def mock_logger():
"""Mock the logger for all tests."""
with patch("api.tasks.run_integrations.logger") as mock_log:
mock_log.bind.return_value = mock_log
yield mock_log
class TestBuildAuthHeader:
"""Tests for _build_auth_header function."""
def test_bearer_token(self):
"""Test bearer token auth header."""
credential = MagicMock()
credential.credential_type = "bearer_token"
credential.credential_data = {"token": "my-secret-token"}
result = _build_auth_header(credential)
assert result == {"Authorization": "Bearer my-secret-token"}
def test_api_key(self):
"""Test API key auth header."""
credential = MagicMock()
credential.credential_type = "api_key"
credential.credential_data = {"header_name": "X-API-Key", "api_key": "key123"}
result = _build_auth_header(credential)
assert result == {"X-API-Key": "key123"}
def test_api_key_default_header(self):
"""Test API key with default header name."""
credential = MagicMock()
credential.credential_type = "api_key"
credential.credential_data = {"api_key": "key123"}
result = _build_auth_header(credential)
assert result == {"X-API-Key": "key123"}
def test_basic_auth(self):
"""Test basic auth header."""
credential = MagicMock()
credential.credential_type = "basic_auth"
credential.credential_data = {"username": "user", "password": "pass"}
result = _build_auth_header(credential)
# base64 of "user:pass" is "dXNlcjpwYXNz"
assert result == {"Authorization": "Basic dXNlcjpwYXNz"}
def test_custom_header(self):
"""Test custom header auth."""
credential = MagicMock()
credential.credential_type = "custom_header"
credential.credential_data = {
"header_name": "X-Custom-Auth",
"header_value": "custom-value",
}
result = _build_auth_header(credential)
assert result == {"X-Custom-Auth": "custom-value"}
def test_unknown_type(self):
"""Test unknown credential type returns empty dict."""
credential = MagicMock()
credential.credential_type = "unknown"
credential.credential_data = {}
result = _build_auth_header(credential)
assert result == {}
class TestBuildRenderContext:
"""Tests for _build_render_context function."""
def test_basic_context(self):
"""Test building render context from workflow run."""
workflow_run = MagicMock()
workflow_run.id = 123
workflow_run.name = "WR-TEST-001"
workflow_run.workflow_id = 456
workflow_run.workflow.name = "Test Workflow"
workflow_run.initial_context = {"phone_number": "+1234567890"}
workflow_run.gathered_context = {
"customer_name": "John",
"mapped_call_disposition": "QUALIFIED",
}
workflow_run.usage_info = {"call_duration_seconds": 120}
workflow_run.completed_at = None
result = _build_render_context(workflow_run)
assert result["workflow_run_id"] == 123
assert result["workflow_run_name"] == "WR-TEST-001"
assert result["workflow_id"] == 456
assert result["workflow_name"] == "Test Workflow"
assert result["initial_context"]["phone_number"] == "+1234567890"
assert result["gathered_context"]["customer_name"] == "John"
assert result["cost_info"]["call_duration_seconds"] == 120
assert result["disposition_code"] == "QUALIFIED"
def test_empty_contexts(self):
"""Test with empty/None contexts."""
workflow_run = MagicMock()
workflow_run.id = 1
workflow_run.name = "Test"
workflow_run.workflow_id = 1
workflow_run.workflow.name = "Workflow"
workflow_run.initial_context = None
workflow_run.gathered_context = None
workflow_run.usage_info = None
workflow_run.completed_at = None
result = _build_render_context(workflow_run)
assert result["initial_context"] == {}
assert result["gathered_context"] == {}
assert result["cost_info"] == {}
assert result["disposition_code"] is None
class TestExecuteWebhookNode:
"""Tests for _execute_webhook_node function."""
@pytest.mark.asyncio
async def test_disabled_webhook_skipped(self):
"""Test that disabled webhooks are skipped."""
webhook_data = {"name": "Test Webhook", "enabled": False}
result = await _execute_webhook_node(
webhook_data=webhook_data,
render_context={},
organization_id=1,
)
assert result is True # Returns True for skipped webhooks
@pytest.mark.asyncio
async def test_missing_url_returns_false(self):
"""Test that missing endpoint URL returns False."""
webhook_data = {"name": "Test Webhook", "enabled": True, "endpoint_url": None}
result = await _execute_webhook_node(
webhook_data=webhook_data,
render_context={},
organization_id=1,
)
assert result is False
@pytest.mark.asyncio
async def test_successful_post_request(self):
"""Test successful POST webhook execution."""
webhook_data = {
"name": "CRM Sync",
"enabled": True,
"http_method": "POST",
"endpoint_url": "https://api.example.com/webhook",
"payload_template": {
"call_id": "{{workflow_run_id}}",
"phone": "{{initial_context.phone_number}}",
},
}
render_context = {
"workflow_run_id": 123,
"initial_context": {"phone_number": "+1234567890"},
}
with patch("api.tasks.run_integrations.db_client") as mock_db:
mock_db.get_credential_by_uuid = AsyncMock(return_value=None)
with patch("api.tasks.run_integrations.httpx.AsyncClient") as mock_client:
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.raise_for_status = MagicMock()
mock_client_instance = AsyncMock()
mock_client_instance.request = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value = mock_client_instance
result = await _execute_webhook_node(
webhook_data=webhook_data,
render_context=render_context,
organization_id=1,
)
assert result is True
# Verify the request was made correctly
mock_client_instance.request.assert_called_once()
call_kwargs = mock_client_instance.request.call_args[1]
assert call_kwargs["method"] == "POST"
assert call_kwargs["url"] == "https://api.example.com/webhook"
assert call_kwargs["json"] == {
"call_id": "123",
"phone": "+1234567890",
}
@pytest.mark.asyncio
async def test_webhook_with_credential(self):
"""Test webhook execution with credential auth."""
webhook_data = {
"name": "Authenticated Webhook",
"enabled": True,
"http_method": "POST",
"endpoint_url": "https://api.example.com/webhook",
"credential_uuid": "cred-123",
"payload_template": {},
}
mock_credential = MagicMock()
mock_credential.name = "API Key"
mock_credential.credential_type = "bearer_token"
mock_credential.credential_data = {"token": "secret-token"}
with patch("api.tasks.run_integrations.db_client") as mock_db:
mock_db.get_credential_by_uuid = AsyncMock(return_value=mock_credential)
with patch("api.tasks.run_integrations.httpx.AsyncClient") as mock_client:
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.raise_for_status = MagicMock()
mock_client_instance = AsyncMock()
mock_client_instance.request = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value = mock_client_instance
result = await _execute_webhook_node(
webhook_data=webhook_data,
render_context={},
organization_id=1,
)
assert result is True
# Verify auth header was included
call_kwargs = mock_client_instance.request.call_args[1]
assert call_kwargs["headers"]["Authorization"] == "Bearer secret-token"
@pytest.mark.asyncio
async def test_webhook_with_custom_headers(self):
"""Test webhook execution with custom headers."""
webhook_data = {
"name": "Custom Headers Webhook",
"enabled": True,
"http_method": "POST",
"endpoint_url": "https://api.example.com/webhook",
"custom_headers": [
{"key": "X-Source", "value": "dograh"},
{"key": "X-Workflow", "value": "test"},
],
"payload_template": {},
}
with patch("api.tasks.run_integrations.db_client") as mock_db:
mock_db.get_credential_by_uuid = AsyncMock(return_value=None)
with patch("api.tasks.run_integrations.httpx.AsyncClient") as mock_client:
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.raise_for_status = MagicMock()
mock_client_instance = AsyncMock()
mock_client_instance.request = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value = mock_client_instance
result = await _execute_webhook_node(
webhook_data=webhook_data,
render_context={},
organization_id=1,
)
assert result is True
# Verify custom headers were included
call_kwargs = mock_client_instance.request.call_args[1]
assert call_kwargs["headers"]["X-Source"] == "dograh"
assert call_kwargs["headers"]["X-Workflow"] == "test"
@pytest.mark.asyncio
async def test_webhook_http_error(self):
"""Test webhook execution with HTTP error."""
import httpx
webhook_data = {
"name": "Failing Webhook",
"enabled": True,
"http_method": "POST",
"endpoint_url": "https://api.example.com/webhook",
"payload_template": {},
}
with patch("api.tasks.run_integrations.db_client") as mock_db:
mock_db.get_credential_by_uuid = AsyncMock(return_value=None)
with patch("api.tasks.run_integrations.httpx.AsyncClient") as mock_client:
mock_response = MagicMock()
mock_response.status_code = 500
mock_response.text = "Internal Server Error"
mock_response.raise_for_status = MagicMock(
side_effect=httpx.HTTPStatusError(
"Server Error",
request=MagicMock(),
response=mock_response,
)
)
mock_client_instance = AsyncMock()
mock_client_instance.request = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value = mock_client_instance
result = await _execute_webhook_node(
webhook_data=webhook_data,
render_context={},
organization_id=1,
)
assert result is False

View file

@ -1,117 +0,0 @@
"""Tests for the `/s3/signed-url` endpoint.
This test-suite verifies:
1. Regular users can retrieve signed URLs for resources belonging to their own workflow runs.
2. Regular users are *forbidden* from accessing resources that belong to other users.
3. Superusers can access any resource irrespective of ownership.
"""
import os
from unittest.mock import AsyncMock
import pytest
from fastapi import status
# Ensure the S3 environment variables exist so that the module import does not fail
os.environ.setdefault("S3_BUCKET", "test-bucket")
os.environ.setdefault("S3_REGION", "us-east-1")
@pytest.mark.asyncio
async def test_signed_url_for_own_run(monkeypatch, test_client_factory, db_session):
"""A normal user should be able to fetch a signed URL for their own workflow run."""
from api.db.models import UserModel
# ------------------------------------------------------------------
# 1. Set-up create user, workflow & workflow run
# ------------------------------------------------------------------
user: UserModel = await db_session.get_or_create_user_by_provider_id("user_own_run")
workflow = await db_session.create_workflow("wf", {}, user.id)
run = await db_session.create_workflow_run("run", workflow.id, "chat", user.id)
key = f"transcripts/{run.id}.txt"
# Patch S3 signed-url generator to avoid network calls
monkeypatch.setattr(
"api.services.filesystem.s3.s3_fs.aget_signed_url",
AsyncMock(return_value="https://signed-url"),
)
async with test_client_factory(user) as client:
response = await client.get(f"/api/v1/s3/signed-url?key={key}")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data == {"url": "https://signed-url", "expires_in": 3600}
@pytest.mark.asyncio
async def test_signed_url_for_other_users_run_forbidden(
monkeypatch, test_client_factory, db_session
):
"""A normal user must *not* access workflow runs owned by someone else."""
from api.db.models import UserModel
# Owner of the workflow run
owner: UserModel = await db_session.get_or_create_user_by_provider_id("owner_user")
workflow = await db_session.create_workflow("wf", {}, owner.id)
run = await db_session.create_workflow_run("run", workflow.id, "chat", owner.id)
# Second user attempting access
intruder: UserModel = await db_session.get_or_create_user_by_provider_id(
"intruder_user"
)
key = f"recordings/{run.id}.wav"
monkeypatch.setattr(
"api.services.filesystem.s3.s3_fs.aget_signed_url",
AsyncMock(return_value="https://signed-url"),
)
async with test_client_factory(intruder) as client:
response = await client.get(f"/api/v1/s3/signed-url?key={key}")
assert response.status_code == status.HTTP_403_FORBIDDEN
@pytest.mark.asyncio
async def test_superuser_can_access_any_run(
monkeypatch, test_client_factory, db_session
):
"""Superusers should be able to fetch signed URLs for any workflow run."""
from api.db.models import UserModel
# Normal user & run owner
owner: UserModel = await db_session.get_or_create_user_by_provider_id(
"owner_of_run"
)
workflow = await db_session.create_workflow("wf", {}, owner.id)
run = await db_session.create_workflow_run("run", workflow.id, "chat", owner.id)
# Superuser
superuser: UserModel = await db_session.get_or_create_user_by_provider_id(
"admin_user"
)
# Promote to superuser
# We need to commit the change so that the DB reflects it
async with db_session.async_session() as session:
db_user = await session.get(UserModel, superuser.id)
db_user.is_superuser = True
await session.commit()
await session.refresh(db_user) # ensure we have the latest state
superuser.is_superuser = True
key = f"transcripts/{run.id}.txt"
monkeypatch.setattr(
"api.services.filesystem.s3.s3_fs.aget_signed_url",
AsyncMock(return_value="https://signed-url"),
)
async with test_client_factory(superuser) as client:
response = await client.get(f"/api/v1/s3/signed-url?key={key}")
assert response.status_code == status.HTTP_200_OK
assert response.json()["url"] == "https://signed-url"

View file

@ -1,129 +0,0 @@
import os
import tempfile
from unittest.mock import AsyncMock, patch
import pytest
from api.tasks.s3_upload import upload_audio_to_s3, upload_transcript_to_s3
@pytest.mark.asyncio
async def test_upload_audio_to_s3_success():
"""Test successful audio upload to S3."""
# Create a temporary file
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tf:
tf.write(b"fake audio data")
temp_path = tf.name
try:
# Mock dependencies
mock_ctx = AsyncMock()
mock_s3_fs = AsyncMock()
mock_db_client = AsyncMock()
with (
patch("api.tasks.s3_upload.s3_fs", mock_s3_fs),
patch("api.tasks.s3_upload.db_client", mock_db_client),
):
await upload_audio_to_s3(
mock_ctx, workflow_run_id=123, temp_file_path=temp_path
)
# Verify S3 upload was called
mock_s3_fs.aupload_file.assert_called_once_with(
temp_path, "recordings/123.wav"
)
# Verify DB update was called
mock_db_client.update_workflow_run.assert_called_once_with(
run_id=123, recording_url="recordings/123.wav"
)
# Verify temp file was cleaned up
assert not os.path.exists(temp_path)
finally:
# Clean up if test failed
if os.path.exists(temp_path):
os.remove(temp_path)
@pytest.mark.asyncio
async def test_upload_audio_to_s3_file_not_found():
"""Test audio upload when temp file doesn't exist."""
mock_ctx = AsyncMock()
with pytest.raises(FileNotFoundError):
await upload_audio_to_s3(
mock_ctx, workflow_run_id=123, temp_file_path="/nonexistent/file.wav"
)
@pytest.mark.asyncio
async def test_upload_transcript_to_s3_success():
"""Test successful transcript upload to S3."""
# Create a temporary file
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as tf:
tf.write("Test transcript content")
temp_path = tf.name
try:
# Mock dependencies
mock_ctx = AsyncMock()
mock_s3_fs = AsyncMock()
mock_db_client = AsyncMock()
with (
patch("api.tasks.s3_upload.s3_fs", mock_s3_fs),
patch("api.tasks.s3_upload.db_client", mock_db_client),
):
await upload_transcript_to_s3(
mock_ctx, workflow_run_id=456, temp_file_path=temp_path
)
# Verify S3 upload was called
mock_s3_fs.aupload_file.assert_called_once_with(
temp_path, "transcripts/456.txt"
)
# Verify DB update was called
mock_db_client.update_workflow_run.assert_called_once_with(
run_id=456, transcript_url="transcripts/456.txt"
)
# Verify temp file was cleaned up
assert not os.path.exists(temp_path)
finally:
# Clean up if test failed
if os.path.exists(temp_path):
os.remove(temp_path)
@pytest.mark.asyncio
async def test_upload_s3_cleanup_on_error():
"""Test that temp files are cleaned up even when S3 upload fails."""
# Create a temporary file
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tf:
tf.write(b"fake audio data")
temp_path = tf.name
try:
mock_ctx = AsyncMock()
mock_s3_fs = AsyncMock()
# Make S3 upload fail
mock_s3_fs.aupload_file.side_effect = Exception("S3 upload failed")
with patch("api.tasks.s3_upload.s3_fs", mock_s3_fs):
with pytest.raises(Exception):
await upload_audio_to_s3(
mock_ctx, workflow_run_id=123, temp_file_path=temp_path
)
# Verify temp file was still cleaned up
assert not os.path.exists(temp_path)
finally:
# Clean up if test failed
if os.path.exists(temp_path):
os.remove(temp_path)

View file

@ -1,89 +0,0 @@
from api.utils.template_renderer import render_template
def test_render_template_basic():
"""Test basic template rendering."""
template = "Hello {{name}}, your balance is {{balance}}."
context = {"name": "John", "balance": "$1000"}
result = render_template(template, context)
assert result == "Hello John, your balance is $1000."
def test_render_template_with_spaces():
"""Test template rendering with spaces around variables."""
template = "Hello {{ name }}, your balance is {{ balance }}."
context = {"name": "John", "balance": "$1000"}
result = render_template(template, context)
assert result == "Hello John, your balance is $1000."
def test_render_template_missing_variable():
"""Test template rendering with missing variables."""
template = "Hello {{name}}, your balance is {{balance}}."
context = {"name": "John"}
result = render_template(template, context)
assert result == "Hello John, your balance is ."
def test_render_template_with_fallback():
"""Test template rendering with fallback values."""
template = "Hello {{name | fallback}}, your balance is {{balance | fallback:$0}}."
context = {}
result = render_template(template, context)
assert result == "Hello Name, your balance is $0."
def test_render_template_with_fallback_existing_value():
"""Test that fallback is not used when value exists."""
template = "Hello {{name | fallback:Guest}}"
context = {"name": "John"}
result = render_template(template, context)
assert result == "Hello John"
def test_render_template_with_line_breaks():
"""Test template rendering with line breaks."""
template = (
"DISPOSITION_CODE: {{call_disposition}}\\nCALL_DURATION: {{call_duration}}"
)
context = {"call_disposition": "XFER", "call_duration": "300"}
result = render_template(template, context)
expected = "DISPOSITION_CODE: XFER\nCALL_DURATION: 300"
assert result == expected
def test_render_template_empty():
"""Test rendering empty template."""
assert render_template("", {}) == ""
assert render_template(None, {}) == None
def test_render_template_no_placeholders():
"""Test template with no placeholders."""
template = "This is a plain text message"
result = render_template(template, {"unused": "value"})
assert result == "This is a plain text message"
def test_render_template_none_values():
"""Test template with None values."""
template = "Value: {{value}}"
context = {"value": None}
result = render_template(template, context)
assert result == "Value: "
def test_render_template_numeric_values():
"""Test template with numeric values."""
template = "Count: {{count}}, Price: {{price}}"
context = {"count": 42, "price": 19.99}
result = render_template(template, context)
assert result == "Count: 42, Price: 19.99"

View file

@ -1,152 +0,0 @@
#!/usr/bin/env python
"""
Test script to verify atomic operations in organization_usage_client.py
This simulates concurrent access from multiple processes.
"""
import asyncio
import os
from concurrent.futures import ProcessPoolExecutor
# Set up environment
os.environ.setdefault("DATABASE_URL", os.environ.get("DATABASE_URL", ""))
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from api.db.organization_usage_client import OrganizationUsageClient
async def reserve_quota_process(org_id: int, tokens: int, process_id: int):
"""Simulate a process trying to reserve quota."""
engine = create_async_engine(os.environ["DATABASE_URL"])
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
client = OrganizationUsageClient(async_session)
results = []
for i in range(5):
result = await client.check_and_reserve_quota(org_id, tokens)
results.append((process_id, i, result))
await asyncio.sleep(0.01) # Small delay to increase contention
await engine.dispose()
return results
async def update_usage_process(org_id: int, tokens: int, process_id: int):
"""Simulate a process updating usage after runs."""
engine = create_async_engine(os.environ["DATABASE_URL"])
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
client = OrganizationUsageClient(async_session)
for i in range(5):
await client.update_usage_after_run(org_id, tokens, duration_seconds=10)
await asyncio.sleep(0.01)
await engine.dispose()
return f"Process {process_id} completed updates"
def run_reserve_quota(args):
"""Wrapper to run async function in process."""
org_id, tokens, process_id = args
return asyncio.run(reserve_quota_process(org_id, tokens, process_id))
def run_update_usage(args):
"""Wrapper to run async function in process."""
org_id, tokens, process_id = args
return asyncio.run(update_usage_process(org_id, tokens, process_id))
async def test_concurrent_quota_reservation():
"""Test that concurrent quota reservations are handled atomically."""
print("Testing concurrent quota reservations...")
# Assuming org_id 1 exists with quota enabled
org_id = 1
tokens_per_request = 100
# Run multiple processes trying to reserve quota simultaneously
with ProcessPoolExecutor(max_workers=3) as executor:
futures = []
for i in range(3):
futures.append(
executor.submit(run_reserve_quota, (org_id, tokens_per_request, i))
)
results = []
for future in futures:
results.extend(future.result())
print(f"Reservation results: {results}")
# Check that reservations were handled atomically
successful_reservations = sum(1 for _, _, success in results if success)
print(f"Successful reservations: {successful_reservations}")
async def test_concurrent_usage_updates():
"""Test that concurrent usage updates are handled atomically."""
print("\nTesting concurrent usage updates...")
org_id = 1
tokens_per_update = 50
# Get initial usage
engine = create_async_engine(os.environ["DATABASE_URL"])
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
client = OrganizationUsageClient(async_session)
initial_usage = await client.get_current_usage(org_id)
initial_tokens = initial_usage["used_dograh_tokens"]
print(f"Initial tokens: {initial_tokens}")
# Run multiple processes updating usage simultaneously
with ProcessPoolExecutor(max_workers=3) as executor:
futures = []
for i in range(3):
futures.append(
executor.submit(run_update_usage, (org_id, tokens_per_update, i))
)
for future in futures:
print(future.result())
# Check final usage
final_usage = await client.get_current_usage(org_id)
final_tokens = final_usage["used_dograh_tokens"]
expected_tokens = initial_tokens + (
3 * 5 * tokens_per_update
) # 3 processes * 5 updates * 50 tokens
print(f"Final tokens: {final_tokens}")
print(f"Expected tokens: {expected_tokens}")
print(f"Difference: {final_tokens - expected_tokens}")
await engine.dispose()
if final_tokens == expected_tokens:
print("✅ All updates were applied atomically!")
else:
print("❌ Some updates were lost due to race conditions!")
async def main():
"""Run all concurrency tests."""
try:
await test_concurrent_quota_reservation()
await test_concurrent_usage_updates()
except Exception as e:
print(f"Error during testing: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
print("Starting organization usage concurrency tests...")
print(f"Using DATABASE_URL: {os.environ.get('DATABASE_URL', 'NOT SET')}")
asyncio.run(main())

View file

@ -1,140 +0,0 @@
import json
import os
from unittest.mock import AsyncMock, patch
import pytest
from pipecat.services.openai.llm import OpenAILLMContext
from api.services.workflow.dto import ExtractionVariableDTO, VariableType
from api.services.workflow.pipecat_engine_variable_extractor import (
VariableExtractionManager,
)
class DummyLLM:
"""A minimal stub that mimics the parts of an LLM service used by the extractor."""
def __init__(self, streamed_response: str | None = None):
# Optionally provide a pre-defined streaming response for _perform_extraction tests
self._streamed_response = streamed_response or "{}"
self.registered_functions: dict[str, AsyncMock] = {}
# ------------------------------------------------------------------
# API used by VariableExtractionManager
# ------------------------------------------------------------------
def register_function(self, name: str, func, cancel_on_interruption=True): # noqa: D401 simple delegate
self.registered_functions[name] = func
async def get_chat_completions(self, _context, _messages):
"""Return an async generator that yields a single chunk with the full response."""
class _Delta: # noqa: D401 tiny helper classes for stub response
def __init__(self, content):
self.content = content
class _Choice:
def __init__(self, delta):
self.delta = delta
class _Chunk:
def __init__(self, content):
self.choices = [_Choice(_Delta(content))]
async def _stream():
yield _Chunk(self._streamed_response)
return _stream()
class DummyEngine:
"""A bare-bones Engine stub exposing only what the extractor relies on."""
def __init__(self, llm):
self.llm = llm
self.context = OpenAILLMContext()
self._pending_function_calls = 0
# VariableExtractionManager currently updates this private attribute
self._gathered_context: dict = {}
# ------------------------------------------------------------------
# Tests
# ------------------------------------------------------------------
@pytest.mark.asyncio
async def test_perform_extraction_parses_json_correctly():
"""_perform_extraction should return the parsed JSON from the LLM stream."""
# Set dummy OpenAI API key to prevent initialization errors
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
expected_payload = {"name": "Alice", "age": 30}
llm = DummyLLM(json.dumps(expected_payload))
engine = DummyEngine(llm)
manager = VariableExtractionManager(engine)
# Mock the AsyncOpenAI client and its response
mock_response = AsyncMock()
mock_response.choices = [AsyncMock()]
mock_response.choices[0].message = AsyncMock()
mock_response.choices[0].message.content = json.dumps(expected_payload)
mock_client = AsyncMock()
mock_client.chat.completions.create.return_value = mock_response
with patch(
"api.services.workflow.pipecat_engine_variable_extractor.AsyncOpenAI",
return_value=mock_client,
):
# Minimal set of variables to extract the prompts themselves are irrelevant here
extraction_variables = [
ExtractionVariableDTO(
name="name", type=VariableType.string, prompt="user name"
),
ExtractionVariableDTO(
name="age", type=VariableType.number, prompt="user age"
),
]
result = await manager._perform_extraction(
extraction_variables, parent_ctx=None, extraction_prompt=""
)
assert result == expected_payload
@pytest.mark.asyncio
async def test_perform_extraction_with_custom_system_prompt():
"""_perform_extraction should use the provided extraction_prompt as system prompt."""
# Set dummy OpenAI API key to prevent initialization errors
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
expected_payload = {"color": "blue"}
llm = DummyLLM(json.dumps(expected_payload))
engine = DummyEngine(llm)
manager = VariableExtractionManager(engine)
# Mock the AsyncOpenAI client and its response
mock_response = AsyncMock()
mock_response.choices = [AsyncMock()]
mock_response.choices[0].message = AsyncMock()
mock_response.choices[0].message.content = json.dumps(expected_payload)
mock_client = AsyncMock()
mock_client.chat.completions.create.return_value = mock_response
with patch(
"api.services.workflow.pipecat_engine_variable_extractor.AsyncOpenAI",
return_value=mock_client,
):
extraction_variables = [
ExtractionVariableDTO(
name="color", type=VariableType.string, prompt="favourite color"
)
]
# Call with a custom extraction prompt
custom_prompt = "You are a color extraction specialist."
result = await manager._perform_extraction(
extraction_variables, parent_ctx=None, extraction_prompt=custom_prompt
)
assert result == expected_payload

View file

@ -1,547 +0,0 @@
"""
Test voicemail detection in RTC connection flow.
This test emulates how a call is connected using SmallWebRTC,
triggers voicemail detection, and verifies the disconnect reason.
"""
import json
from unittest.mock import AsyncMock, Mock, patch
import pytest
from pipecat.utils.enums import EndTaskReason
from api.routes.rtc_offer import RTCOfferRequest, offer
from api.services.workflow.pipecat_engine_voicemail_detector import VoicemailDetector
@pytest.mark.asyncio
class TestVoicemailDetectionRTC:
"""Test voicemail detection through RTC connection flow."""
async def test_voicemail_detection_full_flow(self):
"""
Test complete voicemail detection flow:
1. RTC connection request
2. Transport sends on_client_connected event
3. Engine initializes with voicemail detection enabled
4. Voicemail detector returns true
5. Call terminates with voicemail_detected reason
6. Transport sends on_client_disconnected event
7. Disconnect reason is properly set
"""
# Mock user and authentication
mock_user = Mock()
mock_user.id = 1
mock_user.organization_id = 1
# Mock workflow with voicemail detection enabled
mock_workflow = Mock()
mock_workflow.id = 100
mock_workflow.workflow_definition_with_fallback = {
"edges": [],
"nodes": [
{
"id": "start",
"type": "start",
"data": {
"detect_voicemail": True,
"system_prompt": "You are a helpful assistant",
},
}
],
}
# Mock workflow run
mock_workflow_run = Mock()
mock_workflow_run.id = 200
mock_workflow_run.is_completed = False
# Create request
request = RTCOfferRequest(
pc_id="test_pc_123",
sdp="test_sdp_offer",
type="offer",
workflow_id=mock_workflow.id,
workflow_run_id=mock_workflow_run.id,
restart_pc=False,
call_context_vars={"test_var": "test_value"},
)
# Mock dependencies
with (
patch("api.services.auth.depends.get_user") as mock_get_user_dep,
patch("api.routes.rtc_offer.SmallWebRTCConnection") as MockWebRTCConnection,
patch("api.routes.rtc_offer.run_pipeline_smallwebrtc") as mock_run_pipeline,
):
# Setup mocks
mock_get_user_dep.return_value = mock_user
# Mock WebRTC connection
mock_connection = Mock()
mock_connection.pc_id = "test_pc_123"
mock_connection.initialize = AsyncMock()
mock_connection.get_answer = Mock(
return_value={
"pc_id": "test_pc_123",
"sdp": "test_sdp_answer",
"type": "answer",
}
)
MockWebRTCConnection.return_value = mock_connection
# Track registered event handlers
registered_handlers = {}
def mock_event_handler(event_name):
def decorator(func):
registered_handlers[event_name] = func
return func
return decorator
mock_connection.event_handler = mock_event_handler
# Mock BackgroundTasks
mock_background_tasks = Mock()
# Create the offer
response = await offer(request, mock_background_tasks, mock_user)
# Verify response
assert response["pc_id"] == "test_pc_123"
assert response["type"] == "answer"
# Verify connection was initialized
mock_connection.initialize.assert_called_once_with(
sdp="test_sdp_offer", type="offer"
)
# Verify background task was added
mock_background_tasks.add_task.assert_called_once()
task_args = mock_background_tasks.add_task.call_args[0]
assert task_args[0] == mock_run_pipeline
assert task_args[1] == mock_connection
assert task_args[2] == mock_workflow.id
assert task_args[3] == mock_workflow_run.id
assert task_args[4] == mock_user.id
assert task_args[5] == {"test_var": "test_value"}
async def test_voicemail_detection_in_pipeline(self):
"""Tests whether the updates happen in on_client_disconnected properly
with values set in the engine"""
# Mock components
mock_transport = AsyncMock()
mock_engine = Mock() # Use Mock instead of AsyncMock for engine
mock_engine.initialize = AsyncMock()
mock_engine.cleanup = AsyncMock()
mock_audio_buffer = AsyncMock()
mock_task = AsyncMock()
mock_aggregator = Mock()
# Setup engine with voicemail detector
mock_voicemail_detector = AsyncMock(spec=VoicemailDetector)
mock_engine.voicemail_detector = mock_voicemail_detector
mock_engine.get_call_disposition = Mock(
return_value=EndTaskReason.VOICEMAIL_DETECTED.value
)
mock_engine.get_gathered_context = Mock(
return_value={
"voicemail_transcript": "Hi, you've reached John's voicemail. Please leave a message.",
"voicemail_confidence": 0.95,
}
)
# Mock usage metrics
mock_aggregator.get_all_usage_metrics_serialized.return_value = {}
# Register event handlers
from api.services.pipecat.event_handlers import (
register_transport_event_handlers,
)
# Track registered handlers
handlers = {}
def track_handler(event_name):
def decorator(func):
handlers[event_name] = func
return func
return decorator
mock_transport.event_handler = track_handler
# Create a mock db_client module with update_workflow_run method
mock_db_client = Mock()
mock_db_client.update_workflow_run = AsyncMock()
with (
patch("api.services.pipecat.event_handlers.db_client", mock_db_client),
patch(
"api.services.pipecat.event_handlers.enqueue_job",
new_callable=AsyncMock,
) as mock_enqueue_job,
patch(
"api.services.pipecat.event_handlers.get_organization_id_from_workflow_run",
return_value=1,
),
patch(
"api.services.pipecat.event_handlers.apply_disposition_mapping",
side_effect=lambda value, org_id: value, # Return value unchanged
),
):
# Register handlers
register_transport_event_handlers(
mock_transport,
workflow_run_id=123,
audio_buffer=mock_audio_buffer,
task=mock_task,
engine=mock_engine,
usage_metrics_aggregator=mock_aggregator,
)
# Verify handlers were registered
assert "on_client_connected" in handlers
assert "on_client_disconnected" in handlers
# Simulate client connection
await handlers["on_client_connected"](
mock_transport, {"id": "participant_1"}
)
# Verify initialization
mock_audio_buffer.start_recording.assert_called_once()
mock_engine.initialize.assert_called_once()
# Simulate voicemail detection and disconnect
await handlers["on_client_disconnected"](
mock_transport, {"id": "participant_1"}, None
)
# Verify engine cleanup
mock_engine.cleanup.assert_called_once()
# TODO: check whether task was cancelled or not once have more
# clarity on how to handle engine disconnect vs remote hangup
# Verify task was NOT cancelled (engine disconnect)
# mock_task.cancel.assert_not_called()
# Verify workflow run was updated with voicemail context
mock_db_client.update_workflow_run.assert_called()
call_args = mock_db_client.update_workflow_run.call_args
assert call_args[1]["run_id"] == 123
# Check that the mapped_call_disposition was set correctly
assert (
call_args[1]["gathered_context"]["mapped_call_disposition"]
== "voicemail_detected"
)
async def test_voicemail_detector_audio_processing(self):
"""Test VoicemailDetector audio processing and detection logic - tests that voicemail detector
calls engine's send_end_task_frame with the correct reason and metadata"""
# Create voicemail detector
detector = VoicemailDetector(detection_duration=5.0, workflow_run_id=123)
# Mock OpenAI client
mock_openai = AsyncMock()
mock_whisper_response = Mock()
mock_whisper_response.text = "Hi, you've reached the voicemail of John Smith. Please leave a message after the beep."
mock_openai.audio.transcriptions.create.return_value = mock_whisper_response
mock_gpt_response = Mock()
mock_gpt_response.choices = [Mock()]
mock_gpt_response.choices[0].message.content = json.dumps(
{
"is_voicemail": True,
"confidence": 0.98,
"reasoning": "Clear voicemail greeting with request to leave message",
}
)
mock_openai.chat.completions.create.return_value = mock_gpt_response
# Mock engine
mock_engine = AsyncMock()
mock_engine.task = AsyncMock()
with (
patch(
"api.services.workflow.pipecat_engine_voicemail_detector.AsyncOpenAI",
return_value=mock_openai,
),
patch(
"api.services.workflow.pipecat_engine_voicemail_detector.s3_fs"
) as mock_s3,
):
# Mock S3 upload to return None (simulating successful upload)
mock_s3.aupload_file = AsyncMock(return_value=True)
# Start detection
await detector.start_detection(mock_engine)
assert detector.is_detecting == True
# Simulate audio data (16kHz, mono, 5 seconds)
sample_rate = 16000
duration = 5.0
audio_data = b"\x00\x00" * int(sample_rate * duration) # Silent audio
# Process audio in chunks
chunk_size = 1600 # 100ms chunks
for i in range(0, len(audio_data), chunk_size):
chunk = audio_data[i : i + chunk_size]
await detector.handle_audio_data(None, chunk, sample_rate, 1)
# Wait for detection to complete
if detector._detection_task:
await detector._detection_task
# Verify OpenAI calls
mock_openai.audio.transcriptions.create.assert_called_once()
mock_openai.chat.completions.create.assert_called_once()
# Verify send_end_task_frame was called with voicemail detection
mock_engine.send_end_task_frame.assert_called_once_with(
reason=EndTaskReason.VOICEMAIL_DETECTED.value,
additional_metadata={
"voicemail_transcript": "Hi, you've reached the voicemail of John Smith. Please leave a message after the beep.",
"voicemail_confidence": 0.98,
"voicemail_reasoning": "Clear voicemail greeting with request to leave message",
"voicemail_detection_duration": 5.0,
"voicemail_audio_s3_path": "voicemail_detections/123_voicemail_98_5.wav", # S3 upload returns True, so filename is used
},
abort_immediately=True,
)
async def test_voicemail_detector_no_detection(self):
"""Test VoicemailDetector when voicemail is not detected."""
# Create voicemail detector
detector = VoicemailDetector(detection_duration=5.0, workflow_run_id=124)
# Mock OpenAI client
mock_openai = AsyncMock()
mock_whisper_response = Mock()
mock_whisper_response.text = "Hello? Hello? Can you hear me?"
mock_openai.audio.transcriptions.create.return_value = mock_whisper_response
mock_gpt_response = Mock()
mock_gpt_response.choices = [Mock()]
mock_gpt_response.choices[0].message.content = json.dumps(
{
"is_voicemail": False,
"confidence": 0.95,
"reasoning": "Live person speaking, asking if caller can hear them",
}
)
mock_openai.chat.completions.create.return_value = mock_gpt_response
# Mock engine
mock_engine = AsyncMock()
mock_engine.task = AsyncMock()
with patch(
"api.services.workflow.pipecat_engine_voicemail_detector.AsyncOpenAI",
return_value=mock_openai,
):
# Start detection
await detector.start_detection(mock_engine)
# Simulate audio data
sample_rate = 16000
duration = 5.0
audio_data = b"\x00\x00" * int(sample_rate * duration)
# Process audio
await detector.handle_audio_data(None, audio_data, sample_rate, 1)
# Wait for detection
if detector._detection_task:
await detector._detection_task
# Verify send_end_task_frame was NOT called
mock_engine.send_end_task_frame.assert_not_called()
async def test_voicemail_detector_cancellation(self):
"""Test VoicemailDetector cancellation before completion."""
# Create voicemail detector
detector = VoicemailDetector(detection_duration=10.0, workflow_run_id=125)
# Mock engine
mock_engine = AsyncMock()
# Start detection
await detector.start_detection(mock_engine)
assert detector.is_detecting == True
# Cancel detection immediately
await detector.stop_detection()
assert detector._is_cancelled == True
# Try to add audio data after cancellation
await detector.handle_audio_data(None, b"\x00\x00" * 1000, 16000, 1)
# Verify buffer didn't grow (no audio accepted after cancellation)
assert len(detector.audio_buffer) == 0
async def test_disconnect_reason_propagation(self):
"""Test that voicemail disconnect reason is properly propagated."""
# Create disconnect reason info directly
disconnect_info = {
"disposition_code": EndTaskReason.VOICEMAIL_DETECTED.value,
"details": "Voicemail detected after 5 seconds of audio",
"is_remote": False,
"is_user_initiated": False,
"is_successful_transfer": False,
"transport_metadata": {
"voicemail_confidence": 0.97,
"voicemail_transcript": "You've reached voicemail...",
},
}
# Verify attributes
assert disconnect_info["disposition_code"] == "voicemail_detected"
assert disconnect_info["is_remote"] == False
assert disconnect_info["is_user_initiated"] == False
assert disconnect_info["is_successful_transfer"] == False
assert (
disconnect_info["details"] == "Voicemail detected after 5 seconds of audio"
)
assert disconnect_info["transport_metadata"]["voicemail_confidence"] == 0.97
async def test_voicemail_detection_end_to_end(self):
"""
Complete end-to-end test covering:
1. on_client_connected event
2. Engine initialization with voicemail detection
3. Audio processing and voicemail detection
4. Engine setting disconnect reason
5. on_client_disconnected event
6. Proper disconnect reason in workflow run update
"""
# Create comprehensive mocks
from api.services.pipecat.event_handlers import (
register_transport_event_handlers,
)
# Mock transport
mock_transport = AsyncMock()
handlers = {}
def track_handler(event_name):
def decorator(func):
handlers[event_name] = func
return func
return decorator
mock_transport.event_handler = track_handler
# Mock audio buffer
mock_audio_buffer = Mock()
mock_audio_buffer.start_recording = AsyncMock()
mock_audio_buffer.stop_recording = AsyncMock()
# Mock task
mock_task = AsyncMock()
# Mock aggregator
mock_aggregator = Mock()
mock_aggregator.get_all_usage_metrics_serialized.return_value = {}
# Create a mock engine with voicemail detection
mock_engine = Mock()
mock_engine.initialize = AsyncMock()
mock_engine.cleanup = AsyncMock()
# Mock voicemail detector
mock_voicemail_detector = Mock()
mock_engine.voicemail_detector = mock_voicemail_detector
mock_engine._voicemail_detector = mock_voicemail_detector
# Initially no disconnect reason
mock_engine.get_call_disposition = Mock(return_value=None)
mock_engine.get_gathered_context = Mock(return_value={})
# Mock db_client
mock_db_client = Mock()
mock_db_client.update_workflow_run = AsyncMock()
with (
patch("api.services.pipecat.event_handlers.db_client", mock_db_client),
patch(
"api.services.pipecat.event_handlers.enqueue_job",
new_callable=AsyncMock,
) as mock_enqueue_job,
patch(
"api.services.pipecat.event_handlers.get_organization_id_from_workflow_run",
return_value=1,
),
patch(
"api.services.pipecat.event_handlers.apply_disposition_mapping",
side_effect=lambda value, org_id: value, # Return value unchanged
),
):
# Register event handlers
register_transport_event_handlers(
mock_transport,
workflow_run_id=123,
audio_buffer=mock_audio_buffer,
task=mock_task,
engine=mock_engine,
usage_metrics_aggregator=mock_aggregator,
)
# Verify handlers were registered
assert "on_client_connected" in handlers
assert "on_client_disconnected" in handlers
# Step 1: Client connects
await handlers["on_client_connected"](
mock_transport, {"id": "participant_1"}
)
# Verify initialization
mock_audio_buffer.start_recording.assert_called_once()
mock_engine.initialize.assert_called_once()
# Step 2-3: Simulate voicemail detection occurs
# Update engine state to reflect voicemail was detected
mock_engine.get_call_disposition = Mock(
return_value=EndTaskReason.VOICEMAIL_DETECTED.value
)
mock_engine.get_gathered_context = Mock(
return_value={
"voicemail_transcript": "You've reached voicemail, leave a message",
"voicemail_confidence": 0.95,
}
)
# Step 5: Client disconnects
await handlers["on_client_disconnected"](
mock_transport, {"id": "participant_1"}, None
)
# Verify engine cleanup
mock_engine.cleanup.assert_called_once()
# Step 6: Verify proper disconnect reason in workflow run update
mock_db_client.update_workflow_run.assert_called()
call_args = mock_db_client.update_workflow_run.call_args
# Check the gathered context includes disconnect reason
gathered_context = call_args[1]["gathered_context"]
assert gathered_context["mapped_call_disposition"] == "voicemail_detected"
assert gathered_context["voicemail_confidence"] == 0.95
assert (
gathered_context["voicemail_transcript"]
== "You've reached voicemail, leave a message"
)
# Verify task was NOT cancelled (engine-initiated disconnect)
mock_task.cancel.assert_not_called()
# Verify audio buffer was stopped
mock_audio_buffer.stop_recording.assert_called_once()
# Verify background jobs were enqueued
assert (
mock_enqueue_job.call_count >= 3
) # At least 3 jobs should be enqueued

View file

@ -1,667 +0,0 @@
"""
Tests for workflow API routes.
This module tests the create, update, get, and validate workflow endpoints.
The fixtures for database setup, test client, and utilities are in conftest.py.
"""
import pytest
from fastapi import status
@pytest.fixture
def sample_workflow_definition():
"""Sample workflow definition for testing."""
return {
"nodes": [
{
"id": "6581",
"type": "startCall",
"position": {"x": 427, "y": 23},
"data": {
"prompt": "Hello, I am Abhishek from Dograh. ",
"is_static": True,
"name": "Start Call",
"is_start": True,
"invalid": False,
"validationMessage": None,
"allow_interrupt": False,
},
"measured": {"width": 300, "height": 100},
"selected": True,
"dragging": False,
},
{
"id": "915",
"type": "agentNode",
"position": {"x": 305, "y": 340},
"data": {
"prompt": "You are a voice agent whose mode of speaking is voice. Ask the user whether they want to talk to a sales guy or a customer service agent.",
"name": "Agent",
"invalid": False,
"validationMessage": None,
"allow_interrupt": False,
},
"measured": {"width": 300, "height": 100},
"selected": False,
"dragging": False,
},
{
"id": "7598",
"type": "agentNode",
"position": {"x": 90, "y": 650},
"data": {
"prompt": "You are a customer service agent whose mode of communication with the user is voice. Tell them that someone from our team will reach out to them soon",
"name": "Agent",
"invalid": False,
"validationMessage": None,
"allow_interrupt": True,
},
"measured": {"width": 300, "height": 100},
"selected": False,
"dragging": False,
},
{
"id": "6919",
"type": "agentNode",
"position": {"x": 520, "y": 650},
"data": {
"prompt": "You are a sales representative whose mode of communication with the user is voice. Tell the user that someone from our team will reach out to you soon",
"name": "Agent",
"invalid": False,
"validationMessage": None,
"allow_interrupt": True,
},
"measured": {"width": 300, "height": 100},
"selected": False,
"dragging": False,
},
{
"id": "1802",
"type": "endCall",
"position": {"x": 305, "y": 960},
"data": {
"prompt": "Thank you!",
"invalid": False,
"validationMessage": None,
"is_static": True,
"name": "End Call",
"is_end": True,
"allow_interrupt": False,
},
"measured": {"width": 300, "height": 100},
"selected": False,
"dragging": False,
},
],
"edges": [
{
"animated": True,
"type": "custom",
"source": "915",
"target": "7598",
"id": "xy-edge__915-7598",
"selected": False,
"data": {
"condition": "The customer wants to talk to a customer service agent",
"label": "customer service agent",
"invalid": False,
"validationMessage": None,
},
},
{
"animated": True,
"type": "custom",
"source": "915",
"target": "6919",
"id": "xy-edge__915-6919",
"selected": False,
"data": {
"condition": "customer wants to talk to a sales representative",
"label": "sales representative",
"invalid": False,
"validationMessage": None,
},
},
{
"animated": True,
"type": "custom",
"source": "6581",
"target": "915",
"id": "xy-edge__6581-915",
"selected": False,
"data": {
"condition": "Always take this route",
"label": "Always take this route",
"invalid": False,
"validationMessage": None,
},
},
{
"animated": True,
"type": "custom",
"source": "7598",
"target": "1802",
"id": "xy-edge__7598-1802",
"selected": False,
"data": {
"condition": "end call",
"label": "end call",
"invalid": False,
"validationMessage": None,
},
},
{
"animated": True,
"type": "custom",
"source": "6919",
"target": "1802",
"id": "xy-edge__6919-1802",
"selected": False,
"data": {
"condition": "end call",
"label": "end call",
"invalid": False,
"validationMessage": None,
},
},
],
"viewport": {"x": 0, "y": 0, "zoom": 1},
}
class TestCreateWorkflow:
"""Test cases for creating workflows."""
async def test_create_workflow_success(
self, test_client_factory, db_session, sample_workflow_definition
):
"""Test successful workflow creation."""
# Create a test user for this test
test_user = await db_session.get_or_create_user_by_provider_id(
"test_user_create_success"
)
request_data = {
"name": "Test Workflow",
"workflow_definition": sample_workflow_definition,
}
async with test_client_factory(test_user) as client:
response = await client.post("/api/v1/workflow/create", json=request_data)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert "id" in data
assert data["name"] == "Test Workflow"
assert data["workflow_definition"] == sample_workflow_definition
assert "created_at" in data
assert "current_definition_id" in data
async def test_create_workflow_invalid_definition(
self, test_client_factory, db_session
):
"""Test workflow creation with invalid definition."""
# Create a test user for this test
test_user = await db_session.get_or_create_user_by_provider_id(
"test_user_invalid_def"
)
request_data = {
"name": "Invalid Workflow",
"workflow_definition": {"invalid": "structure"},
}
async with test_client_factory(test_user) as client:
response = await client.post("/api/v1/workflow/create", json=request_data)
# The API should still create the workflow even with invalid definition
# Validation happens in the validate endpoint
assert response.status_code == status.HTTP_200_OK
@pytest.mark.asyncio
async def test_create_workflow_missing_name(
self, test_client_factory, db_session, sample_workflow_definition
):
"""Test workflow creation without name."""
# Create a test user for this test
test_user = await db_session.get_or_create_user_by_provider_id(
"test_user_missing_name"
)
request_data = {"workflow_definition": sample_workflow_definition}
async with test_client_factory(test_user) as client:
response = await client.post("/api/v1/workflow/create", json=request_data)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
@pytest.mark.asyncio
async def test_create_workflow_missing_definition(
self, test_client_factory, db_session
):
"""Test workflow creation without workflow definition."""
# Create a test user for this test
test_user = await db_session.get_or_create_user_by_provider_id(
"test_user_missing_definition"
)
request_data = {"name": "Test Workflow"}
async with test_client_factory(test_user) as client:
response = await client.post("/api/v1/workflow/create", json=request_data)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
class TestGetWorkflows:
"""Test cases for fetching workflows."""
@pytest.mark.asyncio
async def test_get_all_workflows_empty(self, test_client_factory, db_session):
"""Test getting all workflows when none exist."""
# Create a test user within the test function
test_user = await db_session.get_or_create_user_by_provider_id(
"test_user_empty_workflows"
)
# Create a test client for this specific user
async with test_client_factory(test_user) as client:
response = await client.get("/api/v1/workflow/fetch")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert isinstance(data, list)
assert len(data) == 0
@pytest.mark.asyncio
async def test_get_all_workflows_with_data(
self, test_client_factory, db_session, sample_workflow_definition
):
"""Test getting all workflows when some exist."""
# Create a test user within the test function
test_user = await db_session.get_or_create_user_by_provider_id(
"test_user_with_workflows"
)
# Create a test client for this specific user
async with test_client_factory(test_user) as client:
# Create a workflow first
create_response = await client.post(
"/api/v1/workflow/create",
json={
"name": "Test Workflow 1",
"workflow_definition": sample_workflow_definition,
},
)
assert create_response.status_code == status.HTTP_200_OK
# Create another workflow
create_response2 = await client.post(
"/api/v1/workflow/create",
json={
"name": "Test Workflow 2",
"workflow_definition": sample_workflow_definition,
},
)
assert create_response2.status_code == status.HTTP_200_OK
# Get all workflows
response = await client.get("/api/v1/workflow/fetch")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert isinstance(data, list)
assert len(data) == 2
# Check that both workflows are returned
workflow_names = [w["name"] for w in data]
assert "Test Workflow 1" in workflow_names
assert "Test Workflow 2" in workflow_names
@pytest.mark.asyncio
async def test_get_specific_workflow(
self, test_client_factory, db_session, sample_workflow_definition
):
"""Test getting a specific workflow by ID."""
# Create a test user for this test
test_user = await db_session.get_or_create_user_by_provider_id(
"test_user_specific_workflow"
)
async with test_client_factory(test_user) as client:
# Create a workflow first
create_response = await client.post(
"/api/v1/workflow/create",
json={
"name": "Specific Workflow",
"workflow_definition": sample_workflow_definition,
},
)
assert create_response.status_code == status.HTTP_200_OK
created_workflow = create_response.json()
workflow_id = created_workflow["id"]
# Get the specific workflow
response = await client.get(
f"/api/v1/workflow/fetch?workflow_id={workflow_id}"
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["id"] == workflow_id
assert data["name"] == "Specific Workflow"
assert data["workflow_definition"] == sample_workflow_definition
@pytest.mark.asyncio
async def test_get_nonexistent_workflow(self, test_client_factory, db_session):
"""Test getting a workflow that doesn't exist."""
# Create a test user for this test
test_user = await db_session.get_or_create_user_by_provider_id(
"test_user_nonexistent"
)
async with test_client_factory(test_user) as client:
response = await client.get("/api/v1/workflow/fetch?workflow_id=99999")
assert response.status_code == status.HTTP_404_NOT_FOUND
assert "not found" in response.json()["detail"].lower()
class TestUpdateWorkflow:
"""Test cases for updating workflows."""
@pytest.mark.asyncio
async def test_update_workflow_name_only(
self, test_client_factory, db_session, sample_workflow_definition
):
"""Test updating only the workflow name."""
# Create a test user for this test
test_user = await db_session.get_or_create_user_by_provider_id(
"test_user_update_name"
)
async with test_client_factory(test_user) as client:
# Create a workflow first
create_response = await client.post(
"/api/v1/workflow/create",
json={
"name": "Original Name",
"workflow_definition": sample_workflow_definition,
},
)
assert create_response.status_code == status.HTTP_200_OK
workflow_id = create_response.json()["id"]
# Update the workflow name
update_data = {"name": "Updated Name"}
response = await client.put(
f"/api/v1/workflow/{workflow_id}", json=update_data
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["id"] == workflow_id
assert data["name"] == "Updated Name"
assert (
data["workflow_definition"] == sample_workflow_definition
) # Should remain unchanged
@pytest.mark.asyncio
async def test_update_workflow_name_and_definition(
self, test_client_factory, db_session, sample_workflow_definition
):
"""Test updating both workflow name and definition."""
# Create a test user for this test
test_user = await db_session.get_or_create_user_by_provider_id(
"test_user_update_both"
)
async with test_client_factory(test_user) as client:
# Create a workflow first
create_response = await client.post(
"/api/v1/workflow/create",
json={
"name": "Original Name",
"workflow_definition": sample_workflow_definition,
},
)
assert create_response.status_code == status.HTTP_200_OK
workflow_id = create_response.json()["id"]
# Create new workflow definition
new_definition = {
"nodes": [
{
"id": "start",
"type": "start",
"position": {"x": 50, "y": 50},
"data": {"label": "New Start"},
}
],
"edges": [],
}
# Update the workflow
update_data = {
"name": "Updated Name",
"workflow_definition": new_definition,
}
response = await client.put(
f"/api/v1/workflow/{workflow_id}", json=update_data
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["id"] == workflow_id
assert data["name"] == "Updated Name"
assert data["workflow_definition"] == new_definition
@pytest.mark.asyncio
async def test_update_nonexistent_workflow(self, test_client_factory, db_session):
"""Test updating a workflow that doesn't exist."""
# Create a test user for this test
test_user = await db_session.get_or_create_user_by_provider_id(
"test_user_update_nonexistent"
)
update_data = {"name": "Updated Name"}
async with test_client_factory(test_user) as client:
response = await client.put("/api/v1/workflow/99999", json=update_data)
assert response.status_code == status.HTTP_404_NOT_FOUND
assert "not found" in response.json()["detail"].lower()
@pytest.mark.asyncio
async def test_update_workflow_missing_name(
self, test_client_factory, db_session, sample_workflow_definition
):
"""Test updating a workflow without providing a name."""
# Create a test user for this test
test_user = await db_session.get_or_create_user_by_provider_id(
"test_user_update_missing_name"
)
async with test_client_factory(test_user) as client:
# Create a workflow first
create_response = await client.post(
"/api/v1/workflow/create",
json={
"name": "Original Name",
"workflow_definition": sample_workflow_definition,
},
)
assert create_response.status_code == status.HTTP_200_OK
workflow_id = create_response.json()["id"]
# Try to update without providing name
update_data = {"workflow_definition": sample_workflow_definition}
response = await client.put(
f"/api/v1/workflow/{workflow_id}", json=update_data
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
class TestWorkflowValidation:
"""Test cases for workflow validation endpoint."""
@pytest.mark.asyncio
async def test_validate_workflow_success(
self, test_client_factory, db_session, sample_workflow_definition
):
"""Test successful workflow validation."""
# Create a test user for this test
test_user = await db_session.get_or_create_user_by_provider_id(
"test_user_validate_success"
)
async with test_client_factory(test_user) as client:
# Create a workflow first
create_response = await client.post(
"/api/v1/workflow/create",
json={
"name": "Valid Workflow",
"workflow_definition": sample_workflow_definition,
},
)
assert create_response.status_code == status.HTTP_200_OK
workflow_id = create_response.json()["id"]
# Validate the workflow
response = await client.post(f"/api/v1/workflow/{workflow_id}/validate")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["is_valid"] is True
assert data["errors"] == []
@pytest.mark.asyncio
async def test_validate_nonexistent_workflow(self, test_client_factory, db_session):
"""Test validating a workflow that doesn't exist."""
# Create a test user for this test
test_user = await db_session.get_or_create_user_by_provider_id(
"test_user_validate_nonexistent"
)
async with test_client_factory(test_user) as client:
response = await client.post("/api/v1/workflow/99999/validate")
assert response.status_code == status.HTTP_404_NOT_FOUND
assert "not found" in response.json()["detail"].lower()
class TestWorkflowIntegration:
"""Integration tests for workflow operations."""
@pytest.mark.asyncio
async def test_full_workflow_lifecycle(
self, test_client_factory, db_session, sample_workflow_definition
):
"""Test the complete lifecycle of a workflow: create, get, update, validate."""
# Create a test user for this test
test_user = await db_session.get_or_create_user_by_provider_id(
"test_user_lifecycle"
)
async with test_client_factory(test_user) as client:
# 1. Create workflow
create_response = await client.post(
"/api/v1/workflow/create",
json={
"name": "Lifecycle Test Workflow",
"workflow_definition": sample_workflow_definition,
},
)
assert create_response.status_code == status.HTTP_200_OK
workflow_id = create_response.json()["id"]
# 2. Get the created workflow
get_response = await client.get(
f"/api/v1/workflow/fetch?workflow_id={workflow_id}"
)
assert get_response.status_code == status.HTTP_200_OK
workflow_data = get_response.json()
assert workflow_data["name"] == "Lifecycle Test Workflow"
# 3. Add a new node in the workflow definition
new_node = {
"id": "6919_new",
"type": "agentNode",
"position": {"x": 520, "y": 650},
"data": {
"prompt": "Something new",
"name": "Agent",
"invalid": False,
"validationMessage": None,
"allow_interrupt": True,
},
"measured": {"width": 300, "height": 100},
"selected": False,
"dragging": False,
}
new_edges = [
{
"source": "6919",
"target": "6919_new",
"id": "xy-edge__6919-6919_new",
"data": {
"condition": "Always take this route",
"label": "Always take this route",
"invalid": False,
"validationMessage": None,
},
},
{
"source": "6919_new",
"target": "1802",
"id": "xy-edge__6919_new-1802",
"data": {
"condition": "Always take this route",
"label": "Always take this route",
"invalid": False,
"validationMessage": None,
},
},
]
new_definition = {
"nodes": [
*sample_workflow_definition["nodes"],
new_node,
],
"edges": [
*sample_workflow_definition["edges"],
*new_edges,
],
}
update_response = await client.put(
f"/api/v1/workflow/{workflow_id}",
json={
"name": "Updated Lifecycle Workflow",
"workflow_definition": new_definition,
},
)
assert update_response.status_code == status.HTTP_200_OK
assert update_response.json()["name"] == "Updated Lifecycle Workflow"
# 4. Validate the updated workflow
validate_response = await client.post(
f"/api/v1/workflow/{workflow_id}/validate"
)
assert validate_response.status_code == status.HTTP_200_OK
assert validate_response.json()["is_valid"] is True
# 5. Verify the update by getting the workflow again
final_get_response = await client.get(
f"/api/v1/workflow/fetch?workflow_id={workflow_id}"
)
assert final_get_response.status_code == status.HTTP_200_OK
final_data = final_get_response.json()
assert final_data["name"] == "Updated Lifecycle Workflow"
assert final_data["workflow_definition"] == new_definition