mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
chore: update pipecat to 1.3.0 (#379)
* chore: rename PipelineTask to PipelineWorker * fix: fix tests * chore: update pipecat submodule * fix: fix anyio same task cancellation scope
This commit is contained in:
parent
e695436fb3
commit
5ef3be92b5
26 changed files with 170 additions and 132 deletions
|
|
@ -21,7 +21,7 @@ from api.tasks.function_names import FunctionNames
|
|||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
)
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.pipeline.worker import PipelineWorker
|
||||
from pipecat.processors.audio.audio_buffer_processor import AudioBufferProcessor
|
||||
from pipecat.utils.enums import EndTaskReason
|
||||
|
||||
|
|
@ -58,7 +58,7 @@ async def _capture_call_event(
|
|||
|
||||
|
||||
def register_event_handlers(
|
||||
task: PipelineTask,
|
||||
task: PipelineWorker,
|
||||
transport,
|
||||
workflow_run_id: int,
|
||||
engine: PipecatEngine,
|
||||
|
|
@ -184,13 +184,13 @@ def register_event_handlers(
|
|||
)
|
||||
|
||||
@task.event_handler("on_pipeline_started")
|
||||
async def on_pipeline_started(_task: PipelineTask, _frame: Frame):
|
||||
async def on_pipeline_started(_task: PipelineWorker, _frame: Frame):
|
||||
logger.debug("In on_pipeline_started callback handler")
|
||||
ready_state["pipeline_started"] = True
|
||||
await maybe_trigger_initial_response()
|
||||
|
||||
@task.event_handler("on_pipeline_error")
|
||||
async def on_pipeline_error(_task: PipelineTask, frame: Frame):
|
||||
async def on_pipeline_error(_task: PipelineWorker, frame: Frame):
|
||||
logger.warning(f"Pipeline error for workflow run {workflow_run_id}: {frame}")
|
||||
try:
|
||||
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
|
||||
|
|
@ -218,7 +218,7 @@ def register_event_handlers(
|
|||
|
||||
@task.event_handler("on_pipeline_finished")
|
||||
async def on_pipeline_finished(
|
||||
task: PipelineTask,
|
||||
task: PipelineWorker,
|
||||
_frame: Frame,
|
||||
):
|
||||
logger.debug(f"In on_pipeline_finished callback handler")
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from loguru import logger
|
|||
|
||||
from api.services.pipecat.audio_config import AudioConfig
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.pipeline.worker import PipelineParams, PipelineWorker
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.audio.audio_buffer_processor import AudioBufferProcessor
|
||||
from pipecat.utils.run_context import turn_var
|
||||
|
|
@ -194,7 +194,7 @@ def create_pipeline_task(
|
|||
f"out: {audio_config.transport_out_sample_rate}Hz"
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
task = PipelineWorker(
|
||||
pipeline,
|
||||
params=pipeline_params,
|
||||
enable_tracing=True,
|
||||
|
|
|
|||
|
|
@ -51,6 +51,7 @@ from api.services.pipecat.tracing_config import (
|
|||
ensure_tracing,
|
||||
)
|
||||
from api.services.pipecat.transport_setup import create_webrtc_transport
|
||||
from api.services.pipecat.worker_runner import run_pipeline_worker
|
||||
from api.services.pipecat.ws_sender_registry import get_ws_sender
|
||||
from api.services.telephony import registry as telephony_registry
|
||||
from api.services.workflow.dto import ReactFlowDTO
|
||||
|
|
@ -61,7 +62,6 @@ from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnal
|
|||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.extensions.voicemail.voicemail_detector import VoicemailDetector
|
||||
from pipecat.pipeline.base_task import PipelineTaskParams
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMAssistantAggregatorParams,
|
||||
LLMContextAggregatorPair,
|
||||
|
|
@ -821,12 +821,15 @@ async def _run_pipeline(
|
|||
|
||||
try:
|
||||
# Run the pipeline
|
||||
loop = asyncio.get_running_loop()
|
||||
params = PipelineTaskParams(loop=loop)
|
||||
await task.run(params)
|
||||
await run_pipeline_worker(task)
|
||||
logger.info(f"Task completed for run {workflow_run_id}")
|
||||
except asyncio.CancelledError:
|
||||
logger.warning("Received CancelledError in _run_pipeline")
|
||||
finally:
|
||||
# Close MCP sessions here, not in engine.cleanup(). The anyio cancel
|
||||
# scopes opened by MCPClient.start() in engine.initialize() are
|
||||
# task-affine; this finally runs in the same task as initialize(),
|
||||
# whereas engine.cleanup() runs in a pipecat event-handler task.
|
||||
await engine.close_mcp_sessions()
|
||||
await feedback_observer.cleanup()
|
||||
logger.debug(f"Cleaned up context providers for workflow run {workflow_run_id}")
|
||||
|
|
|
|||
36
api/services/pipecat/worker_runner.py
Normal file
36
api/services/pipecat/worker_runner.py
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
import asyncio
|
||||
|
||||
from pipecat.pipeline.worker import PipelineWorker
|
||||
from pipecat.workers.runner import WorkerRunner
|
||||
|
||||
|
||||
async def run_pipeline_worker(
|
||||
worker: PipelineWorker,
|
||||
*,
|
||||
handle_sigint: bool = False,
|
||||
handle_sigterm: bool = False,
|
||||
auto_end: bool = True,
|
||||
) -> None:
|
||||
"""Run a pipeline worker through the v1.3 worker runner lifecycle."""
|
||||
runner = WorkerRunner(handle_sigint=handle_sigint, handle_sigterm=handle_sigterm)
|
||||
await runner.add_workers(worker)
|
||||
await runner.run(auto_end=auto_end)
|
||||
|
||||
|
||||
async def wait_for_pipeline_worker_started(
|
||||
worker: PipelineWorker,
|
||||
*,
|
||||
timeout: float = 3.0,
|
||||
run_task: asyncio.Task | None = None,
|
||||
) -> None:
|
||||
"""Wait until a pipeline worker has fired its stable start lifecycle."""
|
||||
|
||||
async def _wait_until_started():
|
||||
while worker.started_at is None:
|
||||
if run_task and run_task.done():
|
||||
await run_task
|
||||
if worker.has_finished():
|
||||
raise RuntimeError("PipelineWorker finished before starting")
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
await asyncio.wait_for(_wait_until_started(), timeout=timeout)
|
||||
|
|
@ -79,8 +79,12 @@ class McpToolSession:
|
|||
self.available: bool = False
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Connect, initialize, and cache the tool list. Never raises —
|
||||
on any failure the session is marked unavailable."""
|
||||
"""Connect, initialize, and cache the tool list.
|
||||
|
||||
Never raises on a connect failure — a dead/unreachable MCP server
|
||||
leaves the session marked unavailable (``available = False``). Genuine
|
||||
external cancellation, KeyboardInterrupt, and SystemExit are re-raised
|
||||
(see the CancelledError handling below and ``_degrade``)."""
|
||||
try:
|
||||
params = build_streamable_http_params(
|
||||
url=self._url,
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from pipecat.frames.frames import (
|
|||
LLMContextFrame,
|
||||
TTSSpeakFrame,
|
||||
)
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.pipeline.worker import PipelineWorker
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.services.settings import LLMSettings
|
||||
|
|
@ -60,7 +60,7 @@ class PipecatEngine:
|
|||
def __init__(
|
||||
self,
|
||||
*,
|
||||
task: Optional[PipelineTask] = None,
|
||||
task: Optional[PipelineWorker] = None,
|
||||
llm: Optional["LLMService"] = None,
|
||||
inference_llm: Optional["LLMService"] = None,
|
||||
context: Optional[LLMContext] = None,
|
||||
|
|
@ -842,7 +842,7 @@ class PipecatEngine:
|
|||
"""
|
||||
self.context = context
|
||||
|
||||
def set_task(self, task: PipelineTask) -> None:
|
||||
def set_task(self, task: PipelineWorker) -> None:
|
||||
"""Set the pipeline task.
|
||||
|
||||
This allows setting the task after the engine has been created,
|
||||
|
|
@ -955,7 +955,15 @@ class PipecatEngine:
|
|||
exc_info=True,
|
||||
)
|
||||
|
||||
async def _close_mcp_sessions(self) -> None:
|
||||
async def close_mcp_sessions(self) -> None:
|
||||
"""Close all open MCP tool sessions.
|
||||
|
||||
Must run in the same task that ran initialize() (which opened the
|
||||
sessions via _open_mcp_sessions). The MCP client's underlying anyio
|
||||
cancel scopes are task-affine — they must be exited from the task that
|
||||
entered them — so this is invoked from _run_pipeline's finally, not
|
||||
from cleanup() (which runs in a pipecat event-handler task).
|
||||
"""
|
||||
for tool_uuid, session in list(self._mcp_sessions.items()):
|
||||
try:
|
||||
await session.close()
|
||||
|
|
@ -964,7 +972,14 @@ class PipecatEngine:
|
|||
self._mcp_sessions = {}
|
||||
|
||||
async def cleanup(self):
|
||||
"""Clean up engine resources on disconnect."""
|
||||
"""Clean up engine resources on disconnect.
|
||||
|
||||
MCP tool sessions are intentionally NOT closed here — see
|
||||
close_mcp_sessions(). This method runs in a pipecat event-handler task
|
||||
(on_pipeline_finished), a different task than the one that opened the
|
||||
MCP sessions; closing them here raises "Attempted to exit cancel scope
|
||||
in a different task than it was entered in".
|
||||
"""
|
||||
# Cancel any pending timeout tasks
|
||||
if (
|
||||
self._user_response_timeout_task
|
||||
|
|
@ -973,11 +988,5 @@ class PipecatEngine:
|
|||
self._user_response_timeout_task.cancel()
|
||||
|
||||
# Cancel any in-flight background summarization.
|
||||
# MCP sessions are closed in a finally block so they are guaranteed to
|
||||
# run even if the summarization cleanup raises an exception.
|
||||
try:
|
||||
if self._context_summarization_manager:
|
||||
await self._context_summarization_manager.cleanup()
|
||||
finally:
|
||||
# Close any open MCP tool sessions
|
||||
await self._close_mcp_sessions()
|
||||
if self._context_summarization_manager:
|
||||
await self._context_summarization_manager.cleanup()
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ from pipecat.frames.frames import (
|
|||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMAssistantAggregatorParams,
|
||||
|
|
@ -45,6 +44,10 @@ from api.services.pipecat.tracing_config import (
|
|||
build_remote_parent_context,
|
||||
get_trace_url,
|
||||
)
|
||||
from api.services.pipecat.worker_runner import (
|
||||
run_pipeline_worker,
|
||||
wait_for_pipeline_worker_started,
|
||||
)
|
||||
from api.services.workflow.dto import ReactFlowDTO
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.workflow_graph import WorkflowGraph
|
||||
|
|
@ -534,8 +537,7 @@ async def execute_text_chat_pending_turn(
|
|||
conversation_type="text",
|
||||
additional_span_attributes=trace_span_attributes,
|
||||
)
|
||||
runner = PipelineRunner(handle_sigint=False, handle_sigterm=False)
|
||||
runner_task = asyncio.create_task(runner.run(task))
|
||||
runner_task = asyncio.create_task(run_pipeline_worker(task))
|
||||
|
||||
engine.set_task(task)
|
||||
engine.set_audio_config(audio_config)
|
||||
|
|
@ -548,7 +550,7 @@ async def execute_text_chat_pending_turn(
|
|||
)
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(task._pipeline_start_event.wait(), timeout=5.0)
|
||||
await wait_for_pipeline_worker_started(task, timeout=5.0, run_task=runner_task)
|
||||
|
||||
await engine.initialize()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue