mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-30 02:46:23 +02:00
Merge branch 'release/v1.2'
This commit is contained in:
commit
0bff629f87
28 changed files with 3881 additions and 111 deletions
|
|
@ -305,6 +305,639 @@ Answer: The capital of France is Paris."""
|
|||
assert reasoning_plan[1]["action"] == "find_population"
|
||||
assert all("step" in step for step in reasoning_plan)
|
||||
|
||||
def test_multi_iteration_react_execution(self):
|
||||
"""Test complete multi-iteration ReACT cycle with sequential tool invocations
|
||||
|
||||
This test simulates a complex query that requires:
|
||||
1. Tool #1: Search for initial information
|
||||
2. Tool #2: Analyze/refine based on Tool #1's output
|
||||
3. Tool #3: Generate final answer using accumulated context
|
||||
|
||||
Each iteration includes Think -> Act -> Observe phases with
|
||||
observations feeding into subsequent thinking phases.
|
||||
"""
|
||||
# Arrange
|
||||
question = "Find the GDP of the capital of Japan and compare it to Tokyo's population"
|
||||
|
||||
# Mock tools that build on each other's outputs
|
||||
tool_invocation_log = []
|
||||
|
||||
def mock_geo_search(query):
|
||||
"""Tool 1: Geographic information search"""
|
||||
tool_invocation_log.append(("geo_search", query))
|
||||
if "capital" in query.lower() and "japan" in query.lower():
|
||||
return {"city": "Tokyo", "country": "Japan", "is_capital": True}
|
||||
return {"error": "Location not found"}
|
||||
|
||||
def mock_economic_data(query, context=None):
|
||||
"""Tool 2: Economic data retrieval (uses context from Tool 1)"""
|
||||
tool_invocation_log.append(("economic_data", query, context))
|
||||
if context and context.get("city") == "Tokyo":
|
||||
return {"city": "Tokyo", "gdp_trillion_yen": 115.7, "year": 2023}
|
||||
return {"error": "Economic data not available"}
|
||||
|
||||
def mock_demographic_data(query, context=None):
|
||||
"""Tool 3: Demographic data and comparison (uses context from Tools 1 & 2)"""
|
||||
tool_invocation_log.append(("demographic_data", query, context))
|
||||
if context and context.get("city") == "Tokyo":
|
||||
population_millions = 14.0
|
||||
gdp_from_context = context.get("gdp_trillion_yen", 0)
|
||||
return {
|
||||
"city": "Tokyo",
|
||||
"population_millions": population_millions,
|
||||
"gdp_trillion_yen": gdp_from_context,
|
||||
"gdp_per_capita_million_yen": round(gdp_from_context / population_millions, 2) if population_millions > 0 else 0
|
||||
}
|
||||
return {"error": "Demographic data not available"}
|
||||
|
||||
# Execute multi-iteration ReACT cycle
|
||||
def execute_multi_iteration_react(question, tools):
|
||||
"""Execute a complete multi-iteration ReACT cycle"""
|
||||
iterations = []
|
||||
context = {}
|
||||
|
||||
# Iteration 1: Initial geographic search
|
||||
iteration_1 = {
|
||||
"iteration": 1,
|
||||
"think": "I need to first identify the capital of Japan to get its GDP",
|
||||
"act": {"tool": "geo_search", "query": "capital of Japan"},
|
||||
"observe": None
|
||||
}
|
||||
result_1 = tools["geo_search"](iteration_1["act"]["query"])
|
||||
iteration_1["observe"] = f"Found that {result_1['city']} is the capital of {result_1['country']}"
|
||||
context.update(result_1)
|
||||
iterations.append(iteration_1)
|
||||
|
||||
# Iteration 2: Get economic data using context from iteration 1
|
||||
iteration_2 = {
|
||||
"iteration": 2,
|
||||
"think": f"Now I know {context['city']} is the capital. I need to get its GDP data",
|
||||
"act": {"tool": "economic_data", "query": f"GDP of {context['city']}"},
|
||||
"observe": None
|
||||
}
|
||||
result_2 = tools["economic_data"](iteration_2["act"]["query"], context)
|
||||
iteration_2["observe"] = f"Retrieved GDP data: {result_2['gdp_trillion_yen']} trillion yen for {result_2['year']}"
|
||||
context.update(result_2)
|
||||
iterations.append(iteration_2)
|
||||
|
||||
# Iteration 3: Get demographic data and compare using accumulated context
|
||||
iteration_3 = {
|
||||
"iteration": 3,
|
||||
"think": f"I have the GDP ({context['gdp_trillion_yen']} trillion yen). Now I need population data to compare",
|
||||
"act": {"tool": "demographic_data", "query": f"population of {context['city']}"},
|
||||
"observe": None
|
||||
}
|
||||
result_3 = tools["demographic_data"](iteration_3["act"]["query"], context)
|
||||
iteration_3["observe"] = f"Population is {result_3['population_millions']} million. GDP per capita is {result_3['gdp_per_capita_million_yen']} million yen"
|
||||
context.update(result_3)
|
||||
iterations.append(iteration_3)
|
||||
|
||||
# Final answer synthesis
|
||||
final_answer = {
|
||||
"think": "I now have all the information needed to answer the question",
|
||||
"answer": f"Tokyo, the capital of Japan, has a GDP of {context['gdp_trillion_yen']} trillion yen and a population of {context['population_millions']} million people, resulting in a GDP per capita of {context['gdp_per_capita_million_yen']} million yen."
|
||||
}
|
||||
|
||||
return {
|
||||
"iterations": iterations,
|
||||
"final_answer": final_answer,
|
||||
"context": context,
|
||||
"tool_invocations": len(tool_invocation_log)
|
||||
}
|
||||
|
||||
tools = {
|
||||
"geo_search": mock_geo_search,
|
||||
"economic_data": mock_economic_data,
|
||||
"demographic_data": mock_demographic_data
|
||||
}
|
||||
|
||||
# Act
|
||||
result = execute_multi_iteration_react(question, tools)
|
||||
|
||||
# Assert - Verify complete multi-iteration execution
|
||||
assert len(result["iterations"]) == 3, "Should have exactly 3 iterations"
|
||||
|
||||
# Verify each iteration has complete Think-Act-Observe cycle
|
||||
for i, iteration in enumerate(result["iterations"], 1):
|
||||
assert iteration["iteration"] == i
|
||||
assert "think" in iteration and len(iteration["think"]) > 0
|
||||
assert "act" in iteration and "tool" in iteration["act"]
|
||||
assert "observe" in iteration and iteration["observe"] is not None
|
||||
|
||||
# Verify sequential tool invocations
|
||||
assert tool_invocation_log[0][0] == "geo_search"
|
||||
assert tool_invocation_log[1][0] == "economic_data"
|
||||
assert tool_invocation_log[2][0] == "demographic_data"
|
||||
|
||||
# Verify context accumulation across iterations
|
||||
assert "Tokyo" in tool_invocation_log[1][1], "Iteration 2 should use data from iteration 1"
|
||||
assert tool_invocation_log[2][2].get("gdp_trillion_yen") == 115.7, "Iteration 3 should have accumulated GDP data"
|
||||
|
||||
# Verify observations feed into subsequent thinking
|
||||
assert "Tokyo" in result["iterations"][1]["think"], "Iteration 2 thinking should reference observation from iteration 1"
|
||||
assert "115.7" in result["iterations"][2]["think"], "Iteration 3 thinking should reference GDP from iteration 2"
|
||||
|
||||
# Verify final answer synthesis
|
||||
assert "Tokyo" in result["final_answer"]["answer"]
|
||||
assert "115.7" in result["final_answer"]["answer"]
|
||||
assert "14.0" in result["final_answer"]["answer"]
|
||||
assert "8.26" in result["final_answer"]["answer"], "Should include calculated GDP per capita"
|
||||
|
||||
# Verify all 3 tools were invoked in sequence
|
||||
assert result["tool_invocations"] == 3
|
||||
|
||||
def test_multi_iteration_with_dynamic_tool_selection(self):
|
||||
"""Test multi-iteration ReACT with mocked LLM reasoning dynamically selecting tools
|
||||
|
||||
This test simulates how an LLM would dynamically choose tools based on:
|
||||
1. The original question
|
||||
2. Previous observations
|
||||
3. Accumulated context
|
||||
|
||||
The mocked LLM reasoning adapts its tool selection based on what it has learned
|
||||
in previous iterations, mimicking real agent behavior.
|
||||
"""
|
||||
# Arrange
|
||||
question = "What are the main exports of the largest city in Brazil by population?"
|
||||
|
||||
# Track reasoning and tool selection
|
||||
reasoning_log = []
|
||||
tool_invocation_log = []
|
||||
|
||||
def mock_llm_reasoning(question, history, available_tools):
|
||||
"""Mock LLM that reasons about tool selection based on context"""
|
||||
# Analyze what we know from history
|
||||
context = {}
|
||||
for step in history:
|
||||
if "observation" in step:
|
||||
# Extract information from observations
|
||||
obs = step["observation"]
|
||||
if "São Paulo" in obs:
|
||||
context["city"] = "São Paulo"
|
||||
if "largest city" in obs:
|
||||
context["is_largest"] = True
|
||||
if "million" in obs and "population" in obs:
|
||||
context["has_population"] = True
|
||||
if "exports" in obs:
|
||||
context["has_exports"] = True
|
||||
|
||||
# Decide next action based on what we know
|
||||
if not context.get("city"):
|
||||
# Step 1: Need to find the largest city
|
||||
reasoning = "I need to find the largest city in Brazil by population"
|
||||
tool = "geo_search"
|
||||
args = {"query": "largest city Brazil population"}
|
||||
elif not context.get("has_population"):
|
||||
# Step 2: Confirm population data
|
||||
reasoning = f"I found {context['city']}. Now I need to verify it's the largest by checking population"
|
||||
tool = "demographic_data"
|
||||
args = {"query": f"population {context['city']} Brazil"}
|
||||
elif not context.get("has_exports"):
|
||||
# Step 3: Get export information
|
||||
reasoning = f"Confirmed {context['city']} is the largest. Now I need export information"
|
||||
tool = "economic_data"
|
||||
args = {"query": f"main exports {context['city']} Brazil"}
|
||||
else:
|
||||
# Final: Have all information
|
||||
reasoning = "I have all the information needed to answer"
|
||||
tool = "final_answer"
|
||||
args = None
|
||||
|
||||
reasoning_log.append({"reasoning": reasoning, "tool": tool, "context": context.copy()})
|
||||
return reasoning, tool, args
|
||||
|
||||
def mock_geo_search(query):
|
||||
"""Mock geographic search tool"""
|
||||
tool_invocation_log.append(("geo_search", query))
|
||||
if "largest city brazil" in query.lower():
|
||||
return {
|
||||
"result": "São Paulo is the largest city in Brazil",
|
||||
"details": {"city": "São Paulo", "country": "Brazil", "rank": 1}
|
||||
}
|
||||
return {"error": "No results found"}
|
||||
|
||||
def mock_demographic_data(query):
|
||||
"""Mock demographic data tool"""
|
||||
tool_invocation_log.append(("demographic_data", query))
|
||||
if "são paulo" in query.lower():
|
||||
return {
|
||||
"result": "São Paulo has a population of 12.4 million in the city proper, 22.8 million in the metro area",
|
||||
"details": {"city_population": 12.4, "metro_population": 22.8, "unit": "million"}
|
||||
}
|
||||
return {"error": "No demographic data found"}
|
||||
|
||||
def mock_economic_data(query):
|
||||
"""Mock economic data tool"""
|
||||
tool_invocation_log.append(("economic_data", query))
|
||||
if "são paulo" in query.lower() and "export" in query.lower():
|
||||
return {
|
||||
"result": "São Paulo's main exports include aircraft, vehicles, machinery, coffee, and soybeans",
|
||||
"details": {
|
||||
"top_exports": ["aircraft", "vehicles", "machinery", "coffee", "soybeans"],
|
||||
"export_value_billions_usd": 65.2
|
||||
}
|
||||
}
|
||||
return {"error": "No economic data found"}
|
||||
|
||||
# Execute multi-iteration ReACT with dynamic tool selection
|
||||
def execute_dynamic_react(question, tools, llm_reasoner):
|
||||
"""Execute ReACT with dynamic LLM-based tool selection"""
|
||||
iterations = []
|
||||
history = []
|
||||
available_tools = list(tools.keys())
|
||||
|
||||
max_iterations = 4
|
||||
for i in range(max_iterations):
|
||||
# LLM reasons about next action
|
||||
reasoning, tool_name, args = llm_reasoner(question, history, available_tools)
|
||||
|
||||
if tool_name == "final_answer":
|
||||
# Agent has decided it has enough information
|
||||
final_answer = {
|
||||
"reasoning": reasoning,
|
||||
"answer": "São Paulo, Brazil's largest city with 12.4 million people, " +
|
||||
"has main exports including aircraft, vehicles, machinery, coffee, and soybeans."
|
||||
}
|
||||
break
|
||||
|
||||
# Execute selected tool
|
||||
iteration = {
|
||||
"iteration": i + 1,
|
||||
"think": reasoning,
|
||||
"act": {"tool": tool_name, "args": args},
|
||||
"observe": None
|
||||
}
|
||||
|
||||
# Get tool result
|
||||
if tool_name in tools:
|
||||
result = tools[tool_name](args["query"])
|
||||
iteration["observe"] = result.get("result", "No information found")
|
||||
else:
|
||||
iteration["observe"] = f"Tool {tool_name} not available"
|
||||
|
||||
iterations.append(iteration)
|
||||
|
||||
# Add to history for next iteration
|
||||
history.append({
|
||||
"thought": reasoning,
|
||||
"action": tool_name,
|
||||
"args": args,
|
||||
"observation": iteration["observe"]
|
||||
})
|
||||
|
||||
return {
|
||||
"iterations": iterations,
|
||||
"final_answer": final_answer if 'final_answer' in locals() else None,
|
||||
"reasoning_log": reasoning_log,
|
||||
"tool_invocations": len(tool_invocation_log)
|
||||
}
|
||||
|
||||
tools = {
|
||||
"geo_search": mock_geo_search,
|
||||
"demographic_data": mock_demographic_data,
|
||||
"economic_data": mock_economic_data
|
||||
}
|
||||
|
||||
# Act
|
||||
result = execute_dynamic_react(question, tools, mock_llm_reasoning)
|
||||
|
||||
# Assert - Verify dynamic multi-iteration execution
|
||||
assert len(result["iterations"]) == 3, "Should have 3 iterations before final answer"
|
||||
|
||||
# Verify reasoning adapts based on observations
|
||||
assert len(reasoning_log) == 4, "Should have 4 reasoning steps (3 tools + final)"
|
||||
|
||||
# Verify first iteration searches for largest city
|
||||
assert reasoning_log[0]["tool"] == "geo_search"
|
||||
assert "largest city" in reasoning_log[0]["reasoning"].lower()
|
||||
assert not reasoning_log[0]["context"].get("city")
|
||||
|
||||
# Verify second iteration uses city name from first observation
|
||||
assert reasoning_log[1]["tool"] == "demographic_data"
|
||||
assert "São Paulo" in reasoning_log[1]["reasoning"]
|
||||
assert reasoning_log[1]["context"]["city"] == "São Paulo"
|
||||
|
||||
# Verify third iteration builds on previous knowledge
|
||||
assert reasoning_log[2]["tool"] == "economic_data"
|
||||
assert "export" in reasoning_log[2]["reasoning"].lower()
|
||||
assert reasoning_log[2]["context"]["has_population"] is True
|
||||
|
||||
# Verify final reasoning has all information
|
||||
assert reasoning_log[3]["tool"] == "final_answer"
|
||||
assert reasoning_log[3]["context"]["has_exports"] is True
|
||||
|
||||
# Verify tool invocation sequence
|
||||
assert tool_invocation_log[0][0] == "geo_search"
|
||||
assert tool_invocation_log[1][0] == "demographic_data"
|
||||
assert tool_invocation_log[2][0] == "economic_data"
|
||||
|
||||
# Verify observations influence subsequent tool selection
|
||||
assert "São Paulo" in result["iterations"][1]["act"]["args"]["query"]
|
||||
assert "São Paulo" in result["iterations"][2]["act"]["args"]["query"]
|
||||
|
||||
# Verify final answer synthesizes all gathered information
|
||||
assert result["final_answer"] is not None
|
||||
assert "São Paulo" in result["final_answer"]["answer"]
|
||||
assert "12.4 million" in result["final_answer"]["answer"]
|
||||
assert "aircraft" in result["final_answer"]["answer"]
|
||||
assert "vehicles" in result["final_answer"]["answer"]
|
||||
|
||||
def test_action_name_with_quotes_handling(self):
|
||||
"""Test that action names with quotes are properly stripped
|
||||
|
||||
This test verifies the fix for when LLMs output action names wrapped
|
||||
in quotes, e.g., Action: "get_bank_balance" instead of Action: get_bank_balance
|
||||
"""
|
||||
# Arrange
|
||||
def parse_react_output(text):
|
||||
"""Parse ReAct format output into structured steps"""
|
||||
steps = []
|
||||
lines = text.strip().split('\n')
|
||||
|
||||
thought = None
|
||||
action = None
|
||||
args = None
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if line.startswith('Think:') or line.startswith('Thought:'):
|
||||
thought = line.split(':', 1)[1].strip()
|
||||
elif line.startswith('Action:'):
|
||||
action = line[7:].strip()
|
||||
# Strip quotes from action name - this is the fix being tested
|
||||
while action and action[0] == '"':
|
||||
action = action[1:]
|
||||
while action and action[-1] == '"':
|
||||
action = action[:-1]
|
||||
elif line.startswith('Args:'):
|
||||
# Simple args parsing for test
|
||||
args_text = line[5:].strip()
|
||||
if args_text:
|
||||
import json
|
||||
try:
|
||||
args = json.loads(args_text)
|
||||
except:
|
||||
args = {"raw": args_text}
|
||||
|
||||
return {
|
||||
"thought": thought,
|
||||
"action": action,
|
||||
"args": args
|
||||
}
|
||||
|
||||
# Test cases with various quote patterns
|
||||
test_cases = [
|
||||
# Normal case without quotes
|
||||
(
|
||||
'Thought: I need to check the bank balance\nAction: get_bank_balance\nArgs: {"account": "12345"}',
|
||||
"get_bank_balance"
|
||||
),
|
||||
# Single quotes around action name
|
||||
(
|
||||
'Thought: I need to check the bank balance\nAction: "get_bank_balance"\nArgs: {"account": "12345"}',
|
||||
"get_bank_balance"
|
||||
),
|
||||
# Multiple quotes (nested)
|
||||
(
|
||||
'Thought: I need to check the bank balance\nAction: ""get_bank_balance""\nArgs: {"account": "12345"}',
|
||||
"get_bank_balance"
|
||||
),
|
||||
# Action with underscores and quotes
|
||||
(
|
||||
'Thought: I need to search\nAction: "search_knowledge_base"\nArgs: {"query": "test"}',
|
||||
"search_knowledge_base"
|
||||
),
|
||||
# Action with hyphens and quotes
|
||||
(
|
||||
'Thought: I need to search\nAction: "search-knowledge-base"\nArgs: {"query": "test"}',
|
||||
"search-knowledge-base"
|
||||
),
|
||||
# Edge case: just quotes (should result in empty string)
|
||||
(
|
||||
'Thought: Error case\nAction: ""\nArgs: {}',
|
||||
""
|
||||
),
|
||||
# Mixed quotes at start and end
|
||||
(
|
||||
'Thought: Processing\nAction: """complex_tool"""\nArgs: {}',
|
||||
"complex_tool"
|
||||
),
|
||||
]
|
||||
|
||||
# Act & Assert
|
||||
for llm_output, expected_action in test_cases:
|
||||
result = parse_react_output(llm_output)
|
||||
assert result["action"] == expected_action, \
|
||||
f"Failed to parse action correctly from: {llm_output}\nExpected: {expected_action}, Got: {result['action']}"
|
||||
|
||||
# Test with actual tool matching
|
||||
tools = {
|
||||
"get_bank_balance": {"description": "Get bank balance"},
|
||||
"search_knowledge_base": {"description": "Search knowledge"},
|
||||
"complex_tool": {"description": "Complex operations"}
|
||||
}
|
||||
|
||||
# Simulate tool lookup with quoted action names
|
||||
quoted_actions = [
|
||||
'"get_bank_balance"',
|
||||
'""search_knowledge_base""',
|
||||
'complex_tool', # without quotes
|
||||
'"complex_tool"'
|
||||
]
|
||||
|
||||
for quoted_action in quoted_actions:
|
||||
# Strip quotes as the fix does
|
||||
clean_action = quoted_action
|
||||
while clean_action and clean_action[0] == '"':
|
||||
clean_action = clean_action[1:]
|
||||
while clean_action and clean_action[-1] == '"':
|
||||
clean_action = clean_action[:-1]
|
||||
|
||||
# Verify the cleaned action exists in tools (except empty string case)
|
||||
if clean_action:
|
||||
assert clean_action in tools, \
|
||||
f"Cleaned action '{clean_action}' from '{quoted_action}' should be in tools"
|
||||
|
||||
def test_mcp_tool_arguments_support(self):
|
||||
"""Test that MCP tools can be configured with arguments and expose them correctly
|
||||
|
||||
This test verifies the MCP tool arguments feature where:
|
||||
1. MCP tool configurations can specify arguments
|
||||
2. Configuration parsing extracts arguments correctly
|
||||
3. Arguments are structured properly for tool use
|
||||
"""
|
||||
# Define a simple Argument class for testing (mimics the real one)
|
||||
class TestArgument:
|
||||
def __init__(self, name, type, description):
|
||||
self.name = name
|
||||
self.type = type
|
||||
self.description = description
|
||||
|
||||
# Define a mock McpToolImpl that mimics the new functionality
|
||||
class MockMcpToolImpl:
|
||||
def __init__(self, context, mcp_tool_id, arguments=None):
|
||||
self.context = context
|
||||
self.mcp_tool_id = mcp_tool_id
|
||||
self.arguments = arguments or []
|
||||
|
||||
def get_arguments(self):
|
||||
return self.arguments
|
||||
|
||||
# Test 1: MCP tool with arguments
|
||||
test_arguments = [
|
||||
TestArgument(
|
||||
name="account_id",
|
||||
type="string",
|
||||
description="Bank account identifier"
|
||||
),
|
||||
TestArgument(
|
||||
name="date",
|
||||
type="string",
|
||||
description="Date for balance query (optional, format: YYYY-MM-DD)"
|
||||
)
|
||||
]
|
||||
|
||||
context_mock = lambda service_name: None
|
||||
mcp_tool_with_args = MockMcpToolImpl(
|
||||
context=context_mock,
|
||||
mcp_tool_id="get_bank_balance",
|
||||
arguments=test_arguments
|
||||
)
|
||||
|
||||
returned_args = mcp_tool_with_args.get_arguments()
|
||||
|
||||
# Verify arguments are stored and returned correctly
|
||||
assert len(returned_args) == 2
|
||||
assert returned_args[0].name == "account_id"
|
||||
assert returned_args[0].type == "string"
|
||||
assert returned_args[0].description == "Bank account identifier"
|
||||
assert returned_args[1].name == "date"
|
||||
assert returned_args[1].type == "string"
|
||||
assert "optional" in returned_args[1].description.lower()
|
||||
|
||||
# Test 2: MCP tool without arguments (backward compatibility)
|
||||
mcp_tool_no_args = MockMcpToolImpl(
|
||||
context=context_mock,
|
||||
mcp_tool_id="simple_tool"
|
||||
)
|
||||
|
||||
returned_args_empty = mcp_tool_no_args.get_arguments()
|
||||
assert len(returned_args_empty) == 0
|
||||
assert returned_args_empty == []
|
||||
|
||||
# Test 3: MCP tool with empty arguments list
|
||||
mcp_tool_empty_args = MockMcpToolImpl(
|
||||
context=context_mock,
|
||||
mcp_tool_id="another_tool",
|
||||
arguments=[]
|
||||
)
|
||||
|
||||
returned_args_explicit_empty = mcp_tool_empty_args.get_arguments()
|
||||
assert len(returned_args_explicit_empty) == 0
|
||||
assert returned_args_explicit_empty == []
|
||||
|
||||
# Test 4: Configuration parsing simulation
|
||||
def simulate_config_parsing(config_data):
|
||||
"""Simulate how service.py parses MCP tool configuration"""
|
||||
config_args = config_data.get("arguments", [])
|
||||
arguments = [
|
||||
TestArgument(
|
||||
name=arg.get("name"),
|
||||
type=arg.get("type"),
|
||||
description=arg.get("description")
|
||||
)
|
||||
for arg in config_args
|
||||
]
|
||||
return arguments
|
||||
|
||||
# Test configuration with arguments
|
||||
config_with_args = {
|
||||
"type": "mcp-tool",
|
||||
"name": "get_bank_balance",
|
||||
"description": "Get bank account balance",
|
||||
"mcp-tool": "get_bank_balance",
|
||||
"arguments": [
|
||||
{
|
||||
"name": "account_id",
|
||||
"type": "string",
|
||||
"description": "Bank account identifier"
|
||||
},
|
||||
{
|
||||
"name": "date",
|
||||
"type": "string",
|
||||
"description": "Date for balance query (optional)"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
parsed_args = simulate_config_parsing(config_with_args)
|
||||
assert len(parsed_args) == 2
|
||||
assert parsed_args[0].name == "account_id"
|
||||
assert parsed_args[1].name == "date"
|
||||
|
||||
# Test configuration without arguments
|
||||
config_without_args = {
|
||||
"type": "mcp-tool",
|
||||
"name": "simple_tool",
|
||||
"description": "Simple MCP tool",
|
||||
"mcp-tool": "simple_tool"
|
||||
}
|
||||
|
||||
parsed_args_empty = simulate_config_parsing(config_without_args)
|
||||
assert len(parsed_args_empty) == 0
|
||||
|
||||
# Test 5: Argument structure validation
|
||||
def validate_argument_structure(arg):
|
||||
"""Validate that an argument has required fields"""
|
||||
required_fields = ['name', 'type', 'description']
|
||||
return all(hasattr(arg, field) and getattr(arg, field) for field in required_fields)
|
||||
|
||||
# Validate all parsed arguments have proper structure
|
||||
for arg in parsed_args:
|
||||
assert validate_argument_structure(arg), f"Argument {arg.name} missing required fields"
|
||||
|
||||
# Test 6: Prompt template integration simulation
|
||||
def simulate_prompt_template_rendering(tools):
|
||||
"""Simulate how agent prompts include tool arguments"""
|
||||
tool_descriptions = []
|
||||
|
||||
for tool in tools:
|
||||
tool_desc = f"- **{tool.name}**: {tool.description}"
|
||||
|
||||
# Add argument details if present
|
||||
for arg in tool.arguments:
|
||||
tool_desc += f"\n - Required: `\"{arg.name}\"` ({arg.type}): {arg.description}"
|
||||
|
||||
tool_descriptions.append(tool_desc)
|
||||
|
||||
return "\n".join(tool_descriptions)
|
||||
|
||||
# Create mock tools with our MCP tool
|
||||
class MockTool:
|
||||
def __init__(self, name, description, arguments):
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.arguments = arguments
|
||||
|
||||
mock_tools = [
|
||||
MockTool("search", "Search the web", []), # Tool without arguments
|
||||
MockTool("get_bank_balance", "Get bank account balance", parsed_args) # MCP tool with arguments
|
||||
]
|
||||
|
||||
prompt_section = simulate_prompt_template_rendering(mock_tools)
|
||||
|
||||
# Verify the prompt includes MCP tool arguments
|
||||
assert "get_bank_balance" in prompt_section
|
||||
assert "account_id" in prompt_section
|
||||
assert "Bank account identifier" in prompt_section
|
||||
assert "date" in prompt_section
|
||||
assert "(string)" in prompt_section
|
||||
assert "Required:" in prompt_section
|
||||
|
||||
# Verify tools without arguments still work
|
||||
assert "search" in prompt_section
|
||||
assert "Search the web" in prompt_section
|
||||
|
||||
def test_error_handling_in_react_cycle(self):
|
||||
"""Test error handling during ReAct execution"""
|
||||
# Arrange
|
||||
|
|
@ -474,4 +1107,4 @@ Answer: The capital of France is Paris."""
|
|||
assert "error" in result
|
||||
else:
|
||||
assert result["tool_name"] == expected_tool
|
||||
assert result["parameters"] == expected_params
|
||||
assert result["parameters"] == expected_params
|
||||
|
|
|
|||
458
tests/unit/test_cli/test_config_commands.py
Normal file
458
tests/unit/test_cli/test_config_commands.py
Normal file
|
|
@ -0,0 +1,458 @@
|
|||
"""
|
||||
Unit tests for CLI configuration commands.
|
||||
|
||||
Tests the business logic of list/get/put/delete config item commands
|
||||
while mocking the Config API.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
import sys
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from io import StringIO
|
||||
|
||||
from trustgraph.cli.list_config_items import list_config_items, main as list_main
|
||||
from trustgraph.cli.get_config_item import get_config_item, main as get_main
|
||||
from trustgraph.cli.put_config_item import put_config_item, main as put_main
|
||||
from trustgraph.cli.delete_config_item import delete_config_item, main as delete_main
|
||||
from trustgraph.api.types import ConfigKey, ConfigValue
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_api():
|
||||
"""Mock Api instance with config() method."""
|
||||
mock_api_instance = Mock()
|
||||
mock_config = Mock()
|
||||
mock_api_instance.config.return_value = mock_config
|
||||
return mock_api_instance, mock_config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_config_keys():
|
||||
"""Sample configuration keys."""
|
||||
return ["template-1", "template-2", "system-prompt"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_config_value():
|
||||
"""Sample configuration value."""
|
||||
return ConfigValue(
|
||||
type="prompt",
|
||||
key="template-1",
|
||||
value="You are a helpful assistant. Please respond to: {query}"
|
||||
)
|
||||
|
||||
|
||||
class TestListConfigItems:
|
||||
"""Test the list_config_items function."""
|
||||
|
||||
@patch('trustgraph.cli.list_config_items.Api')
|
||||
def test_list_config_items_text_format(self, mock_api_class, mock_api, sample_config_keys, capsys):
|
||||
"""Test listing config items in text format."""
|
||||
mock_api_class.return_value, mock_config = mock_api
|
||||
mock_config.list.return_value = sample_config_keys
|
||||
|
||||
list_config_items("http://test.com", "prompt", "text")
|
||||
|
||||
captured = capsys.readouterr()
|
||||
output_lines = captured.out.strip().split('\n')
|
||||
|
||||
assert len(output_lines) == 3
|
||||
assert "template-1" in output_lines
|
||||
assert "template-2" in output_lines
|
||||
assert "system-prompt" in output_lines
|
||||
|
||||
mock_config.list.assert_called_once_with("prompt")
|
||||
|
||||
@patch('trustgraph.cli.list_config_items.Api')
|
||||
def test_list_config_items_json_format(self, mock_api_class, mock_api, sample_config_keys, capsys):
|
||||
"""Test listing config items in JSON format."""
|
||||
mock_api_class.return_value, mock_config = mock_api
|
||||
mock_config.list.return_value = sample_config_keys
|
||||
|
||||
list_config_items("http://test.com", "prompt", "json")
|
||||
|
||||
captured = capsys.readouterr()
|
||||
output = json.loads(captured.out.strip())
|
||||
|
||||
assert output == sample_config_keys
|
||||
mock_config.list.assert_called_once_with("prompt")
|
||||
|
||||
@patch('trustgraph.cli.list_config_items.Api')
|
||||
def test_list_config_items_empty_list(self, mock_api_class, mock_api, capsys):
|
||||
"""Test listing when no config items exist."""
|
||||
mock_api_class.return_value, mock_config = mock_api
|
||||
mock_config.list.return_value = []
|
||||
|
||||
list_config_items("http://test.com", "nonexistent", "text")
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out.strip() == ""
|
||||
|
||||
mock_config.list.assert_called_once_with("nonexistent")
|
||||
|
||||
def test_list_main_parses_args_correctly(self):
|
||||
"""Test that list main() parses arguments correctly."""
|
||||
test_args = [
|
||||
'tg-list-config-items',
|
||||
'--type', 'prompt',
|
||||
'--format', 'json',
|
||||
'--api-url', 'http://custom.com'
|
||||
]
|
||||
|
||||
with patch('sys.argv', test_args), \
|
||||
patch('trustgraph.cli.list_config_items.list_config_items') as mock_list:
|
||||
|
||||
list_main()
|
||||
|
||||
mock_list.assert_called_once_with(
|
||||
url='http://custom.com',
|
||||
config_type='prompt',
|
||||
format_type='json'
|
||||
)
|
||||
|
||||
def test_list_main_uses_defaults(self):
|
||||
"""Test that list main() uses default values."""
|
||||
test_args = [
|
||||
'tg-list-config-items',
|
||||
'--type', 'prompt'
|
||||
]
|
||||
|
||||
with patch('sys.argv', test_args), \
|
||||
patch('trustgraph.cli.list_config_items.list_config_items') as mock_list:
|
||||
|
||||
list_main()
|
||||
|
||||
mock_list.assert_called_once_with(
|
||||
url='http://localhost:8088/',
|
||||
config_type='prompt',
|
||||
format_type='text'
|
||||
)
|
||||
|
||||
|
||||
class TestGetConfigItem:
|
||||
"""Test the get_config_item function."""
|
||||
|
||||
@patch('trustgraph.cli.get_config_item.Api')
|
||||
def test_get_config_item_text_format(self, mock_api_class, mock_api, sample_config_value, capsys):
|
||||
"""Test getting config item in text format."""
|
||||
mock_api_class.return_value, mock_config = mock_api
|
||||
mock_config.get.return_value = [sample_config_value]
|
||||
|
||||
get_config_item("http://test.com", "prompt", "template-1", "text")
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out.strip() == sample_config_value.value
|
||||
|
||||
# Verify ConfigKey was constructed correctly
|
||||
call_args = mock_config.get.call_args[0][0]
|
||||
assert len(call_args) == 1
|
||||
config_key = call_args[0]
|
||||
assert config_key.type == "prompt"
|
||||
assert config_key.key == "template-1"
|
||||
|
||||
@patch('trustgraph.cli.get_config_item.Api')
|
||||
def test_get_config_item_json_format(self, mock_api_class, mock_api, sample_config_value, capsys):
|
||||
"""Test getting config item in JSON format."""
|
||||
mock_api_class.return_value, mock_config = mock_api
|
||||
mock_config.get.return_value = [sample_config_value]
|
||||
|
||||
get_config_item("http://test.com", "prompt", "template-1", "json")
|
||||
|
||||
captured = capsys.readouterr()
|
||||
output = json.loads(captured.out.strip())
|
||||
|
||||
assert output == sample_config_value.value
|
||||
mock_config.get.assert_called_once()
|
||||
|
||||
@patch('trustgraph.cli.get_config_item.Api')
|
||||
def test_get_config_item_not_found(self, mock_api_class, mock_api):
|
||||
"""Test getting non-existent config item raises exception."""
|
||||
mock_api_class.return_value, mock_config = mock_api
|
||||
mock_config.get.return_value = []
|
||||
|
||||
with pytest.raises(Exception, match="Configuration item not found"):
|
||||
get_config_item("http://test.com", "prompt", "nonexistent", "text")
|
||||
|
||||
def test_get_main_parses_args_correctly(self):
|
||||
"""Test that get main() parses arguments correctly."""
|
||||
test_args = [
|
||||
'tg-get-config-item',
|
||||
'--type', 'prompt',
|
||||
'--key', 'template-1',
|
||||
'--format', 'json',
|
||||
'--api-url', 'http://custom.com'
|
||||
]
|
||||
|
||||
with patch('sys.argv', test_args), \
|
||||
patch('trustgraph.cli.get_config_item.get_config_item') as mock_get:
|
||||
|
||||
get_main()
|
||||
|
||||
mock_get.assert_called_once_with(
|
||||
url='http://custom.com',
|
||||
config_type='prompt',
|
||||
key='template-1',
|
||||
format_type='json'
|
||||
)
|
||||
|
||||
|
||||
class TestPutConfigItem:
|
||||
"""Test the put_config_item function."""
|
||||
|
||||
@patch('trustgraph.cli.put_config_item.Api')
|
||||
def test_put_config_item_with_value(self, mock_api_class, mock_api, capsys):
|
||||
"""Test putting config item with command line value."""
|
||||
mock_api_class.return_value, mock_config = mock_api
|
||||
|
||||
put_config_item("http://test.com", "prompt", "new-template", "Custom prompt: {input}")
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "Configuration item set: prompt/new-template" in captured.out
|
||||
|
||||
# Verify ConfigValue was constructed correctly
|
||||
call_args = mock_config.put.call_args[0][0]
|
||||
assert len(call_args) == 1
|
||||
config_value = call_args[0]
|
||||
assert config_value.type == "prompt"
|
||||
assert config_value.key == "new-template"
|
||||
assert config_value.value == "Custom prompt: {input}"
|
||||
|
||||
@patch('trustgraph.cli.put_config_item.Api')
|
||||
def test_put_config_item_multiline_value(self, mock_api_class, mock_api):
|
||||
"""Test putting config item with multiline value."""
|
||||
mock_api_class.return_value, mock_config = mock_api
|
||||
|
||||
multiline_value = "Line 1\nLine 2\nLine 3"
|
||||
put_config_item("http://test.com", "prompt", "multiline-template", multiline_value)
|
||||
|
||||
call_args = mock_config.put.call_args[0][0]
|
||||
config_value = call_args[0]
|
||||
assert config_value.value == multiline_value
|
||||
|
||||
def test_put_main_with_value_arg(self):
|
||||
"""Test put main() with --value argument."""
|
||||
test_args = [
|
||||
'tg-put-config-item',
|
||||
'--type', 'prompt',
|
||||
'--key', 'new-template',
|
||||
'--value', 'Custom prompt: {input}',
|
||||
'--api-url', 'http://custom.com'
|
||||
]
|
||||
|
||||
with patch('sys.argv', test_args), \
|
||||
patch('trustgraph.cli.put_config_item.put_config_item') as mock_put:
|
||||
|
||||
put_main()
|
||||
|
||||
mock_put.assert_called_once_with(
|
||||
url='http://custom.com',
|
||||
config_type='prompt',
|
||||
key='new-template',
|
||||
value='Custom prompt: {input}'
|
||||
)
|
||||
|
||||
def test_put_main_with_stdin_arg(self):
|
||||
"""Test put main() with --stdin argument."""
|
||||
test_args = [
|
||||
'tg-put-config-item',
|
||||
'--type', 'prompt',
|
||||
'--key', 'stdin-template',
|
||||
'--stdin'
|
||||
]
|
||||
|
||||
stdin_content = "Content from stdin\nMultiple lines"
|
||||
|
||||
with patch('sys.argv', test_args), \
|
||||
patch('sys.stdin', StringIO(stdin_content)), \
|
||||
patch('trustgraph.cli.put_config_item.put_config_item') as mock_put:
|
||||
|
||||
put_main()
|
||||
|
||||
mock_put.assert_called_once_with(
|
||||
url='http://localhost:8088/',
|
||||
config_type='prompt',
|
||||
key='stdin-template',
|
||||
value=stdin_content
|
||||
)
|
||||
|
||||
def test_put_main_mutually_exclusive_args(self):
|
||||
"""Test that --value and --stdin are mutually exclusive."""
|
||||
test_args = [
|
||||
'tg-put-config-item',
|
||||
'--type', 'prompt',
|
||||
'--key', 'template',
|
||||
'--value', 'test',
|
||||
'--stdin'
|
||||
]
|
||||
|
||||
with patch('sys.argv', test_args):
|
||||
with pytest.raises(SystemExit):
|
||||
put_main()
|
||||
|
||||
|
||||
class TestDeleteConfigItem:
|
||||
"""Test the delete_config_item function."""
|
||||
|
||||
@patch('trustgraph.cli.delete_config_item.Api')
|
||||
def test_delete_config_item(self, mock_api_class, mock_api, capsys):
|
||||
"""Test deleting config item."""
|
||||
mock_api_class.return_value, mock_config = mock_api
|
||||
|
||||
delete_config_item("http://test.com", "prompt", "old-template")
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "Configuration item deleted: prompt/old-template" in captured.out
|
||||
|
||||
# Verify ConfigKey was constructed correctly
|
||||
call_args = mock_config.delete.call_args[0][0]
|
||||
assert len(call_args) == 1
|
||||
config_key = call_args[0]
|
||||
assert config_key.type == "prompt"
|
||||
assert config_key.key == "old-template"
|
||||
|
||||
def test_delete_main_parses_args_correctly(self):
|
||||
"""Test that delete main() parses arguments correctly."""
|
||||
test_args = [
|
||||
'tg-delete-config-item',
|
||||
'--type', 'prompt',
|
||||
'--key', 'old-template',
|
||||
'--api-url', 'http://custom.com'
|
||||
]
|
||||
|
||||
with patch('sys.argv', test_args), \
|
||||
patch('trustgraph.cli.delete_config_item.delete_config_item') as mock_delete:
|
||||
|
||||
delete_main()
|
||||
|
||||
mock_delete.assert_called_once_with(
|
||||
url='http://custom.com',
|
||||
config_type='prompt',
|
||||
key='old-template'
|
||||
)
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
"""Test error handling scenarios."""
|
||||
|
||||
@patch('trustgraph.cli.list_config_items.Api')
|
||||
def test_list_handles_api_exception(self, mock_api_class, capsys):
|
||||
"""Test that list command handles API exceptions."""
|
||||
mock_api_class.side_effect = Exception("API connection failed")
|
||||
|
||||
list_main_with_args(['--type', 'prompt'])
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "Exception: API connection failed" in captured.out
|
||||
|
||||
@patch('trustgraph.cli.get_config_item.Api')
|
||||
def test_get_handles_api_exception(self, mock_api_class, capsys):
|
||||
"""Test that get command handles API exceptions."""
|
||||
mock_api_class.side_effect = Exception("API connection failed")
|
||||
|
||||
get_main_with_args(['--type', 'prompt', '--key', 'test'])
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "Exception: API connection failed" in captured.out
|
||||
|
||||
@patch('trustgraph.cli.put_config_item.Api')
|
||||
def test_put_handles_api_exception(self, mock_api_class, capsys):
|
||||
"""Test that put command handles API exceptions."""
|
||||
mock_api_class.side_effect = Exception("API connection failed")
|
||||
|
||||
put_main_with_args(['--type', 'prompt', '--key', 'test', '--value', 'test'])
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "Exception: API connection failed" in captured.out
|
||||
|
||||
@patch('trustgraph.cli.delete_config_item.Api')
|
||||
def test_delete_handles_api_exception(self, mock_api_class, capsys):
|
||||
"""Test that delete command handles API exceptions."""
|
||||
mock_api_class.side_effect = Exception("API connection failed")
|
||||
|
||||
delete_main_with_args(['--type', 'prompt', '--key', 'test'])
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "Exception: API connection failed" in captured.out
|
||||
|
||||
|
||||
class TestDataValidation:
|
||||
"""Test data validation and edge cases."""
|
||||
|
||||
@patch('trustgraph.cli.get_config_item.Api')
|
||||
def test_get_empty_string_value(self, mock_api_class, mock_api, capsys):
|
||||
"""Test getting config item with empty string value."""
|
||||
mock_api_class.return_value, mock_config = mock_api
|
||||
empty_value = ConfigValue(type="prompt", key="empty", value="")
|
||||
mock_config.get.return_value = [empty_value]
|
||||
|
||||
get_config_item("http://test.com", "prompt", "empty", "text")
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out == "\n" # Just a newline from print()
|
||||
|
||||
@patch('trustgraph.cli.put_config_item.Api')
|
||||
def test_put_empty_string_value(self, mock_api_class, mock_api):
|
||||
"""Test putting config item with empty string value."""
|
||||
mock_api_class.return_value, mock_config = mock_api
|
||||
|
||||
put_config_item("http://test.com", "prompt", "empty", "")
|
||||
|
||||
call_args = mock_config.put.call_args[0][0]
|
||||
config_value = call_args[0]
|
||||
assert config_value.value == ""
|
||||
|
||||
@patch('trustgraph.cli.get_config_item.Api')
|
||||
def test_get_special_characters_value(self, mock_api_class, mock_api, capsys):
|
||||
"""Test getting config item with special characters."""
|
||||
mock_api_class.return_value, mock_config = mock_api
|
||||
special_value = ConfigValue(
|
||||
type="prompt",
|
||||
key="special",
|
||||
value="Special chars: äöü 中文 🌟 \"quotes\" 'apostrophes'"
|
||||
)
|
||||
mock_config.get.return_value = [special_value]
|
||||
|
||||
get_config_item("http://test.com", "prompt", "special", "text")
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "äöü 中文 🌟" in captured.out
|
||||
assert '"quotes"' in captured.out
|
||||
|
||||
|
||||
# Helper functions for testing main() with custom args
|
||||
def list_main_with_args(args):
|
||||
"""Helper to test list_main with custom arguments."""
|
||||
test_args = ['tg-list-config-items'] + args
|
||||
with patch('sys.argv', test_args):
|
||||
try:
|
||||
list_main()
|
||||
except SystemExit:
|
||||
pass
|
||||
|
||||
def get_main_with_args(args):
|
||||
"""Helper to test get_main with custom arguments."""
|
||||
test_args = ['tg-get-config-item'] + args
|
||||
with patch('sys.argv', test_args):
|
||||
try:
|
||||
get_main()
|
||||
except SystemExit:
|
||||
pass
|
||||
|
||||
def put_main_with_args(args):
|
||||
"""Helper to test put_main with custom arguments."""
|
||||
test_args = ['tg-put-config-item'] + args
|
||||
with patch('sys.argv', test_args):
|
||||
try:
|
||||
put_main()
|
||||
except SystemExit:
|
||||
pass
|
||||
|
||||
def delete_main_with_args(args):
|
||||
"""Helper to test delete_main with custom arguments."""
|
||||
test_args = ['tg-delete-config-item'] + args
|
||||
with patch('sys.argv', test_args):
|
||||
try:
|
||||
delete_main()
|
||||
except SystemExit:
|
||||
pass
|
||||
272
tests/unit/test_knowledge_graph/test_object_validation.py
Normal file
272
tests/unit/test_knowledge_graph/test_object_validation.py
Normal file
|
|
@ -0,0 +1,272 @@
|
|||
"""
|
||||
Unit tests for Object Validation Logic
|
||||
|
||||
Tests the validation logic for extracted objects against schemas,
|
||||
including handling of nested JSON format issues and field validation.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
from trustgraph.schema import RowSchema, Field
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cities_schema():
|
||||
"""Cities schema matching the production schema"""
|
||||
fields = []
|
||||
|
||||
# Create fields with proper attribute assignment
|
||||
f1 = Field()
|
||||
f1.name = "city"
|
||||
f1.type = "string"
|
||||
f1.primary = True
|
||||
f1.required = True
|
||||
f1.description = "City name"
|
||||
fields.append(f1)
|
||||
|
||||
f2 = Field()
|
||||
f2.name = "country"
|
||||
f2.type = "string"
|
||||
f2.primary = True
|
||||
f2.required = True
|
||||
f2.description = "Country name"
|
||||
fields.append(f2)
|
||||
|
||||
f3 = Field()
|
||||
f3.name = "population"
|
||||
f3.type = "integer"
|
||||
f3.primary = False
|
||||
f3.required = True
|
||||
f3.description = "Population count"
|
||||
fields.append(f3)
|
||||
|
||||
f4 = Field()
|
||||
f4.name = "climate"
|
||||
f4.type = "string"
|
||||
f4.primary = False
|
||||
f4.required = True
|
||||
f4.description = "Climate type"
|
||||
fields.append(f4)
|
||||
|
||||
f5 = Field()
|
||||
f5.name = "primary_language"
|
||||
f5.type = "string"
|
||||
f5.primary = False
|
||||
f5.required = True
|
||||
f5.description = "Primary language spoken"
|
||||
fields.append(f5)
|
||||
|
||||
f6 = Field()
|
||||
f6.name = "currency"
|
||||
f6.type = "string"
|
||||
f6.primary = False
|
||||
f6.required = True
|
||||
f6.description = "Currency used"
|
||||
fields.append(f6)
|
||||
|
||||
schema = RowSchema()
|
||||
schema.name = "Cities"
|
||||
schema.description = "City demographics"
|
||||
schema.fields = fields
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def validator():
|
||||
"""Create a mock processor with just the validation method"""
|
||||
from unittest.mock import MagicMock
|
||||
from trustgraph.extract.kg.objects.processor import Processor
|
||||
|
||||
# Create a mock processor
|
||||
mock_processor = MagicMock()
|
||||
|
||||
# Bind the validate_object method to the mock
|
||||
mock_processor.validate_object = Processor.validate_object.__get__(mock_processor, Processor)
|
||||
|
||||
return mock_processor
|
||||
|
||||
|
||||
class TestObjectValidation:
|
||||
"""Test cases for object validation logic"""
|
||||
|
||||
def test_valid_object_passes_validation(self, validator, cities_schema):
|
||||
"""Test that a valid object passes validation"""
|
||||
valid_obj = {
|
||||
"city": "Shanghai",
|
||||
"country": "China",
|
||||
"population": "30482140",
|
||||
"climate": "Humid subtropical",
|
||||
"primary_language": "Mandarin Chinese",
|
||||
"currency": "Chinese Yuan (CNY)"
|
||||
}
|
||||
|
||||
result = validator.validate_object(valid_obj, cities_schema, "Cities")
|
||||
assert result is True
|
||||
|
||||
def test_nested_json_format_fails_validation(self, validator, cities_schema):
|
||||
"""Test that nested JSON format is detected and fails validation"""
|
||||
nested_obj = {
|
||||
"Cities": '{"city": "Jakarta", "country": "Indonesia", "population": 11634078, "climate": "Tropical monsoon", "primary_language": "Indonesian", "currency": "Indonesian Rupiah (IDR)"}'
|
||||
}
|
||||
|
||||
result = validator.validate_object( nested_obj, cities_schema, "Cities")
|
||||
assert result is False
|
||||
|
||||
def test_missing_required_field_fails_validation(self, validator, cities_schema):
|
||||
"""Test that missing required field fails validation"""
|
||||
missing_field_obj = {
|
||||
"city": "London",
|
||||
"country": "UK",
|
||||
"population": "9000000",
|
||||
"climate": "Temperate",
|
||||
# Missing primary_language (required)
|
||||
"currency": "GBP"
|
||||
}
|
||||
|
||||
result = validator.validate_object( missing_field_obj, cities_schema, "Cities")
|
||||
assert result is False
|
||||
|
||||
def test_null_primary_key_fails_validation(self, validator, cities_schema):
|
||||
"""Test that null primary key field fails validation"""
|
||||
null_primary_obj = {
|
||||
"city": None, # Primary key is null
|
||||
"country": "France",
|
||||
"population": "2000000",
|
||||
"climate": "Mediterranean",
|
||||
"primary_language": "French",
|
||||
"currency": "EUR"
|
||||
}
|
||||
|
||||
result = validator.validate_object( null_primary_obj, cities_schema, "Cities")
|
||||
assert result is False
|
||||
|
||||
def test_missing_primary_key_fails_validation(self, validator, cities_schema):
|
||||
"""Test that missing primary key field fails validation"""
|
||||
missing_primary_obj = {
|
||||
# Missing city (primary key)
|
||||
"country": "Spain",
|
||||
"population": "3000000",
|
||||
"climate": "Mediterranean",
|
||||
"primary_language": "Spanish",
|
||||
"currency": "EUR"
|
||||
}
|
||||
|
||||
result = validator.validate_object( missing_primary_obj, cities_schema, "Cities")
|
||||
assert result is False
|
||||
|
||||
def test_invalid_integer_type_fails_validation(self, validator, cities_schema):
|
||||
"""Test that invalid integer value fails validation"""
|
||||
invalid_type_obj = {
|
||||
"city": "Tokyo",
|
||||
"country": "Japan",
|
||||
"population": "not_a_number", # Invalid integer
|
||||
"climate": "Humid subtropical",
|
||||
"primary_language": "Japanese",
|
||||
"currency": "JPY"
|
||||
}
|
||||
|
||||
result = validator.validate_object( invalid_type_obj, cities_schema, "Cities")
|
||||
assert result is False
|
||||
|
||||
def test_numeric_string_for_integer_passes_validation(self, validator, cities_schema):
|
||||
"""Test that numeric string for integer field passes validation"""
|
||||
numeric_string_obj = {
|
||||
"city": "Beijing",
|
||||
"country": "China",
|
||||
"population": "21540000", # String that can be converted to int
|
||||
"climate": "Continental",
|
||||
"primary_language": "Mandarin",
|
||||
"currency": "CNY"
|
||||
}
|
||||
|
||||
result = validator.validate_object( numeric_string_obj, cities_schema, "Cities")
|
||||
assert result is True
|
||||
|
||||
def test_integer_value_for_integer_field_passes_validation(self, validator, cities_schema):
|
||||
"""Test that actual integer value for integer field passes validation"""
|
||||
integer_obj = {
|
||||
"city": "Mumbai",
|
||||
"country": "India",
|
||||
"population": 20185064, # Actual integer
|
||||
"climate": "Tropical",
|
||||
"primary_language": "Hindi",
|
||||
"currency": "INR"
|
||||
}
|
||||
|
||||
result = validator.validate_object( integer_obj, cities_schema, "Cities")
|
||||
assert result is True
|
||||
|
||||
def test_non_dict_object_fails_validation(self, validator, cities_schema):
|
||||
"""Test that non-dictionary object fails validation"""
|
||||
non_dict_obj = "This is not a dictionary"
|
||||
|
||||
result = validator.validate_object( non_dict_obj, cities_schema, "Cities")
|
||||
assert result is False
|
||||
|
||||
def test_optional_field_missing_passes_validation(self, validator):
|
||||
"""Test that missing optional field passes validation"""
|
||||
# Create schema with optional field
|
||||
fields = [
|
||||
Field(name="id", type="string", primary=True, required=True),
|
||||
Field(name="name", type="string", required=True),
|
||||
Field(name="description", type="string", required=False), # Optional
|
||||
]
|
||||
schema = RowSchema(name="TestSchema", fields=fields)
|
||||
|
||||
obj = {
|
||||
"id": "123",
|
||||
"name": "Test Name",
|
||||
# description is missing but optional
|
||||
}
|
||||
|
||||
result = validator.validate_object( obj, schema, "TestSchema")
|
||||
assert result is True
|
||||
|
||||
def test_float_type_validation(self, validator):
|
||||
"""Test float type validation"""
|
||||
fields = [
|
||||
Field(name="id", type="string", primary=True, required=True),
|
||||
Field(name="price", type="float", required=True),
|
||||
]
|
||||
schema = RowSchema(name="Product", fields=fields)
|
||||
|
||||
# Valid float as string
|
||||
obj1 = {"id": "1", "price": "19.99"}
|
||||
assert validator.validate_object( obj1, schema, "Product") is True
|
||||
|
||||
# Valid float
|
||||
obj2 = {"id": "2", "price": 19.99}
|
||||
assert validator.validate_object( obj2, schema, "Product") is True
|
||||
|
||||
# Valid integer (can be float)
|
||||
obj3 = {"id": "3", "price": 20}
|
||||
assert validator.validate_object( obj3, schema, "Product") is True
|
||||
|
||||
# Invalid float
|
||||
obj4 = {"id": "4", "price": "not_a_float"}
|
||||
assert validator.validate_object( obj4, schema, "Product") is False
|
||||
|
||||
def test_boolean_type_validation(self, validator):
|
||||
"""Test boolean type validation"""
|
||||
fields = [
|
||||
Field(name="id", type="string", primary=True, required=True),
|
||||
Field(name="active", type="boolean", required=True),
|
||||
]
|
||||
schema = RowSchema(name="User", fields=fields)
|
||||
|
||||
# Valid boolean
|
||||
obj1 = {"id": "1", "active": True}
|
||||
assert validator.validate_object( obj1, schema, "User") is True
|
||||
|
||||
# Valid boolean as string
|
||||
obj2 = {"id": "2", "active": "true"}
|
||||
assert validator.validate_object( obj2, schema, "User") is True
|
||||
|
||||
# Valid boolean as integer
|
||||
obj3 = {"id": "3", "active": 1}
|
||||
assert validator.validate_object( obj3, schema, "User") is True
|
||||
|
||||
# Invalid boolean type
|
||||
obj4 = {"id": "4", "active": []}
|
||||
assert validator.validate_object( obj4, schema, "User") is False
|
||||
|
|
@ -188,16 +188,36 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert result.out_token == 0
|
||||
assert result.model == 'gemini-2.0-flash-001'
|
||||
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.google.auth.default')
|
||||
>>>>>>> release/v1.2
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
<<<<<<< HEAD
|
||||
async def test_processor_initialization_without_private_key(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
"""Test processor initialization without private key (should fail)"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
=======
|
||||
async def test_processor_initialization_without_private_key(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account, mock_auth_default):
|
||||
"""Test processor initialization without private key (uses default credentials)"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
# Mock google.auth.default() to return credentials and project ID
|
||||
mock_credentials = MagicMock()
|
||||
mock_auth_default.return_value = (mock_credentials, "test-project-123")
|
||||
|
||||
# Mock GenerativeModel
|
||||
mock_model = MagicMock()
|
||||
mock_generative_model.return_value = mock_model
|
||||
>>>>>>> release/v1.2
|
||||
|
||||
config = {
|
||||
'region': 'us-central1',
|
||||
|
|
@ -210,9 +230,22 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
<<<<<<< HEAD
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError, match="Private key file not specified"):
|
||||
processor = Processor(**config)
|
||||
=======
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gemini-2.0-flash-001'
|
||||
mock_auth_default.assert_called_once()
|
||||
mock_vertexai.init.assert_called_once_with(
|
||||
location='us-central1',
|
||||
project='test-project-123'
|
||||
)
|
||||
>>>>>>> release/v1.2
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
|
|
@ -292,12 +325,20 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
# Verify service account was called with custom key
|
||||
mock_service_account.Credentials.from_service_account_file.assert_called_once_with('custom-key.json')
|
||||
|
||||
<<<<<<< HEAD
|
||||
# Verify that parameters dict has the correct values (this is accessible)
|
||||
assert processor.parameters["temperature"] == 0.7
|
||||
assert processor.parameters["max_output_tokens"] == 4096
|
||||
assert processor.parameters["top_p"] == 1.0
|
||||
assert processor.parameters["top_k"] == 32
|
||||
assert processor.parameters["candidate_count"] == 1
|
||||
=======
|
||||
# Verify that api_params dict has the correct values (this is accessible)
|
||||
assert processor.api_params["temperature"] == 0.7
|
||||
assert processor.api_params["max_output_tokens"] == 4096
|
||||
assert processor.api_params["top_p"] == 1.0
|
||||
assert processor.api_params["top_k"] == 32
|
||||
>>>>>>> release/v1.2
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
|
|
@ -392,6 +433,61 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
# The prompt should be "" + "\n\n" + "" = "\n\n"
|
||||
assert call_args[0][0] == "\n\n"
|
||||
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.AnthropicVertex')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_anthropic_processor_initialization_with_private_key(self, mock_llm_init, mock_async_init, mock_service_account, mock_anthropic_vertex):
|
||||
"""Test Anthropic processor initialization with private key credentials"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
mock_credentials = MagicMock()
|
||||
mock_credentials.project_id = "test-project-456"
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
|
||||
# Mock AnthropicVertex
|
||||
mock_anthropic_client = MagicMock()
|
||||
mock_anthropic_vertex.return_value = mock_anthropic_client
|
||||
|
||||
config = {
|
||||
'region': 'us-west1',
|
||||
'model': 'claude-3-sonnet@20240229', # Anthropic model
|
||||
'temperature': 0.5,
|
||||
'max_output': 2048,
|
||||
'private_key': 'anthropic-key.json',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-anthropic-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'claude-3-sonnet@20240229'
|
||||
assert processor.is_anthropic == True
|
||||
|
||||
# Verify service account was called with private key
|
||||
mock_service_account.Credentials.from_service_account_file.assert_called_once_with('anthropic-key.json')
|
||||
|
||||
# Verify AnthropicVertex was initialized with credentials
|
||||
mock_anthropic_vertex.assert_called_once_with(
|
||||
region='us-west1',
|
||||
project_id='test-project-456',
|
||||
credentials=mock_credentials
|
||||
)
|
||||
|
||||
# Verify api_params are set correctly
|
||||
assert processor.api_params["temperature"] == 0.5
|
||||
assert processor.api_params["max_output_tokens"] == 2048
|
||||
assert processor.api_params["top_p"] == 1.0
|
||||
assert processor.api_params["top_k"] == 32
|
||||
|
||||
>>>>>>> release/v1.2
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
Loading…
Add table
Add a link
Reference in a new issue