Fix/multiple generation (#104)

* fixes #100

* Fix test

* fix: fix bad configuration issue
This commit is contained in:
Abhishek 2026-01-03 12:59:18 +05:30 committed by GitHub
parent 90b690efff
commit 56953bbd09
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 758 additions and 460 deletions

View file

@ -1,5 +1,7 @@
from datetime import datetime, timezone
from loguru import logger
from pydantic import ValidationError
from sqlalchemy.future import select
from api.db.base_client import BaseDBClient
@ -66,12 +68,21 @@ class UserClient(BaseDBClient):
if not configuration_obj:
return UserConfiguration()
return UserConfiguration.model_validate(
{
**configuration_obj.configuration,
"last_validated_at": configuration_obj.last_validated_at,
}
)
try:
return UserConfiguration.model_validate(
{
**configuration_obj.configuration,
"last_validated_at": configuration_obj.last_validated_at,
}
)
except ValidationError as e:
# If configuration contains an unsupported provider,
# return a default configuration without failing
logger.warning(
f"Failed to validate user configuration for user {user_id}: {e}. "
"Returning default configuration."
)
return UserConfiguration()
async def update_user_configuration(
self, user_id: int, configuration: UserConfiguration

View file

@ -494,7 +494,6 @@ async def _run_pipeline(
max_duration_end_task_callback=engine.create_max_duration_callback(),
generation_started_callback=engine.create_generation_started_callback(),
llm_text_frame_callback=engine.handle_llm_text_frame,
# Note: speaking event callbacks are now handled by pre-aggregator processor
)
pipeline_metrics_aggregator = PipelineMetricsAggregator()

View file

@ -13,9 +13,8 @@ from pipecat.frames.frames import (
CancelFrame,
EndFrame,
FunctionCallResultProperties,
FunctionCallsFromLLMInfoFrame,
LLMContextFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
TTSSpeakFrame,
)
from pipecat.pipeline.task import PipelineTask
@ -104,7 +103,7 @@ class PipecatEngine:
self._builtin_function_schemas: Optional[list[dict]] = None
# Track current LLM reference text for TTS aggregation correction
self._current_llm_reference_text: str = ""
self._current_llm_generation_reference_text: str = ""
# Custom tool manager (initialized in initialize())
self._custom_tool_manager: Optional[CustomToolManager] = None
@ -173,6 +172,9 @@ class PipecatEngine:
await self._register_builtin_functions()
await self.set_node(self.workflow.start_node_id)
# Trigger initial LLM generation
await self.task.queue_frame(LLMContextFrame(self.context))
logger.debug(f"{self.__class__.__name__} initialized")
except Exception as e:
logger.error(f"Error initializing {self.__class__.__name__}: {e}")
@ -218,7 +220,6 @@ class PipecatEngine:
result = {"status": "done"}
properties = FunctionCallResultProperties(
run_llm=False,
on_context_updated=on_context_updated,
)
@ -256,8 +257,6 @@ class PipecatEngine:
"""Register built-in functions (calculator and timezone) with the LLM."""
logger.debug("Registering built-in functions with LLM")
properties = FunctionCallResultProperties(run_llm=True)
# Register calculator function
async def calculate_func(function_call_params: FunctionCallParams) -> None:
logger.info(f"LLM Function Call EXECUTED: safe_calculator")
@ -266,12 +265,10 @@ class PipecatEngine:
expr = function_call_params.arguments.get("expression", "")
result = safe_calculator(expr)
await function_call_params.result_callback(
{"expression": expr, "result": result}, properties=properties
{"expression": expr, "result": result}
)
except Exception as e:
await function_call_params.result_callback(
{"error": str(e)}, properties=properties
)
await function_call_params.result_callback({"error": str(e)})
# Register timezone functions
async def get_current_time_func(
@ -282,13 +279,9 @@ class PipecatEngine:
try:
timezone = function_call_params.arguments.get("timezone", "UTC")
result = get_current_time(timezone)
await function_call_params.result_callback(
result, properties=properties
)
await function_call_params.result_callback(result)
except Exception as e:
await function_call_params.result_callback(
{"error": str(e)}, properties=properties
)
await function_call_params.result_callback({"error": str(e)})
async def convert_time_func(function_call_params: FunctionCallParams) -> None:
logger.info(f"LLM Function Call EXECUTED: convert_time")
@ -299,29 +292,15 @@ class PipecatEngine:
function_call_params.arguments.get("time"),
function_call_params.arguments.get("target_timezone"),
)
await function_call_params.result_callback(
result, properties=properties
)
await function_call_params.result_callback(result)
except Exception as e:
await function_call_params.result_callback(
{"error": str(e)}, properties=properties
)
await function_call_params.result_callback({"error": str(e)})
# Register all built-in functions
self.llm.register_function("safe_calculator", calculate_func)
self.llm.register_function("get_current_time", get_current_time_func)
self.llm.register_function("convert_time", convert_time_func)
async def _queue_tts_response(self, text: str) -> None:
"""Queue TTS frames for static text response."""
await self.task.queue_frames(
[
LLMFullResponseStartFrame(),
TTSSpeakFrame(text=text),
LLMFullResponseEndFrame(),
]
)
async def _perform_variable_extraction_if_needed(
self, previous_node: Optional[Node]
) -> None:
@ -384,7 +363,6 @@ class PipecatEngine:
functions,
) = await self._compose_system_message_functions_for_node(node)
await self._update_llm_context(system_message, functions)
await self.task.queue_frame(LLMContextFrame(self.context))
async def set_node(self, node_id: str):
"""
@ -733,7 +711,7 @@ class PipecatEngine:
async def handle_llm_text_frame(self, text: str):
"""Accumulate LLM text frames to build reference text."""
self._current_llm_reference_text += text
self._current_llm_generation_reference_text += text
def handle_client_disconnected(self):
"""Handle client disconnected event."""

View file

@ -114,10 +114,10 @@ def create_max_duration_callback(engine: "PipecatEngine"):
def create_generation_started_callback(engine: "PipecatEngine"):
"""Return a callback that resets flags at the start of each LLM generation."""
async def handle_generation_started(): # noqa: D401
async def handle_generation_started():
logger.debug("LLM generation started in callback processor")
# Clear reference text from previous generation
engine._current_llm_reference_text = ""
engine._current_llm_generation_reference_text = ""
return handle_generation_started
@ -184,7 +184,7 @@ def create_aggregation_correction_callback(engine: "PipecatEngine"):
return "".join(out_chars)
def correct_aggregation(corrupted: str) -> str:
reference = engine._current_llm_reference_text
reference = engine._current_llm_generation_reference_text
if not reference:
logger.warning("No reference text available for aggregation correction")

View file

@ -1,128 +0,0 @@
from unittest.mock import Mock
import pytest
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
from pipecat.services.openai.llm import OpenAILLMContext
from api.services.workflow.pipecat_engine import PipecatEngine
from api.services.workflow.pipecat_engine_callbacks import (
create_generation_started_callback,
)
class TestAggregationIntegration:
"""Integration tests for the TTS aggregation correction flow."""
@pytest.mark.asyncio
async def test_engine_reference_text_tracking(self):
"""Test that the engine properly tracks LLM reference text."""
# Create mock dependencies
mock_task = Mock()
mock_llm = Mock()
mock_context = Mock(spec=OpenAILLMContext)
mock_tts = Mock()
mock_workflow = Mock()
mock_workflow.start_node_id = "start"
mock_workflow.nodes = {
"start": Mock(is_start=True, is_static=True, is_end=False, out_edges=[])
}
# Create engine
engine = PipecatEngine(
task=mock_task,
llm=mock_llm,
context=mock_context,
tts=mock_tts,
workflow=mock_workflow,
call_context_vars={},
workflow_run_id=1,
)
# Test initial state
assert engine._current_llm_reference_text == ""
# Test accumulating LLM text
await engine.handle_llm_text_frame("Hello ")
assert engine._current_llm_reference_text == "Hello "
await engine.handle_llm_text_frame("world!")
assert engine._current_llm_reference_text == "Hello world!"
# Test generation started callback clears reference text
callback = create_generation_started_callback(engine)
await callback()
assert engine._current_llm_reference_text == ""
@pytest.mark.asyncio
async def test_aggregation_correction_callback_creation(self):
"""Test creating the aggregation correction callback."""
# Create mock engine
mock_task = Mock()
mock_llm = Mock()
mock_context = Mock(spec=OpenAILLMContext)
mock_workflow = Mock()
engine = PipecatEngine(
task=mock_task,
llm=mock_llm,
context=mock_context,
workflow=mock_workflow,
call_context_vars={},
workflow_run_id=1,
)
# Set reference text
engine._current_llm_reference_text = "Hello, world! How are you?"
# Create correction callback
callback = engine.create_aggregation_correction_callback()
# Test correction - note that trailing punctuation might be stripped if not in corrupted text
corrected = callback("Hello world How are you")
assert corrected == "Hello, world! How are you"
def test_llm_assistant_aggregator_params_with_callback(self):
"""Test that LLMAssistantAggregatorParams accepts correction callback."""
def mock_callback(text: str) -> str:
return text.upper()
params = LLMAssistantAggregatorParams(
expect_stripped_words=True, correct_aggregation_callback=mock_callback
)
assert params.expect_stripped_words is True
assert params.correct_aggregation_callback is not None
assert params.correct_aggregation_callback("hello") == "HELLO"
@pytest.mark.asyncio
async def test_pipeline_callbacks_processor_llm_text_frame(self):
"""Test that PipelineEngineCallbacksProcessor handles LLMTextFrame."""
from pipecat.frames.frames import LLMTextFrame
from pipecat.processors.frame_processor import FrameDirection
from api.services.pipecat.pipeline_engine_callbacks_processor import (
PipelineEngineCallbacksProcessor,
)
# Track callback invocations
callback_invoked = False
callback_text = None
async def mock_llm_text_callback(text: str):
nonlocal callback_invoked, callback_text
callback_invoked = True
callback_text = text
# Create processor with callback
processor = PipelineEngineCallbacksProcessor(
llm_text_frame_callback=mock_llm_text_callback
)
# Process LLMTextFrame
frame = LLMTextFrame(text="Hello world")
await processor.process_frame(frame, FrameDirection.DOWNSTREAM)
# Verify callback was invoked
assert callback_invoked is True
assert callback_text == "Hello world"

View file

@ -1,159 +0,0 @@
from unittest.mock import AsyncMock, Mock
import pytest
from pipecat.frames.frames import StartInterruptionFrame
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
from pipecat.services.openai.llm import (
OpenAIAssistantContextAggregator,
OpenAILLMContext,
)
class TestInterruptionCorrection:
"""Test that TTS aggregation correction works during interruptions."""
@pytest.mark.asyncio
async def test_openai_interruption_with_correction(self):
"""Test OpenAI assistant context aggregator applies correction during interruption."""
# Create mock context
mock_context = Mock(spec=OpenAILLMContext)
mock_context.get_messages.return_value = []
mock_context.add_message = Mock()
# Create correction callback
def correction_callback(text: str) -> str:
# Simulate fixing corrupted text
if text == "Hello world how are you":
return "Hello world, how are you"
return text
# Create aggregator with correction callback
params = LLMAssistantAggregatorParams(
expect_stripped_words=True, correct_aggregation_callback=correction_callback
)
aggregator = OpenAIAssistantContextAggregator(
context=mock_context, params=params
)
# Set up aggregation state
aggregator._aggregation = "Hello world how are you"
aggregator._current_llm_response_id = "test-id"
aggregator._response_function_messages = {}
aggregator._function_calls_in_progress = {}
aggregator._started = 1
# Mock push_context_frame and reset methods
aggregator.push_context_frame = AsyncMock()
aggregator.reset = AsyncMock()
# Process interruption
interruption_frame = StartInterruptionFrame()
await aggregator._handle_interruptions(interruption_frame)
# Verify the corrected text was added to context
mock_context.add_message.assert_called_once()
added_message = mock_context.add_message.call_args[0][0]
assert added_message["role"] == "assistant"
assert (
added_message["content"]
== "Hello world, how are you <<interrupted_by_user>>"
)
@pytest.mark.asyncio
async def test_google_interruption_with_correction(self):
"""Test Google assistant context aggregator applies correction during interruption."""
from pipecat.services.google.llm import (
Content,
GoogleAssistantContextAggregator,
)
# Create mock context
mock_context = Mock(spec=OpenAILLMContext)
mock_context.get_messages.return_value = []
mock_context.add_message = Mock()
# Create correction callback
def correction_callback(text: str) -> str:
# Simulate fixing corrupted text
if text == "I am here to help":
return "I am here to help"
return text
# Create aggregator with correction callback
params = LLMAssistantAggregatorParams(
expect_stripped_words=True, correct_aggregation_callback=correction_callback
)
aggregator = GoogleAssistantContextAggregator(
context=mock_context, params=params
)
# Set up aggregation state
aggregator._aggregation = "I am here to help"
aggregator._current_llm_response_id = "test-id"
aggregator._response_function_messages = {}
aggregator._function_calls_in_progress = {}
aggregator._started = 1
# Mock push_context_frame and reset methods
aggregator.push_context_frame = AsyncMock()
aggregator.reset = AsyncMock()
# Process interruption
interruption_frame = StartInterruptionFrame()
await aggregator._handle_interruptions(interruption_frame)
# Verify the corrected text was added to context
mock_context.add_message.assert_called_once()
added_content = mock_context.add_message.call_args[0][0]
# Google uses Content objects
assert isinstance(added_content, Content)
assert added_content.role == "model"
assert len(added_content.parts) == 1
assert (
added_content.parts[0].text == "I am here to help <<interrupted_by_user>>"
)
@pytest.mark.asyncio
async def test_interruption_correction_error_handling(self):
"""Test that interruption handling continues even if correction callback fails."""
# Create mock context
mock_context = Mock(spec=OpenAILLMContext)
mock_context.get_messages.return_value = []
mock_context.add_message = Mock()
# Create correction callback that raises error
def failing_callback(text: str) -> str:
raise ValueError("Correction failed")
# Create aggregator with failing callback
params = LLMAssistantAggregatorParams(
expect_stripped_words=True, correct_aggregation_callback=failing_callback
)
aggregator = OpenAIAssistantContextAggregator(
context=mock_context, params=params
)
# Set up aggregation state
aggregator._aggregation = "Some text"
aggregator._current_llm_response_id = "test-id"
aggregator._response_function_messages = {}
aggregator._function_calls_in_progress = {}
aggregator._started = 1
# Mock push_context_frame and reset methods
aggregator.push_context_frame = AsyncMock()
aggregator.reset = AsyncMock()
# Process interruption - should not raise
interruption_frame = StartInterruptionFrame()
await aggregator._handle_interruptions(interruption_frame)
# Verify the original text was still added (fallback behavior)
mock_context.add_message.assert_called_once()
added_message = mock_context.add_message.call_args[0][0]
assert added_message["role"] == "assistant"
assert added_message["content"] == "Some text <<interrupted_by_user>>"

256
api/tests/conftest.py Normal file
View file

@ -0,0 +1,256 @@
from dataclasses import dataclass
from typing import Any, Dict
from unittest.mock import Mock
import pytest
from api.services.workflow.dto import (
EdgeDataDTO,
NodeDataDTO,
NodeType,
Position,
ReactFlowDTO,
RFEdgeDTO,
RFNodeDTO,
)
from api.services.workflow.workflow import WorkflowGraph
START_CALL_SYSTEM_PROMPT = "start_call_system_prompt"
END_CALL_SYSTEM_PROMPT = "end_call_system_prompt"
@dataclass
class MockToolModel:
"""Mock tool model for testing."""
tool_uuid: str
name: str
description: str
definition: Dict[str, Any]
@pytest.fixture
def mock_engine():
"""Create a mock PipecatEngine."""
engine = Mock()
engine._workflow_run_id = 1
engine._call_context_vars = {"customer_name": "John Doe"}
engine.llm = Mock()
engine.llm.register_function = Mock()
return engine
@pytest.fixture
def sample_tools():
"""Create sample mock tools for testing."""
return [
MockToolModel(
tool_uuid="weather-uuid-123",
name="Get Weather",
description="Get current weather for a location",
definition={
"schema_version": 1,
"type": "http_api",
"config": {
"method": "GET",
"url": "https://api.weather.com/current",
"parameters": [
{
"name": "location",
"type": "string",
"description": "City name (e.g., San Francisco, CA)",
"required": True,
},
{
"name": "units",
"type": "string",
"description": "Temperature units: celsius or fahrenheit",
"required": False,
},
],
},
},
),
MockToolModel(
tool_uuid="booking-uuid-456",
name="Book Appointment",
description="Book an appointment for the customer",
definition={
"schema_version": 1,
"type": "http_api",
"config": {
"method": "POST",
"url": "https://api.example.com/appointments",
"parameters": [
{
"name": "customer_name",
"type": "string",
"description": "Customer's full name",
"required": True,
},
{
"name": "date",
"type": "string",
"description": "Appointment date (YYYY-MM-DD)",
"required": True,
},
{
"name": "time",
"type": "string",
"description": "Appointment time (HH:MM)",
"required": True,
},
{
"name": "notes",
"type": "string",
"description": "Additional notes",
"required": False,
},
],
},
},
),
MockToolModel(
tool_uuid="lookup-uuid-789",
name="Customer Lookup",
description="Look up customer information by phone number",
definition={
"schema_version": 1,
"type": "http_api",
"config": {
"method": "GET",
"url": "https://api.example.com/customers/lookup",
"parameters": [
{
"name": "phone",
"type": "string",
"description": "Customer phone number",
"required": True,
},
],
},
},
),
]
@pytest.fixture
def simple_workflow() -> WorkflowGraph:
"""Create a simple two-node workflow for testing.
The workflow has:
- Start node with a prompt
- End node with a prompt
- One edge connecting them with label "End Call"
"""
dto = ReactFlowDTO(
nodes=[
RFNodeDTO(
id="1",
type=NodeType.startNode,
position=Position(x=0, y=0),
data=NodeDataDTO(
name="Start Call",
prompt=START_CALL_SYSTEM_PROMPT,
is_start=True,
allow_interrupt=False,
add_global_prompt=False,
),
),
RFNodeDTO(
id="2",
type=NodeType.endNode,
position=Position(x=0, y=200),
data=NodeDataDTO(
name="End Call",
prompt=END_CALL_SYSTEM_PROMPT,
is_end=True,
allow_interrupt=False,
add_global_prompt=False,
),
),
],
edges=[
RFEdgeDTO(
id="1-2",
source="1",
target="2",
data=EdgeDataDTO(
label="End Call",
condition="When the user says to end the call, end the call",
),
),
],
)
return WorkflowGraph(dto)
@pytest.fixture
def three_node_workflow() -> WorkflowGraph:
"""Create a three-node workflow for testing with an intermediate agent node.
The workflow has:
- Start node
- Agent node (for collecting information)
- End node
"""
dto = ReactFlowDTO(
nodes=[
RFNodeDTO(
id="1",
type=NodeType.startNode,
position=Position(x=0, y=0),
data=NodeDataDTO(
name="Start Call",
prompt=START_CALL_SYSTEM_PROMPT,
is_start=True,
allow_interrupt=True,
add_global_prompt=False,
),
),
RFNodeDTO(
id="2",
type=NodeType.agentNode,
position=Position(x=0, y=200),
data=NodeDataDTO(
name="Collect Info",
prompt="Help the user with their request. Ask clarifying questions if needed.",
allow_interrupt=True,
add_global_prompt=False,
),
),
RFNodeDTO(
id="3",
type=NodeType.endNode,
position=Position(x=0, y=400),
data=NodeDataDTO(
name="End Call",
prompt=END_CALL_SYSTEM_PROMPT,
is_end=True,
allow_interrupt=False,
add_global_prompt=False,
),
),
],
edges=[
RFEdgeDTO(
id="1-2",
source="1",
target="2",
data=EdgeDataDTO(
label="Collect Info",
condition="When the user wants help, collect their information",
),
),
RFEdgeDTO(
id="2-3",
source="2",
target="3",
data=EdgeDataDTO(
label="End Call",
condition="When the user is done or wants to end the call",
),
),
],
)
return WorkflowGraph(dto)

View file

@ -10,14 +10,14 @@ def test_aggregation_fixer():
creates a fresh callback for every (reference, corrupted) pair.
The production callback now needs a PipecatEngine instance with the
`_current_llm_reference_text` set. For test-friendliness we mock a bare
`_current_llm_generation_reference_text` set. For test-friendliness we mock a bare
object providing just that attribute for each assertion so the original
two-argument test cases remain unchanged.
"""
def fixer(reference: str, corrupted: str) -> str: # noqa: D401
mock_engine = Mock()
mock_engine._current_llm_reference_text = reference
mock_engine._current_llm_generation_reference_text = reference
return create_aggregation_correction_callback(mock_engine)(corrupted)
##### Trailing extra Chars #####
@ -172,7 +172,7 @@ def test_create_aggregation_correction_callback():
"""Test the new aggregation correction callback creator."""
# Mock engine with reference text
mock_engine = Mock()
mock_engine._current_llm_reference_text = "Good Morning Mr NARGES, My name is Alex and I am calling you from Consumer Services."
mock_engine._current_llm_generation_reference_text = "Good Morning Mr NARGES, My name is Alex and I am calling you from Consumer Services."
# Create callback
callback = create_aggregation_correction_callback(mock_engine)
@ -187,6 +187,6 @@ def test_create_aggregation_correction_callback():
)
# Test with no reference text
mock_engine._current_llm_reference_text = ""
mock_engine._current_llm_generation_reference_text = ""
corrected = callback("Some corrupted text")
assert corrected == "Some corrupted text" # Should return as-is when no reference

View file

@ -6,9 +6,7 @@ This module tests the full flow of:
3. Verifying the context is properly configured for LLM generation
"""
from dataclasses import dataclass
from typing import Any, Dict
from unittest.mock import AsyncMock, Mock, patch
from unittest.mock import AsyncMock, patch
import pytest
@ -17,126 +15,14 @@ from api.services.workflow.pipecat_engine_utils import (
get_function_schema,
update_llm_context,
)
from api.tests.conftest import MockToolModel
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.processors.aggregators.llm_context import LLMContext
@dataclass
class MockToolModel:
"""Mock tool model for testing."""
tool_uuid: str
name: str
description: str
definition: Dict[str, Any]
class TestCustomToolManagerContextIntegration:
"""Integration tests for CustomToolManager with LLMContext."""
@pytest.fixture
def mock_engine(self):
"""Create a mock PipecatEngine."""
engine = Mock()
engine._workflow_run_id = 1
engine._call_context_vars = {"customer_name": "John Doe"}
engine.llm = Mock()
engine.llm.register_function = Mock()
return engine
@pytest.fixture
def sample_tools(self):
"""Create sample mock tools for testing."""
return [
MockToolModel(
tool_uuid="weather-uuid-123",
name="Get Weather",
description="Get current weather for a location",
definition={
"schema_version": 1,
"type": "http_api",
"config": {
"method": "GET",
"url": "https://api.weather.com/current",
"parameters": [
{
"name": "location",
"type": "string",
"description": "City name (e.g., San Francisco, CA)",
"required": True,
},
{
"name": "units",
"type": "string",
"description": "Temperature units: celsius or fahrenheit",
"required": False,
},
],
},
},
),
MockToolModel(
tool_uuid="booking-uuid-456",
name="Book Appointment",
description="Book an appointment for the customer",
definition={
"schema_version": 1,
"type": "http_api",
"config": {
"method": "POST",
"url": "https://api.example.com/appointments",
"parameters": [
{
"name": "customer_name",
"type": "string",
"description": "Customer's full name",
"required": True,
},
{
"name": "date",
"type": "string",
"description": "Appointment date (YYYY-MM-DD)",
"required": True,
},
{
"name": "time",
"type": "string",
"description": "Appointment time (HH:MM)",
"required": True,
},
{
"name": "notes",
"type": "string",
"description": "Additional notes",
"required": False,
},
],
},
},
),
MockToolModel(
tool_uuid="lookup-uuid-789",
name="Customer Lookup",
description="Look up customer information by phone number",
definition={
"schema_version": 1,
"type": "http_api",
"config": {
"method": "GET",
"url": "https://api.example.com/customers/lookup",
"parameters": [
{
"name": "phone",
"type": "string",
"description": "Customer phone number",
"required": True,
},
],
},
},
),
]
@pytest.mark.asyncio
async def test_get_tool_schemas_and_update_context(self, mock_engine, sample_tools):
"""Test fetching tool schemas via CustomToolManager and updating LLM context."""

View file

@ -6,6 +6,6 @@ from api.services.workflow.dto import ReactFlowDTO
@pytest.mark.asyncio
async def test_dto():
# assert no exceptions are raised
with open("services/workflow/test/definitions/rf-1.json", "r") as f:
with open("tests/definitions/rf-1.json", "r") as f:
dto = ReactFlowDTO.model_validate_json(f.read())
assert dto is not None

View file

@ -0,0 +1,340 @@
"""Tests for tool calls with PipecatEngine and MockLLM.
This module tests the behavior when the LLM generates tool calls (single or parallel),
using PipecatEngine's actual function registration and execution logic.
"""
import asyncio
from typing import Any, Dict, List
from unittest.mock import AsyncMock, patch
import pytest
from api.services.pipecat.pipeline_engine_callbacks_processor import (
PipelineEngineCallbacksProcessor,
)
from api.services.workflow.pipecat_engine import PipecatEngine
from api.services.workflow.workflow import WorkflowGraph
from api.tests.conftest import END_CALL_SYSTEM_PROMPT
from pipecat.frames.frames import (
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
Frame,
TextFrame,
)
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
from pipecat.processors.aggregators.llm_response_universal import (
LLMContextAggregatorPair,
)
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.tests import MockLLMService
class MockBotStoppedSpeakingOnLLMTextFrameProcessor(FrameProcessor):
"""
Mocking the transport, where transport sends BotStartedSpeakingFrame
and BotStoppedSpeakingFrame when it encounters a LLMTextFrame.
"""
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, TextFrame):
await self.push_frame(BotStartedSpeakingFrame())
await self.push_frame(
BotStartedSpeakingFrame(), direction=FrameDirection.UPSTREAM
)
await asyncio.sleep(0.1)
await self.push_frame(BotStoppedSpeakingFrame())
await self.push_frame(
BotStoppedSpeakingFrame(), direction=FrameDirection.UPSTREAM
)
await self.push_frame(frame, direction)
async def run_pipeline_with_tool_calls(
workflow: WorkflowGraph,
functions: List[Dict[str, Any]],
text: str | None = None,
num_text_steps: int = 1,
) -> tuple[MockLLMService, LLMContext]:
"""Run a pipeline with mock tool calls and return the LLM for assertions.
Args:
workflow: The workflow graph to use.
functions: List of function call definitions with name, arguments, and tool_call_id.
text: Text to add to the first step (streamed before the tool calls).
num_text_steps: Number of text response steps after the tool calls.
Returns:
The MockLLMService instance for making assertions.
"""
# Create first step chunks
if text:
# Create text chunks (without final chunk) followed by function call chunks
text_chunks = MockLLMService.create_text_chunks(text)
func_chunks = MockLLMService.create_multiple_function_call_chunks(functions)
# Exclude the final chunk from text_chunks (which has finish_reason="stop")
first_step_chunks = text_chunks[:-1] + func_chunks
else:
first_step_chunks = MockLLMService.create_multiple_function_call_chunks(
functions
)
# Create multi-step responses
mock_steps = MockLLMService.create_multi_step_responses(
first_step_chunks, num_text_steps=num_text_steps, step_prefix="Response"
)
# Create MockLLMService with multi-step support
llm = MockLLMService(mock_steps=mock_steps, chunk_delay=0.001)
mock_transport_emulator = MockBotStoppedSpeakingOnLLMTextFrameProcessor()
# Create LLM context
context = LLMContext()
# Add assistant context aggregator
assistant_params = LLMAssistantAggregatorParams(expect_stripped_words=True)
context_aggregator = LLMContextAggregatorPair(
context, assistant_params=assistant_params
)
assistant_context_aggregator = context_aggregator.assistant()
# Create PipecatEngine with the workflow
engine = PipecatEngine(
llm=llm,
context=context,
workflow=workflow,
call_context_vars={"customer_name": "Test User"},
workflow_run_id=1,
)
# Create the pipeline with the mock LLM
pipeline = Pipeline(
[
llm,
mock_transport_emulator,
assistant_context_aggregator,
]
)
# Create a real pipeline task
task = PipelineTask(
pipeline,
params=PipelineParams(allow_interruptions=False),
)
engine.set_task(task)
# Patch DB calls to avoid actual database access
with patch(
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run",
new_callable=AsyncMock,
return_value=1,
):
with patch(
"api.services.workflow.pipecat_engine.apply_disposition_mapping",
new_callable=AsyncMock,
return_value="completed",
):
runner = PipelineRunner()
async def run_pipeline():
await runner.run(task)
async def initialize_engine():
# Small delay to let runner start
await asyncio.sleep(0.01)
await engine.initialize()
# Run both concurrently
await asyncio.gather(run_pipeline(), initialize_engine())
return llm, context
class TestPipecatEngineToolCalls:
"""Test tool calls through PipecatEngine."""
@pytest.mark.asyncio
async def test_parallel_builtin_and_transition_calls_through_engine(
self, simple_workflow: WorkflowGraph
):
"""Test parallel function calls using PipecatEngine's actual handlers.
This test verifies that when the LLM generates parallel tool calls for:
1. A built-in function (safe_calculator) - registered by _register_builtin_functions
2. A transition function (end_call) - registered by _register_transition_function_with_llm
Both functions are properly executed through the engine's handlers and
the transition correctly moves to the end node.
The test uses multi-step mock responses:
- Step 1: Parallel tool calls (safe_calculator + end_call)
- Step 2+: Text responses for subsequent node prompts
"""
functions = [
{
"name": "end_call",
"arguments": {},
"tool_call_id": "call_transition",
},
{
"name": "safe_calculator",
"arguments": {"expression": "25 * 4"},
"tool_call_id": "call_calc",
},
]
llm, context = await run_pipeline_with_tool_calls(
workflow=simple_workflow,
functions=functions,
num_text_steps=2,
)
# Assert that the LLM generation was called a total of 2 times,
# 1st time when StartNode was executed, and second time
# when EndCall generation happened
assert llm.get_current_step() == 2, (
"LLM generation should have happened 2 times"
)
# Assert that the context was updated with END_CALL_SYSTEM_PROMPT
assert context.messages[0]["content"] == END_CALL_SYSTEM_PROMPT
@pytest.mark.asyncio
async def test_parallel_builtin_and_transition_calls_through_engine_1(
self, simple_workflow: WorkflowGraph
):
"""Test parallel function calls using PipecatEngine's actual handlers.
This test verifies that when the LLM generates parallel tool calls for:
1. A built-in function (safe_calculator) - registered by _register_builtin_functions
2. A transition function (end_call) - registered by _register_transition_function_with_llm
Both functions are properly executed through the engine's handlers and
the transition correctly moves to the end node.
The test uses multi-step mock responses:
- Step 1: Parallel tool calls (safe_calculator + end_call)
- Step 2+: Text responses for subsequent node prompts
"""
functions = [
{
"name": "safe_calculator",
"arguments": {"expression": "25 * 4"},
"tool_call_id": "call_calc",
},
{
"name": "end_call",
"arguments": {},
"tool_call_id": "call_transition",
},
]
llm, context = await run_pipeline_with_tool_calls(
workflow=simple_workflow,
functions=functions,
num_text_steps=2,
)
# Assert that the LLM generation was called a total of 2 times,
# 1st time when StartNode was executed, and second time
# when EndCall generation happened. The tool should not invoke
# an LLM generation
assert llm.get_current_step() == 2, (
"LLM generation should have happened 2 times"
)
# Assert that the context was updated with END_CALL_SYSTEM_PROMPT
assert context.messages[0]["content"] == END_CALL_SYSTEM_PROMPT
@pytest.mark.asyncio
async def test_parallel_builtin_and_transition_calls_through_engine_with_text(
self, simple_workflow: WorkflowGraph
):
"""Test parallel function calls using PipecatEngine's actual handlers.
This test verifies that when the LLM generates parallel tool calls for:
1. A built-in function (safe_calculator) - registered by _register_builtin_functions
2. A transition function (end_call) - registered by _register_transition_function_with_llm
Both functions are properly executed through the engine's handlers and
the transition correctly moves to the end node.
The test uses multi-step mock responses:
- Step 1: Parallel tool calls (safe_calculator + end_call)
- Step 2+: Text responses for subsequent node prompts
"""
functions = [
{
"name": "end_call",
"arguments": {},
"tool_call_id": "call_transition",
},
{
"name": "safe_calculator",
"arguments": {"expression": "25 * 4"},
"tool_call_id": "call_calc",
},
]
llm, context = await run_pipeline_with_tool_calls(
workflow=simple_workflow,
functions=functions,
text="Hello There!",
num_text_steps=2,
)
# Assert that the LLM generation was called a total of 2 times,
# 1st time when StartNode was executed, and second time
# when EndCall generation happened. The tool should not invoke
# an LLM generation
assert llm.get_current_step() == 2, (
"LLM generation should have happened 2 times"
)
# Assert that the context was updated with END_CALL_SYSTEM_PROMPT
assert context.messages[0]["content"] == END_CALL_SYSTEM_PROMPT
@pytest.mark.asyncio
async def test_single_transition_call_through_engine(
self, simple_workflow: WorkflowGraph
):
"""Test a single transition function call (end_call) through PipecatEngine.
This test verifies that when the LLM generates only a transition tool call,
the engine properly executes it and transitions to the end node.
Since end_call transitions to the end node which triggers another LLM
generation, the LLM is called exactly once for the initial StartNode.
"""
functions = [
{
"name": "end_call",
"arguments": {},
"tool_call_id": "call_transition",
},
]
llm, context = await run_pipeline_with_tool_calls(
workflow=simple_workflow,
functions=functions,
num_text_steps=1,
)
# LLM is called once for the StartNode, then end_call transitions to EndNode
# which triggers a second generation
assert llm.get_current_step() == 2, (
"LLM generation should have happened 2 times"
)
# Assert that the context was updated with END_CALL_SYSTEM_PROMPT
assert context.messages[0]["content"] == END_CALL_SYSTEM_PROMPT

View file

@ -27,6 +27,8 @@ import {
type KeyValueItem,
ParameterEditor,
type ToolParameter,
UrlInput,
validateUrl,
} from "@/components/http";
import { Button } from "@/components/ui/button";
import {
@ -151,8 +153,9 @@ export default function ToolDetailPage() {
const handleSave = async () => {
// Validate URL
if (!url.trim()) {
setError("URL is required");
const urlValidation = validateUrl(url);
if (!urlValidation.valid) {
setError(urlValidation.error || "Invalid URL");
return;
}
@ -431,10 +434,11 @@ const data = await response.json();`;
<div className="grid gap-2">
<Label>Endpoint URL</Label>
<Input
<UrlInput
value={url}
onChange={(e) => setUrl(e.target.value)}
onChange={setUrl}
placeholder="https://api.example.com/appointments"
showValidation
/>
</div>
</TabsContent>

View file

@ -10,6 +10,8 @@ import {
HttpMethodSelector,
KeyValueEditor,
type KeyValueItem,
UrlInput,
validateUrl,
} from "@/components/http";
import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input";
@ -57,8 +59,9 @@ export const WebhookNode = memo(({ data, selected, id }: WebhookNodeProps) => {
const handleSave = async () => {
// Validate endpoint URL
if (!endpointUrl.trim()) {
setEndpointError('Endpoint URL is required');
const urlValidation = validateUrl(endpointUrl);
if (!urlValidation.valid) {
setEndpointError(urlValidation.error || 'Invalid URL');
return;
}
setEndpointError(null);
@ -284,10 +287,11 @@ const WebhookNodeEditForm = ({
<Label className="text-xs text-muted-foreground">
The URL to send the webhook request to.
</Label>
<Input
<UrlInput
value={endpointUrl}
onChange={(e) => setEndpointUrl(e.target.value)}
onChange={setEndpointUrl}
placeholder="https://api.example.com/webhook"
showValidation
/>
</div>
</TabsContent>

View file

@ -3,3 +3,4 @@ export { CredentialSelector } from "./credential-selector";
export { type HttpMethod, HttpMethodSelector } from "./http-method-selector";
export { KeyValueEditor, type KeyValueItem } from "./key-value-editor";
export { ParameterEditor, type ParameterType,type ToolParameter } from "./parameter-editor";
export { UrlInput, type UrlValidationResult,validateUrl } from "./url-input";

View file

@ -0,0 +1,106 @@
"use client";
import { useCallback, useState } from "react";
import { Input } from "@/components/ui/input";
import { cn } from "@/lib/utils";
// URL regex pattern that validates:
// - http:// or https:// protocol (required)
// - Optional username:password@
// - Domain name or IP address
// - Optional port number
// - Optional path, query string, and fragment
const URL_REGEX =
/^https?:\/\/(?:[\w-]+(?::[\w-]+)?@)?(?:[\w-]+\.)*[\w-]+(?::\d{1,5})?(?:\/[^\s]*)?$/i;
export interface UrlValidationResult {
valid: boolean;
error?: string;
}
export function validateUrl(url: string): UrlValidationResult {
const trimmedUrl = url.trim();
if (!trimmedUrl) {
return { valid: false, error: "URL is required" };
}
if (!URL_REGEX.test(trimmedUrl)) {
return {
valid: false,
error: "Invalid URL format. Must start with http:// or https://",
};
}
return { valid: true };
}
interface UrlInputProps {
value: string;
onChange: (value: string) => void;
placeholder?: string;
disabled?: boolean;
className?: string;
/** Show validation error styling and message inline */
showValidation?: boolean;
/** Called when validation state changes */
onValidationChange?: (result: UrlValidationResult) => void;
}
export function UrlInput({
value,
onChange,
placeholder = "https://api.example.com/endpoint",
disabled = false,
className,
showValidation = false,
onValidationChange,
}: UrlInputProps) {
const [touched, setTouched] = useState(false);
const handleChange = useCallback(
(e: React.ChangeEvent<HTMLInputElement>) => {
const newValue = e.target.value;
onChange(newValue);
if (onValidationChange && (touched || newValue)) {
onValidationChange(validateUrl(newValue));
}
},
[onChange, onValidationChange, touched]
);
const handleBlur = useCallback(() => {
setTouched(true);
const trimmedValue = value.trim();
if (trimmedValue !== value) {
onChange(trimmedValue);
}
if (onValidationChange && trimmedValue) {
onValidationChange(validateUrl(trimmedValue));
}
}, [onChange, onValidationChange, value]);
const validation = validateUrl(value);
const showError = showValidation && touched && !validation.valid && value;
return (
<div className="space-y-1">
<Input
value={value}
onChange={handleChange}
onBlur={handleBlur}
placeholder={placeholder}
disabled={disabled}
className={cn(
showError && "border-destructive focus-visible:ring-destructive",
className
)}
/>
{showError && (
<p className="text-xs text-destructive">{validation.error}</p>
)}
</div>
);
}