mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-30 02:46:23 +02:00
Test suite executed from CI pipeline (#433)
* Test strategy & test cases * Unit tests * Integration tests
This commit is contained in:
parent
9c7a070681
commit
2f7fddd206
101 changed files with 17811 additions and 1 deletions
|
|
@ -1,27 +0,0 @@
|
|||
|
||||
test-prompt-... is tested with this prompt set...
|
||||
|
||||
prompt-template \
|
||||
-p pulsar://localhost:6650 \
|
||||
--system-prompt 'You are a {{attitude}}, you are called {{name}}' \
|
||||
--global-term \
|
||||
'name=Craig' \
|
||||
'attitude=LOUD, SHOUTY ANNOYING BOT' \
|
||||
--prompt \
|
||||
'question={{question}}' \
|
||||
'french-question={{question}}' \
|
||||
"analyze=Find the name and age in this text, and output a JSON structure containing just the name and age fields: {{description}}. Don't add markup, just output the raw JSON object." \
|
||||
"graph-query=Study the following knowledge graph, and then answer the question.\\n\nGraph:\\n{% for edge in knowledge %}({{edge.0}})-[{{edge.1}}]->({{edge.2}})\\n{%endfor%}\\nQuestion:\\n{{question}}" \
|
||||
"extract-definition=Analyse the text provided, and then return a list of terms and definitions. The output should be a JSON array, each item in the array is an object with fields 'term' and 'definition'.Don't add markup, just output the raw JSON object. Here is the text:\\n{{text}}" \
|
||||
--prompt-response-type \
|
||||
'question=text' \
|
||||
'analyze=json' \
|
||||
'graph-query=text' \
|
||||
'extract-definition=json' \
|
||||
--prompt-term \
|
||||
'question=name:Bonny' \
|
||||
'french-question=attitude:French-speaking bot' \
|
||||
--prompt-schema \
|
||||
'analyze={ "type" : "object", "properties" : { "age": { "type" : "number" }, "name": { "type" : "string" } } }' \
|
||||
'extract-definition={ "type": "array", "items": { "type": "object", "properties": { "term": { "type": "string" }, "definition": { "type": "string" } }, "required": [ "term", "definition" ] } }'
|
||||
|
||||
3
tests/__init__.py
Normal file
3
tests/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
"""
|
||||
TrustGraph test suite
|
||||
"""
|
||||
269
tests/integration/README.md
Normal file
269
tests/integration/README.md
Normal file
|
|
@ -0,0 +1,269 @@
|
|||
# Integration Test Pattern for TrustGraph
|
||||
|
||||
This directory contains integration tests that verify the coordination between multiple TrustGraph services and components, following the patterns outlined in [TEST_STRATEGY.md](../../TEST_STRATEGY.md).
|
||||
|
||||
## Integration Test Approach
|
||||
|
||||
Integration tests focus on **service-to-service communication patterns** and **end-to-end message flows** while still using mocks for external infrastructure.
|
||||
|
||||
### Key Principles
|
||||
|
||||
1. **Test Service Coordination**: Verify that services work together correctly
|
||||
2. **Mock External Dependencies**: Use mocks for databases, APIs, and infrastructure
|
||||
3. **Real Business Logic**: Exercise actual service logic and data transformations
|
||||
4. **Error Propagation**: Test how errors flow through the system
|
||||
5. **Configuration Testing**: Verify services respond correctly to different configurations
|
||||
|
||||
## Test Structure
|
||||
|
||||
### Fixtures (conftest.py)
|
||||
|
||||
Common fixtures for integration tests:
|
||||
- `mock_pulsar_client`: Mock Pulsar messaging client
|
||||
- `mock_flow_context`: Mock flow context for service coordination
|
||||
- `integration_config`: Standard configuration for integration tests
|
||||
- `sample_documents`: Test document collections
|
||||
- `sample_embeddings`: Test embedding vectors
|
||||
- `sample_queries`: Test query sets
|
||||
|
||||
### Test Patterns
|
||||
|
||||
#### 1. End-to-End Flow Testing
|
||||
|
||||
```python
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_end_to_end_flow(self, service_instance, mock_clients):
|
||||
"""Test complete service pipeline from input to output"""
|
||||
# Arrange - Set up realistic test data
|
||||
# Act - Execute the full service workflow
|
||||
# Assert - Verify coordination between all components
|
||||
```
|
||||
|
||||
#### 2. Error Propagation Testing
|
||||
|
||||
```python
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_error_handling(self, service_instance, mock_clients):
|
||||
"""Test how errors propagate through service coordination"""
|
||||
# Arrange - Set up failure scenarios
|
||||
# Act - Execute service with failing dependency
|
||||
# Assert - Verify proper error handling and cleanup
|
||||
```
|
||||
|
||||
#### 3. Configuration Testing
|
||||
|
||||
```python
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_configuration_scenarios(self, service_instance):
|
||||
"""Test service behavior with different configurations"""
|
||||
# Test multiple configuration scenarios
|
||||
# Verify service adapts correctly to each configuration
|
||||
```
|
||||
|
||||
## Running Integration Tests
|
||||
|
||||
### Run All Integration Tests
|
||||
```bash
|
||||
pytest tests/integration/ -m integration
|
||||
```
|
||||
|
||||
### Run Specific Test
|
||||
```bash
|
||||
pytest tests/integration/test_document_rag_integration.py::TestDocumentRagIntegration::test_document_rag_end_to_end_flow -v
|
||||
```
|
||||
|
||||
### Run with Coverage (Skip Coverage Requirement)
|
||||
```bash
|
||||
pytest tests/integration/ -m integration --cov=trustgraph --cov-fail-under=0
|
||||
```
|
||||
|
||||
### Run Slow Tests
|
||||
```bash
|
||||
pytest tests/integration/ -m "integration and slow"
|
||||
```
|
||||
|
||||
### Skip Slow Tests
|
||||
```bash
|
||||
pytest tests/integration/ -m "integration and not slow"
|
||||
```
|
||||
|
||||
## Examples: Integration Test Implementations
|
||||
|
||||
### 1. Document RAG Integration Test
|
||||
|
||||
The `test_document_rag_integration.py` demonstrates the integration test pattern:
|
||||
|
||||
### What It Tests
|
||||
- **Service Coordination**: Embeddings → Document Retrieval → Prompt Generation
|
||||
- **Error Handling**: Failure scenarios for each service dependency
|
||||
- **Configuration**: Different document limits, users, and collections
|
||||
- **Performance**: Large document set handling
|
||||
|
||||
### Key Features
|
||||
- **Realistic Data Flow**: Uses actual service logic with mocked dependencies
|
||||
- **Multiple Scenarios**: Success, failure, and edge cases
|
||||
- **Verbose Logging**: Tests logging functionality
|
||||
- **Multi-User Support**: Tests user and collection isolation
|
||||
|
||||
### Test Coverage
|
||||
- ✅ End-to-end happy path
|
||||
- ✅ No documents found scenario
|
||||
- ✅ Service failure scenarios (embeddings, documents, prompt)
|
||||
- ✅ Configuration variations
|
||||
- ✅ Multi-user isolation
|
||||
- ✅ Performance testing
|
||||
- ✅ Verbose logging
|
||||
|
||||
### 2. Text Completion Integration Test
|
||||
|
||||
The `test_text_completion_integration.py` demonstrates external API integration testing:
|
||||
|
||||
### What It Tests
|
||||
- **External API Integration**: OpenAI API connectivity and authentication
|
||||
- **Rate Limiting**: Proper handling of API rate limits and retries
|
||||
- **Error Handling**: API failures, connection timeouts, and error propagation
|
||||
- **Token Tracking**: Accurate input/output token counting and metrics
|
||||
- **Configuration**: Different model parameters and settings
|
||||
- **Concurrency**: Multiple simultaneous API requests
|
||||
|
||||
### Key Features
|
||||
- **Realistic Mock Responses**: Uses actual OpenAI API response structures
|
||||
- **Authentication Testing**: API key validation and base URL configuration
|
||||
- **Error Scenarios**: Rate limits, connection failures, invalid requests
|
||||
- **Performance Metrics**: Timing and token usage validation
|
||||
- **Model Flexibility**: Tests different GPT models and parameters
|
||||
|
||||
### Test Coverage
|
||||
- ✅ Successful text completion generation
|
||||
- ✅ Multiple model configurations (GPT-3.5, GPT-4, GPT-4-turbo)
|
||||
- ✅ Rate limit handling (RateLimitError → TooManyRequests)
|
||||
- ✅ API error handling and propagation
|
||||
- ✅ Token counting accuracy
|
||||
- ✅ Prompt construction and parameter validation
|
||||
- ✅ Authentication patterns and API key validation
|
||||
- ✅ Concurrent request processing
|
||||
- ✅ Response content extraction and validation
|
||||
- ✅ Performance timing measurements
|
||||
|
||||
### 3. Agent Manager Integration Test
|
||||
|
||||
The `test_agent_manager_integration.py` demonstrates complex service coordination testing:
|
||||
|
||||
### What It Tests
|
||||
- **ReAct Pattern**: Think-Act-Observe cycles with multi-step reasoning
|
||||
- **Tool Coordination**: Selection and execution of different tools (knowledge query, text completion, MCP tools)
|
||||
- **Conversation State**: Management of conversation history and context
|
||||
- **Multi-Service Integration**: Coordination between prompt, graph RAG, and tool services
|
||||
- **Error Handling**: Tool failures, unknown tools, and error propagation
|
||||
- **Configuration Management**: Dynamic tool loading and configuration
|
||||
|
||||
### Key Features
|
||||
- **Complex Coordination**: Tests agent reasoning with multiple tool options
|
||||
- **Stateful Processing**: Maintains conversation history across interactions
|
||||
- **Dynamic Tool Selection**: Tests tool selection based on context and reasoning
|
||||
- **Callback Pattern**: Tests think/observe callback mechanisms
|
||||
- **JSON Serialization**: Handles complex data structures in prompts
|
||||
- **Performance Testing**: Large conversation history handling
|
||||
|
||||
### Test Coverage
|
||||
- ✅ Basic reasoning cycle with tool selection
|
||||
- ✅ Final answer generation (ending ReAct cycle)
|
||||
- ✅ Full ReAct cycle with tool execution
|
||||
- ✅ Conversation history management
|
||||
- ✅ Multiple tool coordination and selection
|
||||
- ✅ Tool argument validation and processing
|
||||
- ✅ Error handling (unknown tools, execution failures)
|
||||
- ✅ Context integration and additional prompting
|
||||
- ✅ Empty tool configuration handling
|
||||
- ✅ Tool response processing and cleanup
|
||||
- ✅ Performance with large conversation history
|
||||
- ✅ JSON serialization in complex prompts
|
||||
|
||||
### 4. Knowledge Graph Extract → Store Pipeline Integration Test
|
||||
|
||||
The `test_kg_extract_store_integration.py` demonstrates multi-stage pipeline testing:
|
||||
|
||||
### What It Tests
|
||||
- **Text-to-Graph Transformation**: Complete pipeline from text chunks to graph triples
|
||||
- **Entity Extraction**: Definition extraction with proper URI generation
|
||||
- **Relationship Extraction**: Subject-predicate-object relationship extraction
|
||||
- **Graph Database Integration**: Storage coordination with Cassandra knowledge store
|
||||
- **Data Validation**: Entity filtering, validation, and consistency checks
|
||||
- **Pipeline Coordination**: Multi-stage processing with proper data flow
|
||||
|
||||
### Key Features
|
||||
- **Multi-Stage Pipeline**: Tests definitions → relationships → storage coordination
|
||||
- **Graph Data Structures**: RDF triples, entity contexts, and graph embeddings
|
||||
- **URI Generation**: Consistent entity URI creation across pipeline stages
|
||||
- **Data Transformation**: Complex text analysis to structured graph data
|
||||
- **Batch Processing**: Large document set processing performance
|
||||
- **Error Resilience**: Graceful handling of extraction failures
|
||||
|
||||
### Test Coverage
|
||||
- ✅ Definitions extraction pipeline (text → entities + definitions)
|
||||
- ✅ Relationships extraction pipeline (text → subject-predicate-object)
|
||||
- ✅ URI generation consistency between processors
|
||||
- ✅ Triple generation from definitions and relationships
|
||||
- ✅ Knowledge store integration (triples and embeddings storage)
|
||||
- ✅ End-to-end pipeline coordination
|
||||
- ✅ Error handling in extraction services
|
||||
- ✅ Empty and invalid extraction results handling
|
||||
- ✅ Entity filtering and validation
|
||||
- ✅ Large batch processing performance
|
||||
- ✅ Metadata propagation through pipeline stages
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Test Organization
|
||||
- Group related tests in classes
|
||||
- Use descriptive test names that explain the scenario
|
||||
- Follow the Arrange-Act-Assert pattern
|
||||
- Use appropriate pytest markers (`@pytest.mark.integration`, `@pytest.mark.slow`)
|
||||
|
||||
### Mock Strategy
|
||||
- Mock external services (databases, APIs, message brokers)
|
||||
- Use real service logic and data transformations
|
||||
- Create realistic mock responses that match actual service behavior
|
||||
- Reset mocks between tests to ensure isolation
|
||||
|
||||
### Test Data
|
||||
- Use realistic test data that reflects actual usage patterns
|
||||
- Create reusable fixtures for common test scenarios
|
||||
- Test with various data sizes and edge cases
|
||||
- Include both success and failure scenarios
|
||||
|
||||
### Error Testing
|
||||
- Test each dependency failure scenario
|
||||
- Verify proper error propagation and cleanup
|
||||
- Test timeout and retry mechanisms
|
||||
- Validate error response formats
|
||||
|
||||
### Performance Testing
|
||||
- Mark performance tests with `@pytest.mark.slow`
|
||||
- Test with realistic data volumes
|
||||
- Set reasonable performance expectations
|
||||
- Monitor resource usage during tests
|
||||
|
||||
## Adding New Integration Tests
|
||||
|
||||
1. **Identify Service Dependencies**: Map out which services your target service coordinates with
|
||||
2. **Create Mock Fixtures**: Set up mocks for each dependency in conftest.py
|
||||
3. **Design Test Scenarios**: Plan happy path, error cases, and edge conditions
|
||||
4. **Implement Tests**: Follow the established patterns in this directory
|
||||
5. **Add Documentation**: Update this README with your new test patterns
|
||||
|
||||
## Test Markers
|
||||
|
||||
- `@pytest.mark.integration`: Marks tests as integration tests
|
||||
- `@pytest.mark.slow`: Marks tests that take longer to run
|
||||
- `@pytest.mark.asyncio`: Required for async test functions
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
- Add tests with real test containers for database integration
|
||||
- Implement contract testing for service interfaces
|
||||
- Add performance benchmarking for critical paths
|
||||
- Create integration test templates for common service patterns
|
||||
0
tests/integration/__init__.py
Normal file
0
tests/integration/__init__.py
Normal file
386
tests/integration/conftest.py
Normal file
386
tests/integration/conftest.py
Normal file
|
|
@ -0,0 +1,386 @@
|
|||
"""
|
||||
Shared fixtures and configuration for integration tests
|
||||
|
||||
This file provides common fixtures and test configuration for integration tests.
|
||||
Following the TEST_STRATEGY.md patterns for integration testing.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pulsar_client():
|
||||
"""Mock Pulsar client for integration tests"""
|
||||
client = MagicMock()
|
||||
client.create_producer.return_value = AsyncMock()
|
||||
client.subscribe.return_value = AsyncMock()
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_flow_context():
|
||||
"""Mock flow context for testing service coordination"""
|
||||
context = MagicMock()
|
||||
|
||||
# Mock flow producers/consumers
|
||||
context.return_value.send = AsyncMock()
|
||||
context.return_value.receive = AsyncMock()
|
||||
|
||||
return context
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def integration_config():
|
||||
"""Common configuration for integration tests"""
|
||||
return {
|
||||
"pulsar_host": "localhost",
|
||||
"pulsar_port": 6650,
|
||||
"test_timeout": 30.0,
|
||||
"max_retries": 3,
|
||||
"doc_limit": 10,
|
||||
"embedding_dim": 5,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_documents():
|
||||
"""Sample document collection for testing"""
|
||||
return [
|
||||
{
|
||||
"id": "doc1",
|
||||
"content": "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data.",
|
||||
"collection": "ml_knowledge",
|
||||
"user": "test_user"
|
||||
},
|
||||
{
|
||||
"id": "doc2",
|
||||
"content": "Deep learning uses neural networks with multiple layers to model complex patterns in data.",
|
||||
"collection": "ml_knowledge",
|
||||
"user": "test_user"
|
||||
},
|
||||
{
|
||||
"id": "doc3",
|
||||
"content": "Supervised learning algorithms learn from labeled training data to make predictions on new data.",
|
||||
"collection": "ml_knowledge",
|
||||
"user": "test_user"
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_embeddings():
|
||||
"""Sample embedding vectors for testing"""
|
||||
return [
|
||||
[0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
[0.6, 0.7, 0.8, 0.9, 1.0],
|
||||
[0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
[0.7, 0.8, 0.9, 1.0, 0.1],
|
||||
[0.3, 0.4, 0.5, 0.6, 0.7]
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_queries():
|
||||
"""Sample queries for testing"""
|
||||
return [
|
||||
"What is machine learning?",
|
||||
"How does deep learning work?",
|
||||
"Explain supervised learning",
|
||||
"What are neural networks?",
|
||||
"How do algorithms learn from data?"
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_text_completion_requests():
|
||||
"""Sample text completion requests for testing"""
|
||||
return [
|
||||
{
|
||||
"system": "You are a helpful assistant.",
|
||||
"prompt": "What is artificial intelligence?",
|
||||
"expected_keywords": ["artificial intelligence", "AI", "machine learning"]
|
||||
},
|
||||
{
|
||||
"system": "You are a technical expert.",
|
||||
"prompt": "Explain neural networks",
|
||||
"expected_keywords": ["neural networks", "neurons", "layers"]
|
||||
},
|
||||
{
|
||||
"system": "You are a teacher.",
|
||||
"prompt": "What is supervised learning?",
|
||||
"expected_keywords": ["supervised learning", "training", "labels"]
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_response():
|
||||
"""Mock OpenAI API response structure"""
|
||||
return {
|
||||
"id": "chatcmpl-test123",
|
||||
"object": "chat.completion",
|
||||
"created": 1234567890,
|
||||
"model": "gpt-3.5-turbo",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "This is a test response from the AI model."
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 50,
|
||||
"completion_tokens": 100,
|
||||
"total_tokens": 150
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def text_completion_configs():
|
||||
"""Various text completion configurations for testing"""
|
||||
return [
|
||||
{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"temperature": 0.0,
|
||||
"max_output": 1024,
|
||||
"description": "Conservative settings"
|
||||
},
|
||||
{
|
||||
"model": "gpt-4",
|
||||
"temperature": 0.7,
|
||||
"max_output": 2048,
|
||||
"description": "Balanced settings"
|
||||
},
|
||||
{
|
||||
"model": "gpt-4-turbo",
|
||||
"temperature": 1.0,
|
||||
"max_output": 4096,
|
||||
"description": "Creative settings"
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_agent_tools():
|
||||
"""Sample agent tools configuration for testing"""
|
||||
return {
|
||||
"knowledge_query": {
|
||||
"name": "knowledge_query",
|
||||
"description": "Query the knowledge graph for information",
|
||||
"type": "knowledge-query",
|
||||
"arguments": [
|
||||
{
|
||||
"name": "question",
|
||||
"type": "string",
|
||||
"description": "The question to ask the knowledge graph"
|
||||
}
|
||||
]
|
||||
},
|
||||
"text_completion": {
|
||||
"name": "text_completion",
|
||||
"description": "Generate text completion using LLM",
|
||||
"type": "text-completion",
|
||||
"arguments": [
|
||||
{
|
||||
"name": "question",
|
||||
"type": "string",
|
||||
"description": "The question to ask the LLM"
|
||||
}
|
||||
]
|
||||
},
|
||||
"web_search": {
|
||||
"name": "web_search",
|
||||
"description": "Search the web for information",
|
||||
"type": "mcp-tool",
|
||||
"arguments": [
|
||||
{
|
||||
"name": "query",
|
||||
"type": "string",
|
||||
"description": "The search query"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_agent_requests():
|
||||
"""Sample agent requests for testing"""
|
||||
return [
|
||||
{
|
||||
"question": "What is machine learning?",
|
||||
"plan": "",
|
||||
"state": "",
|
||||
"history": [],
|
||||
"expected_tool": "knowledge_query"
|
||||
},
|
||||
{
|
||||
"question": "Can you explain neural networks in simple terms?",
|
||||
"plan": "",
|
||||
"state": "",
|
||||
"history": [],
|
||||
"expected_tool": "text_completion"
|
||||
},
|
||||
{
|
||||
"question": "Search for the latest AI research papers",
|
||||
"plan": "",
|
||||
"state": "",
|
||||
"history": [],
|
||||
"expected_tool": "web_search"
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_agent_responses():
|
||||
"""Sample agent responses for testing"""
|
||||
return [
|
||||
{
|
||||
"thought": "I need to search for information about machine learning",
|
||||
"action": "knowledge_query",
|
||||
"arguments": {"question": "What is machine learning?"}
|
||||
},
|
||||
{
|
||||
"thought": "I can provide a direct answer about neural networks",
|
||||
"final-answer": "Neural networks are computing systems inspired by biological neural networks."
|
||||
},
|
||||
{
|
||||
"thought": "I should search the web for recent research",
|
||||
"action": "web_search",
|
||||
"arguments": {"query": "latest AI research papers 2024"}
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_conversation_history():
|
||||
"""Sample conversation history for testing"""
|
||||
return [
|
||||
{
|
||||
"thought": "I need to search for basic information first",
|
||||
"action": "knowledge_query",
|
||||
"arguments": {"question": "What is artificial intelligence?"},
|
||||
"observation": "AI is the simulation of human intelligence in machines."
|
||||
},
|
||||
{
|
||||
"thought": "Now I can provide more specific information",
|
||||
"action": "text_completion",
|
||||
"arguments": {"question": "Explain machine learning within AI"},
|
||||
"observation": "Machine learning is a subset of AI that enables computers to learn from data."
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_kg_extraction_data():
|
||||
"""Sample knowledge graph extraction data for testing"""
|
||||
return {
|
||||
"text_chunks": [
|
||||
"Machine Learning is a subset of Artificial Intelligence that enables computers to learn from data.",
|
||||
"Neural Networks are computing systems inspired by biological neural networks.",
|
||||
"Deep Learning uses neural networks with multiple layers to model complex patterns."
|
||||
],
|
||||
"expected_entities": [
|
||||
"Machine Learning",
|
||||
"Artificial Intelligence",
|
||||
"Neural Networks",
|
||||
"Deep Learning"
|
||||
],
|
||||
"expected_relationships": [
|
||||
{
|
||||
"subject": "Machine Learning",
|
||||
"predicate": "is_subset_of",
|
||||
"object": "Artificial Intelligence"
|
||||
},
|
||||
{
|
||||
"subject": "Deep Learning",
|
||||
"predicate": "uses",
|
||||
"object": "Neural Networks"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_kg_definitions():
|
||||
"""Sample knowledge graph definitions for testing"""
|
||||
return [
|
||||
{
|
||||
"entity": "Machine Learning",
|
||||
"definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming."
|
||||
},
|
||||
{
|
||||
"entity": "Artificial Intelligence",
|
||||
"definition": "The simulation of human intelligence in machines that are programmed to think and act like humans."
|
||||
},
|
||||
{
|
||||
"entity": "Neural Networks",
|
||||
"definition": "Computing systems inspired by biological neural networks that process information using interconnected nodes."
|
||||
},
|
||||
{
|
||||
"entity": "Deep Learning",
|
||||
"definition": "A subset of machine learning that uses neural networks with multiple layers to model complex patterns in data."
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_kg_relationships():
|
||||
"""Sample knowledge graph relationships for testing"""
|
||||
return [
|
||||
{
|
||||
"subject": "Machine Learning",
|
||||
"predicate": "is_subset_of",
|
||||
"object": "Artificial Intelligence",
|
||||
"object-entity": True
|
||||
},
|
||||
{
|
||||
"subject": "Deep Learning",
|
||||
"predicate": "is_subset_of",
|
||||
"object": "Machine Learning",
|
||||
"object-entity": True
|
||||
},
|
||||
{
|
||||
"subject": "Neural Networks",
|
||||
"predicate": "is_used_in",
|
||||
"object": "Deep Learning",
|
||||
"object-entity": True
|
||||
},
|
||||
{
|
||||
"subject": "Machine Learning",
|
||||
"predicate": "processes",
|
||||
"object": "data patterns",
|
||||
"object-entity": False
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_kg_triples():
|
||||
"""Sample knowledge graph triples for testing"""
|
||||
return [
|
||||
{
|
||||
"subject": "http://trustgraph.ai/e/machine-learning",
|
||||
"predicate": "http://www.w3.org/2000/01/rdf-schema#label",
|
||||
"object": "Machine Learning"
|
||||
},
|
||||
{
|
||||
"subject": "http://trustgraph.ai/e/machine-learning",
|
||||
"predicate": "http://trustgraph.ai/definition",
|
||||
"object": "A subset of artificial intelligence that enables computers to learn from data."
|
||||
},
|
||||
{
|
||||
"subject": "http://trustgraph.ai/e/machine-learning",
|
||||
"predicate": "http://trustgraph.ai/e/is_subset_of",
|
||||
"object": "http://trustgraph.ai/e/artificial-intelligence"
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
# Test markers for integration tests
|
||||
pytestmark = pytest.mark.integration
|
||||
532
tests/integration/test_agent_manager_integration.py
Normal file
532
tests/integration/test_agent_manager_integration.py
Normal file
|
|
@ -0,0 +1,532 @@
|
|||
"""
|
||||
Integration tests for Agent Manager (ReAct Pattern) Service
|
||||
|
||||
These tests verify the end-to-end functionality of the Agent Manager service,
|
||||
testing the ReAct pattern (Think-Act-Observe), tool coordination, multi-step reasoning,
|
||||
and conversation state management.
|
||||
Following the TEST_STRATEGY.md approach for integration testing.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from trustgraph.agent.react.agent_manager import AgentManager
|
||||
from trustgraph.agent.react.tools import KnowledgeQueryImpl, TextCompletionImpl, McpToolImpl
|
||||
from trustgraph.agent.react.types import Action, Final, Tool, Argument
|
||||
from trustgraph.schema import AgentRequest, AgentResponse, AgentStep, Error
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestAgentManagerIntegration:
|
||||
"""Integration tests for Agent Manager ReAct pattern coordination"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_flow_context(self):
|
||||
"""Mock flow context for service coordination"""
|
||||
context = MagicMock()
|
||||
|
||||
# Mock prompt client
|
||||
prompt_client = AsyncMock()
|
||||
prompt_client.agent_react.return_value = {
|
||||
"thought": "I need to search for information about machine learning",
|
||||
"action": "knowledge_query",
|
||||
"arguments": {"question": "What is machine learning?"}
|
||||
}
|
||||
|
||||
# Mock graph RAG client
|
||||
graph_rag_client = AsyncMock()
|
||||
graph_rag_client.rag.return_value = "Machine learning is a subset of AI that enables computers to learn from data."
|
||||
|
||||
# Mock text completion client
|
||||
text_completion_client = AsyncMock()
|
||||
text_completion_client.question.return_value = "Machine learning involves algorithms that improve through experience."
|
||||
|
||||
# Mock MCP tool client
|
||||
mcp_tool_client = AsyncMock()
|
||||
mcp_tool_client.invoke.return_value = "Tool execution successful"
|
||||
|
||||
# Configure context to return appropriate clients
|
||||
def context_router(service_name):
|
||||
if service_name == "prompt-request":
|
||||
return prompt_client
|
||||
elif service_name == "graph-rag-request":
|
||||
return graph_rag_client
|
||||
elif service_name == "prompt-request":
|
||||
return text_completion_client
|
||||
elif service_name == "mcp-tool-request":
|
||||
return mcp_tool_client
|
||||
else:
|
||||
return AsyncMock()
|
||||
|
||||
context.side_effect = context_router
|
||||
return context
|
||||
|
||||
@pytest.fixture
|
||||
def sample_tools(self):
|
||||
"""Sample tool configuration for testing"""
|
||||
return {
|
||||
"knowledge_query": Tool(
|
||||
name="knowledge_query",
|
||||
description="Query the knowledge graph for information",
|
||||
arguments={
|
||||
"question": Argument(
|
||||
name="question",
|
||||
type="string",
|
||||
description="The question to ask the knowledge graph"
|
||||
)
|
||||
},
|
||||
implementation=KnowledgeQueryImpl,
|
||||
config={}
|
||||
),
|
||||
"text_completion": Tool(
|
||||
name="text_completion",
|
||||
description="Generate text completion using LLM",
|
||||
arguments={
|
||||
"question": Argument(
|
||||
name="question",
|
||||
type="string",
|
||||
description="The question to ask the LLM"
|
||||
)
|
||||
},
|
||||
implementation=TextCompletionImpl,
|
||||
config={}
|
||||
),
|
||||
"web_search": Tool(
|
||||
name="web_search",
|
||||
description="Search the web for information",
|
||||
arguments={
|
||||
"query": Argument(
|
||||
name="query",
|
||||
type="string",
|
||||
description="The search query"
|
||||
)
|
||||
},
|
||||
implementation=lambda context: AsyncMock(invoke=AsyncMock(return_value="Web search results")),
|
||||
config={}
|
||||
)
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def agent_manager(self, sample_tools):
|
||||
"""Create agent manager with sample tools"""
|
||||
return AgentManager(
|
||||
tools=sample_tools,
|
||||
additional_context="You are a helpful AI assistant with access to knowledge and tools."
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_manager_reasoning_cycle(self, agent_manager, mock_flow_context):
|
||||
"""Test basic reasoning cycle with tool selection"""
|
||||
# Arrange
|
||||
question = "What is machine learning?"
|
||||
history = []
|
||||
|
||||
# Act
|
||||
action = await agent_manager.reason(question, history, mock_flow_context)
|
||||
|
||||
# Assert
|
||||
assert isinstance(action, Action)
|
||||
assert action.thought == "I need to search for information about machine learning"
|
||||
assert action.name == "knowledge_query"
|
||||
assert action.arguments == {"question": "What is machine learning?"}
|
||||
assert action.observation == ""
|
||||
|
||||
# Verify prompt client was called correctly
|
||||
prompt_client = mock_flow_context("prompt-request")
|
||||
prompt_client.agent_react.assert_called_once()
|
||||
|
||||
# Verify the prompt variables passed to agent_react
|
||||
call_args = prompt_client.agent_react.call_args
|
||||
variables = call_args[0][0]
|
||||
assert variables["question"] == question
|
||||
assert len(variables["tools"]) == 3 # knowledge_query, text_completion, web_search
|
||||
assert variables["context"] == "You are a helpful AI assistant with access to knowledge and tools."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_manager_final_answer(self, agent_manager, mock_flow_context):
|
||||
"""Test agent manager returning final answer"""
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").agent_react.return_value = {
|
||||
"thought": "I have enough information to answer the question",
|
||||
"final-answer": "Machine learning is a field of AI that enables computers to learn from data."
|
||||
}
|
||||
|
||||
question = "What is machine learning?"
|
||||
history = []
|
||||
|
||||
# Act
|
||||
action = await agent_manager.reason(question, history, mock_flow_context)
|
||||
|
||||
# Assert
|
||||
assert isinstance(action, Final)
|
||||
assert action.thought == "I have enough information to answer the question"
|
||||
assert action.final == "Machine learning is a field of AI that enables computers to learn from data."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_manager_react_with_tool_execution(self, agent_manager, mock_flow_context):
|
||||
"""Test full ReAct cycle with tool execution"""
|
||||
# Arrange
|
||||
question = "What is machine learning?"
|
||||
history = []
|
||||
|
||||
think_callback = AsyncMock()
|
||||
observe_callback = AsyncMock()
|
||||
|
||||
# Act
|
||||
action = await agent_manager.react(question, history, think_callback, observe_callback, mock_flow_context)
|
||||
|
||||
# Assert
|
||||
assert isinstance(action, Action)
|
||||
assert action.thought == "I need to search for information about machine learning"
|
||||
assert action.name == "knowledge_query"
|
||||
assert action.arguments == {"question": "What is machine learning?"}
|
||||
assert action.observation == "Machine learning is a subset of AI that enables computers to learn from data."
|
||||
|
||||
# Verify callbacks were called
|
||||
think_callback.assert_called_once_with("I need to search for information about machine learning")
|
||||
observe_callback.assert_called_once_with("Machine learning is a subset of AI that enables computers to learn from data.")
|
||||
|
||||
# Verify tool was executed
|
||||
graph_rag_client = mock_flow_context("graph-rag-request")
|
||||
graph_rag_client.rag.assert_called_once_with("What is machine learning?")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_manager_react_with_final_answer(self, agent_manager, mock_flow_context):
|
||||
"""Test ReAct cycle ending with final answer"""
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").agent_react.return_value = {
|
||||
"thought": "I can provide a direct answer",
|
||||
"final-answer": "Machine learning is a branch of artificial intelligence."
|
||||
}
|
||||
|
||||
question = "What is machine learning?"
|
||||
history = []
|
||||
|
||||
think_callback = AsyncMock()
|
||||
observe_callback = AsyncMock()
|
||||
|
||||
# Act
|
||||
action = await agent_manager.react(question, history, think_callback, observe_callback, mock_flow_context)
|
||||
|
||||
# Assert
|
||||
assert isinstance(action, Final)
|
||||
assert action.thought == "I can provide a direct answer"
|
||||
assert action.final == "Machine learning is a branch of artificial intelligence."
|
||||
|
||||
# Verify only think callback was called (no observation for final answer)
|
||||
think_callback.assert_called_once_with("I can provide a direct answer")
|
||||
observe_callback.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_manager_with_conversation_history(self, agent_manager, mock_flow_context):
|
||||
"""Test agent manager with conversation history"""
|
||||
# Arrange
|
||||
question = "Can you tell me more about neural networks?"
|
||||
history = [
|
||||
Action(
|
||||
thought="I need to search for information about machine learning",
|
||||
name="knowledge_query",
|
||||
arguments={"question": "What is machine learning?"},
|
||||
observation="Machine learning is a subset of AI that enables computers to learn from data."
|
||||
)
|
||||
]
|
||||
|
||||
# Act
|
||||
action = await agent_manager.reason(question, history, mock_flow_context)
|
||||
|
||||
# Assert
|
||||
assert isinstance(action, Action)
|
||||
|
||||
# Verify history was included in prompt variables
|
||||
prompt_client = mock_flow_context("prompt-request")
|
||||
call_args = prompt_client.agent_react.call_args
|
||||
variables = call_args[0][0]
|
||||
assert len(variables["history"]) == 1
|
||||
assert variables["history"][0]["thought"] == "I need to search for information about machine learning"
|
||||
assert variables["history"][0]["action"] == "knowledge_query"
|
||||
assert variables["history"][0]["observation"] == "Machine learning is a subset of AI that enables computers to learn from data."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_manager_tool_selection(self, agent_manager, mock_flow_context):
|
||||
"""Test agent manager selecting different tools"""
|
||||
# Test different tool selections
|
||||
tool_scenarios = [
|
||||
("knowledge_query", "graph-rag-request"),
|
||||
("text_completion", "prompt-request"),
|
||||
]
|
||||
|
||||
for tool_name, expected_service in tool_scenarios:
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").agent_react.return_value = {
|
||||
"thought": f"I need to use {tool_name}",
|
||||
"action": tool_name,
|
||||
"arguments": {"question": "test question"}
|
||||
}
|
||||
|
||||
think_callback = AsyncMock()
|
||||
observe_callback = AsyncMock()
|
||||
|
||||
# Act
|
||||
action = await agent_manager.react("test question", [], think_callback, observe_callback, mock_flow_context)
|
||||
|
||||
# Assert
|
||||
assert isinstance(action, Action)
|
||||
assert action.name == tool_name
|
||||
|
||||
# Verify correct service was called
|
||||
if tool_name == "knowledge_query":
|
||||
mock_flow_context("graph-rag-request").rag.assert_called()
|
||||
elif tool_name == "text_completion":
|
||||
mock_flow_context("prompt-request").question.assert_called()
|
||||
|
||||
# Reset mocks for next iteration
|
||||
for service in ["prompt-request", "graph-rag-request", "prompt-request"]:
|
||||
mock_flow_context(service).reset_mock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_manager_unknown_tool_error(self, agent_manager, mock_flow_context):
|
||||
"""Test agent manager error handling for unknown tool"""
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").agent_react.return_value = {
|
||||
"thought": "I need to use an unknown tool",
|
||||
"action": "unknown_tool",
|
||||
"arguments": {"param": "value"}
|
||||
}
|
||||
|
||||
think_callback = AsyncMock()
|
||||
observe_callback = AsyncMock()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
await agent_manager.react("test question", [], think_callback, observe_callback, mock_flow_context)
|
||||
|
||||
assert "No action for unknown_tool!" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_manager_tool_execution_error(self, agent_manager, mock_flow_context):
|
||||
"""Test agent manager handling tool execution errors"""
|
||||
# Arrange
|
||||
mock_flow_context("graph-rag-request").rag.side_effect = Exception("Tool execution failed")
|
||||
|
||||
think_callback = AsyncMock()
|
||||
observe_callback = AsyncMock()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await agent_manager.react("test question", [], think_callback, observe_callback, mock_flow_context)
|
||||
|
||||
assert "Tool execution failed" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_manager_multiple_tools_coordination(self, agent_manager, mock_flow_context):
|
||||
"""Test agent manager coordination with multiple available tools"""
|
||||
# Arrange
|
||||
question = "Find information about AI and summarize it"
|
||||
|
||||
# Mock multi-step reasoning
|
||||
mock_flow_context("prompt-request").agent_react.return_value = {
|
||||
"thought": "I need to search for AI information first",
|
||||
"action": "knowledge_query",
|
||||
"arguments": {"question": "What is artificial intelligence?"}
|
||||
}
|
||||
|
||||
# Act
|
||||
action = await agent_manager.reason(question, [], mock_flow_context)
|
||||
|
||||
# Assert
|
||||
assert isinstance(action, Action)
|
||||
assert action.name == "knowledge_query"
|
||||
|
||||
# Verify tool information was passed to prompt
|
||||
prompt_client = mock_flow_context("prompt-request")
|
||||
call_args = prompt_client.agent_react.call_args
|
||||
variables = call_args[0][0]
|
||||
|
||||
# Should have all 3 tools available
|
||||
tool_names = [tool["name"] for tool in variables["tools"]]
|
||||
assert "knowledge_query" in tool_names
|
||||
assert "text_completion" in tool_names
|
||||
assert "web_search" in tool_names
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_manager_tool_argument_validation(self, agent_manager, mock_flow_context):
|
||||
"""Test agent manager with various tool argument patterns"""
|
||||
# Arrange
|
||||
test_cases = [
|
||||
{
|
||||
"action": "knowledge_query",
|
||||
"arguments": {"question": "What is deep learning?"},
|
||||
"expected_service": "graph-rag-request"
|
||||
},
|
||||
{
|
||||
"action": "text_completion",
|
||||
"arguments": {"question": "Explain neural networks"},
|
||||
"expected_service": "prompt-request"
|
||||
},
|
||||
{
|
||||
"action": "web_search",
|
||||
"arguments": {"query": "latest AI research"},
|
||||
"expected_service": None # Custom mock
|
||||
}
|
||||
]
|
||||
|
||||
for test_case in test_cases:
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").agent_react.return_value = {
|
||||
"thought": f"Using {test_case['action']}",
|
||||
"action": test_case['action'],
|
||||
"arguments": test_case['arguments']
|
||||
}
|
||||
|
||||
think_callback = AsyncMock()
|
||||
observe_callback = AsyncMock()
|
||||
|
||||
# Act
|
||||
action = await agent_manager.react("test", [], think_callback, observe_callback, mock_flow_context)
|
||||
|
||||
# Assert
|
||||
assert isinstance(action, Action)
|
||||
assert action.name == test_case['action']
|
||||
assert action.arguments == test_case['arguments']
|
||||
|
||||
# Reset mocks
|
||||
for service in ["prompt-request", "graph-rag-request", "prompt-request"]:
|
||||
mock_flow_context(service).reset_mock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_manager_context_integration(self, agent_manager, mock_flow_context):
|
||||
"""Test agent manager integration with additional context"""
|
||||
# Arrange
|
||||
agent_with_context = AgentManager(
|
||||
tools={"knowledge_query": agent_manager.tools["knowledge_query"]},
|
||||
additional_context="You are an expert in machine learning research."
|
||||
)
|
||||
|
||||
question = "What are the latest developments in AI?"
|
||||
|
||||
# Act
|
||||
action = await agent_with_context.reason(question, [], mock_flow_context)
|
||||
|
||||
# Assert
|
||||
prompt_client = mock_flow_context("prompt-request")
|
||||
call_args = prompt_client.agent_react.call_args
|
||||
variables = call_args[0][0]
|
||||
|
||||
assert variables["context"] == "You are an expert in machine learning research."
|
||||
assert variables["question"] == question
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_manager_empty_tools(self, mock_flow_context):
|
||||
"""Test agent manager with no tools available"""
|
||||
# Arrange
|
||||
agent_no_tools = AgentManager(tools={}, additional_context="")
|
||||
|
||||
question = "What is machine learning?"
|
||||
|
||||
# Act
|
||||
action = await agent_no_tools.reason(question, [], mock_flow_context)
|
||||
|
||||
# Assert
|
||||
prompt_client = mock_flow_context("prompt-request")
|
||||
call_args = prompt_client.agent_react.call_args
|
||||
variables = call_args[0][0]
|
||||
|
||||
assert len(variables["tools"]) == 0
|
||||
assert variables["tool_names"] == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_manager_tool_response_processing(self, agent_manager, mock_flow_context):
|
||||
"""Test agent manager processing different tool response types"""
|
||||
# Arrange
|
||||
response_scenarios = [
|
||||
"Simple text response",
|
||||
"Multi-line response\nwith several lines\nof information",
|
||||
"Response with special characters: @#$%^&*()_+-=[]{}|;':\",./<>?",
|
||||
" Response with whitespace ",
|
||||
"" # Empty response
|
||||
]
|
||||
|
||||
for expected_response in response_scenarios:
|
||||
# Set up mock response
|
||||
mock_flow_context("graph-rag-request").rag.return_value = expected_response
|
||||
|
||||
think_callback = AsyncMock()
|
||||
observe_callback = AsyncMock()
|
||||
|
||||
# Act
|
||||
action = await agent_manager.react("test question", [], think_callback, observe_callback, mock_flow_context)
|
||||
|
||||
# Assert
|
||||
assert isinstance(action, Action)
|
||||
assert action.observation == expected_response.strip()
|
||||
observe_callback.assert_called_with(expected_response.strip())
|
||||
|
||||
# Reset mocks
|
||||
mock_flow_context("graph-rag-request").reset_mock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.slow
|
||||
async def test_agent_manager_performance_with_large_history(self, agent_manager, mock_flow_context):
|
||||
"""Test agent manager performance with large conversation history"""
|
||||
# Arrange
|
||||
large_history = [
|
||||
Action(
|
||||
thought=f"Step {i} thinking",
|
||||
name="knowledge_query",
|
||||
arguments={"question": f"Question {i}"},
|
||||
observation=f"Observation {i}"
|
||||
)
|
||||
for i in range(50) # Large history
|
||||
]
|
||||
|
||||
question = "Final question"
|
||||
|
||||
# Act
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
action = await agent_manager.reason(question, large_history, mock_flow_context)
|
||||
|
||||
end_time = time.time()
|
||||
execution_time = end_time - start_time
|
||||
|
||||
# Assert
|
||||
assert isinstance(action, Action)
|
||||
assert execution_time < 5.0 # Should complete within reasonable time
|
||||
|
||||
# Verify history was processed correctly
|
||||
prompt_client = mock_flow_context("prompt-request")
|
||||
call_args = prompt_client.agent_react.call_args
|
||||
variables = call_args[0][0]
|
||||
assert len(variables["history"]) == 50
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_manager_json_serialization(self, agent_manager, mock_flow_context):
|
||||
"""Test agent manager handling of JSON serialization in prompts"""
|
||||
# Arrange
|
||||
complex_history = [
|
||||
Action(
|
||||
thought="Complex thinking with special characters: \"quotes\", 'apostrophes', and symbols",
|
||||
name="knowledge_query",
|
||||
arguments={"question": "What about JSON serialization?", "complex": {"nested": "value"}},
|
||||
observation="Response with JSON: {\"key\": \"value\"}"
|
||||
)
|
||||
]
|
||||
|
||||
question = "Handle JSON properly"
|
||||
|
||||
# Act
|
||||
action = await agent_manager.reason(question, complex_history, mock_flow_context)
|
||||
|
||||
# Assert
|
||||
assert isinstance(action, Action)
|
||||
|
||||
# Verify JSON was properly serialized in prompt
|
||||
prompt_client = mock_flow_context("prompt-request")
|
||||
call_args = prompt_client.agent_react.call_args
|
||||
variables = call_args[0][0]
|
||||
|
||||
# Should not raise JSON serialization errors
|
||||
json_str = json.dumps(variables, indent=4)
|
||||
assert len(json_str) > 0
|
||||
309
tests/integration/test_document_rag_integration.py
Normal file
309
tests/integration/test_document_rag_integration.py
Normal file
|
|
@ -0,0 +1,309 @@
|
|||
"""
|
||||
Integration tests for DocumentRAG retrieval system
|
||||
|
||||
These tests verify the end-to-end functionality of the DocumentRAG system,
|
||||
testing the coordination between embeddings, document retrieval, and prompt services.
|
||||
Following the TEST_STRATEGY.md approach for integration testing.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from testcontainers.compose import DockerCompose
|
||||
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestDocumentRagIntegration:
|
||||
"""Integration tests for DocumentRAG system coordination"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embeddings_client(self):
|
||||
"""Mock embeddings client that returns realistic vector embeddings"""
|
||||
client = AsyncMock()
|
||||
client.embed.return_value = [
|
||||
[0.1, 0.2, 0.3, 0.4, 0.5], # Realistic 5-dimensional embedding
|
||||
[0.6, 0.7, 0.8, 0.9, 1.0] # Second embedding for testing
|
||||
]
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_doc_embeddings_client(self):
|
||||
"""Mock document embeddings client that returns realistic document chunks"""
|
||||
client = AsyncMock()
|
||||
client.query.return_value = [
|
||||
"Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data.",
|
||||
"Deep learning uses neural networks with multiple layers to model complex patterns in data.",
|
||||
"Supervised learning algorithms learn from labeled training data to make predictions on new data."
|
||||
]
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_prompt_client(self):
|
||||
"""Mock prompt client that generates realistic responses"""
|
||||
client = AsyncMock()
|
||||
client.document_prompt.return_value = (
|
||||
"Machine learning is a field of artificial intelligence that enables computers to learn "
|
||||
"and improve from experience without being explicitly programmed. It uses algorithms "
|
||||
"to find patterns in data and make predictions or decisions."
|
||||
)
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def document_rag(self, mock_embeddings_client, mock_doc_embeddings_client, mock_prompt_client):
|
||||
"""Create DocumentRag instance with mocked dependencies"""
|
||||
return DocumentRag(
|
||||
embeddings_client=mock_embeddings_client,
|
||||
doc_embeddings_client=mock_doc_embeddings_client,
|
||||
prompt_client=mock_prompt_client,
|
||||
verbose=True
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_end_to_end_flow(self, document_rag, mock_embeddings_client,
|
||||
mock_doc_embeddings_client, mock_prompt_client):
|
||||
"""Test complete DocumentRAG pipeline from query to response"""
|
||||
# Arrange
|
||||
query = "What is machine learning?"
|
||||
user = "test_user"
|
||||
collection = "ml_knowledge"
|
||||
doc_limit = 10
|
||||
|
||||
# Act
|
||||
result = await document_rag.query(
|
||||
query=query,
|
||||
user=user,
|
||||
collection=collection,
|
||||
doc_limit=doc_limit
|
||||
)
|
||||
|
||||
# Assert - Verify service coordination
|
||||
mock_embeddings_client.embed.assert_called_once_with(query)
|
||||
|
||||
mock_doc_embeddings_client.query.assert_called_once_with(
|
||||
[[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]],
|
||||
limit=doc_limit,
|
||||
user=user,
|
||||
collection=collection
|
||||
)
|
||||
|
||||
mock_prompt_client.document_prompt.assert_called_once_with(
|
||||
query=query,
|
||||
documents=[
|
||||
"Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data.",
|
||||
"Deep learning uses neural networks with multiple layers to model complex patterns in data.",
|
||||
"Supervised learning algorithms learn from labeled training data to make predictions on new data."
|
||||
]
|
||||
)
|
||||
|
||||
# Verify final response
|
||||
assert result is not None
|
||||
assert isinstance(result, str)
|
||||
assert "machine learning" in result.lower()
|
||||
assert "artificial intelligence" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_with_no_documents_found(self, mock_embeddings_client,
|
||||
mock_doc_embeddings_client, mock_prompt_client):
|
||||
"""Test DocumentRAG behavior when no documents are retrieved"""
|
||||
# Arrange
|
||||
mock_doc_embeddings_client.query.return_value = [] # No documents found
|
||||
mock_prompt_client.document_prompt.return_value = "I couldn't find any relevant documents for your query."
|
||||
|
||||
document_rag = DocumentRag(
|
||||
embeddings_client=mock_embeddings_client,
|
||||
doc_embeddings_client=mock_doc_embeddings_client,
|
||||
prompt_client=mock_prompt_client,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Act
|
||||
result = await document_rag.query("very obscure query")
|
||||
|
||||
# Assert
|
||||
mock_embeddings_client.embed.assert_called_once()
|
||||
mock_doc_embeddings_client.query.assert_called_once()
|
||||
mock_prompt_client.document_prompt.assert_called_once_with(
|
||||
query="very obscure query",
|
||||
documents=[]
|
||||
)
|
||||
|
||||
assert result == "I couldn't find any relevant documents for your query."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_embeddings_service_failure(self, mock_embeddings_client,
|
||||
mock_doc_embeddings_client, mock_prompt_client):
|
||||
"""Test DocumentRAG error handling when embeddings service fails"""
|
||||
# Arrange
|
||||
mock_embeddings_client.embed.side_effect = Exception("Embeddings service unavailable")
|
||||
|
||||
document_rag = DocumentRag(
|
||||
embeddings_client=mock_embeddings_client,
|
||||
doc_embeddings_client=mock_doc_embeddings_client,
|
||||
prompt_client=mock_prompt_client,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await document_rag.query("test query")
|
||||
|
||||
assert "Embeddings service unavailable" in str(exc_info.value)
|
||||
mock_embeddings_client.embed.assert_called_once()
|
||||
mock_doc_embeddings_client.query.assert_not_called()
|
||||
mock_prompt_client.document_prompt.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_document_service_failure(self, mock_embeddings_client,
|
||||
mock_doc_embeddings_client, mock_prompt_client):
|
||||
"""Test DocumentRAG error handling when document service fails"""
|
||||
# Arrange
|
||||
mock_doc_embeddings_client.query.side_effect = Exception("Document service connection failed")
|
||||
|
||||
document_rag = DocumentRag(
|
||||
embeddings_client=mock_embeddings_client,
|
||||
doc_embeddings_client=mock_doc_embeddings_client,
|
||||
prompt_client=mock_prompt_client,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await document_rag.query("test query")
|
||||
|
||||
assert "Document service connection failed" in str(exc_info.value)
|
||||
mock_embeddings_client.embed.assert_called_once()
|
||||
mock_doc_embeddings_client.query.assert_called_once()
|
||||
mock_prompt_client.document_prompt.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_prompt_service_failure(self, mock_embeddings_client,
|
||||
mock_doc_embeddings_client, mock_prompt_client):
|
||||
"""Test DocumentRAG error handling when prompt service fails"""
|
||||
# Arrange
|
||||
mock_prompt_client.document_prompt.side_effect = Exception("LLM service rate limited")
|
||||
|
||||
document_rag = DocumentRag(
|
||||
embeddings_client=mock_embeddings_client,
|
||||
doc_embeddings_client=mock_doc_embeddings_client,
|
||||
prompt_client=mock_prompt_client,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await document_rag.query("test query")
|
||||
|
||||
assert "LLM service rate limited" in str(exc_info.value)
|
||||
mock_embeddings_client.embed.assert_called_once()
|
||||
mock_doc_embeddings_client.query.assert_called_once()
|
||||
mock_prompt_client.document_prompt.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_with_different_document_limits(self, document_rag,
|
||||
mock_doc_embeddings_client):
|
||||
"""Test DocumentRAG with various document limit configurations"""
|
||||
# Test different document limits
|
||||
test_cases = [1, 5, 10, 25, 50]
|
||||
|
||||
for limit in test_cases:
|
||||
# Reset mock call history
|
||||
mock_doc_embeddings_client.reset_mock()
|
||||
|
||||
# Act
|
||||
await document_rag.query(f"query with limit {limit}", doc_limit=limit)
|
||||
|
||||
# Assert
|
||||
mock_doc_embeddings_client.query.assert_called_once()
|
||||
call_args = mock_doc_embeddings_client.query.call_args
|
||||
assert call_args.kwargs['limit'] == limit
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_multi_user_isolation(self, document_rag, mock_doc_embeddings_client):
|
||||
"""Test DocumentRAG properly isolates queries by user and collection"""
|
||||
# Arrange
|
||||
test_scenarios = [
|
||||
("user1", "collection1"),
|
||||
("user2", "collection2"),
|
||||
("user1", "collection2"), # Same user, different collection
|
||||
("user2", "collection1"), # Different user, same collection
|
||||
]
|
||||
|
||||
for user, collection in test_scenarios:
|
||||
# Reset mock call history
|
||||
mock_doc_embeddings_client.reset_mock()
|
||||
|
||||
# Act
|
||||
await document_rag.query(
|
||||
f"query from {user} in {collection}",
|
||||
user=user,
|
||||
collection=collection
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_doc_embeddings_client.query.assert_called_once()
|
||||
call_args = mock_doc_embeddings_client.query.call_args
|
||||
assert call_args.kwargs['user'] == user
|
||||
assert call_args.kwargs['collection'] == collection
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_verbose_logging(self, mock_embeddings_client,
|
||||
mock_doc_embeddings_client, mock_prompt_client,
|
||||
capsys):
|
||||
"""Test DocumentRAG verbose logging functionality"""
|
||||
# Arrange
|
||||
document_rag = DocumentRag(
|
||||
embeddings_client=mock_embeddings_client,
|
||||
doc_embeddings_client=mock_doc_embeddings_client,
|
||||
prompt_client=mock_prompt_client,
|
||||
verbose=True
|
||||
)
|
||||
|
||||
# Act
|
||||
await document_rag.query("test query for verbose logging")
|
||||
|
||||
# Assert
|
||||
captured = capsys.readouterr()
|
||||
assert "Initialised" in captured.out
|
||||
assert "Construct prompt..." in captured.out
|
||||
assert "Compute embeddings..." in captured.out
|
||||
assert "Get docs..." in captured.out
|
||||
assert "Invoke LLM..." in captured.out
|
||||
assert "Done" in captured.out
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.slow
|
||||
async def test_document_rag_performance_with_large_document_set(self, document_rag,
|
||||
mock_doc_embeddings_client):
|
||||
"""Test DocumentRAG performance with large document retrieval"""
|
||||
# Arrange - Mock large document set (100 documents)
|
||||
large_doc_set = [f"Document {i} content about machine learning and AI" for i in range(100)]
|
||||
mock_doc_embeddings_client.query.return_value = large_doc_set
|
||||
|
||||
# Act
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
result = await document_rag.query("performance test query", doc_limit=100)
|
||||
|
||||
end_time = time.time()
|
||||
execution_time = end_time - start_time
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert execution_time < 5.0 # Should complete within 5 seconds
|
||||
mock_doc_embeddings_client.query.assert_called_once()
|
||||
call_args = mock_doc_embeddings_client.query.call_args
|
||||
assert call_args.kwargs['limit'] == 100
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_default_parameters(self, document_rag, mock_doc_embeddings_client):
|
||||
"""Test DocumentRAG uses correct default parameters"""
|
||||
# Act
|
||||
await document_rag.query("test query with defaults")
|
||||
|
||||
# Assert
|
||||
mock_doc_embeddings_client.query.assert_called_once()
|
||||
call_args = mock_doc_embeddings_client.query.call_args
|
||||
assert call_args.kwargs['user'] == "trustgraph"
|
||||
assert call_args.kwargs['collection'] == "default"
|
||||
assert call_args.kwargs['limit'] == 20
|
||||
642
tests/integration/test_kg_extract_store_integration.py
Normal file
642
tests/integration/test_kg_extract_store_integration.py
Normal file
|
|
@ -0,0 +1,642 @@
|
|||
"""
|
||||
Integration tests for Knowledge Graph Extract → Store Pipeline
|
||||
|
||||
These tests verify the end-to-end functionality of the knowledge graph extraction
|
||||
and storage pipeline, testing text-to-graph transformation, entity extraction,
|
||||
relationship extraction, and graph database storage.
|
||||
Following the TEST_STRATEGY.md approach for integration testing.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
import urllib.parse
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from trustgraph.extract.kg.definitions.extract import Processor as DefinitionsProcessor
|
||||
from trustgraph.extract.kg.relationships.extract import Processor as RelationshipsProcessor
|
||||
from trustgraph.storage.knowledge.store import Processor as KnowledgeStoreProcessor
|
||||
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Value, Error
|
||||
from trustgraph.schema import EntityContext, EntityContexts, GraphEmbeddings
|
||||
from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestKnowledgeGraphPipelineIntegration:
|
||||
"""Integration tests for Knowledge Graph Extract → Store Pipeline"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_flow_context(self):
|
||||
"""Mock flow context for service coordination"""
|
||||
context = MagicMock()
|
||||
|
||||
# Mock prompt client for definitions extraction
|
||||
prompt_client = AsyncMock()
|
||||
prompt_client.extract_definitions.return_value = [
|
||||
{
|
||||
"entity": "Machine Learning",
|
||||
"definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming."
|
||||
},
|
||||
{
|
||||
"entity": "Neural Networks",
|
||||
"definition": "Computing systems inspired by biological neural networks that process information."
|
||||
}
|
||||
]
|
||||
|
||||
# Mock prompt client for relationships extraction
|
||||
prompt_client.extract_relationships.return_value = [
|
||||
{
|
||||
"subject": "Machine Learning",
|
||||
"predicate": "is_subset_of",
|
||||
"object": "Artificial Intelligence",
|
||||
"object-entity": True
|
||||
},
|
||||
{
|
||||
"subject": "Neural Networks",
|
||||
"predicate": "is_used_in",
|
||||
"object": "Machine Learning",
|
||||
"object-entity": True
|
||||
}
|
||||
]
|
||||
|
||||
# Mock producers for output streams
|
||||
triples_producer = AsyncMock()
|
||||
entity_contexts_producer = AsyncMock()
|
||||
|
||||
# Configure context routing
|
||||
def context_router(service_name):
|
||||
if service_name == "prompt-request":
|
||||
return prompt_client
|
||||
elif service_name == "triples":
|
||||
return triples_producer
|
||||
elif service_name == "entity-contexts":
|
||||
return entity_contexts_producer
|
||||
else:
|
||||
return AsyncMock()
|
||||
|
||||
context.side_effect = context_router
|
||||
return context
|
||||
|
||||
@pytest.fixture
|
||||
def mock_cassandra_store(self):
|
||||
"""Mock Cassandra knowledge table store"""
|
||||
store = AsyncMock()
|
||||
store.add_triples.return_value = None
|
||||
store.add_graph_embeddings.return_value = None
|
||||
return store
|
||||
|
||||
@pytest.fixture
|
||||
def sample_chunk(self):
|
||||
"""Sample text chunk for processing"""
|
||||
return Chunk(
|
||||
metadata=Metadata(
|
||||
id="doc-123",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
metadata=[]
|
||||
),
|
||||
chunk=b"Machine Learning is a subset of Artificial Intelligence. Neural Networks are used in Machine Learning to process complex patterns."
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_definitions_response(self):
|
||||
"""Sample definitions extraction response"""
|
||||
return [
|
||||
{
|
||||
"entity": "Machine Learning",
|
||||
"definition": "A subset of artificial intelligence that enables computers to learn from data."
|
||||
},
|
||||
{
|
||||
"entity": "Artificial Intelligence",
|
||||
"definition": "The simulation of human intelligence in machines."
|
||||
},
|
||||
{
|
||||
"entity": "Neural Networks",
|
||||
"definition": "Computing systems inspired by biological neural networks."
|
||||
}
|
||||
]
|
||||
|
||||
@pytest.fixture
|
||||
def sample_relationships_response(self):
|
||||
"""Sample relationships extraction response"""
|
||||
return [
|
||||
{
|
||||
"subject": "Machine Learning",
|
||||
"predicate": "is_subset_of",
|
||||
"object": "Artificial Intelligence",
|
||||
"object-entity": True
|
||||
},
|
||||
{
|
||||
"subject": "Neural Networks",
|
||||
"predicate": "is_used_in",
|
||||
"object": "Machine Learning",
|
||||
"object-entity": True
|
||||
},
|
||||
{
|
||||
"subject": "Machine Learning",
|
||||
"predicate": "processes",
|
||||
"object": "data patterns",
|
||||
"object-entity": False
|
||||
}
|
||||
]
|
||||
|
||||
@pytest.fixture
|
||||
def definitions_processor(self):
|
||||
"""Create definitions processor with minimal configuration"""
|
||||
processor = MagicMock()
|
||||
processor.to_uri = DefinitionsProcessor.to_uri.__get__(processor, DefinitionsProcessor)
|
||||
processor.emit_triples = DefinitionsProcessor.emit_triples.__get__(processor, DefinitionsProcessor)
|
||||
processor.emit_ecs = DefinitionsProcessor.emit_ecs.__get__(processor, DefinitionsProcessor)
|
||||
processor.on_message = DefinitionsProcessor.on_message.__get__(processor, DefinitionsProcessor)
|
||||
return processor
|
||||
|
||||
@pytest.fixture
|
||||
def relationships_processor(self):
|
||||
"""Create relationships processor with minimal configuration"""
|
||||
processor = MagicMock()
|
||||
processor.to_uri = RelationshipsProcessor.to_uri.__get__(processor, RelationshipsProcessor)
|
||||
processor.emit_triples = RelationshipsProcessor.emit_triples.__get__(processor, RelationshipsProcessor)
|
||||
processor.on_message = RelationshipsProcessor.on_message.__get__(processor, RelationshipsProcessor)
|
||||
return processor
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_definitions_extraction_pipeline(self, definitions_processor, mock_flow_context, sample_chunk):
|
||||
"""Test definitions extraction from text chunk to graph triples"""
|
||||
# Arrange
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = sample_chunk
|
||||
mock_consumer = MagicMock()
|
||||
|
||||
# Act
|
||||
await definitions_processor.on_message(mock_msg, mock_consumer, mock_flow_context)
|
||||
|
||||
# Assert
|
||||
# Verify prompt client was called for definitions extraction
|
||||
prompt_client = mock_flow_context("prompt-request")
|
||||
prompt_client.extract_definitions.assert_called_once()
|
||||
call_args = prompt_client.extract_definitions.call_args
|
||||
assert "Machine Learning" in call_args.kwargs['text']
|
||||
assert "Neural Networks" in call_args.kwargs['text']
|
||||
|
||||
# Verify triples producer was called
|
||||
triples_producer = mock_flow_context("triples")
|
||||
triples_producer.send.assert_called_once()
|
||||
|
||||
# Verify entity contexts producer was called
|
||||
entity_contexts_producer = mock_flow_context("entity-contexts")
|
||||
entity_contexts_producer.send.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_relationships_extraction_pipeline(self, relationships_processor, mock_flow_context, sample_chunk):
|
||||
"""Test relationships extraction from text chunk to graph triples"""
|
||||
# Arrange
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = sample_chunk
|
||||
mock_consumer = MagicMock()
|
||||
|
||||
# Act
|
||||
await relationships_processor.on_message(mock_msg, mock_consumer, mock_flow_context)
|
||||
|
||||
# Assert
|
||||
# Verify prompt client was called for relationships extraction
|
||||
prompt_client = mock_flow_context("prompt-request")
|
||||
prompt_client.extract_relationships.assert_called_once()
|
||||
call_args = prompt_client.extract_relationships.call_args
|
||||
assert "Machine Learning" in call_args.kwargs['text']
|
||||
|
||||
# Verify triples producer was called
|
||||
triples_producer = mock_flow_context("triples")
|
||||
triples_producer.send.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uri_generation_consistency(self, definitions_processor, relationships_processor):
|
||||
"""Test URI generation consistency between processors"""
|
||||
# Arrange
|
||||
test_entities = [
|
||||
"Machine Learning",
|
||||
"Artificial Intelligence",
|
||||
"Neural Networks",
|
||||
"Deep Learning",
|
||||
"Natural Language Processing"
|
||||
]
|
||||
|
||||
# Act & Assert
|
||||
for entity in test_entities:
|
||||
def_uri = definitions_processor.to_uri(entity)
|
||||
rel_uri = relationships_processor.to_uri(entity)
|
||||
|
||||
# URIs should be identical between processors
|
||||
assert def_uri == rel_uri
|
||||
|
||||
# URI should be properly encoded
|
||||
assert def_uri.startswith(TRUSTGRAPH_ENTITIES)
|
||||
assert " " not in def_uri
|
||||
assert def_uri.endswith(urllib.parse.quote(entity.replace(" ", "-").lower().encode("utf-8")))
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_definitions_triple_generation(self, definitions_processor, sample_definitions_response):
|
||||
"""Test triple generation from definitions extraction"""
|
||||
# Arrange
|
||||
metadata = Metadata(
|
||||
id="test-doc",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
metadata=[]
|
||||
)
|
||||
|
||||
# Act
|
||||
triples = []
|
||||
entities = []
|
||||
|
||||
for defn in sample_definitions_response:
|
||||
s = defn["entity"]
|
||||
o = defn["definition"]
|
||||
|
||||
if s and o:
|
||||
s_uri = definitions_processor.to_uri(s)
|
||||
s_value = Value(value=str(s_uri), is_uri=True)
|
||||
o_value = Value(value=str(o), is_uri=False)
|
||||
|
||||
# Generate triples as the processor would
|
||||
triples.append(Triple(
|
||||
s=s_value,
|
||||
p=Value(value=RDF_LABEL, is_uri=True),
|
||||
o=Value(value=s, is_uri=False)
|
||||
))
|
||||
|
||||
triples.append(Triple(
|
||||
s=s_value,
|
||||
p=Value(value=DEFINITION, is_uri=True),
|
||||
o=o_value
|
||||
))
|
||||
|
||||
entities.append(EntityContext(
|
||||
entity=s_value,
|
||||
context=defn["definition"]
|
||||
))
|
||||
|
||||
# Assert
|
||||
assert len(triples) == 6 # 2 triples per entity * 3 entities
|
||||
assert len(entities) == 3 # 1 entity context per entity
|
||||
|
||||
# Verify triple structure
|
||||
label_triples = [t for t in triples if t.p.value == RDF_LABEL]
|
||||
definition_triples = [t for t in triples if t.p.value == DEFINITION]
|
||||
|
||||
assert len(label_triples) == 3
|
||||
assert len(definition_triples) == 3
|
||||
|
||||
# Verify entity contexts
|
||||
for entity in entities:
|
||||
assert entity.entity.is_uri is True
|
||||
assert entity.entity.value.startswith(TRUSTGRAPH_ENTITIES)
|
||||
assert len(entity.context) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_relationships_triple_generation(self, relationships_processor, sample_relationships_response):
|
||||
"""Test triple generation from relationships extraction"""
|
||||
# Arrange
|
||||
metadata = Metadata(
|
||||
id="test-doc",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
metadata=[]
|
||||
)
|
||||
|
||||
# Act
|
||||
triples = []
|
||||
|
||||
for rel in sample_relationships_response:
|
||||
s = rel["subject"]
|
||||
p = rel["predicate"]
|
||||
o = rel["object"]
|
||||
|
||||
if s and p and o:
|
||||
s_uri = relationships_processor.to_uri(s)
|
||||
s_value = Value(value=str(s_uri), is_uri=True)
|
||||
|
||||
p_uri = relationships_processor.to_uri(p)
|
||||
p_value = Value(value=str(p_uri), is_uri=True)
|
||||
|
||||
if rel["object-entity"]:
|
||||
o_uri = relationships_processor.to_uri(o)
|
||||
o_value = Value(value=str(o_uri), is_uri=True)
|
||||
else:
|
||||
o_value = Value(value=str(o), is_uri=False)
|
||||
|
||||
# Main relationship triple
|
||||
triples.append(Triple(s=s_value, p=p_value, o=o_value))
|
||||
|
||||
# Label triples
|
||||
triples.append(Triple(
|
||||
s=s_value,
|
||||
p=Value(value=RDF_LABEL, is_uri=True),
|
||||
o=Value(value=str(s), is_uri=False)
|
||||
))
|
||||
|
||||
triples.append(Triple(
|
||||
s=p_value,
|
||||
p=Value(value=RDF_LABEL, is_uri=True),
|
||||
o=Value(value=str(p), is_uri=False)
|
||||
))
|
||||
|
||||
if rel["object-entity"]:
|
||||
triples.append(Triple(
|
||||
s=o_value,
|
||||
p=Value(value=RDF_LABEL, is_uri=True),
|
||||
o=Value(value=str(o), is_uri=False)
|
||||
))
|
||||
|
||||
# Assert
|
||||
assert len(triples) > 0
|
||||
|
||||
# Verify relationship triples exist
|
||||
relationship_triples = [t for t in triples if t.p.value.endswith("is_subset_of") or t.p.value.endswith("is_used_in")]
|
||||
assert len(relationship_triples) >= 2
|
||||
|
||||
# Verify label triples
|
||||
label_triples = [t for t in triples if t.p.value == RDF_LABEL]
|
||||
assert len(label_triples) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_knowledge_store_triples_storage(self, mock_cassandra_store):
|
||||
"""Test knowledge store triples storage integration"""
|
||||
# Arrange
|
||||
processor = MagicMock()
|
||||
processor.table_store = mock_cassandra_store
|
||||
processor.on_triples = KnowledgeStoreProcessor.on_triples.__get__(processor, KnowledgeStoreProcessor)
|
||||
|
||||
sample_triples = Triples(
|
||||
metadata=Metadata(
|
||||
id="test-doc",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
metadata=[]
|
||||
),
|
||||
triples=[
|
||||
Triple(
|
||||
s=Value(value="http://trustgraph.ai/e/machine-learning", is_uri=True),
|
||||
p=Value(value=DEFINITION, is_uri=True),
|
||||
o=Value(value="A subset of AI", is_uri=False)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = sample_triples
|
||||
|
||||
# Act
|
||||
await processor.on_triples(mock_msg, None, None)
|
||||
|
||||
# Assert
|
||||
mock_cassandra_store.add_triples.assert_called_once_with(sample_triples)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_knowledge_store_graph_embeddings_storage(self, mock_cassandra_store):
|
||||
"""Test knowledge store graph embeddings storage integration"""
|
||||
# Arrange
|
||||
processor = MagicMock()
|
||||
processor.table_store = mock_cassandra_store
|
||||
processor.on_graph_embeddings = KnowledgeStoreProcessor.on_graph_embeddings.__get__(processor, KnowledgeStoreProcessor)
|
||||
|
||||
sample_embeddings = GraphEmbeddings(
|
||||
metadata=Metadata(
|
||||
id="test-doc",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
metadata=[]
|
||||
),
|
||||
entities=[]
|
||||
)
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = sample_embeddings
|
||||
|
||||
# Act
|
||||
await processor.on_graph_embeddings(mock_msg, None, None)
|
||||
|
||||
# Assert
|
||||
mock_cassandra_store.add_graph_embeddings.assert_called_once_with(sample_embeddings)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_to_end_pipeline_coordination(self, definitions_processor, relationships_processor,
|
||||
mock_flow_context, sample_chunk):
|
||||
"""Test end-to-end pipeline coordination from chunk to storage"""
|
||||
# Arrange
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = sample_chunk
|
||||
mock_consumer = MagicMock()
|
||||
|
||||
# Act - Process through definitions extractor
|
||||
await definitions_processor.on_message(mock_msg, mock_consumer, mock_flow_context)
|
||||
|
||||
# Act - Process through relationships extractor
|
||||
await relationships_processor.on_message(mock_msg, mock_consumer, mock_flow_context)
|
||||
|
||||
# Assert
|
||||
# Verify both extractors called prompt service
|
||||
prompt_client = mock_flow_context("prompt-request")
|
||||
prompt_client.extract_definitions.assert_called_once()
|
||||
prompt_client.extract_relationships.assert_called_once()
|
||||
|
||||
# Verify triples were produced from both extractors
|
||||
triples_producer = mock_flow_context("triples")
|
||||
assert triples_producer.send.call_count == 2 # One from each extractor
|
||||
|
||||
# Verify entity contexts were produced from definitions extractor
|
||||
entity_contexts_producer = mock_flow_context("entity-contexts")
|
||||
entity_contexts_producer.send.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling_in_definitions_extraction(self, definitions_processor, mock_flow_context, sample_chunk):
|
||||
"""Test error handling in definitions extraction"""
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").extract_definitions.side_effect = Exception("Prompt service unavailable")
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = sample_chunk
|
||||
mock_consumer = MagicMock()
|
||||
|
||||
# Act & Assert
|
||||
# Should not raise exception, but should handle it gracefully
|
||||
await definitions_processor.on_message(mock_msg, mock_consumer, mock_flow_context)
|
||||
|
||||
# Verify prompt was attempted
|
||||
prompt_client = mock_flow_context("prompt-request")
|
||||
prompt_client.extract_definitions.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling_in_relationships_extraction(self, relationships_processor, mock_flow_context, sample_chunk):
|
||||
"""Test error handling in relationships extraction"""
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").extract_relationships.side_effect = Exception("Prompt service unavailable")
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = sample_chunk
|
||||
mock_consumer = MagicMock()
|
||||
|
||||
# Act & Assert
|
||||
# Should not raise exception, but should handle it gracefully
|
||||
await relationships_processor.on_message(mock_msg, mock_consumer, mock_flow_context)
|
||||
|
||||
# Verify prompt was attempted
|
||||
prompt_client = mock_flow_context("prompt-request")
|
||||
prompt_client.extract_relationships.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_extraction_results_handling(self, definitions_processor, mock_flow_context, sample_chunk):
|
||||
"""Test handling of empty extraction results"""
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").extract_definitions.return_value = []
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = sample_chunk
|
||||
mock_consumer = MagicMock()
|
||||
|
||||
# Act
|
||||
await definitions_processor.on_message(mock_msg, mock_consumer, mock_flow_context)
|
||||
|
||||
# Assert
|
||||
# Should still call producers but with empty results
|
||||
triples_producer = mock_flow_context("triples")
|
||||
entity_contexts_producer = mock_flow_context("entity-contexts")
|
||||
|
||||
triples_producer.send.assert_called_once()
|
||||
entity_contexts_producer.send.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_extraction_format_handling(self, definitions_processor, mock_flow_context, sample_chunk):
|
||||
"""Test handling of invalid extraction response format"""
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").extract_definitions.return_value = "invalid format" # Should be list
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = sample_chunk
|
||||
mock_consumer = MagicMock()
|
||||
|
||||
# Act & Assert
|
||||
# Should handle invalid format gracefully
|
||||
await definitions_processor.on_message(mock_msg, mock_consumer, mock_flow_context)
|
||||
|
||||
# Verify prompt was attempted
|
||||
prompt_client = mock_flow_context("prompt-request")
|
||||
prompt_client.extract_definitions.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_entity_filtering_and_validation(self, definitions_processor, mock_flow_context):
|
||||
"""Test entity filtering and validation in extraction"""
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").extract_definitions.return_value = [
|
||||
{"entity": "Valid Entity", "definition": "Valid definition"},
|
||||
{"entity": "", "definition": "Empty entity"}, # Should be filtered
|
||||
{"entity": "Valid Entity 2", "definition": ""}, # Should be filtered
|
||||
{"entity": None, "definition": "None entity"}, # Should be filtered
|
||||
{"entity": "Valid Entity 3", "definition": None}, # Should be filtered
|
||||
]
|
||||
|
||||
sample_chunk = Chunk(
|
||||
metadata=Metadata(id="test", user="user", collection="collection", metadata=[]),
|
||||
chunk=b"Test chunk"
|
||||
)
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = sample_chunk
|
||||
mock_consumer = MagicMock()
|
||||
|
||||
# Act
|
||||
await definitions_processor.on_message(mock_msg, mock_consumer, mock_flow_context)
|
||||
|
||||
# Assert
|
||||
# Should only process valid entities
|
||||
triples_producer = mock_flow_context("triples")
|
||||
entity_contexts_producer = mock_flow_context("entity-contexts")
|
||||
|
||||
triples_producer.send.assert_called_once()
|
||||
entity_contexts_producer.send.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.slow
|
||||
async def test_large_batch_processing_performance(self, definitions_processor, relationships_processor,
|
||||
mock_flow_context):
|
||||
"""Test performance with large batch of chunks"""
|
||||
# Arrange
|
||||
large_chunk_batch = [
|
||||
Chunk(
|
||||
metadata=Metadata(id=f"doc-{i}", user="user", collection="collection", metadata=[]),
|
||||
chunk=f"Document {i} contains machine learning and AI content.".encode("utf-8")
|
||||
)
|
||||
for i in range(100) # Large batch
|
||||
]
|
||||
|
||||
mock_consumer = MagicMock()
|
||||
|
||||
# Act
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
for chunk in large_chunk_batch:
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = chunk
|
||||
|
||||
# Process through both extractors
|
||||
await definitions_processor.on_message(mock_msg, mock_consumer, mock_flow_context)
|
||||
await relationships_processor.on_message(mock_msg, mock_consumer, mock_flow_context)
|
||||
|
||||
end_time = time.time()
|
||||
execution_time = end_time - start_time
|
||||
|
||||
# Assert
|
||||
assert execution_time < 30.0 # Should complete within reasonable time
|
||||
|
||||
# Verify all chunks were processed
|
||||
prompt_client = mock_flow_context("prompt-request")
|
||||
assert prompt_client.extract_definitions.call_count == 100
|
||||
assert prompt_client.extract_relationships.call_count == 100
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_metadata_propagation_through_pipeline(self, definitions_processor, mock_flow_context):
|
||||
"""Test metadata propagation through the pipeline"""
|
||||
# Arrange
|
||||
original_metadata = Metadata(
|
||||
id="test-doc-123",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
metadata=[
|
||||
Triple(
|
||||
s=Value(value="doc:test", is_uri=True),
|
||||
p=Value(value="dc:title", is_uri=True),
|
||||
o=Value(value="Test Document", is_uri=False)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
sample_chunk = Chunk(
|
||||
metadata=original_metadata,
|
||||
chunk=b"Test content for metadata propagation"
|
||||
)
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = sample_chunk
|
||||
mock_consumer = MagicMock()
|
||||
|
||||
# Act
|
||||
await definitions_processor.on_message(mock_msg, mock_consumer, mock_flow_context)
|
||||
|
||||
# Assert
|
||||
# Verify metadata was propagated to output
|
||||
triples_producer = mock_flow_context("triples")
|
||||
entity_contexts_producer = mock_flow_context("entity-contexts")
|
||||
|
||||
triples_producer.send.assert_called_once()
|
||||
entity_contexts_producer.send.assert_called_once()
|
||||
|
||||
# Check that metadata was included in the calls
|
||||
triples_call = triples_producer.send.call_args[0][0]
|
||||
entity_contexts_call = entity_contexts_producer.send.call_args[0][0]
|
||||
|
||||
assert triples_call.metadata.id == "test-doc-123"
|
||||
assert triples_call.metadata.user == "test_user"
|
||||
assert triples_call.metadata.collection == "test_collection"
|
||||
|
||||
assert entity_contexts_call.metadata.id == "test-doc-123"
|
||||
assert entity_contexts_call.metadata.user == "test_user"
|
||||
assert entity_contexts_call.metadata.collection == "test_collection"
|
||||
429
tests/integration/test_text_completion_integration.py
Normal file
429
tests/integration/test_text_completion_integration.py
Normal file
|
|
@ -0,0 +1,429 @@
|
|||
"""
|
||||
Integration tests for Text Completion Service (OpenAI)
|
||||
|
||||
These tests verify the end-to-end functionality of the OpenAI text completion service,
|
||||
testing API connectivity, authentication, rate limiting, error handling, and token tracking.
|
||||
Following the TEST_STRATEGY.md approach for integration testing.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import os
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from openai import OpenAI, RateLimitError
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionMessage
|
||||
from openai.types.chat.chat_completion import Choice
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
|
||||
from trustgraph.model.text_completion.openai.llm import Processor
|
||||
from trustgraph.exceptions import TooManyRequests
|
||||
from trustgraph.base import LlmResult
|
||||
from trustgraph.schema import TextCompletionRequest, TextCompletionResponse, Error
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestTextCompletionIntegration:
|
||||
"""Integration tests for OpenAI text completion service coordination"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_client(self):
|
||||
"""Mock OpenAI client that returns realistic responses"""
|
||||
client = MagicMock(spec=OpenAI)
|
||||
|
||||
# Mock chat completion response
|
||||
usage = CompletionUsage(prompt_tokens=50, completion_tokens=100, total_tokens=150)
|
||||
message = ChatCompletionMessage(role="assistant", content="This is a test response from the AI model.")
|
||||
choice = Choice(index=0, message=message, finish_reason="stop")
|
||||
|
||||
completion = ChatCompletion(
|
||||
id="chatcmpl-test123",
|
||||
choices=[choice],
|
||||
created=1234567890,
|
||||
model="gpt-3.5-turbo",
|
||||
object="chat.completion",
|
||||
usage=usage
|
||||
)
|
||||
|
||||
client.chat.completions.create.return_value = completion
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def processor_config(self):
|
||||
"""Configuration for processor testing"""
|
||||
return {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"temperature": 0.7,
|
||||
"max_output": 1024,
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def text_completion_processor(self, processor_config):
|
||||
"""Create text completion processor with test configuration"""
|
||||
# Create a minimal processor instance for testing generate_content
|
||||
processor = MagicMock()
|
||||
processor.model = processor_config["model"]
|
||||
processor.temperature = processor_config["temperature"]
|
||||
processor.max_output = processor_config["max_output"]
|
||||
|
||||
# Add the actual generate_content method from Processor class
|
||||
processor.generate_content = Processor.generate_content.__get__(processor, Processor)
|
||||
|
||||
return processor
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_completion_successful_generation(self, text_completion_processor, mock_openai_client):
|
||||
"""Test successful text completion generation"""
|
||||
# Arrange
|
||||
text_completion_processor.openai = mock_openai_client
|
||||
system_prompt = "You are a helpful assistant."
|
||||
user_prompt = "What is machine learning?"
|
||||
|
||||
# Act
|
||||
result = await text_completion_processor.generate_content(system_prompt, user_prompt)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "This is a test response from the AI model."
|
||||
assert result.in_token == 50
|
||||
assert result.out_token == 100
|
||||
assert result.model == "gpt-3.5-turbo"
|
||||
|
||||
# Verify OpenAI API was called correctly
|
||||
mock_openai_client.chat.completions.create.assert_called_once()
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
|
||||
assert call_args.kwargs['model'] == "gpt-3.5-turbo"
|
||||
assert call_args.kwargs['temperature'] == 0.7
|
||||
assert call_args.kwargs['max_tokens'] == 1024
|
||||
assert len(call_args.kwargs['messages']) == 1
|
||||
assert call_args.kwargs['messages'][0]['role'] == "user"
|
||||
assert "You are a helpful assistant." in call_args.kwargs['messages'][0]['content'][0]['text']
|
||||
assert "What is machine learning?" in call_args.kwargs['messages'][0]['content'][0]['text']
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_completion_with_different_configurations(self, mock_openai_client):
|
||||
"""Test text completion with various configuration parameters"""
|
||||
# Test different configurations
|
||||
test_configs = [
|
||||
{"model": "gpt-4", "temperature": 0.0, "max_output": 512},
|
||||
{"model": "gpt-3.5-turbo", "temperature": 1.0, "max_output": 2048},
|
||||
{"model": "gpt-4-turbo", "temperature": 0.5, "max_output": 4096}
|
||||
]
|
||||
|
||||
for config in test_configs:
|
||||
# Arrange - Create minimal processor mock
|
||||
processor = MagicMock()
|
||||
processor.model = config['model']
|
||||
processor.temperature = config['temperature']
|
||||
processor.max_output = config['max_output']
|
||||
processor.openai = mock_openai_client
|
||||
|
||||
# Add the actual generate_content method
|
||||
processor.generate_content = Processor.generate_content.__get__(processor, Processor)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "This is a test response from the AI model."
|
||||
assert result.in_token == 50
|
||||
assert result.out_token == 100
|
||||
# Note: result.model comes from mock response, not processor config
|
||||
|
||||
# Verify configuration was applied
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
assert call_args.kwargs['model'] == config['model']
|
||||
assert call_args.kwargs['temperature'] == config['temperature']
|
||||
assert call_args.kwargs['max_tokens'] == config['max_output']
|
||||
|
||||
# Reset mock for next iteration
|
||||
mock_openai_client.reset_mock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_completion_rate_limit_handling(self, text_completion_processor, mock_openai_client):
|
||||
"""Test proper rate limit error handling"""
|
||||
# Arrange
|
||||
mock_openai_client.chat.completions.create.side_effect = RateLimitError(
|
||||
"Rate limit exceeded",
|
||||
response=MagicMock(status_code=429),
|
||||
body={}
|
||||
)
|
||||
text_completion_processor.openai = mock_openai_client
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(TooManyRequests):
|
||||
await text_completion_processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Verify OpenAI API was called
|
||||
mock_openai_client.chat.completions.create.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_completion_api_error_handling(self, text_completion_processor, mock_openai_client):
|
||||
"""Test handling of general API errors"""
|
||||
# Arrange
|
||||
mock_openai_client.chat.completions.create.side_effect = Exception("API connection failed")
|
||||
text_completion_processor.openai = mock_openai_client
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await text_completion_processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
assert "API connection failed" in str(exc_info.value)
|
||||
mock_openai_client.chat.completions.create.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_completion_token_tracking(self, text_completion_processor, mock_openai_client):
|
||||
"""Test accurate token counting and tracking"""
|
||||
# Arrange - Different token counts for multiple requests
|
||||
test_cases = [
|
||||
(25, 75), # Small request
|
||||
(100, 200), # Medium request
|
||||
(500, 1000) # Large request
|
||||
]
|
||||
|
||||
for input_tokens, output_tokens in test_cases:
|
||||
# Update mock response with different token counts
|
||||
usage = CompletionUsage(
|
||||
prompt_tokens=input_tokens,
|
||||
completion_tokens=output_tokens,
|
||||
total_tokens=input_tokens + output_tokens
|
||||
)
|
||||
message = ChatCompletionMessage(role="assistant", content="Test response")
|
||||
choice = Choice(index=0, message=message, finish_reason="stop")
|
||||
|
||||
completion = ChatCompletion(
|
||||
id="chatcmpl-test123",
|
||||
choices=[choice],
|
||||
created=1234567890,
|
||||
model="gpt-3.5-turbo",
|
||||
object="chat.completion",
|
||||
usage=usage
|
||||
)
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = completion
|
||||
text_completion_processor.openai = mock_openai_client
|
||||
|
||||
# Act
|
||||
result = await text_completion_processor.generate_content("System", "Prompt")
|
||||
|
||||
# Assert
|
||||
assert result.in_token == input_tokens
|
||||
assert result.out_token == output_tokens
|
||||
assert result.model == "gpt-3.5-turbo"
|
||||
|
||||
# Reset mock for next iteration
|
||||
mock_openai_client.reset_mock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_completion_prompt_construction(self, text_completion_processor, mock_openai_client):
|
||||
"""Test proper prompt construction with system and user prompts"""
|
||||
# Arrange
|
||||
text_completion_processor.openai = mock_openai_client
|
||||
system_prompt = "You are an expert in artificial intelligence."
|
||||
user_prompt = "Explain neural networks in simple terms."
|
||||
|
||||
# Act
|
||||
result = await text_completion_processor.generate_content(system_prompt, user_prompt)
|
||||
|
||||
# Assert
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
sent_message = call_args.kwargs['messages'][0]['content'][0]['text']
|
||||
|
||||
# Verify system and user prompts are combined correctly
|
||||
assert system_prompt in sent_message
|
||||
assert user_prompt in sent_message
|
||||
assert sent_message.startswith(system_prompt)
|
||||
assert user_prompt in sent_message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_completion_concurrent_requests(self, processor_config, mock_openai_client):
|
||||
"""Test handling of concurrent requests"""
|
||||
# Arrange
|
||||
processors = []
|
||||
for i in range(5):
|
||||
processor = MagicMock()
|
||||
processor.model = processor_config["model"]
|
||||
processor.temperature = processor_config["temperature"]
|
||||
processor.max_output = processor_config["max_output"]
|
||||
processor.openai = mock_openai_client
|
||||
processor.generate_content = Processor.generate_content.__get__(processor, Processor)
|
||||
processors.append(processor)
|
||||
|
||||
# Simulate multiple concurrent requests
|
||||
tasks = []
|
||||
for i, processor in enumerate(processors):
|
||||
task = processor.generate_content(f"System {i}", f"Prompt {i}")
|
||||
tasks.append(task)
|
||||
|
||||
# Act
|
||||
import asyncio
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# Assert
|
||||
assert len(results) == 5
|
||||
for result in results:
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "This is a test response from the AI model."
|
||||
assert result.in_token == 50
|
||||
assert result.out_token == 100
|
||||
|
||||
# Verify all requests were processed
|
||||
assert mock_openai_client.chat.completions.create.call_count == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_completion_response_format_validation(self, text_completion_processor, mock_openai_client):
|
||||
"""Test response format and structure validation"""
|
||||
# Arrange
|
||||
text_completion_processor.openai = mock_openai_client
|
||||
|
||||
# Act
|
||||
result = await text_completion_processor.generate_content("System", "Prompt")
|
||||
|
||||
# Assert
|
||||
# Verify OpenAI API call parameters
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
assert call_args.kwargs['response_format'] == {"type": "text"}
|
||||
assert call_args.kwargs['top_p'] == 1
|
||||
assert call_args.kwargs['frequency_penalty'] == 0
|
||||
assert call_args.kwargs['presence_penalty'] == 0
|
||||
|
||||
# Verify result structure
|
||||
assert hasattr(result, 'text')
|
||||
assert hasattr(result, 'in_token')
|
||||
assert hasattr(result, 'out_token')
|
||||
assert hasattr(result, 'model')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_completion_authentication_patterns(self):
|
||||
"""Test different authentication configurations"""
|
||||
# Test missing API key first (this should fail early)
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
Processor(id="test-no-key", api_key=None)
|
||||
assert "OpenAI API key not specified" in str(exc_info.value)
|
||||
|
||||
# Test authentication pattern by examining the initialization logic
|
||||
# Since we can't fully instantiate due to taskgroup requirements,
|
||||
# we'll test the authentication logic directly
|
||||
from trustgraph.model.text_completion.openai.llm import default_api_key, default_base_url
|
||||
|
||||
# Test default values
|
||||
assert default_base_url == "https://api.openai.com/v1"
|
||||
|
||||
# Test configuration parameters
|
||||
test_configs = [
|
||||
{"api_key": "test-key-1", "url": "https://api.openai.com/v1"},
|
||||
{"api_key": "test-key-2", "url": "https://custom.openai.com/v1"},
|
||||
]
|
||||
|
||||
for config in test_configs:
|
||||
# We can't fully test instantiation due to taskgroup,
|
||||
# but we can verify the authentication logic would work
|
||||
assert config["api_key"] is not None
|
||||
assert config["url"] is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_completion_error_propagation(self, text_completion_processor, mock_openai_client):
|
||||
"""Test error propagation through the service"""
|
||||
# Test different error types
|
||||
error_cases = [
|
||||
(RateLimitError("Rate limit", response=MagicMock(status_code=429), body={}), TooManyRequests),
|
||||
(Exception("Connection timeout"), Exception),
|
||||
(ValueError("Invalid request"), ValueError),
|
||||
]
|
||||
|
||||
for error_input, expected_error in error_cases:
|
||||
# Arrange
|
||||
mock_openai_client.chat.completions.create.side_effect = error_input
|
||||
text_completion_processor.openai = mock_openai_client
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(expected_error):
|
||||
await text_completion_processor.generate_content("System", "Prompt")
|
||||
|
||||
# Reset mock for next iteration
|
||||
mock_openai_client.reset_mock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_completion_model_parameter_validation(self, mock_openai_client):
|
||||
"""Test that model parameters are correctly passed to OpenAI API"""
|
||||
# Arrange
|
||||
processor = MagicMock()
|
||||
processor.model = "gpt-4"
|
||||
processor.temperature = 0.8
|
||||
processor.max_output = 2048
|
||||
processor.openai = mock_openai_client
|
||||
processor.generate_content = Processor.generate_content.__get__(processor, Processor)
|
||||
|
||||
# Act
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
assert call_args.kwargs['model'] == "gpt-4"
|
||||
assert call_args.kwargs['temperature'] == 0.8
|
||||
assert call_args.kwargs['max_tokens'] == 2048
|
||||
assert call_args.kwargs['top_p'] == 1
|
||||
assert call_args.kwargs['frequency_penalty'] == 0
|
||||
assert call_args.kwargs['presence_penalty'] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.slow
|
||||
async def test_text_completion_performance_timing(self, text_completion_processor, mock_openai_client):
|
||||
"""Test performance timing for text completion"""
|
||||
# Arrange
|
||||
text_completion_processor.openai = mock_openai_client
|
||||
|
||||
# Act
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
result = await text_completion_processor.generate_content("System", "Prompt")
|
||||
|
||||
end_time = time.time()
|
||||
execution_time = end_time - start_time
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert execution_time < 1.0 # Should complete quickly with mocked API
|
||||
mock_openai_client.chat.completions.create.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_completion_response_content_extraction(self, text_completion_processor, mock_openai_client):
|
||||
"""Test proper extraction of response content from OpenAI API"""
|
||||
# Arrange
|
||||
test_responses = [
|
||||
"This is a simple response.",
|
||||
"This is a multi-line response.\nWith multiple lines.\nAnd more content.",
|
||||
"Response with special characters: @#$%^&*()_+-=[]{}|;':\",./<>?",
|
||||
"" # Empty response
|
||||
]
|
||||
|
||||
for test_content in test_responses:
|
||||
# Update mock response
|
||||
usage = CompletionUsage(prompt_tokens=10, completion_tokens=20, total_tokens=30)
|
||||
message = ChatCompletionMessage(role="assistant", content=test_content)
|
||||
choice = Choice(index=0, message=message, finish_reason="stop")
|
||||
|
||||
completion = ChatCompletion(
|
||||
id="chatcmpl-test123",
|
||||
choices=[choice],
|
||||
created=1234567890,
|
||||
model="gpt-3.5-turbo",
|
||||
object="chat.completion",
|
||||
usage=usage
|
||||
)
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = completion
|
||||
text_completion_processor.openai = mock_openai_client
|
||||
|
||||
# Act
|
||||
result = await text_completion_processor.generate_content("System", "Prompt")
|
||||
|
||||
# Assert
|
||||
assert result.text == test_content
|
||||
assert result.in_token == 10
|
||||
assert result.out_token == 20
|
||||
assert result.model == "gpt-3.5-turbo"
|
||||
|
||||
# Reset mock for next iteration
|
||||
mock_openai_client.reset_mock()
|
||||
21
tests/pytest.ini
Normal file
21
tests/pytest.ini
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
[pytest]
|
||||
testpaths = tests
|
||||
python_paths = .
|
||||
python_files = test_*.py
|
||||
python_classes = Test*
|
||||
python_functions = test_*
|
||||
addopts =
|
||||
-v
|
||||
--tb=short
|
||||
--strict-markers
|
||||
--disable-warnings
|
||||
--cov=trustgraph
|
||||
--cov-report=html
|
||||
--cov-report=term-missing
|
||||
# --cov-fail-under=80
|
||||
asyncio_mode = auto
|
||||
markers =
|
||||
slow: marks tests as slow (deselect with '-m "not slow"')
|
||||
integration: marks tests as integration tests
|
||||
unit: marks tests as unit tests
|
||||
vertexai: marks tests as vertex ai specific tests
|
||||
21
tests/query
21
tests/query
|
|
@ -1,21 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from trustgraph.graph_rag import GraphRag
|
||||
import sys
|
||||
|
||||
query = " ".join(sys.argv[1:])
|
||||
|
||||
gr = GraphRag(
|
||||
verbose=True,
|
||||
pulsar_host="pulsar://localhost:6650",
|
||||
pr_request_queue="non-persistent://tg/request/prompt",
|
||||
pr_response_queue="non-persistent://tg/response/prompt-response",
|
||||
)
|
||||
|
||||
if query == "":
|
||||
query="""This knowledge graph describes the Space Shuttle disaster.
|
||||
Present 20 facts which are present in the knowledge graph."""
|
||||
|
||||
resp = gr.query(query)
|
||||
print(resp)
|
||||
|
||||
|
|
@ -1,100 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
Accepts entity/vector pairs and writes them to a Milvus store.
|
||||
"""
|
||||
|
||||
from trustgraph.schema import Chunk
|
||||
from trustgraph.schema import chunk_ingest_queue
|
||||
from trustgraph.log_level import LogLevel
|
||||
from trustgraph.base import Consumer
|
||||
from threading import Thread, Lock
|
||||
import time
|
||||
|
||||
module = "test-chunk-size"
|
||||
|
||||
default_input_queue = chunk_ingest_queue
|
||||
default_subscriber = module
|
||||
default_store_uri = 'http://localhost:19530'
|
||||
|
||||
class Processor(Consumer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
width = params.get("width", 200)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": Chunk,
|
||||
}
|
||||
)
|
||||
|
||||
self.sizes = {}
|
||||
self.width = width
|
||||
self.lock = Lock()
|
||||
|
||||
Thread(target=self.report).start()
|
||||
|
||||
def report(self):
|
||||
|
||||
while True:
|
||||
time.sleep(1)
|
||||
|
||||
print()
|
||||
|
||||
with self.lock:
|
||||
tot = 0
|
||||
for i in range(0, 20000, self.width):
|
||||
k = (i, i + self.width)
|
||||
if k in self.sizes:
|
||||
print(f"{i:5d} ..{i+self.width:5d}: {self.sizes[k]}")
|
||||
tot += self.sizes[k]
|
||||
print(f"{'Total':13s}: {tot}")
|
||||
|
||||
|
||||
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
|
||||
chunk = v.chunk.decode("utf-8")
|
||||
|
||||
l = len(chunk)
|
||||
|
||||
|
||||
low = int(l / self.width) * self.width
|
||||
high = low + self.width
|
||||
key = (low, high)
|
||||
|
||||
with self.lock:
|
||||
|
||||
if key not in self.sizes:
|
||||
self.sizes[key] = 0
|
||||
|
||||
self.sizes[key] += 1
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
Consumer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--width',
|
||||
type=int,
|
||||
default=200,
|
||||
help=f'Histogram width (default: 200)',
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
run()
|
||||
|
||||
9
tests/requirements.txt
Normal file
9
tests/requirements.txt
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
pytest>=7.0.0
|
||||
pytest-asyncio>=0.21.0
|
||||
pytest-mock>=3.10.0
|
||||
pytest-cov>=4.0.0
|
||||
google-cloud-aiplatform>=1.25.0
|
||||
google-auth>=2.17.0
|
||||
google-api-core>=2.11.0
|
||||
pulsar-client>=3.0.0
|
||||
prometheus-client>=0.16.0
|
||||
|
|
@ -1,48 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import json
|
||||
import textwrap
|
||||
from trustgraph.clients.agent_client import AgentClient
|
||||
|
||||
def wrap(text, width=75):
|
||||
|
||||
if text is None: text = "n/a"
|
||||
|
||||
out = textwrap.wrap(
|
||||
text, width=width
|
||||
)
|
||||
return "\n".join(out)
|
||||
|
||||
def output(text, prefix="> ", width=78):
|
||||
|
||||
out = textwrap.indent(
|
||||
text, prefix=prefix
|
||||
)
|
||||
print(out)
|
||||
|
||||
p = AgentClient(
|
||||
pulsar_host="pulsar://pulsar:6650",
|
||||
input_queue = "non-persistent://tg/request/agent:0000",
|
||||
output_queue = "non-persistent://tg/response/agent:0000",
|
||||
)
|
||||
|
||||
q = "How many cats does Mark have? Calculate that number raised to 0.4 power. Is that number lower than the numeric part of the mission identifier of the Space Shuttle Challenger on its last mission? If so, give me an apple pie recipe, otherwise return a poem about cheese."
|
||||
|
||||
output(wrap(q), "\U00002753 ")
|
||||
print()
|
||||
|
||||
def think(x):
|
||||
output(wrap(x), "\U0001f914 ")
|
||||
print()
|
||||
|
||||
def observe(x):
|
||||
output(wrap(x), "\U0001f4a1 ")
|
||||
print()
|
||||
|
||||
resp = p.request(
|
||||
question=q, think=think, observe=observe,
|
||||
)
|
||||
|
||||
output(resp, "\U0001f4ac ")
|
||||
print()
|
||||
|
||||
|
|
@ -1,2 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
|
|
@ -1,22 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import pulsar
|
||||
from trustgraph.clients.document_embeddings_client import DocumentEmbeddingsClient
|
||||
from trustgraph.clients.embeddings_client import EmbeddingsClient
|
||||
|
||||
ec = EmbeddingsClient(pulsar_host="pulsar://localhost:6650")
|
||||
|
||||
vectors = ec.request("What caused the space shuttle to explode?")
|
||||
|
||||
print(vectors)
|
||||
|
||||
llm = DocumentEmbeddingsClient(pulsar_host="pulsar://localhost:6650")
|
||||
|
||||
limit=10
|
||||
|
||||
resp = llm.request(vectors, limit)
|
||||
|
||||
print("Response...")
|
||||
for val in resp:
|
||||
print(val)
|
||||
|
||||
|
|
@ -1,25 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import pulsar
|
||||
from trustgraph.clients.prompt_client import PromptClient
|
||||
|
||||
p = PromptClient(pulsar_host="pulsar://localhost:6650")
|
||||
|
||||
docs = [
|
||||
"In our house there is a big cat and a small cat.",
|
||||
"The small cat is black.",
|
||||
"The big cat is called Fred.",
|
||||
"The orange stripey cat is big.",
|
||||
"The black cat pounces on the big cat.",
|
||||
"The black cat is called Hope."
|
||||
]
|
||||
|
||||
query="What is the name of the cat who pounces on Fred? Provide a full explanation."
|
||||
|
||||
resp = p.request_document_prompt(
|
||||
query=query,
|
||||
documents=docs,
|
||||
)
|
||||
|
||||
print(resp)
|
||||
|
||||
|
|
@ -1,19 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import pulsar
|
||||
from trustgraph.clients.document_rag_client import DocumentRagClient
|
||||
|
||||
rag = DocumentRagClient(
|
||||
pulsar_host="pulsar://localhost:6650",
|
||||
subscriber="test1",
|
||||
input_queue = "non-persistent://tg/request/document-rag:default",
|
||||
output_queue = "non-persistent://tg/response/document-rag:default",
|
||||
)
|
||||
|
||||
query="""
|
||||
What was the cause of the space shuttle disaster?"""
|
||||
|
||||
resp = rag.request(query)
|
||||
|
||||
print(resp)
|
||||
|
||||
|
|
@ -1,18 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import pulsar
|
||||
from trustgraph.clients.embeddings_client import EmbeddingsClient
|
||||
|
||||
embed = EmbeddingsClient(
|
||||
pulsar_host="pulsar://pulsar:6650",
|
||||
input_queue="non-persistent://tg/request/embeddings:default",
|
||||
output_queue="non-persistent://tg/response/embeddings:default",
|
||||
subscriber="test1",
|
||||
)
|
||||
|
||||
prompt="Write a funny limerick about a llama"
|
||||
|
||||
resp = embed.request(prompt)
|
||||
|
||||
print(resp)
|
||||
|
||||
|
|
@ -1,92 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import requests
|
||||
|
||||
url = "http://localhost:8088/"
|
||||
|
||||
resp = requests.post(
|
||||
f"{url}/api/v1/flow",
|
||||
json={
|
||||
"operation": "list-classes",
|
||||
}
|
||||
)
|
||||
|
||||
print(resp)
|
||||
print(resp.text)
|
||||
|
||||
resp = requests.post(
|
||||
f"{url}/api/v1/flow",
|
||||
json={
|
||||
"operation": "get-class",
|
||||
"class-name": "default",
|
||||
}
|
||||
)
|
||||
|
||||
print(resp)
|
||||
print(resp.text)
|
||||
|
||||
resp = requests.post(
|
||||
f"{url}/api/v1/flow",
|
||||
json={
|
||||
"operation": "put-class",
|
||||
"class-name": "bunch",
|
||||
"class-definition": "{}",
|
||||
}
|
||||
)
|
||||
|
||||
print(resp)
|
||||
print(resp.text)
|
||||
|
||||
resp = requests.post(
|
||||
f"{url}/api/v1/flow",
|
||||
json={
|
||||
"operation": "get-class",
|
||||
"class-name": "bunch",
|
||||
}
|
||||
)
|
||||
|
||||
print(resp)
|
||||
print(resp.text)
|
||||
|
||||
resp = requests.post(
|
||||
f"{url}/api/v1/flow",
|
||||
json={
|
||||
"operation": "list-classes",
|
||||
}
|
||||
)
|
||||
|
||||
print(resp)
|
||||
print(resp.text)
|
||||
|
||||
|
||||
resp = requests.post(
|
||||
f"{url}/api/v1/flow",
|
||||
json={
|
||||
"operation": "delete-class",
|
||||
"class-name": "bunch",
|
||||
}
|
||||
)
|
||||
|
||||
print(resp)
|
||||
print(resp.text)
|
||||
|
||||
|
||||
resp = requests.post(
|
||||
f"{url}/api/v1/flow",
|
||||
json={
|
||||
"operation": "list-classes",
|
||||
}
|
||||
)
|
||||
|
||||
print(resp)
|
||||
print(resp.text)
|
||||
|
||||
resp = requests.post(
|
||||
f"{url}/api/v1/flow",
|
||||
json={
|
||||
"operation": "list-flows",
|
||||
}
|
||||
)
|
||||
|
||||
print(resp)
|
||||
print(resp.text)
|
||||
|
|
@ -1,19 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import requests
|
||||
|
||||
url = "http://localhost:8088/"
|
||||
|
||||
resp = requests.post(
|
||||
f"{url}/api/v1/flow",
|
||||
json={
|
||||
"operation": "get-class",
|
||||
"class-name": "default",
|
||||
}
|
||||
)
|
||||
|
||||
resp = resp.json()
|
||||
|
||||
print(resp["class-definition"])
|
||||
|
||||
|
||||
File diff suppressed because one or more lines are too long
|
|
@ -1,23 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import requests
|
||||
import json
|
||||
|
||||
url = "http://localhost:8088/"
|
||||
|
||||
resp = requests.post(
|
||||
f"{url}/api/v1/flow",
|
||||
json={
|
||||
"operation": "start-flow",
|
||||
"flow-id": "0003",
|
||||
"class-name": "default",
|
||||
}
|
||||
)
|
||||
|
||||
print(resp)
|
||||
print(resp.text)
|
||||
resp = resp.json()
|
||||
|
||||
|
||||
print(resp)
|
||||
|
||||
|
|
@ -1,22 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import requests
|
||||
import json
|
||||
|
||||
url = "http://localhost:8088/"
|
||||
|
||||
resp = requests.post(
|
||||
f"{url}/api/v1/flow",
|
||||
json={
|
||||
"operation": "stop-flow",
|
||||
"flow-id": "0003",
|
||||
}
|
||||
)
|
||||
|
||||
print(resp)
|
||||
print(resp.text)
|
||||
resp = resp.json()
|
||||
|
||||
|
||||
print(resp)
|
||||
|
||||
|
|
@ -1,11 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import pulsar
|
||||
from trustgraph.clients.config_client import ConfigClient
|
||||
|
||||
cli = ConfigClient(pulsar_host="pulsar://localhost:6650")
|
||||
|
||||
resp = cli.request_config()
|
||||
|
||||
print(resp)
|
||||
|
||||
|
|
@ -1,22 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import pulsar
|
||||
from trustgraph.clients.graph_embeddings_client import GraphEmbeddingsClient
|
||||
from trustgraph.clients.embeddings_client import EmbeddingsClient
|
||||
|
||||
ec = EmbeddingsClient(pulsar_host="pulsar://localhost:6650")
|
||||
|
||||
vectors = ec.request("What caused the space shuttle to explode?")
|
||||
|
||||
print(vectors)
|
||||
|
||||
llm = GraphEmbeddingsClient(pulsar_host="pulsar://localhost:6650")
|
||||
|
||||
limit=10
|
||||
|
||||
resp = llm.request(vectors, limit)
|
||||
|
||||
print("Response...")
|
||||
for val in resp:
|
||||
print(val.value)
|
||||
|
||||
|
|
@ -1,22 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import pulsar
|
||||
from trustgraph.clients.graph_rag_client import GraphRagClient
|
||||
|
||||
rag = GraphRagClient(
|
||||
pulsar_host="pulsar://localhost:6650",
|
||||
subscriber="test1",
|
||||
input_queue = "non-persistent://tg/request/graph-rag:default",
|
||||
output_queue = "non-persistent://tg/response/graph-rag:default",
|
||||
)
|
||||
|
||||
#query="""
|
||||
#This knowledge graph describes the Space Shuttle disaster.
|
||||
#Present 20 facts which are present in the knowledge graph."""
|
||||
|
||||
query = "How many cats does Mark have?"
|
||||
|
||||
resp = rag.request(query)
|
||||
|
||||
print(resp)
|
||||
|
||||
|
|
@ -1,14 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import pulsar
|
||||
from trustgraph.clients.graph_rag_client import GraphRagClient
|
||||
|
||||
rag = GraphRagClient(pulsar_host="pulsar://localhost:6650")
|
||||
|
||||
query="""List 20 key points to describe the research that led to the discovery of Leo VI.
|
||||
"""
|
||||
|
||||
resp = rag.request(query)
|
||||
|
||||
print(resp)
|
||||
|
||||
|
|
@ -1,24 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import pulsar
|
||||
from trustgraph.clients.prompt_client import PromptClient
|
||||
|
||||
p = PromptClient(pulsar_host="pulsar://localhost:6650")
|
||||
|
||||
chunk = """I noticed a cat in my garden. It is a four-legged animal
|
||||
which is a mammal and can be tame or wild. I wonder if it will be friends
|
||||
with me. I think the cat's name is Fred and it has 4 legs.
|
||||
|
||||
A cat is a small mammal.
|
||||
|
||||
A grapefruit is a citrus fruit.
|
||||
|
||||
"""
|
||||
|
||||
resp = p.request_definitions(
|
||||
chunk=chunk,
|
||||
)
|
||||
|
||||
for d in resp:
|
||||
print(d.name, ":", d.definition)
|
||||
|
||||
|
|
@ -1,72 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import pulsar
|
||||
from trustgraph.clients.prompt_client import PromptClient
|
||||
|
||||
p = PromptClient(pulsar_host="pulsar://localhost:6650")
|
||||
|
||||
facts = [
|
||||
("accident", "evoked", "a wide range of deeply felt public responses"),
|
||||
("Space Shuttle concept", "had", "genesis"),
|
||||
("Commission", "had", "a mandate to develop recommendations for corrective or other action based upon the Commission's findings and determinations"),
|
||||
("Commission", "established", "teams of persons"),
|
||||
("Space Shuttle Challenger", "http://www.w3.org/2004/02/skos/core#definition", "A space shuttle that was destroyed in an accident during mission 51-L."),
|
||||
("The mid fuselage", "contains", "the payload bay"),
|
||||
("Volume I", "contains", "Chapter IX"),
|
||||
("accident", "resulted in", "firm national resolve that those men and women be forever enshrined in the annals of American heroes"),
|
||||
("Volume I", "contains", "Chapter IV"),
|
||||
("Volume I", "contains", "Appendix A"),
|
||||
("Volume I", "contains", "Appendix B"),
|
||||
("Volume I", "contains", "The Staff"),
|
||||
("Commission", "required", "detailed investigation"),
|
||||
("Commission", "focused", "safety aspects of future flights"),
|
||||
("Commission", "http://www.w3.org/2004/02/skos/core#definition", "An independent group appointed to investigate the Space Shuttle Challenger accident."),
|
||||
("Commission", "moved forward with", "its investigation"),
|
||||
("President", "appointed", "an independent Commission"),
|
||||
("accident", "interrupted", "one of the most productive engineering, scientific and exploratory programs in history"),
|
||||
("Volume I", "contains", "Preface"),
|
||||
("Commission", "believes", "investigation"),
|
||||
("Volume I", "contains", "Chapter I"),
|
||||
("President", "was moved and troubled", "by this accident in a very personal way"),
|
||||
("PRESIDENTIAL COMMISSION", "Report to", "President"),
|
||||
("Volume I", "contains", "Chapter VI"),
|
||||
("Commission", "held", "public hearings dealing with the facts leading up to the accident"),
|
||||
("Volume I", "http://www.w3.org/2004/02/skos/core#definition", "The first volume of a multi-volume publication."),
|
||||
("Space Shuttle Challenger", "was involved in", "an accident"),
|
||||
("Volume I", "contains", "Chapter VII"),
|
||||
("Volume I", "contains", "Chapter II"),
|
||||
("Volume I", "contains", "Chapter V"),
|
||||
("Commission", "believes", "its investigation and report have been responsive to the request of the President and hopes that they will serve the best interests of the nation in restoring the United States space program to its preeminent position in the world"),
|
||||
("Commission", "supported", "panels"),
|
||||
("Volume I", "contains", "Chapter VIII"),
|
||||
("NASA", "cooperated", "Commission"),
|
||||
("liquid oxygen tank", "contains", "oxidizer"),
|
||||
("President", "http://www.w3.org/2004/02/skos/core#definition", "The head of state of the United States."),
|
||||
("Volume I", "contains", "Chapter III"),
|
||||
("Apollo lunar landing spacecraft", "had", "not yet flown"),
|
||||
("Commission", "construe", "mandate"),
|
||||
("accident", "became", "a milestone on the way to achieving the full potential that space offers to mankind"),
|
||||
("Volume I", "contains", "The Commission"),
|
||||
("Commission", "focused", "attention"),
|
||||
("Commission", "learned", "lessons"),
|
||||
("Commission", "required", "interfere with or supersede Congress"),
|
||||
("Commission", "was made up of", "persons not connected with the mission"),
|
||||
("Commission", "required", "review budgetary matters"),
|
||||
("Space Shuttle", "became", "focus of NASA's near-term future"),
|
||||
("Volume I", "contains", "Appendix C"),
|
||||
("accident", "caused", "grief and sadness for the loss of seven brave members of the crew"),
|
||||
("Commission", "http://www.w3.org/2004/02/skos/core#definition", "A group established to investigate the space shuttle accident"),
|
||||
("Volume I", "contains", "Appendix D"),
|
||||
("Commission", "had", "a mandate to review the circumstances surrounding the accident to establish the probable cause or causes of the accident"),
|
||||
("Volume I", "contains", "Recommendations")
|
||||
]
|
||||
|
||||
query="Present 20 facts which are present in the knowledge graph."
|
||||
|
||||
resp = p.request_kg_prompt(
|
||||
query=query,
|
||||
kg=facts,
|
||||
)
|
||||
|
||||
print(resp)
|
||||
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import pulsar
|
||||
from trustgraph.clients.prompt_client import PromptClient
|
||||
|
||||
p = PromptClient(pulsar_host="pulsar://localhost:6650")
|
||||
|
||||
chunk = """I noticed a cat in my garden. It is a four-legged animal
|
||||
which is a mammal and can be tame or wild. I wonder if it will be friends
|
||||
with me. I think the cat's name is Fred and it has 4 legs"""
|
||||
|
||||
resp = p.request_relationships(
|
||||
chunk=chunk,
|
||||
)
|
||||
|
||||
for d in resp:
|
||||
print(d.s)
|
||||
print(" ", d.p)
|
||||
print(" ", d.o)
|
||||
print(" ", d.o_entity)
|
||||
|
||||
|
|
@ -1,19 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import pulsar
|
||||
from trustgraph.clients.prompt_client import PromptClient
|
||||
|
||||
p = PromptClient(pulsar_host="pulsar://localhost:6650")
|
||||
|
||||
chunk = """I noticed a cat in my garden. It is a four-legged animal
|
||||
which is a mammal and can be tame or wild. I wonder if it will be friends
|
||||
with me. I think the cat's name is Fred and it has 4 legs"""
|
||||
|
||||
resp = p.request_topics(
|
||||
chunk=chunk,
|
||||
)
|
||||
|
||||
for d in resp:
|
||||
print(d.topic)
|
||||
print(" ", d.definition)
|
||||
|
||||
|
|
@ -1,19 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import pulsar
|
||||
from trustgraph.clients.llm_client import LlmClient
|
||||
|
||||
llm = LlmClient(
|
||||
pulsar_host="pulsar://pulsar:6650",
|
||||
input_queue="non-persistent://tg/request/text-completion:default",
|
||||
output_queue="non-persistent://tg/response/text-completion:default",
|
||||
subscriber="test1",
|
||||
)
|
||||
|
||||
system = "You are a lovely assistant."
|
||||
prompt="what is 2 + 2 == 5"
|
||||
|
||||
resp = llm.request(system, prompt)
|
||||
|
||||
print(resp)
|
||||
|
||||
|
|
@ -1,15 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import pulsar
|
||||
from trustgraph.clients.llm_client import LlmClient
|
||||
|
||||
llm = LlmClient(pulsar_host="pulsar://localhost:6650")
|
||||
|
||||
prompt="What is 2 + 12?"
|
||||
|
||||
try:
|
||||
resp = llm.request(prompt)
|
||||
print(resp)
|
||||
except Exception as e:
|
||||
print(f"{e.__class__.__name__}: {e}")
|
||||
|
||||
|
|
@ -1,15 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import pulsar
|
||||
from trustgraph.clients.llm_client import LlmClient
|
||||
|
||||
llm = LlmClient(pulsar_host="pulsar://localhost:6650")
|
||||
|
||||
prompt="What is 2 + 12?"
|
||||
|
||||
try:
|
||||
resp = llm.request(prompt)
|
||||
print(resp)
|
||||
except Exception as e:
|
||||
print(f"{e.__class__.__name__}: {e}")
|
||||
|
||||
|
|
@ -1,36 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import pulsar
|
||||
from pulsar.schema import JsonSchema
|
||||
import base64
|
||||
|
||||
from trustgraph.schema import Document, Metadata
|
||||
|
||||
client = pulsar.Client("pulsar://localhost:6650", listener_name="localhost")
|
||||
|
||||
prod = client.create_producer(
|
||||
topic="persistent://tg/flow/document-load:0000",
|
||||
schema=JsonSchema(Document),
|
||||
chunking_enabled=True,
|
||||
)
|
||||
|
||||
path = "../sources/Challenger-Report-Vol1.pdf"
|
||||
|
||||
with open(path, "rb") as f:
|
||||
blob = base64.b64encode(f.read()).decode("utf-8")
|
||||
|
||||
message = Document(
|
||||
metadata = Metadata(
|
||||
id = "00001",
|
||||
metadata = [],
|
||||
user="trustgraph",
|
||||
collection="default",
|
||||
),
|
||||
data=blob
|
||||
)
|
||||
|
||||
prod.send(message)
|
||||
|
||||
prod.close()
|
||||
client.close()
|
||||
|
||||
|
|
@ -1,37 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import pulsar
|
||||
from pulsar.schema import JsonSchema
|
||||
import base64
|
||||
|
||||
from trustgraph.schema import TextDocument, Metadata
|
||||
|
||||
client = pulsar.Client("pulsar://localhost:6650", listener_name="localhost")
|
||||
|
||||
prod = client.create_producer(
|
||||
topic="persistent://tg/flow/text-document-load:0000",
|
||||
schema=JsonSchema(TextDocument),
|
||||
chunking_enabled=True,
|
||||
)
|
||||
|
||||
path = "../trustgraph/docs/README.cats"
|
||||
|
||||
with open(path, "r") as f:
|
||||
# blob = base64.b64encode(f.read()).decode("utf-8")
|
||||
blob = f.read()
|
||||
|
||||
message = TextDocument(
|
||||
metadata = Metadata(
|
||||
id = "00001",
|
||||
metadata = [],
|
||||
user="trustgraph",
|
||||
collection="default",
|
||||
),
|
||||
text=blob
|
||||
)
|
||||
|
||||
prod.send(message)
|
||||
|
||||
prod.close()
|
||||
client.close()
|
||||
|
||||
|
|
@ -1,35 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from langchain_huggingface import HuggingFaceEmbeddings
|
||||
|
||||
from trustgraph.direct.milvus import TripleVectors
|
||||
|
||||
client = TripleVectors()
|
||||
|
||||
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
|
||||
|
||||
text="""A cat is a small animal. A dog is a large animal.
|
||||
Cats say miaow. Dogs go woof.
|
||||
"""
|
||||
|
||||
embeds = embeddings.embed_documents([text])[0]
|
||||
|
||||
text2="""If you couldn't download the model due to network issues, as a walkaround, you can use random vectors to represent the text and still finish the example. Just note that the search result won't reflect semantic similarity as the vectors are fake ones.
|
||||
"""
|
||||
|
||||
embeds2 = embeddings.embed_documents([text2])[0]
|
||||
|
||||
client.insert(embeds, "animals")
|
||||
client.insert(embeds, "vectors")
|
||||
|
||||
query="""What noise does a cat make?"""
|
||||
|
||||
qembeds = embeddings.embed_documents([query])[0]
|
||||
|
||||
res = client.search(
|
||||
qembeds,
|
||||
limit=2
|
||||
)
|
||||
|
||||
print(res)
|
||||
|
||||
|
|
@ -1,18 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import json
|
||||
from trustgraph.clients.prompt_client import PromptClient
|
||||
|
||||
p = PromptClient(pulsar_host="pulsar://localhost:6650")
|
||||
|
||||
description = """Fred is a 4-legged cat who is 12 years old"""
|
||||
|
||||
resp = p.request(
|
||||
id="analyze",
|
||||
terms = {
|
||||
"description": description,
|
||||
}
|
||||
)
|
||||
|
||||
print(json.dumps(resp, indent=4))
|
||||
|
||||
|
|
@ -1,51 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import json
|
||||
from trustgraph.clients.prompt_client import PromptClient
|
||||
|
||||
p = PromptClient(
|
||||
pulsar_host="pulsar://localhost:6650",
|
||||
input_queue="non-persistent://tg/request/prompt:default",
|
||||
output_queue="non-persistent://tg/response/prompt:default",
|
||||
subscriber="test1",
|
||||
)
|
||||
|
||||
chunk="""
|
||||
The Space Shuttle was a reusable spacecraft that transported astronauts and cargo to and from Earth's orbit. It was designed to launch like a rocket, maneuver in orbit like a spacecraft, and land like an airplane. The Space Shuttle was NASA's space transportation system and was used for many purposes, including:
|
||||
|
||||
Carrying astronauts
|
||||
The Space Shuttle could carry up to seven astronauts at a time.
|
||||
|
||||
Launching, recovering, and repairing satellites
|
||||
The Space Shuttle could launch satellites into orbit, recover them, and repair them.
|
||||
Building the International Space Station
|
||||
The Space Shuttle carried large parts into space to build the International Space Station.
|
||||
Conducting research
|
||||
Astronauts conducted experiments in the Space Shuttle, which was like a science lab in space.
|
||||
|
||||
The Space Shuttle was retired in 2011 after the Columbia accident in 2003. The Columbia Accident Investigation Board report found that the Space Shuttle was unsafe and expensive to make safe.
|
||||
Here are some other facts about the Space Shuttle:
|
||||
|
||||
The Space Shuttle was 184 ft tall and had a diameter of 29 ft.
|
||||
|
||||
The Space Shuttle had a mass of 4,480,000 lb.
|
||||
The Space Shuttle's first flight was on April 12, 1981.
|
||||
The Space Shuttle's last mission was in 2011.
|
||||
"""
|
||||
|
||||
q = "Tell me some facts in the knowledge graph"
|
||||
|
||||
resp = p.request(
|
||||
id="extract-definitions",
|
||||
variables = {
|
||||
"text": chunk,
|
||||
}
|
||||
)
|
||||
|
||||
print(resp)
|
||||
|
||||
for fact in resp:
|
||||
print(fact["entity"], "::")
|
||||
print(fact["definition"])
|
||||
print()
|
||||
|
||||
|
|
@ -1,18 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import pulsar
|
||||
from trustgraph.clients.prompt_client import PromptClient
|
||||
|
||||
p = PromptClient(pulsar_host="pulsar://localhost:6650")
|
||||
|
||||
question = """What is the square root of 16?"""
|
||||
|
||||
resp = p.request(
|
||||
id="french-question",
|
||||
terms = {
|
||||
"question": question
|
||||
}
|
||||
)
|
||||
|
||||
print(resp)
|
||||
|
||||
|
|
@ -1,44 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import json
|
||||
from trustgraph.clients.prompt_client import PromptClient
|
||||
|
||||
p = PromptClient(pulsar_host="pulsar://localhost:6650")
|
||||
|
||||
knowledge = [
|
||||
("accident", "evoked", "a wide range of deeply felt public responses"),
|
||||
("Space Shuttle concept", "had", "genesis"),
|
||||
("Commission", "had", "a mandate to develop recommendations for corrective or other action based upon the Commission's findings and determinations"),
|
||||
("Commission", "established", "teams of persons"),
|
||||
("Space Shuttle Challenger", "http://www.w3.org/2004/02/skos/core#definition", "A space shuttle that was destroyed in an accident during mission 51-L."),
|
||||
("The mid fuselage", "contains", "the payload bay"),
|
||||
("Volume I", "contains", "Chapter IX"),
|
||||
("accident", "resulted in", "firm national resolve that those men and women be forever enshrined in the annals of American heroes"),
|
||||
("Volume I", "contains", "Chapter VII"),
|
||||
("Volume I", "contains", "Chapter II"),
|
||||
("Volume I", "contains", "Chapter V"),
|
||||
("Commission", "believes", "its investigation and report have been responsive to the request of the President and hopes that they will serve the best interests of the nation in restoring the United States space program to its preeminent position in the world"),
|
||||
("Commission", "construe", "mandate"),
|
||||
("accident", "became", "a milestone on the way to achieving the full potential that space offers to mankind"),
|
||||
("Volume I", "contains", "The Commission"),
|
||||
("Commission", "http://www.w3.org/2004/02/skos/core#definition", "A group established to investigate the space shuttle accident"),
|
||||
("Volume I", "contains", "Appendix D"),
|
||||
("Commission", "had", "a mandate to review the circumstances surrounding the accident to establish the probable cause or causes of the accident"),
|
||||
("Volume I", "contains", "Recommendations")
|
||||
]
|
||||
|
||||
q = "Tell me some facts in the knowledge graph"
|
||||
|
||||
resp = p.request(
|
||||
id="graph-query",
|
||||
terms = {
|
||||
"name": "Jayney",
|
||||
"knowledge": knowledge,
|
||||
"question": q
|
||||
}
|
||||
)
|
||||
|
||||
print(resp)
|
||||
|
||||
|
||||
|
||||
|
|
@ -1,23 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import pulsar
|
||||
from trustgraph.clients.prompt_client import PromptClient
|
||||
|
||||
p = PromptClient(
|
||||
pulsar_host="pulsar://localhost:6650",
|
||||
input_queue="non-persistent://tg/request/prompt:default",
|
||||
output_queue="non-persistent://tg/response/prompt:default",
|
||||
subscriber="test1",
|
||||
)
|
||||
|
||||
question = """What is the square root of 16?"""
|
||||
|
||||
resp = p.request(
|
||||
id="question",
|
||||
variables = {
|
||||
"question": question
|
||||
}
|
||||
)
|
||||
|
||||
print(resp)
|
||||
|
||||
|
|
@ -1,19 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import pulsar
|
||||
from trustgraph.clients.prompt_client import PromptClient
|
||||
|
||||
p = PromptClient(pulsar_host="pulsar://localhost:6650")
|
||||
|
||||
question = """What is the square root of 16?"""
|
||||
|
||||
resp = p.request(
|
||||
id="question",
|
||||
terms = {
|
||||
"question": question,
|
||||
"attitude": "Spanish-speaking bot"
|
||||
}
|
||||
)
|
||||
|
||||
print(resp)
|
||||
|
||||
|
|
@ -1,51 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import pulsar
|
||||
from trustgraph.clients.prompt_client import PromptClient
|
||||
from trustgraph.objects.object import Schema
|
||||
from trustgraph.objects.field import Field, FieldType
|
||||
|
||||
schema = Schema(
|
||||
name="actors",
|
||||
description="actors in this story",
|
||||
fields=[
|
||||
Field(
|
||||
name="name", type=FieldType.STRING,
|
||||
description="Name of the animal or person in the story"
|
||||
),
|
||||
Field(
|
||||
name="legs", type=FieldType.INT,
|
||||
description="Number of legs of the animal or person"
|
||||
),
|
||||
Field(
|
||||
name="notes", type=FieldType.STRING,
|
||||
description="Additional notes or observations about this animal or person"
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
chunk = """I noticed a cat in my garden. It is a four-legged animal
|
||||
which is a mammal and can be tame or wild. I wonder if it will be friends
|
||||
with me? I think the cat's name is Fred and it has 4 legs.
|
||||
There is also a dog barking outside. The dog has 4 legs also.
|
||||
The dog comes to my call when I shout "Come here, Bernard".
|
||||
|
||||
I am also standing in the garden, my name is Steve and I have 2 legs.
|
||||
|
||||
My friend Clifford is coming to visit shortly, he has 3 legs due to
|
||||
a freak accident at birth.
|
||||
"""
|
||||
|
||||
p = PromptClient(pulsar_host="pulsar://localhost:6650")
|
||||
|
||||
resp = p.request_rows(
|
||||
schema=schema,
|
||||
chunk=chunk,
|
||||
)
|
||||
|
||||
for d in resp:
|
||||
print(f"Name: {d['name']}")
|
||||
print(f" No. of legs: {d['legs']}")
|
||||
print(f" Notes: {d['notes']}")
|
||||
print()
|
||||
|
||||
|
|
@ -1,13 +0,0 @@
|
|||
|
||||
scripts/object-extract-row \
|
||||
-p pulsar://localhost:6650 \
|
||||
--field 'name:string:100:pri:Name of the person in the story' \
|
||||
--field 'job:string:100::Job title or role' \
|
||||
--field 'date:string:20::Date entered into role if known' \
|
||||
--field 'supervisor:string:100::Supervisor or manager of this person, if known' \
|
||||
--field 'location:string:100::Main base or location of work, if known' \
|
||||
--field 'notes:string:1000::Additional notes or observations about this animal or person' \
|
||||
--no-metrics \
|
||||
--name actors \
|
||||
--description 'Relevant people'
|
||||
|
||||
|
|
@ -1,72 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import pulsar
|
||||
from trustgraph.clients.triples_query_client import TriplesQueryClient
|
||||
|
||||
tq = TriplesQueryClient(
|
||||
pulsar_host="pulsar://localhost:6650",
|
||||
)
|
||||
|
||||
e = "http://trustgraph.ai/e/shuttle"
|
||||
|
||||
limit=3
|
||||
|
||||
def dump(resp):
|
||||
print("Response...")
|
||||
for t in resp:
|
||||
print(t.s.value, t.p.value, t.o.value)
|
||||
|
||||
print("-- * ---------------------------")
|
||||
|
||||
resp = tq.request(None, None, None, limit)
|
||||
dump(resp)
|
||||
|
||||
print("-- s ---------------------------")
|
||||
|
||||
resp = tq.request("http://trustgraph.ai/e/shuttle", None, None, limit)
|
||||
dump(resp)
|
||||
|
||||
print("-- p ---------------------------")
|
||||
|
||||
resp = tq.request(None, "http://trustgraph.ai/e/landed", None, limit)
|
||||
dump(resp)
|
||||
|
||||
print("-- o ---------------------------")
|
||||
|
||||
resp = tq.request(None, None, "President", limit)
|
||||
dump(resp)
|
||||
|
||||
print("-- sp ---------------------------")
|
||||
|
||||
resp = tq.request(
|
||||
"http://trustgraph.ai/e/shuttle", "http://trustgraph.ai/e/landed", None,
|
||||
limit
|
||||
)
|
||||
dump(resp)
|
||||
|
||||
print("-- so ---------------------------")
|
||||
|
||||
resp = tq.request(
|
||||
"http://trustgraph.ai/e/shuttle", None, "the tower",
|
||||
limit
|
||||
)
|
||||
dump(resp)
|
||||
|
||||
print("-- po ---------------------------")
|
||||
|
||||
resp = tq.request(
|
||||
None, "http://trustgraph.ai/e/landed",
|
||||
"on the concrete runway at Kennedy Space Center",
|
||||
limit
|
||||
)
|
||||
dump(resp)
|
||||
|
||||
print("-- spo ---------------------------")
|
||||
|
||||
resp = tq.request(
|
||||
"http://trustgraph.ai/e/shuttle", "http://trustgraph.ai/e/landed",
|
||||
"on the concrete runway at Kennedy Space Center",
|
||||
limit
|
||||
)
|
||||
dump(resp)
|
||||
|
||||
3
tests/unit/__init__.py
Normal file
3
tests/unit/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
"""
|
||||
Unit tests for TrustGraph services
|
||||
"""
|
||||
58
tests/unit/test_base/test_async_processor.py
Normal file
58
tests/unit/test_base/test_async_processor.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
"""
|
||||
Unit tests for trustgraph.base.async_processor
|
||||
Starting small with a single test to verify basic functionality
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.base.async_processor import AsyncProcessor
|
||||
|
||||
|
||||
class TestAsyncProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Test AsyncProcessor base class functionality"""
|
||||
|
||||
@patch('trustgraph.base.async_processor.PulsarClient')
|
||||
@patch('trustgraph.base.async_processor.Consumer')
|
||||
@patch('trustgraph.base.async_processor.ProcessorMetrics')
|
||||
@patch('trustgraph.base.async_processor.ConsumerMetrics')
|
||||
async def test_async_processor_initialization_basic(self, mock_consumer_metrics, mock_processor_metrics,
|
||||
mock_consumer, mock_pulsar_client):
|
||||
"""Test basic AsyncProcessor initialization"""
|
||||
# Arrange
|
||||
mock_pulsar_client.return_value = MagicMock()
|
||||
mock_consumer.return_value = MagicMock()
|
||||
mock_processor_metrics.return_value = MagicMock()
|
||||
mock_consumer_metrics.return_value = MagicMock()
|
||||
|
||||
config = {
|
||||
'id': 'test-async-processor',
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = AsyncProcessor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify basic attributes are set
|
||||
assert processor.id == 'test-async-processor'
|
||||
assert processor.taskgroup == config['taskgroup']
|
||||
assert processor.running == True
|
||||
assert hasattr(processor, 'config_handlers')
|
||||
assert processor.config_handlers == []
|
||||
|
||||
# Verify PulsarClient was created
|
||||
mock_pulsar_client.assert_called_once_with(**config)
|
||||
|
||||
# Verify metrics were initialized
|
||||
mock_processor_metrics.assert_called_once()
|
||||
mock_consumer_metrics.assert_called_once()
|
||||
|
||||
# Verify Consumer was created for config subscription
|
||||
mock_consumer.assert_called_once()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
347
tests/unit/test_base/test_flow_processor.py
Normal file
347
tests/unit/test_base/test_flow_processor.py
Normal file
|
|
@ -0,0 +1,347 @@
|
|||
"""
|
||||
Unit tests for trustgraph.base.flow_processor
|
||||
Starting small with a single test to verify basic functionality
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.base.flow_processor import FlowProcessor
|
||||
|
||||
|
||||
class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Test FlowProcessor base class functionality"""
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
|
||||
async def test_flow_processor_initialization_basic(self, mock_register_config, mock_async_init):
|
||||
"""Test basic FlowProcessor initialization"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_register_config.return_value = None
|
||||
|
||||
config = {
|
||||
'id': 'test-flow-processor',
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = FlowProcessor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify AsyncProcessor.__init__ was called
|
||||
mock_async_init.assert_called_once()
|
||||
|
||||
# Verify register_config_handler was called with the correct handler
|
||||
mock_register_config.assert_called_once_with(processor.on_configure_flows)
|
||||
|
||||
# Verify FlowProcessor-specific initialization
|
||||
assert hasattr(processor, 'flows')
|
||||
assert processor.flows == {}
|
||||
assert hasattr(processor, 'specifications')
|
||||
assert processor.specifications == []
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
|
||||
async def test_register_specification(self, mock_register_config, mock_async_init):
|
||||
"""Test registering a specification"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_register_config.return_value = None
|
||||
|
||||
config = {
|
||||
'id': 'test-flow-processor',
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = FlowProcessor(**config)
|
||||
mock_spec = MagicMock()
|
||||
mock_spec.name = 'test-spec'
|
||||
|
||||
# Act
|
||||
processor.register_specification(mock_spec)
|
||||
|
||||
# Assert
|
||||
assert len(processor.specifications) == 1
|
||||
assert processor.specifications[0] == mock_spec
|
||||
|
||||
@patch('trustgraph.base.flow_processor.Flow')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
|
||||
async def test_start_flow(self, mock_register_config, mock_async_init, mock_flow_class):
|
||||
"""Test starting a flow"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_register_config.return_value = None
|
||||
|
||||
config = {
|
||||
'id': 'test-flow-processor',
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = FlowProcessor(**config)
|
||||
processor.id = 'test-processor' # Set id for Flow creation
|
||||
|
||||
mock_flow = AsyncMock()
|
||||
mock_flow_class.return_value = mock_flow
|
||||
|
||||
flow_name = 'test-flow'
|
||||
flow_defn = {'config': 'test-config'}
|
||||
|
||||
# Act
|
||||
await processor.start_flow(flow_name, flow_defn)
|
||||
|
||||
# Assert
|
||||
assert flow_name in processor.flows
|
||||
# Verify Flow was created with correct parameters
|
||||
mock_flow_class.assert_called_once_with('test-processor', flow_name, processor, flow_defn)
|
||||
# Verify the flow's start method was called
|
||||
mock_flow.start.assert_called_once()
|
||||
|
||||
@patch('trustgraph.base.flow_processor.Flow')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
|
||||
async def test_stop_flow(self, mock_register_config, mock_async_init, mock_flow_class):
|
||||
"""Test stopping a flow"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_register_config.return_value = None
|
||||
|
||||
config = {
|
||||
'id': 'test-flow-processor',
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = FlowProcessor(**config)
|
||||
processor.id = 'test-processor'
|
||||
|
||||
mock_flow = AsyncMock()
|
||||
mock_flow_class.return_value = mock_flow
|
||||
|
||||
flow_name = 'test-flow'
|
||||
flow_defn = {'config': 'test-config'}
|
||||
|
||||
# Start a flow first
|
||||
await processor.start_flow(flow_name, flow_defn)
|
||||
|
||||
# Act
|
||||
await processor.stop_flow(flow_name)
|
||||
|
||||
# Assert
|
||||
assert flow_name not in processor.flows
|
||||
mock_flow.stop.assert_called_once()
|
||||
|
||||
@patch('trustgraph.base.flow_processor.Flow')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
|
||||
async def test_stop_flow_not_exists(self, mock_register_config, mock_async_init, mock_flow_class):
|
||||
"""Test stopping a flow that doesn't exist"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_register_config.return_value = None
|
||||
|
||||
config = {
|
||||
'id': 'test-flow-processor',
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = FlowProcessor(**config)
|
||||
|
||||
# Act - should not raise an exception
|
||||
await processor.stop_flow('non-existent-flow')
|
||||
|
||||
# Assert - flows dict should still be empty
|
||||
assert processor.flows == {}
|
||||
|
||||
@patch('trustgraph.base.flow_processor.Flow')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
|
||||
async def test_on_configure_flows_basic(self, mock_register_config, mock_async_init, mock_flow_class):
|
||||
"""Test basic flow configuration handling"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_register_config.return_value = None
|
||||
|
||||
config = {
|
||||
'id': 'test-flow-processor',
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = FlowProcessor(**config)
|
||||
processor.id = 'test-processor'
|
||||
|
||||
mock_flow = AsyncMock()
|
||||
mock_flow_class.return_value = mock_flow
|
||||
|
||||
# Configuration with flows for this processor
|
||||
flow_config = {
|
||||
'test-flow': {'config': 'test-config'}
|
||||
}
|
||||
config_data = {
|
||||
'flows-active': {
|
||||
'test-processor': '{"test-flow": {"config": "test-config"}}'
|
||||
}
|
||||
}
|
||||
|
||||
# Act
|
||||
await processor.on_configure_flows(config_data, version=1)
|
||||
|
||||
# Assert
|
||||
assert 'test-flow' in processor.flows
|
||||
mock_flow_class.assert_called_once_with('test-processor', 'test-flow', processor, {'config': 'test-config'})
|
||||
mock_flow.start.assert_called_once()
|
||||
|
||||
@patch('trustgraph.base.flow_processor.Flow')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
|
||||
async def test_on_configure_flows_no_config(self, mock_register_config, mock_async_init, mock_flow_class):
|
||||
"""Test flow configuration handling when no config exists for this processor"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_register_config.return_value = None
|
||||
|
||||
config = {
|
||||
'id': 'test-flow-processor',
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = FlowProcessor(**config)
|
||||
processor.id = 'test-processor'
|
||||
|
||||
# Configuration without flows for this processor
|
||||
config_data = {
|
||||
'flows-active': {
|
||||
'other-processor': '{"other-flow": {"config": "other-config"}}'
|
||||
}
|
||||
}
|
||||
|
||||
# Act
|
||||
await processor.on_configure_flows(config_data, version=1)
|
||||
|
||||
# Assert
|
||||
assert processor.flows == {}
|
||||
mock_flow_class.assert_not_called()
|
||||
|
||||
@patch('trustgraph.base.flow_processor.Flow')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
|
||||
async def test_on_configure_flows_invalid_config(self, mock_register_config, mock_async_init, mock_flow_class):
|
||||
"""Test flow configuration handling with invalid config format"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_register_config.return_value = None
|
||||
|
||||
config = {
|
||||
'id': 'test-flow-processor',
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = FlowProcessor(**config)
|
||||
processor.id = 'test-processor'
|
||||
|
||||
# Configuration without flows-active key
|
||||
config_data = {
|
||||
'other-data': 'some-value'
|
||||
}
|
||||
|
||||
# Act
|
||||
await processor.on_configure_flows(config_data, version=1)
|
||||
|
||||
# Assert
|
||||
assert processor.flows == {}
|
||||
mock_flow_class.assert_not_called()
|
||||
|
||||
@patch('trustgraph.base.flow_processor.Flow')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
|
||||
async def test_on_configure_flows_start_and_stop(self, mock_register_config, mock_async_init, mock_flow_class):
|
||||
"""Test flow configuration handling with starting and stopping flows"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_register_config.return_value = None
|
||||
|
||||
config = {
|
||||
'id': 'test-flow-processor',
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = FlowProcessor(**config)
|
||||
processor.id = 'test-processor'
|
||||
|
||||
mock_flow1 = AsyncMock()
|
||||
mock_flow2 = AsyncMock()
|
||||
mock_flow_class.side_effect = [mock_flow1, mock_flow2]
|
||||
|
||||
# First configuration - start flow1
|
||||
config_data1 = {
|
||||
'flows-active': {
|
||||
'test-processor': '{"flow1": {"config": "config1"}}'
|
||||
}
|
||||
}
|
||||
|
||||
await processor.on_configure_flows(config_data1, version=1)
|
||||
|
||||
# Second configuration - stop flow1, start flow2
|
||||
config_data2 = {
|
||||
'flows-active': {
|
||||
'test-processor': '{"flow2": {"config": "config2"}}'
|
||||
}
|
||||
}
|
||||
|
||||
# Act
|
||||
await processor.on_configure_flows(config_data2, version=2)
|
||||
|
||||
# Assert
|
||||
# flow1 should be stopped and removed
|
||||
assert 'flow1' not in processor.flows
|
||||
mock_flow1.stop.assert_called_once()
|
||||
|
||||
# flow2 should be started and added
|
||||
assert 'flow2' in processor.flows
|
||||
mock_flow2.start.assert_called_once()
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.start')
|
||||
async def test_start_calls_parent(self, mock_parent_start, mock_register_config, mock_async_init):
|
||||
"""Test that start() calls parent start method"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_register_config.return_value = None
|
||||
mock_parent_start.return_value = None
|
||||
|
||||
config = {
|
||||
'id': 'test-flow-processor',
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = FlowProcessor(**config)
|
||||
|
||||
# Act
|
||||
await processor.start()
|
||||
|
||||
# Assert
|
||||
mock_parent_start.assert_called_once()
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
|
||||
async def test_add_args_calls_parent(self, mock_register_config, mock_async_init):
|
||||
"""Test that add_args() calls parent add_args method"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_register_config.return_value = None
|
||||
|
||||
mock_parser = MagicMock()
|
||||
|
||||
# Act
|
||||
with patch('trustgraph.base.async_processor.AsyncProcessor.add_args') as mock_parent_add_args:
|
||||
FlowProcessor.add_args(mock_parser)
|
||||
|
||||
# Assert
|
||||
mock_parent_add_args.assert_called_once_with(mock_parser)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
69
tests/unit/test_gateway/test_auth.py
Normal file
69
tests/unit/test_gateway/test_auth.py
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
"""
|
||||
Tests for Gateway Authentication
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from trustgraph.gateway.auth import Authenticator
|
||||
|
||||
|
||||
class TestAuthenticator:
|
||||
"""Test cases for Authenticator class"""
|
||||
|
||||
def test_authenticator_initialization_with_token(self):
|
||||
"""Test Authenticator initialization with valid token"""
|
||||
auth = Authenticator(token="test-token-123")
|
||||
|
||||
assert auth.token == "test-token-123"
|
||||
assert auth.allow_all is False
|
||||
|
||||
def test_authenticator_initialization_with_allow_all(self):
|
||||
"""Test Authenticator initialization with allow_all=True"""
|
||||
auth = Authenticator(allow_all=True)
|
||||
|
||||
assert auth.token is None
|
||||
assert auth.allow_all is True
|
||||
|
||||
def test_authenticator_initialization_without_token_raises_error(self):
|
||||
"""Test Authenticator initialization without token raises RuntimeError"""
|
||||
with pytest.raises(RuntimeError, match="Need a token"):
|
||||
Authenticator()
|
||||
|
||||
def test_authenticator_initialization_with_empty_token_raises_error(self):
|
||||
"""Test Authenticator initialization with empty token raises RuntimeError"""
|
||||
with pytest.raises(RuntimeError, match="Need a token"):
|
||||
Authenticator(token="")
|
||||
|
||||
def test_permitted_with_allow_all_returns_true(self):
|
||||
"""Test permitted method returns True when allow_all is enabled"""
|
||||
auth = Authenticator(allow_all=True)
|
||||
|
||||
# Should return True regardless of token or roles
|
||||
assert auth.permitted("any-token", []) is True
|
||||
assert auth.permitted("different-token", ["admin"]) is True
|
||||
assert auth.permitted(None, ["user"]) is True
|
||||
|
||||
def test_permitted_with_matching_token_returns_true(self):
|
||||
"""Test permitted method returns True with matching token"""
|
||||
auth = Authenticator(token="secret-token")
|
||||
|
||||
# Should return True when tokens match
|
||||
assert auth.permitted("secret-token", []) is True
|
||||
assert auth.permitted("secret-token", ["admin", "user"]) is True
|
||||
|
||||
def test_permitted_with_non_matching_token_returns_false(self):
|
||||
"""Test permitted method returns False with non-matching token"""
|
||||
auth = Authenticator(token="secret-token")
|
||||
|
||||
# Should return False when tokens don't match
|
||||
assert auth.permitted("wrong-token", []) is False
|
||||
assert auth.permitted("different-token", ["admin"]) is False
|
||||
assert auth.permitted(None, ["user"]) is False
|
||||
|
||||
def test_permitted_with_token_and_allow_all_returns_true(self):
|
||||
"""Test permitted method with both token and allow_all set"""
|
||||
auth = Authenticator(token="test-token", allow_all=True)
|
||||
|
||||
# allow_all should take precedence
|
||||
assert auth.permitted("any-token", []) is True
|
||||
assert auth.permitted("wrong-token", ["admin"]) is True
|
||||
408
tests/unit/test_gateway/test_config_receiver.py
Normal file
408
tests/unit/test_gateway/test_config_receiver.py
Normal file
|
|
@ -0,0 +1,408 @@
|
|||
"""
|
||||
Tests for Gateway Config Receiver
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import json
|
||||
from unittest.mock import Mock, patch, Mock, MagicMock
|
||||
import uuid
|
||||
|
||||
from trustgraph.gateway.config.receiver import ConfigReceiver
|
||||
|
||||
# Save the real method before patching
|
||||
_real_config_loader = ConfigReceiver.config_loader
|
||||
|
||||
# Patch async methods at module level to prevent coroutine warnings
|
||||
ConfigReceiver.config_loader = Mock()
|
||||
|
||||
|
||||
class TestConfigReceiver:
|
||||
"""Test cases for ConfigReceiver class"""
|
||||
|
||||
def test_config_receiver_initialization(self):
|
||||
"""Test ConfigReceiver initialization"""
|
||||
mock_pulsar_client = Mock()
|
||||
|
||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
||||
|
||||
assert config_receiver.pulsar_client == mock_pulsar_client
|
||||
assert config_receiver.flow_handlers == []
|
||||
assert config_receiver.flows == {}
|
||||
|
||||
def test_add_handler(self):
|
||||
"""Test adding flow handlers"""
|
||||
mock_pulsar_client = Mock()
|
||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
||||
|
||||
handler1 = Mock()
|
||||
handler2 = Mock()
|
||||
|
||||
config_receiver.add_handler(handler1)
|
||||
config_receiver.add_handler(handler2)
|
||||
|
||||
assert len(config_receiver.flow_handlers) == 2
|
||||
assert handler1 in config_receiver.flow_handlers
|
||||
assert handler2 in config_receiver.flow_handlers
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_with_new_flows(self):
|
||||
"""Test on_config method with new flows"""
|
||||
mock_pulsar_client = Mock()
|
||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
||||
|
||||
# Track calls manually instead of using AsyncMock
|
||||
start_flow_calls = []
|
||||
|
||||
async def mock_start_flow(*args):
|
||||
start_flow_calls.append(args)
|
||||
|
||||
config_receiver.start_flow = mock_start_flow
|
||||
|
||||
# Create mock message with flows
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.return_value = Mock(
|
||||
version="1.0",
|
||||
config={
|
||||
"flows": {
|
||||
"flow1": '{"name": "test_flow_1", "steps": []}',
|
||||
"flow2": '{"name": "test_flow_2", "steps": []}'
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
await config_receiver.on_config(mock_msg, None, None)
|
||||
|
||||
# Verify flows were added
|
||||
assert "flow1" in config_receiver.flows
|
||||
assert "flow2" in config_receiver.flows
|
||||
assert config_receiver.flows["flow1"] == {"name": "test_flow_1", "steps": []}
|
||||
assert config_receiver.flows["flow2"] == {"name": "test_flow_2", "steps": []}
|
||||
|
||||
# Verify start_flow was called for each new flow
|
||||
assert len(start_flow_calls) == 2
|
||||
assert ("flow1", {"name": "test_flow_1", "steps": []}) in start_flow_calls
|
||||
assert ("flow2", {"name": "test_flow_2", "steps": []}) in start_flow_calls
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_with_removed_flows(self):
|
||||
"""Test on_config method with removed flows"""
|
||||
mock_pulsar_client = Mock()
|
||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
||||
|
||||
# Pre-populate with existing flows
|
||||
config_receiver.flows = {
|
||||
"flow1": {"name": "test_flow_1", "steps": []},
|
||||
"flow2": {"name": "test_flow_2", "steps": []}
|
||||
}
|
||||
|
||||
# Track calls manually instead of using AsyncMock
|
||||
stop_flow_calls = []
|
||||
|
||||
async def mock_stop_flow(*args):
|
||||
stop_flow_calls.append(args)
|
||||
|
||||
config_receiver.stop_flow = mock_stop_flow
|
||||
|
||||
# Create mock message with only flow1 (flow2 removed)
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.return_value = Mock(
|
||||
version="1.0",
|
||||
config={
|
||||
"flows": {
|
||||
"flow1": '{"name": "test_flow_1", "steps": []}'
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
await config_receiver.on_config(mock_msg, None, None)
|
||||
|
||||
# Verify flow2 was removed
|
||||
assert "flow1" in config_receiver.flows
|
||||
assert "flow2" not in config_receiver.flows
|
||||
|
||||
# Verify stop_flow was called for removed flow
|
||||
assert len(stop_flow_calls) == 1
|
||||
assert stop_flow_calls[0] == ("flow2", {"name": "test_flow_2", "steps": []})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_with_no_flows(self):
|
||||
"""Test on_config method with no flows in config"""
|
||||
mock_pulsar_client = Mock()
|
||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
||||
|
||||
# Mock the start_flow and stop_flow methods with async functions
|
||||
async def mock_start_flow(*args):
|
||||
pass
|
||||
async def mock_stop_flow(*args):
|
||||
pass
|
||||
config_receiver.start_flow = mock_start_flow
|
||||
config_receiver.stop_flow = mock_stop_flow
|
||||
|
||||
# Create mock message without flows
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.return_value = Mock(
|
||||
version="1.0",
|
||||
config={}
|
||||
)
|
||||
|
||||
await config_receiver.on_config(mock_msg, None, None)
|
||||
|
||||
# Verify no flows were added
|
||||
assert config_receiver.flows == {}
|
||||
|
||||
# Since no flows were in the config, the flow methods shouldn't be called
|
||||
# (We can't easily assert this with simple async functions, but the test
|
||||
# passes if no exceptions are thrown)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_exception_handling(self):
|
||||
"""Test on_config method handles exceptions gracefully"""
|
||||
mock_pulsar_client = Mock()
|
||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
||||
|
||||
# Create mock message that will cause an exception
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.side_effect = Exception("Test exception")
|
||||
|
||||
# This should not raise an exception
|
||||
await config_receiver.on_config(mock_msg, None, None)
|
||||
|
||||
# Verify flows remain empty
|
||||
assert config_receiver.flows == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_flow_with_handlers(self):
|
||||
"""Test start_flow method with multiple handlers"""
|
||||
mock_pulsar_client = Mock()
|
||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
||||
|
||||
# Add mock handlers
|
||||
handler1 = Mock()
|
||||
handler1.start_flow = Mock()
|
||||
handler2 = Mock()
|
||||
handler2.start_flow = Mock()
|
||||
|
||||
config_receiver.add_handler(handler1)
|
||||
config_receiver.add_handler(handler2)
|
||||
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
|
||||
await config_receiver.start_flow("flow1", flow_data)
|
||||
|
||||
# Verify all handlers were called
|
||||
handler1.start_flow.assert_called_once_with("flow1", flow_data)
|
||||
handler2.start_flow.assert_called_once_with("flow1", flow_data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_flow_with_handler_exception(self):
|
||||
"""Test start_flow method handles handler exceptions"""
|
||||
mock_pulsar_client = Mock()
|
||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
||||
|
||||
# Add mock handler that raises exception
|
||||
handler = Mock()
|
||||
handler.start_flow = Mock(side_effect=Exception("Handler error"))
|
||||
|
||||
config_receiver.add_handler(handler)
|
||||
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
|
||||
# This should not raise an exception
|
||||
await config_receiver.start_flow("flow1", flow_data)
|
||||
|
||||
# Verify handler was called
|
||||
handler.start_flow.assert_called_once_with("flow1", flow_data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_flow_with_handlers(self):
|
||||
"""Test stop_flow method with multiple handlers"""
|
||||
mock_pulsar_client = Mock()
|
||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
||||
|
||||
# Add mock handlers
|
||||
handler1 = Mock()
|
||||
handler1.stop_flow = Mock()
|
||||
handler2 = Mock()
|
||||
handler2.stop_flow = Mock()
|
||||
|
||||
config_receiver.add_handler(handler1)
|
||||
config_receiver.add_handler(handler2)
|
||||
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
|
||||
await config_receiver.stop_flow("flow1", flow_data)
|
||||
|
||||
# Verify all handlers were called
|
||||
handler1.stop_flow.assert_called_once_with("flow1", flow_data)
|
||||
handler2.stop_flow.assert_called_once_with("flow1", flow_data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_flow_with_handler_exception(self):
|
||||
"""Test stop_flow method handles handler exceptions"""
|
||||
mock_pulsar_client = Mock()
|
||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
||||
|
||||
# Add mock handler that raises exception
|
||||
handler = Mock()
|
||||
handler.stop_flow = Mock(side_effect=Exception("Handler error"))
|
||||
|
||||
config_receiver.add_handler(handler)
|
||||
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
|
||||
# This should not raise an exception
|
||||
await config_receiver.stop_flow("flow1", flow_data)
|
||||
|
||||
# Verify handler was called
|
||||
handler.stop_flow.assert_called_once_with("flow1", flow_data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_config_loader_creates_consumer(self):
|
||||
"""Test config_loader method creates Pulsar consumer"""
|
||||
mock_pulsar_client = Mock()
|
||||
|
||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
||||
# Temporarily restore the real config_loader for this test
|
||||
config_receiver.config_loader = _real_config_loader.__get__(config_receiver)
|
||||
|
||||
# Mock Consumer class
|
||||
with patch('trustgraph.gateway.config.receiver.Consumer') as mock_consumer_class, \
|
||||
patch('uuid.uuid4') as mock_uuid:
|
||||
|
||||
mock_uuid.return_value = "test-uuid"
|
||||
mock_consumer = Mock()
|
||||
async def mock_start():
|
||||
pass
|
||||
mock_consumer.start = mock_start
|
||||
mock_consumer_class.return_value = mock_consumer
|
||||
|
||||
# Create a task that will complete quickly
|
||||
async def quick_task():
|
||||
await config_receiver.config_loader()
|
||||
|
||||
# Run the task with a timeout to prevent hanging
|
||||
try:
|
||||
await asyncio.wait_for(quick_task(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
# This is expected since the method runs indefinitely
|
||||
pass
|
||||
|
||||
# Verify Consumer was created with correct parameters
|
||||
mock_consumer_class.assert_called_once()
|
||||
call_args = mock_consumer_class.call_args
|
||||
|
||||
assert call_args[1]['client'] == mock_pulsar_client
|
||||
assert call_args[1]['subscriber'] == "gateway-test-uuid"
|
||||
assert call_args[1]['handler'] == config_receiver.on_config
|
||||
assert call_args[1]['start_of_messages'] is True
|
||||
|
||||
@patch('asyncio.create_task')
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_creates_config_loader_task(self, mock_create_task):
|
||||
"""Test start method creates config loader task"""
|
||||
mock_pulsar_client = Mock()
|
||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
||||
|
||||
# Mock create_task to avoid actually creating tasks with real coroutines
|
||||
mock_task = Mock()
|
||||
mock_create_task.return_value = mock_task
|
||||
|
||||
await config_receiver.start()
|
||||
|
||||
# Verify task was created
|
||||
mock_create_task.assert_called_once()
|
||||
|
||||
# Verify the argument passed to create_task is a coroutine
|
||||
call_args = mock_create_task.call_args[0]
|
||||
assert len(call_args) == 1 # Should have one argument (the coroutine)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_mixed_flow_operations(self):
|
||||
"""Test on_config with mixed add/remove operations"""
|
||||
mock_pulsar_client = Mock()
|
||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
||||
|
||||
# Pre-populate with existing flows
|
||||
config_receiver.flows = {
|
||||
"flow1": {"name": "test_flow_1", "steps": []},
|
||||
"flow2": {"name": "test_flow_2", "steps": []}
|
||||
}
|
||||
|
||||
# Track calls manually instead of using Mock
|
||||
start_flow_calls = []
|
||||
stop_flow_calls = []
|
||||
|
||||
async def mock_start_flow(*args):
|
||||
start_flow_calls.append(args)
|
||||
|
||||
async def mock_stop_flow(*args):
|
||||
stop_flow_calls.append(args)
|
||||
|
||||
# Directly assign to avoid patch.object detecting async methods
|
||||
original_start_flow = config_receiver.start_flow
|
||||
original_stop_flow = config_receiver.stop_flow
|
||||
config_receiver.start_flow = mock_start_flow
|
||||
config_receiver.stop_flow = mock_stop_flow
|
||||
|
||||
try:
|
||||
|
||||
# Create mock message with flow1 removed and flow3 added
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.return_value = Mock(
|
||||
version="1.0",
|
||||
config={
|
||||
"flows": {
|
||||
"flow2": '{"name": "test_flow_2", "steps": []}',
|
||||
"flow3": '{"name": "test_flow_3", "steps": []}'
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
await config_receiver.on_config(mock_msg, None, None)
|
||||
|
||||
# Verify final state
|
||||
assert "flow1" not in config_receiver.flows
|
||||
assert "flow2" in config_receiver.flows
|
||||
assert "flow3" in config_receiver.flows
|
||||
|
||||
# Verify operations
|
||||
assert len(start_flow_calls) == 1
|
||||
assert start_flow_calls[0] == ("flow3", {"name": "test_flow_3", "steps": []})
|
||||
assert len(stop_flow_calls) == 1
|
||||
assert stop_flow_calls[0] == ("flow1", {"name": "test_flow_1", "steps": []})
|
||||
|
||||
finally:
|
||||
# Restore original methods
|
||||
config_receiver.start_flow = original_start_flow
|
||||
config_receiver.stop_flow = original_stop_flow
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_invalid_json_flow_data(self):
|
||||
"""Test on_config handles invalid JSON in flow data"""
|
||||
mock_pulsar_client = Mock()
|
||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
||||
|
||||
# Mock the start_flow method with an async function
|
||||
async def mock_start_flow(*args):
|
||||
pass
|
||||
config_receiver.start_flow = mock_start_flow
|
||||
|
||||
# Create mock message with invalid JSON
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.return_value = Mock(
|
||||
version="1.0",
|
||||
config={
|
||||
"flows": {
|
||||
"flow1": '{"invalid": json}', # Invalid JSON
|
||||
"flow2": '{"name": "valid_flow", "steps": []}' # Valid JSON
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# This should handle the exception gracefully
|
||||
await config_receiver.on_config(mock_msg, None, None)
|
||||
|
||||
# The entire operation should fail due to JSON parsing error
|
||||
# So no flows should be added
|
||||
assert config_receiver.flows == {}
|
||||
93
tests/unit/test_gateway/test_dispatch_config.py
Normal file
93
tests/unit/test_gateway/test_dispatch_config.py
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
"""
|
||||
Tests for Gateway Config Dispatch
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch, AsyncMock, Mock
|
||||
|
||||
from trustgraph.gateway.dispatch.config import ConfigRequestor
|
||||
|
||||
# Import parent class for local patching
|
||||
from trustgraph.gateway.dispatch.requestor import ServiceRequestor
|
||||
|
||||
|
||||
class TestConfigRequestor:
|
||||
"""Test cases for ConfigRequestor class"""
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.config.TranslatorRegistry')
|
||||
def test_config_requestor_initialization(self, mock_translator_registry):
|
||||
"""Test ConfigRequestor initialization"""
|
||||
# Mock translators
|
||||
mock_request_translator = Mock()
|
||||
mock_response_translator = Mock()
|
||||
mock_translator_registry.get_request_translator.return_value = mock_request_translator
|
||||
mock_translator_registry.get_response_translator.return_value = mock_response_translator
|
||||
|
||||
# Mock dependencies
|
||||
mock_pulsar_client = Mock()
|
||||
|
||||
requestor = ConfigRequestor(
|
||||
pulsar_client=mock_pulsar_client,
|
||||
consumer="test-consumer",
|
||||
subscriber="test-subscriber",
|
||||
timeout=60
|
||||
)
|
||||
|
||||
# Verify translator setup
|
||||
mock_translator_registry.get_request_translator.assert_called_once_with("config")
|
||||
mock_translator_registry.get_response_translator.assert_called_once_with("config")
|
||||
|
||||
assert requestor.request_translator == mock_request_translator
|
||||
assert requestor.response_translator == mock_response_translator
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.config.TranslatorRegistry')
|
||||
def test_config_requestor_to_request(self, mock_translator_registry):
|
||||
"""Test ConfigRequestor to_request method"""
|
||||
# Mock translators
|
||||
mock_request_translator = Mock()
|
||||
mock_translator_registry.get_request_translator.return_value = mock_request_translator
|
||||
mock_translator_registry.get_response_translator.return_value = Mock()
|
||||
|
||||
# Setup translator response
|
||||
mock_request_translator.to_pulsar.return_value = "translated_request"
|
||||
|
||||
# Patch ServiceRequestor async methods with regular mocks (not AsyncMock)
|
||||
with patch.object(ServiceRequestor, 'start', return_value=None), \
|
||||
patch.object(ServiceRequestor, 'process', return_value=None):
|
||||
requestor = ConfigRequestor(
|
||||
pulsar_client=Mock(),
|
||||
consumer="test-consumer",
|
||||
subscriber="test-subscriber"
|
||||
)
|
||||
|
||||
# Call to_request
|
||||
result = requestor.to_request({"test": "body"})
|
||||
|
||||
# Verify translator was called correctly
|
||||
mock_request_translator.to_pulsar.assert_called_once_with({"test": "body"})
|
||||
assert result == "translated_request"
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.config.TranslatorRegistry')
|
||||
def test_config_requestor_from_response(self, mock_translator_registry):
|
||||
"""Test ConfigRequestor from_response method"""
|
||||
# Mock translators
|
||||
mock_response_translator = Mock()
|
||||
mock_translator_registry.get_request_translator.return_value = Mock()
|
||||
mock_translator_registry.get_response_translator.return_value = mock_response_translator
|
||||
|
||||
# Setup translator response
|
||||
mock_response_translator.from_response_with_completion.return_value = "translated_response"
|
||||
|
||||
requestor = ConfigRequestor(
|
||||
pulsar_client=Mock(),
|
||||
consumer="test-consumer",
|
||||
subscriber="test-subscriber"
|
||||
)
|
||||
|
||||
# Call from_response
|
||||
mock_message = Mock()
|
||||
result = requestor.from_response(mock_message)
|
||||
|
||||
# Verify translator was called correctly
|
||||
mock_response_translator.from_response_with_completion.assert_called_once_with(mock_message)
|
||||
assert result == "translated_response"
|
||||
558
tests/unit/test_gateway/test_dispatch_manager.py
Normal file
558
tests/unit/test_gateway/test_dispatch_manager.py
Normal file
|
|
@ -0,0 +1,558 @@
|
|||
"""
|
||||
Tests for Gateway Dispatcher Manager
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import Mock, patch, AsyncMock, MagicMock
|
||||
import uuid
|
||||
|
||||
from trustgraph.gateway.dispatch.manager import DispatcherManager, DispatcherWrapper
|
||||
|
||||
# Keep the real methods intact for proper testing
|
||||
|
||||
|
||||
class TestDispatcherWrapper:
|
||||
"""Test cases for DispatcherWrapper class"""
|
||||
|
||||
def test_dispatcher_wrapper_initialization(self):
|
||||
"""Test DispatcherWrapper initialization"""
|
||||
mock_handler = Mock()
|
||||
wrapper = DispatcherWrapper(mock_handler)
|
||||
|
||||
assert wrapper.handler == mock_handler
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatcher_wrapper_process(self):
|
||||
"""Test DispatcherWrapper process method"""
|
||||
mock_handler = AsyncMock()
|
||||
wrapper = DispatcherWrapper(mock_handler)
|
||||
|
||||
result = await wrapper.process("arg1", "arg2")
|
||||
|
||||
mock_handler.assert_called_once_with("arg1", "arg2")
|
||||
assert result == mock_handler.return_value
|
||||
|
||||
|
||||
class TestDispatcherManager:
|
||||
"""Test cases for DispatcherManager class"""
|
||||
|
||||
def test_dispatcher_manager_initialization(self):
|
||||
"""Test DispatcherManager initialization"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
assert manager.pulsar_client == mock_pulsar_client
|
||||
assert manager.config_receiver == mock_config_receiver
|
||||
assert manager.prefix == "api-gateway" # default prefix
|
||||
assert manager.flows == {}
|
||||
assert manager.dispatchers == {}
|
||||
|
||||
# Verify manager was added as handler to config receiver
|
||||
mock_config_receiver.add_handler.assert_called_once_with(manager)
|
||||
|
||||
def test_dispatcher_manager_initialization_with_custom_prefix(self):
|
||||
"""Test DispatcherManager initialization with custom prefix"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver, prefix="custom-prefix")
|
||||
|
||||
assert manager.prefix == "custom-prefix"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_flow(self):
|
||||
"""Test start_flow method"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
|
||||
await manager.start_flow("flow1", flow_data)
|
||||
|
||||
assert "flow1" in manager.flows
|
||||
assert manager.flows["flow1"] == flow_data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_flow(self):
|
||||
"""Test stop_flow method"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
# Pre-populate with a flow
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
manager.flows["flow1"] = flow_data
|
||||
|
||||
await manager.stop_flow("flow1", flow_data)
|
||||
|
||||
assert "flow1" not in manager.flows
|
||||
|
||||
def test_dispatch_global_service_returns_wrapper(self):
|
||||
"""Test dispatch_global_service returns DispatcherWrapper"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
wrapper = manager.dispatch_global_service()
|
||||
|
||||
assert isinstance(wrapper, DispatcherWrapper)
|
||||
assert wrapper.handler == manager.process_global_service
|
||||
|
||||
def test_dispatch_core_export_returns_wrapper(self):
|
||||
"""Test dispatch_core_export returns DispatcherWrapper"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
wrapper = manager.dispatch_core_export()
|
||||
|
||||
assert isinstance(wrapper, DispatcherWrapper)
|
||||
assert wrapper.handler == manager.process_core_export
|
||||
|
||||
def test_dispatch_core_import_returns_wrapper(self):
|
||||
"""Test dispatch_core_import returns DispatcherWrapper"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
wrapper = manager.dispatch_core_import()
|
||||
|
||||
assert isinstance(wrapper, DispatcherWrapper)
|
||||
assert wrapper.handler == manager.process_core_import
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_core_import(self):
|
||||
"""Test process_core_import method"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.CoreImport') as mock_core_import:
|
||||
mock_importer = Mock()
|
||||
mock_importer.process = AsyncMock(return_value="import_result")
|
||||
mock_core_import.return_value = mock_importer
|
||||
|
||||
result = await manager.process_core_import("data", "error", "ok", "request")
|
||||
|
||||
mock_core_import.assert_called_once_with(mock_pulsar_client)
|
||||
mock_importer.process.assert_called_once_with("data", "error", "ok", "request")
|
||||
assert result == "import_result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_core_export(self):
|
||||
"""Test process_core_export method"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.CoreExport') as mock_core_export:
|
||||
mock_exporter = Mock()
|
||||
mock_exporter.process = AsyncMock(return_value="export_result")
|
||||
mock_core_export.return_value = mock_exporter
|
||||
|
||||
result = await manager.process_core_export("data", "error", "ok", "request")
|
||||
|
||||
mock_core_export.assert_called_once_with(mock_pulsar_client)
|
||||
mock_exporter.process.assert_called_once_with("data", "error", "ok", "request")
|
||||
assert result == "export_result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_global_service(self):
|
||||
"""Test process_global_service method"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
manager.invoke_global_service = AsyncMock(return_value="global_result")
|
||||
|
||||
params = {"kind": "test_kind"}
|
||||
result = await manager.process_global_service("data", "responder", params)
|
||||
|
||||
manager.invoke_global_service.assert_called_once_with("data", "responder", "test_kind")
|
||||
assert result == "global_result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_global_service_with_existing_dispatcher(self):
|
||||
"""Test invoke_global_service with existing dispatcher"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
# Pre-populate with existing dispatcher
|
||||
mock_dispatcher = Mock()
|
||||
mock_dispatcher.process = AsyncMock(return_value="cached_result")
|
||||
manager.dispatchers[(None, "config")] = mock_dispatcher
|
||||
|
||||
result = await manager.invoke_global_service("data", "responder", "config")
|
||||
|
||||
mock_dispatcher.process.assert_called_once_with("data", "responder")
|
||||
assert result == "cached_result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_global_service_creates_new_dispatcher(self):
|
||||
"""Test invoke_global_service creates new dispatcher"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers') as mock_dispatchers:
|
||||
mock_dispatcher_class = Mock()
|
||||
mock_dispatcher = Mock()
|
||||
mock_dispatcher.start = AsyncMock()
|
||||
mock_dispatcher.process = AsyncMock(return_value="new_result")
|
||||
mock_dispatcher_class.return_value = mock_dispatcher
|
||||
mock_dispatchers.__getitem__.return_value = mock_dispatcher_class
|
||||
|
||||
result = await manager.invoke_global_service("data", "responder", "config")
|
||||
|
||||
# Verify dispatcher was created with correct parameters
|
||||
mock_dispatcher_class.assert_called_once_with(
|
||||
pulsar_client=mock_pulsar_client,
|
||||
timeout=120,
|
||||
consumer="api-gateway-config-request",
|
||||
subscriber="api-gateway-config-request"
|
||||
)
|
||||
mock_dispatcher.start.assert_called_once()
|
||||
mock_dispatcher.process.assert_called_once_with("data", "responder")
|
||||
|
||||
# Verify dispatcher was cached
|
||||
assert manager.dispatchers[(None, "config")] == mock_dispatcher
|
||||
assert result == "new_result"
|
||||
|
||||
def test_dispatch_flow_import_returns_method(self):
|
||||
"""Test dispatch_flow_import returns correct method"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
result = manager.dispatch_flow_import()
|
||||
|
||||
assert result == manager.process_flow_import
|
||||
|
||||
def test_dispatch_flow_export_returns_method(self):
|
||||
"""Test dispatch_flow_export returns correct method"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
result = manager.dispatch_flow_export()
|
||||
|
||||
assert result == manager.process_flow_export
|
||||
|
||||
def test_dispatch_socket_returns_method(self):
|
||||
"""Test dispatch_socket returns correct method"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
result = manager.dispatch_socket()
|
||||
|
||||
assert result == manager.process_socket
|
||||
|
||||
def test_dispatch_flow_service_returns_wrapper(self):
|
||||
"""Test dispatch_flow_service returns DispatcherWrapper"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
wrapper = manager.dispatch_flow_service()
|
||||
|
||||
assert isinstance(wrapper, DispatcherWrapper)
|
||||
assert wrapper.handler == manager.process_flow_service
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_flow_import_with_valid_flow_and_kind(self):
|
||||
"""Test process_flow_import with valid flow and kind"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
# Setup test flow
|
||||
manager.flows["test_flow"] = {
|
||||
"interfaces": {
|
||||
"triples-store": {"queue": "test_queue"}
|
||||
}
|
||||
}
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.import_dispatchers') as mock_dispatchers, \
|
||||
patch('uuid.uuid4') as mock_uuid:
|
||||
mock_uuid.return_value = "test-uuid"
|
||||
mock_dispatcher_class = Mock()
|
||||
mock_dispatcher = Mock()
|
||||
mock_dispatcher.start = AsyncMock()
|
||||
mock_dispatcher_class.return_value = mock_dispatcher
|
||||
mock_dispatchers.__getitem__.return_value = mock_dispatcher_class
|
||||
mock_dispatchers.__contains__.return_value = True
|
||||
|
||||
params = {"flow": "test_flow", "kind": "triples"}
|
||||
result = await manager.process_flow_import("ws", "running", params)
|
||||
|
||||
mock_dispatcher_class.assert_called_once_with(
|
||||
pulsar_client=mock_pulsar_client,
|
||||
ws="ws",
|
||||
running="running",
|
||||
queue={"queue": "test_queue"}
|
||||
)
|
||||
mock_dispatcher.start.assert_called_once()
|
||||
assert result == mock_dispatcher
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_flow_import_with_invalid_flow(self):
|
||||
"""Test process_flow_import with invalid flow"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
params = {"flow": "invalid_flow", "kind": "triples"}
|
||||
|
||||
with pytest.raises(RuntimeError, match="Invalid flow"):
|
||||
await manager.process_flow_import("ws", "running", params)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_flow_import_with_invalid_kind(self):
|
||||
"""Test process_flow_import with invalid kind"""
|
||||
import warnings
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", RuntimeWarning)
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
# Setup test flow
|
||||
manager.flows["test_flow"] = {
|
||||
"interfaces": {
|
||||
"triples-store": {"queue": "test_queue"}
|
||||
}
|
||||
}
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.import_dispatchers') as mock_dispatchers:
|
||||
mock_dispatchers.__contains__.return_value = False
|
||||
|
||||
params = {"flow": "test_flow", "kind": "invalid_kind"}
|
||||
|
||||
with pytest.raises(RuntimeError, match="Invalid kind"):
|
||||
await manager.process_flow_import("ws", "running", params)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_flow_export_with_valid_flow_and_kind(self):
|
||||
"""Test process_flow_export with valid flow and kind"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
# Setup test flow
|
||||
manager.flows["test_flow"] = {
|
||||
"interfaces": {
|
||||
"triples-store": {"queue": "test_queue"}
|
||||
}
|
||||
}
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.export_dispatchers') as mock_dispatchers, \
|
||||
patch('uuid.uuid4') as mock_uuid:
|
||||
mock_uuid.return_value = "test-uuid"
|
||||
mock_dispatcher_class = Mock()
|
||||
mock_dispatcher = Mock()
|
||||
mock_dispatcher_class.return_value = mock_dispatcher
|
||||
mock_dispatchers.__getitem__.return_value = mock_dispatcher_class
|
||||
mock_dispatchers.__contains__.return_value = True
|
||||
|
||||
params = {"flow": "test_flow", "kind": "triples"}
|
||||
result = await manager.process_flow_export("ws", "running", params)
|
||||
|
||||
mock_dispatcher_class.assert_called_once_with(
|
||||
pulsar_client=mock_pulsar_client,
|
||||
ws="ws",
|
||||
running="running",
|
||||
queue={"queue": "test_queue"},
|
||||
consumer="api-gateway-test-uuid",
|
||||
subscriber="api-gateway-test-uuid"
|
||||
)
|
||||
assert result == mock_dispatcher
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_socket(self):
|
||||
"""Test process_socket method"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.Mux') as mock_mux:
|
||||
mock_mux_instance = Mock()
|
||||
mock_mux.return_value = mock_mux_instance
|
||||
|
||||
result = await manager.process_socket("ws", "running", {})
|
||||
|
||||
mock_mux.assert_called_once_with(manager, "ws", "running")
|
||||
assert result == mock_mux_instance
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_flow_service(self):
|
||||
"""Test process_flow_service method"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
manager.invoke_flow_service = AsyncMock(return_value="flow_result")
|
||||
|
||||
params = {"flow": "test_flow", "kind": "agent"}
|
||||
result = await manager.process_flow_service("data", "responder", params)
|
||||
|
||||
manager.invoke_flow_service.assert_called_once_with("data", "responder", "test_flow", "agent")
|
||||
assert result == "flow_result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_flow_service_with_existing_dispatcher(self):
|
||||
"""Test invoke_flow_service with existing dispatcher"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
# Add flow to the flows dictionary
|
||||
manager.flows["test_flow"] = {"services": {"agent": {}}}
|
||||
|
||||
# Pre-populate with existing dispatcher
|
||||
mock_dispatcher = Mock()
|
||||
mock_dispatcher.process = AsyncMock(return_value="cached_result")
|
||||
manager.dispatchers[("test_flow", "agent")] = mock_dispatcher
|
||||
|
||||
result = await manager.invoke_flow_service("data", "responder", "test_flow", "agent")
|
||||
|
||||
mock_dispatcher.process.assert_called_once_with("data", "responder")
|
||||
assert result == "cached_result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_flow_service_creates_request_response_dispatcher(self):
|
||||
"""Test invoke_flow_service creates request-response dispatcher"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
# Setup test flow
|
||||
manager.flows["test_flow"] = {
|
||||
"interfaces": {
|
||||
"agent": {
|
||||
"request": "agent_request_queue",
|
||||
"response": "agent_response_queue"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.request_response_dispatchers') as mock_dispatchers:
|
||||
mock_dispatcher_class = Mock()
|
||||
mock_dispatcher = Mock()
|
||||
mock_dispatcher.start = AsyncMock()
|
||||
mock_dispatcher.process = AsyncMock(return_value="new_result")
|
||||
mock_dispatcher_class.return_value = mock_dispatcher
|
||||
mock_dispatchers.__getitem__.return_value = mock_dispatcher_class
|
||||
mock_dispatchers.__contains__.return_value = True
|
||||
|
||||
result = await manager.invoke_flow_service("data", "responder", "test_flow", "agent")
|
||||
|
||||
# Verify dispatcher was created with correct parameters
|
||||
mock_dispatcher_class.assert_called_once_with(
|
||||
pulsar_client=mock_pulsar_client,
|
||||
request_queue="agent_request_queue",
|
||||
response_queue="agent_response_queue",
|
||||
timeout=120,
|
||||
consumer="api-gateway-test_flow-agent-request",
|
||||
subscriber="api-gateway-test_flow-agent-request"
|
||||
)
|
||||
mock_dispatcher.start.assert_called_once()
|
||||
mock_dispatcher.process.assert_called_once_with("data", "responder")
|
||||
|
||||
# Verify dispatcher was cached
|
||||
assert manager.dispatchers[("test_flow", "agent")] == mock_dispatcher
|
||||
assert result == "new_result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_flow_service_creates_sender_dispatcher(self):
|
||||
"""Test invoke_flow_service creates sender dispatcher"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
# Setup test flow
|
||||
manager.flows["test_flow"] = {
|
||||
"interfaces": {
|
||||
"text-load": {"queue": "text_load_queue"}
|
||||
}
|
||||
}
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.request_response_dispatchers') as mock_rr_dispatchers, \
|
||||
patch('trustgraph.gateway.dispatch.manager.sender_dispatchers') as mock_sender_dispatchers:
|
||||
mock_rr_dispatchers.__contains__.return_value = False
|
||||
mock_sender_dispatchers.__contains__.return_value = True
|
||||
|
||||
mock_dispatcher_class = Mock()
|
||||
mock_dispatcher = Mock()
|
||||
mock_dispatcher.start = AsyncMock()
|
||||
mock_dispatcher.process = AsyncMock(return_value="sender_result")
|
||||
mock_dispatcher_class.return_value = mock_dispatcher
|
||||
mock_sender_dispatchers.__getitem__.return_value = mock_dispatcher_class
|
||||
|
||||
result = await manager.invoke_flow_service("data", "responder", "test_flow", "text-load")
|
||||
|
||||
# Verify dispatcher was created with correct parameters
|
||||
mock_dispatcher_class.assert_called_once_with(
|
||||
pulsar_client=mock_pulsar_client,
|
||||
queue={"queue": "text_load_queue"}
|
||||
)
|
||||
mock_dispatcher.start.assert_called_once()
|
||||
mock_dispatcher.process.assert_called_once_with("data", "responder")
|
||||
|
||||
# Verify dispatcher was cached
|
||||
assert manager.dispatchers[("test_flow", "text-load")] == mock_dispatcher
|
||||
assert result == "sender_result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_flow_service_invalid_flow(self):
|
||||
"""Test invoke_flow_service with invalid flow"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Invalid flow"):
|
||||
await manager.invoke_flow_service("data", "responder", "invalid_flow", "agent")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_flow_service_unsupported_kind_by_flow(self):
|
||||
"""Test invoke_flow_service with kind not supported by flow"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
# Setup test flow without agent interface
|
||||
manager.flows["test_flow"] = {
|
||||
"interfaces": {
|
||||
"text-completion": {"request": "req", "response": "resp"}
|
||||
}
|
||||
}
|
||||
|
||||
with pytest.raises(RuntimeError, match="This kind not supported by flow"):
|
||||
await manager.invoke_flow_service("data", "responder", "test_flow", "agent")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_flow_service_invalid_kind(self):
|
||||
"""Test invoke_flow_service with invalid kind"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
# Setup test flow with interface but unsupported kind
|
||||
manager.flows["test_flow"] = {
|
||||
"interfaces": {
|
||||
"invalid-kind": {"request": "req", "response": "resp"}
|
||||
}
|
||||
}
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.request_response_dispatchers') as mock_rr_dispatchers, \
|
||||
patch('trustgraph.gateway.dispatch.manager.sender_dispatchers') as mock_sender_dispatchers:
|
||||
mock_rr_dispatchers.__contains__.return_value = False
|
||||
mock_sender_dispatchers.__contains__.return_value = False
|
||||
|
||||
with pytest.raises(RuntimeError, match="Invalid kind"):
|
||||
await manager.invoke_flow_service("data", "responder", "test_flow", "invalid-kind")
|
||||
171
tests/unit/test_gateway/test_dispatch_mux.py
Normal file
171
tests/unit/test_gateway/test_dispatch_mux.py
Normal file
|
|
@ -0,0 +1,171 @@
|
|||
"""
|
||||
Tests for Gateway Dispatch Mux
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
|
||||
from trustgraph.gateway.dispatch.mux import Mux, MAX_QUEUE_SIZE
|
||||
|
||||
|
||||
class TestMux:
|
||||
"""Test cases for Mux class"""
|
||||
|
||||
def test_mux_initialization(self):
|
||||
"""Test Mux initialization"""
|
||||
mock_dispatcher_manager = MagicMock()
|
||||
mock_ws = MagicMock()
|
||||
mock_running = MagicMock()
|
||||
|
||||
mux = Mux(
|
||||
dispatcher_manager=mock_dispatcher_manager,
|
||||
ws=mock_ws,
|
||||
running=mock_running
|
||||
)
|
||||
|
||||
assert mux.dispatcher_manager == mock_dispatcher_manager
|
||||
assert mux.ws == mock_ws
|
||||
assert mux.running == mock_running
|
||||
assert isinstance(mux.q, asyncio.Queue)
|
||||
assert mux.q.maxsize == MAX_QUEUE_SIZE
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mux_destroy_with_websocket(self):
|
||||
"""Test Mux destroy method with websocket"""
|
||||
mock_dispatcher_manager = MagicMock()
|
||||
mock_ws = AsyncMock()
|
||||
mock_running = MagicMock()
|
||||
|
||||
mux = Mux(
|
||||
dispatcher_manager=mock_dispatcher_manager,
|
||||
ws=mock_ws,
|
||||
running=mock_running
|
||||
)
|
||||
|
||||
# Call destroy
|
||||
await mux.destroy()
|
||||
|
||||
# Verify running.stop was called
|
||||
mock_running.stop.assert_called_once()
|
||||
|
||||
# Verify websocket close was called
|
||||
mock_ws.close.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mux_destroy_without_websocket(self):
|
||||
"""Test Mux destroy method without websocket"""
|
||||
mock_dispatcher_manager = MagicMock()
|
||||
mock_running = MagicMock()
|
||||
|
||||
mux = Mux(
|
||||
dispatcher_manager=mock_dispatcher_manager,
|
||||
ws=None,
|
||||
running=mock_running
|
||||
)
|
||||
|
||||
# Call destroy
|
||||
await mux.destroy()
|
||||
|
||||
# Verify running.stop was called
|
||||
mock_running.stop.assert_called_once()
|
||||
# No websocket to close
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mux_receive_valid_message(self):
|
||||
"""Test Mux receive method with valid message"""
|
||||
mock_dispatcher_manager = MagicMock()
|
||||
mock_ws = AsyncMock()
|
||||
mock_running = MagicMock()
|
||||
|
||||
mux = Mux(
|
||||
dispatcher_manager=mock_dispatcher_manager,
|
||||
ws=mock_ws,
|
||||
running=mock_running
|
||||
)
|
||||
|
||||
# Mock message with valid JSON
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.json.return_value = {
|
||||
"request": {"type": "test"},
|
||||
"id": "test-id-123",
|
||||
"service": "test-service"
|
||||
}
|
||||
|
||||
# Call receive
|
||||
await mux.receive(mock_msg)
|
||||
|
||||
# Verify json was called
|
||||
mock_msg.json.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mux_receive_message_without_request(self):
|
||||
"""Test Mux receive method with message missing request field"""
|
||||
mock_dispatcher_manager = MagicMock()
|
||||
mock_ws = AsyncMock()
|
||||
mock_running = MagicMock()
|
||||
|
||||
mux = Mux(
|
||||
dispatcher_manager=mock_dispatcher_manager,
|
||||
ws=mock_ws,
|
||||
running=mock_running
|
||||
)
|
||||
|
||||
# Mock message without request field
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.json.return_value = {
|
||||
"id": "test-id-123"
|
||||
}
|
||||
|
||||
# receive method should handle the RuntimeError internally
|
||||
# Based on the code, it seems to catch exceptions
|
||||
await mux.receive(mock_msg)
|
||||
|
||||
mock_ws.send_json.assert_called_once_with({"error": "Bad message"})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mux_receive_message_without_id(self):
|
||||
"""Test Mux receive method with message missing id field"""
|
||||
mock_dispatcher_manager = MagicMock()
|
||||
mock_ws = AsyncMock()
|
||||
mock_running = MagicMock()
|
||||
|
||||
mux = Mux(
|
||||
dispatcher_manager=mock_dispatcher_manager,
|
||||
ws=mock_ws,
|
||||
running=mock_running
|
||||
)
|
||||
|
||||
# Mock message without id field
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.json.return_value = {
|
||||
"request": {"type": "test"}
|
||||
}
|
||||
|
||||
# receive method should handle the RuntimeError internally
|
||||
await mux.receive(mock_msg)
|
||||
|
||||
mock_ws.send_json.assert_called_once_with({"error": "Bad message"})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mux_receive_invalid_json(self):
|
||||
"""Test Mux receive method with invalid JSON"""
|
||||
mock_dispatcher_manager = MagicMock()
|
||||
mock_ws = AsyncMock()
|
||||
mock_running = MagicMock()
|
||||
|
||||
mux = Mux(
|
||||
dispatcher_manager=mock_dispatcher_manager,
|
||||
ws=mock_ws,
|
||||
running=mock_running
|
||||
)
|
||||
|
||||
# Mock message with invalid JSON
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.json.side_effect = ValueError("Invalid JSON")
|
||||
|
||||
# receive method should handle the ValueError internally
|
||||
await mux.receive(mock_msg)
|
||||
|
||||
mock_msg.json.assert_called_once()
|
||||
mock_ws.send_json.assert_called_once_with({"error": "Invalid JSON"})
|
||||
118
tests/unit/test_gateway/test_dispatch_requestor.py
Normal file
118
tests/unit/test_gateway/test_dispatch_requestor.py
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
"""
|
||||
Tests for Gateway Service Requestor
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
from trustgraph.gateway.dispatch.requestor import ServiceRequestor
|
||||
|
||||
|
||||
class TestServiceRequestor:
|
||||
"""Test cases for ServiceRequestor class"""
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.requestor.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.requestor.Subscriber')
|
||||
def test_service_requestor_initialization(self, mock_subscriber, mock_publisher):
|
||||
"""Test ServiceRequestor initialization"""
|
||||
mock_pulsar_client = MagicMock()
|
||||
mock_request_schema = MagicMock()
|
||||
mock_response_schema = MagicMock()
|
||||
|
||||
requestor = ServiceRequestor(
|
||||
pulsar_client=mock_pulsar_client,
|
||||
request_queue="test-request-queue",
|
||||
request_schema=mock_request_schema,
|
||||
response_queue="test-response-queue",
|
||||
response_schema=mock_response_schema,
|
||||
subscription="test-subscription",
|
||||
consumer_name="test-consumer",
|
||||
timeout=300
|
||||
)
|
||||
|
||||
# Verify Publisher was created correctly
|
||||
mock_publisher.assert_called_once_with(
|
||||
mock_pulsar_client, "test-request-queue", schema=mock_request_schema
|
||||
)
|
||||
|
||||
# Verify Subscriber was created correctly
|
||||
mock_subscriber.assert_called_once_with(
|
||||
mock_pulsar_client, "test-response-queue",
|
||||
"test-subscription", "test-consumer", mock_response_schema
|
||||
)
|
||||
|
||||
assert requestor.timeout == 300
|
||||
assert requestor.running is True
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.requestor.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.requestor.Subscriber')
|
||||
def test_service_requestor_with_defaults(self, mock_subscriber, mock_publisher):
|
||||
"""Test ServiceRequestor initialization with default parameters"""
|
||||
mock_pulsar_client = MagicMock()
|
||||
mock_request_schema = MagicMock()
|
||||
mock_response_schema = MagicMock()
|
||||
|
||||
requestor = ServiceRequestor(
|
||||
pulsar_client=mock_pulsar_client,
|
||||
request_queue="test-queue",
|
||||
request_schema=mock_request_schema,
|
||||
response_queue="response-queue",
|
||||
response_schema=mock_response_schema
|
||||
)
|
||||
|
||||
# Verify default values
|
||||
mock_subscriber.assert_called_once_with(
|
||||
mock_pulsar_client, "response-queue",
|
||||
"api-gateway", "api-gateway", mock_response_schema
|
||||
)
|
||||
assert requestor.timeout == 600 # Default timeout
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.requestor.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.requestor.Subscriber')
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_requestor_start(self, mock_subscriber, mock_publisher):
|
||||
"""Test ServiceRequestor start method"""
|
||||
mock_pulsar_client = MagicMock()
|
||||
mock_sub_instance = AsyncMock()
|
||||
mock_pub_instance = AsyncMock()
|
||||
mock_subscriber.return_value = mock_sub_instance
|
||||
mock_publisher.return_value = mock_pub_instance
|
||||
|
||||
requestor = ServiceRequestor(
|
||||
pulsar_client=mock_pulsar_client,
|
||||
request_queue="test-queue",
|
||||
request_schema=MagicMock(),
|
||||
response_queue="response-queue",
|
||||
response_schema=MagicMock()
|
||||
)
|
||||
|
||||
# Call start
|
||||
await requestor.start()
|
||||
|
||||
# Verify both subscriber and publisher start were called
|
||||
mock_sub_instance.start.assert_called_once()
|
||||
mock_pub_instance.start.assert_called_once()
|
||||
assert requestor.running is True
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.requestor.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.requestor.Subscriber')
|
||||
def test_service_requestor_attributes(self, mock_subscriber, mock_publisher):
|
||||
"""Test ServiceRequestor has correct attributes"""
|
||||
mock_pulsar_client = MagicMock()
|
||||
mock_pub_instance = AsyncMock()
|
||||
mock_sub_instance = AsyncMock()
|
||||
mock_publisher.return_value = mock_pub_instance
|
||||
mock_subscriber.return_value = mock_sub_instance
|
||||
|
||||
requestor = ServiceRequestor(
|
||||
pulsar_client=mock_pulsar_client,
|
||||
request_queue="test-queue",
|
||||
request_schema=MagicMock(),
|
||||
response_queue="response-queue",
|
||||
response_schema=MagicMock()
|
||||
)
|
||||
|
||||
# Verify attributes are set correctly
|
||||
assert requestor.pub == mock_pub_instance
|
||||
assert requestor.sub == mock_sub_instance
|
||||
assert requestor.running is True
|
||||
120
tests/unit/test_gateway/test_dispatch_sender.py
Normal file
120
tests/unit/test_gateway/test_dispatch_sender.py
Normal file
|
|
@ -0,0 +1,120 @@
|
|||
"""
|
||||
Tests for Gateway Service Sender
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
from trustgraph.gateway.dispatch.sender import ServiceSender
|
||||
|
||||
|
||||
class TestServiceSender:
|
||||
"""Test cases for ServiceSender class"""
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.sender.Publisher')
|
||||
def test_service_sender_initialization(self, mock_publisher):
|
||||
"""Test ServiceSender initialization"""
|
||||
mock_pulsar_client = MagicMock()
|
||||
mock_schema = MagicMock()
|
||||
|
||||
sender = ServiceSender(
|
||||
pulsar_client=mock_pulsar_client,
|
||||
queue="test-queue",
|
||||
schema=mock_schema
|
||||
)
|
||||
|
||||
# Verify Publisher was created correctly
|
||||
mock_publisher.assert_called_once_with(
|
||||
mock_pulsar_client, "test-queue", schema=mock_schema
|
||||
)
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.sender.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_sender_start(self, mock_publisher):
|
||||
"""Test ServiceSender start method"""
|
||||
mock_pub_instance = AsyncMock()
|
||||
mock_publisher.return_value = mock_pub_instance
|
||||
|
||||
sender = ServiceSender(
|
||||
pulsar_client=MagicMock(),
|
||||
queue="test-queue",
|
||||
schema=MagicMock()
|
||||
)
|
||||
|
||||
# Call start
|
||||
await sender.start()
|
||||
|
||||
# Verify publisher start was called
|
||||
mock_pub_instance.start.assert_called_once()
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.sender.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_sender_stop(self, mock_publisher):
|
||||
"""Test ServiceSender stop method"""
|
||||
mock_pub_instance = AsyncMock()
|
||||
mock_publisher.return_value = mock_pub_instance
|
||||
|
||||
sender = ServiceSender(
|
||||
pulsar_client=MagicMock(),
|
||||
queue="test-queue",
|
||||
schema=MagicMock()
|
||||
)
|
||||
|
||||
# Call stop
|
||||
await sender.stop()
|
||||
|
||||
# Verify publisher stop was called
|
||||
mock_pub_instance.stop.assert_called_once()
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.sender.Publisher')
|
||||
def test_service_sender_to_request_not_implemented(self, mock_publisher):
|
||||
"""Test ServiceSender to_request method raises RuntimeError"""
|
||||
sender = ServiceSender(
|
||||
pulsar_client=MagicMock(),
|
||||
queue="test-queue",
|
||||
schema=MagicMock()
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Not defined"):
|
||||
sender.to_request({"test": "request"})
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.sender.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_sender_process(self, mock_publisher):
|
||||
"""Test ServiceSender process method"""
|
||||
mock_pub_instance = AsyncMock()
|
||||
mock_publisher.return_value = mock_pub_instance
|
||||
|
||||
# Create a concrete sender that implements to_request
|
||||
class ConcreteSender(ServiceSender):
|
||||
def to_request(self, request):
|
||||
return {"processed": request}
|
||||
|
||||
sender = ConcreteSender(
|
||||
pulsar_client=MagicMock(),
|
||||
queue="test-queue",
|
||||
schema=MagicMock()
|
||||
)
|
||||
|
||||
test_request = {"test": "data"}
|
||||
|
||||
# Call process
|
||||
await sender.process(test_request)
|
||||
|
||||
# Verify publisher send was called with processed request
|
||||
mock_pub_instance.send.assert_called_once_with(None, {"processed": test_request})
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.sender.Publisher')
|
||||
def test_service_sender_attributes(self, mock_publisher):
|
||||
"""Test ServiceSender has correct attributes"""
|
||||
mock_pub_instance = MagicMock()
|
||||
mock_publisher.return_value = mock_pub_instance
|
||||
|
||||
sender = ServiceSender(
|
||||
pulsar_client=MagicMock(),
|
||||
queue="test-queue",
|
||||
schema=MagicMock()
|
||||
)
|
||||
|
||||
# Verify attributes are set correctly
|
||||
assert sender.pub == mock_pub_instance
|
||||
89
tests/unit/test_gateway/test_dispatch_serialize.py
Normal file
89
tests/unit/test_gateway/test_dispatch_serialize.py
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
"""
|
||||
Tests for Gateway Dispatch Serialization
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from trustgraph.gateway.dispatch.serialize import to_value, to_subgraph, serialize_value
|
||||
from trustgraph.schema import Value, Triple
|
||||
|
||||
|
||||
class TestDispatchSerialize:
|
||||
"""Test cases for dispatch serialization functions"""
|
||||
|
||||
def test_to_value_with_uri(self):
|
||||
"""Test to_value function with URI"""
|
||||
input_data = {"v": "http://example.com/resource", "e": True}
|
||||
|
||||
result = to_value(input_data)
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "http://example.com/resource"
|
||||
assert result.is_uri is True
|
||||
|
||||
def test_to_value_with_literal(self):
|
||||
"""Test to_value function with literal value"""
|
||||
input_data = {"v": "literal string", "e": False}
|
||||
|
||||
result = to_value(input_data)
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "literal string"
|
||||
assert result.is_uri is False
|
||||
|
||||
def test_to_subgraph_with_multiple_triples(self):
|
||||
"""Test to_subgraph function with multiple triples"""
|
||||
input_data = [
|
||||
{
|
||||
"s": {"v": "subject1", "e": True},
|
||||
"p": {"v": "predicate1", "e": True},
|
||||
"o": {"v": "object1", "e": False}
|
||||
},
|
||||
{
|
||||
"s": {"v": "subject2", "e": False},
|
||||
"p": {"v": "predicate2", "e": True},
|
||||
"o": {"v": "object2", "e": True}
|
||||
}
|
||||
]
|
||||
|
||||
result = to_subgraph(input_data)
|
||||
|
||||
assert len(result) == 2
|
||||
assert all(isinstance(triple, Triple) for triple in result)
|
||||
|
||||
# Check first triple
|
||||
assert result[0].s.value == "subject1"
|
||||
assert result[0].s.is_uri is True
|
||||
assert result[0].p.value == "predicate1"
|
||||
assert result[0].p.is_uri is True
|
||||
assert result[0].o.value == "object1"
|
||||
assert result[0].o.is_uri is False
|
||||
|
||||
# Check second triple
|
||||
assert result[1].s.value == "subject2"
|
||||
assert result[1].s.is_uri is False
|
||||
|
||||
def test_to_subgraph_with_empty_list(self):
|
||||
"""Test to_subgraph function with empty input"""
|
||||
input_data = []
|
||||
|
||||
result = to_subgraph(input_data)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_serialize_value_with_uri(self):
|
||||
"""Test serialize_value function with URI value"""
|
||||
value = Value(value="http://example.com/test", is_uri=True)
|
||||
|
||||
result = serialize_value(value)
|
||||
|
||||
assert result == {"v": "http://example.com/test", "e": True}
|
||||
|
||||
def test_serialize_value_with_literal(self):
|
||||
"""Test serialize_value function with literal value"""
|
||||
value = Value(value="test literal", is_uri=False)
|
||||
|
||||
result = serialize_value(value)
|
||||
|
||||
assert result == {"v": "test literal", "e": False}
|
||||
55
tests/unit/test_gateway/test_endpoint_constant.py
Normal file
55
tests/unit/test_gateway/test_endpoint_constant.py
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
"""
|
||||
Tests for Gateway Constant Endpoint
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
from aiohttp import web
|
||||
|
||||
from trustgraph.gateway.endpoint.constant_endpoint import ConstantEndpoint
|
||||
|
||||
|
||||
class TestConstantEndpoint:
|
||||
"""Test cases for ConstantEndpoint class"""
|
||||
|
||||
def test_constant_endpoint_initialization(self):
|
||||
"""Test ConstantEndpoint initialization"""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
|
||||
endpoint = ConstantEndpoint(
|
||||
endpoint_path="/api/test",
|
||||
auth=mock_auth,
|
||||
dispatcher=mock_dispatcher
|
||||
)
|
||||
|
||||
assert endpoint.path == "/api/test"
|
||||
assert endpoint.auth == mock_auth
|
||||
assert endpoint.dispatcher == mock_dispatcher
|
||||
assert endpoint.operation == "service"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_constant_endpoint_start_method(self):
|
||||
"""Test ConstantEndpoint start method (should be no-op)"""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
|
||||
endpoint = ConstantEndpoint("/api/test", mock_auth, mock_dispatcher)
|
||||
|
||||
# start() should complete without error
|
||||
await endpoint.start()
|
||||
|
||||
def test_add_routes_registers_post_handler(self):
|
||||
"""Test add_routes method registers POST route"""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
mock_app = MagicMock()
|
||||
|
||||
endpoint = ConstantEndpoint("/api/test", mock_auth, mock_dispatcher)
|
||||
endpoint.add_routes(mock_app)
|
||||
|
||||
# Verify add_routes was called with POST route
|
||||
mock_app.add_routes.assert_called_once()
|
||||
# The call should include web.post with the path and handler
|
||||
call_args = mock_app.add_routes.call_args[0][0]
|
||||
assert len(call_args) == 1 # One route added
|
||||
89
tests/unit/test_gateway/test_endpoint_manager.py
Normal file
89
tests/unit/test_gateway/test_endpoint_manager.py
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
"""
|
||||
Tests for Gateway Endpoint Manager
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from trustgraph.gateway.endpoint.manager import EndpointManager
|
||||
|
||||
|
||||
class TestEndpointManager:
|
||||
"""Test cases for EndpointManager class"""
|
||||
|
||||
def test_endpoint_manager_initialization(self):
|
||||
"""Test EndpointManager initialization creates all endpoints"""
|
||||
mock_dispatcher_manager = MagicMock()
|
||||
mock_auth = MagicMock()
|
||||
|
||||
# Mock dispatcher methods
|
||||
mock_dispatcher_manager.dispatch_global_service.return_value = MagicMock()
|
||||
mock_dispatcher_manager.dispatch_socket.return_value = MagicMock()
|
||||
mock_dispatcher_manager.dispatch_flow_service.return_value = MagicMock()
|
||||
mock_dispatcher_manager.dispatch_flow_import.return_value = MagicMock()
|
||||
mock_dispatcher_manager.dispatch_flow_export.return_value = MagicMock()
|
||||
mock_dispatcher_manager.dispatch_core_import.return_value = MagicMock()
|
||||
mock_dispatcher_manager.dispatch_core_export.return_value = MagicMock()
|
||||
|
||||
manager = EndpointManager(
|
||||
dispatcher_manager=mock_dispatcher_manager,
|
||||
auth=mock_auth,
|
||||
prometheus_url="http://prometheus:9090",
|
||||
timeout=300
|
||||
)
|
||||
|
||||
assert manager.dispatcher_manager == mock_dispatcher_manager
|
||||
assert manager.timeout == 300
|
||||
assert manager.services == {}
|
||||
assert len(manager.endpoints) > 0 # Should have multiple endpoints
|
||||
|
||||
def test_endpoint_manager_with_default_timeout(self):
|
||||
"""Test EndpointManager with default timeout value"""
|
||||
mock_dispatcher_manager = MagicMock()
|
||||
mock_auth = MagicMock()
|
||||
|
||||
# Mock dispatcher methods
|
||||
mock_dispatcher_manager.dispatch_global_service.return_value = MagicMock()
|
||||
mock_dispatcher_manager.dispatch_socket.return_value = MagicMock()
|
||||
mock_dispatcher_manager.dispatch_flow_service.return_value = MagicMock()
|
||||
mock_dispatcher_manager.dispatch_flow_import.return_value = MagicMock()
|
||||
mock_dispatcher_manager.dispatch_flow_export.return_value = MagicMock()
|
||||
mock_dispatcher_manager.dispatch_core_import.return_value = MagicMock()
|
||||
mock_dispatcher_manager.dispatch_core_export.return_value = MagicMock()
|
||||
|
||||
manager = EndpointManager(
|
||||
dispatcher_manager=mock_dispatcher_manager,
|
||||
auth=mock_auth,
|
||||
prometheus_url="http://prometheus:9090"
|
||||
)
|
||||
|
||||
assert manager.timeout == 600 # Default value
|
||||
|
||||
def test_endpoint_manager_dispatcher_calls(self):
|
||||
"""Test EndpointManager calls all required dispatcher methods"""
|
||||
mock_dispatcher_manager = MagicMock()
|
||||
mock_auth = MagicMock()
|
||||
|
||||
# Mock dispatcher methods that are actually called
|
||||
mock_dispatcher_manager.dispatch_global_service.return_value = MagicMock()
|
||||
mock_dispatcher_manager.dispatch_socket.return_value = MagicMock()
|
||||
mock_dispatcher_manager.dispatch_flow_service.return_value = MagicMock()
|
||||
mock_dispatcher_manager.dispatch_flow_import.return_value = MagicMock()
|
||||
mock_dispatcher_manager.dispatch_flow_export.return_value = MagicMock()
|
||||
mock_dispatcher_manager.dispatch_core_import.return_value = MagicMock()
|
||||
mock_dispatcher_manager.dispatch_core_export.return_value = MagicMock()
|
||||
|
||||
EndpointManager(
|
||||
dispatcher_manager=mock_dispatcher_manager,
|
||||
auth=mock_auth,
|
||||
prometheus_url="http://test:9090"
|
||||
)
|
||||
|
||||
# Verify all dispatcher methods were called during initialization
|
||||
mock_dispatcher_manager.dispatch_global_service.assert_called_once()
|
||||
mock_dispatcher_manager.dispatch_socket.assert_called() # Called twice
|
||||
mock_dispatcher_manager.dispatch_flow_service.assert_called_once()
|
||||
mock_dispatcher_manager.dispatch_flow_import.assert_called_once()
|
||||
mock_dispatcher_manager.dispatch_flow_export.assert_called_once()
|
||||
mock_dispatcher_manager.dispatch_core_import.assert_called_once()
|
||||
mock_dispatcher_manager.dispatch_core_export.assert_called_once()
|
||||
60
tests/unit/test_gateway/test_endpoint_metrics.py
Normal file
60
tests/unit/test_gateway/test_endpoint_metrics.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
"""
|
||||
Tests for Gateway Metrics Endpoint
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from trustgraph.gateway.endpoint.metrics import MetricsEndpoint
|
||||
|
||||
|
||||
class TestMetricsEndpoint:
|
||||
"""Test cases for MetricsEndpoint class"""
|
||||
|
||||
def test_metrics_endpoint_initialization(self):
|
||||
"""Test MetricsEndpoint initialization"""
|
||||
mock_auth = MagicMock()
|
||||
|
||||
endpoint = MetricsEndpoint(
|
||||
prometheus_url="http://prometheus:9090",
|
||||
endpoint_path="/metrics",
|
||||
auth=mock_auth
|
||||
)
|
||||
|
||||
assert endpoint.prometheus_url == "http://prometheus:9090"
|
||||
assert endpoint.path == "/metrics"
|
||||
assert endpoint.auth == mock_auth
|
||||
assert endpoint.operation == "service"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_metrics_endpoint_start_method(self):
|
||||
"""Test MetricsEndpoint start method (should be no-op)"""
|
||||
mock_auth = MagicMock()
|
||||
|
||||
endpoint = MetricsEndpoint(
|
||||
prometheus_url="http://localhost:9090",
|
||||
endpoint_path="/metrics",
|
||||
auth=mock_auth
|
||||
)
|
||||
|
||||
# start() should complete without error
|
||||
await endpoint.start()
|
||||
|
||||
def test_add_routes_registers_get_handler(self):
|
||||
"""Test add_routes method registers GET route with wildcard path"""
|
||||
mock_auth = MagicMock()
|
||||
mock_app = MagicMock()
|
||||
|
||||
endpoint = MetricsEndpoint(
|
||||
prometheus_url="http://prometheus:9090",
|
||||
endpoint_path="/metrics",
|
||||
auth=mock_auth
|
||||
)
|
||||
|
||||
endpoint.add_routes(mock_app)
|
||||
|
||||
# Verify add_routes was called with GET route
|
||||
mock_app.add_routes.assert_called_once()
|
||||
# The call should include web.get with wildcard path pattern
|
||||
call_args = mock_app.add_routes.call_args[0][0]
|
||||
assert len(call_args) == 1 # One route added
|
||||
133
tests/unit/test_gateway/test_endpoint_socket.py
Normal file
133
tests/unit/test_gateway/test_endpoint_socket.py
Normal file
|
|
@ -0,0 +1,133 @@
|
|||
"""
|
||||
Tests for Gateway Socket Endpoint
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
from aiohttp import WSMsgType
|
||||
|
||||
from trustgraph.gateway.endpoint.socket import SocketEndpoint
|
||||
|
||||
|
||||
class TestSocketEndpoint:
|
||||
"""Test cases for SocketEndpoint class"""
|
||||
|
||||
def test_socket_endpoint_initialization(self):
|
||||
"""Test SocketEndpoint initialization"""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
|
||||
endpoint = SocketEndpoint(
|
||||
endpoint_path="/api/socket",
|
||||
auth=mock_auth,
|
||||
dispatcher=mock_dispatcher
|
||||
)
|
||||
|
||||
assert endpoint.path == "/api/socket"
|
||||
assert endpoint.auth == mock_auth
|
||||
assert endpoint.dispatcher == mock_dispatcher
|
||||
assert endpoint.operation == "socket"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_worker_method(self):
|
||||
"""Test SocketEndpoint worker method"""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = AsyncMock()
|
||||
|
||||
endpoint = SocketEndpoint("/api/socket", mock_auth, mock_dispatcher)
|
||||
|
||||
mock_ws = MagicMock()
|
||||
mock_running = MagicMock()
|
||||
|
||||
# Call worker method
|
||||
await endpoint.worker(mock_ws, mock_dispatcher, mock_running)
|
||||
|
||||
# Verify dispatcher.run was called
|
||||
mock_dispatcher.run.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_listener_method_with_text_message(self):
|
||||
"""Test SocketEndpoint listener method with text message"""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = AsyncMock()
|
||||
|
||||
endpoint = SocketEndpoint("/api/socket", mock_auth, mock_dispatcher)
|
||||
|
||||
# Mock websocket with text message
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.type = WSMsgType.TEXT
|
||||
|
||||
# Create async iterator for websocket
|
||||
async def async_iter():
|
||||
yield mock_msg
|
||||
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.__aiter__ = lambda self: async_iter()
|
||||
mock_running = MagicMock()
|
||||
|
||||
# Call listener method
|
||||
await endpoint.listener(mock_ws, mock_dispatcher, mock_running)
|
||||
|
||||
# Verify dispatcher.receive was called with the message
|
||||
mock_dispatcher.receive.assert_called_once_with(mock_msg)
|
||||
# Verify cleanup methods were called
|
||||
mock_running.stop.assert_called_once()
|
||||
mock_ws.close.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_listener_method_with_binary_message(self):
|
||||
"""Test SocketEndpoint listener method with binary message"""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = AsyncMock()
|
||||
|
||||
endpoint = SocketEndpoint("/api/socket", mock_auth, mock_dispatcher)
|
||||
|
||||
# Mock websocket with binary message
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.type = WSMsgType.BINARY
|
||||
|
||||
# Create async iterator for websocket
|
||||
async def async_iter():
|
||||
yield mock_msg
|
||||
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.__aiter__ = lambda self: async_iter()
|
||||
mock_running = MagicMock()
|
||||
|
||||
# Call listener method
|
||||
await endpoint.listener(mock_ws, mock_dispatcher, mock_running)
|
||||
|
||||
# Verify dispatcher.receive was called with the message
|
||||
mock_dispatcher.receive.assert_called_once_with(mock_msg)
|
||||
# Verify cleanup methods were called
|
||||
mock_running.stop.assert_called_once()
|
||||
mock_ws.close.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_listener_method_with_close_message(self):
|
||||
"""Test SocketEndpoint listener method with close message"""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = AsyncMock()
|
||||
|
||||
endpoint = SocketEndpoint("/api/socket", mock_auth, mock_dispatcher)
|
||||
|
||||
# Mock websocket with close message
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.type = WSMsgType.CLOSE
|
||||
|
||||
# Create async iterator for websocket
|
||||
async def async_iter():
|
||||
yield mock_msg
|
||||
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.__aiter__ = lambda self: async_iter()
|
||||
mock_running = MagicMock()
|
||||
|
||||
# Call listener method
|
||||
await endpoint.listener(mock_ws, mock_dispatcher, mock_running)
|
||||
|
||||
# Verify dispatcher.receive was NOT called for close message
|
||||
mock_dispatcher.receive.assert_not_called()
|
||||
# Verify cleanup methods were called after break
|
||||
mock_running.stop.assert_called_once()
|
||||
mock_ws.close.assert_called_once()
|
||||
124
tests/unit/test_gateway/test_endpoint_stream.py
Normal file
124
tests/unit/test_gateway/test_endpoint_stream.py
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
"""
|
||||
Tests for Gateway Stream Endpoint
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from trustgraph.gateway.endpoint.stream_endpoint import StreamEndpoint
|
||||
|
||||
|
||||
class TestStreamEndpoint:
|
||||
"""Test cases for StreamEndpoint class"""
|
||||
|
||||
def test_stream_endpoint_initialization_with_post(self):
|
||||
"""Test StreamEndpoint initialization with POST method"""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
|
||||
endpoint = StreamEndpoint(
|
||||
endpoint_path="/api/stream",
|
||||
auth=mock_auth,
|
||||
dispatcher=mock_dispatcher,
|
||||
method="POST"
|
||||
)
|
||||
|
||||
assert endpoint.path == "/api/stream"
|
||||
assert endpoint.auth == mock_auth
|
||||
assert endpoint.dispatcher == mock_dispatcher
|
||||
assert endpoint.operation == "service"
|
||||
assert endpoint.method == "POST"
|
||||
|
||||
def test_stream_endpoint_initialization_with_get(self):
|
||||
"""Test StreamEndpoint initialization with GET method"""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
|
||||
endpoint = StreamEndpoint(
|
||||
endpoint_path="/api/stream",
|
||||
auth=mock_auth,
|
||||
dispatcher=mock_dispatcher,
|
||||
method="GET"
|
||||
)
|
||||
|
||||
assert endpoint.method == "GET"
|
||||
|
||||
def test_stream_endpoint_initialization_default_method(self):
|
||||
"""Test StreamEndpoint initialization with default POST method"""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
|
||||
endpoint = StreamEndpoint(
|
||||
endpoint_path="/api/stream",
|
||||
auth=mock_auth,
|
||||
dispatcher=mock_dispatcher
|
||||
)
|
||||
|
||||
assert endpoint.method == "POST" # Default value
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_endpoint_start_method(self):
|
||||
"""Test StreamEndpoint start method (should be no-op)"""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
|
||||
endpoint = StreamEndpoint("/api/stream", mock_auth, mock_dispatcher)
|
||||
|
||||
# start() should complete without error
|
||||
await endpoint.start()
|
||||
|
||||
def test_add_routes_with_post_method(self):
|
||||
"""Test add_routes method with POST method"""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
mock_app = MagicMock()
|
||||
|
||||
endpoint = StreamEndpoint(
|
||||
endpoint_path="/api/stream",
|
||||
auth=mock_auth,
|
||||
dispatcher=mock_dispatcher,
|
||||
method="POST"
|
||||
)
|
||||
|
||||
endpoint.add_routes(mock_app)
|
||||
|
||||
# Verify add_routes was called with POST route
|
||||
mock_app.add_routes.assert_called_once()
|
||||
call_args = mock_app.add_routes.call_args[0][0]
|
||||
assert len(call_args) == 1 # One route added
|
||||
|
||||
def test_add_routes_with_get_method(self):
|
||||
"""Test add_routes method with GET method"""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
mock_app = MagicMock()
|
||||
|
||||
endpoint = StreamEndpoint(
|
||||
endpoint_path="/api/stream",
|
||||
auth=mock_auth,
|
||||
dispatcher=mock_dispatcher,
|
||||
method="GET"
|
||||
)
|
||||
|
||||
endpoint.add_routes(mock_app)
|
||||
|
||||
# Verify add_routes was called with GET route
|
||||
mock_app.add_routes.assert_called_once()
|
||||
call_args = mock_app.add_routes.call_args[0][0]
|
||||
assert len(call_args) == 1 # One route added
|
||||
|
||||
def test_add_routes_with_invalid_method_raises_error(self):
|
||||
"""Test add_routes method with invalid method raises RuntimeError"""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
mock_app = MagicMock()
|
||||
|
||||
endpoint = StreamEndpoint(
|
||||
endpoint_path="/api/stream",
|
||||
auth=mock_auth,
|
||||
dispatcher=mock_dispatcher,
|
||||
method="INVALID"
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Bad method"):
|
||||
endpoint.add_routes(mock_app)
|
||||
53
tests/unit/test_gateway/test_endpoint_variable.py
Normal file
53
tests/unit/test_gateway/test_endpoint_variable.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
"""
|
||||
Tests for Gateway Variable Endpoint
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from trustgraph.gateway.endpoint.variable_endpoint import VariableEndpoint
|
||||
|
||||
|
||||
class TestVariableEndpoint:
|
||||
"""Test cases for VariableEndpoint class"""
|
||||
|
||||
def test_variable_endpoint_initialization(self):
|
||||
"""Test VariableEndpoint initialization"""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
|
||||
endpoint = VariableEndpoint(
|
||||
endpoint_path="/api/variable",
|
||||
auth=mock_auth,
|
||||
dispatcher=mock_dispatcher
|
||||
)
|
||||
|
||||
assert endpoint.path == "/api/variable"
|
||||
assert endpoint.auth == mock_auth
|
||||
assert endpoint.dispatcher == mock_dispatcher
|
||||
assert endpoint.operation == "service"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_variable_endpoint_start_method(self):
|
||||
"""Test VariableEndpoint start method (should be no-op)"""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
|
||||
endpoint = VariableEndpoint("/api/var", mock_auth, mock_dispatcher)
|
||||
|
||||
# start() should complete without error
|
||||
await endpoint.start()
|
||||
|
||||
def test_add_routes_registers_post_handler(self):
|
||||
"""Test add_routes method registers POST route"""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
mock_app = MagicMock()
|
||||
|
||||
endpoint = VariableEndpoint("/api/variable", mock_auth, mock_dispatcher)
|
||||
endpoint.add_routes(mock_app)
|
||||
|
||||
# Verify add_routes was called with POST route
|
||||
mock_app.add_routes.assert_called_once()
|
||||
call_args = mock_app.add_routes.call_args[0][0]
|
||||
assert len(call_args) == 1 # One route added
|
||||
90
tests/unit/test_gateway/test_running.py
Normal file
90
tests/unit/test_gateway/test_running.py
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
"""
|
||||
Tests for Gateway Running utility class
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from trustgraph.gateway.running import Running
|
||||
|
||||
|
||||
class TestRunning:
|
||||
"""Test cases for Running class"""
|
||||
|
||||
def test_running_initialization(self):
|
||||
"""Test Running class initialization"""
|
||||
running = Running()
|
||||
|
||||
# Should start with running = True
|
||||
assert running.running is True
|
||||
|
||||
def test_running_get_method(self):
|
||||
"""Test Running.get() method returns current state"""
|
||||
running = Running()
|
||||
|
||||
# Should return True initially
|
||||
assert running.get() is True
|
||||
|
||||
# Should return False after stopping
|
||||
running.stop()
|
||||
assert running.get() is False
|
||||
|
||||
def test_running_stop_method(self):
|
||||
"""Test Running.stop() method sets running to False"""
|
||||
running = Running()
|
||||
|
||||
# Initially should be True
|
||||
assert running.running is True
|
||||
|
||||
# After calling stop(), should be False
|
||||
running.stop()
|
||||
assert running.running is False
|
||||
|
||||
def test_running_stop_is_idempotent(self):
|
||||
"""Test that calling stop() multiple times is safe"""
|
||||
running = Running()
|
||||
|
||||
# Stop multiple times
|
||||
running.stop()
|
||||
assert running.running is False
|
||||
|
||||
running.stop()
|
||||
assert running.running is False
|
||||
|
||||
# get() should still return False
|
||||
assert running.get() is False
|
||||
|
||||
def test_running_state_transitions(self):
|
||||
"""Test the complete state transition from running to stopped"""
|
||||
running = Running()
|
||||
|
||||
# Initial state: running
|
||||
assert running.get() is True
|
||||
assert running.running is True
|
||||
|
||||
# Transition to stopped
|
||||
running.stop()
|
||||
assert running.get() is False
|
||||
assert running.running is False
|
||||
|
||||
def test_running_multiple_instances_independent(self):
|
||||
"""Test that multiple Running instances are independent"""
|
||||
running1 = Running()
|
||||
running2 = Running()
|
||||
|
||||
# Both should start as running
|
||||
assert running1.get() is True
|
||||
assert running2.get() is True
|
||||
|
||||
# Stop only one
|
||||
running1.stop()
|
||||
|
||||
# States should be independent
|
||||
assert running1.get() is False
|
||||
assert running2.get() is True
|
||||
|
||||
# Stop the other
|
||||
running2.stop()
|
||||
|
||||
# Both should now be stopped
|
||||
assert running1.get() is False
|
||||
assert running2.get() is False
|
||||
360
tests/unit/test_gateway/test_service.py
Normal file
360
tests/unit/test_gateway/test_service.py
Normal file
|
|
@ -0,0 +1,360 @@
|
|||
"""
|
||||
Tests for Gateway Service API
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import Mock, patch, MagicMock, AsyncMock
|
||||
from aiohttp import web
|
||||
import pulsar
|
||||
|
||||
from trustgraph.gateway.service import Api, run, default_pulsar_host, default_prometheus_url, default_timeout, default_port, default_api_token
|
||||
|
||||
# Tests for Gateway Service API
|
||||
|
||||
|
||||
class TestApi:
|
||||
"""Test cases for Api class"""
|
||||
|
||||
|
||||
def test_api_initialization_with_defaults(self):
|
||||
"""Test Api initialization with default values"""
|
||||
with patch('pulsar.Client') as mock_client:
|
||||
mock_client.return_value = Mock()
|
||||
|
||||
api = Api()
|
||||
|
||||
assert api.port == default_port
|
||||
assert api.timeout == default_timeout
|
||||
assert api.pulsar_host == default_pulsar_host
|
||||
assert api.pulsar_api_key is None
|
||||
assert api.prometheus_url == default_prometheus_url + "/"
|
||||
assert api.auth.allow_all is True
|
||||
|
||||
# Verify Pulsar client was created without API key
|
||||
mock_client.assert_called_once_with(
|
||||
default_pulsar_host,
|
||||
listener_name=None
|
||||
)
|
||||
|
||||
def test_api_initialization_with_custom_config(self):
|
||||
"""Test Api initialization with custom configuration"""
|
||||
config = {
|
||||
"port": 9000,
|
||||
"timeout": 300,
|
||||
"pulsar_host": "pulsar://custom-host:6650",
|
||||
"pulsar_api_key": "test-api-key",
|
||||
"pulsar_listener": "custom-listener",
|
||||
"prometheus_url": "http://custom-prometheus:9090",
|
||||
"api_token": "secret-token"
|
||||
}
|
||||
|
||||
with patch('pulsar.Client') as mock_client, \
|
||||
patch('pulsar.AuthenticationToken') as mock_auth:
|
||||
mock_client.return_value = Mock()
|
||||
mock_auth.return_value = Mock()
|
||||
|
||||
api = Api(**config)
|
||||
|
||||
assert api.port == 9000
|
||||
assert api.timeout == 300
|
||||
assert api.pulsar_host == "pulsar://custom-host:6650"
|
||||
assert api.pulsar_api_key == "test-api-key"
|
||||
assert api.prometheus_url == "http://custom-prometheus:9090/"
|
||||
assert api.auth.token == "secret-token"
|
||||
assert api.auth.allow_all is False
|
||||
|
||||
# Verify Pulsar client was created with API key
|
||||
mock_auth.assert_called_once_with("test-api-key")
|
||||
mock_client.assert_called_once_with(
|
||||
"pulsar://custom-host:6650",
|
||||
listener_name="custom-listener",
|
||||
authentication=mock_auth.return_value
|
||||
)
|
||||
|
||||
def test_api_initialization_with_pulsar_api_key(self):
|
||||
"""Test Api initialization with Pulsar API key authentication"""
|
||||
with patch('pulsar.Client') as mock_client, \
|
||||
patch('pulsar.AuthenticationToken') as mock_auth:
|
||||
mock_client.return_value = Mock()
|
||||
mock_auth.return_value = Mock()
|
||||
|
||||
api = Api(pulsar_api_key="test-key")
|
||||
|
||||
mock_auth.assert_called_once_with("test-key")
|
||||
mock_client.assert_called_once_with(
|
||||
default_pulsar_host,
|
||||
listener_name=None,
|
||||
authentication=mock_auth.return_value
|
||||
)
|
||||
|
||||
def test_api_initialization_prometheus_url_normalization(self):
|
||||
"""Test that prometheus_url gets normalized with trailing slash"""
|
||||
with patch('pulsar.Client') as mock_client:
|
||||
mock_client.return_value = Mock()
|
||||
|
||||
# Test URL without trailing slash
|
||||
api = Api(prometheus_url="http://prometheus:9090")
|
||||
assert api.prometheus_url == "http://prometheus:9090/"
|
||||
|
||||
# Test URL with trailing slash
|
||||
api = Api(prometheus_url="http://prometheus:9090/")
|
||||
assert api.prometheus_url == "http://prometheus:9090/"
|
||||
|
||||
def test_api_initialization_empty_api_token_means_no_auth(self):
|
||||
"""Test that empty API token results in allow_all authentication"""
|
||||
with patch('pulsar.Client') as mock_client:
|
||||
mock_client.return_value = Mock()
|
||||
|
||||
api = Api(api_token="")
|
||||
assert api.auth.allow_all is True
|
||||
|
||||
def test_api_initialization_none_api_token_means_no_auth(self):
|
||||
"""Test that None API token results in allow_all authentication"""
|
||||
with patch('pulsar.Client') as mock_client:
|
||||
mock_client.return_value = Mock()
|
||||
|
||||
api = Api(api_token=None)
|
||||
assert api.auth.allow_all is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_app_factory_creates_application(self):
|
||||
"""Test that app_factory creates aiohttp application"""
|
||||
with patch('pulsar.Client') as mock_client:
|
||||
mock_client.return_value = Mock()
|
||||
|
||||
api = Api()
|
||||
|
||||
# Mock the dependencies
|
||||
api.config_receiver = Mock()
|
||||
api.config_receiver.start = AsyncMock()
|
||||
api.endpoint_manager = Mock()
|
||||
api.endpoint_manager.add_routes = Mock()
|
||||
api.endpoint_manager.start = AsyncMock()
|
||||
|
||||
app = await api.app_factory()
|
||||
|
||||
assert isinstance(app, web.Application)
|
||||
assert app._client_max_size == 256 * 1024 * 1024
|
||||
|
||||
# Verify that config receiver was started
|
||||
api.config_receiver.start.assert_called_once()
|
||||
|
||||
# Verify that endpoint manager was configured
|
||||
api.endpoint_manager.add_routes.assert_called_once_with(app)
|
||||
api.endpoint_manager.start.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_app_factory_with_custom_endpoints(self):
|
||||
"""Test app_factory with custom endpoints"""
|
||||
with patch('pulsar.Client') as mock_client:
|
||||
mock_client.return_value = Mock()
|
||||
|
||||
api = Api()
|
||||
|
||||
# Mock custom endpoints
|
||||
mock_endpoint1 = Mock()
|
||||
mock_endpoint1.add_routes = Mock()
|
||||
mock_endpoint1.start = AsyncMock()
|
||||
|
||||
mock_endpoint2 = Mock()
|
||||
mock_endpoint2.add_routes = Mock()
|
||||
mock_endpoint2.start = AsyncMock()
|
||||
|
||||
api.endpoints = [mock_endpoint1, mock_endpoint2]
|
||||
|
||||
# Mock the dependencies
|
||||
api.config_receiver = Mock()
|
||||
api.config_receiver.start = AsyncMock()
|
||||
api.endpoint_manager = Mock()
|
||||
api.endpoint_manager.add_routes = Mock()
|
||||
api.endpoint_manager.start = AsyncMock()
|
||||
|
||||
app = await api.app_factory()
|
||||
|
||||
# Verify custom endpoints were configured
|
||||
mock_endpoint1.add_routes.assert_called_once_with(app)
|
||||
mock_endpoint1.start.assert_called_once()
|
||||
mock_endpoint2.add_routes.assert_called_once_with(app)
|
||||
mock_endpoint2.start.assert_called_once()
|
||||
|
||||
def test_run_method_calls_web_run_app(self):
|
||||
"""Test that run method calls web.run_app"""
|
||||
with patch('pulsar.Client') as mock_client, \
|
||||
patch('aiohttp.web.run_app') as mock_run_app:
|
||||
mock_client.return_value = Mock()
|
||||
|
||||
api = Api(port=8080)
|
||||
api.run()
|
||||
|
||||
# Verify run_app was called once with the correct port
|
||||
mock_run_app.assert_called_once()
|
||||
args, kwargs = mock_run_app.call_args
|
||||
assert len(args) == 1 # Should have one positional arg (the coroutine)
|
||||
assert kwargs == {'port': 8080} # Should have port keyword arg
|
||||
|
||||
def test_api_components_initialization(self):
|
||||
"""Test that all API components are properly initialized"""
|
||||
with patch('pulsar.Client') as mock_client:
|
||||
mock_client.return_value = Mock()
|
||||
|
||||
api = Api()
|
||||
|
||||
# Verify all components are initialized
|
||||
assert api.config_receiver is not None
|
||||
assert api.dispatcher_manager is not None
|
||||
assert api.endpoint_manager is not None
|
||||
assert api.endpoints == []
|
||||
|
||||
# Verify component relationships
|
||||
assert api.dispatcher_manager.pulsar_client == api.pulsar_client
|
||||
assert api.dispatcher_manager.config_receiver == api.config_receiver
|
||||
assert api.endpoint_manager.dispatcher_manager == api.dispatcher_manager
|
||||
# EndpointManager doesn't store auth directly, it passes it to individual endpoints
|
||||
|
||||
|
||||
class TestRunFunction:
|
||||
"""Test cases for the run() function"""
|
||||
|
||||
def test_run_function_with_metrics_enabled(self):
|
||||
"""Test run function with metrics enabled"""
|
||||
import warnings
|
||||
# Suppress the specific async warning with a broader pattern
|
||||
warnings.filterwarnings("ignore", message=".*Api.app_factory.*was never awaited", category=RuntimeWarning)
|
||||
|
||||
with patch('argparse.ArgumentParser.parse_args') as mock_parse_args, \
|
||||
patch('trustgraph.gateway.service.start_http_server') as mock_start_http_server:
|
||||
|
||||
# Mock command line arguments
|
||||
mock_args = Mock()
|
||||
mock_args.metrics = True
|
||||
mock_args.metrics_port = 8000
|
||||
mock_parse_args.return_value = mock_args
|
||||
|
||||
# Create a simple mock instance without any async methods
|
||||
mock_api_instance = Mock()
|
||||
mock_api_instance.run = Mock()
|
||||
|
||||
# Create a mock Api class without importing the real one
|
||||
mock_api = Mock(return_value=mock_api_instance)
|
||||
|
||||
# Patch using context manager to avoid importing the real Api class
|
||||
with patch('trustgraph.gateway.service.Api', mock_api):
|
||||
# Mock vars() to return a dict
|
||||
with patch('builtins.vars') as mock_vars:
|
||||
mock_vars.return_value = {
|
||||
'metrics': True,
|
||||
'metrics_port': 8000,
|
||||
'pulsar_host': default_pulsar_host,
|
||||
'timeout': default_timeout
|
||||
}
|
||||
|
||||
run()
|
||||
|
||||
# Verify metrics server was started
|
||||
mock_start_http_server.assert_called_once_with(8000)
|
||||
|
||||
# Verify Api was created and run was called
|
||||
mock_api.assert_called_once()
|
||||
mock_api_instance.run.assert_called_once()
|
||||
|
||||
@patch('trustgraph.gateway.service.start_http_server')
|
||||
@patch('argparse.ArgumentParser.parse_args')
|
||||
def test_run_function_with_metrics_disabled(self, mock_parse_args, mock_start_http_server):
|
||||
"""Test run function with metrics disabled"""
|
||||
# Mock command line arguments
|
||||
mock_args = Mock()
|
||||
mock_args.metrics = False
|
||||
mock_parse_args.return_value = mock_args
|
||||
|
||||
# Create a simple mock instance without any async methods
|
||||
mock_api_instance = Mock()
|
||||
mock_api_instance.run = Mock()
|
||||
|
||||
# Patch the Api class inside the test without using decorators
|
||||
with patch('trustgraph.gateway.service.Api') as mock_api:
|
||||
mock_api.return_value = mock_api_instance
|
||||
|
||||
# Mock vars() to return a dict
|
||||
with patch('builtins.vars') as mock_vars:
|
||||
mock_vars.return_value = {
|
||||
'metrics': False,
|
||||
'metrics_port': 8000,
|
||||
'pulsar_host': default_pulsar_host,
|
||||
'timeout': default_timeout
|
||||
}
|
||||
|
||||
run()
|
||||
|
||||
# Verify metrics server was NOT started
|
||||
mock_start_http_server.assert_not_called()
|
||||
|
||||
# Verify Api was created and run was called
|
||||
mock_api.assert_called_once()
|
||||
mock_api_instance.run.assert_called_once()
|
||||
|
||||
@patch('argparse.ArgumentParser.parse_args')
|
||||
def test_run_function_argument_parsing(self, mock_parse_args):
|
||||
"""Test that run function properly parses command line arguments"""
|
||||
# Mock command line arguments
|
||||
mock_args = Mock()
|
||||
mock_args.metrics = False
|
||||
mock_parse_args.return_value = mock_args
|
||||
|
||||
# Create a simple mock instance without any async methods
|
||||
mock_api_instance = Mock()
|
||||
mock_api_instance.run = Mock()
|
||||
|
||||
# Mock vars() to return a dict with all expected arguments
|
||||
expected_args = {
|
||||
'pulsar_host': 'pulsar://test:6650',
|
||||
'pulsar_api_key': 'test-key',
|
||||
'pulsar_listener': 'test-listener',
|
||||
'prometheus_url': 'http://test-prometheus:9090',
|
||||
'port': 9000,
|
||||
'timeout': 300,
|
||||
'api_token': 'secret',
|
||||
'log_level': 'INFO',
|
||||
'metrics': False,
|
||||
'metrics_port': 8001
|
||||
}
|
||||
|
||||
# Patch the Api class inside the test without using decorators
|
||||
with patch('trustgraph.gateway.service.Api') as mock_api:
|
||||
mock_api.return_value = mock_api_instance
|
||||
|
||||
with patch('builtins.vars') as mock_vars:
|
||||
mock_vars.return_value = expected_args
|
||||
|
||||
run()
|
||||
|
||||
# Verify Api was created with the parsed arguments
|
||||
mock_api.assert_called_once_with(**expected_args)
|
||||
mock_api_instance.run.assert_called_once()
|
||||
|
||||
def test_run_function_creates_argument_parser(self):
|
||||
"""Test that run function creates argument parser with correct arguments"""
|
||||
with patch('argparse.ArgumentParser') as mock_parser_class:
|
||||
mock_parser = Mock()
|
||||
mock_parser_class.return_value = mock_parser
|
||||
mock_parser.parse_args.return_value = Mock(metrics=False)
|
||||
|
||||
with patch('trustgraph.gateway.service.Api') as mock_api, \
|
||||
patch('builtins.vars') as mock_vars:
|
||||
mock_vars.return_value = {'metrics': False}
|
||||
mock_api.return_value = Mock()
|
||||
|
||||
run()
|
||||
|
||||
# Verify ArgumentParser was created
|
||||
mock_parser_class.assert_called_once()
|
||||
|
||||
# Verify add_argument was called for each expected argument
|
||||
expected_arguments = [
|
||||
'pulsar-host', 'pulsar-api-key', 'pulsar-listener',
|
||||
'prometheus-url', 'port', 'timeout', 'api-token',
|
||||
'log-level', 'metrics', 'metrics-port'
|
||||
]
|
||||
|
||||
# Check that add_argument was called multiple times (once for each arg)
|
||||
assert mock_parser.add_argument.call_count >= len(expected_arguments)
|
||||
148
tests/unit/test_query/conftest.py
Normal file
148
tests/unit/test_query/conftest.py
Normal file
|
|
@ -0,0 +1,148 @@
|
|||
"""
|
||||
Shared fixtures for query tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_query_config():
|
||||
"""Base configuration for query processors"""
|
||||
return {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-query-processor'
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def qdrant_query_config(base_query_config):
|
||||
"""Configuration for Qdrant query processors"""
|
||||
return base_query_config | {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key'
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_qdrant_client():
|
||||
"""Mock Qdrant client"""
|
||||
mock_client = MagicMock()
|
||||
mock_client.query_points.return_value = []
|
||||
return mock_client
|
||||
|
||||
|
||||
# Graph embeddings query fixtures
|
||||
@pytest.fixture
|
||||
def mock_graph_embeddings_request():
|
||||
"""Mock graph embeddings request message"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'test_user'
|
||||
mock_message.collection = 'test_collection'
|
||||
return mock_message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_graph_embeddings_multiple_vectors():
|
||||
"""Mock graph embeddings request with multiple vectors"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4]]
|
||||
mock_message.limit = 3
|
||||
mock_message.user = 'multi_user'
|
||||
mock_message.collection = 'multi_collection'
|
||||
return mock_message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_graph_embeddings_query_response():
|
||||
"""Mock graph embeddings query response from Qdrant"""
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'entity': 'entity1'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'entity': 'entity2'}
|
||||
return [mock_point1, mock_point2]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_graph_embeddings_uri_response():
|
||||
"""Mock graph embeddings query response with URIs"""
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'entity': 'http://example.com/entity1'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'entity': 'https://secure.example.com/entity2'}
|
||||
mock_point3 = MagicMock()
|
||||
mock_point3.payload = {'entity': 'regular entity'}
|
||||
return [mock_point1, mock_point2, mock_point3]
|
||||
|
||||
|
||||
# Document embeddings query fixtures
|
||||
@pytest.fixture
|
||||
def mock_document_embeddings_request():
|
||||
"""Mock document embeddings request message"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'test_user'
|
||||
mock_message.collection = 'test_collection'
|
||||
return mock_message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document_embeddings_multiple_vectors():
|
||||
"""Mock document embeddings request with multiple vectors"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4]]
|
||||
mock_message.limit = 3
|
||||
mock_message.user = 'multi_user'
|
||||
mock_message.collection = 'multi_collection'
|
||||
return mock_message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document_embeddings_query_response():
|
||||
"""Mock document embeddings query response from Qdrant"""
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'doc': 'first document chunk'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'doc': 'second document chunk'}
|
||||
return [mock_point1, mock_point2]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document_embeddings_utf8_response():
|
||||
"""Mock document embeddings query response with UTF-8 content"""
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'doc': 'Document with UTF-8: café, naïve, résumé'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'doc': 'Chinese text: 你好世界'}
|
||||
return [mock_point1, mock_point2]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_empty_query_response():
|
||||
"""Mock empty query response"""
|
||||
return []
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_large_query_response():
|
||||
"""Mock large query response with many results"""
|
||||
mock_points = []
|
||||
for i in range(10):
|
||||
mock_point = MagicMock()
|
||||
mock_point.payload = {'doc': f'document chunk {i}'}
|
||||
mock_points.append(mock_point)
|
||||
return mock_points
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_mixed_dimension_vectors():
|
||||
"""Mock request with vectors of different dimensions"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4, 0.5]] # 2D and 3D
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'dim_user'
|
||||
mock_message.collection = 'dim_collection'
|
||||
return mock_message
|
||||
542
tests/unit/test_query/test_doc_embeddings_qdrant_query.py
Normal file
542
tests/unit/test_query/test_doc_embeddings_qdrant_query.py
Normal file
|
|
@ -0,0 +1,542 @@
|
|||
"""
|
||||
Unit tests for trustgraph.query.doc_embeddings.qdrant.service
|
||||
Testing document embeddings query functionality
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.query.doc_embeddings.qdrant.service import Processor
|
||||
|
||||
|
||||
class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
||||
"""Test Qdrant document embeddings query functionality"""
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test basic Qdrant processor initialization"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-doc-query-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify base class initialization was called
|
||||
mock_base_init.assert_called_once()
|
||||
|
||||
# Verify QdrantClient was created with correct parameters
|
||||
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key='test-api-key')
|
||||
|
||||
# Verify processor attributes
|
||||
assert hasattr(processor, 'qdrant')
|
||||
assert processor.qdrant == mock_qdrant_instance
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_processor_initialization_with_defaults(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test processor initialization with default values"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-doc-query-processor'
|
||||
# No store_uri or api_key provided - should use defaults
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify QdrantClient was created with default URI and None API key
|
||||
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key=None)
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_query_document_embeddings_single_vector(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying document embeddings with single vector"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query response
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'doc': 'first document chunk'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'doc': 'second document chunk'}
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = [mock_point1, mock_point2]
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'test_user'
|
||||
mock_message.collection = 'test_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called with correct parameters
|
||||
expected_collection = 'd_test_user_test_collection_3'
|
||||
mock_qdrant_instance.query_points.assert_called_once_with(
|
||||
collection_name=expected_collection,
|
||||
query=[0.1, 0.2, 0.3],
|
||||
limit=5, # Direct limit, no multiplication
|
||||
with_payload=True
|
||||
)
|
||||
|
||||
# Verify result contains expected documents
|
||||
assert len(result) == 2
|
||||
# Results should be strings (document chunks)
|
||||
assert isinstance(result[0], str)
|
||||
assert isinstance(result[1], str)
|
||||
# Verify content
|
||||
assert result[0] == 'first document chunk'
|
||||
assert result[1] == 'second document chunk'
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_query_document_embeddings_multiple_vectors(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying document embeddings with multiple vectors"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query responses for different vectors
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'doc': 'document from vector 1'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'doc': 'document from vector 2'}
|
||||
mock_point3 = MagicMock()
|
||||
mock_point3.payload = {'doc': 'another document from vector 2'}
|
||||
|
||||
mock_response1 = MagicMock()
|
||||
mock_response1.points = [mock_point1]
|
||||
mock_response2 = MagicMock()
|
||||
mock_response2.points = [mock_point2, mock_point3]
|
||||
mock_qdrant_instance.query_points.side_effect = [mock_response1, mock_response2]
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with multiple vectors
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4]]
|
||||
mock_message.limit = 3
|
||||
mock_message.user = 'multi_user'
|
||||
mock_message.collection = 'multi_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called twice
|
||||
assert mock_qdrant_instance.query_points.call_count == 2
|
||||
|
||||
# Verify both collections were queried
|
||||
expected_collection = 'd_multi_user_multi_collection_2'
|
||||
calls = mock_qdrant_instance.query_points.call_args_list
|
||||
assert calls[0][1]['collection_name'] == expected_collection
|
||||
assert calls[1][1]['collection_name'] == expected_collection
|
||||
assert calls[0][1]['query'] == [0.1, 0.2]
|
||||
assert calls[1][1]['query'] == [0.3, 0.4]
|
||||
|
||||
# Verify results from both vectors are combined
|
||||
assert len(result) == 3
|
||||
assert 'document from vector 1' in result
|
||||
assert 'document from vector 2' in result
|
||||
assert 'another document from vector 2' in result
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_query_document_embeddings_with_limit(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying document embeddings respects limit parameter"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query response with many results
|
||||
mock_points = []
|
||||
for i in range(10):
|
||||
mock_point = MagicMock()
|
||||
mock_point.payload = {'doc': f'document chunk {i}'}
|
||||
mock_points.append(mock_point)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = mock_points
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with limit
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_message.limit = 3 # Should only return 3 results
|
||||
mock_message.user = 'limit_user'
|
||||
mock_message.collection = 'limit_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called with exact limit (no multiplication)
|
||||
mock_qdrant_instance.query_points.assert_called_once()
|
||||
call_args = mock_qdrant_instance.query_points.call_args
|
||||
assert call_args[1]['limit'] == 3 # Direct limit
|
||||
|
||||
# Verify result contains all returned documents (limit applied by Qdrant)
|
||||
assert len(result) == 10 # All results returned by mock
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_query_document_embeddings_empty_results(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying document embeddings with empty results"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock empty query response
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = []
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'empty_user'
|
||||
mock_message.collection = 'empty_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_query_document_embeddings_different_dimensions(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying document embeddings with different vector dimensions"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query responses
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'doc': 'document from 2D vector'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'doc': 'document from 3D vector'}
|
||||
|
||||
mock_response1 = MagicMock()
|
||||
mock_response1.points = [mock_point1]
|
||||
mock_response2 = MagicMock()
|
||||
mock_response2.points = [mock_point2]
|
||||
mock_qdrant_instance.query_points.side_effect = [mock_response1, mock_response2]
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with different dimension vectors
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4, 0.5]] # 2D and 3D
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'dim_user'
|
||||
mock_message.collection = 'dim_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called twice with different collections
|
||||
assert mock_qdrant_instance.query_points.call_count == 2
|
||||
calls = mock_qdrant_instance.query_points.call_args_list
|
||||
|
||||
# First call should use 2D collection
|
||||
assert calls[0][1]['collection_name'] == 'd_dim_user_dim_collection_2'
|
||||
assert calls[0][1]['query'] == [0.1, 0.2]
|
||||
|
||||
# Second call should use 3D collection
|
||||
assert calls[1][1]['collection_name'] == 'd_dim_user_dim_collection_3'
|
||||
assert calls[1][1]['query'] == [0.3, 0.4, 0.5]
|
||||
|
||||
# Verify results
|
||||
assert len(result) == 2
|
||||
assert 'document from 2D vector' in result
|
||||
assert 'document from 3D vector' in result
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_query_document_embeddings_utf8_encoding(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying document embeddings with UTF-8 content"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query response with UTF-8 content
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'doc': 'Document with UTF-8: café, naïve, résumé'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'doc': 'Chinese text: 你好世界'}
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = [mock_point1, mock_point2]
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'utf8_user'
|
||||
mock_message.collection = 'utf8_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
|
||||
# Verify UTF-8 content works correctly
|
||||
assert 'Document with UTF-8: café, naïve, résumé' in result
|
||||
assert 'Chinese text: 你好世界' in result
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_query_document_embeddings_qdrant_error(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying document embeddings handles Qdrant errors"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock Qdrant error
|
||||
mock_qdrant_instance.query_points.side_effect = Exception("Qdrant connection failed")
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'error_user'
|
||||
mock_message.collection = 'error_collection'
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Qdrant connection failed"):
|
||||
await processor.query_document_embeddings(mock_message)
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_query_document_embeddings_zero_limit(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying document embeddings with zero limit"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query response
|
||||
mock_point = MagicMock()
|
||||
mock_point.payload = {'doc': 'document chunk'}
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = [mock_point]
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with zero limit
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.limit = 0
|
||||
mock_message.user = 'zero_user'
|
||||
mock_message.collection = 'zero_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Should still query (with limit 0)
|
||||
mock_qdrant_instance.query_points.assert_called_once()
|
||||
call_args = mock_qdrant_instance.query_points.call_args
|
||||
assert call_args[1]['limit'] == 0
|
||||
|
||||
# Result should contain all returned documents
|
||||
assert len(result) == 1
|
||||
assert result[0] == 'document chunk'
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_query_document_embeddings_large_limit(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying document embeddings with large limit"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query response with fewer results than limit
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'doc': 'document 1'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'doc': 'document 2'}
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = [mock_point1, mock_point2]
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with large limit
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.limit = 1000 # Large limit
|
||||
mock_message.user = 'large_user'
|
||||
mock_message.collection = 'large_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Should query with full limit
|
||||
mock_qdrant_instance.query_points.assert_called_once()
|
||||
call_args = mock_qdrant_instance.query_points.call_args
|
||||
assert call_args[1]['limit'] == 1000
|
||||
|
||||
# Result should contain all available documents
|
||||
assert len(result) == 2
|
||||
assert 'document 1' in result
|
||||
assert 'document 2' in result
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_query_document_embeddings_missing_payload(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying document embeddings with missing payload data"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query response with missing 'doc' key
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'doc': 'valid document'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {} # Missing 'doc' key
|
||||
mock_point3 = MagicMock()
|
||||
mock_point3.payload = {'other_key': 'invalid'} # Wrong key
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = [mock_point1, mock_point2, mock_point3]
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'payload_user'
|
||||
mock_message.collection = 'payload_collection'
|
||||
|
||||
# Act & Assert
|
||||
# This should raise a KeyError when trying to access payload['doc']
|
||||
with pytest.raises(KeyError):
|
||||
await processor.query_document_embeddings(mock_message)
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_add_args_calls_parent(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test that add_args() calls parent add_args method"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_client.return_value = MagicMock()
|
||||
mock_parser = MagicMock()
|
||||
|
||||
# Act
|
||||
with patch('trustgraph.base.DocumentEmbeddingsQueryService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(mock_parser)
|
||||
|
||||
# Assert
|
||||
mock_parent_add_args.assert_called_once_with(mock_parser)
|
||||
|
||||
# Verify processor-specific arguments were added
|
||||
assert mock_parser.add_argument.call_count >= 2 # At least store-uri and api-key
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
537
tests/unit/test_query/test_graph_embeddings_qdrant_query.py
Normal file
537
tests/unit/test_query/test_graph_embeddings_qdrant_query.py
Normal file
|
|
@ -0,0 +1,537 @@
|
|||
"""
|
||||
Unit tests for trustgraph.query.graph_embeddings.qdrant.service
|
||||
Testing graph embeddings query functionality
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.query.graph_embeddings.qdrant.service import Processor
|
||||
|
||||
|
||||
class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
||||
"""Test Qdrant graph embeddings query functionality"""
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test basic Qdrant processor initialization"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-graph-query-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify base class initialization was called
|
||||
mock_base_init.assert_called_once()
|
||||
|
||||
# Verify QdrantClient was created with correct parameters
|
||||
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key='test-api-key')
|
||||
|
||||
# Verify processor attributes
|
||||
assert hasattr(processor, 'qdrant')
|
||||
assert processor.qdrant == mock_qdrant_instance
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_processor_initialization_with_defaults(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test processor initialization with default values"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-graph-query-processor'
|
||||
# No store_uri or api_key provided - should use defaults
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify QdrantClient was created with default URI and None API key
|
||||
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key=None)
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_create_value_http_uri(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test create_value with HTTP URI"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_client.return_value = MagicMock()
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
value = processor.create_value('http://example.com/entity')
|
||||
|
||||
# Assert
|
||||
assert hasattr(value, 'value')
|
||||
assert value.value == 'http://example.com/entity'
|
||||
assert hasattr(value, 'is_uri')
|
||||
assert value.is_uri == True
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_create_value_https_uri(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test create_value with HTTPS URI"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_client.return_value = MagicMock()
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
value = processor.create_value('https://secure.example.com/entity')
|
||||
|
||||
# Assert
|
||||
assert hasattr(value, 'value')
|
||||
assert value.value == 'https://secure.example.com/entity'
|
||||
assert hasattr(value, 'is_uri')
|
||||
assert value.is_uri == True
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_create_value_regular_string(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test create_value with regular string (non-URI)"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_client.return_value = MagicMock()
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
value = processor.create_value('regular entity name')
|
||||
|
||||
# Assert
|
||||
assert hasattr(value, 'value')
|
||||
assert value.value == 'regular entity name'
|
||||
assert hasattr(value, 'is_uri')
|
||||
assert value.is_uri == False
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_query_graph_embeddings_single_vector(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying graph embeddings with single vector"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query response
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'entity': 'entity1'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'entity': 'entity2'}
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = [mock_point1, mock_point2]
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'test_user'
|
||||
mock_message.collection = 'test_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called with correct parameters
|
||||
expected_collection = 't_test_user_test_collection_3'
|
||||
mock_qdrant_instance.query_points.assert_called_once_with(
|
||||
collection_name=expected_collection,
|
||||
query=[0.1, 0.2, 0.3],
|
||||
limit=10, # limit * 2 for deduplication
|
||||
with_payload=True
|
||||
)
|
||||
|
||||
# Verify result contains expected entities
|
||||
assert len(result) == 2
|
||||
assert all(hasattr(entity, 'value') for entity in result)
|
||||
entity_values = [entity.value for entity in result]
|
||||
assert 'entity1' in entity_values
|
||||
assert 'entity2' in entity_values
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_query_graph_embeddings_multiple_vectors(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying graph embeddings with multiple vectors"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query responses for different vectors
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'entity': 'entity1'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'entity': 'entity2'}
|
||||
mock_point3 = MagicMock()
|
||||
mock_point3.payload = {'entity': 'entity3'}
|
||||
|
||||
mock_response1 = MagicMock()
|
||||
mock_response1.points = [mock_point1, mock_point2]
|
||||
mock_response2 = MagicMock()
|
||||
mock_response2.points = [mock_point2, mock_point3]
|
||||
mock_qdrant_instance.query_points.side_effect = [mock_response1, mock_response2]
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with multiple vectors
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4]]
|
||||
mock_message.limit = 3
|
||||
mock_message.user = 'multi_user'
|
||||
mock_message.collection = 'multi_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called twice
|
||||
assert mock_qdrant_instance.query_points.call_count == 2
|
||||
|
||||
# Verify both collections were queried
|
||||
expected_collection = 't_multi_user_multi_collection_2'
|
||||
calls = mock_qdrant_instance.query_points.call_args_list
|
||||
assert calls[0][1]['collection_name'] == expected_collection
|
||||
assert calls[1][1]['collection_name'] == expected_collection
|
||||
assert calls[0][1]['query'] == [0.1, 0.2]
|
||||
assert calls[1][1]['query'] == [0.3, 0.4]
|
||||
|
||||
# Verify deduplication - entity2 appears in both results but should only appear once
|
||||
entity_values = [entity.value for entity in result]
|
||||
assert len(set(entity_values)) == len(entity_values) # All unique
|
||||
assert 'entity1' in entity_values
|
||||
assert 'entity2' in entity_values
|
||||
assert 'entity3' in entity_values
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_query_graph_embeddings_with_limit(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying graph embeddings respects limit parameter"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query response with more results than limit
|
||||
mock_points = []
|
||||
for i in range(10):
|
||||
mock_point = MagicMock()
|
||||
mock_point.payload = {'entity': f'entity{i}'}
|
||||
mock_points.append(mock_point)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = mock_points
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with limit
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_message.limit = 3 # Should only return 3 results
|
||||
mock_message.user = 'limit_user'
|
||||
mock_message.collection = 'limit_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called with limit * 2
|
||||
mock_qdrant_instance.query_points.assert_called_once()
|
||||
call_args = mock_qdrant_instance.query_points.call_args
|
||||
assert call_args[1]['limit'] == 6 # 3 * 2
|
||||
|
||||
# Verify result is limited to requested limit
|
||||
assert len(result) == 3
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_query_graph_embeddings_empty_results(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying graph embeddings with empty results"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock empty query response
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = []
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'empty_user'
|
||||
mock_message.collection = 'empty_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_query_graph_embeddings_different_dimensions(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying graph embeddings with different vector dimensions"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query responses
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'entity': 'entity2d'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'entity': 'entity3d'}
|
||||
|
||||
mock_response1 = MagicMock()
|
||||
mock_response1.points = [mock_point1]
|
||||
mock_response2 = MagicMock()
|
||||
mock_response2.points = [mock_point2]
|
||||
mock_qdrant_instance.query_points.side_effect = [mock_response1, mock_response2]
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with different dimension vectors
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4, 0.5]] # 2D and 3D
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'dim_user'
|
||||
mock_message.collection = 'dim_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called twice with different collections
|
||||
assert mock_qdrant_instance.query_points.call_count == 2
|
||||
calls = mock_qdrant_instance.query_points.call_args_list
|
||||
|
||||
# First call should use 2D collection
|
||||
assert calls[0][1]['collection_name'] == 't_dim_user_dim_collection_2'
|
||||
assert calls[0][1]['query'] == [0.1, 0.2]
|
||||
|
||||
# Second call should use 3D collection
|
||||
assert calls[1][1]['collection_name'] == 't_dim_user_dim_collection_3'
|
||||
assert calls[1][1]['query'] == [0.3, 0.4, 0.5]
|
||||
|
||||
# Verify results
|
||||
entity_values = [entity.value for entity in result]
|
||||
assert 'entity2d' in entity_values
|
||||
assert 'entity3d' in entity_values
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_query_graph_embeddings_uri_detection(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying graph embeddings with URI detection"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query response with URIs and regular strings
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'entity': 'http://example.com/entity1'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'entity': 'https://secure.example.com/entity2'}
|
||||
mock_point3 = MagicMock()
|
||||
mock_point3.payload = {'entity': 'regular entity'}
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = [mock_point1, mock_point2, mock_point3]
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'uri_user'
|
||||
mock_message.collection = 'uri_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 3
|
||||
|
||||
# Check URI entities
|
||||
uri_entities = [entity for entity in result if hasattr(entity, 'is_uri') and entity.is_uri]
|
||||
assert len(uri_entities) == 2
|
||||
uri_values = [entity.value for entity in uri_entities]
|
||||
assert 'http://example.com/entity1' in uri_values
|
||||
assert 'https://secure.example.com/entity2' in uri_values
|
||||
|
||||
# Check regular entities
|
||||
regular_entities = [entity for entity in result if hasattr(entity, 'is_uri') and not entity.is_uri]
|
||||
assert len(regular_entities) == 1
|
||||
assert regular_entities[0].value == 'regular entity'
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_query_graph_embeddings_qdrant_error(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying graph embeddings handles Qdrant errors"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock Qdrant error
|
||||
mock_qdrant_instance.query_points.side_effect = Exception("Qdrant connection failed")
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'error_user'
|
||||
mock_message.collection = 'error_collection'
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Qdrant connection failed"):
|
||||
await processor.query_graph_embeddings(mock_message)
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_query_graph_embeddings_zero_limit(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying graph embeddings with zero limit"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query response - even with zero limit, Qdrant might return results
|
||||
mock_point = MagicMock()
|
||||
mock_point.payload = {'entity': 'entity1'}
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = [mock_point]
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with zero limit
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.limit = 0
|
||||
mock_message.user = 'zero_user'
|
||||
mock_message.collection = 'zero_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Should still query (with limit 0)
|
||||
mock_qdrant_instance.query_points.assert_called_once()
|
||||
call_args = mock_qdrant_instance.query_points.call_args
|
||||
assert call_args[1]['limit'] == 0 # 0 * 2 = 0
|
||||
|
||||
# With zero limit, the logic still adds one entity before checking the limit
|
||||
# So it returns one result (current behavior, not ideal but actual)
|
||||
assert len(result) == 1
|
||||
assert result[0].value == 'entity1'
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_add_args_calls_parent(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test that add_args() calls parent add_args method"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_client.return_value = MagicMock()
|
||||
mock_parser = MagicMock()
|
||||
|
||||
# Act
|
||||
with patch('trustgraph.base.GraphEmbeddingsQueryService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(mock_parser)
|
||||
|
||||
# Assert
|
||||
mock_parent_add_args.assert_called_once_with(mock_parser)
|
||||
|
||||
# Verify processor-specific arguments were added
|
||||
assert mock_parser.add_argument.call_count >= 2 # At least store-uri and api-key
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
539
tests/unit/test_query/test_triples_cassandra_query.py
Normal file
539
tests/unit/test_query/test_triples_cassandra_query.py
Normal file
|
|
@ -0,0 +1,539 @@
|
|||
"""
|
||||
Tests for Cassandra triples query service
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.query.triples.cassandra.service import Processor
|
||||
from trustgraph.schema import Value
|
||||
|
||||
|
||||
class TestCassandraQueryProcessor:
|
||||
"""Test cases for Cassandra query processor"""
|
||||
|
||||
@pytest.fixture
|
||||
def processor(self):
|
||||
"""Create a processor instance for testing"""
|
||||
return Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id='test-cassandra-query',
|
||||
graph_host='localhost'
|
||||
)
|
||||
|
||||
def test_create_value_with_http_uri(self, processor):
|
||||
"""Test create_value with HTTP URI"""
|
||||
result = processor.create_value("http://example.com/resource")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "http://example.com/resource"
|
||||
assert result.is_uri is True
|
||||
|
||||
def test_create_value_with_https_uri(self, processor):
|
||||
"""Test create_value with HTTPS URI"""
|
||||
result = processor.create_value("https://example.com/resource")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "https://example.com/resource"
|
||||
assert result.is_uri is True
|
||||
|
||||
def test_create_value_with_literal(self, processor):
|
||||
"""Test create_value with literal value"""
|
||||
result = processor.create_value("just a literal string")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "just a literal string"
|
||||
assert result.is_uri is False
|
||||
|
||||
def test_create_value_with_empty_string(self, processor):
|
||||
"""Test create_value with empty string"""
|
||||
result = processor.create_value("")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == ""
|
||||
assert result.is_uri is False
|
||||
|
||||
def test_create_value_with_partial_uri(self, processor):
|
||||
"""Test create_value with string that looks like URI but isn't complete"""
|
||||
result = processor.create_value("http")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "http"
|
||||
assert result.is_uri is False
|
||||
|
||||
def test_create_value_with_ftp_uri(self, processor):
|
||||
"""Test create_value with FTP URI (should not be detected as URI)"""
|
||||
result = processor.create_value("ftp://example.com/file")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "ftp://example.com/file"
|
||||
assert result.is_uri is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
|
||||
async def test_query_triples_spo_query(self, mock_trustgraph):
|
||||
"""Test querying triples with subject, predicate, and object specified"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
# Setup mock TrustGraph
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
mock_tg_instance.get_spo.return_value = None # SPO query returns None if found
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id='test-cassandra-query',
|
||||
graph_host='localhost'
|
||||
)
|
||||
|
||||
# Create query request with all SPO values
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
p=Value(value='test_predicate', is_uri=False),
|
||||
o=Value(value='test_object', is_uri=False),
|
||||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
# Verify TrustGraph was created with correct parameters
|
||||
mock_trustgraph.assert_called_once_with(
|
||||
hosts=['localhost'],
|
||||
keyspace='test_user',
|
||||
table='test_collection'
|
||||
)
|
||||
|
||||
# Verify get_spo was called with correct parameters
|
||||
mock_tg_instance.get_spo.assert_called_once_with(
|
||||
'test_subject', 'test_predicate', 'test_object', limit=100
|
||||
)
|
||||
|
||||
# Verify result contains the queried triple
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'test_subject'
|
||||
assert result[0].p.value == 'test_predicate'
|
||||
assert result[0].o.value == 'test_object'
|
||||
|
||||
def test_processor_initialization_with_defaults(self):
|
||||
"""Test processor initialization with default parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
assert processor.graph_host == ['localhost']
|
||||
assert processor.username is None
|
||||
assert processor.password is None
|
||||
assert processor.table is None
|
||||
|
||||
def test_processor_initialization_with_custom_params(self):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
graph_host='cassandra.example.com',
|
||||
graph_username='queryuser',
|
||||
graph_password='querypass'
|
||||
)
|
||||
|
||||
assert processor.graph_host == ['cassandra.example.com']
|
||||
assert processor.username == 'queryuser'
|
||||
assert processor.password == 'querypass'
|
||||
assert processor.table is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
|
||||
async def test_query_triples_sp_pattern(self, mock_trustgraph):
|
||||
"""Test SP query pattern (subject and predicate, no object)"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
# Setup mock TrustGraph and response
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.o = 'result_object'
|
||||
mock_tg_instance.get_sp.return_value = [mock_result]
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
p=Value(value='test_predicate', is_uri=False),
|
||||
o=None,
|
||||
limit=50
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
mock_tg_instance.get_sp.assert_called_once_with('test_subject', 'test_predicate', limit=50)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'test_subject'
|
||||
assert result[0].p.value == 'test_predicate'
|
||||
assert result[0].o.value == 'result_object'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
|
||||
async def test_query_triples_s_pattern(self, mock_trustgraph):
|
||||
"""Test S query pattern (subject only)"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.p = 'result_predicate'
|
||||
mock_result.o = 'result_object'
|
||||
mock_tg_instance.get_s.return_value = [mock_result]
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
p=None,
|
||||
o=None,
|
||||
limit=25
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
mock_tg_instance.get_s.assert_called_once_with('test_subject', limit=25)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'test_subject'
|
||||
assert result[0].p.value == 'result_predicate'
|
||||
assert result[0].o.value == 'result_object'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
|
||||
async def test_query_triples_p_pattern(self, mock_trustgraph):
|
||||
"""Test P query pattern (predicate only)"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.s = 'result_subject'
|
||||
mock_result.o = 'result_object'
|
||||
mock_tg_instance.get_p.return_value = [mock_result]
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=Value(value='test_predicate', is_uri=False),
|
||||
o=None,
|
||||
limit=10
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
mock_tg_instance.get_p.assert_called_once_with('test_predicate', limit=10)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'result_subject'
|
||||
assert result[0].p.value == 'test_predicate'
|
||||
assert result[0].o.value == 'result_object'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
|
||||
async def test_query_triples_o_pattern(self, mock_trustgraph):
|
||||
"""Test O query pattern (object only)"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.s = 'result_subject'
|
||||
mock_result.p = 'result_predicate'
|
||||
mock_tg_instance.get_o.return_value = [mock_result]
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=None,
|
||||
o=Value(value='test_object', is_uri=False),
|
||||
limit=75
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
mock_tg_instance.get_o.assert_called_once_with('test_object', limit=75)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'result_subject'
|
||||
assert result[0].p.value == 'result_predicate'
|
||||
assert result[0].o.value == 'test_object'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
|
||||
async def test_query_triples_get_all_pattern(self, mock_trustgraph):
|
||||
"""Test query pattern with no constraints (get all)"""
|
||||
from trustgraph.schema import TriplesQueryRequest
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.s = 'all_subject'
|
||||
mock_result.p = 'all_predicate'
|
||||
mock_result.o = 'all_object'
|
||||
mock_tg_instance.get_all.return_value = [mock_result]
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=None,
|
||||
o=None,
|
||||
limit=1000
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
mock_tg_instance.get_all.assert_called_once_with(limit=1000)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'all_subject'
|
||||
assert result[0].p.value == 'all_predicate'
|
||||
assert result[0].o.value == 'all_object'
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
from argparse import ArgumentParser
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
# Mock the parent class add_args method
|
||||
with patch('trustgraph.query.triples.cassandra.service.TriplesQueryService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Verify parent add_args was called
|
||||
mock_parent_add_args.assert_called_once_with(parser)
|
||||
|
||||
# Verify our specific arguments were added
|
||||
args = parser.parse_args([])
|
||||
|
||||
assert hasattr(args, 'graph_host')
|
||||
assert args.graph_host == 'localhost'
|
||||
assert hasattr(args, 'graph_username')
|
||||
assert args.graph_username is None
|
||||
assert hasattr(args, 'graph_password')
|
||||
assert args.graph_password is None
|
||||
|
||||
def test_add_args_with_custom_values(self):
|
||||
"""Test add_args with custom command line values"""
|
||||
from argparse import ArgumentParser
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.query.triples.cassandra.service.TriplesQueryService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with custom values
|
||||
args = parser.parse_args([
|
||||
'--graph-host', 'query.cassandra.com',
|
||||
'--graph-username', 'queryuser',
|
||||
'--graph-password', 'querypass'
|
||||
])
|
||||
|
||||
assert args.graph_host == 'query.cassandra.com'
|
||||
assert args.graph_username == 'queryuser'
|
||||
assert args.graph_password == 'querypass'
|
||||
|
||||
def test_add_args_short_form(self):
|
||||
"""Test add_args with short form arguments"""
|
||||
from argparse import ArgumentParser
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.query.triples.cassandra.service.TriplesQueryService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with short form
|
||||
args = parser.parse_args(['-g', 'short.query.com'])
|
||||
|
||||
assert args.graph_host == 'short.query.com'
|
||||
|
||||
@patch('trustgraph.query.triples.cassandra.service.Processor.launch')
|
||||
def test_run_function(self, mock_launch):
|
||||
"""Test the run function calls Processor.launch with correct parameters"""
|
||||
from trustgraph.query.triples.cassandra.service import run, default_ident
|
||||
|
||||
run()
|
||||
|
||||
mock_launch.assert_called_once_with(default_ident, '\nTriples query service. Input is a (s, p, o) triple, some values may be\nnull. Output is a list of triples.\n')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
|
||||
async def test_query_triples_with_authentication(self, mock_trustgraph):
|
||||
"""Test querying with username and password authentication"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
mock_tg_instance.get_spo.return_value = None
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
graph_username='authuser',
|
||||
graph_password='authpass'
|
||||
)
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
p=Value(value='test_predicate', is_uri=False),
|
||||
o=Value(value='test_object', is_uri=False),
|
||||
limit=100
|
||||
)
|
||||
|
||||
await processor.query_triples(query)
|
||||
|
||||
# Verify TrustGraph was created with authentication
|
||||
mock_trustgraph.assert_called_once_with(
|
||||
hosts=['localhost'],
|
||||
keyspace='test_user',
|
||||
table='test_collection',
|
||||
username='authuser',
|
||||
password='authpass'
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
|
||||
async def test_query_triples_table_reuse(self, mock_trustgraph):
|
||||
"""Test that TrustGraph is reused for same table"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
mock_tg_instance.get_spo.return_value = None
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
p=Value(value='test_predicate', is_uri=False),
|
||||
o=Value(value='test_object', is_uri=False),
|
||||
limit=100
|
||||
)
|
||||
|
||||
# First query should create TrustGraph
|
||||
await processor.query_triples(query)
|
||||
assert mock_trustgraph.call_count == 1
|
||||
|
||||
# Second query with same table should reuse TrustGraph
|
||||
await processor.query_triples(query)
|
||||
assert mock_trustgraph.call_count == 1 # Should not increase
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
|
||||
async def test_query_triples_table_switching(self, mock_trustgraph):
|
||||
"""Test table switching creates new TrustGraph"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
mock_tg_instance1 = MagicMock()
|
||||
mock_tg_instance2 = MagicMock()
|
||||
mock_trustgraph.side_effect = [mock_tg_instance1, mock_tg_instance2]
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
# First query
|
||||
query1 = TriplesQueryRequest(
|
||||
user='user1',
|
||||
collection='collection1',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
p=None,
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
||||
await processor.query_triples(query1)
|
||||
assert processor.table == ('user1', 'collection1')
|
||||
|
||||
# Second query with different table
|
||||
query2 = TriplesQueryRequest(
|
||||
user='user2',
|
||||
collection='collection2',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
p=None,
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
||||
await processor.query_triples(query2)
|
||||
assert processor.table == ('user2', 'collection2')
|
||||
|
||||
# Verify TrustGraph was created twice
|
||||
assert mock_trustgraph.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
|
||||
async def test_query_triples_exception_handling(self, mock_trustgraph):
|
||||
"""Test exception handling during query execution"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
mock_tg_instance.get_spo.side_effect = Exception("Query failed")
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
p=Value(value='test_predicate', is_uri=False),
|
||||
o=Value(value='test_object', is_uri=False),
|
||||
limit=100
|
||||
)
|
||||
|
||||
with pytest.raises(Exception, match="Query failed"):
|
||||
await processor.query_triples(query)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
|
||||
async def test_query_triples_multiple_results(self, mock_trustgraph):
|
||||
"""Test query returning multiple results"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
# Mock multiple results
|
||||
mock_result1 = MagicMock()
|
||||
mock_result1.o = 'object1'
|
||||
mock_result2 = MagicMock()
|
||||
mock_result2.o = 'object2'
|
||||
mock_tg_instance.get_sp.return_value = [mock_result1, mock_result2]
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
p=Value(value='test_predicate', is_uri=False),
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0].o.value == 'object1'
|
||||
assert result[1].o.value == 'object2'
|
||||
475
tests/unit/test_retrieval/test_document_rag.py
Normal file
475
tests/unit/test_retrieval/test_document_rag.py
Normal file
|
|
@ -0,0 +1,475 @@
|
|||
"""
|
||||
Tests for DocumentRAG retrieval implementation
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
|
||||
from trustgraph.retrieval.document_rag.document_rag import DocumentRag, Query
|
||||
|
||||
|
||||
class TestDocumentRag:
|
||||
"""Test cases for DocumentRag class"""
|
||||
|
||||
def test_document_rag_initialization_with_defaults(self):
|
||||
"""Test DocumentRag initialization with default verbose setting"""
|
||||
# Create mock clients
|
||||
mock_prompt_client = MagicMock()
|
||||
mock_embeddings_client = MagicMock()
|
||||
mock_doc_embeddings_client = MagicMock()
|
||||
|
||||
# Initialize DocumentRag
|
||||
document_rag = DocumentRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
embeddings_client=mock_embeddings_client,
|
||||
doc_embeddings_client=mock_doc_embeddings_client
|
||||
)
|
||||
|
||||
# Verify initialization
|
||||
assert document_rag.prompt_client == mock_prompt_client
|
||||
assert document_rag.embeddings_client == mock_embeddings_client
|
||||
assert document_rag.doc_embeddings_client == mock_doc_embeddings_client
|
||||
assert document_rag.verbose is False # Default value
|
||||
|
||||
def test_document_rag_initialization_with_verbose(self):
|
||||
"""Test DocumentRag initialization with verbose enabled"""
|
||||
# Create mock clients
|
||||
mock_prompt_client = MagicMock()
|
||||
mock_embeddings_client = MagicMock()
|
||||
mock_doc_embeddings_client = MagicMock()
|
||||
|
||||
# Initialize DocumentRag with verbose=True
|
||||
document_rag = DocumentRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
embeddings_client=mock_embeddings_client,
|
||||
doc_embeddings_client=mock_doc_embeddings_client,
|
||||
verbose=True
|
||||
)
|
||||
|
||||
# Verify initialization
|
||||
assert document_rag.prompt_client == mock_prompt_client
|
||||
assert document_rag.embeddings_client == mock_embeddings_client
|
||||
assert document_rag.doc_embeddings_client == mock_doc_embeddings_client
|
||||
assert document_rag.verbose is True
|
||||
|
||||
|
||||
class TestQuery:
|
||||
"""Test cases for Query class"""
|
||||
|
||||
def test_query_initialization_with_defaults(self):
|
||||
"""Test Query initialization with default parameters"""
|
||||
# Create mock DocumentRag
|
||||
mock_rag = MagicMock()
|
||||
|
||||
# Initialize Query with defaults
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Verify initialization
|
||||
assert query.rag == mock_rag
|
||||
assert query.user == "test_user"
|
||||
assert query.collection == "test_collection"
|
||||
assert query.verbose is False
|
||||
assert query.doc_limit == 20 # Default value
|
||||
|
||||
def test_query_initialization_with_custom_doc_limit(self):
|
||||
"""Test Query initialization with custom doc_limit"""
|
||||
# Create mock DocumentRag
|
||||
mock_rag = MagicMock()
|
||||
|
||||
# Initialize Query with custom doc_limit
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="custom_user",
|
||||
collection="custom_collection",
|
||||
verbose=True,
|
||||
doc_limit=50
|
||||
)
|
||||
|
||||
# Verify initialization
|
||||
assert query.rag == mock_rag
|
||||
assert query.user == "custom_user"
|
||||
assert query.collection == "custom_collection"
|
||||
assert query.verbose is True
|
||||
assert query.doc_limit == 50
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_vector_method(self):
|
||||
"""Test Query.get_vector method calls embeddings client correctly"""
|
||||
# Create mock DocumentRag with embeddings client
|
||||
mock_rag = MagicMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_rag.embeddings_client = mock_embeddings_client
|
||||
|
||||
# Mock the embed method to return test vectors
|
||||
expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
mock_embeddings_client.embed.return_value = expected_vectors
|
||||
|
||||
# Initialize Query
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Call get_vector
|
||||
test_query = "What documents are relevant?"
|
||||
result = await query.get_vector(test_query)
|
||||
|
||||
# Verify embeddings client was called correctly
|
||||
mock_embeddings_client.embed.assert_called_once_with(test_query)
|
||||
|
||||
# Verify result matches expected vectors
|
||||
assert result == expected_vectors
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_docs_method(self):
|
||||
"""Test Query.get_docs method retrieves documents correctly"""
|
||||
# Create mock DocumentRag with clients
|
||||
mock_rag = MagicMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_doc_embeddings_client = AsyncMock()
|
||||
mock_rag.embeddings_client = mock_embeddings_client
|
||||
mock_rag.doc_embeddings_client = mock_doc_embeddings_client
|
||||
|
||||
# Mock the embedding and document query responses
|
||||
test_vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_embeddings_client.embed.return_value = test_vectors
|
||||
|
||||
# Mock document results
|
||||
test_docs = ["Document 1 content", "Document 2 content"]
|
||||
mock_doc_embeddings_client.query.return_value = test_docs
|
||||
|
||||
# Initialize Query
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
doc_limit=15
|
||||
)
|
||||
|
||||
# Call get_docs
|
||||
test_query = "Find relevant documents"
|
||||
result = await query.get_docs(test_query)
|
||||
|
||||
# Verify embeddings client was called
|
||||
mock_embeddings_client.embed.assert_called_once_with(test_query)
|
||||
|
||||
# Verify doc embeddings client was called correctly
|
||||
mock_doc_embeddings_client.query.assert_called_once_with(
|
||||
test_vectors,
|
||||
limit=15,
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
# Verify result is list of documents
|
||||
assert result == test_docs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_query_method(self):
|
||||
"""Test DocumentRag.query method orchestrates full document RAG pipeline"""
|
||||
# Create mock clients
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_doc_embeddings_client = AsyncMock()
|
||||
|
||||
# Mock embeddings and document responses
|
||||
test_vectors = [[0.1, 0.2, 0.3]]
|
||||
test_docs = ["Relevant document content", "Another document"]
|
||||
expected_response = "This is the document RAG response"
|
||||
|
||||
mock_embeddings_client.embed.return_value = test_vectors
|
||||
mock_doc_embeddings_client.query.return_value = test_docs
|
||||
mock_prompt_client.document_prompt.return_value = expected_response
|
||||
|
||||
# Initialize DocumentRag
|
||||
document_rag = DocumentRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
embeddings_client=mock_embeddings_client,
|
||||
doc_embeddings_client=mock_doc_embeddings_client,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Call DocumentRag.query
|
||||
result = await document_rag.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
doc_limit=10
|
||||
)
|
||||
|
||||
# Verify embeddings client was called
|
||||
mock_embeddings_client.embed.assert_called_once_with("test query")
|
||||
|
||||
# Verify doc embeddings client was called
|
||||
mock_doc_embeddings_client.query.assert_called_once_with(
|
||||
test_vectors,
|
||||
limit=10,
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
# Verify prompt client was called with documents and query
|
||||
mock_prompt_client.document_prompt.assert_called_once_with(
|
||||
query="test query",
|
||||
documents=test_docs
|
||||
)
|
||||
|
||||
# Verify result
|
||||
assert result == expected_response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_query_with_defaults(self):
|
||||
"""Test DocumentRag.query method with default parameters"""
|
||||
# Create mock clients
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_doc_embeddings_client = AsyncMock()
|
||||
|
||||
# Mock responses
|
||||
mock_embeddings_client.embed.return_value = [[0.1, 0.2]]
|
||||
mock_doc_embeddings_client.query.return_value = ["Default doc"]
|
||||
mock_prompt_client.document_prompt.return_value = "Default response"
|
||||
|
||||
# Initialize DocumentRag
|
||||
document_rag = DocumentRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
embeddings_client=mock_embeddings_client,
|
||||
doc_embeddings_client=mock_doc_embeddings_client
|
||||
)
|
||||
|
||||
# Call DocumentRag.query with minimal parameters
|
||||
result = await document_rag.query("simple query")
|
||||
|
||||
# Verify default parameters were used
|
||||
mock_doc_embeddings_client.query.assert_called_once_with(
|
||||
[[0.1, 0.2]],
|
||||
limit=20, # Default doc_limit
|
||||
user="trustgraph", # Default user
|
||||
collection="default" # Default collection
|
||||
)
|
||||
|
||||
assert result == "Default response"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_docs_with_verbose_output(self):
|
||||
"""Test Query.get_docs method with verbose logging"""
|
||||
# Create mock DocumentRag with clients
|
||||
mock_rag = MagicMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_doc_embeddings_client = AsyncMock()
|
||||
mock_rag.embeddings_client = mock_embeddings_client
|
||||
mock_rag.doc_embeddings_client = mock_doc_embeddings_client
|
||||
|
||||
# Mock responses
|
||||
mock_embeddings_client.embed.return_value = [[0.7, 0.8]]
|
||||
mock_doc_embeddings_client.query.return_value = ["Verbose test doc"]
|
||||
|
||||
# Initialize Query with verbose=True
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=True,
|
||||
doc_limit=5
|
||||
)
|
||||
|
||||
# Call get_docs
|
||||
result = await query.get_docs("verbose test")
|
||||
|
||||
# Verify calls were made
|
||||
mock_embeddings_client.embed.assert_called_once_with("verbose test")
|
||||
mock_doc_embeddings_client.query.assert_called_once()
|
||||
|
||||
# Verify result
|
||||
assert result == ["Verbose test doc"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_query_with_verbose(self):
|
||||
"""Test DocumentRag.query method with verbose logging enabled"""
|
||||
# Create mock clients
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_doc_embeddings_client = AsyncMock()
|
||||
|
||||
# Mock responses
|
||||
mock_embeddings_client.embed.return_value = [[0.3, 0.4]]
|
||||
mock_doc_embeddings_client.query.return_value = ["Verbose doc content"]
|
||||
mock_prompt_client.document_prompt.return_value = "Verbose RAG response"
|
||||
|
||||
# Initialize DocumentRag with verbose=True
|
||||
document_rag = DocumentRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
embeddings_client=mock_embeddings_client,
|
||||
doc_embeddings_client=mock_doc_embeddings_client,
|
||||
verbose=True
|
||||
)
|
||||
|
||||
# Call DocumentRag.query
|
||||
result = await document_rag.query("verbose query test")
|
||||
|
||||
# Verify all clients were called
|
||||
mock_embeddings_client.embed.assert_called_once_with("verbose query test")
|
||||
mock_doc_embeddings_client.query.assert_called_once()
|
||||
mock_prompt_client.document_prompt.assert_called_once_with(
|
||||
query="verbose query test",
|
||||
documents=["Verbose doc content"]
|
||||
)
|
||||
|
||||
assert result == "Verbose RAG response"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_docs_with_empty_results(self):
|
||||
"""Test Query.get_docs method when no documents are found"""
|
||||
# Create mock DocumentRag with clients
|
||||
mock_rag = MagicMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_doc_embeddings_client = AsyncMock()
|
||||
mock_rag.embeddings_client = mock_embeddings_client
|
||||
mock_rag.doc_embeddings_client = mock_doc_embeddings_client
|
||||
|
||||
# Mock responses - empty document list
|
||||
mock_embeddings_client.embed.return_value = [[0.1, 0.2]]
|
||||
mock_doc_embeddings_client.query.return_value = [] # No documents found
|
||||
|
||||
# Initialize Query
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Call get_docs
|
||||
result = await query.get_docs("query with no results")
|
||||
|
||||
# Verify calls were made
|
||||
mock_embeddings_client.embed.assert_called_once_with("query with no results")
|
||||
mock_doc_embeddings_client.query.assert_called_once()
|
||||
|
||||
# Verify empty result is returned
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_query_with_empty_documents(self):
|
||||
"""Test DocumentRag.query method when no documents are retrieved"""
|
||||
# Create mock clients
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_doc_embeddings_client = AsyncMock()
|
||||
|
||||
# Mock responses - no documents found
|
||||
mock_embeddings_client.embed.return_value = [[0.5, 0.6]]
|
||||
mock_doc_embeddings_client.query.return_value = [] # Empty document list
|
||||
mock_prompt_client.document_prompt.return_value = "No documents found response"
|
||||
|
||||
# Initialize DocumentRag
|
||||
document_rag = DocumentRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
embeddings_client=mock_embeddings_client,
|
||||
doc_embeddings_client=mock_doc_embeddings_client,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Call DocumentRag.query
|
||||
result = await document_rag.query("query with no matching docs")
|
||||
|
||||
# Verify prompt client was called with empty document list
|
||||
mock_prompt_client.document_prompt.assert_called_once_with(
|
||||
query="query with no matching docs",
|
||||
documents=[]
|
||||
)
|
||||
|
||||
assert result == "No documents found response"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_vector_with_verbose(self):
|
||||
"""Test Query.get_vector method with verbose logging"""
|
||||
# Create mock DocumentRag with embeddings client
|
||||
mock_rag = MagicMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_rag.embeddings_client = mock_embeddings_client
|
||||
|
||||
# Mock the embed method
|
||||
expected_vectors = [[0.9, 1.0, 1.1]]
|
||||
mock_embeddings_client.embed.return_value = expected_vectors
|
||||
|
||||
# Initialize Query with verbose=True
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=True
|
||||
)
|
||||
|
||||
# Call get_vector
|
||||
result = await query.get_vector("verbose vector test")
|
||||
|
||||
# Verify embeddings client was called
|
||||
mock_embeddings_client.embed.assert_called_once_with("verbose vector test")
|
||||
|
||||
# Verify result
|
||||
assert result == expected_vectors
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_integration_flow(self):
|
||||
"""Test complete DocumentRag integration with realistic data flow"""
|
||||
# Create mock clients
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_doc_embeddings_client = AsyncMock()
|
||||
|
||||
# Mock realistic responses
|
||||
query_text = "What is machine learning?"
|
||||
query_vectors = [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||
retrieved_docs = [
|
||||
"Machine learning is a subset of artificial intelligence...",
|
||||
"ML algorithms learn patterns from data to make predictions...",
|
||||
"Common ML techniques include supervised and unsupervised learning..."
|
||||
]
|
||||
final_response = "Machine learning is a field of AI that enables computers to learn and improve from experience without being explicitly programmed."
|
||||
|
||||
mock_embeddings_client.embed.return_value = query_vectors
|
||||
mock_doc_embeddings_client.query.return_value = retrieved_docs
|
||||
mock_prompt_client.document_prompt.return_value = final_response
|
||||
|
||||
# Initialize DocumentRag
|
||||
document_rag = DocumentRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
embeddings_client=mock_embeddings_client,
|
||||
doc_embeddings_client=mock_doc_embeddings_client,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Execute full pipeline
|
||||
result = await document_rag.query(
|
||||
query=query_text,
|
||||
user="research_user",
|
||||
collection="ml_knowledge",
|
||||
doc_limit=25
|
||||
)
|
||||
|
||||
# Verify complete pipeline execution
|
||||
mock_embeddings_client.embed.assert_called_once_with(query_text)
|
||||
|
||||
mock_doc_embeddings_client.query.assert_called_once_with(
|
||||
query_vectors,
|
||||
limit=25,
|
||||
user="research_user",
|
||||
collection="ml_knowledge"
|
||||
)
|
||||
|
||||
mock_prompt_client.document_prompt.assert_called_once_with(
|
||||
query=query_text,
|
||||
documents=retrieved_docs
|
||||
)
|
||||
|
||||
# Verify final result
|
||||
assert result == final_response
|
||||
595
tests/unit/test_retrieval/test_graph_rag.py
Normal file
595
tests/unit/test_retrieval/test_graph_rag.py
Normal file
|
|
@ -0,0 +1,595 @@
|
|||
"""
|
||||
Tests for GraphRAG retrieval implementation
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import unittest.mock
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag, Query
|
||||
|
||||
|
||||
class TestGraphRag:
|
||||
"""Test cases for GraphRag class"""
|
||||
|
||||
def test_graph_rag_initialization_with_defaults(self):
|
||||
"""Test GraphRag initialization with default verbose setting"""
|
||||
# Create mock clients
|
||||
mock_prompt_client = MagicMock()
|
||||
mock_embeddings_client = MagicMock()
|
||||
mock_graph_embeddings_client = MagicMock()
|
||||
mock_triples_client = MagicMock()
|
||||
|
||||
# Initialize GraphRag
|
||||
graph_rag = GraphRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
embeddings_client=mock_embeddings_client,
|
||||
graph_embeddings_client=mock_graph_embeddings_client,
|
||||
triples_client=mock_triples_client
|
||||
)
|
||||
|
||||
# Verify initialization
|
||||
assert graph_rag.prompt_client == mock_prompt_client
|
||||
assert graph_rag.embeddings_client == mock_embeddings_client
|
||||
assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client
|
||||
assert graph_rag.triples_client == mock_triples_client
|
||||
assert graph_rag.verbose is False # Default value
|
||||
assert graph_rag.label_cache == {} # Empty cache initially
|
||||
|
||||
def test_graph_rag_initialization_with_verbose(self):
|
||||
"""Test GraphRag initialization with verbose enabled"""
|
||||
# Create mock clients
|
||||
mock_prompt_client = MagicMock()
|
||||
mock_embeddings_client = MagicMock()
|
||||
mock_graph_embeddings_client = MagicMock()
|
||||
mock_triples_client = MagicMock()
|
||||
|
||||
# Initialize GraphRag with verbose=True
|
||||
graph_rag = GraphRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
embeddings_client=mock_embeddings_client,
|
||||
graph_embeddings_client=mock_graph_embeddings_client,
|
||||
triples_client=mock_triples_client,
|
||||
verbose=True
|
||||
)
|
||||
|
||||
# Verify initialization
|
||||
assert graph_rag.prompt_client == mock_prompt_client
|
||||
assert graph_rag.embeddings_client == mock_embeddings_client
|
||||
assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client
|
||||
assert graph_rag.triples_client == mock_triples_client
|
||||
assert graph_rag.verbose is True
|
||||
assert graph_rag.label_cache == {} # Empty cache initially
|
||||
|
||||
|
||||
class TestQuery:
|
||||
"""Test cases for Query class"""
|
||||
|
||||
def test_query_initialization_with_defaults(self):
|
||||
"""Test Query initialization with default parameters"""
|
||||
# Create mock GraphRag
|
||||
mock_rag = MagicMock()
|
||||
|
||||
# Initialize Query with defaults
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Verify initialization
|
||||
assert query.rag == mock_rag
|
||||
assert query.user == "test_user"
|
||||
assert query.collection == "test_collection"
|
||||
assert query.verbose is False
|
||||
assert query.entity_limit == 50 # Default value
|
||||
assert query.triple_limit == 30 # Default value
|
||||
assert query.max_subgraph_size == 1000 # Default value
|
||||
assert query.max_path_length == 2 # Default value
|
||||
|
||||
def test_query_initialization_with_custom_params(self):
|
||||
"""Test Query initialization with custom parameters"""
|
||||
# Create mock GraphRag
|
||||
mock_rag = MagicMock()
|
||||
|
||||
# Initialize Query with custom parameters
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="custom_user",
|
||||
collection="custom_collection",
|
||||
verbose=True,
|
||||
entity_limit=100,
|
||||
triple_limit=60,
|
||||
max_subgraph_size=2000,
|
||||
max_path_length=3
|
||||
)
|
||||
|
||||
# Verify initialization
|
||||
assert query.rag == mock_rag
|
||||
assert query.user == "custom_user"
|
||||
assert query.collection == "custom_collection"
|
||||
assert query.verbose is True
|
||||
assert query.entity_limit == 100
|
||||
assert query.triple_limit == 60
|
||||
assert query.max_subgraph_size == 2000
|
||||
assert query.max_path_length == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_vector_method(self):
|
||||
"""Test Query.get_vector method calls embeddings client correctly"""
|
||||
# Create mock GraphRag with embeddings client
|
||||
mock_rag = MagicMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_rag.embeddings_client = mock_embeddings_client
|
||||
|
||||
# Mock the embed method to return test vectors
|
||||
expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
mock_embeddings_client.embed.return_value = expected_vectors
|
||||
|
||||
# Initialize Query
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Call get_vector
|
||||
test_query = "What is the capital of France?"
|
||||
result = await query.get_vector(test_query)
|
||||
|
||||
# Verify embeddings client was called correctly
|
||||
mock_embeddings_client.embed.assert_called_once_with(test_query)
|
||||
|
||||
# Verify result matches expected vectors
|
||||
assert result == expected_vectors
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_vector_method_with_verbose(self):
|
||||
"""Test Query.get_vector method with verbose output"""
|
||||
# Create mock GraphRag with embeddings client
|
||||
mock_rag = MagicMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_rag.embeddings_client = mock_embeddings_client
|
||||
|
||||
# Mock the embed method
|
||||
expected_vectors = [[0.7, 0.8, 0.9]]
|
||||
mock_embeddings_client.embed.return_value = expected_vectors
|
||||
|
||||
# Initialize Query with verbose=True
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=True
|
||||
)
|
||||
|
||||
# Call get_vector
|
||||
test_query = "Test query for embeddings"
|
||||
result = await query.get_vector(test_query)
|
||||
|
||||
# Verify embeddings client was called correctly
|
||||
mock_embeddings_client.embed.assert_called_once_with(test_query)
|
||||
|
||||
# Verify result matches expected vectors
|
||||
assert result == expected_vectors
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_entities_method(self):
|
||||
"""Test Query.get_entities method retrieves entities correctly"""
|
||||
# Create mock GraphRag with clients
|
||||
mock_rag = MagicMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_graph_embeddings_client = AsyncMock()
|
||||
mock_rag.embeddings_client = mock_embeddings_client
|
||||
mock_rag.graph_embeddings_client = mock_graph_embeddings_client
|
||||
|
||||
# Mock the embedding and entity query responses
|
||||
test_vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_embeddings_client.embed.return_value = test_vectors
|
||||
|
||||
# Mock entity objects that have string representation
|
||||
mock_entity1 = MagicMock()
|
||||
mock_entity1.__str__ = MagicMock(return_value="entity1")
|
||||
mock_entity2 = MagicMock()
|
||||
mock_entity2.__str__ = MagicMock(return_value="entity2")
|
||||
mock_graph_embeddings_client.query.return_value = [mock_entity1, mock_entity2]
|
||||
|
||||
# Initialize Query
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
entity_limit=25
|
||||
)
|
||||
|
||||
# Call get_entities
|
||||
test_query = "Find related entities"
|
||||
result = await query.get_entities(test_query)
|
||||
|
||||
# Verify embeddings client was called
|
||||
mock_embeddings_client.embed.assert_called_once_with(test_query)
|
||||
|
||||
# Verify graph embeddings client was called correctly
|
||||
mock_graph_embeddings_client.query.assert_called_once_with(
|
||||
vectors=test_vectors,
|
||||
limit=25,
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
# Verify result is list of entity strings
|
||||
assert result == ["entity1", "entity2"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_maybe_label_with_cached_label(self):
|
||||
"""Test Query.maybe_label method with cached label"""
|
||||
# Create mock GraphRag with label cache
|
||||
mock_rag = MagicMock()
|
||||
mock_rag.label_cache = {"entity1": "Entity One Label"}
|
||||
|
||||
# Initialize Query
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Call maybe_label with cached entity
|
||||
result = await query.maybe_label("entity1")
|
||||
|
||||
# Verify cached label is returned
|
||||
assert result == "Entity One Label"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_maybe_label_with_label_lookup(self):
|
||||
"""Test Query.maybe_label method with database label lookup"""
|
||||
# Create mock GraphRag with triples client
|
||||
mock_rag = MagicMock()
|
||||
mock_rag.label_cache = {} # Empty cache
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_rag.triples_client = mock_triples_client
|
||||
|
||||
# Mock triple result with label
|
||||
mock_triple = MagicMock()
|
||||
mock_triple.o = "Human Readable Label"
|
||||
mock_triples_client.query.return_value = [mock_triple]
|
||||
|
||||
# Initialize Query
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Call maybe_label
|
||||
result = await query.maybe_label("http://example.com/entity")
|
||||
|
||||
# Verify triples client was called correctly
|
||||
mock_triples_client.query.assert_called_once_with(
|
||||
s="http://example.com/entity",
|
||||
p="http://www.w3.org/2000/01/rdf-schema#label",
|
||||
o=None,
|
||||
limit=1,
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
# Verify result and cache update
|
||||
assert result == "Human Readable Label"
|
||||
assert mock_rag.label_cache["http://example.com/entity"] == "Human Readable Label"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_maybe_label_with_no_label_found(self):
|
||||
"""Test Query.maybe_label method when no label is found"""
|
||||
# Create mock GraphRag with triples client
|
||||
mock_rag = MagicMock()
|
||||
mock_rag.label_cache = {} # Empty cache
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_rag.triples_client = mock_triples_client
|
||||
|
||||
# Mock empty result (no label found)
|
||||
mock_triples_client.query.return_value = []
|
||||
|
||||
# Initialize Query
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Call maybe_label
|
||||
result = await query.maybe_label("unlabeled_entity")
|
||||
|
||||
# Verify triples client was called
|
||||
mock_triples_client.query.assert_called_once_with(
|
||||
s="unlabeled_entity",
|
||||
p="http://www.w3.org/2000/01/rdf-schema#label",
|
||||
o=None,
|
||||
limit=1,
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
# Verify result is entity itself and cache is updated
|
||||
assert result == "unlabeled_entity"
|
||||
assert mock_rag.label_cache["unlabeled_entity"] == "unlabeled_entity"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_follow_edges_basic_functionality(self):
|
||||
"""Test Query.follow_edges method basic triple discovery"""
|
||||
# Create mock GraphRag with triples client
|
||||
mock_rag = MagicMock()
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_rag.triples_client = mock_triples_client
|
||||
|
||||
# Mock triple results for different query patterns
|
||||
mock_triple1 = MagicMock()
|
||||
mock_triple1.s, mock_triple1.p, mock_triple1.o = "entity1", "predicate1", "object1"
|
||||
|
||||
mock_triple2 = MagicMock()
|
||||
mock_triple2.s, mock_triple2.p, mock_triple2.o = "subject2", "entity1", "object2"
|
||||
|
||||
mock_triple3 = MagicMock()
|
||||
mock_triple3.s, mock_triple3.p, mock_triple3.o = "subject3", "predicate3", "entity1"
|
||||
|
||||
# Setup query responses for s=ent, p=ent, o=ent patterns
|
||||
mock_triples_client.query.side_effect = [
|
||||
[mock_triple1], # s=ent, p=None, o=None
|
||||
[mock_triple2], # s=None, p=ent, o=None
|
||||
[mock_triple3], # s=None, p=None, o=ent
|
||||
]
|
||||
|
||||
# Initialize Query
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
triple_limit=10
|
||||
)
|
||||
|
||||
# Call follow_edges
|
||||
subgraph = set()
|
||||
await query.follow_edges("entity1", subgraph, path_length=1)
|
||||
|
||||
# Verify all three query patterns were called
|
||||
assert mock_triples_client.query.call_count == 3
|
||||
|
||||
# Verify query calls
|
||||
mock_triples_client.query.assert_any_call(
|
||||
s="entity1", p=None, o=None, limit=10,
|
||||
user="test_user", collection="test_collection"
|
||||
)
|
||||
mock_triples_client.query.assert_any_call(
|
||||
s=None, p="entity1", o=None, limit=10,
|
||||
user="test_user", collection="test_collection"
|
||||
)
|
||||
mock_triples_client.query.assert_any_call(
|
||||
s=None, p=None, o="entity1", limit=10,
|
||||
user="test_user", collection="test_collection"
|
||||
)
|
||||
|
||||
# Verify subgraph contains discovered triples
|
||||
expected_subgraph = {
|
||||
("entity1", "predicate1", "object1"),
|
||||
("subject2", "entity1", "object2"),
|
||||
("subject3", "predicate3", "entity1")
|
||||
}
|
||||
assert subgraph == expected_subgraph
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_follow_edges_with_path_length_zero(self):
|
||||
"""Test Query.follow_edges method with path_length=0"""
|
||||
# Create mock GraphRag
|
||||
mock_rag = MagicMock()
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_rag.triples_client = mock_triples_client
|
||||
|
||||
# Initialize Query
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Call follow_edges with path_length=0
|
||||
subgraph = set()
|
||||
await query.follow_edges("entity1", subgraph, path_length=0)
|
||||
|
||||
# Verify no queries were made
|
||||
mock_triples_client.query.assert_not_called()
|
||||
|
||||
# Verify subgraph remains empty
|
||||
assert subgraph == set()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_follow_edges_with_max_subgraph_size_limit(self):
|
||||
"""Test Query.follow_edges method respects max_subgraph_size"""
|
||||
# Create mock GraphRag
|
||||
mock_rag = MagicMock()
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_rag.triples_client = mock_triples_client
|
||||
|
||||
# Initialize Query with small max_subgraph_size
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
max_subgraph_size=2
|
||||
)
|
||||
|
||||
# Pre-populate subgraph to exceed limit
|
||||
subgraph = {("s1", "p1", "o1"), ("s2", "p2", "o2"), ("s3", "p3", "o3")}
|
||||
|
||||
# Call follow_edges
|
||||
await query.follow_edges("entity1", subgraph, path_length=1)
|
||||
|
||||
# Verify no queries were made due to size limit
|
||||
mock_triples_client.query.assert_not_called()
|
||||
|
||||
# Verify subgraph unchanged
|
||||
assert len(subgraph) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_subgraph_method(self):
|
||||
"""Test Query.get_subgraph method orchestrates entity and edge discovery"""
|
||||
# Create mock Query that patches get_entities and follow_edges
|
||||
mock_rag = MagicMock()
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
max_path_length=1
|
||||
)
|
||||
|
||||
# Mock get_entities to return test entities
|
||||
query.get_entities = AsyncMock(return_value=["entity1", "entity2"])
|
||||
|
||||
# Mock follow_edges to add triples to subgraph
|
||||
async def mock_follow_edges(ent, subgraph, path_length):
|
||||
subgraph.add((ent, "predicate", "object"))
|
||||
|
||||
query.follow_edges = AsyncMock(side_effect=mock_follow_edges)
|
||||
|
||||
# Call get_subgraph
|
||||
result = await query.get_subgraph("test query")
|
||||
|
||||
# Verify get_entities was called
|
||||
query.get_entities.assert_called_once_with("test query")
|
||||
|
||||
# Verify follow_edges was called for each entity
|
||||
assert query.follow_edges.call_count == 2
|
||||
query.follow_edges.assert_any_call("entity1", unittest.mock.ANY, 1)
|
||||
query.follow_edges.assert_any_call("entity2", unittest.mock.ANY, 1)
|
||||
|
||||
# Verify result is list format
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_labelgraph_method(self):
|
||||
"""Test Query.get_labelgraph method converts entities to labels"""
|
||||
# Create mock Query
|
||||
mock_rag = MagicMock()
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
max_subgraph_size=100
|
||||
)
|
||||
|
||||
# Mock get_subgraph to return test triples
|
||||
test_subgraph = [
|
||||
("entity1", "predicate1", "object1"),
|
||||
("subject2", "http://www.w3.org/2000/01/rdf-schema#label", "Label Value"), # Should be filtered
|
||||
("entity3", "predicate3", "object3")
|
||||
]
|
||||
query.get_subgraph = AsyncMock(return_value=test_subgraph)
|
||||
|
||||
# Mock maybe_label to return human-readable labels
|
||||
async def mock_maybe_label(entity):
|
||||
label_map = {
|
||||
"entity1": "Human Entity One",
|
||||
"predicate1": "Human Predicate One",
|
||||
"object1": "Human Object One",
|
||||
"entity3": "Human Entity Three",
|
||||
"predicate3": "Human Predicate Three",
|
||||
"object3": "Human Object Three"
|
||||
}
|
||||
return label_map.get(entity, entity)
|
||||
|
||||
query.maybe_label = AsyncMock(side_effect=mock_maybe_label)
|
||||
|
||||
# Call get_labelgraph
|
||||
result = await query.get_labelgraph("test query")
|
||||
|
||||
# Verify get_subgraph was called
|
||||
query.get_subgraph.assert_called_once_with("test query")
|
||||
|
||||
# Verify label triples are filtered out
|
||||
assert len(result) == 2 # Label triple should be excluded
|
||||
|
||||
# Verify maybe_label was called for non-label triples
|
||||
expected_calls = [
|
||||
(("entity1",), {}), (("predicate1",), {}), (("object1",), {}),
|
||||
(("entity3",), {}), (("predicate3",), {}), (("object3",), {})
|
||||
]
|
||||
assert query.maybe_label.call_count == 6
|
||||
|
||||
# Verify result contains human-readable labels
|
||||
expected_result = [
|
||||
("Human Entity One", "Human Predicate One", "Human Object One"),
|
||||
("Human Entity Three", "Human Predicate Three", "Human Object Three")
|
||||
]
|
||||
assert result == expected_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_rag_query_method(self):
|
||||
"""Test GraphRag.query method orchestrates full RAG pipeline"""
|
||||
# Create mock clients
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_graph_embeddings_client = AsyncMock()
|
||||
mock_triples_client = AsyncMock()
|
||||
|
||||
# Mock prompt client response
|
||||
expected_response = "This is the RAG response"
|
||||
mock_prompt_client.kg_prompt.return_value = expected_response
|
||||
|
||||
# Initialize GraphRag
|
||||
graph_rag = GraphRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
embeddings_client=mock_embeddings_client,
|
||||
graph_embeddings_client=mock_graph_embeddings_client,
|
||||
triples_client=mock_triples_client,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Mock the Query class behavior by patching get_labelgraph
|
||||
test_labelgraph = [("Subject", "Predicate", "Object")]
|
||||
|
||||
# We need to patch the Query class's get_labelgraph method
|
||||
original_query_init = Query.__init__
|
||||
original_get_labelgraph = Query.get_labelgraph
|
||||
|
||||
def mock_query_init(self, *args, **kwargs):
|
||||
original_query_init(self, *args, **kwargs)
|
||||
|
||||
async def mock_get_labelgraph(self, query_text):
|
||||
return test_labelgraph
|
||||
|
||||
Query.__init__ = mock_query_init
|
||||
Query.get_labelgraph = mock_get_labelgraph
|
||||
|
||||
try:
|
||||
# Call GraphRag.query
|
||||
result = await graph_rag.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
entity_limit=25,
|
||||
triple_limit=15
|
||||
)
|
||||
|
||||
# Verify prompt client was called with knowledge graph and query
|
||||
mock_prompt_client.kg_prompt.assert_called_once_with("test query", test_labelgraph)
|
||||
|
||||
# Verify result
|
||||
assert result == expected_response
|
||||
|
||||
finally:
|
||||
# Restore original methods
|
||||
Query.__init__ = original_query_init
|
||||
Query.get_labelgraph = original_get_labelgraph
|
||||
277
tests/unit/test_rev_gateway/test_dispatcher.py
Normal file
277
tests/unit/test_rev_gateway/test_dispatcher.py
Normal file
|
|
@ -0,0 +1,277 @@
|
|||
"""
|
||||
Tests for Reverse Gateway Dispatcher
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
from trustgraph.rev_gateway.dispatcher import WebSocketResponder, MessageDispatcher
|
||||
|
||||
|
||||
class TestWebSocketResponder:
|
||||
"""Test cases for WebSocketResponder class"""
|
||||
|
||||
def test_websocket_responder_initialization(self):
|
||||
"""Test WebSocketResponder initialization"""
|
||||
responder = WebSocketResponder()
|
||||
|
||||
assert responder.response is None
|
||||
assert responder.completed is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_responder_send_method(self):
|
||||
"""Test WebSocketResponder send method"""
|
||||
responder = WebSocketResponder()
|
||||
|
||||
test_response = {"data": "test response"}
|
||||
|
||||
# Call send method
|
||||
await responder.send(test_response)
|
||||
|
||||
# Verify response was stored
|
||||
assert responder.response == test_response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_responder_call_method(self):
|
||||
"""Test WebSocketResponder __call__ method"""
|
||||
responder = WebSocketResponder()
|
||||
|
||||
test_response = {"result": "success"}
|
||||
test_completed = True
|
||||
|
||||
# Call the responder
|
||||
await responder(test_response, test_completed)
|
||||
|
||||
# Verify response and completed status were set
|
||||
assert responder.response == test_response
|
||||
assert responder.completed == test_completed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_responder_call_method_with_false_completion(self):
|
||||
"""Test WebSocketResponder __call__ method with incomplete response"""
|
||||
responder = WebSocketResponder()
|
||||
|
||||
test_response = {"partial": "data"}
|
||||
test_completed = False
|
||||
|
||||
# Call the responder
|
||||
await responder(test_response, test_completed)
|
||||
|
||||
# Verify response was set and completed is True (since send() always sets completed=True)
|
||||
assert responder.response == test_response
|
||||
assert responder.completed is True
|
||||
|
||||
|
||||
class TestMessageDispatcher:
|
||||
"""Test cases for MessageDispatcher class"""
|
||||
|
||||
def test_message_dispatcher_initialization_with_defaults(self):
|
||||
"""Test MessageDispatcher initialization with default parameters"""
|
||||
dispatcher = MessageDispatcher()
|
||||
|
||||
assert dispatcher.max_workers == 10
|
||||
assert dispatcher.semaphore._value == 10
|
||||
assert dispatcher.active_tasks == set()
|
||||
assert dispatcher.pulsar_client is None
|
||||
assert dispatcher.dispatcher_manager is None
|
||||
assert len(dispatcher.service_mapping) > 0
|
||||
|
||||
def test_message_dispatcher_initialization_with_custom_workers(self):
|
||||
"""Test MessageDispatcher initialization with custom max_workers"""
|
||||
dispatcher = MessageDispatcher(max_workers=5)
|
||||
|
||||
assert dispatcher.max_workers == 5
|
||||
assert dispatcher.semaphore._value == 5
|
||||
|
||||
@patch('trustgraph.rev_gateway.dispatcher.DispatcherManager')
|
||||
def test_message_dispatcher_initialization_with_pulsar_client(self, mock_dispatcher_manager):
|
||||
"""Test MessageDispatcher initialization with pulsar_client and config_receiver"""
|
||||
mock_pulsar_client = MagicMock()
|
||||
mock_config_receiver = MagicMock()
|
||||
mock_dispatcher_instance = MagicMock()
|
||||
mock_dispatcher_manager.return_value = mock_dispatcher_instance
|
||||
|
||||
dispatcher = MessageDispatcher(
|
||||
max_workers=8,
|
||||
config_receiver=mock_config_receiver,
|
||||
pulsar_client=mock_pulsar_client
|
||||
)
|
||||
|
||||
assert dispatcher.max_workers == 8
|
||||
assert dispatcher.pulsar_client == mock_pulsar_client
|
||||
assert dispatcher.dispatcher_manager == mock_dispatcher_instance
|
||||
mock_dispatcher_manager.assert_called_once_with(
|
||||
mock_pulsar_client, mock_config_receiver, prefix="rev-gateway"
|
||||
)
|
||||
|
||||
def test_message_dispatcher_service_mapping(self):
|
||||
"""Test MessageDispatcher service mapping contains expected services"""
|
||||
dispatcher = MessageDispatcher()
|
||||
|
||||
expected_services = [
|
||||
"text-completion", "graph-rag", "agent", "embeddings",
|
||||
"graph-embeddings", "triples", "document-load", "text-load",
|
||||
"flow", "knowledge", "config", "librarian", "document-rag"
|
||||
]
|
||||
|
||||
for service in expected_services:
|
||||
assert service in dispatcher.service_mapping
|
||||
|
||||
# Test specific mappings
|
||||
assert dispatcher.service_mapping["text-completion"] == "text-completion"
|
||||
assert dispatcher.service_mapping["document-load"] == "document"
|
||||
assert dispatcher.service_mapping["text-load"] == "text-document"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_dispatcher_handle_message_without_dispatcher_manager(self):
|
||||
"""Test MessageDispatcher handle_message without dispatcher manager"""
|
||||
dispatcher = MessageDispatcher()
|
||||
|
||||
test_message = {
|
||||
"id": "test-123",
|
||||
"service": "test-service",
|
||||
"request": {"data": "test"}
|
||||
}
|
||||
|
||||
result = await dispatcher.handle_message(test_message)
|
||||
|
||||
assert result["id"] == "test-123"
|
||||
assert "error" in result["response"]
|
||||
assert "DispatcherManager not available" in result["response"]["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_dispatcher_handle_message_with_exception(self):
|
||||
"""Test MessageDispatcher handle_message with exception during processing"""
|
||||
mock_dispatcher_manager = MagicMock()
|
||||
mock_dispatcher_manager.invoke_global_service = AsyncMock(side_effect=Exception("Test error"))
|
||||
|
||||
dispatcher = MessageDispatcher()
|
||||
dispatcher.dispatcher_manager = mock_dispatcher_manager
|
||||
|
||||
test_message = {
|
||||
"id": "test-456",
|
||||
"service": "text-completion",
|
||||
"request": {"prompt": "test"}
|
||||
}
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {"text-completion": True}):
|
||||
result = await dispatcher.handle_message(test_message)
|
||||
|
||||
assert result["id"] == "test-456"
|
||||
assert "error" in result["response"]
|
||||
assert "Test error" in result["response"]["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_dispatcher_handle_message_global_service(self):
|
||||
"""Test MessageDispatcher handle_message with global service"""
|
||||
mock_dispatcher_manager = MagicMock()
|
||||
mock_dispatcher_manager.invoke_global_service = AsyncMock()
|
||||
mock_responder = MagicMock()
|
||||
mock_responder.completed = True
|
||||
mock_responder.response = {"result": "success"}
|
||||
|
||||
dispatcher = MessageDispatcher()
|
||||
dispatcher.dispatcher_manager = mock_dispatcher_manager
|
||||
|
||||
test_message = {
|
||||
"id": "test-789",
|
||||
"service": "text-completion",
|
||||
"request": {"prompt": "hello"}
|
||||
}
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {"text-completion": True}):
|
||||
with patch('trustgraph.rev_gateway.dispatcher.WebSocketResponder', return_value=mock_responder):
|
||||
result = await dispatcher.handle_message(test_message)
|
||||
|
||||
assert result["id"] == "test-789"
|
||||
assert result["response"] == {"result": "success"}
|
||||
mock_dispatcher_manager.invoke_global_service.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_dispatcher_handle_message_flow_service(self):
|
||||
"""Test MessageDispatcher handle_message with flow service"""
|
||||
mock_dispatcher_manager = MagicMock()
|
||||
mock_dispatcher_manager.invoke_flow_service = AsyncMock()
|
||||
mock_responder = MagicMock()
|
||||
mock_responder.completed = True
|
||||
mock_responder.response = {"data": "flow_result"}
|
||||
|
||||
dispatcher = MessageDispatcher()
|
||||
dispatcher.dispatcher_manager = mock_dispatcher_manager
|
||||
|
||||
test_message = {
|
||||
"id": "test-flow-123",
|
||||
"service": "document-rag",
|
||||
"request": {"query": "test"},
|
||||
"flow": "custom-flow"
|
||||
}
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {}):
|
||||
with patch('trustgraph.rev_gateway.dispatcher.WebSocketResponder', return_value=mock_responder):
|
||||
result = await dispatcher.handle_message(test_message)
|
||||
|
||||
assert result["id"] == "test-flow-123"
|
||||
assert result["response"] == {"data": "flow_result"}
|
||||
mock_dispatcher_manager.invoke_flow_service.assert_called_once_with(
|
||||
{"query": "test"}, mock_responder, "custom-flow", "document-rag"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_dispatcher_handle_message_incomplete_response(self):
|
||||
"""Test MessageDispatcher handle_message with incomplete response"""
|
||||
mock_dispatcher_manager = MagicMock()
|
||||
mock_dispatcher_manager.invoke_flow_service = AsyncMock()
|
||||
mock_responder = MagicMock()
|
||||
mock_responder.completed = False
|
||||
mock_responder.response = None
|
||||
|
||||
dispatcher = MessageDispatcher()
|
||||
dispatcher.dispatcher_manager = mock_dispatcher_manager
|
||||
|
||||
test_message = {
|
||||
"id": "test-incomplete",
|
||||
"service": "agent",
|
||||
"request": {"input": "test"}
|
||||
}
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {}):
|
||||
with patch('trustgraph.rev_gateway.dispatcher.WebSocketResponder', return_value=mock_responder):
|
||||
result = await dispatcher.handle_message(test_message)
|
||||
|
||||
assert result["id"] == "test-incomplete"
|
||||
assert result["response"] == {"error": "No response received"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_dispatcher_shutdown(self):
|
||||
"""Test MessageDispatcher shutdown method"""
|
||||
import asyncio
|
||||
|
||||
dispatcher = MessageDispatcher()
|
||||
|
||||
# Create actual async tasks
|
||||
async def dummy_task():
|
||||
await asyncio.sleep(0.01)
|
||||
return "done"
|
||||
|
||||
task1 = asyncio.create_task(dummy_task())
|
||||
task2 = asyncio.create_task(dummy_task())
|
||||
dispatcher.active_tasks = {task1, task2}
|
||||
|
||||
# Call shutdown
|
||||
await dispatcher.shutdown()
|
||||
|
||||
# Verify tasks were completed
|
||||
assert task1.done()
|
||||
assert task2.done()
|
||||
assert len(dispatcher.active_tasks) == 2 # Tasks remain in set but are completed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_dispatcher_shutdown_with_no_tasks(self):
|
||||
"""Test MessageDispatcher shutdown with no active tasks"""
|
||||
dispatcher = MessageDispatcher()
|
||||
|
||||
# Call shutdown with no active tasks
|
||||
await dispatcher.shutdown()
|
||||
|
||||
# Should complete without error
|
||||
assert dispatcher.active_tasks == set()
|
||||
545
tests/unit/test_rev_gateway/test_rev_gateway_service.py
Normal file
545
tests/unit/test_rev_gateway/test_rev_gateway_service.py
Normal file
|
|
@ -0,0 +1,545 @@
|
|||
"""
|
||||
Tests for Reverse Gateway Service
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock, AsyncMock, patch, Mock
|
||||
from aiohttp import WSMsgType, ClientWebSocketResponse
|
||||
import json
|
||||
|
||||
from trustgraph.rev_gateway.service import ReverseGateway, parse_args, run
|
||||
|
||||
|
||||
class TestReverseGateway:
|
||||
"""Test cases for ReverseGateway class"""
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
def test_reverse_gateway_initialization_defaults(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway initialization with default parameters"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
|
||||
gateway = ReverseGateway()
|
||||
|
||||
assert gateway.websocket_uri == "ws://localhost:7650/out"
|
||||
assert gateway.host == "localhost"
|
||||
assert gateway.port == 7650
|
||||
assert gateway.scheme == "ws"
|
||||
assert gateway.path == "/out"
|
||||
assert gateway.url == "ws://localhost:7650/out"
|
||||
assert gateway.max_workers == 10
|
||||
assert gateway.running is False
|
||||
assert gateway.reconnect_delay == 3.0
|
||||
assert gateway.pulsar_host == "pulsar://pulsar:6650"
|
||||
assert gateway.pulsar_api_key is None
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
def test_reverse_gateway_initialization_custom_params(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway initialization with custom parameters"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
|
||||
gateway = ReverseGateway(
|
||||
websocket_uri="wss://example.com:8080/websocket",
|
||||
max_workers=20,
|
||||
pulsar_host="pulsar://custom:6650",
|
||||
pulsar_api_key="test-key",
|
||||
pulsar_listener="test-listener"
|
||||
)
|
||||
|
||||
assert gateway.websocket_uri == "wss://example.com:8080/websocket"
|
||||
assert gateway.host == "example.com"
|
||||
assert gateway.port == 8080
|
||||
assert gateway.scheme == "wss"
|
||||
assert gateway.path == "/websocket"
|
||||
assert gateway.url == "wss://example.com:8080/websocket"
|
||||
assert gateway.max_workers == 20
|
||||
assert gateway.pulsar_host == "pulsar://custom:6650"
|
||||
assert gateway.pulsar_api_key == "test-key"
|
||||
assert gateway.pulsar_listener == "test-listener"
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
def test_reverse_gateway_initialization_with_missing_path(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway initialization with WebSocket URI missing path"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
|
||||
gateway = ReverseGateway(websocket_uri="ws://example.com")
|
||||
|
||||
assert gateway.path == "/ws"
|
||||
assert gateway.url == "ws://example.com/ws"
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
def test_reverse_gateway_initialization_invalid_scheme(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway initialization with invalid WebSocket scheme"""
|
||||
with pytest.raises(ValueError, match="WebSocket URI must use ws:// or wss:// scheme"):
|
||||
ReverseGateway(websocket_uri="http://example.com")
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
def test_reverse_gateway_initialization_missing_hostname(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway initialization with missing hostname"""
|
||||
with pytest.raises(ValueError, match="WebSocket URI must include hostname"):
|
||||
ReverseGateway(websocket_uri="ws://")
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
def test_reverse_gateway_pulsar_client_with_auth(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway creates Pulsar client with authentication"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
|
||||
with patch('pulsar.AuthenticationToken') as mock_auth:
|
||||
mock_auth_instance = MagicMock()
|
||||
mock_auth.return_value = mock_auth_instance
|
||||
|
||||
gateway = ReverseGateway(
|
||||
pulsar_api_key="test-key",
|
||||
pulsar_listener="test-listener"
|
||||
)
|
||||
|
||||
mock_auth.assert_called_once_with("test-key")
|
||||
mock_pulsar_client.assert_called_once_with(
|
||||
"pulsar://pulsar:6650",
|
||||
listener_name="test-listener",
|
||||
authentication=mock_auth_instance
|
||||
)
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
@patch('trustgraph.rev_gateway.service.ClientSession')
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_connect_success(self, mock_session_class, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway successful connection"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_ws = AsyncMock()
|
||||
mock_session.ws_connect.return_value = mock_ws
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
gateway = ReverseGateway()
|
||||
|
||||
result = await gateway.connect()
|
||||
|
||||
assert result is True
|
||||
assert gateway.session == mock_session
|
||||
assert gateway.ws == mock_ws
|
||||
mock_session.ws_connect.assert_called_once_with(gateway.url)
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
@patch('trustgraph.rev_gateway.service.ClientSession')
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_connect_failure(self, mock_session_class, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway connection failure"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.ws_connect.side_effect = Exception("Connection failed")
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
gateway = ReverseGateway()
|
||||
|
||||
result = await gateway.connect()
|
||||
|
||||
assert result is False
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_disconnect(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway disconnect"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
|
||||
gateway = ReverseGateway()
|
||||
|
||||
# Mock websocket and session
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.closed = False
|
||||
mock_session = AsyncMock()
|
||||
mock_session.closed = False
|
||||
|
||||
gateway.ws = mock_ws
|
||||
gateway.session = mock_session
|
||||
|
||||
await gateway.disconnect()
|
||||
|
||||
mock_ws.close.assert_called_once()
|
||||
mock_session.close.assert_called_once()
|
||||
assert gateway.ws is None
|
||||
assert gateway.session is None
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_send_message(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway send message"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
|
||||
gateway = ReverseGateway()
|
||||
|
||||
# Mock websocket
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.closed = False
|
||||
gateway.ws = mock_ws
|
||||
|
||||
test_message = {"id": "test", "data": "hello"}
|
||||
|
||||
await gateway.send_message(test_message)
|
||||
|
||||
mock_ws.send_str.assert_called_once_with(json.dumps(test_message))
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_send_message_closed_connection(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway send message with closed connection"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
|
||||
gateway = ReverseGateway()
|
||||
|
||||
# Mock closed websocket
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.closed = True
|
||||
gateway.ws = mock_ws
|
||||
|
||||
test_message = {"id": "test", "data": "hello"}
|
||||
|
||||
await gateway.send_message(test_message)
|
||||
|
||||
# Should not call send_str on closed connection
|
||||
mock_ws.send_str.assert_not_called()
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_handle_message(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway handle message"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
|
||||
mock_dispatcher_instance = AsyncMock()
|
||||
mock_dispatcher_instance.handle_message.return_value = {"response": "success"}
|
||||
mock_dispatcher.return_value = mock_dispatcher_instance
|
||||
|
||||
gateway = ReverseGateway()
|
||||
|
||||
# Mock send_message
|
||||
gateway.send_message = AsyncMock()
|
||||
|
||||
test_message = '{"id": "test", "service": "test-service", "request": {"data": "test"}}'
|
||||
|
||||
await gateway.handle_message(test_message)
|
||||
|
||||
mock_dispatcher_instance.handle_message.assert_called_once_with({
|
||||
"id": "test",
|
||||
"service": "test-service",
|
||||
"request": {"data": "test"}
|
||||
})
|
||||
gateway.send_message.assert_called_once_with({"response": "success"})
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_handle_message_invalid_json(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway handle message with invalid JSON"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
|
||||
gateway = ReverseGateway()
|
||||
|
||||
# Mock send_message
|
||||
gateway.send_message = AsyncMock()
|
||||
|
||||
test_message = 'invalid json'
|
||||
|
||||
# Should not raise exception
|
||||
await gateway.handle_message(test_message)
|
||||
|
||||
# Should not call send_message due to error
|
||||
gateway.send_message.assert_not_called()
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_listen_text_message(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway listen with text message"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
|
||||
gateway = ReverseGateway()
|
||||
gateway.running = True
|
||||
|
||||
# Mock websocket
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.closed = False
|
||||
gateway.ws = mock_ws
|
||||
|
||||
# Mock handle_message
|
||||
gateway.handle_message = AsyncMock()
|
||||
|
||||
# Mock message
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.type = WSMsgType.TEXT
|
||||
mock_msg.data = '{"test": "message"}'
|
||||
|
||||
# Mock receive to return message once, then raise exception to stop loop
|
||||
mock_ws.receive.side_effect = [mock_msg, Exception("Test stop")]
|
||||
|
||||
# listen() catches exceptions and breaks, so no exception should be raised
|
||||
await gateway.listen()
|
||||
|
||||
gateway.handle_message.assert_called_once_with('{"test": "message"}')
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_listen_binary_message(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway listen with binary message"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
|
||||
gateway = ReverseGateway()
|
||||
gateway.running = True
|
||||
|
||||
# Mock websocket
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.closed = False
|
||||
gateway.ws = mock_ws
|
||||
|
||||
# Mock handle_message
|
||||
gateway.handle_message = AsyncMock()
|
||||
|
||||
# Mock message
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.type = WSMsgType.BINARY
|
||||
mock_msg.data = b'{"test": "binary"}'
|
||||
|
||||
# Mock receive to return message once, then raise exception to stop loop
|
||||
mock_ws.receive.side_effect = [mock_msg, Exception("Test stop")]
|
||||
|
||||
# listen() catches exceptions and breaks, so no exception should be raised
|
||||
await gateway.listen()
|
||||
|
||||
gateway.handle_message.assert_called_once_with('{"test": "binary"}')
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_listen_close_message(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway listen with close message"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
|
||||
gateway = ReverseGateway()
|
||||
gateway.running = True
|
||||
|
||||
# Mock websocket
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.closed = False
|
||||
gateway.ws = mock_ws
|
||||
|
||||
# Mock handle_message
|
||||
gateway.handle_message = AsyncMock()
|
||||
|
||||
# Mock message
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.type = WSMsgType.CLOSE
|
||||
|
||||
# Mock receive to return close message
|
||||
mock_ws.receive.return_value = mock_msg
|
||||
|
||||
await gateway.listen()
|
||||
|
||||
# Should not call handle_message for close message
|
||||
gateway.handle_message.assert_not_called()
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_shutdown(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway shutdown"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
|
||||
mock_dispatcher_instance = AsyncMock()
|
||||
mock_dispatcher.return_value = mock_dispatcher_instance
|
||||
|
||||
gateway = ReverseGateway()
|
||||
gateway.running = True
|
||||
|
||||
# Mock disconnect
|
||||
gateway.disconnect = AsyncMock()
|
||||
|
||||
await gateway.shutdown()
|
||||
|
||||
assert gateway.running is False
|
||||
mock_dispatcher_instance.shutdown.assert_called_once()
|
||||
gateway.disconnect.assert_called_once()
|
||||
mock_client_instance.close.assert_called_once()
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
def test_reverse_gateway_stop(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway stop"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
|
||||
gateway = ReverseGateway()
|
||||
gateway.running = True
|
||||
|
||||
gateway.stop()
|
||||
|
||||
assert gateway.running is False
|
||||
|
||||
|
||||
class TestReverseGatewayRun:
|
||||
"""Test cases for ReverseGateway run method"""
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_run_successful_cycle(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway run method with successful connect/listen cycle"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
|
||||
mock_config_receiver_instance = AsyncMock()
|
||||
mock_config_receiver.return_value = mock_config_receiver_instance
|
||||
|
||||
gateway = ReverseGateway()
|
||||
|
||||
# Mock methods
|
||||
gateway.connect = AsyncMock(return_value=True)
|
||||
gateway.listen = AsyncMock()
|
||||
gateway.disconnect = AsyncMock()
|
||||
gateway.shutdown = AsyncMock()
|
||||
|
||||
# Stop after one iteration
|
||||
call_count = 0
|
||||
async def mock_connect():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return True
|
||||
else:
|
||||
gateway.running = False
|
||||
return False
|
||||
|
||||
gateway.connect = mock_connect
|
||||
|
||||
await gateway.run()
|
||||
|
||||
mock_config_receiver_instance.start.assert_called_once()
|
||||
gateway.listen.assert_called_once()
|
||||
# disconnect is called twice: once in the main loop, once in shutdown
|
||||
assert gateway.disconnect.call_count == 2
|
||||
gateway.shutdown.assert_called_once()
|
||||
|
||||
|
||||
class TestReverseGatewayArgs:
|
||||
"""Test cases for argument parsing and run function"""
|
||||
|
||||
def test_parse_args_defaults(self):
|
||||
"""Test parse_args with default values"""
|
||||
import sys
|
||||
|
||||
# Mock sys.argv
|
||||
original_argv = sys.argv
|
||||
sys.argv = ['reverse-gateway']
|
||||
|
||||
try:
|
||||
args = parse_args()
|
||||
|
||||
assert args.websocket_uri is None
|
||||
assert args.max_workers == 10
|
||||
assert args.pulsar_host is None
|
||||
assert args.pulsar_api_key is None
|
||||
assert args.pulsar_listener is None
|
||||
finally:
|
||||
sys.argv = original_argv
|
||||
|
||||
def test_parse_args_custom_values(self):
|
||||
"""Test parse_args with custom values"""
|
||||
import sys
|
||||
|
||||
# Mock sys.argv
|
||||
original_argv = sys.argv
|
||||
sys.argv = [
|
||||
'reverse-gateway',
|
||||
'--websocket-uri', 'ws://custom:8080/ws',
|
||||
'--max-workers', '20',
|
||||
'--pulsar-host', 'pulsar://custom:6650',
|
||||
'--pulsar-api-key', 'test-key',
|
||||
'--pulsar-listener', 'test-listener'
|
||||
]
|
||||
|
||||
try:
|
||||
args = parse_args()
|
||||
|
||||
assert args.websocket_uri == 'ws://custom:8080/ws'
|
||||
assert args.max_workers == 20
|
||||
assert args.pulsar_host == 'pulsar://custom:6650'
|
||||
assert args.pulsar_api_key == 'test-key'
|
||||
assert args.pulsar_listener == 'test-listener'
|
||||
finally:
|
||||
sys.argv = original_argv
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ReverseGateway')
|
||||
@patch('asyncio.run')
|
||||
def test_run_function(self, mock_asyncio_run, mock_gateway_class):
|
||||
"""Test run function"""
|
||||
import sys
|
||||
|
||||
# Mock sys.argv
|
||||
original_argv = sys.argv
|
||||
sys.argv = ['reverse-gateway', '--max-workers', '15']
|
||||
|
||||
try:
|
||||
mock_gateway_instance = MagicMock()
|
||||
mock_gateway_instance.url = "ws://localhost:7650/out"
|
||||
mock_gateway_instance.pulsar_host = "pulsar://pulsar:6650"
|
||||
mock_gateway_class.return_value = mock_gateway_instance
|
||||
|
||||
run()
|
||||
|
||||
mock_gateway_class.assert_called_once_with(
|
||||
websocket_uri=None,
|
||||
max_workers=15,
|
||||
pulsar_host=None,
|
||||
pulsar_api_key=None,
|
||||
pulsar_listener=None
|
||||
)
|
||||
mock_asyncio_run.assert_called_once_with(mock_gateway_instance.run())
|
||||
finally:
|
||||
sys.argv = original_argv
|
||||
162
tests/unit/test_storage/conftest.py
Normal file
162
tests/unit/test_storage/conftest.py
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
"""
|
||||
Shared fixtures for storage tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_storage_config():
|
||||
"""Base configuration for storage processors"""
|
||||
return {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-storage-processor'
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def qdrant_storage_config(base_storage_config):
|
||||
"""Configuration for Qdrant storage processors"""
|
||||
return base_storage_config | {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key'
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_qdrant_client():
|
||||
"""Mock Qdrant client"""
|
||||
mock_client = MagicMock()
|
||||
mock_client.collection_exists.return_value = True
|
||||
mock_client.create_collection.return_value = None
|
||||
mock_client.upsert.return_value = None
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_uuid():
|
||||
"""Mock UUID generation"""
|
||||
mock_uuid = MagicMock()
|
||||
mock_uuid.uuid4.return_value = MagicMock()
|
||||
mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid-123')
|
||||
return mock_uuid
|
||||
|
||||
|
||||
# Document embeddings fixtures
|
||||
@pytest.fixture
|
||||
def mock_document_embeddings_message():
|
||||
"""Mock document embeddings message"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'test_user'
|
||||
mock_message.metadata.collection = 'test_collection'
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.chunk.decode.return_value = 'test document chunk'
|
||||
mock_chunk.vectors = [[0.1, 0.2, 0.3]]
|
||||
|
||||
mock_message.chunks = [mock_chunk]
|
||||
return mock_message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document_embeddings_multiple_chunks():
|
||||
"""Mock document embeddings message with multiple chunks"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'multi_user'
|
||||
mock_message.metadata.collection = 'multi_collection'
|
||||
|
||||
mock_chunk1 = MagicMock()
|
||||
mock_chunk1.chunk.decode.return_value = 'first document chunk'
|
||||
mock_chunk1.vectors = [[0.1, 0.2]]
|
||||
|
||||
mock_chunk2 = MagicMock()
|
||||
mock_chunk2.chunk.decode.return_value = 'second document chunk'
|
||||
mock_chunk2.vectors = [[0.3, 0.4]]
|
||||
|
||||
mock_message.chunks = [mock_chunk1, mock_chunk2]
|
||||
return mock_message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document_embeddings_multiple_vectors():
|
||||
"""Mock document embeddings message with multiple vectors per chunk"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'vector_user'
|
||||
mock_message.metadata.collection = 'vector_collection'
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.chunk.decode.return_value = 'multi-vector document chunk'
|
||||
mock_chunk.vectors = [
|
||||
[0.1, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.6],
|
||||
[0.7, 0.8, 0.9]
|
||||
]
|
||||
|
||||
mock_message.chunks = [mock_chunk]
|
||||
return mock_message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document_embeddings_empty_chunk():
|
||||
"""Mock document embeddings message with empty chunk"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'empty_user'
|
||||
mock_message.metadata.collection = 'empty_collection'
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.chunk.decode.return_value = "" # Empty string
|
||||
mock_chunk.vectors = [[0.1, 0.2]]
|
||||
|
||||
mock_message.chunks = [mock_chunk]
|
||||
return mock_message
|
||||
|
||||
|
||||
# Graph embeddings fixtures
|
||||
@pytest.fixture
|
||||
def mock_graph_embeddings_message():
|
||||
"""Mock graph embeddings message"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'test_user'
|
||||
mock_message.metadata.collection = 'test_collection'
|
||||
|
||||
mock_entity = MagicMock()
|
||||
mock_entity.entity.value = 'test_entity'
|
||||
mock_entity.vectors = [[0.1, 0.2, 0.3]]
|
||||
|
||||
mock_message.entities = [mock_entity]
|
||||
return mock_message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_graph_embeddings_multiple_entities():
|
||||
"""Mock graph embeddings message with multiple entities"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'multi_user'
|
||||
mock_message.metadata.collection = 'multi_collection'
|
||||
|
||||
mock_entity1 = MagicMock()
|
||||
mock_entity1.entity.value = 'entity_one'
|
||||
mock_entity1.vectors = [[0.1, 0.2]]
|
||||
|
||||
mock_entity2 = MagicMock()
|
||||
mock_entity2.entity.value = 'entity_two'
|
||||
mock_entity2.vectors = [[0.3, 0.4]]
|
||||
|
||||
mock_message.entities = [mock_entity1, mock_entity2]
|
||||
return mock_message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_graph_embeddings_empty_entity():
|
||||
"""Mock graph embeddings message with empty entity"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'empty_user'
|
||||
mock_message.metadata.collection = 'empty_collection'
|
||||
|
||||
mock_entity = MagicMock()
|
||||
mock_entity.entity.value = "" # Empty string
|
||||
mock_entity.vectors = [[0.1, 0.2]]
|
||||
|
||||
mock_message.entities = [mock_entity]
|
||||
return mock_message
|
||||
569
tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py
Normal file
569
tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py
Normal file
|
|
@ -0,0 +1,569 @@
|
|||
"""
|
||||
Unit tests for trustgraph.storage.doc_embeddings.qdrant.write
|
||||
Testing document embeddings storage functionality
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.storage.doc_embeddings.qdrant.write import Processor
|
||||
|
||||
|
||||
class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||
"""Test Qdrant document embeddings storage functionality"""
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test basic Qdrant processor initialization"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-doc-qdrant-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify base class initialization was called
|
||||
mock_base_init.assert_called_once()
|
||||
|
||||
# Verify QdrantClient was created with correct parameters
|
||||
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key='test-api-key')
|
||||
|
||||
# Verify processor attributes
|
||||
assert hasattr(processor, 'qdrant')
|
||||
assert processor.qdrant == mock_qdrant_instance
|
||||
assert hasattr(processor, 'last_collection')
|
||||
assert processor.last_collection is None
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_processor_initialization_with_defaults(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test processor initialization with default values"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-doc-qdrant-processor'
|
||||
# No store_uri or api_key provided - should use defaults
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify QdrantClient was created with default URI and None API key
|
||||
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key=None)
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_store_document_embeddings_basic(self, mock_base_init, mock_uuid, mock_qdrant_client):
|
||||
"""Test storing document embeddings with basic message"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True # Collection already exists
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
mock_uuid.uuid4.return_value = MagicMock()
|
||||
mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid-123')
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-doc-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with chunks and vectors
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'test_user'
|
||||
mock_message.metadata.collection = 'test_collection'
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.chunk.decode.return_value = 'test document chunk'
|
||||
mock_chunk.vectors = [[0.1, 0.2, 0.3]] # Single vector with 3 dimensions
|
||||
|
||||
mock_message.chunks = [mock_chunk]
|
||||
|
||||
# Act
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify collection existence was checked
|
||||
expected_collection = 'd_test_user_test_collection_3'
|
||||
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection)
|
||||
|
||||
# Verify upsert was called
|
||||
mock_qdrant_instance.upsert.assert_called_once()
|
||||
|
||||
# Verify upsert parameters
|
||||
upsert_call_args = mock_qdrant_instance.upsert.call_args
|
||||
assert upsert_call_args[1]['collection_name'] == expected_collection
|
||||
assert len(upsert_call_args[1]['points']) == 1
|
||||
|
||||
point = upsert_call_args[1]['points'][0]
|
||||
assert point.vector == [0.1, 0.2, 0.3]
|
||||
assert point.payload['doc'] == 'test document chunk'
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_store_document_embeddings_multiple_chunks(self, mock_base_init, mock_uuid, mock_qdrant_client):
|
||||
"""Test storing document embeddings with multiple chunks"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
mock_uuid.uuid4.return_value = MagicMock()
|
||||
mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid')
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-doc-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with multiple chunks
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'multi_user'
|
||||
mock_message.metadata.collection = 'multi_collection'
|
||||
|
||||
mock_chunk1 = MagicMock()
|
||||
mock_chunk1.chunk.decode.return_value = 'first document chunk'
|
||||
mock_chunk1.vectors = [[0.1, 0.2]]
|
||||
|
||||
mock_chunk2 = MagicMock()
|
||||
mock_chunk2.chunk.decode.return_value = 'second document chunk'
|
||||
mock_chunk2.vectors = [[0.3, 0.4]]
|
||||
|
||||
mock_message.chunks = [mock_chunk1, mock_chunk2]
|
||||
|
||||
# Act
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Should be called twice (once per chunk)
|
||||
assert mock_qdrant_instance.upsert.call_count == 2
|
||||
|
||||
# Verify both chunks were processed
|
||||
upsert_calls = mock_qdrant_instance.upsert.call_args_list
|
||||
|
||||
# First chunk
|
||||
first_call = upsert_calls[0]
|
||||
first_point = first_call[1]['points'][0]
|
||||
assert first_point.vector == [0.1, 0.2]
|
||||
assert first_point.payload['doc'] == 'first document chunk'
|
||||
|
||||
# Second chunk
|
||||
second_call = upsert_calls[1]
|
||||
second_point = second_call[1]['points'][0]
|
||||
assert second_point.vector == [0.3, 0.4]
|
||||
assert second_point.payload['doc'] == 'second document chunk'
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_store_document_embeddings_multiple_vectors_per_chunk(self, mock_base_init, mock_uuid, mock_qdrant_client):
|
||||
"""Test storing document embeddings with multiple vectors per chunk"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
mock_uuid.uuid4.return_value = MagicMock()
|
||||
mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid')
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-doc-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with chunk having multiple vectors
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'vector_user'
|
||||
mock_message.metadata.collection = 'vector_collection'
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.chunk.decode.return_value = 'multi-vector document chunk'
|
||||
mock_chunk.vectors = [
|
||||
[0.1, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.6],
|
||||
[0.7, 0.8, 0.9]
|
||||
]
|
||||
|
||||
mock_message.chunks = [mock_chunk]
|
||||
|
||||
# Act
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Should be called 3 times (once per vector)
|
||||
assert mock_qdrant_instance.upsert.call_count == 3
|
||||
|
||||
# Verify all vectors were processed
|
||||
upsert_calls = mock_qdrant_instance.upsert.call_args_list
|
||||
|
||||
expected_vectors = [
|
||||
[0.1, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.6],
|
||||
[0.7, 0.8, 0.9]
|
||||
]
|
||||
|
||||
for i, call in enumerate(upsert_calls):
|
||||
point = call[1]['points'][0]
|
||||
assert point.vector == expected_vectors[i]
|
||||
assert point.payload['doc'] == 'multi-vector document chunk'
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_store_document_embeddings_empty_chunk(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test storing document embeddings skips empty chunks"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-doc-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with empty chunk
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'empty_user'
|
||||
mock_message.metadata.collection = 'empty_collection'
|
||||
|
||||
mock_chunk_empty = MagicMock()
|
||||
mock_chunk_empty.chunk.decode.return_value = "" # Empty string
|
||||
mock_chunk_empty.vectors = [[0.1, 0.2]]
|
||||
|
||||
mock_message.chunks = [mock_chunk_empty]
|
||||
|
||||
# Act
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Should not call upsert for empty chunks
|
||||
mock_qdrant_instance.upsert.assert_not_called()
|
||||
mock_qdrant_instance.collection_exists.assert_not_called()
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_collection_creation_when_not_exists(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test collection creation when it doesn't exist"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = False # Collection doesn't exist
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-doc-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'new_user'
|
||||
mock_message.metadata.collection = 'new_collection'
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.chunk.decode.return_value = 'test chunk'
|
||||
mock_chunk.vectors = [[0.1, 0.2, 0.3, 0.4, 0.5]] # 5 dimensions
|
||||
|
||||
mock_message.chunks = [mock_chunk]
|
||||
|
||||
# Act
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
expected_collection = 'd_new_user_new_collection_5'
|
||||
|
||||
# Verify collection existence check and creation
|
||||
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection)
|
||||
mock_qdrant_instance.create_collection.assert_called_once()
|
||||
|
||||
# Verify create_collection was called with correct parameters
|
||||
create_call_args = mock_qdrant_instance.create_collection.call_args
|
||||
assert create_call_args[1]['collection_name'] == expected_collection
|
||||
|
||||
# Verify upsert was still called after collection creation
|
||||
mock_qdrant_instance.upsert.assert_called_once()
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_collection_creation_exception(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test collection creation handles exceptions"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = False
|
||||
mock_qdrant_instance.create_collection.side_effect = Exception("Qdrant connection failed")
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-doc-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'error_user'
|
||||
mock_message.metadata.collection = 'error_collection'
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.chunk.decode.return_value = 'test chunk'
|
||||
mock_chunk.vectors = [[0.1, 0.2]]
|
||||
|
||||
mock_message.chunks = [mock_chunk]
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Qdrant connection failed"):
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_collection_caching_behavior(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test collection caching with last_collection"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-doc-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create first mock message
|
||||
mock_message1 = MagicMock()
|
||||
mock_message1.metadata.user = 'cache_user'
|
||||
mock_message1.metadata.collection = 'cache_collection'
|
||||
|
||||
mock_chunk1 = MagicMock()
|
||||
mock_chunk1.chunk.decode.return_value = 'first chunk'
|
||||
mock_chunk1.vectors = [[0.1, 0.2, 0.3]]
|
||||
|
||||
mock_message1.chunks = [mock_chunk1]
|
||||
|
||||
# First call
|
||||
await processor.store_document_embeddings(mock_message1)
|
||||
|
||||
# Reset mock to track second call
|
||||
mock_qdrant_instance.reset_mock()
|
||||
|
||||
# Create second mock message with same dimensions
|
||||
mock_message2 = MagicMock()
|
||||
mock_message2.metadata.user = 'cache_user'
|
||||
mock_message2.metadata.collection = 'cache_collection'
|
||||
|
||||
mock_chunk2 = MagicMock()
|
||||
mock_chunk2.chunk.decode.return_value = 'second chunk'
|
||||
mock_chunk2.vectors = [[0.4, 0.5, 0.6]] # Same dimension (3)
|
||||
|
||||
mock_message2.chunks = [mock_chunk2]
|
||||
|
||||
# Act - Second call with same collection
|
||||
await processor.store_document_embeddings(mock_message2)
|
||||
|
||||
# Assert
|
||||
expected_collection = 'd_cache_user_cache_collection_3'
|
||||
assert processor.last_collection == expected_collection
|
||||
|
||||
# Verify second call skipped existence check (cached)
|
||||
mock_qdrant_instance.collection_exists.assert_not_called()
|
||||
mock_qdrant_instance.create_collection.assert_not_called()
|
||||
|
||||
# But upsert should still be called
|
||||
mock_qdrant_instance.upsert.assert_called_once()
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_different_dimensions_different_collections(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test that different vector dimensions create different collections"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-doc-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with different dimension vectors
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'dim_user'
|
||||
mock_message.metadata.collection = 'dim_collection'
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.chunk.decode.return_value = 'dimension test chunk'
|
||||
mock_chunk.vectors = [
|
||||
[0.1, 0.2], # 2 dimensions
|
||||
[0.3, 0.4, 0.5] # 3 dimensions
|
||||
]
|
||||
|
||||
mock_message.chunks = [mock_chunk]
|
||||
|
||||
# Act
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Should check existence of both collections
|
||||
expected_collections = ['d_dim_user_dim_collection_2', 'd_dim_user_dim_collection_3']
|
||||
actual_calls = [call.args[0] for call in mock_qdrant_instance.collection_exists.call_args_list]
|
||||
assert actual_calls == expected_collections
|
||||
|
||||
# Should upsert to both collections
|
||||
assert mock_qdrant_instance.upsert.call_count == 2
|
||||
|
||||
upsert_calls = mock_qdrant_instance.upsert.call_args_list
|
||||
assert upsert_calls[0][1]['collection_name'] == 'd_dim_user_dim_collection_2'
|
||||
assert upsert_calls[1][1]['collection_name'] == 'd_dim_user_dim_collection_3'
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_add_args_calls_parent(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test that add_args() calls parent add_args method"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_client.return_value = MagicMock()
|
||||
mock_parser = MagicMock()
|
||||
|
||||
# Act
|
||||
with patch('trustgraph.base.DocumentEmbeddingsStoreService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(mock_parser)
|
||||
|
||||
# Assert
|
||||
mock_parent_add_args.assert_called_once_with(mock_parser)
|
||||
|
||||
# Verify processor-specific arguments were added
|
||||
assert mock_parser.add_argument.call_count >= 2 # At least store-uri and api-key
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_utf8_decoding_handling(self, mock_base_init, mock_uuid, mock_qdrant_client):
|
||||
"""Test proper UTF-8 decoding of chunk text"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
mock_uuid.uuid4.return_value = MagicMock()
|
||||
mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid')
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-doc-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with UTF-8 encoded text
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'utf8_user'
|
||||
mock_message.metadata.collection = 'utf8_collection'
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.chunk.decode.return_value = 'UTF-8 text with special chars: café, naïve, résumé'
|
||||
mock_chunk.vectors = [[0.1, 0.2]]
|
||||
|
||||
mock_message.chunks = [mock_chunk]
|
||||
|
||||
# Act
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify chunk.decode was called with 'utf-8'
|
||||
mock_chunk.chunk.decode.assert_called_with('utf-8')
|
||||
|
||||
# Verify the decoded text was stored in payload
|
||||
upsert_call_args = mock_qdrant_instance.upsert.call_args
|
||||
point = upsert_call_args[1]['points'][0]
|
||||
assert point.payload['doc'] == 'UTF-8 text with special chars: café, naïve, résumé'
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_chunk_decode_exception_handling(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test handling of chunk decode exceptions"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-doc-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with decode error
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'decode_user'
|
||||
mock_message.metadata.collection = 'decode_collection'
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.chunk.decode.side_effect = UnicodeDecodeError('utf-8', b'', 0, 1, 'invalid start byte')
|
||||
mock_chunk.vectors = [[0.1, 0.2]]
|
||||
|
||||
mock_message.chunks = [mock_chunk]
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(UnicodeDecodeError):
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
428
tests/unit/test_storage/test_graph_embeddings_qdrant_storage.py
Normal file
428
tests/unit/test_storage/test_graph_embeddings_qdrant_storage.py
Normal file
|
|
@ -0,0 +1,428 @@
|
|||
"""
|
||||
Unit tests for trustgraph.storage.graph_embeddings.qdrant.write
|
||||
Starting small with a single test to verify basic functionality
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.storage.graph_embeddings.qdrant.write import Processor
|
||||
|
||||
|
||||
class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||
"""Test Qdrant graph embeddings storage functionality"""
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test basic Qdrant processor initialization"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-qdrant-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify base class initialization was called
|
||||
mock_base_init.assert_called_once()
|
||||
|
||||
# Verify QdrantClient was created with correct parameters
|
||||
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key='test-api-key')
|
||||
|
||||
# Verify processor attributes
|
||||
assert hasattr(processor, 'qdrant')
|
||||
assert processor.qdrant == mock_qdrant_instance
|
||||
assert hasattr(processor, 'last_collection')
|
||||
assert processor.last_collection is None
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
async def test_get_collection_creates_new_collection(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test get_collection creates a new collection when it doesn't exist"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = False
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
collection_name = processor.get_collection(dim=512, user='test_user', collection='test_collection')
|
||||
|
||||
# Assert
|
||||
expected_name = 't_test_user_test_collection_512'
|
||||
assert collection_name == expected_name
|
||||
assert processor.last_collection == expected_name
|
||||
|
||||
# Verify collection existence check and creation
|
||||
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_name)
|
||||
mock_qdrant_instance.create_collection.assert_called_once()
|
||||
|
||||
# Verify create_collection was called with correct parameters
|
||||
create_call_args = mock_qdrant_instance.create_collection.call_args
|
||||
assert create_call_args[1]['collection_name'] == expected_name
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
async def test_store_graph_embeddings_basic(self, mock_base_init, mock_uuid, mock_qdrant_client):
|
||||
"""Test storing graph embeddings with basic message"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True # Collection already exists
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
mock_uuid.uuid4.return_value.return_value = 'test-uuid-123'
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with entities and vectors
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'test_user'
|
||||
mock_message.metadata.collection = 'test_collection'
|
||||
|
||||
mock_entity = MagicMock()
|
||||
mock_entity.entity.value = 'test_entity'
|
||||
mock_entity.vectors = [[0.1, 0.2, 0.3]] # Single vector with 3 dimensions
|
||||
|
||||
mock_message.entities = [mock_entity]
|
||||
|
||||
# Act
|
||||
await processor.store_graph_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify collection existence was checked
|
||||
expected_collection = 't_test_user_test_collection_3'
|
||||
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection)
|
||||
|
||||
# Verify upsert was called
|
||||
mock_qdrant_instance.upsert.assert_called_once()
|
||||
|
||||
# Verify upsert parameters
|
||||
upsert_call_args = mock_qdrant_instance.upsert.call_args
|
||||
assert upsert_call_args[1]['collection_name'] == expected_collection
|
||||
assert len(upsert_call_args[1]['points']) == 1
|
||||
|
||||
point = upsert_call_args[1]['points'][0]
|
||||
assert point.vector == [0.1, 0.2, 0.3]
|
||||
assert point.payload['entity'] == 'test_entity'
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
async def test_get_collection_uses_existing_collection(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test get_collection uses existing collection without creating new one"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True # Collection exists
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
collection_name = processor.get_collection(dim=256, user='existing_user', collection='existing_collection')
|
||||
|
||||
# Assert
|
||||
expected_name = 't_existing_user_existing_collection_256'
|
||||
assert collection_name == expected_name
|
||||
assert processor.last_collection == expected_name
|
||||
|
||||
# Verify collection existence check was performed
|
||||
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_name)
|
||||
# Verify create_collection was NOT called
|
||||
mock_qdrant_instance.create_collection.assert_not_called()
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
async def test_get_collection_caches_last_collection(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test get_collection skips checks when using same collection"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# First call
|
||||
collection_name1 = processor.get_collection(dim=128, user='cache_user', collection='cache_collection')
|
||||
|
||||
# Reset mock to track second call
|
||||
mock_qdrant_instance.reset_mock()
|
||||
|
||||
# Act - Second call with same parameters
|
||||
collection_name2 = processor.get_collection(dim=128, user='cache_user', collection='cache_collection')
|
||||
|
||||
# Assert
|
||||
expected_name = 't_cache_user_cache_collection_128'
|
||||
assert collection_name1 == expected_name
|
||||
assert collection_name2 == expected_name
|
||||
|
||||
# Verify second call skipped existence check (cached)
|
||||
mock_qdrant_instance.collection_exists.assert_not_called()
|
||||
mock_qdrant_instance.create_collection.assert_not_called()
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
async def test_get_collection_creation_exception(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test get_collection handles collection creation exceptions"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = False
|
||||
mock_qdrant_instance.create_collection.side_effect = Exception("Qdrant connection failed")
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Qdrant connection failed"):
|
||||
processor.get_collection(dim=512, user='error_user', collection='error_collection')
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
async def test_store_graph_embeddings_multiple_entities(self, mock_base_init, mock_uuid, mock_qdrant_client):
|
||||
"""Test storing graph embeddings with multiple entities"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
mock_uuid.uuid4.return_value.return_value = 'test-uuid'
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with multiple entities
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'multi_user'
|
||||
mock_message.metadata.collection = 'multi_collection'
|
||||
|
||||
mock_entity1 = MagicMock()
|
||||
mock_entity1.entity.value = 'entity_one'
|
||||
mock_entity1.vectors = [[0.1, 0.2]]
|
||||
|
||||
mock_entity2 = MagicMock()
|
||||
mock_entity2.entity.value = 'entity_two'
|
||||
mock_entity2.vectors = [[0.3, 0.4]]
|
||||
|
||||
mock_message.entities = [mock_entity1, mock_entity2]
|
||||
|
||||
# Act
|
||||
await processor.store_graph_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Should be called twice (once per entity)
|
||||
assert mock_qdrant_instance.upsert.call_count == 2
|
||||
|
||||
# Verify both entities were processed
|
||||
upsert_calls = mock_qdrant_instance.upsert.call_args_list
|
||||
|
||||
# First entity
|
||||
first_call = upsert_calls[0]
|
||||
first_point = first_call[1]['points'][0]
|
||||
assert first_point.vector == [0.1, 0.2]
|
||||
assert first_point.payload['entity'] == 'entity_one'
|
||||
|
||||
# Second entity
|
||||
second_call = upsert_calls[1]
|
||||
second_point = second_call[1]['points'][0]
|
||||
assert second_point.vector == [0.3, 0.4]
|
||||
assert second_point.payload['entity'] == 'entity_two'
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
async def test_store_graph_embeddings_multiple_vectors_per_entity(self, mock_base_init, mock_uuid, mock_qdrant_client):
|
||||
"""Test storing graph embeddings with multiple vectors per entity"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
mock_uuid.uuid4.return_value.return_value = 'test-uuid'
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with entity having multiple vectors
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'vector_user'
|
||||
mock_message.metadata.collection = 'vector_collection'
|
||||
|
||||
mock_entity = MagicMock()
|
||||
mock_entity.entity.value = 'multi_vector_entity'
|
||||
mock_entity.vectors = [
|
||||
[0.1, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.6],
|
||||
[0.7, 0.8, 0.9]
|
||||
]
|
||||
|
||||
mock_message.entities = [mock_entity]
|
||||
|
||||
# Act
|
||||
await processor.store_graph_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Should be called 3 times (once per vector)
|
||||
assert mock_qdrant_instance.upsert.call_count == 3
|
||||
|
||||
# Verify all vectors were processed
|
||||
upsert_calls = mock_qdrant_instance.upsert.call_args_list
|
||||
|
||||
expected_vectors = [
|
||||
[0.1, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.6],
|
||||
[0.7, 0.8, 0.9]
|
||||
]
|
||||
|
||||
for i, call in enumerate(upsert_calls):
|
||||
point = call[1]['points'][0]
|
||||
assert point.vector == expected_vectors[i]
|
||||
assert point.payload['entity'] == 'multi_vector_entity'
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
async def test_store_graph_embeddings_empty_entity_value(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test storing graph embeddings skips empty entity values"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with empty entity value
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'empty_user'
|
||||
mock_message.metadata.collection = 'empty_collection'
|
||||
|
||||
mock_entity_empty = MagicMock()
|
||||
mock_entity_empty.entity.value = "" # Empty string
|
||||
mock_entity_empty.vectors = [[0.1, 0.2]]
|
||||
|
||||
mock_entity_none = MagicMock()
|
||||
mock_entity_none.entity.value = None # None value
|
||||
mock_entity_none.vectors = [[0.3, 0.4]]
|
||||
|
||||
mock_message.entities = [mock_entity_empty, mock_entity_none]
|
||||
|
||||
# Act
|
||||
await processor.store_graph_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Should not call upsert for empty entities
|
||||
mock_qdrant_instance.upsert.assert_not_called()
|
||||
mock_qdrant_instance.collection_exists.assert_not_called()
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
async def test_processor_initialization_with_defaults(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test processor initialization with default values"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-qdrant-processor'
|
||||
# No store_uri or api_key provided - should use defaults
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify QdrantClient was created with default URI and None API key
|
||||
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key=None)
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
async def test_add_args_calls_parent(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test that add_args() calls parent add_args method"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_client.return_value = MagicMock()
|
||||
mock_parser = MagicMock()
|
||||
|
||||
# Act
|
||||
with patch('trustgraph.base.GraphEmbeddingsStoreService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(mock_parser)
|
||||
|
||||
# Assert
|
||||
mock_parent_add_args.assert_called_once_with(mock_parser)
|
||||
|
||||
# Verify processor-specific arguments were added
|
||||
assert mock_parser.add_argument.call_count >= 2 # At least store-uri and api-key
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
373
tests/unit/test_storage/test_triples_cassandra_storage.py
Normal file
373
tests/unit/test_storage/test_triples_cassandra_storage.py
Normal file
|
|
@ -0,0 +1,373 @@
|
|||
"""
|
||||
Tests for Cassandra triples storage service
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
|
||||
from trustgraph.storage.triples.cassandra.write import Processor
|
||||
from trustgraph.schema import Value, Triple
|
||||
|
||||
|
||||
class TestCassandraStorageProcessor:
|
||||
"""Test cases for Cassandra storage processor"""
|
||||
|
||||
def test_processor_initialization_with_defaults(self):
|
||||
"""Test processor initialization with default parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
assert processor.graph_host == ['localhost']
|
||||
assert processor.username is None
|
||||
assert processor.password is None
|
||||
assert processor.table is None
|
||||
|
||||
def test_processor_initialization_with_custom_params(self):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
id='custom-storage',
|
||||
graph_host='cassandra.example.com',
|
||||
graph_username='testuser',
|
||||
graph_password='testpass'
|
||||
)
|
||||
|
||||
assert processor.graph_host == ['cassandra.example.com']
|
||||
assert processor.username == 'testuser'
|
||||
assert processor.password == 'testpass'
|
||||
assert processor.table is None
|
||||
|
||||
def test_processor_initialization_with_partial_auth(self):
|
||||
"""Test processor initialization with only username (no password)"""
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
graph_username='testuser'
|
||||
)
|
||||
|
||||
assert processor.username == 'testuser'
|
||||
assert processor.password is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.TrustGraph')
|
||||
async def test_table_switching_with_auth(self, mock_trustgraph):
|
||||
"""Test table switching logic when authentication is provided"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
graph_username='testuser',
|
||||
graph_password='testpass'
|
||||
)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'user1'
|
||||
mock_message.metadata.collection = 'collection1'
|
||||
mock_message.triples = []
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify TrustGraph was called with auth parameters
|
||||
mock_trustgraph.assert_called_once_with(
|
||||
hosts=['localhost'],
|
||||
keyspace='user1',
|
||||
table='collection1',
|
||||
username='testuser',
|
||||
password='testpass'
|
||||
)
|
||||
assert processor.table == ('user1', 'collection1')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.TrustGraph')
|
||||
async def test_table_switching_without_auth(self, mock_trustgraph):
|
||||
"""Test table switching logic when no authentication is provided"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'user2'
|
||||
mock_message.metadata.collection = 'collection2'
|
||||
mock_message.triples = []
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify TrustGraph was called without auth parameters
|
||||
mock_trustgraph.assert_called_once_with(
|
||||
hosts=['localhost'],
|
||||
keyspace='user2',
|
||||
table='collection2'
|
||||
)
|
||||
assert processor.table == ('user2', 'collection2')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.TrustGraph')
|
||||
async def test_table_reuse_when_same(self, mock_trustgraph):
|
||||
"""Test that TrustGraph is not recreated when table hasn't changed"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'user1'
|
||||
mock_message.metadata.collection = 'collection1'
|
||||
mock_message.triples = []
|
||||
|
||||
# First call should create TrustGraph
|
||||
await processor.store_triples(mock_message)
|
||||
assert mock_trustgraph.call_count == 1
|
||||
|
||||
# Second call with same table should reuse TrustGraph
|
||||
await processor.store_triples(mock_message)
|
||||
assert mock_trustgraph.call_count == 1 # Should not increase
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.TrustGraph')
|
||||
async def test_triple_insertion(self, mock_trustgraph):
|
||||
"""Test that triples are properly inserted into Cassandra"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create mock triples
|
||||
triple1 = MagicMock()
|
||||
triple1.s.value = 'subject1'
|
||||
triple1.p.value = 'predicate1'
|
||||
triple1.o.value = 'object1'
|
||||
|
||||
triple2 = MagicMock()
|
||||
triple2.s.value = 'subject2'
|
||||
triple2.p.value = 'predicate2'
|
||||
triple2.o.value = 'object2'
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'user1'
|
||||
mock_message.metadata.collection = 'collection1'
|
||||
mock_message.triples = [triple1, triple2]
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify both triples were inserted
|
||||
assert mock_tg_instance.insert.call_count == 2
|
||||
mock_tg_instance.insert.assert_any_call('subject1', 'predicate1', 'object1')
|
||||
mock_tg_instance.insert.assert_any_call('subject2', 'predicate2', 'object2')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.TrustGraph')
|
||||
async def test_triple_insertion_with_empty_list(self, mock_trustgraph):
|
||||
"""Test behavior when message has no triples"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create mock message with empty triples
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'user1'
|
||||
mock_message.metadata.collection = 'collection1'
|
||||
mock_message.triples = []
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify no triples were inserted
|
||||
mock_tg_instance.insert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.TrustGraph')
|
||||
@patch('trustgraph.storage.triples.cassandra.write.time.sleep')
|
||||
async def test_exception_handling_with_retry(self, mock_sleep, mock_trustgraph):
|
||||
"""Test exception handling during TrustGraph creation"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_trustgraph.side_effect = Exception("Connection failed")
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'user1'
|
||||
mock_message.metadata.collection = 'collection1'
|
||||
mock_message.triples = []
|
||||
|
||||
with pytest.raises(Exception, match="Connection failed"):
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify sleep was called before re-raising
|
||||
mock_sleep.assert_called_once_with(1)
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
# Mock the parent class add_args method
|
||||
with patch('trustgraph.storage.triples.cassandra.write.TriplesStoreService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Verify parent add_args was called
|
||||
mock_parent_add_args.assert_called_once_with(parser)
|
||||
|
||||
# Verify our specific arguments were added
|
||||
# Parse empty args to check defaults
|
||||
args = parser.parse_args([])
|
||||
|
||||
assert hasattr(args, 'graph_host')
|
||||
assert args.graph_host == 'localhost'
|
||||
assert hasattr(args, 'graph_username')
|
||||
assert args.graph_username is None
|
||||
assert hasattr(args, 'graph_password')
|
||||
assert args.graph_password is None
|
||||
|
||||
def test_add_args_with_custom_values(self):
|
||||
"""Test add_args with custom command line values"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.storage.triples.cassandra.write.TriplesStoreService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with custom values
|
||||
args = parser.parse_args([
|
||||
'--graph-host', 'cassandra.example.com',
|
||||
'--graph-username', 'testuser',
|
||||
'--graph-password', 'testpass'
|
||||
])
|
||||
|
||||
assert args.graph_host == 'cassandra.example.com'
|
||||
assert args.graph_username == 'testuser'
|
||||
assert args.graph_password == 'testpass'
|
||||
|
||||
def test_add_args_short_form(self):
|
||||
"""Test add_args with short form arguments"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.storage.triples.cassandra.write.TriplesStoreService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with short form
|
||||
args = parser.parse_args(['-g', 'short.example.com'])
|
||||
|
||||
assert args.graph_host == 'short.example.com'
|
||||
|
||||
@patch('trustgraph.storage.triples.cassandra.write.Processor.launch')
|
||||
def test_run_function(self, mock_launch):
|
||||
"""Test the run function calls Processor.launch with correct parameters"""
|
||||
from trustgraph.storage.triples.cassandra.write import run, default_ident
|
||||
|
||||
run()
|
||||
|
||||
mock_launch.assert_called_once_with(default_ident, '\nGraph writer. Input is graph edge. Writes edges to Cassandra graph.\n')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.TrustGraph')
|
||||
async def test_store_triples_table_switching_between_different_tables(self, mock_trustgraph):
|
||||
"""Test table switching when different tables are used in sequence"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance1 = MagicMock()
|
||||
mock_tg_instance2 = MagicMock()
|
||||
mock_trustgraph.side_effect = [mock_tg_instance1, mock_tg_instance2]
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# First message with table1
|
||||
mock_message1 = MagicMock()
|
||||
mock_message1.metadata.user = 'user1'
|
||||
mock_message1.metadata.collection = 'collection1'
|
||||
mock_message1.triples = []
|
||||
|
||||
await processor.store_triples(mock_message1)
|
||||
assert processor.table == ('user1', 'collection1')
|
||||
assert processor.tg == mock_tg_instance1
|
||||
|
||||
# Second message with different table
|
||||
mock_message2 = MagicMock()
|
||||
mock_message2.metadata.user = 'user2'
|
||||
mock_message2.metadata.collection = 'collection2'
|
||||
mock_message2.triples = []
|
||||
|
||||
await processor.store_triples(mock_message2)
|
||||
assert processor.table == ('user2', 'collection2')
|
||||
assert processor.tg == mock_tg_instance2
|
||||
|
||||
# Verify TrustGraph was created twice for different tables
|
||||
assert mock_trustgraph.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.TrustGraph')
|
||||
async def test_store_triples_with_special_characters_in_values(self, mock_trustgraph):
|
||||
"""Test storing triples with special characters and unicode"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create triple with special characters
|
||||
triple = MagicMock()
|
||||
triple.s.value = 'subject with spaces & symbols'
|
||||
triple.p.value = 'predicate:with/colons'
|
||||
triple.o.value = 'object with "quotes" and unicode: ñáéíóú'
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'test_user'
|
||||
mock_message.metadata.collection = 'test_collection'
|
||||
mock_message.triples = [triple]
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify the triple was inserted with special characters preserved
|
||||
mock_tg_instance.insert.assert_called_once_with(
|
||||
'subject with spaces & symbols',
|
||||
'predicate:with/colons',
|
||||
'object with "quotes" and unicode: ñáéíóú'
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.TrustGraph')
|
||||
async def test_store_triples_preserves_old_table_on_exception(self, mock_trustgraph):
|
||||
"""Test that table remains unchanged when TrustGraph creation fails"""
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Set an initial table
|
||||
processor.table = ('old_user', 'old_collection')
|
||||
|
||||
# Mock TrustGraph to raise exception
|
||||
mock_trustgraph.side_effect = Exception("Connection failed")
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'new_user'
|
||||
mock_message.metadata.collection = 'new_collection'
|
||||
mock_message.triples = []
|
||||
|
||||
with pytest.raises(Exception, match="Connection failed"):
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Table should remain unchanged since self.table = table happens after try/except
|
||||
assert processor.table == ('old_user', 'old_collection')
|
||||
# TrustGraph should be set to None though
|
||||
assert processor.tg is None
|
||||
3
tests/unit/test_text_completion/__init__.py
Normal file
3
tests/unit/test_text_completion/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
"""
|
||||
Unit tests for text completion services
|
||||
"""
|
||||
3
tests/unit/test_text_completion/common/__init__.py
Normal file
3
tests/unit/test_text_completion/common/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
"""
|
||||
Common utilities for text completion tests
|
||||
"""
|
||||
69
tests/unit/test_text_completion/common/base_test_cases.py
Normal file
69
tests/unit/test_text_completion/common/base_test_cases.py
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
"""
|
||||
Base test patterns that can be reused across different text completion models
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
|
||||
class BaseTextCompletionTestCase(IsolatedAsyncioTestCase, ABC):
|
||||
"""
|
||||
Base test class for text completion processors
|
||||
Provides common test patterns that can be reused
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_processor_class(self):
|
||||
"""Return the processor class to test"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_base_config(self):
|
||||
"""Return base configuration for the processor"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_mock_patches(self):
|
||||
"""Return list of patch decorators for mocking dependencies"""
|
||||
pass
|
||||
|
||||
def create_base_config(self, **overrides):
|
||||
"""Create base config with optional overrides"""
|
||||
config = self.get_base_config()
|
||||
config.update(overrides)
|
||||
return config
|
||||
|
||||
def create_mock_llm_result(self, text="Test response", in_token=10, out_token=5):
|
||||
"""Create a mock LLM result"""
|
||||
from trustgraph.base import LlmResult
|
||||
return LlmResult(text=text, in_token=in_token, out_token=out_token)
|
||||
|
||||
|
||||
class CommonTestPatterns:
|
||||
"""
|
||||
Common test patterns that can be used across different models
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def basic_initialization_test_pattern(test_instance):
|
||||
"""
|
||||
Test pattern for basic processor initialization
|
||||
test_instance should be a BaseTextCompletionTestCase
|
||||
"""
|
||||
# This would contain the common pattern for initialization testing
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def successful_generation_test_pattern(test_instance):
|
||||
"""
|
||||
Test pattern for successful content generation
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def error_handling_test_pattern(test_instance):
|
||||
"""
|
||||
Test pattern for error handling
|
||||
"""
|
||||
pass
|
||||
53
tests/unit/test_text_completion/common/mock_helpers.py
Normal file
53
tests/unit/test_text_completion/common/mock_helpers.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
"""
|
||||
Common mocking utilities for text completion tests
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
|
||||
class CommonMocks:
|
||||
"""Common mock objects used across text completion tests"""
|
||||
|
||||
@staticmethod
|
||||
def create_mock_async_processor_init():
|
||||
"""Create mock for AsyncProcessor.__init__"""
|
||||
mock = MagicMock()
|
||||
mock.return_value = None
|
||||
return mock
|
||||
|
||||
@staticmethod
|
||||
def create_mock_llm_service_init():
|
||||
"""Create mock for LlmService.__init__"""
|
||||
mock = MagicMock()
|
||||
mock.return_value = None
|
||||
return mock
|
||||
|
||||
@staticmethod
|
||||
def create_mock_response(text="Test response", prompt_tokens=10, completion_tokens=5):
|
||||
"""Create a mock response object"""
|
||||
response = MagicMock()
|
||||
response.text = text
|
||||
response.usage_metadata.prompt_token_count = prompt_tokens
|
||||
response.usage_metadata.candidates_token_count = completion_tokens
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def create_basic_config():
|
||||
"""Create basic config with required fields"""
|
||||
return {
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
|
||||
class MockPatches:
|
||||
"""Common patch decorators for different services"""
|
||||
|
||||
@staticmethod
|
||||
def get_base_patches():
|
||||
"""Get patches that are common to all processors"""
|
||||
return [
|
||||
'trustgraph.base.async_processor.AsyncProcessor.__init__',
|
||||
'trustgraph.base.llm_service.LlmService.__init__'
|
||||
]
|
||||
499
tests/unit/test_text_completion/conftest.py
Normal file
499
tests/unit/test_text_completion/conftest.py
Normal file
|
|
@ -0,0 +1,499 @@
|
|||
"""
|
||||
Pytest configuration and fixtures for text completion tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
from trustgraph.base import LlmResult
|
||||
|
||||
|
||||
# === Common Fixtures for All Text Completion Models ===
|
||||
|
||||
@pytest.fixture
|
||||
def base_processor_config():
|
||||
"""Base configuration required by all processors"""
|
||||
return {
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_llm_result():
|
||||
"""Sample LlmResult for testing"""
|
||||
return LlmResult(
|
||||
text="Test response",
|
||||
in_token=10,
|
||||
out_token=5
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_async_processor_init():
|
||||
"""Mock AsyncProcessor.__init__ to avoid infrastructure requirements"""
|
||||
mock = MagicMock()
|
||||
mock.return_value = None
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_service_init():
|
||||
"""Mock LlmService.__init__ to avoid infrastructure requirements"""
|
||||
mock = MagicMock()
|
||||
mock.return_value = None
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_prometheus_metrics():
|
||||
"""Mock Prometheus metrics"""
|
||||
mock_metric = MagicMock()
|
||||
mock_metric.labels.return_value.time.return_value = MagicMock()
|
||||
return mock_metric
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pulsar_consumer():
|
||||
"""Mock Pulsar consumer for integration testing"""
|
||||
return AsyncMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pulsar_producer():
|
||||
"""Mock Pulsar producer for integration testing"""
|
||||
return AsyncMock()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_env_vars(monkeypatch):
|
||||
"""Mock environment variables for testing"""
|
||||
monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "test-project")
|
||||
monkeypatch.setenv("GOOGLE_APPLICATION_CREDENTIALS", "/path/to/test-credentials.json")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_async_context_manager():
|
||||
"""Mock async context manager for testing"""
|
||||
class MockAsyncContextManager:
|
||||
def __init__(self, return_value):
|
||||
self.return_value = return_value
|
||||
|
||||
async def __aenter__(self):
|
||||
return self.return_value
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
pass
|
||||
|
||||
return MockAsyncContextManager
|
||||
|
||||
|
||||
# === VertexAI Specific Fixtures ===
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vertexai_credentials():
|
||||
"""Mock Google Cloud service account credentials"""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vertexai_model():
|
||||
"""Mock VertexAI GenerativeModel"""
|
||||
mock_model = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Test response"
|
||||
mock_response.usage_metadata.prompt_token_count = 10
|
||||
mock_response.usage_metadata.candidates_token_count = 5
|
||||
mock_model.generate_content.return_value = mock_response
|
||||
return mock_model
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vertexai_processor_config(base_processor_config):
|
||||
"""Default configuration for VertexAI processor"""
|
||||
config = base_processor_config.copy()
|
||||
config.update({
|
||||
'region': 'us-central1',
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'private_key': 'private.json'
|
||||
})
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_safety_settings():
|
||||
"""Mock safety settings for VertexAI"""
|
||||
safety_settings = []
|
||||
for i in range(4): # 4 safety categories
|
||||
setting = MagicMock()
|
||||
setting.category = f"HARM_CATEGORY_{i}"
|
||||
setting.threshold = "BLOCK_MEDIUM_AND_ABOVE"
|
||||
safety_settings.append(setting)
|
||||
|
||||
return safety_settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_generation_config():
|
||||
"""Mock generation configuration for VertexAI"""
|
||||
config = MagicMock()
|
||||
config.temperature = 0.0
|
||||
config.max_output_tokens = 8192
|
||||
config.top_p = 1.0
|
||||
config.top_k = 10
|
||||
config.candidate_count = 1
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vertexai_exception():
|
||||
"""Mock VertexAI exceptions"""
|
||||
from google.api_core.exceptions import ResourceExhausted
|
||||
return ResourceExhausted("Test resource exhausted error")
|
||||
|
||||
|
||||
# === Ollama Specific Fixtures ===
|
||||
|
||||
@pytest.fixture
|
||||
def ollama_processor_config(base_processor_config):
|
||||
"""Default configuration for Ollama processor"""
|
||||
config = base_processor_config.copy()
|
||||
config.update({
|
||||
'model': 'llama2',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'host': 'localhost',
|
||||
'port': 11434
|
||||
})
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ollama_client():
|
||||
"""Mock Ollama client"""
|
||||
mock_client = MagicMock()
|
||||
mock_response = {
|
||||
'response': 'Test response from Ollama',
|
||||
'done': True,
|
||||
'eval_count': 5,
|
||||
'prompt_eval_count': 10
|
||||
}
|
||||
mock_client.generate.return_value = mock_response
|
||||
return mock_client
|
||||
|
||||
|
||||
# === OpenAI Specific Fixtures ===
|
||||
|
||||
@pytest.fixture
|
||||
def openai_processor_config(base_processor_config):
|
||||
"""Default configuration for OpenAI processor"""
|
||||
config = base_processor_config.copy()
|
||||
config.update({
|
||||
'model': 'gpt-3.5-turbo',
|
||||
'api_key': 'test-api-key',
|
||||
'url': 'https://api.openai.com/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096
|
||||
})
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_client():
|
||||
"""Mock OpenAI client"""
|
||||
mock_client = MagicMock()
|
||||
|
||||
# Mock the response structure
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Test response from OpenAI"
|
||||
mock_response.usage.prompt_tokens = 15
|
||||
mock_response.usage.completion_tokens = 8
|
||||
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_rate_limit_error():
|
||||
"""Mock OpenAI rate limit error"""
|
||||
from openai import RateLimitError
|
||||
return RateLimitError("Rate limit exceeded", response=MagicMock(), body=None)
|
||||
|
||||
|
||||
# === Azure OpenAI Specific Fixtures ===
|
||||
|
||||
@pytest.fixture
|
||||
def azure_openai_processor_config(base_processor_config):
|
||||
"""Default configuration for Azure OpenAI processor"""
|
||||
config = base_processor_config.copy()
|
||||
config.update({
|
||||
'model': 'gpt-4',
|
||||
'endpoint': 'https://test.openai.azure.com/',
|
||||
'token': 'test-token',
|
||||
'api_version': '2024-12-01-preview',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192
|
||||
})
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_azure_openai_client():
|
||||
"""Mock Azure OpenAI client"""
|
||||
mock_client = MagicMock()
|
||||
|
||||
# Mock the response structure
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Test response from Azure OpenAI"
|
||||
mock_response.usage.prompt_tokens = 20
|
||||
mock_response.usage.completion_tokens = 10
|
||||
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_azure_openai_rate_limit_error():
|
||||
"""Mock Azure OpenAI rate limit error"""
|
||||
from openai import RateLimitError
|
||||
return RateLimitError("Rate limit exceeded", response=MagicMock(), body=None)
|
||||
|
||||
|
||||
# === Azure Specific Fixtures ===
|
||||
|
||||
@pytest.fixture
|
||||
def azure_processor_config(base_processor_config):
|
||||
"""Default configuration for Azure processor"""
|
||||
config = base_processor_config.copy()
|
||||
config.update({
|
||||
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
|
||||
'token': 'test-token',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192
|
||||
})
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_azure_requests():
|
||||
"""Mock requests for Azure processor"""
|
||||
mock_requests = MagicMock()
|
||||
|
||||
# Mock successful response
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
'choices': [{
|
||||
'message': {
|
||||
'content': 'Test response from Azure'
|
||||
}
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 18,
|
||||
'completion_tokens': 9
|
||||
}
|
||||
}
|
||||
mock_requests.post.return_value = mock_response
|
||||
return mock_requests
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_azure_rate_limit_response():
|
||||
"""Mock Azure rate limit response"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 429
|
||||
return mock_response
|
||||
|
||||
|
||||
# === Claude Specific Fixtures ===
|
||||
|
||||
@pytest.fixture
|
||||
def claude_processor_config(base_processor_config):
|
||||
"""Default configuration for Claude processor"""
|
||||
config = base_processor_config.copy()
|
||||
config.update({
|
||||
'model': 'claude-3-5-sonnet-20240620',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192
|
||||
})
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_claude_client():
|
||||
"""Mock Claude (Anthropic) client"""
|
||||
mock_client = MagicMock()
|
||||
|
||||
# Mock the response structure
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock()]
|
||||
mock_response.content[0].text = "Test response from Claude"
|
||||
mock_response.usage.input_tokens = 22
|
||||
mock_response.usage.output_tokens = 12
|
||||
|
||||
mock_client.messages.create.return_value = mock_response
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_claude_rate_limit_error():
|
||||
"""Mock Claude rate limit error"""
|
||||
import anthropic
|
||||
return anthropic.RateLimitError("Rate limit exceeded", response=MagicMock(), body=None)
|
||||
|
||||
|
||||
# === vLLM Specific Fixtures ===
|
||||
|
||||
@pytest.fixture
|
||||
def vllm_processor_config(base_processor_config):
|
||||
"""Default configuration for vLLM processor"""
|
||||
config = base_processor_config.copy()
|
||||
config.update({
|
||||
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
|
||||
'url': 'http://vllm-service:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048
|
||||
})
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vllm_session():
|
||||
"""Mock aiohttp ClientSession for vLLM"""
|
||||
mock_session = MagicMock()
|
||||
|
||||
# Mock successful response
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'choices': [{
|
||||
'text': 'Test response from vLLM'
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 16,
|
||||
'completion_tokens': 8
|
||||
}
|
||||
})
|
||||
|
||||
# Mock the async context manager
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session.post.return_value.__aexit__.return_value = None
|
||||
|
||||
return mock_session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vllm_error_response():
|
||||
"""Mock vLLM error response"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 500
|
||||
return mock_response
|
||||
|
||||
|
||||
# === Cohere Specific Fixtures ===
|
||||
|
||||
@pytest.fixture
|
||||
def cohere_processor_config(base_processor_config):
|
||||
"""Default configuration for Cohere processor"""
|
||||
config = base_processor_config.copy()
|
||||
config.update({
|
||||
'model': 'c4ai-aya-23-8b',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0
|
||||
})
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_cohere_client():
|
||||
"""Mock Cohere client"""
|
||||
mock_client = MagicMock()
|
||||
|
||||
# Mock the response structure
|
||||
mock_output = MagicMock()
|
||||
mock_output.text = "Test response from Cohere"
|
||||
mock_output.meta.billed_units.input_tokens = 18
|
||||
mock_output.meta.billed_units.output_tokens = 10
|
||||
|
||||
mock_client.chat.return_value = mock_output
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_cohere_rate_limit_error():
|
||||
"""Mock Cohere rate limit error"""
|
||||
import cohere
|
||||
return cohere.TooManyRequestsError("Rate limit exceeded")
|
||||
|
||||
|
||||
# === Google AI Studio Specific Fixtures ===
|
||||
|
||||
@pytest.fixture
|
||||
def googleaistudio_processor_config(base_processor_config):
|
||||
"""Default configuration for Google AI Studio processor"""
|
||||
config = base_processor_config.copy()
|
||||
config.update({
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192
|
||||
})
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_googleaistudio_client():
|
||||
"""Mock Google AI Studio client"""
|
||||
mock_client = MagicMock()
|
||||
|
||||
# Mock the response structure
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Test response from Google AI Studio"
|
||||
mock_response.usage_metadata.prompt_token_count = 20
|
||||
mock_response.usage_metadata.candidates_token_count = 12
|
||||
|
||||
mock_client.models.generate_content.return_value = mock_response
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_googleaistudio_rate_limit_error():
|
||||
"""Mock Google AI Studio rate limit error"""
|
||||
from google.api_core.exceptions import ResourceExhausted
|
||||
return ResourceExhausted("Rate limit exceeded")
|
||||
|
||||
|
||||
# === LlamaFile Specific Fixtures ===
|
||||
|
||||
@pytest.fixture
|
||||
def llamafile_processor_config(base_processor_config):
|
||||
"""Default configuration for LlamaFile processor"""
|
||||
config = base_processor_config.copy()
|
||||
config.update({
|
||||
'model': 'LLaMA_CPP',
|
||||
'llamafile': 'http://localhost:8080/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096
|
||||
})
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llamafile_client():
|
||||
"""Mock OpenAI client for LlamaFile"""
|
||||
mock_client = MagicMock()
|
||||
|
||||
# Mock the response structure
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Test response from LlamaFile"
|
||||
mock_response.usage.prompt_tokens = 14
|
||||
mock_response.usage.completion_tokens = 8
|
||||
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
return mock_client
|
||||
407
tests/unit/test_text_completion/test_azure_openai_processor.py
Normal file
407
tests/unit/test_text_completion/test_azure_openai_processor.py
Normal file
|
|
@ -0,0 +1,407 @@
|
|||
"""
|
||||
Unit tests for trustgraph.model.text_completion.azure_openai
|
||||
Following the same successful pattern as previous tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.model.text_completion.azure_openai.llm import Processor
|
||||
from trustgraph.base import LlmResult
|
||||
from trustgraph.exceptions import TooManyRequests
|
||||
|
||||
|
||||
class TestAzureOpenAIProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Test Azure OpenAI processor functionality"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
|
||||
"""Test basic processor initialization"""
|
||||
# Arrange
|
||||
mock_azure_client = MagicMock()
|
||||
mock_azure_openai_class.return_value = mock_azure_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-4',
|
||||
'endpoint': 'https://test.openai.azure.com/',
|
||||
'token': 'test-token',
|
||||
'api_version': '2024-12-01-preview',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gpt-4'
|
||||
assert processor.temperature == 0.0
|
||||
assert processor.max_output == 4192
|
||||
assert hasattr(processor, 'openai')
|
||||
mock_azure_openai_class.assert_called_once_with(
|
||||
api_key='test-token',
|
||||
api_version='2024-12-01-preview',
|
||||
azure_endpoint='https://test.openai.azure.com/'
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
|
||||
"""Test successful content generation"""
|
||||
# Arrange
|
||||
mock_azure_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Generated response from Azure OpenAI"
|
||||
mock_response.usage.prompt_tokens = 25
|
||||
mock_response.usage.completion_tokens = 15
|
||||
|
||||
mock_azure_client.chat.completions.create.return_value = mock_response
|
||||
mock_azure_openai_class.return_value = mock_azure_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-4',
|
||||
'endpoint': 'https://test.openai.azure.com/',
|
||||
'token': 'test-token',
|
||||
'api_version': '2024-12-01-preview',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Generated response from Azure OpenAI"
|
||||
assert result.in_token == 25
|
||||
assert result.out_token == 15
|
||||
assert result.model == 'gpt-4'
|
||||
|
||||
# Verify the Azure OpenAI API call
|
||||
mock_azure_client.chat.completions.create.assert_called_once_with(
|
||||
model='gpt-4',
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": "System prompt\n\nUser prompt"
|
||||
}]
|
||||
}],
|
||||
temperature=0.0,
|
||||
max_tokens=4192,
|
||||
top_p=1
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_rate_limit_error(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
|
||||
"""Test rate limit error handling"""
|
||||
# Arrange
|
||||
from openai import RateLimitError
|
||||
|
||||
mock_azure_client = MagicMock()
|
||||
mock_azure_client.chat.completions.create.side_effect = RateLimitError("Rate limit exceeded", response=MagicMock(), body=None)
|
||||
mock_azure_openai_class.return_value = mock_azure_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-4',
|
||||
'endpoint': 'https://test.openai.azure.com/',
|
||||
'token': 'test-token',
|
||||
'api_version': '2024-12-01-preview',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(TooManyRequests):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
|
||||
"""Test handling of generic exceptions"""
|
||||
# Arrange
|
||||
mock_azure_client = MagicMock()
|
||||
mock_azure_client.chat.completions.create.side_effect = Exception("Azure API connection error")
|
||||
mock_azure_openai_class.return_value = mock_azure_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-4',
|
||||
'endpoint': 'https://test.openai.azure.com/',
|
||||
'token': 'test-token',
|
||||
'api_version': '2024-12-01-preview',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Azure API connection error"):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_without_endpoint(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
|
||||
"""Test processor initialization without endpoint (should fail)"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-4',
|
||||
'endpoint': None, # No endpoint provided
|
||||
'token': 'test-token',
|
||||
'api_version': '2024-12-01-preview',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError, match="Azure endpoint not specified"):
|
||||
processor = Processor(**config)
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_without_token(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
|
||||
"""Test processor initialization without token (should fail)"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-4',
|
||||
'endpoint': 'https://test.openai.azure.com/',
|
||||
'token': None, # No token provided
|
||||
'api_version': '2024-12-01-preview',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError, match="Azure token not specified"):
|
||||
processor = Processor(**config)
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
# Arrange
|
||||
mock_azure_client = MagicMock()
|
||||
mock_azure_openai_class.return_value = mock_azure_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-35-turbo',
|
||||
'endpoint': 'https://custom.openai.azure.com/',
|
||||
'token': 'custom-token',
|
||||
'api_version': '2023-05-15',
|
||||
'temperature': 0.7,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gpt-35-turbo'
|
||||
assert processor.temperature == 0.7
|
||||
assert processor.max_output == 2048
|
||||
mock_azure_openai_class.assert_called_once_with(
|
||||
api_key='custom-token',
|
||||
api_version='2023-05-15',
|
||||
azure_endpoint='https://custom.openai.azure.com/'
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_defaults(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
|
||||
"""Test processor initialization with default values"""
|
||||
# Arrange
|
||||
mock_azure_client = MagicMock()
|
||||
mock_azure_openai_class.return_value = mock_azure_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
# Only provide required fields, should use defaults
|
||||
config = {
|
||||
'endpoint': 'https://test.openai.azure.com/',
|
||||
'token': 'test-token',
|
||||
'model': 'gpt-4', # Required for Azure
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gpt-4'
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
assert processor.max_output == 4192 # default_max_output
|
||||
mock_azure_openai_class.assert_called_once_with(
|
||||
api_key='test-token',
|
||||
api_version='2024-12-01-preview', # default_api
|
||||
azure_endpoint='https://test.openai.azure.com/'
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
|
||||
"""Test content generation with empty prompts"""
|
||||
# Arrange
|
||||
mock_azure_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Default response"
|
||||
mock_response.usage.prompt_tokens = 2
|
||||
mock_response.usage.completion_tokens = 3
|
||||
|
||||
mock_azure_client.chat.completions.create.return_value = mock_response
|
||||
mock_azure_openai_class.return_value = mock_azure_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-4',
|
||||
'endpoint': 'https://test.openai.azure.com/',
|
||||
'token': 'test-token',
|
||||
'api_version': '2024-12-01-preview',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("", "")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Default response"
|
||||
assert result.in_token == 2
|
||||
assert result.out_token == 3
|
||||
assert result.model == 'gpt-4'
|
||||
|
||||
# Verify the combined prompt is sent correctly
|
||||
call_args = mock_azure_client.chat.completions.create.call_args
|
||||
expected_prompt = "\n\n" # Empty system + "\n\n" + empty user
|
||||
assert call_args[1]['messages'][0]['content'][0]['text'] == expected_prompt
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_message_structure(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
|
||||
"""Test that Azure OpenAI messages are structured correctly"""
|
||||
# Arrange
|
||||
mock_azure_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Response with proper structure"
|
||||
mock_response.usage.prompt_tokens = 30
|
||||
mock_response.usage.completion_tokens = 20
|
||||
|
||||
mock_azure_client.chat.completions.create.return_value = mock_response
|
||||
mock_azure_openai_class.return_value = mock_azure_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-4',
|
||||
'endpoint': 'https://test.openai.azure.com/',
|
||||
'token': 'test-token',
|
||||
'api_version': '2024-12-01-preview',
|
||||
'temperature': 0.5,
|
||||
'max_output': 1024,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("You are a helpful assistant", "What is AI?")
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with proper structure"
|
||||
assert result.in_token == 30
|
||||
assert result.out_token == 20
|
||||
|
||||
# Verify the message structure matches Azure OpenAI Chat API format
|
||||
call_args = mock_azure_client.chat.completions.create.call_args
|
||||
messages = call_args[1]['messages']
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0]['role'] == 'user'
|
||||
assert messages[0]['content'][0]['type'] == 'text'
|
||||
assert messages[0]['content'][0]['text'] == "You are a helpful assistant\n\nWhat is AI?"
|
||||
|
||||
# Verify other parameters
|
||||
assert call_args[1]['model'] == 'gpt-4'
|
||||
assert call_args[1]['temperature'] == 0.5
|
||||
assert call_args[1]['max_tokens'] == 1024
|
||||
assert call_args[1]['top_p'] == 1
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
463
tests/unit/test_text_completion/test_azure_processor.py
Normal file
463
tests/unit/test_text_completion/test_azure_processor.py
Normal file
|
|
@ -0,0 +1,463 @@
|
|||
"""
|
||||
Unit tests for trustgraph.model.text_completion.azure
|
||||
Following the same successful pattern as previous tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.model.text_completion.azure.llm import Processor
|
||||
from trustgraph.base import LlmResult
|
||||
from trustgraph.exceptions import TooManyRequests
|
||||
|
||||
|
||||
class TestAzureProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Test Azure processor functionality"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_requests):
|
||||
"""Test basic processor initialization"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
|
||||
'token': 'test-token',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.endpoint == 'https://test.inference.ai.azure.com/v1/chat/completions'
|
||||
assert processor.token == 'test-token'
|
||||
assert processor.temperature == 0.0
|
||||
assert processor.max_output == 4192
|
||||
assert processor.model == 'AzureAI'
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_requests):
|
||||
"""Test successful content generation"""
|
||||
# Arrange
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
'choices': [{
|
||||
'message': {
|
||||
'content': 'Generated response from Azure'
|
||||
}
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 20,
|
||||
'completion_tokens': 12
|
||||
}
|
||||
}
|
||||
mock_requests.post.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
|
||||
'token': 'test-token',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Generated response from Azure"
|
||||
assert result.in_token == 20
|
||||
assert result.out_token == 12
|
||||
assert result.model == 'AzureAI'
|
||||
|
||||
# Verify the API call was made correctly
|
||||
mock_requests.post.assert_called_once()
|
||||
call_args = mock_requests.post.call_args
|
||||
|
||||
# Check URL
|
||||
assert call_args[0][0] == 'https://test.inference.ai.azure.com/v1/chat/completions'
|
||||
|
||||
# Check headers
|
||||
headers = call_args[1]['headers']
|
||||
assert headers['Content-Type'] == 'application/json'
|
||||
assert headers['Authorization'] == 'Bearer test-token'
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_rate_limit_error(self, mock_llm_init, mock_async_init, mock_requests):
|
||||
"""Test rate limit error handling"""
|
||||
# Arrange
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 429
|
||||
mock_requests.post.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
|
||||
'token': 'test-token',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(TooManyRequests):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_http_error(self, mock_llm_init, mock_async_init, mock_requests):
|
||||
"""Test HTTP error handling"""
|
||||
# Arrange
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
mock_requests.post.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
|
||||
'token': 'test-token',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError, match="LLM failure"):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_requests):
|
||||
"""Test handling of generic exceptions"""
|
||||
# Arrange
|
||||
mock_requests.post.side_effect = Exception("Connection error")
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
|
||||
'token': 'test-token',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Connection error"):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_without_endpoint(self, mock_llm_init, mock_async_init, mock_requests):
|
||||
"""Test processor initialization without endpoint (should fail)"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'endpoint': None, # No endpoint provided
|
||||
'token': 'test-token',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError, match="Azure endpoint not specified"):
|
||||
processor = Processor(**config)
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_without_token(self, mock_llm_init, mock_async_init, mock_requests):
|
||||
"""Test processor initialization without token (should fail)"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
|
||||
'token': None, # No token provided
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError, match="Azure token not specified"):
|
||||
processor = Processor(**config)
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_requests):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'endpoint': 'https://custom.inference.ai.azure.com/v1/chat/completions',
|
||||
'token': 'custom-token',
|
||||
'temperature': 0.7,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.endpoint == 'https://custom.inference.ai.azure.com/v1/chat/completions'
|
||||
assert processor.token == 'custom-token'
|
||||
assert processor.temperature == 0.7
|
||||
assert processor.max_output == 2048
|
||||
assert processor.model == 'AzureAI'
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_defaults(self, mock_llm_init, mock_async_init, mock_requests):
|
||||
"""Test processor initialization with default values"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
# Only provide required fields, should use defaults
|
||||
config = {
|
||||
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
|
||||
'token': 'test-token',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.endpoint == 'https://test.inference.ai.azure.com/v1/chat/completions'
|
||||
assert processor.token == 'test-token'
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
assert processor.max_output == 4192 # default_max_output
|
||||
assert processor.model == 'AzureAI' # default_model
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_requests):
|
||||
"""Test content generation with empty prompts"""
|
||||
# Arrange
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
'choices': [{
|
||||
'message': {
|
||||
'content': 'Default response'
|
||||
}
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 2,
|
||||
'completion_tokens': 3
|
||||
}
|
||||
}
|
||||
mock_requests.post.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
|
||||
'token': 'test-token',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("", "")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Default response"
|
||||
assert result.in_token == 2
|
||||
assert result.out_token == 3
|
||||
assert result.model == 'AzureAI'
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_build_prompt_structure(self, mock_llm_init, mock_async_init, mock_requests):
|
||||
"""Test that build_prompt creates correct message structure"""
|
||||
# Arrange
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
'choices': [{
|
||||
'message': {
|
||||
'content': 'Response with proper structure'
|
||||
}
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 25,
|
||||
'completion_tokens': 15
|
||||
}
|
||||
}
|
||||
mock_requests.post.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
|
||||
'token': 'test-token',
|
||||
'temperature': 0.5,
|
||||
'max_output': 1024,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("You are a helpful assistant", "What is AI?")
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with proper structure"
|
||||
assert result.in_token == 25
|
||||
assert result.out_token == 15
|
||||
|
||||
# Verify the request structure
|
||||
mock_requests.post.assert_called_once()
|
||||
call_args = mock_requests.post.call_args
|
||||
|
||||
# Parse the request body
|
||||
import json
|
||||
request_body = json.loads(call_args[1]['data'])
|
||||
|
||||
# Verify message structure
|
||||
assert 'messages' in request_body
|
||||
assert len(request_body['messages']) == 2
|
||||
|
||||
# Check system message
|
||||
assert request_body['messages'][0]['role'] == 'system'
|
||||
assert request_body['messages'][0]['content'] == 'You are a helpful assistant'
|
||||
|
||||
# Check user message
|
||||
assert request_body['messages'][1]['role'] == 'user'
|
||||
assert request_body['messages'][1]['content'] == 'What is AI?'
|
||||
|
||||
# Check parameters
|
||||
assert request_body['temperature'] == 0.5
|
||||
assert request_body['max_tokens'] == 1024
|
||||
assert request_body['top_p'] == 1
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_call_llm_method(self, mock_llm_init, mock_async_init, mock_requests):
|
||||
"""Test the call_llm method directly"""
|
||||
# Arrange
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
'choices': [{
|
||||
'message': {
|
||||
'content': 'Test response'
|
||||
}
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 10,
|
||||
'completion_tokens': 5
|
||||
}
|
||||
}
|
||||
mock_requests.post.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
|
||||
'token': 'test-token',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = processor.call_llm('{"test": "body"}')
|
||||
|
||||
# Assert
|
||||
assert result == mock_response.json.return_value
|
||||
|
||||
# Verify the request was made correctly
|
||||
mock_requests.post.assert_called_once_with(
|
||||
'https://test.inference.ai.azure.com/v1/chat/completions',
|
||||
data='{"test": "body"}',
|
||||
headers={
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': 'Bearer test-token'
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
440
tests/unit/test_text_completion/test_claude_processor.py
Normal file
440
tests/unit/test_text_completion/test_claude_processor.py
Normal file
|
|
@ -0,0 +1,440 @@
|
|||
"""
|
||||
Unit tests for trustgraph.model.text_completion.claude
|
||||
Following the same successful pattern as previous tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.model.text_completion.claude.llm import Processor
|
||||
from trustgraph.base import LlmResult
|
||||
from trustgraph.exceptions import TooManyRequests
|
||||
|
||||
|
||||
class TestClaudeProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Test Claude processor functionality"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_anthropic_class):
|
||||
"""Test basic processor initialization"""
|
||||
# Arrange
|
||||
mock_claude_client = MagicMock()
|
||||
mock_anthropic_class.return_value = mock_claude_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'claude-3-5-sonnet-20240620',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'claude-3-5-sonnet-20240620'
|
||||
assert processor.temperature == 0.0
|
||||
assert processor.max_output == 8192
|
||||
assert hasattr(processor, 'claude')
|
||||
mock_anthropic_class.assert_called_once_with(api_key='test-api-key')
|
||||
|
||||
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_anthropic_class):
|
||||
"""Test successful content generation"""
|
||||
# Arrange
|
||||
mock_claude_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock()]
|
||||
mock_response.content[0].text = "Generated response from Claude"
|
||||
mock_response.usage.input_tokens = 25
|
||||
mock_response.usage.output_tokens = 15
|
||||
|
||||
mock_claude_client.messages.create.return_value = mock_response
|
||||
mock_anthropic_class.return_value = mock_claude_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'claude-3-5-sonnet-20240620',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Generated response from Claude"
|
||||
assert result.in_token == 25
|
||||
assert result.out_token == 15
|
||||
assert result.model == 'claude-3-5-sonnet-20240620'
|
||||
|
||||
# Verify the Claude API call
|
||||
mock_claude_client.messages.create.assert_called_once_with(
|
||||
model='claude-3-5-sonnet-20240620',
|
||||
max_tokens=8192,
|
||||
temperature=0.0,
|
||||
system="System prompt",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": "User prompt"
|
||||
}]
|
||||
}]
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_rate_limit_error(self, mock_llm_init, mock_async_init, mock_anthropic_class):
|
||||
"""Test rate limit error handling"""
|
||||
# Arrange
|
||||
import anthropic
|
||||
|
||||
mock_claude_client = MagicMock()
|
||||
mock_claude_client.messages.create.side_effect = anthropic.RateLimitError(
|
||||
"Rate limit exceeded",
|
||||
response=MagicMock(),
|
||||
body=None
|
||||
)
|
||||
mock_anthropic_class.return_value = mock_claude_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'claude-3-5-sonnet-20240620',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(TooManyRequests):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_anthropic_class):
|
||||
"""Test handling of generic exceptions"""
|
||||
# Arrange
|
||||
mock_claude_client = MagicMock()
|
||||
mock_claude_client.messages.create.side_effect = Exception("API connection error")
|
||||
mock_anthropic_class.return_value = mock_claude_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'claude-3-5-sonnet-20240620',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="API connection error"):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_without_api_key(self, mock_llm_init, mock_async_init, mock_anthropic_class):
|
||||
"""Test processor initialization without API key (should fail)"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'claude-3-5-sonnet-20240620',
|
||||
'api_key': None, # No API key provided
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError, match="Claude API key not specified"):
|
||||
processor = Processor(**config)
|
||||
|
||||
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_anthropic_class):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
# Arrange
|
||||
mock_claude_client = MagicMock()
|
||||
mock_anthropic_class.return_value = mock_claude_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'claude-3-haiku-20240307',
|
||||
'api_key': 'custom-api-key',
|
||||
'temperature': 0.7,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'claude-3-haiku-20240307'
|
||||
assert processor.temperature == 0.7
|
||||
assert processor.max_output == 4096
|
||||
mock_anthropic_class.assert_called_once_with(api_key='custom-api-key')
|
||||
|
||||
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_defaults(self, mock_llm_init, mock_async_init, mock_anthropic_class):
|
||||
"""Test processor initialization with default values"""
|
||||
# Arrange
|
||||
mock_claude_client = MagicMock()
|
||||
mock_anthropic_class.return_value = mock_claude_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
# Only provide required fields, should use defaults
|
||||
config = {
|
||||
'api_key': 'test-api-key',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'claude-3-5-sonnet-20240620' # default_model
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
assert processor.max_output == 8192 # default_max_output
|
||||
mock_anthropic_class.assert_called_once_with(api_key='test-api-key')
|
||||
|
||||
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_anthropic_class):
|
||||
"""Test content generation with empty prompts"""
|
||||
# Arrange
|
||||
mock_claude_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock()]
|
||||
mock_response.content[0].text = "Default response"
|
||||
mock_response.usage.input_tokens = 2
|
||||
mock_response.usage.output_tokens = 3
|
||||
|
||||
mock_claude_client.messages.create.return_value = mock_response
|
||||
mock_anthropic_class.return_value = mock_claude_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'claude-3-5-sonnet-20240620',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("", "")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Default response"
|
||||
assert result.in_token == 2
|
||||
assert result.out_token == 3
|
||||
assert result.model == 'claude-3-5-sonnet-20240620'
|
||||
|
||||
# Verify the system prompt and user content are handled correctly
|
||||
call_args = mock_claude_client.messages.create.call_args
|
||||
assert call_args[1]['system'] == ""
|
||||
assert call_args[1]['messages'][0]['content'][0]['text'] == ""
|
||||
|
||||
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_message_structure(self, mock_llm_init, mock_async_init, mock_anthropic_class):
|
||||
"""Test that Claude messages are structured correctly"""
|
||||
# Arrange
|
||||
mock_claude_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock()]
|
||||
mock_response.content[0].text = "Response with proper structure"
|
||||
mock_response.usage.input_tokens = 30
|
||||
mock_response.usage.output_tokens = 20
|
||||
|
||||
mock_claude_client.messages.create.return_value = mock_response
|
||||
mock_anthropic_class.return_value = mock_claude_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'claude-3-5-sonnet-20240620',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.5,
|
||||
'max_output': 1024,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("You are a helpful assistant", "What is AI?")
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with proper structure"
|
||||
assert result.in_token == 30
|
||||
assert result.out_token == 20
|
||||
|
||||
# Verify the message structure matches Claude API format
|
||||
call_args = mock_claude_client.messages.create.call_args
|
||||
|
||||
# Check system prompt
|
||||
assert call_args[1]['system'] == "You are a helpful assistant"
|
||||
|
||||
# Check user message structure
|
||||
messages = call_args[1]['messages']
|
||||
assert len(messages) == 1
|
||||
assert messages[0]['role'] == 'user'
|
||||
assert messages[0]['content'][0]['type'] == 'text'
|
||||
assert messages[0]['content'][0]['text'] == "What is AI?"
|
||||
|
||||
# Verify other parameters
|
||||
assert call_args[1]['model'] == 'claude-3-5-sonnet-20240620'
|
||||
assert call_args[1]['temperature'] == 0.5
|
||||
assert call_args[1]['max_tokens'] == 1024
|
||||
|
||||
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_multiple_content_blocks(self, mock_llm_init, mock_async_init, mock_anthropic_class):
|
||||
"""Test handling of multiple content blocks in response"""
|
||||
# Arrange
|
||||
mock_claude_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
|
||||
# Mock multiple content blocks (Claude can return multiple)
|
||||
mock_content_1 = MagicMock()
|
||||
mock_content_1.text = "First part of response"
|
||||
mock_content_2 = MagicMock()
|
||||
mock_content_2.text = "Second part of response"
|
||||
mock_response.content = [mock_content_1, mock_content_2]
|
||||
|
||||
mock_response.usage.input_tokens = 40
|
||||
mock_response.usage.output_tokens = 30
|
||||
|
||||
mock_claude_client.messages.create.return_value = mock_response
|
||||
mock_anthropic_class.return_value = mock_claude_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'claude-3-5-sonnet-20240620',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
# Should take the first content block
|
||||
assert result.text == "First part of response"
|
||||
assert result.in_token == 40
|
||||
assert result.out_token == 30
|
||||
assert result.model == 'claude-3-5-sonnet-20240620'
|
||||
|
||||
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_claude_client_initialization(self, mock_llm_init, mock_async_init, mock_anthropic_class):
|
||||
"""Test that Claude client is initialized correctly"""
|
||||
# Arrange
|
||||
mock_claude_client = MagicMock()
|
||||
mock_anthropic_class.return_value = mock_claude_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'claude-3-opus-20240229',
|
||||
'api_key': 'sk-ant-test-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify Anthropic client was called with correct API key
|
||||
mock_anthropic_class.assert_called_once_with(api_key='sk-ant-test-key')
|
||||
|
||||
# Verify processor has the client
|
||||
assert processor.claude == mock_claude_client
|
||||
assert processor.model == 'claude-3-opus-20240229'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
447
tests/unit/test_text_completion/test_cohere_processor.py
Normal file
447
tests/unit/test_text_completion/test_cohere_processor.py
Normal file
|
|
@ -0,0 +1,447 @@
|
|||
"""
|
||||
Unit tests for trustgraph.model.text_completion.cohere
|
||||
Following the same successful pattern as previous tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.model.text_completion.cohere.llm import Processor
|
||||
from trustgraph.base import LlmResult
|
||||
from trustgraph.exceptions import TooManyRequests
|
||||
|
||||
|
||||
class TestCohereProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Test Cohere processor functionality"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_cohere_class):
|
||||
"""Test basic processor initialization"""
|
||||
# Arrange
|
||||
mock_cohere_client = MagicMock()
|
||||
mock_cohere_class.return_value = mock_cohere_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'c4ai-aya-23-8b',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'c4ai-aya-23-8b'
|
||||
assert processor.temperature == 0.0
|
||||
assert hasattr(processor, 'cohere')
|
||||
mock_cohere_class.assert_called_once_with(api_key='test-api-key')
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_cohere_class):
|
||||
"""Test successful content generation"""
|
||||
# Arrange
|
||||
mock_cohere_client = MagicMock()
|
||||
mock_output = MagicMock()
|
||||
mock_output.text = "Generated response from Cohere"
|
||||
mock_output.meta.billed_units.input_tokens = 25
|
||||
mock_output.meta.billed_units.output_tokens = 15
|
||||
|
||||
mock_cohere_client.chat.return_value = mock_output
|
||||
mock_cohere_class.return_value = mock_cohere_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'c4ai-aya-23-8b',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Generated response from Cohere"
|
||||
assert result.in_token == 25
|
||||
assert result.out_token == 15
|
||||
assert result.model == 'c4ai-aya-23-8b'
|
||||
|
||||
# Verify the Cohere API call
|
||||
mock_cohere_client.chat.assert_called_once_with(
|
||||
model='c4ai-aya-23-8b',
|
||||
message="User prompt",
|
||||
preamble="System prompt",
|
||||
temperature=0.0,
|
||||
chat_history=[],
|
||||
prompt_truncation='auto',
|
||||
connectors=[]
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_rate_limit_error(self, mock_llm_init, mock_async_init, mock_cohere_class):
|
||||
"""Test rate limit error handling"""
|
||||
# Arrange
|
||||
import cohere
|
||||
|
||||
mock_cohere_client = MagicMock()
|
||||
mock_cohere_client.chat.side_effect = cohere.TooManyRequestsError("Rate limit exceeded")
|
||||
mock_cohere_class.return_value = mock_cohere_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'c4ai-aya-23-8b',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(TooManyRequests):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_cohere_class):
|
||||
"""Test handling of generic exceptions"""
|
||||
# Arrange
|
||||
mock_cohere_client = MagicMock()
|
||||
mock_cohere_client.chat.side_effect = Exception("API connection error")
|
||||
mock_cohere_class.return_value = mock_cohere_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'c4ai-aya-23-8b',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="API connection error"):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_without_api_key(self, mock_llm_init, mock_async_init, mock_cohere_class):
|
||||
"""Test processor initialization without API key (should fail)"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'c4ai-aya-23-8b',
|
||||
'api_key': None, # No API key provided
|
||||
'temperature': 0.0,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError, match="Cohere API key not specified"):
|
||||
processor = Processor(**config)
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_cohere_class):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
# Arrange
|
||||
mock_cohere_client = MagicMock()
|
||||
mock_cohere_class.return_value = mock_cohere_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'command-light',
|
||||
'api_key': 'custom-api-key',
|
||||
'temperature': 0.7,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'command-light'
|
||||
assert processor.temperature == 0.7
|
||||
mock_cohere_class.assert_called_once_with(api_key='custom-api-key')
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_defaults(self, mock_llm_init, mock_async_init, mock_cohere_class):
|
||||
"""Test processor initialization with default values"""
|
||||
# Arrange
|
||||
mock_cohere_client = MagicMock()
|
||||
mock_cohere_class.return_value = mock_cohere_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
# Only provide required fields, should use defaults
|
||||
config = {
|
||||
'api_key': 'test-api-key',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'c4ai-aya-23-8b' # default_model
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
mock_cohere_class.assert_called_once_with(api_key='test-api-key')
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_cohere_class):
|
||||
"""Test content generation with empty prompts"""
|
||||
# Arrange
|
||||
mock_cohere_client = MagicMock()
|
||||
mock_output = MagicMock()
|
||||
mock_output.text = "Default response"
|
||||
mock_output.meta.billed_units.input_tokens = 2
|
||||
mock_output.meta.billed_units.output_tokens = 3
|
||||
|
||||
mock_cohere_client.chat.return_value = mock_output
|
||||
mock_cohere_class.return_value = mock_cohere_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'c4ai-aya-23-8b',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("", "")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Default response"
|
||||
assert result.in_token == 2
|
||||
assert result.out_token == 3
|
||||
assert result.model == 'c4ai-aya-23-8b'
|
||||
|
||||
# Verify the preamble and message are handled correctly
|
||||
call_args = mock_cohere_client.chat.call_args
|
||||
assert call_args[1]['preamble'] == ""
|
||||
assert call_args[1]['message'] == ""
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_chat_structure(self, mock_llm_init, mock_async_init, mock_cohere_class):
|
||||
"""Test that Cohere chat is structured correctly"""
|
||||
# Arrange
|
||||
mock_cohere_client = MagicMock()
|
||||
mock_output = MagicMock()
|
||||
mock_output.text = "Response with proper structure"
|
||||
mock_output.meta.billed_units.input_tokens = 30
|
||||
mock_output.meta.billed_units.output_tokens = 20
|
||||
|
||||
mock_cohere_client.chat.return_value = mock_output
|
||||
mock_cohere_class.return_value = mock_cohere_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'c4ai-aya-23-8b',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.5,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("You are a helpful assistant", "What is AI?")
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with proper structure"
|
||||
assert result.in_token == 30
|
||||
assert result.out_token == 20
|
||||
|
||||
# Verify the chat structure matches Cohere API format
|
||||
call_args = mock_cohere_client.chat.call_args
|
||||
|
||||
# Check parameters
|
||||
assert call_args[1]['model'] == 'c4ai-aya-23-8b'
|
||||
assert call_args[1]['message'] == "What is AI?"
|
||||
assert call_args[1]['preamble'] == "You are a helpful assistant"
|
||||
assert call_args[1]['temperature'] == 0.5
|
||||
assert call_args[1]['chat_history'] == []
|
||||
assert call_args[1]['prompt_truncation'] == 'auto'
|
||||
assert call_args[1]['connectors'] == []
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_token_parsing(self, mock_llm_init, mock_async_init, mock_cohere_class):
|
||||
"""Test token parsing from Cohere response"""
|
||||
# Arrange
|
||||
mock_cohere_client = MagicMock()
|
||||
mock_output = MagicMock()
|
||||
mock_output.text = "Token parsing test"
|
||||
mock_output.meta.billed_units.input_tokens = 50
|
||||
mock_output.meta.billed_units.output_tokens = 25
|
||||
|
||||
mock_cohere_client.chat.return_value = mock_output
|
||||
mock_cohere_class.return_value = mock_cohere_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'c4ai-aya-23-8b',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System", "User query")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Token parsing test"
|
||||
assert result.in_token == 50
|
||||
assert result.out_token == 25
|
||||
assert result.model == 'c4ai-aya-23-8b'
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_cohere_client_initialization(self, mock_llm_init, mock_async_init, mock_cohere_class):
|
||||
"""Test that Cohere client is initialized correctly"""
|
||||
# Arrange
|
||||
mock_cohere_client = MagicMock()
|
||||
mock_cohere_class.return_value = mock_cohere_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'command-r',
|
||||
'api_key': 'co-test-key',
|
||||
'temperature': 0.0,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify Cohere client was called with correct API key
|
||||
mock_cohere_class.assert_called_once_with(api_key='co-test-key')
|
||||
|
||||
# Verify processor has the client
|
||||
assert processor.cohere == mock_cohere_client
|
||||
assert processor.model == 'command-r'
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_chat_parameters(self, mock_llm_init, mock_async_init, mock_cohere_class):
|
||||
"""Test that all chat parameters are passed correctly"""
|
||||
# Arrange
|
||||
mock_cohere_client = MagicMock()
|
||||
mock_output = MagicMock()
|
||||
mock_output.text = "Chat parameter test"
|
||||
mock_output.meta.billed_units.input_tokens = 20
|
||||
mock_output.meta.billed_units.output_tokens = 10
|
||||
|
||||
mock_cohere_client.chat.return_value = mock_output
|
||||
mock_cohere_class.return_value = mock_cohere_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'c4ai-aya-23-8b',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.3,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System instructions", "User question")
|
||||
|
||||
# Assert
|
||||
assert result.text == "Chat parameter test"
|
||||
|
||||
# Verify all parameters are passed correctly
|
||||
call_args = mock_cohere_client.chat.call_args
|
||||
assert call_args[1]['model'] == 'c4ai-aya-23-8b'
|
||||
assert call_args[1]['message'] == "User question"
|
||||
assert call_args[1]['preamble'] == "System instructions"
|
||||
assert call_args[1]['temperature'] == 0.3
|
||||
assert call_args[1]['chat_history'] == []
|
||||
assert call_args[1]['prompt_truncation'] == 'auto'
|
||||
assert call_args[1]['connectors'] == []
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
482
tests/unit/test_text_completion/test_googleaistudio_processor.py
Normal file
482
tests/unit/test_text_completion/test_googleaistudio_processor.py
Normal file
|
|
@ -0,0 +1,482 @@
|
|||
"""
|
||||
Unit tests for trustgraph.model.text_completion.googleaistudio
|
||||
Following the same successful pattern as previous tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.model.text_completion.googleaistudio.llm import Processor
|
||||
from trustgraph.base import LlmResult
|
||||
from trustgraph.exceptions import TooManyRequests
|
||||
|
||||
|
||||
class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Test Google AI Studio processor functionality"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test basic processor initialization"""
|
||||
# Arrange
|
||||
mock_genai_client = MagicMock()
|
||||
mock_genai_class.return_value = mock_genai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gemini-2.0-flash-001'
|
||||
assert processor.temperature == 0.0
|
||||
assert processor.max_output == 8192
|
||||
assert hasattr(processor, 'client')
|
||||
assert hasattr(processor, 'safety_settings')
|
||||
assert len(processor.safety_settings) == 4 # 4 safety categories
|
||||
mock_genai_class.assert_called_once_with(api_key='test-api-key')
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test successful content generation"""
|
||||
# Arrange
|
||||
mock_genai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Generated response from Google AI Studio"
|
||||
mock_response.usage_metadata.prompt_token_count = 25
|
||||
mock_response.usage_metadata.candidates_token_count = 15
|
||||
|
||||
mock_genai_client.models.generate_content.return_value = mock_response
|
||||
mock_genai_class.return_value = mock_genai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Generated response from Google AI Studio"
|
||||
assert result.in_token == 25
|
||||
assert result.out_token == 15
|
||||
assert result.model == 'gemini-2.0-flash-001'
|
||||
|
||||
# Verify the Google AI Studio API call
|
||||
mock_genai_client.models.generate_content.assert_called_once()
|
||||
call_args = mock_genai_client.models.generate_content.call_args
|
||||
assert call_args[1]['model'] == 'gemini-2.0-flash-001'
|
||||
assert call_args[1]['contents'] == "User prompt"
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_rate_limit_error(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test rate limit error handling"""
|
||||
# Arrange
|
||||
from google.api_core.exceptions import ResourceExhausted
|
||||
|
||||
mock_genai_client = MagicMock()
|
||||
mock_genai_client.models.generate_content.side_effect = ResourceExhausted("Rate limit exceeded")
|
||||
mock_genai_class.return_value = mock_genai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(TooManyRequests):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test handling of generic exceptions"""
|
||||
# Arrange
|
||||
mock_genai_client = MagicMock()
|
||||
mock_genai_client.models.generate_content.side_effect = Exception("API connection error")
|
||||
mock_genai_class.return_value = mock_genai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="API connection error"):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_without_api_key(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test processor initialization without API key (should fail)"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'api_key': None, # No API key provided
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError, match="Google AI Studio API key not specified"):
|
||||
processor = Processor(**config)
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
# Arrange
|
||||
mock_genai_client = MagicMock()
|
||||
mock_genai_class.return_value = mock_genai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-1.5-pro',
|
||||
'api_key': 'custom-api-key',
|
||||
'temperature': 0.7,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gemini-1.5-pro'
|
||||
assert processor.temperature == 0.7
|
||||
assert processor.max_output == 4096
|
||||
mock_genai_class.assert_called_once_with(api_key='custom-api-key')
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_defaults(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test processor initialization with default values"""
|
||||
# Arrange
|
||||
mock_genai_client = MagicMock()
|
||||
mock_genai_class.return_value = mock_genai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
# Only provide required fields, should use defaults
|
||||
config = {
|
||||
'api_key': 'test-api-key',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gemini-2.0-flash-001' # default_model
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
assert processor.max_output == 8192 # default_max_output
|
||||
mock_genai_class.assert_called_once_with(api_key='test-api-key')
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test content generation with empty prompts"""
|
||||
# Arrange
|
||||
mock_genai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Default response"
|
||||
mock_response.usage_metadata.prompt_token_count = 2
|
||||
mock_response.usage_metadata.candidates_token_count = 3
|
||||
|
||||
mock_genai_client.models.generate_content.return_value = mock_response
|
||||
mock_genai_class.return_value = mock_genai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("", "")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Default response"
|
||||
assert result.in_token == 2
|
||||
assert result.out_token == 3
|
||||
assert result.model == 'gemini-2.0-flash-001'
|
||||
|
||||
# Verify the system instruction and content are handled correctly
|
||||
call_args = mock_genai_client.models.generate_content.call_args
|
||||
assert call_args[1]['contents'] == ""
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_configuration_structure(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test that generation configuration is structured correctly"""
|
||||
# Arrange
|
||||
mock_genai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Response with proper structure"
|
||||
mock_response.usage_metadata.prompt_token_count = 30
|
||||
mock_response.usage_metadata.candidates_token_count = 20
|
||||
|
||||
mock_genai_client.models.generate_content.return_value = mock_response
|
||||
mock_genai_class.return_value = mock_genai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.5,
|
||||
'max_output': 1024,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("You are a helpful assistant", "What is AI?")
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with proper structure"
|
||||
assert result.in_token == 30
|
||||
assert result.out_token == 20
|
||||
|
||||
# Verify the generation configuration
|
||||
call_args = mock_genai_client.models.generate_content.call_args
|
||||
config_arg = call_args[1]['config']
|
||||
|
||||
# Check that the configuration has the right structure
|
||||
assert call_args[1]['model'] == 'gemini-2.0-flash-001'
|
||||
assert call_args[1]['contents'] == "What is AI?"
|
||||
# Config should be a GenerateContentConfig object
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_safety_settings_initialization(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test that safety settings are initialized correctly"""
|
||||
# Arrange
|
||||
mock_genai_client = MagicMock()
|
||||
mock_genai_class.return_value = mock_genai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert hasattr(processor, 'safety_settings')
|
||||
assert len(processor.safety_settings) == 4
|
||||
# Should have 4 safety categories: hate speech, harassment, sexually explicit, dangerous content
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_token_parsing(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test token parsing from Google AI Studio response"""
|
||||
# Arrange
|
||||
mock_genai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Token parsing test"
|
||||
mock_response.usage_metadata.prompt_token_count = 50
|
||||
mock_response.usage_metadata.candidates_token_count = 25
|
||||
|
||||
mock_genai_client.models.generate_content.return_value = mock_response
|
||||
mock_genai_class.return_value = mock_genai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System", "User query")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Token parsing test"
|
||||
assert result.in_token == 50
|
||||
assert result.out_token == 25
|
||||
assert result.model == 'gemini-2.0-flash-001'
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_genai_client_initialization(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test that Google AI Studio client is initialized correctly"""
|
||||
# Arrange
|
||||
mock_genai_client = MagicMock()
|
||||
mock_genai_class.return_value = mock_genai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-1.5-flash',
|
||||
'api_key': 'gai-test-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify Google AI Studio client was called with correct API key
|
||||
mock_genai_class.assert_called_once_with(api_key='gai-test-key')
|
||||
|
||||
# Verify processor has the client
|
||||
assert processor.client == mock_genai_client
|
||||
assert processor.model == 'gemini-1.5-flash'
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_system_instruction(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test that system instruction is handled correctly"""
|
||||
# Arrange
|
||||
mock_genai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "System instruction test"
|
||||
mock_response.usage_metadata.prompt_token_count = 35
|
||||
mock_response.usage_metadata.candidates_token_count = 25
|
||||
|
||||
mock_genai_client.models.generate_content.return_value = mock_response
|
||||
mock_genai_class.return_value = mock_genai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("Be helpful and concise", "Explain quantum computing")
|
||||
|
||||
# Assert
|
||||
assert result.text == "System instruction test"
|
||||
assert result.in_token == 35
|
||||
assert result.out_token == 25
|
||||
|
||||
# Verify the system instruction is passed in the config
|
||||
call_args = mock_genai_client.models.generate_content.call_args
|
||||
config_arg = call_args[1]['config']
|
||||
# The system instruction should be in the config object
|
||||
assert call_args[1]['contents'] == "Explain quantum computing"
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
454
tests/unit/test_text_completion/test_llamafile_processor.py
Normal file
454
tests/unit/test_text_completion/test_llamafile_processor.py
Normal file
|
|
@ -0,0 +1,454 @@
|
|||
"""
|
||||
Unit tests for trustgraph.model.text_completion.llamafile
|
||||
Following the same successful pattern as previous tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.model.text_completion.llamafile.llm import Processor
|
||||
from trustgraph.base import LlmResult
|
||||
from trustgraph.exceptions import TooManyRequests
|
||||
|
||||
|
||||
class TestLlamaFileProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Test LlamaFile processor functionality"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test basic processor initialization"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'LLaMA_CPP',
|
||||
'llamafile': 'http://localhost:8080/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'LLaMA_CPP'
|
||||
assert processor.llamafile == 'http://localhost:8080/v1'
|
||||
assert processor.temperature == 0.0
|
||||
assert processor.max_output == 4096
|
||||
assert hasattr(processor, 'openai')
|
||||
mock_openai_class.assert_called_once_with(
|
||||
base_url='http://localhost:8080/v1',
|
||||
api_key='sk-no-key-required'
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test successful content generation"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Generated response from LlamaFile"
|
||||
mock_response.usage.prompt_tokens = 20
|
||||
mock_response.usage.completion_tokens = 12
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'LLaMA_CPP',
|
||||
'llamafile': 'http://localhost:8080/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Generated response from LlamaFile"
|
||||
assert result.in_token == 20
|
||||
assert result.out_token == 12
|
||||
assert result.model == 'llama.cpp' # Note: model in result is hardcoded to 'llama.cpp'
|
||||
|
||||
# Verify the OpenAI API call structure
|
||||
mock_openai_client.chat.completions.create.assert_called_once_with(
|
||||
model='LLaMA_CPP',
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "System prompt\n\nUser prompt"
|
||||
}]
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test handling of generic exceptions"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_openai_client.chat.completions.create.side_effect = Exception("Connection error")
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'LLaMA_CPP',
|
||||
'llamafile': 'http://localhost:8080/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Connection error"):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'custom-llama',
|
||||
'llamafile': 'http://custom-host:8080/v1',
|
||||
'temperature': 0.7,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'custom-llama'
|
||||
assert processor.llamafile == 'http://custom-host:8080/v1'
|
||||
assert processor.temperature == 0.7
|
||||
assert processor.max_output == 2048
|
||||
mock_openai_class.assert_called_once_with(
|
||||
base_url='http://custom-host:8080/v1',
|
||||
api_key='sk-no-key-required'
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_defaults(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test processor initialization with default values"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
# Only provide required fields, should use defaults
|
||||
config = {
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'LLaMA_CPP' # default_model
|
||||
assert processor.llamafile == 'http://localhost:8080/v1' # default_llamafile
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
assert processor.max_output == 4096 # default_max_output
|
||||
mock_openai_class.assert_called_once_with(
|
||||
base_url='http://localhost:8080/v1',
|
||||
api_key='sk-no-key-required'
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test content generation with empty prompts"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Default response"
|
||||
mock_response.usage.prompt_tokens = 2
|
||||
mock_response.usage.completion_tokens = 3
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'LLaMA_CPP',
|
||||
'llamafile': 'http://localhost:8080/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("", "")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Default response"
|
||||
assert result.in_token == 2
|
||||
assert result.out_token == 3
|
||||
assert result.model == 'llama.cpp'
|
||||
|
||||
# Verify the combined prompt is sent correctly
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
expected_prompt = "\n\n" # Empty system + "\n\n" + empty user
|
||||
assert call_args[1]['messages'][0]['content'] == expected_prompt
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_message_structure(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test that LlamaFile messages are structured correctly"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Response with proper structure"
|
||||
mock_response.usage.prompt_tokens = 25
|
||||
mock_response.usage.completion_tokens = 15
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'LLaMA_CPP',
|
||||
'llamafile': 'http://localhost:8080/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("You are a helpful assistant", "What is AI?")
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with proper structure"
|
||||
assert result.in_token == 25
|
||||
assert result.out_token == 15
|
||||
|
||||
# Verify the message structure
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
messages = call_args[1]['messages']
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0]['role'] == 'user'
|
||||
assert messages[0]['content'] == "You are a helpful assistant\n\nWhat is AI?"
|
||||
|
||||
# Verify model parameter
|
||||
assert call_args[1]['model'] == 'LLaMA_CPP'
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_openai_client_initialization(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test that OpenAI client is initialized correctly for LlamaFile"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'llama-custom',
|
||||
'llamafile': 'http://llamafile-server:8080/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify OpenAI client was called with correct parameters
|
||||
mock_openai_class.assert_called_once_with(
|
||||
base_url='http://llamafile-server:8080/v1',
|
||||
api_key='sk-no-key-required'
|
||||
)
|
||||
|
||||
# Verify processor has the client
|
||||
assert processor.openai == mock_openai_client
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_prompt_construction(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test prompt construction with system and user prompts"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Response with system instructions"
|
||||
mock_response.usage.prompt_tokens = 30
|
||||
mock_response.usage.completion_tokens = 20
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'LLaMA_CPP',
|
||||
'llamafile': 'http://localhost:8080/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("You are a helpful assistant", "What is machine learning?")
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with system instructions"
|
||||
assert result.in_token == 30
|
||||
assert result.out_token == 20
|
||||
|
||||
# Verify the combined prompt
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
expected_prompt = "You are a helpful assistant\n\nWhat is machine learning?"
|
||||
assert call_args[1]['messages'][0]['content'] == expected_prompt
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_hardcoded_model_response(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test that response model is hardcoded to 'llama.cpp'"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Test response"
|
||||
mock_response.usage.prompt_tokens = 15
|
||||
mock_response.usage.completion_tokens = 10
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'custom-model-name', # This should be ignored in response
|
||||
'llamafile': 'http://localhost:8080/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System", "User")
|
||||
|
||||
# Assert
|
||||
assert result.model == 'llama.cpp' # Should always be 'llama.cpp', not 'custom-model-name'
|
||||
assert processor.model == 'custom-model-name' # But processor.model should still be custom
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_no_rate_limiting(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test that no rate limiting is implemented (SLM assumption)"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "No rate limiting test"
|
||||
mock_response.usage.prompt_tokens = 10
|
||||
mock_response.usage.completion_tokens = 5
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'LLaMA_CPP',
|
||||
'llamafile': 'http://localhost:8080/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System", "User")
|
||||
|
||||
# Assert
|
||||
assert result.text == "No rate limiting test"
|
||||
# No specific rate limit error handling tested since SLM presumably has no rate limits
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
317
tests/unit/test_text_completion/test_ollama_processor.py
Normal file
317
tests/unit/test_text_completion/test_ollama_processor.py
Normal file
|
|
@ -0,0 +1,317 @@
|
|||
"""
|
||||
Unit tests for trustgraph.model.text_completion.ollama
|
||||
Following the same successful pattern as VertexAI tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.model.text_completion.ollama.llm import Processor
|
||||
from trustgraph.base import LlmResult
|
||||
|
||||
|
||||
class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Test Ollama processor functionality"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test basic processor initialization"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
# Mock the parent class initialization
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'llama2',
|
||||
'ollama': 'http://localhost:11434',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'llama2'
|
||||
assert hasattr(processor, 'llm')
|
||||
mock_client_class.assert_called_once_with(host='http://localhost:11434')
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test successful content generation"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_response = {
|
||||
'response': 'Generated response from Ollama',
|
||||
'prompt_eval_count': 15,
|
||||
'eval_count': 8
|
||||
}
|
||||
mock_client.generate.return_value = mock_response
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'llama2',
|
||||
'ollama': 'http://localhost:11434',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Generated response from Ollama"
|
||||
assert result.in_token == 15
|
||||
assert result.out_token == 8
|
||||
assert result.model == 'llama2'
|
||||
mock_client.generate.assert_called_once_with('llama2', "System prompt\n\nUser prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test handling of generic exceptions"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_client.generate.side_effect = Exception("Connection error")
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'llama2',
|
||||
'ollama': 'http://localhost:11434',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Connection error"):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'mistral',
|
||||
'ollama': 'http://192.168.1.100:11434',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'mistral'
|
||||
mock_client_class.assert_called_once_with(host='http://192.168.1.100:11434')
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_defaults(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test processor initialization with default values"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
# Don't provide model or ollama - should use defaults
|
||||
config = {
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gemma2:9b' # default_model
|
||||
# Should use default_ollama (http://localhost:11434 or from OLLAMA_HOST env)
|
||||
mock_client_class.assert_called_once()
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test content generation with empty prompts"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_response = {
|
||||
'response': 'Default response',
|
||||
'prompt_eval_count': 2,
|
||||
'eval_count': 3
|
||||
}
|
||||
mock_client.generate.return_value = mock_response
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'llama2',
|
||||
'ollama': 'http://localhost:11434',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("", "")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Default response"
|
||||
assert result.in_token == 2
|
||||
assert result.out_token == 3
|
||||
assert result.model == 'llama2'
|
||||
|
||||
# The prompt should be "" + "\n\n" + "" = "\n\n"
|
||||
mock_client.generate.assert_called_once_with('llama2', "\n\n")
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_token_counting(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test token counting from Ollama response"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_response = {
|
||||
'response': 'Test response',
|
||||
'prompt_eval_count': 50,
|
||||
'eval_count': 25
|
||||
}
|
||||
mock_client.generate.return_value = mock_response
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'llama2',
|
||||
'ollama': 'http://localhost:11434',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Test response"
|
||||
assert result.in_token == 50
|
||||
assert result.out_token == 25
|
||||
assert result.model == 'llama2'
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_ollama_client_initialization(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test that Ollama client is initialized correctly"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'codellama',
|
||||
'ollama': 'http://ollama-server:11434',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify Client was called with correct host
|
||||
mock_client_class.assert_called_once_with(host='http://ollama-server:11434')
|
||||
|
||||
# Verify processor has the client
|
||||
assert processor.llm == mock_client
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_prompt_construction(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test prompt construction with system and user prompts"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_response = {
|
||||
'response': 'Response with system instructions',
|
||||
'prompt_eval_count': 25,
|
||||
'eval_count': 15
|
||||
}
|
||||
mock_client.generate.return_value = mock_response
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'llama2',
|
||||
'ollama': 'http://localhost:11434',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("You are a helpful assistant", "What is AI?")
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with system instructions"
|
||||
assert result.in_token == 25
|
||||
assert result.out_token == 15
|
||||
|
||||
# Verify the combined prompt
|
||||
mock_client.generate.assert_called_once_with('llama2', "You are a helpful assistant\n\nWhat is AI?")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
395
tests/unit/test_text_completion/test_openai_processor.py
Normal file
395
tests/unit/test_text_completion/test_openai_processor.py
Normal file
|
|
@ -0,0 +1,395 @@
|
|||
"""
|
||||
Unit tests for trustgraph.model.text_completion.openai
|
||||
Following the same successful pattern as VertexAI and Ollama tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.model.text_completion.openai.llm import Processor
|
||||
from trustgraph.base import LlmResult
|
||||
from trustgraph.exceptions import TooManyRequests
|
||||
|
||||
|
||||
class TestOpenAIProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Test OpenAI processor functionality"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test basic processor initialization"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-3.5-turbo',
|
||||
'api_key': 'test-api-key',
|
||||
'url': 'https://api.openai.com/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gpt-3.5-turbo'
|
||||
assert processor.temperature == 0.0
|
||||
assert processor.max_output == 4096
|
||||
assert hasattr(processor, 'openai')
|
||||
mock_openai_class.assert_called_once_with(base_url='https://api.openai.com/v1', api_key='test-api-key')
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test successful content generation"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Generated response from OpenAI"
|
||||
mock_response.usage.prompt_tokens = 20
|
||||
mock_response.usage.completion_tokens = 12
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-3.5-turbo',
|
||||
'api_key': 'test-api-key',
|
||||
'url': 'https://api.openai.com/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Generated response from OpenAI"
|
||||
assert result.in_token == 20
|
||||
assert result.out_token == 12
|
||||
assert result.model == 'gpt-3.5-turbo'
|
||||
|
||||
# Verify the OpenAI API call
|
||||
mock_openai_client.chat.completions.create.assert_called_once_with(
|
||||
model='gpt-3.5-turbo',
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": "System prompt\n\nUser prompt"
|
||||
}]
|
||||
}],
|
||||
temperature=0.0,
|
||||
max_tokens=4096,
|
||||
top_p=1,
|
||||
frequency_penalty=0,
|
||||
presence_penalty=0,
|
||||
response_format={"type": "text"}
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_rate_limit_error(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test rate limit error handling"""
|
||||
# Arrange
|
||||
from openai import RateLimitError
|
||||
|
||||
mock_openai_client = MagicMock()
|
||||
mock_openai_client.chat.completions.create.side_effect = RateLimitError("Rate limit exceeded", response=MagicMock(), body=None)
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-3.5-turbo',
|
||||
'api_key': 'test-api-key',
|
||||
'url': 'https://api.openai.com/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(TooManyRequests):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test handling of generic exceptions"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_openai_client.chat.completions.create.side_effect = Exception("API connection error")
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-3.5-turbo',
|
||||
'api_key': 'test-api-key',
|
||||
'url': 'https://api.openai.com/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="API connection error"):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_without_api_key(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test processor initialization without API key (should fail)"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-3.5-turbo',
|
||||
'api_key': None, # No API key provided
|
||||
'url': 'https://api.openai.com/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError, match="OpenAI API key not specified"):
|
||||
processor = Processor(**config)
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-4',
|
||||
'api_key': 'custom-api-key',
|
||||
'url': 'https://custom-openai-url.com/v1',
|
||||
'temperature': 0.7,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gpt-4'
|
||||
assert processor.temperature == 0.7
|
||||
assert processor.max_output == 2048
|
||||
mock_openai_class.assert_called_once_with(base_url='https://custom-openai-url.com/v1', api_key='custom-api-key')
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_defaults(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test processor initialization with default values"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
# Only provide required fields, should use defaults
|
||||
config = {
|
||||
'api_key': 'test-api-key',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gpt-3.5-turbo' # default_model
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
assert processor.max_output == 4096 # default_max_output
|
||||
mock_openai_class.assert_called_once_with(base_url='https://api.openai.com/v1', api_key='test-api-key')
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test content generation with empty prompts"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Default response"
|
||||
mock_response.usage.prompt_tokens = 2
|
||||
mock_response.usage.completion_tokens = 3
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-3.5-turbo',
|
||||
'api_key': 'test-api-key',
|
||||
'url': 'https://api.openai.com/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("", "")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Default response"
|
||||
assert result.in_token == 2
|
||||
assert result.out_token == 3
|
||||
assert result.model == 'gpt-3.5-turbo'
|
||||
|
||||
# Verify the combined prompt is sent correctly
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
expected_prompt = "\n\n" # Empty system + "\n\n" + empty user
|
||||
assert call_args[1]['messages'][0]['content'][0]['text'] == expected_prompt
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_openai_client_initialization_without_base_url(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test OpenAI client initialization without base_url"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-3.5-turbo',
|
||||
'api_key': 'test-api-key',
|
||||
'url': None, # No base URL
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert - should be called without base_url when it's None
|
||||
mock_openai_class.assert_called_once_with(api_key='test-api-key')
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_message_structure(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test that OpenAI messages are structured correctly"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Response with proper structure"
|
||||
mock_response.usage.prompt_tokens = 25
|
||||
mock_response.usage.completion_tokens = 15
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-3.5-turbo',
|
||||
'api_key': 'test-api-key',
|
||||
'url': 'https://api.openai.com/v1',
|
||||
'temperature': 0.5,
|
||||
'max_output': 1024,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("You are a helpful assistant", "What is AI?")
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with proper structure"
|
||||
assert result.in_token == 25
|
||||
assert result.out_token == 15
|
||||
|
||||
# Verify the message structure matches OpenAI Chat API format
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
messages = call_args[1]['messages']
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0]['role'] == 'user'
|
||||
assert messages[0]['content'][0]['type'] == 'text'
|
||||
assert messages[0]['content'][0]['text'] == "You are a helpful assistant\n\nWhat is AI?"
|
||||
|
||||
# Verify other parameters
|
||||
assert call_args[1]['model'] == 'gpt-3.5-turbo'
|
||||
assert call_args[1]['temperature'] == 0.5
|
||||
assert call_args[1]['max_tokens'] == 1024
|
||||
assert call_args[1]['top_p'] == 1
|
||||
assert call_args[1]['frequency_penalty'] == 0
|
||||
assert call_args[1]['presence_penalty'] == 0
|
||||
assert call_args[1]['response_format'] == {"type": "text"}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
397
tests/unit/test_text_completion/test_vertexai_processor.py
Normal file
397
tests/unit/test_text_completion/test_vertexai_processor.py
Normal file
|
|
@ -0,0 +1,397 @@
|
|||
"""
|
||||
Unit tests for trustgraph.model.text_completion.vertexai
|
||||
Starting simple with one test to get the basics working
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.model.text_completion.vertexai.llm import Processor
|
||||
from trustgraph.base import LlmResult
|
||||
|
||||
|
||||
class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Simple test for processor initialization"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
"""Test basic processor initialization with mocked dependencies"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
mock_model = MagicMock()
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
# Mock the parent class initialization to avoid taskgroup requirement
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'region': 'us-central1',
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'private_key': 'private.json',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(), # Required by AsyncProcessor
|
||||
'id': 'test-processor' # Required by AsyncProcessor
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gemini-2.0-flash-001' # It's stored as 'model', not 'model_name'
|
||||
assert hasattr(processor, 'generation_config')
|
||||
assert hasattr(processor, 'safety_settings')
|
||||
assert hasattr(processor, 'llm')
|
||||
mock_service_account.Credentials.from_service_account_file.assert_called_once_with('private.json')
|
||||
mock_vertexai.init.assert_called_once()
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
"""Test successful content generation"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Generated response from Gemini"
|
||||
mock_response.usage_metadata.prompt_token_count = 15
|
||||
mock_response.usage_metadata.candidates_token_count = 8
|
||||
mock_model.generate_content.return_value = mock_response
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'region': 'us-central1',
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'private_key': 'private.json',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Generated response from Gemini"
|
||||
assert result.in_token == 15
|
||||
assert result.out_token == 8
|
||||
assert result.model == 'gemini-2.0-flash-001'
|
||||
# Check that the method was called (actual prompt format may vary)
|
||||
mock_model.generate_content.assert_called_once()
|
||||
# Verify the call was made with the expected parameters
|
||||
call_args = mock_model.generate_content.call_args
|
||||
assert call_args[1]['generation_config'] == processor.generation_config
|
||||
assert call_args[1]['safety_settings'] == processor.safety_settings
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_rate_limit_error(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
"""Test rate limit error handling"""
|
||||
# Arrange
|
||||
from google.api_core.exceptions import ResourceExhausted
|
||||
from trustgraph.exceptions import TooManyRequests
|
||||
|
||||
mock_credentials = MagicMock()
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_model.generate_content.side_effect = ResourceExhausted("Rate limit exceeded")
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'region': 'us-central1',
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'private_key': 'private.json',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(TooManyRequests):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_blocked_response(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
"""Test handling of blocked content (safety filters)"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = None # Blocked content returns None
|
||||
mock_response.usage_metadata.prompt_token_count = 10
|
||||
mock_response.usage_metadata.candidates_token_count = 0
|
||||
mock_model.generate_content.return_value = mock_response
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'region': 'us-central1',
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'private_key': 'private.json',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "Blocked content")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text is None # Should preserve None for blocked content
|
||||
assert result.in_token == 10
|
||||
assert result.out_token == 0
|
||||
assert result.model == 'gemini-2.0-flash-001'
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_without_private_key(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
"""Test processor initialization without private key (should fail)"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'region': 'us-central1',
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'private_key': None, # No private key provided
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError, match="Private key file not specified"):
|
||||
processor = Processor(**config)
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
"""Test handling of generic exceptions"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_model.generate_content.side_effect = Exception("Network error")
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'region': 'us-central1',
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'private_key': 'private.json',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Network error"):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
mock_model = MagicMock()
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'region': 'us-west1',
|
||||
'model': 'gemini-1.5-pro',
|
||||
'temperature': 0.7,
|
||||
'max_output': 4096,
|
||||
'private_key': 'custom-key.json',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gemini-1.5-pro'
|
||||
|
||||
# Verify that generation_config object exists (can't easily check internal values)
|
||||
assert hasattr(processor, 'generation_config')
|
||||
assert processor.generation_config is not None
|
||||
|
||||
# Verify that safety settings are configured
|
||||
assert len(processor.safety_settings) == 4
|
||||
|
||||
# Verify service account was called with custom key
|
||||
mock_service_account.Credentials.from_service_account_file.assert_called_once_with('custom-key.json')
|
||||
|
||||
# Verify that parameters dict has the correct values (this is accessible)
|
||||
assert processor.parameters["temperature"] == 0.7
|
||||
assert processor.parameters["max_output_tokens"] == 4096
|
||||
assert processor.parameters["top_p"] == 1.0
|
||||
assert processor.parameters["top_k"] == 32
|
||||
assert processor.parameters["candidate_count"] == 1
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_vertexai_initialization_with_credentials(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
"""Test that VertexAI is initialized correctly with credentials"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_credentials.project_id = "test-project-123"
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
mock_model = MagicMock()
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'region': 'europe-west1',
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'private_key': 'service-account.json',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify VertexAI init was called with correct parameters
|
||||
mock_vertexai.init.assert_called_once_with(
|
||||
location='europe-west1',
|
||||
credentials=mock_credentials,
|
||||
project='test-project-123'
|
||||
)
|
||||
|
||||
# Verify GenerativeModel was created with the right model name
|
||||
mock_generative_model.assert_called_once_with('gemini-2.0-flash-001')
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
"""Test content generation with empty prompts"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Default response"
|
||||
mock_response.usage_metadata.prompt_token_count = 2
|
||||
mock_response.usage_metadata.candidates_token_count = 3
|
||||
mock_model.generate_content.return_value = mock_response
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'region': 'us-central1',
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'private_key': 'private.json',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("", "")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Default response"
|
||||
assert result.in_token == 2
|
||||
assert result.out_token == 3
|
||||
assert result.model == 'gemini-2.0-flash-001'
|
||||
|
||||
# Verify the model was called with the combined empty prompts
|
||||
mock_model.generate_content.assert_called_once()
|
||||
call_args = mock_model.generate_content.call_args
|
||||
# The prompt should be "" + "\n\n" + "" = "\n\n"
|
||||
assert call_args[0][0] == "\n\n"
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
489
tests/unit/test_text_completion/test_vllm_processor.py
Normal file
489
tests/unit/test_text_completion/test_vllm_processor.py
Normal file
|
|
@ -0,0 +1,489 @@
|
|||
"""
|
||||
Unit tests for trustgraph.model.text_completion.vllm
|
||||
Following the same successful pattern as previous tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.model.text_completion.vllm.llm import Processor
|
||||
from trustgraph.base import LlmResult
|
||||
from trustgraph.exceptions import TooManyRequests
|
||||
|
||||
|
||||
class TestVLLMProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Test vLLM processor functionality"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test basic processor initialization"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
|
||||
'url': 'http://vllm-service:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'TheBloke/Mistral-7B-v0.1-AWQ'
|
||||
assert processor.base_url == 'http://vllm-service:8899/v1'
|
||||
assert processor.temperature == 0.0
|
||||
assert processor.max_output == 2048
|
||||
assert hasattr(processor, 'session')
|
||||
mock_session_class.assert_called_once()
|
||||
|
||||
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test successful content generation"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'choices': [{
|
||||
'text': 'Generated response from vLLM'
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 20,
|
||||
'completion_tokens': 12
|
||||
}
|
||||
})
|
||||
|
||||
# Mock the async context manager
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session.post.return_value.__aexit__.return_value = None
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
|
||||
'url': 'http://vllm-service:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Generated response from vLLM"
|
||||
assert result.in_token == 20
|
||||
assert result.out_token == 12
|
||||
assert result.model == 'TheBloke/Mistral-7B-v0.1-AWQ'
|
||||
|
||||
# Verify the vLLM API call
|
||||
mock_session.post.assert_called_once_with(
|
||||
'http://vllm-service:8899/v1/completions',
|
||||
headers={'Content-Type': 'application/json'},
|
||||
json={
|
||||
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
|
||||
'prompt': 'System prompt\n\nUser prompt',
|
||||
'max_tokens': 2048,
|
||||
'temperature': 0.0
|
||||
}
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_http_error(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test HTTP error handling"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 500
|
||||
|
||||
# Mock the async context manager
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session.post.return_value.__aexit__.return_value = None
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
|
||||
'url': 'http://vllm-service:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError, match="Bad status: 500"):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test handling of generic exceptions"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_session.post.side_effect = Exception("Connection error")
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
|
||||
'url': 'http://vllm-service:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Connection error"):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'custom-model',
|
||||
'url': 'http://custom-vllm:8080/v1',
|
||||
'temperature': 0.7,
|
||||
'max_output': 1024,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'custom-model'
|
||||
assert processor.base_url == 'http://custom-vllm:8080/v1'
|
||||
assert processor.temperature == 0.7
|
||||
assert processor.max_output == 1024
|
||||
mock_session_class.assert_called_once()
|
||||
|
||||
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_defaults(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test processor initialization with default values"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
# Only provide required fields, should use defaults
|
||||
config = {
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'TheBloke/Mistral-7B-v0.1-AWQ' # default_model
|
||||
assert processor.base_url == 'http://vllm-service:8899/v1' # default_base_url
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
assert processor.max_output == 2048 # default_max_output
|
||||
mock_session_class.assert_called_once()
|
||||
|
||||
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test content generation with empty prompts"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'choices': [{
|
||||
'text': 'Default response'
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 2,
|
||||
'completion_tokens': 3
|
||||
}
|
||||
})
|
||||
|
||||
# Mock the async context manager
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session.post.return_value.__aexit__.return_value = None
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
|
||||
'url': 'http://vllm-service:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("", "")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Default response"
|
||||
assert result.in_token == 2
|
||||
assert result.out_token == 3
|
||||
assert result.model == 'TheBloke/Mistral-7B-v0.1-AWQ'
|
||||
|
||||
# Verify the combined prompt is sent correctly
|
||||
call_args = mock_session.post.call_args
|
||||
expected_prompt = "\n\n" # Empty system + "\n\n" + empty user
|
||||
assert call_args[1]['json']['prompt'] == expected_prompt
|
||||
|
||||
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_request_structure(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test that vLLM request is structured correctly"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'choices': [{
|
||||
'text': 'Response with proper structure'
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 25,
|
||||
'completion_tokens': 15
|
||||
}
|
||||
})
|
||||
|
||||
# Mock the async context manager
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session.post.return_value.__aexit__.return_value = None
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
|
||||
'url': 'http://vllm-service:8899/v1',
|
||||
'temperature': 0.5,
|
||||
'max_output': 1024,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("You are a helpful assistant", "What is AI?")
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with proper structure"
|
||||
assert result.in_token == 25
|
||||
assert result.out_token == 15
|
||||
|
||||
# Verify the request structure
|
||||
call_args = mock_session.post.call_args
|
||||
|
||||
# Check URL
|
||||
assert call_args[0][0] == 'http://vllm-service:8899/v1/completions'
|
||||
|
||||
# Check headers
|
||||
assert call_args[1]['headers']['Content-Type'] == 'application/json'
|
||||
|
||||
# Check request body
|
||||
request_data = call_args[1]['json']
|
||||
assert request_data['model'] == 'TheBloke/Mistral-7B-v0.1-AWQ'
|
||||
assert request_data['prompt'] == "You are a helpful assistant\n\nWhat is AI?"
|
||||
assert request_data['temperature'] == 0.5
|
||||
assert request_data['max_tokens'] == 1024
|
||||
|
||||
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_vllm_session_initialization(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test that aiohttp session is initialized correctly"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'test-model',
|
||||
'url': 'http://test-vllm:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify ClientSession was created
|
||||
mock_session_class.assert_called_once()
|
||||
|
||||
# Verify processor has the session
|
||||
assert processor.session == mock_session
|
||||
|
||||
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_response_parsing(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test response parsing from vLLM API"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'choices': [{
|
||||
'text': 'Parsed response text'
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 35,
|
||||
'completion_tokens': 25
|
||||
}
|
||||
})
|
||||
|
||||
# Mock the async context manager
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session.post.return_value.__aexit__.return_value = None
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
|
||||
'url': 'http://vllm-service:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System", "User query")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Parsed response text"
|
||||
assert result.in_token == 35
|
||||
assert result.out_token == 25
|
||||
assert result.model == 'TheBloke/Mistral-7B-v0.1-AWQ'
|
||||
|
||||
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_prompt_construction(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test prompt construction with system and user prompts"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'choices': [{
|
||||
'text': 'Response with system instructions'
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 40,
|
||||
'completion_tokens': 30
|
||||
}
|
||||
})
|
||||
|
||||
# Mock the async context manager
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session.post.return_value.__aexit__.return_value = None
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
|
||||
'url': 'http://vllm-service:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("You are a helpful assistant", "What is machine learning?")
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with system instructions"
|
||||
assert result.in_token == 40
|
||||
assert result.out_token == 30
|
||||
|
||||
# Verify the combined prompt
|
||||
call_args = mock_session.post.call_args
|
||||
expected_prompt = "You are a helpful assistant\n\nWhat is machine learning?"
|
||||
assert call_args[1]['json']['prompt'] == expected_prompt
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
Loading…
Add table
Add a link
Reference in a new issue