mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
506 lines
16 KiB
Python
506 lines
16 KiB
Python
"""
|
|
Tests for LoopTalk API routes and orchestration.
|
|
|
|
This module tests the LoopTalk testing functionality including test session creation,
|
|
pipeline orchestration, and agent-to-agent communication.
|
|
"""
|
|
|
|
import asyncio
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
from fastapi import status
|
|
|
|
from api.db.db_client import DBClient
|
|
from api.services.looptalk.orchestrator import LoopTalkTestOrchestrator
|
|
|
|
|
|
@pytest.fixture
|
|
def actor_workflow_definition():
|
|
"""Sample actor workflow definition for testing."""
|
|
return {
|
|
"nodes": [
|
|
{
|
|
"id": "1",
|
|
"type": "startCall",
|
|
"position": {"x": 0, "y": 0},
|
|
"data": {
|
|
"prompt": "Hello, I'm the actor agent.",
|
|
"is_static": True,
|
|
"name": "Start Call",
|
|
"is_start": True,
|
|
"allow_interrupt": False,
|
|
},
|
|
},
|
|
{
|
|
"id": "2",
|
|
"type": "agentNode",
|
|
"position": {"x": 100, "y": 0},
|
|
"data": {
|
|
"prompt": "You are an actor agent testing the adversary. Ask probing questions.",
|
|
"name": "Actor Agent",
|
|
"allow_interrupt": True,
|
|
},
|
|
},
|
|
{
|
|
"id": "3",
|
|
"type": "endCall",
|
|
"position": {"x": 200, "y": 0},
|
|
"data": {
|
|
"prompt": "Goodbye!",
|
|
"name": "End Call",
|
|
"is_end": True,
|
|
},
|
|
},
|
|
],
|
|
"edges": [
|
|
{
|
|
"id": "e1",
|
|
"source": "1",
|
|
"target": "2",
|
|
"data": {"label": "Continue", "condition": "Always"},
|
|
},
|
|
{
|
|
"id": "e2",
|
|
"source": "2",
|
|
"target": "3",
|
|
"data": {"label": "End", "condition": "Always"},
|
|
},
|
|
],
|
|
"stt": {"provider": "openai", "api_key": "test-key", "model": "whisper-1"},
|
|
"llm": {"provider": "openai", "api_key": "test-key", "model": "gpt-4o-mini"},
|
|
"tts": {
|
|
"provider": "openai",
|
|
"api_key": "test-key",
|
|
"model": "tts-1",
|
|
"voice": "nova",
|
|
},
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def adversary_workflow_definition():
|
|
"""Sample adversary workflow definition for testing."""
|
|
return {
|
|
"nodes": [
|
|
{
|
|
"id": "1",
|
|
"type": "startCall",
|
|
"position": {"x": 0, "y": 0},
|
|
"data": {
|
|
"prompt": "Hello, I'm the adversary agent.",
|
|
"is_static": True,
|
|
"name": "Start Call",
|
|
"is_start": True,
|
|
"allow_interrupt": False,
|
|
},
|
|
},
|
|
{
|
|
"id": "2",
|
|
"type": "agentNode",
|
|
"position": {"x": 100, "y": 0},
|
|
"data": {
|
|
"prompt": "You are an adversary agent being tested. Respond defensively.",
|
|
"name": "Adversary Agent",
|
|
"allow_interrupt": True,
|
|
},
|
|
},
|
|
{
|
|
"id": "3",
|
|
"type": "endCall",
|
|
"position": {"x": 200, "y": 0},
|
|
"data": {
|
|
"prompt": "Goodbye!",
|
|
"name": "End Call",
|
|
"is_end": True,
|
|
},
|
|
},
|
|
],
|
|
"edges": [
|
|
{
|
|
"id": "e1",
|
|
"source": "1",
|
|
"target": "2",
|
|
"data": {"label": "Continue", "condition": "Always"},
|
|
},
|
|
{
|
|
"id": "e2",
|
|
"source": "2",
|
|
"target": "3",
|
|
"data": {"label": "End", "condition": "Always"},
|
|
},
|
|
],
|
|
"stt": {"provider": "deepgram", "api_key": "test-key", "model": "nova-2"},
|
|
"llm": {
|
|
"provider": "groq",
|
|
"api_key": "test-key",
|
|
"model": "llama-3.1-70b-versatile",
|
|
},
|
|
"tts": {"provider": "deepgram", "api_key": "test-key", "voice": "nova-2"},
|
|
}
|
|
|
|
|
|
from pipecat.processors.frame_processor import FrameProcessor
|
|
|
|
|
|
class MockSTTService(FrameProcessor):
|
|
"""Mock STT service for testing."""
|
|
|
|
def __init__(self, **kwargs):
|
|
super().__init__(**kwargs)
|
|
|
|
async def run_stt(self, audio: bytes) -> str:
|
|
return "Mock transcription"
|
|
|
|
|
|
class MockLLMService(FrameProcessor):
|
|
"""Mock LLM service for testing."""
|
|
|
|
def __init__(self, **kwargs):
|
|
super().__init__(**kwargs)
|
|
|
|
async def run_llm(self, messages) -> str:
|
|
return "Mock LLM response"
|
|
|
|
def create_context_aggregator(self, context):
|
|
"""Mock context aggregator creation."""
|
|
return MagicMock()
|
|
|
|
|
|
class MockTTSService(FrameProcessor):
|
|
"""Mock TTS service for testing."""
|
|
|
|
def __init__(self, **kwargs):
|
|
super().__init__(**kwargs)
|
|
|
|
async def run_tts(self, text: str) -> bytes:
|
|
return b"Mock audio data"
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def test_user_with_org(db_session):
|
|
"""Create a test user with an organization set up."""
|
|
user = await db_session.get_or_create_user_by_provider_id("test_looptalk_user")
|
|
org, _ = await db_session.get_or_create_organization_by_provider_id(
|
|
"test_looptalk_org"
|
|
)
|
|
|
|
user_id = user.id
|
|
org_id = org.id
|
|
|
|
await db_session.add_user_to_organization(user_id, org_id)
|
|
|
|
# Update user's selected organization
|
|
async with db_session.async_session() as session:
|
|
from sqlalchemy import update
|
|
|
|
from api.db.models import UserModel
|
|
|
|
await session.execute(
|
|
update(UserModel)
|
|
.where(UserModel.id == user_id)
|
|
.values(selected_organization_id=org_id)
|
|
)
|
|
await session.commit()
|
|
|
|
# Return fresh user object
|
|
return await db_session.get_user_by_id(user_id)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_test_session(
|
|
test_client_factory,
|
|
db_session,
|
|
test_user_with_org,
|
|
actor_workflow_definition,
|
|
adversary_workflow_definition,
|
|
):
|
|
"""Test creating a new LoopTalk test session."""
|
|
async with test_client_factory(test_user_with_org) as test_client:
|
|
# First create two workflows
|
|
actor_workflow_response = await test_client.post(
|
|
"/api/v1/workflow/create",
|
|
json={
|
|
"name": "Actor Workflow",
|
|
"workflow_definition": actor_workflow_definition,
|
|
},
|
|
)
|
|
assert actor_workflow_response.status_code == status.HTTP_200_OK
|
|
actor_workflow_id = actor_workflow_response.json()["id"]
|
|
|
|
adversary_workflow_response = await test_client.post(
|
|
"/api/v1/workflow/create",
|
|
json={
|
|
"name": "Adversary Workflow",
|
|
"workflow_definition": adversary_workflow_definition,
|
|
},
|
|
)
|
|
assert adversary_workflow_response.status_code == status.HTTP_200_OK
|
|
adversary_workflow_id = adversary_workflow_response.json()["id"]
|
|
|
|
# Create test session
|
|
response = await test_client.post(
|
|
"/api/v1/looptalk/test-sessions",
|
|
json={
|
|
"name": "Test Session 1",
|
|
"actor_workflow_id": actor_workflow_id,
|
|
"adversary_workflow_id": adversary_workflow_id,
|
|
"config": {"test_duration": 60},
|
|
},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
assert data["name"] == "Test Session 1"
|
|
assert data["status"] == "pending"
|
|
assert data["actor_workflow_id"] == actor_workflow_id
|
|
assert data["adversary_workflow_id"] == adversary_workflow_id
|
|
assert data["config"]["test_duration"] == 60
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_test_sessions(test_client_factory, db_session, test_user_with_org):
|
|
"""Test listing LoopTalk test sessions."""
|
|
async with test_client_factory(test_user_with_org) as test_client:
|
|
response = await test_client.get(
|
|
"/api/v1/looptalk/test-sessions",
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
assert isinstance(data, list)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_looptalk_orchestrator_plumbing(
|
|
db_session: DBClient, actor_workflow_definition, adversary_workflow_definition
|
|
):
|
|
"""Test the LoopTalk orchestrator plumbing with mocked services."""
|
|
|
|
# Create test user and organization
|
|
user = await db_session.get_or_create_user_by_provider_id(
|
|
provider_id="test-user-123"
|
|
)
|
|
org, _ = await db_session.get_or_create_organization_by_provider_id(
|
|
org_provider_id="test-org-123"
|
|
)
|
|
|
|
# Get IDs before session closes
|
|
user_id = user.id
|
|
org_id = org.id
|
|
|
|
await db_session.add_user_to_organization(user_id, org_id)
|
|
|
|
# Update user's selected organization manually
|
|
async with db_session.async_session() as session:
|
|
from sqlalchemy import update
|
|
|
|
from api.db.models import UserModel
|
|
|
|
await session.execute(
|
|
update(UserModel)
|
|
.where(UserModel.id == user_id)
|
|
.values(selected_organization_id=org_id)
|
|
)
|
|
await session.commit()
|
|
|
|
actor_workflow = await db_session.create_workflow(
|
|
name="Actor Workflow",
|
|
workflow_definition=actor_workflow_definition,
|
|
user_id=user_id,
|
|
)
|
|
|
|
adversary_workflow = await db_session.create_workflow(
|
|
name="Adversary Workflow",
|
|
workflow_definition=adversary_workflow_definition,
|
|
user_id=user_id,
|
|
)
|
|
|
|
# Create test session
|
|
test_session = await db_session.create_test_session(
|
|
organization_id=org_id,
|
|
name="Test Session",
|
|
actor_workflow_id=actor_workflow.id,
|
|
adversary_workflow_id=adversary_workflow.id,
|
|
config={"test_duration": 10},
|
|
)
|
|
|
|
# Mock the service factories - patch at the actual import location in pipeline_builder
|
|
with (
|
|
patch(
|
|
"api.services.looptalk.core.pipeline_builder.create_stt_service"
|
|
) as mock_stt_factory,
|
|
patch(
|
|
"api.services.looptalk.core.pipeline_builder.create_llm_service"
|
|
) as mock_llm_factory,
|
|
patch(
|
|
"api.services.looptalk.core.pipeline_builder.create_tts_service"
|
|
) as mock_tts_factory,
|
|
patch(
|
|
"api.services.workflow.pipecat_engine.PipecatEngine"
|
|
) as mock_engine_class,
|
|
patch(
|
|
"api.services.pipecat.pipeline_builder.build_pipeline"
|
|
) as mock_build_pipeline,
|
|
patch("api.services.pipecat.pipeline_builder.PipelineTask") as mock_task_class,
|
|
):
|
|
# Configure mocks
|
|
mock_stt_factory.return_value = MockSTTService()
|
|
mock_llm_factory.return_value = MockLLMService()
|
|
mock_tts_factory.return_value = MockTTSService()
|
|
|
|
mock_engine = MagicMock()
|
|
mock_engine.initialize = AsyncMock()
|
|
mock_engine.get_callback_processor = MagicMock(return_value=MagicMock())
|
|
mock_engine_class.return_value = mock_engine
|
|
|
|
# Mock pipeline and task
|
|
mock_pipeline = MagicMock()
|
|
mock_task = MagicMock()
|
|
mock_task.run = AsyncMock()
|
|
mock_task.cancel = AsyncMock() # Make cancel async
|
|
mock_build_pipeline.return_value = mock_pipeline
|
|
mock_task_class.return_value = mock_task
|
|
|
|
# Create orchestrator
|
|
orchestrator = LoopTalkTestOrchestrator(db_client=db_session)
|
|
|
|
# Start test session (in a separate task to avoid blocking)
|
|
start_task = asyncio.create_task(
|
|
orchestrator.start_test_session(
|
|
test_session_id=test_session.id, organization_id=org_id
|
|
)
|
|
)
|
|
|
|
# Give it a moment to start
|
|
await asyncio.sleep(0.5)
|
|
|
|
# Verify the session is running through session manager
|
|
session_info = orchestrator.session_manager.get_session(test_session.id)
|
|
assert session_info is not None
|
|
assert session_info["test_session"].id == test_session.id
|
|
assert "actor_task" in session_info
|
|
assert "adversary_task" in session_info
|
|
|
|
# Verify service factories were called
|
|
assert mock_stt_factory.call_count == 2 # Once for each agent
|
|
assert mock_llm_factory.call_count == 2
|
|
assert mock_tts_factory.call_count == 2
|
|
|
|
# Verify pipelines were created with PipelineTask
|
|
assert mock_task_class.call_count == 2
|
|
|
|
# Stop the test session
|
|
await orchestrator.stop_test_session(test_session_id=test_session.id)
|
|
|
|
# Verify session was cleaned up
|
|
assert orchestrator.session_manager.get_session(test_session.id) is None
|
|
|
|
# Cancel the start task
|
|
start_task.cancel()
|
|
try:
|
|
await start_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_load_test_creation(
|
|
test_client_factory,
|
|
db_session,
|
|
test_user_with_org,
|
|
actor_workflow_definition,
|
|
adversary_workflow_definition,
|
|
):
|
|
"""Test creating a load test with multiple sessions."""
|
|
async with test_client_factory(test_user_with_org) as test_client:
|
|
# First create two workflows
|
|
actor_workflow_response = await test_client.post(
|
|
"/api/v1/workflow/create",
|
|
json={
|
|
"name": "Actor Workflow",
|
|
"workflow_definition": actor_workflow_definition,
|
|
},
|
|
)
|
|
actor_workflow_id = actor_workflow_response.json()["id"]
|
|
|
|
adversary_workflow_response = await test_client.post(
|
|
"/api/v1/workflow/create",
|
|
json={
|
|
"name": "Adversary Workflow",
|
|
"workflow_definition": adversary_workflow_definition,
|
|
},
|
|
)
|
|
adversary_workflow_id = adversary_workflow_response.json()["id"]
|
|
|
|
# Create load test
|
|
response = await test_client.post(
|
|
"/api/v1/looptalk/load-tests",
|
|
json={
|
|
"name_prefix": "Load Test",
|
|
"actor_workflow_id": actor_workflow_id,
|
|
"adversary_workflow_id": adversary_workflow_id,
|
|
"test_count": 3,
|
|
"config": {"test_duration": 30},
|
|
},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
assert data["total"] == 3
|
|
assert "load_test_group_id" in data
|
|
assert len(data["test_session_ids"]) == 3
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_invalid_workflow_ids(
|
|
test_client_factory, db_session, test_user_with_org
|
|
):
|
|
"""Test creating test session with invalid workflow IDs."""
|
|
async with test_client_factory(test_user_with_org) as test_client:
|
|
response = await test_client.post(
|
|
"/api/v1/looptalk/test-sessions",
|
|
json={
|
|
"name": "Invalid Test",
|
|
"actor_workflow_id": 99999,
|
|
"adversary_workflow_id": 99999,
|
|
"config": {},
|
|
},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
|
assert "workflow not found" in response.json()["detail"].lower()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_transport_manager():
|
|
"""Test the internal transport manager functionality."""
|
|
from pipecat.transports import InternalTransportManager, TransportParams
|
|
|
|
manager = InternalTransportManager()
|
|
|
|
# Create transport pair
|
|
params = TransportParams(
|
|
audio_out_enabled=True,
|
|
audio_in_enabled=True,
|
|
audio_out_sample_rate=16000,
|
|
audio_in_sample_rate=16000,
|
|
)
|
|
|
|
actor_transport, adversary_transport = manager.create_transport_pair(
|
|
test_session_id="test-123", actor_params=params, adversary_params=params
|
|
)
|
|
|
|
# Verify transports are connected
|
|
assert actor_transport._output._partner == adversary_transport._input
|
|
assert adversary_transport._output._partner == actor_transport._input
|
|
|
|
# Verify transport pair is tracked
|
|
assert manager.get_active_test_count() == 1
|
|
assert manager.get_transport_pair("test-123") is not None
|
|
|
|
# Remove transport pair
|
|
manager.remove_transport_pair("test-123")
|
|
assert manager.get_active_test_count() == 0
|
|
assert manager.get_transport_pair("test-123") is None
|