Fix/multiple generation (#104)

* fixes #100

* Fix test

* fix: fix bad configuration issue
This commit is contained in:
Abhishek 2026-01-03 12:59:18 +05:30 committed by GitHub
parent 90b690efff
commit 56953bbd09
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 758 additions and 460 deletions

View file

@ -494,7 +494,6 @@ async def _run_pipeline(
max_duration_end_task_callback=engine.create_max_duration_callback(),
generation_started_callback=engine.create_generation_started_callback(),
llm_text_frame_callback=engine.handle_llm_text_frame,
# Note: speaking event callbacks are now handled by pre-aggregator processor
)
pipeline_metrics_aggregator = PipelineMetricsAggregator()

View file

@ -13,9 +13,8 @@ from pipecat.frames.frames import (
CancelFrame,
EndFrame,
FunctionCallResultProperties,
FunctionCallsFromLLMInfoFrame,
LLMContextFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
TTSSpeakFrame,
)
from pipecat.pipeline.task import PipelineTask
@ -104,7 +103,7 @@ class PipecatEngine:
self._builtin_function_schemas: Optional[list[dict]] = None
# Track current LLM reference text for TTS aggregation correction
self._current_llm_reference_text: str = ""
self._current_llm_generation_reference_text: str = ""
# Custom tool manager (initialized in initialize())
self._custom_tool_manager: Optional[CustomToolManager] = None
@ -173,6 +172,9 @@ class PipecatEngine:
await self._register_builtin_functions()
await self.set_node(self.workflow.start_node_id)
# Trigger initial LLM generation
await self.task.queue_frame(LLMContextFrame(self.context))
logger.debug(f"{self.__class__.__name__} initialized")
except Exception as e:
logger.error(f"Error initializing {self.__class__.__name__}: {e}")
@ -218,7 +220,6 @@ class PipecatEngine:
result = {"status": "done"}
properties = FunctionCallResultProperties(
run_llm=False,
on_context_updated=on_context_updated,
)
@ -256,8 +257,6 @@ class PipecatEngine:
"""Register built-in functions (calculator and timezone) with the LLM."""
logger.debug("Registering built-in functions with LLM")
properties = FunctionCallResultProperties(run_llm=True)
# Register calculator function
async def calculate_func(function_call_params: FunctionCallParams) -> None:
logger.info(f"LLM Function Call EXECUTED: safe_calculator")
@ -266,12 +265,10 @@ class PipecatEngine:
expr = function_call_params.arguments.get("expression", "")
result = safe_calculator(expr)
await function_call_params.result_callback(
{"expression": expr, "result": result}, properties=properties
{"expression": expr, "result": result}
)
except Exception as e:
await function_call_params.result_callback(
{"error": str(e)}, properties=properties
)
await function_call_params.result_callback({"error": str(e)})
# Register timezone functions
async def get_current_time_func(
@ -282,13 +279,9 @@ class PipecatEngine:
try:
timezone = function_call_params.arguments.get("timezone", "UTC")
result = get_current_time(timezone)
await function_call_params.result_callback(
result, properties=properties
)
await function_call_params.result_callback(result)
except Exception as e:
await function_call_params.result_callback(
{"error": str(e)}, properties=properties
)
await function_call_params.result_callback({"error": str(e)})
async def convert_time_func(function_call_params: FunctionCallParams) -> None:
logger.info(f"LLM Function Call EXECUTED: convert_time")
@ -299,29 +292,15 @@ class PipecatEngine:
function_call_params.arguments.get("time"),
function_call_params.arguments.get("target_timezone"),
)
await function_call_params.result_callback(
result, properties=properties
)
await function_call_params.result_callback(result)
except Exception as e:
await function_call_params.result_callback(
{"error": str(e)}, properties=properties
)
await function_call_params.result_callback({"error": str(e)})
# Register all built-in functions
self.llm.register_function("safe_calculator", calculate_func)
self.llm.register_function("get_current_time", get_current_time_func)
self.llm.register_function("convert_time", convert_time_func)
async def _queue_tts_response(self, text: str) -> None:
"""Queue TTS frames for static text response."""
await self.task.queue_frames(
[
LLMFullResponseStartFrame(),
TTSSpeakFrame(text=text),
LLMFullResponseEndFrame(),
]
)
async def _perform_variable_extraction_if_needed(
self, previous_node: Optional[Node]
) -> None:
@ -384,7 +363,6 @@ class PipecatEngine:
functions,
) = await self._compose_system_message_functions_for_node(node)
await self._update_llm_context(system_message, functions)
await self.task.queue_frame(LLMContextFrame(self.context))
async def set_node(self, node_id: str):
"""
@ -733,7 +711,7 @@ class PipecatEngine:
async def handle_llm_text_frame(self, text: str):
"""Accumulate LLM text frames to build reference text."""
self._current_llm_reference_text += text
self._current_llm_generation_reference_text += text
def handle_client_disconnected(self):
"""Handle client disconnected event."""

View file

@ -114,10 +114,10 @@ def create_max_duration_callback(engine: "PipecatEngine"):
def create_generation_started_callback(engine: "PipecatEngine"):
"""Return a callback that resets flags at the start of each LLM generation."""
async def handle_generation_started(): # noqa: D401
async def handle_generation_started():
logger.debug("LLM generation started in callback processor")
# Clear reference text from previous generation
engine._current_llm_reference_text = ""
engine._current_llm_generation_reference_text = ""
return handle_generation_started
@ -184,7 +184,7 @@ def create_aggregation_correction_callback(engine: "PipecatEngine"):
return "".join(out_chars)
def correct_aggregation(corrupted: str) -> str:
reference = engine._current_llm_reference_text
reference = engine._current_llm_generation_reference_text
if not reference:
logger.warning("No reference text available for aggregation correction")

View file

@ -1,164 +0,0 @@
{
"nodes": [
{
"id": "915",
"type": "agentNode",
"position": {
"x": 633,
"y": 324
},
"data": {
"prompt": "You are a voice agent whose mode of speaking is voice. Ask the user whether they want to talk to a sales guy or a customer service agent",
"name": "Agent"
},
"measured": {
"width": 300,
"height": 100
},
"selected": false,
"dragging": false
},
{
"id": "7598",
"type": "agentNode",
"position": {
"x": 460.1247806640531,
"y": 610.3714977079578
},
"data": {
"prompt": "You are a customer service agent whose mode of communication with the user is voice. Tell them that someone from our team will reach out to them soon",
"name": "Agent"
},
"measured": {
"width": 300,
"height": 100
},
"selected": false,
"dragging": false
},
{
"id": "6919",
"type": "agentNode",
"position": {
"x": 914.666735413607,
"y": 642.9800281289787
},
"data": {
"prompt": "You are a sales representative whose mode of communication with the user is voice. Tell the user that someone from our team will reach out to you soon",
"name": "Agent"
},
"measured": {
"width": 300,
"height": 100
},
"selected": false,
"dragging": false
},
{
"id": "6581",
"type": "startCall",
"position": {
"x": 648,
"y": 35
},
"data": {
"prompt": "Hello, I am Abhishek from Dograh. ",
"is_static": true,
"name": "Start Call",
"is_start": true
},
"measured": {
"width": 300,
"height": 100
},
"selected": false,
"dragging": false
},
{
"id": "1802",
"type": "endCall",
"position": {
"x": 666.7733431033548,
"y": 987.4345801025363
},
"data": {
"prompt": "Thank you for calling Dograh. Have a great day!",
"is_static": true,
"name": "End Call"
},
"measured": {
"width": 300,
"height": 100
},
"selected": false,
"dragging": false
}
],
"edges": [
{
"animated": true,
"type": "custom",
"source": "915",
"target": "7598",
"id": "xy-edge__915-7598",
"selected": false,
"data": {
"condition": "The customer wants to talk to a customer service agent",
"label": "customer service agent"
}
},
{
"animated": true,
"type": "custom",
"source": "915",
"target": "6919",
"id": "xy-edge__915-6919",
"selected": false,
"data": {
"condition": "customer wants to talk to a sales representative",
"label": "sales representative"
}
},
{
"animated": true,
"type": "custom",
"source": "6581",
"target": "915",
"id": "xy-edge__6581-915",
"selected": false,
"data": {
"condition": "Always take this route",
"label": "Always take this route"
}
},
{
"animated": true,
"type": "custom",
"source": "7598",
"target": "1802",
"id": "xy-edge__7598-1802",
"selected": false,
"data": {
"condition": "end call",
"label": "end call"
}
},
{
"animated": true,
"type": "custom",
"source": "6919",
"target": "1802",
"id": "xy-edge__6919-1802",
"selected": false,
"data": {
"condition": "end call",
"label": "end call"
}
}
],
"viewport": {
"x": 0,
"y": 0,
"zoom": 1
}
}

View file

@ -1,192 +0,0 @@
from unittest.mock import Mock
from api.services.workflow.pipecat_engine_callbacks import (
create_aggregation_correction_callback,
)
def test_aggregation_fixer():
"""Validate the aggregation correction algorithm using a helper that
creates a fresh callback for every (reference, corrupted) pair.
The production callback now needs a PipecatEngine instance with the
`_current_llm_reference_text` set. For test-friendliness we mock a bare
object providing just that attribute for each assertion so the original
two-argument test cases remain unchanged.
"""
def fixer(reference: str, corrupted: str) -> str: # noqa: D401
mock_engine = Mock()
mock_engine._current_llm_reference_text = reference
return create_aggregation_correction_callback(mock_engine)(corrupted)
##### Trailing extra Chars #####
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"My name is Alex and I am calling you from Cons umer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
)
== "My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?"
), "leading_whole_sentence"
# Whole sentences
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"Good Morning Mr NAR GES , My name is Alex and I am calling you from Cons umer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
)
== "Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?"
), "whole_sentences"
# With a period in the end
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"Good Morning Mr NAR GES , My name is Alex and I am calling you from Cons umer Services.",
)
== "Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services."
), "period_end"
# without a period in the end
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"Good Morning Mr NAR GES , My name is Alex and I am calling you from Cons umer Services",
)
== "Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services"
), "without_period_end"
# Extra space in the end
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"Good Morning Mr NAR GES , My name is Alex and I am calling you from Cons umer Services ",
)
== "Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services"
), "extra_space"
# Multiple spaces in corruption
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"Good Morning Mr NAR GES , My name is Alex and I am calling you from Cons umer Servi ces ",
)
== "Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services"
), "multiple_space"
# Multiple spaces in corruption ending in a whitespace
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"Good Morning Mr NAR GES , My name is Alex and I am calling you from Cons umer Servi ces. ",
)
== "Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. "
), "multiple_space_end_ws"
##### Leading extra Chars #####
# Whole sentences
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"My name is Alex and I am calling you from Cons umer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
)
== "My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?"
), "leading_whole_sentence"
# With a period in the end
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"My name is Alex and I am calling you from Cons umer Services.",
)
== "My name is Alex and I am calling you from Consumer Services."
), "leading_period_end"
# without a period in the end
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"My name is Alex and I am calling you from Cons umer Services",
)
== "My name is Alex and I am calling you from Consumer Services"
), "leading_without_period_end"
# Extra space in the end
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"My name is Alex and I am calling you from Cons umer Services ",
)
== "My name is Alex and I am calling you from Consumer Services"
), "leading_extra_space"
# Multiple spaces in corruption
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"My name is Alex and I am calling you from Cons umer Servi ces ",
)
== "My name is Alex and I am calling you from Consumer Services"
), "leading_multiple_space"
# Multiple spaces in corruption ending in a whitespace
assert (
fixer(
"Good Morning Mr NARGES , My name is Alex and I am calling you from Consumer Services. The reason of my call today, as I can see in our records that you are making your monthly credit card payments on time, but you STILL carry a balance of over 7 thousand dollars, right?",
"My name is Alex and I am calling you from Cons umer Servi ces. ",
)
== "My name is Alex and I am calling you from Consumer Services. "
), "leading_multiple_space_end_ws"
# Whitespace
assert fixer("", "") == ""
# Missing reference
assert (
fixer("", "My name is Alex and I am calling you from Cons umer Servi ces.")
== "My name is Alex and I am calling you from Cons umer Servi ces."
), "missing_reference"
# Smaller reference
assert (
fixer(
"My name is Alex",
"My name is Alex and I am calling you from Cons umer Servi ces.",
)
== "My name is Alex and I am calling you from Cons umer Servi ces."
), "smaller_reference"
# Unrelated reference
assert (
fixer(
"Hello Hello",
"My name is Alex and I am calling you from Cons umer Servi ces.",
)
== "My name is Alex and I am calling you from Cons umer Servi ces."
), "unrelated_reference"
def test_create_aggregation_correction_callback():
"""Test the new aggregation correction callback creator."""
# Mock engine with reference text
mock_engine = Mock()
mock_engine._current_llm_reference_text = "Good Morning Mr NARGES, My name is Alex and I am calling you from Consumer Services."
# Create callback
callback = create_aggregation_correction_callback(mock_engine)
# Test correction
corrected = callback(
"Good Morning Mr NAR GES, My name is Alex and I am calling you from Cons umer Services."
)
assert (
corrected
== "Good Morning Mr NARGES, My name is Alex and I am calling you from Consumer Services."
)
# Test with no reference text
mock_engine._current_llm_reference_text = ""
corrected = callback("Some corrupted text")
assert corrected == "Some corrupted text" # Should return as-is when no reference

View file

@ -1,128 +0,0 @@
from unittest.mock import Mock
import pytest
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
from pipecat.services.openai.llm import OpenAILLMContext
from api.services.workflow.pipecat_engine import PipecatEngine
from api.services.workflow.pipecat_engine_callbacks import (
create_generation_started_callback,
)
class TestAggregationIntegration:
"""Integration tests for the TTS aggregation correction flow."""
@pytest.mark.asyncio
async def test_engine_reference_text_tracking(self):
"""Test that the engine properly tracks LLM reference text."""
# Create mock dependencies
mock_task = Mock()
mock_llm = Mock()
mock_context = Mock(spec=OpenAILLMContext)
mock_tts = Mock()
mock_workflow = Mock()
mock_workflow.start_node_id = "start"
mock_workflow.nodes = {
"start": Mock(is_start=True, is_static=True, is_end=False, out_edges=[])
}
# Create engine
engine = PipecatEngine(
task=mock_task,
llm=mock_llm,
context=mock_context,
tts=mock_tts,
workflow=mock_workflow,
call_context_vars={},
workflow_run_id=1,
)
# Test initial state
assert engine._current_llm_reference_text == ""
# Test accumulating LLM text
await engine.handle_llm_text_frame("Hello ")
assert engine._current_llm_reference_text == "Hello "
await engine.handle_llm_text_frame("world!")
assert engine._current_llm_reference_text == "Hello world!"
# Test generation started callback clears reference text
callback = create_generation_started_callback(engine)
await callback()
assert engine._current_llm_reference_text == ""
@pytest.mark.asyncio
async def test_aggregation_correction_callback_creation(self):
"""Test creating the aggregation correction callback."""
# Create mock engine
mock_task = Mock()
mock_llm = Mock()
mock_context = Mock(spec=OpenAILLMContext)
mock_workflow = Mock()
engine = PipecatEngine(
task=mock_task,
llm=mock_llm,
context=mock_context,
workflow=mock_workflow,
call_context_vars={},
workflow_run_id=1,
)
# Set reference text
engine._current_llm_reference_text = "Hello, world! How are you?"
# Create correction callback
callback = engine.create_aggregation_correction_callback()
# Test correction - note that trailing punctuation might be stripped if not in corrupted text
corrected = callback("Hello world How are you")
assert corrected == "Hello, world! How are you"
def test_llm_assistant_aggregator_params_with_callback(self):
"""Test that LLMAssistantAggregatorParams accepts correction callback."""
def mock_callback(text: str) -> str:
return text.upper()
params = LLMAssistantAggregatorParams(
expect_stripped_words=True, correct_aggregation_callback=mock_callback
)
assert params.expect_stripped_words is True
assert params.correct_aggregation_callback is not None
assert params.correct_aggregation_callback("hello") == "HELLO"
@pytest.mark.asyncio
async def test_pipeline_callbacks_processor_llm_text_frame(self):
"""Test that PipelineEngineCallbacksProcessor handles LLMTextFrame."""
from pipecat.frames.frames import LLMTextFrame
from pipecat.processors.frame_processor import FrameDirection
from api.services.pipecat.pipeline_engine_callbacks_processor import (
PipelineEngineCallbacksProcessor,
)
# Track callback invocations
callback_invoked = False
callback_text = None
async def mock_llm_text_callback(text: str):
nonlocal callback_invoked, callback_text
callback_invoked = True
callback_text = text
# Create processor with callback
processor = PipelineEngineCallbacksProcessor(
llm_text_frame_callback=mock_llm_text_callback
)
# Process LLMTextFrame
frame = LLMTextFrame(text="Hello world")
await processor.process_frame(frame, FrameDirection.DOWNSTREAM)
# Verify callback was invoked
assert callback_invoked is True
assert callback_text == "Hello world"

View file

@ -1,31 +0,0 @@
from api.services.pricing.cost_calculator import cost_calculator
def test_cost_calculator():
"""Test function to verify cost calculation works"""
sample_usage = {
"llm": {
"OpenAILLMService#0|||gpt-4.1-mini": {
"prompt_tokens": 45380,
"completion_tokens": 496,
"total_tokens": 45876,
"cache_read_input_tokens": 0,
"cache_creation_input_tokens": 0,
}
},
"tts": {"ElevenLabsTTSService#0|||eleven_flash_v2_5": 2399},
"stt": {"DeepgramSTTService#0|||nova-3-general": 177.21536946296692},
"call_duration_seconds": 179,
}
result = cost_calculator.calculate_total_cost(sample_usage)
assert result["llm_cost"] == 45380 * 0.40 / 1_000_000 + 496 * 1.60 / 1_000_000
assert result["tts_cost"] == 2399 * 0.0256 / 1_000
assert result["stt_cost"] == 177.21536946296692 / 60 * 0.0077
assert (
abs(
result["total"]
- (result["llm_cost"] + result["tts_cost"] + result["stt_cost"])
)
< 1e-10
)

View file

@ -1,11 +0,0 @@
import pytest
from api.services.workflow.dto import ReactFlowDTO
@pytest.mark.asyncio
async def test_dto():
# assert no exceptions are raised
with open("services/workflow/test/definitions/rf-1.json", "r") as f:
dto = ReactFlowDTO.model_validate_json(f.read())
assert dto is not None

View file

@ -1,159 +0,0 @@
from unittest.mock import AsyncMock, Mock
import pytest
from pipecat.frames.frames import StartInterruptionFrame
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
from pipecat.services.openai.llm import (
OpenAIAssistantContextAggregator,
OpenAILLMContext,
)
class TestInterruptionCorrection:
"""Test that TTS aggregation correction works during interruptions."""
@pytest.mark.asyncio
async def test_openai_interruption_with_correction(self):
"""Test OpenAI assistant context aggregator applies correction during interruption."""
# Create mock context
mock_context = Mock(spec=OpenAILLMContext)
mock_context.get_messages.return_value = []
mock_context.add_message = Mock()
# Create correction callback
def correction_callback(text: str) -> str:
# Simulate fixing corrupted text
if text == "Hello world how are you":
return "Hello world, how are you"
return text
# Create aggregator with correction callback
params = LLMAssistantAggregatorParams(
expect_stripped_words=True, correct_aggregation_callback=correction_callback
)
aggregator = OpenAIAssistantContextAggregator(
context=mock_context, params=params
)
# Set up aggregation state
aggregator._aggregation = "Hello world how are you"
aggregator._current_llm_response_id = "test-id"
aggregator._response_function_messages = {}
aggregator._function_calls_in_progress = {}
aggregator._started = 1
# Mock push_context_frame and reset methods
aggregator.push_context_frame = AsyncMock()
aggregator.reset = AsyncMock()
# Process interruption
interruption_frame = StartInterruptionFrame()
await aggregator._handle_interruptions(interruption_frame)
# Verify the corrected text was added to context
mock_context.add_message.assert_called_once()
added_message = mock_context.add_message.call_args[0][0]
assert added_message["role"] == "assistant"
assert (
added_message["content"]
== "Hello world, how are you <<interrupted_by_user>>"
)
@pytest.mark.asyncio
async def test_google_interruption_with_correction(self):
"""Test Google assistant context aggregator applies correction during interruption."""
from pipecat.services.google.llm import (
Content,
GoogleAssistantContextAggregator,
)
# Create mock context
mock_context = Mock(spec=OpenAILLMContext)
mock_context.get_messages.return_value = []
mock_context.add_message = Mock()
# Create correction callback
def correction_callback(text: str) -> str:
# Simulate fixing corrupted text
if text == "I am here to help":
return "I am here to help"
return text
# Create aggregator with correction callback
params = LLMAssistantAggregatorParams(
expect_stripped_words=True, correct_aggregation_callback=correction_callback
)
aggregator = GoogleAssistantContextAggregator(
context=mock_context, params=params
)
# Set up aggregation state
aggregator._aggregation = "I am here to help"
aggregator._current_llm_response_id = "test-id"
aggregator._response_function_messages = {}
aggregator._function_calls_in_progress = {}
aggregator._started = 1
# Mock push_context_frame and reset methods
aggregator.push_context_frame = AsyncMock()
aggregator.reset = AsyncMock()
# Process interruption
interruption_frame = StartInterruptionFrame()
await aggregator._handle_interruptions(interruption_frame)
# Verify the corrected text was added to context
mock_context.add_message.assert_called_once()
added_content = mock_context.add_message.call_args[0][0]
# Google uses Content objects
assert isinstance(added_content, Content)
assert added_content.role == "model"
assert len(added_content.parts) == 1
assert (
added_content.parts[0].text == "I am here to help <<interrupted_by_user>>"
)
@pytest.mark.asyncio
async def test_interruption_correction_error_handling(self):
"""Test that interruption handling continues even if correction callback fails."""
# Create mock context
mock_context = Mock(spec=OpenAILLMContext)
mock_context.get_messages.return_value = []
mock_context.add_message = Mock()
# Create correction callback that raises error
def failing_callback(text: str) -> str:
raise ValueError("Correction failed")
# Create aggregator with failing callback
params = LLMAssistantAggregatorParams(
expect_stripped_words=True, correct_aggregation_callback=failing_callback
)
aggregator = OpenAIAssistantContextAggregator(
context=mock_context, params=params
)
# Set up aggregation state
aggregator._aggregation = "Some text"
aggregator._current_llm_response_id = "test-id"
aggregator._response_function_messages = {}
aggregator._function_calls_in_progress = {}
aggregator._started = 1
# Mock push_context_frame and reset methods
aggregator.push_context_frame = AsyncMock()
aggregator.reset = AsyncMock()
# Process interruption - should not raise
interruption_frame = StartInterruptionFrame()
await aggregator._handle_interruptions(interruption_frame)
# Verify the original text was still added (fallback behavior)
mock_context.add_message.assert_called_once()
added_message = mock_context.add_message.call_args[0][0]
assert added_message["role"] == "assistant"
assert added_message["content"] == "Some text <<interrupted_by_user>>"