fix: await pending variable extraction tasks before pipeline finishes

This commit is contained in:
Sabiha Khan 2026-03-23 09:40:20 +05:30
parent c61a3843a5
commit 563e631522
2 changed files with 53 additions and 4 deletions

View file

@ -145,7 +145,7 @@ def register_event_handlers(
logger.debug(f"Added trace URL to gathered_context: {trace_url}")
# also consider existing gathered context in workflow_run
gathered_context = {**gathered_context, **workflow_run.gathered_context}
gathered_context = {**workflow_run.gathered_context, **gathered_context}
# Set user_speech call tag
call_tags = gathered_context.get("call_tags", [])

View file

@ -86,6 +86,7 @@ class PipecatEngine:
self._current_node: Optional[Node] = None
self._gathered_context: dict = {}
self._user_response_timeout_task: Optional[asyncio.Task] = None
self._pending_extraction_tasks: set[asyncio.Task] = set()
# Will be set later in initialize() when we have
# access to _context
@ -433,6 +434,9 @@ class PipecatEngine:
async def _do_extraction():
try:
logger.debug(
f"Starting variable extraction for node: {node.name}"
)
extracted_data = (
await self._variable_extraction_manager._perform_extraction(
extraction_variables, parent_context, extraction_prompt
@ -440,22 +444,64 @@ class PipecatEngine:
)
self._gathered_context.update(extracted_data)
logger.debug(
f"Variable extraction completed. Extracted: {extracted_data}"
f"Variable extraction completed for node: {node.name}. Extracted: {extracted_data}"
)
except Exception as e:
logger.error(f"Error during variable extraction: {str(e)}")
logger.error(f"Error during variable extraction for node {node.name}: {str(e)}")
if run_in_background:
logger.debug(
f"Scheduling background variable extraction for node: {node.name}"
)
asyncio.create_task(_do_extraction())
task = asyncio.create_task(
_do_extraction(), name=f"variable-extraction:{node.name}"
)
self._pending_extraction_tasks.add(task)
task.add_done_callback(self._pending_extraction_tasks.discard)
else:
logger.debug(
f"Performing synchronous variable extraction for node: {node.name}"
)
await _do_extraction()
async def _await_pending_extractions(self, timeout: float = 5.0) -> None:
"""Await all in-flight background extraction tasks.
Args:
timeout: Maximum seconds to wait for pending extractions.
"""
if not self._pending_extraction_tasks:
return
task_names = [t.get_name() for t in self._pending_extraction_tasks]
logger.debug(
f"Awaiting {len(self._pending_extraction_tasks)} pending extraction task(s): {task_names}"
)
start_time = asyncio.get_event_loop().time()
try:
results = await asyncio.wait_for(
asyncio.gather(*self._pending_extraction_tasks, return_exceptions=True),
timeout=timeout,
)
elapsed = asyncio.get_event_loop().time() - start_time
# Log any exceptions returned by gather
for task_name, result in zip(task_names, results):
if isinstance(result, Exception):
logger.error(
f"Pending extraction task '{task_name}' failed: {result}"
)
logger.debug(
f"All pending extraction tasks completed in {elapsed:.2f}s"
)
except asyncio.TimeoutError:
incomplete = [
t.get_name() for t in self._pending_extraction_tasks if not t.done()
]
logger.warning(
f"Timed out waiting for pending extraction tasks after {timeout}s. "
f"Incomplete: {incomplete}"
)
async def _setup_llm_context(self, node: Node) -> None:
"""Common method to set up LLM context"""
# Set OTel span name for tracing
@ -598,6 +644,9 @@ class PipecatEngine:
EndTaskReason.PIPELINE_ERROR.value,
EndTaskReason.VOICEMAIL_DETECTED.value,
):
# Await any in-flight background extractions from previous nodes
await self._await_pending_extractions()
# Perform final variable extraction synchronously before ending
await self._perform_variable_extraction_if_needed(
self._current_node, run_in_background=False