mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-25 08:48:13 +02:00
feat: simplify pipecat engine execution (#54)
This commit is contained in:
parent
99a768f291
commit
6ce25a589c
20 changed files with 52 additions and 1405 deletions
|
|
@ -1,69 +0,0 @@
|
|||
"""Engine Pre-Aggregator Processor
|
||||
|
||||
This processor sits before the user context aggregator in the pipeline and handles
|
||||
engine-specific callbacks for frames that need to be processed before aggregation.
|
||||
This ensures the engine can update context before the aggregator generates LLM frames.
|
||||
"""
|
||||
|
||||
from typing import Awaitable, Callable, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.services.pipecat.exceptions import VoicemailDetectedException
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
|
||||
class EnginePreAggregatorProcessor(FrameProcessor):
|
||||
"""
|
||||
Processor that handles engine callbacks before user context aggregation.
|
||||
|
||||
This processor is positioned before the user context aggregator to ensure
|
||||
the engine can update LLM context before aggregation occurs.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_started_speaking_callback: Optional[Callable[[], Awaitable[None]]] = None,
|
||||
user_stopped_speaking_callback: Optional[Callable[[], Awaitable[None]]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._user_started_speaking_callback = user_started_speaking_callback
|
||||
self._user_stopped_speaking_callback = user_stopped_speaking_callback
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
# Handle frames that need engine processing before aggregation
|
||||
if isinstance(frame, UserStartedSpeakingFrame):
|
||||
await self._handle_user_started_speaking()
|
||||
elif isinstance(frame, UserStoppedSpeakingFrame):
|
||||
try:
|
||||
await self._handle_user_stopped_speaking()
|
||||
except VoicemailDetectedException:
|
||||
# We have detected voicemail, lets not
|
||||
# forward the UserStoppedSpeakingFrame, so that
|
||||
# we don't issue an llm call from user context
|
||||
# aggregator
|
||||
logger.debug("Voicemail detected, not pushing UserStoppedSpeakingFrame")
|
||||
return
|
||||
|
||||
# Always push the frame downstream
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _handle_user_started_speaking(self):
|
||||
"""Handle UserStartedSpeakingFrame before aggregation."""
|
||||
if self._user_started_speaking_callback:
|
||||
# logger.debug("Engine pre-aggregator: User started speaking")
|
||||
await self._user_started_speaking_callback()
|
||||
|
||||
async def _handle_user_stopped_speaking(self):
|
||||
"""Handle UserStoppedSpeakingFrame before aggregation."""
|
||||
if self._user_stopped_speaking_callback:
|
||||
# logger.debug("Engine pre-aggregator: User stopped speaking")
|
||||
await self._user_stopped_speaking_callback()
|
||||
|
|
@ -9,7 +9,7 @@ from api.constants import (
|
|||
from api.services.pipecat.audio_config import AudioConfig
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.audio.audio_buffer_processor import AudioBuffer
|
||||
from pipecat.processors.audio.audio_synchronizer import AudioSynchronizer
|
||||
from pipecat.processors.transcript_processor import TranscriptProcessor
|
||||
|
|
@ -39,7 +39,7 @@ def create_pipeline_components(audio_config: AudioConfig, engine: "PipecatEngine
|
|||
assistant_correct_aggregation_callback=engine.create_aggregation_correction_callback()
|
||||
)
|
||||
|
||||
context = OpenAILLMContext()
|
||||
context = LLMContext()
|
||||
|
||||
return audio_buffer, audio_synchronizer, transcript, context
|
||||
|
||||
|
|
@ -58,7 +58,6 @@ def build_pipeline(
|
|||
stt_mute_filter,
|
||||
pipeline_metrics_aggregator,
|
||||
user_idle_disconnect,
|
||||
engine_pre_aggregator_processor=None,
|
||||
):
|
||||
"""Build the main pipeline with all components"""
|
||||
# Register processors with synchronizer for merged audio
|
||||
|
|
@ -69,16 +68,12 @@ def build_pipeline(
|
|||
processors = [
|
||||
transport.input(), # Transport user input
|
||||
audio_buffer.input(), # Record input audio (only processes InputAudioRawFrame)
|
||||
stt_mute_filter,
|
||||
stt, # STT can now have audio_passthrough=False
|
||||
stt_mute_filter, # STTMuteFilters don't let VAD related events pass through if muted
|
||||
user_idle_disconnect,
|
||||
transcript.user(),
|
||||
]
|
||||
|
||||
# Insert engine pre-aggregator processor if provided (before user aggregator)
|
||||
if engine_pre_aggregator_processor:
|
||||
processors.append(engine_pre_aggregator_processor)
|
||||
|
||||
processors.extend(
|
||||
[
|
||||
user_context_aggregator,
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ from pipecat.frames.frames import (
|
|||
Frame,
|
||||
HeartbeatFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMGeneratedTextFrame,
|
||||
LLMTextFrame,
|
||||
StartFrame,
|
||||
TTSSpeakFrame,
|
||||
|
|
@ -26,7 +25,6 @@ class PipelineEngineCallbacksProcessor(FrameProcessor):
|
|||
self,
|
||||
max_call_duration_seconds: int = 300,
|
||||
max_duration_end_task_callback: Optional[Callable[[], Awaitable[None]]] = None,
|
||||
llm_generated_text_callback: Optional[Callable[[], Awaitable[None]]] = None,
|
||||
generation_started_callback: Optional[Callable[[], Awaitable[None]]] = None,
|
||||
llm_text_frame_callback: Optional[Callable[[str], Awaitable[None]]] = None,
|
||||
):
|
||||
|
|
@ -34,7 +32,6 @@ class PipelineEngineCallbacksProcessor(FrameProcessor):
|
|||
self._start_time = None
|
||||
self._max_call_duration_seconds = max_call_duration_seconds
|
||||
self._max_duration_end_task_callback = max_duration_end_task_callback
|
||||
self._llm_generated_text_callback = llm_generated_text_callback
|
||||
self._generation_started_callback = generation_started_callback
|
||||
self._llm_text_frame_callback = llm_text_frame_callback
|
||||
self._end_task_frame_pushed = False
|
||||
|
|
@ -46,8 +43,6 @@ class PipelineEngineCallbacksProcessor(FrameProcessor):
|
|||
await self._start(frame)
|
||||
elif isinstance(frame, HeartbeatFrame):
|
||||
await self._check_call_duration()
|
||||
elif isinstance(frame, LLMGeneratedTextFrame):
|
||||
await self._generated_text_frame(frame)
|
||||
elif isinstance(frame, LLMFullResponseStartFrame):
|
||||
await self._generation_started()
|
||||
elif (
|
||||
|
|
@ -74,11 +69,6 @@ class PipelineEngineCallbacksProcessor(FrameProcessor):
|
|||
"Max call duration exceeded. Skipping EndTaskFrame since already sent"
|
||||
)
|
||||
|
||||
async def _generated_text_frame(self, _: LLMGeneratedTextFrame):
|
||||
"""Handle LLMGeneratedTextFrame."""
|
||||
if self._llm_generated_text_callback is not None:
|
||||
await self._llm_generated_text_callback()
|
||||
|
||||
async def _generation_started(self):
|
||||
if self._generation_started_callback:
|
||||
await self._generation_started_callback()
|
||||
|
|
|
|||
|
|
@ -7,9 +7,6 @@ from api.db import db_client
|
|||
from api.db.models import WorkflowModel
|
||||
from api.enums import WorkflowRunMode
|
||||
from api.services.pipecat.audio_config import AudioConfig, create_audio_config
|
||||
from api.services.pipecat.engine_pre_aggregator_processor import (
|
||||
EnginePreAggregatorProcessor,
|
||||
)
|
||||
from api.services.pipecat.event_handlers import (
|
||||
register_audio_data_handler,
|
||||
register_task_event_handler,
|
||||
|
|
@ -43,6 +40,9 @@ from api.services.workflow.pipecat_engine import PipecatEngine
|
|||
from api.services.workflow.workflow import WorkflowGraph
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
)
|
||||
from pipecat.processors.filters.stt_mute_filter import (
|
||||
STTMuteConfig,
|
||||
STTMuteFilter,
|
||||
|
|
@ -357,21 +357,14 @@ async def _run_pipeline(
|
|||
expect_stripped_words=True,
|
||||
correct_aggregation_callback=engine.create_aggregation_correction_callback(),
|
||||
)
|
||||
context_aggregator = llm.create_context_aggregator(
|
||||
context_aggregator = LLMContextAggregatorPair(
|
||||
context, assistant_params=assistant_params
|
||||
)
|
||||
|
||||
# Create engine pre-aggregator processor for speaking events
|
||||
engine_pre_aggregator_processor = EnginePreAggregatorProcessor(
|
||||
user_started_speaking_callback=engine.create_user_started_speaking_callback(),
|
||||
user_stopped_speaking_callback=engine.create_user_stopped_speaking_callback(),
|
||||
)
|
||||
|
||||
# Create usage metrics aggregator with engine's callback
|
||||
pipeline_engine_callback_processor = PipelineEngineCallbacksProcessor(
|
||||
max_call_duration_seconds=max_call_duration_seconds,
|
||||
max_duration_end_task_callback=engine.create_max_duration_callback(),
|
||||
llm_generated_text_callback=engine.create_llm_generated_text_callback(),
|
||||
generation_started_callback=engine.create_generation_started_callback(),
|
||||
llm_text_frame_callback=engine.handle_llm_text_frame,
|
||||
# Note: speaking event callbacks are now handled by pre-aggregator processor
|
||||
|
|
@ -398,11 +391,6 @@ async def _run_pipeline(
|
|||
user_context_aggregator = context_aggregator.user()
|
||||
assistant_context_aggregator = context_aggregator.assistant()
|
||||
|
||||
@assistant_context_aggregator.event_handler("on_push_aggregation")
|
||||
async def on_assistant_aggregator_push_context(_aggregator):
|
||||
logger.debug("Assistant aggregator push context – flushing pending transitions")
|
||||
await engine.flush_pending_transitions(source="context_push")
|
||||
|
||||
# Build the pipeline with the STT mute filter and context controller
|
||||
pipeline = build_pipeline(
|
||||
transport,
|
||||
|
|
@ -418,7 +406,6 @@ async def _run_pipeline(
|
|||
stt_mute_filter,
|
||||
pipeline_metrics_aggregator,
|
||||
user_idle_disconnect,
|
||||
engine_pre_aggregator_processor=engine_pre_aggregator_processor,
|
||||
)
|
||||
|
||||
# Create pipeline task with audio configuration
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue