mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-19 08:28:10 +02:00
feat: add openai realtime models
This commit is contained in:
parent
53f1959edf
commit
4d7b681928
33 changed files with 1518 additions and 75 deletions
|
|
@ -140,6 +140,45 @@ class TestToolToFunctionSchema:
|
|||
assert "duration_minutes" in required
|
||||
assert "is_priority" not in required
|
||||
|
||||
def test_preset_parameters_are_not_exposed_to_llm_schema(self):
|
||||
"""Test that preset parameters are injected at runtime, not shown to the LLM."""
|
||||
tool = MockToolModel(
|
||||
tool_uuid="test-uuid-preset",
|
||||
name="Lookup Customer",
|
||||
description="Lookup a customer using contextual identifiers",
|
||||
category="http_api",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "http_api",
|
||||
"config": {
|
||||
"method": "POST",
|
||||
"url": "https://api.example.com/customers/lookup",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "customer_name",
|
||||
"type": "string",
|
||||
"description": "Customer name spoken by the caller",
|
||||
"required": True,
|
||||
}
|
||||
],
|
||||
"preset_parameters": [
|
||||
{
|
||||
"name": "phone_number",
|
||||
"type": "string",
|
||||
"value_template": "{{initial_context.phone_number}}",
|
||||
"required": True,
|
||||
}
|
||||
],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
schema = tool_to_function_schema(tool)
|
||||
props = schema["function"]["parameters"]["properties"]
|
||||
|
||||
assert "customer_name" in props
|
||||
assert "phone_number" not in props
|
||||
|
||||
def test_tool_name_sanitization(self):
|
||||
"""Test that tool names with special characters are sanitized."""
|
||||
tool = MockToolModel(
|
||||
|
|
@ -255,6 +294,108 @@ class TestExecuteHttpTool:
|
|||
assert result["status_code"] == 201
|
||||
assert result["data"]["id"] == 123
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_request_injects_preset_parameters(self):
|
||||
"""Test that preset parameters are resolved from runtime context."""
|
||||
tool = MockToolModel(
|
||||
tool_uuid="test-uuid-preset",
|
||||
name="Create Lead",
|
||||
description="Create a lead with caller context",
|
||||
category="http_api",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "http_api",
|
||||
"config": {
|
||||
"method": "POST",
|
||||
"url": "https://api.example.com/leads",
|
||||
"timeout_ms": 5000,
|
||||
"preset_parameters": [
|
||||
{
|
||||
"name": "phone_number",
|
||||
"type": "string",
|
||||
"value_template": "{{initial_context.phone_number}}",
|
||||
"required": True,
|
||||
},
|
||||
{
|
||||
"name": "customer_id",
|
||||
"type": "number",
|
||||
"value_template": "{{gathered_context.customer_id}}",
|
||||
"required": True,
|
||||
},
|
||||
{
|
||||
"name": "is_vip",
|
||||
"type": "boolean",
|
||||
"value_template": "{{initial_context.is_vip}}",
|
||||
"required": False,
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
arguments = {"name": "John"}
|
||||
|
||||
with patch(
|
||||
"api.services.workflow.tools.custom_tool.httpx.AsyncClient"
|
||||
) as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 201
|
||||
mock_response.json.return_value = {"id": 123}
|
||||
mock_client.request.return_value = mock_response
|
||||
mock_client_class.return_value.__aenter__.return_value = mock_client
|
||||
|
||||
result = await execute_http_tool(
|
||||
tool,
|
||||
arguments,
|
||||
call_context_vars={
|
||||
"phone_number": "+14155550123",
|
||||
"is_vip": "true",
|
||||
},
|
||||
gathered_context_vars={"customer_id": "42"},
|
||||
)
|
||||
|
||||
call_kwargs = mock_client.request.call_args.kwargs
|
||||
assert call_kwargs["json"] == {
|
||||
"name": "John",
|
||||
"phone_number": "+14155550123",
|
||||
"customer_id": 42,
|
||||
"is_vip": True,
|
||||
}
|
||||
assert result["status"] == "success"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_required_preset_parameter_returns_error(self):
|
||||
"""Test that required preset parameters fail before the HTTP request."""
|
||||
tool = MockToolModel(
|
||||
tool_uuid="test-uuid-preset-error",
|
||||
name="Create Lead",
|
||||
description="Create a lead with caller context",
|
||||
category="http_api",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "http_api",
|
||||
"config": {
|
||||
"method": "POST",
|
||||
"url": "https://api.example.com/leads",
|
||||
"timeout_ms": 5000,
|
||||
"preset_parameters": [
|
||||
{
|
||||
"name": "phone_number",
|
||||
"type": "string",
|
||||
"value_template": "{{initial_context.phone_number}}",
|
||||
"required": True,
|
||||
}
|
||||
],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
result = await execute_http_tool(tool, {"name": "John"}, call_context_vars={})
|
||||
|
||||
assert result["status"] == "error"
|
||||
assert "phone_number" in result["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_request_sends_query_params(self):
|
||||
"""Test that GET requests send arguments as query parameters."""
|
||||
|
|
|
|||
86
api/tests/test_gemini_live_reconnect_tool_results.py
Normal file
86
api/tests/test_gemini_live_reconnect_tool_results.py
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
import json
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
|
||||
from api.services.pipecat.realtime.gemini_live import DograhGeminiLiveLLMService
|
||||
|
||||
|
||||
class _TestDograhGeminiLiveLLMService(DograhGeminiLiveLLMService):
|
||||
"""Dograh Gemini service with client creation stubbed for unit tests."""
|
||||
|
||||
def create_client(self):
|
||||
self._client = SimpleNamespace(
|
||||
aio=SimpleNamespace(live=SimpleNamespace(connect=None))
|
||||
)
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self):
|
||||
self.send_tool_response = AsyncMock()
|
||||
self.send_realtime_input = AsyncMock()
|
||||
self.close = AsyncMock()
|
||||
|
||||
|
||||
def _make_service() -> _TestDograhGeminiLiveLLMService:
|
||||
service = _TestDograhGeminiLiveLLMService(api_key="test-key")
|
||||
service.stop_all_metrics = AsyncMock()
|
||||
service.start_ttfb_metrics = AsyncMock()
|
||||
service.cancel_task = AsyncMock()
|
||||
service.push_error = AsyncMock()
|
||||
return service
|
||||
|
||||
|
||||
def _make_tool_result_context(tool_call_id: str) -> LLMContext:
|
||||
return LLMContext(
|
||||
messages=[
|
||||
{
|
||||
"role": "tool",
|
||||
"content": json.dumps({"status": "done"}),
|
||||
"tool_call_id": tool_call_id,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_updated_context_during_reconnect_keeps_result_pending_until_session_ready():
|
||||
service = _make_service()
|
||||
service._handled_initial_context = True
|
||||
service._tool_call_id_to_name = {"call-transition": "transition_to_next_node"}
|
||||
service._session = _FakeSession()
|
||||
|
||||
context = _make_tool_result_context("call-transition")
|
||||
|
||||
await service._disconnect()
|
||||
await service._handle_context(context)
|
||||
|
||||
# A reconnect gap should not count as successful delivery to Gemini.
|
||||
assert "call-transition" not in service._completed_tool_calls
|
||||
|
||||
session = _FakeSession()
|
||||
await service._handle_session_ready(session)
|
||||
|
||||
session.send_tool_response.assert_awaited_once()
|
||||
sent_response = session.send_tool_response.await_args.kwargs["function_responses"]
|
||||
assert sent_response.id == "call-transition"
|
||||
assert sent_response.name == "transition_to_next_node"
|
||||
assert "call-transition" in service._completed_tool_calls
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_does_not_forget_previously_delivered_tool_results():
|
||||
service = _make_service()
|
||||
service._context = _make_tool_result_context("call-transition")
|
||||
service._completed_tool_calls = {"call-transition"}
|
||||
service._tool_call_id_to_name = {"call-transition": "transition_to_next_node"}
|
||||
service._session = _FakeSession()
|
||||
service._tool_result = AsyncMock()
|
||||
|
||||
await service._disconnect()
|
||||
await service._process_completed_function_calls(send_new_results=True)
|
||||
|
||||
service._tool_result.assert_not_awaited()
|
||||
assert service._completed_tool_calls == {"call-transition"}
|
||||
98
api/tests/test_openai_realtime_initial_context.py
Normal file
98
api/tests/test_openai_realtime_initial_context.py
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from pipecat.frames.frames import TTSSpeakFrame
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
|
||||
from api.services.pipecat.realtime.openai_realtime import (
|
||||
DograhOpenAIRealtimeLLMService,
|
||||
)
|
||||
|
||||
|
||||
def _make_service() -> DograhOpenAIRealtimeLLMService:
|
||||
service = DograhOpenAIRealtimeLLMService(api_key="test-key")
|
||||
service._create_response = AsyncMock()
|
||||
service._process_completed_function_calls = AsyncMock()
|
||||
return service
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initial_context_triggers_response_when_context_was_prepopulated():
|
||||
service = _make_service()
|
||||
context = LLMContext()
|
||||
service._context = context
|
||||
|
||||
await service._handle_context(context)
|
||||
|
||||
assert service._handled_initial_context is True
|
||||
assert service._context is context
|
||||
service._create_response.assert_awaited_once()
|
||||
service._process_completed_function_calls.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_updated_context_uses_tool_result_path_after_initial_context():
|
||||
service = _make_service()
|
||||
context = LLMContext()
|
||||
service._handled_initial_context = True
|
||||
|
||||
await service._handle_context(context)
|
||||
|
||||
assert service._context is context
|
||||
service._create_response.assert_not_awaited()
|
||||
service._process_completed_function_calls.assert_awaited_once_with(
|
||||
send_new_results=True
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tts_greeting_uses_initial_context_handler():
|
||||
service = _make_service()
|
||||
service._context = LLMContext()
|
||||
service._handle_context = AsyncMock()
|
||||
|
||||
await service.process_frame(
|
||||
TTSSpeakFrame("hello", append_to_context=True),
|
||||
FrameDirection.DOWNSTREAM,
|
||||
)
|
||||
|
||||
service._handle_context.assert_awaited_once_with(service._context)
|
||||
service._create_response.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_function_call_executes_immediately_when_bot_is_not_speaking():
|
||||
service = _make_service()
|
||||
service._context = LLMContext()
|
||||
service.run_function_calls = AsyncMock()
|
||||
service._pending_function_calls["call-1"] = SimpleNamespace(name="customer_support")
|
||||
|
||||
await service._handle_evt_function_call_arguments_done(
|
||||
SimpleNamespace(call_id="call-1", arguments='{"department":"sales"}')
|
||||
)
|
||||
|
||||
service.run_function_calls.assert_awaited_once()
|
||||
assert service._deferred_function_calls == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_function_call_is_deferred_until_bot_stops_speaking():
|
||||
service = _make_service()
|
||||
service._context = LLMContext()
|
||||
service.run_function_calls = AsyncMock()
|
||||
service._bot_is_speaking = True
|
||||
service._pending_function_calls["call-1"] = SimpleNamespace(name="customer_support")
|
||||
|
||||
await service._handle_evt_function_call_arguments_done(
|
||||
SimpleNamespace(call_id="call-1", arguments='{"department":"sales"}')
|
||||
)
|
||||
|
||||
service.run_function_calls.assert_not_awaited()
|
||||
assert len(service._deferred_function_calls) == 1
|
||||
|
||||
await service._run_pending_function_calls()
|
||||
|
||||
service.run_function_calls.assert_awaited_once()
|
||||
assert service._deferred_function_calls == []
|
||||
61
api/tests/test_run_pipeline_realtime_turn_config.py
Normal file
61
api/tests/test_run_pipeline_realtime_turn_config.py
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.turns.user_start import (
|
||||
ExternalUserTurnStartStrategy,
|
||||
)
|
||||
from pipecat.turns.user_start.vad_user_turn_start_strategy import (
|
||||
VADUserTurnStartStrategy,
|
||||
)
|
||||
from pipecat.turns.user_stop import (
|
||||
ExternalUserTurnStopStrategy,
|
||||
SpeechTimeoutUserTurnStopStrategy,
|
||||
)
|
||||
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.pipecat.run_pipeline import _create_realtime_user_turn_config
|
||||
|
||||
|
||||
def test_gemini_realtime_uses_local_vad_without_local_interruptions():
|
||||
strategies, vad_analyzer = _create_realtime_user_turn_config(
|
||||
ServiceProviders.GOOGLE_REALTIME.value
|
||||
)
|
||||
|
||||
assert isinstance(vad_analyzer, SileroVADAnalyzer)
|
||||
assert len(strategies.start) == 1
|
||||
assert isinstance(strategies.start[0], VADUserTurnStartStrategy)
|
||||
assert strategies.start[0]._enable_interruptions is False
|
||||
assert len(strategies.stop) == 1
|
||||
assert isinstance(strategies.stop[0], SpeechTimeoutUserTurnStopStrategy)
|
||||
|
||||
|
||||
def test_gemini_vertex_realtime_uses_same_turn_config_as_gemini_live():
|
||||
strategies, vad_analyzer = _create_realtime_user_turn_config(
|
||||
ServiceProviders.GOOGLE_VERTEX_REALTIME.value
|
||||
)
|
||||
|
||||
assert isinstance(vad_analyzer, SileroVADAnalyzer)
|
||||
assert len(strategies.start) == 1
|
||||
assert isinstance(strategies.start[0], VADUserTurnStartStrategy)
|
||||
assert strategies.start[0]._enable_interruptions is False
|
||||
|
||||
|
||||
def test_openai_realtime_uses_provider_turn_frames_without_local_vad():
|
||||
strategies, vad_analyzer = _create_realtime_user_turn_config(
|
||||
ServiceProviders.OPENAI_REALTIME.value
|
||||
)
|
||||
|
||||
assert vad_analyzer is None
|
||||
assert len(strategies.start) == 1
|
||||
assert isinstance(strategies.start[0], ExternalUserTurnStartStrategy)
|
||||
assert strategies.start[0]._enable_interruptions is False
|
||||
assert len(strategies.stop) == 1
|
||||
assert isinstance(strategies.stop[0], ExternalUserTurnStopStrategy)
|
||||
|
||||
|
||||
def test_unknown_realtime_providers_keep_local_vad():
|
||||
strategies, vad_analyzer = _create_realtime_user_turn_config("other_realtime")
|
||||
|
||||
assert isinstance(vad_analyzer, SileroVADAnalyzer)
|
||||
assert len(strategies.start) == 1
|
||||
assert isinstance(strategies.start[0], VADUserTurnStartStrategy)
|
||||
assert len(strategies.stop) == 1
|
||||
assert isinstance(strategies.stop[0], SpeechTimeoutUserTurnStopStrategy)
|
||||
|
|
@ -66,7 +66,7 @@ class TestUnregisteredFunctionCall:
|
|||
|
||||
# Pipecat's missing-function handler returns a string error.
|
||||
assert isinstance(result_frame.result, str)
|
||||
assert "not registered" in result_frame.result
|
||||
assert "not currently available" in result_frame.result
|
||||
assert "nonexistent_tool" in result_frame.result
|
||||
|
||||
# In-progress frame should also be emitted before the result so mute
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue