Fix realtime initial greeting handling (#481)

This commit is contained in:
Abhishek 2026-06-29 17:25:42 +05:30 committed by GitHub
parent d9800fddd6
commit 090d042a78
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 714 additions and 70 deletions

View file

@ -0,0 +1,88 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from pipecat.frames.frames import TTSSpeakFrame
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.openai.realtime import events
from api.services.pipecat.realtime.azure_realtime import (
DograhAzureRealtimeLLMService,
)
def _make_service() -> DograhAzureRealtimeLLMService:
service = DograhAzureRealtimeLLMService(
api_key="test-key",
base_url="wss://example.test/openai/realtime",
)
service._create_response = AsyncMock()
service._process_completed_function_calls = AsyncMock()
return service
@pytest.mark.asyncio
async def test_tts_greeting_sends_exact_static_greeting_prompt():
service = _make_service()
service._context = LLMContext([{"role": "user", "content": "Existing context"}])
service._api_session_ready = True
service.send_client_event = AsyncMock()
service.push_frame = AsyncMock()
service.start_processing_metrics = AsyncMock()
service.start_ttfb_metrics = AsyncMock()
await service.process_frame(
TTSSpeakFrame("Hi Sam, this is Sarah from Acme.", append_to_context=True),
FrameDirection.DOWNSTREAM,
)
sent_events = [call.args[0] for call in service.send_client_event.await_args_list]
assert isinstance(sent_events[0], events.ConversationItemCreateEvent)
assert sent_events[0].item.role == "user"
assert sent_events[0].item.content[0].text == "Existing context"
assert isinstance(sent_events[1], events.SessionUpdateEvent)
response_event = sent_events[-1]
assert isinstance(response_event, events.ResponseCreateEvent)
assert response_event.response.tool_choice == "none"
prompt = response_event.response.instructions
assert "The phone call has just connected. Greet the caller now:" in prompt
assert prompt.endswith('"Hi Sam, this is Sarah from Acme."')
assert service._llm_needs_conversation_setup is False
service._create_response.assert_not_awaited()
@pytest.mark.asyncio
async def test_tts_greeting_waits_for_session_updated_before_sending_prompt():
service = _make_service()
service._context = LLMContext([{"role": "user", "content": "Existing context"}])
await service.process_frame(
TTSSpeakFrame("Hello from Dograh.", append_to_context=True),
FrameDirection.DOWNSTREAM,
)
assert service._handled_initial_context is True
assert service._run_llm_when_api_session_ready is True
assert service._pending_initial_greeting_text == "Hello from Dograh."
service.send_client_event = AsyncMock()
service.push_frame = AsyncMock()
service.start_processing_metrics = AsyncMock()
service.start_ttfb_metrics = AsyncMock()
await service._handle_evt_session_updated(SimpleNamespace())
sent_events = [call.args[0] for call in service.send_client_event.await_args_list]
assert isinstance(sent_events[0], events.ConversationItemCreateEvent)
assert sent_events[0].item.content[0].text == "Existing context"
assert isinstance(sent_events[1], events.SessionUpdateEvent)
response_event = sent_events[-1]
assert isinstance(response_event, events.ResponseCreateEvent)
assert response_event.response.tool_choice == "none"
prompt = response_event.response.instructions
assert prompt.endswith('"Hello from Dograh."')
assert service._run_llm_when_api_session_ready is False
assert service._pending_initial_greeting_text is None
assert service._llm_needs_conversation_setup is False
service._create_response.assert_not_awaited()

View file

@ -3,7 +3,7 @@ from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from pipecat.frames.frames import TranscriptionFrame
from pipecat.frames.frames import TranscriptionFrame, TTSSpeakFrame
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.frame_processor import FrameDirection
@ -21,6 +21,7 @@ class _TestDograhGeminiLiveLLMService(DograhGeminiLiveLLMService):
class _FakeSession:
def __init__(self):
self.send_client_content = AsyncMock()
self.send_tool_response = AsyncMock()
self.send_realtime_input = AsyncMock()
self.close = AsyncMock()
@ -108,3 +109,57 @@ async def test_user_transcription_matches_upstream_upstream_push_behavior():
assert frame.text == "Hi there"
assert frame.finalized is False
assert direction == FrameDirection.UPSTREAM
@pytest.mark.asyncio
async def test_tts_greeting_sends_exact_static_greeting_prompt_to_gemini():
service = _make_service()
service._context = LLMContext()
service._session = _FakeSession()
await service.process_frame(
TTSSpeakFrame("Hi Sam, this is Sarah from Acme.", append_to_context=True),
FrameDirection.DOWNSTREAM,
)
service._session.send_client_content.assert_awaited_once()
kwargs = service._session.send_client_content.await_args.kwargs
assert kwargs["turn_complete"] is True
turns = kwargs["turns"]
assert len(turns) == 1
assert turns[0].role == "user"
prompt = turns[0].parts[0].text
assert "The phone call has just connected. Greet the caller now:" in prompt
assert (
'Do not add anything before or after it.\n\n"Hi Sam, this is Sarah from Acme."'
in prompt
)
assert service._handled_initial_context is True
assert service._pending_initial_greeting_text is None
assert service._ready_for_realtime_input is True
@pytest.mark.asyncio
async def test_tts_greeting_waits_for_gemini_session_before_sending_prompt():
service = _make_service()
service._context = LLMContext()
await service.process_frame(
TTSSpeakFrame("Hello from Dograh.", append_to_context=True),
FrameDirection.DOWNSTREAM,
)
assert service._handled_initial_context is True
assert service._run_llm_when_session_ready is True
assert service._pending_initial_greeting_text == "Hello from Dograh."
session = _FakeSession()
await service._handle_session_ready(session)
session.send_client_content.assert_awaited_once()
prompt = session.send_client_content.await_args.kwargs["turns"][0].parts[0].text
assert prompt.endswith('"Hello from Dograh."')
assert service._run_llm_when_session_ready is False
assert service._pending_initial_greeting_text is None

View file

@ -37,17 +37,71 @@ async def test_initial_context_triggers_response_when_context_was_prepopulated()
@pytest.mark.asyncio
async def test_tts_greeting_uses_initial_context_handler():
async def test_tts_greeting_sends_exact_static_greeting_prompt():
service = _make_service()
service._context = LLMContext()
service._handle_context = AsyncMock()
service._context = LLMContext([{"role": "user", "content": "Existing context"}])
service._api_session_ready = True
service.send_client_event = AsyncMock()
service.push_frame = AsyncMock()
service.start_processing_metrics = AsyncMock()
service.start_ttfb_metrics = AsyncMock()
await service.process_frame(
TTSSpeakFrame("hello", append_to_context=True),
TTSSpeakFrame("Hi Sam, this is Sarah from Acme.", append_to_context=True),
FrameDirection.DOWNSTREAM,
)
service._handle_context.assert_awaited_once_with(service._context)
sent_events = [call.args[0] for call in service.send_client_event.await_args_list]
assert isinstance(sent_events[0], events.ConversationItemCreateEvent)
assert sent_events[0].item.role == "user"
assert sent_events[0].item.content[0].text == "Existing context"
assert isinstance(sent_events[1], events.SessionUpdateEvent)
greeting_event = sent_events[2]
assert isinstance(greeting_event, events.ConversationItemCreateEvent)
assert greeting_event.item.role == "user"
assert greeting_event.item.type == "message"
prompt = greeting_event.item.content[0].text
assert "The phone call has just connected. Greet the caller now:" in prompt
assert prompt.endswith('"Hi Sam, this is Sarah from Acme."')
assert isinstance(sent_events[-1], events.ResponseCreateEvent)
assert service._llm_needs_conversation_setup is False
service._create_response.assert_not_awaited()
@pytest.mark.asyncio
async def test_tts_greeting_waits_for_session_updated_before_sending_prompt():
service = _make_service()
service._context = LLMContext([{"role": "user", "content": "Existing context"}])
await service.process_frame(
TTSSpeakFrame("Hello from Dograh.", append_to_context=True),
FrameDirection.DOWNSTREAM,
)
assert service._handled_initial_context is True
assert service._run_llm_when_api_session_ready is True
assert service._pending_initial_greeting_text == "Hello from Dograh."
service.send_client_event = AsyncMock()
service.push_frame = AsyncMock()
service.start_processing_metrics = AsyncMock()
service.start_ttfb_metrics = AsyncMock()
await service._handle_evt_session_updated(SimpleNamespace())
sent_events = [call.args[0] for call in service.send_client_event.await_args_list]
assert isinstance(sent_events[0], events.ConversationItemCreateEvent)
assert sent_events[0].item.content[0].text == "Existing context"
assert isinstance(sent_events[1], events.SessionUpdateEvent)
greeting_event = sent_events[2]
assert isinstance(greeting_event, events.ConversationItemCreateEvent)
prompt = greeting_event.item.content[0].text
assert prompt.endswith('"Hello from Dograh."')
assert isinstance(sent_events[-1], events.ResponseCreateEvent)
assert service._run_llm_when_api_session_ready is False
assert service._pending_initial_greeting_text is None
assert service._llm_needs_conversation_setup is False
service._create_response.assert_not_awaited()
@pytest.mark.asyncio

View file

@ -5,6 +5,7 @@ import pytest
from pipecat.frames.frames import TTSSpeakFrame
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.openai.realtime import events
from api.services.pipecat.realtime.openai_realtime import (
DograhOpenAIRealtimeLLMService,
@ -48,17 +49,69 @@ async def test_updated_context_uses_tool_result_path_after_initial_context():
@pytest.mark.asyncio
async def test_tts_greeting_uses_initial_context_handler():
async def test_tts_greeting_sends_exact_static_greeting_prompt():
service = _make_service()
service._context = LLMContext()
service._handle_context = AsyncMock()
service._api_session_ready = True
service.send_client_event = AsyncMock()
service.push_frame = AsyncMock()
service.start_processing_metrics = AsyncMock()
service.start_ttfb_metrics = AsyncMock()
await service.process_frame(
TTSSpeakFrame("hello", append_to_context=True),
TTSSpeakFrame("Hi Sam, this is Sarah from Acme.", append_to_context=True),
FrameDirection.DOWNSTREAM,
)
service._handle_context.assert_awaited_once_with(service._context)
sent_events = [call.args[0] for call in service.send_client_event.await_args_list]
assert not any(
isinstance(event, events.ConversationItemCreateEvent) for event in sent_events
)
assert isinstance(sent_events[0], events.SessionUpdateEvent)
response_event = sent_events[-1]
assert isinstance(response_event, events.ResponseCreateEvent)
assert response_event.response.tool_choice == "none"
prompt = response_event.response.instructions
assert "The phone call has just connected. Greet the caller now:" in prompt
assert prompt.endswith('"Hi Sam, this is Sarah from Acme."')
assert service._llm_needs_conversation_setup is False
service._create_response.assert_not_awaited()
@pytest.mark.asyncio
async def test_tts_greeting_waits_for_session_updated_before_sending_prompt():
service = _make_service()
service._context = LLMContext()
await service.process_frame(
TTSSpeakFrame("Hello from Dograh.", append_to_context=True),
FrameDirection.DOWNSTREAM,
)
assert service._handled_initial_context is True
assert service._run_llm_when_api_session_ready is True
assert service._pending_initial_greeting_text == "Hello from Dograh."
service.send_client_event = AsyncMock()
service.push_frame = AsyncMock()
service.start_processing_metrics = AsyncMock()
service.start_ttfb_metrics = AsyncMock()
await service._handle_evt_session_updated(SimpleNamespace())
sent_events = [call.args[0] for call in service.send_client_event.await_args_list]
assert not any(
isinstance(event, events.ConversationItemCreateEvent) for event in sent_events
)
assert isinstance(sent_events[0], events.SessionUpdateEvent)
response_event = sent_events[-1]
assert isinstance(response_event, events.ResponseCreateEvent)
assert response_event.response.tool_choice == "none"
prompt = response_event.response.instructions
assert prompt.endswith('"Hello from Dograh."')
assert service._run_llm_when_api_session_ready is False
assert service._pending_initial_greeting_text is None
assert service._llm_needs_conversation_setup is False
service._create_response.assert_not_awaited()

View file

@ -7,7 +7,23 @@ from pipecat.processors.frame_processor import FrameDirection
from pipecat.transports.base_output import BaseOutputTransport
from pipecat.transports.base_transport import TransportParams
from api.services.pipecat.realtime_feedback_observer import RealtimeFeedbackObserver
from api.services.pipecat.in_memory_buffers import InMemoryLogsBuffer
from api.services.pipecat.realtime_feedback_observer import (
RealtimeFeedbackObserver,
register_turn_log_handlers,
)
class _FakeAggregator:
def __init__(self):
self.handlers = {}
def event_handler(self, event_name):
def decorator(handler):
self.handlers[event_name] = handler
return handler
return decorator
def _frame_pushed(frame, direction, *, source=None):
@ -98,3 +114,33 @@ async def test_observer_waits_for_tts_text_from_output_transport():
"payload": {"text": "Hello"},
}
]
@pytest.mark.asyncio
async def test_turn_log_handlers_persist_user_message_added_events():
logs_buffer = InMemoryLogsBuffer(workflow_run_id=123)
user_aggregator = _FakeAggregator()
assistant_aggregator = _FakeAggregator()
register_turn_log_handlers(logs_buffer, user_aggregator, assistant_aggregator)
assert "on_user_turn_message_added" in user_aggregator.handlers
assert "on_user_turn_stopped" not in user_aggregator.handlers
await user_aggregator.handlers["on_user_turn_message_added"](
user_aggregator,
SimpleNamespace(
content="Hi there",
timestamp="2026-01-01T00:00:00+00:00",
),
)
events = logs_buffer.get_events()
assert len(events) == 1
assert events[0]["type"] == "rtf-user-transcription"
assert events[0]["payload"] == {
"text": "Hi there",
"final": True,
"timestamp": "2026-01-01T00:00:00+00:00",
}
assert events[0]["turn"] == 1

View file

@ -30,6 +30,7 @@ def test_gemini_realtime_uses_local_vad_without_local_interruptions():
assert strategies.start[0]._enable_interruptions is False
assert len(strategies.stop) == 1
assert isinstance(strategies.stop[0], SpeechTimeoutUserTurnStopStrategy)
assert strategies.stop[0].wait_for_transcript is False
def test_gemini_vertex_realtime_uses_same_turn_config_as_gemini_live():
@ -41,6 +42,9 @@ def test_gemini_vertex_realtime_uses_same_turn_config_as_gemini_live():
assert len(strategies.start) == 1
assert isinstance(strategies.start[0], VADUserTurnStartStrategy)
assert strategies.start[0]._enable_interruptions is False
assert len(strategies.stop) == 1
assert isinstance(strategies.stop[0], SpeechTimeoutUserTurnStopStrategy)
assert strategies.stop[0].wait_for_transcript is False
def test_openai_realtime_uses_provider_turn_frames_without_local_vad():
@ -54,6 +58,21 @@ def test_openai_realtime_uses_provider_turn_frames_without_local_vad():
assert strategies.start[0]._enable_interruptions is False
assert len(strategies.stop) == 1
assert isinstance(strategies.stop[0], ExternalUserTurnStopStrategy)
assert strategies.stop[0].wait_for_transcript is False
def test_azure_realtime_uses_provider_turn_frames_without_local_vad():
strategies, vad_analyzer = _create_realtime_user_turn_config(
ServiceProviders.AZURE_REALTIME.value
)
assert vad_analyzer is None
assert len(strategies.start) == 1
assert isinstance(strategies.start[0], ExternalUserTurnStartStrategy)
assert strategies.start[0]._enable_interruptions is False
assert len(strategies.stop) == 1
assert isinstance(strategies.stop[0], ExternalUserTurnStopStrategy)
assert strategies.stop[0].wait_for_transcript is False
def test_grok_realtime_uses_provider_turn_frames_without_local_vad():
@ -67,6 +86,21 @@ def test_grok_realtime_uses_provider_turn_frames_without_local_vad():
assert strategies.start[0]._enable_interruptions is False
assert len(strategies.stop) == 1
assert isinstance(strategies.stop[0], ExternalUserTurnStopStrategy)
assert strategies.stop[0].wait_for_transcript is False
def test_ultravox_realtime_uses_local_vad_with_local_interruptions():
strategies, vad_analyzer = _create_realtime_user_turn_config(
ServiceProviders.ULTRAVOX_REALTIME.value
)
assert isinstance(vad_analyzer, SileroVADAnalyzer)
assert len(strategies.start) == 1
assert isinstance(strategies.start[0], VADUserTurnStartStrategy)
assert strategies.start[0]._enable_interruptions is True
assert len(strategies.stop) == 1
assert isinstance(strategies.stop[0], SpeechTimeoutUserTurnStopStrategy)
assert strategies.stop[0].wait_for_transcript is False
def test_unknown_realtime_providers_keep_local_vad():
@ -75,8 +109,10 @@ def test_unknown_realtime_providers_keep_local_vad():
assert isinstance(vad_analyzer, SileroVADAnalyzer)
assert len(strategies.start) == 1
assert isinstance(strategies.start[0], VADUserTurnStartStrategy)
assert strategies.start[0]._enable_interruptions is True
assert len(strategies.stop) == 1
assert isinstance(strategies.stop[0], SpeechTimeoutUserTurnStopStrategy)
assert strategies.stop[0].wait_for_transcript is False
def test_external_turn_stt_uses_longer_stop_timeout():

View file

@ -0,0 +1,38 @@
from api.utils.template_renderer import render_template
def test_initial_context_prefix_resolves_against_flat_context():
context = {
"first_name": "Abhishek",
"runtime_configuration": {
"realtime_model": "gpt-realtime-2",
},
}
assert (
render_template("Hi {{initial_context.first_name | there}}", context)
== "Hi Abhishek"
)
assert (
render_template(
"Model {{initial_context.runtime_configuration.realtime_model}}", context
)
== "Model gpt-realtime-2"
)
def test_initial_context_prefix_prefers_explicit_initial_context():
context = {
"first_name": "Flat",
"initial_context": {
"first_name": "Nested",
},
}
assert render_template("Hi {{initial_context.first_name}}", context) == "Hi Nested"
def test_initial_context_prefix_uses_fallback_when_missing_from_both_contexts():
assert (
render_template("Hi {{initial_context.first_name | there}}", {}) == "Hi there"
)