chore: bump pipecat version and fix tests (#263)

* chore: bump pipecat version and fix tests

* chore: add github workflow to run tests

* fix: install reqirements.dev.txt in test script

* fix: fix api-test action

* feat: add integration test

* test: add integration tests

* test: add test for function call mute strategy
This commit is contained in:
Abhishek 2026-05-04 21:35:37 +05:30 committed by GitHub
parent d256c6005c
commit 0e12c41fc7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
76 changed files with 1776 additions and 670 deletions

View file

@ -9,7 +9,7 @@ the root api/conftest.py. This module provides lightweight, non-DB fixtures:
from dataclasses import dataclass, field
from typing import Any, Dict, Optional
from unittest.mock import Mock
from unittest.mock import AsyncMock, Mock, patch
import pytest
@ -123,13 +123,28 @@ class MockToolModel:
@pytest.fixture
def mock_engine():
"""Create a mock PipecatEngine."""
"""Create a mock PipecatEngine.
Binds the real `_get_organization_id` method so the fetch + cache logic
runs against a patched `db_client.get_organization_id_by_workflow_run_id`
(returns org_id=1) for the duration of the fixture.
"""
from api.services.workflow.pipecat_engine import PipecatEngine
engine = Mock()
engine._workflow_run_id = 1
engine._call_context_vars = {"customer_name": "John Doe"}
engine._organization_id = None
engine._get_organization_id = PipecatEngine._get_organization_id.__get__(engine)
engine.llm = Mock()
engine.llm.register_function = Mock()
return engine
with patch(
"api.db:db_client.get_organization_id_by_workflow_run_id",
new_callable=AsyncMock,
return_value=1,
):
yield engine
@pytest.fixture

View file

View file

@ -0,0 +1,249 @@
"""Shared scaffolding for ``_run_pipeline`` integration tests.
Both ``test_run_pipeline.py`` and ``test_run_pipeline_text_greeting.py``
drive the real ``_run_pipeline`` end-to-end with the same set of external
boundaries patched out (STT/LLM/TTS factories, S3 recording fetcher,
PostHog publisher, ARQ enqueuer, real-time feedback observer). This
module centralises that scaffolding so each test only declares the bits
that differ its workflow definition and any preconfigured mocks.
Provided here:
- ``USER_CONFIGURATION``: a minimal user-configuration dict with valid
provider/model values; the keys themselves are dummy.
- ``PassthroughProcessor``: an STT stand-in that forwards frames as-is.
- ``NoopFeedbackObserver``: a ``RealtimeFeedbackObserver`` stand-in with
no WebSocket / clock-task side effects.
- ``patch_run_pipeline_externals``: ``contextmanager`` that applies the
full patch set and captures the constructed ``PipelineTask`` for the
caller. Optional ``llm`` / ``tts`` arguments inject preconfigured
mocks; otherwise blank ``MockLLMService`` / ``MockTTSService``
instances are constructed per-call.
- ``create_workflow_run_rows``: helper that creates the org / user /
user-configuration / workflow / workflow-run rows for an integration
test. Each test wires this through its own thin fixture so the
workflow definition stays local to the test.
"""
from contextlib import ExitStack, contextmanager
from typing import Any
from unittest.mock import AsyncMock, patch
from pipecat.frames.frames import Frame
from pipecat.observers.base_observer import BaseObserver
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from api.db.models import OrganizationModel, UserModel
from api.enums import WorkflowRunMode
from pipecat.tests import MockLLMService, MockTTSService
USER_CONFIGURATION: dict[str, Any] = {
"is_realtime": False,
"stt": {
"provider": "deepgram",
"model": "nova-3",
"api_key": "test-key",
},
"tts": {
"provider": "cartesia",
"model": "sonic-2",
"api_key": "test-key",
"voice_id": "test-voice",
},
"llm": {
"provider": "openai",
"model": "gpt-4.1",
"api_key": "test-key",
},
}
class PassthroughProcessor(FrameProcessor):
"""Stand-in for the STT processor: forwards every frame untouched."""
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
await self.push_frame(frame, direction)
class NoopFeedbackObserver(BaseObserver):
"""Stand-in for ``RealtimeFeedbackObserver``: no WS / no clock task."""
def __init__(self, *_args, **_kwargs):
super().__init__()
async def cleanup(self):
pass
@contextmanager
def patch_run_pipeline_externals(
captured_task: list,
*,
llm: MockLLMService | None = None,
tts: MockTTSService | None = None,
):
"""Patch the externally-talking pieces of ``_run_pipeline`` and capture
the constructed ``PipelineTask`` so tests can drive it from outside.
Args:
captured_task: A list the constructed ``PipelineTask`` is appended
to. Tests read ``captured_task[0]`` to get a handle on the task
(to wait on its start event, queue frames, cancel it, etc.).
llm: Optional pre-built ``MockLLMService``. When given, every call
to ``create_llm_service`` returns this same instance (so the
test can inspect its ``mock_steps`` / ``current_step``).
When ``None``, a blank ``MockLLMService`` is constructed.
tts: Optional pre-built ``MockTTSService``. Same semantics as
``llm``: pass an instance to share state with the test, or
``None`` to use a fresh one.
"""
from api.services.pipecat import pipeline_builder as _pipeline_builder
original_create_task = _pipeline_builder.create_pipeline_task
def _capture_task(*args, **kwargs):
task = original_create_task(*args, **kwargs)
captured_task.append(task)
return task
def _llm_factory(*_args, **_kwargs):
return llm if llm is not None else MockLLMService(api_key="test")
def _tts_factory(*_args, **_kwargs):
return tts if tts is not None else MockTTSService()
with ExitStack() as stack:
# Replace service factories with in-process test doubles.
stack.enter_context(
patch(
"api.services.pipecat.run_pipeline.create_llm_service",
_llm_factory,
)
)
stack.enter_context(
patch(
"api.services.pipecat.run_pipeline.create_stt_service",
lambda *_args, **_kwargs: PassthroughProcessor(),
)
)
stack.enter_context(
patch(
"api.services.pipecat.run_pipeline.create_tts_service",
_tts_factory,
)
)
# S3 — the recording fetcher would otherwise resolve org-scoped recordings.
stack.enter_context(
patch(
"api.services.pipecat.run_pipeline.create_recording_audio_fetcher",
lambda *_args, **_kwargs: AsyncMock(return_value=None),
)
)
# External fire-and-forget integrations.
stack.enter_context(
patch(
"api.services.pipecat.event_handlers._capture_call_event",
new=AsyncMock(),
)
)
stack.enter_context(
patch(
"api.services.pipecat.event_handlers.enqueue_job",
new=AsyncMock(),
)
)
# Skip the real-time feedback observer (WebSocket / log-buffer streaming).
stack.enter_context(
patch(
"api.services.pipecat.run_pipeline.RealtimeFeedbackObserver",
NoopFeedbackObserver,
)
)
# Disposition mapper would otherwise call out to the LLM.
stack.enter_context(
patch(
"api.services.workflow.pipecat_engine.apply_disposition_mapping",
new_callable=AsyncMock,
return_value="completed",
)
)
# Capture the PipelineTask so the test can drive it from outside.
stack.enter_context(
patch(
"api.services.pipecat.run_pipeline.create_pipeline_task",
side_effect=_capture_task,
)
)
yield
async def create_workflow_run_rows(
db_session,
async_session,
*,
workflow_definition: dict,
name_prefix: str,
provider_id_suffix: str,
):
"""Create org / user / user-configuration / workflow / workflow-run rows
in the test database for a ``_run_pipeline`` integration test.
Args:
db_session: The patched ``DBClient`` from the ``db_session`` fixture.
async_session: The raw ``AsyncSession`` from the ``async_session``
fixture (used to add the org/user rows directly).
workflow_definition: The dict that becomes
``WorkflowModel.workflow_definition`` and the V1 workflow_json.
name_prefix: Used to build human-readable workflow / run names.
provider_id_suffix: Used to generate unique ``provider_id`` values
for the org and user rows so concurrent or repeated test runs
don't collide.
Returns:
Tuple of (workflow_run, user, workflow).
"""
from api.schemas.user_configuration import UserConfiguration
org = OrganizationModel(provider_id=f"test-org-{provider_id_suffix}")
async_session.add(org)
await async_session.flush()
user = UserModel(
provider_id=f"test-user-{provider_id_suffix}",
selected_organization_id=org.id,
)
async_session.add(user)
await async_session.flush()
await db_session.update_user_configuration(
user_id=user.id,
configuration=UserConfiguration.model_validate(USER_CONFIGURATION),
)
workflow = await db_session.create_workflow(
name=f"{name_prefix} Workflow",
workflow_definition=workflow_definition,
user_id=user.id,
organization_id=org.id,
)
workflow_run = await db_session.create_workflow_run(
name=f"{name_prefix} Run",
workflow_id=workflow.id,
mode=WorkflowRunMode.SMALLWEBRTC.value,
user_id=user.id,
)
return workflow_run, user, workflow
# Keep the module's public surface explicit so ``import *`` doesn't grab
# transitive imports.
__all__ = [
"USER_CONFIGURATION",
"PassthroughProcessor",
"NoopFeedbackObserver",
"patch_run_pipeline_externals",
"create_workflow_run_rows",
]

View file

@ -0,0 +1,134 @@
"""Integration tests for ``api.services.pipecat.run_pipeline._run_pipeline``.
Drives the actual ``_run_pipeline`` against the test database with real
DB rows (organization, user, user configuration, workflow, workflow run)
and pipecat's real ``MockTransport`` / ``Pipeline`` / ``PipelineTask``.
The only patches are for things that talk to genuinely external systems;
those are applied via ``patch_run_pipeline_externals`` from the shared
helpers module.
Verifies that the wiring done by ``_run_pipeline`` (in particular
``register_event_handlers``) produces the right behaviour end-to-end:
``maybe_trigger_initial_response`` fires (``engine.set_node`` runs), and
on shutdown the workflow run is persisted with the expected state,
completion flag, and ``gathered_context`` entries.
"""
import asyncio
import pytest
from pipecat.tests.mock_transport import MockTransport
from pipecat.transports.base_transport import TransportParams
from api.enums import WorkflowRunMode, WorkflowRunState
from api.services.pipecat.audio_config import create_audio_config
from api.services.pipecat.run_pipeline import _run_pipeline
from api.tests.integrations._run_pipeline_helpers import (
create_workflow_run_rows,
patch_run_pipeline_externals,
)
WORKFLOW_DEFINITION = {
"nodes": [
{
"id": "start",
"type": "startCall",
"position": {"x": 0, "y": 0},
"data": {
"name": "Start",
"prompt": "You are a helpful assistant. Greet the user briefly.",
"is_start": True,
"allow_interrupt": False,
"add_global_prompt": False,
},
},
{
"id": "end",
"type": "endCall",
"position": {"x": 0, "y": 200},
"data": {
"name": "End",
"prompt": "End the call politely.",
"is_end": True,
"allow_interrupt": False,
"add_global_prompt": False,
},
},
],
"edges": [
{
"id": "start-end",
"source": "start",
"target": "end",
"data": {"label": "End", "condition": "When the user wants to end."},
}
],
}
@pytest.fixture
async def workflow_run_setup(db_session, async_session):
"""Create org/user/user_configuration/workflow/workflow_run rows in the
test database. Returns (workflow_run, user, workflow)."""
return await create_workflow_run_rows(
db_session,
async_session,
workflow_definition=WORKFLOW_DEFINITION,
name_prefix="Event Handler Integration",
provider_id_suffix="event-handlers",
)
@pytest.mark.asyncio
async def test_run_pipeline_fires_initial_response_and_completes_run(
workflow_run_setup, db_session
):
"""End-to-end: _run_pipeline boots, register_event_handlers wires up,
on_pipeline_started + on_client_connected both fire, the initial
response is triggered (set_node), and on_pipeline_finished updates
the workflow_run row to COMPLETED."""
workflow_run, user, workflow = workflow_run_setup
transport = MockTransport(
TransportParams(audio_in_enabled=True, audio_out_enabled=True)
)
captured_task: list = []
audio_config = create_audio_config(WorkflowRunMode.SMALLWEBRTC.value)
with patch_run_pipeline_externals(captured_task):
run_coro = _run_pipeline(
transport=transport,
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
user_id=user.id,
audio_config=audio_config,
user_provider_id=user.provider_id,
)
run_task = asyncio.create_task(run_coro)
# Wait until create_pipeline_task is invoked. Surface any
# exception from _run_pipeline immediately rather than swallowing
# it during the wait loop.
for _ in range(60):
if captured_task or run_task.done():
break
await asyncio.sleep(0.05)
if run_task.done() and not captured_task:
run_task.result() # re-raise the failure
assert captured_task, "create_pipeline_task was never invoked"
pipeline_task = captured_task[0]
await asyncio.wait_for(pipeline_task._pipeline_start_event.wait(), timeout=3.0)
# Let the initial response handler (set_node, queue LLMContextFrame)
# complete before tearing things down.
await asyncio.sleep(0.1)
await pipeline_task.cancel()
await asyncio.wait_for(run_task, timeout=5.0)
# Verify the run was completed end-to-end via the real on_pipeline_finished
# handler — DB side effects, not mock assertions.
refreshed = await db_session.get_workflow_run_by_id(workflow_run.id)
assert refreshed.is_completed is True
assert refreshed.state == WorkflowRunState.COMPLETED.value
# set_node("start") populates "nodes_visited" via _gathered_context, and
# on_pipeline_finished merges call_tags into gathered_context.
assert "Start" in refreshed.gathered_context.get("nodes_visited", [])
assert "call_tags" in refreshed.gathered_context

View file

@ -0,0 +1,289 @@
"""Integration test for the text-greeting flow through ``_run_pipeline``.
Drives the full pipeline produced by ``_run_pipeline`` against the test
database with a workflow whose start node has a text greeting configured.
The flow under test:
1. ``maybe_trigger_initial_response`` (in ``event_handlers.py``) sees a
text greeting and queues ``TTSSpeakFrame(greeting)``.
2. ``MockTTSService`` synthesises audio for the greeting; the real
``MediaSender`` machinery in ``MockOutputTransport`` emits
``BotStartedSpeakingFrame`` and ``BotStoppedSpeakingFrame``.
3. The TTS service emits an ``LLMAssistantPushAggregationFrame`` after
``TTSStoppedFrame``, so the greeting is appended to the assistant
context by ``LLMAssistantAggregator``.
4. We then push a ``TranscriptionFrame`` into the pipeline. After the
user-turn-stop timeout, ``LLMUserAggregator`` pushes a context frame
to the LLM, ``MockLLMService`` returns an ``end_call`` tool call, and
the engine's transition function moves to the end node and calls
``end_call_with_reason``.
5. ``on_pipeline_finished`` records the run as COMPLETED.
External boundaries are patched via ``patch_run_pipeline_externals``
from the shared helpers module. Preconfigured ``MockLLMService`` /
``MockTTSService`` instances are passed in so the end_call response is
deterministic and the synthesised audio length is short.
"""
import asyncio
import pytest
from pipecat.frames.frames import TranscriptionFrame
from pipecat.tests.mock_transport import MockTransport
from pipecat.transports.base_transport import TransportParams
from pipecat.utils.time import time_now_iso8601
from api.enums import WorkflowRunMode, WorkflowRunState
from api.services.pipecat.audio_config import create_audio_config
from api.services.pipecat.run_pipeline import _run_pipeline
from api.tests.integrations._run_pipeline_helpers import (
create_workflow_run_rows,
patch_run_pipeline_externals,
)
from pipecat.tests import MockLLMService, MockTTSService
GREETING_TEXT = (
"Thanks for calling Happy Feet, this is Sarah. How can I help you today?"
)
WORKFLOW_DEFINITION = {
"nodes": [
{
"id": "start",
"type": "startCall",
"position": {"x": 0, "y": 0},
"data": {
"name": "Start",
"prompt": "You are Sarah. Help the caller and end the call when they ask.",
"is_start": True,
"allow_interrupt": False,
"add_global_prompt": False,
"greeting": GREETING_TEXT,
"greeting_type": "text",
},
},
{
"id": "end",
"type": "endCall",
"position": {"x": 0, "y": 200},
"data": {
"name": "End",
"prompt": "End the call politely.",
"is_end": True,
"allow_interrupt": False,
"add_global_prompt": False,
},
},
],
"edges": [
{
"id": "start-end",
"source": "start",
"target": "end",
"data": {"label": "End Call", "condition": "When the user wants to end."},
}
],
}
# Hard cap on the entire test. Without this, a hung pipeline would keep the
# pytest worker alive indefinitely (the harness has no pytest-timeout plugin).
TEST_HARD_TIMEOUT_SECONDS = 25.0
@pytest.fixture
async def workflow_run_setup(db_session, async_session):
"""Create org/user/user_configuration/workflow/workflow_run rows. The
workflow's start node is configured with a text greeting."""
return await create_workflow_run_rows(
db_session,
async_session,
workflow_definition=WORKFLOW_DEFINITION,
name_prefix="Text Greeting Integration",
provider_id_suffix="text-greeting",
)
def _greeting_in_assistant_context(context) -> bool:
"""Return True if the greeting text has been appended to the assistant context."""
for message in context.get_messages():
if isinstance(message, dict) and message.get("role") == "assistant":
content = message.get("content") or ""
if GREETING_TEXT in content:
return True
return False
def _find_processor_by_class_name(pipeline_task, class_name: str):
"""Walk every processor reachable from the task's pipeline (including nested
sub-pipelines) and return the first one whose class name matches."""
visited: set[int] = set()
stack = [pipeline_task._pipeline]
while stack:
processor = stack.pop()
if id(processor) in visited:
continue
visited.add(id(processor))
if processor.__class__.__name__ == class_name:
return processor
sub = getattr(processor, "_processors", None)
if sub:
stack.extend(sub)
return None
async def _wait_for(predicate, *, timeout: float, interval: float = 0.05) -> bool:
"""Poll ``predicate`` (sync callable returning bool) until it returns True
or the timeout elapses. Returns the final predicate value."""
deadline = asyncio.get_event_loop().time() + timeout
while asyncio.get_event_loop().time() < deadline:
if predicate():
return True
await asyncio.sleep(interval)
return predicate()
async def _run_test_body(workflow_run_setup, db_session) -> None:
workflow_run, user, workflow = workflow_run_setup
# Prepare the LLM with one step: the end_call function call.
# Edge label "End Call" maps to function name "end_call".
end_call_chunks = MockLLMService.create_function_call_chunks(
function_name="end_call",
arguments={},
tool_call_id="call_end_1",
)
llm = MockLLMService(mock_steps=[end_call_chunks], chunk_delay=0.001)
# Short audio greeting so the bot finishes speaking quickly in tests.
tts = MockTTSService(mock_audio_duration_ms=50, frame_delay=0)
transport = MockTransport(
TransportParams(audio_in_enabled=True, audio_out_enabled=True)
)
captured_task: list = []
audio_config = create_audio_config(WorkflowRunMode.SMALLWEBRTC.value)
pipeline_task = None
try:
with patch_run_pipeline_externals(captured_task, llm=llm, tts=tts):
run_coro = _run_pipeline(
transport=transport,
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
user_id=user.id,
audio_config=audio_config,
user_provider_id=user.provider_id,
)
run_task = asyncio.create_task(run_coro)
for _ in range(60):
if captured_task or run_task.done():
break
await asyncio.sleep(0.05)
if run_task.done() and not captured_task:
run_task.result()
assert captured_task, "create_pipeline_task was never invoked"
pipeline_task = captured_task[0]
await asyncio.wait_for(
pipeline_task._pipeline_start_event.wait(), timeout=3.0
)
# Locate the assistant aggregator's LLM context (downstream of TTS).
# The PipelineTask wraps the user's pipeline inside another Pipeline,
# so we walk the tree recursively.
assistant_aggregator = _find_processor_by_class_name(
pipeline_task, "LLMAssistantAggregator"
)
assert assistant_aggregator is not None, (
"LLMAssistantAggregator not found in pipeline"
)
context = assistant_aggregator.context
# Wait for the greeting to be appended to the assistant context. The
# TTSSpeakFrame -> audio frames -> BotStoppedSpeaking -> assistant
# aggregation push chain runs through the real pipeline.
appeared = await _wait_for(
lambda: _greeting_in_assistant_context(context), timeout=5.0
)
assert appeared, (
"Greeting was not appended to the assistant context. "
f"Messages: {context.get_messages()}"
)
# The LLM must not have been invoked yet — the greeting bypasses
# the LLM entirely (goes straight to TTS via TTSSpeakFrame).
assert llm.get_current_step() == 0, (
f"LLM should not have run yet; current_step={llm.get_current_step()}"
)
# Now simulate the user replying. SpeechTimeoutUserTurnStopStrategy
# (default 0.6s) ends the user turn, which triggers an LLM run;
# the LLM returns end_call; the transition function moves to the
# end node and ends the call.
await pipeline_task.queue_frame(
TranscriptionFrame(
text="I want to end the call now please.",
user_id="test-user",
timestamp=time_now_iso8601(),
)
)
# Wait for the run to complete.
await asyncio.wait_for(run_task, timeout=10.0)
# Outside the patch ctx so the assertions exercise real DB state.
# The first LLM run produces the end_call; the engine then transitions
# to the End node and triggers a second generation (which is empty —
# mock_steps[1] is unset). What matters is that at least one run
# happened, i.e. the user transcript actually drove the LLM.
assert llm.get_current_step() >= 1, (
f"Expected at least one LLM generation; got step={llm.get_current_step()}"
)
refreshed = await db_session.get_workflow_run_by_id(workflow_run.id)
assert refreshed.is_completed is True
assert refreshed.state == WorkflowRunState.COMPLETED.value
nodes_visited = refreshed.gathered_context.get("nodes_visited", [])
assert "Start" in nodes_visited
assert "End" in nodes_visited
finally:
# Best-effort cleanup so a partially-run pipeline doesn't leak tasks
# past the test boundary.
if pipeline_task is not None and not pipeline_task.has_finished():
try:
await asyncio.wait_for(pipeline_task.cancel(), timeout=3.0)
except Exception:
pass
@pytest.mark.asyncio
async def test_text_greeting_speaks_then_user_transcript_triggers_end_call(
workflow_run_setup, db_session
):
"""End-to-end:
- ``maybe_trigger_initial_response`` queues ``TTSSpeakFrame`` for the
start-node text greeting.
- ``MockTTSService`` synthesises audio; ``MockOutputTransport`` emits
bot speaking events; the assistant aggregator appends the greeting
to the context after the TTS turn ends.
- We push a ``TranscriptionFrame`` into the pipeline. After the user
turn stop timeout, ``MockLLMService`` returns an ``end_call`` tool
call which transitions to the end node and ends the run.
The whole body is bounded by ``TEST_HARD_TIMEOUT_SECONDS`` so a hung
pipeline fails the test rather than wedging the test runner.
"""
try:
await asyncio.wait_for(
_run_test_body(workflow_run_setup, db_session),
timeout=TEST_HARD_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError as e:
raise AssertionError(
f"Test exceeded hard timeout of {TEST_HARD_TIMEOUT_SECONDS}s — "
"pipeline likely hung. Check earlier debug logs for the last frame "
"to reach the pipeline."
) from e

View file

@ -12,12 +12,6 @@ from typing import Any, Dict
from unittest.mock import AsyncMock, Mock, patch
import pytest
from api.services.workflow.pipecat_engine_custom_tools import get_function_schema
from api.services.workflow.tools.custom_tool import (
execute_http_tool,
tool_to_function_schema,
)
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.frames.frames import (
FunctionCallInProgressFrame,
@ -31,6 +25,12 @@ from pipecat.frames.frames import (
from pipecat.pipeline.pipeline import Pipeline
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.services.llm_service import FunctionCallParams
from api.services.workflow.pipecat_engine_custom_tools import get_function_schema
from api.services.workflow.tools.custom_tool import (
execute_http_tool,
tool_to_function_schema,
)
from pipecat.tests import MockLLMService, run_test
@ -720,13 +720,19 @@ class TestCustomToolManagerUnit:
@pytest.mark.asyncio
async def test_get_tool_schemas_returns_correct_format(self):
"""Test that get_tool_schemas returns FunctionSchema objects."""
from api.services.workflow.pipecat_engine_custom_tools import CustomToolManager
from pipecat.adapters.schemas.function_schema import FunctionSchema
# Create a mock engine
from api.services.workflow.pipecat_engine import PipecatEngine
from api.services.workflow.pipecat_engine_custom_tools import CustomToolManager
mock_engine = Mock()
mock_engine._workflow_run_id = 1
mock_engine._call_context_vars = {}
mock_engine._organization_id = None
mock_engine._get_organization_id = PipecatEngine._get_organization_id.__get__(
mock_engine
)
manager = CustomToolManager(mock_engine)
@ -754,29 +760,31 @@ class TestCustomToolManagerUnit:
},
)
with patch(
"api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run"
) as mock_get_org:
mock_get_org.return_value = 1
with patch(
with (
patch(
"api.services.workflow.pipecat_engine_custom_tools.db_client"
) as mock_db:
mock_db.get_tools_by_uuids = AsyncMock(return_value=[mock_tool])
) as mock_db,
patch(
"api.db:db_client.get_organization_id_by_workflow_run_id",
new_callable=AsyncMock,
return_value=1,
),
):
mock_db.get_tools_by_uuids = AsyncMock(return_value=[mock_tool])
schemas = await manager.get_tool_schemas(["uuid-1"])
schemas = await manager.get_tool_schemas(["uuid-1"])
assert len(schemas) == 1
schema = schemas[0]
assert len(schemas) == 1
schema = schemas[0]
# Schema should be a FunctionSchema object
assert isinstance(schema, FunctionSchema)
# Schema should be a FunctionSchema object
assert isinstance(schema, FunctionSchema)
# FunctionSchema should have correct attributes
assert schema.name == "test_tool"
assert "param1" in schema.properties
assert schema.properties["param1"]["type"] == "string"
assert "param1" in schema.required
# FunctionSchema should have correct attributes
assert schema.name == "test_tool"
assert "param1" in schema.properties
assert schema.properties["param1"]["type"] == "string"
assert "param1" in schema.required
@pytest.mark.asyncio
async def test_register_handlers_creates_working_handler(self):
@ -792,9 +800,15 @@ class TestCustomToolManagerUnit:
mock_llm.register_function = capture_register
from api.services.workflow.pipecat_engine import PipecatEngine
mock_engine = Mock()
mock_engine._workflow_run_id = 1
mock_engine._call_context_vars = {}
mock_engine._organization_id = None
mock_engine._get_organization_id = PipecatEngine._get_organization_id.__get__(
mock_engine
)
mock_engine.llm = mock_llm
manager = CustomToolManager(mock_engine)
@ -815,20 +829,22 @@ class TestCustomToolManagerUnit:
},
)
with patch(
"api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run"
) as mock_get_org:
mock_get_org.return_value = 1
with patch(
with (
patch(
"api.services.workflow.pipecat_engine_custom_tools.db_client"
) as mock_db:
mock_db.get_tools_by_uuids = AsyncMock(return_value=[mock_tool])
) as mock_db,
patch(
"api.db:db_client.get_organization_id_by_workflow_run_id",
new_callable=AsyncMock,
return_value=1,
),
):
mock_db.get_tools_by_uuids = AsyncMock(return_value=[mock_tool])
await manager.register_handlers(["uuid-1"])
await manager.register_handlers(["uuid-1"])
# Verify handler was registered
assert "api_call" in registered_handlers
# Verify handler was registered
assert "api_call" in registered_handlers
# Now test that the handler works
handler = registered_handlers["api_call"]

View file

@ -9,15 +9,15 @@ This module tests the full flow of:
from unittest.mock import AsyncMock, patch
import pytest
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.processors.aggregators.llm_context import LLMContext
from api.services.workflow.pipecat_engine_custom_tools import (
CustomToolManager,
get_function_schema,
)
from api.tests.conftest import MockToolModel
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.processors.aggregators.llm_context import LLMContext
def _update_llm_context(context, system_message, functions):
@ -45,70 +45,65 @@ class TestCustomToolManagerContextIntegration:
manager = CustomToolManager(mock_engine)
with patch(
"api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run"
) as mock_get_org:
mock_get_org.return_value = 1
"api.services.workflow.pipecat_engine_custom_tools.db_client"
) as mock_db:
mock_db.get_tools_by_uuids = AsyncMock(return_value=sample_tools)
with patch(
"api.services.workflow.pipecat_engine_custom_tools.db_client"
) as mock_db:
mock_db.get_tools_by_uuids = AsyncMock(return_value=sample_tools)
# Get tool schemas via CustomToolManager - now returns FunctionSchema objects
tool_uuids = ["weather-uuid-123", "booking-uuid-456", "lookup-uuid-789"]
schemas = await manager.get_tool_schemas(tool_uuids)
# Get tool schemas via CustomToolManager - now returns FunctionSchema objects
tool_uuids = ["weather-uuid-123", "booking-uuid-456", "lookup-uuid-789"]
schemas = await manager.get_tool_schemas(tool_uuids)
# Verify schemas were returned as FunctionSchema objects
assert len(schemas) == 3
assert all(isinstance(s, FunctionSchema) for s in schemas)
# Verify schemas were returned as FunctionSchema objects
assert len(schemas) == 3
assert all(isinstance(s, FunctionSchema) for s in schemas)
# Create context with conversation history
context = LLMContext()
context.set_messages(
[
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": "I need to check the weather and book an appointment.",
},
{
"role": "assistant",
"content": "I can help with both. Where would you like to check the weather?",
},
{"role": "user", "content": "San Francisco"},
]
)
# Create context with conversation history
context = LLMContext()
context.set_messages(
[
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": "I need to check the weather and book an appointment.",
},
{
"role": "assistant",
"content": "I can help with both. Where would you like to check the weather?",
},
{"role": "user", "content": "San Francisco"},
]
)
# Update context with new system message and tools
# Now we can pass schemas directly since they're FunctionSchema objects
new_system = {
"role": "system",
"content": "You are a scheduling assistant with access to weather and booking tools.",
}
_update_llm_context(context, new_system, schemas)
# Update context with new system message and tools
# Now we can pass schemas directly since they're FunctionSchema objects
new_system = {
"role": "system",
"content": "You are a scheduling assistant with access to weather and booking tools.",
}
_update_llm_context(context, new_system, schemas)
# Verify context was updated correctly
messages = context.messages
assert len(messages) == 4
assert (
messages[0]["content"]
== "You are a scheduling assistant with access to weather and booking tools."
)
assert messages[1]["role"] == "user"
assert messages[3]["content"] == "San Francisco"
# Verify context was updated correctly
messages = context.messages
assert len(messages) == 4
assert (
messages[0]["content"]
== "You are a scheduling assistant with access to weather and booking tools."
)
assert messages[1]["role"] == "user"
assert messages[3]["content"] == "San Francisco"
# Verify tools were set
tools = context.tools
assert tools is not None
assert len(tools.standard_tools) == 3
# Verify tools were set
tools = context.tools
assert tools is not None
assert len(tools.standard_tools) == 3
# Verify tool names
tool_names = {t.name for t in tools.standard_tools}
assert tool_names == {
"get_weather",
"book_appointment",
"customer_lookup",
}
# Verify tool names
tool_names = {t.name for t in tools.standard_tools}
assert tool_names == {
"get_weather",
"book_appointment",
"customer_lookup",
}
@pytest.mark.asyncio
async def test_tool_schemas_have_correct_properties(
@ -118,39 +113,32 @@ class TestCustomToolManagerContextIntegration:
manager = CustomToolManager(mock_engine)
with patch(
"api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run"
) as mock_get_org:
mock_get_org.return_value = 1
"api.services.workflow.pipecat_engine_custom_tools.db_client"
) as mock_db:
mock_db.get_tools_by_uuids = AsyncMock(return_value=sample_tools)
with patch(
"api.services.workflow.pipecat_engine_custom_tools.db_client"
) as mock_db:
mock_db.get_tools_by_uuids = AsyncMock(return_value=sample_tools)
schemas = await manager.get_tool_schemas(
["weather-uuid-123", "booking-uuid-456"]
)
schemas = await manager.get_tool_schemas(
["weather-uuid-123", "booking-uuid-456"]
)
# Find the booking schema - now using FunctionSchema attributes
booking_schema = next(s for s in schemas if s.name == "book_appointment")
# Find the booking schema - now using FunctionSchema attributes
booking_schema = next(
s for s in schemas if s.name == "book_appointment"
)
# Verify parameter properties
assert "customer_name" in booking_schema.properties
assert "date" in booking_schema.properties
assert "time" in booking_schema.properties
assert "notes" in booking_schema.properties
# Verify parameter properties
assert "customer_name" in booking_schema.properties
assert "date" in booking_schema.properties
assert "time" in booking_schema.properties
assert "notes" in booking_schema.properties
# Verify types
assert booking_schema.properties["customer_name"]["type"] == "string"
assert booking_schema.properties["date"]["type"] == "string"
# Verify types
assert booking_schema.properties["customer_name"]["type"] == "string"
assert booking_schema.properties["date"]["type"] == "string"
# Verify required
assert "customer_name" in booking_schema.required
assert "date" in booking_schema.required
assert "time" in booking_schema.required
assert "notes" not in booking_schema.required
# Verify required
assert "customer_name" in booking_schema.required
assert "date" in booking_schema.required
assert "time" in booking_schema.required
assert "notes" not in booking_schema.required
@pytest.mark.asyncio
async def test_context_update_with_builtin_and_custom_tools(
@ -160,67 +148,62 @@ class TestCustomToolManagerContextIntegration:
manager = CustomToolManager(mock_engine)
with patch(
"api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run"
) as mock_get_org:
mock_get_org.return_value = 1
"api.services.workflow.pipecat_engine_custom_tools.db_client"
) as mock_db:
mock_db.get_tools_by_uuids = AsyncMock(
return_value=[sample_tools[0]]
) # Just weather
with patch(
"api.services.workflow.pipecat_engine_custom_tools.db_client"
) as mock_db:
mock_db.get_tools_by_uuids = AsyncMock(
return_value=[sample_tools[0]]
) # Just weather
# Get custom tool schemas - returns FunctionSchema objects
custom_schemas = await manager.get_tool_schemas(["weather-uuid-123"])
# Get custom tool schemas - returns FunctionSchema objects
custom_schemas = await manager.get_tool_schemas(["weather-uuid-123"])
# Create built-in function schemas (like calculator, timezone)
builtin_functions = [
get_function_schema(
"safe_calculator",
"Evaluate a mathematical expression safely",
properties={
"expression": {
"type": "string",
"description": "Mathematical expression to evaluate",
}
},
required=["expression"],
),
get_function_schema(
"get_current_time",
"Get the current time in a timezone",
properties={
"timezone": {
"type": "string",
"description": "Timezone name (e.g., America/New_York)",
}
},
required=["timezone"],
),
]
# Create built-in function schemas (like calculator, timezone)
builtin_functions = [
get_function_schema(
"safe_calculator",
"Evaluate a mathematical expression safely",
properties={
"expression": {
"type": "string",
"description": "Mathematical expression to evaluate",
}
},
required=["expression"],
),
get_function_schema(
"get_current_time",
"Get the current time in a timezone",
properties={
"timezone": {
"type": "string",
"description": "Timezone name (e.g., America/New_York)",
}
},
required=["timezone"],
),
]
# Combine built-in and custom functions - both are FunctionSchema objects
all_functions = builtin_functions + custom_schemas
# Combine built-in and custom functions - both are FunctionSchema objects
all_functions = builtin_functions + custom_schemas
# Update context
context = LLMContext()
context.set_messages([{"role": "system", "content": "Old prompt"}])
# Update context
context = LLMContext()
context.set_messages([{"role": "system", "content": "Old prompt"}])
new_system = {
"role": "system",
"content": "Assistant with calculator and weather tools",
}
_update_llm_context(context, new_system, all_functions)
new_system = {
"role": "system",
"content": "Assistant with calculator and weather tools",
}
_update_llm_context(context, new_system, all_functions)
# Verify all tools are present
tools = context.tools
assert len(tools.standard_tools) == 3
# Verify all tools are present
tools = context.tools
assert len(tools.standard_tools) == 3
tool_names = {t.name for t in tools.standard_tools}
assert "safe_calculator" in tool_names
assert "get_current_time" in tool_names
assert "get_weather" in tool_names
tool_names = {t.name for t in tools.standard_tools}
assert "safe_calculator" in tool_names
assert "get_current_time" in tool_names
assert "get_weather" in tool_names
@pytest.mark.asyncio
async def test_context_preserves_function_call_history(
@ -230,65 +213,60 @@ class TestCustomToolManagerContextIntegration:
manager = CustomToolManager(mock_engine)
with patch(
"api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run"
) as mock_get_org:
mock_get_org.return_value = 1
"api.services.workflow.pipecat_engine_custom_tools.db_client"
) as mock_db:
mock_db.get_tools_by_uuids = AsyncMock(return_value=[sample_tools[0]])
with patch(
"api.services.workflow.pipecat_engine_custom_tools.db_client"
) as mock_db:
mock_db.get_tools_by_uuids = AsyncMock(return_value=[sample_tools[0]])
# Get schemas - returns FunctionSchema objects
schemas = await manager.get_tool_schemas(["weather-uuid-123"])
# Get schemas - returns FunctionSchema objects
schemas = await manager.get_tool_schemas(["weather-uuid-123"])
# Create context with function call history
context = LLMContext()
context.set_messages(
[
{"role": "system", "content": "Old system prompt"},
{"role": "user", "content": "What's the weather in NYC?"},
{
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": "call_123",
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"location": "New York, NY"}',
},
}
],
},
{
"role": "tool",
"tool_call_id": "call_123",
"content": '{"temperature": 72, "condition": "sunny"}',
},
{
"role": "assistant",
"content": "The weather in NYC is 72°F and sunny!",
},
]
)
# Create context with function call history
context = LLMContext()
context.set_messages(
[
{"role": "system", "content": "Old system prompt"},
{"role": "user", "content": "What's the weather in NYC?"},
{
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": "call_123",
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"location": "New York, NY"}',
},
}
],
},
{
"role": "tool",
"tool_call_id": "call_123",
"content": '{"temperature": 72, "condition": "sunny"}',
},
{
"role": "assistant",
"content": "The weather in NYC is 72°F and sunny!",
},
]
)
new_system = {"role": "system", "content": "Updated weather assistant"}
_update_llm_context(context, new_system, schemas)
new_system = {"role": "system", "content": "Updated weather assistant"}
_update_llm_context(context, new_system, schemas)
messages = context.messages
# System + user + assistant(tool_call) + tool + assistant = 5
assert len(messages) == 5
messages = context.messages
# System + user + assistant(tool_call) + tool + assistant = 5
assert len(messages) == 5
# Verify function call messages are preserved
tool_call_msg = messages[2]
assert tool_call_msg["role"] == "assistant"
assert "tool_calls" in tool_call_msg
# Verify function call messages are preserved
tool_call_msg = messages[2]
assert tool_call_msg["role"] == "assistant"
assert "tool_calls" in tool_call_msg
tool_result_msg = messages[3]
assert tool_result_msg["role"] == "tool"
assert tool_result_msg["tool_call_id"] == "call_123"
tool_result_msg = messages[3]
assert tool_result_msg["role"] == "tool"
assert tool_result_msg["tool_call_id"] == "call_123"
@pytest.mark.asyncio
async def test_empty_tool_list_does_not_set_tools(self, mock_engine):
@ -296,26 +274,21 @@ class TestCustomToolManagerContextIntegration:
manager = CustomToolManager(mock_engine)
with patch(
"api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run"
) as mock_get_org:
mock_get_org.return_value = 1
"api.services.workflow.pipecat_engine_custom_tools.db_client"
) as mock_db:
mock_db.get_tools_by_uuids = AsyncMock(return_value=[])
with patch(
"api.services.workflow.pipecat_engine_custom_tools.db_client"
) as mock_db:
mock_db.get_tools_by_uuids = AsyncMock(return_value=[])
schemas = await manager.get_tool_schemas([])
assert schemas == []
schemas = await manager.get_tool_schemas([])
assert schemas == []
context = LLMContext()
context.set_messages([{"role": "system", "content": "Old"}])
context = LLMContext()
context.set_messages([{"role": "system", "content": "Old"}])
new_system = {"role": "system", "content": "No tools available"}
_update_llm_context(context, new_system, [])
new_system = {"role": "system", "content": "No tools available"}
_update_llm_context(context, new_system, [])
# Context should have updated message but no tools set
assert context.messages[0]["content"] == "No tools available"
# Context should have updated message but no tools set
assert context.messages[0]["content"] == "No tools available"
@pytest.mark.asyncio
async def test_numeric_and_boolean_parameter_types(self, mock_engine):
@ -357,33 +330,28 @@ class TestCustomToolManagerContextIntegration:
manager = CustomToolManager(mock_engine)
with patch(
"api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run"
) as mock_get_org:
mock_get_org.return_value = 1
"api.services.workflow.pipecat_engine_custom_tools.db_client"
) as mock_db:
mock_db.get_tools_by_uuids = AsyncMock(return_value=[tool_with_types])
with patch(
"api.services.workflow.pipecat_engine_custom_tools.db_client"
) as mock_db:
mock_db.get_tools_by_uuids = AsyncMock(return_value=[tool_with_types])
# Get schemas - returns FunctionSchema objects
schemas = await manager.get_tool_schemas(["order-uuid"])
schema = schemas[0]
# Get schemas - returns FunctionSchema objects
schemas = await manager.get_tool_schemas(["order-uuid"])
schema = schemas[0]
# Verify types using FunctionSchema attributes
assert schema.properties["item_id"]["type"] == "string"
assert schema.properties["quantity"]["type"] == "number"
assert schema.properties["express_shipping"]["type"] == "boolean"
# Verify types using FunctionSchema attributes
assert schema.properties["item_id"]["type"] == "string"
assert schema.properties["quantity"]["type"] == "number"
assert schema.properties["express_shipping"]["type"] == "boolean"
# Update context - pass schema directly
context = LLMContext()
context.set_messages([{"role": "system", "content": "Old"}])
_update_llm_context(
context, {"role": "system", "content": "Order assistant"}, schemas
)
# Update context - pass schema directly
context = LLMContext()
context.set_messages([{"role": "system", "content": "Old"}])
_update_llm_context(
context, {"role": "system", "content": "Order assistant"}, schemas
)
# Verify tool was set with correct types
tool = context.tools.standard_tools[0]
assert tool.name == "place_order"
assert tool.properties["quantity"]["type"] == "number"
assert tool.properties["express_shipping"]["type"] == "boolean"
# Verify tool was set with correct types
tool = context.tools.standard_tools[0]
assert tool.name == "place_order"
assert tool.properties["quantity"]["type"] == "number"
assert tool.properties["express_shipping"]["type"] == "boolean"

View file

@ -14,18 +14,10 @@ result in the context when generating the next response.
"""
import asyncio
from typing import Any, Dict, List
from typing import List
from unittest.mock import AsyncMock, patch
import pytest
from api.services.workflow.pipecat_engine import PipecatEngine
from api.services.workflow.workflow import WorkflowGraph
from api.tests.conftest import (
AGENT_SYSTEM_PROMPT,
END_CALL_SYSTEM_PROMPT,
START_CALL_SYSTEM_PROMPT,
)
from pipecat.frames.frames import LLMContextFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
@ -35,75 +27,21 @@ from pipecat.processors.aggregators.llm_response_universal import (
LLMAssistantAggregatorParams,
LLMContextAggregatorPair,
)
from pipecat.tests import MockLLMService, MockTTSService
from pipecat.tests.mock_transport import MockTransport
from pipecat.transports.base_transport import TransportParams
class ContextCapturingMockLLM(MockLLMService):
"""A MockLLMService that captures the context state at each generation.
This allows us to verify that tool call results are present in the context
when the next LLM generation is triggered.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.captured_contexts: List[Dict[str, Any]] = []
async def _stream_chat_completions_universal_context(self, context):
"""Override to capture context state before streaming chunks."""
# Deep copy the messages to avoid mutation issues
messages_snapshot = []
for msg in context.messages:
msg_copy = dict(msg)
# Copy content to avoid reference issues
if "content" in msg_copy:
msg_copy["content"] = (
str(msg_copy["content"]) if msg_copy["content"] else None
)
messages_snapshot.append(msg_copy)
self.captured_contexts.append(
{
"step": self._current_step,
"messages": messages_snapshot,
"system_prompt": self._settings.system_instruction,
}
)
# Call parent implementation to stream the mock chunks
return await super()._stream_chat_completions_universal_context(context)
def get_context_at_step(self, step: int) -> Dict[str, Any]:
"""Get the captured context at a specific step (0-indexed)."""
for ctx in self.captured_contexts:
if ctx["step"] == step:
return ctx
return None
def has_tool_call_result_at_step(self, step: int, function_name: str) -> bool:
"""Check if a tool call result for the given function exists in context at step."""
ctx = self.get_context_at_step(step)
if not ctx:
return False
for msg in ctx["messages"]:
# Check for tool/function role messages
if msg.get("role") == "tool" and msg.get("name") == function_name:
return True
# Also check for tool_call_id which indicates a tool response
if msg.get("tool_call_id") and function_name in str(msg.get("name", "")):
return True
return False
def get_system_prompt_at_step(self, step: int) -> str:
"""Get the system prompt from settings at a specific step."""
ctx = self.get_context_at_step(step)
if ctx:
return ctx.get("system_prompt") or ""
return ""
from api.services.workflow.pipecat_engine import PipecatEngine
from api.services.workflow.workflow import WorkflowGraph
from api.tests.conftest import (
AGENT_SYSTEM_PROMPT,
END_CALL_SYSTEM_PROMPT,
START_CALL_SYSTEM_PROMPT,
)
from pipecat.tests import (
ContextCapturingMockLLM,
MockLLMService,
MockTTSService,
)
async def run_pipeline_and_capture_context(
@ -142,7 +80,7 @@ async def run_pipeline_and_capture_context(
context = LLMContext()
# Add assistant context aggregator
assistant_params = LLMAssistantAggregatorParams(expect_stripped_words=True)
assistant_params = LLMAssistantAggregatorParams()
context_aggregator = LLMContextAggregatorPair(
context, assistant_params=assistant_params
)
@ -184,7 +122,7 @@ async def run_pipeline_and_capture_context(
# Patch DB calls
with patch(
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run",
"api.db:db_client.get_organization_id_by_workflow_run_id",
new_callable=AsyncMock,
return_value=1,
):

View file

@ -23,6 +23,23 @@ from typing import Any, Dict, List
from unittest.mock import AsyncMock, patch
import pytest
from pipecat.frames.frames import Frame, LLMContextFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_response_universal import (
LLMAssistantAggregatorParams,
LLMContextAggregatorPair,
LLMUserAggregatorParams,
)
from pipecat.tests.mock_transport import MockTransport
from pipecat.transports.base_transport import TransportParams
from pipecat.turns.user_mute import (
CallbackUserMuteStrategy,
MuteUntilFirstBotCompleteUserMuteStrategy,
)
from pipecat.utils.enums import EndTaskReason
from api.enums import ToolCategory
from api.services.workflow.dto import (
@ -42,24 +59,7 @@ from api.services.workflow.pipecat_engine_variable_extractor import (
)
from api.services.workflow.workflow import WorkflowGraph
from api.tests.conftest import END_CALL_SYSTEM_PROMPT, START_CALL_SYSTEM_PROMPT
from pipecat.frames.frames import Frame, LLMContextFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_response_universal import (
LLMAssistantAggregatorParams,
LLMContextAggregatorPair,
LLMUserAggregatorParams,
)
from pipecat.tests import MockLLMService, MockTTSService
from pipecat.tests.mock_transport import MockTransport
from pipecat.transports.base_transport import TransportParams
from pipecat.turns.user_mute import (
CallbackUserMuteStrategy,
MuteUntilFirstBotCompleteUserMuteStrategy,
)
from pipecat.utils.enums import EndTaskReason
class EndCallTestHelper:
@ -182,7 +182,7 @@ async def create_engine_with_tracking(
engine.end_call_with_reason = tracked_end_call
# Create context aggregator with user mute strategies (after engine so we can use its callback)
assistant_params = LLMAssistantAggregatorParams(expect_stripped_words=True)
assistant_params = LLMAssistantAggregatorParams()
# Wrap should_mute_user to track calls
original_should_mute_user = engine.should_mute_user
@ -265,7 +265,7 @@ class TestEndCallViaNodeTransition:
# Patch DB calls and extraction manager
with patch(
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run",
"api.db:db_client.get_organization_id_by_workflow_run_id",
new_callable=AsyncMock,
return_value=1,
):
@ -369,7 +369,7 @@ class TestEndCallViaNodeTransition:
# Patch DB calls and extraction manager
with patch(
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run",
"api.db:db_client.get_organization_id_by_workflow_run_id",
new_callable=AsyncMock,
return_value=1,
):
@ -468,7 +468,7 @@ class TestEndCallViaCustomTool:
# Patch DB calls and extraction manager
with patch(
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run",
"api.db:db_client.get_organization_id_by_workflow_run_id",
new_callable=AsyncMock,
return_value=1,
):
@ -560,7 +560,7 @@ class TestEndCallViaCustomTool:
# Patch DB calls and extraction manager
with patch(
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run",
"api.db:db_client.get_organization_id_by_workflow_run_id",
new_callable=AsyncMock,
return_value=1,
):
@ -638,7 +638,7 @@ class TestEndCallViaClientDisconnect:
# Patch DB calls and extraction manager
with patch(
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run",
"api.db:db_client.get_organization_id_by_workflow_run_id",
new_callable=AsyncMock,
return_value=1,
):
@ -729,7 +729,7 @@ class TestEndCallRaceConditions:
# Patch DB calls and extraction manager
with patch(
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run",
"api.db:db_client.get_organization_id_by_workflow_run_id",
new_callable=AsyncMock,
return_value=1,
):
@ -841,7 +841,7 @@ class TestEndCallRaceConditions:
# Patch DB calls and extraction manager
with patch(
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run",
"api.db:db_client.get_organization_id_by_workflow_run_id",
new_callable=AsyncMock,
return_value=1,
):
@ -937,7 +937,7 @@ class TestEndCallExtractionBehavior:
# Patch DB calls and extraction manager
with patch(
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run",
"api.db:db_client.get_organization_id_by_workflow_run_id",
new_callable=AsyncMock,
return_value=1,
):
@ -1061,7 +1061,7 @@ class TestEndCallExtractionBehavior:
# Patch DB calls and extraction manager
with patch(
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run",
"api.db:db_client.get_organization_id_by_workflow_run_id",
new_callable=AsyncMock,
return_value=1,
):

View file

@ -15,9 +15,6 @@ import asyncio
from unittest.mock import AsyncMock, patch
import pytest
from api.services.workflow.pipecat_engine import PipecatEngine
from api.services.workflow.workflow import WorkflowGraph
from pipecat.frames.frames import (
Frame,
FunctionCallResultFrame,
@ -36,7 +33,6 @@ from pipecat.processors.aggregators.llm_response_universal import (
LLMUserAggregatorParams,
)
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.tests import MockLLMService, MockTTSService
from pipecat.tests.mock_transport import MockTransport
from pipecat.transports.base_transport import TransportParams
from pipecat.turns.user_mute import (
@ -52,6 +48,10 @@ from pipecat.turns.user_stop import (
from pipecat.turns.user_turn_strategies import UserTurnStrategies
from pipecat.utils.time import time_now_iso8601
from api.services.workflow.pipecat_engine import PipecatEngine
from api.services.workflow.workflow import WorkflowGraph
from pipecat.tests import MockLLMService, MockTTSService
class UserSpeechInjector(FrameProcessor):
"""Processor that injects user speaking frames on FunctionCallResultFrame.
@ -183,7 +183,7 @@ async def create_test_pipeline(
)
# Create context aggregator with user and assistant params
assistant_params = LLMAssistantAggregatorParams(expect_stripped_words=True)
assistant_params = LLMAssistantAggregatorParams()
context_aggregator = LLMContextAggregatorPair(
context, assistant_params=assistant_params, user_params=user_params
@ -277,7 +277,7 @@ class TestNodeSwitchWithUserSpeech:
# Patch DB calls
with patch(
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run",
"api.db:db_client.get_organization_id_by_workflow_run_id",
new_callable=AsyncMock,
return_value=1,
):

View file

@ -9,10 +9,6 @@ from typing import Any, Dict, List
from unittest.mock import AsyncMock, patch
import pytest
from api.services.workflow.pipecat_engine import PipecatEngine
from api.services.workflow.workflow import WorkflowGraph
from api.tests.conftest import END_CALL_SYSTEM_PROMPT
from pipecat.frames.frames import LLMContextFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
@ -22,10 +18,14 @@ from pipecat.processors.aggregators.llm_response_universal import (
LLMAssistantAggregatorParams,
LLMContextAggregatorPair,
)
from pipecat.tests import MockLLMService, MockTTSService
from pipecat.tests.mock_transport import MockTransport
from pipecat.transports.base_transport import TransportParams
from api.services.workflow.pipecat_engine import PipecatEngine
from api.services.workflow.workflow import WorkflowGraph
from api.tests.conftest import END_CALL_SYSTEM_PROMPT
from pipecat.tests import MockLLMService, MockTTSService
async def run_pipeline_with_tool_calls(
workflow: WorkflowGraph,
@ -81,7 +81,7 @@ async def run_pipeline_with_tool_calls(
context = LLMContext()
# Add assistant context aggregator
assistant_params = LLMAssistantAggregatorParams(expect_stripped_words=True)
assistant_params = LLMAssistantAggregatorParams()
context_aggregator = LLMContextAggregatorPair(
context, assistant_params=assistant_params
)
@ -113,7 +113,7 @@ async def run_pipeline_with_tool_calls(
# Patch DB calls to avoid actual database access
with patch(
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run",
"api.db:db_client.get_organization_id_by_workflow_run_id",
new_callable=AsyncMock,
return_value=1,
):

View file

@ -0,0 +1,280 @@
"""Tests verifying user is muted while a transition function is executing.
When the LLM calls a transition function (registered via
``_register_transition_function_with_llm``), pipecat broadcasts a
``FunctionCallsStartedFrame`` that ``FunctionCallUserMuteStrategy`` uses to
mute the user until a ``FunctionCallResultFrame`` arrives. These tests assert
that mute behavior holds end-to-end through the engine's transition flow,
so that user audio doesn't race the node switch / extraction / context update
that runs inside the transition function.
"""
import asyncio
from unittest.mock import AsyncMock, patch
import pytest
from pipecat.frames.frames import LLMContextFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_response_universal import (
LLMAssistantAggregatorParams,
LLMContextAggregatorPair,
LLMUserAggregatorParams,
)
from pipecat.tests.mock_transport import MockTransport
from pipecat.transports.base_transport import TransportParams
from pipecat.turns.user_mute import (
CallbackUserMuteStrategy,
FunctionCallUserMuteStrategy,
MuteUntilFirstBotCompleteUserMuteStrategy,
)
from api.services.workflow.pipecat_engine import PipecatEngine
from api.services.workflow.pipecat_engine_variable_extractor import (
VariableExtractionManager,
)
from api.services.workflow.workflow import WorkflowGraph
from pipecat.tests import MockLLMService, MockTTSService
async def _build_engine_and_pipeline(
workflow: WorkflowGraph,
mock_llm: MockLLMService,
):
"""Set up engine + pipeline mirroring the non-realtime production wiring.
Returns (engine, task, function_call_mute_strategy, user_context_aggregator).
"""
tts = MockTTSService(mock_audio_duration_ms=40, frame_delay=0)
transport = MockTransport(
params=TransportParams(
audio_in_enabled=True,
audio_out_enabled=True,
audio_in_sample_rate=16000,
audio_out_sample_rate=16000,
),
)
context = LLMContext()
engine = PipecatEngine(
llm=mock_llm,
context=context,
workflow=workflow,
call_context_vars={"customer_name": "Test User"},
workflow_run_id=1,
)
# Hold a reference so the test can introspect the in-progress set.
function_call_mute_strategy = FunctionCallUserMuteStrategy()
# Match run_pipeline.py's non-realtime mute-strategy stack so the test
# exercises the same wiring that would be active in a real call.
user_mute_strategies = [
MuteUntilFirstBotCompleteUserMuteStrategy(),
function_call_mute_strategy,
CallbackUserMuteStrategy(should_mute_callback=engine.should_mute_user),
]
user_params = LLMUserAggregatorParams(user_mute_strategies=user_mute_strategies)
assistant_params = LLMAssistantAggregatorParams()
context_aggregator = LLMContextAggregatorPair(
context, assistant_params=assistant_params, user_params=user_params
)
user_context_aggregator = context_aggregator.user()
assistant_context_aggregator = context_aggregator.assistant()
pipeline = Pipeline(
[
transport.input(),
user_context_aggregator,
mock_llm,
tts,
transport.output(),
assistant_context_aggregator,
]
)
task = PipelineTask(pipeline, params=PipelineParams(), enable_rtvi=False)
engine.set_task(task)
return engine, task, function_call_mute_strategy, user_context_aggregator
class TestTransitionFunctionMutesUser:
"""Verify the user is muted while transition functions execute."""
@pytest.mark.asyncio
async def test_user_is_muted_during_transition_function(
self, simple_workflow: WorkflowGraph
):
"""The user must be muted from the moment a transition function starts
until its result is delivered.
Scenario:
1. LLM calls the ``end_call`` transition function (start end edge).
2. Wrap the registered handler so we can read mute state from inside it.
3. VERIFY: the function-call mute strategy has the call in flight.
4. VERIFY: the user aggregator's ``_user_is_muted`` flag is True.
"""
step_0_chunks = MockLLMService.create_function_call_chunks(
function_name="end_call",
arguments={},
tool_call_id="call_end_1",
)
llm = MockLLMService(mock_steps=[step_0_chunks], chunk_delay=0.001)
(
engine,
task,
function_call_mute_strategy,
user_context_aggregator,
) = await _build_engine_and_pipeline(simple_workflow, llm)
captured_states: list[dict] = []
# Wrap register_function so we can introspect mute state from inside
# the transition handler. We must wrap *after* the engine is created
# but *before* set_node registers the transition functions.
original_register_function = llm.register_function
def wrapping_register_function(name, func, *args, **kwargs):
async def wrapped(function_call_params):
# Yield once so the user aggregator has a chance to drain
# the broadcasted FunctionCallsStartedFrame and update its
# mute state before we sample it.
await asyncio.sleep(0.02)
captured_states.append(
{
"name": name,
"function_call_in_progress": bool(
function_call_mute_strategy._function_call_in_progress
),
"user_is_muted": user_context_aggregator._user_is_muted,
"tool_call_ids": set(
function_call_mute_strategy._function_call_in_progress
),
}
)
return await func(function_call_params)
return original_register_function(name, wrapped, *args, **kwargs)
llm.register_function = wrapping_register_function
with patch(
"api.db:db_client.get_organization_id_by_workflow_run_id",
new_callable=AsyncMock,
return_value=1,
):
with patch(
"api.services.workflow.pipecat_engine.apply_disposition_mapping",
new_callable=AsyncMock,
return_value="completed",
):
with patch.object(
VariableExtractionManager,
"_perform_extraction",
new_callable=AsyncMock,
return_value={"user_intent": "end call"},
):
runner = PipelineRunner()
async def run_pipeline():
await runner.run(task)
async def initialize_engine():
await asyncio.sleep(0.01)
await engine.initialize()
await engine.set_node(engine.workflow.start_node_id)
await engine.llm.queue_frame(LLMContextFrame(engine.context))
await asyncio.wait_for(
asyncio.gather(run_pipeline(), initialize_engine()),
timeout=10.0,
)
assert len(captured_states) == 1, (
f"Expected the transition function to be invoked exactly once, "
f"got {len(captured_states)}: {captured_states}"
)
state = captured_states[0]
assert state["name"] == "end_call"
assert state["function_call_in_progress"], (
"FunctionCallUserMuteStrategy should have the transition call in "
f"progress while the handler runs (state={state})"
)
assert "call_end_1" in state["tool_call_ids"], (
f"Expected tool_call_id 'call_end_1' to be tracked, got {state['tool_call_ids']}"
)
assert state["user_is_muted"], (
"User aggregator's _user_is_muted should be True during the "
f"transition function (state={state})"
)
@pytest.mark.asyncio
async def test_user_is_unmuted_after_transition_function_returns(
self, simple_workflow: WorkflowGraph
):
"""After the transition function's result is delivered, the function-call
mute strategy should clear its in-progress set. Other strategies in the
stack (CallbackUserMuteStrategy via engine.should_mute_user) may still
keep the pipeline muted because end_call_with_reason fires when the
engine reaches the End node, but the function-call strategy itself
must release its hold.
"""
step_0_chunks = MockLLMService.create_function_call_chunks(
function_name="end_call",
arguments={},
tool_call_id="call_end_1",
)
llm = MockLLMService(mock_steps=[step_0_chunks], chunk_delay=0.001)
(
engine,
task,
function_call_mute_strategy,
_user_context_aggregator,
) = await _build_engine_and_pipeline(simple_workflow, llm)
with patch(
"api.db:db_client.get_organization_id_by_workflow_run_id",
new_callable=AsyncMock,
return_value=1,
):
with patch(
"api.services.workflow.pipecat_engine.apply_disposition_mapping",
new_callable=AsyncMock,
return_value="completed",
):
with patch.object(
VariableExtractionManager,
"_perform_extraction",
new_callable=AsyncMock,
return_value={"user_intent": "end call"},
):
runner = PipelineRunner()
async def run_pipeline():
await runner.run(task)
async def initialize_engine():
await asyncio.sleep(0.01)
await engine.initialize()
await engine.set_node(engine.workflow.start_node_id)
await engine.llm.queue_frame(LLMContextFrame(engine.context))
await asyncio.wait_for(
asyncio.gather(run_pipeline(), initialize_engine()),
timeout=10.0,
)
assert function_call_mute_strategy._function_call_in_progress == set(), (
"FunctionCallUserMuteStrategy should have cleared its in-progress "
"set after the transition function's result was delivered, got "
f"{function_call_mute_strategy._function_call_in_progress}"
)

View file

@ -16,12 +16,6 @@ from typing import Any, Dict, List
from unittest.mock import AsyncMock, patch
import pytest
from api.services.workflow.pipecat_engine import PipecatEngine
from api.services.workflow.pipecat_engine_variable_extractor import (
VariableExtractionManager,
)
from api.services.workflow.workflow import WorkflowGraph
from pipecat.frames.frames import LLMContextFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
@ -31,10 +25,16 @@ from pipecat.processors.aggregators.llm_response_universal import (
LLMAssistantAggregatorParams,
LLMContextAggregatorPair,
)
from pipecat.tests import MockLLMService, MockTTSService
from pipecat.tests.mock_transport import MockTransport
from pipecat.transports.base_transport import TransportParams
from api.services.workflow.pipecat_engine import PipecatEngine
from api.services.workflow.pipecat_engine_variable_extractor import (
VariableExtractionManager,
)
from api.services.workflow.workflow import WorkflowGraph
from pipecat.tests import MockLLMService, MockTTSService
class TestVariableExtractionDuringTransitions:
"""Test that variable extraction is triggered for the correct node during transitions."""
@ -97,7 +97,7 @@ class TestVariableExtractionDuringTransitions:
context = LLMContext()
# Add assistant context aggregator
assistant_params = LLMAssistantAggregatorParams(expect_stripped_words=True)
assistant_params = LLMAssistantAggregatorParams()
context_aggregator = LLMContextAggregatorPair(
context, assistant_params=assistant_params
)
@ -152,7 +152,7 @@ class TestVariableExtractionDuringTransitions:
# Patch DB calls and extraction manager
with patch(
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run",
"api.db:db_client.get_organization_id_by_workflow_run_id",
new_callable=AsyncMock,
return_value=1,
):

View file

@ -2,7 +2,6 @@ import asyncio
import pytest
from loguru import logger
from pipecat.frames.frames import (
EndTaskFrame,
Frame,
@ -35,8 +34,10 @@ class BusyWaitProcessor(FrameProcessor):
# Simulate a delay, which can happen sometimes due to slow LLM Inferencing or
# other reasons
try:
logger.debug(f"{self} sleeping with frame: {frame}")
await asyncio.sleep(5)
logger.debug(
f"{self} sleeping with frame: {frame} for {self._wait_time} seconds"
)
await asyncio.sleep(self._wait_time)
logger.debug(f"{self} woke up with frame: {frame}")
except asyncio.CancelledError:
logger.debug(f"{self} was cancelled")
@ -46,7 +47,7 @@ class BusyWaitProcessor(FrameProcessor):
@pytest.mark.asyncio
async def test_interruption_with_blocked_end_frame():
busy_wait_processor = BusyWaitProcessor(wait_time=5)
busy_wait_processor = BusyWaitProcessor(wait_time=0.5)
transport = MockTransport()
pipeline = Pipeline([transport, busy_wait_processor])
@ -78,11 +79,13 @@ async def test_interruption_with_blocked_end_frame():
# Wait with timeout
done, pending = await asyncio.wait(
[pipeline_task, queue_task],
timeout=1.0,
timeout=2.0,
return_when=asyncio.ALL_COMPLETED,
)
# If there are pending tasks, we timed out
# FIXME: Currently I have creaetd an issue on pipecat which talks about
# how this behaviour is not good. https://github.com/pipecat-ai/pipecat/issues/4412
if pending:
# Cancel all pending tasks
for t in pending:
@ -92,9 +95,9 @@ async def test_interruption_with_blocked_end_frame():
try:
await asyncio.wait_for(
asyncio.gather(*pending, return_exceptions=True),
timeout=1.0,
timeout=2.0,
)
except asyncio.TimeoutError:
pass # Cleanup took too long, continue anyway
pytest.fail("Test timed out after 1 second")
pytest.fail("Test timed out after 2 second")

View file

@ -12,6 +12,14 @@ and inspect what arrives downstream.
from typing import Optional
import pytest
from pipecat.frames.frames import (
LLMFullResponseEndFrame,
LLMTextFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
TTSTextFrame,
)
from api.services.pipecat.recording_audio_cache import RecordingAudio
from api.services.pipecat.recording_router_processor import (
@ -21,14 +29,6 @@ from api.services.workflow.pipecat_engine_context_composer import (
RECORDING_MARKER,
TTS_MARKER,
)
from pipecat.frames.frames import (
LLMFullResponseEndFrame,
LLMTextFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
TTSTextFrame,
)
from pipecat.tests import run_test
# ---------------------------------------------------------------------------

View file

@ -11,21 +11,6 @@ from typing import Any, Dict, List
from unittest.mock import AsyncMock, Mock, patch
import pytest
from api.services.pipecat.recording_audio_cache import RecordingAudio
from api.services.workflow.dto import (
EdgeDataDTO,
EndCallNodeData,
EndCallRFNode,
Position,
ReactFlowDTO,
RFEdgeDTO,
StartCallNodeData,
StartCallRFNode,
)
from api.services.workflow.pipecat_engine import PipecatEngine
from api.services.workflow.pipecat_engine_custom_tools import CustomToolManager
from api.services.workflow.workflow import WorkflowGraph
from pipecat.frames.frames import (
Frame,
LLMContextFrame,
@ -42,10 +27,25 @@ from pipecat.processors.aggregators.llm_response_universal import (
LLMAssistantAggregatorParams,
LLMContextAggregatorPair,
)
from pipecat.tests import MockLLMService, MockTTSService
from pipecat.tests.mock_transport import MockTransport
from pipecat.transports.base_transport import TransportParams
from api.services.pipecat.recording_audio_cache import RecordingAudio
from api.services.workflow.dto import (
EdgeDataDTO,
EndCallNodeData,
EndCallRFNode,
Position,
ReactFlowDTO,
RFEdgeDTO,
StartCallNodeData,
StartCallRFNode,
)
from api.services.workflow.pipecat_engine import PipecatEngine
from api.services.workflow.pipecat_engine_custom_tools import CustomToolManager
from api.services.workflow.workflow import WorkflowGraph
from pipecat.tests import MockLLMService, MockTTSService
# ─── Constants ──────────────────────────────────────────────────
START_PROMPT = "Start Call System Prompt"
@ -189,7 +189,7 @@ async def run_pipeline_and_capture_frames(
)
context = LLMContext()
assistant_params = LLMAssistantAggregatorParams(expect_stripped_words=True)
assistant_params = LLMAssistantAggregatorParams()
context_aggregator = LLMContextAggregatorPair(
context, assistant_params=assistant_params
)
@ -234,7 +234,7 @@ async def run_pipeline_and_capture_frames(
with (
patch(
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run",
"api.db:db_client.get_organization_id_by_workflow_run_id",
new_callable=AsyncMock,
return_value=1,
),

View file

@ -32,12 +32,6 @@ import asyncio
from unittest.mock import AsyncMock, patch
import pytest
from api.services.workflow.pipecat_engine import PipecatEngine
from api.services.workflow.pipecat_engine_variable_extractor import (
VariableExtractionManager,
)
from api.services.workflow.workflow import WorkflowGraph
from pipecat.frames.frames import LLMContextFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
@ -48,7 +42,6 @@ from pipecat.processors.aggregators.llm_response_universal import (
LLMContextAggregatorPair,
LLMUserAggregatorParams,
)
from pipecat.tests import MockLLMService, MockTTSService
from pipecat.tests.mock_transport import MockTransport
from pipecat.transports.base_transport import TransportParams
from pipecat.turns.user_mute import (
@ -57,6 +50,13 @@ from pipecat.turns.user_mute import (
)
from pipecat.utils.enums import EndTaskReason
from api.services.workflow.pipecat_engine import PipecatEngine
from api.services.workflow.pipecat_engine_variable_extractor import (
VariableExtractionManager,
)
from api.services.workflow.workflow import WorkflowGraph
from pipecat.tests import MockLLMService, MockTTSService
async def create_test_pipeline_with_failing_transport(
workflow: WorkflowGraph,
@ -131,7 +131,7 @@ async def create_test_pipeline_with_failing_transport(
user_mute_strategies=user_mute_strategies,
)
assistant_params = LLMAssistantAggregatorParams(expect_stripped_words=True)
assistant_params = LLMAssistantAggregatorParams()
context_aggregator = LLMContextAggregatorPair(
context, assistant_params=assistant_params, user_params=user_params
@ -204,7 +204,7 @@ class TestTTSPauseWithAudioWriteFailure:
# Patch DB calls
with patch(
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run",
"api.db:db_client.get_organization_id_by_workflow_run_id",
new_callable=AsyncMock,
return_value=1,
):
@ -324,7 +324,7 @@ class TestTTSPauseWithAudioWriteFailure:
)
with patch(
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run",
"api.db:db_client.get_organization_id_by_workflow_run_id",
new_callable=AsyncMock,
return_value=1,
):

View file

@ -0,0 +1,81 @@
"""Tests for LLM behavior when calling an unregistered function."""
import pytest
from pipecat.frames.frames import (
FunctionCallInProgressFrame,
FunctionCallResultFrame,
FunctionCallsFromLLMInfoFrame,
FunctionCallsStartedFrame,
LLMContextFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
)
from pipecat.pipeline.pipeline import Pipeline
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.tests import MockLLMService, run_test
class TestUnregisteredFunctionCall:
"""Tests for LLM behavior when generating a tool call for an unregistered function."""
@pytest.mark.asyncio
async def test_unregistered_function_emits_error_result(self):
"""LLM calling an unregistered function should still terminate with a
FunctionCallResultFrame whose result is an error string, instead of
crashing the pipeline."""
chunks = MockLLMService.create_function_call_chunks(
function_name="nonexistent_tool",
arguments={"foo": "bar"},
tool_call_id="call_missing_1",
)
llm = MockLLMService(mock_chunks=chunks, chunk_delay=0.001)
# Intentionally do NOT register any handler for "nonexistent_tool".
messages = [{"role": "user", "content": "Please use a tool I never registered"}]
context = LLMContext(messages)
pipeline = Pipeline([llm])
received_down_frames, _ = await run_test(
pipeline,
frames_to_send=[LLMContextFrame(context)],
expected_down_frames=[
LLMFullResponseStartFrame,
FunctionCallsFromLLMInfoFrame,
FunctionCallsStartedFrame,
LLMFullResponseEndFrame,
FunctionCallInProgressFrame,
FunctionCallResultFrame,
],
)
result_frames = [
f for f in received_down_frames if isinstance(f, FunctionCallResultFrame)
]
assert len(result_frames) == 1, (
"Expected exactly one FunctionCallResultFrame for the unregistered call"
)
result_frame = result_frames[0]
assert result_frame.function_name == "nonexistent_tool"
assert result_frame.tool_call_id == "call_missing_1"
assert result_frame.arguments == {"foo": "bar"}
# Pipecat's missing-function handler returns a string error.
assert isinstance(result_frame.result, str)
assert "not registered" in result_frame.result
assert "nonexistent_tool" in result_frame.result
# In-progress frame should also be emitted before the result so mute
# strategies can release the tool_call_id.
in_progress_frames = [
f
for f in received_down_frames
if isinstance(f, FunctionCallInProgressFrame)
]
assert len(in_progress_frames) == 1
assert in_progress_frames[0].function_name == "nonexistent_tool"
assert in_progress_frames[0].tool_call_id == "call_missing_1"

View file

@ -13,9 +13,6 @@ import asyncio
from unittest.mock import AsyncMock, patch
import pytest
from api.services.workflow.pipecat_engine import PipecatEngine
from api.services.workflow.workflow import WorkflowGraph
from pipecat.frames.frames import (
BotStoppedSpeakingFrame,
Frame,
@ -35,7 +32,6 @@ from pipecat.processors.aggregators.llm_response_universal import (
LLMUserAggregatorParams,
)
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.tests import MockLLMService, MockTTSService
from pipecat.tests.mock_transport import MockTransport
from pipecat.transports.base_transport import TransportParams
from pipecat.turns.user_mute import (
@ -47,6 +43,10 @@ from pipecat.turns.user_stop import ExternalUserTurnStopStrategy
from pipecat.turns.user_turn_strategies import UserTurnStrategies
from pipecat.utils.time import time_now_iso8601
from api.services.workflow.pipecat_engine import PipecatEngine
from api.services.workflow.workflow import WorkflowGraph
from pipecat.tests import MockLLMService, MockTTSService
class UserSpeechInjector(FrameProcessor):
"""Processor that injects user speaking frames after the bot finishes speaking.
@ -161,7 +161,7 @@ async def create_pipeline_with_speech_injection(
user_idle_timeout=user_idle_timeout,
)
assistant_params = LLMAssistantAggregatorParams(expect_stripped_words=True)
assistant_params = LLMAssistantAggregatorParams()
context_aggregator = LLMContextAggregatorPair(
context, assistant_params=assistant_params, user_params=user_params
@ -257,7 +257,7 @@ class TestUserIdleHandler:
)
with patch(
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run",
"api.db:db_client.get_organization_id_by_workflow_run_id",
new_callable=AsyncMock,
return_value=1,
):

View file

@ -15,12 +15,6 @@ from typing import List
from unittest.mock import AsyncMock, patch
import pytest
from api.services.workflow.pipecat_engine import PipecatEngine
from api.services.workflow.pipecat_engine_variable_extractor import (
VariableExtractionManager,
)
from api.services.workflow.workflow import WorkflowGraph
from pipecat.frames.frames import (
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
@ -41,7 +35,6 @@ from pipecat.processors.aggregators.llm_response_universal import (
LLMUserAggregatorParams,
)
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.tests import MockLLMService, MockTTSService
from pipecat.tests.mock_transport import MockTransport
from pipecat.transports.base_transport import TransportParams
from pipecat.turns.user_mute import (
@ -51,6 +44,13 @@ from pipecat.turns.user_mute import (
from pipecat.turns.user_turn_strategies import ExternalUserTurnStrategies
from pipecat.utils.time import time_now_iso8601
from api.services.workflow.pipecat_engine import PipecatEngine
from api.services.workflow.pipecat_engine_variable_extractor import (
VariableExtractionManager,
)
from api.services.workflow.workflow import WorkflowGraph
from pipecat.tests import MockLLMService, MockTTSService
class BotSpeakingObserverProcessor(FrameProcessor):
"""Observer that records mute status when bot speaking events flow upstream.
@ -160,7 +160,7 @@ async def create_engine_for_mute_test(
)
# Create context aggregator with user mute strategies
assistant_params = LLMAssistantAggregatorParams(expect_stripped_words=True)
assistant_params = LLMAssistantAggregatorParams()
user_mute_strategies = [
MuteUntilFirstBotCompleteUserMuteStrategy(),
@ -243,7 +243,7 @@ class TestUserMutingDuringBotSpeech:
) = await create_engine_for_mute_test(simple_workflow, llm, tts_duration_ms=50)
with patch(
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run",
"api.db:db_client.get_organization_id_by_workflow_run_id",
new_callable=AsyncMock,
return_value=1,
):
@ -334,7 +334,7 @@ class TestUserMutingDuringBotSpeech:
) = await create_engine_for_mute_test(simple_workflow, llm, tts_duration_ms=50)
with patch(
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run",
"api.db:db_client.get_organization_id_by_workflow_run_id",
new_callable=AsyncMock,
return_value=1,
):
@ -430,7 +430,7 @@ class TestUserMutingDuringBotSpeech:
) = await create_engine_for_mute_test(simple_workflow, llm, tts_duration_ms=50)
with patch(
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run",
"api.db:db_client.get_organization_id_by_workflow_run_id",
new_callable=AsyncMock,
return_value=1,
):

View file

@ -8,7 +8,6 @@ incoming speech as CONVERSATION or VOICEMAIL and how the main LLM responds.
import asyncio
import pytest
from pipecat.extensions.voicemail.voicemail_detector import VoicemailDetector
from pipecat.frames.frames import (
EndTaskFrame,
@ -27,7 +26,6 @@ from pipecat.processors.aggregators.llm_response_universal import (
LLMUserAggregatorParams,
)
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.tests import MockLLMService
from pipecat.turns.user_start import (
TranscriptionUserTurnStartStrategy,
VADUserTurnStartStrategy,
@ -38,6 +36,8 @@ from pipecat.turns.user_stop import (
from pipecat.turns.user_turn_strategies import UserTurnStrategies
from pipecat.utils.time import time_now_iso8601
from pipecat.tests import MockLLMService
class FrameInjector(FrameProcessor):
"""Simple processor that can inject frames into the pipeline."""
@ -110,7 +110,7 @@ class TestVoicemailDetectorWithUserAggregator:
user_turn_strategies=user_turn_strategies,
)
assistant_params = LLMAssistantAggregatorParams(expect_stripped_words=True)
assistant_params = LLMAssistantAggregatorParams()
context_aggregator = LLMContextAggregatorPair(
context, assistant_params=assistant_params, user_params=user_params