mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
fix: await pending variable extraction tasks before pipeline finishes (#198)
This commit is contained in:
parent
1967a71935
commit
d42c52dc87
2 changed files with 53 additions and 4 deletions
|
|
@ -145,7 +145,7 @@ def register_event_handlers(
|
|||
logger.debug(f"Added trace URL to gathered_context: {trace_url}")
|
||||
|
||||
# also consider existing gathered context in workflow_run
|
||||
gathered_context = {**gathered_context, **workflow_run.gathered_context}
|
||||
gathered_context = {**workflow_run.gathered_context, **gathered_context}
|
||||
|
||||
# Set user_speech call tag
|
||||
call_tags = gathered_context.get("call_tags", [])
|
||||
|
|
|
|||
|
|
@ -86,6 +86,7 @@ class PipecatEngine:
|
|||
self._current_node: Optional[Node] = None
|
||||
self._gathered_context: dict = {}
|
||||
self._user_response_timeout_task: Optional[asyncio.Task] = None
|
||||
self._pending_extraction_tasks: set[asyncio.Task] = set()
|
||||
|
||||
# Will be set later in initialize() when we have
|
||||
# access to _context
|
||||
|
|
@ -433,6 +434,9 @@ class PipecatEngine:
|
|||
|
||||
async def _do_extraction():
|
||||
try:
|
||||
logger.debug(
|
||||
f"Starting variable extraction for node: {node.name}"
|
||||
)
|
||||
extracted_data = (
|
||||
await self._variable_extraction_manager._perform_extraction(
|
||||
extraction_variables, parent_context, extraction_prompt
|
||||
|
|
@ -440,22 +444,64 @@ class PipecatEngine:
|
|||
)
|
||||
self._gathered_context.update(extracted_data)
|
||||
logger.debug(
|
||||
f"Variable extraction completed. Extracted: {extracted_data}"
|
||||
f"Variable extraction completed for node: {node.name}. Extracted: {extracted_data}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during variable extraction: {str(e)}")
|
||||
logger.error(f"Error during variable extraction for node {node.name}: {str(e)}")
|
||||
|
||||
if run_in_background:
|
||||
logger.debug(
|
||||
f"Scheduling background variable extraction for node: {node.name}"
|
||||
)
|
||||
asyncio.create_task(_do_extraction())
|
||||
task = asyncio.create_task(
|
||||
_do_extraction(), name=f"variable-extraction:{node.name}"
|
||||
)
|
||||
self._pending_extraction_tasks.add(task)
|
||||
task.add_done_callback(self._pending_extraction_tasks.discard)
|
||||
else:
|
||||
logger.debug(
|
||||
f"Performing synchronous variable extraction for node: {node.name}"
|
||||
)
|
||||
await _do_extraction()
|
||||
|
||||
async def _await_pending_extractions(self, timeout: float = 5.0) -> None:
|
||||
"""Await all in-flight background extraction tasks.
|
||||
|
||||
Args:
|
||||
timeout: Maximum seconds to wait for pending extractions.
|
||||
"""
|
||||
if not self._pending_extraction_tasks:
|
||||
return
|
||||
|
||||
task_names = [t.get_name() for t in self._pending_extraction_tasks]
|
||||
logger.debug(
|
||||
f"Awaiting {len(self._pending_extraction_tasks)} pending extraction task(s): {task_names}"
|
||||
)
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
try:
|
||||
results = await asyncio.wait_for(
|
||||
asyncio.gather(*self._pending_extraction_tasks, return_exceptions=True),
|
||||
timeout=timeout,
|
||||
)
|
||||
elapsed = asyncio.get_event_loop().time() - start_time
|
||||
# Log any exceptions returned by gather
|
||||
for task_name, result in zip(task_names, results):
|
||||
if isinstance(result, Exception):
|
||||
logger.error(
|
||||
f"Pending extraction task '{task_name}' failed: {result}"
|
||||
)
|
||||
logger.debug(
|
||||
f"All pending extraction tasks completed in {elapsed:.2f}s"
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
incomplete = [
|
||||
t.get_name() for t in self._pending_extraction_tasks if not t.done()
|
||||
]
|
||||
logger.warning(
|
||||
f"Timed out waiting for pending extraction tasks after {timeout}s. "
|
||||
f"Incomplete: {incomplete}"
|
||||
)
|
||||
|
||||
async def _setup_llm_context(self, node: Node) -> None:
|
||||
"""Common method to set up LLM context"""
|
||||
# Set OTel span name for tracing
|
||||
|
|
@ -602,6 +648,9 @@ class PipecatEngine:
|
|||
EndTaskReason.PIPELINE_ERROR.value,
|
||||
EndTaskReason.VOICEMAIL_DETECTED.value,
|
||||
):
|
||||
# Await any in-flight background extractions from previous nodes
|
||||
await self._await_pending_extractions()
|
||||
|
||||
# Perform final variable extraction synchronously before ending
|
||||
await self._perform_variable_extraction_if_needed(
|
||||
self._current_node, run_in_background=False
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue