mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-22 08:38:13 +02:00
Merge remote-tracking branch 'origin/main' into pr-381
This commit is contained in:
commit
858c474139
119 changed files with 5057 additions and 1018 deletions
|
|
@ -15,7 +15,7 @@ Provided here:
|
|||
- ``NoopFeedbackObserver``: a ``RealtimeFeedbackObserver`` stand-in with
|
||||
no WebSocket / clock-task side effects.
|
||||
- ``patch_run_pipeline_externals``: ``contextmanager`` that applies the
|
||||
full patch set and captures the constructed ``PipelineTask`` for the
|
||||
full patch set and captures the constructed ``PipelineWorker`` for the
|
||||
caller. Optional ``llm`` / ``tts`` arguments inject preconfigured
|
||||
mocks; otherwise blank ``MockLLMService`` / ``MockTTSService``
|
||||
instances are constructed per-call.
|
||||
|
|
@ -84,10 +84,10 @@ def patch_run_pipeline_externals(
|
|||
tts: MockTTSService | None = None,
|
||||
):
|
||||
"""Patch the externally-talking pieces of ``_run_pipeline`` and capture
|
||||
the constructed ``PipelineTask`` so tests can drive it from outside.
|
||||
the constructed ``PipelineWorker`` so tests can drive it from outside.
|
||||
|
||||
Args:
|
||||
captured_task: A list the constructed ``PipelineTask`` is appended
|
||||
captured_task: A list the constructed ``PipelineWorker`` is appended
|
||||
to. Tests read ``captured_task[0]`` to get a handle on the task
|
||||
(to wait on its start event, queue frames, cancel it, etc.).
|
||||
llm: Optional pre-built ``MockLLMService``. When given, every call
|
||||
|
|
@ -168,7 +168,7 @@ def patch_run_pipeline_externals(
|
|||
return_value="completed",
|
||||
)
|
||||
)
|
||||
# Capture the PipelineTask so the test can drive it from outside.
|
||||
# Capture the PipelineWorker so the test can drive it from outside.
|
||||
stack.enter_context(
|
||||
patch(
|
||||
"api.services.pipecat.run_pipeline.create_pipeline_task",
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
Drives the actual ``_run_pipeline`` against the test database with real
|
||||
DB rows (organization, user, user configuration, workflow, workflow run)
|
||||
and pipecat's real ``MockTransport`` / ``Pipeline`` / ``PipelineTask``.
|
||||
and pipecat's real ``MockTransport`` / ``Pipeline`` / ``PipelineWorker``.
|
||||
The only patches are for things that talk to genuinely external systems;
|
||||
those are applied via ``patch_run_pipeline_externals`` from the shared
|
||||
helpers module.
|
||||
|
|
@ -23,6 +23,7 @@ from pipecat.transports.base_transport import TransportParams
|
|||
from api.enums import WorkflowRunMode, WorkflowRunState
|
||||
from api.services.pipecat.audio_config import create_audio_config
|
||||
from api.services.pipecat.run_pipeline import _run_pipeline
|
||||
from api.services.pipecat.worker_runner import wait_for_pipeline_worker_started
|
||||
from api.tests.integrations._run_pipeline_helpers import (
|
||||
create_workflow_run_rows,
|
||||
patch_run_pipeline_externals,
|
||||
|
|
@ -116,7 +117,9 @@ async def test_run_pipeline_fires_initial_response_and_completes_run(
|
|||
run_task.result() # re-raise the failure
|
||||
assert captured_task, "create_pipeline_task was never invoked"
|
||||
pipeline_task = captured_task[0]
|
||||
await asyncio.wait_for(pipeline_task._pipeline_start_event.wait(), timeout=3.0)
|
||||
await wait_for_pipeline_worker_started(
|
||||
pipeline_task, timeout=3.0, run_task=run_task
|
||||
)
|
||||
# Let the initial response handler (set_node, queue LLMContextFrame)
|
||||
# complete before tearing things down.
|
||||
await asyncio.sleep(0.1)
|
||||
|
|
|
|||
|
|
@ -36,6 +36,7 @@ from pipecat.utils.time import time_now_iso8601
|
|||
from api.enums import WorkflowRunMode, WorkflowRunState
|
||||
from api.services.pipecat.audio_config import create_audio_config
|
||||
from api.services.pipecat.run_pipeline import _run_pipeline
|
||||
from api.services.pipecat.worker_runner import wait_for_pipeline_worker_started
|
||||
from api.tests.integrations._run_pipeline_helpers import (
|
||||
create_workflow_run_rows,
|
||||
patch_run_pipeline_externals,
|
||||
|
|
@ -186,12 +187,12 @@ async def _run_test_body(workflow_run_setup, db_session) -> None:
|
|||
assert captured_task, "create_pipeline_task was never invoked"
|
||||
pipeline_task = captured_task[0]
|
||||
|
||||
await asyncio.wait_for(
|
||||
pipeline_task._pipeline_start_event.wait(), timeout=3.0
|
||||
await wait_for_pipeline_worker_started(
|
||||
pipeline_task, timeout=3.0, run_task=run_task
|
||||
)
|
||||
|
||||
# Locate the assistant aggregator's LLM context (downstream of TTS).
|
||||
# The PipelineTask wraps the user's pipeline inside another Pipeline,
|
||||
# The PipelineWorker wraps the user's pipeline inside another Pipeline,
|
||||
# so we walk the tree recursively.
|
||||
assistant_aggregator = _find_processor_by_class_name(
|
||||
pipeline_task, "LLMAssistantAggregator"
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from pipecat.frames.frames import (
|
|||
LLMContextFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
UserTurnInferenceCompletedFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
|
|
@ -28,6 +29,7 @@ from pipecat.services.llm_service import FunctionCallParams
|
|||
|
||||
from api.services.workflow.pipecat_engine_custom_tools import get_function_schema
|
||||
from api.services.workflow.tools.custom_tool import (
|
||||
_coerce_parameter_value,
|
||||
execute_http_tool,
|
||||
tool_to_function_schema,
|
||||
)
|
||||
|
|
@ -140,6 +142,51 @@ class TestToolToFunctionSchema:
|
|||
assert "duration_minutes" in required
|
||||
assert "is_priority" not in required
|
||||
|
||||
def test_tool_with_object_and_array_parameters(self):
|
||||
"""Test converting a tool with object and array parameters."""
|
||||
tool = MockToolModel(
|
||||
tool_uuid="test-uuid-nested",
|
||||
name="Create Booking",
|
||||
description="Create a booking with nested details",
|
||||
category="http_api",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "http_api",
|
||||
"config": {
|
||||
"method": "POST",
|
||||
"url": "https://api.example.com/bookings",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "booking",
|
||||
"type": "object",
|
||||
"description": "Nested booking payload",
|
||||
"required": True,
|
||||
},
|
||||
{
|
||||
"name": "attendees",
|
||||
"type": "array",
|
||||
"description": "Booking attendees",
|
||||
"required": False,
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
schema = tool_to_function_schema(tool)
|
||||
|
||||
props = schema["function"]["parameters"]["properties"]
|
||||
assert props["booking"] == {
|
||||
"type": "object",
|
||||
"additionalProperties": True,
|
||||
"description": "Nested booking payload",
|
||||
}
|
||||
assert props["attendees"] == {
|
||||
"type": "array",
|
||||
"items": {},
|
||||
"description": "Booking attendees",
|
||||
}
|
||||
|
||||
def test_preset_parameters_are_not_exposed_to_llm_schema(self):
|
||||
"""Test that preset parameters are injected at runtime, not shown to the LLM."""
|
||||
tool = MockToolModel(
|
||||
|
|
@ -294,6 +341,51 @@ class TestExecuteHttpTool:
|
|||
assert result["status_code"] == 201
|
||||
assert result["data"]["id"] == 123
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_request_sends_nested_json_body(self):
|
||||
"""Test that POST requests preserve nested arguments in the JSON body."""
|
||||
tool = MockToolModel(
|
||||
tool_uuid="test-uuid-nested",
|
||||
name="Create Booking",
|
||||
description="Create a nested booking",
|
||||
category="http_api",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "http_api",
|
||||
"config": {
|
||||
"method": "POST",
|
||||
"url": "https://api.example.com/bookings",
|
||||
"timeout_ms": 5000,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
arguments = {
|
||||
"booking": {
|
||||
"start": "2026-05-28T10:00:00Z",
|
||||
"attendee": {"name": "Jane", "email": "jane@example.com"},
|
||||
"metadata": {"source": "voice"},
|
||||
}
|
||||
}
|
||||
|
||||
with patch(
|
||||
"api.services.workflow.tools.custom_tool.httpx.AsyncClient"
|
||||
) as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"bookingId": "booking-123"}
|
||||
mock_client.request.return_value = mock_response
|
||||
mock_client_class.return_value.__aenter__.return_value = mock_client
|
||||
|
||||
result = await execute_http_tool(tool, arguments)
|
||||
|
||||
call_kwargs = mock_client.request.call_args.kwargs
|
||||
assert call_kwargs["json"] == arguments
|
||||
assert isinstance(call_kwargs["json"]["booking"], dict)
|
||||
assert isinstance(call_kwargs["json"]["booking"]["attendee"], dict)
|
||||
assert result["status"] == "success"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_request_injects_preset_parameters(self):
|
||||
"""Test that preset parameters are resolved from runtime context."""
|
||||
|
|
@ -468,7 +560,7 @@ class TestExecuteHttpTool:
|
|||
mock_client.request.return_value = mock_response
|
||||
mock_client_class.return_value.__aenter__.return_value = mock_client
|
||||
|
||||
result = await execute_http_tool(tool, arguments)
|
||||
await execute_http_tool(tool, arguments)
|
||||
|
||||
call_kwargs = mock_client.request.call_args.kwargs
|
||||
assert call_kwargs["method"] == "DELETE"
|
||||
|
|
@ -639,6 +731,51 @@ class TestExecuteHttpTool:
|
|||
mock_db.get_credential_by_uuid.assert_not_called()
|
||||
|
||||
|
||||
class TestCoerceParameterValue:
|
||||
"""Tests for _coerce_parameter_value function."""
|
||||
|
||||
def test_object_value_returns_dict_unchanged(self):
|
||||
"""Test that object parameters preserve dict values."""
|
||||
value = {"attendee": {"name": "Jane"}}
|
||||
|
||||
assert _coerce_parameter_value(value, "object") is value
|
||||
|
||||
def test_object_value_parses_json_string(self):
|
||||
"""Test that object parameters parse JSON string values."""
|
||||
value = '{"attendee": {"name": "Jane"}}'
|
||||
|
||||
assert _coerce_parameter_value(value, "object") == {
|
||||
"attendee": {"name": "Jane"}
|
||||
}
|
||||
|
||||
def test_array_value_returns_list_unchanged(self):
|
||||
"""Test that array parameters preserve list values."""
|
||||
value = [{"name": "Jane"}, {"name": "Sam"}]
|
||||
|
||||
assert _coerce_parameter_value(value, "array") is value
|
||||
|
||||
def test_array_value_parses_json_string(self):
|
||||
"""Test that array parameters parse JSON string values."""
|
||||
value = '[{"name": "Jane"}, {"name": "Sam"}]'
|
||||
|
||||
assert _coerce_parameter_value(value, "array") == [
|
||||
{"name": "Jane"},
|
||||
{"name": "Sam"},
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("value", ["not json", "[]", "null"])
|
||||
def test_object_value_rejects_invalid_or_wrong_shape(self, value):
|
||||
"""Test that object parameters require a JSON object."""
|
||||
with pytest.raises(ValueError, match="Cannot convert"):
|
||||
_coerce_parameter_value(value, "object")
|
||||
|
||||
@pytest.mark.parametrize("value", ["not json", "{}", "null"])
|
||||
def test_array_value_rejects_invalid_or_wrong_shape(self, value):
|
||||
"""Test that array parameters require a JSON array."""
|
||||
with pytest.raises(ValueError, match="Cannot convert"):
|
||||
_coerce_parameter_value(value, "array")
|
||||
|
||||
|
||||
class TestAuthHeaders:
|
||||
"""Tests for auth header building utilities."""
|
||||
|
||||
|
|
@ -793,6 +930,7 @@ class TestCustomToolManagerIntegration:
|
|||
expected_down_frames=[
|
||||
LLMFullResponseStartFrame,
|
||||
FunctionCallsFromLLMInfoFrame,
|
||||
UserTurnInferenceCompletedFrame,
|
||||
FunctionCallsStartedFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
FunctionCallInProgressFrame,
|
||||
|
|
|
|||
|
|
@ -3,7 +3,9 @@ from types import SimpleNamespace
|
|||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from pipecat.frames.frames import TranscriptionFrame
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
|
||||
from api.services.pipecat.realtime.gemini_live import DograhGeminiLiveLLMService
|
||||
|
||||
|
|
@ -84,3 +86,25 @@ async def test_disconnect_does_not_forget_previously_delivered_tool_results():
|
|||
|
||||
service._tool_result.assert_not_awaited()
|
||||
assert service._completed_tool_calls == {"call-transition"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_transcription_matches_upstream_upstream_push_behavior():
|
||||
service = _make_service()
|
||||
service._handle_user_transcription = AsyncMock()
|
||||
service.push_frame = AsyncMock()
|
||||
service.broadcast_frame = AsyncMock()
|
||||
|
||||
await service._push_user_transcription("Hi there")
|
||||
|
||||
service._handle_user_transcription.assert_awaited_once_with(
|
||||
"Hi there", True, service._settings.language
|
||||
)
|
||||
service.broadcast_frame.assert_not_awaited()
|
||||
service.push_frame.assert_awaited_once()
|
||||
|
||||
frame, direction = service.push_frame.await_args.args
|
||||
assert isinstance(frame, TranscriptionFrame)
|
||||
assert frame.text == "Hi there"
|
||||
assert frame.finalized is False
|
||||
assert direction == FrameDirection.UPSTREAM
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ async def test_engine_opens_and_closes_mcp_sessions(monkeypatch):
|
|||
assert sess.available is True
|
||||
assert len(sess.function_schemas()) == 2
|
||||
finally:
|
||||
await engine._close_mcp_sessions()
|
||||
await engine.close_mcp_sessions()
|
||||
assert engine._mcp_sessions == {}
|
||||
|
||||
|
||||
|
|
|
|||
164
api/tests/test_mcp_tool_creation.py
Normal file
164
api/tests/test_mcp_tool_creation.py
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
|
||||
from api.app import app
|
||||
from api.mcp_server.server import mcp
|
||||
from api.mcp_server.tools.tool_creation import create_tool
|
||||
from api.schemas.tool import CreateToolRequest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def authed_user() -> MagicMock:
|
||||
user = MagicMock()
|
||||
user.id = 11
|
||||
user.provider_id = "provider-11"
|
||||
user.selected_organization_id = 22
|
||||
return user
|
||||
|
||||
|
||||
def _tool_model(**overrides):
|
||||
now = datetime.now(UTC)
|
||||
values = {
|
||||
"id": 3,
|
||||
"tool_uuid": "tool-uuid-3",
|
||||
"name": "Lookup Account",
|
||||
"description": "Lookup an account by phone number",
|
||||
"category": "http_api",
|
||||
"icon": "globe",
|
||||
"icon_color": "#3B82F6",
|
||||
"status": "active",
|
||||
"definition": {
|
||||
"schema_version": 1,
|
||||
"type": "http_api",
|
||||
"config": {"method": "POST", "url": "https://api.example.com/lookup"},
|
||||
},
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
}
|
||||
values.update(overrides)
|
||||
return SimpleNamespace(**values)
|
||||
|
||||
|
||||
def _http_tool_request(**config_overrides) -> CreateToolRequest:
|
||||
config = {"method": "post", "url": "https://api.example.com/lookup"}
|
||||
config.update(config_overrides)
|
||||
return CreateToolRequest(
|
||||
name="Lookup Account",
|
||||
description="Lookup an account by phone number",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "http_api",
|
||||
"config": config,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_create_tool_creates_reusable_tool(authed_user: MagicMock):
|
||||
create_tool_mock = AsyncMock(return_value=_tool_model())
|
||||
|
||||
with (
|
||||
patch(
|
||||
"api.mcp_server.tools.tool_creation.authenticate_mcp_request",
|
||||
AsyncMock(return_value=authed_user),
|
||||
),
|
||||
patch(
|
||||
"api.services.tool_management.db_client.create_tool",
|
||||
create_tool_mock,
|
||||
),
|
||||
patch("api.services.tool_management.capture_event") as capture_event_mock,
|
||||
):
|
||||
result = await create_tool(_http_tool_request())
|
||||
|
||||
assert result["created"] is True
|
||||
assert result["tool_uuid"] == "tool-uuid-3"
|
||||
assert result["category"] == "http_api"
|
||||
create_tool_mock.assert_awaited_once()
|
||||
assert create_tool_mock.call_args.kwargs["organization_id"] == 22
|
||||
assert create_tool_mock.call_args.kwargs["user_id"] == 11
|
||||
assert create_tool_mock.call_args.kwargs["definition"]["config"]["method"] == "POST"
|
||||
capture_event_mock.assert_called_once()
|
||||
assert capture_event_mock.call_args.kwargs["properties"]["source"] == "mcp"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_create_tool_rejects_unknown_credential(authed_user: MagicMock):
|
||||
create_tool_mock = AsyncMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"api.mcp_server.tools.tool_creation.authenticate_mcp_request",
|
||||
AsyncMock(return_value=authed_user),
|
||||
),
|
||||
patch(
|
||||
"api.services.tool_management.db_client.get_credential_by_uuid",
|
||||
AsyncMock(return_value=None),
|
||||
),
|
||||
patch(
|
||||
"api.services.tool_management.db_client.create_tool",
|
||||
create_tool_mock,
|
||||
),
|
||||
):
|
||||
result = await create_tool(_http_tool_request(credential_uuid="cred-missing"))
|
||||
|
||||
assert result["created"] is False
|
||||
assert result["error_code"] == "credential_not_found"
|
||||
create_tool_mock.assert_not_awaited()
|
||||
|
||||
|
||||
def test_sdk_openapi_exposes_create_tool_schema_and_llm_hints():
|
||||
sdk_routes = [
|
||||
r
|
||||
for r in app.routes
|
||||
if getattr(r, "openapi_extra", None)
|
||||
and "x-sdk-method" in (r.openapi_extra or {})
|
||||
]
|
||||
spec = get_openapi(title=app.title, version=app.version, routes=sdk_routes)
|
||||
operations = [
|
||||
op
|
||||
for path_item in spec["paths"].values()
|
||||
for op in path_item.values()
|
||||
if isinstance(op, dict)
|
||||
]
|
||||
assert any(op.get("x-sdk-method") == "create_tool" for op in operations)
|
||||
|
||||
credential_schema = spec["components"]["schemas"]["HttpApiConfig"]["properties"][
|
||||
"credential_uuid"
|
||||
]
|
||||
assert "list_credentials" in credential_schema["llm_hint"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_create_tool_schema_includes_validation_and_llm_hints():
|
||||
tools = await mcp.list_tools()
|
||||
create_tool_spec = next(t for t in tools if t.name == "create_tool")
|
||||
|
||||
request_schema = create_tool_spec.parameters["properties"]["request"]
|
||||
definition_schema = request_schema["properties"]["definition"]
|
||||
http_config = definition_schema["oneOf"][0]["properties"]["config"]
|
||||
|
||||
assert request_schema["properties"]["category"]["enum"] == [
|
||||
"http_api",
|
||||
"end_call",
|
||||
"transfer_call",
|
||||
"calculator",
|
||||
"native",
|
||||
"integration",
|
||||
"mcp",
|
||||
]
|
||||
assert http_config["properties"]["method"]["enum"] == [
|
||||
"GET",
|
||||
"POST",
|
||||
"PUT",
|
||||
"PATCH",
|
||||
"DELETE",
|
||||
]
|
||||
assert (
|
||||
"list_credentials" in http_config["properties"]["credential_uuid"]["llm_hint"]
|
||||
)
|
||||
|
|
@ -16,10 +16,20 @@ Test coverage:
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from pydantic import ValidationError
|
||||
|
||||
from api.routes.tool import CreateToolRequest, McpToolDefinition, UpdateToolRequest
|
||||
from api.routes.tool import (
|
||||
CreateToolRequest,
|
||||
McpToolConfig,
|
||||
McpToolDefinition,
|
||||
UpdateToolRequest,
|
||||
_populate_discovered_tools,
|
||||
refresh_mcp_tools,
|
||||
)
|
||||
from api.services.workflow.tools.mcp_tool import (
|
||||
validate_mcp_definition,
|
||||
)
|
||||
|
|
@ -70,6 +80,53 @@ def test_update_tool_request_accepts_mcp_definition():
|
|||
assert req.definition.config.url == "https://x/mcp"
|
||||
|
||||
|
||||
def test_update_tool_request_accepts_http_api_complex_parameter_types():
|
||||
"""HTTP API tools may accept structured JSON parameters."""
|
||||
req = UpdateToolRequest(
|
||||
name="Check Availability New Multi",
|
||||
description="Check Availability when asked for it.",
|
||||
definition={
|
||||
"schema_version": 1,
|
||||
"type": "http_api",
|
||||
"config": {
|
||||
"method": "POST",
|
||||
"url": "https://automation.dograh.com/webhook/example",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "params",
|
||||
"type": "object",
|
||||
"description": (
|
||||
"An object containing the name and datetime in ISO format"
|
||||
),
|
||||
"required": True,
|
||||
},
|
||||
{
|
||||
"name": "slots",
|
||||
"type": "array",
|
||||
"description": "Candidate availability slots.",
|
||||
"required": False,
|
||||
},
|
||||
],
|
||||
"preset_parameters": [
|
||||
{
|
||||
"name": "phone_number",
|
||||
"type": "string",
|
||||
"value_template": "{{initial_context.phone_number}}",
|
||||
"required": True,
|
||||
}
|
||||
],
|
||||
"timeout_ms": 5000,
|
||||
"customMessageType": "text",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert req.definition.type == "http_api"
|
||||
parameters = req.definition.config.parameters
|
||||
assert parameters[0].type == "object"
|
||||
assert parameters[1].type == "array"
|
||||
|
||||
|
||||
def test_create_tool_request_accepts_mcp_with_all_fields():
|
||||
"""All optional MCP config fields are accepted and preserved."""
|
||||
req = CreateToolRequest(
|
||||
|
|
@ -279,10 +336,6 @@ async def test_post_tool_mcp_invalid_url_returns_422(test_client_factory, db_ses
|
|||
|
||||
# ── Task 6: discovered_tools field and _populate_discovered_tools helper ──────
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from api.routes.tool import McpToolConfig, _populate_discovered_tools
|
||||
|
||||
|
||||
def test_mcp_config_accepts_discovered_tools():
|
||||
cfg = McpToolConfig(
|
||||
|
|
@ -296,10 +349,10 @@ def test_mcp_config_accepts_discovered_tools():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_populate_discovered_tools_overwrites_cache(monkeypatch):
|
||||
import api.routes.tool as tool_mod
|
||||
import api.services.tool_management as tool_svc
|
||||
|
||||
monkeypatch.setattr(
|
||||
tool_mod,
|
||||
tool_svc,
|
||||
"discover_mcp_tools",
|
||||
AsyncMock(return_value=[{"name": "echo", "description": "Echo"}]),
|
||||
)
|
||||
|
|
@ -327,10 +380,10 @@ async def test_populate_discovered_tools_non_mcp_is_noop():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_populate_discovered_tools_server_down_sets_empty(monkeypatch):
|
||||
import api.routes.tool as tool_mod
|
||||
import api.services.tool_management as tool_svc
|
||||
|
||||
monkeypatch.setattr(
|
||||
tool_mod,
|
||||
tool_svc,
|
||||
"discover_mcp_tools",
|
||||
AsyncMock(side_effect=RuntimeError("connection refused")),
|
||||
)
|
||||
|
|
@ -345,10 +398,6 @@ async def test_populate_discovered_tools_server_down_sets_empty(monkeypatch):
|
|||
|
||||
# ── Task 7: POST /{tool_uuid}/mcp/refresh ─────────────────────────────────────
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from api.routes.tool import refresh_mcp_tools
|
||||
|
||||
|
||||
def _fake_user(org_id=1):
|
||||
u = MagicMock()
|
||||
|
|
@ -373,19 +422,19 @@ def _mcp_tool_model(org_id=1):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_success(monkeypatch):
|
||||
import api.routes.tool as tool_mod
|
||||
import api.services.tool_management as tool_svc
|
||||
|
||||
tool = _mcp_tool_model()
|
||||
monkeypatch.setattr(
|
||||
tool_mod.db_client, "get_tool_by_uuid", AsyncMock(return_value=tool)
|
||||
tool_svc.db_client, "get_tool_by_uuid", AsyncMock(return_value=tool)
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
tool_mod.db_client,
|
||||
tool_svc.db_client,
|
||||
"update_tool",
|
||||
AsyncMock(return_value=tool),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
tool_mod,
|
||||
tool_svc,
|
||||
"discover_mcp_tools",
|
||||
AsyncMock(return_value=[{"name": "echo", "description": "Echo"}]),
|
||||
)
|
||||
|
|
@ -396,29 +445,29 @@ async def test_refresh_success(monkeypatch):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_server_down_returns_200_with_error(monkeypatch):
|
||||
import api.routes.tool as tool_mod
|
||||
import api.services.tool_management as tool_svc
|
||||
|
||||
tool = _mcp_tool_model()
|
||||
monkeypatch.setattr(
|
||||
tool_mod.db_client, "get_tool_by_uuid", AsyncMock(return_value=tool)
|
||||
tool_svc.db_client, "get_tool_by_uuid", AsyncMock(return_value=tool)
|
||||
)
|
||||
monkeypatch.setattr(tool_mod.db_client, "update_tool", AsyncMock(return_value=tool))
|
||||
monkeypatch.setattr(tool_mod, "discover_mcp_tools", AsyncMock(return_value=[]))
|
||||
monkeypatch.setattr(tool_svc.db_client, "update_tool", AsyncMock(return_value=tool))
|
||||
monkeypatch.setattr(tool_svc, "discover_mcp_tools", AsyncMock(return_value=[]))
|
||||
resp = await refresh_mcp_tools("tu-mcp", user=_fake_user())
|
||||
assert resp.discovered_tools == []
|
||||
assert resp.error # non-empty human-readable message
|
||||
# update_tool should NOT be called when discovery returns empty
|
||||
tool_mod.db_client.update_tool.assert_not_called()
|
||||
tool_svc.db_client.update_tool.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_non_mcp_is_400(monkeypatch):
|
||||
import api.routes.tool as tool_mod
|
||||
import api.services.tool_management as tool_svc
|
||||
|
||||
tool = _mcp_tool_model()
|
||||
tool.category = "http_api"
|
||||
monkeypatch.setattr(
|
||||
tool_mod.db_client, "get_tool_by_uuid", AsyncMock(return_value=tool)
|
||||
tool_svc.db_client, "get_tool_by_uuid", AsyncMock(return_value=tool)
|
||||
)
|
||||
with pytest.raises(HTTPException) as ei:
|
||||
await refresh_mcp_tools("tu-mcp", user=_fake_user())
|
||||
|
|
@ -427,10 +476,10 @@ async def test_refresh_non_mcp_is_400(monkeypatch):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_not_found_is_404(monkeypatch):
|
||||
import api.routes.tool as tool_mod
|
||||
import api.services.tool_management as tool_svc
|
||||
|
||||
monkeypatch.setattr(
|
||||
tool_mod.db_client, "get_tool_by_uuid", AsyncMock(return_value=None)
|
||||
tool_svc.db_client, "get_tool_by_uuid", AsyncMock(return_value=None)
|
||||
)
|
||||
with pytest.raises(HTTPException) as ei:
|
||||
await refresh_mcp_tools("nope", user=_fake_user())
|
||||
|
|
|
|||
19
api/tests/test_pipecat_engine_callbacks.py
Normal file
19
api/tests/test_pipecat_engine_callbacks.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from pipecat.utils.enums import EndTaskReason
|
||||
|
||||
from api.services.workflow.pipecat_engine_callbacks import create_max_duration_callback
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_duration_callback_aborts_immediately():
|
||||
engine = AsyncMock()
|
||||
|
||||
callback = create_max_duration_callback(engine)
|
||||
await callback()
|
||||
|
||||
engine.end_call_with_reason.assert_awaited_once_with(
|
||||
EndTaskReason.CALL_DURATION_EXCEEDED.value,
|
||||
abort_immediately=True,
|
||||
)
|
||||
|
|
@ -20,8 +20,7 @@ from unittest.mock import AsyncMock, patch
|
|||
import pytest
|
||||
from pipecat.frames.frames import LLMContextFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.pipeline.worker import PipelineParams, PipelineWorker
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMAssistantAggregatorParams,
|
||||
|
|
@ -30,6 +29,7 @@ from pipecat.processors.aggregators.llm_response_universal import (
|
|||
from pipecat.tests.mock_transport import MockTransport
|
||||
from pipecat.transports.base_transport import TransportParams
|
||||
|
||||
from api.services.pipecat.worker_runner import run_pipeline_worker
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.workflow_graph import WorkflowGraph
|
||||
from api.tests.conftest import (
|
||||
|
|
@ -116,7 +116,7 @@ async def run_pipeline_and_capture_context(
|
|||
)
|
||||
|
||||
# Create pipeline task
|
||||
task = PipelineTask(pipeline, params=PipelineParams(), enable_rtvi=False)
|
||||
task = PipelineWorker(pipeline, params=PipelineParams(), enable_rtvi=False)
|
||||
|
||||
engine.set_task(task)
|
||||
|
||||
|
|
@ -131,10 +131,9 @@ async def run_pipeline_and_capture_context(
|
|||
new_callable=AsyncMock,
|
||||
return_value="completed",
|
||||
):
|
||||
runner = PipelineRunner()
|
||||
|
||||
async def run_pipeline():
|
||||
await runner.run(task)
|
||||
await run_pipeline_worker(task)
|
||||
|
||||
async def initialize_engine():
|
||||
await asyncio.sleep(0.01)
|
||||
|
|
|
|||
|
|
@ -25,8 +25,7 @@ from unittest.mock import AsyncMock, patch
|
|||
import pytest
|
||||
from pipecat.frames.frames import Frame, LLMContextFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.pipeline.worker import PipelineParams, PipelineWorker
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMAssistantAggregatorParams,
|
||||
|
|
@ -42,6 +41,7 @@ from pipecat.turns.user_mute import (
|
|||
from pipecat.utils.enums import EndTaskReason
|
||||
|
||||
from api.enums import ToolCategory
|
||||
from api.services.pipecat.worker_runner import run_pipeline_worker
|
||||
from api.services.workflow.dto import (
|
||||
EdgeDataDTO,
|
||||
EndCallNodeData,
|
||||
|
|
@ -112,7 +112,7 @@ async def create_engine_with_tracking(
|
|||
mock_llm: MockLLMService,
|
||||
test_helper: EndCallTestHelper,
|
||||
generate_audio: bool = True,
|
||||
) -> tuple[PipecatEngine, MockTTSService, MockTransport, PipelineTask]:
|
||||
) -> tuple[PipecatEngine, MockTTSService, MockTransport, PipelineWorker]:
|
||||
"""Create a PipecatEngine with tracking for end call behavior.
|
||||
|
||||
Args:
|
||||
|
|
@ -222,7 +222,7 @@ async def create_engine_with_tracking(
|
|||
)
|
||||
|
||||
# Create pipeline task
|
||||
task = PipelineTask(pipeline, params=PipelineParams(), enable_rtvi=False)
|
||||
task = PipelineWorker(pipeline, params=PipelineParams(), enable_rtvi=False)
|
||||
|
||||
engine.set_task(task)
|
||||
|
||||
|
|
@ -279,10 +279,9 @@ class TestEndCallViaNodeTransition:
|
|||
new_callable=AsyncMock,
|
||||
return_value={"user_intent": "end call"},
|
||||
):
|
||||
runner = PipelineRunner()
|
||||
|
||||
async def run_pipeline():
|
||||
await runner.run(task)
|
||||
await run_pipeline_worker(task)
|
||||
|
||||
async def initialize_engine():
|
||||
await asyncio.sleep(0.01)
|
||||
|
|
@ -383,10 +382,9 @@ class TestEndCallViaNodeTransition:
|
|||
new_callable=AsyncMock,
|
||||
return_value={"greeting_type": "formal", "user_name": "John"},
|
||||
):
|
||||
runner = PipelineRunner()
|
||||
|
||||
async def run_pipeline():
|
||||
await runner.run(task)
|
||||
await run_pipeline_worker(task)
|
||||
|
||||
async def initialize_engine():
|
||||
await asyncio.sleep(0.01)
|
||||
|
|
@ -482,10 +480,9 @@ class TestEndCallViaCustomTool:
|
|||
new_callable=AsyncMock,
|
||||
return_value={"user_intent": "end"},
|
||||
):
|
||||
runner = PipelineRunner()
|
||||
|
||||
async def run_pipeline():
|
||||
await runner.run(task)
|
||||
await run_pipeline_worker(task)
|
||||
|
||||
async def initialize_engine():
|
||||
await asyncio.sleep(0.01)
|
||||
|
|
@ -574,10 +571,9 @@ class TestEndCallViaCustomTool:
|
|||
new_callable=AsyncMock,
|
||||
return_value={"user_intent": "end"},
|
||||
):
|
||||
runner = PipelineRunner()
|
||||
|
||||
async def run_pipeline():
|
||||
await runner.run(task)
|
||||
await run_pipeline_worker(task)
|
||||
|
||||
async def initialize_engine():
|
||||
await asyncio.sleep(0.01)
|
||||
|
|
@ -652,10 +648,9 @@ class TestEndCallViaClientDisconnect:
|
|||
new_callable=AsyncMock,
|
||||
return_value={"user_intent": "disconnected"},
|
||||
):
|
||||
runner = PipelineRunner()
|
||||
|
||||
async def run_pipeline():
|
||||
await runner.run(task)
|
||||
await run_pipeline_worker(task)
|
||||
|
||||
async def initialize_and_disconnect():
|
||||
await asyncio.sleep(0.01)
|
||||
|
|
@ -743,10 +738,9 @@ class TestEndCallRaceConditions:
|
|||
new_callable=AsyncMock,
|
||||
return_value={"user_intent": "end"},
|
||||
):
|
||||
runner = PipelineRunner()
|
||||
|
||||
async def run_pipeline():
|
||||
await runner.run(task)
|
||||
await run_pipeline_worker(task)
|
||||
|
||||
async def initialize_and_race():
|
||||
await asyncio.sleep(0.01)
|
||||
|
|
@ -855,10 +849,9 @@ class TestEndCallRaceConditions:
|
|||
new_callable=AsyncMock,
|
||||
return_value={"user_intent": "end"},
|
||||
):
|
||||
runner = PipelineRunner()
|
||||
|
||||
async def run_pipeline():
|
||||
await runner.run(task)
|
||||
await run_pipeline_worker(task)
|
||||
|
||||
async def initialize_and_race_disconnect():
|
||||
nonlocal disconnect_called
|
||||
|
|
@ -950,10 +943,9 @@ class TestEndCallExtractionBehavior:
|
|||
"_perform_extraction",
|
||||
side_effect=mock_extraction,
|
||||
):
|
||||
runner = PipelineRunner()
|
||||
|
||||
async def run_pipeline():
|
||||
await runner.run(task)
|
||||
await run_pipeline_worker(task)
|
||||
|
||||
async def initialize_and_end():
|
||||
await asyncio.sleep(0.01)
|
||||
|
|
@ -1076,10 +1068,9 @@ class TestEndCallExtractionBehavior:
|
|||
"_perform_extraction",
|
||||
extraction_mock,
|
||||
):
|
||||
runner = PipelineRunner()
|
||||
|
||||
async def run_pipeline():
|
||||
await runner.run(task)
|
||||
await run_pipeline_worker(task)
|
||||
|
||||
async def initialize_and_end():
|
||||
await asyncio.sleep(0.01)
|
||||
|
|
|
|||
|
|
@ -24,8 +24,7 @@ from pipecat.frames.frames import (
|
|||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.pipeline.worker import PipelineParams, PipelineWorker
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMAssistantAggregatorParams,
|
||||
|
|
@ -48,6 +47,7 @@ from pipecat.turns.user_stop import (
|
|||
from pipecat.turns.user_turn_strategies import UserTurnStrategies
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
from api.services.pipecat.worker_runner import run_pipeline_worker
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.workflow_graph import WorkflowGraph
|
||||
from pipecat.tests import MockLLMService, MockTTSService
|
||||
|
|
@ -119,7 +119,7 @@ async def create_test_pipeline(
|
|||
workflow: WorkflowGraph,
|
||||
mock_llm: MockLLMService,
|
||||
user_speech_initial_delay: float = 0.01,
|
||||
) -> tuple[PipecatEngine, MockTransport, PipelineTask]:
|
||||
) -> tuple[PipecatEngine, MockTransport, PipelineWorker]:
|
||||
"""Create a PipecatEngine with full pipeline for testing node switch scenarios.
|
||||
|
||||
The pipeline includes a UserSpeechInjector processor that injects
|
||||
|
|
@ -208,7 +208,7 @@ async def create_test_pipeline(
|
|||
)
|
||||
|
||||
# Create pipeline task
|
||||
task = PipelineTask(pipeline, params=PipelineParams(), enable_rtvi=False)
|
||||
task = PipelineWorker(pipeline, params=PipelineParams(), enable_rtvi=False)
|
||||
|
||||
engine.set_task(task)
|
||||
|
||||
|
|
@ -286,10 +286,9 @@ class TestNodeSwitchWithUserSpeech:
|
|||
new_callable=AsyncMock,
|
||||
return_value="completed",
|
||||
):
|
||||
runner = PipelineRunner()
|
||||
|
||||
async def run_pipeline():
|
||||
await runner.run(task)
|
||||
await run_pipeline_worker(task)
|
||||
|
||||
async def initialize_engine():
|
||||
await asyncio.sleep(0.01)
|
||||
|
|
|
|||
|
|
@ -11,8 +11,7 @@ from unittest.mock import AsyncMock, patch
|
|||
import pytest
|
||||
from pipecat.frames.frames import LLMContextFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.pipeline.worker import PipelineParams, PipelineWorker
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMAssistantAggregatorParams,
|
||||
|
|
@ -21,6 +20,7 @@ from pipecat.processors.aggregators.llm_response_universal import (
|
|||
from pipecat.tests.mock_transport import MockTransport
|
||||
from pipecat.transports.base_transport import TransportParams
|
||||
|
||||
from api.services.pipecat.worker_runner import run_pipeline_worker
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.workflow_graph import WorkflowGraph
|
||||
from api.tests.conftest import END_CALL_SYSTEM_PROMPT
|
||||
|
|
@ -107,7 +107,7 @@ async def run_pipeline_with_tool_calls(
|
|||
)
|
||||
|
||||
# Create a real pipeline task
|
||||
task = PipelineTask(pipeline, params=PipelineParams(), enable_rtvi=False)
|
||||
task = PipelineWorker(pipeline, params=PipelineParams(), enable_rtvi=False)
|
||||
|
||||
engine.set_task(task)
|
||||
|
||||
|
|
@ -122,10 +122,9 @@ async def run_pipeline_with_tool_calls(
|
|||
new_callable=AsyncMock,
|
||||
return_value="completed",
|
||||
):
|
||||
runner = PipelineRunner()
|
||||
|
||||
async def run_pipeline():
|
||||
await runner.run(task)
|
||||
await run_pipeline_worker(task)
|
||||
|
||||
async def initialize_engine():
|
||||
# Small delay to let runner start
|
||||
|
|
|
|||
|
|
@ -15,8 +15,7 @@ from unittest.mock import AsyncMock, patch
|
|||
import pytest
|
||||
from pipecat.frames.frames import LLMContextFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.pipeline.worker import PipelineParams, PipelineWorker
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMAssistantAggregatorParams,
|
||||
|
|
@ -31,6 +30,7 @@ from pipecat.turns.user_mute import (
|
|||
MuteUntilFirstBotCompleteUserMuteStrategy,
|
||||
)
|
||||
|
||||
from api.services.pipecat.worker_runner import run_pipeline_worker
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.pipecat_engine_variable_extractor import (
|
||||
VariableExtractionManager,
|
||||
|
|
@ -99,7 +99,7 @@ async def _build_engine_and_pipeline(
|
|||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(pipeline, params=PipelineParams(), enable_rtvi=False)
|
||||
task = PipelineWorker(pipeline, params=PipelineParams(), enable_rtvi=False)
|
||||
engine.set_task(task)
|
||||
|
||||
return engine, task, function_call_mute_strategy, user_context_aggregator
|
||||
|
|
@ -182,10 +182,9 @@ class TestTransitionFunctionMutesUser:
|
|||
new_callable=AsyncMock,
|
||||
return_value={"user_intent": "end call"},
|
||||
):
|
||||
runner = PipelineRunner()
|
||||
|
||||
async def run_pipeline():
|
||||
await runner.run(task)
|
||||
await run_pipeline_worker(task)
|
||||
|
||||
async def initialize_engine():
|
||||
await asyncio.sleep(0.01)
|
||||
|
|
@ -257,10 +256,9 @@ class TestTransitionFunctionMutesUser:
|
|||
new_callable=AsyncMock,
|
||||
return_value={"user_intent": "end call"},
|
||||
):
|
||||
runner = PipelineRunner()
|
||||
|
||||
async def run_pipeline():
|
||||
await runner.run(task)
|
||||
await run_pipeline_worker(task)
|
||||
|
||||
async def initialize_engine():
|
||||
await asyncio.sleep(0.01)
|
||||
|
|
|
|||
|
|
@ -18,8 +18,7 @@ from unittest.mock import AsyncMock, patch
|
|||
import pytest
|
||||
from pipecat.frames.frames import LLMContextFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.pipeline.worker import PipelineParams, PipelineWorker
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMAssistantAggregatorParams,
|
||||
|
|
@ -28,6 +27,7 @@ from pipecat.processors.aggregators.llm_response_universal import (
|
|||
from pipecat.tests.mock_transport import MockTransport
|
||||
from pipecat.transports.base_transport import TransportParams
|
||||
|
||||
from api.services.pipecat.worker_runner import run_pipeline_worker
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.pipecat_engine_variable_extractor import (
|
||||
VariableExtractionManager,
|
||||
|
|
@ -142,7 +142,7 @@ class TestVariableExtractionDuringTransitions:
|
|||
)
|
||||
|
||||
# Create pipeline task
|
||||
task = PipelineTask(
|
||||
task = PipelineWorker(
|
||||
pipeline,
|
||||
params=PipelineParams(),
|
||||
enable_rtvi=False,
|
||||
|
|
@ -168,10 +168,9 @@ class TestVariableExtractionDuringTransitions:
|
|||
new_callable=AsyncMock,
|
||||
return_value={"user_name": "John Doe"},
|
||||
):
|
||||
runner = PipelineRunner()
|
||||
|
||||
async def run_pipeline():
|
||||
await runner.run(task)
|
||||
await run_pipeline_worker(task)
|
||||
|
||||
async def initialize_engine():
|
||||
await asyncio.sleep(0.01)
|
||||
|
|
|
|||
|
|
@ -8,11 +8,12 @@ from pipecat.frames.frames import (
|
|||
InterruptionTaskFrame,
|
||||
LLMRunFrame,
|
||||
)
|
||||
from pipecat.pipeline.base_task import PipelineTaskParams
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.pipeline.worker import PipelineWorker
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
from api.services.pipecat.worker_runner import run_pipeline_worker
|
||||
|
||||
|
||||
class MockTransport(FrameProcessor):
|
||||
def __init__(self, **kwargs):
|
||||
|
|
@ -51,12 +52,10 @@ async def test_interruption_with_blocked_end_frame():
|
|||
transport = MockTransport()
|
||||
pipeline = Pipeline([transport, busy_wait_processor])
|
||||
|
||||
task = PipelineTask(pipeline, enable_rtvi=False)
|
||||
task = PipelineWorker(pipeline, enable_rtvi=False)
|
||||
|
||||
async def run_pipeline():
|
||||
loop = asyncio.get_running_loop()
|
||||
params = PipelineTaskParams(loop=loop)
|
||||
await task.run(params=params)
|
||||
await run_pipeline_worker(task)
|
||||
|
||||
async def queue_frame():
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
|
|
|||
100
api/tests/test_realtime_feedback_observer.py
Normal file
100
api/tests/test_realtime_feedback_observer.py
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from pipecat.frames.frames import TranscriptionFrame, TTSTextFrame
|
||||
from pipecat.observers.base_observer import FramePushed
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.transports.base_output import BaseOutputTransport
|
||||
from pipecat.transports.base_transport import TransportParams
|
||||
|
||||
from api.services.pipecat.realtime_feedback_observer import RealtimeFeedbackObserver
|
||||
|
||||
|
||||
def _frame_pushed(frame, direction, *, source=None):
|
||||
return FramePushed(
|
||||
source=source or SimpleNamespace(),
|
||||
destination=SimpleNamespace(),
|
||||
frame=frame,
|
||||
direction=direction,
|
||||
timestamp=0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_observer_streams_upstream_only_transcription_frames():
|
||||
messages = []
|
||||
|
||||
async def ws_sender(message):
|
||||
messages.append(message)
|
||||
|
||||
observer = RealtimeFeedbackObserver(ws_sender=ws_sender)
|
||||
frame = TranscriptionFrame(
|
||||
"Hi there",
|
||||
user_id="user-1",
|
||||
timestamp="2026-01-01T00:00:00+00:00",
|
||||
)
|
||||
|
||||
await observer.on_push_frame(_frame_pushed(frame, FrameDirection.UPSTREAM))
|
||||
|
||||
assert messages == [
|
||||
{
|
||||
"type": "rtf-user-transcription",
|
||||
"payload": {
|
||||
"text": "Hi there",
|
||||
"final": True,
|
||||
"timestamp": "2026-01-01T00:00:00+00:00",
|
||||
"user_id": "user-1",
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_observer_ignores_upstream_broadcast_transcription_sibling():
|
||||
messages = []
|
||||
|
||||
async def ws_sender(message):
|
||||
messages.append(message)
|
||||
|
||||
observer = RealtimeFeedbackObserver(ws_sender=ws_sender)
|
||||
frame = TranscriptionFrame(
|
||||
"Hi there",
|
||||
user_id="user-1",
|
||||
timestamp="2026-01-01T00:00:00+00:00",
|
||||
)
|
||||
frame.broadcast_sibling_id = 1234
|
||||
|
||||
await observer.on_push_frame(_frame_pushed(frame, FrameDirection.UPSTREAM))
|
||||
|
||||
assert messages == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_observer_waits_for_tts_text_from_output_transport():
|
||||
messages = []
|
||||
|
||||
async def ws_sender(message):
|
||||
messages.append(message)
|
||||
|
||||
observer = RealtimeFeedbackObserver(ws_sender=ws_sender)
|
||||
frame = TTSTextFrame("Hello", aggregated_by="word")
|
||||
frame.pts = 123
|
||||
|
||||
await observer.on_push_frame(_frame_pushed(frame, FrameDirection.DOWNSTREAM))
|
||||
assert messages == []
|
||||
|
||||
output_transport = BaseOutputTransport(TransportParams())
|
||||
await observer.on_push_frame(
|
||||
_frame_pushed(
|
||||
frame,
|
||||
FrameDirection.DOWNSTREAM,
|
||||
source=output_transport,
|
||||
)
|
||||
)
|
||||
|
||||
assert messages == [
|
||||
{
|
||||
"type": "rtf-bot-text",
|
||||
"payload": {"text": "Hello"},
|
||||
}
|
||||
]
|
||||
23
api/tests/test_run_usage_response.py
Normal file
23
api/tests/test_run_usage_response.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
from api.services.pricing.run_usage_response import format_public_usage_info
|
||||
|
||||
|
||||
def test_format_public_usage_info():
|
||||
usage_info = {
|
||||
"llm": {
|
||||
"SarvamLLMService#0|||sarvam-30b": {
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 150,
|
||||
}
|
||||
},
|
||||
"tts": {"ElevenLabsTTSService#0|||eleven_flash_v2_5": 42},
|
||||
"stt": {},
|
||||
"call_duration_seconds": 12.4,
|
||||
}
|
||||
|
||||
result = format_public_usage_info(usage_info)
|
||||
|
||||
assert result["llm"]["SarvamLLMService#0|||sarvam-30b"]["prompt_tokens"] == 100
|
||||
assert result["tts"]["ElevenLabsTTSService#0|||eleven_flash_v2_5"] == 42
|
||||
assert result["stt"] == {}
|
||||
assert result["call_duration_seconds"] == 12.4
|
||||
114
api/tests/test_sarvam_service_factory.py
Normal file
114
api/tests/test_sarvam_service_factory.py
Normal file
|
|
@ -0,0 +1,114 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from pipecat.services.sarvam.llm import SarvamLLMService as RealSarvamLLMService
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
from api.services.configuration.registry import (
|
||||
SarvamLLMConfiguration,
|
||||
ServiceProviders,
|
||||
)
|
||||
from api.services.pipecat.audio_config import AudioConfig
|
||||
from api.services.pipecat.service_factory import (
|
||||
create_llm_service,
|
||||
create_llm_service_from_provider,
|
||||
create_stt_service,
|
||||
)
|
||||
|
||||
|
||||
class TestSarvamLLMConfiguration:
|
||||
def test_default_values(self):
|
||||
config = SarvamLLMConfiguration(api_key="test-key")
|
||||
assert config.provider == ServiceProviders.SARVAM
|
||||
assert config.model == "sarvam-30b"
|
||||
assert config.temperature == 0.5
|
||||
|
||||
def test_custom_model(self):
|
||||
config = SarvamLLMConfiguration(api_key="test-key", model="sarvam-105b")
|
||||
assert config.model == "sarvam-105b"
|
||||
|
||||
|
||||
class TestSarvamLLMServiceFactory:
|
||||
def test_create_sarvam_llm_service(self):
|
||||
with patch(
|
||||
"api.services.pipecat.service_factory.SarvamLLMService"
|
||||
) as mock_service:
|
||||
mock_service.Settings = RealSarvamLLMService.Settings
|
||||
create_llm_service_from_provider(
|
||||
provider=ServiceProviders.SARVAM.value,
|
||||
model="sarvam-30b",
|
||||
api_key="test-key",
|
||||
)
|
||||
|
||||
assert mock_service.call_count == 1
|
||||
kwargs = mock_service.call_args.kwargs
|
||||
assert kwargs["api_key"] == "test-key"
|
||||
assert kwargs["settings"].model == "sarvam-30b"
|
||||
assert kwargs["settings"].temperature == 0.5
|
||||
|
||||
def test_create_sarvam_llm_service_passes_user_temperature(self):
|
||||
with patch(
|
||||
"api.services.pipecat.service_factory.SarvamLLMService"
|
||||
) as mock_service:
|
||||
mock_service.Settings = RealSarvamLLMService.Settings
|
||||
create_llm_service_from_provider(
|
||||
provider=ServiceProviders.SARVAM.value,
|
||||
model="sarvam-30b",
|
||||
api_key="test-key",
|
||||
temperature=0.8,
|
||||
)
|
||||
|
||||
kwargs = mock_service.call_args.kwargs
|
||||
assert kwargs["settings"].temperature == 0.8
|
||||
|
||||
def test_create_llm_service_extracts_sarvam_temperature(self):
|
||||
user_config = SimpleNamespace(
|
||||
llm=SimpleNamespace(
|
||||
provider=ServiceProviders.SARVAM.value,
|
||||
model="sarvam-30b",
|
||||
api_key="test-key",
|
||||
temperature=0.7,
|
||||
)
|
||||
)
|
||||
|
||||
with patch(
|
||||
"api.services.pipecat.service_factory.SarvamLLMService"
|
||||
) as mock_service:
|
||||
mock_service.Settings = RealSarvamLLMService.Settings
|
||||
create_llm_service(user_config)
|
||||
|
||||
kwargs = mock_service.call_args.kwargs
|
||||
assert kwargs["settings"].temperature == 0.7
|
||||
|
||||
|
||||
class TestSarvamSTTServiceFactory:
|
||||
@pytest.mark.parametrize(
|
||||
"input_language,expected_language",
|
||||
[
|
||||
("unknown", None),
|
||||
(None, None),
|
||||
("hi-IN", Language.HI_IN),
|
||||
("ne-IN", "ne-IN"),
|
||||
],
|
||||
)
|
||||
def test_stt_language_mapping(self, input_language, expected_language):
|
||||
user_config = SimpleNamespace(
|
||||
stt=SimpleNamespace(
|
||||
provider=ServiceProviders.SARVAM.value,
|
||||
model="saaras:v3",
|
||||
api_key="test-key",
|
||||
language=input_language,
|
||||
)
|
||||
)
|
||||
audio_config = AudioConfig(
|
||||
transport_in_sample_rate=16000, transport_out_sample_rate=16000
|
||||
)
|
||||
|
||||
with patch(
|
||||
"api.services.pipecat.service_factory.SarvamSTTService"
|
||||
) as mock_service:
|
||||
create_stt_service(user_config, audio_config)
|
||||
|
||||
kwargs = mock_service.call_args.kwargs
|
||||
assert kwargs["settings"].language == expected_language
|
||||
|
|
@ -20,8 +20,7 @@ from pipecat.frames.frames import (
|
|||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.pipeline.worker import PipelineParams, PipelineWorker
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMAssistantAggregatorParams,
|
||||
|
|
@ -31,6 +30,7 @@ from pipecat.tests.mock_transport import MockTransport
|
|||
from pipecat.transports.base_transport import TransportParams
|
||||
|
||||
from api.services.pipecat.recording_audio_cache import RecordingAudio
|
||||
from api.services.pipecat.worker_runner import run_pipeline_worker
|
||||
from api.services.workflow.dto import (
|
||||
EdgeDataDTO,
|
||||
EndCallNodeData,
|
||||
|
|
@ -212,7 +212,7 @@ async def run_pipeline_and_capture_frames(
|
|||
engine.set_transport_output(transport_output)
|
||||
|
||||
pipeline = Pipeline([llm, tts, transport_output, context_aggregator.assistant()])
|
||||
task = PipelineTask(pipeline, params=PipelineParams(), enable_rtvi=False)
|
||||
task = PipelineWorker(pipeline, params=PipelineParams(), enable_rtvi=False)
|
||||
engine.set_task(task)
|
||||
|
||||
# Spy on task.queue_frame and transport_output.queue_frame to capture
|
||||
|
|
@ -247,10 +247,9 @@ async def run_pipeline_and_capture_frames(
|
|||
return_value="completed",
|
||||
),
|
||||
):
|
||||
runner = PipelineRunner()
|
||||
|
||||
async def run():
|
||||
await runner.run(task)
|
||||
await run_pipeline_worker(task)
|
||||
|
||||
async def initialize():
|
||||
await asyncio.sleep(0.01)
|
||||
|
|
|
|||
|
|
@ -34,8 +34,7 @@ from unittest.mock import AsyncMock, patch
|
|||
import pytest
|
||||
from pipecat.frames.frames import LLMContextFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.pipeline.worker import PipelineParams, PipelineWorker
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMAssistantAggregatorParams,
|
||||
|
|
@ -50,6 +49,7 @@ from pipecat.turns.user_mute import (
|
|||
)
|
||||
from pipecat.utils.enums import EndTaskReason
|
||||
|
||||
from api.services.pipecat.worker_runner import run_pipeline_worker
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.pipecat_engine_variable_extractor import (
|
||||
VariableExtractionManager,
|
||||
|
|
@ -62,7 +62,7 @@ async def create_test_pipeline_with_failing_transport(
|
|||
workflow: WorkflowGraph,
|
||||
mock_llm: MockLLMService,
|
||||
fail_after_n_frames: int = 0,
|
||||
) -> tuple[PipecatEngine, MockTTSService, MockTransport, PipelineTask]:
|
||||
) -> tuple[PipecatEngine, MockTTSService, MockTransport, PipelineWorker]:
|
||||
"""Create a PipecatEngine with failing output transport for testing.
|
||||
|
||||
Uses the real MockTransport which now extends BaseOutputTransport and uses
|
||||
|
|
@ -152,7 +152,7 @@ async def create_test_pipeline_with_failing_transport(
|
|||
)
|
||||
|
||||
# Create pipeline task
|
||||
task = PipelineTask(pipeline, params=PipelineParams(), enable_rtvi=False)
|
||||
task = PipelineWorker(pipeline, params=PipelineParams(), enable_rtvi=False)
|
||||
|
||||
engine.set_task(task)
|
||||
|
||||
|
|
@ -219,10 +219,9 @@ class TestTTSPauseWithAudioWriteFailure:
|
|||
new_callable=AsyncMock,
|
||||
return_value={},
|
||||
):
|
||||
runner = PipelineRunner()
|
||||
|
||||
async def run_pipeline():
|
||||
await runner.run(task)
|
||||
await run_pipeline_worker(task)
|
||||
|
||||
async def initialize_and_end_call():
|
||||
await asyncio.sleep(0.01)
|
||||
|
|
@ -339,10 +338,9 @@ class TestTTSPauseWithAudioWriteFailure:
|
|||
new_callable=AsyncMock,
|
||||
return_value={},
|
||||
):
|
||||
runner = PipelineRunner()
|
||||
|
||||
async def run_pipeline():
|
||||
await runner.run(task)
|
||||
await run_pipeline_worker(task)
|
||||
|
||||
async def initialize_and_observe():
|
||||
await asyncio.sleep(0.01)
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from pipecat.frames.frames import (
|
|||
LLMContextFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
UserTurnInferenceCompletedFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
|
|
@ -45,6 +46,7 @@ class TestUnregisteredFunctionCall:
|
|||
expected_down_frames=[
|
||||
LLMFullResponseStartFrame,
|
||||
FunctionCallsFromLLMInfoFrame,
|
||||
UserTurnInferenceCompletedFrame,
|
||||
FunctionCallsStartedFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
FunctionCallInProgressFrame,
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from api.services.configuration.registry import (
|
|||
from api.services.gen_ai.embedding.openai_service import OpenAIEmbeddingService
|
||||
from api.services.pipecat.service_factory import (
|
||||
create_llm_service_from_provider,
|
||||
create_stt_service,
|
||||
create_tts_service,
|
||||
)
|
||||
from api.utils.url_security import validate_user_configured_service_url
|
||||
|
|
@ -214,6 +215,80 @@ def test_runtime_blocks_elevenlabs_local_tts_base_url_in_saas(monkeypatch):
|
|||
assert "localhost" in exc_info.value.detail
|
||||
|
||||
|
||||
def test_runtime_blocks_openai_stt_private_base_url_in_saas(monkeypatch):
|
||||
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas")
|
||||
user_config = SimpleNamespace(
|
||||
stt=SimpleNamespace(
|
||||
provider=ServiceProviders.OPENAI.value,
|
||||
api_key="test-key",
|
||||
model="gpt-4o-transcribe",
|
||||
base_url="http://10.0.0.10/v1",
|
||||
)
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
create_stt_service(user_config, audio_config=None)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "public IP" in exc_info.value.detail
|
||||
|
||||
|
||||
def test_runtime_blocks_openai_stt_localhost_base_url_in_saas(monkeypatch):
|
||||
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas")
|
||||
user_config = SimpleNamespace(
|
||||
stt=SimpleNamespace(
|
||||
provider=ServiceProviders.OPENAI.value,
|
||||
api_key="test-key",
|
||||
model="gpt-4o-transcribe",
|
||||
base_url="http://localhost:8000/v1",
|
||||
)
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
create_stt_service(user_config, audio_config=None)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "localhost" in exc_info.value.detail
|
||||
|
||||
|
||||
def test_runtime_blocks_openai_tts_private_base_url_in_saas(monkeypatch):
|
||||
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas")
|
||||
user_config = SimpleNamespace(
|
||||
tts=SimpleNamespace(
|
||||
provider=ServiceProviders.OPENAI.value,
|
||||
api_key="test-key",
|
||||
model="gpt-4o-mini-tts",
|
||||
voice="alloy",
|
||||
base_url="http://10.0.0.10/v1",
|
||||
)
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
create_tts_service(user_config, audio_config=None)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "public IP" in exc_info.value.detail
|
||||
|
||||
|
||||
def test_runtime_blocks_openai_tts_localhost_base_url_in_saas(monkeypatch):
|
||||
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas")
|
||||
user_config = SimpleNamespace(
|
||||
tts=SimpleNamespace(
|
||||
provider=ServiceProviders.OPENAI.value,
|
||||
api_key="test-key",
|
||||
model="gpt-4o-mini-tts",
|
||||
voice="alloy",
|
||||
base_url="http://localhost:8000/v1",
|
||||
)
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
create_tts_service(user_config, audio_config=None)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "localhost" in exc_info.value.detail
|
||||
|
||||
|
||||
def test_embedding_service_blocks_private_base_url_in_saas(monkeypatch):
|
||||
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas")
|
||||
|
||||
|
|
|
|||
|
|
@ -23,8 +23,7 @@ from pipecat.frames.frames import (
|
|||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.pipeline.worker import PipelineParams, PipelineWorker
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMAssistantAggregatorParams,
|
||||
|
|
@ -43,6 +42,7 @@ from pipecat.turns.user_stop import ExternalUserTurnStopStrategy
|
|||
from pipecat.turns.user_turn_strategies import UserTurnStrategies
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
from api.services.pipecat.worker_runner import run_pipeline_worker
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.workflow_graph import WorkflowGraph
|
||||
from pipecat.tests import MockLLMService, MockTTSService
|
||||
|
|
@ -100,7 +100,7 @@ async def create_pipeline_with_speech_injection(
|
|||
speeches: list[str],
|
||||
user_idle_timeout: float = 0.2,
|
||||
mock_audio_duration_ms: int = 400,
|
||||
) -> tuple[PipecatEngine, PipelineTask, object]:
|
||||
) -> tuple[PipecatEngine, PipelineWorker, object]:
|
||||
"""Create a pipeline with user speech injection and idle handling.
|
||||
|
||||
Sets up a realistic pipeline with:
|
||||
|
|
@ -194,7 +194,7 @@ async def create_pipeline_with_speech_injection(
|
|||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(pipeline, params=PipelineParams(), enable_rtvi=False)
|
||||
task = PipelineWorker(pipeline, params=PipelineParams(), enable_rtvi=False)
|
||||
engine.set_task(task)
|
||||
|
||||
return engine, task, user_idle_handler
|
||||
|
|
@ -266,10 +266,9 @@ class TestUserIdleHandler:
|
|||
new_callable=AsyncMock,
|
||||
return_value="completed",
|
||||
):
|
||||
runner = PipelineRunner()
|
||||
|
||||
async def run_pipeline():
|
||||
await runner.run(task)
|
||||
await run_pipeline_worker(task)
|
||||
|
||||
async def initialize_engine():
|
||||
await asyncio.sleep(0.01)
|
||||
|
|
|
|||
|
|
@ -25,8 +25,7 @@ from pipecat.frames.frames import (
|
|||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.pipeline.worker import PipelineParams, PipelineWorker
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMAssistantAggregatorParams,
|
||||
|
|
@ -44,6 +43,7 @@ from pipecat.turns.user_mute import (
|
|||
from pipecat.turns.user_turn_strategies import ExternalUserTurnStrategies
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
from api.services.pipecat.worker_runner import run_pipeline_worker
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.pipecat_engine_variable_extractor import (
|
||||
VariableExtractionManager,
|
||||
|
|
@ -125,7 +125,7 @@ async def create_engine_for_mute_test(
|
|||
PipecatEngine,
|
||||
MockTTSService,
|
||||
MockTransport,
|
||||
PipelineTask,
|
||||
PipelineWorker,
|
||||
LLMUserAggregator,
|
||||
BotSpeakingObserverProcessor,
|
||||
]:
|
||||
|
|
@ -196,7 +196,7 @@ async def create_engine_for_mute_test(
|
|||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(pipeline, params=PipelineParams(), enable_rtvi=False)
|
||||
task = PipelineWorker(pipeline, params=PipelineParams(), enable_rtvi=False)
|
||||
engine.set_task(task)
|
||||
|
||||
return engine, tts, mock_transport, task, user_context_aggregator, observer
|
||||
|
|
@ -258,10 +258,9 @@ class TestUserMutingDuringBotSpeech:
|
|||
new_callable=AsyncMock,
|
||||
return_value={},
|
||||
):
|
||||
runner = PipelineRunner()
|
||||
|
||||
async def run_pipeline():
|
||||
await runner.run(task)
|
||||
await run_pipeline_worker(task)
|
||||
|
||||
async def run_test():
|
||||
await asyncio.sleep(0.01)
|
||||
|
|
@ -349,10 +348,9 @@ class TestUserMutingDuringBotSpeech:
|
|||
new_callable=AsyncMock,
|
||||
return_value={},
|
||||
):
|
||||
runner = PipelineRunner()
|
||||
|
||||
async def run_pipeline():
|
||||
await runner.run(task)
|
||||
await run_pipeline_worker(task)
|
||||
|
||||
async def run_test():
|
||||
await asyncio.sleep(0.01)
|
||||
|
|
@ -445,10 +443,9 @@ class TestUserMutingDuringBotSpeech:
|
|||
new_callable=AsyncMock,
|
||||
return_value={},
|
||||
):
|
||||
runner = PipelineRunner()
|
||||
|
||||
async def run_pipeline():
|
||||
await runner.run(task)
|
||||
await run_pipeline_worker(task)
|
||||
|
||||
async def run_test():
|
||||
await asyncio.sleep(0.01)
|
||||
|
|
|
|||
|
|
@ -17,8 +17,7 @@ from pipecat.frames.frames import (
|
|||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.pipeline.worker import PipelineParams, PipelineWorker
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMAssistantAggregatorParams,
|
||||
|
|
@ -36,6 +35,7 @@ from pipecat.turns.user_stop import (
|
|||
from pipecat.turns.user_turn_strategies import UserTurnStrategies
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
from api.services.pipecat.worker_runner import run_pipeline_worker
|
||||
from pipecat.tests import MockLLMService
|
||||
|
||||
|
||||
|
|
@ -161,11 +161,10 @@ class TestVoicemailDetectorWithUserAggregator:
|
|||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(pipeline, params=PipelineParams(), enable_rtvi=False)
|
||||
runner = PipelineRunner()
|
||||
task = PipelineWorker(pipeline, params=PipelineParams(), enable_rtvi=False)
|
||||
|
||||
async def run_pipeline():
|
||||
await runner.run(task)
|
||||
await run_pipeline_worker(task)
|
||||
|
||||
async def inject_frames():
|
||||
await asyncio.sleep(0.05)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue