mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-10 08:05:22 +02:00
fix: fix rtf logs and gemini live turn taking
This commit is contained in:
parent
25751efe3c
commit
0c0b8383bf
6 changed files with 159 additions and 148 deletions
|
|
@ -3,7 +3,9 @@ from types import SimpleNamespace
|
|||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from pipecat.frames.frames import TranscriptionFrame
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
|
||||
from api.services.pipecat.realtime.gemini_live import DograhGeminiLiveLLMService
|
||||
|
||||
|
|
@ -84,3 +86,25 @@ async def test_disconnect_does_not_forget_previously_delivered_tool_results():
|
|||
|
||||
service._tool_result.assert_not_awaited()
|
||||
assert service._completed_tool_calls == {"call-transition"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_transcription_matches_upstream_upstream_push_behavior():
|
||||
service = _make_service()
|
||||
service._handle_user_transcription = AsyncMock()
|
||||
service.push_frame = AsyncMock()
|
||||
service.broadcast_frame = AsyncMock()
|
||||
|
||||
await service._push_user_transcription("Hi there")
|
||||
|
||||
service._handle_user_transcription.assert_awaited_once_with(
|
||||
"Hi there", True, service._settings.language
|
||||
)
|
||||
service.broadcast_frame.assert_not_awaited()
|
||||
service.push_frame.assert_awaited_once()
|
||||
|
||||
frame, direction = service.push_frame.await_args.args
|
||||
assert isinstance(frame, TranscriptionFrame)
|
||||
assert frame.text == "Hi there"
|
||||
assert frame.finalized is False
|
||||
assert direction == FrameDirection.UPSTREAM
|
||||
|
|
|
|||
100
api/tests/test_realtime_feedback_observer.py
Normal file
100
api/tests/test_realtime_feedback_observer.py
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from pipecat.frames.frames import TranscriptionFrame, TTSTextFrame
|
||||
from pipecat.observers.base_observer import FramePushed
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.transports.base_output import BaseOutputTransport
|
||||
from pipecat.transports.base_transport import TransportParams
|
||||
|
||||
from api.services.pipecat.realtime_feedback_observer import RealtimeFeedbackObserver
|
||||
|
||||
|
||||
def _frame_pushed(frame, direction, *, source=None):
|
||||
return FramePushed(
|
||||
source=source or SimpleNamespace(),
|
||||
destination=SimpleNamespace(),
|
||||
frame=frame,
|
||||
direction=direction,
|
||||
timestamp=0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_observer_streams_upstream_only_transcription_frames():
|
||||
messages = []
|
||||
|
||||
async def ws_sender(message):
|
||||
messages.append(message)
|
||||
|
||||
observer = RealtimeFeedbackObserver(ws_sender=ws_sender)
|
||||
frame = TranscriptionFrame(
|
||||
"Hi there",
|
||||
user_id="user-1",
|
||||
timestamp="2026-01-01T00:00:00+00:00",
|
||||
)
|
||||
|
||||
await observer.on_push_frame(_frame_pushed(frame, FrameDirection.UPSTREAM))
|
||||
|
||||
assert messages == [
|
||||
{
|
||||
"type": "rtf-user-transcription",
|
||||
"payload": {
|
||||
"text": "Hi there",
|
||||
"final": True,
|
||||
"timestamp": "2026-01-01T00:00:00+00:00",
|
||||
"user_id": "user-1",
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_observer_ignores_upstream_broadcast_transcription_sibling():
|
||||
messages = []
|
||||
|
||||
async def ws_sender(message):
|
||||
messages.append(message)
|
||||
|
||||
observer = RealtimeFeedbackObserver(ws_sender=ws_sender)
|
||||
frame = TranscriptionFrame(
|
||||
"Hi there",
|
||||
user_id="user-1",
|
||||
timestamp="2026-01-01T00:00:00+00:00",
|
||||
)
|
||||
frame.broadcast_sibling_id = 1234
|
||||
|
||||
await observer.on_push_frame(_frame_pushed(frame, FrameDirection.UPSTREAM))
|
||||
|
||||
assert messages == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_observer_waits_for_tts_text_from_output_transport():
|
||||
messages = []
|
||||
|
||||
async def ws_sender(message):
|
||||
messages.append(message)
|
||||
|
||||
observer = RealtimeFeedbackObserver(ws_sender=ws_sender)
|
||||
frame = TTSTextFrame("Hello", aggregated_by="word")
|
||||
frame.pts = 123
|
||||
|
||||
await observer.on_push_frame(_frame_pushed(frame, FrameDirection.DOWNSTREAM))
|
||||
assert messages == []
|
||||
|
||||
output_transport = BaseOutputTransport(TransportParams())
|
||||
await observer.on_push_frame(
|
||||
_frame_pushed(
|
||||
frame,
|
||||
FrameDirection.DOWNSTREAM,
|
||||
source=output_transport,
|
||||
)
|
||||
)
|
||||
|
||||
assert messages == [
|
||||
{
|
||||
"type": "rtf-bot-text",
|
||||
"payload": {"text": "Hello"},
|
||||
}
|
||||
]
|
||||
Loading…
Add table
Add a link
Reference in a new issue