mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-16 08:25:18 +02:00
Fix/multiple generation (#104)
* fixes #100 * Fix test * fix: fix bad configuration issue
This commit is contained in:
parent
90b690efff
commit
56953bbd09
18 changed files with 758 additions and 460 deletions
|
|
@ -494,7 +494,6 @@ async def _run_pipeline(
|
|||
max_duration_end_task_callback=engine.create_max_duration_callback(),
|
||||
generation_started_callback=engine.create_generation_started_callback(),
|
||||
llm_text_frame_callback=engine.handle_llm_text_frame,
|
||||
# Note: speaking event callbacks are now handled by pre-aggregator processor
|
||||
)
|
||||
|
||||
pipeline_metrics_aggregator = PipelineMetricsAggregator()
|
||||
|
|
|
|||
|
|
@ -13,9 +13,8 @@ from pipecat.frames.frames import (
|
|||
CancelFrame,
|
||||
EndFrame,
|
||||
FunctionCallResultProperties,
|
||||
FunctionCallsFromLLMInfoFrame,
|
||||
LLMContextFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
TTSSpeakFrame,
|
||||
)
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
|
|
@ -104,7 +103,7 @@ class PipecatEngine:
|
|||
self._builtin_function_schemas: Optional[list[dict]] = None
|
||||
|
||||
# Track current LLM reference text for TTS aggregation correction
|
||||
self._current_llm_reference_text: str = ""
|
||||
self._current_llm_generation_reference_text: str = ""
|
||||
|
||||
# Custom tool manager (initialized in initialize())
|
||||
self._custom_tool_manager: Optional[CustomToolManager] = None
|
||||
|
|
@ -173,6 +172,9 @@ class PipecatEngine:
|
|||
await self._register_builtin_functions()
|
||||
|
||||
await self.set_node(self.workflow.start_node_id)
|
||||
|
||||
# Trigger initial LLM generation
|
||||
await self.task.queue_frame(LLMContextFrame(self.context))
|
||||
logger.debug(f"{self.__class__.__name__} initialized")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing {self.__class__.__name__}: {e}")
|
||||
|
|
@ -218,7 +220,6 @@ class PipecatEngine:
|
|||
result = {"status": "done"}
|
||||
|
||||
properties = FunctionCallResultProperties(
|
||||
run_llm=False,
|
||||
on_context_updated=on_context_updated,
|
||||
)
|
||||
|
||||
|
|
@ -256,8 +257,6 @@ class PipecatEngine:
|
|||
"""Register built-in functions (calculator and timezone) with the LLM."""
|
||||
logger.debug("Registering built-in functions with LLM")
|
||||
|
||||
properties = FunctionCallResultProperties(run_llm=True)
|
||||
|
||||
# Register calculator function
|
||||
async def calculate_func(function_call_params: FunctionCallParams) -> None:
|
||||
logger.info(f"LLM Function Call EXECUTED: safe_calculator")
|
||||
|
|
@ -266,12 +265,10 @@ class PipecatEngine:
|
|||
expr = function_call_params.arguments.get("expression", "")
|
||||
result = safe_calculator(expr)
|
||||
await function_call_params.result_callback(
|
||||
{"expression": expr, "result": result}, properties=properties
|
||||
{"expression": expr, "result": result}
|
||||
)
|
||||
except Exception as e:
|
||||
await function_call_params.result_callback(
|
||||
{"error": str(e)}, properties=properties
|
||||
)
|
||||
await function_call_params.result_callback({"error": str(e)})
|
||||
|
||||
# Register timezone functions
|
||||
async def get_current_time_func(
|
||||
|
|
@ -282,13 +279,9 @@ class PipecatEngine:
|
|||
try:
|
||||
timezone = function_call_params.arguments.get("timezone", "UTC")
|
||||
result = get_current_time(timezone)
|
||||
await function_call_params.result_callback(
|
||||
result, properties=properties
|
||||
)
|
||||
await function_call_params.result_callback(result)
|
||||
except Exception as e:
|
||||
await function_call_params.result_callback(
|
||||
{"error": str(e)}, properties=properties
|
||||
)
|
||||
await function_call_params.result_callback({"error": str(e)})
|
||||
|
||||
async def convert_time_func(function_call_params: FunctionCallParams) -> None:
|
||||
logger.info(f"LLM Function Call EXECUTED: convert_time")
|
||||
|
|
@ -299,29 +292,15 @@ class PipecatEngine:
|
|||
function_call_params.arguments.get("time"),
|
||||
function_call_params.arguments.get("target_timezone"),
|
||||
)
|
||||
await function_call_params.result_callback(
|
||||
result, properties=properties
|
||||
)
|
||||
await function_call_params.result_callback(result)
|
||||
except Exception as e:
|
||||
await function_call_params.result_callback(
|
||||
{"error": str(e)}, properties=properties
|
||||
)
|
||||
await function_call_params.result_callback({"error": str(e)})
|
||||
|
||||
# Register all built-in functions
|
||||
self.llm.register_function("safe_calculator", calculate_func)
|
||||
self.llm.register_function("get_current_time", get_current_time_func)
|
||||
self.llm.register_function("convert_time", convert_time_func)
|
||||
|
||||
async def _queue_tts_response(self, text: str) -> None:
|
||||
"""Queue TTS frames for static text response."""
|
||||
await self.task.queue_frames(
|
||||
[
|
||||
LLMFullResponseStartFrame(),
|
||||
TTSSpeakFrame(text=text),
|
||||
LLMFullResponseEndFrame(),
|
||||
]
|
||||
)
|
||||
|
||||
async def _perform_variable_extraction_if_needed(
|
||||
self, previous_node: Optional[Node]
|
||||
) -> None:
|
||||
|
|
@ -384,7 +363,6 @@ class PipecatEngine:
|
|||
functions,
|
||||
) = await self._compose_system_message_functions_for_node(node)
|
||||
await self._update_llm_context(system_message, functions)
|
||||
await self.task.queue_frame(LLMContextFrame(self.context))
|
||||
|
||||
async def set_node(self, node_id: str):
|
||||
"""
|
||||
|
|
@ -733,7 +711,7 @@ class PipecatEngine:
|
|||
|
||||
async def handle_llm_text_frame(self, text: str):
|
||||
"""Accumulate LLM text frames to build reference text."""
|
||||
self._current_llm_reference_text += text
|
||||
self._current_llm_generation_reference_text += text
|
||||
|
||||
def handle_client_disconnected(self):
|
||||
"""Handle client disconnected event."""
|
||||
|
|
|
|||
|
|
@ -114,10 +114,10 @@ def create_max_duration_callback(engine: "PipecatEngine"):
|
|||
def create_generation_started_callback(engine: "PipecatEngine"):
|
||||
"""Return a callback that resets flags at the start of each LLM generation."""
|
||||
|
||||
async def handle_generation_started(): # noqa: D401
|
||||
async def handle_generation_started():
|
||||
logger.debug("LLM generation started in callback processor")
|
||||
# Clear reference text from previous generation
|
||||
engine._current_llm_reference_text = ""
|
||||
engine._current_llm_generation_reference_text = ""
|
||||
|
||||
return handle_generation_started
|
||||
|
||||
|
|
@ -184,7 +184,7 @@ def create_aggregation_correction_callback(engine: "PipecatEngine"):
|
|||
return "".join(out_chars)
|
||||
|
||||
def correct_aggregation(corrupted: str) -> str:
|
||||
reference = engine._current_llm_reference_text
|
||||
reference = engine._current_llm_generation_reference_text
|
||||
|
||||
if not reference:
|
||||
logger.warning("No reference text available for aggregation correction")
|
||||
|
|
|
|||
|
|
@ -1,164 +0,0 @@
|
|||
{
|
||||
"nodes": [
|
||||
{
|
||||
"id": "915",
|
||||
"type": "agentNode",
|
||||
"position": {
|
||||
"x": 633,
|
||||
"y": 324
|
||||
},
|
||||
"data": {
|
||||
"prompt": "You are a voice agent whose mode of speaking is voice. Ask the user whether they want to talk to a sales guy or a customer service agent",
|
||||
"name": "Agent"
|
||||
},
|
||||
"measured": {
|
||||
"width": 300,
|
||||
"height": 100
|
||||
},
|
||||
"selected": false,
|
||||
"dragging": false
|
||||
},
|
||||
{
|
||||
"id": "7598",
|
||||
"type": "agentNode",
|
||||
"position": {
|
||||
"x": 460.1247806640531,
|
||||
"y": 610.3714977079578
|
||||
},
|
||||
"data": {
|
||||
"prompt": "You are a customer service agent whose mode of communication with the user is voice. Tell them that someone from our team will reach out to them soon",
|
||||
"name": "Agent"
|
||||
},
|
||||
"measured": {
|
||||
"width": 300,
|
||||
"height": 100
|
||||
},
|
||||
"selected": false,
|
||||
"dragging": false
|
||||
},
|
||||
{
|
||||
"id": "6919",
|
||||
"type": "agentNode",
|
||||
"position": {
|
||||
"x": 914.666735413607,
|
||||
"y": 642.9800281289787
|
||||
},
|
||||
"data": {
|
||||
"prompt": "You are a sales representative whose mode of communication with the user is voice. Tell the user that someone from our team will reach out to you soon",
|
||||
"name": "Agent"
|
||||
},
|
||||
"measured": {
|
||||
"width": 300,
|
||||
"height": 100
|
||||
},
|
||||
"selected": false,
|
||||
"dragging": false
|
||||
},
|
||||
{
|
||||
"id": "6581",
|
||||
"type": "startCall",
|
||||
"position": {
|
||||
"x": 648,
|
||||
"y": 35
|
||||
},
|
||||
"data": {
|
||||
"prompt": "Hello, I am Abhishek from Dograh. ",
|
||||
"is_static": true,
|
||||
"name": "Start Call",
|
||||
"is_start": true
|
||||
},
|
||||
"measured": {
|
||||
"width": 300,
|
||||
"height": 100
|
||||
},
|
||||
"selected": false,
|
||||
"dragging": false
|
||||
},
|
||||
{
|
||||
"id": "1802",
|
||||
"type": "endCall",
|
||||
"position": {
|
||||
"x": 666.7733431033548,
|
||||
"y": 987.4345801025363
|
||||
},
|
||||
"data": {
|
||||
"prompt": "Thank you for calling Dograh. Have a great day!",
|
||||
"is_static": true,
|
||||
"name": "End Call"
|
||||
},
|
||||
"measured": {
|
||||
"width": 300,
|
||||
"height": 100
|
||||
},
|
||||
"selected": false,
|
||||
"dragging": false
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"animated": true,
|
||||
"type": "custom",
|
||||
"source": "915",
|
||||
"target": "7598",
|
||||
"id": "xy-edge__915-7598",
|
||||
"selected": false,
|
||||
"data": {
|
||||
"condition": "The customer wants to talk to a customer service agent",
|
||||
"label": "customer service agent"
|
||||
}
|
||||
},
|
||||
{
|
||||
"animated": true,
|
||||
"type": "custom",
|
||||
"source": "915",
|
||||
"target": "6919",
|
||||
"id": "xy-edge__915-6919",
|
||||
"selected": false,
|
||||
"data": {
|
||||
"condition": "customer wants to talk to a sales representative",
|
||||
"label": "sales representative"
|
||||
}
|
||||
},
|
||||
{
|
||||
"animated": true,
|
||||
"type": "custom",
|
||||
"source": "6581",
|
||||
"target": "915",
|
||||
"id": "xy-edge__6581-915",
|
||||
"selected": false,
|
||||
"data": {
|
||||
"condition": "Always take this route",
|
||||
"label": "Always take this route"
|
||||
}
|
||||
},
|
||||
{
|
||||
"animated": true,
|
||||
"type": "custom",
|
||||
"source": "7598",
|
||||
"target": "1802",
|
||||
"id": "xy-edge__7598-1802",
|
||||
"selected": false,
|
||||
"data": {
|
||||
"condition": "end call",
|
||||
"label": "end call"
|
||||
}
|
||||
},
|
||||
{
|
||||
"animated": true,
|
||||
"type": "custom",
|
||||
"source": "6919",
|
||||
"target": "1802",
|
||||
"id": "xy-edge__6919-1802",
|
||||
"selected": false,
|
||||
"data": {
|
||||
"condition": "end call",
|
||||
"label": "end call"
|
||||
}
|
||||
}
|
||||
],
|
||||
"viewport": {
|
||||
"x": 0,
|
||||
"y": 0,
|
||||
"zoom": 1
|
||||
}
|
||||
}
|
||||
|
|
@ -1,192 +0,0 @@
|
|||
from unittest.mock import Mock
|
||||
|
||||
from api.services.workflow.pipecat_engine_callbacks import (
|
||||
create_aggregation_correction_callback,
|
||||
)
|
||||
|
||||
|
||||
def test_aggregation_fixer():
|
||||
"""Validate the aggregation correction algorithm using a helper that
|
||||
creates a fresh callback for every (reference, corrupted) pair.
|
||||
|
||||
The production callback now needs a PipecatEngine instance with the
|
||||
`_current_llm_reference_text` set. For test-friendliness we mock a bare
|
||||
object providing just that attribute for each assertion so the original
|
||||
two-argument test cases remain unchanged.
|
||||
"""
|
||||
|
||||
def fixer(reference: str, corrupted: str) -> str: # noqa: D401
|
||||
mock_engine = Mock()
|
||||
mock_engine._current_llm_reference_text = reference
|
||||
return create_aggregation_correction_callback(mock_engine)(corrupted)
|
||||
|
||||
##### Trailing extra Chars #####
|
||||
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"My name is Alex and I am calling you from Cons umer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
)
|
||||
== "My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?"
|
||||
), "leading_whole_sentence"
|
||||
|
||||
# Whole sentences
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"Good Morning Mr NAR GES , My name is Alex and I am calling you from Cons umer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
)
|
||||
== "Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?"
|
||||
), "whole_sentences"
|
||||
|
||||
# With a period in the end
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"Good Morning Mr NAR GES , My name is Alex and I am calling you from Cons umer Services.",
|
||||
)
|
||||
== "Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services."
|
||||
), "period_end"
|
||||
|
||||
# without a period in the end
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"Good Morning Mr NAR GES , My name is Alex and I am calling you from Cons umer Services",
|
||||
)
|
||||
== "Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services"
|
||||
), "without_period_end"
|
||||
|
||||
# Extra space in the end
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"Good Morning Mr NAR GES , My name is Alex and I am calling you from Cons umer Services ",
|
||||
)
|
||||
== "Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services"
|
||||
), "extra_space"
|
||||
|
||||
# Multiple spaces in corruption
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"Good Morning Mr NAR GES , My name is Alex and I am calling you from Cons umer Servi ces ",
|
||||
)
|
||||
== "Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services"
|
||||
), "multiple_space"
|
||||
|
||||
# Multiple spaces in corruption ending in a whitespace
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"Good Morning Mr NAR GES , My name is Alex and I am calling you from Cons umer Servi ces. ",
|
||||
)
|
||||
== "Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. "
|
||||
), "multiple_space_end_ws"
|
||||
|
||||
##### Leading extra Chars #####
|
||||
|
||||
# Whole sentences
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"My name is Alex and I am calling you from Cons umer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
)
|
||||
== "My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?"
|
||||
), "leading_whole_sentence"
|
||||
|
||||
# With a period in the end
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"My name is Alex and I am calling you from Cons umer Services.",
|
||||
)
|
||||
== "My name is Alex and I am calling you from Consumer Services."
|
||||
), "leading_period_end"
|
||||
|
||||
# without a period in the end
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"My name is Alex and I am calling you from Cons umer Services",
|
||||
)
|
||||
== "My name is Alex and I am calling you from Consumer Services"
|
||||
), "leading_without_period_end"
|
||||
|
||||
# Extra space in the end
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"My name is Alex and I am calling you from Cons umer Services ",
|
||||
)
|
||||
== "My name is Alex and I am calling you from Consumer Services"
|
||||
), "leading_extra_space"
|
||||
|
||||
# Multiple spaces in corruption
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"My name is Alex and I am calling you from Cons umer Servi ces ",
|
||||
)
|
||||
== "My name is Alex and I am calling you from Consumer Services"
|
||||
), "leading_multiple_space"
|
||||
|
||||
# Multiple spaces in corruption ending in a whitespace
|
||||
assert (
|
||||
fixer(
|
||||
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
|
||||
"My name is Alex and I am calling you from Cons umer Servi ces. ",
|
||||
)
|
||||
== "My name is Alex and I am calling you from Consumer Services. "
|
||||
), "leading_multiple_space_end_ws"
|
||||
|
||||
# Whitespace
|
||||
assert fixer("", "") == ""
|
||||
|
||||
# Missing reference
|
||||
assert (
|
||||
fixer("", "My name is Alex and I am calling you from Cons umer Servi ces.")
|
||||
== "My name is Alex and I am calling you from Cons umer Servi ces."
|
||||
), "missing_reference"
|
||||
|
||||
# Smaller reference
|
||||
assert (
|
||||
fixer(
|
||||
"My name is Alex",
|
||||
"My name is Alex and I am calling you from Cons umer Servi ces.",
|
||||
)
|
||||
== "My name is Alex and I am calling you from Cons umer Servi ces."
|
||||
), "smaller_reference"
|
||||
|
||||
# Unrelated reference
|
||||
assert (
|
||||
fixer(
|
||||
"Hello Hello",
|
||||
"My name is Alex and I am calling you from Cons umer Servi ces.",
|
||||
)
|
||||
== "My name is Alex and I am calling you from Cons umer Servi ces."
|
||||
), "unrelated_reference"
|
||||
|
||||
|
||||
def test_create_aggregation_correction_callback():
|
||||
"""Test the new aggregation correction callback creator."""
|
||||
# Mock engine with reference text
|
||||
mock_engine = Mock()
|
||||
mock_engine._current_llm_reference_text = "Good Morning Mr NARGES, My name is Alex and I am calling you from Consumer Services."
|
||||
|
||||
# Create callback
|
||||
callback = create_aggregation_correction_callback(mock_engine)
|
||||
|
||||
# Test correction
|
||||
corrected = callback(
|
||||
"Good Morning Mr NAR GES, My name is Alex and I am calling you from Cons umer Services."
|
||||
)
|
||||
assert (
|
||||
corrected
|
||||
== "Good Morning Mr NARGES, My name is Alex and I am calling you from Consumer Services."
|
||||
)
|
||||
|
||||
# Test with no reference text
|
||||
mock_engine._current_llm_reference_text = ""
|
||||
corrected = callback("Some corrupted text")
|
||||
assert corrected == "Some corrupted text" # Should return as-is when no reference
|
||||
|
|
@ -1,128 +0,0 @@
|
|||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
|
||||
from pipecat.services.openai.llm import OpenAILLMContext
|
||||
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.pipecat_engine_callbacks import (
|
||||
create_generation_started_callback,
|
||||
)
|
||||
|
||||
|
||||
class TestAggregationIntegration:
|
||||
"""Integration tests for the TTS aggregation correction flow."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_engine_reference_text_tracking(self):
|
||||
"""Test that the engine properly tracks LLM reference text."""
|
||||
# Create mock dependencies
|
||||
mock_task = Mock()
|
||||
mock_llm = Mock()
|
||||
mock_context = Mock(spec=OpenAILLMContext)
|
||||
mock_tts = Mock()
|
||||
mock_workflow = Mock()
|
||||
mock_workflow.start_node_id = "start"
|
||||
mock_workflow.nodes = {
|
||||
"start": Mock(is_start=True, is_static=True, is_end=False, out_edges=[])
|
||||
}
|
||||
|
||||
# Create engine
|
||||
engine = PipecatEngine(
|
||||
task=mock_task,
|
||||
llm=mock_llm,
|
||||
context=mock_context,
|
||||
tts=mock_tts,
|
||||
workflow=mock_workflow,
|
||||
call_context_vars={},
|
||||
workflow_run_id=1,
|
||||
)
|
||||
|
||||
# Test initial state
|
||||
assert engine._current_llm_reference_text == ""
|
||||
|
||||
# Test accumulating LLM text
|
||||
await engine.handle_llm_text_frame("Hello ")
|
||||
assert engine._current_llm_reference_text == "Hello "
|
||||
|
||||
await engine.handle_llm_text_frame("world!")
|
||||
assert engine._current_llm_reference_text == "Hello world!"
|
||||
|
||||
# Test generation started callback clears reference text
|
||||
callback = create_generation_started_callback(engine)
|
||||
await callback()
|
||||
assert engine._current_llm_reference_text == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aggregation_correction_callback_creation(self):
|
||||
"""Test creating the aggregation correction callback."""
|
||||
# Create mock engine
|
||||
mock_task = Mock()
|
||||
mock_llm = Mock()
|
||||
mock_context = Mock(spec=OpenAILLMContext)
|
||||
mock_workflow = Mock()
|
||||
|
||||
engine = PipecatEngine(
|
||||
task=mock_task,
|
||||
llm=mock_llm,
|
||||
context=mock_context,
|
||||
workflow=mock_workflow,
|
||||
call_context_vars={},
|
||||
workflow_run_id=1,
|
||||
)
|
||||
|
||||
# Set reference text
|
||||
engine._current_llm_reference_text = "Hello, world! How are you?"
|
||||
|
||||
# Create correction callback
|
||||
callback = engine.create_aggregation_correction_callback()
|
||||
|
||||
# Test correction - note that trailing punctuation might be stripped if not in corrupted text
|
||||
corrected = callback("Hello world How are you")
|
||||
assert corrected == "Hello, world! How are you"
|
||||
|
||||
def test_llm_assistant_aggregator_params_with_callback(self):
|
||||
"""Test that LLMAssistantAggregatorParams accepts correction callback."""
|
||||
|
||||
def mock_callback(text: str) -> str:
|
||||
return text.upper()
|
||||
|
||||
params = LLMAssistantAggregatorParams(
|
||||
expect_stripped_words=True, correct_aggregation_callback=mock_callback
|
||||
)
|
||||
|
||||
assert params.expect_stripped_words is True
|
||||
assert params.correct_aggregation_callback is not None
|
||||
assert params.correct_aggregation_callback("hello") == "HELLO"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_callbacks_processor_llm_text_frame(self):
|
||||
"""Test that PipelineEngineCallbacksProcessor handles LLMTextFrame."""
|
||||
from pipecat.frames.frames import LLMTextFrame
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
|
||||
from api.services.pipecat.pipeline_engine_callbacks_processor import (
|
||||
PipelineEngineCallbacksProcessor,
|
||||
)
|
||||
|
||||
# Track callback invocations
|
||||
callback_invoked = False
|
||||
callback_text = None
|
||||
|
||||
async def mock_llm_text_callback(text: str):
|
||||
nonlocal callback_invoked, callback_text
|
||||
callback_invoked = True
|
||||
callback_text = text
|
||||
|
||||
# Create processor with callback
|
||||
processor = PipelineEngineCallbacksProcessor(
|
||||
llm_text_frame_callback=mock_llm_text_callback
|
||||
)
|
||||
|
||||
# Process LLMTextFrame
|
||||
frame = LLMTextFrame(text="Hello world")
|
||||
await processor.process_frame(frame, FrameDirection.DOWNSTREAM)
|
||||
|
||||
# Verify callback was invoked
|
||||
assert callback_invoked is True
|
||||
assert callback_text == "Hello world"
|
||||
|
|
@ -1,31 +0,0 @@
|
|||
from api.services.pricing.cost_calculator import cost_calculator
|
||||
|
||||
|
||||
def test_cost_calculator():
|
||||
"""Test function to verify cost calculation works"""
|
||||
sample_usage = {
|
||||
"llm": {
|
||||
"OpenAILLMService#0|||gpt-4.1-mini": {
|
||||
"prompt_tokens": 45380,
|
||||
"completion_tokens": 496,
|
||||
"total_tokens": 45876,
|
||||
"cache_read_input_tokens": 0,
|
||||
"cache_creation_input_tokens": 0,
|
||||
}
|
||||
},
|
||||
"tts": {"ElevenLabsTTSService#0|||eleven_flash_v2_5": 2399},
|
||||
"stt": {"DeepgramSTTService#0|||nova-3-general": 177.21536946296692},
|
||||
"call_duration_seconds": 179,
|
||||
}
|
||||
|
||||
result = cost_calculator.calculate_total_cost(sample_usage)
|
||||
assert result["llm_cost"] == 45380 * 0.40 / 1_000_000 + 496 * 1.60 / 1_000_000
|
||||
assert result["tts_cost"] == 2399 * 0.0256 / 1_000
|
||||
assert result["stt_cost"] == 177.21536946296692 / 60 * 0.0077
|
||||
assert (
|
||||
abs(
|
||||
result["total"]
|
||||
- (result["llm_cost"] + result["tts_cost"] + result["stt_cost"])
|
||||
)
|
||||
< 1e-10
|
||||
)
|
||||
|
|
@ -1,11 +0,0 @@
|
|||
import pytest
|
||||
|
||||
from api.services.workflow.dto import ReactFlowDTO
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dto():
|
||||
# assert no exceptions are raised
|
||||
with open("services/workflow/test/definitions/rf-1.json", "r") as f:
|
||||
dto = ReactFlowDTO.model_validate_json(f.read())
|
||||
assert dto is not None
|
||||
|
|
@ -1,159 +0,0 @@
|
|||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
from pipecat.frames.frames import StartInterruptionFrame
|
||||
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
|
||||
from pipecat.services.openai.llm import (
|
||||
OpenAIAssistantContextAggregator,
|
||||
OpenAILLMContext,
|
||||
)
|
||||
|
||||
|
||||
class TestInterruptionCorrection:
|
||||
"""Test that TTS aggregation correction works during interruptions."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_interruption_with_correction(self):
|
||||
"""Test OpenAI assistant context aggregator applies correction during interruption."""
|
||||
# Create mock context
|
||||
mock_context = Mock(spec=OpenAILLMContext)
|
||||
mock_context.get_messages.return_value = []
|
||||
mock_context.add_message = Mock()
|
||||
|
||||
# Create correction callback
|
||||
def correction_callback(text: str) -> str:
|
||||
# Simulate fixing corrupted text
|
||||
if text == "Hello world how are you":
|
||||
return "Hello world, how are you"
|
||||
return text
|
||||
|
||||
# Create aggregator with correction callback
|
||||
params = LLMAssistantAggregatorParams(
|
||||
expect_stripped_words=True, correct_aggregation_callback=correction_callback
|
||||
)
|
||||
|
||||
aggregator = OpenAIAssistantContextAggregator(
|
||||
context=mock_context, params=params
|
||||
)
|
||||
|
||||
# Set up aggregation state
|
||||
aggregator._aggregation = "Hello world how are you"
|
||||
aggregator._current_llm_response_id = "test-id"
|
||||
aggregator._response_function_messages = {}
|
||||
aggregator._function_calls_in_progress = {}
|
||||
aggregator._started = 1
|
||||
|
||||
# Mock push_context_frame and reset methods
|
||||
aggregator.push_context_frame = AsyncMock()
|
||||
aggregator.reset = AsyncMock()
|
||||
|
||||
# Process interruption
|
||||
interruption_frame = StartInterruptionFrame()
|
||||
await aggregator._handle_interruptions(interruption_frame)
|
||||
|
||||
# Verify the corrected text was added to context
|
||||
mock_context.add_message.assert_called_once()
|
||||
added_message = mock_context.add_message.call_args[0][0]
|
||||
assert added_message["role"] == "assistant"
|
||||
assert (
|
||||
added_message["content"]
|
||||
== "Hello world, how are you <<interrupted_by_user>>"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_google_interruption_with_correction(self):
|
||||
"""Test Google assistant context aggregator applies correction during interruption."""
|
||||
from pipecat.services.google.llm import (
|
||||
Content,
|
||||
GoogleAssistantContextAggregator,
|
||||
)
|
||||
|
||||
# Create mock context
|
||||
mock_context = Mock(spec=OpenAILLMContext)
|
||||
mock_context.get_messages.return_value = []
|
||||
mock_context.add_message = Mock()
|
||||
|
||||
# Create correction callback
|
||||
def correction_callback(text: str) -> str:
|
||||
# Simulate fixing corrupted text
|
||||
if text == "I am here to help":
|
||||
return "I am here to help"
|
||||
return text
|
||||
|
||||
# Create aggregator with correction callback
|
||||
params = LLMAssistantAggregatorParams(
|
||||
expect_stripped_words=True, correct_aggregation_callback=correction_callback
|
||||
)
|
||||
|
||||
aggregator = GoogleAssistantContextAggregator(
|
||||
context=mock_context, params=params
|
||||
)
|
||||
|
||||
# Set up aggregation state
|
||||
aggregator._aggregation = "I am here to help"
|
||||
aggregator._current_llm_response_id = "test-id"
|
||||
aggregator._response_function_messages = {}
|
||||
aggregator._function_calls_in_progress = {}
|
||||
aggregator._started = 1
|
||||
|
||||
# Mock push_context_frame and reset methods
|
||||
aggregator.push_context_frame = AsyncMock()
|
||||
aggregator.reset = AsyncMock()
|
||||
|
||||
# Process interruption
|
||||
interruption_frame = StartInterruptionFrame()
|
||||
await aggregator._handle_interruptions(interruption_frame)
|
||||
|
||||
# Verify the corrected text was added to context
|
||||
mock_context.add_message.assert_called_once()
|
||||
added_content = mock_context.add_message.call_args[0][0]
|
||||
|
||||
# Google uses Content objects
|
||||
assert isinstance(added_content, Content)
|
||||
assert added_content.role == "model"
|
||||
assert len(added_content.parts) == 1
|
||||
assert (
|
||||
added_content.parts[0].text == "I am here to help <<interrupted_by_user>>"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_interruption_correction_error_handling(self):
|
||||
"""Test that interruption handling continues even if correction callback fails."""
|
||||
# Create mock context
|
||||
mock_context = Mock(spec=OpenAILLMContext)
|
||||
mock_context.get_messages.return_value = []
|
||||
mock_context.add_message = Mock()
|
||||
|
||||
# Create correction callback that raises error
|
||||
def failing_callback(text: str) -> str:
|
||||
raise ValueError("Correction failed")
|
||||
|
||||
# Create aggregator with failing callback
|
||||
params = LLMAssistantAggregatorParams(
|
||||
expect_stripped_words=True, correct_aggregation_callback=failing_callback
|
||||
)
|
||||
|
||||
aggregator = OpenAIAssistantContextAggregator(
|
||||
context=mock_context, params=params
|
||||
)
|
||||
|
||||
# Set up aggregation state
|
||||
aggregator._aggregation = "Some text"
|
||||
aggregator._current_llm_response_id = "test-id"
|
||||
aggregator._response_function_messages = {}
|
||||
aggregator._function_calls_in_progress = {}
|
||||
aggregator._started = 1
|
||||
|
||||
# Mock push_context_frame and reset methods
|
||||
aggregator.push_context_frame = AsyncMock()
|
||||
aggregator.reset = AsyncMock()
|
||||
|
||||
# Process interruption - should not raise
|
||||
interruption_frame = StartInterruptionFrame()
|
||||
await aggregator._handle_interruptions(interruption_frame)
|
||||
|
||||
# Verify the original text was still added (fallback behavior)
|
||||
mock_context.add_message.assert_called_once()
|
||||
added_message = mock_context.add_message.call_args[0][0]
|
||||
assert added_message["role"] == "assistant"
|
||||
assert added_message["content"] == "Some text <<interrupted_by_user>>"
|
||||
Loading…
Add table
Add a link
Reference in a new issue