mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-13 08:15:21 +02:00
Fix/multiple generation (#104)
* fixes #100 * Fix test * fix: fix bad configuration issue
This commit is contained in:
parent
90b690efff
commit
56953bbd09
18 changed files with 758 additions and 460 deletions
256
api/tests/conftest.py
Normal file
256
api/tests/conftest.py
Normal file
|
|
@ -0,0 +1,256 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Any, Dict
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services.workflow.dto import (
|
||||
EdgeDataDTO,
|
||||
NodeDataDTO,
|
||||
NodeType,
|
||||
Position,
|
||||
ReactFlowDTO,
|
||||
RFEdgeDTO,
|
||||
RFNodeDTO,
|
||||
)
|
||||
from api.services.workflow.workflow import WorkflowGraph
|
||||
|
||||
START_CALL_SYSTEM_PROMPT = "start_call_system_prompt"
|
||||
END_CALL_SYSTEM_PROMPT = "end_call_system_prompt"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockToolModel:
|
||||
"""Mock tool model for testing."""
|
||||
|
||||
tool_uuid: str
|
||||
name: str
|
||||
description: str
|
||||
definition: Dict[str, Any]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_engine():
|
||||
"""Create a mock PipecatEngine."""
|
||||
engine = Mock()
|
||||
engine._workflow_run_id = 1
|
||||
engine._call_context_vars = {"customer_name": "John Doe"}
|
||||
engine.llm = Mock()
|
||||
engine.llm.register_function = Mock()
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_tools():
|
||||
"""Create sample mock tools for testing."""
|
||||
return [
|
||||
MockToolModel(
|
||||
tool_uuid="weather-uuid-123",
|
||||
name="Get Weather",
|
||||
description="Get current weather for a location",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "http_api",
|
||||
"config": {
|
||||
"method": "GET",
|
||||
"url": "https://api.weather.com/current",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "location",
|
||||
"type": "string",
|
||||
"description": "City name (e.g., San Francisco, CA)",
|
||||
"required": True,
|
||||
},
|
||||
{
|
||||
"name": "units",
|
||||
"type": "string",
|
||||
"description": "Temperature units: celsius or fahrenheit",
|
||||
"required": False,
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
),
|
||||
MockToolModel(
|
||||
tool_uuid="booking-uuid-456",
|
||||
name="Book Appointment",
|
||||
description="Book an appointment for the customer",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "http_api",
|
||||
"config": {
|
||||
"method": "POST",
|
||||
"url": "https://api.example.com/appointments",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "customer_name",
|
||||
"type": "string",
|
||||
"description": "Customer's full name",
|
||||
"required": True,
|
||||
},
|
||||
{
|
||||
"name": "date",
|
||||
"type": "string",
|
||||
"description": "Appointment date (YYYY-MM-DD)",
|
||||
"required": True,
|
||||
},
|
||||
{
|
||||
"name": "time",
|
||||
"type": "string",
|
||||
"description": "Appointment time (HH:MM)",
|
||||
"required": True,
|
||||
},
|
||||
{
|
||||
"name": "notes",
|
||||
"type": "string",
|
||||
"description": "Additional notes",
|
||||
"required": False,
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
),
|
||||
MockToolModel(
|
||||
tool_uuid="lookup-uuid-789",
|
||||
name="Customer Lookup",
|
||||
description="Look up customer information by phone number",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "http_api",
|
||||
"config": {
|
||||
"method": "GET",
|
||||
"url": "https://api.example.com/customers/lookup",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "phone",
|
||||
"type": "string",
|
||||
"description": "Customer phone number",
|
||||
"required": True,
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_workflow() -> WorkflowGraph:
|
||||
"""Create a simple two-node workflow for testing.
|
||||
|
||||
The workflow has:
|
||||
- Start node with a prompt
|
||||
- End node with a prompt
|
||||
- One edge connecting them with label "End Call"
|
||||
"""
|
||||
dto = ReactFlowDTO(
|
||||
nodes=[
|
||||
RFNodeDTO(
|
||||
id="1",
|
||||
type=NodeType.startNode,
|
||||
position=Position(x=0, y=0),
|
||||
data=NodeDataDTO(
|
||||
name="Start Call",
|
||||
prompt=START_CALL_SYSTEM_PROMPT,
|
||||
is_start=True,
|
||||
allow_interrupt=False,
|
||||
add_global_prompt=False,
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="2",
|
||||
type=NodeType.endNode,
|
||||
position=Position(x=0, y=200),
|
||||
data=NodeDataDTO(
|
||||
name="End Call",
|
||||
prompt=END_CALL_SYSTEM_PROMPT,
|
||||
is_end=True,
|
||||
allow_interrupt=False,
|
||||
add_global_prompt=False,
|
||||
),
|
||||
),
|
||||
],
|
||||
edges=[
|
||||
RFEdgeDTO(
|
||||
id="1-2",
|
||||
source="1",
|
||||
target="2",
|
||||
data=EdgeDataDTO(
|
||||
label="End Call",
|
||||
condition="When the user says to end the call, end the call",
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
return WorkflowGraph(dto)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def three_node_workflow() -> WorkflowGraph:
|
||||
"""Create a three-node workflow for testing with an intermediate agent node.
|
||||
|
||||
The workflow has:
|
||||
- Start node
|
||||
- Agent node (for collecting information)
|
||||
- End node
|
||||
"""
|
||||
dto = ReactFlowDTO(
|
||||
nodes=[
|
||||
RFNodeDTO(
|
||||
id="1",
|
||||
type=NodeType.startNode,
|
||||
position=Position(x=0, y=0),
|
||||
data=NodeDataDTO(
|
||||
name="Start Call",
|
||||
prompt=START_CALL_SYSTEM_PROMPT,
|
||||
is_start=True,
|
||||
allow_interrupt=True,
|
||||
add_global_prompt=False,
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="2",
|
||||
type=NodeType.agentNode,
|
||||
position=Position(x=0, y=200),
|
||||
data=NodeDataDTO(
|
||||
name="Collect Info",
|
||||
prompt="Help the user with their request. Ask clarifying questions if needed.",
|
||||
allow_interrupt=True,
|
||||
add_global_prompt=False,
|
||||
),
|
||||
),
|
||||
RFNodeDTO(
|
||||
id="3",
|
||||
type=NodeType.endNode,
|
||||
position=Position(x=0, y=400),
|
||||
data=NodeDataDTO(
|
||||
name="End Call",
|
||||
prompt=END_CALL_SYSTEM_PROMPT,
|
||||
is_end=True,
|
||||
allow_interrupt=False,
|
||||
add_global_prompt=False,
|
||||
),
|
||||
),
|
||||
],
|
||||
edges=[
|
||||
RFEdgeDTO(
|
||||
id="1-2",
|
||||
source="1",
|
||||
target="2",
|
||||
data=EdgeDataDTO(
|
||||
label="Collect Info",
|
||||
condition="When the user wants help, collect their information",
|
||||
),
|
||||
),
|
||||
RFEdgeDTO(
|
||||
id="2-3",
|
||||
source="2",
|
||||
target="3",
|
||||
data=EdgeDataDTO(
|
||||
label="End Call",
|
||||
condition="When the user is done or wants to end the call",
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
return WorkflowGraph(dto)
|
||||
164
api/tests/definitions/rf-1.json
Normal file
164
api/tests/definitions/rf-1.json
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
{
|
||||
"nodes": [
|
||||
{
|
||||
"id": "915",
|
||||
"type": "agentNode",
|
||||
"position": {
|
||||
"x": 633,
|
||||
"y": 324
|
||||
},
|
||||
"data": {
|
||||
"prompt": "You are a voice agent whose mode of speaking is voice. Ask the user whether they want to talk to a sales guy or a customer service agent",
|
||||
"name": "Agent"
|
||||
},
|
||||
"measured": {
|
||||
"width": 300,
|
||||
"height": 100
|
||||
},
|
||||
"selected": false,
|
||||
"dragging": false
|
||||
},
|
||||
{
|
||||
"id": "7598",
|
||||
"type": "agentNode",
|
||||
"position": {
|
||||
"x": 460.1247806640531,
|
||||
"y": 610.3714977079578
|
||||
},
|
||||
"data": {
|
||||
"prompt": "You are a customer service agent whose mode of communication with the user is voice. Tell them that someone from our team will reach out to them soon",
|
||||
"name": "Agent"
|
||||
},
|
||||
"measured": {
|
||||
"width": 300,
|
||||
"height": 100
|
||||
},
|
||||
"selected": false,
|
||||
"dragging": false
|
||||
},
|
||||
{
|
||||
"id": "6919",
|
||||
"type": "agentNode",
|
||||
"position": {
|
||||
"x": 914.666735413607,
|
||||
"y": 642.9800281289787
|
||||
},
|
||||
"data": {
|
||||
"prompt": "You are a sales representative whose mode of communication with the user is voice. Tell the user that someone from our team will reach out to you soon",
|
||||
"name": "Agent"
|
||||
},
|
||||
"measured": {
|
||||
"width": 300,
|
||||
"height": 100
|
||||
},
|
||||
"selected": false,
|
||||
"dragging": false
|
||||
},
|
||||
{
|
||||
"id": "6581",
|
||||
"type": "startCall",
|
||||
"position": {
|
||||
"x": 648,
|
||||
"y": 35
|
||||
},
|
||||
"data": {
|
||||
"prompt": "Hello, I am Abhishek from Dograh. ",
|
||||
"is_static": true,
|
||||
"name": "Start Call",
|
||||
"is_start": true
|
||||
},
|
||||
"measured": {
|
||||
"width": 300,
|
||||
"height": 100
|
||||
},
|
||||
"selected": false,
|
||||
"dragging": false
|
||||
},
|
||||
{
|
||||
"id": "1802",
|
||||
"type": "endCall",
|
||||
"position": {
|
||||
"x": 666.7733431033548,
|
||||
"y": 987.4345801025363
|
||||
},
|
||||
"data": {
|
||||
"prompt": "Thank you for calling Dograh. Have a great day!",
|
||||
"is_static": true,
|
||||
"name": "End Call"
|
||||
},
|
||||
"measured": {
|
||||
"width": 300,
|
||||
"height": 100
|
||||
},
|
||||
"selected": false,
|
||||
"dragging": false
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"animated": true,
|
||||
"type": "custom",
|
||||
"source": "915",
|
||||
"target": "7598",
|
||||
"id": "xy-edge__915-7598",
|
||||
"selected": false,
|
||||
"data": {
|
||||
"condition": "The customer wants to talk to a customer service agent",
|
||||
"label": "customer service agent"
|
||||
}
|
||||
},
|
||||
{
|
||||
"animated": true,
|
||||
"type": "custom",
|
||||
"source": "915",
|
||||
"target": "6919",
|
||||
"id": "xy-edge__915-6919",
|
||||
"selected": false,
|
||||
"data": {
|
||||
"condition": "customer wants to talk to a sales representative",
|
||||
"label": "sales representative"
|
||||
}
|
||||
},
|
||||
{
|
||||
"animated": true,
|
||||
"type": "custom",
|
||||
"source": "6581",
|
||||
"target": "915",
|
||||
"id": "xy-edge__6581-915",
|
||||
"selected": false,
|
||||
"data": {
|
||||
"condition": "Always take this route",
|
||||
"label": "Always take this route"
|
||||
}
|
||||
},
|
||||
{
|
||||
"animated": true,
|
||||
"type": "custom",
|
||||
"source": "7598",
|
||||
"target": "1802",
|
||||
"id": "xy-edge__7598-1802",
|
||||
"selected": false,
|
||||
"data": {
|
||||
"condition": "end call",
|
||||
"label": "end call"
|
||||
}
|
||||
},
|
||||
{
|
||||
"animated": true,
|
||||
"type": "custom",
|
||||
"source": "6919",
|
||||
"target": "1802",
|
||||
"id": "xy-edge__6919-1802",
|
||||
"selected": false,
|
||||
"data": {
|
||||
"condition": "end call",
|
||||
"label": "end call"
|
||||
}
|
||||
}
|
||||
],
|
||||
"viewport": {
|
||||
"x": 0,
|
||||
"y": 0,
|
||||
"zoom": 1
|
||||
}
|
||||
}
|
||||
192
api/tests/test_aggregation_fix.py
Normal file
192
api/tests/test_aggregation_fix.py
Normal file
|
|
@ -0,0 +1,192 @@
|
|||
from unittest.mock import Mock
|
||||
|
||||
from api.services.workflow.pipecat_engine_callbacks import (
|
||||
create_aggregation_correction_callback,
|
||||
)
|
||||
|
||||
|
||||
def test_aggregation_fixer():
|
||||
"""Validate the aggregation correction algorithm using a helper that
|
||||
creates a fresh callback for every (reference, corrupted) pair.
|
||||
|
||||
The production callback now needs a PipecatEngine instance with the
|
||||
`_current_llm_generation_reference_text` set. For test-friendliness we mock a bare
|
||||
object providing just that attribute for each assertion so the original
|
||||
two-argument test cases remain unchanged.
|
||||
"""
|
||||
|
||||
def fixer(reference: str, corrupted: str) -> str: # noqa: D401
|
||||
mock_engine = Mock()
|
||||
mock_engine._current_llm_generation_reference_text = reference
|
||||
return create_aggregation_correction_callback(mock_engine)(corrupted)
|
||||
|
||||
##### Trailing extra Chars #####
|
||||
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"My name is Alex and I am calling you from Cons umer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
)
|
||||
== "My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?"
|
||||
), "leading_whole_sentence"
|
||||
|
||||
# Whole sentences
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"Good Morning Mr NAR GES , My name is Alex and I am calling you from Cons umer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
)
|
||||
== "Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?"
|
||||
), "whole_sentences"
|
||||
|
||||
# With a period in the end
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"Good Morning Mr NAR GES , My name is Alex and I am calling you from Cons umer Services.",
|
||||
)
|
||||
== "Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services."
|
||||
), "period_end"
|
||||
|
||||
# without a period in the end
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"Good Morning Mr NAR GES , My name is Alex and I am calling you from Cons umer Services",
|
||||
)
|
||||
== "Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services"
|
||||
), "without_period_end"
|
||||
|
||||
# Extra space in the end
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"Good Morning Mr NAR GES , My name is Alex and I am calling you from Cons umer Services ",
|
||||
)
|
||||
== "Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services"
|
||||
), "extra_space"
|
||||
|
||||
# Multiple spaces in corruption
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"Good Morning Mr NAR GES , My name is Alex and I am calling you from Cons umer Servi ces ",
|
||||
)
|
||||
== "Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services"
|
||||
), "multiple_space"
|
||||
|
||||
# Multiple spaces in corruption ending in a whitespace
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"Good Morning Mr NAR GES , My name is Alex and I am calling you from Cons umer Servi ces. ",
|
||||
)
|
||||
== "Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. "
|
||||
), "multiple_space_end_ws"
|
||||
|
||||
##### Leading extra Chars #####
|
||||
|
||||
# Whole sentences
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"My name is Alex and I am calling you from Cons umer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
)
|
||||
== "My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?"
|
||||
), "leading_whole_sentence"
|
||||
|
||||
# With a period in the end
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"My name is Alex and I am calling you from Cons umer Services.",
|
||||
)
|
||||
== "My name is Alex and I am calling you from Consumer Services."
|
||||
), "leading_period_end"
|
||||
|
||||
# without a period in the end
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"My name is Alex and I am calling you from Cons umer Services",
|
||||
)
|
||||
== "My name is Alex and I am calling you from Consumer Services"
|
||||
), "leading_without_period_end"
|
||||
|
||||
# Extra space in the end
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"My name is Alex and I am calling you from Cons umer Services ",
|
||||
)
|
||||
== "My name is Alex and I am calling you from Consumer Services"
|
||||
), "leading_extra_space"
|
||||
|
||||
# Multiple spaces in corruption
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"My name is Alex and I am calling you from Cons umer Servi ces ",
|
||||
)
|
||||
== "My name is Alex and I am calling you from Consumer Services"
|
||||
), "leading_multiple_space"
|
||||
|
||||
# Multiple spaces in corruption ending in a whitespace
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"My name is Alex and I am calling you from Cons umer Servi ces. ",
|
||||
)
|
||||
== "My name is Alex and I am calling you from Consumer Services. "
|
||||
), "leading_multiple_space_end_ws"
|
||||
|
||||
# Whitespace
|
||||
assert fixer("", "") == ""
|
||||
|
||||
# Missing reference
|
||||
assert (
|
||||
fixer("", "My name is Alex and I am calling you from Cons umer Servi ces.")
|
||||
== "My name is Alex and I am calling you from Cons umer Servi ces."
|
||||
), "missing_reference"
|
||||
|
||||
# Smaller reference
|
||||
assert (
|
||||
fixer(
|
||||
"My name is Alex",
|
||||
"My name is Alex and I am calling you from Cons umer Servi ces.",
|
||||
)
|
||||
== "My name is Alex and I am calling you from Cons umer Servi ces."
|
||||
), "smaller_reference"
|
||||
|
||||
# Unrelated reference
|
||||
assert (
|
||||
fixer(
|
||||
"Hello Hello",
|
||||
"My name is Alex and I am calling you from Cons umer Servi ces.",
|
||||
)
|
||||
== "My name is Alex and I am calling you from Cons umer Servi ces."
|
||||
), "unrelated_reference"
|
||||
|
||||
|
||||
def test_create_aggregation_correction_callback():
|
||||
"""Test the new aggregation correction callback creator."""
|
||||
# Mock engine with reference text
|
||||
mock_engine = Mock()
|
||||
mock_engine._current_llm_generation_reference_text = "Good Morning Mr NARGES, My name is Alex and I am calling you from Consumer Services."
|
||||
|
||||
# Create callback
|
||||
callback = create_aggregation_correction_callback(mock_engine)
|
||||
|
||||
# Test correction
|
||||
corrected = callback(
|
||||
"Good Morning Mr NAR GES, My name is Alex and I am calling you from Cons umer Services."
|
||||
)
|
||||
assert (
|
||||
corrected
|
||||
== "Good Morning Mr NARGES, My name is Alex and I am calling you from Consumer Services."
|
||||
)
|
||||
|
||||
# Test with no reference text
|
||||
mock_engine._current_llm_generation_reference_text = ""
|
||||
corrected = callback("Some corrupted text")
|
||||
assert corrected == "Some corrupted text" # Should return as-is when no reference
|
||||
31
api/tests/test_cost_calculator.py
Normal file
31
api/tests/test_cost_calculator.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
from api.services.pricing.cost_calculator import cost_calculator
|
||||
|
||||
|
||||
def test_cost_calculator():
|
||||
"""Test function to verify cost calculation works"""
|
||||
sample_usage = {
|
||||
"llm": {
|
||||
"OpenAILLMService#0|||gpt-4.1-mini": {
|
||||
"prompt_tokens": 45380,
|
||||
"completion_tokens": 496,
|
||||
"total_tokens": 45876,
|
||||
"cache_read_input_tokens": 0,
|
||||
"cache_creation_input_tokens": 0,
|
||||
}
|
||||
},
|
||||
"tts": {"ElevenLabsTTSService#0|||eleven_flash_v2_5": 2399},
|
||||
"stt": {"DeepgramSTTService#0|||nova-3-general": 177.21536946296692},
|
||||
"call_duration_seconds": 179,
|
||||
}
|
||||
|
||||
result = cost_calculator.calculate_total_cost(sample_usage)
|
||||
assert result["llm_cost"] == 45380 * 0.40 / 1_000_000 + 496 * 1.60 / 1_000_000
|
||||
assert result["tts_cost"] == 2399 * 0.0256 / 1_000
|
||||
assert result["stt_cost"] == 177.21536946296692 / 60 * 0.0077
|
||||
assert (
|
||||
abs(
|
||||
result["total"]
|
||||
- (result["llm_cost"] + result["tts_cost"] + result["stt_cost"])
|
||||
)
|
||||
< 1e-10
|
||||
)
|
||||
|
|
@ -6,9 +6,7 @@ This module tests the full flow of:
|
|||
3. Verifying the context is properly configured for LLM generation
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -17,126 +15,14 @@ from api.services.workflow.pipecat_engine_utils import (
|
|||
get_function_schema,
|
||||
update_llm_context,
|
||||
)
|
||||
from api.tests.conftest import MockToolModel
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockToolModel:
|
||||
"""Mock tool model for testing."""
|
||||
|
||||
tool_uuid: str
|
||||
name: str
|
||||
description: str
|
||||
definition: Dict[str, Any]
|
||||
|
||||
|
||||
class TestCustomToolManagerContextIntegration:
|
||||
"""Integration tests for CustomToolManager with LLMContext."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_engine(self):
|
||||
"""Create a mock PipecatEngine."""
|
||||
engine = Mock()
|
||||
engine._workflow_run_id = 1
|
||||
engine._call_context_vars = {"customer_name": "John Doe"}
|
||||
engine.llm = Mock()
|
||||
engine.llm.register_function = Mock()
|
||||
return engine
|
||||
|
||||
@pytest.fixture
|
||||
def sample_tools(self):
|
||||
"""Create sample mock tools for testing."""
|
||||
return [
|
||||
MockToolModel(
|
||||
tool_uuid="weather-uuid-123",
|
||||
name="Get Weather",
|
||||
description="Get current weather for a location",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "http_api",
|
||||
"config": {
|
||||
"method": "GET",
|
||||
"url": "https://api.weather.com/current",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "location",
|
||||
"type": "string",
|
||||
"description": "City name (e.g., San Francisco, CA)",
|
||||
"required": True,
|
||||
},
|
||||
{
|
||||
"name": "units",
|
||||
"type": "string",
|
||||
"description": "Temperature units: celsius or fahrenheit",
|
||||
"required": False,
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
),
|
||||
MockToolModel(
|
||||
tool_uuid="booking-uuid-456",
|
||||
name="Book Appointment",
|
||||
description="Book an appointment for the customer",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "http_api",
|
||||
"config": {
|
||||
"method": "POST",
|
||||
"url": "https://api.example.com/appointments",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "customer_name",
|
||||
"type": "string",
|
||||
"description": "Customer's full name",
|
||||
"required": True,
|
||||
},
|
||||
{
|
||||
"name": "date",
|
||||
"type": "string",
|
||||
"description": "Appointment date (YYYY-MM-DD)",
|
||||
"required": True,
|
||||
},
|
||||
{
|
||||
"name": "time",
|
||||
"type": "string",
|
||||
"description": "Appointment time (HH:MM)",
|
||||
"required": True,
|
||||
},
|
||||
{
|
||||
"name": "notes",
|
||||
"type": "string",
|
||||
"description": "Additional notes",
|
||||
"required": False,
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
),
|
||||
MockToolModel(
|
||||
tool_uuid="lookup-uuid-789",
|
||||
name="Customer Lookup",
|
||||
description="Look up customer information by phone number",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "http_api",
|
||||
"config": {
|
||||
"method": "GET",
|
||||
"url": "https://api.example.com/customers/lookup",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "phone",
|
||||
"type": "string",
|
||||
"description": "Customer phone number",
|
||||
"required": True,
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_tool_schemas_and_update_context(self, mock_engine, sample_tools):
|
||||
"""Test fetching tool schemas via CustomToolManager and updating LLM context."""
|
||||
|
|
|
|||
11
api/tests/test_dto.py
Normal file
11
api/tests/test_dto.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
import pytest
|
||||
|
||||
from api.services.workflow.dto import ReactFlowDTO
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dto():
|
||||
# assert no exceptions are raised
|
||||
with open("tests/definitions/rf-1.json", "r") as f:
|
||||
dto = ReactFlowDTO.model_validate_json(f.read())
|
||||
assert dto is not None
|
||||
340
api/tests/test_pipecat_engine_tool_calls.py
Normal file
340
api/tests/test_pipecat_engine_tool_calls.py
Normal file
|
|
@ -0,0 +1,340 @@
|
|||
"""Tests for tool calls with PipecatEngine and MockLLM.
|
||||
|
||||
This module tests the behavior when the LLM generates tool calls (single or parallel),
|
||||
using PipecatEngine's actual function registration and execution logic.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, List
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services.pipecat.pipeline_engine_callbacks_processor import (
|
||||
PipelineEngineCallbacksProcessor,
|
||||
)
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.workflow import WorkflowGraph
|
||||
from api.tests.conftest import END_CALL_SYSTEM_PROMPT
|
||||
from pipecat.frames.frames import (
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
Frame,
|
||||
TextFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.tests import MockLLMService
|
||||
|
||||
|
||||
class MockBotStoppedSpeakingOnLLMTextFrameProcessor(FrameProcessor):
|
||||
"""
|
||||
Mocking the transport, where transport sends BotStartedSpeakingFrame
|
||||
and BotStoppedSpeakingFrame when it encounters a LLMTextFrame.
|
||||
"""
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, TextFrame):
|
||||
await self.push_frame(BotStartedSpeakingFrame())
|
||||
await self.push_frame(
|
||||
BotStartedSpeakingFrame(), direction=FrameDirection.UPSTREAM
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
await self.push_frame(BotStoppedSpeakingFrame())
|
||||
await self.push_frame(
|
||||
BotStoppedSpeakingFrame(), direction=FrameDirection.UPSTREAM
|
||||
)
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
async def run_pipeline_with_tool_calls(
|
||||
workflow: WorkflowGraph,
|
||||
functions: List[Dict[str, Any]],
|
||||
text: str | None = None,
|
||||
num_text_steps: int = 1,
|
||||
) -> tuple[MockLLMService, LLMContext]:
|
||||
"""Run a pipeline with mock tool calls and return the LLM for assertions.
|
||||
|
||||
Args:
|
||||
workflow: The workflow graph to use.
|
||||
functions: List of function call definitions with name, arguments, and tool_call_id.
|
||||
text: Text to add to the first step (streamed before the tool calls).
|
||||
num_text_steps: Number of text response steps after the tool calls.
|
||||
|
||||
Returns:
|
||||
The MockLLMService instance for making assertions.
|
||||
"""
|
||||
# Create first step chunks
|
||||
if text:
|
||||
# Create text chunks (without final chunk) followed by function call chunks
|
||||
text_chunks = MockLLMService.create_text_chunks(text)
|
||||
func_chunks = MockLLMService.create_multiple_function_call_chunks(functions)
|
||||
# Exclude the final chunk from text_chunks (which has finish_reason="stop")
|
||||
first_step_chunks = text_chunks[:-1] + func_chunks
|
||||
else:
|
||||
first_step_chunks = MockLLMService.create_multiple_function_call_chunks(
|
||||
functions
|
||||
)
|
||||
|
||||
# Create multi-step responses
|
||||
mock_steps = MockLLMService.create_multi_step_responses(
|
||||
first_step_chunks, num_text_steps=num_text_steps, step_prefix="Response"
|
||||
)
|
||||
|
||||
# Create MockLLMService with multi-step support
|
||||
llm = MockLLMService(mock_steps=mock_steps, chunk_delay=0.001)
|
||||
|
||||
mock_transport_emulator = MockBotStoppedSpeakingOnLLMTextFrameProcessor()
|
||||
|
||||
# Create LLM context
|
||||
context = LLMContext()
|
||||
|
||||
# Add assistant context aggregator
|
||||
assistant_params = LLMAssistantAggregatorParams(expect_stripped_words=True)
|
||||
context_aggregator = LLMContextAggregatorPair(
|
||||
context, assistant_params=assistant_params
|
||||
)
|
||||
assistant_context_aggregator = context_aggregator.assistant()
|
||||
|
||||
# Create PipecatEngine with the workflow
|
||||
engine = PipecatEngine(
|
||||
llm=llm,
|
||||
context=context,
|
||||
workflow=workflow,
|
||||
call_context_vars={"customer_name": "Test User"},
|
||||
workflow_run_id=1,
|
||||
)
|
||||
|
||||
# Create the pipeline with the mock LLM
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
llm,
|
||||
mock_transport_emulator,
|
||||
assistant_context_aggregator,
|
||||
]
|
||||
)
|
||||
|
||||
# Create a real pipeline task
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(allow_interruptions=False),
|
||||
)
|
||||
|
||||
engine.set_task(task)
|
||||
|
||||
# Patch DB calls to avoid actual database access
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run",
|
||||
new_callable=AsyncMock,
|
||||
return_value=1,
|
||||
):
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine.apply_disposition_mapping",
|
||||
new_callable=AsyncMock,
|
||||
return_value="completed",
|
||||
):
|
||||
runner = PipelineRunner()
|
||||
|
||||
async def run_pipeline():
|
||||
await runner.run(task)
|
||||
|
||||
async def initialize_engine():
|
||||
# Small delay to let runner start
|
||||
await asyncio.sleep(0.01)
|
||||
await engine.initialize()
|
||||
|
||||
# Run both concurrently
|
||||
await asyncio.gather(run_pipeline(), initialize_engine())
|
||||
|
||||
return llm, context
|
||||
|
||||
|
||||
class TestPipecatEngineToolCalls:
|
||||
"""Test tool calls through PipecatEngine."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_builtin_and_transition_calls_through_engine(
|
||||
self, simple_workflow: WorkflowGraph
|
||||
):
|
||||
"""Test parallel function calls using PipecatEngine's actual handlers.
|
||||
|
||||
This test verifies that when the LLM generates parallel tool calls for:
|
||||
1. A built-in function (safe_calculator) - registered by _register_builtin_functions
|
||||
2. A transition function (end_call) - registered by _register_transition_function_with_llm
|
||||
|
||||
Both functions are properly executed through the engine's handlers and
|
||||
the transition correctly moves to the end node.
|
||||
|
||||
The test uses multi-step mock responses:
|
||||
- Step 1: Parallel tool calls (safe_calculator + end_call)
|
||||
- Step 2+: Text responses for subsequent node prompts
|
||||
"""
|
||||
functions = [
|
||||
{
|
||||
"name": "end_call",
|
||||
"arguments": {},
|
||||
"tool_call_id": "call_transition",
|
||||
},
|
||||
{
|
||||
"name": "safe_calculator",
|
||||
"arguments": {"expression": "25 * 4"},
|
||||
"tool_call_id": "call_calc",
|
||||
},
|
||||
]
|
||||
|
||||
llm, context = await run_pipeline_with_tool_calls(
|
||||
workflow=simple_workflow,
|
||||
functions=functions,
|
||||
num_text_steps=2,
|
||||
)
|
||||
|
||||
# Assert that the LLM generation was called a total of 2 times,
|
||||
# 1st time when StartNode was executed, and second time
|
||||
# when EndCall generation happened
|
||||
assert llm.get_current_step() == 2, (
|
||||
"LLM generation should have happened 2 times"
|
||||
)
|
||||
|
||||
# Assert that the context was updated with END_CALL_SYSTEM_PROMPT
|
||||
assert context.messages[0]["content"] == END_CALL_SYSTEM_PROMPT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_builtin_and_transition_calls_through_engine_1(
|
||||
self, simple_workflow: WorkflowGraph
|
||||
):
|
||||
"""Test parallel function calls using PipecatEngine's actual handlers.
|
||||
|
||||
This test verifies that when the LLM generates parallel tool calls for:
|
||||
1. A built-in function (safe_calculator) - registered by _register_builtin_functions
|
||||
2. A transition function (end_call) - registered by _register_transition_function_with_llm
|
||||
|
||||
Both functions are properly executed through the engine's handlers and
|
||||
the transition correctly moves to the end node.
|
||||
|
||||
The test uses multi-step mock responses:
|
||||
- Step 1: Parallel tool calls (safe_calculator + end_call)
|
||||
- Step 2+: Text responses for subsequent node prompts
|
||||
"""
|
||||
functions = [
|
||||
{
|
||||
"name": "safe_calculator",
|
||||
"arguments": {"expression": "25 * 4"},
|
||||
"tool_call_id": "call_calc",
|
||||
},
|
||||
{
|
||||
"name": "end_call",
|
||||
"arguments": {},
|
||||
"tool_call_id": "call_transition",
|
||||
},
|
||||
]
|
||||
|
||||
llm, context = await run_pipeline_with_tool_calls(
|
||||
workflow=simple_workflow,
|
||||
functions=functions,
|
||||
num_text_steps=2,
|
||||
)
|
||||
|
||||
# Assert that the LLM generation was called a total of 2 times,
|
||||
# 1st time when StartNode was executed, and second time
|
||||
# when EndCall generation happened. The tool should not invoke
|
||||
# an LLM generation
|
||||
assert llm.get_current_step() == 2, (
|
||||
"LLM generation should have happened 2 times"
|
||||
)
|
||||
|
||||
# Assert that the context was updated with END_CALL_SYSTEM_PROMPT
|
||||
assert context.messages[0]["content"] == END_CALL_SYSTEM_PROMPT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_builtin_and_transition_calls_through_engine_with_text(
|
||||
self, simple_workflow: WorkflowGraph
|
||||
):
|
||||
"""Test parallel function calls using PipecatEngine's actual handlers.
|
||||
|
||||
This test verifies that when the LLM generates parallel tool calls for:
|
||||
1. A built-in function (safe_calculator) - registered by _register_builtin_functions
|
||||
2. A transition function (end_call) - registered by _register_transition_function_with_llm
|
||||
|
||||
Both functions are properly executed through the engine's handlers and
|
||||
the transition correctly moves to the end node.
|
||||
|
||||
The test uses multi-step mock responses:
|
||||
- Step 1: Parallel tool calls (safe_calculator + end_call)
|
||||
- Step 2+: Text responses for subsequent node prompts
|
||||
"""
|
||||
functions = [
|
||||
{
|
||||
"name": "end_call",
|
||||
"arguments": {},
|
||||
"tool_call_id": "call_transition",
|
||||
},
|
||||
{
|
||||
"name": "safe_calculator",
|
||||
"arguments": {"expression": "25 * 4"},
|
||||
"tool_call_id": "call_calc",
|
||||
},
|
||||
]
|
||||
|
||||
llm, context = await run_pipeline_with_tool_calls(
|
||||
workflow=simple_workflow,
|
||||
functions=functions,
|
||||
text="Hello There!",
|
||||
num_text_steps=2,
|
||||
)
|
||||
|
||||
# Assert that the LLM generation was called a total of 2 times,
|
||||
# 1st time when StartNode was executed, and second time
|
||||
# when EndCall generation happened. The tool should not invoke
|
||||
# an LLM generation
|
||||
assert llm.get_current_step() == 2, (
|
||||
"LLM generation should have happened 2 times"
|
||||
)
|
||||
|
||||
# Assert that the context was updated with END_CALL_SYSTEM_PROMPT
|
||||
assert context.messages[0]["content"] == END_CALL_SYSTEM_PROMPT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_transition_call_through_engine(
|
||||
self, simple_workflow: WorkflowGraph
|
||||
):
|
||||
"""Test a single transition function call (end_call) through PipecatEngine.
|
||||
|
||||
This test verifies that when the LLM generates only a transition tool call,
|
||||
the engine properly executes it and transitions to the end node.
|
||||
Since end_call transitions to the end node which triggers another LLM
|
||||
generation, the LLM is called exactly once for the initial StartNode.
|
||||
"""
|
||||
functions = [
|
||||
{
|
||||
"name": "end_call",
|
||||
"arguments": {},
|
||||
"tool_call_id": "call_transition",
|
||||
},
|
||||
]
|
||||
|
||||
llm, context = await run_pipeline_with_tool_calls(
|
||||
workflow=simple_workflow,
|
||||
functions=functions,
|
||||
num_text_steps=1,
|
||||
)
|
||||
|
||||
# LLM is called once for the StartNode, then end_call transitions to EndNode
|
||||
# which triggers a second generation
|
||||
assert llm.get_current_step() == 2, (
|
||||
"LLM generation should have happened 2 times"
|
||||
)
|
||||
|
||||
# Assert that the context was updated with END_CALL_SYSTEM_PROMPT
|
||||
assert context.messages[0]["content"] == END_CALL_SYSTEM_PROMPT
|
||||
Loading…
Add table
Add a link
Reference in a new issue