mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-29 02:23:44 +02:00
Test suite executed from CI pipeline (#433)
* Test strategy & test cases * Unit tests * Integration tests
This commit is contained in:
parent
9c7a070681
commit
2f7fddd206
101 changed files with 17811 additions and 1 deletions
269
tests/integration/README.md
Normal file
269
tests/integration/README.md
Normal file
|
|
@ -0,0 +1,269 @@
|
|||
# Integration Test Pattern for TrustGraph
|
||||
|
||||
This directory contains integration tests that verify the coordination between multiple TrustGraph services and components, following the patterns outlined in [TEST_STRATEGY.md](../../TEST_STRATEGY.md).
|
||||
|
||||
## Integration Test Approach
|
||||
|
||||
Integration tests focus on **service-to-service communication patterns** and **end-to-end message flows** while still using mocks for external infrastructure.
|
||||
|
||||
### Key Principles
|
||||
|
||||
1. **Test Service Coordination**: Verify that services work together correctly
|
||||
2. **Mock External Dependencies**: Use mocks for databases, APIs, and infrastructure
|
||||
3. **Real Business Logic**: Exercise actual service logic and data transformations
|
||||
4. **Error Propagation**: Test how errors flow through the system
|
||||
5. **Configuration Testing**: Verify services respond correctly to different configurations
|
||||
|
||||
## Test Structure
|
||||
|
||||
### Fixtures (conftest.py)
|
||||
|
||||
Common fixtures for integration tests:
|
||||
- `mock_pulsar_client`: Mock Pulsar messaging client
|
||||
- `mock_flow_context`: Mock flow context for service coordination
|
||||
- `integration_config`: Standard configuration for integration tests
|
||||
- `sample_documents`: Test document collections
|
||||
- `sample_embeddings`: Test embedding vectors
|
||||
- `sample_queries`: Test query sets
|
||||
|
||||
### Test Patterns
|
||||
|
||||
#### 1. End-to-End Flow Testing
|
||||
|
||||
```python
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_end_to_end_flow(self, service_instance, mock_clients):
|
||||
"""Test complete service pipeline from input to output"""
|
||||
# Arrange - Set up realistic test data
|
||||
# Act - Execute the full service workflow
|
||||
# Assert - Verify coordination between all components
|
||||
```
|
||||
|
||||
#### 2. Error Propagation Testing
|
||||
|
||||
```python
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_error_handling(self, service_instance, mock_clients):
|
||||
"""Test how errors propagate through service coordination"""
|
||||
# Arrange - Set up failure scenarios
|
||||
# Act - Execute service with failing dependency
|
||||
# Assert - Verify proper error handling and cleanup
|
||||
```
|
||||
|
||||
#### 3. Configuration Testing
|
||||
|
||||
```python
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_configuration_scenarios(self, service_instance):
|
||||
"""Test service behavior with different configurations"""
|
||||
# Test multiple configuration scenarios
|
||||
# Verify service adapts correctly to each configuration
|
||||
```
|
||||
|
||||
## Running Integration Tests
|
||||
|
||||
### Run All Integration Tests
|
||||
```bash
|
||||
pytest tests/integration/ -m integration
|
||||
```
|
||||
|
||||
### Run Specific Test
|
||||
```bash
|
||||
pytest tests/integration/test_document_rag_integration.py::TestDocumentRagIntegration::test_document_rag_end_to_end_flow -v
|
||||
```
|
||||
|
||||
### Run with Coverage (Skip Coverage Requirement)
|
||||
```bash
|
||||
pytest tests/integration/ -m integration --cov=trustgraph --cov-fail-under=0
|
||||
```
|
||||
|
||||
### Run Slow Tests
|
||||
```bash
|
||||
pytest tests/integration/ -m "integration and slow"
|
||||
```
|
||||
|
||||
### Skip Slow Tests
|
||||
```bash
|
||||
pytest tests/integration/ -m "integration and not slow"
|
||||
```
|
||||
|
||||
## Examples: Integration Test Implementations
|
||||
|
||||
### 1. Document RAG Integration Test
|
||||
|
||||
The `test_document_rag_integration.py` demonstrates the integration test pattern:
|
||||
|
||||
### What It Tests
|
||||
- **Service Coordination**: Embeddings → Document Retrieval → Prompt Generation
|
||||
- **Error Handling**: Failure scenarios for each service dependency
|
||||
- **Configuration**: Different document limits, users, and collections
|
||||
- **Performance**: Large document set handling
|
||||
|
||||
### Key Features
|
||||
- **Realistic Data Flow**: Uses actual service logic with mocked dependencies
|
||||
- **Multiple Scenarios**: Success, failure, and edge cases
|
||||
- **Verbose Logging**: Tests logging functionality
|
||||
- **Multi-User Support**: Tests user and collection isolation
|
||||
|
||||
### Test Coverage
|
||||
- ✅ End-to-end happy path
|
||||
- ✅ No documents found scenario
|
||||
- ✅ Service failure scenarios (embeddings, documents, prompt)
|
||||
- ✅ Configuration variations
|
||||
- ✅ Multi-user isolation
|
||||
- ✅ Performance testing
|
||||
- ✅ Verbose logging
|
||||
|
||||
### 2. Text Completion Integration Test
|
||||
|
||||
The `test_text_completion_integration.py` demonstrates external API integration testing:
|
||||
|
||||
### What It Tests
|
||||
- **External API Integration**: OpenAI API connectivity and authentication
|
||||
- **Rate Limiting**: Proper handling of API rate limits and retries
|
||||
- **Error Handling**: API failures, connection timeouts, and error propagation
|
||||
- **Token Tracking**: Accurate input/output token counting and metrics
|
||||
- **Configuration**: Different model parameters and settings
|
||||
- **Concurrency**: Multiple simultaneous API requests
|
||||
|
||||
### Key Features
|
||||
- **Realistic Mock Responses**: Uses actual OpenAI API response structures
|
||||
- **Authentication Testing**: API key validation and base URL configuration
|
||||
- **Error Scenarios**: Rate limits, connection failures, invalid requests
|
||||
- **Performance Metrics**: Timing and token usage validation
|
||||
- **Model Flexibility**: Tests different GPT models and parameters
|
||||
|
||||
### Test Coverage
|
||||
- ✅ Successful text completion generation
|
||||
- ✅ Multiple model configurations (GPT-3.5, GPT-4, GPT-4-turbo)
|
||||
- ✅ Rate limit handling (RateLimitError → TooManyRequests)
|
||||
- ✅ API error handling and propagation
|
||||
- ✅ Token counting accuracy
|
||||
- ✅ Prompt construction and parameter validation
|
||||
- ✅ Authentication patterns and API key validation
|
||||
- ✅ Concurrent request processing
|
||||
- ✅ Response content extraction and validation
|
||||
- ✅ Performance timing measurements
|
||||
|
||||
### 3. Agent Manager Integration Test
|
||||
|
||||
The `test_agent_manager_integration.py` demonstrates complex service coordination testing:
|
||||
|
||||
### What It Tests
|
||||
- **ReAct Pattern**: Think-Act-Observe cycles with multi-step reasoning
|
||||
- **Tool Coordination**: Selection and execution of different tools (knowledge query, text completion, MCP tools)
|
||||
- **Conversation State**: Management of conversation history and context
|
||||
- **Multi-Service Integration**: Coordination between prompt, graph RAG, and tool services
|
||||
- **Error Handling**: Tool failures, unknown tools, and error propagation
|
||||
- **Configuration Management**: Dynamic tool loading and configuration
|
||||
|
||||
### Key Features
|
||||
- **Complex Coordination**: Tests agent reasoning with multiple tool options
|
||||
- **Stateful Processing**: Maintains conversation history across interactions
|
||||
- **Dynamic Tool Selection**: Tests tool selection based on context and reasoning
|
||||
- **Callback Pattern**: Tests think/observe callback mechanisms
|
||||
- **JSON Serialization**: Handles complex data structures in prompts
|
||||
- **Performance Testing**: Large conversation history handling
|
||||
|
||||
### Test Coverage
|
||||
- ✅ Basic reasoning cycle with tool selection
|
||||
- ✅ Final answer generation (ending ReAct cycle)
|
||||
- ✅ Full ReAct cycle with tool execution
|
||||
- ✅ Conversation history management
|
||||
- ✅ Multiple tool coordination and selection
|
||||
- ✅ Tool argument validation and processing
|
||||
- ✅ Error handling (unknown tools, execution failures)
|
||||
- ✅ Context integration and additional prompting
|
||||
- ✅ Empty tool configuration handling
|
||||
- ✅ Tool response processing and cleanup
|
||||
- ✅ Performance with large conversation history
|
||||
- ✅ JSON serialization in complex prompts
|
||||
|
||||
### 4. Knowledge Graph Extract → Store Pipeline Integration Test
|
||||
|
||||
The `test_kg_extract_store_integration.py` demonstrates multi-stage pipeline testing:
|
||||
|
||||
### What It Tests
|
||||
- **Text-to-Graph Transformation**: Complete pipeline from text chunks to graph triples
|
||||
- **Entity Extraction**: Definition extraction with proper URI generation
|
||||
- **Relationship Extraction**: Subject-predicate-object relationship extraction
|
||||
- **Graph Database Integration**: Storage coordination with Cassandra knowledge store
|
||||
- **Data Validation**: Entity filtering, validation, and consistency checks
|
||||
- **Pipeline Coordination**: Multi-stage processing with proper data flow
|
||||
|
||||
### Key Features
|
||||
- **Multi-Stage Pipeline**: Tests definitions → relationships → storage coordination
|
||||
- **Graph Data Structures**: RDF triples, entity contexts, and graph embeddings
|
||||
- **URI Generation**: Consistent entity URI creation across pipeline stages
|
||||
- **Data Transformation**: Complex text analysis to structured graph data
|
||||
- **Batch Processing**: Large document set processing performance
|
||||
- **Error Resilience**: Graceful handling of extraction failures
|
||||
|
||||
### Test Coverage
|
||||
- ✅ Definitions extraction pipeline (text → entities + definitions)
|
||||
- ✅ Relationships extraction pipeline (text → subject-predicate-object)
|
||||
- ✅ URI generation consistency between processors
|
||||
- ✅ Triple generation from definitions and relationships
|
||||
- ✅ Knowledge store integration (triples and embeddings storage)
|
||||
- ✅ End-to-end pipeline coordination
|
||||
- ✅ Error handling in extraction services
|
||||
- ✅ Empty and invalid extraction results handling
|
||||
- ✅ Entity filtering and validation
|
||||
- ✅ Large batch processing performance
|
||||
- ✅ Metadata propagation through pipeline stages
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Test Organization
|
||||
- Group related tests in classes
|
||||
- Use descriptive test names that explain the scenario
|
||||
- Follow the Arrange-Act-Assert pattern
|
||||
- Use appropriate pytest markers (`@pytest.mark.integration`, `@pytest.mark.slow`)
|
||||
|
||||
### Mock Strategy
|
||||
- Mock external services (databases, APIs, message brokers)
|
||||
- Use real service logic and data transformations
|
||||
- Create realistic mock responses that match actual service behavior
|
||||
- Reset mocks between tests to ensure isolation
|
||||
|
||||
### Test Data
|
||||
- Use realistic test data that reflects actual usage patterns
|
||||
- Create reusable fixtures for common test scenarios
|
||||
- Test with various data sizes and edge cases
|
||||
- Include both success and failure scenarios
|
||||
|
||||
### Error Testing
|
||||
- Test each dependency failure scenario
|
||||
- Verify proper error propagation and cleanup
|
||||
- Test timeout and retry mechanisms
|
||||
- Validate error response formats
|
||||
|
||||
### Performance Testing
|
||||
- Mark performance tests with `@pytest.mark.slow`
|
||||
- Test with realistic data volumes
|
||||
- Set reasonable performance expectations
|
||||
- Monitor resource usage during tests
|
||||
|
||||
## Adding New Integration Tests
|
||||
|
||||
1. **Identify Service Dependencies**: Map out which services your target service coordinates with
|
||||
2. **Create Mock Fixtures**: Set up mocks for each dependency in conftest.py
|
||||
3. **Design Test Scenarios**: Plan happy path, error cases, and edge conditions
|
||||
4. **Implement Tests**: Follow the established patterns in this directory
|
||||
5. **Add Documentation**: Update this README with your new test patterns
|
||||
|
||||
## Test Markers
|
||||
|
||||
- `@pytest.mark.integration`: Marks tests as integration tests
|
||||
- `@pytest.mark.slow`: Marks tests that take longer to run
|
||||
- `@pytest.mark.asyncio`: Required for async test functions
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
- Add tests with real test containers for database integration
|
||||
- Implement contract testing for service interfaces
|
||||
- Add performance benchmarking for critical paths
|
||||
- Create integration test templates for common service patterns
|
||||
0
tests/integration/__init__.py
Normal file
0
tests/integration/__init__.py
Normal file
386
tests/integration/conftest.py
Normal file
386
tests/integration/conftest.py
Normal file
|
|
@ -0,0 +1,386 @@
|
|||
"""
|
||||
Shared fixtures and configuration for integration tests
|
||||
|
||||
This file provides common fixtures and test configuration for integration tests.
|
||||
Following the TEST_STRATEGY.md patterns for integration testing.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pulsar_client():
|
||||
"""Mock Pulsar client for integration tests"""
|
||||
client = MagicMock()
|
||||
client.create_producer.return_value = AsyncMock()
|
||||
client.subscribe.return_value = AsyncMock()
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_flow_context():
|
||||
"""Mock flow context for testing service coordination"""
|
||||
context = MagicMock()
|
||||
|
||||
# Mock flow producers/consumers
|
||||
context.return_value.send = AsyncMock()
|
||||
context.return_value.receive = AsyncMock()
|
||||
|
||||
return context
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def integration_config():
|
||||
"""Common configuration for integration tests"""
|
||||
return {
|
||||
"pulsar_host": "localhost",
|
||||
"pulsar_port": 6650,
|
||||
"test_timeout": 30.0,
|
||||
"max_retries": 3,
|
||||
"doc_limit": 10,
|
||||
"embedding_dim": 5,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_documents():
|
||||
"""Sample document collection for testing"""
|
||||
return [
|
||||
{
|
||||
"id": "doc1",
|
||||
"content": "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data.",
|
||||
"collection": "ml_knowledge",
|
||||
"user": "test_user"
|
||||
},
|
||||
{
|
||||
"id": "doc2",
|
||||
"content": "Deep learning uses neural networks with multiple layers to model complex patterns in data.",
|
||||
"collection": "ml_knowledge",
|
||||
"user": "test_user"
|
||||
},
|
||||
{
|
||||
"id": "doc3",
|
||||
"content": "Supervised learning algorithms learn from labeled training data to make predictions on new data.",
|
||||
"collection": "ml_knowledge",
|
||||
"user": "test_user"
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_embeddings():
|
||||
"""Sample embedding vectors for testing"""
|
||||
return [
|
||||
[0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
[0.6, 0.7, 0.8, 0.9, 1.0],
|
||||
[0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
[0.7, 0.8, 0.9, 1.0, 0.1],
|
||||
[0.3, 0.4, 0.5, 0.6, 0.7]
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_queries():
|
||||
"""Sample queries for testing"""
|
||||
return [
|
||||
"What is machine learning?",
|
||||
"How does deep learning work?",
|
||||
"Explain supervised learning",
|
||||
"What are neural networks?",
|
||||
"How do algorithms learn from data?"
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_text_completion_requests():
|
||||
"""Sample text completion requests for testing"""
|
||||
return [
|
||||
{
|
||||
"system": "You are a helpful assistant.",
|
||||
"prompt": "What is artificial intelligence?",
|
||||
"expected_keywords": ["artificial intelligence", "AI", "machine learning"]
|
||||
},
|
||||
{
|
||||
"system": "You are a technical expert.",
|
||||
"prompt": "Explain neural networks",
|
||||
"expected_keywords": ["neural networks", "neurons", "layers"]
|
||||
},
|
||||
{
|
||||
"system": "You are a teacher.",
|
||||
"prompt": "What is supervised learning?",
|
||||
"expected_keywords": ["supervised learning", "training", "labels"]
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_response():
|
||||
"""Mock OpenAI API response structure"""
|
||||
return {
|
||||
"id": "chatcmpl-test123",
|
||||
"object": "chat.completion",
|
||||
"created": 1234567890,
|
||||
"model": "gpt-3.5-turbo",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "This is a test response from the AI model."
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 50,
|
||||
"completion_tokens": 100,
|
||||
"total_tokens": 150
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def text_completion_configs():
|
||||
"""Various text completion configurations for testing"""
|
||||
return [
|
||||
{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"temperature": 0.0,
|
||||
"max_output": 1024,
|
||||
"description": "Conservative settings"
|
||||
},
|
||||
{
|
||||
"model": "gpt-4",
|
||||
"temperature": 0.7,
|
||||
"max_output": 2048,
|
||||
"description": "Balanced settings"
|
||||
},
|
||||
{
|
||||
"model": "gpt-4-turbo",
|
||||
"temperature": 1.0,
|
||||
"max_output": 4096,
|
||||
"description": "Creative settings"
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_agent_tools():
|
||||
"""Sample agent tools configuration for testing"""
|
||||
return {
|
||||
"knowledge_query": {
|
||||
"name": "knowledge_query",
|
||||
"description": "Query the knowledge graph for information",
|
||||
"type": "knowledge-query",
|
||||
"arguments": [
|
||||
{
|
||||
"name": "question",
|
||||
"type": "string",
|
||||
"description": "The question to ask the knowledge graph"
|
||||
}
|
||||
]
|
||||
},
|
||||
"text_completion": {
|
||||
"name": "text_completion",
|
||||
"description": "Generate text completion using LLM",
|
||||
"type": "text-completion",
|
||||
"arguments": [
|
||||
{
|
||||
"name": "question",
|
||||
"type": "string",
|
||||
"description": "The question to ask the LLM"
|
||||
}
|
||||
]
|
||||
},
|
||||
"web_search": {
|
||||
"name": "web_search",
|
||||
"description": "Search the web for information",
|
||||
"type": "mcp-tool",
|
||||
"arguments": [
|
||||
{
|
||||
"name": "query",
|
||||
"type": "string",
|
||||
"description": "The search query"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_agent_requests():
|
||||
"""Sample agent requests for testing"""
|
||||
return [
|
||||
{
|
||||
"question": "What is machine learning?",
|
||||
"plan": "",
|
||||
"state": "",
|
||||
"history": [],
|
||||
"expected_tool": "knowledge_query"
|
||||
},
|
||||
{
|
||||
"question": "Can you explain neural networks in simple terms?",
|
||||
"plan": "",
|
||||
"state": "",
|
||||
"history": [],
|
||||
"expected_tool": "text_completion"
|
||||
},
|
||||
{
|
||||
"question": "Search for the latest AI research papers",
|
||||
"plan": "",
|
||||
"state": "",
|
||||
"history": [],
|
||||
"expected_tool": "web_search"
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_agent_responses():
|
||||
"""Sample agent responses for testing"""
|
||||
return [
|
||||
{
|
||||
"thought": "I need to search for information about machine learning",
|
||||
"action": "knowledge_query",
|
||||
"arguments": {"question": "What is machine learning?"}
|
||||
},
|
||||
{
|
||||
"thought": "I can provide a direct answer about neural networks",
|
||||
"final-answer": "Neural networks are computing systems inspired by biological neural networks."
|
||||
},
|
||||
{
|
||||
"thought": "I should search the web for recent research",
|
||||
"action": "web_search",
|
||||
"arguments": {"query": "latest AI research papers 2024"}
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_conversation_history():
|
||||
"""Sample conversation history for testing"""
|
||||
return [
|
||||
{
|
||||
"thought": "I need to search for basic information first",
|
||||
"action": "knowledge_query",
|
||||
"arguments": {"question": "What is artificial intelligence?"},
|
||||
"observation": "AI is the simulation of human intelligence in machines."
|
||||
},
|
||||
{
|
||||
"thought": "Now I can provide more specific information",
|
||||
"action": "text_completion",
|
||||
"arguments": {"question": "Explain machine learning within AI"},
|
||||
"observation": "Machine learning is a subset of AI that enables computers to learn from data."
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_kg_extraction_data():
|
||||
"""Sample knowledge graph extraction data for testing"""
|
||||
return {
|
||||
"text_chunks": [
|
||||
"Machine Learning is a subset of Artificial Intelligence that enables computers to learn from data.",
|
||||
"Neural Networks are computing systems inspired by biological neural networks.",
|
||||
"Deep Learning uses neural networks with multiple layers to model complex patterns."
|
||||
],
|
||||
"expected_entities": [
|
||||
"Machine Learning",
|
||||
"Artificial Intelligence",
|
||||
"Neural Networks",
|
||||
"Deep Learning"
|
||||
],
|
||||
"expected_relationships": [
|
||||
{
|
||||
"subject": "Machine Learning",
|
||||
"predicate": "is_subset_of",
|
||||
"object": "Artificial Intelligence"
|
||||
},
|
||||
{
|
||||
"subject": "Deep Learning",
|
||||
"predicate": "uses",
|
||||
"object": "Neural Networks"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_kg_definitions():
|
||||
"""Sample knowledge graph definitions for testing"""
|
||||
return [
|
||||
{
|
||||
"entity": "Machine Learning",
|
||||
"definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming."
|
||||
},
|
||||
{
|
||||
"entity": "Artificial Intelligence",
|
||||
"definition": "The simulation of human intelligence in machines that are programmed to think and act like humans."
|
||||
},
|
||||
{
|
||||
"entity": "Neural Networks",
|
||||
"definition": "Computing systems inspired by biological neural networks that process information using interconnected nodes."
|
||||
},
|
||||
{
|
||||
"entity": "Deep Learning",
|
||||
"definition": "A subset of machine learning that uses neural networks with multiple layers to model complex patterns in data."
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_kg_relationships():
|
||||
"""Sample knowledge graph relationships for testing"""
|
||||
return [
|
||||
{
|
||||
"subject": "Machine Learning",
|
||||
"predicate": "is_subset_of",
|
||||
"object": "Artificial Intelligence",
|
||||
"object-entity": True
|
||||
},
|
||||
{
|
||||
"subject": "Deep Learning",
|
||||
"predicate": "is_subset_of",
|
||||
"object": "Machine Learning",
|
||||
"object-entity": True
|
||||
},
|
||||
{
|
||||
"subject": "Neural Networks",
|
||||
"predicate": "is_used_in",
|
||||
"object": "Deep Learning",
|
||||
"object-entity": True
|
||||
},
|
||||
{
|
||||
"subject": "Machine Learning",
|
||||
"predicate": "processes",
|
||||
"object": "data patterns",
|
||||
"object-entity": False
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_kg_triples():
|
||||
"""Sample knowledge graph triples for testing"""
|
||||
return [
|
||||
{
|
||||
"subject": "http://trustgraph.ai/e/machine-learning",
|
||||
"predicate": "http://www.w3.org/2000/01/rdf-schema#label",
|
||||
"object": "Machine Learning"
|
||||
},
|
||||
{
|
||||
"subject": "http://trustgraph.ai/e/machine-learning",
|
||||
"predicate": "http://trustgraph.ai/definition",
|
||||
"object": "A subset of artificial intelligence that enables computers to learn from data."
|
||||
},
|
||||
{
|
||||
"subject": "http://trustgraph.ai/e/machine-learning",
|
||||
"predicate": "http://trustgraph.ai/e/is_subset_of",
|
||||
"object": "http://trustgraph.ai/e/artificial-intelligence"
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
# Test markers for integration tests
|
||||
pytestmark = pytest.mark.integration
|
||||
532
tests/integration/test_agent_manager_integration.py
Normal file
532
tests/integration/test_agent_manager_integration.py
Normal file
|
|
@ -0,0 +1,532 @@
|
|||
"""
|
||||
Integration tests for Agent Manager (ReAct Pattern) Service
|
||||
|
||||
These tests verify the end-to-end functionality of the Agent Manager service,
|
||||
testing the ReAct pattern (Think-Act-Observe), tool coordination, multi-step reasoning,
|
||||
and conversation state management.
|
||||
Following the TEST_STRATEGY.md approach for integration testing.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from trustgraph.agent.react.agent_manager import AgentManager
|
||||
from trustgraph.agent.react.tools import KnowledgeQueryImpl, TextCompletionImpl, McpToolImpl
|
||||
from trustgraph.agent.react.types import Action, Final, Tool, Argument
|
||||
from trustgraph.schema import AgentRequest, AgentResponse, AgentStep, Error
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestAgentManagerIntegration:
|
||||
"""Integration tests for Agent Manager ReAct pattern coordination"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_flow_context(self):
|
||||
"""Mock flow context for service coordination"""
|
||||
context = MagicMock()
|
||||
|
||||
# Mock prompt client
|
||||
prompt_client = AsyncMock()
|
||||
prompt_client.agent_react.return_value = {
|
||||
"thought": "I need to search for information about machine learning",
|
||||
"action": "knowledge_query",
|
||||
"arguments": {"question": "What is machine learning?"}
|
||||
}
|
||||
|
||||
# Mock graph RAG client
|
||||
graph_rag_client = AsyncMock()
|
||||
graph_rag_client.rag.return_value = "Machine learning is a subset of AI that enables computers to learn from data."
|
||||
|
||||
# Mock text completion client
|
||||
text_completion_client = AsyncMock()
|
||||
text_completion_client.question.return_value = "Machine learning involves algorithms that improve through experience."
|
||||
|
||||
# Mock MCP tool client
|
||||
mcp_tool_client = AsyncMock()
|
||||
mcp_tool_client.invoke.return_value = "Tool execution successful"
|
||||
|
||||
# Configure context to return appropriate clients
|
||||
def context_router(service_name):
|
||||
if service_name == "prompt-request":
|
||||
return prompt_client
|
||||
elif service_name == "graph-rag-request":
|
||||
return graph_rag_client
|
||||
elif service_name == "prompt-request":
|
||||
return text_completion_client
|
||||
elif service_name == "mcp-tool-request":
|
||||
return mcp_tool_client
|
||||
else:
|
||||
return AsyncMock()
|
||||
|
||||
context.side_effect = context_router
|
||||
return context
|
||||
|
||||
@pytest.fixture
|
||||
def sample_tools(self):
|
||||
"""Sample tool configuration for testing"""
|
||||
return {
|
||||
"knowledge_query": Tool(
|
||||
name="knowledge_query",
|
||||
description="Query the knowledge graph for information",
|
||||
arguments={
|
||||
"question": Argument(
|
||||
name="question",
|
||||
type="string",
|
||||
description="The question to ask the knowledge graph"
|
||||
)
|
||||
},
|
||||
implementation=KnowledgeQueryImpl,
|
||||
config={}
|
||||
),
|
||||
"text_completion": Tool(
|
||||
name="text_completion",
|
||||
description="Generate text completion using LLM",
|
||||
arguments={
|
||||
"question": Argument(
|
||||
name="question",
|
||||
type="string",
|
||||
description="The question to ask the LLM"
|
||||
)
|
||||
},
|
||||
implementation=TextCompletionImpl,
|
||||
config={}
|
||||
),
|
||||
"web_search": Tool(
|
||||
name="web_search",
|
||||
description="Search the web for information",
|
||||
arguments={
|
||||
"query": Argument(
|
||||
name="query",
|
||||
type="string",
|
||||
description="The search query"
|
||||
)
|
||||
},
|
||||
implementation=lambda context: AsyncMock(invoke=AsyncMock(return_value="Web search results")),
|
||||
config={}
|
||||
)
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def agent_manager(self, sample_tools):
|
||||
"""Create agent manager with sample tools"""
|
||||
return AgentManager(
|
||||
tools=sample_tools,
|
||||
additional_context="You are a helpful AI assistant with access to knowledge and tools."
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_manager_reasoning_cycle(self, agent_manager, mock_flow_context):
|
||||
"""Test basic reasoning cycle with tool selection"""
|
||||
# Arrange
|
||||
question = "What is machine learning?"
|
||||
history = []
|
||||
|
||||
# Act
|
||||
action = await agent_manager.reason(question, history, mock_flow_context)
|
||||
|
||||
# Assert
|
||||
assert isinstance(action, Action)
|
||||
assert action.thought == "I need to search for information about machine learning"
|
||||
assert action.name == "knowledge_query"
|
||||
assert action.arguments == {"question": "What is machine learning?"}
|
||||
assert action.observation == ""
|
||||
|
||||
# Verify prompt client was called correctly
|
||||
prompt_client = mock_flow_context("prompt-request")
|
||||
prompt_client.agent_react.assert_called_once()
|
||||
|
||||
# Verify the prompt variables passed to agent_react
|
||||
call_args = prompt_client.agent_react.call_args
|
||||
variables = call_args[0][0]
|
||||
assert variables["question"] == question
|
||||
assert len(variables["tools"]) == 3 # knowledge_query, text_completion, web_search
|
||||
assert variables["context"] == "You are a helpful AI assistant with access to knowledge and tools."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_manager_final_answer(self, agent_manager, mock_flow_context):
|
||||
"""Test agent manager returning final answer"""
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").agent_react.return_value = {
|
||||
"thought": "I have enough information to answer the question",
|
||||
"final-answer": "Machine learning is a field of AI that enables computers to learn from data."
|
||||
}
|
||||
|
||||
question = "What is machine learning?"
|
||||
history = []
|
||||
|
||||
# Act
|
||||
action = await agent_manager.reason(question, history, mock_flow_context)
|
||||
|
||||
# Assert
|
||||
assert isinstance(action, Final)
|
||||
assert action.thought == "I have enough information to answer the question"
|
||||
assert action.final == "Machine learning is a field of AI that enables computers to learn from data."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_manager_react_with_tool_execution(self, agent_manager, mock_flow_context):
|
||||
"""Test full ReAct cycle with tool execution"""
|
||||
# Arrange
|
||||
question = "What is machine learning?"
|
||||
history = []
|
||||
|
||||
think_callback = AsyncMock()
|
||||
observe_callback = AsyncMock()
|
||||
|
||||
# Act
|
||||
action = await agent_manager.react(question, history, think_callback, observe_callback, mock_flow_context)
|
||||
|
||||
# Assert
|
||||
assert isinstance(action, Action)
|
||||
assert action.thought == "I need to search for information about machine learning"
|
||||
assert action.name == "knowledge_query"
|
||||
assert action.arguments == {"question": "What is machine learning?"}
|
||||
assert action.observation == "Machine learning is a subset of AI that enables computers to learn from data."
|
||||
|
||||
# Verify callbacks were called
|
||||
think_callback.assert_called_once_with("I need to search for information about machine learning")
|
||||
observe_callback.assert_called_once_with("Machine learning is a subset of AI that enables computers to learn from data.")
|
||||
|
||||
# Verify tool was executed
|
||||
graph_rag_client = mock_flow_context("graph-rag-request")
|
||||
graph_rag_client.rag.assert_called_once_with("What is machine learning?")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_manager_react_with_final_answer(self, agent_manager, mock_flow_context):
|
||||
"""Test ReAct cycle ending with final answer"""
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").agent_react.return_value = {
|
||||
"thought": "I can provide a direct answer",
|
||||
"final-answer": "Machine learning is a branch of artificial intelligence."
|
||||
}
|
||||
|
||||
question = "What is machine learning?"
|
||||
history = []
|
||||
|
||||
think_callback = AsyncMock()
|
||||
observe_callback = AsyncMock()
|
||||
|
||||
# Act
|
||||
action = await agent_manager.react(question, history, think_callback, observe_callback, mock_flow_context)
|
||||
|
||||
# Assert
|
||||
assert isinstance(action, Final)
|
||||
assert action.thought == "I can provide a direct answer"
|
||||
assert action.final == "Machine learning is a branch of artificial intelligence."
|
||||
|
||||
# Verify only think callback was called (no observation for final answer)
|
||||
think_callback.assert_called_once_with("I can provide a direct answer")
|
||||
observe_callback.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_manager_with_conversation_history(self, agent_manager, mock_flow_context):
|
||||
"""Test agent manager with conversation history"""
|
||||
# Arrange
|
||||
question = "Can you tell me more about neural networks?"
|
||||
history = [
|
||||
Action(
|
||||
thought="I need to search for information about machine learning",
|
||||
name="knowledge_query",
|
||||
arguments={"question": "What is machine learning?"},
|
||||
observation="Machine learning is a subset of AI that enables computers to learn from data."
|
||||
)
|
||||
]
|
||||
|
||||
# Act
|
||||
action = await agent_manager.reason(question, history, mock_flow_context)
|
||||
|
||||
# Assert
|
||||
assert isinstance(action, Action)
|
||||
|
||||
# Verify history was included in prompt variables
|
||||
prompt_client = mock_flow_context("prompt-request")
|
||||
call_args = prompt_client.agent_react.call_args
|
||||
variables = call_args[0][0]
|
||||
assert len(variables["history"]) == 1
|
||||
assert variables["history"][0]["thought"] == "I need to search for information about machine learning"
|
||||
assert variables["history"][0]["action"] == "knowledge_query"
|
||||
assert variables["history"][0]["observation"] == "Machine learning is a subset of AI that enables computers to learn from data."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_manager_tool_selection(self, agent_manager, mock_flow_context):
|
||||
"""Test agent manager selecting different tools"""
|
||||
# Test different tool selections
|
||||
tool_scenarios = [
|
||||
("knowledge_query", "graph-rag-request"),
|
||||
("text_completion", "prompt-request"),
|
||||
]
|
||||
|
||||
for tool_name, expected_service in tool_scenarios:
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").agent_react.return_value = {
|
||||
"thought": f"I need to use {tool_name}",
|
||||
"action": tool_name,
|
||||
"arguments": {"question": "test question"}
|
||||
}
|
||||
|
||||
think_callback = AsyncMock()
|
||||
observe_callback = AsyncMock()
|
||||
|
||||
# Act
|
||||
action = await agent_manager.react("test question", [], think_callback, observe_callback, mock_flow_context)
|
||||
|
||||
# Assert
|
||||
assert isinstance(action, Action)
|
||||
assert action.name == tool_name
|
||||
|
||||
# Verify correct service was called
|
||||
if tool_name == "knowledge_query":
|
||||
mock_flow_context("graph-rag-request").rag.assert_called()
|
||||
elif tool_name == "text_completion":
|
||||
mock_flow_context("prompt-request").question.assert_called()
|
||||
|
||||
# Reset mocks for next iteration
|
||||
for service in ["prompt-request", "graph-rag-request", "prompt-request"]:
|
||||
mock_flow_context(service).reset_mock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_manager_unknown_tool_error(self, agent_manager, mock_flow_context):
|
||||
"""Test agent manager error handling for unknown tool"""
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").agent_react.return_value = {
|
||||
"thought": "I need to use an unknown tool",
|
||||
"action": "unknown_tool",
|
||||
"arguments": {"param": "value"}
|
||||
}
|
||||
|
||||
think_callback = AsyncMock()
|
||||
observe_callback = AsyncMock()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
await agent_manager.react("test question", [], think_callback, observe_callback, mock_flow_context)
|
||||
|
||||
assert "No action for unknown_tool!" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_manager_tool_execution_error(self, agent_manager, mock_flow_context):
|
||||
"""Test agent manager handling tool execution errors"""
|
||||
# Arrange
|
||||
mock_flow_context("graph-rag-request").rag.side_effect = Exception("Tool execution failed")
|
||||
|
||||
think_callback = AsyncMock()
|
||||
observe_callback = AsyncMock()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await agent_manager.react("test question", [], think_callback, observe_callback, mock_flow_context)
|
||||
|
||||
assert "Tool execution failed" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_manager_multiple_tools_coordination(self, agent_manager, mock_flow_context):
|
||||
"""Test agent manager coordination with multiple available tools"""
|
||||
# Arrange
|
||||
question = "Find information about AI and summarize it"
|
||||
|
||||
# Mock multi-step reasoning
|
||||
mock_flow_context("prompt-request").agent_react.return_value = {
|
||||
"thought": "I need to search for AI information first",
|
||||
"action": "knowledge_query",
|
||||
"arguments": {"question": "What is artificial intelligence?"}
|
||||
}
|
||||
|
||||
# Act
|
||||
action = await agent_manager.reason(question, [], mock_flow_context)
|
||||
|
||||
# Assert
|
||||
assert isinstance(action, Action)
|
||||
assert action.name == "knowledge_query"
|
||||
|
||||
# Verify tool information was passed to prompt
|
||||
prompt_client = mock_flow_context("prompt-request")
|
||||
call_args = prompt_client.agent_react.call_args
|
||||
variables = call_args[0][0]
|
||||
|
||||
# Should have all 3 tools available
|
||||
tool_names = [tool["name"] for tool in variables["tools"]]
|
||||
assert "knowledge_query" in tool_names
|
||||
assert "text_completion" in tool_names
|
||||
assert "web_search" in tool_names
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_manager_tool_argument_validation(self, agent_manager, mock_flow_context):
|
||||
"""Test agent manager with various tool argument patterns"""
|
||||
# Arrange
|
||||
test_cases = [
|
||||
{
|
||||
"action": "knowledge_query",
|
||||
"arguments": {"question": "What is deep learning?"},
|
||||
"expected_service": "graph-rag-request"
|
||||
},
|
||||
{
|
||||
"action": "text_completion",
|
||||
"arguments": {"question": "Explain neural networks"},
|
||||
"expected_service": "prompt-request"
|
||||
},
|
||||
{
|
||||
"action": "web_search",
|
||||
"arguments": {"query": "latest AI research"},
|
||||
"expected_service": None # Custom mock
|
||||
}
|
||||
]
|
||||
|
||||
for test_case in test_cases:
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").agent_react.return_value = {
|
||||
"thought": f"Using {test_case['action']}",
|
||||
"action": test_case['action'],
|
||||
"arguments": test_case['arguments']
|
||||
}
|
||||
|
||||
think_callback = AsyncMock()
|
||||
observe_callback = AsyncMock()
|
||||
|
||||
# Act
|
||||
action = await agent_manager.react("test", [], think_callback, observe_callback, mock_flow_context)
|
||||
|
||||
# Assert
|
||||
assert isinstance(action, Action)
|
||||
assert action.name == test_case['action']
|
||||
assert action.arguments == test_case['arguments']
|
||||
|
||||
# Reset mocks
|
||||
for service in ["prompt-request", "graph-rag-request", "prompt-request"]:
|
||||
mock_flow_context(service).reset_mock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_manager_context_integration(self, agent_manager, mock_flow_context):
|
||||
"""Test agent manager integration with additional context"""
|
||||
# Arrange
|
||||
agent_with_context = AgentManager(
|
||||
tools={"knowledge_query": agent_manager.tools["knowledge_query"]},
|
||||
additional_context="You are an expert in machine learning research."
|
||||
)
|
||||
|
||||
question = "What are the latest developments in AI?"
|
||||
|
||||
# Act
|
||||
action = await agent_with_context.reason(question, [], mock_flow_context)
|
||||
|
||||
# Assert
|
||||
prompt_client = mock_flow_context("prompt-request")
|
||||
call_args = prompt_client.agent_react.call_args
|
||||
variables = call_args[0][0]
|
||||
|
||||
assert variables["context"] == "You are an expert in machine learning research."
|
||||
assert variables["question"] == question
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_manager_empty_tools(self, mock_flow_context):
|
||||
"""Test agent manager with no tools available"""
|
||||
# Arrange
|
||||
agent_no_tools = AgentManager(tools={}, additional_context="")
|
||||
|
||||
question = "What is machine learning?"
|
||||
|
||||
# Act
|
||||
action = await agent_no_tools.reason(question, [], mock_flow_context)
|
||||
|
||||
# Assert
|
||||
prompt_client = mock_flow_context("prompt-request")
|
||||
call_args = prompt_client.agent_react.call_args
|
||||
variables = call_args[0][0]
|
||||
|
||||
assert len(variables["tools"]) == 0
|
||||
assert variables["tool_names"] == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_manager_tool_response_processing(self, agent_manager, mock_flow_context):
|
||||
"""Test agent manager processing different tool response types"""
|
||||
# Arrange
|
||||
response_scenarios = [
|
||||
"Simple text response",
|
||||
"Multi-line response\nwith several lines\nof information",
|
||||
"Response with special characters: @#$%^&*()_+-=[]{}|;':\",./<>?",
|
||||
" Response with whitespace ",
|
||||
"" # Empty response
|
||||
]
|
||||
|
||||
for expected_response in response_scenarios:
|
||||
# Set up mock response
|
||||
mock_flow_context("graph-rag-request").rag.return_value = expected_response
|
||||
|
||||
think_callback = AsyncMock()
|
||||
observe_callback = AsyncMock()
|
||||
|
||||
# Act
|
||||
action = await agent_manager.react("test question", [], think_callback, observe_callback, mock_flow_context)
|
||||
|
||||
# Assert
|
||||
assert isinstance(action, Action)
|
||||
assert action.observation == expected_response.strip()
|
||||
observe_callback.assert_called_with(expected_response.strip())
|
||||
|
||||
# Reset mocks
|
||||
mock_flow_context("graph-rag-request").reset_mock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.slow
|
||||
async def test_agent_manager_performance_with_large_history(self, agent_manager, mock_flow_context):
|
||||
"""Test agent manager performance with large conversation history"""
|
||||
# Arrange
|
||||
large_history = [
|
||||
Action(
|
||||
thought=f"Step {i} thinking",
|
||||
name="knowledge_query",
|
||||
arguments={"question": f"Question {i}"},
|
||||
observation=f"Observation {i}"
|
||||
)
|
||||
for i in range(50) # Large history
|
||||
]
|
||||
|
||||
question = "Final question"
|
||||
|
||||
# Act
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
action = await agent_manager.reason(question, large_history, mock_flow_context)
|
||||
|
||||
end_time = time.time()
|
||||
execution_time = end_time - start_time
|
||||
|
||||
# Assert
|
||||
assert isinstance(action, Action)
|
||||
assert execution_time < 5.0 # Should complete within reasonable time
|
||||
|
||||
# Verify history was processed correctly
|
||||
prompt_client = mock_flow_context("prompt-request")
|
||||
call_args = prompt_client.agent_react.call_args
|
||||
variables = call_args[0][0]
|
||||
assert len(variables["history"]) == 50
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_manager_json_serialization(self, agent_manager, mock_flow_context):
|
||||
"""Test agent manager handling of JSON serialization in prompts"""
|
||||
# Arrange
|
||||
complex_history = [
|
||||
Action(
|
||||
thought="Complex thinking with special characters: \"quotes\", 'apostrophes', and symbols",
|
||||
name="knowledge_query",
|
||||
arguments={"question": "What about JSON serialization?", "complex": {"nested": "value"}},
|
||||
observation="Response with JSON: {\"key\": \"value\"}"
|
||||
)
|
||||
]
|
||||
|
||||
question = "Handle JSON properly"
|
||||
|
||||
# Act
|
||||
action = await agent_manager.reason(question, complex_history, mock_flow_context)
|
||||
|
||||
# Assert
|
||||
assert isinstance(action, Action)
|
||||
|
||||
# Verify JSON was properly serialized in prompt
|
||||
prompt_client = mock_flow_context("prompt-request")
|
||||
call_args = prompt_client.agent_react.call_args
|
||||
variables = call_args[0][0]
|
||||
|
||||
# Should not raise JSON serialization errors
|
||||
json_str = json.dumps(variables, indent=4)
|
||||
assert len(json_str) > 0
|
||||
309
tests/integration/test_document_rag_integration.py
Normal file
309
tests/integration/test_document_rag_integration.py
Normal file
|
|
@ -0,0 +1,309 @@
|
|||
"""
|
||||
Integration tests for DocumentRAG retrieval system
|
||||
|
||||
These tests verify the end-to-end functionality of the DocumentRAG system,
|
||||
testing the coordination between embeddings, document retrieval, and prompt services.
|
||||
Following the TEST_STRATEGY.md approach for integration testing.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from testcontainers.compose import DockerCompose
|
||||
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestDocumentRagIntegration:
|
||||
"""Integration tests for DocumentRAG system coordination"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embeddings_client(self):
|
||||
"""Mock embeddings client that returns realistic vector embeddings"""
|
||||
client = AsyncMock()
|
||||
client.embed.return_value = [
|
||||
[0.1, 0.2, 0.3, 0.4, 0.5], # Realistic 5-dimensional embedding
|
||||
[0.6, 0.7, 0.8, 0.9, 1.0] # Second embedding for testing
|
||||
]
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_doc_embeddings_client(self):
|
||||
"""Mock document embeddings client that returns realistic document chunks"""
|
||||
client = AsyncMock()
|
||||
client.query.return_value = [
|
||||
"Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data.",
|
||||
"Deep learning uses neural networks with multiple layers to model complex patterns in data.",
|
||||
"Supervised learning algorithms learn from labeled training data to make predictions on new data."
|
||||
]
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_prompt_client(self):
|
||||
"""Mock prompt client that generates realistic responses"""
|
||||
client = AsyncMock()
|
||||
client.document_prompt.return_value = (
|
||||
"Machine learning is a field of artificial intelligence that enables computers to learn "
|
||||
"and improve from experience without being explicitly programmed. It uses algorithms "
|
||||
"to find patterns in data and make predictions or decisions."
|
||||
)
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def document_rag(self, mock_embeddings_client, mock_doc_embeddings_client, mock_prompt_client):
|
||||
"""Create DocumentRag instance with mocked dependencies"""
|
||||
return DocumentRag(
|
||||
embeddings_client=mock_embeddings_client,
|
||||
doc_embeddings_client=mock_doc_embeddings_client,
|
||||
prompt_client=mock_prompt_client,
|
||||
verbose=True
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_end_to_end_flow(self, document_rag, mock_embeddings_client,
|
||||
mock_doc_embeddings_client, mock_prompt_client):
|
||||
"""Test complete DocumentRAG pipeline from query to response"""
|
||||
# Arrange
|
||||
query = "What is machine learning?"
|
||||
user = "test_user"
|
||||
collection = "ml_knowledge"
|
||||
doc_limit = 10
|
||||
|
||||
# Act
|
||||
result = await document_rag.query(
|
||||
query=query,
|
||||
user=user,
|
||||
collection=collection,
|
||||
doc_limit=doc_limit
|
||||
)
|
||||
|
||||
# Assert - Verify service coordination
|
||||
mock_embeddings_client.embed.assert_called_once_with(query)
|
||||
|
||||
mock_doc_embeddings_client.query.assert_called_once_with(
|
||||
[[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]],
|
||||
limit=doc_limit,
|
||||
user=user,
|
||||
collection=collection
|
||||
)
|
||||
|
||||
mock_prompt_client.document_prompt.assert_called_once_with(
|
||||
query=query,
|
||||
documents=[
|
||||
"Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data.",
|
||||
"Deep learning uses neural networks with multiple layers to model complex patterns in data.",
|
||||
"Supervised learning algorithms learn from labeled training data to make predictions on new data."
|
||||
]
|
||||
)
|
||||
|
||||
# Verify final response
|
||||
assert result is not None
|
||||
assert isinstance(result, str)
|
||||
assert "machine learning" in result.lower()
|
||||
assert "artificial intelligence" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_with_no_documents_found(self, mock_embeddings_client,
|
||||
mock_doc_embeddings_client, mock_prompt_client):
|
||||
"""Test DocumentRAG behavior when no documents are retrieved"""
|
||||
# Arrange
|
||||
mock_doc_embeddings_client.query.return_value = [] # No documents found
|
||||
mock_prompt_client.document_prompt.return_value = "I couldn't find any relevant documents for your query."
|
||||
|
||||
document_rag = DocumentRag(
|
||||
embeddings_client=mock_embeddings_client,
|
||||
doc_embeddings_client=mock_doc_embeddings_client,
|
||||
prompt_client=mock_prompt_client,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Act
|
||||
result = await document_rag.query("very obscure query")
|
||||
|
||||
# Assert
|
||||
mock_embeddings_client.embed.assert_called_once()
|
||||
mock_doc_embeddings_client.query.assert_called_once()
|
||||
mock_prompt_client.document_prompt.assert_called_once_with(
|
||||
query="very obscure query",
|
||||
documents=[]
|
||||
)
|
||||
|
||||
assert result == "I couldn't find any relevant documents for your query."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_embeddings_service_failure(self, mock_embeddings_client,
|
||||
mock_doc_embeddings_client, mock_prompt_client):
|
||||
"""Test DocumentRAG error handling when embeddings service fails"""
|
||||
# Arrange
|
||||
mock_embeddings_client.embed.side_effect = Exception("Embeddings service unavailable")
|
||||
|
||||
document_rag = DocumentRag(
|
||||
embeddings_client=mock_embeddings_client,
|
||||
doc_embeddings_client=mock_doc_embeddings_client,
|
||||
prompt_client=mock_prompt_client,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await document_rag.query("test query")
|
||||
|
||||
assert "Embeddings service unavailable" in str(exc_info.value)
|
||||
mock_embeddings_client.embed.assert_called_once()
|
||||
mock_doc_embeddings_client.query.assert_not_called()
|
||||
mock_prompt_client.document_prompt.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_document_service_failure(self, mock_embeddings_client,
|
||||
mock_doc_embeddings_client, mock_prompt_client):
|
||||
"""Test DocumentRAG error handling when document service fails"""
|
||||
# Arrange
|
||||
mock_doc_embeddings_client.query.side_effect = Exception("Document service connection failed")
|
||||
|
||||
document_rag = DocumentRag(
|
||||
embeddings_client=mock_embeddings_client,
|
||||
doc_embeddings_client=mock_doc_embeddings_client,
|
||||
prompt_client=mock_prompt_client,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await document_rag.query("test query")
|
||||
|
||||
assert "Document service connection failed" in str(exc_info.value)
|
||||
mock_embeddings_client.embed.assert_called_once()
|
||||
mock_doc_embeddings_client.query.assert_called_once()
|
||||
mock_prompt_client.document_prompt.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_prompt_service_failure(self, mock_embeddings_client,
|
||||
mock_doc_embeddings_client, mock_prompt_client):
|
||||
"""Test DocumentRAG error handling when prompt service fails"""
|
||||
# Arrange
|
||||
mock_prompt_client.document_prompt.side_effect = Exception("LLM service rate limited")
|
||||
|
||||
document_rag = DocumentRag(
|
||||
embeddings_client=mock_embeddings_client,
|
||||
doc_embeddings_client=mock_doc_embeddings_client,
|
||||
prompt_client=mock_prompt_client,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await document_rag.query("test query")
|
||||
|
||||
assert "LLM service rate limited" in str(exc_info.value)
|
||||
mock_embeddings_client.embed.assert_called_once()
|
||||
mock_doc_embeddings_client.query.assert_called_once()
|
||||
mock_prompt_client.document_prompt.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_with_different_document_limits(self, document_rag,
|
||||
mock_doc_embeddings_client):
|
||||
"""Test DocumentRAG with various document limit configurations"""
|
||||
# Test different document limits
|
||||
test_cases = [1, 5, 10, 25, 50]
|
||||
|
||||
for limit in test_cases:
|
||||
# Reset mock call history
|
||||
mock_doc_embeddings_client.reset_mock()
|
||||
|
||||
# Act
|
||||
await document_rag.query(f"query with limit {limit}", doc_limit=limit)
|
||||
|
||||
# Assert
|
||||
mock_doc_embeddings_client.query.assert_called_once()
|
||||
call_args = mock_doc_embeddings_client.query.call_args
|
||||
assert call_args.kwargs['limit'] == limit
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_multi_user_isolation(self, document_rag, mock_doc_embeddings_client):
|
||||
"""Test DocumentRAG properly isolates queries by user and collection"""
|
||||
# Arrange
|
||||
test_scenarios = [
|
||||
("user1", "collection1"),
|
||||
("user2", "collection2"),
|
||||
("user1", "collection2"), # Same user, different collection
|
||||
("user2", "collection1"), # Different user, same collection
|
||||
]
|
||||
|
||||
for user, collection in test_scenarios:
|
||||
# Reset mock call history
|
||||
mock_doc_embeddings_client.reset_mock()
|
||||
|
||||
# Act
|
||||
await document_rag.query(
|
||||
f"query from {user} in {collection}",
|
||||
user=user,
|
||||
collection=collection
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_doc_embeddings_client.query.assert_called_once()
|
||||
call_args = mock_doc_embeddings_client.query.call_args
|
||||
assert call_args.kwargs['user'] == user
|
||||
assert call_args.kwargs['collection'] == collection
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_verbose_logging(self, mock_embeddings_client,
|
||||
mock_doc_embeddings_client, mock_prompt_client,
|
||||
capsys):
|
||||
"""Test DocumentRAG verbose logging functionality"""
|
||||
# Arrange
|
||||
document_rag = DocumentRag(
|
||||
embeddings_client=mock_embeddings_client,
|
||||
doc_embeddings_client=mock_doc_embeddings_client,
|
||||
prompt_client=mock_prompt_client,
|
||||
verbose=True
|
||||
)
|
||||
|
||||
# Act
|
||||
await document_rag.query("test query for verbose logging")
|
||||
|
||||
# Assert
|
||||
captured = capsys.readouterr()
|
||||
assert "Initialised" in captured.out
|
||||
assert "Construct prompt..." in captured.out
|
||||
assert "Compute embeddings..." in captured.out
|
||||
assert "Get docs..." in captured.out
|
||||
assert "Invoke LLM..." in captured.out
|
||||
assert "Done" in captured.out
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.slow
|
||||
async def test_document_rag_performance_with_large_document_set(self, document_rag,
|
||||
mock_doc_embeddings_client):
|
||||
"""Test DocumentRAG performance with large document retrieval"""
|
||||
# Arrange - Mock large document set (100 documents)
|
||||
large_doc_set = [f"Document {i} content about machine learning and AI" for i in range(100)]
|
||||
mock_doc_embeddings_client.query.return_value = large_doc_set
|
||||
|
||||
# Act
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
result = await document_rag.query("performance test query", doc_limit=100)
|
||||
|
||||
end_time = time.time()
|
||||
execution_time = end_time - start_time
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert execution_time < 5.0 # Should complete within 5 seconds
|
||||
mock_doc_embeddings_client.query.assert_called_once()
|
||||
call_args = mock_doc_embeddings_client.query.call_args
|
||||
assert call_args.kwargs['limit'] == 100
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_default_parameters(self, document_rag, mock_doc_embeddings_client):
|
||||
"""Test DocumentRAG uses correct default parameters"""
|
||||
# Act
|
||||
await document_rag.query("test query with defaults")
|
||||
|
||||
# Assert
|
||||
mock_doc_embeddings_client.query.assert_called_once()
|
||||
call_args = mock_doc_embeddings_client.query.call_args
|
||||
assert call_args.kwargs['user'] == "trustgraph"
|
||||
assert call_args.kwargs['collection'] == "default"
|
||||
assert call_args.kwargs['limit'] == 20
|
||||
642
tests/integration/test_kg_extract_store_integration.py
Normal file
642
tests/integration/test_kg_extract_store_integration.py
Normal file
|
|
@ -0,0 +1,642 @@
|
|||
"""
|
||||
Integration tests for Knowledge Graph Extract → Store Pipeline
|
||||
|
||||
These tests verify the end-to-end functionality of the knowledge graph extraction
|
||||
and storage pipeline, testing text-to-graph transformation, entity extraction,
|
||||
relationship extraction, and graph database storage.
|
||||
Following the TEST_STRATEGY.md approach for integration testing.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
import urllib.parse
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from trustgraph.extract.kg.definitions.extract import Processor as DefinitionsProcessor
|
||||
from trustgraph.extract.kg.relationships.extract import Processor as RelationshipsProcessor
|
||||
from trustgraph.storage.knowledge.store import Processor as KnowledgeStoreProcessor
|
||||
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Value, Error
|
||||
from trustgraph.schema import EntityContext, EntityContexts, GraphEmbeddings
|
||||
from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestKnowledgeGraphPipelineIntegration:
|
||||
"""Integration tests for Knowledge Graph Extract → Store Pipeline"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_flow_context(self):
|
||||
"""Mock flow context for service coordination"""
|
||||
context = MagicMock()
|
||||
|
||||
# Mock prompt client for definitions extraction
|
||||
prompt_client = AsyncMock()
|
||||
prompt_client.extract_definitions.return_value = [
|
||||
{
|
||||
"entity": "Machine Learning",
|
||||
"definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming."
|
||||
},
|
||||
{
|
||||
"entity": "Neural Networks",
|
||||
"definition": "Computing systems inspired by biological neural networks that process information."
|
||||
}
|
||||
]
|
||||
|
||||
# Mock prompt client for relationships extraction
|
||||
prompt_client.extract_relationships.return_value = [
|
||||
{
|
||||
"subject": "Machine Learning",
|
||||
"predicate": "is_subset_of",
|
||||
"object": "Artificial Intelligence",
|
||||
"object-entity": True
|
||||
},
|
||||
{
|
||||
"subject": "Neural Networks",
|
||||
"predicate": "is_used_in",
|
||||
"object": "Machine Learning",
|
||||
"object-entity": True
|
||||
}
|
||||
]
|
||||
|
||||
# Mock producers for output streams
|
||||
triples_producer = AsyncMock()
|
||||
entity_contexts_producer = AsyncMock()
|
||||
|
||||
# Configure context routing
|
||||
def context_router(service_name):
|
||||
if service_name == "prompt-request":
|
||||
return prompt_client
|
||||
elif service_name == "triples":
|
||||
return triples_producer
|
||||
elif service_name == "entity-contexts":
|
||||
return entity_contexts_producer
|
||||
else:
|
||||
return AsyncMock()
|
||||
|
||||
context.side_effect = context_router
|
||||
return context
|
||||
|
||||
@pytest.fixture
|
||||
def mock_cassandra_store(self):
|
||||
"""Mock Cassandra knowledge table store"""
|
||||
store = AsyncMock()
|
||||
store.add_triples.return_value = None
|
||||
store.add_graph_embeddings.return_value = None
|
||||
return store
|
||||
|
||||
@pytest.fixture
|
||||
def sample_chunk(self):
|
||||
"""Sample text chunk for processing"""
|
||||
return Chunk(
|
||||
metadata=Metadata(
|
||||
id="doc-123",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
metadata=[]
|
||||
),
|
||||
chunk=b"Machine Learning is a subset of Artificial Intelligence. Neural Networks are used in Machine Learning to process complex patterns."
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_definitions_response(self):
|
||||
"""Sample definitions extraction response"""
|
||||
return [
|
||||
{
|
||||
"entity": "Machine Learning",
|
||||
"definition": "A subset of artificial intelligence that enables computers to learn from data."
|
||||
},
|
||||
{
|
||||
"entity": "Artificial Intelligence",
|
||||
"definition": "The simulation of human intelligence in machines."
|
||||
},
|
||||
{
|
||||
"entity": "Neural Networks",
|
||||
"definition": "Computing systems inspired by biological neural networks."
|
||||
}
|
||||
]
|
||||
|
||||
@pytest.fixture
|
||||
def sample_relationships_response(self):
|
||||
"""Sample relationships extraction response"""
|
||||
return [
|
||||
{
|
||||
"subject": "Machine Learning",
|
||||
"predicate": "is_subset_of",
|
||||
"object": "Artificial Intelligence",
|
||||
"object-entity": True
|
||||
},
|
||||
{
|
||||
"subject": "Neural Networks",
|
||||
"predicate": "is_used_in",
|
||||
"object": "Machine Learning",
|
||||
"object-entity": True
|
||||
},
|
||||
{
|
||||
"subject": "Machine Learning",
|
||||
"predicate": "processes",
|
||||
"object": "data patterns",
|
||||
"object-entity": False
|
||||
}
|
||||
]
|
||||
|
||||
@pytest.fixture
|
||||
def definitions_processor(self):
|
||||
"""Create definitions processor with minimal configuration"""
|
||||
processor = MagicMock()
|
||||
processor.to_uri = DefinitionsProcessor.to_uri.__get__(processor, DefinitionsProcessor)
|
||||
processor.emit_triples = DefinitionsProcessor.emit_triples.__get__(processor, DefinitionsProcessor)
|
||||
processor.emit_ecs = DefinitionsProcessor.emit_ecs.__get__(processor, DefinitionsProcessor)
|
||||
processor.on_message = DefinitionsProcessor.on_message.__get__(processor, DefinitionsProcessor)
|
||||
return processor
|
||||
|
||||
@pytest.fixture
|
||||
def relationships_processor(self):
|
||||
"""Create relationships processor with minimal configuration"""
|
||||
processor = MagicMock()
|
||||
processor.to_uri = RelationshipsProcessor.to_uri.__get__(processor, RelationshipsProcessor)
|
||||
processor.emit_triples = RelationshipsProcessor.emit_triples.__get__(processor, RelationshipsProcessor)
|
||||
processor.on_message = RelationshipsProcessor.on_message.__get__(processor, RelationshipsProcessor)
|
||||
return processor
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_definitions_extraction_pipeline(self, definitions_processor, mock_flow_context, sample_chunk):
|
||||
"""Test definitions extraction from text chunk to graph triples"""
|
||||
# Arrange
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = sample_chunk
|
||||
mock_consumer = MagicMock()
|
||||
|
||||
# Act
|
||||
await definitions_processor.on_message(mock_msg, mock_consumer, mock_flow_context)
|
||||
|
||||
# Assert
|
||||
# Verify prompt client was called for definitions extraction
|
||||
prompt_client = mock_flow_context("prompt-request")
|
||||
prompt_client.extract_definitions.assert_called_once()
|
||||
call_args = prompt_client.extract_definitions.call_args
|
||||
assert "Machine Learning" in call_args.kwargs['text']
|
||||
assert "Neural Networks" in call_args.kwargs['text']
|
||||
|
||||
# Verify triples producer was called
|
||||
triples_producer = mock_flow_context("triples")
|
||||
triples_producer.send.assert_called_once()
|
||||
|
||||
# Verify entity contexts producer was called
|
||||
entity_contexts_producer = mock_flow_context("entity-contexts")
|
||||
entity_contexts_producer.send.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_relationships_extraction_pipeline(self, relationships_processor, mock_flow_context, sample_chunk):
|
||||
"""Test relationships extraction from text chunk to graph triples"""
|
||||
# Arrange
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = sample_chunk
|
||||
mock_consumer = MagicMock()
|
||||
|
||||
# Act
|
||||
await relationships_processor.on_message(mock_msg, mock_consumer, mock_flow_context)
|
||||
|
||||
# Assert
|
||||
# Verify prompt client was called for relationships extraction
|
||||
prompt_client = mock_flow_context("prompt-request")
|
||||
prompt_client.extract_relationships.assert_called_once()
|
||||
call_args = prompt_client.extract_relationships.call_args
|
||||
assert "Machine Learning" in call_args.kwargs['text']
|
||||
|
||||
# Verify triples producer was called
|
||||
triples_producer = mock_flow_context("triples")
|
||||
triples_producer.send.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uri_generation_consistency(self, definitions_processor, relationships_processor):
|
||||
"""Test URI generation consistency between processors"""
|
||||
# Arrange
|
||||
test_entities = [
|
||||
"Machine Learning",
|
||||
"Artificial Intelligence",
|
||||
"Neural Networks",
|
||||
"Deep Learning",
|
||||
"Natural Language Processing"
|
||||
]
|
||||
|
||||
# Act & Assert
|
||||
for entity in test_entities:
|
||||
def_uri = definitions_processor.to_uri(entity)
|
||||
rel_uri = relationships_processor.to_uri(entity)
|
||||
|
||||
# URIs should be identical between processors
|
||||
assert def_uri == rel_uri
|
||||
|
||||
# URI should be properly encoded
|
||||
assert def_uri.startswith(TRUSTGRAPH_ENTITIES)
|
||||
assert " " not in def_uri
|
||||
assert def_uri.endswith(urllib.parse.quote(entity.replace(" ", "-").lower().encode("utf-8")))
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_definitions_triple_generation(self, definitions_processor, sample_definitions_response):
|
||||
"""Test triple generation from definitions extraction"""
|
||||
# Arrange
|
||||
metadata = Metadata(
|
||||
id="test-doc",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
metadata=[]
|
||||
)
|
||||
|
||||
# Act
|
||||
triples = []
|
||||
entities = []
|
||||
|
||||
for defn in sample_definitions_response:
|
||||
s = defn["entity"]
|
||||
o = defn["definition"]
|
||||
|
||||
if s and o:
|
||||
s_uri = definitions_processor.to_uri(s)
|
||||
s_value = Value(value=str(s_uri), is_uri=True)
|
||||
o_value = Value(value=str(o), is_uri=False)
|
||||
|
||||
# Generate triples as the processor would
|
||||
triples.append(Triple(
|
||||
s=s_value,
|
||||
p=Value(value=RDF_LABEL, is_uri=True),
|
||||
o=Value(value=s, is_uri=False)
|
||||
))
|
||||
|
||||
triples.append(Triple(
|
||||
s=s_value,
|
||||
p=Value(value=DEFINITION, is_uri=True),
|
||||
o=o_value
|
||||
))
|
||||
|
||||
entities.append(EntityContext(
|
||||
entity=s_value,
|
||||
context=defn["definition"]
|
||||
))
|
||||
|
||||
# Assert
|
||||
assert len(triples) == 6 # 2 triples per entity * 3 entities
|
||||
assert len(entities) == 3 # 1 entity context per entity
|
||||
|
||||
# Verify triple structure
|
||||
label_triples = [t for t in triples if t.p.value == RDF_LABEL]
|
||||
definition_triples = [t for t in triples if t.p.value == DEFINITION]
|
||||
|
||||
assert len(label_triples) == 3
|
||||
assert len(definition_triples) == 3
|
||||
|
||||
# Verify entity contexts
|
||||
for entity in entities:
|
||||
assert entity.entity.is_uri is True
|
||||
assert entity.entity.value.startswith(TRUSTGRAPH_ENTITIES)
|
||||
assert len(entity.context) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_relationships_triple_generation(self, relationships_processor, sample_relationships_response):
|
||||
"""Test triple generation from relationships extraction"""
|
||||
# Arrange
|
||||
metadata = Metadata(
|
||||
id="test-doc",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
metadata=[]
|
||||
)
|
||||
|
||||
# Act
|
||||
triples = []
|
||||
|
||||
for rel in sample_relationships_response:
|
||||
s = rel["subject"]
|
||||
p = rel["predicate"]
|
||||
o = rel["object"]
|
||||
|
||||
if s and p and o:
|
||||
s_uri = relationships_processor.to_uri(s)
|
||||
s_value = Value(value=str(s_uri), is_uri=True)
|
||||
|
||||
p_uri = relationships_processor.to_uri(p)
|
||||
p_value = Value(value=str(p_uri), is_uri=True)
|
||||
|
||||
if rel["object-entity"]:
|
||||
o_uri = relationships_processor.to_uri(o)
|
||||
o_value = Value(value=str(o_uri), is_uri=True)
|
||||
else:
|
||||
o_value = Value(value=str(o), is_uri=False)
|
||||
|
||||
# Main relationship triple
|
||||
triples.append(Triple(s=s_value, p=p_value, o=o_value))
|
||||
|
||||
# Label triples
|
||||
triples.append(Triple(
|
||||
s=s_value,
|
||||
p=Value(value=RDF_LABEL, is_uri=True),
|
||||
o=Value(value=str(s), is_uri=False)
|
||||
))
|
||||
|
||||
triples.append(Triple(
|
||||
s=p_value,
|
||||
p=Value(value=RDF_LABEL, is_uri=True),
|
||||
o=Value(value=str(p), is_uri=False)
|
||||
))
|
||||
|
||||
if rel["object-entity"]:
|
||||
triples.append(Triple(
|
||||
s=o_value,
|
||||
p=Value(value=RDF_LABEL, is_uri=True),
|
||||
o=Value(value=str(o), is_uri=False)
|
||||
))
|
||||
|
||||
# Assert
|
||||
assert len(triples) > 0
|
||||
|
||||
# Verify relationship triples exist
|
||||
relationship_triples = [t for t in triples if t.p.value.endswith("is_subset_of") or t.p.value.endswith("is_used_in")]
|
||||
assert len(relationship_triples) >= 2
|
||||
|
||||
# Verify label triples
|
||||
label_triples = [t for t in triples if t.p.value == RDF_LABEL]
|
||||
assert len(label_triples) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_knowledge_store_triples_storage(self, mock_cassandra_store):
|
||||
"""Test knowledge store triples storage integration"""
|
||||
# Arrange
|
||||
processor = MagicMock()
|
||||
processor.table_store = mock_cassandra_store
|
||||
processor.on_triples = KnowledgeStoreProcessor.on_triples.__get__(processor, KnowledgeStoreProcessor)
|
||||
|
||||
sample_triples = Triples(
|
||||
metadata=Metadata(
|
||||
id="test-doc",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
metadata=[]
|
||||
),
|
||||
triples=[
|
||||
Triple(
|
||||
s=Value(value="http://trustgraph.ai/e/machine-learning", is_uri=True),
|
||||
p=Value(value=DEFINITION, is_uri=True),
|
||||
o=Value(value="A subset of AI", is_uri=False)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = sample_triples
|
||||
|
||||
# Act
|
||||
await processor.on_triples(mock_msg, None, None)
|
||||
|
||||
# Assert
|
||||
mock_cassandra_store.add_triples.assert_called_once_with(sample_triples)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_knowledge_store_graph_embeddings_storage(self, mock_cassandra_store):
|
||||
"""Test knowledge store graph embeddings storage integration"""
|
||||
# Arrange
|
||||
processor = MagicMock()
|
||||
processor.table_store = mock_cassandra_store
|
||||
processor.on_graph_embeddings = KnowledgeStoreProcessor.on_graph_embeddings.__get__(processor, KnowledgeStoreProcessor)
|
||||
|
||||
sample_embeddings = GraphEmbeddings(
|
||||
metadata=Metadata(
|
||||
id="test-doc",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
metadata=[]
|
||||
),
|
||||
entities=[]
|
||||
)
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = sample_embeddings
|
||||
|
||||
# Act
|
||||
await processor.on_graph_embeddings(mock_msg, None, None)
|
||||
|
||||
# Assert
|
||||
mock_cassandra_store.add_graph_embeddings.assert_called_once_with(sample_embeddings)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_to_end_pipeline_coordination(self, definitions_processor, relationships_processor,
|
||||
mock_flow_context, sample_chunk):
|
||||
"""Test end-to-end pipeline coordination from chunk to storage"""
|
||||
# Arrange
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = sample_chunk
|
||||
mock_consumer = MagicMock()
|
||||
|
||||
# Act - Process through definitions extractor
|
||||
await definitions_processor.on_message(mock_msg, mock_consumer, mock_flow_context)
|
||||
|
||||
# Act - Process through relationships extractor
|
||||
await relationships_processor.on_message(mock_msg, mock_consumer, mock_flow_context)
|
||||
|
||||
# Assert
|
||||
# Verify both extractors called prompt service
|
||||
prompt_client = mock_flow_context("prompt-request")
|
||||
prompt_client.extract_definitions.assert_called_once()
|
||||
prompt_client.extract_relationships.assert_called_once()
|
||||
|
||||
# Verify triples were produced from both extractors
|
||||
triples_producer = mock_flow_context("triples")
|
||||
assert triples_producer.send.call_count == 2 # One from each extractor
|
||||
|
||||
# Verify entity contexts were produced from definitions extractor
|
||||
entity_contexts_producer = mock_flow_context("entity-contexts")
|
||||
entity_contexts_producer.send.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling_in_definitions_extraction(self, definitions_processor, mock_flow_context, sample_chunk):
|
||||
"""Test error handling in definitions extraction"""
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").extract_definitions.side_effect = Exception("Prompt service unavailable")
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = sample_chunk
|
||||
mock_consumer = MagicMock()
|
||||
|
||||
# Act & Assert
|
||||
# Should not raise exception, but should handle it gracefully
|
||||
await definitions_processor.on_message(mock_msg, mock_consumer, mock_flow_context)
|
||||
|
||||
# Verify prompt was attempted
|
||||
prompt_client = mock_flow_context("prompt-request")
|
||||
prompt_client.extract_definitions.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling_in_relationships_extraction(self, relationships_processor, mock_flow_context, sample_chunk):
|
||||
"""Test error handling in relationships extraction"""
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").extract_relationships.side_effect = Exception("Prompt service unavailable")
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = sample_chunk
|
||||
mock_consumer = MagicMock()
|
||||
|
||||
# Act & Assert
|
||||
# Should not raise exception, but should handle it gracefully
|
||||
await relationships_processor.on_message(mock_msg, mock_consumer, mock_flow_context)
|
||||
|
||||
# Verify prompt was attempted
|
||||
prompt_client = mock_flow_context("prompt-request")
|
||||
prompt_client.extract_relationships.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_extraction_results_handling(self, definitions_processor, mock_flow_context, sample_chunk):
|
||||
"""Test handling of empty extraction results"""
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").extract_definitions.return_value = []
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = sample_chunk
|
||||
mock_consumer = MagicMock()
|
||||
|
||||
# Act
|
||||
await definitions_processor.on_message(mock_msg, mock_consumer, mock_flow_context)
|
||||
|
||||
# Assert
|
||||
# Should still call producers but with empty results
|
||||
triples_producer = mock_flow_context("triples")
|
||||
entity_contexts_producer = mock_flow_context("entity-contexts")
|
||||
|
||||
triples_producer.send.assert_called_once()
|
||||
entity_contexts_producer.send.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_extraction_format_handling(self, definitions_processor, mock_flow_context, sample_chunk):
|
||||
"""Test handling of invalid extraction response format"""
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").extract_definitions.return_value = "invalid format" # Should be list
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = sample_chunk
|
||||
mock_consumer = MagicMock()
|
||||
|
||||
# Act & Assert
|
||||
# Should handle invalid format gracefully
|
||||
await definitions_processor.on_message(mock_msg, mock_consumer, mock_flow_context)
|
||||
|
||||
# Verify prompt was attempted
|
||||
prompt_client = mock_flow_context("prompt-request")
|
||||
prompt_client.extract_definitions.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_entity_filtering_and_validation(self, definitions_processor, mock_flow_context):
|
||||
"""Test entity filtering and validation in extraction"""
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").extract_definitions.return_value = [
|
||||
{"entity": "Valid Entity", "definition": "Valid definition"},
|
||||
{"entity": "", "definition": "Empty entity"}, # Should be filtered
|
||||
{"entity": "Valid Entity 2", "definition": ""}, # Should be filtered
|
||||
{"entity": None, "definition": "None entity"}, # Should be filtered
|
||||
{"entity": "Valid Entity 3", "definition": None}, # Should be filtered
|
||||
]
|
||||
|
||||
sample_chunk = Chunk(
|
||||
metadata=Metadata(id="test", user="user", collection="collection", metadata=[]),
|
||||
chunk=b"Test chunk"
|
||||
)
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = sample_chunk
|
||||
mock_consumer = MagicMock()
|
||||
|
||||
# Act
|
||||
await definitions_processor.on_message(mock_msg, mock_consumer, mock_flow_context)
|
||||
|
||||
# Assert
|
||||
# Should only process valid entities
|
||||
triples_producer = mock_flow_context("triples")
|
||||
entity_contexts_producer = mock_flow_context("entity-contexts")
|
||||
|
||||
triples_producer.send.assert_called_once()
|
||||
entity_contexts_producer.send.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.slow
|
||||
async def test_large_batch_processing_performance(self, definitions_processor, relationships_processor,
|
||||
mock_flow_context):
|
||||
"""Test performance with large batch of chunks"""
|
||||
# Arrange
|
||||
large_chunk_batch = [
|
||||
Chunk(
|
||||
metadata=Metadata(id=f"doc-{i}", user="user", collection="collection", metadata=[]),
|
||||
chunk=f"Document {i} contains machine learning and AI content.".encode("utf-8")
|
||||
)
|
||||
for i in range(100) # Large batch
|
||||
]
|
||||
|
||||
mock_consumer = MagicMock()
|
||||
|
||||
# Act
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
for chunk in large_chunk_batch:
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = chunk
|
||||
|
||||
# Process through both extractors
|
||||
await definitions_processor.on_message(mock_msg, mock_consumer, mock_flow_context)
|
||||
await relationships_processor.on_message(mock_msg, mock_consumer, mock_flow_context)
|
||||
|
||||
end_time = time.time()
|
||||
execution_time = end_time - start_time
|
||||
|
||||
# Assert
|
||||
assert execution_time < 30.0 # Should complete within reasonable time
|
||||
|
||||
# Verify all chunks were processed
|
||||
prompt_client = mock_flow_context("prompt-request")
|
||||
assert prompt_client.extract_definitions.call_count == 100
|
||||
assert prompt_client.extract_relationships.call_count == 100
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_metadata_propagation_through_pipeline(self, definitions_processor, mock_flow_context):
|
||||
"""Test metadata propagation through the pipeline"""
|
||||
# Arrange
|
||||
original_metadata = Metadata(
|
||||
id="test-doc-123",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
metadata=[
|
||||
Triple(
|
||||
s=Value(value="doc:test", is_uri=True),
|
||||
p=Value(value="dc:title", is_uri=True),
|
||||
o=Value(value="Test Document", is_uri=False)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
sample_chunk = Chunk(
|
||||
metadata=original_metadata,
|
||||
chunk=b"Test content for metadata propagation"
|
||||
)
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = sample_chunk
|
||||
mock_consumer = MagicMock()
|
||||
|
||||
# Act
|
||||
await definitions_processor.on_message(mock_msg, mock_consumer, mock_flow_context)
|
||||
|
||||
# Assert
|
||||
# Verify metadata was propagated to output
|
||||
triples_producer = mock_flow_context("triples")
|
||||
entity_contexts_producer = mock_flow_context("entity-contexts")
|
||||
|
||||
triples_producer.send.assert_called_once()
|
||||
entity_contexts_producer.send.assert_called_once()
|
||||
|
||||
# Check that metadata was included in the calls
|
||||
triples_call = triples_producer.send.call_args[0][0]
|
||||
entity_contexts_call = entity_contexts_producer.send.call_args[0][0]
|
||||
|
||||
assert triples_call.metadata.id == "test-doc-123"
|
||||
assert triples_call.metadata.user == "test_user"
|
||||
assert triples_call.metadata.collection == "test_collection"
|
||||
|
||||
assert entity_contexts_call.metadata.id == "test-doc-123"
|
||||
assert entity_contexts_call.metadata.user == "test_user"
|
||||
assert entity_contexts_call.metadata.collection == "test_collection"
|
||||
429
tests/integration/test_text_completion_integration.py
Normal file
429
tests/integration/test_text_completion_integration.py
Normal file
|
|
@ -0,0 +1,429 @@
|
|||
"""
|
||||
Integration tests for Text Completion Service (OpenAI)
|
||||
|
||||
These tests verify the end-to-end functionality of the OpenAI text completion service,
|
||||
testing API connectivity, authentication, rate limiting, error handling, and token tracking.
|
||||
Following the TEST_STRATEGY.md approach for integration testing.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import os
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from openai import OpenAI, RateLimitError
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionMessage
|
||||
from openai.types.chat.chat_completion import Choice
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
|
||||
from trustgraph.model.text_completion.openai.llm import Processor
|
||||
from trustgraph.exceptions import TooManyRequests
|
||||
from trustgraph.base import LlmResult
|
||||
from trustgraph.schema import TextCompletionRequest, TextCompletionResponse, Error
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestTextCompletionIntegration:
|
||||
"""Integration tests for OpenAI text completion service coordination"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_client(self):
|
||||
"""Mock OpenAI client that returns realistic responses"""
|
||||
client = MagicMock(spec=OpenAI)
|
||||
|
||||
# Mock chat completion response
|
||||
usage = CompletionUsage(prompt_tokens=50, completion_tokens=100, total_tokens=150)
|
||||
message = ChatCompletionMessage(role="assistant", content="This is a test response from the AI model.")
|
||||
choice = Choice(index=0, message=message, finish_reason="stop")
|
||||
|
||||
completion = ChatCompletion(
|
||||
id="chatcmpl-test123",
|
||||
choices=[choice],
|
||||
created=1234567890,
|
||||
model="gpt-3.5-turbo",
|
||||
object="chat.completion",
|
||||
usage=usage
|
||||
)
|
||||
|
||||
client.chat.completions.create.return_value = completion
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def processor_config(self):
|
||||
"""Configuration for processor testing"""
|
||||
return {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"temperature": 0.7,
|
||||
"max_output": 1024,
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def text_completion_processor(self, processor_config):
|
||||
"""Create text completion processor with test configuration"""
|
||||
# Create a minimal processor instance for testing generate_content
|
||||
processor = MagicMock()
|
||||
processor.model = processor_config["model"]
|
||||
processor.temperature = processor_config["temperature"]
|
||||
processor.max_output = processor_config["max_output"]
|
||||
|
||||
# Add the actual generate_content method from Processor class
|
||||
processor.generate_content = Processor.generate_content.__get__(processor, Processor)
|
||||
|
||||
return processor
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_completion_successful_generation(self, text_completion_processor, mock_openai_client):
|
||||
"""Test successful text completion generation"""
|
||||
# Arrange
|
||||
text_completion_processor.openai = mock_openai_client
|
||||
system_prompt = "You are a helpful assistant."
|
||||
user_prompt = "What is machine learning?"
|
||||
|
||||
# Act
|
||||
result = await text_completion_processor.generate_content(system_prompt, user_prompt)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "This is a test response from the AI model."
|
||||
assert result.in_token == 50
|
||||
assert result.out_token == 100
|
||||
assert result.model == "gpt-3.5-turbo"
|
||||
|
||||
# Verify OpenAI API was called correctly
|
||||
mock_openai_client.chat.completions.create.assert_called_once()
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
|
||||
assert call_args.kwargs['model'] == "gpt-3.5-turbo"
|
||||
assert call_args.kwargs['temperature'] == 0.7
|
||||
assert call_args.kwargs['max_tokens'] == 1024
|
||||
assert len(call_args.kwargs['messages']) == 1
|
||||
assert call_args.kwargs['messages'][0]['role'] == "user"
|
||||
assert "You are a helpful assistant." in call_args.kwargs['messages'][0]['content'][0]['text']
|
||||
assert "What is machine learning?" in call_args.kwargs['messages'][0]['content'][0]['text']
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_completion_with_different_configurations(self, mock_openai_client):
|
||||
"""Test text completion with various configuration parameters"""
|
||||
# Test different configurations
|
||||
test_configs = [
|
||||
{"model": "gpt-4", "temperature": 0.0, "max_output": 512},
|
||||
{"model": "gpt-3.5-turbo", "temperature": 1.0, "max_output": 2048},
|
||||
{"model": "gpt-4-turbo", "temperature": 0.5, "max_output": 4096}
|
||||
]
|
||||
|
||||
for config in test_configs:
|
||||
# Arrange - Create minimal processor mock
|
||||
processor = MagicMock()
|
||||
processor.model = config['model']
|
||||
processor.temperature = config['temperature']
|
||||
processor.max_output = config['max_output']
|
||||
processor.openai = mock_openai_client
|
||||
|
||||
# Add the actual generate_content method
|
||||
processor.generate_content = Processor.generate_content.__get__(processor, Processor)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "This is a test response from the AI model."
|
||||
assert result.in_token == 50
|
||||
assert result.out_token == 100
|
||||
# Note: result.model comes from mock response, not processor config
|
||||
|
||||
# Verify configuration was applied
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
assert call_args.kwargs['model'] == config['model']
|
||||
assert call_args.kwargs['temperature'] == config['temperature']
|
||||
assert call_args.kwargs['max_tokens'] == config['max_output']
|
||||
|
||||
# Reset mock for next iteration
|
||||
mock_openai_client.reset_mock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_completion_rate_limit_handling(self, text_completion_processor, mock_openai_client):
|
||||
"""Test proper rate limit error handling"""
|
||||
# Arrange
|
||||
mock_openai_client.chat.completions.create.side_effect = RateLimitError(
|
||||
"Rate limit exceeded",
|
||||
response=MagicMock(status_code=429),
|
||||
body={}
|
||||
)
|
||||
text_completion_processor.openai = mock_openai_client
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(TooManyRequests):
|
||||
await text_completion_processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Verify OpenAI API was called
|
||||
mock_openai_client.chat.completions.create.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_completion_api_error_handling(self, text_completion_processor, mock_openai_client):
|
||||
"""Test handling of general API errors"""
|
||||
# Arrange
|
||||
mock_openai_client.chat.completions.create.side_effect = Exception("API connection failed")
|
||||
text_completion_processor.openai = mock_openai_client
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await text_completion_processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
assert "API connection failed" in str(exc_info.value)
|
||||
mock_openai_client.chat.completions.create.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_completion_token_tracking(self, text_completion_processor, mock_openai_client):
|
||||
"""Test accurate token counting and tracking"""
|
||||
# Arrange - Different token counts for multiple requests
|
||||
test_cases = [
|
||||
(25, 75), # Small request
|
||||
(100, 200), # Medium request
|
||||
(500, 1000) # Large request
|
||||
]
|
||||
|
||||
for input_tokens, output_tokens in test_cases:
|
||||
# Update mock response with different token counts
|
||||
usage = CompletionUsage(
|
||||
prompt_tokens=input_tokens,
|
||||
completion_tokens=output_tokens,
|
||||
total_tokens=input_tokens + output_tokens
|
||||
)
|
||||
message = ChatCompletionMessage(role="assistant", content="Test response")
|
||||
choice = Choice(index=0, message=message, finish_reason="stop")
|
||||
|
||||
completion = ChatCompletion(
|
||||
id="chatcmpl-test123",
|
||||
choices=[choice],
|
||||
created=1234567890,
|
||||
model="gpt-3.5-turbo",
|
||||
object="chat.completion",
|
||||
usage=usage
|
||||
)
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = completion
|
||||
text_completion_processor.openai = mock_openai_client
|
||||
|
||||
# Act
|
||||
result = await text_completion_processor.generate_content("System", "Prompt")
|
||||
|
||||
# Assert
|
||||
assert result.in_token == input_tokens
|
||||
assert result.out_token == output_tokens
|
||||
assert result.model == "gpt-3.5-turbo"
|
||||
|
||||
# Reset mock for next iteration
|
||||
mock_openai_client.reset_mock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_completion_prompt_construction(self, text_completion_processor, mock_openai_client):
|
||||
"""Test proper prompt construction with system and user prompts"""
|
||||
# Arrange
|
||||
text_completion_processor.openai = mock_openai_client
|
||||
system_prompt = "You are an expert in artificial intelligence."
|
||||
user_prompt = "Explain neural networks in simple terms."
|
||||
|
||||
# Act
|
||||
result = await text_completion_processor.generate_content(system_prompt, user_prompt)
|
||||
|
||||
# Assert
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
sent_message = call_args.kwargs['messages'][0]['content'][0]['text']
|
||||
|
||||
# Verify system and user prompts are combined correctly
|
||||
assert system_prompt in sent_message
|
||||
assert user_prompt in sent_message
|
||||
assert sent_message.startswith(system_prompt)
|
||||
assert user_prompt in sent_message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_completion_concurrent_requests(self, processor_config, mock_openai_client):
|
||||
"""Test handling of concurrent requests"""
|
||||
# Arrange
|
||||
processors = []
|
||||
for i in range(5):
|
||||
processor = MagicMock()
|
||||
processor.model = processor_config["model"]
|
||||
processor.temperature = processor_config["temperature"]
|
||||
processor.max_output = processor_config["max_output"]
|
||||
processor.openai = mock_openai_client
|
||||
processor.generate_content = Processor.generate_content.__get__(processor, Processor)
|
||||
processors.append(processor)
|
||||
|
||||
# Simulate multiple concurrent requests
|
||||
tasks = []
|
||||
for i, processor in enumerate(processors):
|
||||
task = processor.generate_content(f"System {i}", f"Prompt {i}")
|
||||
tasks.append(task)
|
||||
|
||||
# Act
|
||||
import asyncio
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# Assert
|
||||
assert len(results) == 5
|
||||
for result in results:
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "This is a test response from the AI model."
|
||||
assert result.in_token == 50
|
||||
assert result.out_token == 100
|
||||
|
||||
# Verify all requests were processed
|
||||
assert mock_openai_client.chat.completions.create.call_count == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_completion_response_format_validation(self, text_completion_processor, mock_openai_client):
|
||||
"""Test response format and structure validation"""
|
||||
# Arrange
|
||||
text_completion_processor.openai = mock_openai_client
|
||||
|
||||
# Act
|
||||
result = await text_completion_processor.generate_content("System", "Prompt")
|
||||
|
||||
# Assert
|
||||
# Verify OpenAI API call parameters
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
assert call_args.kwargs['response_format'] == {"type": "text"}
|
||||
assert call_args.kwargs['top_p'] == 1
|
||||
assert call_args.kwargs['frequency_penalty'] == 0
|
||||
assert call_args.kwargs['presence_penalty'] == 0
|
||||
|
||||
# Verify result structure
|
||||
assert hasattr(result, 'text')
|
||||
assert hasattr(result, 'in_token')
|
||||
assert hasattr(result, 'out_token')
|
||||
assert hasattr(result, 'model')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_completion_authentication_patterns(self):
|
||||
"""Test different authentication configurations"""
|
||||
# Test missing API key first (this should fail early)
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
Processor(id="test-no-key", api_key=None)
|
||||
assert "OpenAI API key not specified" in str(exc_info.value)
|
||||
|
||||
# Test authentication pattern by examining the initialization logic
|
||||
# Since we can't fully instantiate due to taskgroup requirements,
|
||||
# we'll test the authentication logic directly
|
||||
from trustgraph.model.text_completion.openai.llm import default_api_key, default_base_url
|
||||
|
||||
# Test default values
|
||||
assert default_base_url == "https://api.openai.com/v1"
|
||||
|
||||
# Test configuration parameters
|
||||
test_configs = [
|
||||
{"api_key": "test-key-1", "url": "https://api.openai.com/v1"},
|
||||
{"api_key": "test-key-2", "url": "https://custom.openai.com/v1"},
|
||||
]
|
||||
|
||||
for config in test_configs:
|
||||
# We can't fully test instantiation due to taskgroup,
|
||||
# but we can verify the authentication logic would work
|
||||
assert config["api_key"] is not None
|
||||
assert config["url"] is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_completion_error_propagation(self, text_completion_processor, mock_openai_client):
|
||||
"""Test error propagation through the service"""
|
||||
# Test different error types
|
||||
error_cases = [
|
||||
(RateLimitError("Rate limit", response=MagicMock(status_code=429), body={}), TooManyRequests),
|
||||
(Exception("Connection timeout"), Exception),
|
||||
(ValueError("Invalid request"), ValueError),
|
||||
]
|
||||
|
||||
for error_input, expected_error in error_cases:
|
||||
# Arrange
|
||||
mock_openai_client.chat.completions.create.side_effect = error_input
|
||||
text_completion_processor.openai = mock_openai_client
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(expected_error):
|
||||
await text_completion_processor.generate_content("System", "Prompt")
|
||||
|
||||
# Reset mock for next iteration
|
||||
mock_openai_client.reset_mock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_completion_model_parameter_validation(self, mock_openai_client):
|
||||
"""Test that model parameters are correctly passed to OpenAI API"""
|
||||
# Arrange
|
||||
processor = MagicMock()
|
||||
processor.model = "gpt-4"
|
||||
processor.temperature = 0.8
|
||||
processor.max_output = 2048
|
||||
processor.openai = mock_openai_client
|
||||
processor.generate_content = Processor.generate_content.__get__(processor, Processor)
|
||||
|
||||
# Act
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
assert call_args.kwargs['model'] == "gpt-4"
|
||||
assert call_args.kwargs['temperature'] == 0.8
|
||||
assert call_args.kwargs['max_tokens'] == 2048
|
||||
assert call_args.kwargs['top_p'] == 1
|
||||
assert call_args.kwargs['frequency_penalty'] == 0
|
||||
assert call_args.kwargs['presence_penalty'] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.slow
|
||||
async def test_text_completion_performance_timing(self, text_completion_processor, mock_openai_client):
|
||||
"""Test performance timing for text completion"""
|
||||
# Arrange
|
||||
text_completion_processor.openai = mock_openai_client
|
||||
|
||||
# Act
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
result = await text_completion_processor.generate_content("System", "Prompt")
|
||||
|
||||
end_time = time.time()
|
||||
execution_time = end_time - start_time
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert execution_time < 1.0 # Should complete quickly with mocked API
|
||||
mock_openai_client.chat.completions.create.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_completion_response_content_extraction(self, text_completion_processor, mock_openai_client):
|
||||
"""Test proper extraction of response content from OpenAI API"""
|
||||
# Arrange
|
||||
test_responses = [
|
||||
"This is a simple response.",
|
||||
"This is a multi-line response.\nWith multiple lines.\nAnd more content.",
|
||||
"Response with special characters: @#$%^&*()_+-=[]{}|;':\",./<>?",
|
||||
"" # Empty response
|
||||
]
|
||||
|
||||
for test_content in test_responses:
|
||||
# Update mock response
|
||||
usage = CompletionUsage(prompt_tokens=10, completion_tokens=20, total_tokens=30)
|
||||
message = ChatCompletionMessage(role="assistant", content=test_content)
|
||||
choice = Choice(index=0, message=message, finish_reason="stop")
|
||||
|
||||
completion = ChatCompletion(
|
||||
id="chatcmpl-test123",
|
||||
choices=[choice],
|
||||
created=1234567890,
|
||||
model="gpt-3.5-turbo",
|
||||
object="chat.completion",
|
||||
usage=usage
|
||||
)
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = completion
|
||||
text_completion_processor.openai = mock_openai_client
|
||||
|
||||
# Act
|
||||
result = await text_completion_processor.generate_content("System", "Prompt")
|
||||
|
||||
# Assert
|
||||
assert result.text == test_content
|
||||
assert result.in_token == 10
|
||||
assert result.out_token == 20
|
||||
assert result.model == "gpt-3.5-turbo"
|
||||
|
||||
# Reset mock for next iteration
|
||||
mock_openai_client.reset_mock()
|
||||
Loading…
Add table
Add a link
Reference in a new issue