mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-22 08:38:13 +02:00
Initial Commit 🚀 🚀
This commit is contained in:
commit
4f2a629340
444 changed files with 76863 additions and 0 deletions
128
api/services/workflow/test/test_aggregation_integration.py
Normal file
128
api/services/workflow/test/test_aggregation_integration.py
Normal file
|
|
@ -0,0 +1,128 @@
|
|||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
|
||||
from pipecat.services.openai.llm import OpenAILLMContext
|
||||
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.pipecat_engine_callbacks import (
|
||||
create_generation_started_callback,
|
||||
)
|
||||
|
||||
|
||||
class TestAggregationIntegration:
|
||||
"""Integration tests for the TTS aggregation correction flow."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_engine_reference_text_tracking(self):
|
||||
"""Test that the engine properly tracks LLM reference text."""
|
||||
# Create mock dependencies
|
||||
mock_task = Mock()
|
||||
mock_llm = Mock()
|
||||
mock_context = Mock(spec=OpenAILLMContext)
|
||||
mock_tts = Mock()
|
||||
mock_workflow = Mock()
|
||||
mock_workflow.start_node_id = "start"
|
||||
mock_workflow.nodes = {
|
||||
"start": Mock(is_start=True, is_static=True, is_end=False, out_edges=[])
|
||||
}
|
||||
|
||||
# Create engine
|
||||
engine = PipecatEngine(
|
||||
task=mock_task,
|
||||
llm=mock_llm,
|
||||
context=mock_context,
|
||||
tts=mock_tts,
|
||||
workflow=mock_workflow,
|
||||
call_context_vars={},
|
||||
workflow_run_id=1,
|
||||
)
|
||||
|
||||
# Test initial state
|
||||
assert engine._current_llm_reference_text == ""
|
||||
|
||||
# Test accumulating LLM text
|
||||
await engine.handle_llm_text_frame("Hello ")
|
||||
assert engine._current_llm_reference_text == "Hello "
|
||||
|
||||
await engine.handle_llm_text_frame("world!")
|
||||
assert engine._current_llm_reference_text == "Hello world!"
|
||||
|
||||
# Test generation started callback clears reference text
|
||||
callback = create_generation_started_callback(engine)
|
||||
await callback()
|
||||
assert engine._current_llm_reference_text == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aggregation_correction_callback_creation(self):
|
||||
"""Test creating the aggregation correction callback."""
|
||||
# Create mock engine
|
||||
mock_task = Mock()
|
||||
mock_llm = Mock()
|
||||
mock_context = Mock(spec=OpenAILLMContext)
|
||||
mock_workflow = Mock()
|
||||
|
||||
engine = PipecatEngine(
|
||||
task=mock_task,
|
||||
llm=mock_llm,
|
||||
context=mock_context,
|
||||
workflow=mock_workflow,
|
||||
call_context_vars={},
|
||||
workflow_run_id=1,
|
||||
)
|
||||
|
||||
# Set reference text
|
||||
engine._current_llm_reference_text = "Hello, world! How are you?"
|
||||
|
||||
# Create correction callback
|
||||
callback = engine.create_aggregation_correction_callback()
|
||||
|
||||
# Test correction - note that trailing punctuation might be stripped if not in corrupted text
|
||||
corrected = callback("Hello world How are you")
|
||||
assert corrected == "Hello, world! How are you"
|
||||
|
||||
def test_llm_assistant_aggregator_params_with_callback(self):
|
||||
"""Test that LLMAssistantAggregatorParams accepts correction callback."""
|
||||
|
||||
def mock_callback(text: str) -> str:
|
||||
return text.upper()
|
||||
|
||||
params = LLMAssistantAggregatorParams(
|
||||
expect_stripped_words=True, correct_aggregation_callback=mock_callback
|
||||
)
|
||||
|
||||
assert params.expect_stripped_words is True
|
||||
assert params.correct_aggregation_callback is not None
|
||||
assert params.correct_aggregation_callback("hello") == "HELLO"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_callbacks_processor_llm_text_frame(self):
|
||||
"""Test that PipelineEngineCallbacksProcessor handles LLMTextFrame."""
|
||||
from pipecat.frames.frames import LLMTextFrame
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
|
||||
from api.services.pipecat.pipeline_engine_callbacks_processor import (
|
||||
PipelineEngineCallbacksProcessor,
|
||||
)
|
||||
|
||||
# Track callback invocations
|
||||
callback_invoked = False
|
||||
callback_text = None
|
||||
|
||||
async def mock_llm_text_callback(text: str):
|
||||
nonlocal callback_invoked, callback_text
|
||||
callback_invoked = True
|
||||
callback_text = text
|
||||
|
||||
# Create processor with callback
|
||||
processor = PipelineEngineCallbacksProcessor(
|
||||
llm_text_frame_callback=mock_llm_text_callback
|
||||
)
|
||||
|
||||
# Process LLMTextFrame
|
||||
frame = LLMTextFrame(text="Hello world")
|
||||
await processor.process_frame(frame, FrameDirection.DOWNSTREAM)
|
||||
|
||||
# Verify callback was invoked
|
||||
assert callback_invoked is True
|
||||
assert callback_text == "Hello world"
|
||||
Loading…
Add table
Add a link
Reference in a new issue