mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
Fix/multiple generation (#104)
* fixes #100 * Fix test * fix: fix bad configuration issue
This commit is contained in:
parent
90b690efff
commit
56953bbd09
18 changed files with 758 additions and 460 deletions
|
|
@ -1,5 +1,7 @@
|
|||
from datetime import datetime, timezone
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from api.db.base_client import BaseDBClient
|
||||
|
|
@ -66,12 +68,21 @@ class UserClient(BaseDBClient):
|
|||
if not configuration_obj:
|
||||
return UserConfiguration()
|
||||
|
||||
return UserConfiguration.model_validate(
|
||||
{
|
||||
**configuration_obj.configuration,
|
||||
"last_validated_at": configuration_obj.last_validated_at,
|
||||
}
|
||||
)
|
||||
try:
|
||||
return UserConfiguration.model_validate(
|
||||
{
|
||||
**configuration_obj.configuration,
|
||||
"last_validated_at": configuration_obj.last_validated_at,
|
||||
}
|
||||
)
|
||||
except ValidationError as e:
|
||||
# If configuration contains an unsupported provider,
|
||||
# return a default configuration without failing
|
||||
logger.warning(
|
||||
f"Failed to validate user configuration for user {user_id}: {e}. "
|
||||
"Returning default configuration."
|
||||
)
|
||||
return UserConfiguration()
|
||||
|
||||
async def update_user_configuration(
|
||||
self, user_id: int, configuration: UserConfiguration
|
||||
|
|
|
|||
|
|
@ -494,7 +494,6 @@ async def _run_pipeline(
|
|||
max_duration_end_task_callback=engine.create_max_duration_callback(),
|
||||
generation_started_callback=engine.create_generation_started_callback(),
|
||||
llm_text_frame_callback=engine.handle_llm_text_frame,
|
||||
# Note: speaking event callbacks are now handled by pre-aggregator processor
|
||||
)
|
||||
|
||||
pipeline_metrics_aggregator = PipelineMetricsAggregator()
|
||||
|
|
|
|||
|
|
@ -13,9 +13,8 @@ from pipecat.frames.frames import (
|
|||
CancelFrame,
|
||||
EndFrame,
|
||||
FunctionCallResultProperties,
|
||||
FunctionCallsFromLLMInfoFrame,
|
||||
LLMContextFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
TTSSpeakFrame,
|
||||
)
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
|
|
@ -104,7 +103,7 @@ class PipecatEngine:
|
|||
self._builtin_function_schemas: Optional[list[dict]] = None
|
||||
|
||||
# Track current LLM reference text for TTS aggregation correction
|
||||
self._current_llm_reference_text: str = ""
|
||||
self._current_llm_generation_reference_text: str = ""
|
||||
|
||||
# Custom tool manager (initialized in initialize())
|
||||
self._custom_tool_manager: Optional[CustomToolManager] = None
|
||||
|
|
@ -173,6 +172,9 @@ class PipecatEngine:
|
|||
await self._register_builtin_functions()
|
||||
|
||||
await self.set_node(self.workflow.start_node_id)
|
||||
|
||||
# Trigger initial LLM generation
|
||||
await self.task.queue_frame(LLMContextFrame(self.context))
|
||||
logger.debug(f"{self.__class__.__name__} initialized")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing {self.__class__.__name__}: {e}")
|
||||
|
|
@ -218,7 +220,6 @@ class PipecatEngine:
|
|||
result = {"status": "done"}
|
||||
|
||||
properties = FunctionCallResultProperties(
|
||||
run_llm=False,
|
||||
on_context_updated=on_context_updated,
|
||||
)
|
||||
|
||||
|
|
@ -256,8 +257,6 @@ class PipecatEngine:
|
|||
"""Register built-in functions (calculator and timezone) with the LLM."""
|
||||
logger.debug("Registering built-in functions with LLM")
|
||||
|
||||
properties = FunctionCallResultProperties(run_llm=True)
|
||||
|
||||
# Register calculator function
|
||||
async def calculate_func(function_call_params: FunctionCallParams) -> None:
|
||||
logger.info(f"LLM Function Call EXECUTED: safe_calculator")
|
||||
|
|
@ -266,12 +265,10 @@ class PipecatEngine:
|
|||
expr = function_call_params.arguments.get("expression", "")
|
||||
result = safe_calculator(expr)
|
||||
await function_call_params.result_callback(
|
||||
{"expression": expr, "result": result}, properties=properties
|
||||
{"expression": expr, "result": result}
|
||||
)
|
||||
except Exception as e:
|
||||
await function_call_params.result_callback(
|
||||
{"error": str(e)}, properties=properties
|
||||
)
|
||||
await function_call_params.result_callback({"error": str(e)})
|
||||
|
||||
# Register timezone functions
|
||||
async def get_current_time_func(
|
||||
|
|
@ -282,13 +279,9 @@ class PipecatEngine:
|
|||
try:
|
||||
timezone = function_call_params.arguments.get("timezone", "UTC")
|
||||
result = get_current_time(timezone)
|
||||
await function_call_params.result_callback(
|
||||
result, properties=properties
|
||||
)
|
||||
await function_call_params.result_callback(result)
|
||||
except Exception as e:
|
||||
await function_call_params.result_callback(
|
||||
{"error": str(e)}, properties=properties
|
||||
)
|
||||
await function_call_params.result_callback({"error": str(e)})
|
||||
|
||||
async def convert_time_func(function_call_params: FunctionCallParams) -> None:
|
||||
logger.info(f"LLM Function Call EXECUTED: convert_time")
|
||||
|
|
@ -299,29 +292,15 @@ class PipecatEngine:
|
|||
function_call_params.arguments.get("time"),
|
||||
function_call_params.arguments.get("target_timezone"),
|
||||
)
|
||||
await function_call_params.result_callback(
|
||||
result, properties=properties
|
||||
)
|
||||
await function_call_params.result_callback(result)
|
||||
except Exception as e:
|
||||
await function_call_params.result_callback(
|
||||
{"error": str(e)}, properties=properties
|
||||
)
|
||||
await function_call_params.result_callback({"error": str(e)})
|
||||
|
||||
# Register all built-in functions
|
||||
self.llm.register_function("safe_calculator", calculate_func)
|
||||
self.llm.register_function("get_current_time", get_current_time_func)
|
||||
self.llm.register_function("convert_time", convert_time_func)
|
||||
|
||||
async def _queue_tts_response(self, text: str) -> None:
|
||||
"""Queue TTS frames for static text response."""
|
||||
await self.task.queue_frames(
|
||||
[
|
||||
LLMFullResponseStartFrame(),
|
||||
TTSSpeakFrame(text=text),
|
||||
LLMFullResponseEndFrame(),
|
||||
]
|
||||
)
|
||||
|
||||
async def _perform_variable_extraction_if_needed(
|
||||
self, previous_node: Optional[Node]
|
||||
) -> None:
|
||||
|
|
@ -384,7 +363,6 @@ class PipecatEngine:
|
|||
functions,
|
||||
) = await self._compose_system_message_functions_for_node(node)
|
||||
await self._update_llm_context(system_message, functions)
|
||||
await self.task.queue_frame(LLMContextFrame(self.context))
|
||||
|
||||
async def set_node(self, node_id: str):
|
||||
"""
|
||||
|
|
@ -733,7 +711,7 @@ class PipecatEngine:
|
|||
|
||||
async def handle_llm_text_frame(self, text: str):
|
||||
"""Accumulate LLM text frames to build reference text."""
|
||||
self._current_llm_reference_text += text
|
||||
self._current_llm_generation_reference_text += text
|
||||
|
||||
def handle_client_disconnected(self):
|
||||
"""Handle client disconnected event."""
|
||||
|
|
|
|||
|
|
@ -114,10 +114,10 @@ def create_max_duration_callback(engine: "PipecatEngine"):
|
|||
def create_generation_started_callback(engine: "PipecatEngine"):
|
||||
"""Return a callback that resets flags at the start of each LLM generation."""
|
||||
|
||||
async def handle_generation_started(): # noqa: D401
|
||||
async def handle_generation_started():
|
||||
logger.debug("LLM generation started in callback processor")
|
||||
# Clear reference text from previous generation
|
||||
engine._current_llm_reference_text = ""
|
||||
engine._current_llm_generation_reference_text = ""
|
||||
|
||||
return handle_generation_started
|
||||
|
||||
|
|
@ -184,7 +184,7 @@ def create_aggregation_correction_callback(engine: "PipecatEngine"):
|
|||
return "".join(out_chars)
|
||||
|
||||
def correct_aggregation(corrupted: str) -> str:
|
||||
reference = engine._current_llm_reference_text
|
||||
reference = engine._current_llm_generation_reference_text
|
||||
|
||||
if not reference:
|
||||
logger.warning("No reference text available for aggregation correction")
|
||||
|
|
|
|||
|
|
@ -1,128 +0,0 @@
|
|||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
|
||||
from pipecat.services.openai.llm import OpenAILLMContext
|
||||
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.pipecat_engine_callbacks import (
|
||||
create_generation_started_callback,
|
||||
)
|
||||
|
||||
|
||||
class TestAggregationIntegration:
|
||||
"""Integration tests for the TTS aggregation correction flow."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_engine_reference_text_tracking(self):
|
||||
"""Test that the engine properly tracks LLM reference text."""
|
||||
# Create mock dependencies
|
||||
mock_task = Mock()
|
||||
mock_llm = Mock()
|
||||
mock_context = Mock(spec=OpenAILLMContext)
|
||||
mock_tts = Mock()
|
||||
mock_workflow = Mock()
|
||||
mock_workflow.start_node_id = "start"
|
||||
mock_workflow.nodes = {
|
||||
"start": Mock(is_start=True, is_static=True, is_end=False, out_edges=[])
|
||||
}
|
||||
|
||||
# Create engine
|
||||
engine = PipecatEngine(
|
||||
task=mock_task,
|
||||
llm=mock_llm,
|
||||
context=mock_context,
|
||||
tts=mock_tts,
|
||||
workflow=mock_workflow,
|
||||
call_context_vars={},
|
||||
workflow_run_id=1,
|
||||
)
|
||||
|
||||
# Test initial state
|
||||
assert engine._current_llm_reference_text == ""
|
||||
|
||||
# Test accumulating LLM text
|
||||
await engine.handle_llm_text_frame("Hello ")
|
||||
assert engine._current_llm_reference_text == "Hello "
|
||||
|
||||
await engine.handle_llm_text_frame("world!")
|
||||
assert engine._current_llm_reference_text == "Hello world!"
|
||||
|
||||
# Test generation started callback clears reference text
|
||||
callback = create_generation_started_callback(engine)
|
||||
await callback()
|
||||
assert engine._current_llm_reference_text == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aggregation_correction_callback_creation(self):
|
||||
"""Test creating the aggregation correction callback."""
|
||||
# Create mock engine
|
||||
mock_task = Mock()
|
||||
mock_llm = Mock()
|
||||
mock_context = Mock(spec=OpenAILLMContext)
|
||||
mock_workflow = Mock()
|
||||
|
||||
engine = PipecatEngine(
|
||||
task=mock_task,
|
||||
llm=mock_llm,
|
||||
context=mock_context,
|
||||
workflow=mock_workflow,
|
||||
call_context_vars={},
|
||||
workflow_run_id=1,
|
||||
)
|
||||
|
||||
# Set reference text
|
||||
engine._current_llm_reference_text = "Hello, world! How are you?"
|
||||
|
||||
# Create correction callback
|
||||
callback = engine.create_aggregation_correction_callback()
|
||||
|
||||
# Test correction - note that trailing punctuation might be stripped if not in corrupted text
|
||||
corrected = callback("Hello world How are you")
|
||||
assert corrected == "Hello, world! How are you"
|
||||
|
||||
def test_llm_assistant_aggregator_params_with_callback(self):
|
||||
"""Test that LLMAssistantAggregatorParams accepts correction callback."""
|
||||
|
||||
def mock_callback(text: str) -> str:
|
||||
return text.upper()
|
||||
|
||||
params = LLMAssistantAggregatorParams(
|
||||
expect_stripped_words=True, correct_aggregation_callback=mock_callback
|
||||
)
|
||||
|
||||
assert params.expect_stripped_words is True
|
||||
assert params.correct_aggregation_callback is not None
|
||||
assert params.correct_aggregation_callback("hello") == "HELLO"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_callbacks_processor_llm_text_frame(self):
|
||||
"""Test that PipelineEngineCallbacksProcessor handles LLMTextFrame."""
|
||||
from pipecat.frames.frames import LLMTextFrame
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
|
||||
from api.services.pipecat.pipeline_engine_callbacks_processor import (
|
||||
PipelineEngineCallbacksProcessor,
|
||||
)
|
||||
|
||||
# Track callback invocations
|
||||
callback_invoked = False
|
||||
callback_text = None
|
||||
|
||||
async def mock_llm_text_callback(text: str):
|
||||
nonlocal callback_invoked, callback_text
|
||||
callback_invoked = True
|
||||
callback_text = text
|
||||
|
||||
# Create processor with callback
|
||||
processor = PipelineEngineCallbacksProcessor(
|
||||
llm_text_frame_callback=mock_llm_text_callback
|
||||
)
|
||||
|
||||
# Process LLMTextFrame
|
||||
frame = LLMTextFrame(text="Hello world")
|
||||
await processor.process_frame(frame, FrameDirection.DOWNSTREAM)
|
||||
|
||||
# Verify callback was invoked
|
||||
assert callback_invoked is True
|
||||
assert callback_text == "Hello world"
|
||||
|
|
@ -1,159 +0,0 @@
|
|||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
from pipecat.frames.frames import StartInterruptionFrame
|
||||
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
|
||||
from pipecat.services.openai.llm import (
|
||||
OpenAIAssistantContextAggregator,
|
||||
OpenAILLMContext,
|
||||
)
|
||||
|
||||
|
||||
class TestInterruptionCorrection:
|
||||
"""Test that TTS aggregation correction works during interruptions."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_interruption_with_correction(self):
|
||||
"""Test OpenAI assistant context aggregator applies correction during interruption."""
|
||||
# Create mock context
|
||||
mock_context = Mock(spec=OpenAILLMContext)
|
||||
mock_context.get_messages.return_value = []
|
||||
mock_context.add_message = Mock()
|
||||
|
||||
# Create correction callback
|
||||
def correction_callback(text: str) -> str:
|
||||
# Simulate fixing corrupted text
|
||||
if text == "Hello world how are you":
|
||||
return "Hello world, how are you"
|
||||
return text
|
||||
|
||||
# Create aggregator with correction callback
|
||||
params = LLMAssistantAggregatorParams(
|
||||
expect_stripped_words=True, correct_aggregation_callback=correction_callback
|
||||
)
|
||||
|
||||
aggregator = OpenAIAssistantContextAggregator(
|
||||
context=mock_context, params=params
|
||||
)
|
||||
|
||||
# Set up aggregation state
|
||||
aggregator._aggregation = "Hello world how are you"
|
||||
aggregator._current_llm_response_id = "test-id"
|
||||
aggregator._response_function_messages = {}
|
||||
aggregator._function_calls_in_progress = {}
|
||||
aggregator._started = 1
|
||||
|
||||
# Mock push_context_frame and reset methods
|
||||
aggregator.push_context_frame = AsyncMock()
|
||||
aggregator.reset = AsyncMock()
|
||||
|
||||
# Process interruption
|
||||
interruption_frame = StartInterruptionFrame()
|
||||
await aggregator._handle_interruptions(interruption_frame)
|
||||
|
||||
# Verify the corrected text was added to context
|
||||
mock_context.add_message.assert_called_once()
|
||||
added_message = mock_context.add_message.call_args[0][0]
|
||||
assert added_message["role"] == "assistant"
|
||||
assert (
|
||||
added_message["content"]
|
||||
== "Hello world, how are you <<interrupted_by_user>>"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_google_interruption_with_correction(self):
|
||||
"""Test Google assistant context aggregator applies correction during interruption."""
|
||||
from pipecat.services.google.llm import (
|
||||
Content,
|
||||
GoogleAssistantContextAggregator,
|
||||
)
|
||||
|
||||
# Create mock context
|
||||
mock_context = Mock(spec=OpenAILLMContext)
|
||||
mock_context.get_messages.return_value = []
|
||||
mock_context.add_message = Mock()
|
||||
|
||||
# Create correction callback
|
||||
def correction_callback(text: str) -> str:
|
||||
# Simulate fixing corrupted text
|
||||
if text == "I am here to help":
|
||||
return "I am here to help"
|
||||
return text
|
||||
|
||||
# Create aggregator with correction callback
|
||||
params = LLMAssistantAggregatorParams(
|
||||
expect_stripped_words=True, correct_aggregation_callback=correction_callback
|
||||
)
|
||||
|
||||
aggregator = GoogleAssistantContextAggregator(
|
||||
context=mock_context, params=params
|
||||
)
|
||||
|
||||
# Set up aggregation state
|
||||
aggregator._aggregation = "I am here to help"
|
||||
aggregator._current_llm_response_id = "test-id"
|
||||
aggregator._response_function_messages = {}
|
||||
aggregator._function_calls_in_progress = {}
|
||||
aggregator._started = 1
|
||||
|
||||
# Mock push_context_frame and reset methods
|
||||
aggregator.push_context_frame = AsyncMock()
|
||||
aggregator.reset = AsyncMock()
|
||||
|
||||
# Process interruption
|
||||
interruption_frame = StartInterruptionFrame()
|
||||
await aggregator._handle_interruptions(interruption_frame)
|
||||
|
||||
# Verify the corrected text was added to context
|
||||
mock_context.add_message.assert_called_once()
|
||||
added_content = mock_context.add_message.call_args[0][0]
|
||||
|
||||
# Google uses Content objects
|
||||
assert isinstance(added_content, Content)
|
||||
assert added_content.role == "model"
|
||||
assert len(added_content.parts) == 1
|
||||
assert (
|
||||
added_content.parts[0].text == "I am here to help <<interrupted_by_user>>"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_interruption_correction_error_handling(self):
|
||||
"""Test that interruption handling continues even if correction callback fails."""
|
||||
# Create mock context
|
||||
mock_context = Mock(spec=OpenAILLMContext)
|
||||
mock_context.get_messages.return_value = []
|
||||
mock_context.add_message = Mock()
|
||||
|
||||
# Create correction callback that raises error
|
||||
def failing_callback(text: str) -> str:
|
||||
raise ValueError("Correction failed")
|
||||
|
||||
# Create aggregator with failing callback
|
||||
params = LLMAssistantAggregatorParams(
|
||||
expect_stripped_words=True, correct_aggregation_callback=failing_callback
|
||||
)
|
||||
|
||||
aggregator = OpenAIAssistantContextAggregator(
|
||||
context=mock_context, params=params
|
||||
)
|
||||
|
||||
# Set up aggregation state
|
||||
aggregator._aggregation = "Some text"
|
||||
aggregator._current_llm_response_id = "test-id"
|
||||
aggregator._response_function_messages = {}
|
||||
aggregator._function_calls_in_progress = {}
|
||||
aggregator._started = 1
|
||||
|
||||
# Mock push_context_frame and reset methods
|
||||
aggregator.push_context_frame = AsyncMock()
|
||||
aggregator.reset = AsyncMock()
|
||||
|
||||
# Process interruption - should not raise
|
||||
interruption_frame = StartInterruptionFrame()
|
||||
await aggregator._handle_interruptions(interruption_frame)
|
||||
|
||||
# Verify the original text was still added (fallback behavior)
|
||||
mock_context.add_message.assert_called_once()
|
||||
added_message = mock_context.add_message.call_args[0][0]
|
||||
assert added_message["role"] == "assistant"
|
||||
assert added_message["content"] == "Some text <<interrupted_by_user>>"
|
||||
256
api/tests/conftest.py
Normal file
256
api/tests/conftest.py
Normal file
|
|
@ -0,0 +1,256 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Any, Dict
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services.workflow.dto import (
|
||||
EdgeDataDTO,
|
||||
NodeDataDTO,
|
||||
NodeType,
|
||||
Position,
|
||||
ReactFlowDTO,
|
||||
RFEdgeDTO,
|
||||
RFNodeDTO,
|
||||
)
|
||||
from api.services.workflow.workflow import WorkflowGraph
|
||||
|
||||
START_CALL_SYSTEM_PROMPT = "start_call_system_prompt"
|
||||
END_CALL_SYSTEM_PROMPT = "end_call_system_prompt"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockToolModel:
|
||||
"""Mock tool model for testing."""
|
||||
|
||||
tool_uuid: str
|
||||
name: str
|
||||
description: str
|
||||
definition: Dict[str, Any]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_engine():
|
||||
"""Create a mock PipecatEngine."""
|
||||
engine = Mock()
|
||||
engine._workflow_run_id = 1
|
||||
engine._call_context_vars = {"customer_name": "John Doe"}
|
||||
engine.llm = Mock()
|
||||
engine.llm.register_function = Mock()
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_tools():
|
||||
"""Create sample mock tools for testing."""
|
||||
return [
|
||||
MockToolModel(
|
||||
tool_uuid="weather-uuid-123",
|
||||
name="Get Weather",
|
||||
description="Get current weather for a location",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "http_api",
|
||||
"config": {
|
||||
"method": "GET",
|
||||
"url": "https://api.weather.com/current",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "location",
|
||||
"type": "string",
|
||||
"description": "City name (e.g., San Francisco, CA)",
|
||||
"required": True,
|
||||
},
|
||||
{
|
||||
"name": "units",
|
||||
"type": "string",
|
||||
"description": "Temperature units: celsius or fahrenheit",
|
||||
"required": False,
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
),
|
||||
MockToolModel(
|
||||
tool_uuid="booking-uuid-456",
|
||||
name="Book Appointment",
|
||||
description="Book an appointment for the customer",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "http_api",
|
||||
"config": {
|
||||
"method": "POST",
|
||||
"url": "https://api.example.com/appointments",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "customer_name",
|
||||
"type": "string",
|
||||
"description": "Customer's full name",
|
||||
"required": True,
|
||||
},
|
||||
{
|
||||
"name": "date",
|
||||
"type": "string",
|
||||
"description": "Appointment date (YYYY-MM-DD)",
|
||||
"required": True,
|
||||
},
|
||||
{
|
||||
"name": "time",
|
||||
"type": "string",
|
||||
"description": "Appointment time (HH:MM)",
|
||||
"required": True,
|
||||
},
|
||||
{
|
||||
"name": "notes",
|
||||
"type": "string",
|
||||
"description": "Additional notes",
|
||||
"required": False,
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
),
|
||||
MockToolModel(
|
||||
tool_uuid="lookup-uuid-789",
|
||||
name="Customer Lookup",
|
||||
description="Look up customer information by phone number",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "http_api",
|
||||
"config": {
|
||||
"method": "GET",
|
||||
"url": "https://api.example.com/customers/lookup",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "phone",
|
||||
"type": "string",
|
||||
"description": "Customer phone number",
|
||||
"required": True,
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_workflow() -> WorkflowGraph:
|
||||
"""Create a simple two-node workflow for testing.
|
||||
|
||||
The workflow has:
|
||||
- Start node with a prompt
|
||||
- End node with a prompt
|
||||
- One edge connecting them with label "End Call"
|
||||
"""
|
||||
dto = ReactFlowDTO(
|
||||
nodes=[
|
||||
RFNodeDTO(
|
||||
id="1",
|
||||
type=NodeType.startNode,
|
||||
position=Position(x=0, y=0),
|
||||
data=NodeDataDTO(
|
||||
name="Start Call",
|
||||
prompt=START_CALL_SYSTEM_PROMPT,
|
||||
is_start=True,
|
||||
allow_interrupt=False,
|
||||
add_global_prompt=False,
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="2",
|
||||
type=NodeType.endNode,
|
||||
position=Position(x=0, y=200),
|
||||
data=NodeDataDTO(
|
||||
name="End Call",
|
||||
prompt=END_CALL_SYSTEM_PROMPT,
|
||||
is_end=True,
|
||||
allow_interrupt=False,
|
||||
add_global_prompt=False,
|
||||
),
|
||||
),
|
||||
],
|
||||
edges=[
|
||||
RFEdgeDTO(
|
||||
id="1-2",
|
||||
source="1",
|
||||
target="2",
|
||||
data=EdgeDataDTO(
|
||||
label="End Call",
|
||||
condition="When the user says to end the call, end the call",
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
return WorkflowGraph(dto)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def three_node_workflow() -> WorkflowGraph:
|
||||
"""Create a three-node workflow for testing with an intermediate agent node.
|
||||
|
||||
The workflow has:
|
||||
- Start node
|
||||
- Agent node (for collecting information)
|
||||
- End node
|
||||
"""
|
||||
dto = ReactFlowDTO(
|
||||
nodes=[
|
||||
RFNodeDTO(
|
||||
id="1",
|
||||
type=NodeType.startNode,
|
||||
position=Position(x=0, y=0),
|
||||
data=NodeDataDTO(
|
||||
name="Start Call",
|
||||
prompt=START_CALL_SYSTEM_PROMPT,
|
||||
is_start=True,
|
||||
allow_interrupt=True,
|
||||
add_global_prompt=False,
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="2",
|
||||
type=NodeType.agentNode,
|
||||
position=Position(x=0, y=200),
|
||||
data=NodeDataDTO(
|
||||
name="Collect Info",
|
||||
prompt="Help the user with their request. Ask clarifying questions if needed.",
|
||||
allow_interrupt=True,
|
||||
add_global_prompt=False,
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="3",
|
||||
type=NodeType.endNode,
|
||||
position=Position(x=0, y=400),
|
||||
data=NodeDataDTO(
|
||||
name="End Call",
|
||||
prompt=END_CALL_SYSTEM_PROMPT,
|
||||
is_end=True,
|
||||
allow_interrupt=False,
|
||||
add_global_prompt=False,
|
||||
),
|
||||
),
|
||||
],
|
||||
edges=[
|
||||
RFEdgeDTO(
|
||||
id="1-2",
|
||||
source="1",
|
||||
target="2",
|
||||
data=EdgeDataDTO(
|
||||
label="Collect Info",
|
||||
condition="When the user wants help, collect their information",
|
||||
),
|
||||
),
|
||||
RFEdgeDTO(
|
||||
id="2-3",
|
||||
source="2",
|
||||
target="3",
|
||||
data=EdgeDataDTO(
|
||||
label="End Call",
|
||||
condition="When the user is done or wants to end the call",
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
return WorkflowGraph(dto)
|
||||
|
|
@ -10,14 +10,14 @@ def test_aggregation_fixer():
|
|||
creates a fresh callback for every (reference, corrupted) pair.
|
||||
|
||||
The production callback now needs a PipecatEngine instance with the
|
||||
`_current_llm_reference_text` set. For test-friendliness we mock a bare
|
||||
`_current_llm_generation_reference_text` set. For test-friendliness we mock a bare
|
||||
object providing just that attribute for each assertion so the original
|
||||
two-argument test cases remain unchanged.
|
||||
"""
|
||||
|
||||
def fixer(reference: str, corrupted: str) -> str: # noqa: D401
|
||||
mock_engine = Mock()
|
||||
mock_engine._current_llm_reference_text = reference
|
||||
mock_engine._current_llm_generation_reference_text = reference
|
||||
return create_aggregation_correction_callback(mock_engine)(corrupted)
|
||||
|
||||
##### Trailing extra Chars #####
|
||||
|
|
@ -172,7 +172,7 @@ def test_create_aggregation_correction_callback():
|
|||
"""Test the new aggregation correction callback creator."""
|
||||
# Mock engine with reference text
|
||||
mock_engine = Mock()
|
||||
mock_engine._current_llm_reference_text = "Good Morning Mr NARGES, My name is Alex and I am calling you from Consumer Services."
|
||||
mock_engine._current_llm_generation_reference_text = "Good Morning Mr NARGES, My name is Alex and I am calling you from Consumer Services."
|
||||
|
||||
# Create callback
|
||||
callback = create_aggregation_correction_callback(mock_engine)
|
||||
|
|
@ -187,6 +187,6 @@ def test_create_aggregation_correction_callback():
|
|||
)
|
||||
|
||||
# Test with no reference text
|
||||
mock_engine._current_llm_reference_text = ""
|
||||
mock_engine._current_llm_generation_reference_text = ""
|
||||
corrected = callback("Some corrupted text")
|
||||
assert corrected == "Some corrupted text" # Should return as-is when no reference
|
||||
|
|
@ -6,9 +6,7 @@ This module tests the full flow of:
|
|||
3. Verifying the context is properly configured for LLM generation
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -17,126 +15,14 @@ from api.services.workflow.pipecat_engine_utils import (
|
|||
get_function_schema,
|
||||
update_llm_context,
|
||||
)
|
||||
from api.tests.conftest import MockToolModel
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockToolModel:
|
||||
"""Mock tool model for testing."""
|
||||
|
||||
tool_uuid: str
|
||||
name: str
|
||||
description: str
|
||||
definition: Dict[str, Any]
|
||||
|
||||
|
||||
class TestCustomToolManagerContextIntegration:
|
||||
"""Integration tests for CustomToolManager with LLMContext."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_engine(self):
|
||||
"""Create a mock PipecatEngine."""
|
||||
engine = Mock()
|
||||
engine._workflow_run_id = 1
|
||||
engine._call_context_vars = {"customer_name": "John Doe"}
|
||||
engine.llm = Mock()
|
||||
engine.llm.register_function = Mock()
|
||||
return engine
|
||||
|
||||
@pytest.fixture
|
||||
def sample_tools(self):
|
||||
"""Create sample mock tools for testing."""
|
||||
return [
|
||||
MockToolModel(
|
||||
tool_uuid="weather-uuid-123",
|
||||
name="Get Weather",
|
||||
description="Get current weather for a location",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "http_api",
|
||||
"config": {
|
||||
"method": "GET",
|
||||
"url": "https://api.weather.com/current",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "location",
|
||||
"type": "string",
|
||||
"description": "City name (e.g., San Francisco, CA)",
|
||||
"required": True,
|
||||
},
|
||||
{
|
||||
"name": "units",
|
||||
"type": "string",
|
||||
"description": "Temperature units: celsius or fahrenheit",
|
||||
"required": False,
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
),
|
||||
MockToolModel(
|
||||
tool_uuid="booking-uuid-456",
|
||||
name="Book Appointment",
|
||||
description="Book an appointment for the customer",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "http_api",
|
||||
"config": {
|
||||
"method": "POST",
|
||||
"url": "https://api.example.com/appointments",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "customer_name",
|
||||
"type": "string",
|
||||
"description": "Customer's full name",
|
||||
"required": True,
|
||||
},
|
||||
{
|
||||
"name": "date",
|
||||
"type": "string",
|
||||
"description": "Appointment date (YYYY-MM-DD)",
|
||||
"required": True,
|
||||
},
|
||||
{
|
||||
"name": "time",
|
||||
"type": "string",
|
||||
"description": "Appointment time (HH:MM)",
|
||||
"required": True,
|
||||
},
|
||||
{
|
||||
"name": "notes",
|
||||
"type": "string",
|
||||
"description": "Additional notes",
|
||||
"required": False,
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
),
|
||||
MockToolModel(
|
||||
tool_uuid="lookup-uuid-789",
|
||||
name="Customer Lookup",
|
||||
description="Look up customer information by phone number",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "http_api",
|
||||
"config": {
|
||||
"method": "GET",
|
||||
"url": "https://api.example.com/customers/lookup",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "phone",
|
||||
"type": "string",
|
||||
"description": "Customer phone number",
|
||||
"required": True,
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_tool_schemas_and_update_context(self, mock_engine, sample_tools):
|
||||
"""Test fetching tool schemas via CustomToolManager and updating LLM context."""
|
||||
|
|
|
|||
|
|
@ -6,6 +6,6 @@ from api.services.workflow.dto import ReactFlowDTO
|
|||
@pytest.mark.asyncio
|
||||
async def test_dto():
|
||||
# assert no exceptions are raised
|
||||
with open("services/workflow/test/definitions/rf-1.json", "r") as f:
|
||||
with open("tests/definitions/rf-1.json", "r") as f:
|
||||
dto = ReactFlowDTO.model_validate_json(f.read())
|
||||
assert dto is not None
|
||||
340
api/tests/test_pipecat_engine_tool_calls.py
Normal file
340
api/tests/test_pipecat_engine_tool_calls.py
Normal file
|
|
@ -0,0 +1,340 @@
|
|||
"""Tests for tool calls with PipecatEngine and MockLLM.
|
||||
|
||||
This module tests the behavior when the LLM generates tool calls (single or parallel),
|
||||
using PipecatEngine's actual function registration and execution logic.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, List
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services.pipecat.pipeline_engine_callbacks_processor import (
|
||||
PipelineEngineCallbacksProcessor,
|
||||
)
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.workflow import WorkflowGraph
|
||||
from api.tests.conftest import END_CALL_SYSTEM_PROMPT
|
||||
from pipecat.frames.frames import (
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
Frame,
|
||||
TextFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.tests import MockLLMService
|
||||
|
||||
|
||||
class MockBotStoppedSpeakingOnLLMTextFrameProcessor(FrameProcessor):
|
||||
"""
|
||||
Mocking the transport, where transport sends BotStartedSpeakingFrame
|
||||
and BotStoppedSpeakingFrame when it encounters a LLMTextFrame.
|
||||
"""
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, TextFrame):
|
||||
await self.push_frame(BotStartedSpeakingFrame())
|
||||
await self.push_frame(
|
||||
BotStartedSpeakingFrame(), direction=FrameDirection.UPSTREAM
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
await self.push_frame(BotStoppedSpeakingFrame())
|
||||
await self.push_frame(
|
||||
BotStoppedSpeakingFrame(), direction=FrameDirection.UPSTREAM
|
||||
)
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
async def run_pipeline_with_tool_calls(
|
||||
workflow: WorkflowGraph,
|
||||
functions: List[Dict[str, Any]],
|
||||
text: str | None = None,
|
||||
num_text_steps: int = 1,
|
||||
) -> tuple[MockLLMService, LLMContext]:
|
||||
"""Run a pipeline with mock tool calls and return the LLM for assertions.
|
||||
|
||||
Args:
|
||||
workflow: The workflow graph to use.
|
||||
functions: List of function call definitions with name, arguments, and tool_call_id.
|
||||
text: Text to add to the first step (streamed before the tool calls).
|
||||
num_text_steps: Number of text response steps after the tool calls.
|
||||
|
||||
Returns:
|
||||
The MockLLMService instance for making assertions.
|
||||
"""
|
||||
# Create first step chunks
|
||||
if text:
|
||||
# Create text chunks (without final chunk) followed by function call chunks
|
||||
text_chunks = MockLLMService.create_text_chunks(text)
|
||||
func_chunks = MockLLMService.create_multiple_function_call_chunks(functions)
|
||||
# Exclude the final chunk from text_chunks (which has finish_reason="stop")
|
||||
first_step_chunks = text_chunks[:-1] + func_chunks
|
||||
else:
|
||||
first_step_chunks = MockLLMService.create_multiple_function_call_chunks(
|
||||
functions
|
||||
)
|
||||
|
||||
# Create multi-step responses
|
||||
mock_steps = MockLLMService.create_multi_step_responses(
|
||||
first_step_chunks, num_text_steps=num_text_steps, step_prefix="Response"
|
||||
)
|
||||
|
||||
# Create MockLLMService with multi-step support
|
||||
llm = MockLLMService(mock_steps=mock_steps, chunk_delay=0.001)
|
||||
|
||||
mock_transport_emulator = MockBotStoppedSpeakingOnLLMTextFrameProcessor()
|
||||
|
||||
# Create LLM context
|
||||
context = LLMContext()
|
||||
|
||||
# Add assistant context aggregator
|
||||
assistant_params = LLMAssistantAggregatorParams(expect_stripped_words=True)
|
||||
context_aggregator = LLMContextAggregatorPair(
|
||||
context, assistant_params=assistant_params
|
||||
)
|
||||
assistant_context_aggregator = context_aggregator.assistant()
|
||||
|
||||
# Create PipecatEngine with the workflow
|
||||
engine = PipecatEngine(
|
||||
llm=llm,
|
||||
context=context,
|
||||
workflow=workflow,
|
||||
call_context_vars={"customer_name": "Test User"},
|
||||
workflow_run_id=1,
|
||||
)
|
||||
|
||||
# Create the pipeline with the mock LLM
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
llm,
|
||||
mock_transport_emulator,
|
||||
assistant_context_aggregator,
|
||||
]
|
||||
)
|
||||
|
||||
# Create a real pipeline task
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(allow_interruptions=False),
|
||||
)
|
||||
|
||||
engine.set_task(task)
|
||||
|
||||
# Patch DB calls to avoid actual database access
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run",
|
||||
new_callable=AsyncMock,
|
||||
return_value=1,
|
||||
):
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine.apply_disposition_mapping",
|
||||
new_callable=AsyncMock,
|
||||
return_value="completed",
|
||||
):
|
||||
runner = PipelineRunner()
|
||||
|
||||
async def run_pipeline():
|
||||
await runner.run(task)
|
||||
|
||||
async def initialize_engine():
|
||||
# Small delay to let runner start
|
||||
await asyncio.sleep(0.01)
|
||||
await engine.initialize()
|
||||
|
||||
# Run both concurrently
|
||||
await asyncio.gather(run_pipeline(), initialize_engine())
|
||||
|
||||
return llm, context
|
||||
|
||||
|
||||
class TestPipecatEngineToolCalls:
|
||||
"""Test tool calls through PipecatEngine."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_builtin_and_transition_calls_through_engine(
|
||||
self, simple_workflow: WorkflowGraph
|
||||
):
|
||||
"""Test parallel function calls using PipecatEngine's actual handlers.
|
||||
|
||||
This test verifies that when the LLM generates parallel tool calls for:
|
||||
1. A built-in function (safe_calculator) - registered by _register_builtin_functions
|
||||
2. A transition function (end_call) - registered by _register_transition_function_with_llm
|
||||
|
||||
Both functions are properly executed through the engine's handlers and
|
||||
the transition correctly moves to the end node.
|
||||
|
||||
The test uses multi-step mock responses:
|
||||
- Step 1: Parallel tool calls (safe_calculator + end_call)
|
||||
- Step 2+: Text responses for subsequent node prompts
|
||||
"""
|
||||
functions = [
|
||||
{
|
||||
"name": "end_call",
|
||||
"arguments": {},
|
||||
"tool_call_id": "call_transition",
|
||||
},
|
||||
{
|
||||
"name": "safe_calculator",
|
||||
"arguments": {"expression": "25 * 4"},
|
||||
"tool_call_id": "call_calc",
|
||||
},
|
||||
]
|
||||
|
||||
llm, context = await run_pipeline_with_tool_calls(
|
||||
workflow=simple_workflow,
|
||||
functions=functions,
|
||||
num_text_steps=2,
|
||||
)
|
||||
|
||||
# Assert that the LLM generation was called a total of 2 times,
|
||||
# 1st time when StartNode was executed, and second time
|
||||
# when EndCall generation happened
|
||||
assert llm.get_current_step() == 2, (
|
||||
"LLM generation should have happened 2 times"
|
||||
)
|
||||
|
||||
# Assert that the context was updated with END_CALL_SYSTEM_PROMPT
|
||||
assert context.messages[0]["content"] == END_CALL_SYSTEM_PROMPT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_builtin_and_transition_calls_through_engine_1(
|
||||
self, simple_workflow: WorkflowGraph
|
||||
):
|
||||
"""Test parallel function calls using PipecatEngine's actual handlers.
|
||||
|
||||
This test verifies that when the LLM generates parallel tool calls for:
|
||||
1. A built-in function (safe_calculator) - registered by _register_builtin_functions
|
||||
2. A transition function (end_call) - registered by _register_transition_function_with_llm
|
||||
|
||||
Both functions are properly executed through the engine's handlers and
|
||||
the transition correctly moves to the end node.
|
||||
|
||||
The test uses multi-step mock responses:
|
||||
- Step 1: Parallel tool calls (safe_calculator + end_call)
|
||||
- Step 2+: Text responses for subsequent node prompts
|
||||
"""
|
||||
functions = [
|
||||
{
|
||||
"name": "safe_calculator",
|
||||
"arguments": {"expression": "25 * 4"},
|
||||
"tool_call_id": "call_calc",
|
||||
},
|
||||
{
|
||||
"name": "end_call",
|
||||
"arguments": {},
|
||||
"tool_call_id": "call_transition",
|
||||
},
|
||||
]
|
||||
|
||||
llm, context = await run_pipeline_with_tool_calls(
|
||||
workflow=simple_workflow,
|
||||
functions=functions,
|
||||
num_text_steps=2,
|
||||
)
|
||||
|
||||
# Assert that the LLM generation was called a total of 2 times,
|
||||
# 1st time when StartNode was executed, and second time
|
||||
# when EndCall generation happened. The tool should not invoke
|
||||
# an LLM generation
|
||||
assert llm.get_current_step() == 2, (
|
||||
"LLM generation should have happened 2 times"
|
||||
)
|
||||
|
||||
# Assert that the context was updated with END_CALL_SYSTEM_PROMPT
|
||||
assert context.messages[0]["content"] == END_CALL_SYSTEM_PROMPT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_builtin_and_transition_calls_through_engine_with_text(
|
||||
self, simple_workflow: WorkflowGraph
|
||||
):
|
||||
"""Test parallel function calls using PipecatEngine's actual handlers.
|
||||
|
||||
This test verifies that when the LLM generates parallel tool calls for:
|
||||
1. A built-in function (safe_calculator) - registered by _register_builtin_functions
|
||||
2. A transition function (end_call) - registered by _register_transition_function_with_llm
|
||||
|
||||
Both functions are properly executed through the engine's handlers and
|
||||
the transition correctly moves to the end node.
|
||||
|
||||
The test uses multi-step mock responses:
|
||||
- Step 1: Parallel tool calls (safe_calculator + end_call)
|
||||
- Step 2+: Text responses for subsequent node prompts
|
||||
"""
|
||||
functions = [
|
||||
{
|
||||
"name": "end_call",
|
||||
"arguments": {},
|
||||
"tool_call_id": "call_transition",
|
||||
},
|
||||
{
|
||||
"name": "safe_calculator",
|
||||
"arguments": {"expression": "25 * 4"},
|
||||
"tool_call_id": "call_calc",
|
||||
},
|
||||
]
|
||||
|
||||
llm, context = await run_pipeline_with_tool_calls(
|
||||
workflow=simple_workflow,
|
||||
functions=functions,
|
||||
text="Hello There!",
|
||||
num_text_steps=2,
|
||||
)
|
||||
|
||||
# Assert that the LLM generation was called a total of 2 times,
|
||||
# 1st time when StartNode was executed, and second time
|
||||
# when EndCall generation happened. The tool should not invoke
|
||||
# an LLM generation
|
||||
assert llm.get_current_step() == 2, (
|
||||
"LLM generation should have happened 2 times"
|
||||
)
|
||||
|
||||
# Assert that the context was updated with END_CALL_SYSTEM_PROMPT
|
||||
assert context.messages[0]["content"] == END_CALL_SYSTEM_PROMPT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_transition_call_through_engine(
|
||||
self, simple_workflow: WorkflowGraph
|
||||
):
|
||||
"""Test a single transition function call (end_call) through PipecatEngine.
|
||||
|
||||
This test verifies that when the LLM generates only a transition tool call,
|
||||
the engine properly executes it and transitions to the end node.
|
||||
Since end_call transitions to the end node which triggers another LLM
|
||||
generation, the LLM is called exactly once for the initial StartNode.
|
||||
"""
|
||||
functions = [
|
||||
{
|
||||
"name": "end_call",
|
||||
"arguments": {},
|
||||
"tool_call_id": "call_transition",
|
||||
},
|
||||
]
|
||||
|
||||
llm, context = await run_pipeline_with_tool_calls(
|
||||
workflow=simple_workflow,
|
||||
functions=functions,
|
||||
num_text_steps=1,
|
||||
)
|
||||
|
||||
# LLM is called once for the StartNode, then end_call transitions to EndNode
|
||||
# which triggers a second generation
|
||||
assert llm.get_current_step() == 2, (
|
||||
"LLM generation should have happened 2 times"
|
||||
)
|
||||
|
||||
# Assert that the context was updated with END_CALL_SYSTEM_PROMPT
|
||||
assert context.messages[0]["content"] == END_CALL_SYSTEM_PROMPT
|
||||
|
|
@ -27,6 +27,8 @@ import {
|
|||
type KeyValueItem,
|
||||
ParameterEditor,
|
||||
type ToolParameter,
|
||||
UrlInput,
|
||||
validateUrl,
|
||||
} from "@/components/http";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
|
|
@ -151,8 +153,9 @@ export default function ToolDetailPage() {
|
|||
|
||||
const handleSave = async () => {
|
||||
// Validate URL
|
||||
if (!url.trim()) {
|
||||
setError("URL is required");
|
||||
const urlValidation = validateUrl(url);
|
||||
if (!urlValidation.valid) {
|
||||
setError(urlValidation.error || "Invalid URL");
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -431,10 +434,11 @@ const data = await response.json();`;
|
|||
|
||||
<div className="grid gap-2">
|
||||
<Label>Endpoint URL</Label>
|
||||
<Input
|
||||
<UrlInput
|
||||
value={url}
|
||||
onChange={(e) => setUrl(e.target.value)}
|
||||
onChange={setUrl}
|
||||
placeholder="https://api.example.com/appointments"
|
||||
showValidation
|
||||
/>
|
||||
</div>
|
||||
</TabsContent>
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ import {
|
|||
HttpMethodSelector,
|
||||
KeyValueEditor,
|
||||
type KeyValueItem,
|
||||
UrlInput,
|
||||
validateUrl,
|
||||
} from "@/components/http";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Input } from "@/components/ui/input";
|
||||
|
|
@ -57,8 +59,9 @@ export const WebhookNode = memo(({ data, selected, id }: WebhookNodeProps) => {
|
|||
|
||||
const handleSave = async () => {
|
||||
// Validate endpoint URL
|
||||
if (!endpointUrl.trim()) {
|
||||
setEndpointError('Endpoint URL is required');
|
||||
const urlValidation = validateUrl(endpointUrl);
|
||||
if (!urlValidation.valid) {
|
||||
setEndpointError(urlValidation.error || 'Invalid URL');
|
||||
return;
|
||||
}
|
||||
setEndpointError(null);
|
||||
|
|
@ -284,10 +287,11 @@ const WebhookNodeEditForm = ({
|
|||
<Label className="text-xs text-muted-foreground">
|
||||
The URL to send the webhook request to.
|
||||
</Label>
|
||||
<Input
|
||||
<UrlInput
|
||||
value={endpointUrl}
|
||||
onChange={(e) => setEndpointUrl(e.target.value)}
|
||||
onChange={setEndpointUrl}
|
||||
placeholder="https://api.example.com/webhook"
|
||||
showValidation
|
||||
/>
|
||||
</div>
|
||||
</TabsContent>
|
||||
|
|
|
|||
|
|
@ -3,3 +3,4 @@ export { CredentialSelector } from "./credential-selector";
|
|||
export { type HttpMethod, HttpMethodSelector } from "./http-method-selector";
|
||||
export { KeyValueEditor, type KeyValueItem } from "./key-value-editor";
|
||||
export { ParameterEditor, type ParameterType,type ToolParameter } from "./parameter-editor";
|
||||
export { UrlInput, type UrlValidationResult,validateUrl } from "./url-input";
|
||||
|
|
|
|||
106
ui/src/components/http/url-input.tsx
Normal file
106
ui/src/components/http/url-input.tsx
Normal file
|
|
@ -0,0 +1,106 @@
|
|||
"use client";
|
||||
|
||||
import { useCallback, useState } from "react";
|
||||
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
// URL regex pattern that validates:
|
||||
// - http:// or https:// protocol (required)
|
||||
// - Optional username:password@
|
||||
// - Domain name or IP address
|
||||
// - Optional port number
|
||||
// - Optional path, query string, and fragment
|
||||
const URL_REGEX =
|
||||
/^https?:\/\/(?:[\w-]+(?::[\w-]+)?@)?(?:[\w-]+\.)*[\w-]+(?::\d{1,5})?(?:\/[^\s]*)?$/i;
|
||||
|
||||
export interface UrlValidationResult {
|
||||
valid: boolean;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
export function validateUrl(url: string): UrlValidationResult {
|
||||
const trimmedUrl = url.trim();
|
||||
|
||||
if (!trimmedUrl) {
|
||||
return { valid: false, error: "URL is required" };
|
||||
}
|
||||
|
||||
if (!URL_REGEX.test(trimmedUrl)) {
|
||||
return {
|
||||
valid: false,
|
||||
error: "Invalid URL format. Must start with http:// or https://",
|
||||
};
|
||||
}
|
||||
|
||||
return { valid: true };
|
||||
}
|
||||
|
||||
interface UrlInputProps {
|
||||
value: string;
|
||||
onChange: (value: string) => void;
|
||||
placeholder?: string;
|
||||
disabled?: boolean;
|
||||
className?: string;
|
||||
/** Show validation error styling and message inline */
|
||||
showValidation?: boolean;
|
||||
/** Called when validation state changes */
|
||||
onValidationChange?: (result: UrlValidationResult) => void;
|
||||
}
|
||||
|
||||
export function UrlInput({
|
||||
value,
|
||||
onChange,
|
||||
placeholder = "https://api.example.com/endpoint",
|
||||
disabled = false,
|
||||
className,
|
||||
showValidation = false,
|
||||
onValidationChange,
|
||||
}: UrlInputProps) {
|
||||
const [touched, setTouched] = useState(false);
|
||||
|
||||
const handleChange = useCallback(
|
||||
(e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
const newValue = e.target.value;
|
||||
onChange(newValue);
|
||||
|
||||
if (onValidationChange && (touched || newValue)) {
|
||||
onValidationChange(validateUrl(newValue));
|
||||
}
|
||||
},
|
||||
[onChange, onValidationChange, touched]
|
||||
);
|
||||
|
||||
const handleBlur = useCallback(() => {
|
||||
setTouched(true);
|
||||
const trimmedValue = value.trim();
|
||||
if (trimmedValue !== value) {
|
||||
onChange(trimmedValue);
|
||||
}
|
||||
if (onValidationChange && trimmedValue) {
|
||||
onValidationChange(validateUrl(trimmedValue));
|
||||
}
|
||||
}, [onChange, onValidationChange, value]);
|
||||
|
||||
const validation = validateUrl(value);
|
||||
const showError = showValidation && touched && !validation.valid && value;
|
||||
|
||||
return (
|
||||
<div className="space-y-1">
|
||||
<Input
|
||||
value={value}
|
||||
onChange={handleChange}
|
||||
onBlur={handleBlur}
|
||||
placeholder={placeholder}
|
||||
disabled={disabled}
|
||||
className={cn(
|
||||
showError && "border-destructive focus-visible:ring-destructive",
|
||||
className
|
||||
)}
|
||||
/>
|
||||
{showError && (
|
||||
<p className="text-xs text-destructive">{validation.error}</p>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue