mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
* chore: rename PipelineTask to PipelineWorker * fix: fix tests * chore: update pipecat submodule * fix: fix anyio same task cancellation scope
100 lines
3.2 KiB
Python
100 lines
3.2 KiB
Python
import asyncio
|
|
|
|
import pytest
|
|
from loguru import logger
|
|
from pipecat.frames.frames import (
|
|
EndTaskFrame,
|
|
Frame,
|
|
InterruptionTaskFrame,
|
|
LLMRunFrame,
|
|
)
|
|
from pipecat.pipeline.pipeline import Pipeline
|
|
from pipecat.pipeline.worker import PipelineWorker
|
|
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
|
|
|
from api.services.pipecat.worker_runner import run_pipeline_worker
|
|
|
|
|
|
class MockTransport(FrameProcessor):
|
|
def __init__(self, **kwargs):
|
|
super().__init__(**kwargs)
|
|
|
|
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
|
await super().process_frame(frame, direction)
|
|
await self.push_frame(frame, direction)
|
|
|
|
|
|
class BusyWaitProcessor(FrameProcessor):
|
|
def __init__(self, wait_time=5.0, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self._wait_time = wait_time
|
|
|
|
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
|
await super().process_frame(frame, direction)
|
|
if isinstance(frame, LLMRunFrame):
|
|
# Simulate a delay, which can happen sometimes due to slow LLM Inferencing or
|
|
# other reasons
|
|
try:
|
|
logger.debug(
|
|
f"{self} sleeping with frame: {frame} for {self._wait_time} seconds"
|
|
)
|
|
await asyncio.sleep(self._wait_time)
|
|
logger.debug(f"{self} woke up with frame: {frame}")
|
|
except asyncio.CancelledError:
|
|
logger.debug(f"{self} was cancelled")
|
|
raise
|
|
await self.push_frame(frame, direction)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_interruption_with_blocked_end_frame():
|
|
busy_wait_processor = BusyWaitProcessor(wait_time=5.0)
|
|
transport = MockTransport()
|
|
pipeline = Pipeline([transport, busy_wait_processor])
|
|
|
|
task = PipelineWorker(pipeline, enable_rtvi=False)
|
|
|
|
async def run_pipeline():
|
|
await run_pipeline_worker(task)
|
|
|
|
async def queue_frame():
|
|
await task.queue_frames([LLMRunFrame()])
|
|
|
|
# Send EndTaskFrame to simulate EndFrame
|
|
await asyncio.sleep(0.1)
|
|
await transport.queue_frame(EndTaskFrame(), direction=FrameDirection.UPSTREAM)
|
|
|
|
# Simulate an Interruption, which can happen if the user
|
|
# has started to speak
|
|
await asyncio.sleep(0.1)
|
|
await transport.queue_frame(
|
|
InterruptionTaskFrame(), direction=FrameDirection.UPSTREAM
|
|
)
|
|
|
|
# Create tasks explicitly for better control
|
|
pipeline_task = asyncio.create_task(run_pipeline())
|
|
queue_task = asyncio.create_task(queue_frame())
|
|
|
|
# Wait with timeout
|
|
done, pending = await asyncio.wait(
|
|
[pipeline_task, queue_task],
|
|
timeout=2.0,
|
|
return_when=asyncio.ALL_COMPLETED,
|
|
)
|
|
|
|
# If there are pending tasks, we timed out
|
|
if pending:
|
|
# Cancel all pending tasks
|
|
for t in pending:
|
|
t.cancel()
|
|
|
|
# Give limited time for cleanup, then move on regardless
|
|
try:
|
|
await asyncio.wait_for(
|
|
asyncio.gather(*pending, return_exceptions=True),
|
|
timeout=2.0,
|
|
)
|
|
except asyncio.TimeoutError:
|
|
pass # Cleanup took too long, continue anyway
|
|
|
|
pytest.fail("Test timed out after 2 second")
|