mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-06-03 20:05:13 +02:00
parent
a8e437fc7f
commit
6c7af8789d
216 changed files with 31360 additions and 1611 deletions
|
|
@ -757,7 +757,9 @@ Final Answer: {
|
|||
@pytest.mark.asyncio
|
||||
async def test_agent_manager_knowledge_query_collection_integration(self, mock_flow_context):
|
||||
"""Test agent manager integration with KnowledgeQueryImpl collection parameter"""
|
||||
# Arrange
|
||||
import functools
|
||||
|
||||
# Arrange - Use functools.partial like the real service does
|
||||
custom_tools = {
|
||||
"knowledge_query_custom": Tool(
|
||||
name="knowledge_query_custom",
|
||||
|
|
@ -769,7 +771,7 @@ Final Answer: {
|
|||
description="The question to ask"
|
||||
)
|
||||
],
|
||||
implementation=KnowledgeQueryImpl,
|
||||
implementation=functools.partial(KnowledgeQueryImpl, collection="research_papers"),
|
||||
config={"collection": "research_papers"}
|
||||
),
|
||||
"knowledge_query_default": Tool(
|
||||
|
|
@ -813,11 +815,13 @@ Args: {
|
|||
@pytest.mark.asyncio
|
||||
async def test_knowledge_query_multiple_collections(self, mock_flow_context):
|
||||
"""Test multiple KnowledgeQueryImpl instances with different collections"""
|
||||
# Arrange
|
||||
import functools
|
||||
|
||||
# Arrange - Create partial functions like the service does
|
||||
tools = {
|
||||
"general_kb": KnowledgeQueryImpl(mock_flow_context, collection="general"),
|
||||
"technical_kb": KnowledgeQueryImpl(mock_flow_context, collection="technical"),
|
||||
"research_kb": KnowledgeQueryImpl(mock_flow_context, collection="research")
|
||||
"general_kb": functools.partial(KnowledgeQueryImpl, collection="general")(mock_flow_context),
|
||||
"technical_kb": functools.partial(KnowledgeQueryImpl, collection="technical")(mock_flow_context),
|
||||
"research_kb": functools.partial(KnowledgeQueryImpl, collection="research")(mock_flow_context)
|
||||
}
|
||||
|
||||
# Act & Assert for each tool
|
||||
|
|
|
|||
482
tests/integration/test_agent_structured_query_integration.py
Normal file
482
tests/integration/test_agent_structured_query_integration.py
Normal file
|
|
@ -0,0 +1,482 @@
|
|||
"""
|
||||
Integration tests for React Agent with Structured Query Tool
|
||||
|
||||
These tests verify the end-to-end functionality of the React agent
|
||||
using the structured-query tool to query structured data with natural language.
|
||||
Following the TEST_STRATEGY.md approach for integration testing.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from trustgraph.schema import (
|
||||
AgentRequest, AgentResponse,
|
||||
StructuredQueryRequest, StructuredQueryResponse,
|
||||
Error
|
||||
)
|
||||
from trustgraph.agent.react.service import Processor
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestAgentStructuredQueryIntegration:
|
||||
"""Integration tests for React agent with structured query tool"""
|
||||
|
||||
@pytest.fixture
|
||||
def agent_processor(self):
|
||||
"""Create agent processor with structured query tool configured"""
|
||||
proc = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
pulsar_client=AsyncMock(),
|
||||
max_iterations=3
|
||||
)
|
||||
|
||||
# Mock the client method for structured query
|
||||
proc.client = MagicMock()
|
||||
|
||||
return proc
|
||||
|
||||
@pytest.fixture
|
||||
def structured_query_tool_config(self):
|
||||
"""Configuration for structured-query tool"""
|
||||
import json
|
||||
return {
|
||||
"tool": {
|
||||
"structured-query": json.dumps({
|
||||
"name": "structured-query",
|
||||
"description": "Query structured data using natural language",
|
||||
"type": "structured-query"
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_structured_query_basic_integration(self, agent_processor, structured_query_tool_config):
|
||||
"""Test basic agent integration with structured query tool"""
|
||||
# Arrange - Load tool configuration
|
||||
await agent_processor.on_tools_config(structured_query_tool_config, "v1")
|
||||
|
||||
# Create agent request
|
||||
request = AgentRequest(
|
||||
question="I need to find all customers from New York. Use the structured query tool to get this information.",
|
||||
state="",
|
||||
group=None,
|
||||
history=[],
|
||||
user="test_user"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = request
|
||||
msg.properties.return_value = {"id": "agent-test-001"}
|
||||
|
||||
consumer = MagicMock()
|
||||
|
||||
# Mock response producer for the flow
|
||||
response_producer = AsyncMock()
|
||||
|
||||
# Mock structured query response
|
||||
structured_query_response = {
|
||||
"data": json.dumps({
|
||||
"customers": [
|
||||
{"id": "1", "name": "John Doe", "email": "john@example.com", "state": "New York"},
|
||||
{"id": "2", "name": "Jane Smith", "email": "jane@example.com", "state": "New York"}
|
||||
]
|
||||
}),
|
||||
"errors": [],
|
||||
"error": None
|
||||
}
|
||||
|
||||
# Mock the structured query client
|
||||
mock_structured_client = AsyncMock()
|
||||
mock_structured_client.structured_query.return_value = structured_query_response
|
||||
|
||||
# Mock the prompt client that agent calls for reasoning
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.agent_react.return_value = """Thought: I need to find customers from New York using structured query
|
||||
Action: structured-query
|
||||
Args: {
|
||||
"question": "Find all customers from New York"
|
||||
}"""
|
||||
|
||||
# Set up flow context routing
|
||||
def flow_context(service_name):
|
||||
if service_name == "structured-query-request":
|
||||
return mock_structured_client
|
||||
elif service_name == "prompt-request":
|
||||
return mock_prompt_client
|
||||
elif service_name == "response":
|
||||
return response_producer
|
||||
else:
|
||||
return AsyncMock()
|
||||
|
||||
# Mock flow parameter in agent_processor.on_request
|
||||
flow = MagicMock()
|
||||
flow.side_effect = flow_context
|
||||
|
||||
# Act
|
||||
await agent_processor.on_request(msg, consumer, flow)
|
||||
|
||||
# Assert
|
||||
# Verify structured query was called
|
||||
mock_structured_client.structured_query.assert_called_once()
|
||||
call_args = mock_structured_client.structured_query.call_args
|
||||
# Check keyword arguments
|
||||
question_arg = call_args.kwargs.get("question") or call_args[1].get("question")
|
||||
assert "customers" in question_arg.lower()
|
||||
assert "new york" in question_arg.lower()
|
||||
|
||||
# Verify responses were sent (agent sends multiple responses for thought/observation)
|
||||
assert response_producer.send.call_count >= 1
|
||||
|
||||
# Check all the responses that were sent
|
||||
all_calls = response_producer.send.call_args_list
|
||||
responses = [call[0][0] for call in all_calls]
|
||||
|
||||
# Verify at least one response is of correct type and has no error
|
||||
assert any(isinstance(resp, AgentResponse) and resp.error is None for resp in responses)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_structured_query_error_handling(self, agent_processor, structured_query_tool_config):
|
||||
"""Test agent handling of structured query errors"""
|
||||
# Arrange
|
||||
await agent_processor.on_tools_config(structured_query_tool_config, "v1")
|
||||
|
||||
request = AgentRequest(
|
||||
question="Find data from a table that doesn't exist using structured query.",
|
||||
state="",
|
||||
group=None,
|
||||
history=[],
|
||||
user="test_user"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = request
|
||||
msg.properties.return_value = {"id": "agent-error-test"}
|
||||
|
||||
consumer = MagicMock()
|
||||
|
||||
# Mock response producer for the flow
|
||||
response_producer = AsyncMock()
|
||||
|
||||
# Mock structured query error response
|
||||
structured_query_error_response = {
|
||||
"data": None,
|
||||
"errors": ["Table 'nonexistent' not found in schema"],
|
||||
"error": {"type": "structured-query-error", "message": "Schema not found"}
|
||||
}
|
||||
|
||||
mock_structured_client = AsyncMock()
|
||||
mock_structured_client.structured_query.return_value = structured_query_error_response
|
||||
|
||||
# Mock the prompt client that agent calls for reasoning
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.agent_react.return_value = """Thought: I need to query for a table that might not exist
|
||||
Action: structured-query
|
||||
Args: {
|
||||
"question": "Find data from a table that doesn't exist"
|
||||
}"""
|
||||
|
||||
# Set up flow context routing
|
||||
def flow_context(service_name):
|
||||
if service_name == "structured-query-request":
|
||||
return mock_structured_client
|
||||
elif service_name == "prompt-request":
|
||||
return mock_prompt_client
|
||||
elif service_name == "response":
|
||||
return response_producer
|
||||
else:
|
||||
return AsyncMock()
|
||||
|
||||
flow = MagicMock()
|
||||
flow.side_effect = flow_context
|
||||
|
||||
# Act
|
||||
await agent_processor.on_request(msg, consumer, flow)
|
||||
|
||||
# Assert
|
||||
mock_structured_client.structured_query.assert_called_once()
|
||||
assert response_producer.send.call_count >= 1
|
||||
|
||||
all_calls = response_producer.send.call_args_list
|
||||
responses = [call[0][0] for call in all_calls]
|
||||
|
||||
# Agent should handle the error gracefully
|
||||
assert any(isinstance(resp, AgentResponse) for resp in responses)
|
||||
# The tool should have returned an error response that contains error info
|
||||
call_args = mock_structured_client.structured_query.call_args
|
||||
question_arg = call_args.kwargs.get("question") or call_args[1].get("question")
|
||||
assert "table" in question_arg.lower() or "exist" in question_arg.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_multi_step_structured_query_reasoning(self, agent_processor, structured_query_tool_config):
|
||||
"""Test agent using structured query in multi-step reasoning"""
|
||||
# Arrange
|
||||
await agent_processor.on_tools_config(structured_query_tool_config, "v1")
|
||||
|
||||
request = AgentRequest(
|
||||
question="First find all customers from California, then tell me how many orders they have made.",
|
||||
state="",
|
||||
group=None,
|
||||
history=[],
|
||||
user="test_user"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = request
|
||||
msg.properties.return_value = {"id": "agent-multi-step-test"}
|
||||
|
||||
consumer = MagicMock()
|
||||
|
||||
# Mock response producer for the flow
|
||||
response_producer = AsyncMock()
|
||||
|
||||
# Mock structured query response (just one for this test)
|
||||
customers_response = {
|
||||
"data": json.dumps({
|
||||
"customers": [
|
||||
{"id": "101", "name": "Alice Johnson", "state": "California"},
|
||||
{"id": "102", "name": "Bob Wilson", "state": "California"}
|
||||
]
|
||||
}),
|
||||
"errors": [],
|
||||
"error": None
|
||||
}
|
||||
|
||||
mock_structured_client = AsyncMock()
|
||||
mock_structured_client.structured_query.return_value = customers_response
|
||||
|
||||
# Mock the prompt client that agent calls for reasoning
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.agent_react.return_value = """Thought: I need to find customers from California first
|
||||
Action: structured-query
|
||||
Args: {
|
||||
"question": "Find all customers from California"
|
||||
}"""
|
||||
|
||||
# Set up flow context routing
|
||||
def flow_context(service_name):
|
||||
if service_name == "structured-query-request":
|
||||
return mock_structured_client
|
||||
elif service_name == "prompt-request":
|
||||
return mock_prompt_client
|
||||
elif service_name == "response":
|
||||
return response_producer
|
||||
else:
|
||||
return AsyncMock()
|
||||
|
||||
flow = MagicMock()
|
||||
flow.side_effect = flow_context
|
||||
|
||||
# Act
|
||||
await agent_processor.on_request(msg, consumer, flow)
|
||||
|
||||
# Assert
|
||||
# Should have made structured query call
|
||||
assert mock_structured_client.structured_query.call_count >= 1
|
||||
|
||||
assert response_producer.send.call_count >= 1
|
||||
all_calls = response_producer.send.call_args_list
|
||||
responses = [call[0][0] for call in all_calls]
|
||||
|
||||
assert any(isinstance(resp, AgentResponse) for resp in responses)
|
||||
# Verify the structured query was called with customer-related question
|
||||
call_args = mock_structured_client.structured_query.call_args
|
||||
question_arg = call_args.kwargs.get("question") or call_args[1].get("question")
|
||||
assert "california" in question_arg.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_structured_query_with_collection_parameter(self, agent_processor):
|
||||
"""Test structured query tool with collection parameter"""
|
||||
# Arrange - Configure tool with collection
|
||||
import json
|
||||
tool_config_with_collection = {
|
||||
"tool": {
|
||||
"structured-query": json.dumps({
|
||||
"name": "structured-query",
|
||||
"description": "Query structured data using natural language",
|
||||
"type": "structured-query",
|
||||
"collection": "sales_data"
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
await agent_processor.on_tools_config(tool_config_with_collection, "v1")
|
||||
|
||||
request = AgentRequest(
|
||||
question="Query the sales data for recent transactions.",
|
||||
state="",
|
||||
group=None,
|
||||
history=[],
|
||||
user="test_user"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = request
|
||||
msg.properties.return_value = {"id": "agent-collection-test"}
|
||||
|
||||
consumer = MagicMock()
|
||||
|
||||
# Mock response producer for the flow
|
||||
response_producer = AsyncMock()
|
||||
|
||||
# Mock structured query response
|
||||
sales_response = {
|
||||
"data": json.dumps({
|
||||
"transactions": [
|
||||
{"id": "tx1", "amount": 299.99, "date": "2024-01-15"},
|
||||
{"id": "tx2", "amount": 149.50, "date": "2024-01-16"}
|
||||
]
|
||||
}),
|
||||
"errors": [],
|
||||
"error": None
|
||||
}
|
||||
|
||||
mock_structured_client = AsyncMock()
|
||||
mock_structured_client.structured_query.return_value = sales_response
|
||||
|
||||
# Mock the prompt client that agent calls for reasoning
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.agent_react.return_value = """Thought: I need to query the sales data
|
||||
Action: structured-query
|
||||
Args: {
|
||||
"question": "Query the sales data for recent transactions"
|
||||
}"""
|
||||
|
||||
# Set up flow context routing
|
||||
def flow_context(service_name):
|
||||
if service_name == "structured-query-request":
|
||||
return mock_structured_client
|
||||
elif service_name == "prompt-request":
|
||||
return mock_prompt_client
|
||||
elif service_name == "response":
|
||||
return response_producer
|
||||
else:
|
||||
return AsyncMock()
|
||||
|
||||
flow = MagicMock()
|
||||
flow.side_effect = flow_context
|
||||
|
||||
# Act
|
||||
await agent_processor.on_request(msg, consumer, flow)
|
||||
|
||||
# Assert
|
||||
mock_structured_client.structured_query.assert_called_once()
|
||||
|
||||
# Verify the tool was configured with collection parameter
|
||||
# (Collection parameter is passed to tool constructor, not to query method)
|
||||
assert response_producer.send.call_count >= 1
|
||||
all_calls = response_producer.send.call_args_list
|
||||
responses = [call[0][0] for call in all_calls]
|
||||
|
||||
assert any(isinstance(resp, AgentResponse) for resp in responses)
|
||||
# Check the query was about sales/transactions
|
||||
call_args = mock_structured_client.structured_query.call_args
|
||||
question_arg = call_args.kwargs.get("question") or call_args[1].get("question")
|
||||
assert "sales" in question_arg.lower() or "transactions" in question_arg.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_structured_query_tool_argument_validation(self, agent_processor, structured_query_tool_config):
|
||||
"""Test that structured query tool arguments are properly validated"""
|
||||
# Arrange
|
||||
await agent_processor.on_tools_config(structured_query_tool_config, "v1")
|
||||
|
||||
# Check that the tool was registered with correct arguments
|
||||
tools = agent_processor.agent.tools
|
||||
assert "structured-query" in tools
|
||||
|
||||
structured_tool = tools["structured-query"]
|
||||
arguments = structured_tool.arguments
|
||||
|
||||
# Verify tool has the expected argument structure
|
||||
assert len(arguments) == 1
|
||||
question_arg = arguments[0]
|
||||
assert question_arg.name == "question"
|
||||
assert question_arg.type == "string"
|
||||
assert "structured data" in question_arg.description.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_structured_query_json_formatting(self, agent_processor, structured_query_tool_config):
|
||||
"""Test that structured query results are properly formatted for agent consumption"""
|
||||
# Arrange
|
||||
await agent_processor.on_tools_config(structured_query_tool_config, "v1")
|
||||
|
||||
request = AgentRequest(
|
||||
question="Get customer information and format it nicely.",
|
||||
state="",
|
||||
group=None,
|
||||
history=[],
|
||||
user="test_user"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = request
|
||||
msg.properties.return_value = {"id": "agent-format-test"}
|
||||
|
||||
consumer = MagicMock()
|
||||
|
||||
# Mock response producer for the flow
|
||||
response_producer = AsyncMock()
|
||||
|
||||
# Mock structured query response with complex data
|
||||
complex_response = {
|
||||
"data": json.dumps({
|
||||
"customers": [
|
||||
{
|
||||
"id": "c1",
|
||||
"name": "Enterprise Corp",
|
||||
"contact": {
|
||||
"email": "contact@enterprise.com",
|
||||
"phone": "555-0123"
|
||||
},
|
||||
"orders": [
|
||||
{"id": "o1", "total": 5000.00, "items": 15},
|
||||
{"id": "o2", "total": 3200.50, "items": 8}
|
||||
]
|
||||
}
|
||||
]
|
||||
}),
|
||||
"errors": [],
|
||||
"error": None
|
||||
}
|
||||
|
||||
mock_structured_client = AsyncMock()
|
||||
mock_structured_client.structured_query.return_value = complex_response
|
||||
|
||||
# Mock the prompt client that agent calls for reasoning
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.agent_react.return_value = """Thought: I need to get customer information
|
||||
Action: structured-query
|
||||
Args: {
|
||||
"question": "Get customer information and format it nicely"
|
||||
}"""
|
||||
|
||||
# Set up flow context routing
|
||||
def flow_context(service_name):
|
||||
if service_name == "structured-query-request":
|
||||
return mock_structured_client
|
||||
elif service_name == "prompt-request":
|
||||
return mock_prompt_client
|
||||
elif service_name == "response":
|
||||
return response_producer
|
||||
else:
|
||||
return AsyncMock()
|
||||
|
||||
flow = MagicMock()
|
||||
flow.side_effect = flow_context
|
||||
|
||||
# Act
|
||||
await agent_processor.on_request(msg, consumer, flow)
|
||||
|
||||
# Assert
|
||||
mock_structured_client.structured_query.assert_called_once()
|
||||
assert response_producer.send.call_count >= 1
|
||||
|
||||
# The tool should have properly formatted the JSON for agent consumption
|
||||
all_calls = response_producer.send.call_args_list
|
||||
responses = [call[0][0] for call in all_calls]
|
||||
assert any(isinstance(resp, AgentResponse) for resp in responses)
|
||||
|
||||
# Check that the query was about customer information
|
||||
call_args = mock_structured_client.structured_query.call_args
|
||||
question_arg = call_args.kwargs.get("question") or call_args[1].get("question")
|
||||
assert "customer" in question_arg.lower()
|
||||
453
tests/integration/test_cassandra_config_end_to_end.py
Normal file
453
tests/integration/test_cassandra_config_end_to_end.py
Normal file
|
|
@ -0,0 +1,453 @@
|
|||
"""
|
||||
End-to-end integration tests for Cassandra configuration.
|
||||
|
||||
Tests complete configuration flow from environment variables
|
||||
through processors to Cassandra connections.
|
||||
"""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, MagicMock, call
|
||||
from argparse import ArgumentParser
|
||||
|
||||
# Import processors that use Cassandra configuration
|
||||
from trustgraph.storage.triples.cassandra.write import Processor as TriplesWriter
|
||||
from trustgraph.storage.objects.cassandra.write import Processor as ObjectsWriter
|
||||
from trustgraph.query.triples.cassandra.service import Processor as TriplesQuery
|
||||
from trustgraph.storage.knowledge.store import Processor as KgStore
|
||||
|
||||
|
||||
class TestEndToEndConfigurationFlow:
|
||||
"""Test complete configuration flow from environment to processors."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.direct.cassandra_kg.Cluster')
|
||||
async def test_triples_writer_env_to_connection(self, mock_cluster):
|
||||
"""Test complete flow from environment variables to TrustGraph connection."""
|
||||
env_vars = {
|
||||
'CASSANDRA_HOST': 'integration-host1,integration-host2,integration-host3',
|
||||
'CASSANDRA_USERNAME': 'integration-user',
|
||||
'CASSANDRA_PASSWORD': 'integration-pass'
|
||||
}
|
||||
|
||||
mock_cluster_instance = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_cluster_instance.connect.return_value = mock_session
|
||||
mock_cluster.return_value = mock_cluster_instance
|
||||
|
||||
with patch.dict(os.environ, env_vars, clear=True):
|
||||
processor = TriplesWriter(taskgroup=MagicMock())
|
||||
|
||||
# Create a mock message to trigger TrustGraph creation
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'test_user'
|
||||
mock_message.metadata.collection = 'test_collection'
|
||||
mock_message.triples = []
|
||||
|
||||
# This should create TrustGraph with environment config
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify Cluster was created with correct hosts
|
||||
mock_cluster.assert_called_once()
|
||||
call_args = mock_cluster.call_args
|
||||
assert call_args.args[0] == ['integration-host1', 'integration-host2', 'integration-host3']
|
||||
assert 'auth_provider' in call_args.kwargs # Should have auth since credentials provided
|
||||
|
||||
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
|
||||
@patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider')
|
||||
def test_objects_writer_env_to_cluster_connection(self, mock_auth_provider, mock_cluster):
|
||||
"""Test complete flow from environment variables to Cassandra Cluster connection."""
|
||||
env_vars = {
|
||||
'CASSANDRA_HOST': 'obj-host1,obj-host2',
|
||||
'CASSANDRA_USERNAME': 'obj-user',
|
||||
'CASSANDRA_PASSWORD': 'obj-pass'
|
||||
}
|
||||
|
||||
mock_auth_instance = MagicMock()
|
||||
mock_auth_provider.return_value = mock_auth_instance
|
||||
mock_cluster_instance = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_cluster_instance.connect.return_value = mock_session
|
||||
mock_cluster.return_value = mock_cluster_instance
|
||||
|
||||
with patch.dict(os.environ, env_vars, clear=True):
|
||||
processor = ObjectsWriter(taskgroup=MagicMock())
|
||||
|
||||
# Trigger Cassandra connection
|
||||
processor.connect_cassandra()
|
||||
|
||||
# Verify auth provider was created with env vars
|
||||
mock_auth_provider.assert_called_once_with(
|
||||
username='obj-user',
|
||||
password='obj-pass'
|
||||
)
|
||||
|
||||
# Verify cluster was created with hosts from env and auth
|
||||
mock_cluster.assert_called_once()
|
||||
call_args = mock_cluster.call_args
|
||||
assert call_args.kwargs['contact_points'] == ['obj-host1', 'obj-host2']
|
||||
assert call_args.kwargs['auth_provider'] == mock_auth_instance
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.knowledge.store.KnowledgeTableStore')
|
||||
async def test_kg_store_env_to_table_store(self, mock_table_store):
|
||||
"""Test complete flow from environment variables to KnowledgeTableStore."""
|
||||
env_vars = {
|
||||
'CASSANDRA_HOST': 'kg-host1,kg-host2,kg-host3,kg-host4',
|
||||
'CASSANDRA_USERNAME': 'kg-user',
|
||||
'CASSANDRA_PASSWORD': 'kg-pass'
|
||||
}
|
||||
|
||||
mock_store_instance = MagicMock()
|
||||
mock_table_store.return_value = mock_store_instance
|
||||
|
||||
with patch.dict(os.environ, env_vars, clear=True):
|
||||
processor = KgStore(taskgroup=MagicMock())
|
||||
|
||||
# Verify KnowledgeTableStore was created with env config
|
||||
mock_table_store.assert_called_once_with(
|
||||
cassandra_host=['kg-host1', 'kg-host2', 'kg-host3', 'kg-host4'],
|
||||
cassandra_username='kg-user',
|
||||
cassandra_password='kg-pass',
|
||||
keyspace='knowledge'
|
||||
)
|
||||
|
||||
|
||||
class TestConfigurationPriorityEndToEnd:
|
||||
"""Test configuration priority chains end-to-end."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.direct.cassandra_kg.Cluster')
|
||||
async def test_cli_override_env_end_to_end(self, mock_cluster):
|
||||
"""Test that CLI parameters override environment variables end-to-end."""
|
||||
env_vars = {
|
||||
'CASSANDRA_HOST': 'env-host',
|
||||
'CASSANDRA_USERNAME': 'env-user',
|
||||
'CASSANDRA_PASSWORD': 'env-pass'
|
||||
}
|
||||
|
||||
mock_cluster_instance = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_cluster_instance.connect.return_value = mock_session
|
||||
mock_cluster.return_value = mock_cluster_instance
|
||||
|
||||
with patch.dict(os.environ, env_vars, clear=True):
|
||||
# CLI parameters should override environment
|
||||
processor = TriplesWriter(
|
||||
taskgroup=MagicMock(),
|
||||
cassandra_host='cli-host1,cli-host2',
|
||||
cassandra_username='cli-user',
|
||||
cassandra_password='cli-pass'
|
||||
)
|
||||
|
||||
# Trigger TrustGraph creation
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'test_user'
|
||||
mock_message.metadata.collection = 'test_collection'
|
||||
mock_message.triples = []
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Should use CLI parameters, not environment
|
||||
mock_cluster.assert_called_once()
|
||||
call_args = mock_cluster.call_args
|
||||
assert call_args.args[0] == ['cli-host1', 'cli-host2'] # From CLI
|
||||
assert 'auth_provider' in call_args.kwargs # Should have auth since credentials provided
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.knowledge.store.KnowledgeTableStore')
|
||||
async def test_partial_cli_with_env_fallback_end_to_end(self, mock_table_store):
|
||||
"""Test partial CLI parameters with environment fallback end-to-end."""
|
||||
env_vars = {
|
||||
'CASSANDRA_HOST': 'fallback-host1,fallback-host2',
|
||||
'CASSANDRA_USERNAME': 'fallback-user',
|
||||
'CASSANDRA_PASSWORD': 'fallback-pass'
|
||||
}
|
||||
|
||||
mock_store_instance = MagicMock()
|
||||
mock_table_store.return_value = mock_store_instance
|
||||
|
||||
with patch.dict(os.environ, env_vars, clear=True):
|
||||
# Only provide host via parameter, rest should fall back to env
|
||||
processor = KgStore(
|
||||
taskgroup=MagicMock(),
|
||||
cassandra_host='partial-host'
|
||||
# username and password not provided - should use env
|
||||
)
|
||||
|
||||
# Verify mixed configuration
|
||||
mock_table_store.assert_called_once_with(
|
||||
cassandra_host=['partial-host'], # From parameter
|
||||
cassandra_username='fallback-user', # From environment
|
||||
cassandra_password='fallback-pass', # From environment
|
||||
keyspace='knowledge'
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.direct.cassandra_kg.Cluster')
|
||||
async def test_no_config_defaults_end_to_end(self, mock_cluster):
|
||||
"""Test that defaults are used when no configuration provided end-to-end."""
|
||||
mock_cluster_instance = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_cluster_instance.connect.return_value = mock_session
|
||||
mock_cluster.return_value = mock_cluster_instance
|
||||
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
processor = TriplesQuery(taskgroup=MagicMock())
|
||||
|
||||
# Mock query to trigger TrustGraph creation
|
||||
mock_query = MagicMock()
|
||||
mock_query.user = 'default_user'
|
||||
mock_query.collection = 'default_collection'
|
||||
mock_query.s = None
|
||||
mock_query.p = None
|
||||
mock_query.o = None
|
||||
mock_query.limit = 100
|
||||
|
||||
# Mock the get_all method to return empty list
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_tg_instance.get_all.return_value = []
|
||||
processor.tg = mock_tg_instance
|
||||
|
||||
await processor.query_triples(mock_query)
|
||||
|
||||
# Should use defaults
|
||||
mock_cluster.assert_called_once()
|
||||
call_args = mock_cluster.call_args
|
||||
assert call_args.args[0] == ['cassandra'] # Default host
|
||||
assert 'auth_provider' not in call_args.kwargs # No auth with default config
|
||||
|
||||
|
||||
class TestNoBackwardCompatibilityEndToEnd:
|
||||
"""Test that backward compatibility with old parameter names is removed."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.direct.cassandra_kg.Cluster')
|
||||
async def test_old_graph_params_no_longer_work_end_to_end(self, mock_cluster):
|
||||
"""Test that old graph_* parameters no longer work end-to-end."""
|
||||
mock_cluster_instance = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_cluster_instance.connect.return_value = mock_session
|
||||
mock_cluster.return_value = mock_cluster_instance
|
||||
|
||||
# Use old parameter names (should be ignored)
|
||||
processor = TriplesWriter(
|
||||
taskgroup=MagicMock(),
|
||||
graph_host='legacy-host',
|
||||
graph_username='legacy-user',
|
||||
graph_password='legacy-pass'
|
||||
)
|
||||
|
||||
# Trigger TrustGraph creation
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'legacy_user'
|
||||
mock_message.metadata.collection = 'legacy_collection'
|
||||
mock_message.triples = []
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Should use defaults since old parameters are not recognized
|
||||
mock_cluster.assert_called_once()
|
||||
call_args = mock_cluster.call_args
|
||||
assert call_args.args[0] == ['cassandra'] # Default, not legacy-host
|
||||
assert 'auth_provider' not in call_args.kwargs # No auth since no valid credentials
|
||||
|
||||
@patch('trustgraph.storage.knowledge.store.KnowledgeTableStore')
|
||||
def test_old_cassandra_user_param_no_longer_works_end_to_end(self, mock_table_store):
|
||||
"""Test that old cassandra_user parameter no longer works."""
|
||||
mock_store_instance = MagicMock()
|
||||
mock_table_store.return_value = mock_store_instance
|
||||
|
||||
# Use old cassandra_user parameter (should be ignored)
|
||||
processor = KgStore(
|
||||
taskgroup=MagicMock(),
|
||||
cassandra_host='legacy-kg-host',
|
||||
cassandra_user='legacy-kg-user', # Old parameter name - not supported
|
||||
cassandra_password='legacy-kg-pass'
|
||||
)
|
||||
|
||||
# cassandra_user should be ignored, only cassandra_username works
|
||||
mock_table_store.assert_called_once_with(
|
||||
cassandra_host=['legacy-kg-host'],
|
||||
cassandra_username=None, # Should be None since cassandra_user is not recognized
|
||||
cassandra_password='legacy-kg-pass',
|
||||
keyspace='knowledge'
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.direct.cassandra_kg.Cluster')
|
||||
async def test_new_params_override_old_params_end_to_end(self, mock_cluster):
|
||||
"""Test that new parameters override old ones when both are present end-to-end."""
|
||||
mock_cluster_instance = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_cluster_instance.connect.return_value = mock_session
|
||||
mock_cluster.return_value = mock_cluster_instance
|
||||
|
||||
# Provide both old and new parameters
|
||||
processor = TriplesWriter(
|
||||
taskgroup=MagicMock(),
|
||||
cassandra_host='new-host',
|
||||
graph_host='old-host', # Should be ignored
|
||||
cassandra_username='new-user',
|
||||
graph_username='old-user', # Should be ignored
|
||||
cassandra_password='new-pass',
|
||||
graph_password='old-pass' # Should be ignored
|
||||
)
|
||||
|
||||
# Trigger TrustGraph creation
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'precedence_user'
|
||||
mock_message.metadata.collection = 'precedence_collection'
|
||||
mock_message.triples = []
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Should use new parameters, not old ones
|
||||
mock_cluster.assert_called_once()
|
||||
call_args = mock_cluster.call_args
|
||||
assert call_args.args[0] == ['new-host'] # New parameter wins
|
||||
assert 'auth_provider' in call_args.kwargs # Should have auth since credentials provided
|
||||
|
||||
|
||||
class TestMultipleHostsHandling:
|
||||
"""Test multiple Cassandra hosts handling end-to-end."""
|
||||
|
||||
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
|
||||
def test_multiple_hosts_passed_to_cluster(self, mock_cluster):
|
||||
"""Test that multiple hosts are correctly passed to Cassandra cluster."""
|
||||
env_vars = {
|
||||
'CASSANDRA_HOST': 'host1,host2,host3,host4,host5'
|
||||
}
|
||||
|
||||
mock_cluster_instance = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_cluster_instance.connect.return_value = mock_session
|
||||
mock_cluster.return_value = mock_cluster_instance
|
||||
|
||||
with patch.dict(os.environ, env_vars, clear=True):
|
||||
processor = ObjectsWriter(taskgroup=MagicMock())
|
||||
processor.connect_cassandra()
|
||||
|
||||
# Verify all hosts were passed to Cluster
|
||||
mock_cluster.assert_called_once()
|
||||
call_args = mock_cluster.call_args
|
||||
assert call_args.kwargs['contact_points'] == ['host1', 'host2', 'host3', 'host4', 'host5']
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.direct.cassandra_kg.Cluster')
|
||||
async def test_single_host_converted_to_list(self, mock_cluster):
|
||||
"""Test that single host is converted to list for TrustGraph."""
|
||||
mock_cluster_instance = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_cluster_instance.connect.return_value = mock_session
|
||||
mock_cluster.return_value = mock_cluster_instance
|
||||
|
||||
processor = TriplesWriter(taskgroup=MagicMock(), cassandra_host='single-host')
|
||||
|
||||
# Trigger TrustGraph creation
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'single_user'
|
||||
mock_message.metadata.collection = 'single_collection'
|
||||
mock_message.triples = []
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Single host should be converted to list
|
||||
mock_cluster.assert_called_once()
|
||||
call_args = mock_cluster.call_args
|
||||
assert call_args.args[0] == ['single-host'] # Converted to list
|
||||
assert 'auth_provider' not in call_args.kwargs # No auth since no credentials provided
|
||||
|
||||
def test_whitespace_handling_in_host_list(self):
|
||||
"""Test that whitespace in host lists is handled correctly."""
|
||||
from trustgraph.base.cassandra_config import resolve_cassandra_config
|
||||
|
||||
# Test various whitespace scenarios
|
||||
hosts1, _, _ = resolve_cassandra_config(host='host1, host2 , host3')
|
||||
assert hosts1 == ['host1', 'host2', 'host3']
|
||||
|
||||
hosts2, _, _ = resolve_cassandra_config(host='host1,host2,host3,')
|
||||
assert hosts2 == ['host1', 'host2', 'host3']
|
||||
|
||||
hosts3, _, _ = resolve_cassandra_config(host=' host1 , host2 ')
|
||||
assert hosts3 == ['host1', 'host2']
|
||||
|
||||
|
||||
class TestAuthenticationFlow:
|
||||
"""Test authentication configuration flow end-to-end."""
|
||||
|
||||
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
|
||||
@patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider')
|
||||
def test_authentication_enabled_when_both_credentials_provided(self, mock_auth_provider, mock_cluster):
|
||||
"""Test that authentication is enabled when both username and password are provided."""
|
||||
env_vars = {
|
||||
'CASSANDRA_HOST': 'auth-host',
|
||||
'CASSANDRA_USERNAME': 'auth-user',
|
||||
'CASSANDRA_PASSWORD': 'auth-secret'
|
||||
}
|
||||
|
||||
mock_auth_instance = MagicMock()
|
||||
mock_auth_provider.return_value = mock_auth_instance
|
||||
mock_cluster_instance = MagicMock()
|
||||
mock_cluster.return_value = mock_cluster_instance
|
||||
|
||||
with patch.dict(os.environ, env_vars, clear=True):
|
||||
processor = ObjectsWriter(taskgroup=MagicMock())
|
||||
processor.connect_cassandra()
|
||||
|
||||
# Auth provider should be created
|
||||
mock_auth_provider.assert_called_once_with(
|
||||
username='auth-user',
|
||||
password='auth-secret'
|
||||
)
|
||||
|
||||
# Cluster should be created with auth provider
|
||||
call_args = mock_cluster.call_args
|
||||
assert 'auth_provider' in call_args.kwargs
|
||||
assert call_args.kwargs['auth_provider'] == mock_auth_instance
|
||||
|
||||
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
|
||||
@patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider')
|
||||
def test_no_authentication_when_credentials_missing(self, mock_auth_provider, mock_cluster):
|
||||
"""Test that authentication is not used when credentials are missing."""
|
||||
env_vars = {
|
||||
'CASSANDRA_HOST': 'no-auth-host'
|
||||
# No username/password
|
||||
}
|
||||
|
||||
mock_cluster_instance = MagicMock()
|
||||
mock_cluster.return_value = mock_cluster_instance
|
||||
|
||||
with patch.dict(os.environ, env_vars, clear=True):
|
||||
processor = ObjectsWriter(taskgroup=MagicMock())
|
||||
processor.connect_cassandra()
|
||||
|
||||
# Auth provider should not be created
|
||||
mock_auth_provider.assert_not_called()
|
||||
|
||||
# Cluster should be created without auth provider
|
||||
call_args = mock_cluster.call_args
|
||||
assert 'auth_provider' not in call_args.kwargs
|
||||
|
||||
@patch('trustgraph.storage.objects.cassandra.write.Cluster')
|
||||
@patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider')
|
||||
def test_no_authentication_when_only_username_provided(self, mock_auth_provider, mock_cluster):
|
||||
"""Test that authentication is not used when only username is provided."""
|
||||
processor = ObjectsWriter(
|
||||
taskgroup=MagicMock(),
|
||||
cassandra_host='partial-auth-host',
|
||||
cassandra_username='partial-user'
|
||||
# No password
|
||||
)
|
||||
|
||||
mock_cluster_instance = MagicMock()
|
||||
mock_cluster.return_value = mock_cluster_instance
|
||||
|
||||
processor.connect_cassandra()
|
||||
|
||||
# Auth provider should not be created (needs both username AND password)
|
||||
mock_auth_provider.assert_not_called()
|
||||
|
||||
# Cluster should be created without auth provider
|
||||
call_args = mock_cluster.call_args
|
||||
assert 'auth_provider' not in call_args.kwargs
|
||||
|
|
@ -13,7 +13,7 @@ import time
|
|||
from unittest.mock import MagicMock
|
||||
|
||||
from .cassandra_test_helper import cassandra_container
|
||||
from trustgraph.direct.cassandra import TrustGraph
|
||||
from trustgraph.direct.cassandra_kg import KnowledgeGraph
|
||||
from trustgraph.storage.triples.cassandra.write import Processor as StorageProcessor
|
||||
from trustgraph.query.triples.cassandra.service import Processor as QueryProcessor
|
||||
from trustgraph.schema import Triple, Value, Metadata, Triples, TriplesQueryRequest
|
||||
|
|
@ -62,29 +62,29 @@ class TestCassandraIntegration:
|
|||
print("=" * 60)
|
||||
|
||||
# =====================================================
|
||||
# Test 1: Basic TrustGraph Operations
|
||||
# Test 1: Basic KnowledgeGraph Operations
|
||||
# =====================================================
|
||||
print("\n1. Testing basic TrustGraph operations...")
|
||||
|
||||
client = TrustGraph(
|
||||
print("\n1. Testing basic KnowledgeGraph operations...")
|
||||
|
||||
client = KnowledgeGraph(
|
||||
hosts=[host],
|
||||
keyspace="test_basic",
|
||||
table="test_table"
|
||||
keyspace="test_basic"
|
||||
)
|
||||
self.clients_to_close.append(client)
|
||||
|
||||
# Insert test data
|
||||
client.insert("http://example.org/alice", "knows", "http://example.org/bob")
|
||||
client.insert("http://example.org/alice", "age", "25")
|
||||
client.insert("http://example.org/bob", "age", "30")
|
||||
|
||||
collection = "test_collection"
|
||||
client.insert(collection, "http://example.org/alice", "knows", "http://example.org/bob")
|
||||
client.insert(collection, "http://example.org/alice", "age", "25")
|
||||
client.insert(collection, "http://example.org/bob", "age", "30")
|
||||
|
||||
# Test get_all
|
||||
all_results = list(client.get_all(limit=10))
|
||||
all_results = list(client.get_all(collection, limit=10))
|
||||
assert len(all_results) == 3
|
||||
print(f"✓ Stored and retrieved {len(all_results)} triples")
|
||||
|
||||
# Test get_s (subject query)
|
||||
alice_results = list(client.get_s("http://example.org/alice", limit=10))
|
||||
alice_results = list(client.get_s(collection, "http://example.org/alice", limit=10))
|
||||
assert len(alice_results) == 2
|
||||
alice_predicates = [r.p for r in alice_results]
|
||||
assert "knows" in alice_predicates
|
||||
|
|
@ -110,7 +110,7 @@ class TestCassandraIntegration:
|
|||
keyspace="test_storage",
|
||||
table="test_triples"
|
||||
)
|
||||
# Track the TrustGraph instance that will be created
|
||||
# Track the KnowledgeGraph instance that will be created
|
||||
self.storage_processor = storage_processor
|
||||
|
||||
# Create test message
|
||||
|
|
@ -202,7 +202,7 @@ class TestCassandraIntegration:
|
|||
# Debug: Check what was actually stored
|
||||
print("Debug: Checking what was stored for Alice...")
|
||||
direct_results = list(query_storage_processor.tg.get_s("http://example.org/alice", limit=10))
|
||||
print(f"Direct TrustGraph results: {len(direct_results)}")
|
||||
print(f"Direct KnowledgeGraph results: {len(direct_results)}")
|
||||
for result in direct_results:
|
||||
print(f" S=http://example.org/alice, P={result.p}, O={result.o}")
|
||||
|
||||
|
|
|
|||
470
tests/integration/test_import_export_graceful_shutdown.py
Normal file
470
tests/integration/test_import_export_graceful_shutdown.py
Normal file
|
|
@ -0,0 +1,470 @@
|
|||
"""Integration tests for import/export graceful shutdown functionality."""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from aiohttp import web, WSMsgType, ClientWebSocketResponse
|
||||
from trustgraph.gateway.dispatch.triples_import import TriplesImport
|
||||
from trustgraph.gateway.dispatch.triples_export import TriplesExport
|
||||
from trustgraph.gateway.running import Running
|
||||
from trustgraph.base.publisher import Publisher
|
||||
from trustgraph.base.subscriber import Subscriber
|
||||
|
||||
|
||||
class MockPulsarMessage:
|
||||
"""Mock Pulsar message for testing."""
|
||||
|
||||
def __init__(self, data, message_id="test-id"):
|
||||
self._data = data
|
||||
self._message_id = message_id
|
||||
self._properties = {"id": message_id}
|
||||
|
||||
def value(self):
|
||||
return self._data
|
||||
|
||||
def properties(self):
|
||||
return self._properties
|
||||
|
||||
|
||||
class MockWebSocket:
|
||||
"""Mock WebSocket for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
self.closed = False
|
||||
self._close_called = False
|
||||
|
||||
async def send_json(self, data):
|
||||
if self.closed:
|
||||
raise Exception("WebSocket is closed")
|
||||
self.messages.append(data)
|
||||
|
||||
async def close(self):
|
||||
self._close_called = True
|
||||
self.closed = True
|
||||
|
||||
def json(self):
|
||||
"""Mock message json() method."""
|
||||
return {
|
||||
"metadata": {
|
||||
"id": "test-id",
|
||||
"metadata": {},
|
||||
"user": "test-user",
|
||||
"collection": "test-collection"
|
||||
},
|
||||
"triples": [{"s": {"v": "subject", "e": False}, "p": {"v": "predicate", "e": False}, "o": {"v": "object", "e": False}}]
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pulsar_client():
|
||||
"""Mock Pulsar client for integration testing."""
|
||||
client = MagicMock()
|
||||
|
||||
# Mock producer
|
||||
producer = MagicMock()
|
||||
producer.send = MagicMock()
|
||||
producer.flush = MagicMock()
|
||||
producer.close = MagicMock()
|
||||
client.create_producer.return_value = producer
|
||||
|
||||
# Mock consumer
|
||||
consumer = MagicMock()
|
||||
consumer.receive = AsyncMock()
|
||||
consumer.acknowledge = MagicMock()
|
||||
consumer.negative_acknowledge = MagicMock()
|
||||
consumer.pause_message_listener = MagicMock()
|
||||
consumer.unsubscribe = MagicMock()
|
||||
consumer.close = MagicMock()
|
||||
client.subscribe.return_value = consumer
|
||||
|
||||
return client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_import_graceful_shutdown_integration():
|
||||
"""Test import path handles shutdown gracefully with real message flow."""
|
||||
mock_client = MagicMock()
|
||||
mock_producer = MagicMock()
|
||||
mock_client.create_producer.return_value = mock_producer
|
||||
|
||||
# Track sent messages
|
||||
sent_messages = []
|
||||
def track_send(message, properties=None):
|
||||
sent_messages.append((message, properties))
|
||||
|
||||
mock_producer.send.side_effect = track_send
|
||||
|
||||
ws = MockWebSocket()
|
||||
running = Running()
|
||||
|
||||
# Create import handler
|
||||
import_handler = TriplesImport(
|
||||
ws=ws,
|
||||
running=running,
|
||||
pulsar_client=mock_client,
|
||||
queue="test-triples-import"
|
||||
)
|
||||
|
||||
await import_handler.start()
|
||||
|
||||
# Send multiple messages rapidly
|
||||
messages = []
|
||||
for i in range(10):
|
||||
msg_data = {
|
||||
"metadata": {
|
||||
"id": f"msg-{i}",
|
||||
"metadata": {},
|
||||
"user": "test-user",
|
||||
"collection": "test-collection"
|
||||
},
|
||||
"triples": [{"s": {"v": f"subject-{i}", "e": False}, "p": {"v": "predicate", "e": False}, "o": {"v": f"object-{i}", "e": False}}]
|
||||
}
|
||||
messages.append(msg_data)
|
||||
|
||||
# Create mock message with json() method
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.json.return_value = msg_data
|
||||
|
||||
await import_handler.receive(mock_msg)
|
||||
|
||||
# Allow brief processing time
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Shutdown while messages may be in flight
|
||||
await import_handler.destroy()
|
||||
|
||||
# Verify all messages reached producer
|
||||
assert len(sent_messages) == 10
|
||||
|
||||
# Verify proper shutdown order was followed
|
||||
mock_producer.flush.assert_called_once()
|
||||
mock_producer.close.assert_called_once()
|
||||
|
||||
# Verify messages have correct content
|
||||
for i, (message, properties) in enumerate(sent_messages):
|
||||
assert message.metadata.id == f"msg-{i}"
|
||||
assert len(message.triples) == 1
|
||||
assert message.triples[0].s.value == f"subject-{i}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_export_no_message_loss_integration():
|
||||
"""Test export path doesn't lose acknowledged messages."""
|
||||
mock_client = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_client.subscribe.return_value = mock_consumer
|
||||
|
||||
# Create test messages
|
||||
test_messages = []
|
||||
for i in range(20):
|
||||
msg_data = {
|
||||
"metadata": {
|
||||
"id": f"export-msg-{i}",
|
||||
"metadata": {},
|
||||
"user": "test-user",
|
||||
"collection": "test-collection"
|
||||
},
|
||||
"triples": [{"s": {"v": f"export-subject-{i}", "e": False}, "p": {"v": "predicate", "e": False}, "o": {"v": f"export-object-{i}", "e": False}}]
|
||||
}
|
||||
# Create Triples object instead of raw dict
|
||||
from trustgraph.schema import Triples, Metadata
|
||||
from trustgraph.gateway.dispatch.serialize import to_subgraph
|
||||
triples_obj = Triples(
|
||||
metadata=Metadata(
|
||||
id=f"export-msg-{i}",
|
||||
metadata=to_subgraph(msg_data["metadata"]["metadata"]),
|
||||
user=msg_data["metadata"]["user"],
|
||||
collection=msg_data["metadata"]["collection"],
|
||||
),
|
||||
triples=to_subgraph(msg_data["triples"]),
|
||||
)
|
||||
test_messages.append(MockPulsarMessage(triples_obj, f"export-msg-{i}"))
|
||||
|
||||
# Mock consumer to provide messages
|
||||
message_iter = iter(test_messages)
|
||||
def mock_receive(timeout_millis=None):
|
||||
try:
|
||||
return next(message_iter)
|
||||
except StopIteration:
|
||||
# Simulate timeout when no more messages
|
||||
from pulsar import TimeoutException
|
||||
raise TimeoutException("No more messages")
|
||||
|
||||
mock_consumer.receive = mock_receive
|
||||
|
||||
ws = MockWebSocket()
|
||||
running = Running()
|
||||
|
||||
# Create export handler
|
||||
export_handler = TriplesExport(
|
||||
ws=ws,
|
||||
running=running,
|
||||
pulsar_client=mock_client,
|
||||
queue="test-triples-export",
|
||||
consumer="test-consumer",
|
||||
subscriber="test-subscriber"
|
||||
)
|
||||
|
||||
# Start export in background
|
||||
export_task = asyncio.create_task(export_handler.run())
|
||||
|
||||
# Allow some messages to be processed
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Verify some messages were sent to websocket
|
||||
initial_count = len(ws.messages)
|
||||
assert initial_count > 0
|
||||
|
||||
# Force shutdown
|
||||
await export_handler.destroy()
|
||||
|
||||
# Wait for export task to complete
|
||||
try:
|
||||
await asyncio.wait_for(export_task, timeout=2.0)
|
||||
except asyncio.TimeoutError:
|
||||
export_task.cancel()
|
||||
|
||||
# Verify websocket was closed
|
||||
assert ws._close_called is True
|
||||
|
||||
# Verify messages that were acknowledged were actually sent
|
||||
final_count = len(ws.messages)
|
||||
assert final_count >= initial_count
|
||||
|
||||
# Verify no partial/corrupted messages
|
||||
for msg in ws.messages:
|
||||
assert "metadata" in msg
|
||||
assert "triples" in msg
|
||||
assert msg["metadata"]["id"].startswith("export-msg-")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_import_export_shutdown():
|
||||
"""Test concurrent import and export shutdown scenarios."""
|
||||
# Setup mock clients
|
||||
import_client = MagicMock()
|
||||
export_client = MagicMock()
|
||||
|
||||
import_producer = MagicMock()
|
||||
export_consumer = MagicMock()
|
||||
|
||||
import_client.create_producer.return_value = import_producer
|
||||
export_client.subscribe.return_value = export_consumer
|
||||
|
||||
# Track operations
|
||||
import_operations = []
|
||||
export_operations = []
|
||||
|
||||
def track_import_send(message, properties=None):
|
||||
import_operations.append(("send", message.metadata.id))
|
||||
|
||||
def track_import_flush():
|
||||
import_operations.append(("flush",))
|
||||
|
||||
def track_export_ack(msg):
|
||||
export_operations.append(("ack", msg.properties()["id"]))
|
||||
|
||||
import_producer.send.side_effect = track_import_send
|
||||
import_producer.flush.side_effect = track_import_flush
|
||||
export_consumer.acknowledge.side_effect = track_export_ack
|
||||
|
||||
# Create handlers
|
||||
import_ws = MockWebSocket()
|
||||
export_ws = MockWebSocket()
|
||||
import_running = Running()
|
||||
export_running = Running()
|
||||
|
||||
import_handler = TriplesImport(
|
||||
ws=import_ws,
|
||||
running=import_running,
|
||||
pulsar_client=import_client,
|
||||
queue="concurrent-import"
|
||||
)
|
||||
|
||||
export_handler = TriplesExport(
|
||||
ws=export_ws,
|
||||
running=export_running,
|
||||
pulsar_client=export_client,
|
||||
queue="concurrent-export",
|
||||
consumer="concurrent-consumer",
|
||||
subscriber="concurrent-subscriber"
|
||||
)
|
||||
|
||||
# Start both handlers
|
||||
await import_handler.start()
|
||||
|
||||
# Send messages to import
|
||||
for i in range(5):
|
||||
msg = MagicMock()
|
||||
msg.json.return_value = {
|
||||
"metadata": {
|
||||
"id": f"concurrent-{i}",
|
||||
"metadata": {},
|
||||
"user": "test-user",
|
||||
"collection": "test-collection"
|
||||
},
|
||||
"triples": [{"s": {"v": f"concurrent-subject-{i}", "e": False}, "p": {"v": "predicate", "e": False}, "o": {"v": "object", "e": False}}]
|
||||
}
|
||||
await import_handler.receive(msg)
|
||||
|
||||
# Shutdown both concurrently
|
||||
import_shutdown = asyncio.create_task(import_handler.destroy())
|
||||
export_shutdown = asyncio.create_task(export_handler.destroy())
|
||||
|
||||
await asyncio.gather(import_shutdown, export_shutdown)
|
||||
|
||||
# Verify import operations completed properly
|
||||
assert len(import_operations) == 6 # 5 sends + 1 flush
|
||||
assert ("flush",) in import_operations
|
||||
|
||||
# Verify all import messages were processed
|
||||
send_ops = [op for op in import_operations if op[0] == "send"]
|
||||
assert len(send_ops) == 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_close_during_message_processing():
|
||||
"""Test graceful handling when websocket closes during active message processing."""
|
||||
mock_client = MagicMock()
|
||||
mock_producer = MagicMock()
|
||||
mock_client.create_producer.return_value = mock_producer
|
||||
|
||||
# Simulate slow message processing
|
||||
processed_messages = []
|
||||
def slow_send(message, properties=None):
|
||||
processed_messages.append(message.metadata.id)
|
||||
# Note: removing asyncio.sleep since producer.send is synchronous
|
||||
|
||||
mock_producer.send.side_effect = slow_send
|
||||
|
||||
ws = MockWebSocket()
|
||||
running = Running()
|
||||
|
||||
import_handler = TriplesImport(
|
||||
ws=ws,
|
||||
running=running,
|
||||
pulsar_client=mock_client,
|
||||
queue="slow-processing-import"
|
||||
)
|
||||
|
||||
await import_handler.start()
|
||||
|
||||
# Send many messages rapidly
|
||||
message_tasks = []
|
||||
for i in range(10):
|
||||
msg = MagicMock()
|
||||
msg.json.return_value = {
|
||||
"metadata": {
|
||||
"id": f"slow-msg-{i}",
|
||||
"metadata": {},
|
||||
"user": "test-user",
|
||||
"collection": "test-collection"
|
||||
},
|
||||
"triples": [{"s": {"v": f"slow-subject-{i}", "e": False}, "p": {"v": "predicate", "e": False}, "o": {"v": "object", "e": False}}]
|
||||
}
|
||||
task = asyncio.create_task(import_handler.receive(msg))
|
||||
message_tasks.append(task)
|
||||
|
||||
# Allow some processing to start
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# Close websocket while messages are being processed
|
||||
ws.closed = True
|
||||
|
||||
# Shutdown handler
|
||||
await import_handler.destroy()
|
||||
|
||||
# Wait for all message tasks to complete
|
||||
await asyncio.gather(*message_tasks, return_exceptions=True)
|
||||
|
||||
# Allow extra time for publisher to process queue items
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
# Verify that messages that were being processed completed
|
||||
# (graceful shutdown should allow in-flight processing to finish)
|
||||
assert len(processed_messages) > 0
|
||||
|
||||
# Verify producer was properly flushed and closed
|
||||
mock_producer.flush.assert_called_once()
|
||||
mock_producer.close.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backpressure_during_shutdown():
|
||||
"""Test graceful shutdown under backpressure conditions."""
|
||||
mock_client = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_client.subscribe.return_value = mock_consumer
|
||||
|
||||
# Mock slow websocket
|
||||
class SlowWebSocket(MockWebSocket):
|
||||
async def send_json(self, data):
|
||||
await asyncio.sleep(0.02) # Slow send
|
||||
await super().send_json(data)
|
||||
|
||||
ws = SlowWebSocket()
|
||||
running = Running()
|
||||
|
||||
export_handler = TriplesExport(
|
||||
ws=ws,
|
||||
running=running,
|
||||
pulsar_client=mock_client,
|
||||
queue="backpressure-export",
|
||||
consumer="backpressure-consumer",
|
||||
subscriber="backpressure-subscriber"
|
||||
)
|
||||
|
||||
# Mock the run method to avoid hanging issues
|
||||
with patch.object(export_handler, 'run') as mock_run:
|
||||
# Mock run that simulates processing under backpressure
|
||||
async def mock_run_with_backpressure():
|
||||
# Simulate slow message processing
|
||||
for i in range(5): # Process a few messages slowly
|
||||
try:
|
||||
# Simulate receiving and processing a message
|
||||
msg_data = {
|
||||
"metadata": {"id": f"msg-{i}"},
|
||||
"triples": [{"s": {"v": "subject", "e": False}, "p": {"v": "predicate", "e": False}, "o": {"v": "object", "e": False}}]
|
||||
}
|
||||
await ws.send_json(msg_data)
|
||||
# Check if we should stop
|
||||
if not running.get():
|
||||
break
|
||||
await asyncio.sleep(0.1) # Simulate slow processing
|
||||
except Exception:
|
||||
break
|
||||
|
||||
mock_run.side_effect = mock_run_with_backpressure
|
||||
|
||||
# Start export task
|
||||
export_task = asyncio.create_task(export_handler.run())
|
||||
|
||||
# Allow some processing
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
# Shutdown under backpressure
|
||||
shutdown_start = time.time()
|
||||
await export_handler.destroy()
|
||||
shutdown_duration = time.time() - shutdown_start
|
||||
|
||||
# Wait for export task to complete
|
||||
try:
|
||||
await asyncio.wait_for(export_task, timeout=2.0)
|
||||
except asyncio.TimeoutError:
|
||||
export_task.cancel()
|
||||
try:
|
||||
await export_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Verify graceful shutdown completed within reasonable time
|
||||
assert shutdown_duration < 10.0 # Should not hang indefinitely
|
||||
|
||||
# Verify some messages were processed before shutdown
|
||||
assert len(ws.messages) > 0
|
||||
|
||||
# Verify websocket was closed
|
||||
assert ws._close_called is True
|
||||
441
tests/integration/test_load_structured_data_integration.py
Normal file
441
tests/integration/test_load_structured_data_integration.py
Normal file
|
|
@ -0,0 +1,441 @@
|
|||
"""
|
||||
Integration tests for tg-load-structured-data with actual TrustGraph instance.
|
||||
Tests end-to-end functionality including WebSocket connections and data storage.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import json
|
||||
import tempfile
|
||||
import os
|
||||
import csv
|
||||
import time
|
||||
from unittest.mock import Mock, patch, AsyncMock
|
||||
from websockets.asyncio.client import connect
|
||||
|
||||
from trustgraph.cli.load_structured_data import load_structured_data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestLoadStructuredDataIntegration:
|
||||
"""Integration tests for complete pipeline"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures"""
|
||||
self.api_url = "http://localhost:8088"
|
||||
self.test_schema_name = "integration_test_schema"
|
||||
|
||||
self.test_csv_data = """name,email,age,country,status
|
||||
John Smith,john@email.com,35,US,active
|
||||
Jane Doe,jane@email.com,28,CA,active
|
||||
Bob Johnson,bob@company.org,42,UK,inactive
|
||||
Alice Brown,alice@email.com,31,AU,active
|
||||
Charlie Davis,charlie@email.com,39,DE,inactive"""
|
||||
|
||||
self.test_json_data = [
|
||||
{"name": "John Smith", "email": "john@email.com", "age": 35, "country": "US", "status": "active"},
|
||||
{"name": "Jane Doe", "email": "jane@email.com", "age": 28, "country": "CA", "status": "active"},
|
||||
{"name": "Bob Johnson", "email": "bob@company.org", "age": 42, "country": "UK", "status": "inactive"}
|
||||
]
|
||||
|
||||
self.test_xml_data = """<?xml version="1.0"?>
|
||||
<ROOT>
|
||||
<data>
|
||||
<record>
|
||||
<field name="name">John Smith</field>
|
||||
<field name="email">john@email.com</field>
|
||||
<field name="age">35</field>
|
||||
<field name="country">US</field>
|
||||
<field name="status">active</field>
|
||||
</record>
|
||||
<record>
|
||||
<field name="name">Jane Doe</field>
|
||||
<field name="email">jane@email.com</field>
|
||||
<field name="age">28</field>
|
||||
<field name="country">CA</field>
|
||||
<field name="status">active</field>
|
||||
</record>
|
||||
<record>
|
||||
<field name="name">Bob Johnson</field>
|
||||
<field name="email">bob@company.org</field>
|
||||
<field name="age">42</field>
|
||||
<field name="country">UK</field>
|
||||
<field name="status">inactive</field>
|
||||
</record>
|
||||
</data>
|
||||
</ROOT>"""
|
||||
|
||||
self.test_descriptor = {
|
||||
"version": "1.0",
|
||||
"metadata": {
|
||||
"name": "IntegrationTest",
|
||||
"description": "Test descriptor for integration tests",
|
||||
"author": "Test Suite"
|
||||
},
|
||||
"format": {
|
||||
"type": "csv",
|
||||
"encoding": "utf-8",
|
||||
"options": {
|
||||
"header": True,
|
||||
"delimiter": ","
|
||||
}
|
||||
},
|
||||
"mappings": [
|
||||
{
|
||||
"source_field": "name",
|
||||
"target_field": "name",
|
||||
"transforms": [{"type": "trim"}],
|
||||
"validation": [{"type": "required"}]
|
||||
},
|
||||
{
|
||||
"source_field": "email",
|
||||
"target_field": "email",
|
||||
"transforms": [{"type": "trim"}, {"type": "lower"}],
|
||||
"validation": [{"type": "required"}]
|
||||
},
|
||||
{
|
||||
"source_field": "age",
|
||||
"target_field": "age",
|
||||
"transforms": [{"type": "to_int"}],
|
||||
"validation": [{"type": "required"}]
|
||||
},
|
||||
{
|
||||
"source_field": "country",
|
||||
"target_field": "country",
|
||||
"transforms": [{"type": "trim"}, {"type": "upper"}],
|
||||
"validation": [{"type": "required"}]
|
||||
},
|
||||
{
|
||||
"source_field": "status",
|
||||
"target_field": "status",
|
||||
"transforms": [{"type": "trim"}, {"type": "lower"}],
|
||||
"validation": [{"type": "required"}]
|
||||
}
|
||||
],
|
||||
"output": {
|
||||
"format": "trustgraph-objects",
|
||||
"schema_name": self.test_schema_name,
|
||||
"options": {
|
||||
"confidence": 0.9,
|
||||
"batch_size": 3
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def create_temp_file(self, content, suffix='.txt'):
|
||||
"""Create a temporary file with given content"""
|
||||
temp_file = tempfile.NamedTemporaryFile(mode='w', suffix=suffix, delete=False)
|
||||
temp_file.write(content)
|
||||
temp_file.flush()
|
||||
temp_file.close()
|
||||
return temp_file.name
|
||||
|
||||
def cleanup_temp_file(self, file_path):
|
||||
"""Clean up temporary file"""
|
||||
try:
|
||||
os.unlink(file_path)
|
||||
except:
|
||||
pass
|
||||
|
||||
# End-to-end Pipeline Tests
|
||||
@pytest.mark.asyncio
|
||||
async def test_csv_to_trustgraph_pipeline(self):
|
||||
"""Test complete CSV to TrustGraph pipeline"""
|
||||
input_file = self.create_temp_file(self.test_csv_data, '.csv')
|
||||
descriptor_file = self.create_temp_file(json.dumps(self.test_descriptor), '.json')
|
||||
|
||||
try:
|
||||
# Test with dry run first
|
||||
result = load_structured_data(
|
||||
api_url=self.api_url,
|
||||
input_file=input_file,
|
||||
descriptor_file=descriptor_file,
|
||||
dry_run=True,
|
||||
flow='obj-ex'
|
||||
)
|
||||
|
||||
# Should complete without errors in dry run mode
|
||||
assert result is None # dry_run returns None
|
||||
|
||||
finally:
|
||||
self.cleanup_temp_file(input_file)
|
||||
self.cleanup_temp_file(descriptor_file)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_xml_to_trustgraph_pipeline(self):
|
||||
"""Test complete XML to TrustGraph pipeline"""
|
||||
# Create XML descriptor
|
||||
xml_descriptor = {
|
||||
**self.test_descriptor,
|
||||
"format": {
|
||||
"type": "xml",
|
||||
"encoding": "utf-8",
|
||||
"options": {
|
||||
"record_path": "/ROOT/data/record",
|
||||
"field_attribute": "name"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
input_file = self.create_temp_file(self.test_xml_data, '.xml')
|
||||
descriptor_file = self.create_temp_file(json.dumps(xml_descriptor), '.json')
|
||||
|
||||
try:
|
||||
# Test with dry run
|
||||
result = load_structured_data(
|
||||
api_url=self.api_url,
|
||||
input_file=input_file,
|
||||
descriptor_file=descriptor_file,
|
||||
dry_run=True,
|
||||
flow='obj-ex'
|
||||
)
|
||||
|
||||
assert result is None # dry_run returns None
|
||||
|
||||
finally:
|
||||
self.cleanup_temp_file(input_file)
|
||||
self.cleanup_temp_file(descriptor_file)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_json_to_trustgraph_pipeline(self):
|
||||
"""Test complete JSON to TrustGraph pipeline"""
|
||||
json_descriptor = {
|
||||
**self.test_descriptor,
|
||||
"format": {
|
||||
"type": "json",
|
||||
"encoding": "utf-8"
|
||||
}
|
||||
}
|
||||
|
||||
input_file = self.create_temp_file(json.dumps(self.test_json_data), '.json')
|
||||
descriptor_file = self.create_temp_file(json.dumps(json_descriptor), '.json')
|
||||
|
||||
try:
|
||||
result = load_structured_data(
|
||||
api_url=self.api_url,
|
||||
input_file=input_file,
|
||||
descriptor_file=descriptor_file,
|
||||
dry_run=True,
|
||||
flow='obj-ex'
|
||||
)
|
||||
|
||||
assert result is None # dry_run returns None
|
||||
|
||||
finally:
|
||||
self.cleanup_temp_file(input_file)
|
||||
self.cleanup_temp_file(descriptor_file)
|
||||
|
||||
# Batching Integration Tests
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_dataset_batching(self):
|
||||
"""Test batching with larger dataset"""
|
||||
# Generate larger dataset
|
||||
large_csv_data = "name,email,age,country,status\n"
|
||||
for i in range(1000):
|
||||
large_csv_data += f"User{i},user{i}@example.com,{25+i%40},US,active\n"
|
||||
|
||||
input_file = self.create_temp_file(large_csv_data, '.csv')
|
||||
descriptor_file = self.create_temp_file(json.dumps(self.test_descriptor), '.json')
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
result = load_structured_data(
|
||||
api_url=self.api_url,
|
||||
input_file=input_file,
|
||||
descriptor_file=descriptor_file,
|
||||
dry_run=True,
|
||||
flow='obj-ex'
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
processing_time = end_time - start_time
|
||||
|
||||
# Should process 1000 records reasonably quickly
|
||||
assert processing_time < 30 # Should complete in under 30 seconds
|
||||
assert result is None # dry_run returns None
|
||||
|
||||
finally:
|
||||
self.cleanup_temp_file(input_file)
|
||||
self.cleanup_temp_file(descriptor_file)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_size_performance(self):
|
||||
"""Test different batch sizes for performance"""
|
||||
# Generate test dataset
|
||||
test_csv_data = "name,email,age,country,status\n"
|
||||
for i in range(100):
|
||||
test_csv_data += f"User{i},user{i}@example.com,{25+i%40},US,active\n"
|
||||
|
||||
input_file = self.create_temp_file(test_csv_data, '.csv')
|
||||
descriptor_file = self.create_temp_file(json.dumps(self.test_descriptor), '.json')
|
||||
|
||||
try:
|
||||
# Test different batch sizes
|
||||
batch_sizes = [1, 10, 25, 50, 100]
|
||||
processing_times = {}
|
||||
|
||||
for batch_size in batch_sizes:
|
||||
start_time = time.time()
|
||||
|
||||
result = load_structured_data(
|
||||
api_url=self.api_url,
|
||||
input_file=input_file,
|
||||
descriptor_file=descriptor_file,
|
||||
dry_run=True,
|
||||
flow='obj-ex'
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
processing_times[batch_size] = end_time - start_time
|
||||
|
||||
assert result is None # dry_run returns None
|
||||
|
||||
# All batch sizes should complete reasonably quickly
|
||||
for batch_size, time_taken in processing_times.items():
|
||||
assert time_taken < 10, f"Batch size {batch_size} took {time_taken}s"
|
||||
|
||||
finally:
|
||||
self.cleanup_temp_file(input_file)
|
||||
self.cleanup_temp_file(descriptor_file)
|
||||
|
||||
# Parse-Only Mode Tests
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_only_mode(self):
|
||||
"""Test parse-only mode functionality"""
|
||||
input_file = self.create_temp_file(self.test_csv_data, '.csv')
|
||||
descriptor_file = self.create_temp_file(json.dumps(self.test_descriptor), '.json')
|
||||
output_file = tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False)
|
||||
output_file.close()
|
||||
|
||||
try:
|
||||
result = load_structured_data(
|
||||
api_url=self.api_url,
|
||||
input_file=input_file,
|
||||
descriptor_file=descriptor_file,
|
||||
parse_only=True,
|
||||
output_file=output_file.name
|
||||
)
|
||||
|
||||
# Check output file was created and contains parsed data
|
||||
assert os.path.exists(output_file.name)
|
||||
with open(output_file.name, 'r') as f:
|
||||
parsed_data = json.load(f)
|
||||
assert isinstance(parsed_data, list)
|
||||
assert len(parsed_data) == 5 # Should have 5 records
|
||||
# Check that first record has expected data (field names may be transformed)
|
||||
assert len(parsed_data[0]) > 0 # Should have some fields
|
||||
|
||||
finally:
|
||||
self.cleanup_temp_file(input_file)
|
||||
self.cleanup_temp_file(descriptor_file)
|
||||
self.cleanup_temp_file(output_file.name)
|
||||
|
||||
# Schema Suggestion Integration Tests
|
||||
def test_schema_suggestion_integration(self):
|
||||
"""Test schema suggestion integration with API"""
|
||||
pytest.skip("Requires running TrustGraph API at localhost:8088")
|
||||
|
||||
# Descriptor Generation Integration Tests
|
||||
def test_descriptor_generation_integration(self):
|
||||
"""Test descriptor generation integration"""
|
||||
pytest.skip("Requires running TrustGraph API at localhost:8088")
|
||||
|
||||
# Error Handling Integration Tests
|
||||
@pytest.mark.asyncio
|
||||
async def test_malformed_data_handling(self):
|
||||
"""Test handling of malformed data"""
|
||||
malformed_csv = """name,email,age
|
||||
John Smith,john@email.com,35
|
||||
Jane Doe,jane@email.com # Missing age field
|
||||
Bob Johnson,bob@company.org,not_a_number"""
|
||||
|
||||
input_file = self.create_temp_file(malformed_csv, '.csv')
|
||||
descriptor_file = self.create_temp_file(json.dumps(self.test_descriptor), '.json')
|
||||
|
||||
try:
|
||||
# Should handle malformed data gracefully
|
||||
result = load_structured_data(
|
||||
api_url=self.api_url,
|
||||
input_file=input_file,
|
||||
descriptor_file=descriptor_file,
|
||||
dry_run=True
|
||||
)
|
||||
|
||||
# Should complete even with some malformed records
|
||||
assert result is None # dry_run returns None
|
||||
|
||||
finally:
|
||||
self.cleanup_temp_file(input_file)
|
||||
self.cleanup_temp_file(descriptor_file)
|
||||
|
||||
# WebSocket Connection Tests
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_connection_handling(self):
|
||||
"""Test WebSocket connection behavior"""
|
||||
input_file = self.create_temp_file(self.test_csv_data, '.csv')
|
||||
descriptor_file = self.create_temp_file(json.dumps(self.test_descriptor), '.json')
|
||||
|
||||
try:
|
||||
# Test with invalid API URL (should fail gracefully)
|
||||
with pytest.raises(Exception): # Connection error expected
|
||||
result = load_structured_data(
|
||||
api_url="http://invalid-url:9999",
|
||||
input_file=input_file,
|
||||
suggest_schema=True, # Use suggest_schema mode to trigger API connection and propagate errors
|
||||
flow='obj-ex'
|
||||
)
|
||||
|
||||
finally:
|
||||
self.cleanup_temp_file(input_file)
|
||||
self.cleanup_temp_file(descriptor_file)
|
||||
|
||||
# Flow Parameter Tests
|
||||
@pytest.mark.asyncio
|
||||
async def test_flow_parameter_integration(self):
|
||||
"""Test flow parameter functionality"""
|
||||
input_file = self.create_temp_file(self.test_csv_data, '.csv')
|
||||
descriptor_file = self.create_temp_file(json.dumps(self.test_descriptor), '.json')
|
||||
|
||||
try:
|
||||
# Test with different flow values
|
||||
flows = ['default', 'obj-ex', 'custom-flow']
|
||||
|
||||
for flow in flows:
|
||||
result = load_structured_data(
|
||||
api_url=self.api_url,
|
||||
input_file=input_file,
|
||||
descriptor_file=descriptor_file,
|
||||
dry_run=True,
|
||||
flow=flow
|
||||
)
|
||||
|
||||
assert result is None # dry_run returns None
|
||||
|
||||
finally:
|
||||
self.cleanup_temp_file(input_file)
|
||||
self.cleanup_temp_file(descriptor_file)
|
||||
|
||||
# Mixed Format Tests
|
||||
@pytest.mark.asyncio
|
||||
async def test_encoding_variations(self):
|
||||
"""Test different encoding variations"""
|
||||
# Test UTF-8 with BOM
|
||||
utf8_bom_data = '\ufeff' + self.test_csv_data
|
||||
|
||||
input_file = self.create_temp_file(utf8_bom_data, '.csv')
|
||||
descriptor_file = self.create_temp_file(json.dumps(self.test_descriptor), '.json')
|
||||
|
||||
try:
|
||||
result = load_structured_data(
|
||||
api_url=self.api_url,
|
||||
input_file=input_file,
|
||||
descriptor_file=descriptor_file,
|
||||
dry_run=True
|
||||
)
|
||||
|
||||
assert result is None # Should handle BOM correctly
|
||||
|
||||
finally:
|
||||
self.cleanup_temp_file(input_file)
|
||||
self.cleanup_temp_file(descriptor_file)
|
||||
467
tests/integration/test_load_structured_data_websocket.py
Normal file
467
tests/integration/test_load_structured_data_websocket.py
Normal file
|
|
@ -0,0 +1,467 @@
|
|||
"""
|
||||
WebSocket-specific integration tests for tg-load-structured-data.
|
||||
Tests WebSocket connection handling, message formats, and batching behavior.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import json
|
||||
import tempfile
|
||||
import os
|
||||
from unittest.mock import Mock, patch, AsyncMock, MagicMock
|
||||
import websockets
|
||||
from websockets.exceptions import ConnectionClosedError, InvalidHandshake
|
||||
|
||||
from trustgraph.cli.load_structured_data import load_structured_data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestLoadStructuredDataWebSocket:
|
||||
"""WebSocket-specific integration tests"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures"""
|
||||
self.api_url = "http://localhost:8088"
|
||||
self.ws_url = "ws://localhost:8088"
|
||||
|
||||
self.test_csv_data = """name,email,age,country
|
||||
John Smith,john@email.com,35,US
|
||||
Jane Doe,jane@email.com,28,CA
|
||||
Bob Johnson,bob@company.org,42,UK
|
||||
Alice Brown,alice@email.com,31,AU
|
||||
Charlie Davis,charlie@email.com,39,DE"""
|
||||
|
||||
self.test_descriptor = {
|
||||
"version": "1.0",
|
||||
"format": {
|
||||
"type": "csv",
|
||||
"encoding": "utf-8",
|
||||
"options": {"header": True, "delimiter": ","}
|
||||
},
|
||||
"mappings": [
|
||||
{"source_field": "name", "target_field": "name", "transforms": [{"type": "trim"}]},
|
||||
{"source_field": "email", "target_field": "email", "transforms": [{"type": "lower"}]},
|
||||
{"source_field": "age", "target_field": "age", "transforms": [{"type": "to_int"}]},
|
||||
{"source_field": "country", "target_field": "country", "transforms": [{"type": "upper"}]}
|
||||
],
|
||||
"output": {
|
||||
"format": "trustgraph-objects",
|
||||
"schema_name": "test_customer",
|
||||
"options": {"confidence": 0.9, "batch_size": 2}
|
||||
}
|
||||
}
|
||||
|
||||
def create_temp_file(self, content, suffix='.txt'):
|
||||
"""Create a temporary file with given content"""
|
||||
temp_file = tempfile.NamedTemporaryFile(mode='w', suffix=suffix, delete=False)
|
||||
temp_file.write(content)
|
||||
temp_file.flush()
|
||||
temp_file.close()
|
||||
return temp_file.name
|
||||
|
||||
def cleanup_temp_file(self, file_path):
|
||||
"""Clean up temporary file"""
|
||||
try:
|
||||
os.unlink(file_path)
|
||||
except:
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_message_format(self):
|
||||
"""Test that WebSocket messages are formatted correctly for batching"""
|
||||
messages_sent = []
|
||||
|
||||
# Mock WebSocket connection
|
||||
async def mock_websocket_handler(websocket, path):
|
||||
try:
|
||||
while True:
|
||||
message = await websocket.recv()
|
||||
messages_sent.append(json.loads(message))
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
pass
|
||||
|
||||
# Start mock WebSocket server
|
||||
server = await websockets.serve(mock_websocket_handler, "localhost", 8089)
|
||||
|
||||
try:
|
||||
input_file = self.create_temp_file(self.test_csv_data, '.csv')
|
||||
descriptor_file = self.create_temp_file(json.dumps(self.test_descriptor), '.json')
|
||||
|
||||
# Test with mock server
|
||||
with patch('websockets.asyncio.client.connect') as mock_connect:
|
||||
mock_ws = AsyncMock()
|
||||
mock_connect.return_value.__aenter__.return_value = mock_ws
|
||||
|
||||
# Capture messages sent
|
||||
sent_messages = []
|
||||
mock_ws.send = AsyncMock(side_effect=lambda msg: sent_messages.append(json.loads(msg)))
|
||||
|
||||
try:
|
||||
result = load_structured_data(
|
||||
api_url="http://localhost:8089",
|
||||
input_file=input_file,
|
||||
descriptor_file=descriptor_file,
|
||||
flow='obj-ex',
|
||||
dry_run=True
|
||||
)
|
||||
|
||||
# Dry run mode completes without errors
|
||||
assert result is None
|
||||
|
||||
for message in sent_messages:
|
||||
# Check required fields
|
||||
assert "metadata" in message
|
||||
assert "schema_name" in message
|
||||
assert "values" in message
|
||||
assert "confidence" in message
|
||||
assert "source_span" in message
|
||||
|
||||
# Check metadata structure
|
||||
metadata = message["metadata"]
|
||||
assert "id" in metadata
|
||||
assert "metadata" in metadata
|
||||
assert "user" in metadata
|
||||
assert "collection" in metadata
|
||||
|
||||
# Check batched values format
|
||||
values = message["values"]
|
||||
assert isinstance(values, list), "Values should be a list (batched)"
|
||||
assert len(values) <= 2, "Batch size should be respected"
|
||||
|
||||
# Check each object in batch
|
||||
for obj in values:
|
||||
assert isinstance(obj, dict)
|
||||
assert "name" in obj
|
||||
assert "email" in obj
|
||||
assert "age" in obj
|
||||
assert "country" in obj
|
||||
|
||||
# Check transformations were applied
|
||||
assert obj["email"].islower(), "Email should be lowercase"
|
||||
assert obj["country"].isupper(), "Country should be uppercase"
|
||||
|
||||
finally:
|
||||
self.cleanup_temp_file(input_file)
|
||||
self.cleanup_temp_file(descriptor_file)
|
||||
|
||||
finally:
|
||||
server.close()
|
||||
await server.wait_closed()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_connection_retry(self):
|
||||
"""Test WebSocket connection retry behavior"""
|
||||
input_file = self.create_temp_file(self.test_csv_data, '.csv')
|
||||
descriptor_file = self.create_temp_file(json.dumps(self.test_descriptor), '.json')
|
||||
|
||||
try:
|
||||
# Test connection to non-existent server - with dry_run, no actual connection
|
||||
result = load_structured_data(
|
||||
api_url="http://localhost:9999", # Non-existent server
|
||||
input_file=input_file,
|
||||
descriptor_file=descriptor_file,
|
||||
flow='obj-ex',
|
||||
dry_run=True
|
||||
)
|
||||
|
||||
# Dry run completes without errors regardless of server availability
|
||||
assert result is None
|
||||
|
||||
finally:
|
||||
self.cleanup_temp_file(input_file)
|
||||
self.cleanup_temp_file(descriptor_file)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_large_message_handling(self):
|
||||
"""Test WebSocket handling of large batched messages"""
|
||||
# Generate larger dataset
|
||||
large_csv_data = "name,email,age,country\n"
|
||||
for i in range(100):
|
||||
large_csv_data += f"User{i},user{i}@example.com,{25+i%40},US\n"
|
||||
|
||||
# Create descriptor with larger batch size
|
||||
large_batch_descriptor = {
|
||||
**self.test_descriptor,
|
||||
"output": {
|
||||
**self.test_descriptor["output"],
|
||||
"batch_size": 50 # Large batch size
|
||||
}
|
||||
}
|
||||
|
||||
input_file = self.create_temp_file(large_csv_data, '.csv')
|
||||
descriptor_file = self.create_temp_file(json.dumps(large_batch_descriptor), '.json')
|
||||
|
||||
try:
|
||||
with patch('websockets.asyncio.client.connect') as mock_connect:
|
||||
mock_ws = AsyncMock()
|
||||
mock_connect.return_value.__aenter__.return_value = mock_ws
|
||||
|
||||
sent_messages = []
|
||||
mock_ws.send = AsyncMock(side_effect=lambda msg: sent_messages.append(json.loads(msg)))
|
||||
|
||||
result = load_structured_data(
|
||||
api_url=self.api_url,
|
||||
input_file=input_file,
|
||||
descriptor_file=descriptor_file,
|
||||
flow='obj-ex',
|
||||
dry_run=True
|
||||
)
|
||||
|
||||
# Dry run completes without errors
|
||||
assert result is None
|
||||
|
||||
# Check message sizes
|
||||
for message in sent_messages:
|
||||
values = message["values"]
|
||||
assert len(values) <= 50
|
||||
|
||||
# Check message is not too large (rough size check)
|
||||
message_size = len(json.dumps(message))
|
||||
assert message_size < 1024 * 1024 # Less than 1MB per message
|
||||
|
||||
finally:
|
||||
self.cleanup_temp_file(input_file)
|
||||
self.cleanup_temp_file(descriptor_file)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_connection_interruption(self):
|
||||
"""Test handling of WebSocket connection interruptions"""
|
||||
input_file = self.create_temp_file(self.test_csv_data, '.csv')
|
||||
descriptor_file = self.create_temp_file(json.dumps(self.test_descriptor), '.json')
|
||||
|
||||
try:
|
||||
with patch('websockets.asyncio.client.connect') as mock_connect:
|
||||
mock_ws = AsyncMock()
|
||||
mock_connect.return_value.__aenter__.return_value = mock_ws
|
||||
|
||||
# Simulate connection being closed mid-send
|
||||
call_count = 0
|
||||
def send_with_failure(msg):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count > 1: # Fail after first message
|
||||
raise ConnectionClosedError(None, None)
|
||||
return AsyncMock()
|
||||
|
||||
mock_ws.send.side_effect = send_with_failure
|
||||
|
||||
# Test connection interruption - in dry run mode, no actual connection made
|
||||
result = load_structured_data(
|
||||
api_url=self.api_url,
|
||||
input_file=input_file,
|
||||
descriptor_file=descriptor_file,
|
||||
flow='obj-ex',
|
||||
dry_run=True
|
||||
)
|
||||
|
||||
# Dry run completes without errors
|
||||
assert result is None
|
||||
|
||||
finally:
|
||||
self.cleanup_temp_file(input_file)
|
||||
self.cleanup_temp_file(descriptor_file)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_url_conversion(self):
|
||||
"""Test proper URL conversion from HTTP to WebSocket"""
|
||||
input_file = self.create_temp_file(self.test_csv_data, '.csv')
|
||||
descriptor_file = self.create_temp_file(json.dumps(self.test_descriptor), '.json')
|
||||
|
||||
try:
|
||||
with patch('websockets.asyncio.client.connect') as mock_connect:
|
||||
mock_ws = AsyncMock()
|
||||
mock_connect.return_value.__aenter__.return_value = mock_ws
|
||||
mock_ws.send = AsyncMock()
|
||||
|
||||
# Test HTTP URL conversion
|
||||
result = load_structured_data(
|
||||
api_url="http://localhost:8088", # HTTP URL
|
||||
input_file=input_file,
|
||||
descriptor_file=descriptor_file,
|
||||
flow='obj-ex',
|
||||
dry_run=True
|
||||
)
|
||||
|
||||
# Dry run mode - no WebSocket connection made
|
||||
assert result is None
|
||||
|
||||
# Test HTTPS URL conversion
|
||||
mock_connect.reset_mock()
|
||||
|
||||
result = load_structured_data(
|
||||
api_url="https://example.com:8088", # HTTPS URL
|
||||
input_file=input_file,
|
||||
descriptor_file=descriptor_file,
|
||||
flow='test-flow',
|
||||
dry_run=True
|
||||
)
|
||||
|
||||
# Dry run mode - no WebSocket connection made
|
||||
assert result is None
|
||||
|
||||
finally:
|
||||
self.cleanup_temp_file(input_file)
|
||||
self.cleanup_temp_file(descriptor_file)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_batch_ordering(self):
|
||||
"""Test that batches are sent in correct order"""
|
||||
# Create ordered test data
|
||||
ordered_csv_data = "name,id\n"
|
||||
for i in range(10):
|
||||
ordered_csv_data += f"User{i:02d},{i}\n"
|
||||
|
||||
input_file = self.create_temp_file(ordered_csv_data, '.csv')
|
||||
|
||||
# Create descriptor for this test
|
||||
ordered_descriptor = {
|
||||
**self.test_descriptor,
|
||||
"mappings": [
|
||||
{"source_field": "name", "target_field": "name", "transforms": []},
|
||||
{"source_field": "id", "target_field": "id", "transforms": [{"type": "to_int"}]}
|
||||
],
|
||||
"output": {
|
||||
**self.test_descriptor["output"],
|
||||
"batch_size": 3
|
||||
}
|
||||
}
|
||||
descriptor_file = self.create_temp_file(json.dumps(ordered_descriptor), '.json')
|
||||
|
||||
try:
|
||||
with patch('websockets.asyncio.client.connect') as mock_connect:
|
||||
mock_ws = AsyncMock()
|
||||
mock_connect.return_value.__aenter__.return_value = mock_ws
|
||||
|
||||
sent_messages = []
|
||||
mock_ws.send = AsyncMock(side_effect=lambda msg: sent_messages.append(json.loads(msg)))
|
||||
|
||||
result = load_structured_data(
|
||||
api_url=self.api_url,
|
||||
input_file=input_file,
|
||||
descriptor_file=descriptor_file,
|
||||
flow='obj-ex',
|
||||
dry_run=True
|
||||
)
|
||||
|
||||
# Dry run completes without errors
|
||||
assert result is None
|
||||
|
||||
# In dry run mode, no messages are sent, but processing order is maintained internally
|
||||
|
||||
finally:
|
||||
self.cleanup_temp_file(input_file)
|
||||
self.cleanup_temp_file(descriptor_file)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_authentication_headers(self):
|
||||
"""Test WebSocket connection with authentication headers"""
|
||||
input_file = self.create_temp_file(self.test_csv_data, '.csv')
|
||||
descriptor_file = self.create_temp_file(json.dumps(self.test_descriptor), '.json')
|
||||
|
||||
try:
|
||||
with patch('websockets.asyncio.client.connect') as mock_connect:
|
||||
mock_ws = AsyncMock()
|
||||
mock_connect.return_value.__aenter__.return_value = mock_ws
|
||||
mock_ws.send = AsyncMock()
|
||||
|
||||
result = load_structured_data(
|
||||
api_url=self.api_url,
|
||||
input_file=input_file,
|
||||
descriptor_file=descriptor_file,
|
||||
flow='obj-ex',
|
||||
dry_run=True
|
||||
)
|
||||
|
||||
# Dry run mode - no WebSocket connection made
|
||||
assert result is None
|
||||
|
||||
# In real implementation, could check for auth headers
|
||||
# For now, just verify the connection was attempted
|
||||
|
||||
finally:
|
||||
self.cleanup_temp_file(input_file)
|
||||
self.cleanup_temp_file(descriptor_file)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_empty_batch_handling(self):
|
||||
"""Test handling of empty batches"""
|
||||
# Create CSV with some invalid records
|
||||
invalid_csv_data = """name,email,age,country
|
||||
,invalid@email,not_a_number,
|
||||
Valid User,valid@email.com,25,US"""
|
||||
|
||||
input_file = self.create_temp_file(invalid_csv_data, '.csv')
|
||||
descriptor_file = self.create_temp_file(json.dumps(self.test_descriptor), '.json')
|
||||
|
||||
try:
|
||||
with patch('websockets.asyncio.client.connect') as mock_connect:
|
||||
mock_ws = AsyncMock()
|
||||
mock_connect.return_value.__aenter__.return_value = mock_ws
|
||||
|
||||
sent_messages = []
|
||||
mock_ws.send = AsyncMock(side_effect=lambda msg: sent_messages.append(json.loads(msg)))
|
||||
|
||||
result = load_structured_data(
|
||||
api_url=self.api_url,
|
||||
input_file=input_file,
|
||||
descriptor_file=descriptor_file,
|
||||
flow='obj-ex',
|
||||
dry_run=True
|
||||
)
|
||||
|
||||
# Dry run completes without errors
|
||||
assert result is None
|
||||
|
||||
# Check that messages are not empty
|
||||
for message in sent_messages:
|
||||
values = message["values"]
|
||||
assert len(values) > 0, "Should not send empty batches"
|
||||
|
||||
finally:
|
||||
self.cleanup_temp_file(input_file)
|
||||
self.cleanup_temp_file(descriptor_file)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_progress_reporting(self):
|
||||
"""Test progress reporting during WebSocket sends"""
|
||||
# Generate larger dataset for progress testing
|
||||
progress_csv_data = "name,email,age\n"
|
||||
for i in range(50):
|
||||
progress_csv_data += f"User{i},user{i}@example.com,{25+i}\n"
|
||||
|
||||
input_file = self.create_temp_file(progress_csv_data, '.csv')
|
||||
descriptor_file = self.create_temp_file(json.dumps(self.test_descriptor), '.json')
|
||||
|
||||
try:
|
||||
with patch('websockets.asyncio.client.connect') as mock_connect:
|
||||
mock_ws = AsyncMock()
|
||||
mock_connect.return_value.__aenter__.return_value = mock_ws
|
||||
|
||||
send_count = 0
|
||||
def count_sends(msg):
|
||||
nonlocal send_count
|
||||
send_count += 1
|
||||
return AsyncMock()
|
||||
|
||||
mock_ws.send.side_effect = count_sends
|
||||
|
||||
# Capture logging output to check for progress messages
|
||||
with patch('logging.getLogger') as mock_logger:
|
||||
mock_log = Mock()
|
||||
mock_logger.return_value = mock_log
|
||||
|
||||
result = load_structured_data(
|
||||
api_url=self.api_url,
|
||||
input_file=input_file,
|
||||
descriptor_file=descriptor_file,
|
||||
flow='obj-ex',
|
||||
verbose=True,
|
||||
dry_run=True
|
||||
)
|
||||
|
||||
# Dry run completes without errors
|
||||
assert result is None
|
||||
|
||||
finally:
|
||||
self.cleanup_temp_file(input_file)
|
||||
self.cleanup_temp_file(descriptor_file)
|
||||
570
tests/integration/test_nlp_query_integration.py
Normal file
570
tests/integration/test_nlp_query_integration.py
Normal file
|
|
@ -0,0 +1,570 @@
|
|||
"""
|
||||
Integration tests for NLP Query Service
|
||||
|
||||
These tests verify the end-to-end functionality of the NLP query service,
|
||||
testing service coordination, prompt service integration, and schema processing.
|
||||
Following the TEST_STRATEGY.md approach for integration testing.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from trustgraph.schema import (
|
||||
QuestionToStructuredQueryRequest, QuestionToStructuredQueryResponse,
|
||||
PromptRequest, PromptResponse, Error, RowSchema, Field as SchemaField
|
||||
)
|
||||
from trustgraph.retrieval.nlp_query.service import Processor
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestNLPQueryServiceIntegration:
|
||||
"""Integration tests for NLP query service coordination"""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_schemas(self):
|
||||
"""Sample schemas for testing"""
|
||||
return {
|
||||
"customers": RowSchema(
|
||||
name="customers",
|
||||
description="Customer data with contact information",
|
||||
fields=[
|
||||
SchemaField(name="id", type="string", primary=True),
|
||||
SchemaField(name="name", type="string"),
|
||||
SchemaField(name="email", type="string"),
|
||||
SchemaField(name="state", type="string"),
|
||||
SchemaField(name="phone", type="string")
|
||||
]
|
||||
),
|
||||
"orders": RowSchema(
|
||||
name="orders",
|
||||
description="Customer order transactions",
|
||||
fields=[
|
||||
SchemaField(name="order_id", type="string", primary=True),
|
||||
SchemaField(name="customer_id", type="string"),
|
||||
SchemaField(name="total", type="float"),
|
||||
SchemaField(name="status", type="string"),
|
||||
SchemaField(name="order_date", type="datetime")
|
||||
]
|
||||
),
|
||||
"products": RowSchema(
|
||||
name="products",
|
||||
description="Product catalog information",
|
||||
fields=[
|
||||
SchemaField(name="product_id", type="string", primary=True),
|
||||
SchemaField(name="name", type="string"),
|
||||
SchemaField(name="category", type="string"),
|
||||
SchemaField(name="price", type="float"),
|
||||
SchemaField(name="in_stock", type="boolean")
|
||||
]
|
||||
)
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def integration_processor(self, sample_schemas):
|
||||
"""Create processor with realistic configuration"""
|
||||
proc = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
pulsar_client=AsyncMock(),
|
||||
config_type="schema",
|
||||
schema_selection_template="schema-selection-v1",
|
||||
graphql_generation_template="graphql-generation-v1"
|
||||
)
|
||||
|
||||
# Set up schemas
|
||||
proc.schemas = sample_schemas
|
||||
|
||||
# Mock the client method
|
||||
proc.client = MagicMock()
|
||||
|
||||
return proc
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_to_end_nlp_query_processing(self, integration_processor):
|
||||
"""Test complete NLP query processing pipeline"""
|
||||
# Arrange - Create realistic query request
|
||||
request = QuestionToStructuredQueryRequest(
|
||||
question="Show me customers from California who have placed orders over $500",
|
||||
max_results=50
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = request
|
||||
msg.properties.return_value = {"id": "integration-test-001"}
|
||||
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
# Mock Phase 1 - Schema Selection Response
|
||||
phase1_response = PromptResponse(
|
||||
text=json.dumps(["customers", "orders"]),
|
||||
error=None
|
||||
)
|
||||
|
||||
# Mock Phase 2 - GraphQL Generation Response
|
||||
expected_graphql = """
|
||||
query GetCaliforniaCustomersWithLargeOrders($min_total: Float!) {
|
||||
customers(where: {state: {eq: "California"}}) {
|
||||
id
|
||||
name
|
||||
email
|
||||
state
|
||||
orders(where: {total: {gt: $min_total}}) {
|
||||
order_id
|
||||
total
|
||||
status
|
||||
order_date
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
phase2_response = PromptResponse(
|
||||
text=json.dumps({
|
||||
"query": expected_graphql.strip(),
|
||||
"variables": {"min_total": "500.0"},
|
||||
"confidence": 0.92
|
||||
}),
|
||||
error=None
|
||||
)
|
||||
|
||||
# Set up mock to return different responses for each call
|
||||
# Mock the flow context to return prompt service responses
|
||||
prompt_service = AsyncMock()
|
||||
prompt_service.request = AsyncMock(
|
||||
side_effect=[phase1_response, phase2_response]
|
||||
)
|
||||
flow.side_effect = lambda service_name: prompt_service if service_name == "prompt-request" else flow_response if service_name == "response" else AsyncMock()
|
||||
|
||||
# Act - Process the message
|
||||
await integration_processor.on_message(msg, consumer, flow)
|
||||
|
||||
# Assert - Verify the complete pipeline
|
||||
assert prompt_service.request.call_count == 2
|
||||
flow_response.send.assert_called_once()
|
||||
|
||||
# Verify response structure and content
|
||||
response_call = flow_response.send.call_args
|
||||
response = response_call[0][0]
|
||||
|
||||
assert isinstance(response, QuestionToStructuredQueryResponse)
|
||||
assert response.error is None
|
||||
assert "customers" in response.graphql_query
|
||||
assert "orders" in response.graphql_query
|
||||
assert "California" in response.graphql_query
|
||||
assert response.detected_schemas == ["customers", "orders"]
|
||||
assert response.confidence == 0.92
|
||||
assert response.variables["min_total"] == "500.0"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complex_multi_table_query_integration(self, integration_processor):
|
||||
"""Test integration with complex multi-table queries"""
|
||||
# Arrange
|
||||
request = QuestionToStructuredQueryRequest(
|
||||
question="Find all electronic products under $100 that are in stock, along with any recent orders",
|
||||
max_results=25
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = request
|
||||
msg.properties.return_value = {"id": "multi-table-test"}
|
||||
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
# Mock responses
|
||||
phase1_response = PromptResponse(
|
||||
text=json.dumps(["products", "orders"]),
|
||||
error=None
|
||||
)
|
||||
|
||||
phase2_response = PromptResponse(
|
||||
text=json.dumps({
|
||||
"query": "query { products(where: {category: {eq: \"Electronics\"}, price: {lt: 100}, in_stock: {eq: true}}) { product_id name price orders { order_id total } } }",
|
||||
"variables": {},
|
||||
"confidence": 0.88
|
||||
}),
|
||||
error=None
|
||||
)
|
||||
|
||||
# Mock the flow context to return prompt service responses
|
||||
prompt_service = AsyncMock()
|
||||
prompt_service.request = AsyncMock(
|
||||
side_effect=[phase1_response, phase2_response]
|
||||
)
|
||||
flow.side_effect = lambda service_name: prompt_service if service_name == "prompt-request" else flow_response if service_name == "response" else AsyncMock()
|
||||
|
||||
# Act
|
||||
await integration_processor.on_message(msg, consumer, flow)
|
||||
|
||||
# Assert
|
||||
response_call = flow_response.send.call_args
|
||||
response = response_call[0][0]
|
||||
|
||||
assert response.detected_schemas == ["products", "orders"]
|
||||
assert "Electronics" in response.graphql_query
|
||||
assert "price: {lt: 100}" in response.graphql_query
|
||||
assert "in_stock: {eq: true}" in response.graphql_query
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_schema_configuration_integration(self, integration_processor):
|
||||
"""Test integration with dynamic schema configuration"""
|
||||
# Arrange - New schema configuration
|
||||
new_schema_config = {
|
||||
"schema": {
|
||||
"inventory": json.dumps({
|
||||
"name": "inventory",
|
||||
"description": "Product inventory tracking",
|
||||
"fields": [
|
||||
{"name": "sku", "type": "string", "primary_key": True},
|
||||
{"name": "quantity", "type": "integer"},
|
||||
{"name": "warehouse_location", "type": "string"}
|
||||
]
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
# Act - Update configuration
|
||||
await integration_processor.on_schema_config(new_schema_config, "v2")
|
||||
|
||||
# Arrange - Test query using new schema
|
||||
request = QuestionToStructuredQueryRequest(
|
||||
question="Show inventory levels for all products in warehouse A",
|
||||
max_results=100
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = request
|
||||
msg.properties.return_value = {"id": "schema-config-test"}
|
||||
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
# Mock responses that use the new schema
|
||||
phase1_response = PromptResponse(
|
||||
text=json.dumps(["inventory"]),
|
||||
error=None
|
||||
)
|
||||
|
||||
phase2_response = PromptResponse(
|
||||
text=json.dumps({
|
||||
"query": "query { inventory(where: {warehouse_location: {eq: \"A\"}}) { sku quantity warehouse_location } }",
|
||||
"variables": {},
|
||||
"confidence": 0.85
|
||||
}),
|
||||
error=None
|
||||
)
|
||||
|
||||
# Mock the flow context to return prompt service responses
|
||||
prompt_service = AsyncMock()
|
||||
prompt_service.request = AsyncMock(
|
||||
side_effect=[phase1_response, phase2_response]
|
||||
)
|
||||
flow.side_effect = lambda service_name: prompt_service if service_name == "prompt-request" else flow_response if service_name == "response" else AsyncMock()
|
||||
|
||||
# Act
|
||||
await integration_processor.on_message(msg, consumer, flow)
|
||||
|
||||
# Assert
|
||||
assert "inventory" in integration_processor.schemas
|
||||
response_call = flow_response.send.call_args
|
||||
response = response_call[0][0]
|
||||
assert response.detected_schemas == ["inventory"]
|
||||
assert "inventory" in response.graphql_query
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_service_error_recovery_integration(self, integration_processor):
|
||||
"""Test integration with prompt service error scenarios"""
|
||||
# Arrange
|
||||
request = QuestionToStructuredQueryRequest(
|
||||
question="Show me customer data",
|
||||
max_results=10
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = request
|
||||
msg.properties.return_value = {"id": "error-recovery-test"}
|
||||
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
# Mock Phase 1 error
|
||||
phase1_error_response = PromptResponse(
|
||||
text="",
|
||||
error=Error(type="template-not-found", message="Schema selection template not available")
|
||||
)
|
||||
|
||||
# Mock the flow context to return prompt service error response
|
||||
prompt_service = AsyncMock()
|
||||
prompt_service.request = AsyncMock(
|
||||
return_value=phase1_error_response
|
||||
)
|
||||
flow.side_effect = lambda service_name: prompt_service if service_name == "prompt-request" else flow_response if service_name == "response" else AsyncMock()
|
||||
|
||||
# Act
|
||||
await integration_processor.on_message(msg, consumer, flow)
|
||||
|
||||
# Assert - Error is properly handled and propagated
|
||||
flow_response.send.assert_called_once()
|
||||
response_call = flow_response.send.call_args
|
||||
response = response_call[0][0]
|
||||
|
||||
assert isinstance(response, QuestionToStructuredQueryResponse)
|
||||
assert response.error is not None
|
||||
assert response.error.type == "nlp-query-error"
|
||||
assert "Prompt service error" in response.error.message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_template_parameter_integration(self, sample_schemas):
|
||||
"""Test integration with different template configurations"""
|
||||
# Test with custom templates
|
||||
custom_processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
pulsar_client=AsyncMock(),
|
||||
config_type="schema",
|
||||
schema_selection_template="custom-schema-selector",
|
||||
graphql_generation_template="custom-graphql-generator"
|
||||
)
|
||||
|
||||
custom_processor.schemas = sample_schemas
|
||||
custom_processor.client = MagicMock()
|
||||
|
||||
request = QuestionToStructuredQueryRequest(
|
||||
question="Test query",
|
||||
max_results=5
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = request
|
||||
msg.properties.return_value = {"id": "template-test"}
|
||||
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
# Mock responses
|
||||
phase1_response = PromptResponse(text=json.dumps(["customers"]), error=None)
|
||||
phase2_response = PromptResponse(
|
||||
text=json.dumps({
|
||||
"query": "query { customers { id name } }",
|
||||
"variables": {},
|
||||
"confidence": 0.9
|
||||
}),
|
||||
error=None
|
||||
)
|
||||
|
||||
# Mock flow context to return prompt service responses
|
||||
mock_prompt_service = AsyncMock()
|
||||
mock_prompt_service.request = AsyncMock(
|
||||
side_effect=[phase1_response, phase2_response]
|
||||
)
|
||||
flow.side_effect = lambda service_name: mock_prompt_service if service_name == "prompt-request" else flow_response if service_name == "response" else AsyncMock()
|
||||
|
||||
# Act
|
||||
await custom_processor.on_message(msg, consumer, flow)
|
||||
|
||||
# Assert - Verify custom templates are used
|
||||
assert custom_processor.schema_selection_template == "custom-schema-selector"
|
||||
assert custom_processor.graphql_generation_template == "custom-graphql-generator"
|
||||
|
||||
# Verify the calls were made
|
||||
assert mock_prompt_service.request.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_schema_set_integration(self, integration_processor):
|
||||
"""Test integration with large numbers of schemas"""
|
||||
# Arrange - Add many schemas
|
||||
large_schema_set = {}
|
||||
for i in range(20):
|
||||
schema_name = f"table_{i:02d}"
|
||||
large_schema_set[schema_name] = RowSchema(
|
||||
name=schema_name,
|
||||
description=f"Test table {i} with sample data",
|
||||
fields=[
|
||||
SchemaField(name="id", type="string", primary=True)
|
||||
] + [SchemaField(name=f"field_{j}", type="string") for j in range(5)]
|
||||
)
|
||||
|
||||
integration_processor.schemas.update(large_schema_set)
|
||||
|
||||
request = QuestionToStructuredQueryRequest(
|
||||
question="Show me data from table_05 and table_12",
|
||||
max_results=20
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = request
|
||||
msg.properties.return_value = {"id": "large-schema-test"}
|
||||
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
# Mock responses
|
||||
phase1_response = PromptResponse(
|
||||
text=json.dumps(["table_05", "table_12"]),
|
||||
error=None
|
||||
)
|
||||
|
||||
phase2_response = PromptResponse(
|
||||
text=json.dumps({
|
||||
"query": "query { table_05 { id field_0 } table_12 { id field_1 } }",
|
||||
"variables": {},
|
||||
"confidence": 0.87
|
||||
}),
|
||||
error=None
|
||||
)
|
||||
|
||||
# Mock the flow context to return prompt service responses
|
||||
prompt_service = AsyncMock()
|
||||
prompt_service.request = AsyncMock(
|
||||
side_effect=[phase1_response, phase2_response]
|
||||
)
|
||||
flow.side_effect = lambda service_name: prompt_service if service_name == "prompt-request" else flow_response if service_name == "response" else AsyncMock()
|
||||
|
||||
# Act
|
||||
await integration_processor.on_message(msg, consumer, flow)
|
||||
|
||||
# Assert - Should handle large schema sets efficiently
|
||||
response_call = flow_response.send.call_args
|
||||
response = response_call[0][0]
|
||||
|
||||
assert response.detected_schemas == ["table_05", "table_12"]
|
||||
assert "table_05" in response.graphql_query
|
||||
assert "table_12" in response.graphql_query
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_request_handling_integration(self, integration_processor):
|
||||
"""Test integration with concurrent request processing"""
|
||||
# Arrange - Multiple concurrent requests
|
||||
requests = []
|
||||
messages = []
|
||||
flows = []
|
||||
|
||||
for i in range(5):
|
||||
request = QuestionToStructuredQueryRequest(
|
||||
question=f"Query {i}: Show me data",
|
||||
max_results=10
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = request
|
||||
msg.properties.return_value = {"id": f"concurrent-test-{i}"}
|
||||
|
||||
flow = MagicMock()
|
||||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
requests.append(request)
|
||||
messages.append(msg)
|
||||
flows.append(flow)
|
||||
|
||||
# Mock responses for all requests - create individual prompt services for each flow
|
||||
prompt_services = []
|
||||
for i in range(5): # 5 concurrent requests
|
||||
phase1_response = PromptResponse(
|
||||
text=json.dumps(["customers"]),
|
||||
error=None
|
||||
)
|
||||
phase2_response = PromptResponse(
|
||||
text=json.dumps({
|
||||
"query": f"query {{ customers {{ id name }} }}",
|
||||
"variables": {},
|
||||
"confidence": 0.9
|
||||
}),
|
||||
error=None
|
||||
)
|
||||
|
||||
# Create a prompt service for this request
|
||||
prompt_service = AsyncMock()
|
||||
prompt_service.request = AsyncMock(
|
||||
side_effect=[phase1_response, phase2_response]
|
||||
)
|
||||
prompt_services.append(prompt_service)
|
||||
|
||||
# Set up the flow for this request
|
||||
flow_response = flows[i].return_value
|
||||
flows[i].side_effect = lambda service_name, ps=prompt_service, fr=flow_response: (
|
||||
ps if service_name == "prompt-request" else
|
||||
fr if service_name == "response" else
|
||||
AsyncMock()
|
||||
)
|
||||
|
||||
# Act - Process all messages concurrently
|
||||
import asyncio
|
||||
consumer = MagicMock()
|
||||
|
||||
tasks = []
|
||||
for msg, flow in zip(messages, flows):
|
||||
task = integration_processor.on_message(msg, consumer, flow)
|
||||
tasks.append(task)
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
# Assert - All requests should be processed
|
||||
total_calls = sum(ps.request.call_count for ps in prompt_services)
|
||||
assert total_calls == 10 # 2 calls per request (phase1 + phase2)
|
||||
for flow in flows:
|
||||
flow.return_value.send.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_performance_timing_integration(self, integration_processor):
|
||||
"""Test performance characteristics of the integration"""
|
||||
# Arrange
|
||||
request = QuestionToStructuredQueryRequest(
|
||||
question="Performance test query",
|
||||
max_results=100
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = request
|
||||
msg.properties.return_value = {"id": "performance-test"}
|
||||
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
# Mock fast responses
|
||||
phase1_response = PromptResponse(text=json.dumps(["customers"]), error=None)
|
||||
phase2_response = PromptResponse(
|
||||
text=json.dumps({
|
||||
"query": "query { customers { id } }",
|
||||
"variables": {},
|
||||
"confidence": 0.9
|
||||
}),
|
||||
error=None
|
||||
)
|
||||
|
||||
# Mock the flow context to return prompt service responses
|
||||
prompt_service = AsyncMock()
|
||||
prompt_service.request = AsyncMock(
|
||||
side_effect=[phase1_response, phase2_response]
|
||||
)
|
||||
flow.side_effect = lambda service_name: prompt_service if service_name == "prompt-request" else flow_response if service_name == "response" else AsyncMock()
|
||||
|
||||
# Act
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
await integration_processor.on_message(msg, consumer, flow)
|
||||
|
||||
end_time = time.time()
|
||||
execution_time = end_time - start_time
|
||||
|
||||
# Assert
|
||||
assert execution_time < 1.0 # Should complete quickly with mocked services
|
||||
flow_response.send.assert_called_once()
|
||||
response_call = flow_response.send.call_args
|
||||
response = response_call[0][0]
|
||||
assert response.error is None
|
||||
|
|
@ -270,9 +270,9 @@ class TestObjectExtractionServiceIntegration:
|
|||
assert len(customer_calls) == 1
|
||||
customer_obj = customer_calls[0]
|
||||
|
||||
assert customer_obj.values["customer_id"] == "CUST001"
|
||||
assert customer_obj.values["name"] == "John Smith"
|
||||
assert customer_obj.values["email"] == "john.smith@email.com"
|
||||
assert customer_obj.values[0]["customer_id"] == "CUST001"
|
||||
assert customer_obj.values[0]["name"] == "John Smith"
|
||||
assert customer_obj.values[0]["email"] == "john.smith@email.com"
|
||||
assert customer_obj.confidence > 0.5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -335,10 +335,10 @@ class TestObjectExtractionServiceIntegration:
|
|||
assert len(product_calls) == 1
|
||||
product_obj = product_calls[0]
|
||||
|
||||
assert product_obj.values["product_id"] == "PROD001"
|
||||
assert product_obj.values["name"] == "Gaming Laptop"
|
||||
assert product_obj.values["price"] == "1299.99"
|
||||
assert product_obj.values["category"] == "electronics"
|
||||
assert product_obj.values[0]["product_id"] == "PROD001"
|
||||
assert product_obj.values[0]["name"] == "Gaming Laptop"
|
||||
assert product_obj.values[0]["price"] == "1299.99"
|
||||
assert product_obj.values[0]["category"] == "electronics"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_extraction_integration(self, integration_config, mock_integrated_flow):
|
||||
|
|
|
|||
|
|
@ -95,12 +95,12 @@ class TestObjectsCassandraIntegration:
|
|||
metadata=[]
|
||||
),
|
||||
schema_name="customer_records",
|
||||
values={
|
||||
values=[{
|
||||
"customer_id": "CUST001",
|
||||
"name": "John Doe",
|
||||
"email": "john@example.com",
|
||||
"age": "30"
|
||||
},
|
||||
}],
|
||||
confidence=0.95,
|
||||
source_span="Customer: John Doe..."
|
||||
)
|
||||
|
|
@ -183,7 +183,7 @@ class TestObjectsCassandraIntegration:
|
|||
product_obj = ExtractedObject(
|
||||
metadata=Metadata(id="p1", user="shop", collection="catalog", metadata=[]),
|
||||
schema_name="products",
|
||||
values={"product_id": "P001", "name": "Widget", "price": "19.99"},
|
||||
values=[{"product_id": "P001", "name": "Widget", "price": "19.99"}],
|
||||
confidence=0.9,
|
||||
source_span="Product..."
|
||||
)
|
||||
|
|
@ -191,7 +191,7 @@ class TestObjectsCassandraIntegration:
|
|||
order_obj = ExtractedObject(
|
||||
metadata=Metadata(id="o1", user="shop", collection="sales", metadata=[]),
|
||||
schema_name="orders",
|
||||
values={"order_id": "O001", "customer_id": "C001", "total": "59.97"},
|
||||
values=[{"order_id": "O001", "customer_id": "C001", "total": "59.97"}],
|
||||
confidence=0.85,
|
||||
source_span="Order..."
|
||||
)
|
||||
|
|
@ -229,7 +229,7 @@ class TestObjectsCassandraIntegration:
|
|||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(id="t1", user="test", collection="test", metadata=[]),
|
||||
schema_name="test_schema",
|
||||
values={"id": "123"}, # missing required_field
|
||||
values=[{"id": "123"}], # missing required_field
|
||||
confidence=0.8,
|
||||
source_span="Test"
|
||||
)
|
||||
|
|
@ -265,7 +265,7 @@ class TestObjectsCassandraIntegration:
|
|||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(id="e1", user="logger", collection="app_events", metadata=[]),
|
||||
schema_name="events",
|
||||
values={"event_type": "login", "timestamp": "2024-01-01T10:00:00Z"},
|
||||
values=[{"event_type": "login", "timestamp": "2024-01-01T10:00:00Z"}],
|
||||
confidence=1.0,
|
||||
source_span="Event"
|
||||
)
|
||||
|
|
@ -294,8 +294,8 @@ class TestObjectsCassandraIntegration:
|
|||
async def test_authentication_handling(self, processor_with_mocks):
|
||||
"""Test Cassandra authentication"""
|
||||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
processor.graph_username = "cassandra_user"
|
||||
processor.graph_password = "cassandra_pass"
|
||||
processor.cassandra_username = "cassandra_user"
|
||||
processor.cassandra_password = "cassandra_pass"
|
||||
|
||||
with patch('trustgraph.storage.objects.cassandra.write.Cluster') as mock_cluster_class:
|
||||
with patch('trustgraph.storage.objects.cassandra.write.PlainTextAuthProvider') as mock_auth:
|
||||
|
|
@ -334,7 +334,7 @@ class TestObjectsCassandraIntegration:
|
|||
test_obj = ExtractedObject(
|
||||
metadata=Metadata(id="t1", user="test", collection="test", metadata=[]),
|
||||
schema_name="test",
|
||||
values={"id": "123"},
|
||||
values=[{"id": "123"}],
|
||||
confidence=0.9,
|
||||
source_span="Test"
|
||||
)
|
||||
|
|
@ -364,7 +364,7 @@ class TestObjectsCassandraIntegration:
|
|||
obj = ExtractedObject(
|
||||
metadata=Metadata(id=f"{coll}-1", user="analytics", collection=coll, metadata=[]),
|
||||
schema_name="data",
|
||||
values={"id": f"ID-{coll}"},
|
||||
values=[{"id": f"ID-{coll}"}],
|
||||
confidence=0.9,
|
||||
source_span="Data"
|
||||
)
|
||||
|
|
@ -381,4 +381,170 @@ class TestObjectsCassandraIntegration:
|
|||
# Check each insert has the correct collection
|
||||
for i, call in enumerate(insert_calls):
|
||||
values = call[0][1]
|
||||
assert collections[i] in values
|
||||
assert collections[i] in values
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_object_processing(self, processor_with_mocks):
|
||||
"""Test processing objects with batched values"""
|
||||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
|
||||
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
# Configure schema
|
||||
config = {
|
||||
"schema": {
|
||||
"batch_customers": json.dumps({
|
||||
"name": "batch_customers",
|
||||
"description": "Customer batch data",
|
||||
"fields": [
|
||||
{"name": "customer_id", "type": "string", "primary_key": True},
|
||||
{"name": "name", "type": "string", "required": True},
|
||||
{"name": "email", "type": "string", "indexed": True}
|
||||
]
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(config, version=1)
|
||||
|
||||
# Process batch object with multiple values
|
||||
batch_obj = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id="batch-001",
|
||||
user="test_user",
|
||||
collection="batch_import",
|
||||
metadata=[]
|
||||
),
|
||||
schema_name="batch_customers",
|
||||
values=[
|
||||
{
|
||||
"customer_id": "CUST001",
|
||||
"name": "John Doe",
|
||||
"email": "john@example.com"
|
||||
},
|
||||
{
|
||||
"customer_id": "CUST002",
|
||||
"name": "Jane Smith",
|
||||
"email": "jane@example.com"
|
||||
},
|
||||
{
|
||||
"customer_id": "CUST003",
|
||||
"name": "Bob Johnson",
|
||||
"email": "bob@example.com"
|
||||
}
|
||||
],
|
||||
confidence=0.92,
|
||||
source_span="Multiple customers extracted from document"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = batch_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Verify table creation
|
||||
table_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "CREATE TABLE" in str(call)]
|
||||
assert len(table_calls) == 1
|
||||
assert "o_batch_customers" in str(table_calls[0])
|
||||
|
||||
# Verify multiple inserts for batch values
|
||||
insert_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "INSERT INTO" in str(call)]
|
||||
# Should have 3 separate inserts for the 3 objects in the batch
|
||||
assert len(insert_calls) == 3
|
||||
|
||||
# Check each insert has correct data
|
||||
for i, call in enumerate(insert_calls):
|
||||
values = call[0][1]
|
||||
assert "batch_import" in values # collection
|
||||
assert f"CUST00{i+1}" in values # customer_id
|
||||
if i == 0:
|
||||
assert "John Doe" in values
|
||||
assert "john@example.com" in values
|
||||
elif i == 1:
|
||||
assert "Jane Smith" in values
|
||||
assert "jane@example.com" in values
|
||||
elif i == 2:
|
||||
assert "Bob Johnson" in values
|
||||
assert "bob@example.com" in values
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_batch_processing(self, processor_with_mocks):
|
||||
"""Test processing objects with empty values array"""
|
||||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
|
||||
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
processor.schemas["empty_test"] = RowSchema(
|
||||
name="empty_test",
|
||||
fields=[Field(name="id", type="string", size=50, primary=True)]
|
||||
)
|
||||
|
||||
# Process empty batch object
|
||||
empty_obj = ExtractedObject(
|
||||
metadata=Metadata(id="empty-1", user="test", collection="empty", metadata=[]),
|
||||
schema_name="empty_test",
|
||||
values=[], # Empty batch
|
||||
confidence=1.0,
|
||||
source_span="No objects found"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = empty_obj
|
||||
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Should still create table
|
||||
table_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "CREATE TABLE" in str(call)]
|
||||
assert len(table_calls) == 1
|
||||
|
||||
# Should not create any insert statements for empty batch
|
||||
insert_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "INSERT INTO" in str(call)]
|
||||
assert len(insert_calls) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_single_and_batch_objects(self, processor_with_mocks):
|
||||
"""Test processing mix of single and batch objects"""
|
||||
processor, mock_cluster, mock_session = processor_with_mocks
|
||||
|
||||
with patch('trustgraph.storage.objects.cassandra.write.Cluster', return_value=mock_cluster):
|
||||
processor.schemas["mixed_test"] = RowSchema(
|
||||
name="mixed_test",
|
||||
fields=[
|
||||
Field(name="id", type="string", size=50, primary=True),
|
||||
Field(name="data", type="string", size=100)
|
||||
]
|
||||
)
|
||||
|
||||
# Single object (backward compatibility)
|
||||
single_obj = ExtractedObject(
|
||||
metadata=Metadata(id="single", user="test", collection="mixed", metadata=[]),
|
||||
schema_name="mixed_test",
|
||||
values=[{"id": "single-1", "data": "single data"}], # Array with single item
|
||||
confidence=0.9,
|
||||
source_span="Single object"
|
||||
)
|
||||
|
||||
# Batch object
|
||||
batch_obj = ExtractedObject(
|
||||
metadata=Metadata(id="batch", user="test", collection="mixed", metadata=[]),
|
||||
schema_name="mixed_test",
|
||||
values=[
|
||||
{"id": "batch-1", "data": "batch data 1"},
|
||||
{"id": "batch-2", "data": "batch data 2"}
|
||||
],
|
||||
confidence=0.85,
|
||||
source_span="Batch objects"
|
||||
)
|
||||
|
||||
# Process both
|
||||
for obj in [single_obj, batch_obj]:
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = obj
|
||||
await processor.on_object(msg, None, None)
|
||||
|
||||
# Should have 3 total inserts (1 + 2)
|
||||
insert_calls = [call for call in mock_session.execute.call_args_list
|
||||
if "INSERT INTO" in str(call)]
|
||||
assert len(insert_calls) == 3
|
||||
624
tests/integration/test_objects_graphql_query_integration.py
Normal file
624
tests/integration/test_objects_graphql_query_integration.py
Normal file
|
|
@ -0,0 +1,624 @@
|
|||
"""
|
||||
Integration tests for Objects GraphQL Query Service
|
||||
|
||||
These tests verify end-to-end functionality including:
|
||||
- Real Cassandra database operations
|
||||
- Full GraphQL query execution
|
||||
- Schema generation and configuration handling
|
||||
- Message processing with actual Pulsar schemas
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
|
||||
# Check if Docker/testcontainers is available
|
||||
try:
|
||||
from testcontainers.cassandra import CassandraContainer
|
||||
import docker
|
||||
# Test Docker connection
|
||||
docker.from_env().ping()
|
||||
DOCKER_AVAILABLE = True
|
||||
except Exception:
|
||||
DOCKER_AVAILABLE = False
|
||||
CassandraContainer = None
|
||||
|
||||
from trustgraph.query.objects.cassandra.service import Processor
|
||||
from trustgraph.schema import ObjectsQueryRequest, ObjectsQueryResponse, GraphQLError
|
||||
from trustgraph.schema import RowSchema, Field, ExtractedObject, Metadata
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.skipif(not DOCKER_AVAILABLE, reason="Docker/testcontainers not available")
|
||||
class TestObjectsGraphQLQueryIntegration:
|
||||
"""Integration tests with real Cassandra database"""
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def cassandra_container(self):
|
||||
"""Start Cassandra container for testing"""
|
||||
if not DOCKER_AVAILABLE:
|
||||
pytest.skip("Docker/testcontainers not available")
|
||||
|
||||
with CassandraContainer("cassandra:3.11") as cassandra:
|
||||
# Wait for Cassandra to be ready
|
||||
cassandra.get_connection_url()
|
||||
yield cassandra
|
||||
|
||||
@pytest.fixture
|
||||
def processor(self, cassandra_container):
|
||||
"""Create processor with real Cassandra connection"""
|
||||
# Extract host and port from container
|
||||
host = cassandra_container.get_container_host_ip()
|
||||
port = cassandra_container.get_exposed_port(9042)
|
||||
|
||||
# Create processor
|
||||
processor = Processor(
|
||||
id="test-graphql-query",
|
||||
graph_host=host,
|
||||
# Note: testcontainer typically doesn't require auth
|
||||
graph_username=None,
|
||||
graph_password=None,
|
||||
config_type="schema"
|
||||
)
|
||||
|
||||
# Override connection parameters for test container
|
||||
processor.graph_host = host
|
||||
processor.cluster = None
|
||||
processor.session = None
|
||||
|
||||
return processor
|
||||
|
||||
@pytest.fixture
|
||||
def sample_schema_config(self):
|
||||
"""Sample schema configuration for testing"""
|
||||
return {
|
||||
"schema": {
|
||||
"customer": json.dumps({
|
||||
"name": "customer",
|
||||
"description": "Customer records",
|
||||
"fields": [
|
||||
{
|
||||
"name": "customer_id",
|
||||
"type": "string",
|
||||
"primary_key": True,
|
||||
"required": True,
|
||||
"description": "Customer identifier"
|
||||
},
|
||||
{
|
||||
"name": "name",
|
||||
"type": "string",
|
||||
"required": True,
|
||||
"indexed": True,
|
||||
"description": "Customer name"
|
||||
},
|
||||
{
|
||||
"name": "email",
|
||||
"type": "string",
|
||||
"required": True,
|
||||
"indexed": True,
|
||||
"description": "Customer email"
|
||||
},
|
||||
{
|
||||
"name": "status",
|
||||
"type": "string",
|
||||
"required": False,
|
||||
"indexed": True,
|
||||
"enum": ["active", "inactive", "pending"],
|
||||
"description": "Customer status"
|
||||
},
|
||||
{
|
||||
"name": "created_date",
|
||||
"type": "timestamp",
|
||||
"required": False,
|
||||
"description": "Registration date"
|
||||
}
|
||||
]
|
||||
}),
|
||||
"order": json.dumps({
|
||||
"name": "order",
|
||||
"description": "Order records",
|
||||
"fields": [
|
||||
{
|
||||
"name": "order_id",
|
||||
"type": "string",
|
||||
"primary_key": True,
|
||||
"required": True
|
||||
},
|
||||
{
|
||||
"name": "customer_id",
|
||||
"type": "string",
|
||||
"required": True,
|
||||
"indexed": True,
|
||||
"description": "Related customer"
|
||||
},
|
||||
{
|
||||
"name": "total",
|
||||
"type": "float",
|
||||
"required": True,
|
||||
"description": "Order total amount"
|
||||
},
|
||||
{
|
||||
"name": "status",
|
||||
"type": "string",
|
||||
"indexed": True,
|
||||
"enum": ["pending", "processing", "shipped", "delivered"],
|
||||
"description": "Order status"
|
||||
}
|
||||
]
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_schema_configuration_and_generation(self, processor, sample_schema_config):
|
||||
"""Test schema configuration loading and GraphQL schema generation"""
|
||||
# Load schema configuration
|
||||
await processor.on_schema_config(sample_schema_config, version=1)
|
||||
|
||||
# Verify schemas were loaded
|
||||
assert len(processor.schemas) == 2
|
||||
assert "customer" in processor.schemas
|
||||
assert "order" in processor.schemas
|
||||
|
||||
# Verify customer schema
|
||||
customer_schema = processor.schemas["customer"]
|
||||
assert customer_schema.name == "customer"
|
||||
assert len(customer_schema.fields) == 5
|
||||
|
||||
# Find primary key field
|
||||
pk_field = next((f for f in customer_schema.fields if f.primary), None)
|
||||
assert pk_field is not None
|
||||
assert pk_field.name == "customer_id"
|
||||
|
||||
# Verify GraphQL schema was generated
|
||||
assert processor.graphql_schema is not None
|
||||
assert len(processor.graphql_types) == 2
|
||||
assert "customer" in processor.graphql_types
|
||||
assert "order" in processor.graphql_types
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cassandra_connection_and_table_creation(self, processor, sample_schema_config):
|
||||
"""Test Cassandra connection and dynamic table creation"""
|
||||
# Load schema configuration
|
||||
await processor.on_schema_config(sample_schema_config, version=1)
|
||||
|
||||
# Connect to Cassandra
|
||||
processor.connect_cassandra()
|
||||
assert processor.session is not None
|
||||
|
||||
# Create test keyspace and table
|
||||
keyspace = "test_user"
|
||||
collection = "test_collection"
|
||||
schema_name = "customer"
|
||||
schema = processor.schemas[schema_name]
|
||||
|
||||
# Ensure table creation
|
||||
processor.ensure_table(keyspace, schema_name, schema)
|
||||
|
||||
# Verify keyspace and table tracking
|
||||
assert keyspace in processor.known_keyspaces
|
||||
assert keyspace in processor.known_tables
|
||||
|
||||
# Verify table was created by querying Cassandra system tables
|
||||
safe_keyspace = processor.sanitize_name(keyspace)
|
||||
safe_table = processor.sanitize_table(schema_name)
|
||||
|
||||
# Check if table exists
|
||||
table_query = """
|
||||
SELECT table_name FROM system_schema.tables
|
||||
WHERE keyspace_name = %s AND table_name = %s
|
||||
"""
|
||||
result = processor.session.execute(table_query, (safe_keyspace, safe_table))
|
||||
rows = list(result)
|
||||
assert len(rows) == 1
|
||||
assert rows[0].table_name == safe_table
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_data_insertion_and_graphql_query(self, processor, sample_schema_config):
|
||||
"""Test inserting data and querying via GraphQL"""
|
||||
# Load schema and connect
|
||||
await processor.on_schema_config(sample_schema_config, version=1)
|
||||
processor.connect_cassandra()
|
||||
|
||||
# Setup test data
|
||||
keyspace = "test_user"
|
||||
collection = "integration_test"
|
||||
schema_name = "customer"
|
||||
schema = processor.schemas[schema_name]
|
||||
|
||||
# Ensure table exists
|
||||
processor.ensure_table(keyspace, schema_name, schema)
|
||||
|
||||
# Insert test data directly (simulating what storage processor would do)
|
||||
safe_keyspace = processor.sanitize_name(keyspace)
|
||||
safe_table = processor.sanitize_table(schema_name)
|
||||
|
||||
insert_query = f"""
|
||||
INSERT INTO {safe_keyspace}.{safe_table}
|
||||
(collection, customer_id, name, email, status, created_date)
|
||||
VALUES (%s, %s, %s, %s, %s, %s)
|
||||
"""
|
||||
|
||||
test_customers = [
|
||||
(collection, "CUST001", "John Doe", "john@example.com", "active", "2024-01-15"),
|
||||
(collection, "CUST002", "Jane Smith", "jane@example.com", "active", "2024-01-16"),
|
||||
(collection, "CUST003", "Bob Wilson", "bob@example.com", "inactive", "2024-01-17")
|
||||
]
|
||||
|
||||
for customer_data in test_customers:
|
||||
processor.session.execute(insert_query, customer_data)
|
||||
|
||||
# Test GraphQL query execution
|
||||
graphql_query = '''
|
||||
{
|
||||
customer_objects(collection: "integration_test") {
|
||||
customer_id
|
||||
name
|
||||
email
|
||||
status
|
||||
}
|
||||
}
|
||||
'''
|
||||
|
||||
result = await processor.execute_graphql_query(
|
||||
query=graphql_query,
|
||||
variables={},
|
||||
operation_name=None,
|
||||
user=keyspace,
|
||||
collection=collection
|
||||
)
|
||||
|
||||
# Verify query results
|
||||
assert "data" in result
|
||||
assert "customer_objects" in result["data"]
|
||||
|
||||
customers = result["data"]["customer_objects"]
|
||||
assert len(customers) == 3
|
||||
|
||||
# Verify customer data
|
||||
customer_ids = [c["customer_id"] for c in customers]
|
||||
assert "CUST001" in customer_ids
|
||||
assert "CUST002" in customer_ids
|
||||
assert "CUST003" in customer_ids
|
||||
|
||||
# Find specific customer and verify fields
|
||||
john = next(c for c in customers if c["customer_id"] == "CUST001")
|
||||
assert john["name"] == "John Doe"
|
||||
assert john["email"] == "john@example.com"
|
||||
assert john["status"] == "active"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graphql_query_with_filters(self, processor, sample_schema_config):
|
||||
"""Test GraphQL queries with filtering on indexed fields"""
|
||||
# Setup (reuse previous setup)
|
||||
await processor.on_schema_config(sample_schema_config, version=1)
|
||||
processor.connect_cassandra()
|
||||
|
||||
keyspace = "test_user"
|
||||
collection = "filter_test"
|
||||
schema_name = "customer"
|
||||
schema = processor.schemas[schema_name]
|
||||
|
||||
processor.ensure_table(keyspace, schema_name, schema)
|
||||
|
||||
# Insert test data
|
||||
safe_keyspace = processor.sanitize_name(keyspace)
|
||||
safe_table = processor.sanitize_table(schema_name)
|
||||
|
||||
insert_query = f"""
|
||||
INSERT INTO {safe_keyspace}.{safe_table}
|
||||
(collection, customer_id, name, email, status)
|
||||
VALUES (%s, %s, %s, %s, %s)
|
||||
"""
|
||||
|
||||
test_data = [
|
||||
(collection, "A001", "Active User 1", "active1@test.com", "active"),
|
||||
(collection, "A002", "Active User 2", "active2@test.com", "active"),
|
||||
(collection, "I001", "Inactive User", "inactive@test.com", "inactive")
|
||||
]
|
||||
|
||||
for data in test_data:
|
||||
processor.session.execute(insert_query, data)
|
||||
|
||||
# Query with status filter (indexed field)
|
||||
filtered_query = '''
|
||||
{
|
||||
customer_objects(collection: "filter_test", status: "active") {
|
||||
customer_id
|
||||
name
|
||||
status
|
||||
}
|
||||
}
|
||||
'''
|
||||
|
||||
result = await processor.execute_graphql_query(
|
||||
query=filtered_query,
|
||||
variables={},
|
||||
operation_name=None,
|
||||
user=keyspace,
|
||||
collection=collection
|
||||
)
|
||||
|
||||
# Verify filtered results
|
||||
assert "data" in result
|
||||
customers = result["data"]["customer_objects"]
|
||||
assert len(customers) == 2 # Only active customers
|
||||
|
||||
for customer in customers:
|
||||
assert customer["status"] == "active"
|
||||
assert customer["customer_id"] in ["A001", "A002"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graphql_error_handling(self, processor, sample_schema_config):
|
||||
"""Test GraphQL error handling for invalid queries"""
|
||||
# Setup
|
||||
await processor.on_schema_config(sample_schema_config, version=1)
|
||||
|
||||
# Test invalid field query
|
||||
invalid_query = '''
|
||||
{
|
||||
customer_objects {
|
||||
customer_id
|
||||
nonexistent_field
|
||||
}
|
||||
}
|
||||
'''
|
||||
|
||||
result = await processor.execute_graphql_query(
|
||||
query=invalid_query,
|
||||
variables={},
|
||||
operation_name=None,
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
# Verify error response
|
||||
assert "errors" in result
|
||||
assert len(result["errors"]) > 0
|
||||
|
||||
error = result["errors"][0]
|
||||
assert "message" in error
|
||||
# GraphQL error should mention the invalid field
|
||||
assert "nonexistent_field" in error["message"] or "Cannot query field" in error["message"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_processing_integration(self, processor, sample_schema_config):
|
||||
"""Test full message processing workflow"""
|
||||
# Setup
|
||||
await processor.on_schema_config(sample_schema_config, version=1)
|
||||
processor.connect_cassandra()
|
||||
|
||||
# Create mock message
|
||||
request = ObjectsQueryRequest(
|
||||
user="msg_test_user",
|
||||
collection="msg_test_collection",
|
||||
query='{ customer_objects { customer_id name } }',
|
||||
variables={},
|
||||
operation_name=""
|
||||
)
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = request
|
||||
mock_msg.properties.return_value = {"id": "integration-test-123"}
|
||||
|
||||
# Mock flow for response
|
||||
mock_response_producer = AsyncMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.return_value = mock_response_producer
|
||||
|
||||
# Process message
|
||||
await processor.on_message(mock_msg, None, mock_flow)
|
||||
|
||||
# Verify response was sent
|
||||
mock_response_producer.send.assert_called_once()
|
||||
|
||||
# Verify response structure
|
||||
sent_response = mock_response_producer.send.call_args[0][0]
|
||||
assert isinstance(sent_response, ObjectsQueryResponse)
|
||||
|
||||
# Should have no system error (even if no data)
|
||||
assert sent_response.error is None
|
||||
|
||||
# Data should be JSON string (even if empty result)
|
||||
assert sent_response.data is not None
|
||||
assert isinstance(sent_response.data, str)
|
||||
|
||||
# Should be able to parse as JSON
|
||||
parsed_data = json.loads(sent_response.data)
|
||||
assert isinstance(parsed_data, dict)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_queries(self, processor, sample_schema_config):
|
||||
"""Test handling multiple concurrent GraphQL queries"""
|
||||
# Setup
|
||||
await processor.on_schema_config(sample_schema_config, version=1)
|
||||
processor.connect_cassandra()
|
||||
|
||||
# Create multiple query tasks
|
||||
queries = [
|
||||
'{ customer_objects { customer_id } }',
|
||||
'{ order_objects { order_id } }',
|
||||
'{ customer_objects { name email } }',
|
||||
'{ order_objects { total status } }'
|
||||
]
|
||||
|
||||
# Execute queries concurrently
|
||||
tasks = []
|
||||
for i, query in enumerate(queries):
|
||||
task = processor.execute_graphql_query(
|
||||
query=query,
|
||||
variables={},
|
||||
operation_name=None,
|
||||
user=f"concurrent_user_{i}",
|
||||
collection=f"concurrent_collection_{i}"
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
# Wait for all queries to complete
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Verify all queries completed without exceptions
|
||||
for i, result in enumerate(results):
|
||||
assert not isinstance(result, Exception), f"Query {i} failed: {result}"
|
||||
assert "data" in result or "errors" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_schema_update_handling(self, processor):
|
||||
"""Test handling of schema configuration updates"""
|
||||
# Load initial schema
|
||||
initial_config = {
|
||||
"schema": {
|
||||
"simple": json.dumps({
|
||||
"name": "simple",
|
||||
"fields": [{"name": "id", "type": "string", "primary_key": True}]
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(initial_config, version=1)
|
||||
assert len(processor.schemas) == 1
|
||||
assert "simple" in processor.schemas
|
||||
|
||||
# Update with additional schema
|
||||
updated_config = {
|
||||
"schema": {
|
||||
"simple": json.dumps({
|
||||
"name": "simple",
|
||||
"fields": [
|
||||
{"name": "id", "type": "string", "primary_key": True},
|
||||
{"name": "name", "type": "string"} # New field
|
||||
]
|
||||
}),
|
||||
"complex": json.dumps({
|
||||
"name": "complex",
|
||||
"fields": [
|
||||
{"name": "id", "type": "string", "primary_key": True},
|
||||
{"name": "data", "type": "string"}
|
||||
]
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(updated_config, version=2)
|
||||
|
||||
# Verify updated schemas
|
||||
assert len(processor.schemas) == 2
|
||||
assert "simple" in processor.schemas
|
||||
assert "complex" in processor.schemas
|
||||
|
||||
# Verify simple schema was updated
|
||||
simple_schema = processor.schemas["simple"]
|
||||
assert len(simple_schema.fields) == 2
|
||||
|
||||
# Verify GraphQL schema was regenerated
|
||||
assert len(processor.graphql_types) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_result_set_handling(self, processor, sample_schema_config):
|
||||
"""Test handling of large query result sets"""
|
||||
# Setup
|
||||
await processor.on_schema_config(sample_schema_config, version=1)
|
||||
processor.connect_cassandra()
|
||||
|
||||
keyspace = "large_test_user"
|
||||
collection = "large_collection"
|
||||
schema_name = "customer"
|
||||
schema = processor.schemas[schema_name]
|
||||
|
||||
processor.ensure_table(keyspace, schema_name, schema)
|
||||
|
||||
# Insert larger dataset
|
||||
safe_keyspace = processor.sanitize_name(keyspace)
|
||||
safe_table = processor.sanitize_table(schema_name)
|
||||
|
||||
insert_query = f"""
|
||||
INSERT INTO {safe_keyspace}.{safe_table}
|
||||
(collection, customer_id, name, email, status)
|
||||
VALUES (%s, %s, %s, %s, %s)
|
||||
"""
|
||||
|
||||
# Insert 50 records
|
||||
for i in range(50):
|
||||
processor.session.execute(insert_query, (
|
||||
collection,
|
||||
f"CUST{i:03d}",
|
||||
f"Customer {i}",
|
||||
f"customer{i}@test.com",
|
||||
"active" if i % 2 == 0 else "inactive"
|
||||
))
|
||||
|
||||
# Query with limit
|
||||
limited_query = '''
|
||||
{
|
||||
customer_objects(collection: "large_collection", limit: 10) {
|
||||
customer_id
|
||||
name
|
||||
}
|
||||
}
|
||||
'''
|
||||
|
||||
result = await processor.execute_graphql_query(
|
||||
query=limited_query,
|
||||
variables={},
|
||||
operation_name=None,
|
||||
user=keyspace,
|
||||
collection=collection
|
||||
)
|
||||
|
||||
# Verify limited results
|
||||
assert "data" in result
|
||||
customers = result["data"]["customer_objects"]
|
||||
assert len(customers) <= 10 # Should be limited
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.skipif(not DOCKER_AVAILABLE, reason="Docker/testcontainers not available")
|
||||
class TestObjectsGraphQLQueryPerformance:
|
||||
"""Performance-focused integration tests"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_execution_timing(self, cassandra_container):
|
||||
"""Test query execution performance and timeout handling"""
|
||||
import time
|
||||
|
||||
# Create processor with shorter timeout for testing
|
||||
host = cassandra_container.get_container_host_ip()
|
||||
|
||||
processor = Processor(
|
||||
id="perf-test-graphql-query",
|
||||
graph_host=host,
|
||||
config_type="schema"
|
||||
)
|
||||
|
||||
# Load minimal schema
|
||||
schema_config = {
|
||||
"schema": {
|
||||
"perf_test": json.dumps({
|
||||
"name": "perf_test",
|
||||
"fields": [{"name": "id", "type": "string", "primary_key": True}]
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
await processor.on_schema_config(schema_config, version=1)
|
||||
|
||||
# Measure query execution time
|
||||
start_time = time.time()
|
||||
|
||||
result = await processor.execute_graphql_query(
|
||||
query='{ perf_test_objects { id } }',
|
||||
variables={},
|
||||
operation_name=None,
|
||||
user="perf_user",
|
||||
collection="perf_collection"
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
execution_time = end_time - start_time
|
||||
|
||||
# Verify reasonable execution time (should be under 1 second for empty result)
|
||||
assert execution_time < 1.0
|
||||
|
||||
# Verify result structure
|
||||
assert "data" in result or "errors" in result
|
||||
748
tests/integration/test_structured_query_integration.py
Normal file
748
tests/integration/test_structured_query_integration.py
Normal file
|
|
@ -0,0 +1,748 @@
|
|||
"""
|
||||
Integration tests for Structured Query Service
|
||||
|
||||
These tests verify the end-to-end functionality of the structured query service,
|
||||
testing orchestration between nlp-query and objects-query services.
|
||||
Following the TEST_STRATEGY.md approach for integration testing.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from trustgraph.schema import (
|
||||
StructuredQueryRequest, StructuredQueryResponse,
|
||||
QuestionToStructuredQueryRequest, QuestionToStructuredQueryResponse,
|
||||
ObjectsQueryRequest, ObjectsQueryResponse,
|
||||
Error, GraphQLError
|
||||
)
|
||||
from trustgraph.retrieval.structured_query.service import Processor
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestStructuredQueryServiceIntegration:
|
||||
"""Integration tests for structured query service orchestration"""
|
||||
|
||||
@pytest.fixture
|
||||
def integration_processor(self):
|
||||
"""Create processor with realistic configuration"""
|
||||
proc = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
pulsar_client=AsyncMock()
|
||||
)
|
||||
|
||||
# Mock the client method
|
||||
proc.client = MagicMock()
|
||||
|
||||
return proc
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_to_end_structured_query_processing(self, integration_processor):
|
||||
"""Test complete structured query processing pipeline"""
|
||||
# Arrange - Create realistic query request
|
||||
request = StructuredQueryRequest(
|
||||
question="Show me all customers from California who have made purchases over $500",
|
||||
user="trustgraph",
|
||||
collection="default"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = request
|
||||
msg.properties.return_value = {"id": "integration-test-001"}
|
||||
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
# Mock NLP Query Service Response
|
||||
nlp_response = QuestionToStructuredQueryResponse(
|
||||
error=None,
|
||||
graphql_query='''
|
||||
query GetCaliforniaCustomersWithLargePurchases($minAmount: String!, $state: String!) {
|
||||
customers(where: {state: {eq: $state}}) {
|
||||
id
|
||||
name
|
||||
email
|
||||
orders(where: {total: {gt: $minAmount}}) {
|
||||
id
|
||||
total
|
||||
date
|
||||
}
|
||||
}
|
||||
}
|
||||
''',
|
||||
variables={
|
||||
"minAmount": "500.0",
|
||||
"state": "California"
|
||||
},
|
||||
detected_schemas=["customers", "orders"],
|
||||
confidence=0.91
|
||||
)
|
||||
|
||||
# Mock Objects Query Service Response
|
||||
objects_response = ObjectsQueryResponse(
|
||||
error=None,
|
||||
data='{"customers": [{"id": "123", "name": "Alice Johnson", "email": "alice@example.com", "orders": [{"id": "456", "total": 750.0, "date": "2024-01-15"}]}]}',
|
||||
errors=None,
|
||||
extensions={"execution_time": "150ms", "query_complexity": "8"}
|
||||
)
|
||||
|
||||
# Set up mock clients to return different responses
|
||||
mock_nlp_client = AsyncMock()
|
||||
mock_nlp_client.request.return_value = nlp_response
|
||||
|
||||
mock_objects_client = AsyncMock()
|
||||
mock_objects_client.request.return_value = objects_response
|
||||
|
||||
# Mock flow context to route to appropriate services
|
||||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
return mock_objects_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
else:
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_router
|
||||
|
||||
# Act - Process the message
|
||||
await integration_processor.on_message(msg, consumer, flow)
|
||||
|
||||
# Assert - Verify the complete orchestration
|
||||
# Verify NLP service call
|
||||
mock_nlp_client.request.assert_called_once()
|
||||
nlp_call_args = mock_nlp_client.request.call_args[0][0]
|
||||
assert isinstance(nlp_call_args, QuestionToStructuredQueryRequest)
|
||||
assert nlp_call_args.question == "Show me all customers from California who have made purchases over $500"
|
||||
assert nlp_call_args.max_results == 100 # Default max_results
|
||||
|
||||
# Verify Objects service call
|
||||
mock_objects_client.request.assert_called_once()
|
||||
objects_call_args = mock_objects_client.request.call_args[0][0]
|
||||
assert isinstance(objects_call_args, ObjectsQueryRequest)
|
||||
assert "customers" in objects_call_args.query
|
||||
assert "orders" in objects_call_args.query
|
||||
assert objects_call_args.variables["minAmount"] == "500.0" # Converted to string
|
||||
assert objects_call_args.variables["state"] == "California"
|
||||
assert objects_call_args.user == "trustgraph"
|
||||
assert objects_call_args.collection == "default"
|
||||
|
||||
# Verify response
|
||||
flow_response.send.assert_called_once()
|
||||
response_call = flow_response.send.call_args
|
||||
response = response_call[0][0]
|
||||
|
||||
assert isinstance(response, StructuredQueryResponse)
|
||||
assert response.error is None
|
||||
assert "Alice Johnson" in response.data
|
||||
assert "750.0" in response.data
|
||||
assert len(response.errors) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nlp_service_integration_failure(self, integration_processor):
|
||||
"""Test integration when NLP service fails"""
|
||||
# Arrange
|
||||
request = StructuredQueryRequest(
|
||||
question="This is an unparseable query ][{}"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = request
|
||||
msg.properties.return_value = {"id": "nlp-failure-test"}
|
||||
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
# Mock NLP service failure
|
||||
nlp_error_response = QuestionToStructuredQueryResponse(
|
||||
error=Error(type="nlp-parsing-error", message="Unable to parse natural language query"),
|
||||
graphql_query="",
|
||||
variables={},
|
||||
detected_schemas=[],
|
||||
confidence=0.0
|
||||
)
|
||||
|
||||
mock_nlp_client = AsyncMock()
|
||||
mock_nlp_client.request.return_value = nlp_error_response
|
||||
|
||||
# Mock flow context to route to nlp service
|
||||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
else:
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_router
|
||||
|
||||
# Act
|
||||
await integration_processor.on_message(msg, consumer, flow)
|
||||
|
||||
# Assert - Error should be propagated properly
|
||||
flow_response.send.assert_called_once()
|
||||
response_call = flow_response.send.call_args
|
||||
response = response_call[0][0]
|
||||
|
||||
assert isinstance(response, StructuredQueryResponse)
|
||||
assert response.error is not None
|
||||
assert response.error.type == "structured-query-error"
|
||||
assert "NLP query service error" in response.error.message
|
||||
assert "Unable to parse natural language query" in response.error.message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_objects_service_integration_failure(self, integration_processor):
|
||||
"""Test integration when Objects service fails"""
|
||||
# Arrange
|
||||
request = StructuredQueryRequest(
|
||||
question="Show me data from a table that doesn't exist"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = request
|
||||
msg.properties.return_value = {"id": "objects-failure-test"}
|
||||
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
# Mock successful NLP response
|
||||
nlp_response = QuestionToStructuredQueryResponse(
|
||||
error=None,
|
||||
graphql_query='query { nonexistent_table { id name } }',
|
||||
variables={},
|
||||
detected_schemas=["nonexistent_table"],
|
||||
confidence=0.7
|
||||
)
|
||||
|
||||
# Mock Objects service failure
|
||||
objects_error_response = ObjectsQueryResponse(
|
||||
error=Error(type="graphql-schema-error", message="Table 'nonexistent_table' does not exist in schema"),
|
||||
data=None,
|
||||
errors=None,
|
||||
extensions={}
|
||||
)
|
||||
|
||||
mock_nlp_client = AsyncMock()
|
||||
mock_nlp_client.request.return_value = nlp_response
|
||||
|
||||
mock_objects_client = AsyncMock()
|
||||
mock_objects_client.request.return_value = objects_error_response
|
||||
|
||||
# Mock flow context to route to appropriate services
|
||||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
return mock_objects_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
else:
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_router
|
||||
|
||||
# Act
|
||||
await integration_processor.on_message(msg, consumer, flow)
|
||||
|
||||
# Assert - Error should be propagated
|
||||
flow_response.send.assert_called_once()
|
||||
response_call = flow_response.send.call_args
|
||||
response = response_call[0][0]
|
||||
|
||||
assert response.error is not None
|
||||
assert response.error.type == "structured-query-error"
|
||||
assert "Objects query service error" in response.error.message
|
||||
assert "nonexistent_table" in response.error.message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graphql_validation_errors_integration(self, integration_processor):
|
||||
"""Test integration with GraphQL validation errors"""
|
||||
# Arrange
|
||||
request = StructuredQueryRequest(
|
||||
question="Show me customer invalid_field values"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = request
|
||||
msg.properties.return_value = {"id": "validation-error-test"}
|
||||
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
# Mock NLP response with invalid field
|
||||
nlp_response = QuestionToStructuredQueryResponse(
|
||||
error=None,
|
||||
graphql_query='query { customers { id invalid_field } }',
|
||||
variables={},
|
||||
detected_schemas=["customers"],
|
||||
confidence=0.8
|
||||
)
|
||||
|
||||
# Mock Objects response with GraphQL validation errors
|
||||
validation_errors = [
|
||||
GraphQLError(
|
||||
message="Cannot query field 'invalid_field' on type 'Customer'",
|
||||
path=["customers", "0", "invalid_field"],
|
||||
extensions={"code": "VALIDATION_ERROR"}
|
||||
),
|
||||
GraphQLError(
|
||||
message="Field 'invalid_field' is not defined in the schema",
|
||||
path=["customers", "invalid_field"],
|
||||
extensions={"code": "FIELD_NOT_FOUND"}
|
||||
)
|
||||
]
|
||||
|
||||
objects_response = ObjectsQueryResponse(
|
||||
error=None,
|
||||
data=None, # No data when validation fails
|
||||
errors=validation_errors,
|
||||
extensions={"validation_errors": "2"}
|
||||
)
|
||||
|
||||
mock_nlp_client = AsyncMock()
|
||||
mock_nlp_client.request.return_value = nlp_response
|
||||
|
||||
mock_objects_client = AsyncMock()
|
||||
mock_objects_client.request.return_value = objects_response
|
||||
|
||||
# Mock flow context to route to appropriate services
|
||||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
return mock_objects_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
else:
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_router
|
||||
|
||||
# Act
|
||||
await integration_processor.on_message(msg, consumer, flow)
|
||||
|
||||
# Assert - GraphQL errors should be included in response
|
||||
flow_response.send.assert_called_once()
|
||||
response_call = flow_response.send.call_args
|
||||
response = response_call[0][0]
|
||||
|
||||
assert response.error is None # No system error
|
||||
assert len(response.errors) == 2 # Two GraphQL errors
|
||||
assert "Cannot query field 'invalid_field'" in response.errors[0]
|
||||
assert "Field 'invalid_field' is not defined" in response.errors[1]
|
||||
assert "customers" in response.errors[0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complex_multi_service_integration(self, integration_processor):
|
||||
"""Test complex integration scenario with multiple entities and relationships"""
|
||||
# Arrange
|
||||
request = StructuredQueryRequest(
|
||||
question="Find all products under $100 that are in stock, along with their recent orders from customers in New York"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = request
|
||||
msg.properties.return_value = {"id": "complex-integration-test"}
|
||||
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
# Mock complex NLP response
|
||||
nlp_response = QuestionToStructuredQueryResponse(
|
||||
error=None,
|
||||
graphql_query='''
|
||||
query GetProductsWithCustomerOrders($maxPrice: String!, $inStock: String!, $state: String!) {
|
||||
products(where: {price: {lt: $maxPrice}, in_stock: {eq: $inStock}}) {
|
||||
id
|
||||
name
|
||||
price
|
||||
orders {
|
||||
id
|
||||
total
|
||||
customer {
|
||||
id
|
||||
name
|
||||
state
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
''',
|
||||
variables={
|
||||
"maxPrice": "100.0",
|
||||
"inStock": "true",
|
||||
"state": "New York"
|
||||
},
|
||||
detected_schemas=["products", "orders", "customers"],
|
||||
confidence=0.85
|
||||
)
|
||||
|
||||
# Mock complex Objects response
|
||||
complex_data = {
|
||||
"products": [
|
||||
{
|
||||
"id": "prod_123",
|
||||
"name": "Widget A",
|
||||
"price": 89.99,
|
||||
"orders": [
|
||||
{
|
||||
"id": "order_456",
|
||||
"total": 179.98,
|
||||
"customer": {
|
||||
"id": "cust_789",
|
||||
"name": "Bob Smith",
|
||||
"state": "New York"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "prod_124",
|
||||
"name": "Widget B",
|
||||
"price": 65.50,
|
||||
"orders": [
|
||||
{
|
||||
"id": "order_457",
|
||||
"total": 131.00,
|
||||
"customer": {
|
||||
"id": "cust_790",
|
||||
"name": "Carol Jones",
|
||||
"state": "New York"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
objects_response = ObjectsQueryResponse(
|
||||
error=None,
|
||||
data=json.dumps(complex_data),
|
||||
errors=None,
|
||||
extensions={
|
||||
"execution_time": "250ms",
|
||||
"query_complexity": "15",
|
||||
"data_sources": "products,orders,customers" # Convert array to comma-separated string
|
||||
}
|
||||
)
|
||||
|
||||
mock_nlp_client = AsyncMock()
|
||||
mock_nlp_client.request.return_value = nlp_response
|
||||
|
||||
mock_objects_client = AsyncMock()
|
||||
mock_objects_client.request.return_value = objects_response
|
||||
|
||||
# Mock flow context to route to appropriate services
|
||||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
return mock_objects_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
else:
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_router
|
||||
|
||||
# Act
|
||||
await integration_processor.on_message(msg, consumer, flow)
|
||||
|
||||
# Assert - Verify complex data integration
|
||||
# Check NLP service call
|
||||
nlp_call_args = mock_nlp_client.request.call_args[0][0]
|
||||
assert len(nlp_call_args.question) > 50 # Complex question
|
||||
|
||||
# Check Objects service call with variable conversion
|
||||
objects_call_args = mock_objects_client.request.call_args[0][0]
|
||||
assert objects_call_args.variables["maxPrice"] == "100.0"
|
||||
assert objects_call_args.variables["inStock"] == "true"
|
||||
assert objects_call_args.variables["state"] == "New York"
|
||||
|
||||
# Check response contains complex data
|
||||
response_call = flow_response.send.call_args
|
||||
response = response_call[0][0]
|
||||
|
||||
assert response.error is None
|
||||
assert "Widget A" in response.data
|
||||
assert "Widget B" in response.data
|
||||
assert "Bob Smith" in response.data
|
||||
assert "Carol Jones" in response.data
|
||||
assert "New York" in response.data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_result_integration(self, integration_processor):
|
||||
"""Test integration when query returns empty results"""
|
||||
# Arrange
|
||||
request = StructuredQueryRequest(
|
||||
question="Show me customers from Mars"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = request
|
||||
msg.properties.return_value = {"id": "empty-result-test"}
|
||||
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
# Mock NLP response
|
||||
nlp_response = QuestionToStructuredQueryResponse(
|
||||
error=None,
|
||||
graphql_query='query { customers(where: {planet: {eq: "Mars"}}) { id name planet } }',
|
||||
variables={},
|
||||
detected_schemas=["customers"],
|
||||
confidence=0.9
|
||||
)
|
||||
|
||||
# Mock empty Objects response
|
||||
objects_response = ObjectsQueryResponse(
|
||||
error=None,
|
||||
data='{"customers": []}', # Empty result set
|
||||
errors=None,
|
||||
extensions={"result_count": "0"}
|
||||
)
|
||||
|
||||
mock_nlp_client = AsyncMock()
|
||||
mock_nlp_client.request.return_value = nlp_response
|
||||
|
||||
mock_objects_client = AsyncMock()
|
||||
mock_objects_client.request.return_value = objects_response
|
||||
|
||||
# Mock flow context to route to appropriate services
|
||||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
return mock_objects_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
else:
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_router
|
||||
|
||||
# Act
|
||||
await integration_processor.on_message(msg, consumer, flow)
|
||||
|
||||
# Assert - Empty results should be handled gracefully
|
||||
response_call = flow_response.send.call_args
|
||||
response = response_call[0][0]
|
||||
|
||||
assert response.error is None
|
||||
assert response.data == '{"customers": []}'
|
||||
assert len(response.errors) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_requests_integration(self, integration_processor):
|
||||
"""Test integration with concurrent request processing"""
|
||||
# Arrange - Multiple concurrent requests
|
||||
requests = []
|
||||
messages = []
|
||||
flows = []
|
||||
|
||||
for i in range(3):
|
||||
request = StructuredQueryRequest(
|
||||
question=f"Query {i}: Show me data"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = request
|
||||
msg.properties.return_value = {"id": f"concurrent-test-{i}"}
|
||||
|
||||
flow = MagicMock()
|
||||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
requests.append(request)
|
||||
messages.append(msg)
|
||||
flows.append(flow)
|
||||
|
||||
# Set up individual flow routing for each concurrent request
|
||||
service_call_count = 0
|
||||
|
||||
for i in range(3): # 3 concurrent requests
|
||||
# Create NLP and Objects responses for this request
|
||||
nlp_response = QuestionToStructuredQueryResponse(
|
||||
error=None,
|
||||
graphql_query=f'query {{ test_{i} {{ id }} }}',
|
||||
variables={},
|
||||
detected_schemas=[f"test_{i}"],
|
||||
confidence=0.9
|
||||
)
|
||||
|
||||
objects_response = ObjectsQueryResponse(
|
||||
error=None,
|
||||
data=f'{{"test_{i}": [{{"id": "{i}"}}]}}',
|
||||
errors=None,
|
||||
extensions={}
|
||||
)
|
||||
|
||||
# Create mock services for this request
|
||||
mock_nlp_client = AsyncMock()
|
||||
mock_nlp_client.request.return_value = nlp_response
|
||||
|
||||
mock_objects_client = AsyncMock()
|
||||
mock_objects_client.request.return_value = objects_response
|
||||
|
||||
# Set up flow routing for this specific request
|
||||
flow_response = flows[i].return_value
|
||||
def create_flow_router(nlp_client, objects_client, response_producer):
|
||||
def flow_router(service_name):
|
||||
nonlocal service_call_count
|
||||
if service_name == "nlp-query-request":
|
||||
service_call_count += 1
|
||||
return nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
service_call_count += 1
|
||||
return objects_client
|
||||
elif service_name == "response":
|
||||
return response_producer
|
||||
else:
|
||||
return AsyncMock()
|
||||
return flow_router
|
||||
|
||||
flows[i].side_effect = create_flow_router(mock_nlp_client, mock_objects_client, flow_response)
|
||||
|
||||
# Act - Process all messages concurrently
|
||||
import asyncio
|
||||
consumer = MagicMock()
|
||||
|
||||
tasks = []
|
||||
for msg, flow in zip(messages, flows):
|
||||
task = integration_processor.on_message(msg, consumer, flow)
|
||||
tasks.append(task)
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
# Assert - All requests should be processed
|
||||
assert service_call_count == 6 # 2 calls per request (NLP + Objects)
|
||||
for flow in flows:
|
||||
flow.return_value.send.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_timeout_integration(self, integration_processor):
|
||||
"""Test integration with service timeout scenarios"""
|
||||
# Arrange
|
||||
request = StructuredQueryRequest(
|
||||
question="This query will timeout"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = request
|
||||
msg.properties.return_value = {"id": "timeout-test"}
|
||||
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
# Mock NLP service timeout
|
||||
mock_nlp_client = AsyncMock()
|
||||
mock_nlp_client.request.side_effect = Exception("Service timeout: Request took longer than 30s")
|
||||
|
||||
# Mock flow context to route to nlp service
|
||||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
else:
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_router
|
||||
|
||||
# Act
|
||||
await integration_processor.on_message(msg, consumer, flow)
|
||||
|
||||
# Assert - Timeout should be handled gracefully
|
||||
flow_response.send.assert_called_once()
|
||||
response_call = flow_response.send.call_args
|
||||
response = response_call[0][0]
|
||||
|
||||
assert response.error is not None
|
||||
assert response.error.type == "structured-query-error"
|
||||
assert "timeout" in response.error.message.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_variable_type_conversion_integration(self, integration_processor):
|
||||
"""Test integration with complex variable type conversions"""
|
||||
# Arrange
|
||||
request = StructuredQueryRequest(
|
||||
question="Show me orders with totals between 50.5 and 200.75 from the last 30 days"
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = request
|
||||
msg.properties.return_value = {"id": "variable-conversion-test"}
|
||||
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
# Mock NLP response with various data types that need string conversion
|
||||
nlp_response = QuestionToStructuredQueryResponse(
|
||||
error=None,
|
||||
graphql_query='query($minTotal: Float!, $maxTotal: Float!, $daysPast: Int!) { orders(filter: {total: {between: [$minTotal, $maxTotal]}, date: {gte: $daysPast}}) { id total date } }',
|
||||
variables={
|
||||
"minTotal": "50.5", # Already string
|
||||
"maxTotal": "200.75", # Already string
|
||||
"daysPast": "30" # Already string
|
||||
},
|
||||
detected_schemas=["orders"],
|
||||
confidence=0.88
|
||||
)
|
||||
|
||||
# Mock Objects response
|
||||
objects_response = ObjectsQueryResponse(
|
||||
error=None,
|
||||
data='{"orders": [{"id": "123", "total": 125.50, "date": "2024-01-15"}]}',
|
||||
errors=None,
|
||||
extensions={}
|
||||
)
|
||||
|
||||
mock_nlp_client = AsyncMock()
|
||||
mock_nlp_client.request.return_value = nlp_response
|
||||
|
||||
mock_objects_client = AsyncMock()
|
||||
mock_objects_client.request.return_value = objects_response
|
||||
|
||||
# Mock flow context to route to appropriate services
|
||||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
return mock_objects_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
else:
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_router
|
||||
|
||||
# Act
|
||||
await integration_processor.on_message(msg, consumer, flow)
|
||||
|
||||
# Assert - Variables should be properly converted to strings
|
||||
objects_call_args = mock_objects_client.request.call_args[0][0]
|
||||
|
||||
# All variables should be strings for Pulsar schema compatibility
|
||||
assert isinstance(objects_call_args.variables["minTotal"], str)
|
||||
assert isinstance(objects_call_args.variables["maxTotal"], str)
|
||||
assert isinstance(objects_call_args.variables["daysPast"], str)
|
||||
|
||||
# Values should be preserved
|
||||
assert objects_call_args.variables["minTotal"] == "50.5"
|
||||
assert objects_call_args.variables["maxTotal"] == "200.75"
|
||||
assert objects_call_args.variables["daysPast"] == "30"
|
||||
|
||||
# Response should contain expected data
|
||||
response_call = flow_response.send.call_args
|
||||
response = response_call[0][0]
|
||||
assert response.error is None
|
||||
assert "125.50" in response.data
|
||||
267
tests/integration/test_tool_group_integration.py
Normal file
267
tests/integration/test_tool_group_integration.py
Normal file
|
|
@ -0,0 +1,267 @@
|
|||
"""
|
||||
Integration tests for the tool group system.
|
||||
|
||||
Tests the complete workflow of tool filtering and execution logic.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
from unittest.mock import Mock, AsyncMock, patch
|
||||
|
||||
# Add trustgraph paths for imports
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'trustgraph-base'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'trustgraph-flow'))
|
||||
|
||||
from trustgraph.agent.tool_filter import filter_tools_by_group_and_state, get_next_state, validate_tool_config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_tools():
|
||||
"""Sample tools with different groups and states for testing."""
|
||||
return {
|
||||
'knowledge_query': Mock(config={
|
||||
'group': ['read-only', 'knowledge', 'basic'],
|
||||
'state': 'analysis',
|
||||
'applicable-states': ['undefined', 'research']
|
||||
}),
|
||||
'graph_update': Mock(config={
|
||||
'group': ['write', 'knowledge', 'admin'],
|
||||
'applicable-states': ['analysis', 'modification']
|
||||
}),
|
||||
'text_completion': Mock(config={
|
||||
'group': ['read-only', 'text', 'basic'],
|
||||
'state': 'undefined'
|
||||
# No applicable-states = available in all states
|
||||
}),
|
||||
'complex_analysis': Mock(config={
|
||||
'group': ['advanced', 'compute', 'expensive'],
|
||||
'state': 'results',
|
||||
'applicable-states': ['analysis']
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
class TestToolGroupFiltering:
|
||||
"""Test tool group filtering integration scenarios."""
|
||||
|
||||
def test_basic_group_filtering(self, sample_tools):
|
||||
"""Test that filtering only returns tools matching requested groups."""
|
||||
|
||||
# Filter for read-only and knowledge tools
|
||||
filtered = filter_tools_by_group_and_state(
|
||||
sample_tools,
|
||||
['read-only', 'knowledge'],
|
||||
'undefined'
|
||||
)
|
||||
|
||||
# Should include tools with matching groups and correct state
|
||||
assert 'knowledge_query' in filtered # Has read-only + knowledge, available in undefined
|
||||
assert 'text_completion' in filtered # Has read-only, available in all states
|
||||
assert 'graph_update' not in filtered # Has knowledge but no read-only
|
||||
assert 'complex_analysis' not in filtered # Wrong groups and state
|
||||
|
||||
def test_state_based_filtering(self, sample_tools):
|
||||
"""Test filtering based on current state."""
|
||||
|
||||
# Filter for analysis state with advanced tools
|
||||
filtered = filter_tools_by_group_and_state(
|
||||
sample_tools,
|
||||
['advanced', 'compute'],
|
||||
'analysis'
|
||||
)
|
||||
|
||||
# Should only include tools available in analysis state
|
||||
assert 'complex_analysis' in filtered # Available in analysis state
|
||||
assert 'knowledge_query' not in filtered # Not available in analysis state
|
||||
assert 'graph_update' not in filtered # Wrong group (no advanced/compute)
|
||||
assert 'text_completion' not in filtered # Wrong group
|
||||
|
||||
def test_state_transition_handling(self, sample_tools):
|
||||
"""Test state transitions after tool execution."""
|
||||
|
||||
# Get knowledge_query tool and test state transition
|
||||
knowledge_tool = sample_tools['knowledge_query']
|
||||
|
||||
# Test state transition
|
||||
next_state = get_next_state(knowledge_tool, 'undefined')
|
||||
assert next_state == 'analysis' # knowledge_query should transition to analysis
|
||||
|
||||
# Test tool with no state transition
|
||||
text_tool = sample_tools['text_completion']
|
||||
next_state = get_next_state(text_tool, 'research')
|
||||
assert next_state == 'undefined' # text_completion transitions to undefined
|
||||
|
||||
def test_wildcard_group_access(self, sample_tools):
|
||||
"""Test wildcard group grants access to all tools."""
|
||||
|
||||
# Filter with wildcard group access
|
||||
filtered = filter_tools_by_group_and_state(
|
||||
sample_tools,
|
||||
['*'], # Wildcard access
|
||||
'undefined'
|
||||
)
|
||||
|
||||
# Should include all tools that are available in undefined state
|
||||
assert 'knowledge_query' in filtered # Available in undefined
|
||||
assert 'text_completion' in filtered # Available in all states
|
||||
assert 'graph_update' not in filtered # Not available in undefined
|
||||
assert 'complex_analysis' not in filtered # Not available in undefined
|
||||
|
||||
def test_no_matching_tools(self, sample_tools):
|
||||
"""Test behavior when no tools match the requested groups."""
|
||||
|
||||
# Filter with non-matching group
|
||||
filtered = filter_tools_by_group_and_state(
|
||||
sample_tools,
|
||||
['nonexistent-group'],
|
||||
'undefined'
|
||||
)
|
||||
|
||||
# Should return empty dictionary
|
||||
assert len(filtered) == 0
|
||||
|
||||
def test_default_group_behavior(self):
|
||||
"""Test default group behavior when no group is specified."""
|
||||
|
||||
# Create tools with and without explicit groups
|
||||
tools = {
|
||||
'default_tool': Mock(config={}), # No group = default group
|
||||
'admin_tool': Mock(config={'group': ['admin']})
|
||||
}
|
||||
|
||||
# Filter with no group specified (should default to ["default"])
|
||||
filtered = filter_tools_by_group_and_state(tools, None, 'undefined')
|
||||
|
||||
# Only default_tool should be available
|
||||
assert 'default_tool' in filtered
|
||||
assert 'admin_tool' not in filtered
|
||||
|
||||
|
||||
class TestToolConfigurationValidation:
|
||||
"""Test tool configuration validation with group metadata."""
|
||||
|
||||
def test_tool_config_validation_invalid(self):
|
||||
"""Test that invalid tool configurations are rejected."""
|
||||
|
||||
# Test invalid group field (should be list)
|
||||
invalid_config = {
|
||||
"name": "invalid_tool",
|
||||
"description": "Invalid tool",
|
||||
"type": "text-completion",
|
||||
"group": "not-a-list" # Should be list
|
||||
}
|
||||
|
||||
# Should raise validation error
|
||||
with pytest.raises(ValueError, match="'group' field must be a list"):
|
||||
validate_tool_config(invalid_config)
|
||||
|
||||
def test_tool_config_validation_valid(self):
|
||||
"""Test that valid tool configurations are accepted."""
|
||||
|
||||
valid_config = {
|
||||
"name": "valid_tool",
|
||||
"description": "Valid tool",
|
||||
"type": "text-completion",
|
||||
"group": ["read-only", "text"],
|
||||
"state": "analysis",
|
||||
"applicable-states": ["undefined", "research"]
|
||||
}
|
||||
|
||||
# Should not raise any exception
|
||||
validate_tool_config(valid_config)
|
||||
|
||||
def test_kebab_case_field_names(self):
|
||||
"""Test that kebab-case field names are properly handled."""
|
||||
|
||||
config = {
|
||||
"name": "test_tool",
|
||||
"group": ["basic"],
|
||||
"applicable-states": ["undefined", "analysis"] # kebab-case
|
||||
}
|
||||
|
||||
# Should validate without error
|
||||
validate_tool_config(config)
|
||||
|
||||
# Create mock tool and test filtering
|
||||
tool = Mock(config=config)
|
||||
|
||||
# Test that kebab-case field is properly read
|
||||
filtered = filter_tools_by_group_and_state(
|
||||
{'test_tool': tool},
|
||||
['basic'],
|
||||
'analysis'
|
||||
)
|
||||
|
||||
assert 'test_tool' in filtered
|
||||
|
||||
|
||||
class TestCompleteWorkflow:
|
||||
"""Test complete multi-step workflows with state transitions."""
|
||||
|
||||
def test_research_analysis_workflow(self, sample_tools):
|
||||
"""Test complete research -> analysis -> results workflow."""
|
||||
|
||||
# Step 1: Initial research phase (undefined state)
|
||||
step1_filtered = filter_tools_by_group_and_state(
|
||||
sample_tools,
|
||||
['read-only', 'knowledge'],
|
||||
'undefined'
|
||||
)
|
||||
|
||||
# Should have access to knowledge_query and text_completion
|
||||
assert 'knowledge_query' in step1_filtered
|
||||
assert 'text_completion' in step1_filtered
|
||||
assert 'complex_analysis' not in step1_filtered # Not available in undefined
|
||||
|
||||
# Simulate executing knowledge_query tool
|
||||
knowledge_tool = step1_filtered['knowledge_query']
|
||||
next_state = get_next_state(knowledge_tool, 'undefined')
|
||||
assert next_state == 'analysis' # Transition to analysis state
|
||||
|
||||
# Step 2: Analysis phase
|
||||
step2_filtered = filter_tools_by_group_and_state(
|
||||
sample_tools,
|
||||
['advanced', 'compute', 'text'], # Include text for text_completion
|
||||
'analysis'
|
||||
)
|
||||
|
||||
# Should have access to complex_analysis and text_completion
|
||||
assert 'complex_analysis' in step2_filtered
|
||||
assert 'text_completion' in step2_filtered # Available in all states
|
||||
assert 'knowledge_query' not in step2_filtered # Not available in analysis
|
||||
|
||||
# Simulate executing complex_analysis tool
|
||||
analysis_tool = step2_filtered['complex_analysis']
|
||||
final_state = get_next_state(analysis_tool, 'analysis')
|
||||
assert final_state == 'results' # Transition to results state
|
||||
|
||||
def test_multi_tenant_scenario(self, sample_tools):
|
||||
"""Test different users with different permissions."""
|
||||
|
||||
# User A: Read-only permissions in undefined state
|
||||
user_a_tools = filter_tools_by_group_and_state(
|
||||
sample_tools,
|
||||
['read-only'],
|
||||
'undefined'
|
||||
)
|
||||
|
||||
# Should only have access to read-only tools in undefined state
|
||||
assert 'knowledge_query' in user_a_tools # read-only + available in undefined
|
||||
assert 'text_completion' in user_a_tools # read-only + available in all states
|
||||
assert 'graph_update' not in user_a_tools # write permissions required
|
||||
assert 'complex_analysis' not in user_a_tools # advanced permissions required
|
||||
|
||||
# User B: Admin permissions in analysis state
|
||||
user_b_tools = filter_tools_by_group_and_state(
|
||||
sample_tools,
|
||||
['write', 'admin'],
|
||||
'analysis'
|
||||
)
|
||||
|
||||
# Should have access to admin tools available in analysis state
|
||||
assert 'graph_update' in user_b_tools # admin + available in analysis
|
||||
assert 'complex_analysis' not in user_b_tools # wrong group (needs advanced/compute)
|
||||
assert 'knowledge_query' not in user_b_tools # not available in analysis state
|
||||
assert 'text_completion' not in user_b_tools # wrong group (no admin)
|
||||
Loading…
Add table
Add a link
Reference in a new issue