From 6b408e588c61ad2a3d61593f48d19371bf090d7d Mon Sep 17 00:00:00 2001 From: Abhishek Kumar Date: Mon, 26 Jan 2026 12:13:55 +0530 Subject: [PATCH] fix: fix variable extraction during pipeline execution flow --- api/services/campaign/call_dispatcher.py | 4 +- api/services/workflow/pipecat_engine.py | 12 +- .../test_pipecat_engine_context_update.py | 75 ++--- api/tests/test_pipecat_engine_tool_calls.py | 3 +- ...test_pipecat_engine_variable_extraction.py | 315 ++++++++++++++++++ api/tests/test_pipeline_cancellation.py | 2 +- api/tests/test_user_idle_handler.py | 3 +- 7 files changed, 352 insertions(+), 62 deletions(-) create mode 100644 api/tests/test_pipecat_engine_variable_extraction.py diff --git a/api/services/campaign/call_dispatcher.py b/api/services/campaign/call_dispatcher.py index 353a42d..ec85c8d 100644 --- a/api/services/campaign/call_dispatcher.py +++ b/api/services/campaign/call_dispatcher.py @@ -182,7 +182,7 @@ class CampaignCallDispatcher: # Get provider first to determine the mode provider = await self.get_telephony_provider(campaign.organization_id) workflow_run_mode = provider.PROVIDER_NAME - + logger.info(f"Provider name: {provider.PROVIDER_NAME}") logger.info(f"Queued run context: {queued_run.context_variables}") @@ -193,7 +193,7 @@ class CampaignCallDispatcher: "campaign_id": campaign.id, "provider": provider.PROVIDER_NAME, } - + logger.info(f"Final initial_context: {initial_context}") # Create workflow run with queued_run_id tracking diff --git a/api/services/workflow/pipecat_engine.py b/api/services/workflow/pipecat_engine.py index d3b3cc1..8a810dc 100644 --- a/api/services/workflow/pipecat_engine.py +++ b/api/services/workflow/pipecat_engine.py @@ -199,6 +199,13 @@ class PipecatEngine: f"Function: {name} -> transitioning to node: {transition_to_node}" ) logger.info(f"Arguments: {function_call_params.arguments}") + + # Perform variable extraction before transitioning to new node + await self._perform_variable_extraction_if_needed(self._current_node) + + # Set context for the new node, so that when the function call result + # frame is received by LLMContextAggregator and an LLM generation + # is done, we have updated context and functions await self.set_node(transition_to_node) try: @@ -208,11 +215,6 @@ class PipecatEngine: This way, when we do set_node from within this function, and go for LLM completion with updated system prompts, the context is updated with function call result. """ - # Perform variable extraction before transitioning to new node - await self._perform_variable_extraction_if_needed( - self._current_node - ) - # Queue EndFrame if we just transitioned to EndNode if self._current_node.is_end: await self.send_end_task_frame( diff --git a/api/tests/test_pipecat_engine_context_update.py b/api/tests/test_pipecat_engine_context_update.py index 1a67665..e20a7ab 100644 --- a/api/tests/test_pipecat_engine_context_update.py +++ b/api/tests/test_pipecat_engine_context_update.py @@ -258,8 +258,7 @@ async def run_pipeline_and_capture_context( # Create pipeline task task = PipelineTask( - pipeline, - params=PipelineParams(allow_interruptions=False), + pipeline, params=PipelineParams(allow_interruptions=False), enable_rtvi=False ) engine.set_task(task) @@ -311,25 +310,17 @@ class TestContextUpdateBeforeNextCompletion: transition completes. The test verifies the context is still correctly updated. """ # Step 0 (Start node): call collect_info to transition to agent - step_0_chunks = MockLLMService.create_multiple_function_call_chunks( - [ - { - "name": "collect_info", - "arguments": {}, - "tool_call_id": "call_transition_1", - }, - ] + step_0_chunks = MockLLMService.create_function_call_chunks( + function_name="collect_info", + arguments={}, + tool_call_id="call_transition_1", ) # Step 1 (Agent node): call end_call to transition to end - step_1_chunks = MockLLMService.create_multiple_function_call_chunks( - [ - { - "name": "end_call", - "arguments": {}, - "tool_call_id": "call_transition_2", - }, - ] + step_1_chunks = MockLLMService.create_function_call_chunks( + function_name="end_call", + arguments={}, + tool_call_id="call_transition_2", ) # Step 2 (End node): text response (end node has no outgoing edges) @@ -393,25 +384,17 @@ class TestContextUpdateBeforeNextCompletion: is handled correctly. """ # Step 0 (Start node): call collect_info to transition to agent - step_0_chunks = MockLLMService.create_multiple_function_call_chunks( - [ - { - "name": "collect_info", - "arguments": {}, - "tool_call_id": "call_transition_1", - }, - ] + step_0_chunks = MockLLMService.create_function_call_chunks( + function_name="collect_info", + arguments={}, + tool_call_id="call_transition_1", ) # Step 1 (Agent node): call end_call to transition to end - step_1_chunks = MockLLMService.create_multiple_function_call_chunks( - [ - { - "name": "end_call", - "arguments": {}, - "tool_call_id": "call_transition_2", - }, - ] + step_1_chunks = MockLLMService.create_function_call_chunks( + function_name="end_call", + arguments={}, + tool_call_id="call_transition_2", ) # Step 2 (End node): text response @@ -472,25 +455,17 @@ class TestContextUpdateBeforeNextCompletion: - Tool call messages and results """ # Step 0 (Start node): call collect_info to transition to agent - step_0_chunks = MockLLMService.create_multiple_function_call_chunks( - [ - { - "name": "collect_info", - "arguments": {}, - "tool_call_id": "call_transition_1", - }, - ] + step_0_chunks = MockLLMService.create_function_call_chunks( + function_name="collect_info", + arguments={}, + tool_call_id="call_transition_1", ) # Step 1 (Agent node): call end_call to transition to end - step_1_chunks = MockLLMService.create_multiple_function_call_chunks( - [ - { - "name": "end_call", - "arguments": {}, - "tool_call_id": "call_transition_2", - }, - ] + step_1_chunks = MockLLMService.create_function_call_chunks( + function_name="end_call", + arguments={}, + tool_call_id="call_transition_2", ) # Step 2 (End node): text response diff --git a/api/tests/test_pipecat_engine_tool_calls.py b/api/tests/test_pipecat_engine_tool_calls.py index 7225a9b..4fe97b6 100644 --- a/api/tests/test_pipecat_engine_tool_calls.py +++ b/api/tests/test_pipecat_engine_tool_calls.py @@ -98,8 +98,7 @@ async def run_pipeline_with_tool_calls( # Create a real pipeline task task = PipelineTask( - pipeline, - params=PipelineParams(allow_interruptions=False), + pipeline, params=PipelineParams(allow_interruptions=False), enable_rtvi=False ) engine.set_task(task) diff --git a/api/tests/test_pipecat_engine_variable_extraction.py b/api/tests/test_pipecat_engine_variable_extraction.py new file mode 100644 index 0000000..88b5e22 --- /dev/null +++ b/api/tests/test_pipecat_engine_variable_extraction.py @@ -0,0 +1,315 @@ +"""Tests for verifying variable extraction is triggered for the correct node during transitions. + +This module tests that when the LLM calls a node transition function, variable extraction +is performed for the SOURCE node (where the conversation happened), not the TARGET node. + +The key behavior being tested: +1. LLM calls a transition function (e.g., "collect_info") from START node +2. START node has extraction_enabled=True with extraction_variables +3. AGENT node (target) has extraction_enabled=False +4. Variable extraction should be triggered for START node's variables +5. Variable extraction should NOT be triggered for AGENT node +""" + +import asyncio +from typing import Any, Dict, List +from unittest.mock import AsyncMock, patch + +import pytest + +from api.services.workflow.dto import ( + EdgeDataDTO, + ExtractionVariableDTO, + NodeDataDTO, + NodeType, + Position, + ReactFlowDTO, + RFEdgeDTO, + RFNodeDTO, + VariableType, +) +from api.services.workflow.pipecat_engine import PipecatEngine +from api.services.workflow.pipecat_engine_variable_extractor import ( + VariableExtractionManager, +) +from api.services.workflow.workflow import WorkflowGraph +from api.tests.conftest import MockTransportProcessor +from pipecat.frames.frames import LLMContextFrame +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineParams, PipelineTask +from pipecat.processors.aggregators.llm_context import LLMContext +from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams +from pipecat.processors.aggregators.llm_response_universal import ( + LLMContextAggregatorPair, +) +from pipecat.tests import MockLLMService, MockTTSService + +# Define prompts for test nodes +START_NODE_PROMPT = "Start Node System Prompt" +AGENT_NODE_PROMPT = "Agent Node System Prompt" +END_NODE_PROMPT = "End Node System Prompt" + + +@pytest.fixture +def three_node_workflow_with_extraction_on_start() -> WorkflowGraph: + """Create a three-node workflow where only the start node has extraction enabled. + + The workflow has: + - Start node with extraction_enabled=True and extraction_variables set + - Agent node with extraction_enabled=False (default) + - End node with extraction_enabled=False (default) + + This is used to test that variable extraction is triggered for the correct node + during transitions. + """ + dto = ReactFlowDTO( + nodes=[ + RFNodeDTO( + id="start", + type=NodeType.startNode, + position=Position(x=0, y=0), + data=NodeDataDTO( + name="Start Call", + prompt=START_NODE_PROMPT, + is_start=True, + allow_interrupt=False, + add_global_prompt=False, + extraction_enabled=True, + extraction_prompt="Extract the user's name from the conversation.", + extraction_variables=[ + ExtractionVariableDTO( + name="user_name", + type=VariableType.string, + prompt="The name the user provided", + ), + ], + ), + ), + RFNodeDTO( + id="agent", + type=NodeType.agentNode, + position=Position(x=0, y=200), + data=NodeDataDTO( + name="Collect Info", + prompt=AGENT_NODE_PROMPT, + allow_interrupt=False, + add_global_prompt=False, + extraction_enabled=False, # Explicitly disabled + ), + ), + RFNodeDTO( + id="end", + type=NodeType.endNode, + position=Position(x=0, y=400), + data=NodeDataDTO( + name="End Call", + prompt=END_NODE_PROMPT, + is_end=True, + allow_interrupt=False, + add_global_prompt=False, + extraction_enabled=False, # Explicitly disabled + ), + ), + ], + edges=[ + RFEdgeDTO( + id="start-agent", + source="start", + target="agent", + data=EdgeDataDTO( + label="Collect Info", + condition="When user has been greeted, proceed to collect information", + ), + ), + RFEdgeDTO( + id="agent-end", + source="agent", + target="end", + data=EdgeDataDTO( + label="End Call", + condition="When information collection is complete, end the call", + ), + ), + ], + ) + return WorkflowGraph(dto) + + +class TestVariableExtractionDuringTransitions: + """Test that variable extraction is triggered for the correct node during transitions.""" + + @pytest.mark.asyncio + async def test_extraction_called_for_source_node_not_target_node( + self, three_node_workflow_with_extraction_on_start: WorkflowGraph + ): + """Test that when transitioning from START to AGENT, extraction is called for START node. + + Scenario: + 1. Start node has extraction_enabled=True with extraction_variables + 2. Agent node has extraction_enabled=False + 3. LLM calls transition function to move from START to AGENT + 4. VERIFY: Variable extraction should be called for START node's variables + 5. VERIFY: Variable extraction should NOT be called for AGENT node + + This test verifies that extraction happens for the SOURCE node of a transition, + which is the node where the conversation context that needs extraction occurred. + """ + # Track which nodes had extraction performed + extraction_calls: List[Dict[str, Any]] = [] + + # Step 0 (Start node): call collect_info to transition to agent + step_0_chunks = MockLLMService.create_function_call_chunks( + function_name="collect_info", + arguments={}, + tool_call_id="call_transition_1", + ) + + # Step 1 (Agent node): call end_call to transition to end + step_1_chunks = MockLLMService.create_function_call_chunks( + function_name="end_call", + arguments={}, + tool_call_id="call_transition_2", + ) + + # Step 2 (End node): text response + step_2_chunks = MockLLMService.create_text_chunks("Goodbye!") + + mock_steps = [step_0_chunks, step_1_chunks, step_2_chunks] + + # Create mock LLM + llm = MockLLMService(mock_steps=mock_steps, chunk_delay=0.001) + + # Create MockTTSService + tts = MockTTSService(mock_audio_duration_ms=10, frame_delay=0) + + mock_transport_emulator = MockTransportProcessor(emit_bot_speaking=False) + + # Create LLM context + context = LLMContext() + + # Add assistant context aggregator + assistant_params = LLMAssistantAggregatorParams(expect_stripped_words=True) + context_aggregator = LLMContextAggregatorPair( + context, assistant_params=assistant_params + ) + assistant_context_aggregator = context_aggregator.assistant() + + workflow = three_node_workflow_with_extraction_on_start + + # Create PipecatEngine with the workflow + engine = PipecatEngine( + llm=llm, + context=context, + workflow=workflow, + call_context_vars={"customer_name": "Test User"}, + workflow_run_id=1, + ) + + # Patch _perform_variable_extraction_if_needed to track calls + original_perform_extraction = engine._perform_variable_extraction_if_needed + + async def tracked_perform_extraction(node): + extraction_calls.append( + { + "node_id": node.id if node else None, + "node_name": node.name if node else None, + "extraction_enabled": node.extraction_enabled if node else None, + "extraction_variables": node.extraction_variables if node else None, + } + ) + # Call original to maintain behavior + await original_perform_extraction(node) + + engine._perform_variable_extraction_if_needed = tracked_perform_extraction + + # Create the pipeline + pipeline = Pipeline( + [ + llm, + tts, + mock_transport_emulator, + assistant_context_aggregator, + ] + ) + + # Create pipeline task + task = PipelineTask( + pipeline, + params=PipelineParams(allow_interruptions=False), + enable_rtvi=False, + ) + + engine.set_task(task) + + # Patch DB calls and extraction manager + with patch( + "api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run", + new_callable=AsyncMock, + return_value=1, + ): + with patch( + "api.services.workflow.pipecat_engine.apply_disposition_mapping", + new_callable=AsyncMock, + return_value="completed", + ): + # Mock the actual extraction to avoid needing a real LLM + with patch.object( + VariableExtractionManager, + "_perform_extraction", + new_callable=AsyncMock, + return_value={"user_name": "John Doe"}, + ): + runner = PipelineRunner() + + async def run_pipeline(): + await runner.run(task) + + async def initialize_engine(): + await asyncio.sleep(0.01) + await engine.initialize() + await engine.llm.queue_frame(LLMContextFrame(engine.context)) + + await asyncio.gather(run_pipeline(), initialize_engine()) + + # Should have 3 LLM generations + assert llm.get_current_step() == 3 + + # Verify extraction was called during transitions + # The key assertion: when transitioning from START to AGENT, + # the extraction should be for START node (which has extraction enabled) + + # Filter to only calls where extraction was actually attempted + # (node has extraction_enabled=True and extraction_variables) + extraction_enabled_calls = [ + call + for call in extraction_calls + if call["extraction_enabled"] and call["extraction_variables"] + ] + + # START node has extraction enabled, so when transitioning FROM start, + # extraction should be triggered for START's variables + assert len(extraction_enabled_calls) >= 1, ( + f"Expected at least 1 extraction call for start node, got {len(extraction_enabled_calls)}. " + f"All calls: {extraction_calls}" + ) + + # Verify the extraction was called for the START node + start_extraction_calls = [ + call for call in extraction_enabled_calls if call["node_id"] == "start" + ] + assert len(start_extraction_calls) >= 1, ( + f"Expected extraction to be called for START node (which has extraction enabled), " + f"but got calls for: {[c['node_id'] for c in extraction_enabled_calls]}" + ) + + # Verify extraction was NOT called for AGENT node + agent_extraction_calls = [ + call + for call in extraction_calls + if call["node_id"] == "agent" and call["extraction_enabled"] + ] + assert len(agent_extraction_calls) == 0, ( + f"Expected NO extraction calls for AGENT node (extraction disabled), " + f"but got {len(agent_extraction_calls)} calls" + ) diff --git a/api/tests/test_pipeline_cancellation.py b/api/tests/test_pipeline_cancellation.py index 45f04c5..02ea5f4 100644 --- a/api/tests/test_pipeline_cancellation.py +++ b/api/tests/test_pipeline_cancellation.py @@ -50,7 +50,7 @@ async def test_interruption_with_blocked_end_frame(): transport = MockTransport() pipeline = Pipeline([transport, busy_wait_processor]) - task = PipelineTask(pipeline) + task = PipelineTask(pipeline, enable_rtvi=False) async def run_pipeline(): loop = asyncio.get_running_loop() diff --git a/api/tests/test_user_idle_handler.py b/api/tests/test_user_idle_handler.py index ac5252c..73e9cba 100644 --- a/api/tests/test_user_idle_handler.py +++ b/api/tests/test_user_idle_handler.py @@ -106,8 +106,7 @@ async def run_pipeline_with_user_idle( # Create pipeline task task = PipelineTask( - pipeline, - params=PipelineParams(allow_interruptions=False), + pipeline, params=PipelineParams(allow_interruptions=False), enable_rtvi=False ) engine.set_task(task)