Fix/multiple generation (#104)

* fixes #100

* Fix test

* fix: fix bad configuration issue
This commit is contained in:
Abhishek 2026-01-03 12:59:18 +05:30 committed by GitHub
parent 90b690efff
commit 56953bbd09
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 758 additions and 460 deletions

256
api/tests/conftest.py Normal file
View file

@ -0,0 +1,256 @@
from dataclasses import dataclass
from typing import Any, Dict
from unittest.mock import Mock
import pytest
from api.services.workflow.dto import (
EdgeDataDTO,
NodeDataDTO,
NodeType,
Position,
ReactFlowDTO,
RFEdgeDTO,
RFNodeDTO,
)
from api.services.workflow.workflow import WorkflowGraph
START_CALL_SYSTEM_PROMPT = "start_call_system_prompt"
END_CALL_SYSTEM_PROMPT = "end_call_system_prompt"
@dataclass
class MockToolModel:
"""Mock tool model for testing."""
tool_uuid: str
name: str
description: str
definition: Dict[str, Any]
@pytest.fixture
def mock_engine():
"""Create a mock PipecatEngine."""
engine = Mock()
engine._workflow_run_id = 1
engine._call_context_vars = {"customer_name": "John Doe"}
engine.llm = Mock()
engine.llm.register_function = Mock()
return engine
@pytest.fixture
def sample_tools():
"""Create sample mock tools for testing."""
return [
MockToolModel(
tool_uuid="weather-uuid-123",
name="Get Weather",
description="Get current weather for a location",
definition={
"schema_version": 1,
"type": "http_api",
"config": {
"method": "GET",
"url": "https://api.weather.com/current",
"parameters": [
{
"name": "location",
"type": "string",
"description": "City name (e.g., San Francisco, CA)",
"required": True,
},
{
"name": "units",
"type": "string",
"description": "Temperature units: celsius or fahrenheit",
"required": False,
},
],
},
},
),
MockToolModel(
tool_uuid="booking-uuid-456",
name="Book Appointment",
description="Book an appointment for the customer",
definition={
"schema_version": 1,
"type": "http_api",
"config": {
"method": "POST",
"url": "https://api.example.com/appointments",
"parameters": [
{
"name": "customer_name",
"type": "string",
"description": "Customer's full name",
"required": True,
},
{
"name": "date",
"type": "string",
"description": "Appointment date (YYYY-MM-DD)",
"required": True,
},
{
"name": "time",
"type": "string",
"description": "Appointment time (HH:MM)",
"required": True,
},
{
"name": "notes",
"type": "string",
"description": "Additional notes",
"required": False,
},
],
},
},
),
MockToolModel(
tool_uuid="lookup-uuid-789",
name="Customer Lookup",
description="Look up customer information by phone number",
definition={
"schema_version": 1,
"type": "http_api",
"config": {
"method": "GET",
"url": "https://api.example.com/customers/lookup",
"parameters": [
{
"name": "phone",
"type": "string",
"description": "Customer phone number",
"required": True,
},
],
},
},
),
]
@pytest.fixture
def simple_workflow() -> WorkflowGraph:
"""Create a simple two-node workflow for testing.
The workflow has:
- Start node with a prompt
- End node with a prompt
- One edge connecting them with label "End Call"
"""
dto = ReactFlowDTO(
nodes=[
RFNodeDTO(
id="1",
type=NodeType.startNode,
position=Position(x=0, y=0),
data=NodeDataDTO(
name="Start Call",
prompt=START_CALL_SYSTEM_PROMPT,
is_start=True,
allow_interrupt=False,
add_global_prompt=False,
),
),
RFNodeDTO(
id="2",
type=NodeType.endNode,
position=Position(x=0, y=200),
data=NodeDataDTO(
name="End Call",
prompt=END_CALL_SYSTEM_PROMPT,
is_end=True,
allow_interrupt=False,
add_global_prompt=False,
),
),
],
edges=[
RFEdgeDTO(
id="1-2",
source="1",
target="2",
data=EdgeDataDTO(
label="End Call",
condition="When the user says to end the call, end the call",
),
),
],
)
return WorkflowGraph(dto)
@pytest.fixture
def three_node_workflow() -> WorkflowGraph:
"""Create a three-node workflow for testing with an intermediate agent node.
The workflow has:
- Start node
- Agent node (for collecting information)
- End node
"""
dto = ReactFlowDTO(
nodes=[
RFNodeDTO(
id="1",
type=NodeType.startNode,
position=Position(x=0, y=0),
data=NodeDataDTO(
name="Start Call",
prompt=START_CALL_SYSTEM_PROMPT,
is_start=True,
allow_interrupt=True,
add_global_prompt=False,
),
),
RFNodeDTO(
id="2",
type=NodeType.agentNode,
position=Position(x=0, y=200),
data=NodeDataDTO(
name="Collect Info",
prompt="Help the user with their request. Ask clarifying questions if needed.",
allow_interrupt=True,
add_global_prompt=False,
),
),
RFNodeDTO(
id="3",
type=NodeType.endNode,
position=Position(x=0, y=400),
data=NodeDataDTO(
name="End Call",
prompt=END_CALL_SYSTEM_PROMPT,
is_end=True,
allow_interrupt=False,
add_global_prompt=False,
),
),
],
edges=[
RFEdgeDTO(
id="1-2",
source="1",
target="2",
data=EdgeDataDTO(
label="Collect Info",
condition="When the user wants help, collect their information",
),
),
RFEdgeDTO(
id="2-3",
source="2",
target="3",
data=EdgeDataDTO(
label="End Call",
condition="When the user is done or wants to end the call",
),
),
],
)
return WorkflowGraph(dto)

View file

@ -0,0 +1,164 @@
{
"nodes": [
{
"id": "915",
"type": "agentNode",
"position": {
"x": 633,
"y": 324
},
"data": {
"prompt": "You are a voice agent whose mode of speaking is voice. Ask the user whether they want to talk to a sales guy or a customer service agent",
"name": "Agent"
},
"measured": {
"width": 300,
"height": 100
},
"selected": false,
"dragging": false
},
{
"id": "7598",
"type": "agentNode",
"position": {
"x": 460.1247806640531,
"y": 610.3714977079578
},
"data": {
"prompt": "You are a customer service agent whose mode of communication with the user is voice. Tell them that someone from our team will reach out to them soon",
"name": "Agent"
},
"measured": {
"width": 300,
"height": 100
},
"selected": false,
"dragging": false
},
{
"id": "6919",
"type": "agentNode",
"position": {
"x": 914.666735413607,
"y": 642.9800281289787
},
"data": {
"prompt": "You are a sales representative whose mode of communication with the user is voice. Tell the user that someone from our team will reach out to you soon",
"name": "Agent"
},
"measured": {
"width": 300,
"height": 100
},
"selected": false,
"dragging": false
},
{
"id": "6581",
"type": "startCall",
"position": {
"x": 648,
"y": 35
},
"data": {
"prompt": "Hello, I am Abhishek from Dograh. ",
"is_static": true,
"name": "Start Call",
"is_start": true
},
"measured": {
"width": 300,
"height": 100
},
"selected": false,
"dragging": false
},
{
"id": "1802",
"type": "endCall",
"position": {
"x": 666.7733431033548,
"y": 987.4345801025363
},
"data": {
"prompt": "Thank you for calling Dograh. Have a great day!",
"is_static": true,
"name": "End Call"
},
"measured": {
"width": 300,
"height": 100
},
"selected": false,
"dragging": false
}
],
"edges": [
{
"animated": true,
"type": "custom",
"source": "915",
"target": "7598",
"id": "xy-edge__915-7598",
"selected": false,
"data": {
"condition": "The customer wants to talk to a customer service agent",
"label": "customer service agent"
}
},
{
"animated": true,
"type": "custom",
"source": "915",
"target": "6919",
"id": "xy-edge__915-6919",
"selected": false,
"data": {
"condition": "customer wants to talk to a sales representative",
"label": "sales representative"
}
},
{
"animated": true,
"type": "custom",
"source": "6581",
"target": "915",
"id": "xy-edge__6581-915",
"selected": false,
"data": {
"condition": "Always take this route",
"label": "Always take this route"
}
},
{
"animated": true,
"type": "custom",
"source": "7598",
"target": "1802",
"id": "xy-edge__7598-1802",
"selected": false,
"data": {
"condition": "end call",
"label": "end call"
}
},
{
"animated": true,
"type": "custom",
"source": "6919",
"target": "1802",
"id": "xy-edge__6919-1802",
"selected": false,
"data": {
"condition": "end call",
"label": "end call"
}
}
],
"viewport": {
"x": 0,
"y": 0,
"zoom": 1
}
}

View file

@ -0,0 +1,192 @@
from unittest.mock import Mock
from api.services.workflow.pipecat_engine_callbacks import (
create_aggregation_correction_callback,
)
def test_aggregation_fixer():
"""Validate the aggregation correction algorithm using a helper that
creates a fresh callback for every (reference, corrupted) pair.
The production callback now needs a PipecatEngine instance with the
`_current_llm_generation_reference_text` set. For test-friendliness we mock a bare
object providing just that attribute for each assertion so the original
two-argument test cases remain unchanged.
"""
def fixer(reference: str, corrupted: str) -> str: # noqa: D401
mock_engine = Mock()
mock_engine._current_llm_generation_reference_text = reference
return create_aggregation_correction_callback(mock_engine)(corrupted)
##### Trailing extra Chars #####
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"My name is Alex and I am calling you from Cons umer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
)
== "My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?"
), "leading_whole_sentence"
# Whole sentences
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"Good Morning Mr NAR GES , My name is Alex and I am calling you from Cons umer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
)
== "Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?"
), "whole_sentences"
# With a period in the end
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"Good Morning Mr NAR GES , My name is Alex and I am calling you from Cons umer Services.",
)
== "Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services."
), "period_end"
# without a period in the end
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"Good Morning Mr NAR GES , My name is Alex and I am calling you from Cons umer Services",
)
== "Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services"
), "without_period_end"
# Extra space in the end
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"Good Morning Mr NAR GES , My name is Alex and I am calling you from Cons umer Services ",
)
== "Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services"
), "extra_space"
# Multiple spaces in corruption
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"Good Morning Mr NAR GES , My name is Alex and I am calling you from Cons umer Servi ces ",
)
== "Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services"
), "multiple_space"
# Multiple spaces in corruption ending in a whitespace
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"Good Morning Mr NAR GES , My name is Alex and I am calling you from Cons umer Servi ces. ",
)
== "Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. "
), "multiple_space_end_ws"
##### Leading extra Chars #####
# Whole sentences
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"My name is Alex and I am calling you from Cons umer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
)
== "My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?"
), "leading_whole_sentence"
# With a period in the end
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"My name is Alex and I am calling you from Cons umer Services.",
)
== "My name is Alex and I am calling you from Consumer Services."
), "leading_period_end"
# without a period in the end
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"My name is Alex and I am calling you from Cons umer Services",
)
== "My name is Alex and I am calling you from Consumer Services"
), "leading_without_period_end"
# Extra space in the end
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"My name is Alex and I am calling you from Cons umer Services ",
)
== "My name is Alex and I am calling you from Consumer Services"
), "leading_extra_space"
# Multiple spaces in corruption
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"My name is Alex and I am calling you from Cons umer Servi ces ",
)
== "My name is Alex and I am calling you from Consumer Services"
), "leading_multiple_space"
# Multiple spaces in corruption ending in a whitespace
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"My name is Alex and I am calling you from Cons umer Servi ces. ",
)
== "My name is Alex and I am calling you from Consumer Services. "
), "leading_multiple_space_end_ws"
# Whitespace
assert fixer("", "") == ""
# Missing reference
assert (
fixer("", "My name is Alex and I am calling you from Cons umer Servi ces.")
== "My name is Alex and I am calling you from Cons umer Servi ces."
), "missing_reference"
# Smaller reference
assert (
fixer(
"My name is Alex",
"My name is Alex and I am calling you from Cons umer Servi ces.",
)
== "My name is Alex and I am calling you from Cons umer Servi ces."
), "smaller_reference"
# Unrelated reference
assert (
fixer(
"Hello Hello",
"My name is Alex and I am calling you from Cons umer Servi ces.",
)
== "My name is Alex and I am calling you from Cons umer Servi ces."
), "unrelated_reference"
def test_create_aggregation_correction_callback():
"""Test the new aggregation correction callback creator."""
# Mock engine with reference text
mock_engine = Mock()
mock_engine._current_llm_generation_reference_text = "Good Morning Mr NARGES, My name is Alex and I am calling you from Consumer Services."
# Create callback
callback = create_aggregation_correction_callback(mock_engine)
# Test correction
corrected = callback(
"Good Morning Mr NAR GES, My name is Alex and I am calling you from Cons umer Services."
)
assert (
corrected
== "Good Morning Mr NARGES, My name is Alex and I am calling you from Consumer Services."
)
# Test with no reference text
mock_engine._current_llm_generation_reference_text = ""
corrected = callback("Some corrupted text")
assert corrected == "Some corrupted text" # Should return as-is when no reference

View file

@ -0,0 +1,31 @@
from api.services.pricing.cost_calculator import cost_calculator
def test_cost_calculator():
"""Test function to verify cost calculation works"""
sample_usage = {
"llm": {
"OpenAILLMService#0|||gpt-4.1-mini": {
"prompt_tokens": 45380,
"completion_tokens": 496,
"total_tokens": 45876,
"cache_read_input_tokens": 0,
"cache_creation_input_tokens": 0,
}
},
"tts": {"ElevenLabsTTSService#0|||eleven_flash_v2_5": 2399},
"stt": {"DeepgramSTTService#0|||nova-3-general": 177.21536946296692},
"call_duration_seconds": 179,
}
result = cost_calculator.calculate_total_cost(sample_usage)
assert result["llm_cost"] == 45380 * 0.40 / 1_000_000 + 496 * 1.60 / 1_000_000
assert result["tts_cost"] == 2399 * 0.0256 / 1_000
assert result["stt_cost"] == 177.21536946296692 / 60 * 0.0077
assert (
abs(
result["total"]
- (result["llm_cost"] + result["tts_cost"] + result["stt_cost"])
)
< 1e-10
)

View file

@ -6,9 +6,7 @@ This module tests the full flow of:
3. Verifying the context is properly configured for LLM generation
"""
from dataclasses import dataclass
from typing import Any, Dict
from unittest.mock import AsyncMock, Mock, patch
from unittest.mock import AsyncMock, patch
import pytest
@ -17,126 +15,14 @@ from api.services.workflow.pipecat_engine_utils import (
get_function_schema,
update_llm_context,
)
from api.tests.conftest import MockToolModel
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.processors.aggregators.llm_context import LLMContext
@dataclass
class MockToolModel:
"""Mock tool model for testing."""
tool_uuid: str
name: str
description: str
definition: Dict[str, Any]
class TestCustomToolManagerContextIntegration:
"""Integration tests for CustomToolManager with LLMContext."""
@pytest.fixture
def mock_engine(self):
"""Create a mock PipecatEngine."""
engine = Mock()
engine._workflow_run_id = 1
engine._call_context_vars = {"customer_name": "John Doe"}
engine.llm = Mock()
engine.llm.register_function = Mock()
return engine
@pytest.fixture
def sample_tools(self):
"""Create sample mock tools for testing."""
return [
MockToolModel(
tool_uuid="weather-uuid-123",
name="Get Weather",
description="Get current weather for a location",
definition={
"schema_version": 1,
"type": "http_api",
"config": {
"method": "GET",
"url": "https://api.weather.com/current",
"parameters": [
{
"name": "location",
"type": "string",
"description": "City name (e.g., San Francisco, CA)",
"required": True,
},
{
"name": "units",
"type": "string",
"description": "Temperature units: celsius or fahrenheit",
"required": False,
},
],
},
},
),
MockToolModel(
tool_uuid="booking-uuid-456",
name="Book Appointment",
description="Book an appointment for the customer",
definition={
"schema_version": 1,
"type": "http_api",
"config": {
"method": "POST",
"url": "https://api.example.com/appointments",
"parameters": [
{
"name": "customer_name",
"type": "string",
"description": "Customer's full name",
"required": True,
},
{
"name": "date",
"type": "string",
"description": "Appointment date (YYYY-MM-DD)",
"required": True,
},
{
"name": "time",
"type": "string",
"description": "Appointment time (HH:MM)",
"required": True,
},
{
"name": "notes",
"type": "string",
"description": "Additional notes",
"required": False,
},
],
},
},
),
MockToolModel(
tool_uuid="lookup-uuid-789",
name="Customer Lookup",
description="Look up customer information by phone number",
definition={
"schema_version": 1,
"type": "http_api",
"config": {
"method": "GET",
"url": "https://api.example.com/customers/lookup",
"parameters": [
{
"name": "phone",
"type": "string",
"description": "Customer phone number",
"required": True,
},
],
},
},
),
]
@pytest.mark.asyncio
async def test_get_tool_schemas_and_update_context(self, mock_engine, sample_tools):
"""Test fetching tool schemas via CustomToolManager and updating LLM context."""

11
api/tests/test_dto.py Normal file
View file

@ -0,0 +1,11 @@
import pytest
from api.services.workflow.dto import ReactFlowDTO
@pytest.mark.asyncio
async def test_dto():
# assert no exceptions are raised
with open("tests/definitions/rf-1.json", "r") as f:
dto = ReactFlowDTO.model_validate_json(f.read())
assert dto is not None

View file

@ -0,0 +1,340 @@
"""Tests for tool calls with PipecatEngine and MockLLM.
This module tests the behavior when the LLM generates tool calls (single or parallel),
using PipecatEngine's actual function registration and execution logic.
"""
import asyncio
from typing import Any, Dict, List
from unittest.mock import AsyncMock, patch
import pytest
from api.services.pipecat.pipeline_engine_callbacks_processor import (
PipelineEngineCallbacksProcessor,
)
from api.services.workflow.pipecat_engine import PipecatEngine
from api.services.workflow.workflow import WorkflowGraph
from api.tests.conftest import END_CALL_SYSTEM_PROMPT
from pipecat.frames.frames import (
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
Frame,
TextFrame,
)
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
from pipecat.processors.aggregators.llm_response_universal import (
LLMContextAggregatorPair,
)
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.tests import MockLLMService
class MockBotStoppedSpeakingOnLLMTextFrameProcessor(FrameProcessor):
"""
Mocking the transport, where transport sends BotStartedSpeakingFrame
and BotStoppedSpeakingFrame when it encounters a LLMTextFrame.
"""
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, TextFrame):
await self.push_frame(BotStartedSpeakingFrame())
await self.push_frame(
BotStartedSpeakingFrame(), direction=FrameDirection.UPSTREAM
)
await asyncio.sleep(0.1)
await self.push_frame(BotStoppedSpeakingFrame())
await self.push_frame(
BotStoppedSpeakingFrame(), direction=FrameDirection.UPSTREAM
)
await self.push_frame(frame, direction)
async def run_pipeline_with_tool_calls(
workflow: WorkflowGraph,
functions: List[Dict[str, Any]],
text: str | None = None,
num_text_steps: int = 1,
) -> tuple[MockLLMService, LLMContext]:
"""Run a pipeline with mock tool calls and return the LLM for assertions.
Args:
workflow: The workflow graph to use.
functions: List of function call definitions with name, arguments, and tool_call_id.
text: Text to add to the first step (streamed before the tool calls).
num_text_steps: Number of text response steps after the tool calls.
Returns:
The MockLLMService instance for making assertions.
"""
# Create first step chunks
if text:
# Create text chunks (without final chunk) followed by function call chunks
text_chunks = MockLLMService.create_text_chunks(text)
func_chunks = MockLLMService.create_multiple_function_call_chunks(functions)
# Exclude the final chunk from text_chunks (which has finish_reason="stop")
first_step_chunks = text_chunks[:-1] + func_chunks
else:
first_step_chunks = MockLLMService.create_multiple_function_call_chunks(
functions
)
# Create multi-step responses
mock_steps = MockLLMService.create_multi_step_responses(
first_step_chunks, num_text_steps=num_text_steps, step_prefix="Response"
)
# Create MockLLMService with multi-step support
llm = MockLLMService(mock_steps=mock_steps, chunk_delay=0.001)
mock_transport_emulator = MockBotStoppedSpeakingOnLLMTextFrameProcessor()
# Create LLM context
context = LLMContext()
# Add assistant context aggregator
assistant_params = LLMAssistantAggregatorParams(expect_stripped_words=True)
context_aggregator = LLMContextAggregatorPair(
context, assistant_params=assistant_params
)
assistant_context_aggregator = context_aggregator.assistant()
# Create PipecatEngine with the workflow
engine = PipecatEngine(
llm=llm,
context=context,
workflow=workflow,
call_context_vars={"customer_name": "Test User"},
workflow_run_id=1,
)
# Create the pipeline with the mock LLM
pipeline = Pipeline(
[
llm,
mock_transport_emulator,
assistant_context_aggregator,
]
)
# Create a real pipeline task
task = PipelineTask(
pipeline,
params=PipelineParams(allow_interruptions=False),
)
engine.set_task(task)
# Patch DB calls to avoid actual database access
with patch(
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run",
new_callable=AsyncMock,
return_value=1,
):
with patch(
"api.services.workflow.pipecat_engine.apply_disposition_mapping",
new_callable=AsyncMock,
return_value="completed",
):
runner = PipelineRunner()
async def run_pipeline():
await runner.run(task)
async def initialize_engine():
# Small delay to let runner start
await asyncio.sleep(0.01)
await engine.initialize()
# Run both concurrently
await asyncio.gather(run_pipeline(), initialize_engine())
return llm, context
class TestPipecatEngineToolCalls:
"""Test tool calls through PipecatEngine."""
@pytest.mark.asyncio
async def test_parallel_builtin_and_transition_calls_through_engine(
self, simple_workflow: WorkflowGraph
):
"""Test parallel function calls using PipecatEngine's actual handlers.
This test verifies that when the LLM generates parallel tool calls for:
1. A built-in function (safe_calculator) - registered by _register_builtin_functions
2. A transition function (end_call) - registered by _register_transition_function_with_llm
Both functions are properly executed through the engine's handlers and
the transition correctly moves to the end node.
The test uses multi-step mock responses:
- Step 1: Parallel tool calls (safe_calculator + end_call)
- Step 2+: Text responses for subsequent node prompts
"""
functions = [
{
"name": "end_call",
"arguments": {},
"tool_call_id": "call_transition",
},
{
"name": "safe_calculator",
"arguments": {"expression": "25 * 4"},
"tool_call_id": "call_calc",
},
]
llm, context = await run_pipeline_with_tool_calls(
workflow=simple_workflow,
functions=functions,
num_text_steps=2,
)
# Assert that the LLM generation was called a total of 2 times,
# 1st time when StartNode was executed, and second time
# when EndCall generation happened
assert llm.get_current_step() == 2, (
"LLM generation should have happened 2 times"
)
# Assert that the context was updated with END_CALL_SYSTEM_PROMPT
assert context.messages[0]["content"] == END_CALL_SYSTEM_PROMPT
@pytest.mark.asyncio
async def test_parallel_builtin_and_transition_calls_through_engine_1(
self, simple_workflow: WorkflowGraph
):
"""Test parallel function calls using PipecatEngine's actual handlers.
This test verifies that when the LLM generates parallel tool calls for:
1. A built-in function (safe_calculator) - registered by _register_builtin_functions
2. A transition function (end_call) - registered by _register_transition_function_with_llm
Both functions are properly executed through the engine's handlers and
the transition correctly moves to the end node.
The test uses multi-step mock responses:
- Step 1: Parallel tool calls (safe_calculator + end_call)
- Step 2+: Text responses for subsequent node prompts
"""
functions = [
{
"name": "safe_calculator",
"arguments": {"expression": "25 * 4"},
"tool_call_id": "call_calc",
},
{
"name": "end_call",
"arguments": {},
"tool_call_id": "call_transition",
},
]
llm, context = await run_pipeline_with_tool_calls(
workflow=simple_workflow,
functions=functions,
num_text_steps=2,
)
# Assert that the LLM generation was called a total of 2 times,
# 1st time when StartNode was executed, and second time
# when EndCall generation happened. The tool should not invoke
# an LLM generation
assert llm.get_current_step() == 2, (
"LLM generation should have happened 2 times"
)
# Assert that the context was updated with END_CALL_SYSTEM_PROMPT
assert context.messages[0]["content"] == END_CALL_SYSTEM_PROMPT
@pytest.mark.asyncio
async def test_parallel_builtin_and_transition_calls_through_engine_with_text(
self, simple_workflow: WorkflowGraph
):
"""Test parallel function calls using PipecatEngine's actual handlers.
This test verifies that when the LLM generates parallel tool calls for:
1. A built-in function (safe_calculator) - registered by _register_builtin_functions
2. A transition function (end_call) - registered by _register_transition_function_with_llm
Both functions are properly executed through the engine's handlers and
the transition correctly moves to the end node.
The test uses multi-step mock responses:
- Step 1: Parallel tool calls (safe_calculator + end_call)
- Step 2+: Text responses for subsequent node prompts
"""
functions = [
{
"name": "end_call",
"arguments": {},
"tool_call_id": "call_transition",
},
{
"name": "safe_calculator",
"arguments": {"expression": "25 * 4"},
"tool_call_id": "call_calc",
},
]
llm, context = await run_pipeline_with_tool_calls(
workflow=simple_workflow,
functions=functions,
text="Hello There!",
num_text_steps=2,
)
# Assert that the LLM generation was called a total of 2 times,
# 1st time when StartNode was executed, and second time
# when EndCall generation happened. The tool should not invoke
# an LLM generation
assert llm.get_current_step() == 2, (
"LLM generation should have happened 2 times"
)
# Assert that the context was updated with END_CALL_SYSTEM_PROMPT
assert context.messages[0]["content"] == END_CALL_SYSTEM_PROMPT
@pytest.mark.asyncio
async def test_single_transition_call_through_engine(
self, simple_workflow: WorkflowGraph
):
"""Test a single transition function call (end_call) through PipecatEngine.
This test verifies that when the LLM generates only a transition tool call,
the engine properly executes it and transitions to the end node.
Since end_call transitions to the end node which triggers another LLM
generation, the LLM is called exactly once for the initial StartNode.
"""
functions = [
{
"name": "end_call",
"arguments": {},
"tool_call_id": "call_transition",
},
]
llm, context = await run_pipeline_with_tool_calls(
workflow=simple_workflow,
functions=functions,
num_text_steps=1,
)
# LLM is called once for the StartNode, then end_call transitions to EndNode
# which triggers a second generation
assert llm.get_current_step() == 2, (
"LLM generation should have happened 2 times"
)
# Assert that the context was updated with END_CALL_SYSTEM_PROMPT
assert context.messages[0]["content"] == END_CALL_SYSTEM_PROMPT