dograh/api/tests/test_custom_tools_context_integration.py

390 lines
16 KiB
Python
Raw Permalink Normal View History

"""Integration tests for CustomToolManager with LLM context updates.
This module tests the full flow of:
1. CustomToolManager fetching and converting tool schemas
2. Setting those tools on the LLM context
3. Verifying the context is properly configured for LLM generation
"""
from unittest.mock import AsyncMock, patch
import pytest
from api.services.workflow.pipecat_engine_custom_tools import (
CustomToolManager,
get_function_schema,
)
from api.tests.conftest import MockToolModel
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.processors.aggregators.llm_context import LLMContext
def _update_llm_context(context, system_message, functions):
"""Inline helper replicating the update_llm_context logic for tests."""
tools_schema = ToolsSchema(standard_tools=functions)
previous_interactions = context.messages
if previous_interactions and previous_interactions[0]["role"] == "system":
messages = [system_message] + previous_interactions[1:]
else:
messages = [system_message] + previous_interactions
context.set_messages(messages)
if functions:
context.set_tools(tools_schema)
class TestCustomToolManagerContextIntegration:
"""Integration tests for CustomToolManager with LLMContext."""
@pytest.mark.asyncio
async def test_get_tool_schemas_and_update_context(self, mock_engine, sample_tools):
"""Test fetching tool schemas via CustomToolManager and updating LLM context."""
manager = CustomToolManager(mock_engine)
with patch(
"api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run"
) as mock_get_org:
mock_get_org.return_value = 1
with patch(
"api.services.workflow.pipecat_engine_custom_tools.db_client"
) as mock_db:
mock_db.get_tools_by_uuids = AsyncMock(return_value=sample_tools)
# Get tool schemas via CustomToolManager - now returns FunctionSchema objects
tool_uuids = ["weather-uuid-123", "booking-uuid-456", "lookup-uuid-789"]
schemas = await manager.get_tool_schemas(tool_uuids)
# Verify schemas were returned as FunctionSchema objects
assert len(schemas) == 3
assert all(isinstance(s, FunctionSchema) for s in schemas)
# Create context with conversation history
context = LLMContext()
context.set_messages(
[
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": "I need to check the weather and book an appointment.",
},
{
"role": "assistant",
"content": "I can help with both. Where would you like to check the weather?",
},
{"role": "user", "content": "San Francisco"},
]
)
# Update context with new system message and tools
# Now we can pass schemas directly since they're FunctionSchema objects
new_system = {
"role": "system",
"content": "You are a scheduling assistant with access to weather and booking tools.",
}
_update_llm_context(context, new_system, schemas)
# Verify context was updated correctly
messages = context.messages
assert len(messages) == 4
assert (
messages[0]["content"]
== "You are a scheduling assistant with access to weather and booking tools."
)
assert messages[1]["role"] == "user"
assert messages[3]["content"] == "San Francisco"
# Verify tools were set
tools = context.tools
assert tools is not None
assert len(tools.standard_tools) == 3
# Verify tool names
tool_names = {t.name for t in tools.standard_tools}
assert tool_names == {
"get_weather",
"book_appointment",
"customer_lookup",
}
@pytest.mark.asyncio
async def test_tool_schemas_have_correct_properties(
self, mock_engine, sample_tools
):
"""Test that tool schemas from CustomToolManager have correct parameter properties."""
manager = CustomToolManager(mock_engine)
with patch(
"api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run"
) as mock_get_org:
mock_get_org.return_value = 1
with patch(
"api.services.workflow.pipecat_engine_custom_tools.db_client"
) as mock_db:
mock_db.get_tools_by_uuids = AsyncMock(return_value=sample_tools)
schemas = await manager.get_tool_schemas(
["weather-uuid-123", "booking-uuid-456"]
)
# Find the booking schema - now using FunctionSchema attributes
booking_schema = next(
s for s in schemas if s.name == "book_appointment"
)
# Verify parameter properties
assert "customer_name" in booking_schema.properties
assert "date" in booking_schema.properties
assert "time" in booking_schema.properties
assert "notes" in booking_schema.properties
# Verify types
assert booking_schema.properties["customer_name"]["type"] == "string"
assert booking_schema.properties["date"]["type"] == "string"
# Verify required
assert "customer_name" in booking_schema.required
assert "date" in booking_schema.required
assert "time" in booking_schema.required
assert "notes" not in booking_schema.required
@pytest.mark.asyncio
async def test_context_update_with_builtin_and_custom_tools(
self, mock_engine, sample_tools
):
"""Test updating context with both built-in and custom tools."""
manager = CustomToolManager(mock_engine)
with patch(
"api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run"
) as mock_get_org:
mock_get_org.return_value = 1
with patch(
"api.services.workflow.pipecat_engine_custom_tools.db_client"
) as mock_db:
mock_db.get_tools_by_uuids = AsyncMock(
return_value=[sample_tools[0]]
) # Just weather
# Get custom tool schemas - returns FunctionSchema objects
custom_schemas = await manager.get_tool_schemas(["weather-uuid-123"])
# Create built-in function schemas (like calculator, timezone)
builtin_functions = [
get_function_schema(
"safe_calculator",
"Evaluate a mathematical expression safely",
properties={
"expression": {
"type": "string",
"description": "Mathematical expression to evaluate",
}
},
required=["expression"],
),
get_function_schema(
"get_current_time",
"Get the current time in a timezone",
properties={
"timezone": {
"type": "string",
"description": "Timezone name (e.g., America/New_York)",
}
},
required=["timezone"],
),
]
# Combine built-in and custom functions - both are FunctionSchema objects
all_functions = builtin_functions + custom_schemas
# Update context
context = LLMContext()
context.set_messages([{"role": "system", "content": "Old prompt"}])
new_system = {
"role": "system",
"content": "Assistant with calculator and weather tools",
}
_update_llm_context(context, new_system, all_functions)
# Verify all tools are present
tools = context.tools
assert len(tools.standard_tools) == 3
tool_names = {t.name for t in tools.standard_tools}
assert "safe_calculator" in tool_names
assert "get_current_time" in tool_names
assert "get_weather" in tool_names
@pytest.mark.asyncio
async def test_context_preserves_function_call_history(
self, mock_engine, sample_tools
):
"""Test that update_llm_context preserves function call messages in history."""
manager = CustomToolManager(mock_engine)
with patch(
"api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run"
) as mock_get_org:
mock_get_org.return_value = 1
with patch(
"api.services.workflow.pipecat_engine_custom_tools.db_client"
) as mock_db:
mock_db.get_tools_by_uuids = AsyncMock(return_value=[sample_tools[0]])
# Get schemas - returns FunctionSchema objects
schemas = await manager.get_tool_schemas(["weather-uuid-123"])
# Create context with function call history
context = LLMContext()
context.set_messages(
[
{"role": "system", "content": "Old system prompt"},
{"role": "user", "content": "What's the weather in NYC?"},
{
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": "call_123",
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"location": "New York, NY"}',
},
}
],
},
{
"role": "tool",
"tool_call_id": "call_123",
"content": '{"temperature": 72, "condition": "sunny"}',
},
{
"role": "assistant",
"content": "The weather in NYC is 72°F and sunny!",
},
]
)
new_system = {"role": "system", "content": "Updated weather assistant"}
_update_llm_context(context, new_system, schemas)
messages = context.messages
# System + user + assistant(tool_call) + tool + assistant = 5
assert len(messages) == 5
# Verify function call messages are preserved
tool_call_msg = messages[2]
assert tool_call_msg["role"] == "assistant"
assert "tool_calls" in tool_call_msg
tool_result_msg = messages[3]
assert tool_result_msg["role"] == "tool"
assert tool_result_msg["tool_call_id"] == "call_123"
@pytest.mark.asyncio
async def test_empty_tool_list_does_not_set_tools(self, mock_engine):
"""Test that empty tool list doesn't set tools on context."""
manager = CustomToolManager(mock_engine)
with patch(
"api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run"
) as mock_get_org:
mock_get_org.return_value = 1
with patch(
"api.services.workflow.pipecat_engine_custom_tools.db_client"
) as mock_db:
mock_db.get_tools_by_uuids = AsyncMock(return_value=[])
schemas = await manager.get_tool_schemas([])
assert schemas == []
context = LLMContext()
context.set_messages([{"role": "system", "content": "Old"}])
new_system = {"role": "system", "content": "No tools available"}
_update_llm_context(context, new_system, [])
# Context should have updated message but no tools set
assert context.messages[0]["content"] == "No tools available"
@pytest.mark.asyncio
async def test_numeric_and_boolean_parameter_types(self, mock_engine):
"""Test that numeric and boolean parameter types are correctly handled."""
tool_with_types = MockToolModel(
tool_uuid="order-uuid",
name="Place Order",
description="Place an order for items",
definition={
"schema_version": 1,
"type": "http_api",
"config": {
"method": "POST",
"url": "https://api.example.com/orders",
"parameters": [
{
"name": "item_id",
"type": "string",
"description": "Item identifier",
"required": True,
},
{
"name": "quantity",
"type": "number",
"description": "Number of items",
"required": True,
},
{
"name": "express_shipping",
"type": "boolean",
"description": "Use express shipping",
"required": False,
},
],
},
},
)
manager = CustomToolManager(mock_engine)
with patch(
"api.services.workflow.pipecat_engine_custom_tools.get_organization_id_from_workflow_run"
) as mock_get_org:
mock_get_org.return_value = 1
with patch(
"api.services.workflow.pipecat_engine_custom_tools.db_client"
) as mock_db:
mock_db.get_tools_by_uuids = AsyncMock(return_value=[tool_with_types])
# Get schemas - returns FunctionSchema objects
schemas = await manager.get_tool_schemas(["order-uuid"])
schema = schemas[0]
# Verify types using FunctionSchema attributes
assert schema.properties["item_id"]["type"] == "string"
assert schema.properties["quantity"]["type"] == "number"
assert schema.properties["express_shipping"]["type"] == "boolean"
# Update context - pass schema directly
context = LLMContext()
context.set_messages([{"role": "system", "content": "Old"}])
_update_llm_context(
context, {"role": "system", "content": "Order assistant"}, schemas
)
# Verify tool was set with correct types
tool = context.tools.standard_tools[0]
assert tool.name == "place_order"
assert tool.properties["quantity"]["type"] == "number"
assert tool.properties["express_shipping"]["type"] == "boolean"