dograh/api/tests/test_pipeline_cancellation.py
Abhishek 5ef3be92b5
chore: update pipecat to 1.3.0 (#379)
* chore: rename PipelineTask to PipelineWorker

* fix: fix tests

* chore: update pipecat submodule

* fix: fix anyio same task cancellation scope
2026-05-29 16:19:42 +05:30

100 lines
3.2 KiB
Python

import asyncio
import pytest
from loguru import logger
from pipecat.frames.frames import (
EndTaskFrame,
Frame,
InterruptionTaskFrame,
LLMRunFrame,
)
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.worker import PipelineWorker
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from api.services.pipecat.worker_runner import run_pipeline_worker
class MockTransport(FrameProcessor):
def __init__(self, **kwargs):
super().__init__(**kwargs)
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
await self.push_frame(frame, direction)
class BusyWaitProcessor(FrameProcessor):
def __init__(self, wait_time=5.0, **kwargs):
super().__init__(**kwargs)
self._wait_time = wait_time
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, LLMRunFrame):
# Simulate a delay, which can happen sometimes due to slow LLM Inferencing or
# other reasons
try:
logger.debug(
f"{self} sleeping with frame: {frame} for {self._wait_time} seconds"
)
await asyncio.sleep(self._wait_time)
logger.debug(f"{self} woke up with frame: {frame}")
except asyncio.CancelledError:
logger.debug(f"{self} was cancelled")
raise
await self.push_frame(frame, direction)
@pytest.mark.asyncio
async def test_interruption_with_blocked_end_frame():
busy_wait_processor = BusyWaitProcessor(wait_time=5.0)
transport = MockTransport()
pipeline = Pipeline([transport, busy_wait_processor])
task = PipelineWorker(pipeline, enable_rtvi=False)
async def run_pipeline():
await run_pipeline_worker(task)
async def queue_frame():
await task.queue_frames([LLMRunFrame()])
# Send EndTaskFrame to simulate EndFrame
await asyncio.sleep(0.1)
await transport.queue_frame(EndTaskFrame(), direction=FrameDirection.UPSTREAM)
# Simulate an Interruption, which can happen if the user
# has started to speak
await asyncio.sleep(0.1)
await transport.queue_frame(
InterruptionTaskFrame(), direction=FrameDirection.UPSTREAM
)
# Create tasks explicitly for better control
pipeline_task = asyncio.create_task(run_pipeline())
queue_task = asyncio.create_task(queue_frame())
# Wait with timeout
done, pending = await asyncio.wait(
[pipeline_task, queue_task],
timeout=2.0,
return_when=asyncio.ALL_COMPLETED,
)
# If there are pending tasks, we timed out
if pending:
# Cancel all pending tasks
for t in pending:
t.cancel()
# Give limited time for cleanup, then move on regardless
try:
await asyncio.wait_for(
asyncio.gather(*pending, return_exceptions=True),
timeout=2.0,
)
except asyncio.TimeoutError:
pass # Cleanup took too long, continue anyway
pytest.fail("Test timed out after 2 second")