From 56953bbd09a73e53b6b3a806edf6aaae5c9c4df7 Mon Sep 17 00:00:00 2001 From: Abhishek Date: Sat, 3 Jan 2026 12:59:18 +0530 Subject: [PATCH] Fix/multiple generation (#104) * fixes #100 * Fix test * fix: fix bad configuration issue --- api/db/user_client.py | 23 +- api/services/pipecat/run_pipeline.py | 1 - api/services/workflow/pipecat_engine.py | 46 +-- .../workflow/pipecat_engine_callbacks.py | 6 +- api/services/workflow/test/__init__.py | 0 .../test/test_aggregation_integration.py | 128 ------- .../test/test_interruption_correction.py | 159 -------- api/tests/conftest.py | 256 +++++++++++++ .../test => tests}/definitions/rf-1.json | 0 .../test => tests}/test_aggregation_fix.py | 8 +- .../test => tests}/test_cost_calculator.py | 0 .../test_custom_tools_context_integration.py | 118 +----- .../workflow/test => tests}/test_dto.py | 2 +- api/tests/test_pipecat_engine_tool_calls.py | 340 ++++++++++++++++++ ui/src/app/tools/[toolUuid]/page.tsx | 12 +- ui/src/components/flow/nodes/WebhookNode.tsx | 12 +- ui/src/components/http/index.ts | 1 + ui/src/components/http/url-input.tsx | 106 ++++++ 18 files changed, 758 insertions(+), 460 deletions(-) delete mode 100644 api/services/workflow/test/__init__.py delete mode 100644 api/services/workflow/test/test_aggregation_integration.py delete mode 100644 api/services/workflow/test/test_interruption_correction.py create mode 100644 api/tests/conftest.py rename api/{services/workflow/test => tests}/definitions/rf-1.json (100%) rename api/{services/workflow/test => tests}/test_aggregation_fix.py (96%) rename api/{services/workflow/test => tests}/test_cost_calculator.py (100%) rename api/{services/workflow/test => tests}/test_dto.py (76%) create mode 100644 api/tests/test_pipecat_engine_tool_calls.py create mode 100644 ui/src/components/http/url-input.tsx diff --git a/api/db/user_client.py b/api/db/user_client.py index c24a85f..ab57f50 100644 --- a/api/db/user_client.py +++ b/api/db/user_client.py @@ -1,5 +1,7 @@ from datetime import datetime, timezone +from loguru import logger +from pydantic import ValidationError from sqlalchemy.future import select from api.db.base_client import BaseDBClient @@ -66,12 +68,21 @@ class UserClient(BaseDBClient): if not configuration_obj: return UserConfiguration() - return UserConfiguration.model_validate( - { - **configuration_obj.configuration, - "last_validated_at": configuration_obj.last_validated_at, - } - ) + try: + return UserConfiguration.model_validate( + { + **configuration_obj.configuration, + "last_validated_at": configuration_obj.last_validated_at, + } + ) + except ValidationError as e: + # If configuration contains an unsupported provider, + # return a default configuration without failing + logger.warning( + f"Failed to validate user configuration for user {user_id}: {e}. " + "Returning default configuration." + ) + return UserConfiguration() async def update_user_configuration( self, user_id: int, configuration: UserConfiguration diff --git a/api/services/pipecat/run_pipeline.py b/api/services/pipecat/run_pipeline.py index 82e5c71..f8765ad 100644 --- a/api/services/pipecat/run_pipeline.py +++ b/api/services/pipecat/run_pipeline.py @@ -494,7 +494,6 @@ async def _run_pipeline( max_duration_end_task_callback=engine.create_max_duration_callback(), generation_started_callback=engine.create_generation_started_callback(), llm_text_frame_callback=engine.handle_llm_text_frame, - # Note: speaking event callbacks are now handled by pre-aggregator processor ) pipeline_metrics_aggregator = PipelineMetricsAggregator() diff --git a/api/services/workflow/pipecat_engine.py b/api/services/workflow/pipecat_engine.py index 8084d9a..6430bce 100644 --- a/api/services/workflow/pipecat_engine.py +++ b/api/services/workflow/pipecat_engine.py @@ -13,9 +13,8 @@ from pipecat.frames.frames import ( CancelFrame, EndFrame, FunctionCallResultProperties, + FunctionCallsFromLLMInfoFrame, LLMContextFrame, - LLMFullResponseEndFrame, - LLMFullResponseStartFrame, TTSSpeakFrame, ) from pipecat.pipeline.task import PipelineTask @@ -104,7 +103,7 @@ class PipecatEngine: self._builtin_function_schemas: Optional[list[dict]] = None # Track current LLM reference text for TTS aggregation correction - self._current_llm_reference_text: str = "" + self._current_llm_generation_reference_text: str = "" # Custom tool manager (initialized in initialize()) self._custom_tool_manager: Optional[CustomToolManager] = None @@ -173,6 +172,9 @@ class PipecatEngine: await self._register_builtin_functions() await self.set_node(self.workflow.start_node_id) + + # Trigger initial LLM generation + await self.task.queue_frame(LLMContextFrame(self.context)) logger.debug(f"{self.__class__.__name__} initialized") except Exception as e: logger.error(f"Error initializing {self.__class__.__name__}: {e}") @@ -218,7 +220,6 @@ class PipecatEngine: result = {"status": "done"} properties = FunctionCallResultProperties( - run_llm=False, on_context_updated=on_context_updated, ) @@ -256,8 +257,6 @@ class PipecatEngine: """Register built-in functions (calculator and timezone) with the LLM.""" logger.debug("Registering built-in functions with LLM") - properties = FunctionCallResultProperties(run_llm=True) - # Register calculator function async def calculate_func(function_call_params: FunctionCallParams) -> None: logger.info(f"LLM Function Call EXECUTED: safe_calculator") @@ -266,12 +265,10 @@ class PipecatEngine: expr = function_call_params.arguments.get("expression", "") result = safe_calculator(expr) await function_call_params.result_callback( - {"expression": expr, "result": result}, properties=properties + {"expression": expr, "result": result} ) except Exception as e: - await function_call_params.result_callback( - {"error": str(e)}, properties=properties - ) + await function_call_params.result_callback({"error": str(e)}) # Register timezone functions async def get_current_time_func( @@ -282,13 +279,9 @@ class PipecatEngine: try: timezone = function_call_params.arguments.get("timezone", "UTC") result = get_current_time(timezone) - await function_call_params.result_callback( - result, properties=properties - ) + await function_call_params.result_callback(result) except Exception as e: - await function_call_params.result_callback( - {"error": str(e)}, properties=properties - ) + await function_call_params.result_callback({"error": str(e)}) async def convert_time_func(function_call_params: FunctionCallParams) -> None: logger.info(f"LLM Function Call EXECUTED: convert_time") @@ -299,29 +292,15 @@ class PipecatEngine: function_call_params.arguments.get("time"), function_call_params.arguments.get("target_timezone"), ) - await function_call_params.result_callback( - result, properties=properties - ) + await function_call_params.result_callback(result) except Exception as e: - await function_call_params.result_callback( - {"error": str(e)}, properties=properties - ) + await function_call_params.result_callback({"error": str(e)}) # Register all built-in functions self.llm.register_function("safe_calculator", calculate_func) self.llm.register_function("get_current_time", get_current_time_func) self.llm.register_function("convert_time", convert_time_func) - async def _queue_tts_response(self, text: str) -> None: - """Queue TTS frames for static text response.""" - await self.task.queue_frames( - [ - LLMFullResponseStartFrame(), - TTSSpeakFrame(text=text), - LLMFullResponseEndFrame(), - ] - ) - async def _perform_variable_extraction_if_needed( self, previous_node: Optional[Node] ) -> None: @@ -384,7 +363,6 @@ class PipecatEngine: functions, ) = await self._compose_system_message_functions_for_node(node) await self._update_llm_context(system_message, functions) - await self.task.queue_frame(LLMContextFrame(self.context)) async def set_node(self, node_id: str): """ @@ -733,7 +711,7 @@ class PipecatEngine: async def handle_llm_text_frame(self, text: str): """Accumulate LLM text frames to build reference text.""" - self._current_llm_reference_text += text + self._current_llm_generation_reference_text += text def handle_client_disconnected(self): """Handle client disconnected event.""" diff --git a/api/services/workflow/pipecat_engine_callbacks.py b/api/services/workflow/pipecat_engine_callbacks.py index a7d6b51..52167fd 100644 --- a/api/services/workflow/pipecat_engine_callbacks.py +++ b/api/services/workflow/pipecat_engine_callbacks.py @@ -114,10 +114,10 @@ def create_max_duration_callback(engine: "PipecatEngine"): def create_generation_started_callback(engine: "PipecatEngine"): """Return a callback that resets flags at the start of each LLM generation.""" - async def handle_generation_started(): # noqa: D401 + async def handle_generation_started(): logger.debug("LLM generation started in callback processor") # Clear reference text from previous generation - engine._current_llm_reference_text = "" + engine._current_llm_generation_reference_text = "" return handle_generation_started @@ -184,7 +184,7 @@ def create_aggregation_correction_callback(engine: "PipecatEngine"): return "".join(out_chars) def correct_aggregation(corrupted: str) -> str: - reference = engine._current_llm_reference_text + reference = engine._current_llm_generation_reference_text if not reference: logger.warning("No reference text available for aggregation correction") diff --git a/api/services/workflow/test/__init__.py b/api/services/workflow/test/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/api/services/workflow/test/test_aggregation_integration.py b/api/services/workflow/test/test_aggregation_integration.py deleted file mode 100644 index 92e5faa..0000000 --- a/api/services/workflow/test/test_aggregation_integration.py +++ /dev/null @@ -1,128 +0,0 @@ -from unittest.mock import Mock - -import pytest -from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams -from pipecat.services.openai.llm import OpenAILLMContext - -from api.services.workflow.pipecat_engine import PipecatEngine -from api.services.workflow.pipecat_engine_callbacks import ( - create_generation_started_callback, -) - - -class TestAggregationIntegration: - """Integration tests for the TTS aggregation correction flow.""" - - @pytest.mark.asyncio - async def test_engine_reference_text_tracking(self): - """Test that the engine properly tracks LLM reference text.""" - # Create mock dependencies - mock_task = Mock() - mock_llm = Mock() - mock_context = Mock(spec=OpenAILLMContext) - mock_tts = Mock() - mock_workflow = Mock() - mock_workflow.start_node_id = "start" - mock_workflow.nodes = { - "start": Mock(is_start=True, is_static=True, is_end=False, out_edges=[]) - } - - # Create engine - engine = PipecatEngine( - task=mock_task, - llm=mock_llm, - context=mock_context, - tts=mock_tts, - workflow=mock_workflow, - call_context_vars={}, - workflow_run_id=1, - ) - - # Test initial state - assert engine._current_llm_reference_text == "" - - # Test accumulating LLM text - await engine.handle_llm_text_frame("Hello ") - assert engine._current_llm_reference_text == "Hello " - - await engine.handle_llm_text_frame("world!") - assert engine._current_llm_reference_text == "Hello world!" - - # Test generation started callback clears reference text - callback = create_generation_started_callback(engine) - await callback() - assert engine._current_llm_reference_text == "" - - @pytest.mark.asyncio - async def test_aggregation_correction_callback_creation(self): - """Test creating the aggregation correction callback.""" - # Create mock engine - mock_task = Mock() - mock_llm = Mock() - mock_context = Mock(spec=OpenAILLMContext) - mock_workflow = Mock() - - engine = PipecatEngine( - task=mock_task, - llm=mock_llm, - context=mock_context, - workflow=mock_workflow, - call_context_vars={}, - workflow_run_id=1, - ) - - # Set reference text - engine._current_llm_reference_text = "Hello, world! How are you?" - - # Create correction callback - callback = engine.create_aggregation_correction_callback() - - # Test correction - note that trailing punctuation might be stripped if not in corrupted text - corrected = callback("Hello world How are you") - assert corrected == "Hello, world! How are you" - - def test_llm_assistant_aggregator_params_with_callback(self): - """Test that LLMAssistantAggregatorParams accepts correction callback.""" - - def mock_callback(text: str) -> str: - return text.upper() - - params = LLMAssistantAggregatorParams( - expect_stripped_words=True, correct_aggregation_callback=mock_callback - ) - - assert params.expect_stripped_words is True - assert params.correct_aggregation_callback is not None - assert params.correct_aggregation_callback("hello") == "HELLO" - - @pytest.mark.asyncio - async def test_pipeline_callbacks_processor_llm_text_frame(self): - """Test that PipelineEngineCallbacksProcessor handles LLMTextFrame.""" - from pipecat.frames.frames import LLMTextFrame - from pipecat.processors.frame_processor import FrameDirection - - from api.services.pipecat.pipeline_engine_callbacks_processor import ( - PipelineEngineCallbacksProcessor, - ) - - # Track callback invocations - callback_invoked = False - callback_text = None - - async def mock_llm_text_callback(text: str): - nonlocal callback_invoked, callback_text - callback_invoked = True - callback_text = text - - # Create processor with callback - processor = PipelineEngineCallbacksProcessor( - llm_text_frame_callback=mock_llm_text_callback - ) - - # Process LLMTextFrame - frame = LLMTextFrame(text="Hello world") - await processor.process_frame(frame, FrameDirection.DOWNSTREAM) - - # Verify callback was invoked - assert callback_invoked is True - assert callback_text == "Hello world" diff --git a/api/services/workflow/test/test_interruption_correction.py b/api/services/workflow/test/test_interruption_correction.py deleted file mode 100644 index 21da54e..0000000 --- a/api/services/workflow/test/test_interruption_correction.py +++ /dev/null @@ -1,159 +0,0 @@ -from unittest.mock import AsyncMock, Mock - -import pytest -from pipecat.frames.frames import StartInterruptionFrame -from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams -from pipecat.services.openai.llm import ( - OpenAIAssistantContextAggregator, - OpenAILLMContext, -) - - -class TestInterruptionCorrection: - """Test that TTS aggregation correction works during interruptions.""" - - @pytest.mark.asyncio - async def test_openai_interruption_with_correction(self): - """Test OpenAI assistant context aggregator applies correction during interruption.""" - # Create mock context - mock_context = Mock(spec=OpenAILLMContext) - mock_context.get_messages.return_value = [] - mock_context.add_message = Mock() - - # Create correction callback - def correction_callback(text: str) -> str: - # Simulate fixing corrupted text - if text == "Hello world how are you": - return "Hello world, how are you" - return text - - # Create aggregator with correction callback - params = LLMAssistantAggregatorParams( - expect_stripped_words=True, correct_aggregation_callback=correction_callback - ) - - aggregator = OpenAIAssistantContextAggregator( - context=mock_context, params=params - ) - - # Set up aggregation state - aggregator._aggregation = "Hello world how are you" - aggregator._current_llm_response_id = "test-id" - aggregator._response_function_messages = {} - aggregator._function_calls_in_progress = {} - aggregator._started = 1 - - # Mock push_context_frame and reset methods - aggregator.push_context_frame = AsyncMock() - aggregator.reset = AsyncMock() - - # Process interruption - interruption_frame = StartInterruptionFrame() - await aggregator._handle_interruptions(interruption_frame) - - # Verify the corrected text was added to context - mock_context.add_message.assert_called_once() - added_message = mock_context.add_message.call_args[0][0] - assert added_message["role"] == "assistant" - assert ( - added_message["content"] - == "Hello world, how are you <>" - ) - - @pytest.mark.asyncio - async def test_google_interruption_with_correction(self): - """Test Google assistant context aggregator applies correction during interruption.""" - from pipecat.services.google.llm import ( - Content, - GoogleAssistantContextAggregator, - ) - - # Create mock context - mock_context = Mock(spec=OpenAILLMContext) - mock_context.get_messages.return_value = [] - mock_context.add_message = Mock() - - # Create correction callback - def correction_callback(text: str) -> str: - # Simulate fixing corrupted text - if text == "I am here to help": - return "I am here to help" - return text - - # Create aggregator with correction callback - params = LLMAssistantAggregatorParams( - expect_stripped_words=True, correct_aggregation_callback=correction_callback - ) - - aggregator = GoogleAssistantContextAggregator( - context=mock_context, params=params - ) - - # Set up aggregation state - aggregator._aggregation = "I am here to help" - aggregator._current_llm_response_id = "test-id" - aggregator._response_function_messages = {} - aggregator._function_calls_in_progress = {} - aggregator._started = 1 - - # Mock push_context_frame and reset methods - aggregator.push_context_frame = AsyncMock() - aggregator.reset = AsyncMock() - - # Process interruption - interruption_frame = StartInterruptionFrame() - await aggregator._handle_interruptions(interruption_frame) - - # Verify the corrected text was added to context - mock_context.add_message.assert_called_once() - added_content = mock_context.add_message.call_args[0][0] - - # Google uses Content objects - assert isinstance(added_content, Content) - assert added_content.role == "model" - assert len(added_content.parts) == 1 - assert ( - added_content.parts[0].text == "I am here to help <>" - ) - - @pytest.mark.asyncio - async def test_interruption_correction_error_handling(self): - """Test that interruption handling continues even if correction callback fails.""" - # Create mock context - mock_context = Mock(spec=OpenAILLMContext) - mock_context.get_messages.return_value = [] - mock_context.add_message = Mock() - - # Create correction callback that raises error - def failing_callback(text: str) -> str: - raise ValueError("Correction failed") - - # Create aggregator with failing callback - params = LLMAssistantAggregatorParams( - expect_stripped_words=True, correct_aggregation_callback=failing_callback - ) - - aggregator = OpenAIAssistantContextAggregator( - context=mock_context, params=params - ) - - # Set up aggregation state - aggregator._aggregation = "Some text" - aggregator._current_llm_response_id = "test-id" - aggregator._response_function_messages = {} - aggregator._function_calls_in_progress = {} - aggregator._started = 1 - - # Mock push_context_frame and reset methods - aggregator.push_context_frame = AsyncMock() - aggregator.reset = AsyncMock() - - # Process interruption - should not raise - interruption_frame = StartInterruptionFrame() - await aggregator._handle_interruptions(interruption_frame) - - # Verify the original text was still added (fallback behavior) - mock_context.add_message.assert_called_once() - added_message = mock_context.add_message.call_args[0][0] - assert added_message["role"] == "assistant" - assert added_message["content"] == "Some text <>" diff --git a/api/tests/conftest.py b/api/tests/conftest.py new file mode 100644 index 0000000..236f710 --- /dev/null +++ b/api/tests/conftest.py @@ -0,0 +1,256 @@ +from dataclasses import dataclass +from typing import Any, Dict +from unittest.mock import Mock + +import pytest + +from api.services.workflow.dto import ( + EdgeDataDTO, + NodeDataDTO, + NodeType, + Position, + ReactFlowDTO, + RFEdgeDTO, + RFNodeDTO, +) +from api.services.workflow.workflow import WorkflowGraph + +START_CALL_SYSTEM_PROMPT = "start_call_system_prompt" +END_CALL_SYSTEM_PROMPT = "end_call_system_prompt" + + +@dataclass +class MockToolModel: + """Mock tool model for testing.""" + + tool_uuid: str + name: str + description: str + definition: Dict[str, Any] + + +@pytest.fixture +def mock_engine(): + """Create a mock PipecatEngine.""" + engine = Mock() + engine._workflow_run_id = 1 + engine._call_context_vars = {"customer_name": "John Doe"} + engine.llm = Mock() + engine.llm.register_function = Mock() + return engine + + +@pytest.fixture +def sample_tools(): + """Create sample mock tools for testing.""" + return [ + MockToolModel( + tool_uuid="weather-uuid-123", + name="Get Weather", + description="Get current weather for a location", + definition={ + "schema_version": 1, + "type": "http_api", + "config": { + "method": "GET", + "url": "https://api.weather.com/current", + "parameters": [ + { + "name": "location", + "type": "string", + "description": "City name (e.g., San Francisco, CA)", + "required": True, + }, + { + "name": "units", + "type": "string", + "description": "Temperature units: celsius or fahrenheit", + "required": False, + }, + ], + }, + }, + ), + MockToolModel( + tool_uuid="booking-uuid-456", + name="Book Appointment", + description="Book an appointment for the customer", + definition={ + "schema_version": 1, + "type": "http_api", + "config": { + "method": "POST", + "url": "https://api.example.com/appointments", + "parameters": [ + { + "name": "customer_name", + "type": "string", + "description": "Customer's full name", + "required": True, + }, + { + "name": "date", + "type": "string", + "description": "Appointment date (YYYY-MM-DD)", + "required": True, + }, + { + "name": "time", + "type": "string", + "description": "Appointment time (HH:MM)", + "required": True, + }, + { + "name": "notes", + "type": "string", + "description": "Additional notes", + "required": False, + }, + ], + }, + }, + ), + MockToolModel( + tool_uuid="lookup-uuid-789", + name="Customer Lookup", + description="Look up customer information by phone number", + definition={ + "schema_version": 1, + "type": "http_api", + "config": { + "method": "GET", + "url": "https://api.example.com/customers/lookup", + "parameters": [ + { + "name": "phone", + "type": "string", + "description": "Customer phone number", + "required": True, + }, + ], + }, + }, + ), + ] + + +@pytest.fixture +def simple_workflow() -> WorkflowGraph: + """Create a simple two-node workflow for testing. + + The workflow has: + - Start node with a prompt + - End node with a prompt + - One edge connecting them with label "End Call" + """ + dto = ReactFlowDTO( + nodes=[ + RFNodeDTO( + id="1", + type=NodeType.startNode, + position=Position(x=0, y=0), + data=NodeDataDTO( + name="Start Call", + prompt=START_CALL_SYSTEM_PROMPT, + is_start=True, + allow_interrupt=False, + add_global_prompt=False, + ), + ), + RFNodeDTO( + id="2", + type=NodeType.endNode, + position=Position(x=0, y=200), + data=NodeDataDTO( + name="End Call", + prompt=END_CALL_SYSTEM_PROMPT, + is_end=True, + allow_interrupt=False, + add_global_prompt=False, + ), + ), + ], + edges=[ + RFEdgeDTO( + id="1-2", + source="1", + target="2", + data=EdgeDataDTO( + label="End Call", + condition="When the user says to end the call, end the call", + ), + ), + ], + ) + return WorkflowGraph(dto) + + +@pytest.fixture +def three_node_workflow() -> WorkflowGraph: + """Create a three-node workflow for testing with an intermediate agent node. + + The workflow has: + - Start node + - Agent node (for collecting information) + - End node + """ + dto = ReactFlowDTO( + nodes=[ + RFNodeDTO( + id="1", + type=NodeType.startNode, + position=Position(x=0, y=0), + data=NodeDataDTO( + name="Start Call", + prompt=START_CALL_SYSTEM_PROMPT, + is_start=True, + allow_interrupt=True, + add_global_prompt=False, + ), + ), + RFNodeDTO( + id="2", + type=NodeType.agentNode, + position=Position(x=0, y=200), + data=NodeDataDTO( + name="Collect Info", + prompt="Help the user with their request. Ask clarifying questions if needed.", + allow_interrupt=True, + add_global_prompt=False, + ), + ), + RFNodeDTO( + id="3", + type=NodeType.endNode, + position=Position(x=0, y=400), + data=NodeDataDTO( + name="End Call", + prompt=END_CALL_SYSTEM_PROMPT, + is_end=True, + allow_interrupt=False, + add_global_prompt=False, + ), + ), + ], + edges=[ + RFEdgeDTO( + id="1-2", + source="1", + target="2", + data=EdgeDataDTO( + label="Collect Info", + condition="When the user wants help, collect their information", + ), + ), + RFEdgeDTO( + id="2-3", + source="2", + target="3", + data=EdgeDataDTO( + label="End Call", + condition="When the user is done or wants to end the call", + ), + ), + ], + ) + return WorkflowGraph(dto) diff --git a/api/services/workflow/test/definitions/rf-1.json b/api/tests/definitions/rf-1.json similarity index 100% rename from api/services/workflow/test/definitions/rf-1.json rename to api/tests/definitions/rf-1.json diff --git a/api/services/workflow/test/test_aggregation_fix.py b/api/tests/test_aggregation_fix.py similarity index 96% rename from api/services/workflow/test/test_aggregation_fix.py rename to api/tests/test_aggregation_fix.py index 373ec1d..001cc41 100644 --- a/api/services/workflow/test/test_aggregation_fix.py +++ b/api/tests/test_aggregation_fix.py @@ -10,14 +10,14 @@ def test_aggregation_fixer(): creates a fresh callback for every (reference, corrupted) pair. The production callback now needs a PipecatEngine instance with the - `_current_llm_reference_text` set. For test-friendliness we mock a bare + `_current_llm_generation_reference_text` set. For test-friendliness we mock a bare object providing just that attribute for each assertion so the original two-argument test cases remain unchanged. """ def fixer(reference: str, corrupted: str) -> str: # noqa: D401 mock_engine = Mock() - mock_engine._current_llm_reference_text = reference + mock_engine._current_llm_generation_reference_text = reference return create_aggregation_correction_callback(mock_engine)(corrupted) ##### Trailing extra Chars ##### @@ -172,7 +172,7 @@ def test_create_aggregation_correction_callback(): """Test the new aggregation correction callback creator.""" # Mock engine with reference text mock_engine = Mock() - mock_engine._current_llm_reference_text = "Good Morning Mr NARGES, My name is Alex and I am calling you from Consumer Services." + mock_engine._current_llm_generation_reference_text = "Good Morning Mr NARGES, My name is Alex and I am calling you from Consumer Services." # Create callback callback = create_aggregation_correction_callback(mock_engine) @@ -187,6 +187,6 @@ def test_create_aggregation_correction_callback(): ) # Test with no reference text - mock_engine._current_llm_reference_text = "" + mock_engine._current_llm_generation_reference_text = "" corrected = callback("Some corrupted text") assert corrected == "Some corrupted text" # Should return as-is when no reference diff --git a/api/services/workflow/test/test_cost_calculator.py b/api/tests/test_cost_calculator.py similarity index 100% rename from api/services/workflow/test/test_cost_calculator.py rename to api/tests/test_cost_calculator.py diff --git a/api/tests/test_custom_tools_context_integration.py b/api/tests/test_custom_tools_context_integration.py index afdee4b..5cfa703 100644 --- a/api/tests/test_custom_tools_context_integration.py +++ b/api/tests/test_custom_tools_context_integration.py @@ -6,9 +6,7 @@ This module tests the full flow of: 3. Verifying the context is properly configured for LLM generation """ -from dataclasses import dataclass -from typing import Any, Dict -from unittest.mock import AsyncMock, Mock, patch +from unittest.mock import AsyncMock, patch import pytest @@ -17,126 +15,14 @@ from api.services.workflow.pipecat_engine_utils import ( get_function_schema, update_llm_context, ) +from api.tests.conftest import MockToolModel from pipecat.adapters.schemas.function_schema import FunctionSchema from pipecat.processors.aggregators.llm_context import LLMContext -@dataclass -class MockToolModel: - """Mock tool model for testing.""" - - tool_uuid: str - name: str - description: str - definition: Dict[str, Any] - - class TestCustomToolManagerContextIntegration: """Integration tests for CustomToolManager with LLMContext.""" - @pytest.fixture - def mock_engine(self): - """Create a mock PipecatEngine.""" - engine = Mock() - engine._workflow_run_id = 1 - engine._call_context_vars = {"customer_name": "John Doe"} - engine.llm = Mock() - engine.llm.register_function = Mock() - return engine - - @pytest.fixture - def sample_tools(self): - """Create sample mock tools for testing.""" - return [ - MockToolModel( - tool_uuid="weather-uuid-123", - name="Get Weather", - description="Get current weather for a location", - definition={ - "schema_version": 1, - "type": "http_api", - "config": { - "method": "GET", - "url": "https://api.weather.com/current", - "parameters": [ - { - "name": "location", - "type": "string", - "description": "City name (e.g., San Francisco, CA)", - "required": True, - }, - { - "name": "units", - "type": "string", - "description": "Temperature units: celsius or fahrenheit", - "required": False, - }, - ], - }, - }, - ), - MockToolModel( - tool_uuid="booking-uuid-456", - name="Book Appointment", - description="Book an appointment for the customer", - definition={ - "schema_version": 1, - "type": "http_api", - "config": { - "method": "POST", - "url": "https://api.example.com/appointments", - "parameters": [ - { - "name": "customer_name", - "type": "string", - "description": "Customer's full name", - "required": True, - }, - { - "name": "date", - "type": "string", - "description": "Appointment date (YYYY-MM-DD)", - "required": True, - }, - { - "name": "time", - "type": "string", - "description": "Appointment time (HH:MM)", - "required": True, - }, - { - "name": "notes", - "type": "string", - "description": "Additional notes", - "required": False, - }, - ], - }, - }, - ), - MockToolModel( - tool_uuid="lookup-uuid-789", - name="Customer Lookup", - description="Look up customer information by phone number", - definition={ - "schema_version": 1, - "type": "http_api", - "config": { - "method": "GET", - "url": "https://api.example.com/customers/lookup", - "parameters": [ - { - "name": "phone", - "type": "string", - "description": "Customer phone number", - "required": True, - }, - ], - }, - }, - ), - ] - @pytest.mark.asyncio async def test_get_tool_schemas_and_update_context(self, mock_engine, sample_tools): """Test fetching tool schemas via CustomToolManager and updating LLM context.""" diff --git a/api/services/workflow/test/test_dto.py b/api/tests/test_dto.py similarity index 76% rename from api/services/workflow/test/test_dto.py rename to api/tests/test_dto.py index 85655cf..5ded56e 100644 --- a/api/services/workflow/test/test_dto.py +++ b/api/tests/test_dto.py @@ -6,6 +6,6 @@ from api.services.workflow.dto import ReactFlowDTO @pytest.mark.asyncio async def test_dto(): # assert no exceptions are raised - with open("services/workflow/test/definitions/rf-1.json", "r") as f: + with open("tests/definitions/rf-1.json", "r") as f: dto = ReactFlowDTO.model_validate_json(f.read()) assert dto is not None diff --git a/api/tests/test_pipecat_engine_tool_calls.py b/api/tests/test_pipecat_engine_tool_calls.py new file mode 100644 index 0000000..d61e1ec --- /dev/null +++ b/api/tests/test_pipecat_engine_tool_calls.py @@ -0,0 +1,340 @@ +"""Tests for tool calls with PipecatEngine and MockLLM. + +This module tests the behavior when the LLM generates tool calls (single or parallel), +using PipecatEngine's actual function registration and execution logic. +""" + +import asyncio +from typing import Any, Dict, List +from unittest.mock import AsyncMock, patch + +import pytest + +from api.services.pipecat.pipeline_engine_callbacks_processor import ( + PipelineEngineCallbacksProcessor, +) +from api.services.workflow.pipecat_engine import PipecatEngine +from api.services.workflow.workflow import WorkflowGraph +from api.tests.conftest import END_CALL_SYSTEM_PROMPT +from pipecat.frames.frames import ( + BotStartedSpeakingFrame, + BotStoppedSpeakingFrame, + Frame, + TextFrame, +) +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineParams, PipelineTask +from pipecat.processors.aggregators.llm_context import LLMContext +from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams +from pipecat.processors.aggregators.llm_response_universal import ( + LLMContextAggregatorPair, +) +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor +from pipecat.tests import MockLLMService + + +class MockBotStoppedSpeakingOnLLMTextFrameProcessor(FrameProcessor): + """ + Mocking the transport, where transport sends BotStartedSpeakingFrame + and BotStoppedSpeakingFrame when it encounters a LLMTextFrame. + """ + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + if isinstance(frame, TextFrame): + await self.push_frame(BotStartedSpeakingFrame()) + await self.push_frame( + BotStartedSpeakingFrame(), direction=FrameDirection.UPSTREAM + ) + + await asyncio.sleep(0.1) + + await self.push_frame(BotStoppedSpeakingFrame()) + await self.push_frame( + BotStoppedSpeakingFrame(), direction=FrameDirection.UPSTREAM + ) + + await self.push_frame(frame, direction) + + +async def run_pipeline_with_tool_calls( + workflow: WorkflowGraph, + functions: List[Dict[str, Any]], + text: str | None = None, + num_text_steps: int = 1, +) -> tuple[MockLLMService, LLMContext]: + """Run a pipeline with mock tool calls and return the LLM for assertions. + + Args: + workflow: The workflow graph to use. + functions: List of function call definitions with name, arguments, and tool_call_id. + text: Text to add to the first step (streamed before the tool calls). + num_text_steps: Number of text response steps after the tool calls. + + Returns: + The MockLLMService instance for making assertions. + """ + # Create first step chunks + if text: + # Create text chunks (without final chunk) followed by function call chunks + text_chunks = MockLLMService.create_text_chunks(text) + func_chunks = MockLLMService.create_multiple_function_call_chunks(functions) + # Exclude the final chunk from text_chunks (which has finish_reason="stop") + first_step_chunks = text_chunks[:-1] + func_chunks + else: + first_step_chunks = MockLLMService.create_multiple_function_call_chunks( + functions + ) + + # Create multi-step responses + mock_steps = MockLLMService.create_multi_step_responses( + first_step_chunks, num_text_steps=num_text_steps, step_prefix="Response" + ) + + # Create MockLLMService with multi-step support + llm = MockLLMService(mock_steps=mock_steps, chunk_delay=0.001) + + mock_transport_emulator = MockBotStoppedSpeakingOnLLMTextFrameProcessor() + + # Create LLM context + context = LLMContext() + + # Add assistant context aggregator + assistant_params = LLMAssistantAggregatorParams(expect_stripped_words=True) + context_aggregator = LLMContextAggregatorPair( + context, assistant_params=assistant_params + ) + assistant_context_aggregator = context_aggregator.assistant() + + # Create PipecatEngine with the workflow + engine = PipecatEngine( + llm=llm, + context=context, + workflow=workflow, + call_context_vars={"customer_name": "Test User"}, + workflow_run_id=1, + ) + + # Create the pipeline with the mock LLM + pipeline = Pipeline( + [ + llm, + mock_transport_emulator, + assistant_context_aggregator, + ] + ) + + # Create a real pipeline task + task = PipelineTask( + pipeline, + params=PipelineParams(allow_interruptions=False), + ) + + engine.set_task(task) + + # Patch DB calls to avoid actual database access + with patch( + "api.services.workflow.pipecat_engine.get_organization_id_from_workflow_run", + new_callable=AsyncMock, + return_value=1, + ): + with patch( + "api.services.workflow.pipecat_engine.apply_disposition_mapping", + new_callable=AsyncMock, + return_value="completed", + ): + runner = PipelineRunner() + + async def run_pipeline(): + await runner.run(task) + + async def initialize_engine(): + # Small delay to let runner start + await asyncio.sleep(0.01) + await engine.initialize() + + # Run both concurrently + await asyncio.gather(run_pipeline(), initialize_engine()) + + return llm, context + + +class TestPipecatEngineToolCalls: + """Test tool calls through PipecatEngine.""" + + @pytest.mark.asyncio + async def test_parallel_builtin_and_transition_calls_through_engine( + self, simple_workflow: WorkflowGraph + ): + """Test parallel function calls using PipecatEngine's actual handlers. + + This test verifies that when the LLM generates parallel tool calls for: + 1. A built-in function (safe_calculator) - registered by _register_builtin_functions + 2. A transition function (end_call) - registered by _register_transition_function_with_llm + + Both functions are properly executed through the engine's handlers and + the transition correctly moves to the end node. + + The test uses multi-step mock responses: + - Step 1: Parallel tool calls (safe_calculator + end_call) + - Step 2+: Text responses for subsequent node prompts + """ + functions = [ + { + "name": "end_call", + "arguments": {}, + "tool_call_id": "call_transition", + }, + { + "name": "safe_calculator", + "arguments": {"expression": "25 * 4"}, + "tool_call_id": "call_calc", + }, + ] + + llm, context = await run_pipeline_with_tool_calls( + workflow=simple_workflow, + functions=functions, + num_text_steps=2, + ) + + # Assert that the LLM generation was called a total of 2 times, + # 1st time when StartNode was executed, and second time + # when EndCall generation happened + assert llm.get_current_step() == 2, ( + "LLM generation should have happened 2 times" + ) + + # Assert that the context was updated with END_CALL_SYSTEM_PROMPT + assert context.messages[0]["content"] == END_CALL_SYSTEM_PROMPT + + @pytest.mark.asyncio + async def test_parallel_builtin_and_transition_calls_through_engine_1( + self, simple_workflow: WorkflowGraph + ): + """Test parallel function calls using PipecatEngine's actual handlers. + + This test verifies that when the LLM generates parallel tool calls for: + 1. A built-in function (safe_calculator) - registered by _register_builtin_functions + 2. A transition function (end_call) - registered by _register_transition_function_with_llm + + Both functions are properly executed through the engine's handlers and + the transition correctly moves to the end node. + + The test uses multi-step mock responses: + - Step 1: Parallel tool calls (safe_calculator + end_call) + - Step 2+: Text responses for subsequent node prompts + """ + functions = [ + { + "name": "safe_calculator", + "arguments": {"expression": "25 * 4"}, + "tool_call_id": "call_calc", + }, + { + "name": "end_call", + "arguments": {}, + "tool_call_id": "call_transition", + }, + ] + + llm, context = await run_pipeline_with_tool_calls( + workflow=simple_workflow, + functions=functions, + num_text_steps=2, + ) + + # Assert that the LLM generation was called a total of 2 times, + # 1st time when StartNode was executed, and second time + # when EndCall generation happened. The tool should not invoke + # an LLM generation + assert llm.get_current_step() == 2, ( + "LLM generation should have happened 2 times" + ) + + # Assert that the context was updated with END_CALL_SYSTEM_PROMPT + assert context.messages[0]["content"] == END_CALL_SYSTEM_PROMPT + + @pytest.mark.asyncio + async def test_parallel_builtin_and_transition_calls_through_engine_with_text( + self, simple_workflow: WorkflowGraph + ): + """Test parallel function calls using PipecatEngine's actual handlers. + + This test verifies that when the LLM generates parallel tool calls for: + 1. A built-in function (safe_calculator) - registered by _register_builtin_functions + 2. A transition function (end_call) - registered by _register_transition_function_with_llm + + Both functions are properly executed through the engine's handlers and + the transition correctly moves to the end node. + + The test uses multi-step mock responses: + - Step 1: Parallel tool calls (safe_calculator + end_call) + - Step 2+: Text responses for subsequent node prompts + """ + functions = [ + { + "name": "end_call", + "arguments": {}, + "tool_call_id": "call_transition", + }, + { + "name": "safe_calculator", + "arguments": {"expression": "25 * 4"}, + "tool_call_id": "call_calc", + }, + ] + + llm, context = await run_pipeline_with_tool_calls( + workflow=simple_workflow, + functions=functions, + text="Hello There!", + num_text_steps=2, + ) + + # Assert that the LLM generation was called a total of 2 times, + # 1st time when StartNode was executed, and second time + # when EndCall generation happened. The tool should not invoke + # an LLM generation + assert llm.get_current_step() == 2, ( + "LLM generation should have happened 2 times" + ) + + # Assert that the context was updated with END_CALL_SYSTEM_PROMPT + assert context.messages[0]["content"] == END_CALL_SYSTEM_PROMPT + + @pytest.mark.asyncio + async def test_single_transition_call_through_engine( + self, simple_workflow: WorkflowGraph + ): + """Test a single transition function call (end_call) through PipecatEngine. + + This test verifies that when the LLM generates only a transition tool call, + the engine properly executes it and transitions to the end node. + Since end_call transitions to the end node which triggers another LLM + generation, the LLM is called exactly once for the initial StartNode. + """ + functions = [ + { + "name": "end_call", + "arguments": {}, + "tool_call_id": "call_transition", + }, + ] + + llm, context = await run_pipeline_with_tool_calls( + workflow=simple_workflow, + functions=functions, + num_text_steps=1, + ) + + # LLM is called once for the StartNode, then end_call transitions to EndNode + # which triggers a second generation + assert llm.get_current_step() == 2, ( + "LLM generation should have happened 2 times" + ) + + # Assert that the context was updated with END_CALL_SYSTEM_PROMPT + assert context.messages[0]["content"] == END_CALL_SYSTEM_PROMPT diff --git a/ui/src/app/tools/[toolUuid]/page.tsx b/ui/src/app/tools/[toolUuid]/page.tsx index 9e9698f..e4bf907 100644 --- a/ui/src/app/tools/[toolUuid]/page.tsx +++ b/ui/src/app/tools/[toolUuid]/page.tsx @@ -27,6 +27,8 @@ import { type KeyValueItem, ParameterEditor, type ToolParameter, + UrlInput, + validateUrl, } from "@/components/http"; import { Button } from "@/components/ui/button"; import { @@ -151,8 +153,9 @@ export default function ToolDetailPage() { const handleSave = async () => { // Validate URL - if (!url.trim()) { - setError("URL is required"); + const urlValidation = validateUrl(url); + if (!urlValidation.valid) { + setError(urlValidation.error || "Invalid URL"); return; } @@ -431,10 +434,11 @@ const data = await response.json();`;
- setUrl(e.target.value)} + onChange={setUrl} placeholder="https://api.example.com/appointments" + showValidation />
diff --git a/ui/src/components/flow/nodes/WebhookNode.tsx b/ui/src/components/flow/nodes/WebhookNode.tsx index 58c9470..6f1083f 100644 --- a/ui/src/components/flow/nodes/WebhookNode.tsx +++ b/ui/src/components/flow/nodes/WebhookNode.tsx @@ -10,6 +10,8 @@ import { HttpMethodSelector, KeyValueEditor, type KeyValueItem, + UrlInput, + validateUrl, } from "@/components/http"; import { Button } from "@/components/ui/button"; import { Input } from "@/components/ui/input"; @@ -57,8 +59,9 @@ export const WebhookNode = memo(({ data, selected, id }: WebhookNodeProps) => { const handleSave = async () => { // Validate endpoint URL - if (!endpointUrl.trim()) { - setEndpointError('Endpoint URL is required'); + const urlValidation = validateUrl(endpointUrl); + if (!urlValidation.valid) { + setEndpointError(urlValidation.error || 'Invalid URL'); return; } setEndpointError(null); @@ -284,10 +287,11 @@ const WebhookNodeEditForm = ({ - setEndpointUrl(e.target.value)} + onChange={setEndpointUrl} placeholder="https://api.example.com/webhook" + showValidation /> diff --git a/ui/src/components/http/index.ts b/ui/src/components/http/index.ts index 9922393..47b8956 100644 --- a/ui/src/components/http/index.ts +++ b/ui/src/components/http/index.ts @@ -3,3 +3,4 @@ export { CredentialSelector } from "./credential-selector"; export { type HttpMethod, HttpMethodSelector } from "./http-method-selector"; export { KeyValueEditor, type KeyValueItem } from "./key-value-editor"; export { ParameterEditor, type ParameterType,type ToolParameter } from "./parameter-editor"; +export { UrlInput, type UrlValidationResult,validateUrl } from "./url-input"; diff --git a/ui/src/components/http/url-input.tsx b/ui/src/components/http/url-input.tsx new file mode 100644 index 0000000..64a7611 --- /dev/null +++ b/ui/src/components/http/url-input.tsx @@ -0,0 +1,106 @@ +"use client"; + +import { useCallback, useState } from "react"; + +import { Input } from "@/components/ui/input"; +import { cn } from "@/lib/utils"; + +// URL regex pattern that validates: +// - http:// or https:// protocol (required) +// - Optional username:password@ +// - Domain name or IP address +// - Optional port number +// - Optional path, query string, and fragment +const URL_REGEX = + /^https?:\/\/(?:[\w-]+(?::[\w-]+)?@)?(?:[\w-]+\.)*[\w-]+(?::\d{1,5})?(?:\/[^\s]*)?$/i; + +export interface UrlValidationResult { + valid: boolean; + error?: string; +} + +export function validateUrl(url: string): UrlValidationResult { + const trimmedUrl = url.trim(); + + if (!trimmedUrl) { + return { valid: false, error: "URL is required" }; + } + + if (!URL_REGEX.test(trimmedUrl)) { + return { + valid: false, + error: "Invalid URL format. Must start with http:// or https://", + }; + } + + return { valid: true }; +} + +interface UrlInputProps { + value: string; + onChange: (value: string) => void; + placeholder?: string; + disabled?: boolean; + className?: string; + /** Show validation error styling and message inline */ + showValidation?: boolean; + /** Called when validation state changes */ + onValidationChange?: (result: UrlValidationResult) => void; +} + +export function UrlInput({ + value, + onChange, + placeholder = "https://api.example.com/endpoint", + disabled = false, + className, + showValidation = false, + onValidationChange, +}: UrlInputProps) { + const [touched, setTouched] = useState(false); + + const handleChange = useCallback( + (e: React.ChangeEvent) => { + const newValue = e.target.value; + onChange(newValue); + + if (onValidationChange && (touched || newValue)) { + onValidationChange(validateUrl(newValue)); + } + }, + [onChange, onValidationChange, touched] + ); + + const handleBlur = useCallback(() => { + setTouched(true); + const trimmedValue = value.trim(); + if (trimmedValue !== value) { + onChange(trimmedValue); + } + if (onValidationChange && trimmedValue) { + onValidationChange(validateUrl(trimmedValue)); + } + }, [onChange, onValidationChange, value]); + + const validation = validateUrl(value); + const showError = showValidation && touched && !validation.valid && value; + + return ( +
+ + {showError && ( +

{validation.error}

+ )} +
+ ); +}