mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
fix: llm generation in case of user idle
Send for LLM generation in case of user idle rather than speaking a hardcoded sentence
This commit is contained in:
parent
56953bbd09
commit
04576ac357
11 changed files with 364 additions and 87 deletions
|
|
@ -14,11 +14,72 @@ from api.services.workflow.dto import (
|
|||
RFNodeDTO,
|
||||
)
|
||||
from api.services.workflow.workflow import WorkflowGraph
|
||||
from pipecat.frames.frames import (
|
||||
BotSpeakingFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
Frame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
START_CALL_SYSTEM_PROMPT = "start_call_system_prompt"
|
||||
END_CALL_SYSTEM_PROMPT = "end_call_system_prompt"
|
||||
|
||||
|
||||
class MockTransportProcessor(FrameProcessor):
|
||||
"""
|
||||
Mocks the transport behavior by emitting Bot speaking frames
|
||||
when it encounters TTS frames.
|
||||
|
||||
This simulates what a real transport would do when the bot is speaking:
|
||||
- TTSStartedFrame -> BotStartedSpeakingFrame
|
||||
- TTSAudioRawFrame -> BotSpeakingFrame
|
||||
- TTSStoppedFrame -> BotStoppedSpeakingFrame
|
||||
|
||||
Args:
|
||||
emit_bot_speaking: If True, also emits BotSpeakingFrame on TTSAudioRawFrame
|
||||
which is needed for UserIdleProcessor to start conversation tracking. Default True.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
emit_bot_speaking: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._emit_bot_speaking = emit_bot_speaking
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, TTSStartedFrame):
|
||||
# Emit BotStartedSpeakingFrame to indicate bot started speaking
|
||||
await self.push_frame(BotStartedSpeakingFrame())
|
||||
await self.push_frame(
|
||||
BotStartedSpeakingFrame(), direction=FrameDirection.UPSTREAM
|
||||
)
|
||||
elif isinstance(frame, TTSAudioRawFrame):
|
||||
# Emit BotSpeakingFrame - this is what triggers the UserIdleProcessor
|
||||
# to start conversation tracking
|
||||
if self._emit_bot_speaking:
|
||||
await self.push_frame(BotSpeakingFrame())
|
||||
await self.push_frame(
|
||||
BotSpeakingFrame(), direction=FrameDirection.UPSTREAM
|
||||
)
|
||||
elif isinstance(frame, TTSStoppedFrame):
|
||||
# Emit BotStoppedSpeakingFrame to indicate bot stopped speaking
|
||||
await self.push_frame(BotStoppedSpeakingFrame())
|
||||
await self.push_frame(
|
||||
BotStoppedSpeakingFrame(), direction=FrameDirection.UPSTREAM
|
||||
)
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockToolModel:
|
||||
"""Mock tool model for testing."""
|
||||
|
|
|
|||
|
|
@ -920,27 +920,6 @@ class TestUpdateLLMContext:
|
|||
assert messages[1]["role"] == "user"
|
||||
assert messages[2]["role"] == "assistant"
|
||||
|
||||
def test_removes_multiple_old_system_messages(self):
|
||||
"""Test that all old system messages are removed."""
|
||||
context = LLMContext()
|
||||
context.set_messages(
|
||||
[
|
||||
{"role": "system", "content": "First system message"},
|
||||
{"role": "user", "content": "Question 1"},
|
||||
{"role": "system", "content": "Second system message"},
|
||||
{"role": "assistant", "content": "Answer 1"},
|
||||
]
|
||||
)
|
||||
|
||||
new_system = {"role": "system", "content": "Only system message"}
|
||||
update_llm_context(context, new_system, [])
|
||||
|
||||
messages = context.messages
|
||||
# Should only have one system message now
|
||||
system_messages = [m for m in messages if m["role"] == "system"]
|
||||
assert len(system_messages) == 1
|
||||
assert system_messages[0]["content"] == "Only system message"
|
||||
|
||||
def test_preserves_conversation_history(self):
|
||||
"""Test that user/assistant messages are preserved in order."""
|
||||
context = LLMContext()
|
||||
|
|
|
|||
|
|
@ -10,18 +10,9 @@ from unittest.mock import AsyncMock, patch
|
|||
|
||||
import pytest
|
||||
|
||||
from api.services.pipecat.pipeline_engine_callbacks_processor import (
|
||||
PipelineEngineCallbacksProcessor,
|
||||
)
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.workflow import WorkflowGraph
|
||||
from api.tests.conftest import END_CALL_SYSTEM_PROMPT
|
||||
from pipecat.frames.frames import (
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
Frame,
|
||||
TextFrame,
|
||||
)
|
||||
from api.tests.conftest import END_CALL_SYSTEM_PROMPT, MockTransportProcessor
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
|
|
@ -30,33 +21,7 @@ from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorPa
|
|||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.tests import MockLLMService
|
||||
|
||||
|
||||
class MockBotStoppedSpeakingOnLLMTextFrameProcessor(FrameProcessor):
|
||||
"""
|
||||
Mocking the transport, where transport sends BotStartedSpeakingFrame
|
||||
and BotStoppedSpeakingFrame when it encounters a LLMTextFrame.
|
||||
"""
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, TextFrame):
|
||||
await self.push_frame(BotStartedSpeakingFrame())
|
||||
await self.push_frame(
|
||||
BotStartedSpeakingFrame(), direction=FrameDirection.UPSTREAM
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
await self.push_frame(BotStoppedSpeakingFrame())
|
||||
await self.push_frame(
|
||||
BotStoppedSpeakingFrame(), direction=FrameDirection.UPSTREAM
|
||||
)
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
from pipecat.tests import MockLLMService, MockTTSService
|
||||
|
||||
|
||||
async def run_pipeline_with_tool_calls(
|
||||
|
|
@ -96,7 +61,10 @@ async def run_pipeline_with_tool_calls(
|
|||
# Create MockLLMService with multi-step support
|
||||
llm = MockLLMService(mock_steps=mock_steps, chunk_delay=0.001)
|
||||
|
||||
mock_transport_emulator = MockBotStoppedSpeakingOnLLMTextFrameProcessor()
|
||||
# Create MockTTSService to generate TTS frames
|
||||
tts = MockTTSService(mock_audio_duration_ms=10, frame_delay=0)
|
||||
|
||||
mock_transport_emulator = MockTransportProcessor(emit_bot_speaking=False)
|
||||
|
||||
# Create LLM context
|
||||
context = LLMContext()
|
||||
|
|
@ -117,10 +85,11 @@ async def run_pipeline_with_tool_calls(
|
|||
workflow_run_id=1,
|
||||
)
|
||||
|
||||
# Create the pipeline with the mock LLM
|
||||
# Create the pipeline with the mock LLM and TTS
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
llm,
|
||||
tts,
|
||||
mock_transport_emulator,
|
||||
assistant_context_aggregator,
|
||||
]
|
||||
|
|
|
|||
236
api/tests/test_user_idle_handler.py
Normal file
236
api/tests/test_user_idle_handler.py
Normal file
|
|
@ -0,0 +1,236 @@
|
|||
"""
|
||||
Simulates a user idle condition and tests the behaviour
|
||||
of the user idle processor.
|
||||
|
||||
This module tests the behavior when the user becomes idle during a conversation,
|
||||
ensuring the UserIdleProcessor properly triggers the callback and the engine
|
||||
handles it correctly.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.workflow import WorkflowGraph
|
||||
from api.tests.conftest import MockTransportProcessor
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
)
|
||||
from pipecat.processors.user_idle_processor import UserIdleProcessor
|
||||
from pipecat.tests import MockLLMService, MockTTSService
|
||||
|
||||
|
||||
async def run_pipeline_with_user_idle(
|
||||
workflow: WorkflowGraph,
|
||||
user_idle_timeout: float = 0.2,
|
||||
mock_steps: list | None = None,
|
||||
) -> tuple[MockLLMService, LLMContext, UserIdleProcessor]:
|
||||
"""Run a pipeline with UserIdleProcessor and simulate user idle condition.
|
||||
|
||||
Args:
|
||||
workflow: The workflow graph to use.
|
||||
user_idle_timeout: Timeout in seconds before considering user idle.
|
||||
mock_steps: Optional list of mock step chunks for the LLM. If not provided,
|
||||
defaults to a simple greeting followed by text responses.
|
||||
|
||||
Returns:
|
||||
Tuple of (MockLLMService, LLMContext, UserIdleProcessor) for assertions.
|
||||
"""
|
||||
# Create mock responses - bot will speak first, then respond to idle prompts
|
||||
# Step 1: Initial greeting
|
||||
# Step 2: Response to first idle (asking if user is still there)
|
||||
# Step 3: Response to second idle (goodbye message)
|
||||
|
||||
if mock_steps is None:
|
||||
mock_steps = MockLLMService.create_multi_step_responses(
|
||||
MockLLMService.create_text_chunks("Hello, how can I help you today?"),
|
||||
num_text_steps=3, # Initial + 2 idle responses
|
||||
step_prefix="Response",
|
||||
)
|
||||
|
||||
llm = MockLLMService(mock_steps=mock_steps, chunk_delay=0.001)
|
||||
tts = MockTTSService(mock_audio_duration_ms=10)
|
||||
|
||||
mock_transport = MockTransportProcessor()
|
||||
|
||||
# Create LLM context
|
||||
context = LLMContext()
|
||||
|
||||
# Create context aggregator with both user and assistant aggregators
|
||||
assistant_params = LLMAssistantAggregatorParams(expect_stripped_words=True)
|
||||
context_aggregator = LLMContextAggregatorPair(
|
||||
context, assistant_params=assistant_params
|
||||
)
|
||||
user_context_aggregator = context_aggregator.user()
|
||||
assistant_context_aggregator = context_aggregator.assistant()
|
||||
|
||||
# Create PipecatEngine with the workflow
|
||||
engine = PipecatEngine(
|
||||
llm=llm,
|
||||
context=context,
|
||||
workflow=workflow,
|
||||
call_context_vars={"customer_name": "Test User"},
|
||||
workflow_run_id=1,
|
||||
)
|
||||
|
||||
# Create UserIdleProcessor with engine's callback and a short timeout
|
||||
user_idle_processor = UserIdleProcessor(
|
||||
callback=engine.create_user_idle_callback(),
|
||||
timeout=user_idle_timeout,
|
||||
)
|
||||
|
||||
# Build the pipeline:
|
||||
# llm -> mock_transport -> user_idle_processor -> assistant_context_aggregator
|
||||
# The user_context_aggregator would normally be at the start for user input
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
user_idle_processor,
|
||||
user_context_aggregator,
|
||||
llm,
|
||||
tts,
|
||||
mock_transport,
|
||||
assistant_context_aggregator,
|
||||
]
|
||||
)
|
||||
|
||||
# Create pipeline task
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(allow_interruptions=False),
|
||||
)
|
||||
|
||||
engine.set_task(task)
|
||||
|
||||
# Patch DB calls to avoid actual database access
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run",
|
||||
new_callable=AsyncMock,
|
||||
return_value=1,
|
||||
):
|
||||
with patch(
|
||||
"api.services.workflow.pipecat_engine.apply_disposition_mapping",
|
||||
new_callable=AsyncMock,
|
||||
return_value="completed",
|
||||
):
|
||||
runner = PipelineRunner()
|
||||
|
||||
async def run_pipeline():
|
||||
await runner.run(task)
|
||||
|
||||
async def initialize_engine():
|
||||
# Small delay to let runner start
|
||||
await asyncio.sleep(0.01)
|
||||
await engine.initialize()
|
||||
|
||||
# Calculate total wait time:
|
||||
# - Initial bot speech
|
||||
# - First idle timeout (user_idle_timeout)
|
||||
# - First idle callback + LLM generation
|
||||
# - Second idle timeout (user_idle_timeout)
|
||||
# - Second idle callback (ends the task)
|
||||
# Add buffer for processing time
|
||||
total_wait_time = (user_idle_timeout * 3) + 1.0
|
||||
|
||||
async def wait_for_idle_to_trigger():
|
||||
# Wait long enough for idle timeouts to trigger
|
||||
await asyncio.sleep(total_wait_time)
|
||||
# Cancel the task if it's still running
|
||||
await task.cancel()
|
||||
|
||||
# Run all concurrently
|
||||
await asyncio.gather(
|
||||
run_pipeline(),
|
||||
initialize_engine(),
|
||||
wait_for_idle_to_trigger(),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
return llm, context, user_idle_processor
|
||||
|
||||
|
||||
class TestUserIdleHandler:
|
||||
"""Test user idle handling through PipecatEngine and UserIdleProcessor."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_idle_triggers_callback(self, simple_workflow: WorkflowGraph):
|
||||
"""Test that user idle condition properly triggers the callback.
|
||||
|
||||
This test verifies that when:
|
||||
1. The bot starts speaking (triggers conversation tracking)
|
||||
2. No user input is received for the timeout period
|
||||
3. The UserIdleProcessor triggers the idle callback
|
||||
|
||||
The engine's user idle callback should:
|
||||
- First retry: Send a message asking if user is still there
|
||||
- Second retry: Send goodbye message and end the call
|
||||
"""
|
||||
llm, context, user_idle_processor = await run_pipeline_with_user_idle(
|
||||
workflow=simple_workflow,
|
||||
user_idle_timeout=0.2, # Short timeout for faster test
|
||||
)
|
||||
|
||||
# Final message in the context should be from the bot
|
||||
assert len(context.messages) == 6, "Total 6 messages"
|
||||
assert context.messages[-1]["content"] == "Response 2", (
|
||||
"Final message in the context should be from LLM"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_idle_with_node_transition(
|
||||
self, three_node_workflow: WorkflowGraph
|
||||
):
|
||||
"""Test user idle handling with node transition via tool call.
|
||||
|
||||
This test verifies that when:
|
||||
1. The bot starts speaking with initial greeting
|
||||
2. The LLM generates a tool call to transition to the next node
|
||||
3. The pipeline correctly handles the transition
|
||||
|
||||
The mock steps are:
|
||||
- Step 1: Text "Hello, how can I help you today?"
|
||||
- Step 2: Tool call "collect_info" to transition to agent node
|
||||
- Step 3+: Text responses after transition
|
||||
"""
|
||||
# Create custom mock steps with tool call for node transition
|
||||
# For three_node_workflow:
|
||||
# - Edge from node 1 -> node 2 has label "Collect Info" -> function: "collect_info"
|
||||
# - Edge from node 2 -> node 3 has label "End Call" -> function: "end_call"
|
||||
mock_steps = [
|
||||
# Step 1: Initial greeting (text)
|
||||
MockLLMService.create_text_chunks("Hello, how can I help you today?"),
|
||||
# Step 2: Transition to Collect Info node (tool call)
|
||||
MockLLMService.create_function_call_chunks(
|
||||
function_name="collect_info",
|
||||
arguments={},
|
||||
tool_call_id="call_collect_info",
|
||||
),
|
||||
# Step 3: Response after transition (text)
|
||||
MockLLMService.create_text_chunks("Response after transition"),
|
||||
# Step 4+: Additional responses for idle handling
|
||||
MockLLMService.create_text_chunks("Response 2"),
|
||||
MockLLMService.create_text_chunks("Response 3"),
|
||||
]
|
||||
|
||||
llm, context, user_idle_processor = await run_pipeline_with_user_idle(
|
||||
workflow=three_node_workflow,
|
||||
user_idle_timeout=0.2,
|
||||
mock_steps=mock_steps,
|
||||
)
|
||||
|
||||
# Verify the LLM was called multiple times (initial + after transition)
|
||||
assert llm.get_current_step() >= 2, (
|
||||
"LLM should have been called at least twice (initial + after transition)"
|
||||
)
|
||||
|
||||
# This should be the message that we inserted from user_idle_handler
|
||||
assert context.messages[2]["role"] == "system", (
|
||||
"The system message should be in the context"
|
||||
)
|
||||
assert "The user has been quiet." in context.messages[2]["content"]
|
||||
Loading…
Add table
Add a link
Reference in a new issue