mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-19 08:28:10 +02:00
Initial Commit 🚀 🚀
This commit is contained in:
commit
4f2a629340
444 changed files with 76863 additions and 0 deletions
0
api/tests/__init__.py
Normal file
0
api/tests/__init__.py
Normal file
138
api/tests/test_assistant_context_aggregator.py
Normal file
138
api/tests/test_assistant_context_aggregator.py
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
import asyncio
|
||||
|
||||
import pytest
|
||||
from pipecat.frames.frames import (
|
||||
FunctionCallInProgressFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
StartInterruptionFrame,
|
||||
TextFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.openai.llm import OpenAIAssistantContextAggregator
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reordering_after_completion():
|
||||
context = OpenAILLMContext()
|
||||
aggr = OpenAIAssistantContextAggregator(context)
|
||||
|
||||
# Initialize task manager properly using PipelineTask
|
||||
pipeline = Pipeline([aggr])
|
||||
task = PipelineTask(pipeline)
|
||||
runner = PipelineRunner()
|
||||
|
||||
# Start the task to properly initialize the frame processor
|
||||
task_coroutine = asyncio.create_task(runner.run(task))
|
||||
|
||||
# Give the task a moment to initialize
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# start new LLM response
|
||||
await aggr.process_frame(LLMFullResponseStartFrame(), FrameDirection.DOWNSTREAM)
|
||||
|
||||
# simulate a pending function call
|
||||
await aggr.process_frame(
|
||||
FunctionCallInProgressFrame(
|
||||
function_name="transition",
|
||||
tool_call_id="1",
|
||||
arguments={},
|
||||
cancel_on_interruption=False,
|
||||
),
|
||||
FrameDirection.DOWNSTREAM,
|
||||
)
|
||||
|
||||
# now text arrives
|
||||
await aggr.process_frame(TextFrame("Hi there"), FrameDirection.DOWNSTREAM)
|
||||
|
||||
# end response
|
||||
await aggr.process_frame(LLMFullResponseEndFrame(), FrameDirection.DOWNSTREAM)
|
||||
|
||||
msgs = context.get_messages()
|
||||
|
||||
# Assert order: assistant text first, then tool_call assistant, then tool response
|
||||
assert msgs[0]["role"] == "assistant" and "tool_calls" not in msgs[0]
|
||||
# Fix: content is a string, not a structured object
|
||||
assert msgs[0]["content"] == "Hi there"
|
||||
assert any(m.get("role") == "assistant" and m.get("tool_calls") for m in msgs[1:])
|
||||
assert any(m.get("role") == "tool" for m in msgs[1:])
|
||||
|
||||
# Clean up the running task
|
||||
await task.cancel()
|
||||
task_coroutine.cancel()
|
||||
try:
|
||||
await task_coroutine
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_interruption_removes_pending_function_calls_and_marks():
|
||||
context = OpenAILLMContext()
|
||||
aggr = OpenAIAssistantContextAggregator(context)
|
||||
|
||||
# Initialize task manager properly using PipelineTask
|
||||
pipeline = Pipeline([aggr])
|
||||
task = PipelineTask(pipeline)
|
||||
runner = PipelineRunner()
|
||||
|
||||
# Start the task to properly initialize the frame processor
|
||||
task_coroutine = asyncio.create_task(runner.run(task))
|
||||
|
||||
# Give the task a moment to initialize
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
await aggr.process_frame(LLMFullResponseStartFrame(), FrameDirection.DOWNSTREAM)
|
||||
await aggr.process_frame(
|
||||
FunctionCallInProgressFrame(
|
||||
function_name="transition",
|
||||
tool_call_id="1",
|
||||
arguments={},
|
||||
cancel_on_interruption=False,
|
||||
),
|
||||
FrameDirection.DOWNSTREAM,
|
||||
)
|
||||
|
||||
# Debug: Check the state before interruption
|
||||
print(
|
||||
f"Function calls in progress before interruption: {aggr._function_calls_in_progress}"
|
||||
)
|
||||
print(f"Messages before interruption: {context.get_messages()}")
|
||||
|
||||
# no text yet - still aggregation
|
||||
await aggr.process_frame(StartInterruptionFrame(), FrameDirection.DOWNSTREAM)
|
||||
|
||||
msgs = context.get_messages()
|
||||
|
||||
# Debug: Print messages to understand what's happening
|
||||
print(f"Messages after interruption: {msgs}")
|
||||
print(
|
||||
f"Function calls in progress after interruption: {aggr._function_calls_in_progress}"
|
||||
)
|
||||
|
||||
# After interruption before any response is complete, context should be cleared
|
||||
# This is the actual behavior - interruptions clear pending function calls
|
||||
if len(msgs) == 0:
|
||||
# Context was cleared due to interruption before completion
|
||||
assert True
|
||||
else:
|
||||
# If there are messages, ensure no tool calls remain
|
||||
assert not any(m.get("tool_calls") for m in msgs)
|
||||
assert not any(m.get("role") == "tool" for m in msgs)
|
||||
|
||||
# Check if interruption marker is present
|
||||
if msgs:
|
||||
assert msgs[-1]["role"] == "assistant"
|
||||
assert "<<interrupted_by_user>>" in msgs[-1]["content"]
|
||||
|
||||
# Clean up the running task
|
||||
await task.cancel()
|
||||
task_coroutine.cancel()
|
||||
try:
|
||||
await task_coroutine
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
120
api/tests/test_audio_transcript_buffers.py
Normal file
120
api/tests/test_audio_transcript_buffers.py
Normal file
|
|
@ -0,0 +1,120 @@
|
|||
import os
|
||||
import wave
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services.pipecat.audio_transcript_buffers import (
|
||||
InMemoryAudioBuffer,
|
||||
InMemoryTranscriptBuffer,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audio_buffer_append_and_write():
|
||||
"""Test that audio buffer can append data and write to temp file."""
|
||||
buffer = InMemoryAudioBuffer(workflow_run_id=123, sample_rate=16000, num_channels=1)
|
||||
|
||||
# Create some test PCM data
|
||||
test_pcm = b"\x00\x01" * 1000 # 2000 bytes
|
||||
|
||||
# Append data
|
||||
await buffer.append(test_pcm)
|
||||
await buffer.append(test_pcm)
|
||||
|
||||
assert buffer.size == 4000
|
||||
assert not buffer.is_empty
|
||||
|
||||
# Write to temp file
|
||||
temp_path = await buffer.write_to_temp_file()
|
||||
|
||||
try:
|
||||
# Verify file exists and is valid WAV
|
||||
assert os.path.exists(temp_path)
|
||||
|
||||
with wave.open(temp_path, "rb") as wf:
|
||||
assert wf.getnchannels() == 1
|
||||
assert wf.getsampwidth() == 2
|
||||
assert wf.getframerate() == 16000
|
||||
# Each frame is 2 bytes (16-bit), so 4000 bytes = 2000 frames
|
||||
assert wf.getnframes() == 2000
|
||||
finally:
|
||||
# Clean up
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audio_buffer_memory_limit():
|
||||
"""Test that audio buffer enforces memory limit."""
|
||||
buffer = InMemoryAudioBuffer(workflow_run_id=123, sample_rate=16000)
|
||||
|
||||
# Set a smaller limit for testing
|
||||
buffer._max_size = 1000
|
||||
|
||||
# This should work
|
||||
await buffer.append(b"\x00" * 500)
|
||||
|
||||
# This should fail
|
||||
with pytest.raises(MemoryError):
|
||||
await buffer.append(b"\x00" * 600)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_buffer_append_and_write():
|
||||
"""Test that transcript buffer can append data and write to temp file."""
|
||||
buffer = InMemoryTranscriptBuffer(workflow_run_id=456)
|
||||
|
||||
# Append some transcript lines
|
||||
await buffer.append("[00:00:01] user: Hello\n")
|
||||
await buffer.append("[00:00:02] assistant: Hi there!\n")
|
||||
await buffer.append("[00:00:03] user: How are you?\n")
|
||||
|
||||
assert not buffer.is_empty
|
||||
|
||||
# Write to temp file
|
||||
temp_path = await buffer.write_to_temp_file()
|
||||
|
||||
try:
|
||||
# Verify file exists and has correct content
|
||||
assert os.path.exists(temp_path)
|
||||
|
||||
with open(temp_path, "r") as f:
|
||||
content = f.read()
|
||||
assert "[00:00:01] user: Hello\n" in content
|
||||
assert "[00:00:02] assistant: Hi there!\n" in content
|
||||
assert "[00:00:03] user: How are you?\n" in content
|
||||
finally:
|
||||
# Clean up
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_buffers():
|
||||
"""Test that empty buffers are handled correctly."""
|
||||
audio_buffer = InMemoryAudioBuffer(workflow_run_id=789, sample_rate=16000)
|
||||
transcript_buffer = InMemoryTranscriptBuffer(workflow_run_id=789)
|
||||
|
||||
assert audio_buffer.is_empty
|
||||
assert transcript_buffer.is_empty
|
||||
|
||||
# Should still be able to write empty files
|
||||
audio_path = await audio_buffer.write_to_temp_file()
|
||||
transcript_path = await transcript_buffer.write_to_temp_file()
|
||||
|
||||
try:
|
||||
assert os.path.exists(audio_path)
|
||||
assert os.path.exists(transcript_path)
|
||||
|
||||
# Empty WAV file should still have valid header
|
||||
with wave.open(audio_path, "rb") as wf:
|
||||
assert wf.getnframes() == 0
|
||||
|
||||
# Empty transcript file
|
||||
with open(transcript_path, "r") as f:
|
||||
assert f.read() == ""
|
||||
finally:
|
||||
if os.path.exists(audio_path):
|
||||
os.remove(audio_path)
|
||||
if os.path.exists(transcript_path):
|
||||
os.remove(transcript_path)
|
||||
179
api/tests/test_base_openai_llm_service.py
Normal file
179
api/tests/test_base_openai_llm_service.py
Normal file
|
|
@ -0,0 +1,179 @@
|
|||
### - This test has some weird loop which keeps on increasing the context size
|
||||
|
||||
# import asyncio
|
||||
# import json
|
||||
# import unittest
|
||||
# from types import SimpleNamespace
|
||||
# from unittest import mock
|
||||
|
||||
# from loguru import logger
|
||||
|
||||
# from pipecat.frames.frames import (
|
||||
# FunctionCallInProgressFrame,
|
||||
# FunctionCallResultFrame,
|
||||
# FunctionCallsStartedFrame,
|
||||
# LLMFullResponseEndFrame,
|
||||
# LLMFullResponseStartFrame,
|
||||
# LLMGeneratedTextFrame,
|
||||
# LLMTextFrame,
|
||||
# )
|
||||
# from pipecat.pipeline.pipeline import Pipeline
|
||||
# from pipecat.processors.aggregators.openai_llm_context import (
|
||||
# OpenAILLMContext,
|
||||
# OpenAILLMContextFrame,
|
||||
# )
|
||||
# from pipecat.services.llm_service import (
|
||||
# FunctionCallParams,
|
||||
# FunctionCallResultProperties,
|
||||
# )
|
||||
# from pipecat.services.openai.llm import OpenAILLMService
|
||||
# from pipecat.tests.utils import run_test
|
||||
|
||||
|
||||
# class _MockAsyncStream:
|
||||
# """A minimal async-stream wrapper that mimics ``openai.AsyncStream``."""
|
||||
|
||||
# def __init__(self, chunks):
|
||||
# self._chunks = chunks
|
||||
|
||||
# def __aiter__(self):
|
||||
# self._idx = 0
|
||||
# return self
|
||||
|
||||
# async def __anext__(self):
|
||||
# if self._idx >= len(self._chunks):
|
||||
# raise StopAsyncIteration
|
||||
# item = self._chunks[self._idx]
|
||||
# self._idx += 1
|
||||
# await asyncio.sleep(0) # Yield control
|
||||
# return item
|
||||
|
||||
|
||||
# # ------------------------------------------------------------------
|
||||
# # Factories for mock chunks
|
||||
# # ------------------------------------------------------------------
|
||||
|
||||
|
||||
# def _make_tool_call(tool_name: str, args_json: str, *, idx: int = 0):
|
||||
# function = SimpleNamespace(name=tool_name, arguments=args_json)
|
||||
# return SimpleNamespace(index=idx, id=f"call-{idx}", function=function)
|
||||
|
||||
|
||||
# def _make_chunk(*, content: str | None = None, tool_calls=None, usage=None):
|
||||
# delta = SimpleNamespace()
|
||||
# # When we are asked to simulate multiple tool calls in parallel, OpenAI
|
||||
# # sends *separate* chunks for every tool-call index. To mimic that behaviour
|
||||
# # in tests we split a list of tool calls (>1) into individual chunks – one
|
||||
# # for each tool call – while keeping the original single-chunk behaviour
|
||||
# # when zero or one tool calls are supplied. This enables us to write
|
||||
# # concise tests such as ``_make_chunk(tool_calls=[call_1, call_2])`` that
|
||||
# # accurately reflect the streaming protocol.
|
||||
|
||||
# # No special handling needed if there is textual content or 0/1 tool calls.
|
||||
# if content is not None or tool_calls is None or len(tool_calls) <= 1:
|
||||
# if content is not None:
|
||||
# delta.content = content
|
||||
# # Always set tool_calls so downstream code can safely access it
|
||||
# delta.tool_calls = tool_calls if tool_calls is not None else None
|
||||
# return SimpleNamespace(choices=[SimpleNamespace(delta=delta)], usage=usage)
|
||||
|
||||
# # --- Multiple tool calls (len(tool_calls) > 1) ---
|
||||
# # Create a list of chunks, each containing a single tool call. This is the
|
||||
# # format produced by the OpenAI client when several tools are invoked in a
|
||||
# # single assistant response.
|
||||
# chunks = []
|
||||
# for tc in tool_calls:
|
||||
# delta_tc = SimpleNamespace(tool_calls=[tc])
|
||||
# chunks.append(SimpleNamespace(choices=[SimpleNamespace(delta=delta_tc)], usage=usage))
|
||||
|
||||
# return chunks
|
||||
|
||||
|
||||
# class TestBaseOpenAILLMService(unittest.IsolatedAsyncioTestCase):
|
||||
# async def test_process_context_with_patch(self):
|
||||
# streamed_text = "Hello from OpenAI!"
|
||||
# tool_name = "echo"
|
||||
# tool_name_2 = "echo_2"
|
||||
# tool_args = {"text": "hello"}
|
||||
# tool_args_2 = {"text": "hello_2"}
|
||||
|
||||
# # Build mocked stream (tool call first, then text)
|
||||
# chunks = [
|
||||
# _make_chunk(content=streamed_text),
|
||||
# _make_chunk(tool_calls=[_make_tool_call(tool_name, json.dumps(tool_args))]),
|
||||
# _make_chunk(tool_calls=[_make_tool_call(tool_name_2, json.dumps(tool_args_2), idx=1)]),
|
||||
# ]
|
||||
|
||||
# # Instantiate real OpenAILLMService (no need for actual API key)
|
||||
# llm = OpenAILLMService(model="gpt-4o-mini", api_key="test")
|
||||
|
||||
# # Patch get_chat_completions to return our mocked async stream
|
||||
# async def fake_get_chat_completions(self, context, messages): # noqa: D401
|
||||
# return _MockAsyncStream(chunks)
|
||||
|
||||
# with mock.patch.object(llm.__class__, "get_chat_completions", fake_get_chat_completions):
|
||||
# # Register echo tool
|
||||
# executed = False
|
||||
|
||||
# async def echo_handler(params: FunctionCallParams):
|
||||
# nonlocal executed
|
||||
# executed = True
|
||||
# # sleep for 1 second
|
||||
# logger.info("echo_handler: sleeping for 5 second")
|
||||
# await asyncio.sleep(5)
|
||||
# await params.result_callback(
|
||||
# {"ok": True},
|
||||
# properties=FunctionCallResultProperties(run_llm=True),
|
||||
# )
|
||||
|
||||
# async def echo_2_handler(params: FunctionCallParams):
|
||||
# nonlocal executed
|
||||
# executed = True
|
||||
# # sleep for 1 second
|
||||
# logger.info("echo_2_handler: sleeping for 5 second")
|
||||
# await asyncio.sleep(5)
|
||||
# await params.result_callback(
|
||||
# {"ok": True},
|
||||
# properties=FunctionCallResultProperties(run_llm=True),
|
||||
# )
|
||||
|
||||
# llm.register_function(tool_name, echo_handler)
|
||||
# llm.register_function(tool_name_2, echo_2_handler)
|
||||
|
||||
# # Prepare context and send
|
||||
# context = OpenAILLMContext()
|
||||
# context.add_message({"role": "user", "content": "Hi"})
|
||||
# frames_to_send = [OpenAILLMContextFrame(context)]
|
||||
|
||||
# expected_down_frames = [
|
||||
# LLMFullResponseStartFrame,
|
||||
# FunctionCallsStartedFrame,
|
||||
# FunctionCallInProgressFrame,
|
||||
# FunctionCallResultFrame,
|
||||
# LLMGeneratedTextFrame,
|
||||
# LLMTextFrame,
|
||||
# LLMFullResponseEndFrame,
|
||||
# ]
|
||||
|
||||
# context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
# pipeline = Pipeline([llm, context_aggregator.assistant()])
|
||||
|
||||
# down_frames, _ = await run_test(
|
||||
# pipeline,
|
||||
# frames_to_send=frames_to_send,
|
||||
# expected_down_frames=expected_down_frames,
|
||||
# send_end_frame=False,
|
||||
# )
|
||||
|
||||
# # Assertions
|
||||
# self.assertTrue(executed)
|
||||
# for fr in down_frames:
|
||||
# if isinstance(fr, FunctionCallResultFrame):
|
||||
# self.assertTrue(fr.run_llm)
|
||||
# if isinstance(fr, LLMTextFrame):
|
||||
# self.assertEqual(fr.text, streamed_text)
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# unittest.main()
|
||||
330
api/tests/test_concurrent_call_limiting.py
Normal file
330
api/tests/test_concurrent_call_limiting.py
Normal file
|
|
@ -0,0 +1,330 @@
|
|||
"""Tests for concurrent call limiting functionality."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from api.enums import OrganizationConfigurationKey
|
||||
from api.services.campaign.rate_limiter import RateLimiter
|
||||
|
||||
|
||||
class TestConcurrentCallLimiting:
|
||||
"""Test suite for concurrent call limiting."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_concurrent_slot_success(self):
|
||||
"""Test successful acquisition of concurrent slot."""
|
||||
rate_limiter = RateLimiter()
|
||||
|
||||
# Mock Redis client
|
||||
with patch.object(rate_limiter, "_get_redis") as mock_redis:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.eval = AsyncMock(return_value="test_slot_123")
|
||||
mock_redis.return_value = mock_client
|
||||
|
||||
# Try to acquire slot
|
||||
slot_id = await rate_limiter.try_acquire_concurrent_slot(
|
||||
organization_id=1, max_concurrent=20
|
||||
)
|
||||
|
||||
assert slot_id == "test_slot_123"
|
||||
mock_client.eval.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_concurrent_slot_limit_reached(self):
|
||||
"""Test slot acquisition when limit is reached."""
|
||||
rate_limiter = RateLimiter()
|
||||
|
||||
# Mock Redis client
|
||||
with patch.object(rate_limiter, "_get_redis") as mock_redis:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.eval = AsyncMock(return_value=None) # Limit reached
|
||||
mock_redis.return_value = mock_client
|
||||
|
||||
# Try to acquire slot
|
||||
slot_id = await rate_limiter.try_acquire_concurrent_slot(
|
||||
organization_id=1, max_concurrent=20
|
||||
)
|
||||
|
||||
assert slot_id is None
|
||||
mock_client.eval.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_release_concurrent_slot(self):
|
||||
"""Test releasing a concurrent slot."""
|
||||
rate_limiter = RateLimiter()
|
||||
|
||||
# Mock Redis client
|
||||
with patch.object(rate_limiter, "_get_redis") as mock_redis:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.zrem = AsyncMock(return_value=1) # Successfully removed
|
||||
mock_redis.return_value = mock_client
|
||||
|
||||
# Release slot
|
||||
success = await rate_limiter.release_concurrent_slot(
|
||||
organization_id=1, slot_id="test_slot_123"
|
||||
)
|
||||
|
||||
assert success is True
|
||||
mock_client.zrem.assert_called_once_with(
|
||||
"concurrent_calls:1", "test_slot_123"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_concurrent_count(self):
|
||||
"""Test getting current concurrent call count."""
|
||||
rate_limiter = RateLimiter()
|
||||
|
||||
# Mock Redis client
|
||||
with patch.object(rate_limiter, "_get_redis") as mock_redis:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.zremrangebyscore = AsyncMock() # Cleanup stale entries
|
||||
mock_client.zcard = AsyncMock(return_value=5) # 5 active calls
|
||||
mock_redis.return_value = mock_client
|
||||
|
||||
# Get count
|
||||
count = await rate_limiter.get_concurrent_count(organization_id=1)
|
||||
|
||||
assert count == 5
|
||||
mock_client.zremrangebyscore.assert_called_once()
|
||||
mock_client.zcard.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stale_entry_cleanup(self):
|
||||
"""Test that stale entries are cleaned up automatically."""
|
||||
rate_limiter = RateLimiter()
|
||||
|
||||
# Mock Redis client
|
||||
with patch.object(rate_limiter, "_get_redis") as mock_redis:
|
||||
mock_client = AsyncMock()
|
||||
|
||||
# Mock eval to simulate cleanup in Lua script
|
||||
mock_client.eval = AsyncMock(return_value="new_slot_123")
|
||||
mock_redis.return_value = mock_client
|
||||
|
||||
# Try to acquire slot (which should trigger cleanup)
|
||||
slot_id = await rate_limiter.try_acquire_concurrent_slot(
|
||||
organization_id=1, max_concurrent=20
|
||||
)
|
||||
|
||||
assert slot_id == "new_slot_123"
|
||||
|
||||
# Verify Lua script was called with proper stale cutoff
|
||||
call_args = mock_client.eval.call_args[0]
|
||||
lua_script = call_args[0]
|
||||
assert "ZREMRANGEBYSCORE" in lua_script # Cleanup command in script
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_workflow_slot_mapping_operations(self):
|
||||
"""Test storing, retrieving, and deleting workflow slot mappings."""
|
||||
rate_limiter = RateLimiter()
|
||||
|
||||
# Mock Redis client
|
||||
with patch.object(rate_limiter, "_get_redis") as mock_redis:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.hset = AsyncMock(return_value=1)
|
||||
mock_client.expire = AsyncMock(return_value=True)
|
||||
mock_client.hgetall = AsyncMock(
|
||||
return_value={"org_id": "1", "slot_id": "test_slot_123"}
|
||||
)
|
||||
mock_client.delete = AsyncMock(return_value=1)
|
||||
mock_redis.return_value = mock_client
|
||||
|
||||
# Test storing mapping
|
||||
success = await rate_limiter.store_workflow_slot_mapping(
|
||||
workflow_run_id=999, organization_id=1, slot_id="test_slot_123"
|
||||
)
|
||||
assert success is True
|
||||
mock_client.hset.assert_called_once()
|
||||
mock_client.expire.assert_called_once()
|
||||
|
||||
# Test retrieving mapping
|
||||
mapping = await rate_limiter.get_workflow_slot_mapping(workflow_run_id=999)
|
||||
assert mapping == (1, "test_slot_123")
|
||||
mock_client.hgetall.assert_called_once_with("workflow_slot_mapping:999")
|
||||
|
||||
# Test deleting mapping
|
||||
deleted = await rate_limiter.delete_workflow_slot_mapping(
|
||||
workflow_run_id=999
|
||||
)
|
||||
assert deleted is True
|
||||
mock_client.delete.assert_called_once_with("workflow_slot_mapping:999")
|
||||
|
||||
|
||||
class TestCampaignCallDispatcher:
|
||||
"""Test suite for CampaignCallDispatcher with concurrent limiting."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_call_waits_for_slot(self):
|
||||
"""Test that dispatch_call waits for available slot."""
|
||||
from api.services.campaign.call_dispatcher import CampaignCallDispatcher
|
||||
|
||||
dispatcher = CampaignCallDispatcher()
|
||||
|
||||
# Mock dependencies
|
||||
mock_campaign = MagicMock(
|
||||
organization_id=1, workflow_id=123, id=456, created_by=789
|
||||
)
|
||||
mock_queued_run = MagicMock(
|
||||
id=111, context_variables={"phone_number": "+1234567890"}
|
||||
)
|
||||
|
||||
# Mock rate limiter to simulate waiting
|
||||
slot_acquired = False
|
||||
call_count = 0
|
||||
|
||||
async def mock_try_acquire(org_id, max_concurrent):
|
||||
nonlocal slot_acquired, call_count
|
||||
call_count += 1
|
||||
if call_count > 2: # Succeed on third try
|
||||
slot_acquired = True
|
||||
return "test_slot_123"
|
||||
return None
|
||||
|
||||
with patch(
|
||||
"api.services.campaign.call_dispatcher.rate_limiter"
|
||||
) as mock_limiter:
|
||||
mock_limiter.try_acquire_concurrent_slot = AsyncMock(
|
||||
side_effect=mock_try_acquire
|
||||
)
|
||||
mock_limiter.release_concurrent_slot = AsyncMock()
|
||||
mock_limiter.store_workflow_slot_mapping = AsyncMock(return_value=True)
|
||||
|
||||
with patch("api.services.campaign.call_dispatcher.db_client") as mock_db:
|
||||
mock_db.get_configuration = AsyncMock(return_value=None)
|
||||
mock_db.get_workflow_by_id = AsyncMock(
|
||||
return_value=MagicMock(template_context_variables={})
|
||||
)
|
||||
mock_db.create_workflow_run = AsyncMock(
|
||||
return_value=MagicMock(id=999, logs={})
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
dispatcher.twilio_service, "initiate_call"
|
||||
) as mock_twilio:
|
||||
mock_twilio.return_value = {"sid": "test_sid"}
|
||||
|
||||
# Dispatch call (should wait and retry)
|
||||
workflow_run = await dispatcher.dispatch_call(
|
||||
mock_queued_run, mock_campaign
|
||||
)
|
||||
|
||||
assert workflow_run is not None
|
||||
assert slot_acquired is True
|
||||
assert call_count == 3 # Tried 3 times
|
||||
assert mock_limiter.try_acquire_concurrent_slot.call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_call_stores_slot_mapping(self):
|
||||
"""Test that dispatch_call stores slot mapping in Redis."""
|
||||
from api.services.campaign.call_dispatcher import CampaignCallDispatcher
|
||||
|
||||
dispatcher = CampaignCallDispatcher()
|
||||
|
||||
# Mock dependencies
|
||||
mock_campaign = MagicMock(
|
||||
organization_id=1, workflow_id=123, id=456, created_by=789
|
||||
)
|
||||
mock_queued_run = MagicMock(
|
||||
id=111, context_variables={"phone_number": "+1234567890"}
|
||||
)
|
||||
|
||||
with patch(
|
||||
"api.services.campaign.call_dispatcher.rate_limiter"
|
||||
) as mock_limiter:
|
||||
mock_limiter.try_acquire_concurrent_slot = AsyncMock(
|
||||
return_value="test_slot_123"
|
||||
)
|
||||
mock_limiter.store_workflow_slot_mapping = AsyncMock(return_value=True)
|
||||
|
||||
with patch("api.services.campaign.call_dispatcher.db_client") as mock_db:
|
||||
mock_db.get_configuration = AsyncMock(return_value=None)
|
||||
mock_db.get_workflow_by_id = AsyncMock(
|
||||
return_value=MagicMock(template_context_variables={})
|
||||
)
|
||||
mock_db.create_workflow_run = AsyncMock(
|
||||
return_value=MagicMock(id=999, logs={})
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
dispatcher.twilio_service, "initiate_call"
|
||||
) as mock_twilio:
|
||||
mock_twilio.return_value = {"sid": "test_sid"}
|
||||
|
||||
# Dispatch call
|
||||
workflow_run = await dispatcher.dispatch_call(
|
||||
mock_queued_run, mock_campaign
|
||||
)
|
||||
|
||||
# Verify slot mapping was stored
|
||||
mock_limiter.store_workflow_slot_mapping.assert_called_once_with(
|
||||
999, 1, "test_slot_123"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_org_specific_concurrent_limit(self):
|
||||
"""Test that organization-specific concurrent limit is used."""
|
||||
from api.services.campaign.call_dispatcher import CampaignCallDispatcher
|
||||
|
||||
dispatcher = CampaignCallDispatcher()
|
||||
|
||||
# Mock db_client to return org-specific limit
|
||||
with patch("api.services.campaign.call_dispatcher.db_client") as mock_db:
|
||||
mock_config = MagicMock(value={"value": 10}) # Org limit is 10
|
||||
mock_db.get_configuration = AsyncMock(return_value=mock_config)
|
||||
|
||||
# Get org limit
|
||||
limit = await dispatcher.get_org_concurrent_limit(organization_id=1)
|
||||
|
||||
assert limit == 10 # Should use org-specific limit
|
||||
mock_db.get_configuration.assert_called_once_with(
|
||||
1, OrganizationConfigurationKey.CONCURRENT_CALL_LIMIT.value
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_concurrent_limit(self):
|
||||
"""Test that default limit is used when org config not found."""
|
||||
from api.services.campaign.call_dispatcher import CampaignCallDispatcher
|
||||
|
||||
dispatcher = CampaignCallDispatcher()
|
||||
|
||||
# Mock db_client to return None (no config)
|
||||
with patch("api.services.campaign.call_dispatcher.db_client") as mock_db:
|
||||
mock_db.get_configuration = AsyncMock(return_value=None)
|
||||
|
||||
# Get org limit
|
||||
limit = await dispatcher.get_org_concurrent_limit(organization_id=1)
|
||||
|
||||
assert limit == 20 # Should use default limit
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_release_call_slot(self):
|
||||
"""Test releasing call slot when workflow completes."""
|
||||
from api.services.campaign.call_dispatcher import CampaignCallDispatcher
|
||||
|
||||
dispatcher = CampaignCallDispatcher()
|
||||
|
||||
# Mock rate limiter
|
||||
with patch(
|
||||
"api.services.campaign.call_dispatcher.rate_limiter"
|
||||
) as mock_limiter:
|
||||
# Mock getting the slot mapping from Redis
|
||||
mock_limiter.get_workflow_slot_mapping = AsyncMock(
|
||||
return_value=(1, "test_slot_123")
|
||||
)
|
||||
mock_limiter.release_concurrent_slot = AsyncMock(return_value=True)
|
||||
mock_limiter.delete_workflow_slot_mapping = AsyncMock(return_value=True)
|
||||
|
||||
# Release slot
|
||||
success = await dispatcher.release_call_slot(workflow_run_id=999)
|
||||
|
||||
assert success is True
|
||||
mock_limiter.get_workflow_slot_mapping.assert_called_once_with(999)
|
||||
mock_limiter.release_concurrent_slot.assert_called_once_with(
|
||||
1, "test_slot_123"
|
||||
)
|
||||
mock_limiter.delete_workflow_slot_mapping.assert_called_once_with(999)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
79
api/tests/test_configuration_masking_merge.py
Normal file
79
api/tests/test_configuration_masking_merge.py
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from api.schemas.user_configuration import UserConfiguration
|
||||
from api.services.configuration.masking import is_mask_of, mask_key, mask_user_config
|
||||
from api.services.configuration.merge import merge_user_configurations
|
||||
from api.services.configuration.registry import (
|
||||
GroqModel,
|
||||
OpenAILLMService,
|
||||
)
|
||||
|
||||
REAL_KEY = "sk-1234567890abcdef"
|
||||
|
||||
|
||||
def _build_config_with_openai(key: str) -> UserConfiguration:
|
||||
return UserConfiguration(
|
||||
llm=OpenAILLMService(api_key=key),
|
||||
stt=None,
|
||||
tts=None,
|
||||
)
|
||||
|
||||
|
||||
def test_mask_key_basic():
|
||||
masked = mask_key(REAL_KEY)
|
||||
# Should reveal only last 4 chars
|
||||
assert masked.endswith(REAL_KEY[-4:])
|
||||
assert set(masked[:-4]) == {"*"}
|
||||
assert len(masked) == len(REAL_KEY)
|
||||
# is_mask_of round-trip
|
||||
assert is_mask_of(masked, REAL_KEY)
|
||||
|
||||
|
||||
def test_mask_user_config_masks_api_keys():
|
||||
cfg = _build_config_with_openai(REAL_KEY)
|
||||
dumped = mask_user_config(cfg)
|
||||
assert dumped["llm"]["api_key"].endswith(REAL_KEY[-4:])
|
||||
assert dumped["llm"]["api_key"].startswith("*" * (len(REAL_KEY) - 4))
|
||||
|
||||
|
||||
def test_merge_preserves_key_when_mask_sent():
|
||||
existing = _build_config_with_openai(REAL_KEY)
|
||||
incoming_partial = {
|
||||
"llm": {
|
||||
"provider": "openai",
|
||||
"model": existing.llm.model,
|
||||
"api_key": mask_key(REAL_KEY), # masked placeholder
|
||||
}
|
||||
}
|
||||
|
||||
merged = merge_user_configurations(existing, incoming_partial)
|
||||
assert merged.llm.api_key == REAL_KEY # key preserved
|
||||
|
||||
|
||||
def test_merge_replaces_key_when_new_key_provided():
|
||||
existing = _build_config_with_openai(REAL_KEY)
|
||||
new_key = "sk-replaced-9999"
|
||||
incoming_partial = {
|
||||
"llm": {
|
||||
"provider": "openai",
|
||||
"model": existing.llm.model,
|
||||
"api_key": new_key,
|
||||
}
|
||||
}
|
||||
merged = merge_user_configurations(existing, incoming_partial)
|
||||
assert merged.llm.api_key == new_key
|
||||
|
||||
|
||||
def test_merge_drops_old_key_when_provider_changes():
|
||||
existing = _build_config_with_openai(REAL_KEY)
|
||||
incoming_partial = {
|
||||
"llm": {
|
||||
"provider": "groq",
|
||||
"model": GroqModel.LLAMA_3_3_70B,
|
||||
# api_key intentionally absent – should NOT inherit old key
|
||||
}
|
||||
}
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
merge_user_configurations(existing, incoming_partial)
|
||||
33
api/tests/test_default_user_configuration.py
Normal file
33
api/tests/test_default_user_configuration.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
import os
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from api.db.user_client import UserClient
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_configuration_created(db_session):
|
||||
# Set env variable for openai to simulate availability of default key
|
||||
os.environ["OPENAI_API_KEY"] = "sk-test-openai-key"
|
||||
|
||||
# Ensure deepgram env variable absent to focus test
|
||||
os.environ.pop("DEEPGRAM_API_KEY", None)
|
||||
|
||||
# Generate a unique (random) provider user ID for each test run
|
||||
test_provider_user_id = f"provider_user_{uuid.uuid4().hex}"
|
||||
user_client: UserClient = db_session # db_session fixture yields the client
|
||||
|
||||
user_model = await user_client.get_or_create_user_by_provider_id(
|
||||
test_provider_user_id
|
||||
)
|
||||
|
||||
config = await user_client.get_user_configurations(user_model.id)
|
||||
|
||||
assert config.llm is not None, "LLM config should be created when env key present"
|
||||
assert config.llm.provider == ServiceProviders.OPENAI
|
||||
assert config.llm.api_key == "sk-test-openai-key"
|
||||
|
||||
# Cleanup / restore env variable side-effects
|
||||
os.environ.pop("OPENAI_API_KEY", None)
|
||||
122
api/tests/test_disposition_mapper.py
Normal file
122
api/tests/test_disposition_mapper.py
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services.workflow.disposition_mapper import (
|
||||
apply_disposition_mapping,
|
||||
get_organization_id_from_workflow_run,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_disposition_mapping_with_valid_mapping():
|
||||
"""Test disposition mapping with valid configuration."""
|
||||
with patch("api.services.workflow.disposition_mapper.db_client") as mock_db_client:
|
||||
# Mock disposition mapping configuration
|
||||
mock_db_client.get_configuration_value = AsyncMock(
|
||||
return_value={
|
||||
"XFER": "TRANSFERRED",
|
||||
"ND": "NOT_QUALIFIED",
|
||||
"user_hangup": "HANGUP",
|
||||
}
|
||||
)
|
||||
|
||||
# Test mapping exists
|
||||
result = await apply_disposition_mapping("XFER", 1)
|
||||
assert result == "TRANSFERRED"
|
||||
|
||||
# Test mapping doesn't exist
|
||||
result = await apply_disposition_mapping("UNKNOWN", 1)
|
||||
assert result == "UNKNOWN"
|
||||
|
||||
# Verify db_client was called correctly
|
||||
mock_db_client.get_configuration_value.assert_called_with(
|
||||
1, "DISPOSITION_CODE_MAPPING", default={}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_disposition_mapping_no_organization_id():
|
||||
"""Test disposition mapping with no organization ID."""
|
||||
# Should return original value
|
||||
result = await apply_disposition_mapping("XFER", None)
|
||||
assert result == "XFER"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_disposition_mapping_empty_value():
|
||||
"""Test disposition mapping with empty value."""
|
||||
# Should return original empty value
|
||||
result = await apply_disposition_mapping("", 1)
|
||||
assert result == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_disposition_mapping_error_handling():
|
||||
"""Test disposition mapping handles errors gracefully."""
|
||||
with patch("api.services.workflow.disposition_mapper.db_client") as mock_db_client:
|
||||
# Mock database error
|
||||
mock_db_client.get_configuration_value = AsyncMock(
|
||||
side_effect=Exception("Database error")
|
||||
)
|
||||
|
||||
# Should return original value on error
|
||||
result = await apply_disposition_mapping("XFER", 1)
|
||||
assert result == "XFER"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_organization_id_from_workflow_run():
|
||||
"""Test getting organization ID from workflow run ID."""
|
||||
with patch("api.services.workflow.disposition_mapper.db_client") as mock_db_client:
|
||||
# Mock workflow run with organization
|
||||
mock_workflow_run = MagicMock()
|
||||
mock_workflow_run.workflow.user.selected_organization_id = 123
|
||||
mock_db_client.get_workflow_run_by_id = AsyncMock(
|
||||
return_value=mock_workflow_run
|
||||
)
|
||||
|
||||
result = await get_organization_id_from_workflow_run(1)
|
||||
assert result == 123
|
||||
|
||||
# Verify db_client was called correctly
|
||||
mock_db_client.get_workflow_run_by_id.assert_called_once_with(1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_organization_id_no_workflow_run():
|
||||
"""Test getting organization ID when workflow run doesn't exist."""
|
||||
with patch("api.services.workflow.disposition_mapper.db_client") as mock_db_client:
|
||||
# Mock no workflow run found
|
||||
mock_db_client.get_workflow_run_by_id = AsyncMock(return_value=None)
|
||||
|
||||
result = await get_organization_id_from_workflow_run(1)
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_organization_id_no_user():
|
||||
"""Test getting organization ID when workflow has no user."""
|
||||
with patch("api.services.workflow.disposition_mapper.db_client") as mock_db_client:
|
||||
# Mock workflow run with no user
|
||||
mock_workflow_run = MagicMock()
|
||||
mock_workflow_run.workflow.user = None
|
||||
mock_db_client.get_workflow_run_by_id = AsyncMock(
|
||||
return_value=mock_workflow_run
|
||||
)
|
||||
|
||||
result = await get_organization_id_from_workflow_run(1)
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_organization_id_error_handling():
|
||||
"""Test getting organization ID handles errors gracefully."""
|
||||
with patch("api.services.workflow.disposition_mapper.db_client") as mock_db_client:
|
||||
# Mock database error
|
||||
mock_db_client.get_workflow_run_by_id = AsyncMock(
|
||||
side_effect=Exception("Database error")
|
||||
)
|
||||
|
||||
result = await get_organization_id_from_workflow_run(1)
|
||||
assert result is None
|
||||
370
api/tests/test_event_handler_disposition_mapping.py
Normal file
370
api/tests/test_event_handler_disposition_mapping.py
Normal file
|
|
@ -0,0 +1,370 @@
|
|||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pipecat.utils.enums import EndTaskReason
|
||||
|
||||
from api.services.pipecat.event_handlers import register_transport_event_handlers
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dependencies():
|
||||
"""Create mock dependencies for event handlers."""
|
||||
# Store registered handlers
|
||||
registered_handlers = {}
|
||||
|
||||
def mock_event_handler(event_name):
|
||||
def decorator(func):
|
||||
registered_handlers[event_name] = func
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
mock_transport = MagicMock()
|
||||
mock_transport.event_handler = mock_event_handler
|
||||
|
||||
mock_task = MagicMock()
|
||||
mock_task.cancel = AsyncMock()
|
||||
|
||||
mock_engine = MagicMock()
|
||||
mock_engine.initialize = AsyncMock()
|
||||
mock_engine.cleanup = AsyncMock()
|
||||
|
||||
mock_audio_buffer = MagicMock()
|
||||
mock_audio_buffer.start_recording = AsyncMock()
|
||||
mock_audio_buffer.stop_recording = AsyncMock()
|
||||
|
||||
mock_usage_metrics_aggregator = MagicMock()
|
||||
mock_usage_metrics_aggregator.get_all_usage_metrics_serialized = MagicMock(
|
||||
return_value={"test": "metrics"}
|
||||
)
|
||||
|
||||
return {
|
||||
"transport": mock_transport,
|
||||
"workflow_run_id": 123,
|
||||
"audio_buffer": mock_audio_buffer,
|
||||
"task": mock_task,
|
||||
"engine": mock_engine,
|
||||
"usage_metrics_aggregator": mock_usage_metrics_aggregator,
|
||||
"audio_synchronizer": None,
|
||||
"registered_handlers": registered_handlers,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transport_disconnect_reason_mapping(mock_dependencies):
|
||||
"""Test that transport_disconnect_reason is mapped when no engine disconnect reason exists."""
|
||||
# Register event handlers
|
||||
register_transport_event_handlers(
|
||||
transport=mock_dependencies["transport"],
|
||||
workflow_run_id=mock_dependencies["workflow_run_id"],
|
||||
audio_buffer=mock_dependencies["audio_buffer"],
|
||||
task=mock_dependencies["task"],
|
||||
engine=mock_dependencies["engine"],
|
||||
usage_metrics_aggregator=mock_dependencies["usage_metrics_aggregator"],
|
||||
audio_synchronizer=mock_dependencies["audio_synchronizer"],
|
||||
)
|
||||
|
||||
# Get the on_client_disconnected handler
|
||||
handler = mock_dependencies["registered_handlers"]["on_client_disconnected"]
|
||||
|
||||
# Mock engine with no call disposition
|
||||
mock_dependencies["engine"].get_call_disposition.return_value = None
|
||||
mock_dependencies["engine"].get_gathered_context.return_value = {
|
||||
"agent_name": "Alex"
|
||||
}
|
||||
|
||||
# Mock the disposition mapper functions
|
||||
with patch(
|
||||
"api.services.pipecat.event_handlers.get_organization_id_from_workflow_run",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_org_id:
|
||||
with patch(
|
||||
"api.services.pipecat.event_handlers.apply_disposition_mapping",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_apply_mapping:
|
||||
with patch(
|
||||
"api.services.pipecat.event_handlers.db_client"
|
||||
) as mock_db_client:
|
||||
with patch(
|
||||
"api.services.pipecat.event_handlers.enqueue_job"
|
||||
) as mock_enqueue:
|
||||
# Mock organization ID
|
||||
mock_get_org_id.return_value = 1
|
||||
|
||||
# Mock call duration for user_hangup logic
|
||||
mock_dependencies[
|
||||
"usage_metrics_aggregator"
|
||||
].get_call_duration.return_value = 15
|
||||
|
||||
# Mock disposition mapping
|
||||
async def apply_mapping_side_effect(value, org_id):
|
||||
return {
|
||||
"NIBP": "NOT_INTERESTED_BUSINESS_PURPOSE",
|
||||
"user_qualified": "QUALIFIED",
|
||||
}.get(value, value)
|
||||
|
||||
mock_apply_mapping.side_effect = apply_mapping_side_effect
|
||||
|
||||
# Mock database operations
|
||||
mock_workflow_run = MagicMock()
|
||||
mock_workflow_run.id = 123
|
||||
mock_workflow_run.workflow_id = 1
|
||||
mock_workflow_run.organization_id = 1
|
||||
mock_workflow_run.gathered_context = {}
|
||||
mock_db_client.get_workflow_run_by_id = AsyncMock(
|
||||
return_value=mock_workflow_run
|
||||
)
|
||||
mock_db_client.update_workflow_run = AsyncMock()
|
||||
|
||||
# Call handler with transport_disconnect_reason
|
||||
await handler(
|
||||
mock_dependencies["transport"],
|
||||
participant=None,
|
||||
transport_disconnect_reason="user_hangup",
|
||||
)
|
||||
|
||||
# Verify disposition mapping was applied with NIBP (since duration > 10)
|
||||
mock_apply_mapping.assert_called_once_with("NIBP", 1)
|
||||
|
||||
# Verify database was updated with mapped value
|
||||
mock_db_client.update_workflow_run.assert_called_once()
|
||||
call_args = mock_db_client.update_workflow_run.call_args
|
||||
assert (
|
||||
call_args[1]["gathered_context"]["mapped_call_disposition"]
|
||||
== "NOT_INTERESTED_BUSINESS_PURPOSE"
|
||||
)
|
||||
|
||||
# Verify task was cancelled (no engine disconnect reason)
|
||||
mock_dependencies["task"].cancel.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transport_disconnect_reason_user_hangup_short_call(mock_dependencies):
|
||||
"""Test that user_hangup with short call duration is mapped to HU."""
|
||||
# Register event handlers
|
||||
register_transport_event_handlers(
|
||||
transport=mock_dependencies["transport"],
|
||||
workflow_run_id=mock_dependencies["workflow_run_id"],
|
||||
audio_buffer=mock_dependencies["audio_buffer"],
|
||||
task=mock_dependencies["task"],
|
||||
engine=mock_dependencies["engine"],
|
||||
usage_metrics_aggregator=mock_dependencies["usage_metrics_aggregator"],
|
||||
audio_synchronizer=mock_dependencies["audio_synchronizer"],
|
||||
)
|
||||
|
||||
# Get the on_client_disconnected handler
|
||||
handler = mock_dependencies["registered_handlers"]["on_client_disconnected"]
|
||||
|
||||
# Mock engine with no call disposition
|
||||
mock_dependencies["engine"].get_call_disposition.return_value = None
|
||||
mock_dependencies["engine"].get_gathered_context.return_value = {
|
||||
"agent_name": "Alex"
|
||||
}
|
||||
|
||||
# Mock the disposition mapper functions
|
||||
with patch(
|
||||
"api.services.pipecat.event_handlers.get_organization_id_from_workflow_run",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_org_id:
|
||||
with patch(
|
||||
"api.services.pipecat.event_handlers.apply_disposition_mapping",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_apply_mapping:
|
||||
with patch(
|
||||
"api.services.pipecat.event_handlers.db_client"
|
||||
) as mock_db_client:
|
||||
with patch(
|
||||
"api.services.pipecat.event_handlers.enqueue_job"
|
||||
) as mock_enqueue:
|
||||
# Mock organization ID
|
||||
mock_get_org_id.return_value = 1
|
||||
|
||||
# Mock call duration for user_hangup logic (< 10 seconds)
|
||||
mock_dependencies[
|
||||
"usage_metrics_aggregator"
|
||||
].get_call_duration.return_value = 5
|
||||
|
||||
# Mock disposition mapping
|
||||
mock_apply_mapping.return_value = "HANGUP"
|
||||
|
||||
# Mock database operations
|
||||
mock_workflow_run = MagicMock()
|
||||
mock_workflow_run.id = 123
|
||||
mock_workflow_run.workflow_id = 1
|
||||
mock_workflow_run.organization_id = 1
|
||||
mock_workflow_run.gathered_context = {}
|
||||
mock_db_client.get_workflow_run_by_id = AsyncMock(
|
||||
return_value=mock_workflow_run
|
||||
)
|
||||
mock_db_client.update_workflow_run = AsyncMock()
|
||||
|
||||
# Call handler with transport_disconnect_reason
|
||||
await handler(
|
||||
mock_dependencies["transport"],
|
||||
participant=None,
|
||||
transport_disconnect_reason="user_hangup",
|
||||
)
|
||||
|
||||
# Verify disposition mapping was applied with HU (since duration < 10)
|
||||
mock_apply_mapping.assert_called_once_with("HU", 1)
|
||||
|
||||
# Verify database was updated with mapped value
|
||||
mock_db_client.update_workflow_run.assert_called_once()
|
||||
call_args = mock_db_client.update_workflow_run.call_args
|
||||
assert (
|
||||
call_args[1]["gathered_context"]["mapped_call_disposition"]
|
||||
== "HANGUP"
|
||||
)
|
||||
|
||||
# Verify task was cancelled (no engine disconnect reason)
|
||||
mock_dependencies["task"].cancel.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_engine_disconnect_reason_takes_precedence(mock_dependencies):
|
||||
"""Test that engine disconnect reason takes precedence and is not mapped."""
|
||||
# Register event handlers
|
||||
register_transport_event_handlers(
|
||||
transport=mock_dependencies["transport"],
|
||||
workflow_run_id=mock_dependencies["workflow_run_id"],
|
||||
audio_buffer=mock_dependencies["audio_buffer"],
|
||||
task=mock_dependencies["task"],
|
||||
engine=mock_dependencies["engine"],
|
||||
usage_metrics_aggregator=mock_dependencies["usage_metrics_aggregator"],
|
||||
audio_synchronizer=mock_dependencies["audio_synchronizer"],
|
||||
)
|
||||
|
||||
# Get the on_client_disconnected handler
|
||||
handler = mock_dependencies["registered_handlers"]["on_client_disconnected"]
|
||||
|
||||
# Mock engine with call disposition
|
||||
mock_dependencies["engine"].get_call_disposition.return_value = "user_qualified"
|
||||
mock_dependencies["engine"].get_gathered_context.return_value = {
|
||||
"agent_name": "Alex"
|
||||
}
|
||||
|
||||
# Mock the disposition mapper functions
|
||||
with patch(
|
||||
"api.services.pipecat.event_handlers.get_organization_id_from_workflow_run",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_org_id:
|
||||
with patch(
|
||||
"api.services.pipecat.event_handlers.apply_disposition_mapping",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_apply_mapping:
|
||||
with patch(
|
||||
"api.services.pipecat.event_handlers.db_client"
|
||||
) as mock_db_client:
|
||||
with patch(
|
||||
"api.services.pipecat.event_handlers.enqueue_job"
|
||||
) as mock_enqueue:
|
||||
# Mock organization ID
|
||||
mock_get_org_id.return_value = 1
|
||||
|
||||
# Mock disposition mapping for engine's reason
|
||||
mock_apply_mapping.return_value = "QUALIFIED"
|
||||
|
||||
# Mock database operations
|
||||
mock_workflow_run = MagicMock()
|
||||
mock_workflow_run.id = 123
|
||||
mock_workflow_run.workflow_id = 1
|
||||
mock_workflow_run.organization_id = 1
|
||||
mock_workflow_run.gathered_context = {}
|
||||
mock_db_client.get_workflow_run_by_id = AsyncMock(
|
||||
return_value=mock_workflow_run
|
||||
)
|
||||
mock_db_client.update_workflow_run = AsyncMock()
|
||||
|
||||
# Call handler with transport_disconnect_reason
|
||||
await handler(
|
||||
mock_dependencies["transport"],
|
||||
participant=None,
|
||||
transport_disconnect_reason="user_hangup",
|
||||
)
|
||||
|
||||
# Verify disposition mapping was called with engine's reason
|
||||
mock_apply_mapping.assert_called_once_with("user_qualified", 1)
|
||||
|
||||
# Verify database was updated with mapped value
|
||||
mock_db_client.update_workflow_run.assert_called_once()
|
||||
call_args = mock_db_client.update_workflow_run.call_args
|
||||
assert (
|
||||
call_args[1]["gathered_context"]["mapped_call_disposition"]
|
||||
== "QUALIFIED"
|
||||
)
|
||||
|
||||
# Verify task was NOT cancelled (engine disconnect reason exists)
|
||||
mock_dependencies["task"].cancel.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_disconnect_reason_uses_unknown(mock_dependencies):
|
||||
"""Test that when no disconnect reason is provided, UNKNOWN is used."""
|
||||
# Register event handlers
|
||||
register_transport_event_handlers(
|
||||
transport=mock_dependencies["transport"],
|
||||
workflow_run_id=mock_dependencies["workflow_run_id"],
|
||||
audio_buffer=mock_dependencies["audio_buffer"],
|
||||
task=mock_dependencies["task"],
|
||||
engine=mock_dependencies["engine"],
|
||||
usage_metrics_aggregator=mock_dependencies["usage_metrics_aggregator"],
|
||||
audio_synchronizer=mock_dependencies["audio_synchronizer"],
|
||||
)
|
||||
|
||||
# Get the on_client_disconnected handler
|
||||
handler = mock_dependencies["registered_handlers"]["on_client_disconnected"]
|
||||
|
||||
# Mock engine with no call disposition
|
||||
mock_dependencies["engine"].get_call_disposition.return_value = None
|
||||
mock_dependencies["engine"].get_gathered_context.return_value = {
|
||||
"agent_name": "Alex"
|
||||
}
|
||||
|
||||
with patch(
|
||||
"api.services.pipecat.event_handlers.get_organization_id_from_workflow_run"
|
||||
) as mock_get_org_id:
|
||||
with patch(
|
||||
"api.services.pipecat.event_handlers.apply_disposition_mapping"
|
||||
) as mock_apply_mapping:
|
||||
with patch(
|
||||
"api.services.pipecat.event_handlers.db_client"
|
||||
) as mock_db_client:
|
||||
with patch(
|
||||
"api.services.pipecat.event_handlers.enqueue_job"
|
||||
) as mock_enqueue:
|
||||
# Mock organization ID
|
||||
mock_get_org_id.return_value = 1
|
||||
|
||||
# Mock disposition mapping - should return UNKNOWN as-is
|
||||
mock_apply_mapping.return_value = EndTaskReason.UNKNOWN.value
|
||||
|
||||
# Mock database operations
|
||||
mock_workflow_run = MagicMock()
|
||||
mock_workflow_run.id = 123
|
||||
mock_workflow_run.workflow_id = 1
|
||||
mock_workflow_run.organization_id = 1
|
||||
mock_workflow_run.gathered_context = {}
|
||||
mock_db_client.get_workflow_run_by_id = AsyncMock(
|
||||
return_value=mock_workflow_run
|
||||
)
|
||||
mock_db_client.update_workflow_run = AsyncMock()
|
||||
|
||||
# Call handler without transport_disconnect_reason
|
||||
await handler(
|
||||
mock_dependencies["transport"],
|
||||
participant=None,
|
||||
transport_disconnect_reason=None,
|
||||
)
|
||||
|
||||
# Verify disposition mapping was called with UNKNOWN
|
||||
mock_apply_mapping.assert_called_once_with(
|
||||
EndTaskReason.UNKNOWN.value, 1
|
||||
)
|
||||
|
||||
# Verify database was updated with UNKNOWN
|
||||
mock_db_client.update_workflow_run.assert_called_once()
|
||||
call_args = mock_db_client.update_workflow_run.call_args
|
||||
assert (
|
||||
call_args[1]["gathered_context"]["mapped_call_disposition"]
|
||||
== EndTaskReason.UNKNOWN.value
|
||||
)
|
||||
184
api/tests/test_event_handlers_refactor.py
Normal file
184
api/tests/test_event_handlers_refactor.py
Normal file
|
|
@ -0,0 +1,184 @@
|
|||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services.pipecat.audio_config import AudioConfig
|
||||
from api.services.pipecat.event_handlers import (
|
||||
register_audio_data_handler,
|
||||
register_transcript_handler,
|
||||
register_transport_event_handlers,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transport_handlers_with_in_memory_buffers():
|
||||
"""Test that transport handlers create and return in-memory buffers."""
|
||||
# Mock dependencies
|
||||
transport = MagicMock()
|
||||
transport.event_handler = lambda event_name: lambda func: func
|
||||
|
||||
audio_buffer = AsyncMock()
|
||||
audio_synchronizer = AsyncMock()
|
||||
task = AsyncMock()
|
||||
engine = AsyncMock()
|
||||
engine.get_call_disposition.return_value = None
|
||||
engine.get_gathered_context.return_value = {}
|
||||
|
||||
usage_metrics_aggregator = AsyncMock()
|
||||
usage_metrics_aggregator.get_call_duration.return_value = 30
|
||||
usage_metrics_aggregator.get_all_usage_metrics_serialized.return_value = {}
|
||||
|
||||
# Create test audio config
|
||||
audio_config = AudioConfig(
|
||||
transport_in_sample_rate=16000,
|
||||
transport_out_sample_rate=16000,
|
||||
pipeline_sample_rate=16000,
|
||||
)
|
||||
|
||||
# Register handlers
|
||||
audio_buf, transcript_buf = register_transport_event_handlers(
|
||||
transport=transport,
|
||||
workflow_run_id=123,
|
||||
audio_buffer=audio_buffer,
|
||||
task=task,
|
||||
engine=engine,
|
||||
usage_metrics_aggregator=usage_metrics_aggregator,
|
||||
audio_synchronizer=audio_synchronizer,
|
||||
audio_config=audio_config,
|
||||
)
|
||||
|
||||
# Verify buffers were created with correct configuration
|
||||
assert audio_buf is not None
|
||||
assert transcript_buf is not None
|
||||
assert audio_buf._workflow_run_id == 123
|
||||
assert audio_buf._sample_rate == 16000
|
||||
assert audio_buf._num_channels == 1
|
||||
assert transcript_buf._workflow_run_id == 123
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audio_handler_with_in_memory_buffer():
|
||||
"""Test audio handler uses in-memory buffer when provided."""
|
||||
# Mock audio synchronizer
|
||||
audio_synchronizer = MagicMock()
|
||||
handlers = {}
|
||||
|
||||
def mock_event_handler(event_name):
|
||||
def decorator(func):
|
||||
handlers[event_name] = func
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
audio_synchronizer.event_handler = mock_event_handler
|
||||
|
||||
# Mock in-memory buffer
|
||||
in_memory_buffer = AsyncMock()
|
||||
|
||||
# Register handler with buffer
|
||||
register_audio_data_handler(
|
||||
audio_synchronizer, workflow_run_id=123, in_memory_buffer=in_memory_buffer
|
||||
)
|
||||
|
||||
# Test the handler
|
||||
assert "on_merged_audio" in handlers
|
||||
handler = handlers["on_merged_audio"]
|
||||
|
||||
# Call handler with test data
|
||||
test_pcm = b"test_audio_data"
|
||||
await handler(None, test_pcm, 16000, 1)
|
||||
|
||||
# Verify buffer was used
|
||||
in_memory_buffer.append.assert_called_once_with(test_pcm)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_handler_with_in_memory_buffer():
|
||||
"""Test transcript handler uses in-memory buffer when provided."""
|
||||
# Mock transcript processor
|
||||
transcript = MagicMock()
|
||||
handlers = {}
|
||||
|
||||
def mock_event_handler(event_name):
|
||||
def decorator(func):
|
||||
handlers[event_name] = func
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
transcript.event_handler = mock_event_handler
|
||||
|
||||
# Mock in-memory buffer
|
||||
in_memory_buffer = AsyncMock()
|
||||
|
||||
# Register handler with buffer
|
||||
register_transcript_handler(
|
||||
transcript, workflow_run_id=456, in_memory_buffer=in_memory_buffer
|
||||
)
|
||||
|
||||
# Create test frame
|
||||
test_frame = MagicMock()
|
||||
test_frame.messages = [
|
||||
MagicMock(timestamp="00:00:01", role="user", content="Hello"),
|
||||
MagicMock(timestamp="00:00:02", role="assistant", content="Hi there"),
|
||||
]
|
||||
|
||||
# Test the handler
|
||||
handler = handlers["on_transcript_update"]
|
||||
await handler(None, test_frame)
|
||||
|
||||
# Verify buffer was used with correct format
|
||||
expected_text = "[00:00:01] user: Hello\n[00:00:02] assistant: Hi there\n"
|
||||
in_memory_buffer.append.assert_called_once_with(expected_text)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audio_config_sample_rates():
|
||||
"""Test that different audio configs result in correct sample rates."""
|
||||
# Mock dependencies
|
||||
transport = MagicMock()
|
||||
transport.event_handler = lambda event_name: lambda func: func
|
||||
|
||||
audio_buffer = AsyncMock()
|
||||
audio_synchronizer = AsyncMock()
|
||||
task = AsyncMock()
|
||||
engine = AsyncMock()
|
||||
engine.get_call_disposition.return_value = None
|
||||
engine.get_gathered_context.return_value = {}
|
||||
|
||||
usage_metrics_aggregator = AsyncMock()
|
||||
usage_metrics_aggregator.get_all_usage_metrics_serialized.return_value = {}
|
||||
|
||||
# Test with 8kHz audio config (e.g., for Stasis/Twilio)
|
||||
audio_config_8k = AudioConfig(
|
||||
transport_in_sample_rate=8000,
|
||||
transport_out_sample_rate=8000,
|
||||
pipeline_sample_rate=8000,
|
||||
)
|
||||
|
||||
audio_buf_8k, _ = register_transport_event_handlers(
|
||||
transport=transport,
|
||||
workflow_run_id=456,
|
||||
audio_buffer=audio_buffer,
|
||||
task=task,
|
||||
engine=engine,
|
||||
usage_metrics_aggregator=usage_metrics_aggregator,
|
||||
audio_synchronizer=audio_synchronizer,
|
||||
audio_config=audio_config_8k,
|
||||
)
|
||||
|
||||
assert audio_buf_8k._sample_rate == 8000
|
||||
|
||||
# Test with no audio config (should default to 16kHz)
|
||||
audio_buf_default, _ = register_transport_event_handlers(
|
||||
transport=transport,
|
||||
workflow_run_id=789,
|
||||
audio_buffer=audio_buffer,
|
||||
task=task,
|
||||
engine=engine,
|
||||
usage_metrics_aggregator=usage_metrics_aggregator,
|
||||
audio_synchronizer=audio_synchronizer,
|
||||
audio_config=None,
|
||||
)
|
||||
|
||||
assert audio_buf_default._sample_rate == 16000
|
||||
162
api/tests/test_filters.py
Normal file
162
api/tests/test_filters.py
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
"""Test filter functionality."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from api.db.filters import ATTRIBUTE_FIELD_MAPPING, apply_workflow_run_filters
|
||||
|
||||
|
||||
def test_attribute_field_mapping():
|
||||
"""Test that all required attributes are mapped."""
|
||||
expected_attributes = [
|
||||
"dateRange",
|
||||
"dispositionCode",
|
||||
"duration",
|
||||
"status",
|
||||
"tokenUsage",
|
||||
"runId",
|
||||
"workflowId",
|
||||
"callTags",
|
||||
"phoneNumber",
|
||||
]
|
||||
|
||||
for attr in expected_attributes:
|
||||
assert attr in ATTRIBUTE_FIELD_MAPPING, f"Missing mapping for {attr}"
|
||||
|
||||
|
||||
def test_filter_with_explicit_type():
|
||||
"""Test that filters work with explicit type from UI."""
|
||||
|
||||
# Mock query
|
||||
mock_query = MagicMock()
|
||||
mock_query.where = MagicMock(return_value=mock_query)
|
||||
|
||||
test_cases = [
|
||||
# Date range filter
|
||||
{
|
||||
"filters": [
|
||||
{
|
||||
"attribute": "dateRange",
|
||||
"type": "dateRange",
|
||||
"value": {"from": "2024-01-01", "to": "2024-01-31"},
|
||||
}
|
||||
],
|
||||
},
|
||||
# Multi-select filter
|
||||
{
|
||||
"filters": [
|
||||
{
|
||||
"attribute": "dispositionCode",
|
||||
"type": "multiSelect",
|
||||
"value": {"codes": ["XFER", "HU"]},
|
||||
}
|
||||
],
|
||||
},
|
||||
# Number range filter
|
||||
{
|
||||
"filters": [
|
||||
{
|
||||
"attribute": "duration",
|
||||
"type": "numberRange",
|
||||
"value": {"min": 60, "max": 300},
|
||||
}
|
||||
],
|
||||
},
|
||||
# Radio/status filter
|
||||
{
|
||||
"filters": [
|
||||
{
|
||||
"attribute": "status",
|
||||
"type": "radio",
|
||||
"value": {"status": "completed"},
|
||||
}
|
||||
],
|
||||
},
|
||||
# Number filter
|
||||
{
|
||||
"filters": [
|
||||
{"attribute": "runId", "type": "number", "value": {"value": 123}}
|
||||
],
|
||||
},
|
||||
# Text filter
|
||||
{
|
||||
"filters": [
|
||||
{
|
||||
"attribute": "phoneNumber",
|
||||
"type": "text",
|
||||
"value": {"value": "+1234567890"},
|
||||
}
|
||||
],
|
||||
},
|
||||
# Tags filter
|
||||
{
|
||||
"filters": [
|
||||
{
|
||||
"attribute": "callTags",
|
||||
"type": "tags",
|
||||
"value": {"codes": ["tag1", "tag2"]},
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
for test_case in test_cases:
|
||||
result = apply_workflow_run_filters(mock_query, test_case["filters"])
|
||||
# The function should process the filter without errors
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_filter_format_with_type():
|
||||
"""Test that filters work with attribute, type, and value."""
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.where = MagicMock(return_value=mock_query)
|
||||
|
||||
# Test with various filter combinations
|
||||
filters = [
|
||||
{
|
||||
"attribute": "dispositionCode",
|
||||
"type": "multiSelect",
|
||||
"value": {"codes": ["NIBP"]},
|
||||
},
|
||||
{
|
||||
"attribute": "duration",
|
||||
"type": "numberRange",
|
||||
"value": {"min": 0, "max": 60},
|
||||
},
|
||||
{"attribute": "phoneNumber", "type": "text", "value": {"value": "555"}},
|
||||
]
|
||||
|
||||
result = apply_workflow_run_filters(mock_query, filters)
|
||||
|
||||
# Should have called where() for applying filters
|
||||
assert mock_query.where.called
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_unknown_attribute_ignored():
|
||||
"""Test that unknown attributes are safely ignored."""
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.where = MagicMock(return_value=mock_query)
|
||||
|
||||
filters = [
|
||||
{"attribute": "unknownAttribute", "value": {"value": "test"}},
|
||||
{"attribute": "dispositionCode", "value": {"codes": ["XFER"]}},
|
||||
]
|
||||
|
||||
result = apply_workflow_run_filters(mock_query, filters)
|
||||
|
||||
# Should still process the valid filter
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_empty_filters():
|
||||
"""Test that empty filters return the query unchanged."""
|
||||
|
||||
mock_query = MagicMock()
|
||||
|
||||
result = apply_workflow_run_filters(mock_query, None)
|
||||
assert result == mock_query
|
||||
|
||||
result = apply_workflow_run_filters(mock_query, [])
|
||||
assert result == mock_query
|
||||
249
api/tests/test_global_prompt.py
Normal file
249
api/tests/test_global_prompt.py
Normal file
|
|
@ -0,0 +1,249 @@
|
|||
"""Tests for global prompt functionality in workflow engine."""
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from pipecat.services.openai.llm import OpenAILLMContext
|
||||
|
||||
from api.services.workflow.dto import (
|
||||
EdgeDataDTO,
|
||||
NodeDataDTO,
|
||||
NodeType,
|
||||
ReactFlowDTO,
|
||||
RFEdgeDTO,
|
||||
RFNodeDTO,
|
||||
)
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.workflow import WorkflowGraph
|
||||
|
||||
|
||||
class TestGlobalPrompt:
|
||||
"""Test suite for global prompt feature."""
|
||||
|
||||
@pytest.fixture
|
||||
def workflow_with_global_node(self):
|
||||
"""Create a workflow with a global node and test nodes."""
|
||||
nodes = [
|
||||
RFNodeDTO(
|
||||
id="global",
|
||||
type=NodeType.globalNode,
|
||||
position={"x": 0, "y": 0},
|
||||
data=NodeDataDTO(
|
||||
name="Global Node",
|
||||
prompt="This is the global context: {{company_name}}",
|
||||
is_static=False,
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="start",
|
||||
type=NodeType.startNode,
|
||||
position={"x": 100, "y": 100},
|
||||
data=NodeDataDTO(
|
||||
name="Start Call",
|
||||
prompt="Welcome to our service!",
|
||||
is_static=False,
|
||||
is_start=True,
|
||||
add_global_prompt=True, # Enable global prompt
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="agent1",
|
||||
type=NodeType.agentNode,
|
||||
position={"x": 200, "y": 200},
|
||||
data=NodeDataDTO(
|
||||
name="Agent 1",
|
||||
prompt="How can I help you today?",
|
||||
add_global_prompt=False, # Disable global prompt
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="agent2",
|
||||
type=NodeType.agentNode,
|
||||
position={"x": 300, "y": 300},
|
||||
data=NodeDataDTO(
|
||||
name="Agent 2",
|
||||
prompt="Please provide your details.",
|
||||
add_global_prompt=True, # Enable global prompt
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="end",
|
||||
type=NodeType.endNode,
|
||||
position={"x": 400, "y": 400},
|
||||
data=NodeDataDTO(
|
||||
name="End Call",
|
||||
prompt="Thank you for calling!",
|
||||
is_static=True,
|
||||
is_end=True,
|
||||
add_global_prompt=True, # Enable global prompt (but static)
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
edges = [
|
||||
RFEdgeDTO(
|
||||
id="e1",
|
||||
source="start",
|
||||
target="agent1",
|
||||
data=EdgeDataDTO(label="Next", condition="Continue to agent"),
|
||||
),
|
||||
RFEdgeDTO(
|
||||
id="e2",
|
||||
source="agent1",
|
||||
target="agent2",
|
||||
data=EdgeDataDTO(label="Details", condition="Get user details"),
|
||||
),
|
||||
RFEdgeDTO(
|
||||
id="e3",
|
||||
source="agent2",
|
||||
target="end",
|
||||
data=EdgeDataDTO(label="Finish", condition="End the call"),
|
||||
),
|
||||
]
|
||||
|
||||
flow_dto = ReactFlowDTO(nodes=nodes, edges=edges)
|
||||
return WorkflowGraph(flow_dto)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dependencies(self):
|
||||
"""Create mock dependencies for PipecatEngine initialization."""
|
||||
return {
|
||||
"task": Mock(),
|
||||
"llm": Mock(),
|
||||
"context": Mock(spec=OpenAILLMContext),
|
||||
"tts": Mock(),
|
||||
"transport": Mock(),
|
||||
"call_context_vars": {"company_name": "Dograh Inc"},
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def engine(self, mock_dependencies, workflow_with_global_node):
|
||||
"""Create a PipecatEngine instance with test workflow."""
|
||||
mock_dependencies["workflow"] = workflow_with_global_node
|
||||
return PipecatEngine(**mock_dependencies)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_prompt_enabled(self, engine):
|
||||
"""Test that global prompt is prepended when add_global_prompt is True."""
|
||||
# Test with start node (add_global_prompt=True)
|
||||
start_node = engine.workflow.nodes["start"]
|
||||
(
|
||||
system_message,
|
||||
functions,
|
||||
) = await engine._compose_system_message_functions_for_node(start_node)
|
||||
|
||||
# Global prompt should be included
|
||||
expected_content = (
|
||||
"This is the global context: Dograh Inc\n\nWelcome to our service!"
|
||||
)
|
||||
assert system_message["content"] == expected_content
|
||||
assert system_message["role"] == "system"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_prompt_disabled(self, engine):
|
||||
"""Test that global prompt is not prepended when add_global_prompt is False."""
|
||||
# Test with agent1 node (add_global_prompt=False)
|
||||
agent1_node = engine.workflow.nodes["agent1"]
|
||||
(
|
||||
system_message,
|
||||
functions,
|
||||
) = await engine._compose_system_message_functions_for_node(agent1_node)
|
||||
|
||||
# Global prompt should NOT be included
|
||||
expected_content = "How can I help you today?"
|
||||
assert system_message["content"] == expected_content
|
||||
assert "global context" not in system_message["content"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_prompt_with_static_node(self, engine):
|
||||
"""Test that static nodes don't use global prompt in engine (even if enabled)."""
|
||||
# Static nodes are handled differently - they use TTSSpeakFrame directly
|
||||
# This test verifies the compose_system_message behavior for completeness
|
||||
end_node = engine.workflow.nodes["end"]
|
||||
|
||||
# Even though add_global_prompt=True, static nodes handle prompts differently
|
||||
# The _compose_system_message_functions_for_node is still called for consistency
|
||||
(
|
||||
system_message,
|
||||
functions,
|
||||
) = await engine._compose_system_message_functions_for_node(end_node)
|
||||
|
||||
# For static nodes, the global prompt would still be composed if enabled
|
||||
expected_content = (
|
||||
"This is the global context: Dograh Inc\n\nThank you for calling!"
|
||||
)
|
||||
assert system_message["content"] == expected_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_prompt_variable_substitution(self, engine):
|
||||
"""Test that variables in global prompt are properly substituted."""
|
||||
agent2_node = engine.workflow.nodes["agent2"]
|
||||
(
|
||||
system_message,
|
||||
functions,
|
||||
) = await engine._compose_system_message_functions_for_node(agent2_node)
|
||||
|
||||
# Verify variable substitution in global prompt
|
||||
assert "Dograh Inc" in system_message["content"]
|
||||
assert "{{company_name}}" not in system_message["content"]
|
||||
|
||||
# Full expected content
|
||||
expected_content = (
|
||||
"This is the global context: Dograh Inc\n\nPlease provide your details."
|
||||
)
|
||||
assert system_message["content"] == expected_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_global_node_scenario(self, engine):
|
||||
"""Test behavior when there's no global node in the workflow."""
|
||||
# Remove global node from workflow
|
||||
engine.workflow.global_node_id = None
|
||||
|
||||
start_node = engine.workflow.nodes["start"]
|
||||
(
|
||||
system_message,
|
||||
functions,
|
||||
) = await engine._compose_system_message_functions_for_node(start_node)
|
||||
|
||||
# Should only have the node's own prompt
|
||||
assert system_message["content"] == "Welcome to our service!"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_global_prompt(self, engine):
|
||||
"""Test behavior when global prompt is empty."""
|
||||
# Set global prompt to empty string
|
||||
engine.workflow.nodes["global"].prompt = ""
|
||||
|
||||
start_node = engine.workflow.nodes["start"]
|
||||
(
|
||||
system_message,
|
||||
functions,
|
||||
) = await engine._compose_system_message_functions_for_node(start_node)
|
||||
|
||||
# Should only have the node's own prompt (empty global prompt is filtered out)
|
||||
assert system_message["content"] == "Welcome to our service!"
|
||||
|
||||
def test_default_add_global_prompt_value(self):
|
||||
"""Test that add_global_prompt defaults to True in NodeDataDTO."""
|
||||
node_data = NodeDataDTO(name="Test", prompt="Test prompt")
|
||||
assert node_data.add_global_prompt is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_prompts_concatenation(self, engine):
|
||||
"""Test proper concatenation of global and node prompts."""
|
||||
# Test with agent2 node that has global prompt enabled
|
||||
agent2_node = engine.workflow.nodes["agent2"]
|
||||
|
||||
(
|
||||
system_message,
|
||||
functions,
|
||||
) = await engine._compose_system_message_functions_for_node(agent2_node)
|
||||
|
||||
# Should have global and node prompts concatenated with double newlines
|
||||
# (extraction prompt is no longer included in system message)
|
||||
expected_parts = [
|
||||
"This is the global context: Dograh Inc",
|
||||
"Please provide your details.",
|
||||
]
|
||||
expected_content = "\n\n".join(expected_parts)
|
||||
assert system_message["content"] == expected_content
|
||||
175
api/tests/test_global_prompt_unit.py
Normal file
175
api/tests/test_global_prompt_unit.py
Normal file
|
|
@ -0,0 +1,175 @@
|
|||
"""Unit tests for global prompt functionality - no DB dependencies."""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add the api directory to the Python path
|
||||
api_path = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(api_path))
|
||||
|
||||
from services.workflow.dto import (
|
||||
EdgeDataDTO,
|
||||
NodeDataDTO,
|
||||
NodeType,
|
||||
ReactFlowDTO,
|
||||
RFEdgeDTO,
|
||||
RFNodeDTO,
|
||||
)
|
||||
from services.workflow.workflow import WorkflowGraph
|
||||
|
||||
|
||||
def test_node_data_dto_default_global_prompt():
|
||||
"""Test that add_global_prompt defaults to True."""
|
||||
node_data = NodeDataDTO(name="Test Node", prompt="Test prompt")
|
||||
assert node_data.add_global_prompt is True
|
||||
print("✓ NodeDataDTO defaults add_global_prompt to True")
|
||||
|
||||
|
||||
def test_node_data_dto_explicit_global_prompt():
|
||||
"""Test explicit setting of add_global_prompt."""
|
||||
# Test with False
|
||||
node_data_false = NodeDataDTO(
|
||||
name="Test Node", prompt="Test prompt", add_global_prompt=False
|
||||
)
|
||||
assert node_data_false.add_global_prompt is False
|
||||
|
||||
# Test with True
|
||||
node_data_true = NodeDataDTO(
|
||||
name="Test Node", prompt="Test prompt", add_global_prompt=True
|
||||
)
|
||||
assert node_data_true.add_global_prompt is True
|
||||
print("✓ NodeDataDTO respects explicit add_global_prompt values")
|
||||
|
||||
|
||||
def test_workflow_node_inherits_global_prompt_setting():
|
||||
"""Test that workflow Node inherits add_global_prompt from NodeDataDTO."""
|
||||
nodes = [
|
||||
RFNodeDTO(
|
||||
id="start",
|
||||
type=NodeType.startNode,
|
||||
position={"x": 0, "y": 0},
|
||||
data=NodeDataDTO(
|
||||
name="Start",
|
||||
prompt="Start prompt",
|
||||
is_start=True,
|
||||
add_global_prompt=True,
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="node1",
|
||||
type=NodeType.agentNode,
|
||||
position={"x": 100, "y": 0},
|
||||
data=NodeDataDTO(
|
||||
name="Node with global", prompt="Test prompt", add_global_prompt=True
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="node2",
|
||||
type=NodeType.agentNode,
|
||||
position={"x": 200, "y": 0},
|
||||
data=NodeDataDTO(
|
||||
name="Node without global",
|
||||
prompt="Test prompt",
|
||||
add_global_prompt=False,
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="end",
|
||||
type=NodeType.endNode,
|
||||
position={"x": 300, "y": 0},
|
||||
data=NodeDataDTO(
|
||||
name="End", prompt="End prompt", is_end=True, add_global_prompt=True
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
edges = [
|
||||
RFEdgeDTO(
|
||||
id="e1",
|
||||
source="start",
|
||||
target="node1",
|
||||
data=EdgeDataDTO(label="Next", condition="Continue"),
|
||||
),
|
||||
RFEdgeDTO(
|
||||
id="e2",
|
||||
source="node1",
|
||||
target="node2",
|
||||
data=EdgeDataDTO(label="Next", condition="Continue"),
|
||||
),
|
||||
RFEdgeDTO(
|
||||
id="e3",
|
||||
source="node2",
|
||||
target="end",
|
||||
data=EdgeDataDTO(label="End", condition="Finish"),
|
||||
),
|
||||
]
|
||||
|
||||
flow_dto = ReactFlowDTO(nodes=nodes, edges=edges)
|
||||
workflow = WorkflowGraph(flow_dto)
|
||||
|
||||
assert workflow.nodes["start"].add_global_prompt is True
|
||||
assert workflow.nodes["node1"].add_global_prompt is True
|
||||
assert workflow.nodes["node2"].add_global_prompt is False
|
||||
assert workflow.nodes["end"].add_global_prompt is True
|
||||
print("✓ Workflow nodes correctly inherit add_global_prompt setting")
|
||||
|
||||
|
||||
def test_compose_system_message_respects_global_prompt_flag():
|
||||
"""Test that system message composition respects add_global_prompt flag."""
|
||||
# This is a simplified version - in real tests we'd use the full engine
|
||||
# But this demonstrates the logic
|
||||
|
||||
class MockNode:
|
||||
def __init__(self, add_global_prompt, prompt):
|
||||
self.add_global_prompt = add_global_prompt
|
||||
self.prompt = prompt
|
||||
self.out_edges = []
|
||||
self.extraction_enabled = False
|
||||
|
||||
# Simulate the logic from _compose_system_message_functions_for_node
|
||||
def compose_message(node, global_prompt):
|
||||
prompts = []
|
||||
|
||||
# Only add global prompt if node.add_global_prompt is True
|
||||
if global_prompt and node.add_global_prompt:
|
||||
prompts.append(global_prompt)
|
||||
|
||||
prompts.append(node.prompt)
|
||||
|
||||
return "\n\n".join(p for p in prompts if p)
|
||||
|
||||
global_prompt = "This is the global context"
|
||||
|
||||
# Test with add_global_prompt=True
|
||||
node_with_global = MockNode(add_global_prompt=True, prompt="Node prompt")
|
||||
message_with = compose_message(node_with_global, global_prompt)
|
||||
assert message_with == "This is the global context\n\nNode prompt"
|
||||
|
||||
# Test with add_global_prompt=False
|
||||
node_without_global = MockNode(add_global_prompt=False, prompt="Node prompt")
|
||||
message_without = compose_message(node_without_global, global_prompt)
|
||||
assert message_without == "Node prompt"
|
||||
|
||||
print("✓ System message composition respects add_global_prompt flag")
|
||||
|
||||
|
||||
def test_static_nodes_with_global_prompt():
|
||||
"""Test static nodes can have add_global_prompt setting."""
|
||||
static_node_data = NodeDataDTO(
|
||||
name="Static Node", prompt="Static text", is_static=True, add_global_prompt=True
|
||||
)
|
||||
|
||||
assert static_node_data.is_static is True
|
||||
assert static_node_data.add_global_prompt is True
|
||||
print("✓ Static nodes can have add_global_prompt setting")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run all tests
|
||||
test_node_data_dto_default_global_prompt()
|
||||
test_node_data_dto_explicit_global_prompt()
|
||||
test_workflow_node_inherits_global_prompt_setting()
|
||||
test_compose_system_message_respects_global_prompt_flag()
|
||||
test_static_nodes_with_global_prompt()
|
||||
|
||||
print("\n✅ All unit tests passed!")
|
||||
248
api/tests/test_leave_counter.py
Normal file
248
api/tests/test_leave_counter.py
Normal file
|
|
@ -0,0 +1,248 @@
|
|||
"""
|
||||
Test cases for _leave_counter mechanism in transport clients.
|
||||
|
||||
This test suite verifies that the _leave_counter prevents premature disconnection
|
||||
when both input and output transports are using the same client.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
from pipecat.frames.frames import EndFrame, StartFrame
|
||||
from pipecat.transports.network.fastapi_websocket import (
|
||||
FastAPIWebsocketCallbacks,
|
||||
FastAPIWebsocketClient,
|
||||
FastAPIWebsocketParams,
|
||||
FastAPIWebsocketTransport,
|
||||
)
|
||||
from pipecat.transports.network.small_webrtc import SmallWebRTCClient
|
||||
|
||||
from api.services.telephony.stasis_rtp_client import StasisRTPClient
|
||||
|
||||
|
||||
class TestLeaveCounterFastAPIWebsocket:
|
||||
"""Test the _leave_counter mechanism in FastAPIWebsocketClient."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_leave_counter_prevents_early_disconnect(self):
|
||||
"""Test that disconnect only happens when both transports have disconnected."""
|
||||
# Create mock websocket
|
||||
mock_websocket = Mock()
|
||||
mock_websocket.close = AsyncMock()
|
||||
# Set client_state directly to WebSocketState.CONNECTED value
|
||||
from starlette.websockets import WebSocketState
|
||||
|
||||
mock_websocket.client_state = WebSocketState.CONNECTED
|
||||
|
||||
# Create callbacks
|
||||
callbacks = FastAPIWebsocketCallbacks(
|
||||
on_client_connected=AsyncMock(),
|
||||
on_client_disconnected=AsyncMock(),
|
||||
on_session_timeout=AsyncMock(),
|
||||
)
|
||||
|
||||
# Create client
|
||||
client = FastAPIWebsocketClient(
|
||||
mock_websocket, is_binary=False, callbacks=callbacks
|
||||
)
|
||||
|
||||
# Create StartFrame
|
||||
start_frame = StartFrame()
|
||||
|
||||
# Simulate both input and output transports calling setup
|
||||
await client.setup(start_frame) # Input transport
|
||||
assert client._leave_counter == 1
|
||||
|
||||
await client.setup(start_frame) # Output transport
|
||||
assert client._leave_counter == 2
|
||||
|
||||
# First disconnect - should not actually disconnect
|
||||
await client.disconnect()
|
||||
assert client._leave_counter == 1
|
||||
mock_websocket.close.assert_not_called()
|
||||
callbacks.on_client_disconnected.assert_not_called()
|
||||
|
||||
# Second disconnect - should actually disconnect
|
||||
await client.disconnect()
|
||||
assert client._leave_counter == 0
|
||||
mock_websocket.close.assert_called_once()
|
||||
callbacks.on_client_disconnected.assert_called_once()
|
||||
|
||||
|
||||
class TestLeaveCounterStasisRTP:
|
||||
"""Test the _leave_counter mechanism in StasisRTPClient."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_leave_counter_prevents_early_disconnect(self):
|
||||
"""Test that disconnect only happens when both transports have disconnected."""
|
||||
# Create mock connection
|
||||
mock_connection = Mock()
|
||||
mock_connection.is_connected.return_value = True
|
||||
mock_connection.disconnect = AsyncMock()
|
||||
mock_connection.notify_sockets_closed = AsyncMock()
|
||||
|
||||
# Mock event_handler as a callable that acts as a decorator
|
||||
def mock_event_handler(event_name):
|
||||
def decorator(func):
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
mock_connection.event_handler = mock_event_handler
|
||||
|
||||
# Create callbacks
|
||||
from api.services.telephony.stasis_rtp_transport import StasisRTPCallbacks
|
||||
|
||||
callbacks = StasisRTPCallbacks(
|
||||
on_client_connected=AsyncMock(),
|
||||
on_client_disconnected=AsyncMock(),
|
||||
on_client_closed=AsyncMock(),
|
||||
)
|
||||
|
||||
# Create client
|
||||
client = StasisRTPClient(mock_connection, callbacks)
|
||||
|
||||
# Create StartFrame
|
||||
start_frame = StartFrame()
|
||||
|
||||
# Simulate both input and output transports calling setup
|
||||
await client.setup(start_frame) # Input transport
|
||||
assert client._leave_counter == 1
|
||||
|
||||
await client.setup(start_frame) # Output transport
|
||||
assert client._leave_counter == 2
|
||||
|
||||
# First disconnect - should not actually disconnect
|
||||
await client.disconnect()
|
||||
assert client._leave_counter == 1
|
||||
mock_connection.disconnect.assert_not_called()
|
||||
|
||||
# Second disconnect - should actually disconnect
|
||||
await client.disconnect()
|
||||
assert client._leave_counter == 0
|
||||
mock_connection.disconnect.assert_called_once()
|
||||
|
||||
|
||||
class TestLeaveCounterSmallWebRTC:
|
||||
"""Test the _leave_counter mechanism in SmallWebRTCClient."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_leave_counter_prevents_early_disconnect(self):
|
||||
"""Test that disconnect only happens when both transports have disconnected."""
|
||||
# Create mock connection
|
||||
mock_connection = Mock()
|
||||
mock_connection.is_connected.return_value = True
|
||||
mock_connection.disconnect = AsyncMock()
|
||||
mock_connection.notify_sockets_closed = AsyncMock()
|
||||
|
||||
# Mock event_handler as a callable that acts as a decorator
|
||||
def mock_event_handler(event_name):
|
||||
def decorator(func):
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
mock_connection.event_handler = mock_event_handler
|
||||
|
||||
# Create callbacks
|
||||
from pipecat.transports.network.small_webrtc import SmallWebRTCCallbacks
|
||||
|
||||
callbacks = SmallWebRTCCallbacks(
|
||||
on_app_message=AsyncMock(),
|
||||
on_client_connected=AsyncMock(),
|
||||
on_client_disconnected=AsyncMock(),
|
||||
)
|
||||
|
||||
# Create client
|
||||
client = SmallWebRTCClient(mock_connection, callbacks)
|
||||
|
||||
# Create StartFrame with required attributes
|
||||
start_frame = StartFrame()
|
||||
|
||||
# Create mock transport params
|
||||
from pipecat.transports.base_transport import TransportParams
|
||||
|
||||
params = TransportParams(
|
||||
audio_in_channels=1, audio_in_sample_rate=16000, audio_out_sample_rate=16000
|
||||
)
|
||||
|
||||
# Simulate both input and output transports calling setup
|
||||
await client.setup(params, start_frame) # Input transport
|
||||
assert client._leave_counter == 1
|
||||
|
||||
await client.setup(params, start_frame) # Output transport
|
||||
assert client._leave_counter == 2
|
||||
|
||||
# First disconnect - should not actually disconnect
|
||||
await client.disconnect()
|
||||
assert client._leave_counter == 1
|
||||
mock_connection.disconnect.assert_not_called()
|
||||
|
||||
# Second disconnect - should actually disconnect
|
||||
await client.disconnect()
|
||||
assert client._leave_counter == 0
|
||||
mock_connection.disconnect.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Complex integration test - requires additional mocking")
|
||||
@pytest.mark.asyncio
|
||||
async def test_transport_lifecycle_with_leave_counter():
|
||||
"""Test complete transport lifecycle with proper leave counter handling."""
|
||||
# Create mock websocket
|
||||
mock_websocket = Mock()
|
||||
mock_websocket.close = AsyncMock()
|
||||
# Set client_state directly to WebSocketState.CONNECTED value
|
||||
from starlette.websockets import WebSocketState
|
||||
|
||||
mock_websocket.client_state = WebSocketState.CONNECTED
|
||||
mock_websocket.iter_bytes = Mock(return_value=iter([]))
|
||||
mock_websocket.send_bytes = AsyncMock()
|
||||
|
||||
# Create transport
|
||||
params = FastAPIWebsocketParams(audio_in_enabled=True, audio_out_enabled=True)
|
||||
transport = FastAPIWebsocketTransport(mock_websocket, params)
|
||||
|
||||
# Get input and output transports
|
||||
input_transport = transport.input()
|
||||
output_transport = transport.output()
|
||||
|
||||
# Setup the transport with required components
|
||||
from pipecat.clocks.system_clock import SystemClock
|
||||
from pipecat.processors.frame_processor import FrameProcessorSetup
|
||||
from pipecat.utils.asyncio.task_manager import TaskManager, TaskManagerParams
|
||||
|
||||
clock = SystemClock()
|
||||
task_manager = TaskManager()
|
||||
|
||||
# Setup task manager with event loop
|
||||
loop = asyncio.get_event_loop()
|
||||
task_manager_params = TaskManagerParams(loop=loop)
|
||||
task_manager.setup(task_manager_params)
|
||||
|
||||
setup = FrameProcessorSetup(clock=clock, task_manager=task_manager)
|
||||
|
||||
# Setup both input and output transports
|
||||
await input_transport.setup(setup)
|
||||
await output_transport.setup(setup)
|
||||
|
||||
# Start both transports
|
||||
start_frame = StartFrame()
|
||||
await input_transport.start(start_frame)
|
||||
await output_transport.start(start_frame)
|
||||
|
||||
# Verify leave counter is 2
|
||||
assert transport._client._leave_counter == 2
|
||||
|
||||
# Stop input transport
|
||||
end_frame = EndFrame()
|
||||
await input_transport.stop(end_frame)
|
||||
|
||||
# Verify websocket not closed yet
|
||||
mock_websocket.close.assert_not_called()
|
||||
|
||||
# Stop output transport
|
||||
await output_transport.stop(end_frame)
|
||||
|
||||
# Now websocket should be closed
|
||||
mock_websocket.close.assert_called_once()
|
||||
143
api/tests/test_llm_generated_text_signal.py
Normal file
143
api/tests/test_llm_generated_text_signal.py
Normal file
|
|
@ -0,0 +1,143 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
Test script to verify that LLMGeneratedTextFrame signaling works correctly
|
||||
with the new local variable approach.
|
||||
"""
|
||||
|
||||
|
||||
def test_local_variable_logic():
|
||||
"""Test the core logic using the same pattern as the implementation"""
|
||||
|
||||
print("=== Testing Local Variable Logic ===")
|
||||
|
||||
# Simulate the logic from _process_context
|
||||
text_generation_signaled = False
|
||||
frames_sent = []
|
||||
|
||||
# Simulate chunks with text content
|
||||
chunks_with_content = ["Hello", " world", "!"]
|
||||
|
||||
for content in chunks_with_content:
|
||||
# This is the exact logic from our implementation
|
||||
if content: # equivalent to chunk.choices[0].delta.content
|
||||
if not text_generation_signaled:
|
||||
frames_sent.append("LLMGeneratedTextFrame")
|
||||
text_generation_signaled = True
|
||||
frames_sent.append(f"LLMTextFrame({content})")
|
||||
|
||||
print(f"Frames sent: {frames_sent}")
|
||||
|
||||
# Verify behavior
|
||||
generated_signals = [f for f in frames_sent if f == "LLMGeneratedTextFrame"]
|
||||
text_frames = [f for f in frames_sent if f.startswith("LLMTextFrame")]
|
||||
|
||||
assert len(generated_signals) == 1, (
|
||||
f"Expected 1 signal, got {len(generated_signals)}"
|
||||
)
|
||||
assert len(text_frames) == 3, f"Expected 3 text frames, got {len(text_frames)}"
|
||||
assert frames_sent[0] == "LLMGeneratedTextFrame", "Signal should be first"
|
||||
|
||||
print("✅ Local variable logic works correctly")
|
||||
return True
|
||||
|
||||
|
||||
def test_no_text_logic():
|
||||
"""Test that no signal is sent when there's no text"""
|
||||
|
||||
print("\n=== Testing No Text Logic ===")
|
||||
|
||||
text_generation_signaled = False
|
||||
frames_sent = []
|
||||
|
||||
# Simulate chunks with no text content (function calls only)
|
||||
chunks_with_content = [None, None, None] # No text content
|
||||
|
||||
for content in chunks_with_content:
|
||||
if content: # This will be False for all chunks
|
||||
if not text_generation_signaled:
|
||||
frames_sent.append("LLMGeneratedTextFrame")
|
||||
text_generation_signaled = True
|
||||
frames_sent.append(f"LLMTextFrame({content})")
|
||||
|
||||
print(f"Frames sent: {frames_sent}")
|
||||
|
||||
assert len(frames_sent) == 0, f"Expected no frames, got {frames_sent}"
|
||||
|
||||
print("✅ No signal sent when no text content")
|
||||
return True
|
||||
|
||||
|
||||
def test_mixed_content_logic():
|
||||
"""Test behavior with mixed function calls and text"""
|
||||
|
||||
print("\n=== Testing Mixed Content Logic ===")
|
||||
|
||||
text_generation_signaled = False
|
||||
frames_sent = []
|
||||
|
||||
# Simulate chunks: function call, text, function call, text
|
||||
chunks = [
|
||||
{"type": "function", "content": None},
|
||||
{"type": "text", "content": "Hello"},
|
||||
{"type": "function", "content": None},
|
||||
{"type": "text", "content": " world"},
|
||||
]
|
||||
|
||||
for chunk in chunks:
|
||||
if chunk["type"] == "function":
|
||||
frames_sent.append("FunctionCallFrame")
|
||||
elif chunk["content"]: # text content
|
||||
if not text_generation_signaled:
|
||||
frames_sent.append("LLMGeneratedTextFrame")
|
||||
text_generation_signaled = True
|
||||
frames_sent.append(f"LLMTextFrame({chunk['content']})")
|
||||
|
||||
print(f"Frames sent: {frames_sent}")
|
||||
|
||||
generated_signals = [f for f in frames_sent if f == "LLMGeneratedTextFrame"]
|
||||
|
||||
assert len(generated_signals) == 1, (
|
||||
f"Expected 1 signal, got {len(generated_signals)}"
|
||||
)
|
||||
# Signal should come before first text frame but after any function frames
|
||||
signal_index = frames_sent.index("LLMGeneratedTextFrame")
|
||||
first_text_index = next(
|
||||
i for i, f in enumerate(frames_sent) if f.startswith("LLMTextFrame")
|
||||
)
|
||||
assert signal_index == first_text_index - 1, (
|
||||
"Signal should come right before first text"
|
||||
)
|
||||
|
||||
print("✅ Mixed content logic works correctly")
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
try:
|
||||
test1_result = test_local_variable_logic()
|
||||
test2_result = test_no_text_logic()
|
||||
test3_result = test_mixed_content_logic()
|
||||
|
||||
print(f"\n=== Test Results ===")
|
||||
print(f"Local variable test: {'✅ PASS' if test1_result else '❌ FAIL'}")
|
||||
print(f"No text test: {'✅ PASS' if test2_result else '❌ FAIL'}")
|
||||
print(f"Mixed content test: {'✅ PASS' if test3_result else '❌ FAIL'}")
|
||||
|
||||
if test1_result and test2_result and test3_result:
|
||||
print("\n🎉 All LLMGeneratedTextFrame signaling logic tests passed!")
|
||||
print(
|
||||
"✅ Implementation correctly signals text generation once, as early as possible"
|
||||
)
|
||||
else:
|
||||
print("\n❌ Some tests failed.")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Test failed with error: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
99
api/tests/test_llm_response_reorder.py
Normal file
99
api/tests/test_llm_response_reorder.py
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
import unittest
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
FunctionCallInProgressFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.google.llm import (
|
||||
GoogleAssistantContextAggregator,
|
||||
GoogleLLMContext,
|
||||
)
|
||||
from pipecat.services.openai.llm import OpenAIAssistantContextAggregator
|
||||
|
||||
|
||||
class TestReorderOpenAIAssistantContextAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_reorder_function_messages_openai(self):
|
||||
"""Ensure that after a text aggregation the function-call messages are moved
|
||||
to appear immediately after the text response, maintaining chronological
|
||||
order (assistant text -> function call -> tool response).
|
||||
"""
|
||||
|
||||
context = OpenAILLMContext()
|
||||
aggregator = OpenAIAssistantContextAggregator(context)
|
||||
|
||||
# Simulate the start of an LLM response so that the aggregator creates a
|
||||
# response session ID that is later used for re-ordering.
|
||||
await aggregator._handle_llm_start(LLMFullResponseStartFrame())
|
||||
|
||||
# Simulate the model emitting a function call which the aggregator will
|
||||
# record for potential re-ordering.
|
||||
await aggregator._handle_function_call_in_progress(
|
||||
FunctionCallInProgressFrame(
|
||||
function_name="get_weather",
|
||||
tool_call_id="1",
|
||||
arguments={},
|
||||
)
|
||||
)
|
||||
|
||||
# Now push the textual part of the assistant response. This should
|
||||
# trigger the re-ordering so that the two function-related messages
|
||||
# appear *after* this text.
|
||||
await aggregator.handle_aggregation("Hello!")
|
||||
|
||||
messages = context.get_messages()
|
||||
|
||||
# We expect exactly three messages after re-ordering.
|
||||
self.assertEqual(len(messages), 3)
|
||||
|
||||
# 1. Assistant text
|
||||
self.assertEqual(messages[0]["role"], "assistant")
|
||||
self.assertEqual(messages[0]["content"], "Hello!")
|
||||
|
||||
# 2. Assistant function-call message
|
||||
self.assertEqual(messages[1]["role"], "assistant")
|
||||
self.assertIn("tool_calls", messages[1])
|
||||
|
||||
# 3. Tool response
|
||||
self.assertEqual(messages[2]["role"], "tool")
|
||||
self.assertEqual(messages[2]["tool_call_id"], "1")
|
||||
|
||||
|
||||
class TestReorderGoogleAssistantContextAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_reorder_function_messages_google(self):
|
||||
context = GoogleLLMContext()
|
||||
aggregator = GoogleAssistantContextAggregator(context)
|
||||
|
||||
# Start an LLM response session.
|
||||
await aggregator._handle_llm_start(LLMFullResponseStartFrame())
|
||||
|
||||
# Emit a function call.
|
||||
await aggregator._handle_function_call_in_progress(
|
||||
FunctionCallInProgressFrame(
|
||||
function_name="get_weather",
|
||||
tool_call_id="1",
|
||||
arguments={},
|
||||
)
|
||||
)
|
||||
|
||||
# Push the textual content.
|
||||
await aggregator.handle_aggregation("Hello!")
|
||||
|
||||
messages = context.messages # Google context stores Content objects.
|
||||
|
||||
self.assertEqual(len(messages), 3)
|
||||
|
||||
# The first message should be the model text.
|
||||
first_msg = messages[0].to_json_dict()
|
||||
self.assertEqual(first_msg["role"], "model")
|
||||
self.assertEqual(first_msg["parts"][0]["text"], "Hello!")
|
||||
|
||||
# The second message contains the function call (also from the model).
|
||||
second_msg = messages[1].to_json_dict()
|
||||
self.assertEqual(second_msg["role"], "model")
|
||||
self.assertIn("function_call", second_msg["parts"][0])
|
||||
|
||||
# The third message is the placeholder function response.
|
||||
third_msg = messages[2].to_json_dict()
|
||||
self.assertEqual(third_msg["role"], "user")
|
||||
self.assertIn("function_response", third_msg["parts"][0])
|
||||
506
api/tests/test_looptalk_routes.py
Normal file
506
api/tests/test_looptalk_routes.py
Normal file
|
|
@ -0,0 +1,506 @@
|
|||
"""
|
||||
Tests for LoopTalk API routes and orchestration.
|
||||
|
||||
This module tests the LoopTalk testing functionality including test session creation,
|
||||
pipeline orchestration, and agent-to-agent communication.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from fastapi import status
|
||||
|
||||
from api.db.db_client import DBClient
|
||||
from api.services.looptalk.orchestrator import LoopTalkTestOrchestrator
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def actor_workflow_definition():
|
||||
"""Sample actor workflow definition for testing."""
|
||||
return {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "1",
|
||||
"type": "startCall",
|
||||
"position": {"x": 0, "y": 0},
|
||||
"data": {
|
||||
"prompt": "Hello, I'm the actor agent.",
|
||||
"is_static": True,
|
||||
"name": "Start Call",
|
||||
"is_start": True,
|
||||
"allow_interrupt": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "2",
|
||||
"type": "agentNode",
|
||||
"position": {"x": 100, "y": 0},
|
||||
"data": {
|
||||
"prompt": "You are an actor agent testing the adversary. Ask probing questions.",
|
||||
"name": "Actor Agent",
|
||||
"allow_interrupt": True,
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "3",
|
||||
"type": "endCall",
|
||||
"position": {"x": 200, "y": 0},
|
||||
"data": {
|
||||
"prompt": "Goodbye!",
|
||||
"name": "End Call",
|
||||
"is_end": True,
|
||||
},
|
||||
},
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"id": "e1",
|
||||
"source": "1",
|
||||
"target": "2",
|
||||
"data": {"label": "Continue", "condition": "Always"},
|
||||
},
|
||||
{
|
||||
"id": "e2",
|
||||
"source": "2",
|
||||
"target": "3",
|
||||
"data": {"label": "End", "condition": "Always"},
|
||||
},
|
||||
],
|
||||
"stt": {"provider": "openai", "api_key": "test-key", "model": "whisper-1"},
|
||||
"llm": {"provider": "openai", "api_key": "test-key", "model": "gpt-4o-mini"},
|
||||
"tts": {
|
||||
"provider": "openai",
|
||||
"api_key": "test-key",
|
||||
"model": "tts-1",
|
||||
"voice": "nova",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def adversary_workflow_definition():
|
||||
"""Sample adversary workflow definition for testing."""
|
||||
return {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "1",
|
||||
"type": "startCall",
|
||||
"position": {"x": 0, "y": 0},
|
||||
"data": {
|
||||
"prompt": "Hello, I'm the adversary agent.",
|
||||
"is_static": True,
|
||||
"name": "Start Call",
|
||||
"is_start": True,
|
||||
"allow_interrupt": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "2",
|
||||
"type": "agentNode",
|
||||
"position": {"x": 100, "y": 0},
|
||||
"data": {
|
||||
"prompt": "You are an adversary agent being tested. Respond defensively.",
|
||||
"name": "Adversary Agent",
|
||||
"allow_interrupt": True,
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "3",
|
||||
"type": "endCall",
|
||||
"position": {"x": 200, "y": 0},
|
||||
"data": {
|
||||
"prompt": "Goodbye!",
|
||||
"name": "End Call",
|
||||
"is_end": True,
|
||||
},
|
||||
},
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"id": "e1",
|
||||
"source": "1",
|
||||
"target": "2",
|
||||
"data": {"label": "Continue", "condition": "Always"},
|
||||
},
|
||||
{
|
||||
"id": "e2",
|
||||
"source": "2",
|
||||
"target": "3",
|
||||
"data": {"label": "End", "condition": "Always"},
|
||||
},
|
||||
],
|
||||
"stt": {"provider": "deepgram", "api_key": "test-key", "model": "nova-2"},
|
||||
"llm": {
|
||||
"provider": "groq",
|
||||
"api_key": "test-key",
|
||||
"model": "llama-3.1-70b-versatile",
|
||||
},
|
||||
"tts": {"provider": "deepgram", "api_key": "test-key", "voice": "nova-2"},
|
||||
}
|
||||
|
||||
|
||||
from pipecat.processors.frame_processor import FrameProcessor
|
||||
|
||||
|
||||
class MockSTTService(FrameProcessor):
|
||||
"""Mock STT service for testing."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def run_stt(self, audio: bytes) -> str:
|
||||
return "Mock transcription"
|
||||
|
||||
|
||||
class MockLLMService(FrameProcessor):
|
||||
"""Mock LLM service for testing."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def run_llm(self, messages) -> str:
|
||||
return "Mock LLM response"
|
||||
|
||||
def create_context_aggregator(self, context):
|
||||
"""Mock context aggregator creation."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
class MockTTSService(FrameProcessor):
|
||||
"""Mock TTS service for testing."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def run_tts(self, text: str) -> bytes:
|
||||
return b"Mock audio data"
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_user_with_org(db_session):
|
||||
"""Create a test user with an organization set up."""
|
||||
user = await db_session.get_or_create_user_by_provider_id("test_looptalk_user")
|
||||
org, _ = await db_session.get_or_create_organization_by_provider_id(
|
||||
"test_looptalk_org"
|
||||
)
|
||||
|
||||
user_id = user.id
|
||||
org_id = org.id
|
||||
|
||||
await db_session.add_user_to_organization(user_id, org_id)
|
||||
|
||||
# Update user's selected organization
|
||||
async with db_session.async_session() as session:
|
||||
from sqlalchemy import update
|
||||
|
||||
from api.db.models import UserModel
|
||||
|
||||
await session.execute(
|
||||
update(UserModel)
|
||||
.where(UserModel.id == user_id)
|
||||
.values(selected_organization_id=org_id)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# Return fresh user object
|
||||
return await db_session.get_user_by_id(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_test_session(
|
||||
test_client_factory,
|
||||
db_session,
|
||||
test_user_with_org,
|
||||
actor_workflow_definition,
|
||||
adversary_workflow_definition,
|
||||
):
|
||||
"""Test creating a new LoopTalk test session."""
|
||||
async with test_client_factory(test_user_with_org) as test_client:
|
||||
# First create two workflows
|
||||
actor_workflow_response = await test_client.post(
|
||||
"/api/v1/workflow/create",
|
||||
json={
|
||||
"name": "Actor Workflow",
|
||||
"workflow_definition": actor_workflow_definition,
|
||||
},
|
||||
)
|
||||
assert actor_workflow_response.status_code == status.HTTP_200_OK
|
||||
actor_workflow_id = actor_workflow_response.json()["id"]
|
||||
|
||||
adversary_workflow_response = await test_client.post(
|
||||
"/api/v1/workflow/create",
|
||||
json={
|
||||
"name": "Adversary Workflow",
|
||||
"workflow_definition": adversary_workflow_definition,
|
||||
},
|
||||
)
|
||||
assert adversary_workflow_response.status_code == status.HTTP_200_OK
|
||||
adversary_workflow_id = adversary_workflow_response.json()["id"]
|
||||
|
||||
# Create test session
|
||||
response = await test_client.post(
|
||||
"/api/v1/looptalk/test-sessions",
|
||||
json={
|
||||
"name": "Test Session 1",
|
||||
"actor_workflow_id": actor_workflow_id,
|
||||
"adversary_workflow_id": adversary_workflow_id,
|
||||
"config": {"test_duration": 60},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["name"] == "Test Session 1"
|
||||
assert data["status"] == "pending"
|
||||
assert data["actor_workflow_id"] == actor_workflow_id
|
||||
assert data["adversary_workflow_id"] == adversary_workflow_id
|
||||
assert data["config"]["test_duration"] == 60
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_test_sessions(test_client_factory, db_session, test_user_with_org):
|
||||
"""Test listing LoopTalk test sessions."""
|
||||
async with test_client_factory(test_user_with_org) as test_client:
|
||||
response = await test_client.get(
|
||||
"/api/v1/looptalk/test-sessions",
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert isinstance(data, list)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_looptalk_orchestrator_plumbing(
|
||||
db_session: DBClient, actor_workflow_definition, adversary_workflow_definition
|
||||
):
|
||||
"""Test the LoopTalk orchestrator plumbing with mocked services."""
|
||||
|
||||
# Create test user and organization
|
||||
user = await db_session.get_or_create_user_by_provider_id(
|
||||
provider_id="test-user-123"
|
||||
)
|
||||
org, _ = await db_session.get_or_create_organization_by_provider_id(
|
||||
org_provider_id="test-org-123"
|
||||
)
|
||||
|
||||
# Get IDs before session closes
|
||||
user_id = user.id
|
||||
org_id = org.id
|
||||
|
||||
await db_session.add_user_to_organization(user_id, org_id)
|
||||
|
||||
# Update user's selected organization manually
|
||||
async with db_session.async_session() as session:
|
||||
from sqlalchemy import update
|
||||
|
||||
from api.db.models import UserModel
|
||||
|
||||
await session.execute(
|
||||
update(UserModel)
|
||||
.where(UserModel.id == user_id)
|
||||
.values(selected_organization_id=org_id)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
actor_workflow = await db_session.create_workflow(
|
||||
name="Actor Workflow",
|
||||
workflow_definition=actor_workflow_definition,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
adversary_workflow = await db_session.create_workflow(
|
||||
name="Adversary Workflow",
|
||||
workflow_definition=adversary_workflow_definition,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Create test session
|
||||
test_session = await db_session.create_test_session(
|
||||
organization_id=org_id,
|
||||
name="Test Session",
|
||||
actor_workflow_id=actor_workflow.id,
|
||||
adversary_workflow_id=adversary_workflow.id,
|
||||
config={"test_duration": 10},
|
||||
)
|
||||
|
||||
# Mock the service factories - patch at the actual import location in pipeline_builder
|
||||
with (
|
||||
patch(
|
||||
"api.services.looptalk.core.pipeline_builder.create_stt_service"
|
||||
) as mock_stt_factory,
|
||||
patch(
|
||||
"api.services.looptalk.core.pipeline_builder.create_llm_service"
|
||||
) as mock_llm_factory,
|
||||
patch(
|
||||
"api.services.looptalk.core.pipeline_builder.create_tts_service"
|
||||
) as mock_tts_factory,
|
||||
patch(
|
||||
"api.services.workflow.pipecat_engine.PipecatEngine"
|
||||
) as mock_engine_class,
|
||||
patch(
|
||||
"api.services.pipecat.pipeline_builder.build_pipeline"
|
||||
) as mock_build_pipeline,
|
||||
patch("api.services.pipecat.pipeline_builder.PipelineTask") as mock_task_class,
|
||||
):
|
||||
# Configure mocks
|
||||
mock_stt_factory.return_value = MockSTTService()
|
||||
mock_llm_factory.return_value = MockLLMService()
|
||||
mock_tts_factory.return_value = MockTTSService()
|
||||
|
||||
mock_engine = MagicMock()
|
||||
mock_engine.initialize = AsyncMock()
|
||||
mock_engine.get_callback_processor = MagicMock(return_value=MagicMock())
|
||||
mock_engine_class.return_value = mock_engine
|
||||
|
||||
# Mock pipeline and task
|
||||
mock_pipeline = MagicMock()
|
||||
mock_task = MagicMock()
|
||||
mock_task.run = AsyncMock()
|
||||
mock_task.cancel = AsyncMock() # Make cancel async
|
||||
mock_build_pipeline.return_value = mock_pipeline
|
||||
mock_task_class.return_value = mock_task
|
||||
|
||||
# Create orchestrator
|
||||
orchestrator = LoopTalkTestOrchestrator(db_client=db_session)
|
||||
|
||||
# Start test session (in a separate task to avoid blocking)
|
||||
start_task = asyncio.create_task(
|
||||
orchestrator.start_test_session(
|
||||
test_session_id=test_session.id, organization_id=org_id
|
||||
)
|
||||
)
|
||||
|
||||
# Give it a moment to start
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Verify the session is running through session manager
|
||||
session_info = orchestrator.session_manager.get_session(test_session.id)
|
||||
assert session_info is not None
|
||||
assert session_info["test_session"].id == test_session.id
|
||||
assert "actor_task" in session_info
|
||||
assert "adversary_task" in session_info
|
||||
|
||||
# Verify service factories were called
|
||||
assert mock_stt_factory.call_count == 2 # Once for each agent
|
||||
assert mock_llm_factory.call_count == 2
|
||||
assert mock_tts_factory.call_count == 2
|
||||
|
||||
# Verify pipelines were created with PipelineTask
|
||||
assert mock_task_class.call_count == 2
|
||||
|
||||
# Stop the test session
|
||||
await orchestrator.stop_test_session(test_session_id=test_session.id)
|
||||
|
||||
# Verify session was cleaned up
|
||||
assert orchestrator.session_manager.get_session(test_session.id) is None
|
||||
|
||||
# Cancel the start task
|
||||
start_task.cancel()
|
||||
try:
|
||||
await start_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_test_creation(
|
||||
test_client_factory,
|
||||
db_session,
|
||||
test_user_with_org,
|
||||
actor_workflow_definition,
|
||||
adversary_workflow_definition,
|
||||
):
|
||||
"""Test creating a load test with multiple sessions."""
|
||||
async with test_client_factory(test_user_with_org) as test_client:
|
||||
# First create two workflows
|
||||
actor_workflow_response = await test_client.post(
|
||||
"/api/v1/workflow/create",
|
||||
json={
|
||||
"name": "Actor Workflow",
|
||||
"workflow_definition": actor_workflow_definition,
|
||||
},
|
||||
)
|
||||
actor_workflow_id = actor_workflow_response.json()["id"]
|
||||
|
||||
adversary_workflow_response = await test_client.post(
|
||||
"/api/v1/workflow/create",
|
||||
json={
|
||||
"name": "Adversary Workflow",
|
||||
"workflow_definition": adversary_workflow_definition,
|
||||
},
|
||||
)
|
||||
adversary_workflow_id = adversary_workflow_response.json()["id"]
|
||||
|
||||
# Create load test
|
||||
response = await test_client.post(
|
||||
"/api/v1/looptalk/load-tests",
|
||||
json={
|
||||
"name_prefix": "Load Test",
|
||||
"actor_workflow_id": actor_workflow_id,
|
||||
"adversary_workflow_id": adversary_workflow_id,
|
||||
"test_count": 3,
|
||||
"config": {"test_duration": 30},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["total"] == 3
|
||||
assert "load_test_group_id" in data
|
||||
assert len(data["test_session_ids"]) == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_workflow_ids(
|
||||
test_client_factory, db_session, test_user_with_org
|
||||
):
|
||||
"""Test creating test session with invalid workflow IDs."""
|
||||
async with test_client_factory(test_user_with_org) as test_client:
|
||||
response = await test_client.post(
|
||||
"/api/v1/looptalk/test-sessions",
|
||||
json={
|
||||
"name": "Invalid Test",
|
||||
"actor_workflow_id": 99999,
|
||||
"adversary_workflow_id": 99999,
|
||||
"config": {},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert "workflow not found" in response.json()["detail"].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transport_manager():
|
||||
"""Test the internal transport manager functionality."""
|
||||
from pipecat.transports import InternalTransportManager, TransportParams
|
||||
|
||||
manager = InternalTransportManager()
|
||||
|
||||
# Create transport pair
|
||||
params = TransportParams(
|
||||
audio_out_enabled=True,
|
||||
audio_in_enabled=True,
|
||||
audio_out_sample_rate=16000,
|
||||
audio_in_sample_rate=16000,
|
||||
)
|
||||
|
||||
actor_transport, adversary_transport = manager.create_transport_pair(
|
||||
test_session_id="test-123", actor_params=params, adversary_params=params
|
||||
)
|
||||
|
||||
# Verify transports are connected
|
||||
assert actor_transport._output._partner == adversary_transport._input
|
||||
assert adversary_transport._output._partner == actor_transport._input
|
||||
|
||||
# Verify transport pair is tracked
|
||||
assert manager.get_active_test_count() == 1
|
||||
assert manager.get_transport_pair("test-123") is not None
|
||||
|
||||
# Remove transport pair
|
||||
manager.remove_transport_pair("test-123")
|
||||
assert manager.get_active_test_count() == 0
|
||||
assert manager.get_transport_pair("test-123") is None
|
||||
142
api/tests/test_mock_llm_service.py
Normal file
142
api/tests/test_mock_llm_service.py
Normal file
|
|
@ -0,0 +1,142 @@
|
|||
### - The test gets stuck. Need to figure out a way to run the test
|
||||
|
||||
# import asyncio
|
||||
# import unittest
|
||||
|
||||
# from loguru import logger
|
||||
|
||||
# from pipecat.frames.frames import (
|
||||
# FunctionCallFromLLM,
|
||||
# FunctionCallInProgressFrame,
|
||||
# FunctionCallResultFrame,
|
||||
# FunctionCallsStartedFrame,
|
||||
# LLMFullResponseEndFrame,
|
||||
# LLMFullResponseStartFrame,
|
||||
# LLMTextFrame,
|
||||
# )
|
||||
# from pipecat.processors.aggregators.openai_llm_context import (
|
||||
# OpenAILLMContext,
|
||||
# OpenAILLMContextFrame,
|
||||
# )
|
||||
# from pipecat.processors.frame_processor import FrameDirection
|
||||
# from pipecat.services.llm_service import (
|
||||
# FunctionCallParams,
|
||||
# FunctionCallResultProperties,
|
||||
# LLMService,
|
||||
# )
|
||||
# from pipecat.tests.utils import run_test
|
||||
|
||||
|
||||
# class MockLLMService(LLMService):
|
||||
# """A very small mocked LLM service that, upon receiving an
|
||||
# ``OpenAILLMContextFrame``, streams a text completion followed by the
|
||||
# execution of the supplied tools (function calls).
|
||||
# """
|
||||
|
||||
# def __init__(self, *, content: str, tools: list[dict[str, dict]], **kwargs):
|
||||
# # Run function calls sequentially so that frame ordering is deterministic.
|
||||
# super().__init__(run_in_parallel=False, **kwargs)
|
||||
# self._content = content
|
||||
# self._tools = tools
|
||||
|
||||
# async def process_frame(self, frame, direction: FrameDirection):
|
||||
# await super().process_frame(frame, direction)
|
||||
|
||||
# if isinstance(frame, OpenAILLMContextFrame) and direction == FrameDirection.DOWNSTREAM:
|
||||
# # Simulate the start of a streamed completion.
|
||||
# await self.push_frame(LLMFullResponseStartFrame())
|
||||
# await self.push_frame(LLMTextFrame(self._content))
|
||||
|
||||
# # Convert tool specs into FunctionCallFromLLM objects.
|
||||
# function_calls = []
|
||||
# for idx, tool in enumerate(self._tools):
|
||||
# function_calls.append(
|
||||
# FunctionCallFromLLM(
|
||||
# function_name=tool["function_name"],
|
||||
# tool_call_id=f"tool_{idx}",
|
||||
# arguments=tool.get("arguments", {}),
|
||||
# context=frame.context,
|
||||
# )
|
||||
# )
|
||||
|
||||
# # Ask the LLM service base class to execute the calls.
|
||||
# await self.run_function_calls(function_calls)
|
||||
|
||||
# # Finish the streamed response.
|
||||
# await self.push_frame(LLMFullResponseEndFrame())
|
||||
|
||||
# async def _run_function_call(self, runner_item): # type: ignore[override] – narrow signature
|
||||
# # Ensure run_llm=True so that downstream processors know they can
|
||||
# # immediately trigger another LLM call after the result is committed.
|
||||
# runner_item.run_llm = True
|
||||
# await super()._run_function_call(runner_item)
|
||||
|
||||
|
||||
# class TestMockLLMPipeline(unittest.IsolatedAsyncioTestCase):
|
||||
# async def test_mock_llm_pipeline_with_tools(self):
|
||||
# # ------------------------------------------------------------------
|
||||
# # 1. Create mocked LLM service with completion text and tools
|
||||
# # ------------------------------------------------------------------
|
||||
# completion_text = "Hello from mocked LLM!"
|
||||
# tools = [
|
||||
# {"function_name": "tool_one", "arguments": {"a": 1}},
|
||||
# {"function_name": "tool_two", "arguments": {"b": 2}},
|
||||
# ]
|
||||
# llm = MockLLMService(content=completion_text, tools=tools)
|
||||
|
||||
# # ------------------------------------------------------------------
|
||||
# # 2. Register the tool functions – they simply log & sleep briefly.
|
||||
# # Each of them marks that it has run so that we can assert later.
|
||||
# # ------------------------------------------------------------------
|
||||
# executed: dict[str, bool] = {t["function_name"]: False for t in tools}
|
||||
|
||||
# def make_handler(name: str):
|
||||
# async def _handler(params: FunctionCallParams):
|
||||
# logger.debug(f"Executing {name} with args {params.arguments}")
|
||||
# executed[name] = True
|
||||
# await asyncio.sleep(0.01)
|
||||
# await params.result_callback(
|
||||
# {"status": "ok"},
|
||||
# properties=FunctionCallResultProperties(run_llm=True),
|
||||
# )
|
||||
|
||||
# return _handler
|
||||
|
||||
# for t in tools:
|
||||
# llm.register_function(t["function_name"], make_handler(t["function_name"]))
|
||||
|
||||
# # ------------------------------------------------------------------
|
||||
# # 3. Build the pipeline and send the initial context frame that
|
||||
# # triggers the completion.
|
||||
# # ------------------------------------------------------------------
|
||||
# context = OpenAILLMContext()
|
||||
# context.add_message({"role": "user", "content": "Hi!"})
|
||||
# frames_to_send = [OpenAILLMContextFrame(context)]
|
||||
|
||||
# expected_down_frames = [
|
||||
# LLMFullResponseStartFrame,
|
||||
# LLMTextFrame,
|
||||
# FunctionCallsStartedFrame,
|
||||
# FunctionCallInProgressFrame,
|
||||
# FunctionCallResultFrame,
|
||||
# FunctionCallInProgressFrame,
|
||||
# FunctionCallResultFrame,
|
||||
# LLMFullResponseEndFrame,
|
||||
# ]
|
||||
|
||||
# # Run the test pipeline.
|
||||
# received_down_frames, _ = await run_test(
|
||||
# llm,
|
||||
# frames_to_send=frames_to_send,
|
||||
# expected_down_frames=expected_down_frames,
|
||||
# )
|
||||
|
||||
# # ------------------------------------------------------------------
|
||||
# # 4. Verify that both tool functions executed and that run_llm=True
|
||||
# # in all FunctionCallResultFrame instances.
|
||||
# # ------------------------------------------------------------------
|
||||
# self.assertTrue(all(executed.values()))
|
||||
|
||||
# for frame in received_down_frames:
|
||||
# if isinstance(frame, FunctionCallResultFrame):
|
||||
# self.assertTrue(frame.run_llm)
|
||||
236
api/tests/test_pipecat_disposition_mapping.py
Normal file
236
api/tests/test_pipecat_disposition_mapping.py
Normal file
|
|
@ -0,0 +1,236 @@
|
|||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
|
||||
|
||||
def create_disposition_mapping_side_effect(mapping_dict):
|
||||
"""Helper to create a side effect function for disposition mapping."""
|
||||
|
||||
async def side_effect(value, org_id):
|
||||
return mapping_dict.get(value, value)
|
||||
|
||||
return side_effect
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dependencies():
|
||||
"""Create mock dependencies for PipecatEngine."""
|
||||
mock_task = MagicMock()
|
||||
mock_task.queue_frame = AsyncMock()
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_context = MagicMock()
|
||||
mock_workflow = MagicMock()
|
||||
|
||||
return {
|
||||
"task": mock_task,
|
||||
"llm": mock_llm,
|
||||
"context": mock_context,
|
||||
"workflow": mock_workflow,
|
||||
"call_context_vars": {},
|
||||
"workflow_run_id": 123,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_disposition_mapping_with_call_disposition(mock_dependencies):
|
||||
"""Test disposition mapping when call_disposition is present."""
|
||||
engine = PipecatEngine(**mock_dependencies)
|
||||
|
||||
# Setup gathered context
|
||||
engine._gathered_context = {
|
||||
"call_disposition": "XFER",
|
||||
"agent_name": "Alex",
|
||||
"total_debt": "$15000",
|
||||
}
|
||||
|
||||
# Mock the disposition mapper functions
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run"
|
||||
) as mock_get_org_id:
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine.apply_disposition_mapping"
|
||||
) as mock_apply_mapping:
|
||||
# Mock organization ID
|
||||
mock_get_org_id.return_value = 1
|
||||
|
||||
# Mock disposition mapping
|
||||
mock_apply_mapping.side_effect = create_disposition_mapping_side_effect(
|
||||
{
|
||||
"XFER": "TRANSFERRED",
|
||||
"ND": "NOT_QUALIFIED",
|
||||
}
|
||||
)
|
||||
|
||||
# Call send_end_task_frame
|
||||
await engine.send_end_task_frame(reason="user_qualified")
|
||||
|
||||
# Verify the frame was queued with mapped values
|
||||
mock_dependencies["task"].queue_frame.assert_called_once()
|
||||
frame = mock_dependencies["task"].queue_frame.call_args[0][0]
|
||||
|
||||
# Check metadata contains mapped values
|
||||
assert frame.metadata["reason"] == "user_qualified" # No mapping for this
|
||||
assert (
|
||||
frame.metadata["call_transfer_context"]["disposition"] == "TRANSFERRED"
|
||||
)
|
||||
|
||||
# Check gathered context was updated
|
||||
assert engine._gathered_context["call_disposition"] == "TRANSFERRED"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_disposition_mapping_with_disconnect_reason(mock_dependencies):
|
||||
"""Test disposition mapping for disconnect_reason when no call_disposition exists."""
|
||||
engine = PipecatEngine(**mock_dependencies)
|
||||
|
||||
# Setup gathered context without call_disposition
|
||||
engine._gathered_context = {
|
||||
"agent_name": "Alex",
|
||||
}
|
||||
|
||||
# Mock the disposition mapper functions
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run"
|
||||
) as mock_get_org_id:
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine.apply_disposition_mapping"
|
||||
) as mock_apply_mapping:
|
||||
# Mock organization ID
|
||||
mock_get_org_id.return_value = 1
|
||||
|
||||
# Mock disposition mapping
|
||||
mock_apply_mapping.side_effect = create_disposition_mapping_side_effect(
|
||||
{
|
||||
"user_qualified": "QUALIFIED",
|
||||
"user_disqualified": "NOT_QUALIFIED",
|
||||
"user_hangup": "HANGUP",
|
||||
}
|
||||
)
|
||||
|
||||
# Call send_end_task_frame with a mappable reason
|
||||
await engine.send_end_task_frame(reason="user_qualified")
|
||||
|
||||
# Verify the frame was queued with mapped disposition
|
||||
mock_dependencies["task"].queue_frame.assert_called_once()
|
||||
frame = mock_dependencies["task"].queue_frame.call_args[0][0]
|
||||
|
||||
# Check metadata contains original reason
|
||||
assert frame.metadata["reason"] == "user_qualified"
|
||||
|
||||
# Check call_transfer_context has mapped disconnect_reason as disposition
|
||||
assert frame.metadata["call_transfer_context"]["disposition"] == "QUALIFIED"
|
||||
|
||||
# Check gathered context was updated with mapped call_disposition
|
||||
assert engine._gathered_context["call_disposition"] == "QUALIFIED"
|
||||
|
||||
# Check internal call_disposition stores mapped value
|
||||
assert engine._call_disposition == "QUALIFIED"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_disposition_takes_precedence(mock_dependencies):
|
||||
"""Test that call_disposition is used when both call_disposition and reason could be mapped."""
|
||||
engine = PipecatEngine(**mock_dependencies)
|
||||
|
||||
# Setup gathered context with call_disposition
|
||||
engine._gathered_context = {
|
||||
"call_disposition": "XFER",
|
||||
"agent_name": "Alex",
|
||||
}
|
||||
|
||||
# Mock the disposition mapper functions
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run"
|
||||
) as mock_get_org_id:
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine.apply_disposition_mapping"
|
||||
) as mock_apply_mapping:
|
||||
# Mock organization ID
|
||||
mock_get_org_id.return_value = 1
|
||||
|
||||
# Mock disposition mapping
|
||||
mock_apply_mapping.side_effect = create_disposition_mapping_side_effect(
|
||||
{
|
||||
"XFER": "TRANSFERRED",
|
||||
"user_qualified": "QUALIFIED",
|
||||
}
|
||||
)
|
||||
|
||||
# Call send_end_task_frame with a reason that could also be mapped
|
||||
await engine.send_end_task_frame(reason="user_qualified")
|
||||
|
||||
# Verify the frame was queued
|
||||
mock_dependencies["task"].queue_frame.assert_called_once()
|
||||
frame = mock_dependencies["task"].queue_frame.call_args[0][0]
|
||||
|
||||
# Check that call_disposition mapping was used, not reason mapping
|
||||
assert (
|
||||
frame.metadata["call_transfer_context"]["disposition"] == "TRANSFERRED"
|
||||
)
|
||||
|
||||
# Check only call_disposition was updated in gathered context
|
||||
assert engine._gathered_context["call_disposition"] == "TRANSFERRED"
|
||||
assert "disconnect_reason" not in engine._gathered_context
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disposition_mapping_no_organization_id(mock_dependencies):
|
||||
"""Test when organization_id cannot be retrieved."""
|
||||
# Set workflow_run_id to None
|
||||
mock_dependencies["workflow_run_id"] = None
|
||||
engine = PipecatEngine(**mock_dependencies)
|
||||
|
||||
engine._gathered_context = {
|
||||
"call_disposition": "XFER",
|
||||
}
|
||||
|
||||
# Call send_end_task_frame
|
||||
await engine.send_end_task_frame(reason="user_qualified")
|
||||
|
||||
# Verify the frame was queued with original values (no mapping)
|
||||
mock_dependencies["task"].queue_frame.assert_called_once()
|
||||
frame = mock_dependencies["task"].queue_frame.call_args[0][0]
|
||||
|
||||
# Check values remain unchanged
|
||||
assert frame.metadata["reason"] == "user_qualified"
|
||||
assert frame.metadata["call_transfer_context"]["disposition"] == "XFER"
|
||||
|
||||
# Gathered context should remain unchanged
|
||||
assert engine._gathered_context["call_disposition"] == "XFER"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disposition_mapping_no_configuration(mock_dependencies):
|
||||
"""Test when no disposition mapping is configured."""
|
||||
engine = PipecatEngine(**mock_dependencies)
|
||||
|
||||
engine._gathered_context = {
|
||||
"call_disposition": "XFER",
|
||||
}
|
||||
|
||||
# Mock the disposition mapper functions
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run"
|
||||
) as mock_get_org_id:
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine.apply_disposition_mapping"
|
||||
) as mock_apply_mapping:
|
||||
# Mock organization ID
|
||||
mock_get_org_id.return_value = 1
|
||||
|
||||
# Mock no disposition mapping (return original value)
|
||||
mock_apply_mapping.side_effect = lambda value, org_id: value
|
||||
|
||||
# Call send_end_task_frame
|
||||
await engine.send_end_task_frame(reason="user_qualified")
|
||||
|
||||
# Verify the frame was queued with original values
|
||||
mock_dependencies["task"].queue_frame.assert_called_once()
|
||||
frame = mock_dependencies["task"].queue_frame.call_args[0][0]
|
||||
|
||||
# Check values remain unchanged
|
||||
assert frame.metadata["reason"] == "user_qualified"
|
||||
assert frame.metadata["call_transfer_context"]["disposition"] == "XFER"
|
||||
206
api/tests/test_pipecat_engine.py
Normal file
206
api/tests/test_pipecat_engine.py
Normal file
|
|
@ -0,0 +1,206 @@
|
|||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.workflow import WorkflowGraph
|
||||
|
||||
|
||||
class TestPipecatEngine:
|
||||
@pytest.fixture
|
||||
def mock_dependencies(self):
|
||||
"""Create mock dependencies for PipecatEngine initialization."""
|
||||
return {
|
||||
"task": Mock(),
|
||||
"llm": Mock(),
|
||||
"context": Mock(),
|
||||
"tts": Mock(),
|
||||
"transport": Mock(),
|
||||
"workflow": Mock(spec=WorkflowGraph),
|
||||
"call_context_vars": {},
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def engine_with_context(self, mock_dependencies):
|
||||
"""Create a PipecatEngine instance with test context variables."""
|
||||
context_vars = {
|
||||
"first_name": "John",
|
||||
"last_name": "Doe",
|
||||
"age": 25,
|
||||
"email": "john.doe@example.com",
|
||||
"empty_var": "",
|
||||
"zero_var": 0,
|
||||
"false_var": False,
|
||||
}
|
||||
mock_dependencies["call_context_vars"] = context_vars
|
||||
return PipecatEngine(**mock_dependencies)
|
||||
|
||||
@pytest.fixture
|
||||
def engine_empty_context(self, mock_dependencies):
|
||||
"""Create a PipecatEngine instance with empty context variables."""
|
||||
mock_dependencies["call_context_vars"] = {}
|
||||
return PipecatEngine(**mock_dependencies)
|
||||
|
||||
def test_format_prompt_simple_variable_replacement(self, engine_with_context):
|
||||
"""Test simple variable replacement without filters."""
|
||||
prompt = "Hello {{ first_name }}, welcome!"
|
||||
result = engine_with_context._format_prompt(prompt)
|
||||
assert result == "Hello John, welcome!"
|
||||
|
||||
def test_format_prompt_multiple_variables(self, engine_with_context):
|
||||
"""Test multiple variable replacements in a single prompt."""
|
||||
prompt = "Hello {{ first_name }} {{ last_name }}, you are {{ age }} years old."
|
||||
result = engine_with_context._format_prompt(prompt)
|
||||
assert result == "Hello John Doe, you are 25 years old."
|
||||
|
||||
def test_format_prompt_with_fallback_existing_value(self, engine_with_context):
|
||||
"""Test fallback filter when value exists."""
|
||||
prompt = "Hello {{ first_name | fallback }}, nice to meet you!"
|
||||
result = engine_with_context._format_prompt(prompt)
|
||||
assert result == "Hello John, nice to meet you!"
|
||||
|
||||
def test_format_prompt_with_fallback_missing_value(self, engine_empty_context):
|
||||
"""Test fallback filter when value is missing."""
|
||||
prompt = "Hello {{ first_name | fallback }}, nice to meet you!"
|
||||
result = engine_empty_context._format_prompt(prompt)
|
||||
assert result == "Hello First_Name, nice to meet you!"
|
||||
|
||||
def test_format_prompt_with_custom_fallback_missing_value(
|
||||
self, engine_empty_context
|
||||
):
|
||||
"""Test fallback filter with custom fallback value when variable is missing."""
|
||||
prompt = "Hello {{ first_name | fallback:Guest }}, welcome!"
|
||||
result = engine_empty_context._format_prompt(prompt)
|
||||
assert result == "Hello Guest, welcome!"
|
||||
|
||||
def test_format_prompt_with_custom_fallback_existing_value(
|
||||
self, engine_with_context
|
||||
):
|
||||
"""Test fallback filter with custom fallback value when variable exists."""
|
||||
prompt = "Hello {{ first_name | fallback:Guest }}, welcome!"
|
||||
result = engine_with_context._format_prompt(prompt)
|
||||
assert result == "Hello John, welcome!"
|
||||
|
||||
def test_format_prompt_empty_string_variable(self, engine_with_context):
|
||||
"""Test variable with empty string value."""
|
||||
prompt = "Value: '{{ empty_var | fallback:No Value }}'"
|
||||
result = engine_with_context._format_prompt(prompt)
|
||||
assert result == "Value: 'No Value'"
|
||||
|
||||
def test_format_prompt_zero_value(self, engine_with_context):
|
||||
"""Test variable with zero value (should not trigger fallback)."""
|
||||
prompt = "Count: {{ zero_var | fallback:None }}"
|
||||
result = engine_with_context._format_prompt(prompt)
|
||||
assert result == "Count: 0"
|
||||
|
||||
def test_format_prompt_false_value(self, engine_with_context):
|
||||
"""Test variable with False value (should not trigger fallback)."""
|
||||
prompt = "Status: {{ false_var | fallback:Unknown }}"
|
||||
result = engine_with_context._format_prompt(prompt)
|
||||
assert result == "Status: False"
|
||||
|
||||
def test_format_prompt_missing_variable_no_fallback(self, engine_empty_context):
|
||||
"""Test missing variable without fallback filter."""
|
||||
prompt = "Hello {{ missing_var }}, welcome!"
|
||||
result = engine_empty_context._format_prompt(prompt)
|
||||
assert result == "Hello , welcome!"
|
||||
|
||||
def test_format_prompt_complex_mixed_scenario(self, engine_with_context):
|
||||
"""Test complex scenario with multiple variables, some with fallbacks."""
|
||||
prompt = (
|
||||
"Dear {{ first_name | fallback:Customer }}, "
|
||||
"your email {{ email }} is confirmed. "
|
||||
"{{ missing_info | fallback:Additional information }} will be sent later. "
|
||||
"You are {{ age }} years old."
|
||||
)
|
||||
result = engine_with_context._format_prompt(prompt)
|
||||
expected = (
|
||||
"Dear John, "
|
||||
"your email john.doe@example.com is confirmed. "
|
||||
"Additional information will be sent later. "
|
||||
"You are 25 years old."
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
def test_format_prompt_whitespace_handling(self, engine_with_context):
|
||||
"""Test handling of whitespace in template variables."""
|
||||
prompt = "Hello {{ first_name | fallback : Default }}, welcome!"
|
||||
result = engine_with_context._format_prompt(prompt)
|
||||
assert result == "Hello John, welcome!"
|
||||
|
||||
def test_format_prompt_no_variables(self, engine_with_context):
|
||||
"""Test prompt with no template variables."""
|
||||
prompt = "This is a regular prompt with no variables."
|
||||
result = engine_with_context._format_prompt(prompt)
|
||||
assert result == "This is a regular prompt with no variables."
|
||||
|
||||
def test_format_prompt_empty_prompt(self, engine_with_context):
|
||||
"""Test empty prompt."""
|
||||
prompt = ""
|
||||
result = engine_with_context._format_prompt(prompt)
|
||||
assert result == ""
|
||||
|
||||
def test_format_prompt_none_prompt(self, engine_with_context):
|
||||
"""Test None prompt."""
|
||||
prompt = None
|
||||
result = engine_with_context._format_prompt(prompt)
|
||||
assert result is None
|
||||
|
||||
def test_format_prompt_nested_braces(self, engine_with_context):
|
||||
"""Test handling of nested or malformed braces."""
|
||||
prompt = "Hello {{ first_name }}, this {is not a template} variable."
|
||||
result = engine_with_context._format_prompt(prompt)
|
||||
assert result == "Hello John, this {is not a template} variable."
|
||||
|
||||
def test_format_prompt_special_characters_in_value(self):
|
||||
"""Test variables containing special characters."""
|
||||
mock_deps = {
|
||||
"task": Mock(),
|
||||
"llm": Mock(),
|
||||
"context": Mock(),
|
||||
"tts": Mock(),
|
||||
"transport": Mock(),
|
||||
"workflow": Mock(spec=WorkflowGraph),
|
||||
"call_context_vars": {
|
||||
"special_name": "John & Jane's Company",
|
||||
"email": "test@domain.com",
|
||||
},
|
||||
}
|
||||
engine = PipecatEngine(**mock_deps)
|
||||
|
||||
prompt = "Company: {{ special_name }}, Contact: {{ email }}"
|
||||
result = engine._format_prompt(prompt)
|
||||
assert result == "Company: John & Jane's Company, Contact: test@domain.com"
|
||||
|
||||
def test_format_prompt_numeric_and_boolean_conversion(self):
|
||||
"""Test conversion of different data types to strings."""
|
||||
mock_deps = {
|
||||
"task": Mock(),
|
||||
"llm": Mock(),
|
||||
"context": Mock(),
|
||||
"tts": Mock(),
|
||||
"transport": Mock(),
|
||||
"workflow": Mock(spec=WorkflowGraph),
|
||||
"call_context_vars": {
|
||||
"count": 42,
|
||||
"price": 99.99,
|
||||
"is_active": True,
|
||||
"items": ["apple", "banana"],
|
||||
},
|
||||
}
|
||||
engine = PipecatEngine(**mock_deps)
|
||||
|
||||
prompt = "Count: {{ count }}, Price: ${{ price }}, Active: {{ is_active }}, Items: {{ items }}"
|
||||
result = engine._format_prompt(prompt)
|
||||
assert (
|
||||
result
|
||||
== "Count: 42, Price: $99.99, Active: True, Items: ['apple', 'banana']"
|
||||
)
|
||||
|
||||
def test_format_prompt_case_sensitivity(self, engine_with_context):
|
||||
"""Test that variable names are case sensitive."""
|
||||
prompt = (
|
||||
"Hello {{ First_Name | fallback }}, welcome!" # Note the capitalization
|
||||
)
|
||||
result = engine_with_context._format_prompt(prompt)
|
||||
assert result == "Hello First_Name, welcome!" # Should use fallback
|
||||
536
api/tests/test_pipecat_engine_set_node.py
Normal file
536
api/tests/test_pipecat_engine_set_node.py
Normal file
|
|
@ -0,0 +1,536 @@
|
|||
import asyncio
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
TTSSpeakFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
|
||||
from pipecat.services.openai.llm import OpenAILLMContext
|
||||
|
||||
from api.services.workflow.dto import EdgeDataDTO, NodeDataDTO
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.workflow import Edge, Node, WorkflowGraph
|
||||
|
||||
|
||||
class TestPipecatEngineSetNode:
|
||||
"""Test cases for PipecatEngine.set_node method refactoring."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_workflow(self):
|
||||
"""Create a mock workflow with various node types."""
|
||||
workflow = Mock(spec=WorkflowGraph)
|
||||
workflow.nodes = {}
|
||||
workflow.start_node_id = "start_node"
|
||||
workflow.global_node_id = None
|
||||
return workflow
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dependencies(self, mock_workflow):
|
||||
"""Create mock dependencies for PipecatEngine initialization."""
|
||||
task = AsyncMock()
|
||||
task.queue_frames = AsyncMock()
|
||||
task.queue_frame = AsyncMock()
|
||||
|
||||
llm = AsyncMock()
|
||||
llm.register_function = Mock()
|
||||
llm.push_frame = AsyncMock()
|
||||
|
||||
context = Mock(spec=OpenAILLMContext)
|
||||
context.set_node_name = Mock()
|
||||
|
||||
return {
|
||||
"task": task,
|
||||
"llm": llm,
|
||||
"context": context,
|
||||
"tts": Mock(),
|
||||
"transport": Mock(),
|
||||
"workflow": mock_workflow,
|
||||
"call_context_vars": {"test_var": "test_value"},
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def engine(self, mock_dependencies):
|
||||
"""Create a PipecatEngine instance."""
|
||||
# Add audio_buffer and workflow_run_id to dependencies
|
||||
mock_dependencies["audio_buffer"] = None
|
||||
mock_dependencies["workflow_run_id"] = 123
|
||||
engine = PipecatEngine(**mock_dependencies)
|
||||
# Mock the builtin function registration
|
||||
engine._register_builtin_functions = AsyncMock()
|
||||
return engine
|
||||
|
||||
def create_node(self, node_id, **kwargs):
|
||||
"""Helper to create a node with default values."""
|
||||
defaults = {
|
||||
"name": f"Node {node_id}",
|
||||
"prompt": f"Prompt for {node_id}",
|
||||
"is_static": False,
|
||||
"is_start": False,
|
||||
"is_end": False,
|
||||
"allow_interrupt": True,
|
||||
"extraction_enabled": False,
|
||||
"extraction_prompt": "",
|
||||
"extraction_variables": [],
|
||||
"add_global_prompt": True,
|
||||
"wait_for_user_response": False,
|
||||
"detect_voicemail": False,
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
|
||||
data = Mock(spec=NodeDataDTO)
|
||||
for key, value in defaults.items():
|
||||
setattr(data, key, value)
|
||||
|
||||
node = Mock(spec=Node)
|
||||
node.id = node_id
|
||||
node.data = data
|
||||
node.out_edges = []
|
||||
|
||||
# Copy attributes from data to node
|
||||
for key, value in defaults.items():
|
||||
setattr(node, key, value)
|
||||
|
||||
return node
|
||||
|
||||
def create_edge(
|
||||
self, source, target, label="Continue", condition="Always continue"
|
||||
):
|
||||
"""Helper to create an edge."""
|
||||
data = Mock(spec=EdgeDataDTO)
|
||||
data.label = label
|
||||
data.condition = condition
|
||||
|
||||
edge = Mock(spec=Edge)
|
||||
edge.source = source
|
||||
edge.target = target
|
||||
edge.data = data
|
||||
edge.get_function_name = Mock(return_value=label.lower().replace(" ", "_"))
|
||||
|
||||
return edge
|
||||
|
||||
# ===== START NODE TESTS =====
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_node_static_immediate_execution(self, engine, mock_workflow):
|
||||
"""Test: Basic static start node executes immediately."""
|
||||
# Setup
|
||||
start_node = self.create_node(
|
||||
"start_node",
|
||||
is_start=True,
|
||||
is_static=True,
|
||||
prompt="Welcome to our service!",
|
||||
)
|
||||
next_node = self.create_node("next_node", is_static=False)
|
||||
|
||||
edge = self.create_edge("start_node", "next_node")
|
||||
start_node.out_edges = [edge]
|
||||
|
||||
mock_workflow.nodes = {"start_node": start_node, "next_node": next_node}
|
||||
|
||||
# Execute
|
||||
await engine.set_node("start_node")
|
||||
|
||||
# Verify
|
||||
# Should queue TTS immediately
|
||||
engine.task.queue_frames.assert_called_once()
|
||||
frames = engine.task.queue_frames.call_args[0][0]
|
||||
assert len(frames) == 3
|
||||
assert isinstance(frames[0], LLMFullResponseStartFrame)
|
||||
assert isinstance(frames[1], TTSSpeakFrame)
|
||||
assert frames[1].text == "Welcome to our service!"
|
||||
assert isinstance(frames[2], LLMFullResponseEndFrame)
|
||||
|
||||
# Static start nodes now set pending transition after context push
|
||||
assert engine._pending_control_transition_after_context_push is not None
|
||||
|
||||
# Should not have set detect_voicemail for static start without it
|
||||
assert not engine._detect_voicemail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_node_with_detect_voicemail_no_audio_buffer(
|
||||
self, engine, mock_workflow
|
||||
):
|
||||
"""Test: Start node with voicemail detection but no audio buffer logs warning."""
|
||||
# Setup
|
||||
start_node = self.create_node(
|
||||
"start_node",
|
||||
is_start=True,
|
||||
is_static=True,
|
||||
detect_voicemail=True,
|
||||
prompt="Hello, this is a business call.",
|
||||
)
|
||||
|
||||
mock_workflow.nodes = {"start_node": start_node}
|
||||
|
||||
# Engine has no audio buffer (None)
|
||||
assert engine._audio_buffer is None
|
||||
|
||||
# Execute
|
||||
await engine.set_node("start_node")
|
||||
|
||||
# Verify
|
||||
# Should NOT set voicemail detection flag since no audio buffer
|
||||
assert engine._detect_voicemail is False
|
||||
assert engine._voicemail_detector is None
|
||||
|
||||
# Should queue TTS immediately
|
||||
engine.task.queue_frames.assert_called_once()
|
||||
frames = engine.task.queue_frames.call_args[0][0]
|
||||
assert isinstance(frames[1], TTSSpeakFrame)
|
||||
assert frames[1].text == "Hello, this is a business call."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_node_non_static_with_detect_voicemail(
|
||||
self, engine, mock_workflow
|
||||
):
|
||||
"""Test: Non-static start node with voicemail detection without audio buffer."""
|
||||
# Setup
|
||||
start_node = self.create_node(
|
||||
"start_node",
|
||||
is_start=True,
|
||||
is_static=False, # Non-static
|
||||
detect_voicemail=True,
|
||||
prompt="You are an AI assistant. Start the conversation.",
|
||||
)
|
||||
|
||||
mock_workflow.nodes = {"start_node": start_node}
|
||||
|
||||
# Mock the context update method
|
||||
engine._update_llm_context = AsyncMock()
|
||||
engine._compose_system_message_functions_for_node = AsyncMock(
|
||||
return_value=({"role": "system", "content": "Test prompt"}, [])
|
||||
)
|
||||
|
||||
# Execute
|
||||
await engine.set_node("start_node")
|
||||
|
||||
# Verify
|
||||
# Should NOT set voicemail detection flags (no audio buffer)
|
||||
assert engine._detect_voicemail is False
|
||||
assert engine._voicemail_detector is None
|
||||
|
||||
# Should update LLM context for non-static node
|
||||
engine._update_llm_context.assert_called_once()
|
||||
|
||||
# Should queue context frame
|
||||
engine.task.queue_frame.assert_called_once()
|
||||
frame = engine.task.queue_frame.call_args[0][0]
|
||||
assert isinstance(frame, OpenAILLMContextFrame)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_node_static_with_wait_for_user_response(
|
||||
self, engine, mock_workflow
|
||||
):
|
||||
"""Test: Static start node with wait_for_user_response."""
|
||||
# Setup
|
||||
start_node = self.create_node(
|
||||
"start_node",
|
||||
is_start=True,
|
||||
is_static=True,
|
||||
wait_for_user_response=True,
|
||||
prompt="Please tell me your name.",
|
||||
)
|
||||
next_node = self.create_node("next_node")
|
||||
|
||||
edge = self.create_edge("start_node", "next_node")
|
||||
start_node.out_edges = [edge]
|
||||
|
||||
mock_workflow.nodes = {"start_node": start_node, "next_node": next_node}
|
||||
|
||||
# Execute
|
||||
await engine.set_node("start_node")
|
||||
|
||||
# Verify
|
||||
# Should queue TTS immediately
|
||||
engine.task.queue_frames.assert_called_once()
|
||||
|
||||
# Should have a pending control transition that will start the timer
|
||||
assert engine._pending_control_transition_after_context_push is not None
|
||||
|
||||
# Timer task should not exist yet
|
||||
assert (
|
||||
not hasattr(engine, "_user_response_timeout_task")
|
||||
or engine._user_response_timeout_task is None
|
||||
)
|
||||
|
||||
# Simulate context push to start the timer
|
||||
await engine.flush_pending_transitions(source="context_push")
|
||||
|
||||
# Now the timeout task should be created
|
||||
assert engine._user_response_timeout_task is not None
|
||||
assert not engine._user_response_timeout_task.done()
|
||||
|
||||
# Clean up the task
|
||||
engine._user_response_timeout_task.cancel()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_node_non_static(self, engine, mock_workflow):
|
||||
"""Test: Non-static start node sends context to LLM."""
|
||||
# Setup
|
||||
start_node = self.create_node(
|
||||
"start_node",
|
||||
is_start=True,
|
||||
is_static=False,
|
||||
prompt="You are a helpful assistant. Greet the user.",
|
||||
)
|
||||
|
||||
mock_workflow.nodes = {"start_node": start_node}
|
||||
|
||||
# Mock the context update method
|
||||
engine._update_llm_context = AsyncMock()
|
||||
engine._compose_system_message_functions_for_node = AsyncMock(
|
||||
return_value=({"role": "system", "content": "Test prompt"}, [])
|
||||
)
|
||||
|
||||
# Execute
|
||||
await engine.set_node("start_node")
|
||||
|
||||
# Verify
|
||||
# Should set context name
|
||||
engine.context.set_node_name.assert_called_once_with("Node start_node")
|
||||
|
||||
# Should update LLM context
|
||||
engine._update_llm_context.assert_called_once()
|
||||
|
||||
# Should queue context frame
|
||||
engine.task.queue_frame.assert_called_once()
|
||||
frame = engine.task.queue_frame.call_args[0][0]
|
||||
assert isinstance(frame, OpenAILLMContextFrame)
|
||||
|
||||
# ===== AGENT NODE TESTS =====
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_node_static(self, engine, mock_workflow):
|
||||
"""Test: Static agent node plays TTS and transitions."""
|
||||
# Setup
|
||||
agent_node = self.create_node(
|
||||
"agent_node", is_static=True, prompt="Processing your request..."
|
||||
)
|
||||
next_node = self.create_node("next_node")
|
||||
|
||||
edge = self.create_edge("agent_node", "next_node")
|
||||
agent_node.out_edges = [edge]
|
||||
|
||||
mock_workflow.nodes = {"agent_node": agent_node, "next_node": next_node}
|
||||
|
||||
# Execute
|
||||
await engine.set_node("agent_node")
|
||||
|
||||
# Verify
|
||||
# Should queue TTS
|
||||
engine.task.queue_frames.assert_called_once()
|
||||
frames = engine.task.queue_frames.call_args[0][0]
|
||||
assert isinstance(frames[1], TTSSpeakFrame)
|
||||
assert frames[1].text == "Processing your request..."
|
||||
|
||||
# Should have pending transition
|
||||
assert engine._pending_control_transition_after_context_push is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_node_non_static(self, engine, mock_workflow):
|
||||
"""Test: Non-static agent node sends context to LLM."""
|
||||
# Setup
|
||||
agent_node = self.create_node(
|
||||
"agent_node",
|
||||
is_static=False,
|
||||
prompt="Analyze the user's request and respond appropriately.",
|
||||
)
|
||||
decision_node = self.create_node("decision_node")
|
||||
|
||||
edge = self.create_edge("agent_node", "decision_node", "analyze_complete")
|
||||
agent_node.out_edges = [edge]
|
||||
|
||||
mock_workflow.nodes = {"agent_node": agent_node, "decision_node": decision_node}
|
||||
|
||||
# Mock methods
|
||||
engine._update_llm_context = AsyncMock()
|
||||
engine._compose_system_message_functions_for_node = AsyncMock(
|
||||
return_value=(
|
||||
{"role": "system", "content": "Test"},
|
||||
[{"name": "test_func"}],
|
||||
)
|
||||
)
|
||||
|
||||
# Execute
|
||||
await engine.set_node("agent_node")
|
||||
|
||||
# Verify
|
||||
# Should register transition function
|
||||
engine.llm.register_function.assert_called_once()
|
||||
call_args = engine.llm.register_function.call_args
|
||||
assert call_args[0][0] == "analyze_complete"
|
||||
assert callable(call_args[0][1]) # Check it's a function
|
||||
assert call_args[1]["cancel_on_interruption"] is True
|
||||
|
||||
# Should update context and send frame
|
||||
engine._update_llm_context.assert_called_once()
|
||||
engine.task.queue_frame.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_node_with_interruption_control(self, engine, mock_workflow):
|
||||
"""Test: Agent node respects allow_interrupt flag."""
|
||||
# Setup
|
||||
no_interrupt_node = self.create_node(
|
||||
"no_interrupt",
|
||||
is_static=True,
|
||||
allow_interrupt=False,
|
||||
prompt="Please wait while I process...",
|
||||
)
|
||||
|
||||
mock_workflow.nodes = {"no_interrupt": no_interrupt_node}
|
||||
|
||||
# Execute
|
||||
await engine.set_node("no_interrupt")
|
||||
|
||||
# Verify current node is set (for STT mute callback)
|
||||
assert engine._current_node == no_interrupt_node
|
||||
assert engine._current_node.allow_interrupt is False
|
||||
|
||||
# ===== END NODE TESTS =====
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_node_static(self, engine, mock_workflow):
|
||||
"""Test: Static end node plays final message and schedules end task."""
|
||||
# Setup
|
||||
end_node = self.create_node(
|
||||
"end_node",
|
||||
is_static=True,
|
||||
is_end=True,
|
||||
prompt="Thank you for calling. Goodbye!",
|
||||
)
|
||||
|
||||
mock_workflow.nodes = {"end_node": end_node}
|
||||
|
||||
# Execute
|
||||
await engine.set_node("end_node")
|
||||
|
||||
# Verify
|
||||
# Should queue TTS
|
||||
engine.task.queue_frames.assert_called_once()
|
||||
frames = engine.task.queue_frames.call_args[0][0]
|
||||
assert frames[1].text == "Thank you for calling. Goodbye!"
|
||||
|
||||
# Should have pending end task
|
||||
assert engine._pending_control_transition_after_context_push is not None
|
||||
|
||||
# Execute the pending transition
|
||||
await engine._pending_control_transition_after_context_push()
|
||||
|
||||
# Should have sent EndFrame via task.queue_frame
|
||||
# The second call should be the EndFrame (first was TTS frames)
|
||||
assert engine.task.queue_frame.call_count >= 1
|
||||
end_frame = engine.task.queue_frame.call_args[0][0]
|
||||
assert isinstance(end_frame, EndFrame)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_node_with_extraction(self, engine, mock_workflow):
|
||||
"""Test: End node with variable extraction."""
|
||||
# Setup
|
||||
end_node = self.create_node(
|
||||
"end_node",
|
||||
is_end=True,
|
||||
is_static=False,
|
||||
extraction_enabled=True,
|
||||
extraction_variables=["user_name", "satisfaction_level"],
|
||||
extraction_prompt="Extract user name and satisfaction",
|
||||
)
|
||||
|
||||
mock_workflow.nodes = {"end_node": end_node}
|
||||
|
||||
# Mock the extraction manager
|
||||
engine._variable_extraction_manager = Mock()
|
||||
engine._perform_variable_extraction_if_needed = AsyncMock()
|
||||
|
||||
# Mock context update and composition methods
|
||||
engine._update_llm_context = AsyncMock()
|
||||
engine._compose_system_message_functions_for_node = AsyncMock(
|
||||
return_value=({"role": "system", "content": "Test"}, [])
|
||||
)
|
||||
|
||||
# Execute
|
||||
await engine.set_node("end_node")
|
||||
|
||||
# Verify
|
||||
# Should trigger extraction
|
||||
engine._perform_variable_extraction_if_needed.assert_called_once_with(end_node)
|
||||
|
||||
# Should have pending end task
|
||||
assert engine._pending_control_transition_after_context_push is not None
|
||||
|
||||
# ===== CALLBACK INTEGRATION TESTS =====
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_stopped_speaking_during_response_wait(
|
||||
self, engine, mock_workflow
|
||||
):
|
||||
"""Test: User stops speaking triggers transition during wait_for_response."""
|
||||
# Setup
|
||||
start_node = self.create_node(
|
||||
"start_node", is_start=True, is_static=True, wait_for_user_response=True
|
||||
)
|
||||
next_node = self.create_node("next_node")
|
||||
|
||||
edge = self.create_edge("start_node", "next_node")
|
||||
start_node.out_edges = [edge]
|
||||
|
||||
mock_workflow.nodes = {"start_node": start_node, "next_node": next_node}
|
||||
|
||||
# Set current node to start node
|
||||
engine._current_node = start_node
|
||||
engine._user_response_timeout_task = asyncio.create_task(asyncio.sleep(3))
|
||||
|
||||
# Create callback and execute
|
||||
callback = engine.create_user_stopped_speaking_callback()
|
||||
|
||||
# Mock set_node to avoid recursion
|
||||
with patch.object(engine, "set_node", new=AsyncMock()) as mock_set_node:
|
||||
await callback()
|
||||
|
||||
# Verify
|
||||
mock_set_node.assert_called_once_with("next_node")
|
||||
assert engine._queue_context_frame is False # Should be set to False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_push_callback_executes_pending_transitions(self, engine):
|
||||
"""Test: flush_pending_transitions executes deferred transitions."""
|
||||
# Setup pending transitions
|
||||
mock_generated_transition = AsyncMock()
|
||||
mock_control_transition = AsyncMock()
|
||||
|
||||
engine._pending_generated_transition_after_context_push = (
|
||||
mock_generated_transition
|
||||
)
|
||||
engine._pending_control_transition_after_context_push = mock_control_transition
|
||||
|
||||
# Execute
|
||||
await engine.flush_pending_transitions(source="context_push")
|
||||
|
||||
# Verify both transitions were executed
|
||||
mock_generated_transition.assert_called_once()
|
||||
mock_control_transition.assert_called_once()
|
||||
|
||||
# Verify they were cleared
|
||||
assert engine._pending_generated_transition_after_context_push is None
|
||||
assert engine._pending_control_transition_after_context_push is None
|
||||
|
||||
# ===== COMPLEX SCENARIO TESTS =====
|
||||
|
||||
|
||||
# Add helper for testing with real async behavior
|
||||
def ANY(cls=None):
|
||||
"""Helper for matching any argument in mock calls."""
|
||||
|
||||
class AnyMatcher:
|
||||
def __init__(self, cls):
|
||||
self.cls = cls
|
||||
|
||||
def __eq__(self, other):
|
||||
if self.cls:
|
||||
return isinstance(other, self.cls)
|
||||
return True
|
||||
|
||||
return AnyMatcher(cls)
|
||||
266
api/tests/test_run_integrations_db_client.py
Normal file
266
api/tests/test_run_integrations_db_client.py
Normal file
|
|
@ -0,0 +1,266 @@
|
|||
"""Tests for run_integrations with new DB client methods."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from api.enums import WorkflowRunMode
|
||||
from api.tasks.run_integrations import run_integrations_post_workflow_run
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_logger():
|
||||
"""Mock the logger for all tests."""
|
||||
with patch("api.tasks.run_integrations.logger") as mock_logger:
|
||||
mock_logger.bind.return_value = mock_logger
|
||||
yield mock_logger
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_workflow_run():
|
||||
"""Create a mock workflow run with all required attributes."""
|
||||
workflow_run = MagicMock()
|
||||
workflow_run.id = 1
|
||||
workflow_run.mode = "browser"
|
||||
workflow_run.gathered_context = {
|
||||
"call_disposition": "XFER",
|
||||
"mapped_call_disposition": "XFER", # Required for Slack integration
|
||||
"call_duration": "120",
|
||||
"agent_name": "TestAgent",
|
||||
}
|
||||
workflow_run.initial_context = {"vendor_id": "123"}
|
||||
|
||||
# Setup workflow and user chain
|
||||
workflow_run.workflow = MagicMock()
|
||||
workflow_run.workflow.user = MagicMock()
|
||||
workflow_run.workflow.user.selected_organization_id = 100
|
||||
|
||||
return workflow_run
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_integration():
|
||||
"""Create a mock integration."""
|
||||
integration = MagicMock()
|
||||
integration.id = 1
|
||||
integration.organisation_id = 100
|
||||
integration.provider = "slack"
|
||||
integration.is_active = True
|
||||
integration.connection_details = {
|
||||
"connection_config": {"incoming_webhook.url": "https://hooks.slack.com/test"}
|
||||
}
|
||||
return integration
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_integrations_with_db_client_methods(
|
||||
mock_workflow_run, mock_integration
|
||||
):
|
||||
"""Test that run_integrations uses the new DB client methods correctly."""
|
||||
|
||||
with patch("api.tasks.run_integrations.set_current_run_id") as mock_set_run_id:
|
||||
with patch("api.tasks.run_integrations.db_client") as mock_db_client:
|
||||
# Mock the new DB client methods
|
||||
mock_db_client.get_workflow_run_with_context = AsyncMock(
|
||||
return_value=(mock_workflow_run, 100)
|
||||
)
|
||||
mock_db_client.get_active_integrations_by_organization = AsyncMock(
|
||||
return_value=[mock_integration]
|
||||
)
|
||||
mock_db_client.get_configuration_value = AsyncMock(
|
||||
return_value={
|
||||
"slack": {
|
||||
"DISPOSITION_CODE": "Disposition: {{mapped_call_disposition}}"
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Mock the aiohttp session for Slack webhook
|
||||
with patch(
|
||||
"api.tasks.run_integrations.aiohttp.ClientSession"
|
||||
) as mock_session_class:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.__aenter__.return_value = mock_session
|
||||
mock_session.__aexit__.return_value = AsyncMock()
|
||||
|
||||
mock_post = MagicMock()
|
||||
mock_post.__aenter__.return_value = mock_response
|
||||
mock_post.__aexit__.return_value = AsyncMock()
|
||||
|
||||
mock_session.post.return_value = mock_post
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
# Call the function
|
||||
await run_integrations_post_workflow_run(None, 1)
|
||||
|
||||
# Verify the correct DB client methods were called
|
||||
mock_set_run_id.assert_called_once_with(1)
|
||||
mock_db_client.get_workflow_run_with_context.assert_called_once_with(1)
|
||||
mock_db_client.get_active_integrations_by_organization.assert_called_once_with(
|
||||
100
|
||||
)
|
||||
|
||||
# Verify the Slack webhook was called
|
||||
mock_session.post.assert_called_once()
|
||||
assert (
|
||||
mock_session.post.call_args[0][0] == "https://hooks.slack.com/test"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_integrations_no_workflow_run():
|
||||
"""Test handling when workflow run is not found."""
|
||||
|
||||
with patch("api.tasks.run_integrations.set_current_run_id"):
|
||||
with patch("api.tasks.run_integrations.db_client") as mock_db_client:
|
||||
# Mock workflow run not found
|
||||
mock_db_client.get_workflow_run_with_context = AsyncMock(
|
||||
return_value=(None, None)
|
||||
)
|
||||
|
||||
# Call the function
|
||||
await run_integrations_post_workflow_run(None, 999)
|
||||
|
||||
# Verify it returns early and doesn't call other DB methods
|
||||
mock_db_client.get_workflow_run_with_context.assert_called_once_with(999)
|
||||
mock_db_client.get_active_integrations_by_organization.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_integrations_no_organization():
|
||||
"""Test handling when user has no organization."""
|
||||
|
||||
mock_workflow_run = MagicMock()
|
||||
mock_workflow_run.id = 1
|
||||
mock_workflow_run.gathered_context = {"test": "data"}
|
||||
mock_workflow_run.workflow = MagicMock()
|
||||
mock_workflow_run.workflow.user = MagicMock()
|
||||
|
||||
with patch("api.tasks.run_integrations.set_current_run_id"):
|
||||
with patch("api.tasks.run_integrations.db_client") as mock_db_client:
|
||||
# Mock workflow run found but no organization
|
||||
mock_db_client.get_workflow_run_with_context = AsyncMock(
|
||||
return_value=(mock_workflow_run, None)
|
||||
)
|
||||
|
||||
# Call the function
|
||||
await run_integrations_post_workflow_run(None, 1)
|
||||
|
||||
# Verify it returns early after checking organization
|
||||
mock_db_client.get_workflow_run_with_context.assert_called_once_with(1)
|
||||
mock_db_client.get_active_integrations_by_organization.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_integrations_no_gathered_context(mock_workflow_run):
|
||||
"""Test handling when workflow run has no gathered context."""
|
||||
|
||||
mock_workflow_run.gathered_context = None
|
||||
|
||||
with patch("api.tasks.run_integrations.set_current_run_id"):
|
||||
with patch("api.tasks.run_integrations.db_client") as mock_db_client:
|
||||
# Mock workflow run with no gathered context
|
||||
mock_db_client.get_workflow_run_with_context = AsyncMock(
|
||||
return_value=(mock_workflow_run, 100)
|
||||
)
|
||||
|
||||
# Call the function
|
||||
await run_integrations_post_workflow_run(None, 1)
|
||||
|
||||
# Verify it returns early after checking gathered_context
|
||||
mock_db_client.get_workflow_run_with_context.assert_called_once_with(1)
|
||||
mock_db_client.get_active_integrations_by_organization.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_integrations_stasis_mode(mock_workflow_run):
|
||||
"""Test that stasis mode triggers vendor sync."""
|
||||
|
||||
mock_workflow_run.mode = WorkflowRunMode.STASIS.value
|
||||
mock_workflow_run.initial_context = {
|
||||
"vendor": "test_vendor",
|
||||
"vendor_base_url": "https://api.vendor.com",
|
||||
"vendor_id": "123",
|
||||
}
|
||||
|
||||
with patch("api.tasks.run_integrations.set_current_run_id"):
|
||||
with patch("api.tasks.run_integrations.db_client") as mock_db_client:
|
||||
with patch("api.tasks.run_integrations._sync_vendor_data") as mock_sync:
|
||||
mock_sync.return_value = None
|
||||
|
||||
mock_db_client.get_workflow_run_with_context = AsyncMock(
|
||||
return_value=(mock_workflow_run, 100)
|
||||
)
|
||||
mock_db_client.get_active_integrations_by_organization = AsyncMock(
|
||||
return_value=[]
|
||||
)
|
||||
|
||||
# Call the function
|
||||
await run_integrations_post_workflow_run(None, 1)
|
||||
|
||||
# Verify vendor sync was called
|
||||
mock_sync.assert_called_once_with(
|
||||
mock_workflow_run.initial_context,
|
||||
mock_workflow_run.gathered_context,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_integrations_multiple_integrations(mock_workflow_run):
|
||||
"""Test processing multiple integrations."""
|
||||
|
||||
# Create multiple mock integrations
|
||||
slack_integration = MagicMock()
|
||||
slack_integration.provider = "slack"
|
||||
slack_integration.connection_details = {
|
||||
"connection_config": {"incoming_webhook.url": "https://hooks.slack.com/test1"}
|
||||
}
|
||||
|
||||
slack_integration2 = MagicMock()
|
||||
slack_integration2.provider = "slack"
|
||||
slack_integration2.connection_details = {
|
||||
"connection_config": {"incoming_webhook.url": "https://hooks.slack.com/test2"}
|
||||
}
|
||||
|
||||
with patch("api.tasks.run_integrations.set_current_run_id"):
|
||||
with patch("api.tasks.run_integrations.db_client") as mock_db_client:
|
||||
mock_db_client.get_workflow_run_with_context = AsyncMock(
|
||||
return_value=(mock_workflow_run, 100)
|
||||
)
|
||||
mock_db_client.get_active_integrations_by_organization = AsyncMock(
|
||||
return_value=[slack_integration, slack_integration2]
|
||||
)
|
||||
mock_db_client.get_configuration_value = AsyncMock(
|
||||
return_value={"slack": {"DISPOSITION_CODE": "Test message"}}
|
||||
)
|
||||
|
||||
with patch(
|
||||
"api.tasks.run_integrations.aiohttp.ClientSession"
|
||||
) as mock_session_class:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.__aenter__.return_value = mock_session
|
||||
mock_session.__aexit__.return_value = AsyncMock()
|
||||
|
||||
mock_post = MagicMock()
|
||||
mock_post.__aenter__.return_value = mock_response
|
||||
mock_post.__aexit__.return_value = AsyncMock()
|
||||
|
||||
mock_session.post.return_value = mock_post
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
# Call the function
|
||||
await run_integrations_post_workflow_run(None, 1)
|
||||
|
||||
# Verify both integrations were processed
|
||||
assert mock_session.post.call_count == 2
|
||||
|
||||
# Check that both webhooks were called
|
||||
call_urls = [call[0][0] for call in mock_session.post.call_args_list]
|
||||
assert "https://hooks.slack.com/test1" in call_urls
|
||||
assert "https://hooks.slack.com/test2" in call_urls
|
||||
136
api/tests/test_run_integrations_template.py
Normal file
136
api/tests/test_run_integrations_template.py
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from api.tasks.run_integrations import _process_slack_integration
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_logger():
|
||||
"""Mock the logger for all tests."""
|
||||
with patch("api.tasks.run_integrations.logger") as mock_logger:
|
||||
# Mock the bind method to return the logger itself
|
||||
mock_logger.bind.return_value = mock_logger
|
||||
yield mock_logger
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slack_integration_with_template():
|
||||
"""Test that Slack integration uses render_template correctly."""
|
||||
# Mock integration
|
||||
mock_integration = MagicMock()
|
||||
mock_integration.id = 1
|
||||
mock_integration.organisation_id = 123
|
||||
mock_integration.connection_details = {
|
||||
"connection_config": {"incoming_webhook.url": "https://hooks.slack.com/test"}
|
||||
}
|
||||
|
||||
# Mock gathered context
|
||||
gathered_context = {
|
||||
"call_disposition": "XFER",
|
||||
"mapped_call_disposition": "XFER", # Required for Slack integration to proceed
|
||||
"call_duration": "300",
|
||||
"agent_name": "Alex",
|
||||
}
|
||||
|
||||
# Mock db_client
|
||||
with patch("api.tasks.run_integrations.db_client") as mock_db_client:
|
||||
# Mock message template configuration
|
||||
mock_db_client.get_configuration_value = AsyncMock(
|
||||
return_value={
|
||||
"slack": {
|
||||
"DISPOSITION_CODE": "Agent: {{agent_name}}\\nDisposition: {{call_disposition}}\\nDuration: {{call_duration}}s"
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Mock aiohttp session
|
||||
with patch(
|
||||
"api.tasks.run_integrations.aiohttp.ClientSession"
|
||||
) as mock_session_class:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.__aenter__.return_value = mock_session
|
||||
mock_session.__aexit__.return_value = AsyncMock()
|
||||
|
||||
mock_post = MagicMock()
|
||||
mock_post.__aenter__.return_value = mock_response
|
||||
mock_post.__aexit__.return_value = AsyncMock()
|
||||
|
||||
mock_session.post.return_value = mock_post
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
# Call the function
|
||||
await _process_slack_integration(mock_integration, gathered_context)
|
||||
|
||||
# Verify the message was formatted correctly
|
||||
mock_session.post.assert_called_once()
|
||||
call_args = mock_session.post.call_args
|
||||
|
||||
# Check the webhook URL
|
||||
assert call_args[0][0] == "https://hooks.slack.com/test"
|
||||
|
||||
# Check the message content
|
||||
json_data = call_args[1]["json"]
|
||||
|
||||
# Check that the template was rendered correctly
|
||||
expected_text = "Agent: Alex\nDisposition: XFER\nDuration: 300s"
|
||||
assert json_data["text"] == expected_text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slack_integration_with_missing_template_vars():
|
||||
"""Test template rendering with missing variables."""
|
||||
# Mock integration
|
||||
mock_integration = MagicMock()
|
||||
mock_integration.id = 1
|
||||
mock_integration.organisation_id = 123
|
||||
mock_integration.connection_details = {
|
||||
"connection_config": {"incoming_webhook.url": "https://hooks.slack.com/test"}
|
||||
}
|
||||
|
||||
# Mock gathered context with missing values
|
||||
gathered_context = {
|
||||
"call_disposition": "XFER",
|
||||
"mapped_call_disposition": "XFER", # Required for Slack integration to proceed
|
||||
# call_duration is missing
|
||||
}
|
||||
|
||||
# Mock db_client
|
||||
with patch("api.tasks.run_integrations.db_client") as mock_db_client:
|
||||
# Mock message template configuration with fallback
|
||||
mock_db_client.get_configuration_value = AsyncMock(
|
||||
return_value={
|
||||
"slack": {
|
||||
"DISPOSITION_CODE": "Disposition: {{call_disposition}}\\nDuration: {{call_duration | fallback:N/A}}"
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Mock aiohttp session
|
||||
with patch(
|
||||
"api.tasks.run_integrations.aiohttp.ClientSession"
|
||||
) as mock_session_class:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.__aenter__.return_value = mock_session
|
||||
mock_session.__aexit__.return_value = AsyncMock()
|
||||
|
||||
mock_post = MagicMock()
|
||||
mock_post.__aenter__.return_value = mock_response
|
||||
mock_post.__aexit__.return_value = AsyncMock()
|
||||
|
||||
mock_session.post.return_value = mock_post
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
# Call the function
|
||||
await _process_slack_integration(mock_integration, gathered_context)
|
||||
|
||||
# Check that the template was rendered with fallback
|
||||
json_data = mock_session.post.call_args[1]["json"]
|
||||
expected_text = "Disposition: XFER\nDuration: N/A"
|
||||
assert json_data["text"] == expected_text
|
||||
117
api/tests/test_s3_signed_url.py
Normal file
117
api/tests/test_s3_signed_url.py
Normal file
|
|
@ -0,0 +1,117 @@
|
|||
"""Tests for the `/s3/signed-url` endpoint.
|
||||
|
||||
This test-suite verifies:
|
||||
1. Regular users can retrieve signed URLs for resources belonging to their own workflow runs.
|
||||
2. Regular users are *forbidden* from accessing resources that belong to other users.
|
||||
3. Superusers can access any resource irrespective of ownership.
|
||||
"""
|
||||
|
||||
import os
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from fastapi import status
|
||||
|
||||
# Ensure the S3 environment variables exist so that the module import does not fail
|
||||
os.environ.setdefault("S3_BUCKET", "test-bucket")
|
||||
os.environ.setdefault("S3_REGION", "us-east-1")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_signed_url_for_own_run(monkeypatch, test_client_factory, db_session):
|
||||
"""A normal user should be able to fetch a signed URL for their own workflow run."""
|
||||
from api.db.models import UserModel
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 1. Set-up – create user, workflow & workflow run
|
||||
# ------------------------------------------------------------------
|
||||
user: UserModel = await db_session.get_or_create_user_by_provider_id("user_own_run")
|
||||
workflow = await db_session.create_workflow("wf", {}, user.id)
|
||||
run = await db_session.create_workflow_run("run", workflow.id, "chat", user.id)
|
||||
|
||||
key = f"transcripts/{run.id}.txt"
|
||||
|
||||
# Patch S3 signed-url generator to avoid network calls
|
||||
monkeypatch.setattr(
|
||||
"api.services.filesystem.s3.s3_fs.aget_signed_url",
|
||||
AsyncMock(return_value="https://signed-url"),
|
||||
)
|
||||
|
||||
async with test_client_factory(user) as client:
|
||||
response = await client.get(f"/api/v1/s3/signed-url?key={key}")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data == {"url": "https://signed-url", "expires_in": 3600}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_signed_url_for_other_users_run_forbidden(
|
||||
monkeypatch, test_client_factory, db_session
|
||||
):
|
||||
"""A normal user must *not* access workflow runs owned by someone else."""
|
||||
from api.db.models import UserModel
|
||||
|
||||
# Owner of the workflow run
|
||||
owner: UserModel = await db_session.get_or_create_user_by_provider_id("owner_user")
|
||||
workflow = await db_session.create_workflow("wf", {}, owner.id)
|
||||
run = await db_session.create_workflow_run("run", workflow.id, "chat", owner.id)
|
||||
|
||||
# Second user attempting access
|
||||
intruder: UserModel = await db_session.get_or_create_user_by_provider_id(
|
||||
"intruder_user"
|
||||
)
|
||||
|
||||
key = f"recordings/{run.id}.wav"
|
||||
|
||||
monkeypatch.setattr(
|
||||
"api.services.filesystem.s3.s3_fs.aget_signed_url",
|
||||
AsyncMock(return_value="https://signed-url"),
|
||||
)
|
||||
|
||||
async with test_client_factory(intruder) as client:
|
||||
response = await client.get(f"/api/v1/s3/signed-url?key={key}")
|
||||
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_superuser_can_access_any_run(
|
||||
monkeypatch, test_client_factory, db_session
|
||||
):
|
||||
"""Superusers should be able to fetch signed URLs for any workflow run."""
|
||||
from api.db.models import UserModel
|
||||
|
||||
# Normal user & run owner
|
||||
owner: UserModel = await db_session.get_or_create_user_by_provider_id(
|
||||
"owner_of_run"
|
||||
)
|
||||
workflow = await db_session.create_workflow("wf", {}, owner.id)
|
||||
run = await db_session.create_workflow_run("run", workflow.id, "chat", owner.id)
|
||||
|
||||
# Superuser
|
||||
superuser: UserModel = await db_session.get_or_create_user_by_provider_id(
|
||||
"admin_user"
|
||||
)
|
||||
|
||||
# Promote to superuser
|
||||
# We need to commit the change so that the DB reflects it
|
||||
async with db_session.async_session() as session:
|
||||
db_user = await session.get(UserModel, superuser.id)
|
||||
db_user.is_superuser = True
|
||||
await session.commit()
|
||||
await session.refresh(db_user) # ensure we have the latest state
|
||||
superuser.is_superuser = True
|
||||
|
||||
key = f"transcripts/{run.id}.txt"
|
||||
|
||||
monkeypatch.setattr(
|
||||
"api.services.filesystem.s3.s3_fs.aget_signed_url",
|
||||
AsyncMock(return_value="https://signed-url"),
|
||||
)
|
||||
|
||||
async with test_client_factory(superuser) as client:
|
||||
response = await client.get(f"/api/v1/s3/signed-url?key={key}")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.json()["url"] == "https://signed-url"
|
||||
129
api/tests/test_s3_upload_tasks.py
Normal file
129
api/tests/test_s3_upload_tasks.py
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
import os
|
||||
import tempfile
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from api.tasks.s3_upload import upload_audio_to_s3, upload_transcript_to_s3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_audio_to_s3_success():
|
||||
"""Test successful audio upload to S3."""
|
||||
# Create a temporary file
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tf:
|
||||
tf.write(b"fake audio data")
|
||||
temp_path = tf.name
|
||||
|
||||
try:
|
||||
# Mock dependencies
|
||||
mock_ctx = AsyncMock()
|
||||
mock_s3_fs = AsyncMock()
|
||||
mock_db_client = AsyncMock()
|
||||
|
||||
with (
|
||||
patch("api.tasks.s3_upload.s3_fs", mock_s3_fs),
|
||||
patch("api.tasks.s3_upload.db_client", mock_db_client),
|
||||
):
|
||||
await upload_audio_to_s3(
|
||||
mock_ctx, workflow_run_id=123, temp_file_path=temp_path
|
||||
)
|
||||
|
||||
# Verify S3 upload was called
|
||||
mock_s3_fs.aupload_file.assert_called_once_with(
|
||||
temp_path, "recordings/123.wav"
|
||||
)
|
||||
|
||||
# Verify DB update was called
|
||||
mock_db_client.update_workflow_run.assert_called_once_with(
|
||||
run_id=123, recording_url="recordings/123.wav"
|
||||
)
|
||||
|
||||
# Verify temp file was cleaned up
|
||||
assert not os.path.exists(temp_path)
|
||||
|
||||
finally:
|
||||
# Clean up if test failed
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_audio_to_s3_file_not_found():
|
||||
"""Test audio upload when temp file doesn't exist."""
|
||||
mock_ctx = AsyncMock()
|
||||
|
||||
with pytest.raises(FileNotFoundError):
|
||||
await upload_audio_to_s3(
|
||||
mock_ctx, workflow_run_id=123, temp_file_path="/nonexistent/file.wav"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_transcript_to_s3_success():
|
||||
"""Test successful transcript upload to S3."""
|
||||
# Create a temporary file
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as tf:
|
||||
tf.write("Test transcript content")
|
||||
temp_path = tf.name
|
||||
|
||||
try:
|
||||
# Mock dependencies
|
||||
mock_ctx = AsyncMock()
|
||||
mock_s3_fs = AsyncMock()
|
||||
mock_db_client = AsyncMock()
|
||||
|
||||
with (
|
||||
patch("api.tasks.s3_upload.s3_fs", mock_s3_fs),
|
||||
patch("api.tasks.s3_upload.db_client", mock_db_client),
|
||||
):
|
||||
await upload_transcript_to_s3(
|
||||
mock_ctx, workflow_run_id=456, temp_file_path=temp_path
|
||||
)
|
||||
|
||||
# Verify S3 upload was called
|
||||
mock_s3_fs.aupload_file.assert_called_once_with(
|
||||
temp_path, "transcripts/456.txt"
|
||||
)
|
||||
|
||||
# Verify DB update was called
|
||||
mock_db_client.update_workflow_run.assert_called_once_with(
|
||||
run_id=456, transcript_url="transcripts/456.txt"
|
||||
)
|
||||
|
||||
# Verify temp file was cleaned up
|
||||
assert not os.path.exists(temp_path)
|
||||
|
||||
finally:
|
||||
# Clean up if test failed
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_s3_cleanup_on_error():
|
||||
"""Test that temp files are cleaned up even when S3 upload fails."""
|
||||
# Create a temporary file
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tf:
|
||||
tf.write(b"fake audio data")
|
||||
temp_path = tf.name
|
||||
|
||||
try:
|
||||
mock_ctx = AsyncMock()
|
||||
mock_s3_fs = AsyncMock()
|
||||
# Make S3 upload fail
|
||||
mock_s3_fs.aupload_file.side_effect = Exception("S3 upload failed")
|
||||
|
||||
with patch("api.tasks.s3_upload.s3_fs", mock_s3_fs):
|
||||
with pytest.raises(Exception):
|
||||
await upload_audio_to_s3(
|
||||
mock_ctx, workflow_run_id=123, temp_file_path=temp_path
|
||||
)
|
||||
|
||||
# Verify temp file was still cleaned up
|
||||
assert not os.path.exists(temp_path)
|
||||
|
||||
finally:
|
||||
# Clean up if test failed
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
89
api/tests/test_template_renderer.py
Normal file
89
api/tests/test_template_renderer.py
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
from api.utils.template_renderer import render_template
|
||||
|
||||
|
||||
def test_render_template_basic():
|
||||
"""Test basic template rendering."""
|
||||
template = "Hello {{name}}, your balance is {{balance}}."
|
||||
context = {"name": "John", "balance": "$1000"}
|
||||
|
||||
result = render_template(template, context)
|
||||
assert result == "Hello John, your balance is $1000."
|
||||
|
||||
|
||||
def test_render_template_with_spaces():
|
||||
"""Test template rendering with spaces around variables."""
|
||||
template = "Hello {{ name }}, your balance is {{ balance }}."
|
||||
context = {"name": "John", "balance": "$1000"}
|
||||
|
||||
result = render_template(template, context)
|
||||
assert result == "Hello John, your balance is $1000."
|
||||
|
||||
|
||||
def test_render_template_missing_variable():
|
||||
"""Test template rendering with missing variables."""
|
||||
template = "Hello {{name}}, your balance is {{balance}}."
|
||||
context = {"name": "John"}
|
||||
|
||||
result = render_template(template, context)
|
||||
assert result == "Hello John, your balance is ."
|
||||
|
||||
|
||||
def test_render_template_with_fallback():
|
||||
"""Test template rendering with fallback values."""
|
||||
template = "Hello {{name | fallback}}, your balance is {{balance | fallback:$0}}."
|
||||
context = {}
|
||||
|
||||
result = render_template(template, context)
|
||||
assert result == "Hello Name, your balance is $0."
|
||||
|
||||
|
||||
def test_render_template_with_fallback_existing_value():
|
||||
"""Test that fallback is not used when value exists."""
|
||||
template = "Hello {{name | fallback:Guest}}"
|
||||
context = {"name": "John"}
|
||||
|
||||
result = render_template(template, context)
|
||||
assert result == "Hello John"
|
||||
|
||||
|
||||
def test_render_template_with_line_breaks():
|
||||
"""Test template rendering with line breaks."""
|
||||
template = (
|
||||
"DISPOSITION_CODE: {{call_disposition}}\\nCALL_DURATION: {{call_duration}}"
|
||||
)
|
||||
context = {"call_disposition": "XFER", "call_duration": "300"}
|
||||
|
||||
result = render_template(template, context)
|
||||
expected = "DISPOSITION_CODE: XFER\nCALL_DURATION: 300"
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_render_template_empty():
|
||||
"""Test rendering empty template."""
|
||||
assert render_template("", {}) == ""
|
||||
assert render_template(None, {}) == None
|
||||
|
||||
|
||||
def test_render_template_no_placeholders():
|
||||
"""Test template with no placeholders."""
|
||||
template = "This is a plain text message"
|
||||
result = render_template(template, {"unused": "value"})
|
||||
assert result == "This is a plain text message"
|
||||
|
||||
|
||||
def test_render_template_none_values():
|
||||
"""Test template with None values."""
|
||||
template = "Value: {{value}}"
|
||||
context = {"value": None}
|
||||
|
||||
result = render_template(template, context)
|
||||
assert result == "Value: "
|
||||
|
||||
|
||||
def test_render_template_numeric_values():
|
||||
"""Test template with numeric values."""
|
||||
template = "Count: {{count}}, Price: {{price}}"
|
||||
context = {"count": 42, "price": 19.99}
|
||||
|
||||
result = render_template(template, context)
|
||||
assert result == "Count: 42, Price: 19.99"
|
||||
152
api/tests/test_usage_concurrency.py
Normal file
152
api/tests/test_usage_concurrency.py
Normal file
|
|
@ -0,0 +1,152 @@
|
|||
#!/usr/bin/env python
|
||||
"""
|
||||
Test script to verify atomic operations in organization_usage_client.py
|
||||
This simulates concurrent access from multiple processes.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
|
||||
# Set up environment
|
||||
os.environ.setdefault("DATABASE_URL", os.environ.get("DATABASE_URL", ""))
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from api.db.organization_usage_client import OrganizationUsageClient
|
||||
|
||||
|
||||
async def reserve_quota_process(org_id: int, tokens: int, process_id: int):
|
||||
"""Simulate a process trying to reserve quota."""
|
||||
engine = create_async_engine(os.environ["DATABASE_URL"])
|
||||
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
client = OrganizationUsageClient(async_session)
|
||||
|
||||
results = []
|
||||
for i in range(5):
|
||||
result = await client.check_and_reserve_quota(org_id, tokens)
|
||||
results.append((process_id, i, result))
|
||||
await asyncio.sleep(0.01) # Small delay to increase contention
|
||||
|
||||
await engine.dispose()
|
||||
return results
|
||||
|
||||
|
||||
async def update_usage_process(org_id: int, tokens: int, process_id: int):
|
||||
"""Simulate a process updating usage after runs."""
|
||||
engine = create_async_engine(os.environ["DATABASE_URL"])
|
||||
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
client = OrganizationUsageClient(async_session)
|
||||
|
||||
for i in range(5):
|
||||
await client.update_usage_after_run(org_id, tokens, duration_seconds=10)
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
await engine.dispose()
|
||||
return f"Process {process_id} completed updates"
|
||||
|
||||
|
||||
def run_reserve_quota(args):
|
||||
"""Wrapper to run async function in process."""
|
||||
org_id, tokens, process_id = args
|
||||
return asyncio.run(reserve_quota_process(org_id, tokens, process_id))
|
||||
|
||||
|
||||
def run_update_usage(args):
|
||||
"""Wrapper to run async function in process."""
|
||||
org_id, tokens, process_id = args
|
||||
return asyncio.run(update_usage_process(org_id, tokens, process_id))
|
||||
|
||||
|
||||
async def test_concurrent_quota_reservation():
|
||||
"""Test that concurrent quota reservations are handled atomically."""
|
||||
print("Testing concurrent quota reservations...")
|
||||
|
||||
# Assuming org_id 1 exists with quota enabled
|
||||
org_id = 1
|
||||
tokens_per_request = 100
|
||||
|
||||
# Run multiple processes trying to reserve quota simultaneously
|
||||
with ProcessPoolExecutor(max_workers=3) as executor:
|
||||
futures = []
|
||||
for i in range(3):
|
||||
futures.append(
|
||||
executor.submit(run_reserve_quota, (org_id, tokens_per_request, i))
|
||||
)
|
||||
|
||||
results = []
|
||||
for future in futures:
|
||||
results.extend(future.result())
|
||||
|
||||
print(f"Reservation results: {results}")
|
||||
|
||||
# Check that reservations were handled atomically
|
||||
successful_reservations = sum(1 for _, _, success in results if success)
|
||||
print(f"Successful reservations: {successful_reservations}")
|
||||
|
||||
|
||||
async def test_concurrent_usage_updates():
|
||||
"""Test that concurrent usage updates are handled atomically."""
|
||||
print("\nTesting concurrent usage updates...")
|
||||
|
||||
org_id = 1
|
||||
tokens_per_update = 50
|
||||
|
||||
# Get initial usage
|
||||
engine = create_async_engine(os.environ["DATABASE_URL"])
|
||||
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
client = OrganizationUsageClient(async_session)
|
||||
|
||||
initial_usage = await client.get_current_usage(org_id)
|
||||
initial_tokens = initial_usage["used_dograh_tokens"]
|
||||
print(f"Initial tokens: {initial_tokens}")
|
||||
|
||||
# Run multiple processes updating usage simultaneously
|
||||
with ProcessPoolExecutor(max_workers=3) as executor:
|
||||
futures = []
|
||||
for i in range(3):
|
||||
futures.append(
|
||||
executor.submit(run_update_usage, (org_id, tokens_per_update, i))
|
||||
)
|
||||
|
||||
for future in futures:
|
||||
print(future.result())
|
||||
|
||||
# Check final usage
|
||||
final_usage = await client.get_current_usage(org_id)
|
||||
final_tokens = final_usage["used_dograh_tokens"]
|
||||
expected_tokens = initial_tokens + (
|
||||
3 * 5 * tokens_per_update
|
||||
) # 3 processes * 5 updates * 50 tokens
|
||||
|
||||
print(f"Final tokens: {final_tokens}")
|
||||
print(f"Expected tokens: {expected_tokens}")
|
||||
print(f"Difference: {final_tokens - expected_tokens}")
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
if final_tokens == expected_tokens:
|
||||
print("✅ All updates were applied atomically!")
|
||||
else:
|
||||
print("❌ Some updates were lost due to race conditions!")
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run all concurrency tests."""
|
||||
try:
|
||||
await test_concurrent_quota_reservation()
|
||||
await test_concurrent_usage_updates()
|
||||
except Exception as e:
|
||||
print(f"Error during testing: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Starting organization usage concurrency tests...")
|
||||
print(f"Using DATABASE_URL: {os.environ.get('DATABASE_URL', 'NOT SET')}")
|
||||
asyncio.run(main())
|
||||
140
api/tests/test_variable_extraction.py
Normal file
140
api/tests/test_variable_extraction.py
Normal file
|
|
@ -0,0 +1,140 @@
|
|||
import json
|
||||
import os
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from pipecat.services.openai.llm import OpenAILLMContext
|
||||
|
||||
from api.services.workflow.dto import ExtractionVariableDTO, VariableType
|
||||
from api.services.workflow.pipecat_engine_variable_extractor import (
|
||||
VariableExtractionManager,
|
||||
)
|
||||
|
||||
|
||||
class DummyLLM:
|
||||
"""A minimal stub that mimics the parts of an LLM service used by the extractor."""
|
||||
|
||||
def __init__(self, streamed_response: str | None = None):
|
||||
# Optionally provide a pre-defined streaming response for _perform_extraction tests
|
||||
self._streamed_response = streamed_response or "{}"
|
||||
self.registered_functions: dict[str, AsyncMock] = {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# API used by VariableExtractionManager
|
||||
# ------------------------------------------------------------------
|
||||
def register_function(self, name: str, func, cancel_on_interruption=True): # noqa: D401 – simple delegate
|
||||
self.registered_functions[name] = func
|
||||
|
||||
async def get_chat_completions(self, _context, _messages):
|
||||
"""Return an async generator that yields a single chunk with the full response."""
|
||||
|
||||
class _Delta: # noqa: D401 – tiny helper classes for stub response
|
||||
def __init__(self, content):
|
||||
self.content = content
|
||||
|
||||
class _Choice:
|
||||
def __init__(self, delta):
|
||||
self.delta = delta
|
||||
|
||||
class _Chunk:
|
||||
def __init__(self, content):
|
||||
self.choices = [_Choice(_Delta(content))]
|
||||
|
||||
async def _stream():
|
||||
yield _Chunk(self._streamed_response)
|
||||
|
||||
return _stream()
|
||||
|
||||
|
||||
class DummyEngine:
|
||||
"""A bare-bones Engine stub exposing only what the extractor relies on."""
|
||||
|
||||
def __init__(self, llm):
|
||||
self.llm = llm
|
||||
self.context = OpenAILLMContext()
|
||||
self._pending_function_calls = 0
|
||||
# VariableExtractionManager currently updates this private attribute
|
||||
self._gathered_context: dict = {}
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tests
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_perform_extraction_parses_json_correctly():
|
||||
"""_perform_extraction should return the parsed JSON from the LLM stream."""
|
||||
# Set dummy OpenAI API key to prevent initialization errors
|
||||
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
|
||||
expected_payload = {"name": "Alice", "age": 30}
|
||||
llm = DummyLLM(json.dumps(expected_payload))
|
||||
engine = DummyEngine(llm)
|
||||
manager = VariableExtractionManager(engine)
|
||||
|
||||
# Mock the AsyncOpenAI client and its response
|
||||
mock_response = AsyncMock()
|
||||
mock_response.choices = [AsyncMock()]
|
||||
mock_response.choices[0].message = AsyncMock()
|
||||
mock_response.choices[0].message.content = json.dumps(expected_payload)
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine_variable_extractor.AsyncOpenAI",
|
||||
return_value=mock_client,
|
||||
):
|
||||
# Minimal set of variables to extract – the prompts themselves are irrelevant here
|
||||
extraction_variables = [
|
||||
ExtractionVariableDTO(
|
||||
name="name", type=VariableType.string, prompt="user name"
|
||||
),
|
||||
ExtractionVariableDTO(
|
||||
name="age", type=VariableType.number, prompt="user age"
|
||||
),
|
||||
]
|
||||
|
||||
result = await manager._perform_extraction(
|
||||
extraction_variables, parent_ctx=None, extraction_prompt=""
|
||||
)
|
||||
|
||||
assert result == expected_payload
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_perform_extraction_with_custom_system_prompt():
|
||||
"""_perform_extraction should use the provided extraction_prompt as system prompt."""
|
||||
# Set dummy OpenAI API key to prevent initialization errors
|
||||
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
|
||||
expected_payload = {"color": "blue"}
|
||||
llm = DummyLLM(json.dumps(expected_payload))
|
||||
engine = DummyEngine(llm)
|
||||
manager = VariableExtractionManager(engine)
|
||||
|
||||
# Mock the AsyncOpenAI client and its response
|
||||
mock_response = AsyncMock()
|
||||
mock_response.choices = [AsyncMock()]
|
||||
mock_response.choices[0].message = AsyncMock()
|
||||
mock_response.choices[0].message.content = json.dumps(expected_payload)
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine_variable_extractor.AsyncOpenAI",
|
||||
return_value=mock_client,
|
||||
):
|
||||
extraction_variables = [
|
||||
ExtractionVariableDTO(
|
||||
name="color", type=VariableType.string, prompt="favourite color"
|
||||
)
|
||||
]
|
||||
|
||||
# Call with a custom extraction prompt
|
||||
custom_prompt = "You are a color extraction specialist."
|
||||
result = await manager._perform_extraction(
|
||||
extraction_variables, parent_ctx=None, extraction_prompt=custom_prompt
|
||||
)
|
||||
|
||||
assert result == expected_payload
|
||||
547
api/tests/test_voicemail_detection_rtc.py
Normal file
547
api/tests/test_voicemail_detection_rtc.py
Normal file
|
|
@ -0,0 +1,547 @@
|
|||
"""
|
||||
Test voicemail detection in RTC connection flow.
|
||||
|
||||
This test emulates how a call is connected using SmallWebRTC,
|
||||
triggers voicemail detection, and verifies the disconnect reason.
|
||||
"""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from pipecat.utils.enums import EndTaskReason
|
||||
|
||||
from api.routes.rtc_offer import RTCOfferRequest, offer
|
||||
from api.services.workflow.pipecat_engine_voicemail_detector import VoicemailDetector
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestVoicemailDetectionRTC:
|
||||
"""Test voicemail detection through RTC connection flow."""
|
||||
|
||||
async def test_voicemail_detection_full_flow(self):
|
||||
"""
|
||||
Test complete voicemail detection flow:
|
||||
1. RTC connection request
|
||||
2. Transport sends on_client_connected event
|
||||
3. Engine initializes with voicemail detection enabled
|
||||
4. Voicemail detector returns true
|
||||
5. Call terminates with voicemail_detected reason
|
||||
6. Transport sends on_client_disconnected event
|
||||
7. Disconnect reason is properly set
|
||||
"""
|
||||
# Mock user and authentication
|
||||
mock_user = Mock()
|
||||
mock_user.id = 1
|
||||
mock_user.organization_id = 1
|
||||
|
||||
# Mock workflow with voicemail detection enabled
|
||||
mock_workflow = Mock()
|
||||
mock_workflow.id = 100
|
||||
mock_workflow.workflow_definition_with_fallback = {
|
||||
"edges": [],
|
||||
"nodes": [
|
||||
{
|
||||
"id": "start",
|
||||
"type": "start",
|
||||
"data": {
|
||||
"detect_voicemail": True,
|
||||
"system_prompt": "You are a helpful assistant",
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
# Mock workflow run
|
||||
mock_workflow_run = Mock()
|
||||
mock_workflow_run.id = 200
|
||||
mock_workflow_run.is_completed = False
|
||||
|
||||
# Create request
|
||||
request = RTCOfferRequest(
|
||||
pc_id="test_pc_123",
|
||||
sdp="test_sdp_offer",
|
||||
type="offer",
|
||||
workflow_id=mock_workflow.id,
|
||||
workflow_run_id=mock_workflow_run.id,
|
||||
restart_pc=False,
|
||||
call_context_vars={"test_var": "test_value"},
|
||||
)
|
||||
|
||||
# Mock dependencies
|
||||
with (
|
||||
patch("api.services.auth.depends.get_user") as mock_get_user_dep,
|
||||
patch("api.routes.rtc_offer.SmallWebRTCConnection") as MockWebRTCConnection,
|
||||
patch("api.routes.rtc_offer.run_pipeline_smallwebrtc") as mock_run_pipeline,
|
||||
):
|
||||
# Setup mocks
|
||||
mock_get_user_dep.return_value = mock_user
|
||||
|
||||
# Mock WebRTC connection
|
||||
mock_connection = Mock()
|
||||
mock_connection.pc_id = "test_pc_123"
|
||||
mock_connection.initialize = AsyncMock()
|
||||
mock_connection.get_answer = Mock(
|
||||
return_value={
|
||||
"pc_id": "test_pc_123",
|
||||
"sdp": "test_sdp_answer",
|
||||
"type": "answer",
|
||||
}
|
||||
)
|
||||
MockWebRTCConnection.return_value = mock_connection
|
||||
|
||||
# Track registered event handlers
|
||||
registered_handlers = {}
|
||||
|
||||
def mock_event_handler(event_name):
|
||||
def decorator(func):
|
||||
registered_handlers[event_name] = func
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
mock_connection.event_handler = mock_event_handler
|
||||
|
||||
# Mock BackgroundTasks
|
||||
mock_background_tasks = Mock()
|
||||
|
||||
# Create the offer
|
||||
response = await offer(request, mock_background_tasks, mock_user)
|
||||
|
||||
# Verify response
|
||||
assert response["pc_id"] == "test_pc_123"
|
||||
assert response["type"] == "answer"
|
||||
|
||||
# Verify connection was initialized
|
||||
mock_connection.initialize.assert_called_once_with(
|
||||
sdp="test_sdp_offer", type="offer"
|
||||
)
|
||||
|
||||
# Verify background task was added
|
||||
mock_background_tasks.add_task.assert_called_once()
|
||||
task_args = mock_background_tasks.add_task.call_args[0]
|
||||
assert task_args[0] == mock_run_pipeline
|
||||
assert task_args[1] == mock_connection
|
||||
assert task_args[2] == mock_workflow.id
|
||||
assert task_args[3] == mock_workflow_run.id
|
||||
assert task_args[4] == mock_user.id
|
||||
assert task_args[5] == {"test_var": "test_value"}
|
||||
|
||||
async def test_voicemail_detection_in_pipeline(self):
|
||||
"""Tests whether the updates happen in on_client_disconnected properly
|
||||
with values set in the engine"""
|
||||
# Mock components
|
||||
mock_transport = AsyncMock()
|
||||
mock_engine = Mock() # Use Mock instead of AsyncMock for engine
|
||||
mock_engine.initialize = AsyncMock()
|
||||
mock_engine.cleanup = AsyncMock()
|
||||
mock_audio_buffer = AsyncMock()
|
||||
mock_task = AsyncMock()
|
||||
mock_aggregator = Mock()
|
||||
|
||||
# Setup engine with voicemail detector
|
||||
mock_voicemail_detector = AsyncMock(spec=VoicemailDetector)
|
||||
mock_engine.voicemail_detector = mock_voicemail_detector
|
||||
mock_engine.get_call_disposition = Mock(
|
||||
return_value=EndTaskReason.VOICEMAIL_DETECTED.value
|
||||
)
|
||||
mock_engine.get_gathered_context = Mock(
|
||||
return_value={
|
||||
"voicemail_transcript": "Hi, you've reached John's voicemail. Please leave a message.",
|
||||
"voicemail_confidence": 0.95,
|
||||
}
|
||||
)
|
||||
|
||||
# Mock usage metrics
|
||||
mock_aggregator.get_all_usage_metrics_serialized.return_value = {}
|
||||
|
||||
# Register event handlers
|
||||
from api.services.pipecat.event_handlers import (
|
||||
register_transport_event_handlers,
|
||||
)
|
||||
|
||||
# Track registered handlers
|
||||
handlers = {}
|
||||
|
||||
def track_handler(event_name):
|
||||
def decorator(func):
|
||||
handlers[event_name] = func
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
mock_transport.event_handler = track_handler
|
||||
|
||||
# Create a mock db_client module with update_workflow_run method
|
||||
mock_db_client = Mock()
|
||||
mock_db_client.update_workflow_run = AsyncMock()
|
||||
|
||||
with (
|
||||
patch("api.services.pipecat.event_handlers.db_client", mock_db_client),
|
||||
patch(
|
||||
"api.services.pipecat.event_handlers.enqueue_job",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_enqueue_job,
|
||||
patch(
|
||||
"api.services.pipecat.event_handlers.get_organization_id_from_workflow_run",
|
||||
return_value=1,
|
||||
),
|
||||
patch(
|
||||
"api.services.pipecat.event_handlers.apply_disposition_mapping",
|
||||
side_effect=lambda value, org_id: value, # Return value unchanged
|
||||
),
|
||||
):
|
||||
# Register handlers
|
||||
register_transport_event_handlers(
|
||||
mock_transport,
|
||||
workflow_run_id=123,
|
||||
audio_buffer=mock_audio_buffer,
|
||||
task=mock_task,
|
||||
engine=mock_engine,
|
||||
usage_metrics_aggregator=mock_aggregator,
|
||||
)
|
||||
|
||||
# Verify handlers were registered
|
||||
assert "on_client_connected" in handlers
|
||||
assert "on_client_disconnected" in handlers
|
||||
|
||||
# Simulate client connection
|
||||
await handlers["on_client_connected"](
|
||||
mock_transport, {"id": "participant_1"}
|
||||
)
|
||||
|
||||
# Verify initialization
|
||||
mock_audio_buffer.start_recording.assert_called_once()
|
||||
mock_engine.initialize.assert_called_once()
|
||||
|
||||
# Simulate voicemail detection and disconnect
|
||||
await handlers["on_client_disconnected"](
|
||||
mock_transport, {"id": "participant_1"}, None
|
||||
)
|
||||
|
||||
# Verify engine cleanup
|
||||
mock_engine.cleanup.assert_called_once()
|
||||
|
||||
# TODO: check whether task was cancelled or not once have more
|
||||
# clarity on how to handle engine disconnect vs remote hangup
|
||||
# Verify task was NOT cancelled (engine disconnect)
|
||||
# mock_task.cancel.assert_not_called()
|
||||
|
||||
# Verify workflow run was updated with voicemail context
|
||||
mock_db_client.update_workflow_run.assert_called()
|
||||
call_args = mock_db_client.update_workflow_run.call_args
|
||||
assert call_args[1]["run_id"] == 123
|
||||
# Check that the mapped_call_disposition was set correctly
|
||||
assert (
|
||||
call_args[1]["gathered_context"]["mapped_call_disposition"]
|
||||
== "voicemail_detected"
|
||||
)
|
||||
|
||||
async def test_voicemail_detector_audio_processing(self):
|
||||
"""Test VoicemailDetector audio processing and detection logic - tests that voicemail detector
|
||||
calls engine's send_end_task_frame with the correct reason and metadata"""
|
||||
# Create voicemail detector
|
||||
detector = VoicemailDetector(detection_duration=5.0, workflow_run_id=123)
|
||||
|
||||
# Mock OpenAI client
|
||||
mock_openai = AsyncMock()
|
||||
mock_whisper_response = Mock()
|
||||
mock_whisper_response.text = "Hi, you've reached the voicemail of John Smith. Please leave a message after the beep."
|
||||
mock_openai.audio.transcriptions.create.return_value = mock_whisper_response
|
||||
|
||||
mock_gpt_response = Mock()
|
||||
mock_gpt_response.choices = [Mock()]
|
||||
mock_gpt_response.choices[0].message.content = json.dumps(
|
||||
{
|
||||
"is_voicemail": True,
|
||||
"confidence": 0.98,
|
||||
"reasoning": "Clear voicemail greeting with request to leave message",
|
||||
}
|
||||
)
|
||||
mock_openai.chat.completions.create.return_value = mock_gpt_response
|
||||
|
||||
# Mock engine
|
||||
mock_engine = AsyncMock()
|
||||
mock_engine.task = AsyncMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"api.services.workflow.pipecat_engine_voicemail_detector.AsyncOpenAI",
|
||||
return_value=mock_openai,
|
||||
),
|
||||
patch(
|
||||
"api.services.workflow.pipecat_engine_voicemail_detector.s3_fs"
|
||||
) as mock_s3,
|
||||
):
|
||||
# Mock S3 upload to return None (simulating successful upload)
|
||||
mock_s3.aupload_file = AsyncMock(return_value=True)
|
||||
# Start detection
|
||||
await detector.start_detection(mock_engine)
|
||||
assert detector.is_detecting == True
|
||||
|
||||
# Simulate audio data (16kHz, mono, 5 seconds)
|
||||
sample_rate = 16000
|
||||
duration = 5.0
|
||||
audio_data = b"\x00\x00" * int(sample_rate * duration) # Silent audio
|
||||
|
||||
# Process audio in chunks
|
||||
chunk_size = 1600 # 100ms chunks
|
||||
for i in range(0, len(audio_data), chunk_size):
|
||||
chunk = audio_data[i : i + chunk_size]
|
||||
await detector.handle_audio_data(None, chunk, sample_rate, 1)
|
||||
|
||||
# Wait for detection to complete
|
||||
if detector._detection_task:
|
||||
await detector._detection_task
|
||||
|
||||
# Verify OpenAI calls
|
||||
mock_openai.audio.transcriptions.create.assert_called_once()
|
||||
mock_openai.chat.completions.create.assert_called_once()
|
||||
|
||||
# Verify send_end_task_frame was called with voicemail detection
|
||||
mock_engine.send_end_task_frame.assert_called_once_with(
|
||||
reason=EndTaskReason.VOICEMAIL_DETECTED.value,
|
||||
additional_metadata={
|
||||
"voicemail_transcript": "Hi, you've reached the voicemail of John Smith. Please leave a message after the beep.",
|
||||
"voicemail_confidence": 0.98,
|
||||
"voicemail_reasoning": "Clear voicemail greeting with request to leave message",
|
||||
"voicemail_detection_duration": 5.0,
|
||||
"voicemail_audio_s3_path": "voicemail_detections/123_voicemail_98_5.wav", # S3 upload returns True, so filename is used
|
||||
},
|
||||
abort_immediately=True,
|
||||
)
|
||||
|
||||
async def test_voicemail_detector_no_detection(self):
|
||||
"""Test VoicemailDetector when voicemail is not detected."""
|
||||
# Create voicemail detector
|
||||
detector = VoicemailDetector(detection_duration=5.0, workflow_run_id=124)
|
||||
|
||||
# Mock OpenAI client
|
||||
mock_openai = AsyncMock()
|
||||
mock_whisper_response = Mock()
|
||||
mock_whisper_response.text = "Hello? Hello? Can you hear me?"
|
||||
mock_openai.audio.transcriptions.create.return_value = mock_whisper_response
|
||||
|
||||
mock_gpt_response = Mock()
|
||||
mock_gpt_response.choices = [Mock()]
|
||||
mock_gpt_response.choices[0].message.content = json.dumps(
|
||||
{
|
||||
"is_voicemail": False,
|
||||
"confidence": 0.95,
|
||||
"reasoning": "Live person speaking, asking if caller can hear them",
|
||||
}
|
||||
)
|
||||
mock_openai.chat.completions.create.return_value = mock_gpt_response
|
||||
|
||||
# Mock engine
|
||||
mock_engine = AsyncMock()
|
||||
mock_engine.task = AsyncMock()
|
||||
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine_voicemail_detector.AsyncOpenAI",
|
||||
return_value=mock_openai,
|
||||
):
|
||||
# Start detection
|
||||
await detector.start_detection(mock_engine)
|
||||
|
||||
# Simulate audio data
|
||||
sample_rate = 16000
|
||||
duration = 5.0
|
||||
audio_data = b"\x00\x00" * int(sample_rate * duration)
|
||||
|
||||
# Process audio
|
||||
await detector.handle_audio_data(None, audio_data, sample_rate, 1)
|
||||
|
||||
# Wait for detection
|
||||
if detector._detection_task:
|
||||
await detector._detection_task
|
||||
|
||||
# Verify send_end_task_frame was NOT called
|
||||
mock_engine.send_end_task_frame.assert_not_called()
|
||||
|
||||
async def test_voicemail_detector_cancellation(self):
|
||||
"""Test VoicemailDetector cancellation before completion."""
|
||||
# Create voicemail detector
|
||||
detector = VoicemailDetector(detection_duration=10.0, workflow_run_id=125)
|
||||
|
||||
# Mock engine
|
||||
mock_engine = AsyncMock()
|
||||
|
||||
# Start detection
|
||||
await detector.start_detection(mock_engine)
|
||||
assert detector.is_detecting == True
|
||||
|
||||
# Cancel detection immediately
|
||||
await detector.stop_detection()
|
||||
assert detector._is_cancelled == True
|
||||
|
||||
# Try to add audio data after cancellation
|
||||
await detector.handle_audio_data(None, b"\x00\x00" * 1000, 16000, 1)
|
||||
|
||||
# Verify buffer didn't grow (no audio accepted after cancellation)
|
||||
assert len(detector.audio_buffer) == 0
|
||||
|
||||
async def test_disconnect_reason_propagation(self):
|
||||
"""Test that voicemail disconnect reason is properly propagated."""
|
||||
# Create disconnect reason info directly
|
||||
disconnect_info = {
|
||||
"disposition_code": EndTaskReason.VOICEMAIL_DETECTED.value,
|
||||
"details": "Voicemail detected after 5 seconds of audio",
|
||||
"is_remote": False,
|
||||
"is_user_initiated": False,
|
||||
"is_successful_transfer": False,
|
||||
"transport_metadata": {
|
||||
"voicemail_confidence": 0.97,
|
||||
"voicemail_transcript": "You've reached voicemail...",
|
||||
},
|
||||
}
|
||||
|
||||
# Verify attributes
|
||||
assert disconnect_info["disposition_code"] == "voicemail_detected"
|
||||
assert disconnect_info["is_remote"] == False
|
||||
assert disconnect_info["is_user_initiated"] == False
|
||||
assert disconnect_info["is_successful_transfer"] == False
|
||||
assert (
|
||||
disconnect_info["details"] == "Voicemail detected after 5 seconds of audio"
|
||||
)
|
||||
assert disconnect_info["transport_metadata"]["voicemail_confidence"] == 0.97
|
||||
|
||||
async def test_voicemail_detection_end_to_end(self):
|
||||
"""
|
||||
Complete end-to-end test covering:
|
||||
1. on_client_connected event
|
||||
2. Engine initialization with voicemail detection
|
||||
3. Audio processing and voicemail detection
|
||||
4. Engine setting disconnect reason
|
||||
5. on_client_disconnected event
|
||||
6. Proper disconnect reason in workflow run update
|
||||
"""
|
||||
# Create comprehensive mocks
|
||||
from api.services.pipecat.event_handlers import (
|
||||
register_transport_event_handlers,
|
||||
)
|
||||
|
||||
# Mock transport
|
||||
mock_transport = AsyncMock()
|
||||
handlers = {}
|
||||
|
||||
def track_handler(event_name):
|
||||
def decorator(func):
|
||||
handlers[event_name] = func
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
mock_transport.event_handler = track_handler
|
||||
|
||||
# Mock audio buffer
|
||||
mock_audio_buffer = Mock()
|
||||
mock_audio_buffer.start_recording = AsyncMock()
|
||||
mock_audio_buffer.stop_recording = AsyncMock()
|
||||
|
||||
# Mock task
|
||||
mock_task = AsyncMock()
|
||||
|
||||
# Mock aggregator
|
||||
mock_aggregator = Mock()
|
||||
mock_aggregator.get_all_usage_metrics_serialized.return_value = {}
|
||||
|
||||
# Create a mock engine with voicemail detection
|
||||
mock_engine = Mock()
|
||||
mock_engine.initialize = AsyncMock()
|
||||
mock_engine.cleanup = AsyncMock()
|
||||
|
||||
# Mock voicemail detector
|
||||
mock_voicemail_detector = Mock()
|
||||
mock_engine.voicemail_detector = mock_voicemail_detector
|
||||
mock_engine._voicemail_detector = mock_voicemail_detector
|
||||
|
||||
# Initially no disconnect reason
|
||||
mock_engine.get_call_disposition = Mock(return_value=None)
|
||||
mock_engine.get_gathered_context = Mock(return_value={})
|
||||
|
||||
# Mock db_client
|
||||
mock_db_client = Mock()
|
||||
mock_db_client.update_workflow_run = AsyncMock()
|
||||
|
||||
with (
|
||||
patch("api.services.pipecat.event_handlers.db_client", mock_db_client),
|
||||
patch(
|
||||
"api.services.pipecat.event_handlers.enqueue_job",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_enqueue_job,
|
||||
patch(
|
||||
"api.services.pipecat.event_handlers.get_organization_id_from_workflow_run",
|
||||
return_value=1,
|
||||
),
|
||||
patch(
|
||||
"api.services.pipecat.event_handlers.apply_disposition_mapping",
|
||||
side_effect=lambda value, org_id: value, # Return value unchanged
|
||||
),
|
||||
):
|
||||
# Register event handlers
|
||||
register_transport_event_handlers(
|
||||
mock_transport,
|
||||
workflow_run_id=123,
|
||||
audio_buffer=mock_audio_buffer,
|
||||
task=mock_task,
|
||||
engine=mock_engine,
|
||||
usage_metrics_aggregator=mock_aggregator,
|
||||
)
|
||||
|
||||
# Verify handlers were registered
|
||||
assert "on_client_connected" in handlers
|
||||
assert "on_client_disconnected" in handlers
|
||||
|
||||
# Step 1: Client connects
|
||||
await handlers["on_client_connected"](
|
||||
mock_transport, {"id": "participant_1"}
|
||||
)
|
||||
|
||||
# Verify initialization
|
||||
mock_audio_buffer.start_recording.assert_called_once()
|
||||
mock_engine.initialize.assert_called_once()
|
||||
|
||||
# Step 2-3: Simulate voicemail detection occurs
|
||||
# Update engine state to reflect voicemail was detected
|
||||
mock_engine.get_call_disposition = Mock(
|
||||
return_value=EndTaskReason.VOICEMAIL_DETECTED.value
|
||||
)
|
||||
mock_engine.get_gathered_context = Mock(
|
||||
return_value={
|
||||
"voicemail_transcript": "You've reached voicemail, leave a message",
|
||||
"voicemail_confidence": 0.95,
|
||||
}
|
||||
)
|
||||
|
||||
# Step 5: Client disconnects
|
||||
await handlers["on_client_disconnected"](
|
||||
mock_transport, {"id": "participant_1"}, None
|
||||
)
|
||||
|
||||
# Verify engine cleanup
|
||||
mock_engine.cleanup.assert_called_once()
|
||||
|
||||
# Step 6: Verify proper disconnect reason in workflow run update
|
||||
mock_db_client.update_workflow_run.assert_called()
|
||||
call_args = mock_db_client.update_workflow_run.call_args
|
||||
|
||||
# Check the gathered context includes disconnect reason
|
||||
gathered_context = call_args[1]["gathered_context"]
|
||||
assert gathered_context["mapped_call_disposition"] == "voicemail_detected"
|
||||
assert gathered_context["voicemail_confidence"] == 0.95
|
||||
assert (
|
||||
gathered_context["voicemail_transcript"]
|
||||
== "You've reached voicemail, leave a message"
|
||||
)
|
||||
|
||||
# Verify task was NOT cancelled (engine-initiated disconnect)
|
||||
mock_task.cancel.assert_not_called()
|
||||
|
||||
# Verify audio buffer was stopped
|
||||
mock_audio_buffer.stop_recording.assert_called_once()
|
||||
|
||||
# Verify background jobs were enqueued
|
||||
assert (
|
||||
mock_enqueue_job.call_count >= 3
|
||||
) # At least 3 jobs should be enqueued
|
||||
667
api/tests/test_workflow_routes.py
Normal file
667
api/tests/test_workflow_routes.py
Normal file
|
|
@ -0,0 +1,667 @@
|
|||
"""
|
||||
Tests for workflow API routes.
|
||||
|
||||
This module tests the create, update, get, and validate workflow endpoints.
|
||||
The fixtures for database setup, test client, and utilities are in conftest.py.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from fastapi import status
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_workflow_definition():
|
||||
"""Sample workflow definition for testing."""
|
||||
return {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "6581",
|
||||
"type": "startCall",
|
||||
"position": {"x": 427, "y": 23},
|
||||
"data": {
|
||||
"prompt": "Hello, I am Abhishek from Dograh. ",
|
||||
"is_static": True,
|
||||
"name": "Start Call",
|
||||
"is_start": True,
|
||||
"invalid": False,
|
||||
"validationMessage": None,
|
||||
"allow_interrupt": False,
|
||||
},
|
||||
"measured": {"width": 300, "height": 100},
|
||||
"selected": True,
|
||||
"dragging": False,
|
||||
},
|
||||
{
|
||||
"id": "915",
|
||||
"type": "agentNode",
|
||||
"position": {"x": 305, "y": 340},
|
||||
"data": {
|
||||
"prompt": "You are a voice agent whose mode of speaking is voice. Ask the user whether they want to talk to a sales guy or a customer service agent.",
|
||||
"name": "Agent",
|
||||
"invalid": False,
|
||||
"validationMessage": None,
|
||||
"allow_interrupt": False,
|
||||
},
|
||||
"measured": {"width": 300, "height": 100},
|
||||
"selected": False,
|
||||
"dragging": False,
|
||||
},
|
||||
{
|
||||
"id": "7598",
|
||||
"type": "agentNode",
|
||||
"position": {"x": 90, "y": 650},
|
||||
"data": {
|
||||
"prompt": "You are a customer service agent whose mode of communication with the user is voice. Tell them that someone from our team will reach out to them soon",
|
||||
"name": "Agent",
|
||||
"invalid": False,
|
||||
"validationMessage": None,
|
||||
"allow_interrupt": True,
|
||||
},
|
||||
"measured": {"width": 300, "height": 100},
|
||||
"selected": False,
|
||||
"dragging": False,
|
||||
},
|
||||
{
|
||||
"id": "6919",
|
||||
"type": "agentNode",
|
||||
"position": {"x": 520, "y": 650},
|
||||
"data": {
|
||||
"prompt": "You are a sales representative whose mode of communication with the user is voice. Tell the user that someone from our team will reach out to you soon",
|
||||
"name": "Agent",
|
||||
"invalid": False,
|
||||
"validationMessage": None,
|
||||
"allow_interrupt": True,
|
||||
},
|
||||
"measured": {"width": 300, "height": 100},
|
||||
"selected": False,
|
||||
"dragging": False,
|
||||
},
|
||||
{
|
||||
"id": "1802",
|
||||
"type": "endCall",
|
||||
"position": {"x": 305, "y": 960},
|
||||
"data": {
|
||||
"prompt": "Thank you!",
|
||||
"invalid": False,
|
||||
"validationMessage": None,
|
||||
"is_static": True,
|
||||
"name": "End Call",
|
||||
"is_end": True,
|
||||
"allow_interrupt": False,
|
||||
},
|
||||
"measured": {"width": 300, "height": 100},
|
||||
"selected": False,
|
||||
"dragging": False,
|
||||
},
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"animated": True,
|
||||
"type": "custom",
|
||||
"source": "915",
|
||||
"target": "7598",
|
||||
"id": "xy-edge__915-7598",
|
||||
"selected": False,
|
||||
"data": {
|
||||
"condition": "The customer wants to talk to a customer service agent",
|
||||
"label": "customer service agent",
|
||||
"invalid": False,
|
||||
"validationMessage": None,
|
||||
},
|
||||
},
|
||||
{
|
||||
"animated": True,
|
||||
"type": "custom",
|
||||
"source": "915",
|
||||
"target": "6919",
|
||||
"id": "xy-edge__915-6919",
|
||||
"selected": False,
|
||||
"data": {
|
||||
"condition": "customer wants to talk to a sales representative",
|
||||
"label": "sales representative",
|
||||
"invalid": False,
|
||||
"validationMessage": None,
|
||||
},
|
||||
},
|
||||
{
|
||||
"animated": True,
|
||||
"type": "custom",
|
||||
"source": "6581",
|
||||
"target": "915",
|
||||
"id": "xy-edge__6581-915",
|
||||
"selected": False,
|
||||
"data": {
|
||||
"condition": "Always take this route",
|
||||
"label": "Always take this route",
|
||||
"invalid": False,
|
||||
"validationMessage": None,
|
||||
},
|
||||
},
|
||||
{
|
||||
"animated": True,
|
||||
"type": "custom",
|
||||
"source": "7598",
|
||||
"target": "1802",
|
||||
"id": "xy-edge__7598-1802",
|
||||
"selected": False,
|
||||
"data": {
|
||||
"condition": "end call",
|
||||
"label": "end call",
|
||||
"invalid": False,
|
||||
"validationMessage": None,
|
||||
},
|
||||
},
|
||||
{
|
||||
"animated": True,
|
||||
"type": "custom",
|
||||
"source": "6919",
|
||||
"target": "1802",
|
||||
"id": "xy-edge__6919-1802",
|
||||
"selected": False,
|
||||
"data": {
|
||||
"condition": "end call",
|
||||
"label": "end call",
|
||||
"invalid": False,
|
||||
"validationMessage": None,
|
||||
},
|
||||
},
|
||||
],
|
||||
"viewport": {"x": 0, "y": 0, "zoom": 1},
|
||||
}
|
||||
|
||||
|
||||
class TestCreateWorkflow:
|
||||
"""Test cases for creating workflows."""
|
||||
|
||||
async def test_create_workflow_success(
|
||||
self, test_client_factory, db_session, sample_workflow_definition
|
||||
):
|
||||
"""Test successful workflow creation."""
|
||||
# Create a test user for this test
|
||||
test_user = await db_session.get_or_create_user_by_provider_id(
|
||||
"test_user_create_success"
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"name": "Test Workflow",
|
||||
"workflow_definition": sample_workflow_definition,
|
||||
}
|
||||
|
||||
async with test_client_factory(test_user) as client:
|
||||
response = await client.post("/api/v1/workflow/create", json=request_data)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
assert "id" in data
|
||||
assert data["name"] == "Test Workflow"
|
||||
assert data["workflow_definition"] == sample_workflow_definition
|
||||
assert "created_at" in data
|
||||
assert "current_definition_id" in data
|
||||
|
||||
async def test_create_workflow_invalid_definition(
|
||||
self, test_client_factory, db_session
|
||||
):
|
||||
"""Test workflow creation with invalid definition."""
|
||||
# Create a test user for this test
|
||||
test_user = await db_session.get_or_create_user_by_provider_id(
|
||||
"test_user_invalid_def"
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"name": "Invalid Workflow",
|
||||
"workflow_definition": {"invalid": "structure"},
|
||||
}
|
||||
|
||||
async with test_client_factory(test_user) as client:
|
||||
response = await client.post("/api/v1/workflow/create", json=request_data)
|
||||
|
||||
# The API should still create the workflow even with invalid definition
|
||||
# Validation happens in the validate endpoint
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_workflow_missing_name(
|
||||
self, test_client_factory, db_session, sample_workflow_definition
|
||||
):
|
||||
"""Test workflow creation without name."""
|
||||
# Create a test user for this test
|
||||
test_user = await db_session.get_or_create_user_by_provider_id(
|
||||
"test_user_missing_name"
|
||||
)
|
||||
|
||||
request_data = {"workflow_definition": sample_workflow_definition}
|
||||
|
||||
async with test_client_factory(test_user) as client:
|
||||
response = await client.post("/api/v1/workflow/create", json=request_data)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_workflow_missing_definition(
|
||||
self, test_client_factory, db_session
|
||||
):
|
||||
"""Test workflow creation without workflow definition."""
|
||||
# Create a test user for this test
|
||||
test_user = await db_session.get_or_create_user_by_provider_id(
|
||||
"test_user_missing_definition"
|
||||
)
|
||||
|
||||
request_data = {"name": "Test Workflow"}
|
||||
|
||||
async with test_client_factory(test_user) as client:
|
||||
response = await client.post("/api/v1/workflow/create", json=request_data)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
|
||||
class TestGetWorkflows:
|
||||
"""Test cases for fetching workflows."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_workflows_empty(self, test_client_factory, db_session):
|
||||
"""Test getting all workflows when none exist."""
|
||||
# Create a test user within the test function
|
||||
test_user = await db_session.get_or_create_user_by_provider_id(
|
||||
"test_user_empty_workflows"
|
||||
)
|
||||
|
||||
# Create a test client for this specific user
|
||||
async with test_client_factory(test_user) as client:
|
||||
response = await client.get("/api/v1/workflow/fetch")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert isinstance(data, list)
|
||||
assert len(data) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_workflows_with_data(
|
||||
self, test_client_factory, db_session, sample_workflow_definition
|
||||
):
|
||||
"""Test getting all workflows when some exist."""
|
||||
# Create a test user within the test function
|
||||
test_user = await db_session.get_or_create_user_by_provider_id(
|
||||
"test_user_with_workflows"
|
||||
)
|
||||
|
||||
# Create a test client for this specific user
|
||||
async with test_client_factory(test_user) as client:
|
||||
# Create a workflow first
|
||||
create_response = await client.post(
|
||||
"/api/v1/workflow/create",
|
||||
json={
|
||||
"name": "Test Workflow 1",
|
||||
"workflow_definition": sample_workflow_definition,
|
||||
},
|
||||
)
|
||||
assert create_response.status_code == status.HTTP_200_OK
|
||||
|
||||
# Create another workflow
|
||||
create_response2 = await client.post(
|
||||
"/api/v1/workflow/create",
|
||||
json={
|
||||
"name": "Test Workflow 2",
|
||||
"workflow_definition": sample_workflow_definition,
|
||||
},
|
||||
)
|
||||
assert create_response2.status_code == status.HTTP_200_OK
|
||||
|
||||
# Get all workflows
|
||||
response = await client.get("/api/v1/workflow/fetch")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert isinstance(data, list)
|
||||
assert len(data) == 2
|
||||
|
||||
# Check that both workflows are returned
|
||||
workflow_names = [w["name"] for w in data]
|
||||
assert "Test Workflow 1" in workflow_names
|
||||
assert "Test Workflow 2" in workflow_names
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_specific_workflow(
|
||||
self, test_client_factory, db_session, sample_workflow_definition
|
||||
):
|
||||
"""Test getting a specific workflow by ID."""
|
||||
# Create a test user for this test
|
||||
test_user = await db_session.get_or_create_user_by_provider_id(
|
||||
"test_user_specific_workflow"
|
||||
)
|
||||
|
||||
async with test_client_factory(test_user) as client:
|
||||
# Create a workflow first
|
||||
create_response = await client.post(
|
||||
"/api/v1/workflow/create",
|
||||
json={
|
||||
"name": "Specific Workflow",
|
||||
"workflow_definition": sample_workflow_definition,
|
||||
},
|
||||
)
|
||||
assert create_response.status_code == status.HTTP_200_OK
|
||||
created_workflow = create_response.json()
|
||||
workflow_id = created_workflow["id"]
|
||||
|
||||
# Get the specific workflow
|
||||
response = await client.get(
|
||||
f"/api/v1/workflow/fetch?workflow_id={workflow_id}"
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
assert data["id"] == workflow_id
|
||||
assert data["name"] == "Specific Workflow"
|
||||
assert data["workflow_definition"] == sample_workflow_definition
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_nonexistent_workflow(self, test_client_factory, db_session):
|
||||
"""Test getting a workflow that doesn't exist."""
|
||||
# Create a test user for this test
|
||||
test_user = await db_session.get_or_create_user_by_provider_id(
|
||||
"test_user_nonexistent"
|
||||
)
|
||||
|
||||
async with test_client_factory(test_user) as client:
|
||||
response = await client.get("/api/v1/workflow/fetch?workflow_id=99999")
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert "not found" in response.json()["detail"].lower()
|
||||
|
||||
|
||||
class TestUpdateWorkflow:
|
||||
"""Test cases for updating workflows."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_workflow_name_only(
|
||||
self, test_client_factory, db_session, sample_workflow_definition
|
||||
):
|
||||
"""Test updating only the workflow name."""
|
||||
# Create a test user for this test
|
||||
test_user = await db_session.get_or_create_user_by_provider_id(
|
||||
"test_user_update_name"
|
||||
)
|
||||
|
||||
async with test_client_factory(test_user) as client:
|
||||
# Create a workflow first
|
||||
create_response = await client.post(
|
||||
"/api/v1/workflow/create",
|
||||
json={
|
||||
"name": "Original Name",
|
||||
"workflow_definition": sample_workflow_definition,
|
||||
},
|
||||
)
|
||||
assert create_response.status_code == status.HTTP_200_OK
|
||||
workflow_id = create_response.json()["id"]
|
||||
|
||||
# Update the workflow name
|
||||
update_data = {"name": "Updated Name"}
|
||||
response = await client.put(
|
||||
f"/api/v1/workflow/{workflow_id}", json=update_data
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
assert data["id"] == workflow_id
|
||||
assert data["name"] == "Updated Name"
|
||||
assert (
|
||||
data["workflow_definition"] == sample_workflow_definition
|
||||
) # Should remain unchanged
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_workflow_name_and_definition(
|
||||
self, test_client_factory, db_session, sample_workflow_definition
|
||||
):
|
||||
"""Test updating both workflow name and definition."""
|
||||
# Create a test user for this test
|
||||
test_user = await db_session.get_or_create_user_by_provider_id(
|
||||
"test_user_update_both"
|
||||
)
|
||||
|
||||
async with test_client_factory(test_user) as client:
|
||||
# Create a workflow first
|
||||
create_response = await client.post(
|
||||
"/api/v1/workflow/create",
|
||||
json={
|
||||
"name": "Original Name",
|
||||
"workflow_definition": sample_workflow_definition,
|
||||
},
|
||||
)
|
||||
assert create_response.status_code == status.HTTP_200_OK
|
||||
workflow_id = create_response.json()["id"]
|
||||
|
||||
# Create new workflow definition
|
||||
new_definition = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "start",
|
||||
"type": "start",
|
||||
"position": {"x": 50, "y": 50},
|
||||
"data": {"label": "New Start"},
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
|
||||
# Update the workflow
|
||||
update_data = {
|
||||
"name": "Updated Name",
|
||||
"workflow_definition": new_definition,
|
||||
}
|
||||
response = await client.put(
|
||||
f"/api/v1/workflow/{workflow_id}", json=update_data
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
assert data["id"] == workflow_id
|
||||
assert data["name"] == "Updated Name"
|
||||
assert data["workflow_definition"] == new_definition
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_nonexistent_workflow(self, test_client_factory, db_session):
|
||||
"""Test updating a workflow that doesn't exist."""
|
||||
# Create a test user for this test
|
||||
test_user = await db_session.get_or_create_user_by_provider_id(
|
||||
"test_user_update_nonexistent"
|
||||
)
|
||||
|
||||
update_data = {"name": "Updated Name"}
|
||||
async with test_client_factory(test_user) as client:
|
||||
response = await client.put("/api/v1/workflow/99999", json=update_data)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert "not found" in response.json()["detail"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_workflow_missing_name(
|
||||
self, test_client_factory, db_session, sample_workflow_definition
|
||||
):
|
||||
"""Test updating a workflow without providing a name."""
|
||||
# Create a test user for this test
|
||||
test_user = await db_session.get_or_create_user_by_provider_id(
|
||||
"test_user_update_missing_name"
|
||||
)
|
||||
|
||||
async with test_client_factory(test_user) as client:
|
||||
# Create a workflow first
|
||||
create_response = await client.post(
|
||||
"/api/v1/workflow/create",
|
||||
json={
|
||||
"name": "Original Name",
|
||||
"workflow_definition": sample_workflow_definition,
|
||||
},
|
||||
)
|
||||
assert create_response.status_code == status.HTTP_200_OK
|
||||
workflow_id = create_response.json()["id"]
|
||||
|
||||
# Try to update without providing name
|
||||
update_data = {"workflow_definition": sample_workflow_definition}
|
||||
response = await client.put(
|
||||
f"/api/v1/workflow/{workflow_id}", json=update_data
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
|
||||
class TestWorkflowValidation:
|
||||
"""Test cases for workflow validation endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_workflow_success(
|
||||
self, test_client_factory, db_session, sample_workflow_definition
|
||||
):
|
||||
"""Test successful workflow validation."""
|
||||
# Create a test user for this test
|
||||
test_user = await db_session.get_or_create_user_by_provider_id(
|
||||
"test_user_validate_success"
|
||||
)
|
||||
|
||||
async with test_client_factory(test_user) as client:
|
||||
# Create a workflow first
|
||||
create_response = await client.post(
|
||||
"/api/v1/workflow/create",
|
||||
json={
|
||||
"name": "Valid Workflow",
|
||||
"workflow_definition": sample_workflow_definition,
|
||||
},
|
||||
)
|
||||
assert create_response.status_code == status.HTTP_200_OK
|
||||
workflow_id = create_response.json()["id"]
|
||||
|
||||
# Validate the workflow
|
||||
response = await client.post(f"/api/v1/workflow/{workflow_id}/validate")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
assert data["is_valid"] is True
|
||||
assert data["errors"] == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_nonexistent_workflow(self, test_client_factory, db_session):
|
||||
"""Test validating a workflow that doesn't exist."""
|
||||
# Create a test user for this test
|
||||
test_user = await db_session.get_or_create_user_by_provider_id(
|
||||
"test_user_validate_nonexistent"
|
||||
)
|
||||
|
||||
async with test_client_factory(test_user) as client:
|
||||
response = await client.post("/api/v1/workflow/99999/validate")
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert "not found" in response.json()["detail"].lower()
|
||||
|
||||
|
||||
class TestWorkflowIntegration:
|
||||
"""Integration tests for workflow operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_workflow_lifecycle(
|
||||
self, test_client_factory, db_session, sample_workflow_definition
|
||||
):
|
||||
"""Test the complete lifecycle of a workflow: create, get, update, validate."""
|
||||
# Create a test user for this test
|
||||
test_user = await db_session.get_or_create_user_by_provider_id(
|
||||
"test_user_lifecycle"
|
||||
)
|
||||
|
||||
async with test_client_factory(test_user) as client:
|
||||
# 1. Create workflow
|
||||
create_response = await client.post(
|
||||
"/api/v1/workflow/create",
|
||||
json={
|
||||
"name": "Lifecycle Test Workflow",
|
||||
"workflow_definition": sample_workflow_definition,
|
||||
},
|
||||
)
|
||||
assert create_response.status_code == status.HTTP_200_OK
|
||||
workflow_id = create_response.json()["id"]
|
||||
|
||||
# 2. Get the created workflow
|
||||
get_response = await client.get(
|
||||
f"/api/v1/workflow/fetch?workflow_id={workflow_id}"
|
||||
)
|
||||
assert get_response.status_code == status.HTTP_200_OK
|
||||
workflow_data = get_response.json()
|
||||
assert workflow_data["name"] == "Lifecycle Test Workflow"
|
||||
|
||||
# 3. Add a new node in the workflow definition
|
||||
new_node = {
|
||||
"id": "6919_new",
|
||||
"type": "agentNode",
|
||||
"position": {"x": 520, "y": 650},
|
||||
"data": {
|
||||
"prompt": "Something new",
|
||||
"name": "Agent",
|
||||
"invalid": False,
|
||||
"validationMessage": None,
|
||||
"allow_interrupt": True,
|
||||
},
|
||||
"measured": {"width": 300, "height": 100},
|
||||
"selected": False,
|
||||
"dragging": False,
|
||||
}
|
||||
new_edges = [
|
||||
{
|
||||
"source": "6919",
|
||||
"target": "6919_new",
|
||||
"id": "xy-edge__6919-6919_new",
|
||||
"data": {
|
||||
"condition": "Always take this route",
|
||||
"label": "Always take this route",
|
||||
"invalid": False,
|
||||
"validationMessage": None,
|
||||
},
|
||||
},
|
||||
{
|
||||
"source": "6919_new",
|
||||
"target": "1802",
|
||||
"id": "xy-edge__6919_new-1802",
|
||||
"data": {
|
||||
"condition": "Always take this route",
|
||||
"label": "Always take this route",
|
||||
"invalid": False,
|
||||
"validationMessage": None,
|
||||
},
|
||||
},
|
||||
]
|
||||
new_definition = {
|
||||
"nodes": [
|
||||
*sample_workflow_definition["nodes"],
|
||||
new_node,
|
||||
],
|
||||
"edges": [
|
||||
*sample_workflow_definition["edges"],
|
||||
*new_edges,
|
||||
],
|
||||
}
|
||||
|
||||
update_response = await client.put(
|
||||
f"/api/v1/workflow/{workflow_id}",
|
||||
json={
|
||||
"name": "Updated Lifecycle Workflow",
|
||||
"workflow_definition": new_definition,
|
||||
},
|
||||
)
|
||||
assert update_response.status_code == status.HTTP_200_OK
|
||||
assert update_response.json()["name"] == "Updated Lifecycle Workflow"
|
||||
|
||||
# 4. Validate the updated workflow
|
||||
validate_response = await client.post(
|
||||
f"/api/v1/workflow/{workflow_id}/validate"
|
||||
)
|
||||
assert validate_response.status_code == status.HTTP_200_OK
|
||||
assert validate_response.json()["is_valid"] is True
|
||||
|
||||
# 5. Verify the update by getting the workflow again
|
||||
final_get_response = await client.get(
|
||||
f"/api/v1/workflow/fetch?workflow_id={workflow_id}"
|
||||
)
|
||||
assert final_get_response.status_code == status.HTTP_200_OK
|
||||
final_data = final_get_response.json()
|
||||
assert final_data["name"] == "Updated Lifecycle Workflow"
|
||||
assert final_data["workflow_definition"] == new_definition
|
||||
Loading…
Add table
Add a link
Reference in a new issue