fix: fix variable extraction during pipeline execution flow

This commit is contained in:
Abhishek Kumar 2026-01-26 12:13:55 +05:30
parent 0e70a77f17
commit 6b408e588c
7 changed files with 352 additions and 62 deletions

View file

@ -182,7 +182,7 @@ class CampaignCallDispatcher:
# Get provider first to determine the mode
provider = await self.get_telephony_provider(campaign.organization_id)
workflow_run_mode = provider.PROVIDER_NAME
logger.info(f"Provider name: {provider.PROVIDER_NAME}")
logger.info(f"Queued run context: {queued_run.context_variables}")
@ -193,7 +193,7 @@ class CampaignCallDispatcher:
"campaign_id": campaign.id,
"provider": provider.PROVIDER_NAME,
}
logger.info(f"Final initial_context: {initial_context}")
# Create workflow run with queued_run_id tracking

View file

@ -199,6 +199,13 @@ class PipecatEngine:
f"Function: {name} -> transitioning to node: {transition_to_node}"
)
logger.info(f"Arguments: {function_call_params.arguments}")
# Perform variable extraction before transitioning to new node
await self._perform_variable_extraction_if_needed(self._current_node)
# Set context for the new node, so that when the function call result
# frame is received by LLMContextAggregator and an LLM generation
# is done, we have updated context and functions
await self.set_node(transition_to_node)
try:
@ -208,11 +215,6 @@ class PipecatEngine:
This way, when we do set_node from within this function, and go for LLM completion with updated
system prompts, the context is updated with function call result.
"""
# Perform variable extraction before transitioning to new node
await self._perform_variable_extraction_if_needed(
self._current_node
)
# Queue EndFrame if we just transitioned to EndNode
if self._current_node.is_end:
await self.send_end_task_frame(

View file

@ -258,8 +258,7 @@ async def run_pipeline_and_capture_context(
# Create pipeline task
task = PipelineTask(
pipeline,
params=PipelineParams(allow_interruptions=False),
pipeline, params=PipelineParams(allow_interruptions=False), enable_rtvi=False
)
engine.set_task(task)
@ -311,25 +310,17 @@ class TestContextUpdateBeforeNextCompletion:
transition completes. The test verifies the context is still correctly updated.
"""
# Step 0 (Start node): call collect_info to transition to agent
step_0_chunks = MockLLMService.create_multiple_function_call_chunks(
[
{
"name": "collect_info",
"arguments": {},
"tool_call_id": "call_transition_1",
},
]
step_0_chunks = MockLLMService.create_function_call_chunks(
function_name="collect_info",
arguments={},
tool_call_id="call_transition_1",
)
# Step 1 (Agent node): call end_call to transition to end
step_1_chunks = MockLLMService.create_multiple_function_call_chunks(
[
{
"name": "end_call",
"arguments": {},
"tool_call_id": "call_transition_2",
},
]
step_1_chunks = MockLLMService.create_function_call_chunks(
function_name="end_call",
arguments={},
tool_call_id="call_transition_2",
)
# Step 2 (End node): text response (end node has no outgoing edges)
@ -393,25 +384,17 @@ class TestContextUpdateBeforeNextCompletion:
is handled correctly.
"""
# Step 0 (Start node): call collect_info to transition to agent
step_0_chunks = MockLLMService.create_multiple_function_call_chunks(
[
{
"name": "collect_info",
"arguments": {},
"tool_call_id": "call_transition_1",
},
]
step_0_chunks = MockLLMService.create_function_call_chunks(
function_name="collect_info",
arguments={},
tool_call_id="call_transition_1",
)
# Step 1 (Agent node): call end_call to transition to end
step_1_chunks = MockLLMService.create_multiple_function_call_chunks(
[
{
"name": "end_call",
"arguments": {},
"tool_call_id": "call_transition_2",
},
]
step_1_chunks = MockLLMService.create_function_call_chunks(
function_name="end_call",
arguments={},
tool_call_id="call_transition_2",
)
# Step 2 (End node): text response
@ -472,25 +455,17 @@ class TestContextUpdateBeforeNextCompletion:
- Tool call messages and results
"""
# Step 0 (Start node): call collect_info to transition to agent
step_0_chunks = MockLLMService.create_multiple_function_call_chunks(
[
{
"name": "collect_info",
"arguments": {},
"tool_call_id": "call_transition_1",
},
]
step_0_chunks = MockLLMService.create_function_call_chunks(
function_name="collect_info",
arguments={},
tool_call_id="call_transition_1",
)
# Step 1 (Agent node): call end_call to transition to end
step_1_chunks = MockLLMService.create_multiple_function_call_chunks(
[
{
"name": "end_call",
"arguments": {},
"tool_call_id": "call_transition_2",
},
]
step_1_chunks = MockLLMService.create_function_call_chunks(
function_name="end_call",
arguments={},
tool_call_id="call_transition_2",
)
# Step 2 (End node): text response

View file

@ -98,8 +98,7 @@ async def run_pipeline_with_tool_calls(
# Create a real pipeline task
task = PipelineTask(
pipeline,
params=PipelineParams(allow_interruptions=False),
pipeline, params=PipelineParams(allow_interruptions=False), enable_rtvi=False
)
engine.set_task(task)

View file

@ -0,0 +1,315 @@
"""Tests for verifying variable extraction is triggered for the correct node during transitions.
This module tests that when the LLM calls a node transition function, variable extraction
is performed for the SOURCE node (where the conversation happened), not the TARGET node.
The key behavior being tested:
1. LLM calls a transition function (e.g., "collect_info") from START node
2. START node has extraction_enabled=True with extraction_variables
3. AGENT node (target) has extraction_enabled=False
4. Variable extraction should be triggered for START node's variables
5. Variable extraction should NOT be triggered for AGENT node
"""
import asyncio
from typing import Any, Dict, List
from unittest.mock import AsyncMock, patch
import pytest
from api.services.workflow.dto import (
EdgeDataDTO,
ExtractionVariableDTO,
NodeDataDTO,
NodeType,
Position,
ReactFlowDTO,
RFEdgeDTO,
RFNodeDTO,
VariableType,
)
from api.services.workflow.pipecat_engine import PipecatEngine
from api.services.workflow.pipecat_engine_variable_extractor import (
VariableExtractionManager,
)
from api.services.workflow.workflow import WorkflowGraph
from api.tests.conftest import MockTransportProcessor
from pipecat.frames.frames import LLMContextFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
from pipecat.processors.aggregators.llm_response_universal import (
LLMContextAggregatorPair,
)
from pipecat.tests import MockLLMService, MockTTSService
# Define prompts for test nodes
START_NODE_PROMPT = "Start Node System Prompt"
AGENT_NODE_PROMPT = "Agent Node System Prompt"
END_NODE_PROMPT = "End Node System Prompt"
@pytest.fixture
def three_node_workflow_with_extraction_on_start() -> WorkflowGraph:
"""Create a three-node workflow where only the start node has extraction enabled.
The workflow has:
- Start node with extraction_enabled=True and extraction_variables set
- Agent node with extraction_enabled=False (default)
- End node with extraction_enabled=False (default)
This is used to test that variable extraction is triggered for the correct node
during transitions.
"""
dto = ReactFlowDTO(
nodes=[
RFNodeDTO(
id="start",
type=NodeType.startNode,
position=Position(x=0, y=0),
data=NodeDataDTO(
name="Start Call",
prompt=START_NODE_PROMPT,
is_start=True,
allow_interrupt=False,
add_global_prompt=False,
extraction_enabled=True,
extraction_prompt="Extract the user's name from the conversation.",
extraction_variables=[
ExtractionVariableDTO(
name="user_name",
type=VariableType.string,
prompt="The name the user provided",
),
],
),
),
RFNodeDTO(
id="agent",
type=NodeType.agentNode,
position=Position(x=0, y=200),
data=NodeDataDTO(
name="Collect Info",
prompt=AGENT_NODE_PROMPT,
allow_interrupt=False,
add_global_prompt=False,
extraction_enabled=False, # Explicitly disabled
),
),
RFNodeDTO(
id="end",
type=NodeType.endNode,
position=Position(x=0, y=400),
data=NodeDataDTO(
name="End Call",
prompt=END_NODE_PROMPT,
is_end=True,
allow_interrupt=False,
add_global_prompt=False,
extraction_enabled=False, # Explicitly disabled
),
),
],
edges=[
RFEdgeDTO(
id="start-agent",
source="start",
target="agent",
data=EdgeDataDTO(
label="Collect Info",
condition="When user has been greeted, proceed to collect information",
),
),
RFEdgeDTO(
id="agent-end",
source="agent",
target="end",
data=EdgeDataDTO(
label="End Call",
condition="When information collection is complete, end the call",
),
),
],
)
return WorkflowGraph(dto)
class TestVariableExtractionDuringTransitions:
"""Test that variable extraction is triggered for the correct node during transitions."""
@pytest.mark.asyncio
async def test_extraction_called_for_source_node_not_target_node(
self, three_node_workflow_with_extraction_on_start: WorkflowGraph
):
"""Test that when transitioning from START to AGENT, extraction is called for START node.
Scenario:
1. Start node has extraction_enabled=True with extraction_variables
2. Agent node has extraction_enabled=False
3. LLM calls transition function to move from START to AGENT
4. VERIFY: Variable extraction should be called for START node's variables
5. VERIFY: Variable extraction should NOT be called for AGENT node
This test verifies that extraction happens for the SOURCE node of a transition,
which is the node where the conversation context that needs extraction occurred.
"""
# Track which nodes had extraction performed
extraction_calls: List[Dict[str, Any]] = []
# Step 0 (Start node): call collect_info to transition to agent
step_0_chunks = MockLLMService.create_function_call_chunks(
function_name="collect_info",
arguments={},
tool_call_id="call_transition_1",
)
# Step 1 (Agent node): call end_call to transition to end
step_1_chunks = MockLLMService.create_function_call_chunks(
function_name="end_call",
arguments={},
tool_call_id="call_transition_2",
)
# Step 2 (End node): text response
step_2_chunks = MockLLMService.create_text_chunks("Goodbye!")
mock_steps = [step_0_chunks, step_1_chunks, step_2_chunks]
# Create mock LLM
llm = MockLLMService(mock_steps=mock_steps, chunk_delay=0.001)
# Create MockTTSService
tts = MockTTSService(mock_audio_duration_ms=10, frame_delay=0)
mock_transport_emulator = MockTransportProcessor(emit_bot_speaking=False)
# Create LLM context
context = LLMContext()
# Add assistant context aggregator
assistant_params = LLMAssistantAggregatorParams(expect_stripped_words=True)
context_aggregator = LLMContextAggregatorPair(
context, assistant_params=assistant_params
)
assistant_context_aggregator = context_aggregator.assistant()
workflow = three_node_workflow_with_extraction_on_start
# Create PipecatEngine with the workflow
engine = PipecatEngine(
llm=llm,
context=context,
workflow=workflow,
call_context_vars={"customer_name": "Test User"},
workflow_run_id=1,
)
# Patch _perform_variable_extraction_if_needed to track calls
original_perform_extraction = engine._perform_variable_extraction_if_needed
async def tracked_perform_extraction(node):
extraction_calls.append(
{
"node_id": node.id if node else None,
"node_name": node.name if node else None,
"extraction_enabled": node.extraction_enabled if node else None,
"extraction_variables": node.extraction_variables if node else None,
}
)
# Call original to maintain behavior
await original_perform_extraction(node)
engine._perform_variable_extraction_if_needed = tracked_perform_extraction
# Create the pipeline
pipeline = Pipeline(
[
llm,
tts,
mock_transport_emulator,
assistant_context_aggregator,
]
)
# Create pipeline task
task = PipelineTask(
pipeline,
params=PipelineParams(allow_interruptions=False),
enable_rtvi=False,
)
engine.set_task(task)
# Patch DB calls and extraction manager
with patch(
"api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run",
new_callable=AsyncMock,
return_value=1,
):
with patch(
"api.services.workflow.pipecat_engine.apply_disposition_mapping",
new_callable=AsyncMock,
return_value="completed",
):
# Mock the actual extraction to avoid needing a real LLM
with patch.object(
VariableExtractionManager,
"_perform_extraction",
new_callable=AsyncMock,
return_value={"user_name": "John Doe"},
):
runner = PipelineRunner()
async def run_pipeline():
await runner.run(task)
async def initialize_engine():
await asyncio.sleep(0.01)
await engine.initialize()
await engine.llm.queue_frame(LLMContextFrame(engine.context))
await asyncio.gather(run_pipeline(), initialize_engine())
# Should have 3 LLM generations
assert llm.get_current_step() == 3
# Verify extraction was called during transitions
# The key assertion: when transitioning from START to AGENT,
# the extraction should be for START node (which has extraction enabled)
# Filter to only calls where extraction was actually attempted
# (node has extraction_enabled=True and extraction_variables)
extraction_enabled_calls = [
call
for call in extraction_calls
if call["extraction_enabled"] and call["extraction_variables"]
]
# START node has extraction enabled, so when transitioning FROM start,
# extraction should be triggered for START's variables
assert len(extraction_enabled_calls) >= 1, (
f"Expected at least 1 extraction call for start node, got {len(extraction_enabled_calls)}. "
f"All calls: {extraction_calls}"
)
# Verify the extraction was called for the START node
start_extraction_calls = [
call for call in extraction_enabled_calls if call["node_id"] == "start"
]
assert len(start_extraction_calls) >= 1, (
f"Expected extraction to be called for START node (which has extraction enabled), "
f"but got calls for: {[c['node_id'] for c in extraction_enabled_calls]}"
)
# Verify extraction was NOT called for AGENT node
agent_extraction_calls = [
call
for call in extraction_calls
if call["node_id"] == "agent" and call["extraction_enabled"]
]
assert len(agent_extraction_calls) == 0, (
f"Expected NO extraction calls for AGENT node (extraction disabled), "
f"but got {len(agent_extraction_calls)} calls"
)

View file

@ -50,7 +50,7 @@ async def test_interruption_with_blocked_end_frame():
transport = MockTransport()
pipeline = Pipeline([transport, busy_wait_processor])
task = PipelineTask(pipeline)
task = PipelineTask(pipeline, enable_rtvi=False)
async def run_pipeline():
loop = asyncio.get_running_loop()

View file

@ -106,8 +106,7 @@ async def run_pipeline_with_user_idle(
# Create pipeline task
task = PipelineTask(
pipeline,
params=PipelineParams(allow_interruptions=False),
pipeline, params=PipelineParams(allow_interruptions=False), enable_rtvi=False
)
engine.set_task(task)