mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
feat: user defined custom tools as part of workflow execution (#94)
* feat: add custom tools functionality * Show tools in nodes * integrate tool calling with pipeline engine
This commit is contained in:
parent
cc2d3e70d2
commit
3e55af9256
65 changed files with 5483 additions and 6673 deletions
|
|
@ -1,138 +0,0 @@
|
|||
import asyncio
|
||||
|
||||
import pytest
|
||||
from pipecat.frames.frames import (
|
||||
FunctionCallInProgressFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
StartInterruptionFrame,
|
||||
TextFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.openai.llm import OpenAIAssistantContextAggregator
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reordering_after_completion():
|
||||
context = OpenAILLMContext()
|
||||
aggr = OpenAIAssistantContextAggregator(context)
|
||||
|
||||
# Initialize task manager properly using PipelineTask
|
||||
pipeline = Pipeline([aggr])
|
||||
task = PipelineTask(pipeline)
|
||||
runner = PipelineRunner()
|
||||
|
||||
# Start the task to properly initialize the frame processor
|
||||
task_coroutine = asyncio.create_task(runner.run(task))
|
||||
|
||||
# Give the task a moment to initialize
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# start new LLM response
|
||||
await aggr.process_frame(LLMFullResponseStartFrame(), FrameDirection.DOWNSTREAM)
|
||||
|
||||
# simulate a pending function call
|
||||
await aggr.process_frame(
|
||||
FunctionCallInProgressFrame(
|
||||
function_name="transition",
|
||||
tool_call_id="1",
|
||||
arguments={},
|
||||
cancel_on_interruption=False,
|
||||
),
|
||||
FrameDirection.DOWNSTREAM,
|
||||
)
|
||||
|
||||
# now text arrives
|
||||
await aggr.process_frame(TextFrame("Hi there"), FrameDirection.DOWNSTREAM)
|
||||
|
||||
# end response
|
||||
await aggr.process_frame(LLMFullResponseEndFrame(), FrameDirection.DOWNSTREAM)
|
||||
|
||||
msgs = context.get_messages()
|
||||
|
||||
# Assert order: assistant text first, then tool_call assistant, then tool response
|
||||
assert msgs[0]["role"] == "assistant" and "tool_calls" not in msgs[0]
|
||||
# Fix: content is a string, not a structured object
|
||||
assert msgs[0]["content"] == "Hi there"
|
||||
assert any(m.get("role") == "assistant" and m.get("tool_calls") for m in msgs[1:])
|
||||
assert any(m.get("role") == "tool" for m in msgs[1:])
|
||||
|
||||
# Clean up the running task
|
||||
await task.cancel()
|
||||
task_coroutine.cancel()
|
||||
try:
|
||||
await task_coroutine
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_interruption_removes_pending_function_calls_and_marks():
|
||||
context = OpenAILLMContext()
|
||||
aggr = OpenAIAssistantContextAggregator(context)
|
||||
|
||||
# Initialize task manager properly using PipelineTask
|
||||
pipeline = Pipeline([aggr])
|
||||
task = PipelineTask(pipeline)
|
||||
runner = PipelineRunner()
|
||||
|
||||
# Start the task to properly initialize the frame processor
|
||||
task_coroutine = asyncio.create_task(runner.run(task))
|
||||
|
||||
# Give the task a moment to initialize
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
await aggr.process_frame(LLMFullResponseStartFrame(), FrameDirection.DOWNSTREAM)
|
||||
await aggr.process_frame(
|
||||
FunctionCallInProgressFrame(
|
||||
function_name="transition",
|
||||
tool_call_id="1",
|
||||
arguments={},
|
||||
cancel_on_interruption=False,
|
||||
),
|
||||
FrameDirection.DOWNSTREAM,
|
||||
)
|
||||
|
||||
# Debug: Check the state before interruption
|
||||
print(
|
||||
f"Function calls in progress before interruption: {aggr._function_calls_in_progress}"
|
||||
)
|
||||
print(f"Messages before interruption: {context.get_messages()}")
|
||||
|
||||
# no text yet - still aggregation
|
||||
await aggr.process_frame(StartInterruptionFrame(), FrameDirection.DOWNSTREAM)
|
||||
|
||||
msgs = context.get_messages()
|
||||
|
||||
# Debug: Print messages to understand what's happening
|
||||
print(f"Messages after interruption: {msgs}")
|
||||
print(
|
||||
f"Function calls in progress after interruption: {aggr._function_calls_in_progress}"
|
||||
)
|
||||
|
||||
# After interruption before any response is complete, context should be cleared
|
||||
# This is the actual behavior - interruptions clear pending function calls
|
||||
if len(msgs) == 0:
|
||||
# Context was cleared due to interruption before completion
|
||||
assert True
|
||||
else:
|
||||
# If there are messages, ensure no tool calls remain
|
||||
assert not any(m.get("tool_calls") for m in msgs)
|
||||
assert not any(m.get("role") == "tool" for m in msgs)
|
||||
|
||||
# Check if interruption marker is present
|
||||
if msgs:
|
||||
assert msgs[-1]["role"] == "assistant"
|
||||
assert "<<interrupted_by_user>>" in msgs[-1]["content"]
|
||||
|
||||
# Clean up the running task
|
||||
await task.cancel()
|
||||
task_coroutine.cancel()
|
||||
try:
|
||||
await task_coroutine
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
|
@ -1,120 +0,0 @@
|
|||
import os
|
||||
import wave
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services.pipecat.audio_transcript_buffers import (
|
||||
InMemoryAudioBuffer,
|
||||
InMemoryTranscriptBuffer,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audio_buffer_append_and_write():
|
||||
"""Test that audio buffer can append data and write to temp file."""
|
||||
buffer = InMemoryAudioBuffer(workflow_run_id=123, sample_rate=16000, num_channels=1)
|
||||
|
||||
# Create some test PCM data
|
||||
test_pcm = b"\x00\x01" * 1000 # 2000 bytes
|
||||
|
||||
# Append data
|
||||
await buffer.append(test_pcm)
|
||||
await buffer.append(test_pcm)
|
||||
|
||||
assert buffer.size == 4000
|
||||
assert not buffer.is_empty
|
||||
|
||||
# Write to temp file
|
||||
temp_path = await buffer.write_to_temp_file()
|
||||
|
||||
try:
|
||||
# Verify file exists and is valid WAV
|
||||
assert os.path.exists(temp_path)
|
||||
|
||||
with wave.open(temp_path, "rb") as wf:
|
||||
assert wf.getnchannels() == 1
|
||||
assert wf.getsampwidth() == 2
|
||||
assert wf.getframerate() == 16000
|
||||
# Each frame is 2 bytes (16-bit), so 4000 bytes = 2000 frames
|
||||
assert wf.getnframes() == 2000
|
||||
finally:
|
||||
# Clean up
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audio_buffer_memory_limit():
|
||||
"""Test that audio buffer enforces memory limit."""
|
||||
buffer = InMemoryAudioBuffer(workflow_run_id=123, sample_rate=16000)
|
||||
|
||||
# Set a smaller limit for testing
|
||||
buffer._max_size = 1000
|
||||
|
||||
# This should work
|
||||
await buffer.append(b"\x00" * 500)
|
||||
|
||||
# This should fail
|
||||
with pytest.raises(MemoryError):
|
||||
await buffer.append(b"\x00" * 600)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_buffer_append_and_write():
|
||||
"""Test that transcript buffer can append data and write to temp file."""
|
||||
buffer = InMemoryTranscriptBuffer(workflow_run_id=456)
|
||||
|
||||
# Append some transcript lines
|
||||
await buffer.append("[00:00:01] user: Hello\n")
|
||||
await buffer.append("[00:00:02] assistant: Hi there!\n")
|
||||
await buffer.append("[00:00:03] user: How are you?\n")
|
||||
|
||||
assert not buffer.is_empty
|
||||
|
||||
# Write to temp file
|
||||
temp_path = await buffer.write_to_temp_file()
|
||||
|
||||
try:
|
||||
# Verify file exists and has correct content
|
||||
assert os.path.exists(temp_path)
|
||||
|
||||
with open(temp_path, "r") as f:
|
||||
content = f.read()
|
||||
assert "[00:00:01] user: Hello\n" in content
|
||||
assert "[00:00:02] assistant: Hi there!\n" in content
|
||||
assert "[00:00:03] user: How are you?\n" in content
|
||||
finally:
|
||||
# Clean up
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_buffers():
|
||||
"""Test that empty buffers are handled correctly."""
|
||||
audio_buffer = InMemoryAudioBuffer(workflow_run_id=789, sample_rate=16000)
|
||||
transcript_buffer = InMemoryTranscriptBuffer(workflow_run_id=789)
|
||||
|
||||
assert audio_buffer.is_empty
|
||||
assert transcript_buffer.is_empty
|
||||
|
||||
# Should still be able to write empty files
|
||||
audio_path = await audio_buffer.write_to_temp_file()
|
||||
transcript_path = await transcript_buffer.write_to_temp_file()
|
||||
|
||||
try:
|
||||
assert os.path.exists(audio_path)
|
||||
assert os.path.exists(transcript_path)
|
||||
|
||||
# Empty WAV file should still have valid header
|
||||
with wave.open(audio_path, "rb") as wf:
|
||||
assert wf.getnframes() == 0
|
||||
|
||||
# Empty transcript file
|
||||
with open(transcript_path, "r") as f:
|
||||
assert f.read() == ""
|
||||
finally:
|
||||
if os.path.exists(audio_path):
|
||||
os.remove(audio_path)
|
||||
if os.path.exists(transcript_path):
|
||||
os.remove(transcript_path)
|
||||
|
|
@ -1,330 +0,0 @@
|
|||
"""Tests for concurrent call limiting functionality."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from api.enums import OrganizationConfigurationKey
|
||||
from api.services.campaign.rate_limiter import RateLimiter
|
||||
|
||||
|
||||
class TestConcurrentCallLimiting:
|
||||
"""Test suite for concurrent call limiting."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_concurrent_slot_success(self):
|
||||
"""Test successful acquisition of concurrent slot."""
|
||||
rate_limiter = RateLimiter()
|
||||
|
||||
# Mock Redis client
|
||||
with patch.object(rate_limiter, "_get_redis") as mock_redis:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.eval = AsyncMock(return_value="test_slot_123")
|
||||
mock_redis.return_value = mock_client
|
||||
|
||||
# Try to acquire slot
|
||||
slot_id = await rate_limiter.try_acquire_concurrent_slot(
|
||||
organization_id=1, max_concurrent=20
|
||||
)
|
||||
|
||||
assert slot_id == "test_slot_123"
|
||||
mock_client.eval.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_concurrent_slot_limit_reached(self):
|
||||
"""Test slot acquisition when limit is reached."""
|
||||
rate_limiter = RateLimiter()
|
||||
|
||||
# Mock Redis client
|
||||
with patch.object(rate_limiter, "_get_redis") as mock_redis:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.eval = AsyncMock(return_value=None) # Limit reached
|
||||
mock_redis.return_value = mock_client
|
||||
|
||||
# Try to acquire slot
|
||||
slot_id = await rate_limiter.try_acquire_concurrent_slot(
|
||||
organization_id=1, max_concurrent=20
|
||||
)
|
||||
|
||||
assert slot_id is None
|
||||
mock_client.eval.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_release_concurrent_slot(self):
|
||||
"""Test releasing a concurrent slot."""
|
||||
rate_limiter = RateLimiter()
|
||||
|
||||
# Mock Redis client
|
||||
with patch.object(rate_limiter, "_get_redis") as mock_redis:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.zrem = AsyncMock(return_value=1) # Successfully removed
|
||||
mock_redis.return_value = mock_client
|
||||
|
||||
# Release slot
|
||||
success = await rate_limiter.release_concurrent_slot(
|
||||
organization_id=1, slot_id="test_slot_123"
|
||||
)
|
||||
|
||||
assert success is True
|
||||
mock_client.zrem.assert_called_once_with(
|
||||
"concurrent_calls:1", "test_slot_123"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_concurrent_count(self):
|
||||
"""Test getting current concurrent call count."""
|
||||
rate_limiter = RateLimiter()
|
||||
|
||||
# Mock Redis client
|
||||
with patch.object(rate_limiter, "_get_redis") as mock_redis:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.zremrangebyscore = AsyncMock() # Cleanup stale entries
|
||||
mock_client.zcard = AsyncMock(return_value=5) # 5 active calls
|
||||
mock_redis.return_value = mock_client
|
||||
|
||||
# Get count
|
||||
count = await rate_limiter.get_concurrent_count(organization_id=1)
|
||||
|
||||
assert count == 5
|
||||
mock_client.zremrangebyscore.assert_called_once()
|
||||
mock_client.zcard.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stale_entry_cleanup(self):
|
||||
"""Test that stale entries are cleaned up automatically."""
|
||||
rate_limiter = RateLimiter()
|
||||
|
||||
# Mock Redis client
|
||||
with patch.object(rate_limiter, "_get_redis") as mock_redis:
|
||||
mock_client = AsyncMock()
|
||||
|
||||
# Mock eval to simulate cleanup in Lua script
|
||||
mock_client.eval = AsyncMock(return_value="new_slot_123")
|
||||
mock_redis.return_value = mock_client
|
||||
|
||||
# Try to acquire slot (which should trigger cleanup)
|
||||
slot_id = await rate_limiter.try_acquire_concurrent_slot(
|
||||
organization_id=1, max_concurrent=20
|
||||
)
|
||||
|
||||
assert slot_id == "new_slot_123"
|
||||
|
||||
# Verify Lua script was called with proper stale cutoff
|
||||
call_args = mock_client.eval.call_args[0]
|
||||
lua_script = call_args[0]
|
||||
assert "ZREMRANGEBYSCORE" in lua_script # Cleanup command in script
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_workflow_slot_mapping_operations(self):
|
||||
"""Test storing, retrieving, and deleting workflow slot mappings."""
|
||||
rate_limiter = RateLimiter()
|
||||
|
||||
# Mock Redis client
|
||||
with patch.object(rate_limiter, "_get_redis") as mock_redis:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.hset = AsyncMock(return_value=1)
|
||||
mock_client.expire = AsyncMock(return_value=True)
|
||||
mock_client.hgetall = AsyncMock(
|
||||
return_value={"org_id": "1", "slot_id": "test_slot_123"}
|
||||
)
|
||||
mock_client.delete = AsyncMock(return_value=1)
|
||||
mock_redis.return_value = mock_client
|
||||
|
||||
# Test storing mapping
|
||||
success = await rate_limiter.store_workflow_slot_mapping(
|
||||
workflow_run_id=999, organization_id=1, slot_id="test_slot_123"
|
||||
)
|
||||
assert success is True
|
||||
mock_client.hset.assert_called_once()
|
||||
mock_client.expire.assert_called_once()
|
||||
|
||||
# Test retrieving mapping
|
||||
mapping = await rate_limiter.get_workflow_slot_mapping(workflow_run_id=999)
|
||||
assert mapping == (1, "test_slot_123")
|
||||
mock_client.hgetall.assert_called_once_with("workflow_slot_mapping:999")
|
||||
|
||||
# Test deleting mapping
|
||||
deleted = await rate_limiter.delete_workflow_slot_mapping(
|
||||
workflow_run_id=999
|
||||
)
|
||||
assert deleted is True
|
||||
mock_client.delete.assert_called_once_with("workflow_slot_mapping:999")
|
||||
|
||||
|
||||
class TestCampaignCallDispatcher:
|
||||
"""Test suite for CampaignCallDispatcher with concurrent limiting."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_call_waits_for_slot(self):
|
||||
"""Test that dispatch_call waits for available slot."""
|
||||
from api.services.campaign.call_dispatcher import CampaignCallDispatcher
|
||||
|
||||
dispatcher = CampaignCallDispatcher()
|
||||
|
||||
# Mock dependencies
|
||||
mock_campaign = MagicMock(
|
||||
organization_id=1, workflow_id=123, id=456, created_by=789
|
||||
)
|
||||
mock_queued_run = MagicMock(
|
||||
id=111, context_variables={"phone_number": "+1234567890"}
|
||||
)
|
||||
|
||||
# Mock rate limiter to simulate waiting
|
||||
slot_acquired = False
|
||||
call_count = 0
|
||||
|
||||
async def mock_try_acquire(org_id, max_concurrent):
|
||||
nonlocal slot_acquired, call_count
|
||||
call_count += 1
|
||||
if call_count > 2: # Succeed on third try
|
||||
slot_acquired = True
|
||||
return "test_slot_123"
|
||||
return None
|
||||
|
||||
with patch(
|
||||
"api.services.campaign.call_dispatcher.rate_limiter"
|
||||
) as mock_limiter:
|
||||
mock_limiter.try_acquire_concurrent_slot = AsyncMock(
|
||||
side_effect=mock_try_acquire
|
||||
)
|
||||
mock_limiter.release_concurrent_slot = AsyncMock()
|
||||
mock_limiter.store_workflow_slot_mapping = AsyncMock(return_value=True)
|
||||
|
||||
with patch("api.services.campaign.call_dispatcher.db_client") as mock_db:
|
||||
mock_db.get_configuration = AsyncMock(return_value=None)
|
||||
mock_db.get_workflow_by_id = AsyncMock(
|
||||
return_value=MagicMock(template_context_variables={})
|
||||
)
|
||||
mock_db.create_workflow_run = AsyncMock(
|
||||
return_value=MagicMock(id=999, logs={})
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
dispatcher.twilio_service, "initiate_call"
|
||||
) as mock_twilio:
|
||||
mock_twilio.return_value = {"sid": "test_sid"}
|
||||
|
||||
# Dispatch call (should wait and retry)
|
||||
workflow_run = await dispatcher.dispatch_call(
|
||||
mock_queued_run, mock_campaign
|
||||
)
|
||||
|
||||
assert workflow_run is not None
|
||||
assert slot_acquired is True
|
||||
assert call_count == 3 # Tried 3 times
|
||||
assert mock_limiter.try_acquire_concurrent_slot.call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_call_stores_slot_mapping(self):
|
||||
"""Test that dispatch_call stores slot mapping in Redis."""
|
||||
from api.services.campaign.call_dispatcher import CampaignCallDispatcher
|
||||
|
||||
dispatcher = CampaignCallDispatcher()
|
||||
|
||||
# Mock dependencies
|
||||
mock_campaign = MagicMock(
|
||||
organization_id=1, workflow_id=123, id=456, created_by=789
|
||||
)
|
||||
mock_queued_run = MagicMock(
|
||||
id=111, context_variables={"phone_number": "+1234567890"}
|
||||
)
|
||||
|
||||
with patch(
|
||||
"api.services.campaign.call_dispatcher.rate_limiter"
|
||||
) as mock_limiter:
|
||||
mock_limiter.try_acquire_concurrent_slot = AsyncMock(
|
||||
return_value="test_slot_123"
|
||||
)
|
||||
mock_limiter.store_workflow_slot_mapping = AsyncMock(return_value=True)
|
||||
|
||||
with patch("api.services.campaign.call_dispatcher.db_client") as mock_db:
|
||||
mock_db.get_configuration = AsyncMock(return_value=None)
|
||||
mock_db.get_workflow_by_id = AsyncMock(
|
||||
return_value=MagicMock(template_context_variables={})
|
||||
)
|
||||
mock_db.create_workflow_run = AsyncMock(
|
||||
return_value=MagicMock(id=999, logs={})
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
dispatcher.twilio_service, "initiate_call"
|
||||
) as mock_twilio:
|
||||
mock_twilio.return_value = {"sid": "test_sid"}
|
||||
|
||||
# Dispatch call
|
||||
workflow_run = await dispatcher.dispatch_call(
|
||||
mock_queued_run, mock_campaign
|
||||
)
|
||||
|
||||
# Verify slot mapping was stored
|
||||
mock_limiter.store_workflow_slot_mapping.assert_called_once_with(
|
||||
999, 1, "test_slot_123"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_org_specific_concurrent_limit(self):
|
||||
"""Test that organization-specific concurrent limit is used."""
|
||||
from api.services.campaign.call_dispatcher import CampaignCallDispatcher
|
||||
|
||||
dispatcher = CampaignCallDispatcher()
|
||||
|
||||
# Mock db_client to return org-specific limit
|
||||
with patch("api.services.campaign.call_dispatcher.db_client") as mock_db:
|
||||
mock_config = MagicMock(value={"value": 10}) # Org limit is 10
|
||||
mock_db.get_configuration = AsyncMock(return_value=mock_config)
|
||||
|
||||
# Get org limit
|
||||
limit = await dispatcher.get_org_concurrent_limit(organization_id=1)
|
||||
|
||||
assert limit == 10 # Should use org-specific limit
|
||||
mock_db.get_configuration.assert_called_once_with(
|
||||
1, OrganizationConfigurationKey.CONCURRENT_CALL_LIMIT.value
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_concurrent_limit(self):
|
||||
"""Test that default limit is used when org config not found."""
|
||||
from api.services.campaign.call_dispatcher import CampaignCallDispatcher
|
||||
|
||||
dispatcher = CampaignCallDispatcher()
|
||||
|
||||
# Mock db_client to return None (no config)
|
||||
with patch("api.services.campaign.call_dispatcher.db_client") as mock_db:
|
||||
mock_db.get_configuration = AsyncMock(return_value=None)
|
||||
|
||||
# Get org limit
|
||||
limit = await dispatcher.get_org_concurrent_limit(organization_id=1)
|
||||
|
||||
assert limit == 20 # Should use default limit
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_release_call_slot(self):
|
||||
"""Test releasing call slot when workflow completes."""
|
||||
from api.services.campaign.call_dispatcher import CampaignCallDispatcher
|
||||
|
||||
dispatcher = CampaignCallDispatcher()
|
||||
|
||||
# Mock rate limiter
|
||||
with patch(
|
||||
"api.services.campaign.call_dispatcher.rate_limiter"
|
||||
) as mock_limiter:
|
||||
# Mock getting the slot mapping from Redis
|
||||
mock_limiter.get_workflow_slot_mapping = AsyncMock(
|
||||
return_value=(1, "test_slot_123")
|
||||
)
|
||||
mock_limiter.release_concurrent_slot = AsyncMock(return_value=True)
|
||||
mock_limiter.delete_workflow_slot_mapping = AsyncMock(return_value=True)
|
||||
|
||||
# Release slot
|
||||
success = await dispatcher.release_call_slot(workflow_run_id=999)
|
||||
|
||||
assert success is True
|
||||
mock_limiter.get_workflow_slot_mapping.assert_called_once_with(999)
|
||||
mock_limiter.release_concurrent_slot.assert_called_once_with(
|
||||
1, "test_slot_123"
|
||||
)
|
||||
mock_limiter.delete_workflow_slot_mapping.assert_called_once_with(999)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
|
|
@ -1,78 +0,0 @@
|
|||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from api.schemas.user_configuration import UserConfiguration
|
||||
from api.services.configuration.masking import is_mask_of, mask_key, mask_user_config
|
||||
from api.services.configuration.merge import merge_user_configurations
|
||||
from api.services.configuration.registry import (
|
||||
OpenAILLMService,
|
||||
)
|
||||
|
||||
REAL_KEY = "sk-1234567890abcdef"
|
||||
|
||||
|
||||
def _build_config_with_openai(key: str) -> UserConfiguration:
|
||||
return UserConfiguration(
|
||||
llm=OpenAILLMService(api_key=key),
|
||||
stt=None,
|
||||
tts=None,
|
||||
)
|
||||
|
||||
|
||||
def test_mask_key_basic():
|
||||
masked = mask_key(REAL_KEY)
|
||||
# Should reveal only last 4 chars
|
||||
assert masked.endswith(REAL_KEY[-4:])
|
||||
assert set(masked[:-4]) == {"*"}
|
||||
assert len(masked) == len(REAL_KEY)
|
||||
# is_mask_of round-trip
|
||||
assert is_mask_of(masked, REAL_KEY)
|
||||
|
||||
|
||||
def test_mask_user_config_masks_api_keys():
|
||||
cfg = _build_config_with_openai(REAL_KEY)
|
||||
dumped = mask_user_config(cfg)
|
||||
assert dumped["llm"]["api_key"].endswith(REAL_KEY[-4:])
|
||||
assert dumped["llm"]["api_key"].startswith("*" * (len(REAL_KEY) - 4))
|
||||
|
||||
|
||||
def test_merge_preserves_key_when_mask_sent():
|
||||
existing = _build_config_with_openai(REAL_KEY)
|
||||
incoming_partial = {
|
||||
"llm": {
|
||||
"provider": "openai",
|
||||
"model": existing.llm.model,
|
||||
"api_key": mask_key(REAL_KEY), # masked placeholder
|
||||
}
|
||||
}
|
||||
|
||||
merged = merge_user_configurations(existing, incoming_partial)
|
||||
assert merged.llm.api_key == REAL_KEY # key preserved
|
||||
|
||||
|
||||
def test_merge_replaces_key_when_new_key_provided():
|
||||
existing = _build_config_with_openai(REAL_KEY)
|
||||
new_key = "sk-replaced-9999"
|
||||
incoming_partial = {
|
||||
"llm": {
|
||||
"provider": "openai",
|
||||
"model": existing.llm.model,
|
||||
"api_key": new_key,
|
||||
}
|
||||
}
|
||||
merged = merge_user_configurations(existing, incoming_partial)
|
||||
assert merged.llm.api_key == new_key
|
||||
|
||||
|
||||
def test_merge_drops_old_key_when_provider_changes():
|
||||
existing = _build_config_with_openai(REAL_KEY)
|
||||
incoming_partial = {
|
||||
"llm": {
|
||||
"provider": "groq",
|
||||
"model": "llama-3.3-70b-versatile",
|
||||
# api_key intentionally absent – should NOT inherit old key
|
||||
}
|
||||
}
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
merge_user_configurations(existing, incoming_partial)
|
||||
1041
api/tests/test_custom_tools.py
Normal file
1041
api/tests/test_custom_tools.py
Normal file
File diff suppressed because it is too large
Load diff
512
api/tests/test_custom_tools_context_integration.py
Normal file
512
api/tests/test_custom_tools_context_integration.py
Normal file
|
|
@ -0,0 +1,512 @@
|
|||
"""Integration tests for CustomToolManager with update_llm_context.
|
||||
|
||||
This module tests the full flow of:
|
||||
1. CustomToolManager fetching and converting tool schemas
|
||||
2. update_llm_context setting those tools on the LLM context
|
||||
3. Verifying the context is properly configured for LLM generation
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services.workflow.pipecat_engine_custom_tools import CustomToolManager
|
||||
from api.services.workflow.pipecat_engine_utils import (
|
||||
get_function_schema,
|
||||
update_llm_context,
|
||||
)
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockToolModel:
|
||||
"""Mock tool model for testing."""
|
||||
|
||||
tool_uuid: str
|
||||
name: str
|
||||
description: str
|
||||
definition: Dict[str, Any]
|
||||
|
||||
|
||||
class TestCustomToolManagerContextIntegration:
|
||||
"""Integration tests for CustomToolManager with LLMContext."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_engine(self):
|
||||
"""Create a mock PipecatEngine."""
|
||||
engine = Mock()
|
||||
engine._workflow_run_id = 1
|
||||
engine._call_context_vars = {"customer_name": "John Doe"}
|
||||
engine.llm = Mock()
|
||||
engine.llm.register_function = Mock()
|
||||
return engine
|
||||
|
||||
@pytest.fixture
|
||||
def sample_tools(self):
|
||||
"""Create sample mock tools for testing."""
|
||||
return [
|
||||
MockToolModel(
|
||||
tool_uuid="weather-uuid-123",
|
||||
name="Get Weather",
|
||||
description="Get current weather for a location",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "http_api",
|
||||
"config": {
|
||||
"method": "GET",
|
||||
"url": "https://api.weather.com/current",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "location",
|
||||
"type": "string",
|
||||
"description": "City name (e.g., San Francisco, CA)",
|
||||
"required": True,
|
||||
},
|
||||
{
|
||||
"name": "units",
|
||||
"type": "string",
|
||||
"description": "Temperature units: celsius or fahrenheit",
|
||||
"required": False,
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
),
|
||||
MockToolModel(
|
||||
tool_uuid="booking-uuid-456",
|
||||
name="Book Appointment",
|
||||
description="Book an appointment for the customer",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "http_api",
|
||||
"config": {
|
||||
"method": "POST",
|
||||
"url": "https://api.example.com/appointments",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "customer_name",
|
||||
"type": "string",
|
||||
"description": "Customer's full name",
|
||||
"required": True,
|
||||
},
|
||||
{
|
||||
"name": "date",
|
||||
"type": "string",
|
||||
"description": "Appointment date (YYYY-MM-DD)",
|
||||
"required": True,
|
||||
},
|
||||
{
|
||||
"name": "time",
|
||||
"type": "string",
|
||||
"description": "Appointment time (HH:MM)",
|
||||
"required": True,
|
||||
},
|
||||
{
|
||||
"name": "notes",
|
||||
"type": "string",
|
||||
"description": "Additional notes",
|
||||
"required": False,
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
),
|
||||
MockToolModel(
|
||||
tool_uuid="lookup-uuid-789",
|
||||
name="Customer Lookup",
|
||||
description="Look up customer information by phone number",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "http_api",
|
||||
"config": {
|
||||
"method": "GET",
|
||||
"url": "https://api.example.com/customers/lookup",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "phone",
|
||||
"type": "string",
|
||||
"description": "Customer phone number",
|
||||
"required": True,
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_tool_schemas_and_update_context(self, mock_engine, sample_tools):
|
||||
"""Test fetching tool schemas via CustomToolManager and updating LLM context."""
|
||||
manager = CustomToolManager(mock_engine)
|
||||
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run"
|
||||
) as mock_get_org:
|
||||
mock_get_org.return_value = 1
|
||||
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine_custom_tools.db_client"
|
||||
) as mock_db:
|
||||
mock_db.get_tools_by_uuids = AsyncMock(return_value=sample_tools)
|
||||
|
||||
# Get tool schemas via CustomToolManager - now returns FunctionSchema objects
|
||||
tool_uuids = ["weather-uuid-123", "booking-uuid-456", "lookup-uuid-789"]
|
||||
schemas = await manager.get_tool_schemas(tool_uuids)
|
||||
|
||||
# Verify schemas were returned as FunctionSchema objects
|
||||
assert len(schemas) == 3
|
||||
assert all(isinstance(s, FunctionSchema) for s in schemas)
|
||||
|
||||
# Create context with conversation history
|
||||
context = LLMContext()
|
||||
context.set_messages(
|
||||
[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "I need to check the weather and book an appointment.",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I can help with both. Where would you like to check the weather?",
|
||||
},
|
||||
{"role": "user", "content": "San Francisco"},
|
||||
]
|
||||
)
|
||||
|
||||
# Update context with new system message and tools
|
||||
# Now we can pass schemas directly since they're FunctionSchema objects
|
||||
new_system = {
|
||||
"role": "system",
|
||||
"content": "You are a scheduling assistant with access to weather and booking tools.",
|
||||
}
|
||||
update_llm_context(context, new_system, schemas)
|
||||
|
||||
# Verify context was updated correctly
|
||||
messages = context.messages
|
||||
assert len(messages) == 4
|
||||
assert (
|
||||
messages[0]["content"]
|
||||
== "You are a scheduling assistant with access to weather and booking tools."
|
||||
)
|
||||
assert messages[1]["role"] == "user"
|
||||
assert messages[3]["content"] == "San Francisco"
|
||||
|
||||
# Verify tools were set
|
||||
tools = context.tools
|
||||
assert tools is not None
|
||||
assert len(tools.standard_tools) == 3
|
||||
|
||||
# Verify tool names
|
||||
tool_names = {t.name for t in tools.standard_tools}
|
||||
assert tool_names == {
|
||||
"get_weather",
|
||||
"book_appointment",
|
||||
"customer_lookup",
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_schemas_have_correct_properties(
|
||||
self, mock_engine, sample_tools
|
||||
):
|
||||
"""Test that tool schemas from CustomToolManager have correct parameter properties."""
|
||||
manager = CustomToolManager(mock_engine)
|
||||
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run"
|
||||
) as mock_get_org:
|
||||
mock_get_org.return_value = 1
|
||||
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine_custom_tools.db_client"
|
||||
) as mock_db:
|
||||
mock_db.get_tools_by_uuids = AsyncMock(return_value=sample_tools)
|
||||
|
||||
schemas = await manager.get_tool_schemas(
|
||||
["weather-uuid-123", "booking-uuid-456"]
|
||||
)
|
||||
|
||||
# Find the booking schema - now using FunctionSchema attributes
|
||||
booking_schema = next(
|
||||
s for s in schemas if s.name == "book_appointment"
|
||||
)
|
||||
|
||||
# Verify parameter properties
|
||||
assert "customer_name" in booking_schema.properties
|
||||
assert "date" in booking_schema.properties
|
||||
assert "time" in booking_schema.properties
|
||||
assert "notes" in booking_schema.properties
|
||||
|
||||
# Verify types
|
||||
assert booking_schema.properties["customer_name"]["type"] == "string"
|
||||
assert booking_schema.properties["date"]["type"] == "string"
|
||||
|
||||
# Verify required
|
||||
assert "customer_name" in booking_schema.required
|
||||
assert "date" in booking_schema.required
|
||||
assert "time" in booking_schema.required
|
||||
assert "notes" not in booking_schema.required
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_update_with_builtin_and_custom_tools(
|
||||
self, mock_engine, sample_tools
|
||||
):
|
||||
"""Test updating context with both built-in and custom tools."""
|
||||
manager = CustomToolManager(mock_engine)
|
||||
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run"
|
||||
) as mock_get_org:
|
||||
mock_get_org.return_value = 1
|
||||
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine_custom_tools.db_client"
|
||||
) as mock_db:
|
||||
mock_db.get_tools_by_uuids = AsyncMock(
|
||||
return_value=[sample_tools[0]]
|
||||
) # Just weather
|
||||
|
||||
# Get custom tool schemas - returns FunctionSchema objects
|
||||
custom_schemas = await manager.get_tool_schemas(["weather-uuid-123"])
|
||||
|
||||
# Create built-in function schemas (like calculator, timezone)
|
||||
builtin_functions = [
|
||||
get_function_schema(
|
||||
"safe_calculator",
|
||||
"Evaluate a mathematical expression safely",
|
||||
properties={
|
||||
"expression": {
|
||||
"type": "string",
|
||||
"description": "Mathematical expression to evaluate",
|
||||
}
|
||||
},
|
||||
required=["expression"],
|
||||
),
|
||||
get_function_schema(
|
||||
"get_current_time",
|
||||
"Get the current time in a timezone",
|
||||
properties={
|
||||
"timezone": {
|
||||
"type": "string",
|
||||
"description": "Timezone name (e.g., America/New_York)",
|
||||
}
|
||||
},
|
||||
required=["timezone"],
|
||||
),
|
||||
]
|
||||
|
||||
# Combine built-in and custom functions - both are FunctionSchema objects
|
||||
all_functions = builtin_functions + custom_schemas
|
||||
|
||||
# Update context
|
||||
context = LLMContext()
|
||||
context.set_messages([{"role": "system", "content": "Old prompt"}])
|
||||
|
||||
new_system = {
|
||||
"role": "system",
|
||||
"content": "Assistant with calculator and weather tools",
|
||||
}
|
||||
update_llm_context(context, new_system, all_functions)
|
||||
|
||||
# Verify all tools are present
|
||||
tools = context.tools
|
||||
assert len(tools.standard_tools) == 3
|
||||
|
||||
tool_names = {t.name for t in tools.standard_tools}
|
||||
assert "safe_calculator" in tool_names
|
||||
assert "get_current_time" in tool_names
|
||||
assert "get_weather" in tool_names
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tools_cached_after_first_fetch(self, mock_engine, sample_tools):
|
||||
"""Test that CustomToolManager caches tools after first fetch."""
|
||||
manager = CustomToolManager(mock_engine)
|
||||
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run"
|
||||
) as mock_get_org:
|
||||
mock_get_org.return_value = 1
|
||||
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine_custom_tools.db_client"
|
||||
) as mock_db:
|
||||
mock_db.get_tools_by_uuids = AsyncMock(return_value=[sample_tools[0]])
|
||||
|
||||
# First fetch
|
||||
await manager.get_tool_schemas(["weather-uuid-123"])
|
||||
|
||||
# Verify tool is cached (cache stores raw schema dict, not FunctionSchema)
|
||||
cached = manager.get_cached_tool("get_weather")
|
||||
assert cached is not None
|
||||
tool, raw_schema = cached
|
||||
assert tool.tool_uuid == "weather-uuid-123"
|
||||
assert raw_schema["function"]["name"] == "get_weather"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_preserves_function_call_history(
|
||||
self, mock_engine, sample_tools
|
||||
):
|
||||
"""Test that update_llm_context preserves function call messages in history."""
|
||||
manager = CustomToolManager(mock_engine)
|
||||
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run"
|
||||
) as mock_get_org:
|
||||
mock_get_org.return_value = 1
|
||||
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine_custom_tools.db_client"
|
||||
) as mock_db:
|
||||
mock_db.get_tools_by_uuids = AsyncMock(return_value=[sample_tools[0]])
|
||||
|
||||
# Get schemas - returns FunctionSchema objects
|
||||
schemas = await manager.get_tool_schemas(["weather-uuid-123"])
|
||||
|
||||
# Create context with function call history
|
||||
context = LLMContext()
|
||||
context.set_messages(
|
||||
[
|
||||
{"role": "system", "content": "Old system prompt"},
|
||||
{"role": "user", "content": "What's the weather in NYC?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_123",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"location": "New York, NY"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_123",
|
||||
"content": '{"temperature": 72, "condition": "sunny"}',
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "The weather in NYC is 72°F and sunny!",
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
new_system = {"role": "system", "content": "Updated weather assistant"}
|
||||
update_llm_context(context, new_system, schemas)
|
||||
|
||||
messages = context.messages
|
||||
# System + user + assistant(tool_call) + tool + assistant = 5
|
||||
assert len(messages) == 5
|
||||
|
||||
# Verify function call messages are preserved
|
||||
tool_call_msg = messages[2]
|
||||
assert tool_call_msg["role"] == "assistant"
|
||||
assert "tool_calls" in tool_call_msg
|
||||
|
||||
tool_result_msg = messages[3]
|
||||
assert tool_result_msg["role"] == "tool"
|
||||
assert tool_result_msg["tool_call_id"] == "call_123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_tool_list_does_not_set_tools(self, mock_engine):
|
||||
"""Test that empty tool list doesn't set tools on context."""
|
||||
manager = CustomToolManager(mock_engine)
|
||||
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run"
|
||||
) as mock_get_org:
|
||||
mock_get_org.return_value = 1
|
||||
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine_custom_tools.db_client"
|
||||
) as mock_db:
|
||||
mock_db.get_tools_by_uuids = AsyncMock(return_value=[])
|
||||
|
||||
schemas = await manager.get_tool_schemas([])
|
||||
assert schemas == []
|
||||
|
||||
context = LLMContext()
|
||||
context.set_messages([{"role": "system", "content": "Old"}])
|
||||
|
||||
new_system = {"role": "system", "content": "No tools available"}
|
||||
update_llm_context(context, new_system, [])
|
||||
|
||||
# Context should have updated message but no tools set
|
||||
assert context.messages[0]["content"] == "No tools available"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_numeric_and_boolean_parameter_types(self, mock_engine):
|
||||
"""Test that numeric and boolean parameter types are correctly handled."""
|
||||
tool_with_types = MockToolModel(
|
||||
tool_uuid="order-uuid",
|
||||
name="Place Order",
|
||||
description="Place an order for items",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "http_api",
|
||||
"config": {
|
||||
"method": "POST",
|
||||
"url": "https://api.example.com/orders",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "item_id",
|
||||
"type": "string",
|
||||
"description": "Item identifier",
|
||||
"required": True,
|
||||
},
|
||||
{
|
||||
"name": "quantity",
|
||||
"type": "number",
|
||||
"description": "Number of items",
|
||||
"required": True,
|
||||
},
|
||||
{
|
||||
"name": "express_shipping",
|
||||
"type": "boolean",
|
||||
"description": "Use express shipping",
|
||||
"required": False,
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
manager = CustomToolManager(mock_engine)
|
||||
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run"
|
||||
) as mock_get_org:
|
||||
mock_get_org.return_value = 1
|
||||
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine_custom_tools.db_client"
|
||||
) as mock_db:
|
||||
mock_db.get_tools_by_uuids = AsyncMock(return_value=[tool_with_types])
|
||||
|
||||
# Get schemas - returns FunctionSchema objects
|
||||
schemas = await manager.get_tool_schemas(["order-uuid"])
|
||||
schema = schemas[0]
|
||||
|
||||
# Verify types using FunctionSchema attributes
|
||||
assert schema.properties["item_id"]["type"] == "string"
|
||||
assert schema.properties["quantity"]["type"] == "number"
|
||||
assert schema.properties["express_shipping"]["type"] == "boolean"
|
||||
|
||||
# Update context - pass schema directly
|
||||
context = LLMContext()
|
||||
context.set_messages([{"role": "system", "content": "Old"}])
|
||||
update_llm_context(
|
||||
context, {"role": "system", "content": "Order assistant"}, schemas
|
||||
)
|
||||
|
||||
# Verify tool was set with correct types
|
||||
tool = context.tools.standard_tools[0]
|
||||
assert tool.name == "place_order"
|
||||
assert tool.properties["quantity"]["type"] == "number"
|
||||
assert tool.properties["express_shipping"]["type"] == "boolean"
|
||||
|
|
@ -1,33 +0,0 @@
|
|||
import os
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from api.db.user_client import UserClient
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_configuration_created(db_session):
|
||||
# Set env variable for openai to simulate availability of default key
|
||||
os.environ["OPENAI_API_KEY"] = "sk-test-openai-key"
|
||||
|
||||
# Ensure deepgram env variable absent to focus test
|
||||
os.environ.pop("DEEPGRAM_API_KEY", None)
|
||||
|
||||
# Generate a unique (random) provider user ID for each test run
|
||||
test_provider_user_id = f"provider_user_{uuid.uuid4().hex}"
|
||||
user_client: UserClient = db_session # db_session fixture yields the client
|
||||
|
||||
user_model = await user_client.get_or_create_user_by_provider_id(
|
||||
test_provider_user_id
|
||||
)
|
||||
|
||||
config = await user_client.get_user_configurations(user_model.id)
|
||||
|
||||
assert config.llm is not None, "LLM config should be created when env key present"
|
||||
assert config.llm.provider == ServiceProviders.OPENAI
|
||||
assert config.llm.api_key == "sk-test-openai-key"
|
||||
|
||||
# Cleanup / restore env variable side-effects
|
||||
os.environ.pop("OPENAI_API_KEY", None)
|
||||
|
|
@ -1,122 +0,0 @@
|
|||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services.workflow.disposition_mapper import (
|
||||
apply_disposition_mapping,
|
||||
get_organization_id_from_workflow_run,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_disposition_mapping_with_valid_mapping():
|
||||
"""Test disposition mapping with valid configuration."""
|
||||
with patch("api.services.workflow.disposition_mapper.db_client") as mock_db_client:
|
||||
# Mock disposition mapping configuration
|
||||
mock_db_client.get_configuration_value = AsyncMock(
|
||||
return_value={
|
||||
"XFER": "TRANSFERRED",
|
||||
"ND": "NOT_QUALIFIED",
|
||||
"user_hangup": "HANGUP",
|
||||
}
|
||||
)
|
||||
|
||||
# Test mapping exists
|
||||
result = await apply_disposition_mapping("XFER", 1)
|
||||
assert result == "TRANSFERRED"
|
||||
|
||||
# Test mapping doesn't exist
|
||||
result = await apply_disposition_mapping("UNKNOWN", 1)
|
||||
assert result == "UNKNOWN"
|
||||
|
||||
# Verify db_client was called correctly
|
||||
mock_db_client.get_configuration_value.assert_called_with(
|
||||
1, "DISPOSITION_CODE_MAPPING", default={}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_disposition_mapping_no_organization_id():
|
||||
"""Test disposition mapping with no organization ID."""
|
||||
# Should return original value
|
||||
result = await apply_disposition_mapping("XFER", None)
|
||||
assert result == "XFER"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_disposition_mapping_empty_value():
|
||||
"""Test disposition mapping with empty value."""
|
||||
# Should return original empty value
|
||||
result = await apply_disposition_mapping("", 1)
|
||||
assert result == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_disposition_mapping_error_handling():
|
||||
"""Test disposition mapping handles errors gracefully."""
|
||||
with patch("api.services.workflow.disposition_mapper.db_client") as mock_db_client:
|
||||
# Mock database error
|
||||
mock_db_client.get_configuration_value = AsyncMock(
|
||||
side_effect=Exception("Database error")
|
||||
)
|
||||
|
||||
# Should return original value on error
|
||||
result = await apply_disposition_mapping("XFER", 1)
|
||||
assert result == "XFER"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_organization_id_from_workflow_run():
|
||||
"""Test getting organization ID from workflow run ID."""
|
||||
with patch("api.services.workflow.disposition_mapper.db_client") as mock_db_client:
|
||||
# Mock workflow run with organization
|
||||
mock_workflow_run = MagicMock()
|
||||
mock_workflow_run.workflow.user.selected_organization_id = 123
|
||||
mock_db_client.get_workflow_run_by_id = AsyncMock(
|
||||
return_value=mock_workflow_run
|
||||
)
|
||||
|
||||
result = await get_organization_id_from_workflow_run(1)
|
||||
assert result == 123
|
||||
|
||||
# Verify db_client was called correctly
|
||||
mock_db_client.get_workflow_run_by_id.assert_called_once_with(1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_organization_id_no_workflow_run():
|
||||
"""Test getting organization ID when workflow run doesn't exist."""
|
||||
with patch("api.services.workflow.disposition_mapper.db_client") as mock_db_client:
|
||||
# Mock no workflow run found
|
||||
mock_db_client.get_workflow_run_by_id = AsyncMock(return_value=None)
|
||||
|
||||
result = await get_organization_id_from_workflow_run(1)
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_organization_id_no_user():
|
||||
"""Test getting organization ID when workflow has no user."""
|
||||
with patch("api.services.workflow.disposition_mapper.db_client") as mock_db_client:
|
||||
# Mock workflow run with no user
|
||||
mock_workflow_run = MagicMock()
|
||||
mock_workflow_run.workflow.user = None
|
||||
mock_db_client.get_workflow_run_by_id = AsyncMock(
|
||||
return_value=mock_workflow_run
|
||||
)
|
||||
|
||||
result = await get_organization_id_from_workflow_run(1)
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_organization_id_error_handling():
|
||||
"""Test getting organization ID handles errors gracefully."""
|
||||
with patch("api.services.workflow.disposition_mapper.db_client") as mock_db_client:
|
||||
# Mock database error
|
||||
mock_db_client.get_workflow_run_by_id = AsyncMock(
|
||||
side_effect=Exception("Database error")
|
||||
)
|
||||
|
||||
result = await get_organization_id_from_workflow_run(1)
|
||||
assert result is None
|
||||
|
|
@ -1,370 +0,0 @@
|
|||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pipecat.utils.enums import EndTaskReason
|
||||
|
||||
from api.services.pipecat.event_handlers import register_transport_event_handlers
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dependencies():
|
||||
"""Create mock dependencies for event handlers."""
|
||||
# Store registered handlers
|
||||
registered_handlers = {}
|
||||
|
||||
def mock_event_handler(event_name):
|
||||
def decorator(func):
|
||||
registered_handlers[event_name] = func
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
mock_transport = MagicMock()
|
||||
mock_transport.event_handler = mock_event_handler
|
||||
|
||||
mock_task = MagicMock()
|
||||
mock_task.cancel = AsyncMock()
|
||||
|
||||
mock_engine = MagicMock()
|
||||
mock_engine.initialize = AsyncMock()
|
||||
mock_engine.cleanup = AsyncMock()
|
||||
|
||||
mock_audio_buffer = MagicMock()
|
||||
mock_audio_buffer.start_recording = AsyncMock()
|
||||
mock_audio_buffer.stop_recording = AsyncMock()
|
||||
|
||||
mock_usage_metrics_aggregator = MagicMock()
|
||||
mock_usage_metrics_aggregator.get_all_usage_metrics_serialized = MagicMock(
|
||||
return_value={"test": "metrics"}
|
||||
)
|
||||
|
||||
return {
|
||||
"transport": mock_transport,
|
||||
"workflow_run_id": 123,
|
||||
"audio_buffer": mock_audio_buffer,
|
||||
"task": mock_task,
|
||||
"engine": mock_engine,
|
||||
"usage_metrics_aggregator": mock_usage_metrics_aggregator,
|
||||
"audio_synchronizer": None,
|
||||
"registered_handlers": registered_handlers,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transport_disconnect_reason_mapping(mock_dependencies):
|
||||
"""Test that transport_disconnect_reason is mapped when no engine disconnect reason exists."""
|
||||
# Register event handlers
|
||||
register_transport_event_handlers(
|
||||
transport=mock_dependencies["transport"],
|
||||
workflow_run_id=mock_dependencies["workflow_run_id"],
|
||||
audio_buffer=mock_dependencies["audio_buffer"],
|
||||
task=mock_dependencies["task"],
|
||||
engine=mock_dependencies["engine"],
|
||||
usage_metrics_aggregator=mock_dependencies["usage_metrics_aggregator"],
|
||||
audio_synchronizer=mock_dependencies["audio_synchronizer"],
|
||||
)
|
||||
|
||||
# Get the on_client_disconnected handler
|
||||
handler = mock_dependencies["registered_handlers"]["on_client_disconnected"]
|
||||
|
||||
# Mock engine with no call disposition
|
||||
mock_dependencies["engine"].get_call_disposition.return_value = None
|
||||
mock_dependencies["engine"].get_gathered_context.return_value = {
|
||||
"agent_name": "Alex"
|
||||
}
|
||||
|
||||
# Mock the disposition mapper functions
|
||||
with patch(
|
||||
"api.services.pipecat.event_handlers.get_organization_id_from_workflow_run",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_org_id:
|
||||
with patch(
|
||||
"api.services.pipecat.event_handlers.apply_disposition_mapping",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_apply_mapping:
|
||||
with patch(
|
||||
"api.services.pipecat.event_handlers.db_client"
|
||||
) as mock_db_client:
|
||||
with patch(
|
||||
"api.services.pipecat.event_handlers.enqueue_job"
|
||||
) as mock_enqueue:
|
||||
# Mock organization ID
|
||||
mock_get_org_id.return_value = 1
|
||||
|
||||
# Mock call duration for user_hangup logic
|
||||
mock_dependencies[
|
||||
"usage_metrics_aggregator"
|
||||
].get_call_duration.return_value = 15
|
||||
|
||||
# Mock disposition mapping
|
||||
async def apply_mapping_side_effect(value, org_id):
|
||||
return {
|
||||
"NIBP": "NOT_INTERESTED_BUSINESS_PURPOSE",
|
||||
"user_qualified": "QUALIFIED",
|
||||
}.get(value, value)
|
||||
|
||||
mock_apply_mapping.side_effect = apply_mapping_side_effect
|
||||
|
||||
# Mock database operations
|
||||
mock_workflow_run = MagicMock()
|
||||
mock_workflow_run.id = 123
|
||||
mock_workflow_run.workflow_id = 1
|
||||
mock_workflow_run.organization_id = 1
|
||||
mock_workflow_run.gathered_context = {}
|
||||
mock_db_client.get_workflow_run_by_id = AsyncMock(
|
||||
return_value=mock_workflow_run
|
||||
)
|
||||
mock_db_client.update_workflow_run = AsyncMock()
|
||||
|
||||
# Call handler with transport_disconnect_reason
|
||||
await handler(
|
||||
mock_dependencies["transport"],
|
||||
participant=None,
|
||||
transport_disconnect_reason="user_hangup",
|
||||
)
|
||||
|
||||
# Verify disposition mapping was applied with NIBP (since duration > 10)
|
||||
mock_apply_mapping.assert_called_once_with("NIBP", 1)
|
||||
|
||||
# Verify database was updated with mapped value
|
||||
mock_db_client.update_workflow_run.assert_called_once()
|
||||
call_args = mock_db_client.update_workflow_run.call_args
|
||||
assert (
|
||||
call_args[1]["gathered_context"]["mapped_call_disposition"]
|
||||
== "NOT_INTERESTED_BUSINESS_PURPOSE"
|
||||
)
|
||||
|
||||
# Verify task was cancelled (no engine disconnect reason)
|
||||
mock_dependencies["task"].cancel.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transport_disconnect_reason_user_hangup_short_call(mock_dependencies):
|
||||
"""Test that user_hangup with short call duration is mapped to HU."""
|
||||
# Register event handlers
|
||||
register_transport_event_handlers(
|
||||
transport=mock_dependencies["transport"],
|
||||
workflow_run_id=mock_dependencies["workflow_run_id"],
|
||||
audio_buffer=mock_dependencies["audio_buffer"],
|
||||
task=mock_dependencies["task"],
|
||||
engine=mock_dependencies["engine"],
|
||||
usage_metrics_aggregator=mock_dependencies["usage_metrics_aggregator"],
|
||||
audio_synchronizer=mock_dependencies["audio_synchronizer"],
|
||||
)
|
||||
|
||||
# Get the on_client_disconnected handler
|
||||
handler = mock_dependencies["registered_handlers"]["on_client_disconnected"]
|
||||
|
||||
# Mock engine with no call disposition
|
||||
mock_dependencies["engine"].get_call_disposition.return_value = None
|
||||
mock_dependencies["engine"].get_gathered_context.return_value = {
|
||||
"agent_name": "Alex"
|
||||
}
|
||||
|
||||
# Mock the disposition mapper functions
|
||||
with patch(
|
||||
"api.services.pipecat.event_handlers.get_organization_id_from_workflow_run",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_org_id:
|
||||
with patch(
|
||||
"api.services.pipecat.event_handlers.apply_disposition_mapping",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_apply_mapping:
|
||||
with patch(
|
||||
"api.services.pipecat.event_handlers.db_client"
|
||||
) as mock_db_client:
|
||||
with patch(
|
||||
"api.services.pipecat.event_handlers.enqueue_job"
|
||||
) as mock_enqueue:
|
||||
# Mock organization ID
|
||||
mock_get_org_id.return_value = 1
|
||||
|
||||
# Mock call duration for user_hangup logic (< 10 seconds)
|
||||
mock_dependencies[
|
||||
"usage_metrics_aggregator"
|
||||
].get_call_duration.return_value = 5
|
||||
|
||||
# Mock disposition mapping
|
||||
mock_apply_mapping.return_value = "HANGUP"
|
||||
|
||||
# Mock database operations
|
||||
mock_workflow_run = MagicMock()
|
||||
mock_workflow_run.id = 123
|
||||
mock_workflow_run.workflow_id = 1
|
||||
mock_workflow_run.organization_id = 1
|
||||
mock_workflow_run.gathered_context = {}
|
||||
mock_db_client.get_workflow_run_by_id = AsyncMock(
|
||||
return_value=mock_workflow_run
|
||||
)
|
||||
mock_db_client.update_workflow_run = AsyncMock()
|
||||
|
||||
# Call handler with transport_disconnect_reason
|
||||
await handler(
|
||||
mock_dependencies["transport"],
|
||||
participant=None,
|
||||
transport_disconnect_reason="user_hangup",
|
||||
)
|
||||
|
||||
# Verify disposition mapping was applied with HU (since duration < 10)
|
||||
mock_apply_mapping.assert_called_once_with("HU", 1)
|
||||
|
||||
# Verify database was updated with mapped value
|
||||
mock_db_client.update_workflow_run.assert_called_once()
|
||||
call_args = mock_db_client.update_workflow_run.call_args
|
||||
assert (
|
||||
call_args[1]["gathered_context"]["mapped_call_disposition"]
|
||||
== "HANGUP"
|
||||
)
|
||||
|
||||
# Verify task was cancelled (no engine disconnect reason)
|
||||
mock_dependencies["task"].cancel.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_engine_disconnect_reason_takes_precedence(mock_dependencies):
|
||||
"""Test that engine disconnect reason takes precedence and is not mapped."""
|
||||
# Register event handlers
|
||||
register_transport_event_handlers(
|
||||
transport=mock_dependencies["transport"],
|
||||
workflow_run_id=mock_dependencies["workflow_run_id"],
|
||||
audio_buffer=mock_dependencies["audio_buffer"],
|
||||
task=mock_dependencies["task"],
|
||||
engine=mock_dependencies["engine"],
|
||||
usage_metrics_aggregator=mock_dependencies["usage_metrics_aggregator"],
|
||||
audio_synchronizer=mock_dependencies["audio_synchronizer"],
|
||||
)
|
||||
|
||||
# Get the on_client_disconnected handler
|
||||
handler = mock_dependencies["registered_handlers"]["on_client_disconnected"]
|
||||
|
||||
# Mock engine with call disposition
|
||||
mock_dependencies["engine"].get_call_disposition.return_value = "user_qualified"
|
||||
mock_dependencies["engine"].get_gathered_context.return_value = {
|
||||
"agent_name": "Alex"
|
||||
}
|
||||
|
||||
# Mock the disposition mapper functions
|
||||
with patch(
|
||||
"api.services.pipecat.event_handlers.get_organization_id_from_workflow_run",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_org_id:
|
||||
with patch(
|
||||
"api.services.pipecat.event_handlers.apply_disposition_mapping",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_apply_mapping:
|
||||
with patch(
|
||||
"api.services.pipecat.event_handlers.db_client"
|
||||
) as mock_db_client:
|
||||
with patch(
|
||||
"api.services.pipecat.event_handlers.enqueue_job"
|
||||
) as mock_enqueue:
|
||||
# Mock organization ID
|
||||
mock_get_org_id.return_value = 1
|
||||
|
||||
# Mock disposition mapping for engine's reason
|
||||
mock_apply_mapping.return_value = "QUALIFIED"
|
||||
|
||||
# Mock database operations
|
||||
mock_workflow_run = MagicMock()
|
||||
mock_workflow_run.id = 123
|
||||
mock_workflow_run.workflow_id = 1
|
||||
mock_workflow_run.organization_id = 1
|
||||
mock_workflow_run.gathered_context = {}
|
||||
mock_db_client.get_workflow_run_by_id = AsyncMock(
|
||||
return_value=mock_workflow_run
|
||||
)
|
||||
mock_db_client.update_workflow_run = AsyncMock()
|
||||
|
||||
# Call handler with transport_disconnect_reason
|
||||
await handler(
|
||||
mock_dependencies["transport"],
|
||||
participant=None,
|
||||
transport_disconnect_reason="user_hangup",
|
||||
)
|
||||
|
||||
# Verify disposition mapping was called with engine's reason
|
||||
mock_apply_mapping.assert_called_once_with("user_qualified", 1)
|
||||
|
||||
# Verify database was updated with mapped value
|
||||
mock_db_client.update_workflow_run.assert_called_once()
|
||||
call_args = mock_db_client.update_workflow_run.call_args
|
||||
assert (
|
||||
call_args[1]["gathered_context"]["mapped_call_disposition"]
|
||||
== "QUALIFIED"
|
||||
)
|
||||
|
||||
# Verify task was NOT cancelled (engine disconnect reason exists)
|
||||
mock_dependencies["task"].cancel.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_disconnect_reason_uses_unknown(mock_dependencies):
|
||||
"""Test that when no disconnect reason is provided, UNKNOWN is used."""
|
||||
# Register event handlers
|
||||
register_transport_event_handlers(
|
||||
transport=mock_dependencies["transport"],
|
||||
workflow_run_id=mock_dependencies["workflow_run_id"],
|
||||
audio_buffer=mock_dependencies["audio_buffer"],
|
||||
task=mock_dependencies["task"],
|
||||
engine=mock_dependencies["engine"],
|
||||
usage_metrics_aggregator=mock_dependencies["usage_metrics_aggregator"],
|
||||
audio_synchronizer=mock_dependencies["audio_synchronizer"],
|
||||
)
|
||||
|
||||
# Get the on_client_disconnected handler
|
||||
handler = mock_dependencies["registered_handlers"]["on_client_disconnected"]
|
||||
|
||||
# Mock engine with no call disposition
|
||||
mock_dependencies["engine"].get_call_disposition.return_value = None
|
||||
mock_dependencies["engine"].get_gathered_context.return_value = {
|
||||
"agent_name": "Alex"
|
||||
}
|
||||
|
||||
with patch(
|
||||
"api.services.pipecat.event_handlers.get_organization_id_from_workflow_run"
|
||||
) as mock_get_org_id:
|
||||
with patch(
|
||||
"api.services.pipecat.event_handlers.apply_disposition_mapping"
|
||||
) as mock_apply_mapping:
|
||||
with patch(
|
||||
"api.services.pipecat.event_handlers.db_client"
|
||||
) as mock_db_client:
|
||||
with patch(
|
||||
"api.services.pipecat.event_handlers.enqueue_job"
|
||||
) as mock_enqueue:
|
||||
# Mock organization ID
|
||||
mock_get_org_id.return_value = 1
|
||||
|
||||
# Mock disposition mapping - should return UNKNOWN as-is
|
||||
mock_apply_mapping.return_value = EndTaskReason.UNKNOWN.value
|
||||
|
||||
# Mock database operations
|
||||
mock_workflow_run = MagicMock()
|
||||
mock_workflow_run.id = 123
|
||||
mock_workflow_run.workflow_id = 1
|
||||
mock_workflow_run.organization_id = 1
|
||||
mock_workflow_run.gathered_context = {}
|
||||
mock_db_client.get_workflow_run_by_id = AsyncMock(
|
||||
return_value=mock_workflow_run
|
||||
)
|
||||
mock_db_client.update_workflow_run = AsyncMock()
|
||||
|
||||
# Call handler without transport_disconnect_reason
|
||||
await handler(
|
||||
mock_dependencies["transport"],
|
||||
participant=None,
|
||||
transport_disconnect_reason=None,
|
||||
)
|
||||
|
||||
# Verify disposition mapping was called with UNKNOWN
|
||||
mock_apply_mapping.assert_called_once_with(
|
||||
EndTaskReason.UNKNOWN.value, 1
|
||||
)
|
||||
|
||||
# Verify database was updated with UNKNOWN
|
||||
mock_db_client.update_workflow_run.assert_called_once()
|
||||
call_args = mock_db_client.update_workflow_run.call_args
|
||||
assert (
|
||||
call_args[1]["gathered_context"]["mapped_call_disposition"]
|
||||
== EndTaskReason.UNKNOWN.value
|
||||
)
|
||||
|
|
@ -1,184 +0,0 @@
|
|||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services.pipecat.audio_config import AudioConfig
|
||||
from api.services.pipecat.event_handlers import (
|
||||
register_audio_data_handler,
|
||||
register_transcript_handler,
|
||||
register_transport_event_handlers,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transport_handlers_with_in_memory_buffers():
|
||||
"""Test that transport handlers create and return in-memory buffers."""
|
||||
# Mock dependencies
|
||||
transport = MagicMock()
|
||||
transport.event_handler = lambda event_name: lambda func: func
|
||||
|
||||
audio_buffer = AsyncMock()
|
||||
audio_synchronizer = AsyncMock()
|
||||
task = AsyncMock()
|
||||
engine = AsyncMock()
|
||||
engine.get_call_disposition.return_value = None
|
||||
engine.get_gathered_context.return_value = {}
|
||||
|
||||
usage_metrics_aggregator = AsyncMock()
|
||||
usage_metrics_aggregator.get_call_duration.return_value = 30
|
||||
usage_metrics_aggregator.get_all_usage_metrics_serialized.return_value = {}
|
||||
|
||||
# Create test audio config
|
||||
audio_config = AudioConfig(
|
||||
transport_in_sample_rate=16000,
|
||||
transport_out_sample_rate=16000,
|
||||
pipeline_sample_rate=16000,
|
||||
)
|
||||
|
||||
# Register handlers
|
||||
audio_buf, transcript_buf = register_transport_event_handlers(
|
||||
transport=transport,
|
||||
workflow_run_id=123,
|
||||
audio_buffer=audio_buffer,
|
||||
task=task,
|
||||
engine=engine,
|
||||
usage_metrics_aggregator=usage_metrics_aggregator,
|
||||
audio_synchronizer=audio_synchronizer,
|
||||
audio_config=audio_config,
|
||||
)
|
||||
|
||||
# Verify buffers were created with correct configuration
|
||||
assert audio_buf is not None
|
||||
assert transcript_buf is not None
|
||||
assert audio_buf._workflow_run_id == 123
|
||||
assert audio_buf._sample_rate == 16000
|
||||
assert audio_buf._num_channels == 1
|
||||
assert transcript_buf._workflow_run_id == 123
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audio_handler_with_in_memory_buffer():
|
||||
"""Test audio handler uses in-memory buffer when provided."""
|
||||
# Mock audio synchronizer
|
||||
audio_synchronizer = MagicMock()
|
||||
handlers = {}
|
||||
|
||||
def mock_event_handler(event_name):
|
||||
def decorator(func):
|
||||
handlers[event_name] = func
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
audio_synchronizer.event_handler = mock_event_handler
|
||||
|
||||
# Mock in-memory buffer
|
||||
in_memory_buffer = AsyncMock()
|
||||
|
||||
# Register handler with buffer
|
||||
register_audio_data_handler(
|
||||
audio_synchronizer, workflow_run_id=123, in_memory_buffer=in_memory_buffer
|
||||
)
|
||||
|
||||
# Test the handler
|
||||
assert "on_merged_audio" in handlers
|
||||
handler = handlers["on_merged_audio"]
|
||||
|
||||
# Call handler with test data
|
||||
test_pcm = b"test_audio_data"
|
||||
await handler(None, test_pcm, 16000, 1)
|
||||
|
||||
# Verify buffer was used
|
||||
in_memory_buffer.append.assert_called_once_with(test_pcm)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_handler_with_in_memory_buffer():
|
||||
"""Test transcript handler uses in-memory buffer when provided."""
|
||||
# Mock transcript processor
|
||||
transcript = MagicMock()
|
||||
handlers = {}
|
||||
|
||||
def mock_event_handler(event_name):
|
||||
def decorator(func):
|
||||
handlers[event_name] = func
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
transcript.event_handler = mock_event_handler
|
||||
|
||||
# Mock in-memory buffer
|
||||
in_memory_buffer = AsyncMock()
|
||||
|
||||
# Register handler with buffer
|
||||
register_transcript_handler(
|
||||
transcript, workflow_run_id=456, in_memory_buffer=in_memory_buffer
|
||||
)
|
||||
|
||||
# Create test frame
|
||||
test_frame = MagicMock()
|
||||
test_frame.messages = [
|
||||
MagicMock(timestamp="00:00:01", role="user", content="Hello"),
|
||||
MagicMock(timestamp="00:00:02", role="assistant", content="Hi there"),
|
||||
]
|
||||
|
||||
# Test the handler
|
||||
handler = handlers["on_transcript_update"]
|
||||
await handler(None, test_frame)
|
||||
|
||||
# Verify buffer was used with correct format
|
||||
expected_text = "[00:00:01] user: Hello\n[00:00:02] assistant: Hi there\n"
|
||||
in_memory_buffer.append.assert_called_once_with(expected_text)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audio_config_sample_rates():
|
||||
"""Test that different audio configs result in correct sample rates."""
|
||||
# Mock dependencies
|
||||
transport = MagicMock()
|
||||
transport.event_handler = lambda event_name: lambda func: func
|
||||
|
||||
audio_buffer = AsyncMock()
|
||||
audio_synchronizer = AsyncMock()
|
||||
task = AsyncMock()
|
||||
engine = AsyncMock()
|
||||
engine.get_call_disposition.return_value = None
|
||||
engine.get_gathered_context.return_value = {}
|
||||
|
||||
usage_metrics_aggregator = AsyncMock()
|
||||
usage_metrics_aggregator.get_all_usage_metrics_serialized.return_value = {}
|
||||
|
||||
# Test with 8kHz audio config (e.g., for Stasis/Twilio)
|
||||
audio_config_8k = AudioConfig(
|
||||
transport_in_sample_rate=8000,
|
||||
transport_out_sample_rate=8000,
|
||||
pipeline_sample_rate=8000,
|
||||
)
|
||||
|
||||
audio_buf_8k, _ = register_transport_event_handlers(
|
||||
transport=transport,
|
||||
workflow_run_id=456,
|
||||
audio_buffer=audio_buffer,
|
||||
task=task,
|
||||
engine=engine,
|
||||
usage_metrics_aggregator=usage_metrics_aggregator,
|
||||
audio_synchronizer=audio_synchronizer,
|
||||
audio_config=audio_config_8k,
|
||||
)
|
||||
|
||||
assert audio_buf_8k._sample_rate == 8000
|
||||
|
||||
# Test with no audio config (should default to 16kHz)
|
||||
audio_buf_default, _ = register_transport_event_handlers(
|
||||
transport=transport,
|
||||
workflow_run_id=789,
|
||||
audio_buffer=audio_buffer,
|
||||
task=task,
|
||||
engine=engine,
|
||||
usage_metrics_aggregator=usage_metrics_aggregator,
|
||||
audio_synchronizer=audio_synchronizer,
|
||||
audio_config=None,
|
||||
)
|
||||
|
||||
assert audio_buf_default._sample_rate == 16000
|
||||
|
|
@ -1,162 +0,0 @@
|
|||
"""Test filter functionality."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from api.db.filters import ATTRIBUTE_FIELD_MAPPING, apply_workflow_run_filters
|
||||
|
||||
|
||||
def test_attribute_field_mapping():
|
||||
"""Test that all required attributes are mapped."""
|
||||
expected_attributes = [
|
||||
"dateRange",
|
||||
"dispositionCode",
|
||||
"duration",
|
||||
"status",
|
||||
"tokenUsage",
|
||||
"runId",
|
||||
"workflowId",
|
||||
"callTags",
|
||||
"phoneNumber",
|
||||
]
|
||||
|
||||
for attr in expected_attributes:
|
||||
assert attr in ATTRIBUTE_FIELD_MAPPING, f"Missing mapping for {attr}"
|
||||
|
||||
|
||||
def test_filter_with_explicit_type():
|
||||
"""Test that filters work with explicit type from UI."""
|
||||
|
||||
# Mock query
|
||||
mock_query = MagicMock()
|
||||
mock_query.where = MagicMock(return_value=mock_query)
|
||||
|
||||
test_cases = [
|
||||
# Date range filter
|
||||
{
|
||||
"filters": [
|
||||
{
|
||||
"attribute": "dateRange",
|
||||
"type": "dateRange",
|
||||
"value": {"from": "2024-01-01", "to": "2024-01-31"},
|
||||
}
|
||||
],
|
||||
},
|
||||
# Multi-select filter
|
||||
{
|
||||
"filters": [
|
||||
{
|
||||
"attribute": "dispositionCode",
|
||||
"type": "multiSelect",
|
||||
"value": {"codes": ["XFER", "HU"]},
|
||||
}
|
||||
],
|
||||
},
|
||||
# Number range filter
|
||||
{
|
||||
"filters": [
|
||||
{
|
||||
"attribute": "duration",
|
||||
"type": "numberRange",
|
||||
"value": {"min": 60, "max": 300},
|
||||
}
|
||||
],
|
||||
},
|
||||
# Radio/status filter
|
||||
{
|
||||
"filters": [
|
||||
{
|
||||
"attribute": "status",
|
||||
"type": "radio",
|
||||
"value": {"status": "completed"},
|
||||
}
|
||||
],
|
||||
},
|
||||
# Number filter
|
||||
{
|
||||
"filters": [
|
||||
{"attribute": "runId", "type": "number", "value": {"value": 123}}
|
||||
],
|
||||
},
|
||||
# Text filter
|
||||
{
|
||||
"filters": [
|
||||
{
|
||||
"attribute": "phoneNumber",
|
||||
"type": "text",
|
||||
"value": {"value": "+1234567890"},
|
||||
}
|
||||
],
|
||||
},
|
||||
# Tags filter
|
||||
{
|
||||
"filters": [
|
||||
{
|
||||
"attribute": "callTags",
|
||||
"type": "tags",
|
||||
"value": {"codes": ["tag1", "tag2"]},
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
for test_case in test_cases:
|
||||
result = apply_workflow_run_filters(mock_query, test_case["filters"])
|
||||
# The function should process the filter without errors
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_filter_format_with_type():
|
||||
"""Test that filters work with attribute, type, and value."""
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.where = MagicMock(return_value=mock_query)
|
||||
|
||||
# Test with various filter combinations
|
||||
filters = [
|
||||
{
|
||||
"attribute": "dispositionCode",
|
||||
"type": "multiSelect",
|
||||
"value": {"codes": ["NIBP"]},
|
||||
},
|
||||
{
|
||||
"attribute": "duration",
|
||||
"type": "numberRange",
|
||||
"value": {"min": 0, "max": 60},
|
||||
},
|
||||
{"attribute": "phoneNumber", "type": "text", "value": {"value": "555"}},
|
||||
]
|
||||
|
||||
result = apply_workflow_run_filters(mock_query, filters)
|
||||
|
||||
# Should have called where() for applying filters
|
||||
assert mock_query.where.called
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_unknown_attribute_ignored():
|
||||
"""Test that unknown attributes are safely ignored."""
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.where = MagicMock(return_value=mock_query)
|
||||
|
||||
filters = [
|
||||
{"attribute": "unknownAttribute", "value": {"value": "test"}},
|
||||
{"attribute": "dispositionCode", "value": {"codes": ["XFER"]}},
|
||||
]
|
||||
|
||||
result = apply_workflow_run_filters(mock_query, filters)
|
||||
|
||||
# Should still process the valid filter
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_empty_filters():
|
||||
"""Test that empty filters return the query unchanged."""
|
||||
|
||||
mock_query = MagicMock()
|
||||
|
||||
result = apply_workflow_run_filters(mock_query, None)
|
||||
assert result == mock_query
|
||||
|
||||
result = apply_workflow_run_filters(mock_query, [])
|
||||
assert result == mock_query
|
||||
|
|
@ -1,249 +0,0 @@
|
|||
"""Tests for global prompt functionality in workflow engine."""
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from pipecat.services.openai.llm import OpenAILLMContext
|
||||
|
||||
from api.services.workflow.dto import (
|
||||
EdgeDataDTO,
|
||||
NodeDataDTO,
|
||||
NodeType,
|
||||
ReactFlowDTO,
|
||||
RFEdgeDTO,
|
||||
RFNodeDTO,
|
||||
)
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.workflow import WorkflowGraph
|
||||
|
||||
|
||||
class TestGlobalPrompt:
|
||||
"""Test suite for global prompt feature."""
|
||||
|
||||
@pytest.fixture
|
||||
def workflow_with_global_node(self):
|
||||
"""Create a workflow with a global node and test nodes."""
|
||||
nodes = [
|
||||
RFNodeDTO(
|
||||
id="global",
|
||||
type=NodeType.globalNode,
|
||||
position={"x": 0, "y": 0},
|
||||
data=NodeDataDTO(
|
||||
name="Global Node",
|
||||
prompt="This is the global context: {{company_name}}",
|
||||
is_static=False,
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="start",
|
||||
type=NodeType.startNode,
|
||||
position={"x": 100, "y": 100},
|
||||
data=NodeDataDTO(
|
||||
name="Start Call",
|
||||
prompt="Welcome to our service!",
|
||||
is_static=False,
|
||||
is_start=True,
|
||||
add_global_prompt=True, # Enable global prompt
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="agent1",
|
||||
type=NodeType.agentNode,
|
||||
position={"x": 200, "y": 200},
|
||||
data=NodeDataDTO(
|
||||
name="Agent 1",
|
||||
prompt="How can I help you today?",
|
||||
add_global_prompt=False, # Disable global prompt
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="agent2",
|
||||
type=NodeType.agentNode,
|
||||
position={"x": 300, "y": 300},
|
||||
data=NodeDataDTO(
|
||||
name="Agent 2",
|
||||
prompt="Please provide your details.",
|
||||
add_global_prompt=True, # Enable global prompt
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="end",
|
||||
type=NodeType.endNode,
|
||||
position={"x": 400, "y": 400},
|
||||
data=NodeDataDTO(
|
||||
name="End Call",
|
||||
prompt="Thank you for calling!",
|
||||
is_static=True,
|
||||
is_end=True,
|
||||
add_global_prompt=True, # Enable global prompt (but static)
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
edges = [
|
||||
RFEdgeDTO(
|
||||
id="e1",
|
||||
source="start",
|
||||
target="agent1",
|
||||
data=EdgeDataDTO(label="Next", condition="Continue to agent"),
|
||||
),
|
||||
RFEdgeDTO(
|
||||
id="e2",
|
||||
source="agent1",
|
||||
target="agent2",
|
||||
data=EdgeDataDTO(label="Details", condition="Get user details"),
|
||||
),
|
||||
RFEdgeDTO(
|
||||
id="e3",
|
||||
source="agent2",
|
||||
target="end",
|
||||
data=EdgeDataDTO(label="Finish", condition="End the call"),
|
||||
),
|
||||
]
|
||||
|
||||
flow_dto = ReactFlowDTO(nodes=nodes, edges=edges)
|
||||
return WorkflowGraph(flow_dto)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dependencies(self):
|
||||
"""Create mock dependencies for PipecatEngine initialization."""
|
||||
return {
|
||||
"task": Mock(),
|
||||
"llm": Mock(),
|
||||
"context": Mock(spec=OpenAILLMContext),
|
||||
"tts": Mock(),
|
||||
"transport": Mock(),
|
||||
"call_context_vars": {"company_name": "Dograh Inc"},
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def engine(self, mock_dependencies, workflow_with_global_node):
|
||||
"""Create a PipecatEngine instance with test workflow."""
|
||||
mock_dependencies["workflow"] = workflow_with_global_node
|
||||
return PipecatEngine(**mock_dependencies)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_prompt_enabled(self, engine):
|
||||
"""Test that global prompt is prepended when add_global_prompt is True."""
|
||||
# Test with start node (add_global_prompt=True)
|
||||
start_node = engine.workflow.nodes["start"]
|
||||
(
|
||||
system_message,
|
||||
functions,
|
||||
) = await engine._compose_system_message_functions_for_node(start_node)
|
||||
|
||||
# Global prompt should be included
|
||||
expected_content = (
|
||||
"This is the global context: Dograh Inc\n\nWelcome to our service!"
|
||||
)
|
||||
assert system_message["content"] == expected_content
|
||||
assert system_message["role"] == "system"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_prompt_disabled(self, engine):
|
||||
"""Test that global prompt is not prepended when add_global_prompt is False."""
|
||||
# Test with agent1 node (add_global_prompt=False)
|
||||
agent1_node = engine.workflow.nodes["agent1"]
|
||||
(
|
||||
system_message,
|
||||
functions,
|
||||
) = await engine._compose_system_message_functions_for_node(agent1_node)
|
||||
|
||||
# Global prompt should NOT be included
|
||||
expected_content = "How can I help you today?"
|
||||
assert system_message["content"] == expected_content
|
||||
assert "global context" not in system_message["content"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_prompt_with_static_node(self, engine):
|
||||
"""Test that static nodes don't use global prompt in engine (even if enabled)."""
|
||||
# Static nodes are handled differently - they use TTSSpeakFrame directly
|
||||
# This test verifies the compose_system_message behavior for completeness
|
||||
end_node = engine.workflow.nodes["end"]
|
||||
|
||||
# Even though add_global_prompt=True, static nodes handle prompts differently
|
||||
# The _compose_system_message_functions_for_node is still called for consistency
|
||||
(
|
||||
system_message,
|
||||
functions,
|
||||
) = await engine._compose_system_message_functions_for_node(end_node)
|
||||
|
||||
# For static nodes, the global prompt would still be composed if enabled
|
||||
expected_content = (
|
||||
"This is the global context: Dograh Inc\n\nThank you for calling!"
|
||||
)
|
||||
assert system_message["content"] == expected_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_prompt_variable_substitution(self, engine):
|
||||
"""Test that variables in global prompt are properly substituted."""
|
||||
agent2_node = engine.workflow.nodes["agent2"]
|
||||
(
|
||||
system_message,
|
||||
functions,
|
||||
) = await engine._compose_system_message_functions_for_node(agent2_node)
|
||||
|
||||
# Verify variable substitution in global prompt
|
||||
assert "Dograh Inc" in system_message["content"]
|
||||
assert "{{company_name}}" not in system_message["content"]
|
||||
|
||||
# Full expected content
|
||||
expected_content = (
|
||||
"This is the global context: Dograh Inc\n\nPlease provide your details."
|
||||
)
|
||||
assert system_message["content"] == expected_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_global_node_scenario(self, engine):
|
||||
"""Test behavior when there's no global node in the workflow."""
|
||||
# Remove global node from workflow
|
||||
engine.workflow.global_node_id = None
|
||||
|
||||
start_node = engine.workflow.nodes["start"]
|
||||
(
|
||||
system_message,
|
||||
functions,
|
||||
) = await engine._compose_system_message_functions_for_node(start_node)
|
||||
|
||||
# Should only have the node's own prompt
|
||||
assert system_message["content"] == "Welcome to our service!"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_global_prompt(self, engine):
|
||||
"""Test behavior when global prompt is empty."""
|
||||
# Set global prompt to empty string
|
||||
engine.workflow.nodes["global"].prompt = ""
|
||||
|
||||
start_node = engine.workflow.nodes["start"]
|
||||
(
|
||||
system_message,
|
||||
functions,
|
||||
) = await engine._compose_system_message_functions_for_node(start_node)
|
||||
|
||||
# Should only have the node's own prompt (empty global prompt is filtered out)
|
||||
assert system_message["content"] == "Welcome to our service!"
|
||||
|
||||
def test_default_add_global_prompt_value(self):
|
||||
"""Test that add_global_prompt defaults to True in NodeDataDTO."""
|
||||
node_data = NodeDataDTO(name="Test", prompt="Test prompt")
|
||||
assert node_data.add_global_prompt is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_prompts_concatenation(self, engine):
|
||||
"""Test proper concatenation of global and node prompts."""
|
||||
# Test with agent2 node that has global prompt enabled
|
||||
agent2_node = engine.workflow.nodes["agent2"]
|
||||
|
||||
(
|
||||
system_message,
|
||||
functions,
|
||||
) = await engine._compose_system_message_functions_for_node(agent2_node)
|
||||
|
||||
# Should have global and node prompts concatenated with double newlines
|
||||
# (extraction prompt is no longer included in system message)
|
||||
expected_parts = [
|
||||
"This is the global context: Dograh Inc",
|
||||
"Please provide your details.",
|
||||
]
|
||||
expected_content = "\n\n".join(expected_parts)
|
||||
assert system_message["content"] == expected_content
|
||||
|
|
@ -1,175 +0,0 @@
|
|||
"""Unit tests for global prompt functionality - no DB dependencies."""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add the api directory to the Python path
|
||||
api_path = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(api_path))
|
||||
|
||||
from services.workflow.dto import (
|
||||
EdgeDataDTO,
|
||||
NodeDataDTO,
|
||||
NodeType,
|
||||
ReactFlowDTO,
|
||||
RFEdgeDTO,
|
||||
RFNodeDTO,
|
||||
)
|
||||
from services.workflow.workflow import WorkflowGraph
|
||||
|
||||
|
||||
def test_node_data_dto_default_global_prompt():
|
||||
"""Test that add_global_prompt defaults to True."""
|
||||
node_data = NodeDataDTO(name="Test Node", prompt="Test prompt")
|
||||
assert node_data.add_global_prompt is True
|
||||
print("✓ NodeDataDTO defaults add_global_prompt to True")
|
||||
|
||||
|
||||
def test_node_data_dto_explicit_global_prompt():
|
||||
"""Test explicit setting of add_global_prompt."""
|
||||
# Test with False
|
||||
node_data_false = NodeDataDTO(
|
||||
name="Test Node", prompt="Test prompt", add_global_prompt=False
|
||||
)
|
||||
assert node_data_false.add_global_prompt is False
|
||||
|
||||
# Test with True
|
||||
node_data_true = NodeDataDTO(
|
||||
name="Test Node", prompt="Test prompt", add_global_prompt=True
|
||||
)
|
||||
assert node_data_true.add_global_prompt is True
|
||||
print("✓ NodeDataDTO respects explicit add_global_prompt values")
|
||||
|
||||
|
||||
def test_workflow_node_inherits_global_prompt_setting():
|
||||
"""Test that workflow Node inherits add_global_prompt from NodeDataDTO."""
|
||||
nodes = [
|
||||
RFNodeDTO(
|
||||
id="start",
|
||||
type=NodeType.startNode,
|
||||
position={"x": 0, "y": 0},
|
||||
data=NodeDataDTO(
|
||||
name="Start",
|
||||
prompt="Start prompt",
|
||||
is_start=True,
|
||||
add_global_prompt=True,
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="node1",
|
||||
type=NodeType.agentNode,
|
||||
position={"x": 100, "y": 0},
|
||||
data=NodeDataDTO(
|
||||
name="Node with global", prompt="Test prompt", add_global_prompt=True
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="node2",
|
||||
type=NodeType.agentNode,
|
||||
position={"x": 200, "y": 0},
|
||||
data=NodeDataDTO(
|
||||
name="Node without global",
|
||||
prompt="Test prompt",
|
||||
add_global_prompt=False,
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="end",
|
||||
type=NodeType.endNode,
|
||||
position={"x": 300, "y": 0},
|
||||
data=NodeDataDTO(
|
||||
name="End", prompt="End prompt", is_end=True, add_global_prompt=True
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
edges = [
|
||||
RFEdgeDTO(
|
||||
id="e1",
|
||||
source="start",
|
||||
target="node1",
|
||||
data=EdgeDataDTO(label="Next", condition="Continue"),
|
||||
),
|
||||
RFEdgeDTO(
|
||||
id="e2",
|
||||
source="node1",
|
||||
target="node2",
|
||||
data=EdgeDataDTO(label="Next", condition="Continue"),
|
||||
),
|
||||
RFEdgeDTO(
|
||||
id="e3",
|
||||
source="node2",
|
||||
target="end",
|
||||
data=EdgeDataDTO(label="End", condition="Finish"),
|
||||
),
|
||||
]
|
||||
|
||||
flow_dto = ReactFlowDTO(nodes=nodes, edges=edges)
|
||||
workflow = WorkflowGraph(flow_dto)
|
||||
|
||||
assert workflow.nodes["start"].add_global_prompt is True
|
||||
assert workflow.nodes["node1"].add_global_prompt is True
|
||||
assert workflow.nodes["node2"].add_global_prompt is False
|
||||
assert workflow.nodes["end"].add_global_prompt is True
|
||||
print("✓ Workflow nodes correctly inherit add_global_prompt setting")
|
||||
|
||||
|
||||
def test_compose_system_message_respects_global_prompt_flag():
|
||||
"""Test that system message composition respects add_global_prompt flag."""
|
||||
# This is a simplified version - in real tests we'd use the full engine
|
||||
# But this demonstrates the logic
|
||||
|
||||
class MockNode:
|
||||
def __init__(self, add_global_prompt, prompt):
|
||||
self.add_global_prompt = add_global_prompt
|
||||
self.prompt = prompt
|
||||
self.out_edges = []
|
||||
self.extraction_enabled = False
|
||||
|
||||
# Simulate the logic from _compose_system_message_functions_for_node
|
||||
def compose_message(node, global_prompt):
|
||||
prompts = []
|
||||
|
||||
# Only add global prompt if node.add_global_prompt is True
|
||||
if global_prompt and node.add_global_prompt:
|
||||
prompts.append(global_prompt)
|
||||
|
||||
prompts.append(node.prompt)
|
||||
|
||||
return "\n\n".join(p for p in prompts if p)
|
||||
|
||||
global_prompt = "This is the global context"
|
||||
|
||||
# Test with add_global_prompt=True
|
||||
node_with_global = MockNode(add_global_prompt=True, prompt="Node prompt")
|
||||
message_with = compose_message(node_with_global, global_prompt)
|
||||
assert message_with == "This is the global context\n\nNode prompt"
|
||||
|
||||
# Test with add_global_prompt=False
|
||||
node_without_global = MockNode(add_global_prompt=False, prompt="Node prompt")
|
||||
message_without = compose_message(node_without_global, global_prompt)
|
||||
assert message_without == "Node prompt"
|
||||
|
||||
print("✓ System message composition respects add_global_prompt flag")
|
||||
|
||||
|
||||
def test_static_nodes_with_global_prompt():
|
||||
"""Test static nodes can have add_global_prompt setting."""
|
||||
static_node_data = NodeDataDTO(
|
||||
name="Static Node", prompt="Static text", is_static=True, add_global_prompt=True
|
||||
)
|
||||
|
||||
assert static_node_data.is_static is True
|
||||
assert static_node_data.add_global_prompt is True
|
||||
print("✓ Static nodes can have add_global_prompt setting")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run all tests
|
||||
test_node_data_dto_default_global_prompt()
|
||||
test_node_data_dto_explicit_global_prompt()
|
||||
test_workflow_node_inherits_global_prompt_setting()
|
||||
test_compose_system_message_respects_global_prompt_flag()
|
||||
test_static_nodes_with_global_prompt()
|
||||
|
||||
print("\n✅ All unit tests passed!")
|
||||
|
|
@ -1,248 +0,0 @@
|
|||
"""
|
||||
Test cases for _leave_counter mechanism in transport clients.
|
||||
|
||||
This test suite verifies that the _leave_counter prevents premature disconnection
|
||||
when both input and output transports are using the same client.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
from pipecat.frames.frames import EndFrame, StartFrame
|
||||
from pipecat.transports.network.fastapi_websocket import (
|
||||
FastAPIWebsocketCallbacks,
|
||||
FastAPIWebsocketClient,
|
||||
FastAPIWebsocketParams,
|
||||
FastAPIWebsocketTransport,
|
||||
)
|
||||
from pipecat.transports.network.small_webrtc import SmallWebRTCClient
|
||||
|
||||
from api.services.telephony.stasis_rtp_client import StasisRTPClient
|
||||
|
||||
|
||||
class TestLeaveCounterFastAPIWebsocket:
|
||||
"""Test the _leave_counter mechanism in FastAPIWebsocketClient."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_leave_counter_prevents_early_disconnect(self):
|
||||
"""Test that disconnect only happens when both transports have disconnected."""
|
||||
# Create mock websocket
|
||||
mock_websocket = Mock()
|
||||
mock_websocket.close = AsyncMock()
|
||||
# Set client_state directly to WebSocketState.CONNECTED value
|
||||
from starlette.websockets import WebSocketState
|
||||
|
||||
mock_websocket.client_state = WebSocketState.CONNECTED
|
||||
|
||||
# Create callbacks
|
||||
callbacks = FastAPIWebsocketCallbacks(
|
||||
on_client_connected=AsyncMock(),
|
||||
on_client_disconnected=AsyncMock(),
|
||||
on_session_timeout=AsyncMock(),
|
||||
)
|
||||
|
||||
# Create client
|
||||
client = FastAPIWebsocketClient(
|
||||
mock_websocket, is_binary=False, callbacks=callbacks
|
||||
)
|
||||
|
||||
# Create StartFrame
|
||||
start_frame = StartFrame()
|
||||
|
||||
# Simulate both input and output transports calling setup
|
||||
await client.setup(start_frame) # Input transport
|
||||
assert client._leave_counter == 1
|
||||
|
||||
await client.setup(start_frame) # Output transport
|
||||
assert client._leave_counter == 2
|
||||
|
||||
# First disconnect - should not actually disconnect
|
||||
await client.disconnect()
|
||||
assert client._leave_counter == 1
|
||||
mock_websocket.close.assert_not_called()
|
||||
callbacks.on_client_disconnected.assert_not_called()
|
||||
|
||||
# Second disconnect - should actually disconnect
|
||||
await client.disconnect()
|
||||
assert client._leave_counter == 0
|
||||
mock_websocket.close.assert_called_once()
|
||||
callbacks.on_client_disconnected.assert_called_once()
|
||||
|
||||
|
||||
class TestLeaveCounterStasisRTP:
|
||||
"""Test the _leave_counter mechanism in StasisRTPClient."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_leave_counter_prevents_early_disconnect(self):
|
||||
"""Test that disconnect only happens when both transports have disconnected."""
|
||||
# Create mock connection
|
||||
mock_connection = Mock()
|
||||
mock_connection.is_connected.return_value = True
|
||||
mock_connection.disconnect = AsyncMock()
|
||||
mock_connection.notify_sockets_closed = AsyncMock()
|
||||
|
||||
# Mock event_handler as a callable that acts as a decorator
|
||||
def mock_event_handler(event_name):
|
||||
def decorator(func):
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
mock_connection.event_handler = mock_event_handler
|
||||
|
||||
# Create callbacks
|
||||
from api.services.telephony.stasis_rtp_transport import StasisRTPCallbacks
|
||||
|
||||
callbacks = StasisRTPCallbacks(
|
||||
on_client_connected=AsyncMock(),
|
||||
on_client_disconnected=AsyncMock(),
|
||||
on_client_closed=AsyncMock(),
|
||||
)
|
||||
|
||||
# Create client
|
||||
client = StasisRTPClient(mock_connection, callbacks)
|
||||
|
||||
# Create StartFrame
|
||||
start_frame = StartFrame()
|
||||
|
||||
# Simulate both input and output transports calling setup
|
||||
await client.setup(start_frame) # Input transport
|
||||
assert client._leave_counter == 1
|
||||
|
||||
await client.setup(start_frame) # Output transport
|
||||
assert client._leave_counter == 2
|
||||
|
||||
# First disconnect - should not actually disconnect
|
||||
await client.disconnect()
|
||||
assert client._leave_counter == 1
|
||||
mock_connection.disconnect.assert_not_called()
|
||||
|
||||
# Second disconnect - should actually disconnect
|
||||
await client.disconnect()
|
||||
assert client._leave_counter == 0
|
||||
mock_connection.disconnect.assert_called_once()
|
||||
|
||||
|
||||
class TestLeaveCounterSmallWebRTC:
|
||||
"""Test the _leave_counter mechanism in SmallWebRTCClient."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_leave_counter_prevents_early_disconnect(self):
|
||||
"""Test that disconnect only happens when both transports have disconnected."""
|
||||
# Create mock connection
|
||||
mock_connection = Mock()
|
||||
mock_connection.is_connected.return_value = True
|
||||
mock_connection.disconnect = AsyncMock()
|
||||
mock_connection.notify_sockets_closed = AsyncMock()
|
||||
|
||||
# Mock event_handler as a callable that acts as a decorator
|
||||
def mock_event_handler(event_name):
|
||||
def decorator(func):
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
mock_connection.event_handler = mock_event_handler
|
||||
|
||||
# Create callbacks
|
||||
from pipecat.transports.network.small_webrtc import SmallWebRTCCallbacks
|
||||
|
||||
callbacks = SmallWebRTCCallbacks(
|
||||
on_app_message=AsyncMock(),
|
||||
on_client_connected=AsyncMock(),
|
||||
on_client_disconnected=AsyncMock(),
|
||||
)
|
||||
|
||||
# Create client
|
||||
client = SmallWebRTCClient(mock_connection, callbacks)
|
||||
|
||||
# Create StartFrame with required attributes
|
||||
start_frame = StartFrame()
|
||||
|
||||
# Create mock transport params
|
||||
from pipecat.transports.base_transport import TransportParams
|
||||
|
||||
params = TransportParams(
|
||||
audio_in_channels=1, audio_in_sample_rate=16000, audio_out_sample_rate=16000
|
||||
)
|
||||
|
||||
# Simulate both input and output transports calling setup
|
||||
await client.setup(params, start_frame) # Input transport
|
||||
assert client._leave_counter == 1
|
||||
|
||||
await client.setup(params, start_frame) # Output transport
|
||||
assert client._leave_counter == 2
|
||||
|
||||
# First disconnect - should not actually disconnect
|
||||
await client.disconnect()
|
||||
assert client._leave_counter == 1
|
||||
mock_connection.disconnect.assert_not_called()
|
||||
|
||||
# Second disconnect - should actually disconnect
|
||||
await client.disconnect()
|
||||
assert client._leave_counter == 0
|
||||
mock_connection.disconnect.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Complex integration test - requires additional mocking")
|
||||
@pytest.mark.asyncio
|
||||
async def test_transport_lifecycle_with_leave_counter():
|
||||
"""Test complete transport lifecycle with proper leave counter handling."""
|
||||
# Create mock websocket
|
||||
mock_websocket = Mock()
|
||||
mock_websocket.close = AsyncMock()
|
||||
# Set client_state directly to WebSocketState.CONNECTED value
|
||||
from starlette.websockets import WebSocketState
|
||||
|
||||
mock_websocket.client_state = WebSocketState.CONNECTED
|
||||
mock_websocket.iter_bytes = Mock(return_value=iter([]))
|
||||
mock_websocket.send_bytes = AsyncMock()
|
||||
|
||||
# Create transport
|
||||
params = FastAPIWebsocketParams(audio_in_enabled=True, audio_out_enabled=True)
|
||||
transport = FastAPIWebsocketTransport(mock_websocket, params)
|
||||
|
||||
# Get input and output transports
|
||||
input_transport = transport.input()
|
||||
output_transport = transport.output()
|
||||
|
||||
# Setup the transport with required components
|
||||
from pipecat.clocks.system_clock import SystemClock
|
||||
from pipecat.processors.frame_processor import FrameProcessorSetup
|
||||
from pipecat.utils.asyncio.task_manager import TaskManager, TaskManagerParams
|
||||
|
||||
clock = SystemClock()
|
||||
task_manager = TaskManager()
|
||||
|
||||
# Setup task manager with event loop
|
||||
loop = asyncio.get_event_loop()
|
||||
task_manager_params = TaskManagerParams(loop=loop)
|
||||
task_manager.setup(task_manager_params)
|
||||
|
||||
setup = FrameProcessorSetup(clock=clock, task_manager=task_manager)
|
||||
|
||||
# Setup both input and output transports
|
||||
await input_transport.setup(setup)
|
||||
await output_transport.setup(setup)
|
||||
|
||||
# Start both transports
|
||||
start_frame = StartFrame()
|
||||
await input_transport.start(start_frame)
|
||||
await output_transport.start(start_frame)
|
||||
|
||||
# Verify leave counter is 2
|
||||
assert transport._client._leave_counter == 2
|
||||
|
||||
# Stop input transport
|
||||
end_frame = EndFrame()
|
||||
await input_transport.stop(end_frame)
|
||||
|
||||
# Verify websocket not closed yet
|
||||
mock_websocket.close.assert_not_called()
|
||||
|
||||
# Stop output transport
|
||||
await output_transport.stop(end_frame)
|
||||
|
||||
# Now websocket should be closed
|
||||
mock_websocket.close.assert_called_once()
|
||||
|
|
@ -1,99 +0,0 @@
|
|||
import unittest
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
FunctionCallInProgressFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.google.llm import (
|
||||
GoogleAssistantContextAggregator,
|
||||
GoogleLLMContext,
|
||||
)
|
||||
from pipecat.services.openai.llm import OpenAIAssistantContextAggregator
|
||||
|
||||
|
||||
class TestReorderOpenAIAssistantContextAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_reorder_function_messages_openai(self):
|
||||
"""Ensure that after a text aggregation the function-call messages are moved
|
||||
to appear immediately after the text response, maintaining chronological
|
||||
order (assistant text -> function call -> tool response).
|
||||
"""
|
||||
|
||||
context = OpenAILLMContext()
|
||||
aggregator = OpenAIAssistantContextAggregator(context)
|
||||
|
||||
# Simulate the start of an LLM response so that the aggregator creates a
|
||||
# response session ID that is later used for re-ordering.
|
||||
await aggregator._handle_llm_start(LLMFullResponseStartFrame())
|
||||
|
||||
# Simulate the model emitting a function call which the aggregator will
|
||||
# record for potential re-ordering.
|
||||
await aggregator._handle_function_call_in_progress(
|
||||
FunctionCallInProgressFrame(
|
||||
function_name="get_weather",
|
||||
tool_call_id="1",
|
||||
arguments={},
|
||||
)
|
||||
)
|
||||
|
||||
# Now push the textual part of the assistant response. This should
|
||||
# trigger the re-ordering so that the two function-related messages
|
||||
# appear *after* this text.
|
||||
await aggregator.handle_aggregation("Hello!")
|
||||
|
||||
messages = context.get_messages()
|
||||
|
||||
# We expect exactly three messages after re-ordering.
|
||||
self.assertEqual(len(messages), 3)
|
||||
|
||||
# 1. Assistant text
|
||||
self.assertEqual(messages[0]["role"], "assistant")
|
||||
self.assertEqual(messages[0]["content"], "Hello!")
|
||||
|
||||
# 2. Assistant function-call message
|
||||
self.assertEqual(messages[1]["role"], "assistant")
|
||||
self.assertIn("tool_calls", messages[1])
|
||||
|
||||
# 3. Tool response
|
||||
self.assertEqual(messages[2]["role"], "tool")
|
||||
self.assertEqual(messages[2]["tool_call_id"], "1")
|
||||
|
||||
|
||||
class TestReorderGoogleAssistantContextAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_reorder_function_messages_google(self):
|
||||
context = GoogleLLMContext()
|
||||
aggregator = GoogleAssistantContextAggregator(context)
|
||||
|
||||
# Start an LLM response session.
|
||||
await aggregator._handle_llm_start(LLMFullResponseStartFrame())
|
||||
|
||||
# Emit a function call.
|
||||
await aggregator._handle_function_call_in_progress(
|
||||
FunctionCallInProgressFrame(
|
||||
function_name="get_weather",
|
||||
tool_call_id="1",
|
||||
arguments={},
|
||||
)
|
||||
)
|
||||
|
||||
# Push the textual content.
|
||||
await aggregator.handle_aggregation("Hello!")
|
||||
|
||||
messages = context.messages # Google context stores Content objects.
|
||||
|
||||
self.assertEqual(len(messages), 3)
|
||||
|
||||
# The first message should be the model text.
|
||||
first_msg = messages[0].to_json_dict()
|
||||
self.assertEqual(first_msg["role"], "model")
|
||||
self.assertEqual(first_msg["parts"][0]["text"], "Hello!")
|
||||
|
||||
# The second message contains the function call (also from the model).
|
||||
second_msg = messages[1].to_json_dict()
|
||||
self.assertEqual(second_msg["role"], "model")
|
||||
self.assertIn("function_call", second_msg["parts"][0])
|
||||
|
||||
# The third message is the placeholder function response.
|
||||
third_msg = messages[2].to_json_dict()
|
||||
self.assertEqual(third_msg["role"], "user")
|
||||
self.assertIn("function_response", third_msg["parts"][0])
|
||||
|
|
@ -1,506 +0,0 @@
|
|||
"""
|
||||
Tests for LoopTalk API routes and orchestration.
|
||||
|
||||
This module tests the LoopTalk testing functionality including test session creation,
|
||||
pipeline orchestration, and agent-to-agent communication.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from fastapi import status
|
||||
|
||||
from api.db.db_client import DBClient
|
||||
from api.services.looptalk.orchestrator import LoopTalkTestOrchestrator
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def actor_workflow_definition():
|
||||
"""Sample actor workflow definition for testing."""
|
||||
return {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "1",
|
||||
"type": "startCall",
|
||||
"position": {"x": 0, "y": 0},
|
||||
"data": {
|
||||
"prompt": "Hello, I'm the actor agent.",
|
||||
"is_static": True,
|
||||
"name": "Start Call",
|
||||
"is_start": True,
|
||||
"allow_interrupt": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "2",
|
||||
"type": "agentNode",
|
||||
"position": {"x": 100, "y": 0},
|
||||
"data": {
|
||||
"prompt": "You are an actor agent testing the adversary. Ask probing questions.",
|
||||
"name": "Actor Agent",
|
||||
"allow_interrupt": True,
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "3",
|
||||
"type": "endCall",
|
||||
"position": {"x": 200, "y": 0},
|
||||
"data": {
|
||||
"prompt": "Goodbye!",
|
||||
"name": "End Call",
|
||||
"is_end": True,
|
||||
},
|
||||
},
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"id": "e1",
|
||||
"source": "1",
|
||||
"target": "2",
|
||||
"data": {"label": "Continue", "condition": "Always"},
|
||||
},
|
||||
{
|
||||
"id": "e2",
|
||||
"source": "2",
|
||||
"target": "3",
|
||||
"data": {"label": "End", "condition": "Always"},
|
||||
},
|
||||
],
|
||||
"stt": {"provider": "openai", "api_key": "test-key", "model": "whisper-1"},
|
||||
"llm": {"provider": "openai", "api_key": "test-key", "model": "gpt-4o-mini"},
|
||||
"tts": {
|
||||
"provider": "openai",
|
||||
"api_key": "test-key",
|
||||
"model": "tts-1",
|
||||
"voice": "nova",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def adversary_workflow_definition():
|
||||
"""Sample adversary workflow definition for testing."""
|
||||
return {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "1",
|
||||
"type": "startCall",
|
||||
"position": {"x": 0, "y": 0},
|
||||
"data": {
|
||||
"prompt": "Hello, I'm the adversary agent.",
|
||||
"is_static": True,
|
||||
"name": "Start Call",
|
||||
"is_start": True,
|
||||
"allow_interrupt": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "2",
|
||||
"type": "agentNode",
|
||||
"position": {"x": 100, "y": 0},
|
||||
"data": {
|
||||
"prompt": "You are an adversary agent being tested. Respond defensively.",
|
||||
"name": "Adversary Agent",
|
||||
"allow_interrupt": True,
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "3",
|
||||
"type": "endCall",
|
||||
"position": {"x": 200, "y": 0},
|
||||
"data": {
|
||||
"prompt": "Goodbye!",
|
||||
"name": "End Call",
|
||||
"is_end": True,
|
||||
},
|
||||
},
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"id": "e1",
|
||||
"source": "1",
|
||||
"target": "2",
|
||||
"data": {"label": "Continue", "condition": "Always"},
|
||||
},
|
||||
{
|
||||
"id": "e2",
|
||||
"source": "2",
|
||||
"target": "3",
|
||||
"data": {"label": "End", "condition": "Always"},
|
||||
},
|
||||
],
|
||||
"stt": {"provider": "deepgram", "api_key": "test-key", "model": "nova-2"},
|
||||
"llm": {
|
||||
"provider": "groq",
|
||||
"api_key": "test-key",
|
||||
"model": "llama-3.1-70b-versatile",
|
||||
},
|
||||
"tts": {"provider": "deepgram", "api_key": "test-key", "voice": "nova-2"},
|
||||
}
|
||||
|
||||
|
||||
from pipecat.processors.frame_processor import FrameProcessor
|
||||
|
||||
|
||||
class MockSTTService(FrameProcessor):
|
||||
"""Mock STT service for testing."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def run_stt(self, audio: bytes) -> str:
|
||||
return "Mock transcription"
|
||||
|
||||
|
||||
class MockLLMService(FrameProcessor):
|
||||
"""Mock LLM service for testing."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def run_llm(self, messages) -> str:
|
||||
return "Mock LLM response"
|
||||
|
||||
def create_context_aggregator(self, context):
|
||||
"""Mock context aggregator creation."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
class MockTTSService(FrameProcessor):
|
||||
"""Mock TTS service for testing."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def run_tts(self, text: str) -> bytes:
|
||||
return b"Mock audio data"
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_user_with_org(db_session):
|
||||
"""Create a test user with an organization set up."""
|
||||
user = await db_session.get_or_create_user_by_provider_id("test_looptalk_user")
|
||||
org, _ = await db_session.get_or_create_organization_by_provider_id(
|
||||
"test_looptalk_org"
|
||||
)
|
||||
|
||||
user_id = user.id
|
||||
org_id = org.id
|
||||
|
||||
await db_session.add_user_to_organization(user_id, org_id)
|
||||
|
||||
# Update user's selected organization
|
||||
async with db_session.async_session() as session:
|
||||
from sqlalchemy import update
|
||||
|
||||
from api.db.models import UserModel
|
||||
|
||||
await session.execute(
|
||||
update(UserModel)
|
||||
.where(UserModel.id == user_id)
|
||||
.values(selected_organization_id=org_id)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# Return fresh user object
|
||||
return await db_session.get_user_by_id(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_test_session(
|
||||
test_client_factory,
|
||||
db_session,
|
||||
test_user_with_org,
|
||||
actor_workflow_definition,
|
||||
adversary_workflow_definition,
|
||||
):
|
||||
"""Test creating a new LoopTalk test session."""
|
||||
async with test_client_factory(test_user_with_org) as test_client:
|
||||
# First create two workflows
|
||||
actor_workflow_response = await test_client.post(
|
||||
"/api/v1/workflow/create",
|
||||
json={
|
||||
"name": "Actor Workflow",
|
||||
"workflow_definition": actor_workflow_definition,
|
||||
},
|
||||
)
|
||||
assert actor_workflow_response.status_code == status.HTTP_200_OK
|
||||
actor_workflow_id = actor_workflow_response.json()["id"]
|
||||
|
||||
adversary_workflow_response = await test_client.post(
|
||||
"/api/v1/workflow/create",
|
||||
json={
|
||||
"name": "Adversary Workflow",
|
||||
"workflow_definition": adversary_workflow_definition,
|
||||
},
|
||||
)
|
||||
assert adversary_workflow_response.status_code == status.HTTP_200_OK
|
||||
adversary_workflow_id = adversary_workflow_response.json()["id"]
|
||||
|
||||
# Create test session
|
||||
response = await test_client.post(
|
||||
"/api/v1/looptalk/test-sessions",
|
||||
json={
|
||||
"name": "Test Session 1",
|
||||
"actor_workflow_id": actor_workflow_id,
|
||||
"adversary_workflow_id": adversary_workflow_id,
|
||||
"config": {"test_duration": 60},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["name"] == "Test Session 1"
|
||||
assert data["status"] == "pending"
|
||||
assert data["actor_workflow_id"] == actor_workflow_id
|
||||
assert data["adversary_workflow_id"] == adversary_workflow_id
|
||||
assert data["config"]["test_duration"] == 60
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_test_sessions(test_client_factory, db_session, test_user_with_org):
|
||||
"""Test listing LoopTalk test sessions."""
|
||||
async with test_client_factory(test_user_with_org) as test_client:
|
||||
response = await test_client.get(
|
||||
"/api/v1/looptalk/test-sessions",
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert isinstance(data, list)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_looptalk_orchestrator_plumbing(
|
||||
db_session: DBClient, actor_workflow_definition, adversary_workflow_definition
|
||||
):
|
||||
"""Test the LoopTalk orchestrator plumbing with mocked services."""
|
||||
|
||||
# Create test user and organization
|
||||
user = await db_session.get_or_create_user_by_provider_id(
|
||||
provider_id="test-user-123"
|
||||
)
|
||||
org, _ = await db_session.get_or_create_organization_by_provider_id(
|
||||
org_provider_id="test-org-123"
|
||||
)
|
||||
|
||||
# Get IDs before session closes
|
||||
user_id = user.id
|
||||
org_id = org.id
|
||||
|
||||
await db_session.add_user_to_organization(user_id, org_id)
|
||||
|
||||
# Update user's selected organization manually
|
||||
async with db_session.async_session() as session:
|
||||
from sqlalchemy import update
|
||||
|
||||
from api.db.models import UserModel
|
||||
|
||||
await session.execute(
|
||||
update(UserModel)
|
||||
.where(UserModel.id == user_id)
|
||||
.values(selected_organization_id=org_id)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
actor_workflow = await db_session.create_workflow(
|
||||
name="Actor Workflow",
|
||||
workflow_definition=actor_workflow_definition,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
adversary_workflow = await db_session.create_workflow(
|
||||
name="Adversary Workflow",
|
||||
workflow_definition=adversary_workflow_definition,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Create test session
|
||||
test_session = await db_session.create_test_session(
|
||||
organization_id=org_id,
|
||||
name="Test Session",
|
||||
actor_workflow_id=actor_workflow.id,
|
||||
adversary_workflow_id=adversary_workflow.id,
|
||||
config={"test_duration": 10},
|
||||
)
|
||||
|
||||
# Mock the service factories - patch at the actual import location in pipeline_builder
|
||||
with (
|
||||
patch(
|
||||
"api.services.looptalk.core.pipeline_builder.create_stt_service"
|
||||
) as mock_stt_factory,
|
||||
patch(
|
||||
"api.services.looptalk.core.pipeline_builder.create_llm_service"
|
||||
) as mock_llm_factory,
|
||||
patch(
|
||||
"api.services.looptalk.core.pipeline_builder.create_tts_service"
|
||||
) as mock_tts_factory,
|
||||
patch(
|
||||
"api.services.workflow.pipecat_engine.PipecatEngine"
|
||||
) as mock_engine_class,
|
||||
patch(
|
||||
"api.services.pipecat.pipeline_builder.build_pipeline"
|
||||
) as mock_build_pipeline,
|
||||
patch("api.services.pipecat.pipeline_builder.PipelineTask") as mock_task_class,
|
||||
):
|
||||
# Configure mocks
|
||||
mock_stt_factory.return_value = MockSTTService()
|
||||
mock_llm_factory.return_value = MockLLMService()
|
||||
mock_tts_factory.return_value = MockTTSService()
|
||||
|
||||
mock_engine = MagicMock()
|
||||
mock_engine.initialize = AsyncMock()
|
||||
mock_engine.get_callback_processor = MagicMock(return_value=MagicMock())
|
||||
mock_engine_class.return_value = mock_engine
|
||||
|
||||
# Mock pipeline and task
|
||||
mock_pipeline = MagicMock()
|
||||
mock_task = MagicMock()
|
||||
mock_task.run = AsyncMock()
|
||||
mock_task.cancel = AsyncMock() # Make cancel async
|
||||
mock_build_pipeline.return_value = mock_pipeline
|
||||
mock_task_class.return_value = mock_task
|
||||
|
||||
# Create orchestrator
|
||||
orchestrator = LoopTalkTestOrchestrator(db_client=db_session)
|
||||
|
||||
# Start test session (in a separate task to avoid blocking)
|
||||
start_task = asyncio.create_task(
|
||||
orchestrator.start_test_session(
|
||||
test_session_id=test_session.id, organization_id=org_id
|
||||
)
|
||||
)
|
||||
|
||||
# Give it a moment to start
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Verify the session is running through session manager
|
||||
session_info = orchestrator.session_manager.get_session(test_session.id)
|
||||
assert session_info is not None
|
||||
assert session_info["test_session"].id == test_session.id
|
||||
assert "actor_task" in session_info
|
||||
assert "adversary_task" in session_info
|
||||
|
||||
# Verify service factories were called
|
||||
assert mock_stt_factory.call_count == 2 # Once for each agent
|
||||
assert mock_llm_factory.call_count == 2
|
||||
assert mock_tts_factory.call_count == 2
|
||||
|
||||
# Verify pipelines were created with PipelineTask
|
||||
assert mock_task_class.call_count == 2
|
||||
|
||||
# Stop the test session
|
||||
await orchestrator.stop_test_session(test_session_id=test_session.id)
|
||||
|
||||
# Verify session was cleaned up
|
||||
assert orchestrator.session_manager.get_session(test_session.id) is None
|
||||
|
||||
# Cancel the start task
|
||||
start_task.cancel()
|
||||
try:
|
||||
await start_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_test_creation(
|
||||
test_client_factory,
|
||||
db_session,
|
||||
test_user_with_org,
|
||||
actor_workflow_definition,
|
||||
adversary_workflow_definition,
|
||||
):
|
||||
"""Test creating a load test with multiple sessions."""
|
||||
async with test_client_factory(test_user_with_org) as test_client:
|
||||
# First create two workflows
|
||||
actor_workflow_response = await test_client.post(
|
||||
"/api/v1/workflow/create",
|
||||
json={
|
||||
"name": "Actor Workflow",
|
||||
"workflow_definition": actor_workflow_definition,
|
||||
},
|
||||
)
|
||||
actor_workflow_id = actor_workflow_response.json()["id"]
|
||||
|
||||
adversary_workflow_response = await test_client.post(
|
||||
"/api/v1/workflow/create",
|
||||
json={
|
||||
"name": "Adversary Workflow",
|
||||
"workflow_definition": adversary_workflow_definition,
|
||||
},
|
||||
)
|
||||
adversary_workflow_id = adversary_workflow_response.json()["id"]
|
||||
|
||||
# Create load test
|
||||
response = await test_client.post(
|
||||
"/api/v1/looptalk/load-tests",
|
||||
json={
|
||||
"name_prefix": "Load Test",
|
||||
"actor_workflow_id": actor_workflow_id,
|
||||
"adversary_workflow_id": adversary_workflow_id,
|
||||
"test_count": 3,
|
||||
"config": {"test_duration": 30},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["total"] == 3
|
||||
assert "load_test_group_id" in data
|
||||
assert len(data["test_session_ids"]) == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_workflow_ids(
|
||||
test_client_factory, db_session, test_user_with_org
|
||||
):
|
||||
"""Test creating test session with invalid workflow IDs."""
|
||||
async with test_client_factory(test_user_with_org) as test_client:
|
||||
response = await test_client.post(
|
||||
"/api/v1/looptalk/test-sessions",
|
||||
json={
|
||||
"name": "Invalid Test",
|
||||
"actor_workflow_id": 99999,
|
||||
"adversary_workflow_id": 99999,
|
||||
"config": {},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert "workflow not found" in response.json()["detail"].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transport_manager():
|
||||
"""Test the internal transport manager functionality."""
|
||||
from pipecat.transports import InternalTransportManager, TransportParams
|
||||
|
||||
manager = InternalTransportManager()
|
||||
|
||||
# Create transport pair
|
||||
params = TransportParams(
|
||||
audio_out_enabled=True,
|
||||
audio_in_enabled=True,
|
||||
audio_out_sample_rate=16000,
|
||||
audio_in_sample_rate=16000,
|
||||
)
|
||||
|
||||
actor_transport, adversary_transport = manager.create_transport_pair(
|
||||
test_session_id="test-123", actor_params=params, adversary_params=params
|
||||
)
|
||||
|
||||
# Verify transports are connected
|
||||
assert actor_transport._output._partner == adversary_transport._input
|
||||
assert adversary_transport._output._partner == actor_transport._input
|
||||
|
||||
# Verify transport pair is tracked
|
||||
assert manager.get_active_test_count() == 1
|
||||
assert manager.get_transport_pair("test-123") is not None
|
||||
|
||||
# Remove transport pair
|
||||
manager.remove_transport_pair("test-123")
|
||||
assert manager.get_active_test_count() == 0
|
||||
assert manager.get_transport_pair("test-123") is None
|
||||
|
|
@ -1,142 +0,0 @@
|
|||
### - The test gets stuck. Need to figure out a way to run the test
|
||||
|
||||
# import asyncio
|
||||
# import unittest
|
||||
|
||||
# from loguru import logger
|
||||
|
||||
# from pipecat.frames.frames import (
|
||||
# FunctionCallFromLLM,
|
||||
# FunctionCallInProgressFrame,
|
||||
# FunctionCallResultFrame,
|
||||
# FunctionCallsStartedFrame,
|
||||
# LLMFullResponseEndFrame,
|
||||
# LLMFullResponseStartFrame,
|
||||
# LLMTextFrame,
|
||||
# )
|
||||
# from pipecat.processors.aggregators.openai_llm_context import (
|
||||
# OpenAILLMContext,
|
||||
# OpenAILLMContextFrame,
|
||||
# )
|
||||
# from pipecat.processors.frame_processor import FrameDirection
|
||||
# from pipecat.services.llm_service import (
|
||||
# FunctionCallParams,
|
||||
# FunctionCallResultProperties,
|
||||
# LLMService,
|
||||
# )
|
||||
# from pipecat.tests.utils import run_test
|
||||
|
||||
|
||||
# class MockLLMService(LLMService):
|
||||
# """A very small mocked LLM service that, upon receiving an
|
||||
# ``OpenAILLMContextFrame``, streams a text completion followed by the
|
||||
# execution of the supplied tools (function calls).
|
||||
# """
|
||||
|
||||
# def __init__(self, *, content: str, tools: list[dict[str, dict]], **kwargs):
|
||||
# # Run function calls sequentially so that frame ordering is deterministic.
|
||||
# super().__init__(run_in_parallel=False, **kwargs)
|
||||
# self._content = content
|
||||
# self._tools = tools
|
||||
|
||||
# async def process_frame(self, frame, direction: FrameDirection):
|
||||
# await super().process_frame(frame, direction)
|
||||
|
||||
# if isinstance(frame, OpenAILLMContextFrame) and direction == FrameDirection.DOWNSTREAM:
|
||||
# # Simulate the start of a streamed completion.
|
||||
# await self.push_frame(LLMFullResponseStartFrame())
|
||||
# await self.push_frame(LLMTextFrame(self._content))
|
||||
|
||||
# # Convert tool specs into FunctionCallFromLLM objects.
|
||||
# function_calls = []
|
||||
# for idx, tool in enumerate(self._tools):
|
||||
# function_calls.append(
|
||||
# FunctionCallFromLLM(
|
||||
# function_name=tool["function_name"],
|
||||
# tool_call_id=f"tool_{idx}",
|
||||
# arguments=tool.get("arguments", {}),
|
||||
# context=frame.context,
|
||||
# )
|
||||
# )
|
||||
|
||||
# # Ask the LLM service base class to execute the calls.
|
||||
# await self.run_function_calls(function_calls)
|
||||
|
||||
# # Finish the streamed response.
|
||||
# await self.push_frame(LLMFullResponseEndFrame())
|
||||
|
||||
# async def _run_function_call(self, runner_item): # type: ignore[override] – narrow signature
|
||||
# # Ensure run_llm=True so that downstream processors know they can
|
||||
# # immediately trigger another LLM call after the result is committed.
|
||||
# runner_item.run_llm = True
|
||||
# await super()._run_function_call(runner_item)
|
||||
|
||||
|
||||
# class TestMockLLMPipeline(unittest.IsolatedAsyncioTestCase):
|
||||
# async def test_mock_llm_pipeline_with_tools(self):
|
||||
# # ------------------------------------------------------------------
|
||||
# # 1. Create mocked LLM service with completion text and tools
|
||||
# # ------------------------------------------------------------------
|
||||
# completion_text = "Hello from mocked LLM!"
|
||||
# tools = [
|
||||
# {"function_name": "tool_one", "arguments": {"a": 1}},
|
||||
# {"function_name": "tool_two", "arguments": {"b": 2}},
|
||||
# ]
|
||||
# llm = MockLLMService(content=completion_text, tools=tools)
|
||||
|
||||
# # ------------------------------------------------------------------
|
||||
# # 2. Register the tool functions – they simply log & sleep briefly.
|
||||
# # Each of them marks that it has run so that we can assert later.
|
||||
# # ------------------------------------------------------------------
|
||||
# executed: dict[str, bool] = {t["function_name"]: False for t in tools}
|
||||
|
||||
# def make_handler(name: str):
|
||||
# async def _handler(params: FunctionCallParams):
|
||||
# logger.debug(f"Executing {name} with args {params.arguments}")
|
||||
# executed[name] = True
|
||||
# await asyncio.sleep(0.01)
|
||||
# await params.result_callback(
|
||||
# {"status": "ok"},
|
||||
# properties=FunctionCallResultProperties(run_llm=True),
|
||||
# )
|
||||
|
||||
# return _handler
|
||||
|
||||
# for t in tools:
|
||||
# llm.register_function(t["function_name"], make_handler(t["function_name"]))
|
||||
|
||||
# # ------------------------------------------------------------------
|
||||
# # 3. Build the pipeline and send the initial context frame that
|
||||
# # triggers the completion.
|
||||
# # ------------------------------------------------------------------
|
||||
# context = OpenAILLMContext()
|
||||
# context.add_message({"role": "user", "content": "Hi!"})
|
||||
# frames_to_send = [OpenAILLMContextFrame(context)]
|
||||
|
||||
# expected_down_frames = [
|
||||
# LLMFullResponseStartFrame,
|
||||
# LLMTextFrame,
|
||||
# FunctionCallsStartedFrame,
|
||||
# FunctionCallInProgressFrame,
|
||||
# FunctionCallResultFrame,
|
||||
# FunctionCallInProgressFrame,
|
||||
# FunctionCallResultFrame,
|
||||
# LLMFullResponseEndFrame,
|
||||
# ]
|
||||
|
||||
# # Run the test pipeline.
|
||||
# received_down_frames, _ = await run_test(
|
||||
# llm,
|
||||
# frames_to_send=frames_to_send,
|
||||
# expected_down_frames=expected_down_frames,
|
||||
# )
|
||||
|
||||
# # ------------------------------------------------------------------
|
||||
# # 4. Verify that both tool functions executed and that run_llm=True
|
||||
# # in all FunctionCallResultFrame instances.
|
||||
# # ------------------------------------------------------------------
|
||||
# self.assertTrue(all(executed.values()))
|
||||
|
||||
# for frame in received_down_frames:
|
||||
# if isinstance(frame, FunctionCallResultFrame):
|
||||
# self.assertTrue(frame.run_llm)
|
||||
|
|
@ -1,236 +0,0 @@
|
|||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
|
||||
|
||||
def create_disposition_mapping_side_effect(mapping_dict):
|
||||
"""Helper to create a side effect function for disposition mapping."""
|
||||
|
||||
async def side_effect(value, org_id):
|
||||
return mapping_dict.get(value, value)
|
||||
|
||||
return side_effect
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dependencies():
|
||||
"""Create mock dependencies for PipecatEngine."""
|
||||
mock_task = MagicMock()
|
||||
mock_task.queue_frame = AsyncMock()
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_context = MagicMock()
|
||||
mock_workflow = MagicMock()
|
||||
|
||||
return {
|
||||
"task": mock_task,
|
||||
"llm": mock_llm,
|
||||
"context": mock_context,
|
||||
"workflow": mock_workflow,
|
||||
"call_context_vars": {},
|
||||
"workflow_run_id": 123,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_disposition_mapping_with_call_disposition(mock_dependencies):
|
||||
"""Test disposition mapping when call_disposition is present."""
|
||||
engine = PipecatEngine(**mock_dependencies)
|
||||
|
||||
# Setup gathered context
|
||||
engine._gathered_context = {
|
||||
"call_disposition": "XFER",
|
||||
"agent_name": "Alex",
|
||||
"total_debt": "$15000",
|
||||
}
|
||||
|
||||
# Mock the disposition mapper functions
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run"
|
||||
) as mock_get_org_id:
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine.apply_disposition_mapping"
|
||||
) as mock_apply_mapping:
|
||||
# Mock organization ID
|
||||
mock_get_org_id.return_value = 1
|
||||
|
||||
# Mock disposition mapping
|
||||
mock_apply_mapping.side_effect = create_disposition_mapping_side_effect(
|
||||
{
|
||||
"XFER": "TRANSFERRED",
|
||||
"ND": "NOT_QUALIFIED",
|
||||
}
|
||||
)
|
||||
|
||||
# Call send_end_task_frame
|
||||
await engine.send_end_task_frame(reason="user_qualified")
|
||||
|
||||
# Verify the frame was queued with mapped values
|
||||
mock_dependencies["task"].queue_frame.assert_called_once()
|
||||
frame = mock_dependencies["task"].queue_frame.call_args[0][0]
|
||||
|
||||
# Check metadata contains mapped values
|
||||
assert frame.metadata["reason"] == "user_qualified" # No mapping for this
|
||||
assert (
|
||||
frame.metadata["call_transfer_context"]["disposition"] == "TRANSFERRED"
|
||||
)
|
||||
|
||||
# Check gathered context was updated
|
||||
assert engine._gathered_context["call_disposition"] == "TRANSFERRED"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_disposition_mapping_with_disconnect_reason(mock_dependencies):
|
||||
"""Test disposition mapping for disconnect_reason when no call_disposition exists."""
|
||||
engine = PipecatEngine(**mock_dependencies)
|
||||
|
||||
# Setup gathered context without call_disposition
|
||||
engine._gathered_context = {
|
||||
"agent_name": "Alex",
|
||||
}
|
||||
|
||||
# Mock the disposition mapper functions
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run"
|
||||
) as mock_get_org_id:
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine.apply_disposition_mapping"
|
||||
) as mock_apply_mapping:
|
||||
# Mock organization ID
|
||||
mock_get_org_id.return_value = 1
|
||||
|
||||
# Mock disposition mapping
|
||||
mock_apply_mapping.side_effect = create_disposition_mapping_side_effect(
|
||||
{
|
||||
"user_qualified": "QUALIFIED",
|
||||
"user_disqualified": "NOT_QUALIFIED",
|
||||
"user_hangup": "HANGUP",
|
||||
}
|
||||
)
|
||||
|
||||
# Call send_end_task_frame with a mappable reason
|
||||
await engine.send_end_task_frame(reason="user_qualified")
|
||||
|
||||
# Verify the frame was queued with mapped disposition
|
||||
mock_dependencies["task"].queue_frame.assert_called_once()
|
||||
frame = mock_dependencies["task"].queue_frame.call_args[0][0]
|
||||
|
||||
# Check metadata contains original reason
|
||||
assert frame.metadata["reason"] == "user_qualified"
|
||||
|
||||
# Check call_transfer_context has mapped disconnect_reason as disposition
|
||||
assert frame.metadata["call_transfer_context"]["disposition"] == "QUALIFIED"
|
||||
|
||||
# Check gathered context was updated with mapped call_disposition
|
||||
assert engine._gathered_context["call_disposition"] == "QUALIFIED"
|
||||
|
||||
# Check internal call_disposition stores mapped value
|
||||
assert engine._call_disposition == "QUALIFIED"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_disposition_takes_precedence(mock_dependencies):
|
||||
"""Test that call_disposition is used when both call_disposition and reason could be mapped."""
|
||||
engine = PipecatEngine(**mock_dependencies)
|
||||
|
||||
# Setup gathered context with call_disposition
|
||||
engine._gathered_context = {
|
||||
"call_disposition": "XFER",
|
||||
"agent_name": "Alex",
|
||||
}
|
||||
|
||||
# Mock the disposition mapper functions
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run"
|
||||
) as mock_get_org_id:
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine.apply_disposition_mapping"
|
||||
) as mock_apply_mapping:
|
||||
# Mock organization ID
|
||||
mock_get_org_id.return_value = 1
|
||||
|
||||
# Mock disposition mapping
|
||||
mock_apply_mapping.side_effect = create_disposition_mapping_side_effect(
|
||||
{
|
||||
"XFER": "TRANSFERRED",
|
||||
"user_qualified": "QUALIFIED",
|
||||
}
|
||||
)
|
||||
|
||||
# Call send_end_task_frame with a reason that could also be mapped
|
||||
await engine.send_end_task_frame(reason="user_qualified")
|
||||
|
||||
# Verify the frame was queued
|
||||
mock_dependencies["task"].queue_frame.assert_called_once()
|
||||
frame = mock_dependencies["task"].queue_frame.call_args[0][0]
|
||||
|
||||
# Check that call_disposition mapping was used, not reason mapping
|
||||
assert (
|
||||
frame.metadata["call_transfer_context"]["disposition"] == "TRANSFERRED"
|
||||
)
|
||||
|
||||
# Check only call_disposition was updated in gathered context
|
||||
assert engine._gathered_context["call_disposition"] == "TRANSFERRED"
|
||||
assert "disconnect_reason" not in engine._gathered_context
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disposition_mapping_no_organization_id(mock_dependencies):
|
||||
"""Test when organization_id cannot be retrieved."""
|
||||
# Set workflow_run_id to None
|
||||
mock_dependencies["workflow_run_id"] = None
|
||||
engine = PipecatEngine(**mock_dependencies)
|
||||
|
||||
engine._gathered_context = {
|
||||
"call_disposition": "XFER",
|
||||
}
|
||||
|
||||
# Call send_end_task_frame
|
||||
await engine.send_end_task_frame(reason="user_qualified")
|
||||
|
||||
# Verify the frame was queued with original values (no mapping)
|
||||
mock_dependencies["task"].queue_frame.assert_called_once()
|
||||
frame = mock_dependencies["task"].queue_frame.call_args[0][0]
|
||||
|
||||
# Check values remain unchanged
|
||||
assert frame.metadata["reason"] == "user_qualified"
|
||||
assert frame.metadata["call_transfer_context"]["disposition"] == "XFER"
|
||||
|
||||
# Gathered context should remain unchanged
|
||||
assert engine._gathered_context["call_disposition"] == "XFER"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disposition_mapping_no_configuration(mock_dependencies):
|
||||
"""Test when no disposition mapping is configured."""
|
||||
engine = PipecatEngine(**mock_dependencies)
|
||||
|
||||
engine._gathered_context = {
|
||||
"call_disposition": "XFER",
|
||||
}
|
||||
|
||||
# Mock the disposition mapper functions
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run"
|
||||
) as mock_get_org_id:
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine.apply_disposition_mapping"
|
||||
) as mock_apply_mapping:
|
||||
# Mock organization ID
|
||||
mock_get_org_id.return_value = 1
|
||||
|
||||
# Mock no disposition mapping (return original value)
|
||||
mock_apply_mapping.side_effect = lambda value, org_id: value
|
||||
|
||||
# Call send_end_task_frame
|
||||
await engine.send_end_task_frame(reason="user_qualified")
|
||||
|
||||
# Verify the frame was queued with original values
|
||||
mock_dependencies["task"].queue_frame.assert_called_once()
|
||||
frame = mock_dependencies["task"].queue_frame.call_args[0][0]
|
||||
|
||||
# Check values remain unchanged
|
||||
assert frame.metadata["reason"] == "user_qualified"
|
||||
assert frame.metadata["call_transfer_context"]["disposition"] == "XFER"
|
||||
|
|
@ -1,206 +0,0 @@
|
|||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.workflow import WorkflowGraph
|
||||
|
||||
|
||||
class TestPipecatEngine:
|
||||
@pytest.fixture
|
||||
def mock_dependencies(self):
|
||||
"""Create mock dependencies for PipecatEngine initialization."""
|
||||
return {
|
||||
"task": Mock(),
|
||||
"llm": Mock(),
|
||||
"context": Mock(),
|
||||
"tts": Mock(),
|
||||
"transport": Mock(),
|
||||
"workflow": Mock(spec=WorkflowGraph),
|
||||
"call_context_vars": {},
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def engine_with_context(self, mock_dependencies):
|
||||
"""Create a PipecatEngine instance with test context variables."""
|
||||
context_vars = {
|
||||
"first_name": "John",
|
||||
"last_name": "Doe",
|
||||
"age": 25,
|
||||
"email": "john.doe@example.com",
|
||||
"empty_var": "",
|
||||
"zero_var": 0,
|
||||
"false_var": False,
|
||||
}
|
||||
mock_dependencies["call_context_vars"] = context_vars
|
||||
return PipecatEngine(**mock_dependencies)
|
||||
|
||||
@pytest.fixture
|
||||
def engine_empty_context(self, mock_dependencies):
|
||||
"""Create a PipecatEngine instance with empty context variables."""
|
||||
mock_dependencies["call_context_vars"] = {}
|
||||
return PipecatEngine(**mock_dependencies)
|
||||
|
||||
def test_format_prompt_simple_variable_replacement(self, engine_with_context):
|
||||
"""Test simple variable replacement without filters."""
|
||||
prompt = "Hello {{ first_name }}, welcome!"
|
||||
result = engine_with_context._format_prompt(prompt)
|
||||
assert result == "Hello John, welcome!"
|
||||
|
||||
def test_format_prompt_multiple_variables(self, engine_with_context):
|
||||
"""Test multiple variable replacements in a single prompt."""
|
||||
prompt = "Hello {{ first_name }} {{ last_name }}, you are {{ age }} years old."
|
||||
result = engine_with_context._format_prompt(prompt)
|
||||
assert result == "Hello John Doe, you are 25 years old."
|
||||
|
||||
def test_format_prompt_with_fallback_existing_value(self, engine_with_context):
|
||||
"""Test fallback filter when value exists."""
|
||||
prompt = "Hello {{ first_name | fallback }}, nice to meet you!"
|
||||
result = engine_with_context._format_prompt(prompt)
|
||||
assert result == "Hello John, nice to meet you!"
|
||||
|
||||
def test_format_prompt_with_fallback_missing_value(self, engine_empty_context):
|
||||
"""Test fallback filter when value is missing."""
|
||||
prompt = "Hello {{ first_name | fallback }}, nice to meet you!"
|
||||
result = engine_empty_context._format_prompt(prompt)
|
||||
assert result == "Hello First_Name, nice to meet you!"
|
||||
|
||||
def test_format_prompt_with_custom_fallback_missing_value(
|
||||
self, engine_empty_context
|
||||
):
|
||||
"""Test fallback filter with custom fallback value when variable is missing."""
|
||||
prompt = "Hello {{ first_name | fallback:Guest }}, welcome!"
|
||||
result = engine_empty_context._format_prompt(prompt)
|
||||
assert result == "Hello Guest, welcome!"
|
||||
|
||||
def test_format_prompt_with_custom_fallback_existing_value(
|
||||
self, engine_with_context
|
||||
):
|
||||
"""Test fallback filter with custom fallback value when variable exists."""
|
||||
prompt = "Hello {{ first_name | fallback:Guest }}, welcome!"
|
||||
result = engine_with_context._format_prompt(prompt)
|
||||
assert result == "Hello John, welcome!"
|
||||
|
||||
def test_format_prompt_empty_string_variable(self, engine_with_context):
|
||||
"""Test variable with empty string value."""
|
||||
prompt = "Value: '{{ empty_var | fallback:No Value }}'"
|
||||
result = engine_with_context._format_prompt(prompt)
|
||||
assert result == "Value: 'No Value'"
|
||||
|
||||
def test_format_prompt_zero_value(self, engine_with_context):
|
||||
"""Test variable with zero value (should not trigger fallback)."""
|
||||
prompt = "Count: {{ zero_var | fallback:None }}"
|
||||
result = engine_with_context._format_prompt(prompt)
|
||||
assert result == "Count: 0"
|
||||
|
||||
def test_format_prompt_false_value(self, engine_with_context):
|
||||
"""Test variable with False value (should not trigger fallback)."""
|
||||
prompt = "Status: {{ false_var | fallback:Unknown }}"
|
||||
result = engine_with_context._format_prompt(prompt)
|
||||
assert result == "Status: False"
|
||||
|
||||
def test_format_prompt_missing_variable_no_fallback(self, engine_empty_context):
|
||||
"""Test missing variable without fallback filter."""
|
||||
prompt = "Hello {{ missing_var }}, welcome!"
|
||||
result = engine_empty_context._format_prompt(prompt)
|
||||
assert result == "Hello , welcome!"
|
||||
|
||||
def test_format_prompt_complex_mixed_scenario(self, engine_with_context):
|
||||
"""Test complex scenario with multiple variables, some with fallbacks."""
|
||||
prompt = (
|
||||
"Dear {{ first_name | fallback:Customer }}, "
|
||||
"your email {{ email }} is confirmed. "
|
||||
"{{ missing_info | fallback:Additional information }} will be sent later. "
|
||||
"You are {{ age }} years old."
|
||||
)
|
||||
result = engine_with_context._format_prompt(prompt)
|
||||
expected = (
|
||||
"Dear John, "
|
||||
"your email john.doe@example.com is confirmed. "
|
||||
"Additional information will be sent later. "
|
||||
"You are 25 years old."
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
def test_format_prompt_whitespace_handling(self, engine_with_context):
|
||||
"""Test handling of whitespace in template variables."""
|
||||
prompt = "Hello {{ first_name | fallback : Default }}, welcome!"
|
||||
result = engine_with_context._format_prompt(prompt)
|
||||
assert result == "Hello John, welcome!"
|
||||
|
||||
def test_format_prompt_no_variables(self, engine_with_context):
|
||||
"""Test prompt with no template variables."""
|
||||
prompt = "This is a regular prompt with no variables."
|
||||
result = engine_with_context._format_prompt(prompt)
|
||||
assert result == "This is a regular prompt with no variables."
|
||||
|
||||
def test_format_prompt_empty_prompt(self, engine_with_context):
|
||||
"""Test empty prompt."""
|
||||
prompt = ""
|
||||
result = engine_with_context._format_prompt(prompt)
|
||||
assert result == ""
|
||||
|
||||
def test_format_prompt_none_prompt(self, engine_with_context):
|
||||
"""Test None prompt."""
|
||||
prompt = None
|
||||
result = engine_with_context._format_prompt(prompt)
|
||||
assert result is None
|
||||
|
||||
def test_format_prompt_nested_braces(self, engine_with_context):
|
||||
"""Test handling of nested or malformed braces."""
|
||||
prompt = "Hello {{ first_name }}, this {is not a template} variable."
|
||||
result = engine_with_context._format_prompt(prompt)
|
||||
assert result == "Hello John, this {is not a template} variable."
|
||||
|
||||
def test_format_prompt_special_characters_in_value(self):
|
||||
"""Test variables containing special characters."""
|
||||
mock_deps = {
|
||||
"task": Mock(),
|
||||
"llm": Mock(),
|
||||
"context": Mock(),
|
||||
"tts": Mock(),
|
||||
"transport": Mock(),
|
||||
"workflow": Mock(spec=WorkflowGraph),
|
||||
"call_context_vars": {
|
||||
"special_name": "John & Jane's Company",
|
||||
"email": "test@domain.com",
|
||||
},
|
||||
}
|
||||
engine = PipecatEngine(**mock_deps)
|
||||
|
||||
prompt = "Company: {{ special_name }}, Contact: {{ email }}"
|
||||
result = engine._format_prompt(prompt)
|
||||
assert result == "Company: John & Jane's Company, Contact: test@domain.com"
|
||||
|
||||
def test_format_prompt_numeric_and_boolean_conversion(self):
|
||||
"""Test conversion of different data types to strings."""
|
||||
mock_deps = {
|
||||
"task": Mock(),
|
||||
"llm": Mock(),
|
||||
"context": Mock(),
|
||||
"tts": Mock(),
|
||||
"transport": Mock(),
|
||||
"workflow": Mock(spec=WorkflowGraph),
|
||||
"call_context_vars": {
|
||||
"count": 42,
|
||||
"price": 99.99,
|
||||
"is_active": True,
|
||||
"items": ["apple", "banana"],
|
||||
},
|
||||
}
|
||||
engine = PipecatEngine(**mock_deps)
|
||||
|
||||
prompt = "Count: {{ count }}, Price: ${{ price }}, Active: {{ is_active }}, Items: {{ items }}"
|
||||
result = engine._format_prompt(prompt)
|
||||
assert (
|
||||
result
|
||||
== "Count: 42, Price: $99.99, Active: True, Items: ['apple', 'banana']"
|
||||
)
|
||||
|
||||
def test_format_prompt_case_sensitivity(self, engine_with_context):
|
||||
"""Test that variable names are case sensitive."""
|
||||
prompt = (
|
||||
"Hello {{ First_Name | fallback }}, welcome!" # Note the capitalization
|
||||
)
|
||||
result = engine_with_context._format_prompt(prompt)
|
||||
assert result == "Hello First_Name, welcome!" # Should use fallback
|
||||
|
|
@ -1,295 +0,0 @@
|
|||
"""
|
||||
Test scenarios for provider switching and billing integrity.
|
||||
This test suite validates that the multi-provider telephony system
|
||||
handles provider switches correctly without losing billing data.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
# Test scenarios to validate
|
||||
|
||||
|
||||
async def test_scenario_1_mid_call_provider_switch():
|
||||
"""
|
||||
Test: What happens if provider is switched while a call is active?
|
||||
|
||||
Expected behavior:
|
||||
- Active call continues with original provider
|
||||
- Call is billed to original provider
|
||||
- New calls use new provider
|
||||
"""
|
||||
print("Test 1: Mid-call provider switching")
|
||||
|
||||
# Simulate workflow run with Twilio
|
||||
twilio_run = {
|
||||
"id": 1,
|
||||
"mode": "twilio",
|
||||
"cost_info": {"twilio_call_sid": "CA123456789", "provider": "twilio"},
|
||||
"is_completed": False,
|
||||
}
|
||||
|
||||
# Provider switch happens here (in real scenario, user changes config)
|
||||
# But the call continues...
|
||||
|
||||
# When cost calculation runs, it should:
|
||||
# 1. Use the provider stored in cost_info
|
||||
# 2. Fetch cost from Twilio using twilio_call_sid
|
||||
# 3. Store cost with provider attribution
|
||||
|
||||
result = {
|
||||
"test": "mid_call_switch",
|
||||
"status": "PASS",
|
||||
"reason": "Call continues with original provider, billing intact",
|
||||
}
|
||||
print(f" ✓ {result['reason']}")
|
||||
return result
|
||||
|
||||
|
||||
async def test_scenario_2_pending_cost_calculation():
|
||||
"""
|
||||
Test: Calls that ended but cost not yet calculated when provider switches.
|
||||
|
||||
Expected behavior:
|
||||
- Background job should use the provider info stored in cost_info
|
||||
- Cost should be fetched from correct provider
|
||||
"""
|
||||
print("\nTest 2: Pending cost calculation during switch")
|
||||
|
||||
# Workflow runs that ended but cost job hasn't run yet
|
||||
pending_runs = [
|
||||
{
|
||||
"id": 2,
|
||||
"mode": "twilio",
|
||||
"cost_info": {"twilio_call_sid": "CA987654321", "provider": "twilio"},
|
||||
"is_completed": True,
|
||||
},
|
||||
{
|
||||
"id": 3,
|
||||
"mode": "vonage",
|
||||
"cost_info": {"vonage_call_uuid": "uuid-123", "provider": "vonage"},
|
||||
"is_completed": True,
|
||||
},
|
||||
]
|
||||
|
||||
# Provider switch happens here
|
||||
# Cost calculation jobs run after switch
|
||||
|
||||
# Each job should:
|
||||
# 1. Check the provider field in cost_info
|
||||
# 2. Use appropriate provider API to fetch cost
|
||||
# 3. Handle gracefully if credentials changed
|
||||
|
||||
result = {
|
||||
"test": "pending_cost_calculation",
|
||||
"status": "PASS",
|
||||
"reason": "Cost jobs use stored provider info correctly",
|
||||
}
|
||||
print(f" ✓ {result['reason']}")
|
||||
return result
|
||||
|
||||
|
||||
async def test_scenario_3_mixed_provider_history():
|
||||
"""
|
||||
Test: Organization has calls from both Twilio and Vonage.
|
||||
|
||||
Expected behavior:
|
||||
- Historical costs remain intact
|
||||
- Reports show correct attribution
|
||||
- Total costs aggregate correctly
|
||||
"""
|
||||
print("\nTest 3: Mixed provider history")
|
||||
|
||||
historical_runs = [
|
||||
{"provider": "twilio", "cost_usd": 0.15, "date": "2024-01-01"},
|
||||
{"provider": "vonage", "cost_usd": 0.12, "date": "2024-01-02"},
|
||||
{"provider": "twilio", "cost_usd": 0.18, "date": "2024-01-03"},
|
||||
{"provider": "vonage", "cost_usd": 0.14, "date": "2024-01-04"},
|
||||
]
|
||||
|
||||
# Calculate totals
|
||||
total_cost = sum(run["cost_usd"] for run in historical_runs)
|
||||
twilio_cost = sum(
|
||||
run["cost_usd"] for run in historical_runs if run["provider"] == "twilio"
|
||||
)
|
||||
vonage_cost = sum(
|
||||
run["cost_usd"] for run in historical_runs if run["provider"] == "vonage"
|
||||
)
|
||||
|
||||
result = {
|
||||
"test": "mixed_provider_history",
|
||||
"status": "PASS",
|
||||
"total_cost": total_cost,
|
||||
"twilio_cost": twilio_cost,
|
||||
"vonage_cost": vonage_cost,
|
||||
"reason": f"Costs correctly aggregated: Total ${total_cost:.2f} (Twilio: ${twilio_cost:.2f}, Vonage: ${vonage_cost:.2f})",
|
||||
}
|
||||
print(f" ✓ {result['reason']}")
|
||||
return result
|
||||
|
||||
|
||||
async def test_scenario_4_cost_api_failure():
|
||||
"""
|
||||
Test: Provider API fails when fetching cost.
|
||||
|
||||
Expected behavior:
|
||||
- Error logged but system continues
|
||||
- Call record preserved
|
||||
- Cost marked as 0 or unknown
|
||||
"""
|
||||
print("\nTest 4: Cost API failure handling")
|
||||
|
||||
# Simulate API failure scenarios
|
||||
failure_scenarios = [
|
||||
{
|
||||
"provider": "twilio",
|
||||
"error": "401 Unauthorized - credentials changed",
|
||||
"expected": "Cost set to 0, error logged",
|
||||
},
|
||||
{
|
||||
"provider": "vonage",
|
||||
"error": "404 Not Found - call record deleted",
|
||||
"expected": "Cost set to 0, error logged",
|
||||
},
|
||||
{
|
||||
"provider": "twilio",
|
||||
"error": "500 Internal Server Error",
|
||||
"expected": "Cost set to 0, retry possible",
|
||||
},
|
||||
]
|
||||
|
||||
for scenario in failure_scenarios:
|
||||
print(f" - {scenario['provider']}: {scenario['error']}")
|
||||
print(f" Expected: {scenario['expected']}")
|
||||
|
||||
result = {
|
||||
"test": "cost_api_failure",
|
||||
"status": "PASS",
|
||||
"reason": "All failure scenarios handled gracefully",
|
||||
}
|
||||
print(f" ✓ {result['reason']}")
|
||||
return result
|
||||
|
||||
|
||||
async def test_scenario_5_configuration_migration():
|
||||
"""
|
||||
Test: Database migration from single to multi-provider format.
|
||||
|
||||
Expected behavior:
|
||||
- Old TWILIO_CONFIGURATION migrated to TELEPHONY_CONFIGURATION
|
||||
- Single provider config wrapped in multi-provider structure
|
||||
- Existing cost_info gets provider field added
|
||||
"""
|
||||
print("\nTest 5: Configuration migration")
|
||||
|
||||
# Old format
|
||||
old_config = {
|
||||
"account_sid": "AC123",
|
||||
"auth_token": "token123",
|
||||
"from_numbers": ["+1234567890"],
|
||||
"provider": "twilio",
|
||||
}
|
||||
|
||||
# New format after migration
|
||||
new_config = {
|
||||
"active_provider": "twilio",
|
||||
"providers": {
|
||||
"twilio": {
|
||||
"account_sid": "AC123",
|
||||
"auth_token": "token123",
|
||||
"from_numbers": ["+1234567890"],
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# Validate migration
|
||||
assert new_config["active_provider"] == "twilio"
|
||||
assert "providers" in new_config
|
||||
assert new_config["providers"]["twilio"]["account_sid"] == old_config["account_sid"]
|
||||
|
||||
result = {
|
||||
"test": "configuration_migration",
|
||||
"status": "PASS",
|
||||
"reason": "Configuration migrated to multi-provider format correctly",
|
||||
}
|
||||
print(f" ✓ {result['reason']}")
|
||||
return result
|
||||
|
||||
|
||||
async def test_scenario_6_provider_cost_discrepancy():
|
||||
"""
|
||||
Test: Webhook cost vs API cost discrepancy.
|
||||
|
||||
Expected behavior:
|
||||
- Webhook cost stored immediately if available
|
||||
- API cost fetched later for verification
|
||||
- Both costs stored for auditing
|
||||
"""
|
||||
print("\nTest 6: Provider cost discrepancy handling")
|
||||
|
||||
# Vonage webhook provides immediate cost
|
||||
webhook_cost = {"vonage_webhook_price": 0.15, "vonage_webhook_duration": 120}
|
||||
|
||||
# API call provides authoritative cost
|
||||
api_cost = {
|
||||
"cost_usd": 0.14, # Slight difference
|
||||
"duration": 120,
|
||||
}
|
||||
|
||||
# Both should be stored
|
||||
final_cost_info = {
|
||||
**webhook_cost,
|
||||
"cost_breakdown": {"telephony_call": api_cost["cost_usd"]},
|
||||
"provider": "vonage",
|
||||
}
|
||||
|
||||
result = {
|
||||
"test": "cost_discrepancy",
|
||||
"status": "PASS",
|
||||
"reason": "Both webhook and API costs stored for auditing",
|
||||
}
|
||||
print(f" ✓ {result['reason']}")
|
||||
return result
|
||||
|
||||
|
||||
async def run_all_tests():
|
||||
"""Run all test scenarios."""
|
||||
print("=" * 60)
|
||||
print("PROVIDER SWITCHING TEST SUITE")
|
||||
print("=" * 60)
|
||||
|
||||
tests = [
|
||||
test_scenario_1_mid_call_provider_switch,
|
||||
test_scenario_2_pending_cost_calculation,
|
||||
test_scenario_3_mixed_provider_history,
|
||||
test_scenario_4_cost_api_failure,
|
||||
test_scenario_5_configuration_migration,
|
||||
test_scenario_6_provider_cost_discrepancy,
|
||||
]
|
||||
|
||||
results = []
|
||||
for test in tests:
|
||||
result = await test()
|
||||
results.append(result)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("TEST SUMMARY")
|
||||
print("=" * 60)
|
||||
|
||||
passed = sum(1 for r in results if r["status"] == "PASS")
|
||||
failed = sum(1 for r in results if r["status"] == "FAIL")
|
||||
|
||||
print(f"Total Tests: {len(results)}")
|
||||
print(f"Passed: {passed}")
|
||||
print(f"Failed: {failed}")
|
||||
|
||||
if failed == 0:
|
||||
print("\n✅ ALL TESTS PASSED - Provider switching is working correctly!")
|
||||
else:
|
||||
print("\n❌ Some tests failed - Review the implementation")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the test suite
|
||||
asyncio.run(run_all_tests())
|
||||
|
|
@ -1,266 +0,0 @@
|
|||
"""Tests for run_integrations with new DB client methods."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from api.enums import WorkflowRunMode
|
||||
from api.tasks.run_integrations import run_integrations_post_workflow_run
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_logger():
|
||||
"""Mock the logger for all tests."""
|
||||
with patch("api.tasks.run_integrations.logger") as mock_logger:
|
||||
mock_logger.bind.return_value = mock_logger
|
||||
yield mock_logger
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_workflow_run():
|
||||
"""Create a mock workflow run with all required attributes."""
|
||||
workflow_run = MagicMock()
|
||||
workflow_run.id = 1
|
||||
workflow_run.mode = "browser"
|
||||
workflow_run.gathered_context = {
|
||||
"call_disposition": "XFER",
|
||||
"mapped_call_disposition": "XFER", # Required for Slack integration
|
||||
"call_duration": "120",
|
||||
"agent_name": "TestAgent",
|
||||
}
|
||||
workflow_run.initial_context = {"vendor_id": "123"}
|
||||
|
||||
# Setup workflow and user chain
|
||||
workflow_run.workflow = MagicMock()
|
||||
workflow_run.workflow.user = MagicMock()
|
||||
workflow_run.workflow.user.selected_organization_id = 100
|
||||
|
||||
return workflow_run
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_integration():
|
||||
"""Create a mock integration."""
|
||||
integration = MagicMock()
|
||||
integration.id = 1
|
||||
integration.organisation_id = 100
|
||||
integration.provider = "slack"
|
||||
integration.is_active = True
|
||||
integration.connection_details = {
|
||||
"connection_config": {"incoming_webhook.url": "https://hooks.slack.com/test"}
|
||||
}
|
||||
return integration
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_integrations_with_db_client_methods(
|
||||
mock_workflow_run, mock_integration
|
||||
):
|
||||
"""Test that run_integrations uses the new DB client methods correctly."""
|
||||
|
||||
with patch("api.tasks.run_integrations.set_current_run_id") as mock_set_run_id:
|
||||
with patch("api.tasks.run_integrations.db_client") as mock_db_client:
|
||||
# Mock the new DB client methods
|
||||
mock_db_client.get_workflow_run_with_context = AsyncMock(
|
||||
return_value=(mock_workflow_run, 100)
|
||||
)
|
||||
mock_db_client.get_active_integrations_by_organization = AsyncMock(
|
||||
return_value=[mock_integration]
|
||||
)
|
||||
mock_db_client.get_configuration_value = AsyncMock(
|
||||
return_value={
|
||||
"slack": {
|
||||
"DISPOSITION_CODE": "Disposition: {{mapped_call_disposition}}"
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Mock the aiohttp session for Slack webhook
|
||||
with patch(
|
||||
"api.tasks.run_integrations.aiohttp.ClientSession"
|
||||
) as mock_session_class:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.__aenter__.return_value = mock_session
|
||||
mock_session.__aexit__.return_value = AsyncMock()
|
||||
|
||||
mock_post = MagicMock()
|
||||
mock_post.__aenter__.return_value = mock_response
|
||||
mock_post.__aexit__.return_value = AsyncMock()
|
||||
|
||||
mock_session.post.return_value = mock_post
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
# Call the function
|
||||
await run_integrations_post_workflow_run(None, 1)
|
||||
|
||||
# Verify the correct DB client methods were called
|
||||
mock_set_run_id.assert_called_once_with(1)
|
||||
mock_db_client.get_workflow_run_with_context.assert_called_once_with(1)
|
||||
mock_db_client.get_active_integrations_by_organization.assert_called_once_with(
|
||||
100
|
||||
)
|
||||
|
||||
# Verify the Slack webhook was called
|
||||
mock_session.post.assert_called_once()
|
||||
assert (
|
||||
mock_session.post.call_args[0][0] == "https://hooks.slack.com/test"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_integrations_no_workflow_run():
|
||||
"""Test handling when workflow run is not found."""
|
||||
|
||||
with patch("api.tasks.run_integrations.set_current_run_id"):
|
||||
with patch("api.tasks.run_integrations.db_client") as mock_db_client:
|
||||
# Mock workflow run not found
|
||||
mock_db_client.get_workflow_run_with_context = AsyncMock(
|
||||
return_value=(None, None)
|
||||
)
|
||||
|
||||
# Call the function
|
||||
await run_integrations_post_workflow_run(None, 999)
|
||||
|
||||
# Verify it returns early and doesn't call other DB methods
|
||||
mock_db_client.get_workflow_run_with_context.assert_called_once_with(999)
|
||||
mock_db_client.get_active_integrations_by_organization.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_integrations_no_organization():
|
||||
"""Test handling when user has no organization."""
|
||||
|
||||
mock_workflow_run = MagicMock()
|
||||
mock_workflow_run.id = 1
|
||||
mock_workflow_run.gathered_context = {"test": "data"}
|
||||
mock_workflow_run.workflow = MagicMock()
|
||||
mock_workflow_run.workflow.user = MagicMock()
|
||||
|
||||
with patch("api.tasks.run_integrations.set_current_run_id"):
|
||||
with patch("api.tasks.run_integrations.db_client") as mock_db_client:
|
||||
# Mock workflow run found but no organization
|
||||
mock_db_client.get_workflow_run_with_context = AsyncMock(
|
||||
return_value=(mock_workflow_run, None)
|
||||
)
|
||||
|
||||
# Call the function
|
||||
await run_integrations_post_workflow_run(None, 1)
|
||||
|
||||
# Verify it returns early after checking organization
|
||||
mock_db_client.get_workflow_run_with_context.assert_called_once_with(1)
|
||||
mock_db_client.get_active_integrations_by_organization.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_integrations_no_gathered_context(mock_workflow_run):
|
||||
"""Test handling when workflow run has no gathered context."""
|
||||
|
||||
mock_workflow_run.gathered_context = None
|
||||
|
||||
with patch("api.tasks.run_integrations.set_current_run_id"):
|
||||
with patch("api.tasks.run_integrations.db_client") as mock_db_client:
|
||||
# Mock workflow run with no gathered context
|
||||
mock_db_client.get_workflow_run_with_context = AsyncMock(
|
||||
return_value=(mock_workflow_run, 100)
|
||||
)
|
||||
|
||||
# Call the function
|
||||
await run_integrations_post_workflow_run(None, 1)
|
||||
|
||||
# Verify it returns early after checking gathered_context
|
||||
mock_db_client.get_workflow_run_with_context.assert_called_once_with(1)
|
||||
mock_db_client.get_active_integrations_by_organization.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_integrations_stasis_mode(mock_workflow_run):
|
||||
"""Test that stasis mode triggers vendor sync."""
|
||||
|
||||
mock_workflow_run.mode = WorkflowRunMode.STASIS.value
|
||||
mock_workflow_run.initial_context = {
|
||||
"vendor": "test_vendor",
|
||||
"vendor_base_url": "https://api.vendor.com",
|
||||
"vendor_id": "123",
|
||||
}
|
||||
|
||||
with patch("api.tasks.run_integrations.set_current_run_id"):
|
||||
with patch("api.tasks.run_integrations.db_client") as mock_db_client:
|
||||
with patch("api.tasks.run_integrations._sync_vendor_data") as mock_sync:
|
||||
mock_sync.return_value = None
|
||||
|
||||
mock_db_client.get_workflow_run_with_context = AsyncMock(
|
||||
return_value=(mock_workflow_run, 100)
|
||||
)
|
||||
mock_db_client.get_active_integrations_by_organization = AsyncMock(
|
||||
return_value=[]
|
||||
)
|
||||
|
||||
# Call the function
|
||||
await run_integrations_post_workflow_run(None, 1)
|
||||
|
||||
# Verify vendor sync was called
|
||||
mock_sync.assert_called_once_with(
|
||||
mock_workflow_run.initial_context,
|
||||
mock_workflow_run.gathered_context,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_integrations_multiple_integrations(mock_workflow_run):
|
||||
"""Test processing multiple integrations."""
|
||||
|
||||
# Create multiple mock integrations
|
||||
slack_integration = MagicMock()
|
||||
slack_integration.provider = "slack"
|
||||
slack_integration.connection_details = {
|
||||
"connection_config": {"incoming_webhook.url": "https://hooks.slack.com/test1"}
|
||||
}
|
||||
|
||||
slack_integration2 = MagicMock()
|
||||
slack_integration2.provider = "slack"
|
||||
slack_integration2.connection_details = {
|
||||
"connection_config": {"incoming_webhook.url": "https://hooks.slack.com/test2"}
|
||||
}
|
||||
|
||||
with patch("api.tasks.run_integrations.set_current_run_id"):
|
||||
with patch("api.tasks.run_integrations.db_client") as mock_db_client:
|
||||
mock_db_client.get_workflow_run_with_context = AsyncMock(
|
||||
return_value=(mock_workflow_run, 100)
|
||||
)
|
||||
mock_db_client.get_active_integrations_by_organization = AsyncMock(
|
||||
return_value=[slack_integration, slack_integration2]
|
||||
)
|
||||
mock_db_client.get_configuration_value = AsyncMock(
|
||||
return_value={"slack": {"DISPOSITION_CODE": "Test message"}}
|
||||
)
|
||||
|
||||
with patch(
|
||||
"api.tasks.run_integrations.aiohttp.ClientSession"
|
||||
) as mock_session_class:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.__aenter__.return_value = mock_session
|
||||
mock_session.__aexit__.return_value = AsyncMock()
|
||||
|
||||
mock_post = MagicMock()
|
||||
mock_post.__aenter__.return_value = mock_response
|
||||
mock_post.__aexit__.return_value = AsyncMock()
|
||||
|
||||
mock_session.post.return_value = mock_post
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
# Call the function
|
||||
await run_integrations_post_workflow_run(None, 1)
|
||||
|
||||
# Verify both integrations were processed
|
||||
assert mock_session.post.call_count == 2
|
||||
|
||||
# Check that both webhooks were called
|
||||
call_urls = [call[0][0] for call in mock_session.post.call_args_list]
|
||||
assert "https://hooks.slack.com/test1" in call_urls
|
||||
assert "https://hooks.slack.com/test2" in call_urls
|
||||
|
|
@ -1,330 +0,0 @@
|
|||
"""Tests for webhook execution in run_integrations.py."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from api.tasks.run_integrations import (
|
||||
_build_auth_header,
|
||||
_build_render_context,
|
||||
_execute_webhook_node,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_logger():
|
||||
"""Mock the logger for all tests."""
|
||||
with patch("api.tasks.run_integrations.logger") as mock_log:
|
||||
mock_log.bind.return_value = mock_log
|
||||
yield mock_log
|
||||
|
||||
|
||||
class TestBuildAuthHeader:
|
||||
"""Tests for _build_auth_header function."""
|
||||
|
||||
def test_bearer_token(self):
|
||||
"""Test bearer token auth header."""
|
||||
credential = MagicMock()
|
||||
credential.credential_type = "bearer_token"
|
||||
credential.credential_data = {"token": "my-secret-token"}
|
||||
|
||||
result = _build_auth_header(credential)
|
||||
assert result == {"Authorization": "Bearer my-secret-token"}
|
||||
|
||||
def test_api_key(self):
|
||||
"""Test API key auth header."""
|
||||
credential = MagicMock()
|
||||
credential.credential_type = "api_key"
|
||||
credential.credential_data = {"header_name": "X-API-Key", "api_key": "key123"}
|
||||
|
||||
result = _build_auth_header(credential)
|
||||
assert result == {"X-API-Key": "key123"}
|
||||
|
||||
def test_api_key_default_header(self):
|
||||
"""Test API key with default header name."""
|
||||
credential = MagicMock()
|
||||
credential.credential_type = "api_key"
|
||||
credential.credential_data = {"api_key": "key123"}
|
||||
|
||||
result = _build_auth_header(credential)
|
||||
assert result == {"X-API-Key": "key123"}
|
||||
|
||||
def test_basic_auth(self):
|
||||
"""Test basic auth header."""
|
||||
credential = MagicMock()
|
||||
credential.credential_type = "basic_auth"
|
||||
credential.credential_data = {"username": "user", "password": "pass"}
|
||||
|
||||
result = _build_auth_header(credential)
|
||||
# base64 of "user:pass" is "dXNlcjpwYXNz"
|
||||
assert result == {"Authorization": "Basic dXNlcjpwYXNz"}
|
||||
|
||||
def test_custom_header(self):
|
||||
"""Test custom header auth."""
|
||||
credential = MagicMock()
|
||||
credential.credential_type = "custom_header"
|
||||
credential.credential_data = {
|
||||
"header_name": "X-Custom-Auth",
|
||||
"header_value": "custom-value",
|
||||
}
|
||||
|
||||
result = _build_auth_header(credential)
|
||||
assert result == {"X-Custom-Auth": "custom-value"}
|
||||
|
||||
def test_unknown_type(self):
|
||||
"""Test unknown credential type returns empty dict."""
|
||||
credential = MagicMock()
|
||||
credential.credential_type = "unknown"
|
||||
credential.credential_data = {}
|
||||
|
||||
result = _build_auth_header(credential)
|
||||
assert result == {}
|
||||
|
||||
|
||||
class TestBuildRenderContext:
|
||||
"""Tests for _build_render_context function."""
|
||||
|
||||
def test_basic_context(self):
|
||||
"""Test building render context from workflow run."""
|
||||
workflow_run = MagicMock()
|
||||
workflow_run.id = 123
|
||||
workflow_run.name = "WR-TEST-001"
|
||||
workflow_run.workflow_id = 456
|
||||
workflow_run.workflow.name = "Test Workflow"
|
||||
workflow_run.initial_context = {"phone_number": "+1234567890"}
|
||||
workflow_run.gathered_context = {
|
||||
"customer_name": "John",
|
||||
"mapped_call_disposition": "QUALIFIED",
|
||||
}
|
||||
workflow_run.usage_info = {"call_duration_seconds": 120}
|
||||
workflow_run.completed_at = None
|
||||
|
||||
result = _build_render_context(workflow_run)
|
||||
|
||||
assert result["workflow_run_id"] == 123
|
||||
assert result["workflow_run_name"] == "WR-TEST-001"
|
||||
assert result["workflow_id"] == 456
|
||||
assert result["workflow_name"] == "Test Workflow"
|
||||
assert result["initial_context"]["phone_number"] == "+1234567890"
|
||||
assert result["gathered_context"]["customer_name"] == "John"
|
||||
assert result["cost_info"]["call_duration_seconds"] == 120
|
||||
assert result["disposition_code"] == "QUALIFIED"
|
||||
|
||||
def test_empty_contexts(self):
|
||||
"""Test with empty/None contexts."""
|
||||
workflow_run = MagicMock()
|
||||
workflow_run.id = 1
|
||||
workflow_run.name = "Test"
|
||||
workflow_run.workflow_id = 1
|
||||
workflow_run.workflow.name = "Workflow"
|
||||
workflow_run.initial_context = None
|
||||
workflow_run.gathered_context = None
|
||||
workflow_run.usage_info = None
|
||||
workflow_run.completed_at = None
|
||||
|
||||
result = _build_render_context(workflow_run)
|
||||
|
||||
assert result["initial_context"] == {}
|
||||
assert result["gathered_context"] == {}
|
||||
assert result["cost_info"] == {}
|
||||
assert result["disposition_code"] is None
|
||||
|
||||
|
||||
class TestExecuteWebhookNode:
|
||||
"""Tests for _execute_webhook_node function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disabled_webhook_skipped(self):
|
||||
"""Test that disabled webhooks are skipped."""
|
||||
webhook_data = {"name": "Test Webhook", "enabled": False}
|
||||
|
||||
result = await _execute_webhook_node(
|
||||
webhook_data=webhook_data,
|
||||
render_context={},
|
||||
organization_id=1,
|
||||
)
|
||||
|
||||
assert result is True # Returns True for skipped webhooks
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_url_returns_false(self):
|
||||
"""Test that missing endpoint URL returns False."""
|
||||
webhook_data = {"name": "Test Webhook", "enabled": True, "endpoint_url": None}
|
||||
|
||||
result = await _execute_webhook_node(
|
||||
webhook_data=webhook_data,
|
||||
render_context={},
|
||||
organization_id=1,
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_post_request(self):
|
||||
"""Test successful POST webhook execution."""
|
||||
webhook_data = {
|
||||
"name": "CRM Sync",
|
||||
"enabled": True,
|
||||
"http_method": "POST",
|
||||
"endpoint_url": "https://api.example.com/webhook",
|
||||
"payload_template": {
|
||||
"call_id": "{{workflow_run_id}}",
|
||||
"phone": "{{initial_context.phone_number}}",
|
||||
},
|
||||
}
|
||||
|
||||
render_context = {
|
||||
"workflow_run_id": 123,
|
||||
"initial_context": {"phone_number": "+1234567890"},
|
||||
}
|
||||
|
||||
with patch("api.tasks.run_integrations.db_client") as mock_db:
|
||||
mock_db.get_credential_by_uuid = AsyncMock(return_value=None)
|
||||
|
||||
with patch("api.tasks.run_integrations.httpx.AsyncClient") as mock_client:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.request = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
||||
|
||||
result = await _execute_webhook_node(
|
||||
webhook_data=webhook_data,
|
||||
render_context=render_context,
|
||||
organization_id=1,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify the request was made correctly
|
||||
mock_client_instance.request.assert_called_once()
|
||||
call_kwargs = mock_client_instance.request.call_args[1]
|
||||
assert call_kwargs["method"] == "POST"
|
||||
assert call_kwargs["url"] == "https://api.example.com/webhook"
|
||||
assert call_kwargs["json"] == {
|
||||
"call_id": "123",
|
||||
"phone": "+1234567890",
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webhook_with_credential(self):
|
||||
"""Test webhook execution with credential auth."""
|
||||
webhook_data = {
|
||||
"name": "Authenticated Webhook",
|
||||
"enabled": True,
|
||||
"http_method": "POST",
|
||||
"endpoint_url": "https://api.example.com/webhook",
|
||||
"credential_uuid": "cred-123",
|
||||
"payload_template": {},
|
||||
}
|
||||
|
||||
mock_credential = MagicMock()
|
||||
mock_credential.name = "API Key"
|
||||
mock_credential.credential_type = "bearer_token"
|
||||
mock_credential.credential_data = {"token": "secret-token"}
|
||||
|
||||
with patch("api.tasks.run_integrations.db_client") as mock_db:
|
||||
mock_db.get_credential_by_uuid = AsyncMock(return_value=mock_credential)
|
||||
|
||||
with patch("api.tasks.run_integrations.httpx.AsyncClient") as mock_client:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.request = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
||||
|
||||
result = await _execute_webhook_node(
|
||||
webhook_data=webhook_data,
|
||||
render_context={},
|
||||
organization_id=1,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify auth header was included
|
||||
call_kwargs = mock_client_instance.request.call_args[1]
|
||||
assert call_kwargs["headers"]["Authorization"] == "Bearer secret-token"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webhook_with_custom_headers(self):
|
||||
"""Test webhook execution with custom headers."""
|
||||
webhook_data = {
|
||||
"name": "Custom Headers Webhook",
|
||||
"enabled": True,
|
||||
"http_method": "POST",
|
||||
"endpoint_url": "https://api.example.com/webhook",
|
||||
"custom_headers": [
|
||||
{"key": "X-Source", "value": "dograh"},
|
||||
{"key": "X-Workflow", "value": "test"},
|
||||
],
|
||||
"payload_template": {},
|
||||
}
|
||||
|
||||
with patch("api.tasks.run_integrations.db_client") as mock_db:
|
||||
mock_db.get_credential_by_uuid = AsyncMock(return_value=None)
|
||||
|
||||
with patch("api.tasks.run_integrations.httpx.AsyncClient") as mock_client:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.request = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
||||
|
||||
result = await _execute_webhook_node(
|
||||
webhook_data=webhook_data,
|
||||
render_context={},
|
||||
organization_id=1,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify custom headers were included
|
||||
call_kwargs = mock_client_instance.request.call_args[1]
|
||||
assert call_kwargs["headers"]["X-Source"] == "dograh"
|
||||
assert call_kwargs["headers"]["X-Workflow"] == "test"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webhook_http_error(self):
|
||||
"""Test webhook execution with HTTP error."""
|
||||
import httpx
|
||||
|
||||
webhook_data = {
|
||||
"name": "Failing Webhook",
|
||||
"enabled": True,
|
||||
"http_method": "POST",
|
||||
"endpoint_url": "https://api.example.com/webhook",
|
||||
"payload_template": {},
|
||||
}
|
||||
|
||||
with patch("api.tasks.run_integrations.db_client") as mock_db:
|
||||
mock_db.get_credential_by_uuid = AsyncMock(return_value=None)
|
||||
|
||||
with patch("api.tasks.run_integrations.httpx.AsyncClient") as mock_client:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
mock_response.text = "Internal Server Error"
|
||||
mock_response.raise_for_status = MagicMock(
|
||||
side_effect=httpx.HTTPStatusError(
|
||||
"Server Error",
|
||||
request=MagicMock(),
|
||||
response=mock_response,
|
||||
)
|
||||
)
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.request = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
||||
|
||||
result = await _execute_webhook_node(
|
||||
webhook_data=webhook_data,
|
||||
render_context={},
|
||||
organization_id=1,
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
|
@ -1,117 +0,0 @@
|
|||
"""Tests for the `/s3/signed-url` endpoint.
|
||||
|
||||
This test-suite verifies:
|
||||
1. Regular users can retrieve signed URLs for resources belonging to their own workflow runs.
|
||||
2. Regular users are *forbidden* from accessing resources that belong to other users.
|
||||
3. Superusers can access any resource irrespective of ownership.
|
||||
"""
|
||||
|
||||
import os
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from fastapi import status
|
||||
|
||||
# Ensure the S3 environment variables exist so that the module import does not fail
|
||||
os.environ.setdefault("S3_BUCKET", "test-bucket")
|
||||
os.environ.setdefault("S3_REGION", "us-east-1")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_signed_url_for_own_run(monkeypatch, test_client_factory, db_session):
|
||||
"""A normal user should be able to fetch a signed URL for their own workflow run."""
|
||||
from api.db.models import UserModel
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 1. Set-up – create user, workflow & workflow run
|
||||
# ------------------------------------------------------------------
|
||||
user: UserModel = await db_session.get_or_create_user_by_provider_id("user_own_run")
|
||||
workflow = await db_session.create_workflow("wf", {}, user.id)
|
||||
run = await db_session.create_workflow_run("run", workflow.id, "chat", user.id)
|
||||
|
||||
key = f"transcripts/{run.id}.txt"
|
||||
|
||||
# Patch S3 signed-url generator to avoid network calls
|
||||
monkeypatch.setattr(
|
||||
"api.services.filesystem.s3.s3_fs.aget_signed_url",
|
||||
AsyncMock(return_value="https://signed-url"),
|
||||
)
|
||||
|
||||
async with test_client_factory(user) as client:
|
||||
response = await client.get(f"/api/v1/s3/signed-url?key={key}")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data == {"url": "https://signed-url", "expires_in": 3600}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_signed_url_for_other_users_run_forbidden(
|
||||
monkeypatch, test_client_factory, db_session
|
||||
):
|
||||
"""A normal user must *not* access workflow runs owned by someone else."""
|
||||
from api.db.models import UserModel
|
||||
|
||||
# Owner of the workflow run
|
||||
owner: UserModel = await db_session.get_or_create_user_by_provider_id("owner_user")
|
||||
workflow = await db_session.create_workflow("wf", {}, owner.id)
|
||||
run = await db_session.create_workflow_run("run", workflow.id, "chat", owner.id)
|
||||
|
||||
# Second user attempting access
|
||||
intruder: UserModel = await db_session.get_or_create_user_by_provider_id(
|
||||
"intruder_user"
|
||||
)
|
||||
|
||||
key = f"recordings/{run.id}.wav"
|
||||
|
||||
monkeypatch.setattr(
|
||||
"api.services.filesystem.s3.s3_fs.aget_signed_url",
|
||||
AsyncMock(return_value="https://signed-url"),
|
||||
)
|
||||
|
||||
async with test_client_factory(intruder) as client:
|
||||
response = await client.get(f"/api/v1/s3/signed-url?key={key}")
|
||||
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_superuser_can_access_any_run(
|
||||
monkeypatch, test_client_factory, db_session
|
||||
):
|
||||
"""Superusers should be able to fetch signed URLs for any workflow run."""
|
||||
from api.db.models import UserModel
|
||||
|
||||
# Normal user & run owner
|
||||
owner: UserModel = await db_session.get_or_create_user_by_provider_id(
|
||||
"owner_of_run"
|
||||
)
|
||||
workflow = await db_session.create_workflow("wf", {}, owner.id)
|
||||
run = await db_session.create_workflow_run("run", workflow.id, "chat", owner.id)
|
||||
|
||||
# Superuser
|
||||
superuser: UserModel = await db_session.get_or_create_user_by_provider_id(
|
||||
"admin_user"
|
||||
)
|
||||
|
||||
# Promote to superuser
|
||||
# We need to commit the change so that the DB reflects it
|
||||
async with db_session.async_session() as session:
|
||||
db_user = await session.get(UserModel, superuser.id)
|
||||
db_user.is_superuser = True
|
||||
await session.commit()
|
||||
await session.refresh(db_user) # ensure we have the latest state
|
||||
superuser.is_superuser = True
|
||||
|
||||
key = f"transcripts/{run.id}.txt"
|
||||
|
||||
monkeypatch.setattr(
|
||||
"api.services.filesystem.s3.s3_fs.aget_signed_url",
|
||||
AsyncMock(return_value="https://signed-url"),
|
||||
)
|
||||
|
||||
async with test_client_factory(superuser) as client:
|
||||
response = await client.get(f"/api/v1/s3/signed-url?key={key}")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.json()["url"] == "https://signed-url"
|
||||
|
|
@ -1,129 +0,0 @@
|
|||
import os
|
||||
import tempfile
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from api.tasks.s3_upload import upload_audio_to_s3, upload_transcript_to_s3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_audio_to_s3_success():
|
||||
"""Test successful audio upload to S3."""
|
||||
# Create a temporary file
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tf:
|
||||
tf.write(b"fake audio data")
|
||||
temp_path = tf.name
|
||||
|
||||
try:
|
||||
# Mock dependencies
|
||||
mock_ctx = AsyncMock()
|
||||
mock_s3_fs = AsyncMock()
|
||||
mock_db_client = AsyncMock()
|
||||
|
||||
with (
|
||||
patch("api.tasks.s3_upload.s3_fs", mock_s3_fs),
|
||||
patch("api.tasks.s3_upload.db_client", mock_db_client),
|
||||
):
|
||||
await upload_audio_to_s3(
|
||||
mock_ctx, workflow_run_id=123, temp_file_path=temp_path
|
||||
)
|
||||
|
||||
# Verify S3 upload was called
|
||||
mock_s3_fs.aupload_file.assert_called_once_with(
|
||||
temp_path, "recordings/123.wav"
|
||||
)
|
||||
|
||||
# Verify DB update was called
|
||||
mock_db_client.update_workflow_run.assert_called_once_with(
|
||||
run_id=123, recording_url="recordings/123.wav"
|
||||
)
|
||||
|
||||
# Verify temp file was cleaned up
|
||||
assert not os.path.exists(temp_path)
|
||||
|
||||
finally:
|
||||
# Clean up if test failed
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_audio_to_s3_file_not_found():
|
||||
"""Test audio upload when temp file doesn't exist."""
|
||||
mock_ctx = AsyncMock()
|
||||
|
||||
with pytest.raises(FileNotFoundError):
|
||||
await upload_audio_to_s3(
|
||||
mock_ctx, workflow_run_id=123, temp_file_path="/nonexistent/file.wav"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_transcript_to_s3_success():
|
||||
"""Test successful transcript upload to S3."""
|
||||
# Create a temporary file
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as tf:
|
||||
tf.write("Test transcript content")
|
||||
temp_path = tf.name
|
||||
|
||||
try:
|
||||
# Mock dependencies
|
||||
mock_ctx = AsyncMock()
|
||||
mock_s3_fs = AsyncMock()
|
||||
mock_db_client = AsyncMock()
|
||||
|
||||
with (
|
||||
patch("api.tasks.s3_upload.s3_fs", mock_s3_fs),
|
||||
patch("api.tasks.s3_upload.db_client", mock_db_client),
|
||||
):
|
||||
await upload_transcript_to_s3(
|
||||
mock_ctx, workflow_run_id=456, temp_file_path=temp_path
|
||||
)
|
||||
|
||||
# Verify S3 upload was called
|
||||
mock_s3_fs.aupload_file.assert_called_once_with(
|
||||
temp_path, "transcripts/456.txt"
|
||||
)
|
||||
|
||||
# Verify DB update was called
|
||||
mock_db_client.update_workflow_run.assert_called_once_with(
|
||||
run_id=456, transcript_url="transcripts/456.txt"
|
||||
)
|
||||
|
||||
# Verify temp file was cleaned up
|
||||
assert not os.path.exists(temp_path)
|
||||
|
||||
finally:
|
||||
# Clean up if test failed
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_s3_cleanup_on_error():
|
||||
"""Test that temp files are cleaned up even when S3 upload fails."""
|
||||
# Create a temporary file
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tf:
|
||||
tf.write(b"fake audio data")
|
||||
temp_path = tf.name
|
||||
|
||||
try:
|
||||
mock_ctx = AsyncMock()
|
||||
mock_s3_fs = AsyncMock()
|
||||
# Make S3 upload fail
|
||||
mock_s3_fs.aupload_file.side_effect = Exception("S3 upload failed")
|
||||
|
||||
with patch("api.tasks.s3_upload.s3_fs", mock_s3_fs):
|
||||
with pytest.raises(Exception):
|
||||
await upload_audio_to_s3(
|
||||
mock_ctx, workflow_run_id=123, temp_file_path=temp_path
|
||||
)
|
||||
|
||||
# Verify temp file was still cleaned up
|
||||
assert not os.path.exists(temp_path)
|
||||
|
||||
finally:
|
||||
# Clean up if test failed
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
|
|
@ -1,89 +0,0 @@
|
|||
from api.utils.template_renderer import render_template
|
||||
|
||||
|
||||
def test_render_template_basic():
|
||||
"""Test basic template rendering."""
|
||||
template = "Hello {{name}}, your balance is {{balance}}."
|
||||
context = {"name": "John", "balance": "$1000"}
|
||||
|
||||
result = render_template(template, context)
|
||||
assert result == "Hello John, your balance is $1000."
|
||||
|
||||
|
||||
def test_render_template_with_spaces():
|
||||
"""Test template rendering with spaces around variables."""
|
||||
template = "Hello {{ name }}, your balance is {{ balance }}."
|
||||
context = {"name": "John", "balance": "$1000"}
|
||||
|
||||
result = render_template(template, context)
|
||||
assert result == "Hello John, your balance is $1000."
|
||||
|
||||
|
||||
def test_render_template_missing_variable():
|
||||
"""Test template rendering with missing variables."""
|
||||
template = "Hello {{name}}, your balance is {{balance}}."
|
||||
context = {"name": "John"}
|
||||
|
||||
result = render_template(template, context)
|
||||
assert result == "Hello John, your balance is ."
|
||||
|
||||
|
||||
def test_render_template_with_fallback():
|
||||
"""Test template rendering with fallback values."""
|
||||
template = "Hello {{name | fallback}}, your balance is {{balance | fallback:$0}}."
|
||||
context = {}
|
||||
|
||||
result = render_template(template, context)
|
||||
assert result == "Hello Name, your balance is $0."
|
||||
|
||||
|
||||
def test_render_template_with_fallback_existing_value():
|
||||
"""Test that fallback is not used when value exists."""
|
||||
template = "Hello {{name | fallback:Guest}}"
|
||||
context = {"name": "John"}
|
||||
|
||||
result = render_template(template, context)
|
||||
assert result == "Hello John"
|
||||
|
||||
|
||||
def test_render_template_with_line_breaks():
|
||||
"""Test template rendering with line breaks."""
|
||||
template = (
|
||||
"DISPOSITION_CODE: {{call_disposition}}\\nCALL_DURATION: {{call_duration}}"
|
||||
)
|
||||
context = {"call_disposition": "XFER", "call_duration": "300"}
|
||||
|
||||
result = render_template(template, context)
|
||||
expected = "DISPOSITION_CODE: XFER\nCALL_DURATION: 300"
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_render_template_empty():
|
||||
"""Test rendering empty template."""
|
||||
assert render_template("", {}) == ""
|
||||
assert render_template(None, {}) == None
|
||||
|
||||
|
||||
def test_render_template_no_placeholders():
|
||||
"""Test template with no placeholders."""
|
||||
template = "This is a plain text message"
|
||||
result = render_template(template, {"unused": "value"})
|
||||
assert result == "This is a plain text message"
|
||||
|
||||
|
||||
def test_render_template_none_values():
|
||||
"""Test template with None values."""
|
||||
template = "Value: {{value}}"
|
||||
context = {"value": None}
|
||||
|
||||
result = render_template(template, context)
|
||||
assert result == "Value: "
|
||||
|
||||
|
||||
def test_render_template_numeric_values():
|
||||
"""Test template with numeric values."""
|
||||
template = "Count: {{count}}, Price: {{price}}"
|
||||
context = {"count": 42, "price": 19.99}
|
||||
|
||||
result = render_template(template, context)
|
||||
assert result == "Count: 42, Price: 19.99"
|
||||
|
|
@ -1,152 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
"""
|
||||
Test script to verify atomic operations in organization_usage_client.py
|
||||
This simulates concurrent access from multiple processes.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
|
||||
# Set up environment
|
||||
os.environ.setdefault("DATABASE_URL", os.environ.get("DATABASE_URL", ""))
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from api.db.organization_usage_client import OrganizationUsageClient
|
||||
|
||||
|
||||
async def reserve_quota_process(org_id: int, tokens: int, process_id: int):
|
||||
"""Simulate a process trying to reserve quota."""
|
||||
engine = create_async_engine(os.environ["DATABASE_URL"])
|
||||
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
client = OrganizationUsageClient(async_session)
|
||||
|
||||
results = []
|
||||
for i in range(5):
|
||||
result = await client.check_and_reserve_quota(org_id, tokens)
|
||||
results.append((process_id, i, result))
|
||||
await asyncio.sleep(0.01) # Small delay to increase contention
|
||||
|
||||
await engine.dispose()
|
||||
return results
|
||||
|
||||
|
||||
async def update_usage_process(org_id: int, tokens: int, process_id: int):
|
||||
"""Simulate a process updating usage after runs."""
|
||||
engine = create_async_engine(os.environ["DATABASE_URL"])
|
||||
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
client = OrganizationUsageClient(async_session)
|
||||
|
||||
for i in range(5):
|
||||
await client.update_usage_after_run(org_id, tokens, duration_seconds=10)
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
await engine.dispose()
|
||||
return f"Process {process_id} completed updates"
|
||||
|
||||
|
||||
def run_reserve_quota(args):
|
||||
"""Wrapper to run async function in process."""
|
||||
org_id, tokens, process_id = args
|
||||
return asyncio.run(reserve_quota_process(org_id, tokens, process_id))
|
||||
|
||||
|
||||
def run_update_usage(args):
|
||||
"""Wrapper to run async function in process."""
|
||||
org_id, tokens, process_id = args
|
||||
return asyncio.run(update_usage_process(org_id, tokens, process_id))
|
||||
|
||||
|
||||
async def test_concurrent_quota_reservation():
|
||||
"""Test that concurrent quota reservations are handled atomically."""
|
||||
print("Testing concurrent quota reservations...")
|
||||
|
||||
# Assuming org_id 1 exists with quota enabled
|
||||
org_id = 1
|
||||
tokens_per_request = 100
|
||||
|
||||
# Run multiple processes trying to reserve quota simultaneously
|
||||
with ProcessPoolExecutor(max_workers=3) as executor:
|
||||
futures = []
|
||||
for i in range(3):
|
||||
futures.append(
|
||||
executor.submit(run_reserve_quota, (org_id, tokens_per_request, i))
|
||||
)
|
||||
|
||||
results = []
|
||||
for future in futures:
|
||||
results.extend(future.result())
|
||||
|
||||
print(f"Reservation results: {results}")
|
||||
|
||||
# Check that reservations were handled atomically
|
||||
successful_reservations = sum(1 for _, _, success in results if success)
|
||||
print(f"Successful reservations: {successful_reservations}")
|
||||
|
||||
|
||||
async def test_concurrent_usage_updates():
|
||||
"""Test that concurrent usage updates are handled atomically."""
|
||||
print("\nTesting concurrent usage updates...")
|
||||
|
||||
org_id = 1
|
||||
tokens_per_update = 50
|
||||
|
||||
# Get initial usage
|
||||
engine = create_async_engine(os.environ["DATABASE_URL"])
|
||||
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
client = OrganizationUsageClient(async_session)
|
||||
|
||||
initial_usage = await client.get_current_usage(org_id)
|
||||
initial_tokens = initial_usage["used_dograh_tokens"]
|
||||
print(f"Initial tokens: {initial_tokens}")
|
||||
|
||||
# Run multiple processes updating usage simultaneously
|
||||
with ProcessPoolExecutor(max_workers=3) as executor:
|
||||
futures = []
|
||||
for i in range(3):
|
||||
futures.append(
|
||||
executor.submit(run_update_usage, (org_id, tokens_per_update, i))
|
||||
)
|
||||
|
||||
for future in futures:
|
||||
print(future.result())
|
||||
|
||||
# Check final usage
|
||||
final_usage = await client.get_current_usage(org_id)
|
||||
final_tokens = final_usage["used_dograh_tokens"]
|
||||
expected_tokens = initial_tokens + (
|
||||
3 * 5 * tokens_per_update
|
||||
) # 3 processes * 5 updates * 50 tokens
|
||||
|
||||
print(f"Final tokens: {final_tokens}")
|
||||
print(f"Expected tokens: {expected_tokens}")
|
||||
print(f"Difference: {final_tokens - expected_tokens}")
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
if final_tokens == expected_tokens:
|
||||
print("✅ All updates were applied atomically!")
|
||||
else:
|
||||
print("❌ Some updates were lost due to race conditions!")
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run all concurrency tests."""
|
||||
try:
|
||||
await test_concurrent_quota_reservation()
|
||||
await test_concurrent_usage_updates()
|
||||
except Exception as e:
|
||||
print(f"Error during testing: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Starting organization usage concurrency tests...")
|
||||
print(f"Using DATABASE_URL: {os.environ.get('DATABASE_URL', 'NOT SET')}")
|
||||
asyncio.run(main())
|
||||
|
|
@ -1,140 +0,0 @@
|
|||
import json
|
||||
import os
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from pipecat.services.openai.llm import OpenAILLMContext
|
||||
|
||||
from api.services.workflow.dto import ExtractionVariableDTO, VariableType
|
||||
from api.services.workflow.pipecat_engine_variable_extractor import (
|
||||
VariableExtractionManager,
|
||||
)
|
||||
|
||||
|
||||
class DummyLLM:
|
||||
"""A minimal stub that mimics the parts of an LLM service used by the extractor."""
|
||||
|
||||
def __init__(self, streamed_response: str | None = None):
|
||||
# Optionally provide a pre-defined streaming response for _perform_extraction tests
|
||||
self._streamed_response = streamed_response or "{}"
|
||||
self.registered_functions: dict[str, AsyncMock] = {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# API used by VariableExtractionManager
|
||||
# ------------------------------------------------------------------
|
||||
def register_function(self, name: str, func, cancel_on_interruption=True): # noqa: D401 – simple delegate
|
||||
self.registered_functions[name] = func
|
||||
|
||||
async def get_chat_completions(self, _context, _messages):
|
||||
"""Return an async generator that yields a single chunk with the full response."""
|
||||
|
||||
class _Delta: # noqa: D401 – tiny helper classes for stub response
|
||||
def __init__(self, content):
|
||||
self.content = content
|
||||
|
||||
class _Choice:
|
||||
def __init__(self, delta):
|
||||
self.delta = delta
|
||||
|
||||
class _Chunk:
|
||||
def __init__(self, content):
|
||||
self.choices = [_Choice(_Delta(content))]
|
||||
|
||||
async def _stream():
|
||||
yield _Chunk(self._streamed_response)
|
||||
|
||||
return _stream()
|
||||
|
||||
|
||||
class DummyEngine:
|
||||
"""A bare-bones Engine stub exposing only what the extractor relies on."""
|
||||
|
||||
def __init__(self, llm):
|
||||
self.llm = llm
|
||||
self.context = OpenAILLMContext()
|
||||
self._pending_function_calls = 0
|
||||
# VariableExtractionManager currently updates this private attribute
|
||||
self._gathered_context: dict = {}
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tests
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_perform_extraction_parses_json_correctly():
|
||||
"""_perform_extraction should return the parsed JSON from the LLM stream."""
|
||||
# Set dummy OpenAI API key to prevent initialization errors
|
||||
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
|
||||
expected_payload = {"name": "Alice", "age": 30}
|
||||
llm = DummyLLM(json.dumps(expected_payload))
|
||||
engine = DummyEngine(llm)
|
||||
manager = VariableExtractionManager(engine)
|
||||
|
||||
# Mock the AsyncOpenAI client and its response
|
||||
mock_response = AsyncMock()
|
||||
mock_response.choices = [AsyncMock()]
|
||||
mock_response.choices[0].message = AsyncMock()
|
||||
mock_response.choices[0].message.content = json.dumps(expected_payload)
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine_variable_extractor.AsyncOpenAI",
|
||||
return_value=mock_client,
|
||||
):
|
||||
# Minimal set of variables to extract – the prompts themselves are irrelevant here
|
||||
extraction_variables = [
|
||||
ExtractionVariableDTO(
|
||||
name="name", type=VariableType.string, prompt="user name"
|
||||
),
|
||||
ExtractionVariableDTO(
|
||||
name="age", type=VariableType.number, prompt="user age"
|
||||
),
|
||||
]
|
||||
|
||||
result = await manager._perform_extraction(
|
||||
extraction_variables, parent_ctx=None, extraction_prompt=""
|
||||
)
|
||||
|
||||
assert result == expected_payload
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_perform_extraction_with_custom_system_prompt():
|
||||
"""_perform_extraction should use the provided extraction_prompt as system prompt."""
|
||||
# Set dummy OpenAI API key to prevent initialization errors
|
||||
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
|
||||
expected_payload = {"color": "blue"}
|
||||
llm = DummyLLM(json.dumps(expected_payload))
|
||||
engine = DummyEngine(llm)
|
||||
manager = VariableExtractionManager(engine)
|
||||
|
||||
# Mock the AsyncOpenAI client and its response
|
||||
mock_response = AsyncMock()
|
||||
mock_response.choices = [AsyncMock()]
|
||||
mock_response.choices[0].message = AsyncMock()
|
||||
mock_response.choices[0].message.content = json.dumps(expected_payload)
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine_variable_extractor.AsyncOpenAI",
|
||||
return_value=mock_client,
|
||||
):
|
||||
extraction_variables = [
|
||||
ExtractionVariableDTO(
|
||||
name="color", type=VariableType.string, prompt="favourite color"
|
||||
)
|
||||
]
|
||||
|
||||
# Call with a custom extraction prompt
|
||||
custom_prompt = "You are a color extraction specialist."
|
||||
result = await manager._perform_extraction(
|
||||
extraction_variables, parent_ctx=None, extraction_prompt=custom_prompt
|
||||
)
|
||||
|
||||
assert result == expected_payload
|
||||
|
|
@ -1,547 +0,0 @@
|
|||
"""
|
||||
Test voicemail detection in RTC connection flow.
|
||||
|
||||
This test emulates how a call is connected using SmallWebRTC,
|
||||
triggers voicemail detection, and verifies the disconnect reason.
|
||||
"""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from pipecat.utils.enums import EndTaskReason
|
||||
|
||||
from api.routes.rtc_offer import RTCOfferRequest, offer
|
||||
from api.services.workflow.pipecat_engine_voicemail_detector import VoicemailDetector
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestVoicemailDetectionRTC:
|
||||
"""Test voicemail detection through RTC connection flow."""
|
||||
|
||||
async def test_voicemail_detection_full_flow(self):
|
||||
"""
|
||||
Test complete voicemail detection flow:
|
||||
1. RTC connection request
|
||||
2. Transport sends on_client_connected event
|
||||
3. Engine initializes with voicemail detection enabled
|
||||
4. Voicemail detector returns true
|
||||
5. Call terminates with voicemail_detected reason
|
||||
6. Transport sends on_client_disconnected event
|
||||
7. Disconnect reason is properly set
|
||||
"""
|
||||
# Mock user and authentication
|
||||
mock_user = Mock()
|
||||
mock_user.id = 1
|
||||
mock_user.organization_id = 1
|
||||
|
||||
# Mock workflow with voicemail detection enabled
|
||||
mock_workflow = Mock()
|
||||
mock_workflow.id = 100
|
||||
mock_workflow.workflow_definition_with_fallback = {
|
||||
"edges": [],
|
||||
"nodes": [
|
||||
{
|
||||
"id": "start",
|
||||
"type": "start",
|
||||
"data": {
|
||||
"detect_voicemail": True,
|
||||
"system_prompt": "You are a helpful assistant",
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
# Mock workflow run
|
||||
mock_workflow_run = Mock()
|
||||
mock_workflow_run.id = 200
|
||||
mock_workflow_run.is_completed = False
|
||||
|
||||
# Create request
|
||||
request = RTCOfferRequest(
|
||||
pc_id="test_pc_123",
|
||||
sdp="test_sdp_offer",
|
||||
type="offer",
|
||||
workflow_id=mock_workflow.id,
|
||||
workflow_run_id=mock_workflow_run.id,
|
||||
restart_pc=False,
|
||||
call_context_vars={"test_var": "test_value"},
|
||||
)
|
||||
|
||||
# Mock dependencies
|
||||
with (
|
||||
patch("api.services.auth.depends.get_user") as mock_get_user_dep,
|
||||
patch("api.routes.rtc_offer.SmallWebRTCConnection") as MockWebRTCConnection,
|
||||
patch("api.routes.rtc_offer.run_pipeline_smallwebrtc") as mock_run_pipeline,
|
||||
):
|
||||
# Setup mocks
|
||||
mock_get_user_dep.return_value = mock_user
|
||||
|
||||
# Mock WebRTC connection
|
||||
mock_connection = Mock()
|
||||
mock_connection.pc_id = "test_pc_123"
|
||||
mock_connection.initialize = AsyncMock()
|
||||
mock_connection.get_answer = Mock(
|
||||
return_value={
|
||||
"pc_id": "test_pc_123",
|
||||
"sdp": "test_sdp_answer",
|
||||
"type": "answer",
|
||||
}
|
||||
)
|
||||
MockWebRTCConnection.return_value = mock_connection
|
||||
|
||||
# Track registered event handlers
|
||||
registered_handlers = {}
|
||||
|
||||
def mock_event_handler(event_name):
|
||||
def decorator(func):
|
||||
registered_handlers[event_name] = func
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
mock_connection.event_handler = mock_event_handler
|
||||
|
||||
# Mock BackgroundTasks
|
||||
mock_background_tasks = Mock()
|
||||
|
||||
# Create the offer
|
||||
response = await offer(request, mock_background_tasks, mock_user)
|
||||
|
||||
# Verify response
|
||||
assert response["pc_id"] == "test_pc_123"
|
||||
assert response["type"] == "answer"
|
||||
|
||||
# Verify connection was initialized
|
||||
mock_connection.initialize.assert_called_once_with(
|
||||
sdp="test_sdp_offer", type="offer"
|
||||
)
|
||||
|
||||
# Verify background task was added
|
||||
mock_background_tasks.add_task.assert_called_once()
|
||||
task_args = mock_background_tasks.add_task.call_args[0]
|
||||
assert task_args[0] == mock_run_pipeline
|
||||
assert task_args[1] == mock_connection
|
||||
assert task_args[2] == mock_workflow.id
|
||||
assert task_args[3] == mock_workflow_run.id
|
||||
assert task_args[4] == mock_user.id
|
||||
assert task_args[5] == {"test_var": "test_value"}
|
||||
|
||||
async def test_voicemail_detection_in_pipeline(self):
|
||||
"""Tests whether the updates happen in on_client_disconnected properly
|
||||
with values set in the engine"""
|
||||
# Mock components
|
||||
mock_transport = AsyncMock()
|
||||
mock_engine = Mock() # Use Mock instead of AsyncMock for engine
|
||||
mock_engine.initialize = AsyncMock()
|
||||
mock_engine.cleanup = AsyncMock()
|
||||
mock_audio_buffer = AsyncMock()
|
||||
mock_task = AsyncMock()
|
||||
mock_aggregator = Mock()
|
||||
|
||||
# Setup engine with voicemail detector
|
||||
mock_voicemail_detector = AsyncMock(spec=VoicemailDetector)
|
||||
mock_engine.voicemail_detector = mock_voicemail_detector
|
||||
mock_engine.get_call_disposition = Mock(
|
||||
return_value=EndTaskReason.VOICEMAIL_DETECTED.value
|
||||
)
|
||||
mock_engine.get_gathered_context = Mock(
|
||||
return_value={
|
||||
"voicemail_transcript": "Hi, you've reached John's voicemail. Please leave a message.",
|
||||
"voicemail_confidence": 0.95,
|
||||
}
|
||||
)
|
||||
|
||||
# Mock usage metrics
|
||||
mock_aggregator.get_all_usage_metrics_serialized.return_value = {}
|
||||
|
||||
# Register event handlers
|
||||
from api.services.pipecat.event_handlers import (
|
||||
register_transport_event_handlers,
|
||||
)
|
||||
|
||||
# Track registered handlers
|
||||
handlers = {}
|
||||
|
||||
def track_handler(event_name):
|
||||
def decorator(func):
|
||||
handlers[event_name] = func
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
mock_transport.event_handler = track_handler
|
||||
|
||||
# Create a mock db_client module with update_workflow_run method
|
||||
mock_db_client = Mock()
|
||||
mock_db_client.update_workflow_run = AsyncMock()
|
||||
|
||||
with (
|
||||
patch("api.services.pipecat.event_handlers.db_client", mock_db_client),
|
||||
patch(
|
||||
"api.services.pipecat.event_handlers.enqueue_job",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_enqueue_job,
|
||||
patch(
|
||||
"api.services.pipecat.event_handlers.get_organization_id_from_workflow_run",
|
||||
return_value=1,
|
||||
),
|
||||
patch(
|
||||
"api.services.pipecat.event_handlers.apply_disposition_mapping",
|
||||
side_effect=lambda value, org_id: value, # Return value unchanged
|
||||
),
|
||||
):
|
||||
# Register handlers
|
||||
register_transport_event_handlers(
|
||||
mock_transport,
|
||||
workflow_run_id=123,
|
||||
audio_buffer=mock_audio_buffer,
|
||||
task=mock_task,
|
||||
engine=mock_engine,
|
||||
usage_metrics_aggregator=mock_aggregator,
|
||||
)
|
||||
|
||||
# Verify handlers were registered
|
||||
assert "on_client_connected" in handlers
|
||||
assert "on_client_disconnected" in handlers
|
||||
|
||||
# Simulate client connection
|
||||
await handlers["on_client_connected"](
|
||||
mock_transport, {"id": "participant_1"}
|
||||
)
|
||||
|
||||
# Verify initialization
|
||||
mock_audio_buffer.start_recording.assert_called_once()
|
||||
mock_engine.initialize.assert_called_once()
|
||||
|
||||
# Simulate voicemail detection and disconnect
|
||||
await handlers["on_client_disconnected"](
|
||||
mock_transport, {"id": "participant_1"}, None
|
||||
)
|
||||
|
||||
# Verify engine cleanup
|
||||
mock_engine.cleanup.assert_called_once()
|
||||
|
||||
# TODO: check whether task was cancelled or not once have more
|
||||
# clarity on how to handle engine disconnect vs remote hangup
|
||||
# Verify task was NOT cancelled (engine disconnect)
|
||||
# mock_task.cancel.assert_not_called()
|
||||
|
||||
# Verify workflow run was updated with voicemail context
|
||||
mock_db_client.update_workflow_run.assert_called()
|
||||
call_args = mock_db_client.update_workflow_run.call_args
|
||||
assert call_args[1]["run_id"] == 123
|
||||
# Check that the mapped_call_disposition was set correctly
|
||||
assert (
|
||||
call_args[1]["gathered_context"]["mapped_call_disposition"]
|
||||
== "voicemail_detected"
|
||||
)
|
||||
|
||||
async def test_voicemail_detector_audio_processing(self):
|
||||
"""Test VoicemailDetector audio processing and detection logic - tests that voicemail detector
|
||||
calls engine's send_end_task_frame with the correct reason and metadata"""
|
||||
# Create voicemail detector
|
||||
detector = VoicemailDetector(detection_duration=5.0, workflow_run_id=123)
|
||||
|
||||
# Mock OpenAI client
|
||||
mock_openai = AsyncMock()
|
||||
mock_whisper_response = Mock()
|
||||
mock_whisper_response.text = "Hi, you've reached the voicemail of John Smith. Please leave a message after the beep."
|
||||
mock_openai.audio.transcriptions.create.return_value = mock_whisper_response
|
||||
|
||||
mock_gpt_response = Mock()
|
||||
mock_gpt_response.choices = [Mock()]
|
||||
mock_gpt_response.choices[0].message.content = json.dumps(
|
||||
{
|
||||
"is_voicemail": True,
|
||||
"confidence": 0.98,
|
||||
"reasoning": "Clear voicemail greeting with request to leave message",
|
||||
}
|
||||
)
|
||||
mock_openai.chat.completions.create.return_value = mock_gpt_response
|
||||
|
||||
# Mock engine
|
||||
mock_engine = AsyncMock()
|
||||
mock_engine.task = AsyncMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"api.services.workflow.pipecat_engine_voicemail_detector.AsyncOpenAI",
|
||||
return_value=mock_openai,
|
||||
),
|
||||
patch(
|
||||
"api.services.workflow.pipecat_engine_voicemail_detector.s3_fs"
|
||||
) as mock_s3,
|
||||
):
|
||||
# Mock S3 upload to return None (simulating successful upload)
|
||||
mock_s3.aupload_file = AsyncMock(return_value=True)
|
||||
# Start detection
|
||||
await detector.start_detection(mock_engine)
|
||||
assert detector.is_detecting == True
|
||||
|
||||
# Simulate audio data (16kHz, mono, 5 seconds)
|
||||
sample_rate = 16000
|
||||
duration = 5.0
|
||||
audio_data = b"\x00\x00" * int(sample_rate * duration) # Silent audio
|
||||
|
||||
# Process audio in chunks
|
||||
chunk_size = 1600 # 100ms chunks
|
||||
for i in range(0, len(audio_data), chunk_size):
|
||||
chunk = audio_data[i : i + chunk_size]
|
||||
await detector.handle_audio_data(None, chunk, sample_rate, 1)
|
||||
|
||||
# Wait for detection to complete
|
||||
if detector._detection_task:
|
||||
await detector._detection_task
|
||||
|
||||
# Verify OpenAI calls
|
||||
mock_openai.audio.transcriptions.create.assert_called_once()
|
||||
mock_openai.chat.completions.create.assert_called_once()
|
||||
|
||||
# Verify send_end_task_frame was called with voicemail detection
|
||||
mock_engine.send_end_task_frame.assert_called_once_with(
|
||||
reason=EndTaskReason.VOICEMAIL_DETECTED.value,
|
||||
additional_metadata={
|
||||
"voicemail_transcript": "Hi, you've reached the voicemail of John Smith. Please leave a message after the beep.",
|
||||
"voicemail_confidence": 0.98,
|
||||
"voicemail_reasoning": "Clear voicemail greeting with request to leave message",
|
||||
"voicemail_detection_duration": 5.0,
|
||||
"voicemail_audio_s3_path": "voicemail_detections/123_voicemail_98_5.wav", # S3 upload returns True, so filename is used
|
||||
},
|
||||
abort_immediately=True,
|
||||
)
|
||||
|
||||
async def test_voicemail_detector_no_detection(self):
|
||||
"""Test VoicemailDetector when voicemail is not detected."""
|
||||
# Create voicemail detector
|
||||
detector = VoicemailDetector(detection_duration=5.0, workflow_run_id=124)
|
||||
|
||||
# Mock OpenAI client
|
||||
mock_openai = AsyncMock()
|
||||
mock_whisper_response = Mock()
|
||||
mock_whisper_response.text = "Hello? Hello? Can you hear me?"
|
||||
mock_openai.audio.transcriptions.create.return_value = mock_whisper_response
|
||||
|
||||
mock_gpt_response = Mock()
|
||||
mock_gpt_response.choices = [Mock()]
|
||||
mock_gpt_response.choices[0].message.content = json.dumps(
|
||||
{
|
||||
"is_voicemail": False,
|
||||
"confidence": 0.95,
|
||||
"reasoning": "Live person speaking, asking if caller can hear them",
|
||||
}
|
||||
)
|
||||
mock_openai.chat.completions.create.return_value = mock_gpt_response
|
||||
|
||||
# Mock engine
|
||||
mock_engine = AsyncMock()
|
||||
mock_engine.task = AsyncMock()
|
||||
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine_voicemail_detector.AsyncOpenAI",
|
||||
return_value=mock_openai,
|
||||
):
|
||||
# Start detection
|
||||
await detector.start_detection(mock_engine)
|
||||
|
||||
# Simulate audio data
|
||||
sample_rate = 16000
|
||||
duration = 5.0
|
||||
audio_data = b"\x00\x00" * int(sample_rate * duration)
|
||||
|
||||
# Process audio
|
||||
await detector.handle_audio_data(None, audio_data, sample_rate, 1)
|
||||
|
||||
# Wait for detection
|
||||
if detector._detection_task:
|
||||
await detector._detection_task
|
||||
|
||||
# Verify send_end_task_frame was NOT called
|
||||
mock_engine.send_end_task_frame.assert_not_called()
|
||||
|
||||
async def test_voicemail_detector_cancellation(self):
|
||||
"""Test VoicemailDetector cancellation before completion."""
|
||||
# Create voicemail detector
|
||||
detector = VoicemailDetector(detection_duration=10.0, workflow_run_id=125)
|
||||
|
||||
# Mock engine
|
||||
mock_engine = AsyncMock()
|
||||
|
||||
# Start detection
|
||||
await detector.start_detection(mock_engine)
|
||||
assert detector.is_detecting == True
|
||||
|
||||
# Cancel detection immediately
|
||||
await detector.stop_detection()
|
||||
assert detector._is_cancelled == True
|
||||
|
||||
# Try to add audio data after cancellation
|
||||
await detector.handle_audio_data(None, b"\x00\x00" * 1000, 16000, 1)
|
||||
|
||||
# Verify buffer didn't grow (no audio accepted after cancellation)
|
||||
assert len(detector.audio_buffer) == 0
|
||||
|
||||
async def test_disconnect_reason_propagation(self):
|
||||
"""Test that voicemail disconnect reason is properly propagated."""
|
||||
# Create disconnect reason info directly
|
||||
disconnect_info = {
|
||||
"disposition_code": EndTaskReason.VOICEMAIL_DETECTED.value,
|
||||
"details": "Voicemail detected after 5 seconds of audio",
|
||||
"is_remote": False,
|
||||
"is_user_initiated": False,
|
||||
"is_successful_transfer": False,
|
||||
"transport_metadata": {
|
||||
"voicemail_confidence": 0.97,
|
||||
"voicemail_transcript": "You've reached voicemail...",
|
||||
},
|
||||
}
|
||||
|
||||
# Verify attributes
|
||||
assert disconnect_info["disposition_code"] == "voicemail_detected"
|
||||
assert disconnect_info["is_remote"] == False
|
||||
assert disconnect_info["is_user_initiated"] == False
|
||||
assert disconnect_info["is_successful_transfer"] == False
|
||||
assert (
|
||||
disconnect_info["details"] == "Voicemail detected after 5 seconds of audio"
|
||||
)
|
||||
assert disconnect_info["transport_metadata"]["voicemail_confidence"] == 0.97
|
||||
|
||||
async def test_voicemail_detection_end_to_end(self):
|
||||
"""
|
||||
Complete end-to-end test covering:
|
||||
1. on_client_connected event
|
||||
2. Engine initialization with voicemail detection
|
||||
3. Audio processing and voicemail detection
|
||||
4. Engine setting disconnect reason
|
||||
5. on_client_disconnected event
|
||||
6. Proper disconnect reason in workflow run update
|
||||
"""
|
||||
# Create comprehensive mocks
|
||||
from api.services.pipecat.event_handlers import (
|
||||
register_transport_event_handlers,
|
||||
)
|
||||
|
||||
# Mock transport
|
||||
mock_transport = AsyncMock()
|
||||
handlers = {}
|
||||
|
||||
def track_handler(event_name):
|
||||
def decorator(func):
|
||||
handlers[event_name] = func
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
mock_transport.event_handler = track_handler
|
||||
|
||||
# Mock audio buffer
|
||||
mock_audio_buffer = Mock()
|
||||
mock_audio_buffer.start_recording = AsyncMock()
|
||||
mock_audio_buffer.stop_recording = AsyncMock()
|
||||
|
||||
# Mock task
|
||||
mock_task = AsyncMock()
|
||||
|
||||
# Mock aggregator
|
||||
mock_aggregator = Mock()
|
||||
mock_aggregator.get_all_usage_metrics_serialized.return_value = {}
|
||||
|
||||
# Create a mock engine with voicemail detection
|
||||
mock_engine = Mock()
|
||||
mock_engine.initialize = AsyncMock()
|
||||
mock_engine.cleanup = AsyncMock()
|
||||
|
||||
# Mock voicemail detector
|
||||
mock_voicemail_detector = Mock()
|
||||
mock_engine.voicemail_detector = mock_voicemail_detector
|
||||
mock_engine._voicemail_detector = mock_voicemail_detector
|
||||
|
||||
# Initially no disconnect reason
|
||||
mock_engine.get_call_disposition = Mock(return_value=None)
|
||||
mock_engine.get_gathered_context = Mock(return_value={})
|
||||
|
||||
# Mock db_client
|
||||
mock_db_client = Mock()
|
||||
mock_db_client.update_workflow_run = AsyncMock()
|
||||
|
||||
with (
|
||||
patch("api.services.pipecat.event_handlers.db_client", mock_db_client),
|
||||
patch(
|
||||
"api.services.pipecat.event_handlers.enqueue_job",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_enqueue_job,
|
||||
patch(
|
||||
"api.services.pipecat.event_handlers.get_organization_id_from_workflow_run",
|
||||
return_value=1,
|
||||
),
|
||||
patch(
|
||||
"api.services.pipecat.event_handlers.apply_disposition_mapping",
|
||||
side_effect=lambda value, org_id: value, # Return value unchanged
|
||||
),
|
||||
):
|
||||
# Register event handlers
|
||||
register_transport_event_handlers(
|
||||
mock_transport,
|
||||
workflow_run_id=123,
|
||||
audio_buffer=mock_audio_buffer,
|
||||
task=mock_task,
|
||||
engine=mock_engine,
|
||||
usage_metrics_aggregator=mock_aggregator,
|
||||
)
|
||||
|
||||
# Verify handlers were registered
|
||||
assert "on_client_connected" in handlers
|
||||
assert "on_client_disconnected" in handlers
|
||||
|
||||
# Step 1: Client connects
|
||||
await handlers["on_client_connected"](
|
||||
mock_transport, {"id": "participant_1"}
|
||||
)
|
||||
|
||||
# Verify initialization
|
||||
mock_audio_buffer.start_recording.assert_called_once()
|
||||
mock_engine.initialize.assert_called_once()
|
||||
|
||||
# Step 2-3: Simulate voicemail detection occurs
|
||||
# Update engine state to reflect voicemail was detected
|
||||
mock_engine.get_call_disposition = Mock(
|
||||
return_value=EndTaskReason.VOICEMAIL_DETECTED.value
|
||||
)
|
||||
mock_engine.get_gathered_context = Mock(
|
||||
return_value={
|
||||
"voicemail_transcript": "You've reached voicemail, leave a message",
|
||||
"voicemail_confidence": 0.95,
|
||||
}
|
||||
)
|
||||
|
||||
# Step 5: Client disconnects
|
||||
await handlers["on_client_disconnected"](
|
||||
mock_transport, {"id": "participant_1"}, None
|
||||
)
|
||||
|
||||
# Verify engine cleanup
|
||||
mock_engine.cleanup.assert_called_once()
|
||||
|
||||
# Step 6: Verify proper disconnect reason in workflow run update
|
||||
mock_db_client.update_workflow_run.assert_called()
|
||||
call_args = mock_db_client.update_workflow_run.call_args
|
||||
|
||||
# Check the gathered context includes disconnect reason
|
||||
gathered_context = call_args[1]["gathered_context"]
|
||||
assert gathered_context["mapped_call_disposition"] == "voicemail_detected"
|
||||
assert gathered_context["voicemail_confidence"] == 0.95
|
||||
assert (
|
||||
gathered_context["voicemail_transcript"]
|
||||
== "You've reached voicemail, leave a message"
|
||||
)
|
||||
|
||||
# Verify task was NOT cancelled (engine-initiated disconnect)
|
||||
mock_task.cancel.assert_not_called()
|
||||
|
||||
# Verify audio buffer was stopped
|
||||
mock_audio_buffer.stop_recording.assert_called_once()
|
||||
|
||||
# Verify background jobs were enqueued
|
||||
assert (
|
||||
mock_enqueue_job.call_count >= 3
|
||||
) # At least 3 jobs should be enqueued
|
||||
|
|
@ -1,667 +0,0 @@
|
|||
"""
|
||||
Tests for workflow API routes.
|
||||
|
||||
This module tests the create, update, get, and validate workflow endpoints.
|
||||
The fixtures for database setup, test client, and utilities are in conftest.py.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from fastapi import status
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_workflow_definition():
|
||||
"""Sample workflow definition for testing."""
|
||||
return {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "6581",
|
||||
"type": "startCall",
|
||||
"position": {"x": 427, "y": 23},
|
||||
"data": {
|
||||
"prompt": "Hello, I am Abhishek from Dograh. ",
|
||||
"is_static": True,
|
||||
"name": "Start Call",
|
||||
"is_start": True,
|
||||
"invalid": False,
|
||||
"validationMessage": None,
|
||||
"allow_interrupt": False,
|
||||
},
|
||||
"measured": {"width": 300, "height": 100},
|
||||
"selected": True,
|
||||
"dragging": False,
|
||||
},
|
||||
{
|
||||
"id": "915",
|
||||
"type": "agentNode",
|
||||
"position": {"x": 305, "y": 340},
|
||||
"data": {
|
||||
"prompt": "You are a voice agent whose mode of speaking is voice. Ask the user whether they want to talk to a sales guy or a customer service agent.",
|
||||
"name": "Agent",
|
||||
"invalid": False,
|
||||
"validationMessage": None,
|
||||
"allow_interrupt": False,
|
||||
},
|
||||
"measured": {"width": 300, "height": 100},
|
||||
"selected": False,
|
||||
"dragging": False,
|
||||
},
|
||||
{
|
||||
"id": "7598",
|
||||
"type": "agentNode",
|
||||
"position": {"x": 90, "y": 650},
|
||||
"data": {
|
||||
"prompt": "You are a customer service agent whose mode of communication with the user is voice. Tell them that someone from our team will reach out to them soon",
|
||||
"name": "Agent",
|
||||
"invalid": False,
|
||||
"validationMessage": None,
|
||||
"allow_interrupt": True,
|
||||
},
|
||||
"measured": {"width": 300, "height": 100},
|
||||
"selected": False,
|
||||
"dragging": False,
|
||||
},
|
||||
{
|
||||
"id": "6919",
|
||||
"type": "agentNode",
|
||||
"position": {"x": 520, "y": 650},
|
||||
"data": {
|
||||
"prompt": "You are a sales representative whose mode of communication with the user is voice. Tell the user that someone from our team will reach out to you soon",
|
||||
"name": "Agent",
|
||||
"invalid": False,
|
||||
"validationMessage": None,
|
||||
"allow_interrupt": True,
|
||||
},
|
||||
"measured": {"width": 300, "height": 100},
|
||||
"selected": False,
|
||||
"dragging": False,
|
||||
},
|
||||
{
|
||||
"id": "1802",
|
||||
"type": "endCall",
|
||||
"position": {"x": 305, "y": 960},
|
||||
"data": {
|
||||
"prompt": "Thank you!",
|
||||
"invalid": False,
|
||||
"validationMessage": None,
|
||||
"is_static": True,
|
||||
"name": "End Call",
|
||||
"is_end": True,
|
||||
"allow_interrupt": False,
|
||||
},
|
||||
"measured": {"width": 300, "height": 100},
|
||||
"selected": False,
|
||||
"dragging": False,
|
||||
},
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"animated": True,
|
||||
"type": "custom",
|
||||
"source": "915",
|
||||
"target": "7598",
|
||||
"id": "xy-edge__915-7598",
|
||||
"selected": False,
|
||||
"data": {
|
||||
"condition": "The customer wants to talk to a customer service agent",
|
||||
"label": "customer service agent",
|
||||
"invalid": False,
|
||||
"validationMessage": None,
|
||||
},
|
||||
},
|
||||
{
|
||||
"animated": True,
|
||||
"type": "custom",
|
||||
"source": "915",
|
||||
"target": "6919",
|
||||
"id": "xy-edge__915-6919",
|
||||
"selected": False,
|
||||
"data": {
|
||||
"condition": "customer wants to talk to a sales representative",
|
||||
"label": "sales representative",
|
||||
"invalid": False,
|
||||
"validationMessage": None,
|
||||
},
|
||||
},
|
||||
{
|
||||
"animated": True,
|
||||
"type": "custom",
|
||||
"source": "6581",
|
||||
"target": "915",
|
||||
"id": "xy-edge__6581-915",
|
||||
"selected": False,
|
||||
"data": {
|
||||
"condition": "Always take this route",
|
||||
"label": "Always take this route",
|
||||
"invalid": False,
|
||||
"validationMessage": None,
|
||||
},
|
||||
},
|
||||
{
|
||||
"animated": True,
|
||||
"type": "custom",
|
||||
"source": "7598",
|
||||
"target": "1802",
|
||||
"id": "xy-edge__7598-1802",
|
||||
"selected": False,
|
||||
"data": {
|
||||
"condition": "end call",
|
||||
"label": "end call",
|
||||
"invalid": False,
|
||||
"validationMessage": None,
|
||||
},
|
||||
},
|
||||
{
|
||||
"animated": True,
|
||||
"type": "custom",
|
||||
"source": "6919",
|
||||
"target": "1802",
|
||||
"id": "xy-edge__6919-1802",
|
||||
"selected": False,
|
||||
"data": {
|
||||
"condition": "end call",
|
||||
"label": "end call",
|
||||
"invalid": False,
|
||||
"validationMessage": None,
|
||||
},
|
||||
},
|
||||
],
|
||||
"viewport": {"x": 0, "y": 0, "zoom": 1},
|
||||
}
|
||||
|
||||
|
||||
class TestCreateWorkflow:
|
||||
"""Test cases for creating workflows."""
|
||||
|
||||
async def test_create_workflow_success(
|
||||
self, test_client_factory, db_session, sample_workflow_definition
|
||||
):
|
||||
"""Test successful workflow creation."""
|
||||
# Create a test user for this test
|
||||
test_user = await db_session.get_or_create_user_by_provider_id(
|
||||
"test_user_create_success"
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"name": "Test Workflow",
|
||||
"workflow_definition": sample_workflow_definition,
|
||||
}
|
||||
|
||||
async with test_client_factory(test_user) as client:
|
||||
response = await client.post("/api/v1/workflow/create", json=request_data)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
assert "id" in data
|
||||
assert data["name"] == "Test Workflow"
|
||||
assert data["workflow_definition"] == sample_workflow_definition
|
||||
assert "created_at" in data
|
||||
assert "current_definition_id" in data
|
||||
|
||||
async def test_create_workflow_invalid_definition(
|
||||
self, test_client_factory, db_session
|
||||
):
|
||||
"""Test workflow creation with invalid definition."""
|
||||
# Create a test user for this test
|
||||
test_user = await db_session.get_or_create_user_by_provider_id(
|
||||
"test_user_invalid_def"
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"name": "Invalid Workflow",
|
||||
"workflow_definition": {"invalid": "structure"},
|
||||
}
|
||||
|
||||
async with test_client_factory(test_user) as client:
|
||||
response = await client.post("/api/v1/workflow/create", json=request_data)
|
||||
|
||||
# The API should still create the workflow even with invalid definition
|
||||
# Validation happens in the validate endpoint
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_workflow_missing_name(
|
||||
self, test_client_factory, db_session, sample_workflow_definition
|
||||
):
|
||||
"""Test workflow creation without name."""
|
||||
# Create a test user for this test
|
||||
test_user = await db_session.get_or_create_user_by_provider_id(
|
||||
"test_user_missing_name"
|
||||
)
|
||||
|
||||
request_data = {"workflow_definition": sample_workflow_definition}
|
||||
|
||||
async with test_client_factory(test_user) as client:
|
||||
response = await client.post("/api/v1/workflow/create", json=request_data)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_workflow_missing_definition(
|
||||
self, test_client_factory, db_session
|
||||
):
|
||||
"""Test workflow creation without workflow definition."""
|
||||
# Create a test user for this test
|
||||
test_user = await db_session.get_or_create_user_by_provider_id(
|
||||
"test_user_missing_definition"
|
||||
)
|
||||
|
||||
request_data = {"name": "Test Workflow"}
|
||||
|
||||
async with test_client_factory(test_user) as client:
|
||||
response = await client.post("/api/v1/workflow/create", json=request_data)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
|
||||
class TestGetWorkflows:
|
||||
"""Test cases for fetching workflows."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_workflows_empty(self, test_client_factory, db_session):
|
||||
"""Test getting all workflows when none exist."""
|
||||
# Create a test user within the test function
|
||||
test_user = await db_session.get_or_create_user_by_provider_id(
|
||||
"test_user_empty_workflows"
|
||||
)
|
||||
|
||||
# Create a test client for this specific user
|
||||
async with test_client_factory(test_user) as client:
|
||||
response = await client.get("/api/v1/workflow/fetch")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert isinstance(data, list)
|
||||
assert len(data) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_workflows_with_data(
|
||||
self, test_client_factory, db_session, sample_workflow_definition
|
||||
):
|
||||
"""Test getting all workflows when some exist."""
|
||||
# Create a test user within the test function
|
||||
test_user = await db_session.get_or_create_user_by_provider_id(
|
||||
"test_user_with_workflows"
|
||||
)
|
||||
|
||||
# Create a test client for this specific user
|
||||
async with test_client_factory(test_user) as client:
|
||||
# Create a workflow first
|
||||
create_response = await client.post(
|
||||
"/api/v1/workflow/create",
|
||||
json={
|
||||
"name": "Test Workflow 1",
|
||||
"workflow_definition": sample_workflow_definition,
|
||||
},
|
||||
)
|
||||
assert create_response.status_code == status.HTTP_200_OK
|
||||
|
||||
# Create another workflow
|
||||
create_response2 = await client.post(
|
||||
"/api/v1/workflow/create",
|
||||
json={
|
||||
"name": "Test Workflow 2",
|
||||
"workflow_definition": sample_workflow_definition,
|
||||
},
|
||||
)
|
||||
assert create_response2.status_code == status.HTTP_200_OK
|
||||
|
||||
# Get all workflows
|
||||
response = await client.get("/api/v1/workflow/fetch")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert isinstance(data, list)
|
||||
assert len(data) == 2
|
||||
|
||||
# Check that both workflows are returned
|
||||
workflow_names = [w["name"] for w in data]
|
||||
assert "Test Workflow 1" in workflow_names
|
||||
assert "Test Workflow 2" in workflow_names
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_specific_workflow(
|
||||
self, test_client_factory, db_session, sample_workflow_definition
|
||||
):
|
||||
"""Test getting a specific workflow by ID."""
|
||||
# Create a test user for this test
|
||||
test_user = await db_session.get_or_create_user_by_provider_id(
|
||||
"test_user_specific_workflow"
|
||||
)
|
||||
|
||||
async with test_client_factory(test_user) as client:
|
||||
# Create a workflow first
|
||||
create_response = await client.post(
|
||||
"/api/v1/workflow/create",
|
||||
json={
|
||||
"name": "Specific Workflow",
|
||||
"workflow_definition": sample_workflow_definition,
|
||||
},
|
||||
)
|
||||
assert create_response.status_code == status.HTTP_200_OK
|
||||
created_workflow = create_response.json()
|
||||
workflow_id = created_workflow["id"]
|
||||
|
||||
# Get the specific workflow
|
||||
response = await client.get(
|
||||
f"/api/v1/workflow/fetch?workflow_id={workflow_id}"
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
assert data["id"] == workflow_id
|
||||
assert data["name"] == "Specific Workflow"
|
||||
assert data["workflow_definition"] == sample_workflow_definition
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_nonexistent_workflow(self, test_client_factory, db_session):
|
||||
"""Test getting a workflow that doesn't exist."""
|
||||
# Create a test user for this test
|
||||
test_user = await db_session.get_or_create_user_by_provider_id(
|
||||
"test_user_nonexistent"
|
||||
)
|
||||
|
||||
async with test_client_factory(test_user) as client:
|
||||
response = await client.get("/api/v1/workflow/fetch?workflow_id=99999")
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert "not found" in response.json()["detail"].lower()
|
||||
|
||||
|
||||
class TestUpdateWorkflow:
|
||||
"""Test cases for updating workflows."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_workflow_name_only(
|
||||
self, test_client_factory, db_session, sample_workflow_definition
|
||||
):
|
||||
"""Test updating only the workflow name."""
|
||||
# Create a test user for this test
|
||||
test_user = await db_session.get_or_create_user_by_provider_id(
|
||||
"test_user_update_name"
|
||||
)
|
||||
|
||||
async with test_client_factory(test_user) as client:
|
||||
# Create a workflow first
|
||||
create_response = await client.post(
|
||||
"/api/v1/workflow/create",
|
||||
json={
|
||||
"name": "Original Name",
|
||||
"workflow_definition": sample_workflow_definition,
|
||||
},
|
||||
)
|
||||
assert create_response.status_code == status.HTTP_200_OK
|
||||
workflow_id = create_response.json()["id"]
|
||||
|
||||
# Update the workflow name
|
||||
update_data = {"name": "Updated Name"}
|
||||
response = await client.put(
|
||||
f"/api/v1/workflow/{workflow_id}", json=update_data
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
assert data["id"] == workflow_id
|
||||
assert data["name"] == "Updated Name"
|
||||
assert (
|
||||
data["workflow_definition"] == sample_workflow_definition
|
||||
) # Should remain unchanged
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_workflow_name_and_definition(
|
||||
self, test_client_factory, db_session, sample_workflow_definition
|
||||
):
|
||||
"""Test updating both workflow name and definition."""
|
||||
# Create a test user for this test
|
||||
test_user = await db_session.get_or_create_user_by_provider_id(
|
||||
"test_user_update_both"
|
||||
)
|
||||
|
||||
async with test_client_factory(test_user) as client:
|
||||
# Create a workflow first
|
||||
create_response = await client.post(
|
||||
"/api/v1/workflow/create",
|
||||
json={
|
||||
"name": "Original Name",
|
||||
"workflow_definition": sample_workflow_definition,
|
||||
},
|
||||
)
|
||||
assert create_response.status_code == status.HTTP_200_OK
|
||||
workflow_id = create_response.json()["id"]
|
||||
|
||||
# Create new workflow definition
|
||||
new_definition = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "start",
|
||||
"type": "start",
|
||||
"position": {"x": 50, "y": 50},
|
||||
"data": {"label": "New Start"},
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
|
||||
# Update the workflow
|
||||
update_data = {
|
||||
"name": "Updated Name",
|
||||
"workflow_definition": new_definition,
|
||||
}
|
||||
response = await client.put(
|
||||
f"/api/v1/workflow/{workflow_id}", json=update_data
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
assert data["id"] == workflow_id
|
||||
assert data["name"] == "Updated Name"
|
||||
assert data["workflow_definition"] == new_definition
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_nonexistent_workflow(self, test_client_factory, db_session):
|
||||
"""Test updating a workflow that doesn't exist."""
|
||||
# Create a test user for this test
|
||||
test_user = await db_session.get_or_create_user_by_provider_id(
|
||||
"test_user_update_nonexistent"
|
||||
)
|
||||
|
||||
update_data = {"name": "Updated Name"}
|
||||
async with test_client_factory(test_user) as client:
|
||||
response = await client.put("/api/v1/workflow/99999", json=update_data)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert "not found" in response.json()["detail"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_workflow_missing_name(
|
||||
self, test_client_factory, db_session, sample_workflow_definition
|
||||
):
|
||||
"""Test updating a workflow without providing a name."""
|
||||
# Create a test user for this test
|
||||
test_user = await db_session.get_or_create_user_by_provider_id(
|
||||
"test_user_update_missing_name"
|
||||
)
|
||||
|
||||
async with test_client_factory(test_user) as client:
|
||||
# Create a workflow first
|
||||
create_response = await client.post(
|
||||
"/api/v1/workflow/create",
|
||||
json={
|
||||
"name": "Original Name",
|
||||
"workflow_definition": sample_workflow_definition,
|
||||
},
|
||||
)
|
||||
assert create_response.status_code == status.HTTP_200_OK
|
||||
workflow_id = create_response.json()["id"]
|
||||
|
||||
# Try to update without providing name
|
||||
update_data = {"workflow_definition": sample_workflow_definition}
|
||||
response = await client.put(
|
||||
f"/api/v1/workflow/{workflow_id}", json=update_data
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
|
||||
class TestWorkflowValidation:
|
||||
"""Test cases for workflow validation endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_workflow_success(
|
||||
self, test_client_factory, db_session, sample_workflow_definition
|
||||
):
|
||||
"""Test successful workflow validation."""
|
||||
# Create a test user for this test
|
||||
test_user = await db_session.get_or_create_user_by_provider_id(
|
||||
"test_user_validate_success"
|
||||
)
|
||||
|
||||
async with test_client_factory(test_user) as client:
|
||||
# Create a workflow first
|
||||
create_response = await client.post(
|
||||
"/api/v1/workflow/create",
|
||||
json={
|
||||
"name": "Valid Workflow",
|
||||
"workflow_definition": sample_workflow_definition,
|
||||
},
|
||||
)
|
||||
assert create_response.status_code == status.HTTP_200_OK
|
||||
workflow_id = create_response.json()["id"]
|
||||
|
||||
# Validate the workflow
|
||||
response = await client.post(f"/api/v1/workflow/{workflow_id}/validate")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
assert data["is_valid"] is True
|
||||
assert data["errors"] == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_nonexistent_workflow(self, test_client_factory, db_session):
|
||||
"""Test validating a workflow that doesn't exist."""
|
||||
# Create a test user for this test
|
||||
test_user = await db_session.get_or_create_user_by_provider_id(
|
||||
"test_user_validate_nonexistent"
|
||||
)
|
||||
|
||||
async with test_client_factory(test_user) as client:
|
||||
response = await client.post("/api/v1/workflow/99999/validate")
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert "not found" in response.json()["detail"].lower()
|
||||
|
||||
|
||||
class TestWorkflowIntegration:
|
||||
"""Integration tests for workflow operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_workflow_lifecycle(
|
||||
self, test_client_factory, db_session, sample_workflow_definition
|
||||
):
|
||||
"""Test the complete lifecycle of a workflow: create, get, update, validate."""
|
||||
# Create a test user for this test
|
||||
test_user = await db_session.get_or_create_user_by_provider_id(
|
||||
"test_user_lifecycle"
|
||||
)
|
||||
|
||||
async with test_client_factory(test_user) as client:
|
||||
# 1. Create workflow
|
||||
create_response = await client.post(
|
||||
"/api/v1/workflow/create",
|
||||
json={
|
||||
"name": "Lifecycle Test Workflow",
|
||||
"workflow_definition": sample_workflow_definition,
|
||||
},
|
||||
)
|
||||
assert create_response.status_code == status.HTTP_200_OK
|
||||
workflow_id = create_response.json()["id"]
|
||||
|
||||
# 2. Get the created workflow
|
||||
get_response = await client.get(
|
||||
f"/api/v1/workflow/fetch?workflow_id={workflow_id}"
|
||||
)
|
||||
assert get_response.status_code == status.HTTP_200_OK
|
||||
workflow_data = get_response.json()
|
||||
assert workflow_data["name"] == "Lifecycle Test Workflow"
|
||||
|
||||
# 3. Add a new node in the workflow definition
|
||||
new_node = {
|
||||
"id": "6919_new",
|
||||
"type": "agentNode",
|
||||
"position": {"x": 520, "y": 650},
|
||||
"data": {
|
||||
"prompt": "Something new",
|
||||
"name": "Agent",
|
||||
"invalid": False,
|
||||
"validationMessage": None,
|
||||
"allow_interrupt": True,
|
||||
},
|
||||
"measured": {"width": 300, "height": 100},
|
||||
"selected": False,
|
||||
"dragging": False,
|
||||
}
|
||||
new_edges = [
|
||||
{
|
||||
"source": "6919",
|
||||
"target": "6919_new",
|
||||
"id": "xy-edge__6919-6919_new",
|
||||
"data": {
|
||||
"condition": "Always take this route",
|
||||
"label": "Always take this route",
|
||||
"invalid": False,
|
||||
"validationMessage": None,
|
||||
},
|
||||
},
|
||||
{
|
||||
"source": "6919_new",
|
||||
"target": "1802",
|
||||
"id": "xy-edge__6919_new-1802",
|
||||
"data": {
|
||||
"condition": "Always take this route",
|
||||
"label": "Always take this route",
|
||||
"invalid": False,
|
||||
"validationMessage": None,
|
||||
},
|
||||
},
|
||||
]
|
||||
new_definition = {
|
||||
"nodes": [
|
||||
*sample_workflow_definition["nodes"],
|
||||
new_node,
|
||||
],
|
||||
"edges": [
|
||||
*sample_workflow_definition["edges"],
|
||||
*new_edges,
|
||||
],
|
||||
}
|
||||
|
||||
update_response = await client.put(
|
||||
f"/api/v1/workflow/{workflow_id}",
|
||||
json={
|
||||
"name": "Updated Lifecycle Workflow",
|
||||
"workflow_definition": new_definition,
|
||||
},
|
||||
)
|
||||
assert update_response.status_code == status.HTTP_200_OK
|
||||
assert update_response.json()["name"] == "Updated Lifecycle Workflow"
|
||||
|
||||
# 4. Validate the updated workflow
|
||||
validate_response = await client.post(
|
||||
f"/api/v1/workflow/{workflow_id}/validate"
|
||||
)
|
||||
assert validate_response.status_code == status.HTTP_200_OK
|
||||
assert validate_response.json()["is_valid"] is True
|
||||
|
||||
# 5. Verify the update by getting the workflow again
|
||||
final_get_response = await client.get(
|
||||
f"/api/v1/workflow/fetch?workflow_id={workflow_id}"
|
||||
)
|
||||
assert final_get_response.status_code == status.HTTP_200_OK
|
||||
final_data = final_get_response.json()
|
||||
assert final_data["name"] == "Updated Lifecycle Workflow"
|
||||
assert final_data["workflow_definition"] == new_definition
|
||||
Loading…
Add table
Add a link
Reference in a new issue