mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Test suite executed from CI pipeline (#433)
* Test strategy & test cases * Unit tests * Integration tests
This commit is contained in:
parent
9c7a070681
commit
2f7fddd206
101 changed files with 17811 additions and 1 deletions
3
tests/unit/__init__.py
Normal file
3
tests/unit/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
"""
|
||||
Unit tests for TrustGraph services
|
||||
"""
|
||||
58
tests/unit/test_base/test_async_processor.py
Normal file
58
tests/unit/test_base/test_async_processor.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
"""
|
||||
Unit tests for trustgraph.base.async_processor
|
||||
Starting small with a single test to verify basic functionality
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.base.async_processor import AsyncProcessor
|
||||
|
||||
|
||||
class TestAsyncProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Test AsyncProcessor base class functionality"""
|
||||
|
||||
@patch('trustgraph.base.async_processor.PulsarClient')
|
||||
@patch('trustgraph.base.async_processor.Consumer')
|
||||
@patch('trustgraph.base.async_processor.ProcessorMetrics')
|
||||
@patch('trustgraph.base.async_processor.ConsumerMetrics')
|
||||
async def test_async_processor_initialization_basic(self, mock_consumer_metrics, mock_processor_metrics,
|
||||
mock_consumer, mock_pulsar_client):
|
||||
"""Test basic AsyncProcessor initialization"""
|
||||
# Arrange
|
||||
mock_pulsar_client.return_value = MagicMock()
|
||||
mock_consumer.return_value = MagicMock()
|
||||
mock_processor_metrics.return_value = MagicMock()
|
||||
mock_consumer_metrics.return_value = MagicMock()
|
||||
|
||||
config = {
|
||||
'id': 'test-async-processor',
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = AsyncProcessor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify basic attributes are set
|
||||
assert processor.id == 'test-async-processor'
|
||||
assert processor.taskgroup == config['taskgroup']
|
||||
assert processor.running == True
|
||||
assert hasattr(processor, 'config_handlers')
|
||||
assert processor.config_handlers == []
|
||||
|
||||
# Verify PulsarClient was created
|
||||
mock_pulsar_client.assert_called_once_with(**config)
|
||||
|
||||
# Verify metrics were initialized
|
||||
mock_processor_metrics.assert_called_once()
|
||||
mock_consumer_metrics.assert_called_once()
|
||||
|
||||
# Verify Consumer was created for config subscription
|
||||
mock_consumer.assert_called_once()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
347
tests/unit/test_base/test_flow_processor.py
Normal file
347
tests/unit/test_base/test_flow_processor.py
Normal file
|
|
@ -0,0 +1,347 @@
|
|||
"""
|
||||
Unit tests for trustgraph.base.flow_processor
|
||||
Starting small with a single test to verify basic functionality
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.base.flow_processor import FlowProcessor
|
||||
|
||||
|
||||
class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Test FlowProcessor base class functionality"""
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
|
||||
async def test_flow_processor_initialization_basic(self, mock_register_config, mock_async_init):
|
||||
"""Test basic FlowProcessor initialization"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_register_config.return_value = None
|
||||
|
||||
config = {
|
||||
'id': 'test-flow-processor',
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = FlowProcessor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify AsyncProcessor.__init__ was called
|
||||
mock_async_init.assert_called_once()
|
||||
|
||||
# Verify register_config_handler was called with the correct handler
|
||||
mock_register_config.assert_called_once_with(processor.on_configure_flows)
|
||||
|
||||
# Verify FlowProcessor-specific initialization
|
||||
assert hasattr(processor, 'flows')
|
||||
assert processor.flows == {}
|
||||
assert hasattr(processor, 'specifications')
|
||||
assert processor.specifications == []
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
|
||||
async def test_register_specification(self, mock_register_config, mock_async_init):
|
||||
"""Test registering a specification"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_register_config.return_value = None
|
||||
|
||||
config = {
|
||||
'id': 'test-flow-processor',
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = FlowProcessor(**config)
|
||||
mock_spec = MagicMock()
|
||||
mock_spec.name = 'test-spec'
|
||||
|
||||
# Act
|
||||
processor.register_specification(mock_spec)
|
||||
|
||||
# Assert
|
||||
assert len(processor.specifications) == 1
|
||||
assert processor.specifications[0] == mock_spec
|
||||
|
||||
@patch('trustgraph.base.flow_processor.Flow')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
|
||||
async def test_start_flow(self, mock_register_config, mock_async_init, mock_flow_class):
|
||||
"""Test starting a flow"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_register_config.return_value = None
|
||||
|
||||
config = {
|
||||
'id': 'test-flow-processor',
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = FlowProcessor(**config)
|
||||
processor.id = 'test-processor' # Set id for Flow creation
|
||||
|
||||
mock_flow = AsyncMock()
|
||||
mock_flow_class.return_value = mock_flow
|
||||
|
||||
flow_name = 'test-flow'
|
||||
flow_defn = {'config': 'test-config'}
|
||||
|
||||
# Act
|
||||
await processor.start_flow(flow_name, flow_defn)
|
||||
|
||||
# Assert
|
||||
assert flow_name in processor.flows
|
||||
# Verify Flow was created with correct parameters
|
||||
mock_flow_class.assert_called_once_with('test-processor', flow_name, processor, flow_defn)
|
||||
# Verify the flow's start method was called
|
||||
mock_flow.start.assert_called_once()
|
||||
|
||||
@patch('trustgraph.base.flow_processor.Flow')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
|
||||
async def test_stop_flow(self, mock_register_config, mock_async_init, mock_flow_class):
|
||||
"""Test stopping a flow"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_register_config.return_value = None
|
||||
|
||||
config = {
|
||||
'id': 'test-flow-processor',
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = FlowProcessor(**config)
|
||||
processor.id = 'test-processor'
|
||||
|
||||
mock_flow = AsyncMock()
|
||||
mock_flow_class.return_value = mock_flow
|
||||
|
||||
flow_name = 'test-flow'
|
||||
flow_defn = {'config': 'test-config'}
|
||||
|
||||
# Start a flow first
|
||||
await processor.start_flow(flow_name, flow_defn)
|
||||
|
||||
# Act
|
||||
await processor.stop_flow(flow_name)
|
||||
|
||||
# Assert
|
||||
assert flow_name not in processor.flows
|
||||
mock_flow.stop.assert_called_once()
|
||||
|
||||
@patch('trustgraph.base.flow_processor.Flow')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
|
||||
async def test_stop_flow_not_exists(self, mock_register_config, mock_async_init, mock_flow_class):
|
||||
"""Test stopping a flow that doesn't exist"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_register_config.return_value = None
|
||||
|
||||
config = {
|
||||
'id': 'test-flow-processor',
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = FlowProcessor(**config)
|
||||
|
||||
# Act - should not raise an exception
|
||||
await processor.stop_flow('non-existent-flow')
|
||||
|
||||
# Assert - flows dict should still be empty
|
||||
assert processor.flows == {}
|
||||
|
||||
@patch('trustgraph.base.flow_processor.Flow')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
|
||||
async def test_on_configure_flows_basic(self, mock_register_config, mock_async_init, mock_flow_class):
|
||||
"""Test basic flow configuration handling"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_register_config.return_value = None
|
||||
|
||||
config = {
|
||||
'id': 'test-flow-processor',
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = FlowProcessor(**config)
|
||||
processor.id = 'test-processor'
|
||||
|
||||
mock_flow = AsyncMock()
|
||||
mock_flow_class.return_value = mock_flow
|
||||
|
||||
# Configuration with flows for this processor
|
||||
flow_config = {
|
||||
'test-flow': {'config': 'test-config'}
|
||||
}
|
||||
config_data = {
|
||||
'flows-active': {
|
||||
'test-processor': '{"test-flow": {"config": "test-config"}}'
|
||||
}
|
||||
}
|
||||
|
||||
# Act
|
||||
await processor.on_configure_flows(config_data, version=1)
|
||||
|
||||
# Assert
|
||||
assert 'test-flow' in processor.flows
|
||||
mock_flow_class.assert_called_once_with('test-processor', 'test-flow', processor, {'config': 'test-config'})
|
||||
mock_flow.start.assert_called_once()
|
||||
|
||||
@patch('trustgraph.base.flow_processor.Flow')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
|
||||
async def test_on_configure_flows_no_config(self, mock_register_config, mock_async_init, mock_flow_class):
|
||||
"""Test flow configuration handling when no config exists for this processor"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_register_config.return_value = None
|
||||
|
||||
config = {
|
||||
'id': 'test-flow-processor',
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = FlowProcessor(**config)
|
||||
processor.id = 'test-processor'
|
||||
|
||||
# Configuration without flows for this processor
|
||||
config_data = {
|
||||
'flows-active': {
|
||||
'other-processor': '{"other-flow": {"config": "other-config"}}'
|
||||
}
|
||||
}
|
||||
|
||||
# Act
|
||||
await processor.on_configure_flows(config_data, version=1)
|
||||
|
||||
# Assert
|
||||
assert processor.flows == {}
|
||||
mock_flow_class.assert_not_called()
|
||||
|
||||
@patch('trustgraph.base.flow_processor.Flow')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
|
||||
async def test_on_configure_flows_invalid_config(self, mock_register_config, mock_async_init, mock_flow_class):
|
||||
"""Test flow configuration handling with invalid config format"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_register_config.return_value = None
|
||||
|
||||
config = {
|
||||
'id': 'test-flow-processor',
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = FlowProcessor(**config)
|
||||
processor.id = 'test-processor'
|
||||
|
||||
# Configuration without flows-active key
|
||||
config_data = {
|
||||
'other-data': 'some-value'
|
||||
}
|
||||
|
||||
# Act
|
||||
await processor.on_configure_flows(config_data, version=1)
|
||||
|
||||
# Assert
|
||||
assert processor.flows == {}
|
||||
mock_flow_class.assert_not_called()
|
||||
|
||||
@patch('trustgraph.base.flow_processor.Flow')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
|
||||
async def test_on_configure_flows_start_and_stop(self, mock_register_config, mock_async_init, mock_flow_class):
|
||||
"""Test flow configuration handling with starting and stopping flows"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_register_config.return_value = None
|
||||
|
||||
config = {
|
||||
'id': 'test-flow-processor',
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = FlowProcessor(**config)
|
||||
processor.id = 'test-processor'
|
||||
|
||||
mock_flow1 = AsyncMock()
|
||||
mock_flow2 = AsyncMock()
|
||||
mock_flow_class.side_effect = [mock_flow1, mock_flow2]
|
||||
|
||||
# First configuration - start flow1
|
||||
config_data1 = {
|
||||
'flows-active': {
|
||||
'test-processor': '{"flow1": {"config": "config1"}}'
|
||||
}
|
||||
}
|
||||
|
||||
await processor.on_configure_flows(config_data1, version=1)
|
||||
|
||||
# Second configuration - stop flow1, start flow2
|
||||
config_data2 = {
|
||||
'flows-active': {
|
||||
'test-processor': '{"flow2": {"config": "config2"}}'
|
||||
}
|
||||
}
|
||||
|
||||
# Act
|
||||
await processor.on_configure_flows(config_data2, version=2)
|
||||
|
||||
# Assert
|
||||
# flow1 should be stopped and removed
|
||||
assert 'flow1' not in processor.flows
|
||||
mock_flow1.stop.assert_called_once()
|
||||
|
||||
# flow2 should be started and added
|
||||
assert 'flow2' in processor.flows
|
||||
mock_flow2.start.assert_called_once()
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.start')
|
||||
async def test_start_calls_parent(self, mock_parent_start, mock_register_config, mock_async_init):
|
||||
"""Test that start() calls parent start method"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_register_config.return_value = None
|
||||
mock_parent_start.return_value = None
|
||||
|
||||
config = {
|
||||
'id': 'test-flow-processor',
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = FlowProcessor(**config)
|
||||
|
||||
# Act
|
||||
await processor.start()
|
||||
|
||||
# Assert
|
||||
mock_parent_start.assert_called_once()
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.register_config_handler')
|
||||
async def test_add_args_calls_parent(self, mock_register_config, mock_async_init):
|
||||
"""Test that add_args() calls parent add_args method"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_register_config.return_value = None
|
||||
|
||||
mock_parser = MagicMock()
|
||||
|
||||
# Act
|
||||
with patch('trustgraph.base.async_processor.AsyncProcessor.add_args') as mock_parent_add_args:
|
||||
FlowProcessor.add_args(mock_parser)
|
||||
|
||||
# Assert
|
||||
mock_parent_add_args.assert_called_once_with(mock_parser)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
69
tests/unit/test_gateway/test_auth.py
Normal file
69
tests/unit/test_gateway/test_auth.py
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
"""
|
||||
Tests for Gateway Authentication
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from trustgraph.gateway.auth import Authenticator
|
||||
|
||||
|
||||
class TestAuthenticator:
|
||||
"""Test cases for Authenticator class"""
|
||||
|
||||
def test_authenticator_initialization_with_token(self):
|
||||
"""Test Authenticator initialization with valid token"""
|
||||
auth = Authenticator(token="test-token-123")
|
||||
|
||||
assert auth.token == "test-token-123"
|
||||
assert auth.allow_all is False
|
||||
|
||||
def test_authenticator_initialization_with_allow_all(self):
|
||||
"""Test Authenticator initialization with allow_all=True"""
|
||||
auth = Authenticator(allow_all=True)
|
||||
|
||||
assert auth.token is None
|
||||
assert auth.allow_all is True
|
||||
|
||||
def test_authenticator_initialization_without_token_raises_error(self):
|
||||
"""Test Authenticator initialization without token raises RuntimeError"""
|
||||
with pytest.raises(RuntimeError, match="Need a token"):
|
||||
Authenticator()
|
||||
|
||||
def test_authenticator_initialization_with_empty_token_raises_error(self):
|
||||
"""Test Authenticator initialization with empty token raises RuntimeError"""
|
||||
with pytest.raises(RuntimeError, match="Need a token"):
|
||||
Authenticator(token="")
|
||||
|
||||
def test_permitted_with_allow_all_returns_true(self):
|
||||
"""Test permitted method returns True when allow_all is enabled"""
|
||||
auth = Authenticator(allow_all=True)
|
||||
|
||||
# Should return True regardless of token or roles
|
||||
assert auth.permitted("any-token", []) is True
|
||||
assert auth.permitted("different-token", ["admin"]) is True
|
||||
assert auth.permitted(None, ["user"]) is True
|
||||
|
||||
def test_permitted_with_matching_token_returns_true(self):
|
||||
"""Test permitted method returns True with matching token"""
|
||||
auth = Authenticator(token="secret-token")
|
||||
|
||||
# Should return True when tokens match
|
||||
assert auth.permitted("secret-token", []) is True
|
||||
assert auth.permitted("secret-token", ["admin", "user"]) is True
|
||||
|
||||
def test_permitted_with_non_matching_token_returns_false(self):
|
||||
"""Test permitted method returns False with non-matching token"""
|
||||
auth = Authenticator(token="secret-token")
|
||||
|
||||
# Should return False when tokens don't match
|
||||
assert auth.permitted("wrong-token", []) is False
|
||||
assert auth.permitted("different-token", ["admin"]) is False
|
||||
assert auth.permitted(None, ["user"]) is False
|
||||
|
||||
def test_permitted_with_token_and_allow_all_returns_true(self):
|
||||
"""Test permitted method with both token and allow_all set"""
|
||||
auth = Authenticator(token="test-token", allow_all=True)
|
||||
|
||||
# allow_all should take precedence
|
||||
assert auth.permitted("any-token", []) is True
|
||||
assert auth.permitted("wrong-token", ["admin"]) is True
|
||||
408
tests/unit/test_gateway/test_config_receiver.py
Normal file
408
tests/unit/test_gateway/test_config_receiver.py
Normal file
|
|
@ -0,0 +1,408 @@
|
|||
"""
|
||||
Tests for Gateway Config Receiver
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import json
|
||||
from unittest.mock import Mock, patch, Mock, MagicMock
|
||||
import uuid
|
||||
|
||||
from trustgraph.gateway.config.receiver import ConfigReceiver
|
||||
|
||||
# Save the real method before patching
|
||||
_real_config_loader = ConfigReceiver.config_loader
|
||||
|
||||
# Patch async methods at module level to prevent coroutine warnings
|
||||
ConfigReceiver.config_loader = Mock()
|
||||
|
||||
|
||||
class TestConfigReceiver:
|
||||
"""Test cases for ConfigReceiver class"""
|
||||
|
||||
def test_config_receiver_initialization(self):
|
||||
"""Test ConfigReceiver initialization"""
|
||||
mock_pulsar_client = Mock()
|
||||
|
||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
||||
|
||||
assert config_receiver.pulsar_client == mock_pulsar_client
|
||||
assert config_receiver.flow_handlers == []
|
||||
assert config_receiver.flows == {}
|
||||
|
||||
def test_add_handler(self):
|
||||
"""Test adding flow handlers"""
|
||||
mock_pulsar_client = Mock()
|
||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
||||
|
||||
handler1 = Mock()
|
||||
handler2 = Mock()
|
||||
|
||||
config_receiver.add_handler(handler1)
|
||||
config_receiver.add_handler(handler2)
|
||||
|
||||
assert len(config_receiver.flow_handlers) == 2
|
||||
assert handler1 in config_receiver.flow_handlers
|
||||
assert handler2 in config_receiver.flow_handlers
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_with_new_flows(self):
|
||||
"""Test on_config method with new flows"""
|
||||
mock_pulsar_client = Mock()
|
||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
||||
|
||||
# Track calls manually instead of using AsyncMock
|
||||
start_flow_calls = []
|
||||
|
||||
async def mock_start_flow(*args):
|
||||
start_flow_calls.append(args)
|
||||
|
||||
config_receiver.start_flow = mock_start_flow
|
||||
|
||||
# Create mock message with flows
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.return_value = Mock(
|
||||
version="1.0",
|
||||
config={
|
||||
"flows": {
|
||||
"flow1": '{"name": "test_flow_1", "steps": []}',
|
||||
"flow2": '{"name": "test_flow_2", "steps": []}'
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
await config_receiver.on_config(mock_msg, None, None)
|
||||
|
||||
# Verify flows were added
|
||||
assert "flow1" in config_receiver.flows
|
||||
assert "flow2" in config_receiver.flows
|
||||
assert config_receiver.flows["flow1"] == {"name": "test_flow_1", "steps": []}
|
||||
assert config_receiver.flows["flow2"] == {"name": "test_flow_2", "steps": []}
|
||||
|
||||
# Verify start_flow was called for each new flow
|
||||
assert len(start_flow_calls) == 2
|
||||
assert ("flow1", {"name": "test_flow_1", "steps": []}) in start_flow_calls
|
||||
assert ("flow2", {"name": "test_flow_2", "steps": []}) in start_flow_calls
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_with_removed_flows(self):
|
||||
"""Test on_config method with removed flows"""
|
||||
mock_pulsar_client = Mock()
|
||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
||||
|
||||
# Pre-populate with existing flows
|
||||
config_receiver.flows = {
|
||||
"flow1": {"name": "test_flow_1", "steps": []},
|
||||
"flow2": {"name": "test_flow_2", "steps": []}
|
||||
}
|
||||
|
||||
# Track calls manually instead of using AsyncMock
|
||||
stop_flow_calls = []
|
||||
|
||||
async def mock_stop_flow(*args):
|
||||
stop_flow_calls.append(args)
|
||||
|
||||
config_receiver.stop_flow = mock_stop_flow
|
||||
|
||||
# Create mock message with only flow1 (flow2 removed)
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.return_value = Mock(
|
||||
version="1.0",
|
||||
config={
|
||||
"flows": {
|
||||
"flow1": '{"name": "test_flow_1", "steps": []}'
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
await config_receiver.on_config(mock_msg, None, None)
|
||||
|
||||
# Verify flow2 was removed
|
||||
assert "flow1" in config_receiver.flows
|
||||
assert "flow2" not in config_receiver.flows
|
||||
|
||||
# Verify stop_flow was called for removed flow
|
||||
assert len(stop_flow_calls) == 1
|
||||
assert stop_flow_calls[0] == ("flow2", {"name": "test_flow_2", "steps": []})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_with_no_flows(self):
|
||||
"""Test on_config method with no flows in config"""
|
||||
mock_pulsar_client = Mock()
|
||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
||||
|
||||
# Mock the start_flow and stop_flow methods with async functions
|
||||
async def mock_start_flow(*args):
|
||||
pass
|
||||
async def mock_stop_flow(*args):
|
||||
pass
|
||||
config_receiver.start_flow = mock_start_flow
|
||||
config_receiver.stop_flow = mock_stop_flow
|
||||
|
||||
# Create mock message without flows
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.return_value = Mock(
|
||||
version="1.0",
|
||||
config={}
|
||||
)
|
||||
|
||||
await config_receiver.on_config(mock_msg, None, None)
|
||||
|
||||
# Verify no flows were added
|
||||
assert config_receiver.flows == {}
|
||||
|
||||
# Since no flows were in the config, the flow methods shouldn't be called
|
||||
# (We can't easily assert this with simple async functions, but the test
|
||||
# passes if no exceptions are thrown)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_exception_handling(self):
|
||||
"""Test on_config method handles exceptions gracefully"""
|
||||
mock_pulsar_client = Mock()
|
||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
||||
|
||||
# Create mock message that will cause an exception
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.side_effect = Exception("Test exception")
|
||||
|
||||
# This should not raise an exception
|
||||
await config_receiver.on_config(mock_msg, None, None)
|
||||
|
||||
# Verify flows remain empty
|
||||
assert config_receiver.flows == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_flow_with_handlers(self):
|
||||
"""Test start_flow method with multiple handlers"""
|
||||
mock_pulsar_client = Mock()
|
||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
||||
|
||||
# Add mock handlers
|
||||
handler1 = Mock()
|
||||
handler1.start_flow = Mock()
|
||||
handler2 = Mock()
|
||||
handler2.start_flow = Mock()
|
||||
|
||||
config_receiver.add_handler(handler1)
|
||||
config_receiver.add_handler(handler2)
|
||||
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
|
||||
await config_receiver.start_flow("flow1", flow_data)
|
||||
|
||||
# Verify all handlers were called
|
||||
handler1.start_flow.assert_called_once_with("flow1", flow_data)
|
||||
handler2.start_flow.assert_called_once_with("flow1", flow_data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_flow_with_handler_exception(self):
|
||||
"""Test start_flow method handles handler exceptions"""
|
||||
mock_pulsar_client = Mock()
|
||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
||||
|
||||
# Add mock handler that raises exception
|
||||
handler = Mock()
|
||||
handler.start_flow = Mock(side_effect=Exception("Handler error"))
|
||||
|
||||
config_receiver.add_handler(handler)
|
||||
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
|
||||
# This should not raise an exception
|
||||
await config_receiver.start_flow("flow1", flow_data)
|
||||
|
||||
# Verify handler was called
|
||||
handler.start_flow.assert_called_once_with("flow1", flow_data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_flow_with_handlers(self):
|
||||
"""Test stop_flow method with multiple handlers"""
|
||||
mock_pulsar_client = Mock()
|
||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
||||
|
||||
# Add mock handlers
|
||||
handler1 = Mock()
|
||||
handler1.stop_flow = Mock()
|
||||
handler2 = Mock()
|
||||
handler2.stop_flow = Mock()
|
||||
|
||||
config_receiver.add_handler(handler1)
|
||||
config_receiver.add_handler(handler2)
|
||||
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
|
||||
await config_receiver.stop_flow("flow1", flow_data)
|
||||
|
||||
# Verify all handlers were called
|
||||
handler1.stop_flow.assert_called_once_with("flow1", flow_data)
|
||||
handler2.stop_flow.assert_called_once_with("flow1", flow_data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_flow_with_handler_exception(self):
|
||||
"""Test stop_flow method handles handler exceptions"""
|
||||
mock_pulsar_client = Mock()
|
||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
||||
|
||||
# Add mock handler that raises exception
|
||||
handler = Mock()
|
||||
handler.stop_flow = Mock(side_effect=Exception("Handler error"))
|
||||
|
||||
config_receiver.add_handler(handler)
|
||||
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
|
||||
# This should not raise an exception
|
||||
await config_receiver.stop_flow("flow1", flow_data)
|
||||
|
||||
# Verify handler was called
|
||||
handler.stop_flow.assert_called_once_with("flow1", flow_data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_config_loader_creates_consumer(self):
|
||||
"""Test config_loader method creates Pulsar consumer"""
|
||||
mock_pulsar_client = Mock()
|
||||
|
||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
||||
# Temporarily restore the real config_loader for this test
|
||||
config_receiver.config_loader = _real_config_loader.__get__(config_receiver)
|
||||
|
||||
# Mock Consumer class
|
||||
with patch('trustgraph.gateway.config.receiver.Consumer') as mock_consumer_class, \
|
||||
patch('uuid.uuid4') as mock_uuid:
|
||||
|
||||
mock_uuid.return_value = "test-uuid"
|
||||
mock_consumer = Mock()
|
||||
async def mock_start():
|
||||
pass
|
||||
mock_consumer.start = mock_start
|
||||
mock_consumer_class.return_value = mock_consumer
|
||||
|
||||
# Create a task that will complete quickly
|
||||
async def quick_task():
|
||||
await config_receiver.config_loader()
|
||||
|
||||
# Run the task with a timeout to prevent hanging
|
||||
try:
|
||||
await asyncio.wait_for(quick_task(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
# This is expected since the method runs indefinitely
|
||||
pass
|
||||
|
||||
# Verify Consumer was created with correct parameters
|
||||
mock_consumer_class.assert_called_once()
|
||||
call_args = mock_consumer_class.call_args
|
||||
|
||||
assert call_args[1]['client'] == mock_pulsar_client
|
||||
assert call_args[1]['subscriber'] == "gateway-test-uuid"
|
||||
assert call_args[1]['handler'] == config_receiver.on_config
|
||||
assert call_args[1]['start_of_messages'] is True
|
||||
|
||||
@patch('asyncio.create_task')
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_creates_config_loader_task(self, mock_create_task):
|
||||
"""Test start method creates config loader task"""
|
||||
mock_pulsar_client = Mock()
|
||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
||||
|
||||
# Mock create_task to avoid actually creating tasks with real coroutines
|
||||
mock_task = Mock()
|
||||
mock_create_task.return_value = mock_task
|
||||
|
||||
await config_receiver.start()
|
||||
|
||||
# Verify task was created
|
||||
mock_create_task.assert_called_once()
|
||||
|
||||
# Verify the argument passed to create_task is a coroutine
|
||||
call_args = mock_create_task.call_args[0]
|
||||
assert len(call_args) == 1 # Should have one argument (the coroutine)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_mixed_flow_operations(self):
|
||||
"""Test on_config with mixed add/remove operations"""
|
||||
mock_pulsar_client = Mock()
|
||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
||||
|
||||
# Pre-populate with existing flows
|
||||
config_receiver.flows = {
|
||||
"flow1": {"name": "test_flow_1", "steps": []},
|
||||
"flow2": {"name": "test_flow_2", "steps": []}
|
||||
}
|
||||
|
||||
# Track calls manually instead of using Mock
|
||||
start_flow_calls = []
|
||||
stop_flow_calls = []
|
||||
|
||||
async def mock_start_flow(*args):
|
||||
start_flow_calls.append(args)
|
||||
|
||||
async def mock_stop_flow(*args):
|
||||
stop_flow_calls.append(args)
|
||||
|
||||
# Directly assign to avoid patch.object detecting async methods
|
||||
original_start_flow = config_receiver.start_flow
|
||||
original_stop_flow = config_receiver.stop_flow
|
||||
config_receiver.start_flow = mock_start_flow
|
||||
config_receiver.stop_flow = mock_stop_flow
|
||||
|
||||
try:
|
||||
|
||||
# Create mock message with flow1 removed and flow3 added
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.return_value = Mock(
|
||||
version="1.0",
|
||||
config={
|
||||
"flows": {
|
||||
"flow2": '{"name": "test_flow_2", "steps": []}',
|
||||
"flow3": '{"name": "test_flow_3", "steps": []}'
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
await config_receiver.on_config(mock_msg, None, None)
|
||||
|
||||
# Verify final state
|
||||
assert "flow1" not in config_receiver.flows
|
||||
assert "flow2" in config_receiver.flows
|
||||
assert "flow3" in config_receiver.flows
|
||||
|
||||
# Verify operations
|
||||
assert len(start_flow_calls) == 1
|
||||
assert start_flow_calls[0] == ("flow3", {"name": "test_flow_3", "steps": []})
|
||||
assert len(stop_flow_calls) == 1
|
||||
assert stop_flow_calls[0] == ("flow1", {"name": "test_flow_1", "steps": []})
|
||||
|
||||
finally:
|
||||
# Restore original methods
|
||||
config_receiver.start_flow = original_start_flow
|
||||
config_receiver.stop_flow = original_stop_flow
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_config_invalid_json_flow_data(self):
|
||||
"""Test on_config handles invalid JSON in flow data"""
|
||||
mock_pulsar_client = Mock()
|
||||
config_receiver = ConfigReceiver(mock_pulsar_client)
|
||||
|
||||
# Mock the start_flow method with an async function
|
||||
async def mock_start_flow(*args):
|
||||
pass
|
||||
config_receiver.start_flow = mock_start_flow
|
||||
|
||||
# Create mock message with invalid JSON
|
||||
mock_msg = Mock()
|
||||
mock_msg.value.return_value = Mock(
|
||||
version="1.0",
|
||||
config={
|
||||
"flows": {
|
||||
"flow1": '{"invalid": json}', # Invalid JSON
|
||||
"flow2": '{"name": "valid_flow", "steps": []}' # Valid JSON
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# This should handle the exception gracefully
|
||||
await config_receiver.on_config(mock_msg, None, None)
|
||||
|
||||
# The entire operation should fail due to JSON parsing error
|
||||
# So no flows should be added
|
||||
assert config_receiver.flows == {}
|
||||
93
tests/unit/test_gateway/test_dispatch_config.py
Normal file
93
tests/unit/test_gateway/test_dispatch_config.py
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
"""
|
||||
Tests for Gateway Config Dispatch
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch, AsyncMock, Mock
|
||||
|
||||
from trustgraph.gateway.dispatch.config import ConfigRequestor
|
||||
|
||||
# Import parent class for local patching
|
||||
from trustgraph.gateway.dispatch.requestor import ServiceRequestor
|
||||
|
||||
|
||||
class TestConfigRequestor:
|
||||
"""Test cases for ConfigRequestor class"""
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.config.TranslatorRegistry')
|
||||
def test_config_requestor_initialization(self, mock_translator_registry):
|
||||
"""Test ConfigRequestor initialization"""
|
||||
# Mock translators
|
||||
mock_request_translator = Mock()
|
||||
mock_response_translator = Mock()
|
||||
mock_translator_registry.get_request_translator.return_value = mock_request_translator
|
||||
mock_translator_registry.get_response_translator.return_value = mock_response_translator
|
||||
|
||||
# Mock dependencies
|
||||
mock_pulsar_client = Mock()
|
||||
|
||||
requestor = ConfigRequestor(
|
||||
pulsar_client=mock_pulsar_client,
|
||||
consumer="test-consumer",
|
||||
subscriber="test-subscriber",
|
||||
timeout=60
|
||||
)
|
||||
|
||||
# Verify translator setup
|
||||
mock_translator_registry.get_request_translator.assert_called_once_with("config")
|
||||
mock_translator_registry.get_response_translator.assert_called_once_with("config")
|
||||
|
||||
assert requestor.request_translator == mock_request_translator
|
||||
assert requestor.response_translator == mock_response_translator
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.config.TranslatorRegistry')
|
||||
def test_config_requestor_to_request(self, mock_translator_registry):
|
||||
"""Test ConfigRequestor to_request method"""
|
||||
# Mock translators
|
||||
mock_request_translator = Mock()
|
||||
mock_translator_registry.get_request_translator.return_value = mock_request_translator
|
||||
mock_translator_registry.get_response_translator.return_value = Mock()
|
||||
|
||||
# Setup translator response
|
||||
mock_request_translator.to_pulsar.return_value = "translated_request"
|
||||
|
||||
# Patch ServiceRequestor async methods with regular mocks (not AsyncMock)
|
||||
with patch.object(ServiceRequestor, 'start', return_value=None), \
|
||||
patch.object(ServiceRequestor, 'process', return_value=None):
|
||||
requestor = ConfigRequestor(
|
||||
pulsar_client=Mock(),
|
||||
consumer="test-consumer",
|
||||
subscriber="test-subscriber"
|
||||
)
|
||||
|
||||
# Call to_request
|
||||
result = requestor.to_request({"test": "body"})
|
||||
|
||||
# Verify translator was called correctly
|
||||
mock_request_translator.to_pulsar.assert_called_once_with({"test": "body"})
|
||||
assert result == "translated_request"
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.config.TranslatorRegistry')
|
||||
def test_config_requestor_from_response(self, mock_translator_registry):
|
||||
"""Test ConfigRequestor from_response method"""
|
||||
# Mock translators
|
||||
mock_response_translator = Mock()
|
||||
mock_translator_registry.get_request_translator.return_value = Mock()
|
||||
mock_translator_registry.get_response_translator.return_value = mock_response_translator
|
||||
|
||||
# Setup translator response
|
||||
mock_response_translator.from_response_with_completion.return_value = "translated_response"
|
||||
|
||||
requestor = ConfigRequestor(
|
||||
pulsar_client=Mock(),
|
||||
consumer="test-consumer",
|
||||
subscriber="test-subscriber"
|
||||
)
|
||||
|
||||
# Call from_response
|
||||
mock_message = Mock()
|
||||
result = requestor.from_response(mock_message)
|
||||
|
||||
# Verify translator was called correctly
|
||||
mock_response_translator.from_response_with_completion.assert_called_once_with(mock_message)
|
||||
assert result == "translated_response"
|
||||
558
tests/unit/test_gateway/test_dispatch_manager.py
Normal file
558
tests/unit/test_gateway/test_dispatch_manager.py
Normal file
|
|
@ -0,0 +1,558 @@
|
|||
"""
|
||||
Tests for Gateway Dispatcher Manager
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import Mock, patch, AsyncMock, MagicMock
|
||||
import uuid
|
||||
|
||||
from trustgraph.gateway.dispatch.manager import DispatcherManager, DispatcherWrapper
|
||||
|
||||
# Keep the real methods intact for proper testing
|
||||
|
||||
|
||||
class TestDispatcherWrapper:
|
||||
"""Test cases for DispatcherWrapper class"""
|
||||
|
||||
def test_dispatcher_wrapper_initialization(self):
|
||||
"""Test DispatcherWrapper initialization"""
|
||||
mock_handler = Mock()
|
||||
wrapper = DispatcherWrapper(mock_handler)
|
||||
|
||||
assert wrapper.handler == mock_handler
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatcher_wrapper_process(self):
|
||||
"""Test DispatcherWrapper process method"""
|
||||
mock_handler = AsyncMock()
|
||||
wrapper = DispatcherWrapper(mock_handler)
|
||||
|
||||
result = await wrapper.process("arg1", "arg2")
|
||||
|
||||
mock_handler.assert_called_once_with("arg1", "arg2")
|
||||
assert result == mock_handler.return_value
|
||||
|
||||
|
||||
class TestDispatcherManager:
|
||||
"""Test cases for DispatcherManager class"""
|
||||
|
||||
def test_dispatcher_manager_initialization(self):
|
||||
"""Test DispatcherManager initialization"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
assert manager.pulsar_client == mock_pulsar_client
|
||||
assert manager.config_receiver == mock_config_receiver
|
||||
assert manager.prefix == "api-gateway" # default prefix
|
||||
assert manager.flows == {}
|
||||
assert manager.dispatchers == {}
|
||||
|
||||
# Verify manager was added as handler to config receiver
|
||||
mock_config_receiver.add_handler.assert_called_once_with(manager)
|
||||
|
||||
def test_dispatcher_manager_initialization_with_custom_prefix(self):
|
||||
"""Test DispatcherManager initialization with custom prefix"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver, prefix="custom-prefix")
|
||||
|
||||
assert manager.prefix == "custom-prefix"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_flow(self):
|
||||
"""Test start_flow method"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
|
||||
await manager.start_flow("flow1", flow_data)
|
||||
|
||||
assert "flow1" in manager.flows
|
||||
assert manager.flows["flow1"] == flow_data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_flow(self):
|
||||
"""Test stop_flow method"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
# Pre-populate with a flow
|
||||
flow_data = {"name": "test_flow", "steps": []}
|
||||
manager.flows["flow1"] = flow_data
|
||||
|
||||
await manager.stop_flow("flow1", flow_data)
|
||||
|
||||
assert "flow1" not in manager.flows
|
||||
|
||||
def test_dispatch_global_service_returns_wrapper(self):
|
||||
"""Test dispatch_global_service returns DispatcherWrapper"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
wrapper = manager.dispatch_global_service()
|
||||
|
||||
assert isinstance(wrapper, DispatcherWrapper)
|
||||
assert wrapper.handler == manager.process_global_service
|
||||
|
||||
def test_dispatch_core_export_returns_wrapper(self):
|
||||
"""Test dispatch_core_export returns DispatcherWrapper"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
wrapper = manager.dispatch_core_export()
|
||||
|
||||
assert isinstance(wrapper, DispatcherWrapper)
|
||||
assert wrapper.handler == manager.process_core_export
|
||||
|
||||
def test_dispatch_core_import_returns_wrapper(self):
|
||||
"""Test dispatch_core_import returns DispatcherWrapper"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
wrapper = manager.dispatch_core_import()
|
||||
|
||||
assert isinstance(wrapper, DispatcherWrapper)
|
||||
assert wrapper.handler == manager.process_core_import
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_core_import(self):
|
||||
"""Test process_core_import method"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.CoreImport') as mock_core_import:
|
||||
mock_importer = Mock()
|
||||
mock_importer.process = AsyncMock(return_value="import_result")
|
||||
mock_core_import.return_value = mock_importer
|
||||
|
||||
result = await manager.process_core_import("data", "error", "ok", "request")
|
||||
|
||||
mock_core_import.assert_called_once_with(mock_pulsar_client)
|
||||
mock_importer.process.assert_called_once_with("data", "error", "ok", "request")
|
||||
assert result == "import_result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_core_export(self):
|
||||
"""Test process_core_export method"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.CoreExport') as mock_core_export:
|
||||
mock_exporter = Mock()
|
||||
mock_exporter.process = AsyncMock(return_value="export_result")
|
||||
mock_core_export.return_value = mock_exporter
|
||||
|
||||
result = await manager.process_core_export("data", "error", "ok", "request")
|
||||
|
||||
mock_core_export.assert_called_once_with(mock_pulsar_client)
|
||||
mock_exporter.process.assert_called_once_with("data", "error", "ok", "request")
|
||||
assert result == "export_result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_global_service(self):
|
||||
"""Test process_global_service method"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
manager.invoke_global_service = AsyncMock(return_value="global_result")
|
||||
|
||||
params = {"kind": "test_kind"}
|
||||
result = await manager.process_global_service("data", "responder", params)
|
||||
|
||||
manager.invoke_global_service.assert_called_once_with("data", "responder", "test_kind")
|
||||
assert result == "global_result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_global_service_with_existing_dispatcher(self):
|
||||
"""Test invoke_global_service with existing dispatcher"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
# Pre-populate with existing dispatcher
|
||||
mock_dispatcher = Mock()
|
||||
mock_dispatcher.process = AsyncMock(return_value="cached_result")
|
||||
manager.dispatchers[(None, "config")] = mock_dispatcher
|
||||
|
||||
result = await manager.invoke_global_service("data", "responder", "config")
|
||||
|
||||
mock_dispatcher.process.assert_called_once_with("data", "responder")
|
||||
assert result == "cached_result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_global_service_creates_new_dispatcher(self):
|
||||
"""Test invoke_global_service creates new dispatcher"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers') as mock_dispatchers:
|
||||
mock_dispatcher_class = Mock()
|
||||
mock_dispatcher = Mock()
|
||||
mock_dispatcher.start = AsyncMock()
|
||||
mock_dispatcher.process = AsyncMock(return_value="new_result")
|
||||
mock_dispatcher_class.return_value = mock_dispatcher
|
||||
mock_dispatchers.__getitem__.return_value = mock_dispatcher_class
|
||||
|
||||
result = await manager.invoke_global_service("data", "responder", "config")
|
||||
|
||||
# Verify dispatcher was created with correct parameters
|
||||
mock_dispatcher_class.assert_called_once_with(
|
||||
pulsar_client=mock_pulsar_client,
|
||||
timeout=120,
|
||||
consumer="api-gateway-config-request",
|
||||
subscriber="api-gateway-config-request"
|
||||
)
|
||||
mock_dispatcher.start.assert_called_once()
|
||||
mock_dispatcher.process.assert_called_once_with("data", "responder")
|
||||
|
||||
# Verify dispatcher was cached
|
||||
assert manager.dispatchers[(None, "config")] == mock_dispatcher
|
||||
assert result == "new_result"
|
||||
|
||||
def test_dispatch_flow_import_returns_method(self):
|
||||
"""Test dispatch_flow_import returns correct method"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
result = manager.dispatch_flow_import()
|
||||
|
||||
assert result == manager.process_flow_import
|
||||
|
||||
def test_dispatch_flow_export_returns_method(self):
|
||||
"""Test dispatch_flow_export returns correct method"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
result = manager.dispatch_flow_export()
|
||||
|
||||
assert result == manager.process_flow_export
|
||||
|
||||
def test_dispatch_socket_returns_method(self):
|
||||
"""Test dispatch_socket returns correct method"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
result = manager.dispatch_socket()
|
||||
|
||||
assert result == manager.process_socket
|
||||
|
||||
def test_dispatch_flow_service_returns_wrapper(self):
|
||||
"""Test dispatch_flow_service returns DispatcherWrapper"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
wrapper = manager.dispatch_flow_service()
|
||||
|
||||
assert isinstance(wrapper, DispatcherWrapper)
|
||||
assert wrapper.handler == manager.process_flow_service
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_flow_import_with_valid_flow_and_kind(self):
|
||||
"""Test process_flow_import with valid flow and kind"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
# Setup test flow
|
||||
manager.flows["test_flow"] = {
|
||||
"interfaces": {
|
||||
"triples-store": {"queue": "test_queue"}
|
||||
}
|
||||
}
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.import_dispatchers') as mock_dispatchers, \
|
||||
patch('uuid.uuid4') as mock_uuid:
|
||||
mock_uuid.return_value = "test-uuid"
|
||||
mock_dispatcher_class = Mock()
|
||||
mock_dispatcher = Mock()
|
||||
mock_dispatcher.start = AsyncMock()
|
||||
mock_dispatcher_class.return_value = mock_dispatcher
|
||||
mock_dispatchers.__getitem__.return_value = mock_dispatcher_class
|
||||
mock_dispatchers.__contains__.return_value = True
|
||||
|
||||
params = {"flow": "test_flow", "kind": "triples"}
|
||||
result = await manager.process_flow_import("ws", "running", params)
|
||||
|
||||
mock_dispatcher_class.assert_called_once_with(
|
||||
pulsar_client=mock_pulsar_client,
|
||||
ws="ws",
|
||||
running="running",
|
||||
queue={"queue": "test_queue"}
|
||||
)
|
||||
mock_dispatcher.start.assert_called_once()
|
||||
assert result == mock_dispatcher
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_flow_import_with_invalid_flow(self):
|
||||
"""Test process_flow_import with invalid flow"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
params = {"flow": "invalid_flow", "kind": "triples"}
|
||||
|
||||
with pytest.raises(RuntimeError, match="Invalid flow"):
|
||||
await manager.process_flow_import("ws", "running", params)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_flow_import_with_invalid_kind(self):
|
||||
"""Test process_flow_import with invalid kind"""
|
||||
import warnings
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", RuntimeWarning)
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
# Setup test flow
|
||||
manager.flows["test_flow"] = {
|
||||
"interfaces": {
|
||||
"triples-store": {"queue": "test_queue"}
|
||||
}
|
||||
}
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.import_dispatchers') as mock_dispatchers:
|
||||
mock_dispatchers.__contains__.return_value = False
|
||||
|
||||
params = {"flow": "test_flow", "kind": "invalid_kind"}
|
||||
|
||||
with pytest.raises(RuntimeError, match="Invalid kind"):
|
||||
await manager.process_flow_import("ws", "running", params)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_flow_export_with_valid_flow_and_kind(self):
|
||||
"""Test process_flow_export with valid flow and kind"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
# Setup test flow
|
||||
manager.flows["test_flow"] = {
|
||||
"interfaces": {
|
||||
"triples-store": {"queue": "test_queue"}
|
||||
}
|
||||
}
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.export_dispatchers') as mock_dispatchers, \
|
||||
patch('uuid.uuid4') as mock_uuid:
|
||||
mock_uuid.return_value = "test-uuid"
|
||||
mock_dispatcher_class = Mock()
|
||||
mock_dispatcher = Mock()
|
||||
mock_dispatcher_class.return_value = mock_dispatcher
|
||||
mock_dispatchers.__getitem__.return_value = mock_dispatcher_class
|
||||
mock_dispatchers.__contains__.return_value = True
|
||||
|
||||
params = {"flow": "test_flow", "kind": "triples"}
|
||||
result = await manager.process_flow_export("ws", "running", params)
|
||||
|
||||
mock_dispatcher_class.assert_called_once_with(
|
||||
pulsar_client=mock_pulsar_client,
|
||||
ws="ws",
|
||||
running="running",
|
||||
queue={"queue": "test_queue"},
|
||||
consumer="api-gateway-test-uuid",
|
||||
subscriber="api-gateway-test-uuid"
|
||||
)
|
||||
assert result == mock_dispatcher
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_socket(self):
|
||||
"""Test process_socket method"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.Mux') as mock_mux:
|
||||
mock_mux_instance = Mock()
|
||||
mock_mux.return_value = mock_mux_instance
|
||||
|
||||
result = await manager.process_socket("ws", "running", {})
|
||||
|
||||
mock_mux.assert_called_once_with(manager, "ws", "running")
|
||||
assert result == mock_mux_instance
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_flow_service(self):
|
||||
"""Test process_flow_service method"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
manager.invoke_flow_service = AsyncMock(return_value="flow_result")
|
||||
|
||||
params = {"flow": "test_flow", "kind": "agent"}
|
||||
result = await manager.process_flow_service("data", "responder", params)
|
||||
|
||||
manager.invoke_flow_service.assert_called_once_with("data", "responder", "test_flow", "agent")
|
||||
assert result == "flow_result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_flow_service_with_existing_dispatcher(self):
|
||||
"""Test invoke_flow_service with existing dispatcher"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
# Add flow to the flows dictionary
|
||||
manager.flows["test_flow"] = {"services": {"agent": {}}}
|
||||
|
||||
# Pre-populate with existing dispatcher
|
||||
mock_dispatcher = Mock()
|
||||
mock_dispatcher.process = AsyncMock(return_value="cached_result")
|
||||
manager.dispatchers[("test_flow", "agent")] = mock_dispatcher
|
||||
|
||||
result = await manager.invoke_flow_service("data", "responder", "test_flow", "agent")
|
||||
|
||||
mock_dispatcher.process.assert_called_once_with("data", "responder")
|
||||
assert result == "cached_result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_flow_service_creates_request_response_dispatcher(self):
|
||||
"""Test invoke_flow_service creates request-response dispatcher"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
# Setup test flow
|
||||
manager.flows["test_flow"] = {
|
||||
"interfaces": {
|
||||
"agent": {
|
||||
"request": "agent_request_queue",
|
||||
"response": "agent_response_queue"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.request_response_dispatchers') as mock_dispatchers:
|
||||
mock_dispatcher_class = Mock()
|
||||
mock_dispatcher = Mock()
|
||||
mock_dispatcher.start = AsyncMock()
|
||||
mock_dispatcher.process = AsyncMock(return_value="new_result")
|
||||
mock_dispatcher_class.return_value = mock_dispatcher
|
||||
mock_dispatchers.__getitem__.return_value = mock_dispatcher_class
|
||||
mock_dispatchers.__contains__.return_value = True
|
||||
|
||||
result = await manager.invoke_flow_service("data", "responder", "test_flow", "agent")
|
||||
|
||||
# Verify dispatcher was created with correct parameters
|
||||
mock_dispatcher_class.assert_called_once_with(
|
||||
pulsar_client=mock_pulsar_client,
|
||||
request_queue="agent_request_queue",
|
||||
response_queue="agent_response_queue",
|
||||
timeout=120,
|
||||
consumer="api-gateway-test_flow-agent-request",
|
||||
subscriber="api-gateway-test_flow-agent-request"
|
||||
)
|
||||
mock_dispatcher.start.assert_called_once()
|
||||
mock_dispatcher.process.assert_called_once_with("data", "responder")
|
||||
|
||||
# Verify dispatcher was cached
|
||||
assert manager.dispatchers[("test_flow", "agent")] == mock_dispatcher
|
||||
assert result == "new_result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_flow_service_creates_sender_dispatcher(self):
|
||||
"""Test invoke_flow_service creates sender dispatcher"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
# Setup test flow
|
||||
manager.flows["test_flow"] = {
|
||||
"interfaces": {
|
||||
"text-load": {"queue": "text_load_queue"}
|
||||
}
|
||||
}
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.request_response_dispatchers') as mock_rr_dispatchers, \
|
||||
patch('trustgraph.gateway.dispatch.manager.sender_dispatchers') as mock_sender_dispatchers:
|
||||
mock_rr_dispatchers.__contains__.return_value = False
|
||||
mock_sender_dispatchers.__contains__.return_value = True
|
||||
|
||||
mock_dispatcher_class = Mock()
|
||||
mock_dispatcher = Mock()
|
||||
mock_dispatcher.start = AsyncMock()
|
||||
mock_dispatcher.process = AsyncMock(return_value="sender_result")
|
||||
mock_dispatcher_class.return_value = mock_dispatcher
|
||||
mock_sender_dispatchers.__getitem__.return_value = mock_dispatcher_class
|
||||
|
||||
result = await manager.invoke_flow_service("data", "responder", "test_flow", "text-load")
|
||||
|
||||
# Verify dispatcher was created with correct parameters
|
||||
mock_dispatcher_class.assert_called_once_with(
|
||||
pulsar_client=mock_pulsar_client,
|
||||
queue={"queue": "text_load_queue"}
|
||||
)
|
||||
mock_dispatcher.start.assert_called_once()
|
||||
mock_dispatcher.process.assert_called_once_with("data", "responder")
|
||||
|
||||
# Verify dispatcher was cached
|
||||
assert manager.dispatchers[("test_flow", "text-load")] == mock_dispatcher
|
||||
assert result == "sender_result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_flow_service_invalid_flow(self):
|
||||
"""Test invoke_flow_service with invalid flow"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Invalid flow"):
|
||||
await manager.invoke_flow_service("data", "responder", "invalid_flow", "agent")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_flow_service_unsupported_kind_by_flow(self):
|
||||
"""Test invoke_flow_service with kind not supported by flow"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
# Setup test flow without agent interface
|
||||
manager.flows["test_flow"] = {
|
||||
"interfaces": {
|
||||
"text-completion": {"request": "req", "response": "resp"}
|
||||
}
|
||||
}
|
||||
|
||||
with pytest.raises(RuntimeError, match="This kind not supported by flow"):
|
||||
await manager.invoke_flow_service("data", "responder", "test_flow", "agent")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_flow_service_invalid_kind(self):
|
||||
"""Test invoke_flow_service with invalid kind"""
|
||||
mock_pulsar_client = Mock()
|
||||
mock_config_receiver = Mock()
|
||||
manager = DispatcherManager(mock_pulsar_client, mock_config_receiver)
|
||||
|
||||
# Setup test flow with interface but unsupported kind
|
||||
manager.flows["test_flow"] = {
|
||||
"interfaces": {
|
||||
"invalid-kind": {"request": "req", "response": "resp"}
|
||||
}
|
||||
}
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.request_response_dispatchers') as mock_rr_dispatchers, \
|
||||
patch('trustgraph.gateway.dispatch.manager.sender_dispatchers') as mock_sender_dispatchers:
|
||||
mock_rr_dispatchers.__contains__.return_value = False
|
||||
mock_sender_dispatchers.__contains__.return_value = False
|
||||
|
||||
with pytest.raises(RuntimeError, match="Invalid kind"):
|
||||
await manager.invoke_flow_service("data", "responder", "test_flow", "invalid-kind")
|
||||
171
tests/unit/test_gateway/test_dispatch_mux.py
Normal file
171
tests/unit/test_gateway/test_dispatch_mux.py
Normal file
|
|
@ -0,0 +1,171 @@
|
|||
"""
|
||||
Tests for Gateway Dispatch Mux
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
|
||||
from trustgraph.gateway.dispatch.mux import Mux, MAX_QUEUE_SIZE
|
||||
|
||||
|
||||
class TestMux:
|
||||
"""Test cases for Mux class"""
|
||||
|
||||
def test_mux_initialization(self):
|
||||
"""Test Mux initialization"""
|
||||
mock_dispatcher_manager = MagicMock()
|
||||
mock_ws = MagicMock()
|
||||
mock_running = MagicMock()
|
||||
|
||||
mux = Mux(
|
||||
dispatcher_manager=mock_dispatcher_manager,
|
||||
ws=mock_ws,
|
||||
running=mock_running
|
||||
)
|
||||
|
||||
assert mux.dispatcher_manager == mock_dispatcher_manager
|
||||
assert mux.ws == mock_ws
|
||||
assert mux.running == mock_running
|
||||
assert isinstance(mux.q, asyncio.Queue)
|
||||
assert mux.q.maxsize == MAX_QUEUE_SIZE
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mux_destroy_with_websocket(self):
|
||||
"""Test Mux destroy method with websocket"""
|
||||
mock_dispatcher_manager = MagicMock()
|
||||
mock_ws = AsyncMock()
|
||||
mock_running = MagicMock()
|
||||
|
||||
mux = Mux(
|
||||
dispatcher_manager=mock_dispatcher_manager,
|
||||
ws=mock_ws,
|
||||
running=mock_running
|
||||
)
|
||||
|
||||
# Call destroy
|
||||
await mux.destroy()
|
||||
|
||||
# Verify running.stop was called
|
||||
mock_running.stop.assert_called_once()
|
||||
|
||||
# Verify websocket close was called
|
||||
mock_ws.close.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mux_destroy_without_websocket(self):
|
||||
"""Test Mux destroy method without websocket"""
|
||||
mock_dispatcher_manager = MagicMock()
|
||||
mock_running = MagicMock()
|
||||
|
||||
mux = Mux(
|
||||
dispatcher_manager=mock_dispatcher_manager,
|
||||
ws=None,
|
||||
running=mock_running
|
||||
)
|
||||
|
||||
# Call destroy
|
||||
await mux.destroy()
|
||||
|
||||
# Verify running.stop was called
|
||||
mock_running.stop.assert_called_once()
|
||||
# No websocket to close
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mux_receive_valid_message(self):
|
||||
"""Test Mux receive method with valid message"""
|
||||
mock_dispatcher_manager = MagicMock()
|
||||
mock_ws = AsyncMock()
|
||||
mock_running = MagicMock()
|
||||
|
||||
mux = Mux(
|
||||
dispatcher_manager=mock_dispatcher_manager,
|
||||
ws=mock_ws,
|
||||
running=mock_running
|
||||
)
|
||||
|
||||
# Mock message with valid JSON
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.json.return_value = {
|
||||
"request": {"type": "test"},
|
||||
"id": "test-id-123",
|
||||
"service": "test-service"
|
||||
}
|
||||
|
||||
# Call receive
|
||||
await mux.receive(mock_msg)
|
||||
|
||||
# Verify json was called
|
||||
mock_msg.json.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mux_receive_message_without_request(self):
|
||||
"""Test Mux receive method with message missing request field"""
|
||||
mock_dispatcher_manager = MagicMock()
|
||||
mock_ws = AsyncMock()
|
||||
mock_running = MagicMock()
|
||||
|
||||
mux = Mux(
|
||||
dispatcher_manager=mock_dispatcher_manager,
|
||||
ws=mock_ws,
|
||||
running=mock_running
|
||||
)
|
||||
|
||||
# Mock message without request field
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.json.return_value = {
|
||||
"id": "test-id-123"
|
||||
}
|
||||
|
||||
# receive method should handle the RuntimeError internally
|
||||
# Based on the code, it seems to catch exceptions
|
||||
await mux.receive(mock_msg)
|
||||
|
||||
mock_ws.send_json.assert_called_once_with({"error": "Bad message"})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mux_receive_message_without_id(self):
|
||||
"""Test Mux receive method with message missing id field"""
|
||||
mock_dispatcher_manager = MagicMock()
|
||||
mock_ws = AsyncMock()
|
||||
mock_running = MagicMock()
|
||||
|
||||
mux = Mux(
|
||||
dispatcher_manager=mock_dispatcher_manager,
|
||||
ws=mock_ws,
|
||||
running=mock_running
|
||||
)
|
||||
|
||||
# Mock message without id field
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.json.return_value = {
|
||||
"request": {"type": "test"}
|
||||
}
|
||||
|
||||
# receive method should handle the RuntimeError internally
|
||||
await mux.receive(mock_msg)
|
||||
|
||||
mock_ws.send_json.assert_called_once_with({"error": "Bad message"})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mux_receive_invalid_json(self):
|
||||
"""Test Mux receive method with invalid JSON"""
|
||||
mock_dispatcher_manager = MagicMock()
|
||||
mock_ws = AsyncMock()
|
||||
mock_running = MagicMock()
|
||||
|
||||
mux = Mux(
|
||||
dispatcher_manager=mock_dispatcher_manager,
|
||||
ws=mock_ws,
|
||||
running=mock_running
|
||||
)
|
||||
|
||||
# Mock message with invalid JSON
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.json.side_effect = ValueError("Invalid JSON")
|
||||
|
||||
# receive method should handle the ValueError internally
|
||||
await mux.receive(mock_msg)
|
||||
|
||||
mock_msg.json.assert_called_once()
|
||||
mock_ws.send_json.assert_called_once_with({"error": "Invalid JSON"})
|
||||
118
tests/unit/test_gateway/test_dispatch_requestor.py
Normal file
118
tests/unit/test_gateway/test_dispatch_requestor.py
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
"""
|
||||
Tests for Gateway Service Requestor
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
from trustgraph.gateway.dispatch.requestor import ServiceRequestor
|
||||
|
||||
|
||||
class TestServiceRequestor:
|
||||
"""Test cases for ServiceRequestor class"""
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.requestor.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.requestor.Subscriber')
|
||||
def test_service_requestor_initialization(self, mock_subscriber, mock_publisher):
|
||||
"""Test ServiceRequestor initialization"""
|
||||
mock_pulsar_client = MagicMock()
|
||||
mock_request_schema = MagicMock()
|
||||
mock_response_schema = MagicMock()
|
||||
|
||||
requestor = ServiceRequestor(
|
||||
pulsar_client=mock_pulsar_client,
|
||||
request_queue="test-request-queue",
|
||||
request_schema=mock_request_schema,
|
||||
response_queue="test-response-queue",
|
||||
response_schema=mock_response_schema,
|
||||
subscription="test-subscription",
|
||||
consumer_name="test-consumer",
|
||||
timeout=300
|
||||
)
|
||||
|
||||
# Verify Publisher was created correctly
|
||||
mock_publisher.assert_called_once_with(
|
||||
mock_pulsar_client, "test-request-queue", schema=mock_request_schema
|
||||
)
|
||||
|
||||
# Verify Subscriber was created correctly
|
||||
mock_subscriber.assert_called_once_with(
|
||||
mock_pulsar_client, "test-response-queue",
|
||||
"test-subscription", "test-consumer", mock_response_schema
|
||||
)
|
||||
|
||||
assert requestor.timeout == 300
|
||||
assert requestor.running is True
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.requestor.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.requestor.Subscriber')
|
||||
def test_service_requestor_with_defaults(self, mock_subscriber, mock_publisher):
|
||||
"""Test ServiceRequestor initialization with default parameters"""
|
||||
mock_pulsar_client = MagicMock()
|
||||
mock_request_schema = MagicMock()
|
||||
mock_response_schema = MagicMock()
|
||||
|
||||
requestor = ServiceRequestor(
|
||||
pulsar_client=mock_pulsar_client,
|
||||
request_queue="test-queue",
|
||||
request_schema=mock_request_schema,
|
||||
response_queue="response-queue",
|
||||
response_schema=mock_response_schema
|
||||
)
|
||||
|
||||
# Verify default values
|
||||
mock_subscriber.assert_called_once_with(
|
||||
mock_pulsar_client, "response-queue",
|
||||
"api-gateway", "api-gateway", mock_response_schema
|
||||
)
|
||||
assert requestor.timeout == 600 # Default timeout
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.requestor.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.requestor.Subscriber')
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_requestor_start(self, mock_subscriber, mock_publisher):
|
||||
"""Test ServiceRequestor start method"""
|
||||
mock_pulsar_client = MagicMock()
|
||||
mock_sub_instance = AsyncMock()
|
||||
mock_pub_instance = AsyncMock()
|
||||
mock_subscriber.return_value = mock_sub_instance
|
||||
mock_publisher.return_value = mock_pub_instance
|
||||
|
||||
requestor = ServiceRequestor(
|
||||
pulsar_client=mock_pulsar_client,
|
||||
request_queue="test-queue",
|
||||
request_schema=MagicMock(),
|
||||
response_queue="response-queue",
|
||||
response_schema=MagicMock()
|
||||
)
|
||||
|
||||
# Call start
|
||||
await requestor.start()
|
||||
|
||||
# Verify both subscriber and publisher start were called
|
||||
mock_sub_instance.start.assert_called_once()
|
||||
mock_pub_instance.start.assert_called_once()
|
||||
assert requestor.running is True
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.requestor.Publisher')
|
||||
@patch('trustgraph.gateway.dispatch.requestor.Subscriber')
|
||||
def test_service_requestor_attributes(self, mock_subscriber, mock_publisher):
|
||||
"""Test ServiceRequestor has correct attributes"""
|
||||
mock_pulsar_client = MagicMock()
|
||||
mock_pub_instance = AsyncMock()
|
||||
mock_sub_instance = AsyncMock()
|
||||
mock_publisher.return_value = mock_pub_instance
|
||||
mock_subscriber.return_value = mock_sub_instance
|
||||
|
||||
requestor = ServiceRequestor(
|
||||
pulsar_client=mock_pulsar_client,
|
||||
request_queue="test-queue",
|
||||
request_schema=MagicMock(),
|
||||
response_queue="response-queue",
|
||||
response_schema=MagicMock()
|
||||
)
|
||||
|
||||
# Verify attributes are set correctly
|
||||
assert requestor.pub == mock_pub_instance
|
||||
assert requestor.sub == mock_sub_instance
|
||||
assert requestor.running is True
|
||||
120
tests/unit/test_gateway/test_dispatch_sender.py
Normal file
120
tests/unit/test_gateway/test_dispatch_sender.py
Normal file
|
|
@ -0,0 +1,120 @@
|
|||
"""
|
||||
Tests for Gateway Service Sender
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
from trustgraph.gateway.dispatch.sender import ServiceSender
|
||||
|
||||
|
||||
class TestServiceSender:
|
||||
"""Test cases for ServiceSender class"""
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.sender.Publisher')
|
||||
def test_service_sender_initialization(self, mock_publisher):
|
||||
"""Test ServiceSender initialization"""
|
||||
mock_pulsar_client = MagicMock()
|
||||
mock_schema = MagicMock()
|
||||
|
||||
sender = ServiceSender(
|
||||
pulsar_client=mock_pulsar_client,
|
||||
queue="test-queue",
|
||||
schema=mock_schema
|
||||
)
|
||||
|
||||
# Verify Publisher was created correctly
|
||||
mock_publisher.assert_called_once_with(
|
||||
mock_pulsar_client, "test-queue", schema=mock_schema
|
||||
)
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.sender.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_sender_start(self, mock_publisher):
|
||||
"""Test ServiceSender start method"""
|
||||
mock_pub_instance = AsyncMock()
|
||||
mock_publisher.return_value = mock_pub_instance
|
||||
|
||||
sender = ServiceSender(
|
||||
pulsar_client=MagicMock(),
|
||||
queue="test-queue",
|
||||
schema=MagicMock()
|
||||
)
|
||||
|
||||
# Call start
|
||||
await sender.start()
|
||||
|
||||
# Verify publisher start was called
|
||||
mock_pub_instance.start.assert_called_once()
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.sender.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_sender_stop(self, mock_publisher):
|
||||
"""Test ServiceSender stop method"""
|
||||
mock_pub_instance = AsyncMock()
|
||||
mock_publisher.return_value = mock_pub_instance
|
||||
|
||||
sender = ServiceSender(
|
||||
pulsar_client=MagicMock(),
|
||||
queue="test-queue",
|
||||
schema=MagicMock()
|
||||
)
|
||||
|
||||
# Call stop
|
||||
await sender.stop()
|
||||
|
||||
# Verify publisher stop was called
|
||||
mock_pub_instance.stop.assert_called_once()
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.sender.Publisher')
|
||||
def test_service_sender_to_request_not_implemented(self, mock_publisher):
|
||||
"""Test ServiceSender to_request method raises RuntimeError"""
|
||||
sender = ServiceSender(
|
||||
pulsar_client=MagicMock(),
|
||||
queue="test-queue",
|
||||
schema=MagicMock()
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Not defined"):
|
||||
sender.to_request({"test": "request"})
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.sender.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_sender_process(self, mock_publisher):
|
||||
"""Test ServiceSender process method"""
|
||||
mock_pub_instance = AsyncMock()
|
||||
mock_publisher.return_value = mock_pub_instance
|
||||
|
||||
# Create a concrete sender that implements to_request
|
||||
class ConcreteSender(ServiceSender):
|
||||
def to_request(self, request):
|
||||
return {"processed": request}
|
||||
|
||||
sender = ConcreteSender(
|
||||
pulsar_client=MagicMock(),
|
||||
queue="test-queue",
|
||||
schema=MagicMock()
|
||||
)
|
||||
|
||||
test_request = {"test": "data"}
|
||||
|
||||
# Call process
|
||||
await sender.process(test_request)
|
||||
|
||||
# Verify publisher send was called with processed request
|
||||
mock_pub_instance.send.assert_called_once_with(None, {"processed": test_request})
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.sender.Publisher')
|
||||
def test_service_sender_attributes(self, mock_publisher):
|
||||
"""Test ServiceSender has correct attributes"""
|
||||
mock_pub_instance = MagicMock()
|
||||
mock_publisher.return_value = mock_pub_instance
|
||||
|
||||
sender = ServiceSender(
|
||||
pulsar_client=MagicMock(),
|
||||
queue="test-queue",
|
||||
schema=MagicMock()
|
||||
)
|
||||
|
||||
# Verify attributes are set correctly
|
||||
assert sender.pub == mock_pub_instance
|
||||
89
tests/unit/test_gateway/test_dispatch_serialize.py
Normal file
89
tests/unit/test_gateway/test_dispatch_serialize.py
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
"""
|
||||
Tests for Gateway Dispatch Serialization
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from trustgraph.gateway.dispatch.serialize import to_value, to_subgraph, serialize_value
|
||||
from trustgraph.schema import Value, Triple
|
||||
|
||||
|
||||
class TestDispatchSerialize:
|
||||
"""Test cases for dispatch serialization functions"""
|
||||
|
||||
def test_to_value_with_uri(self):
|
||||
"""Test to_value function with URI"""
|
||||
input_data = {"v": "http://example.com/resource", "e": True}
|
||||
|
||||
result = to_value(input_data)
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "http://example.com/resource"
|
||||
assert result.is_uri is True
|
||||
|
||||
def test_to_value_with_literal(self):
|
||||
"""Test to_value function with literal value"""
|
||||
input_data = {"v": "literal string", "e": False}
|
||||
|
||||
result = to_value(input_data)
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "literal string"
|
||||
assert result.is_uri is False
|
||||
|
||||
def test_to_subgraph_with_multiple_triples(self):
|
||||
"""Test to_subgraph function with multiple triples"""
|
||||
input_data = [
|
||||
{
|
||||
"s": {"v": "subject1", "e": True},
|
||||
"p": {"v": "predicate1", "e": True},
|
||||
"o": {"v": "object1", "e": False}
|
||||
},
|
||||
{
|
||||
"s": {"v": "subject2", "e": False},
|
||||
"p": {"v": "predicate2", "e": True},
|
||||
"o": {"v": "object2", "e": True}
|
||||
}
|
||||
]
|
||||
|
||||
result = to_subgraph(input_data)
|
||||
|
||||
assert len(result) == 2
|
||||
assert all(isinstance(triple, Triple) for triple in result)
|
||||
|
||||
# Check first triple
|
||||
assert result[0].s.value == "subject1"
|
||||
assert result[0].s.is_uri is True
|
||||
assert result[0].p.value == "predicate1"
|
||||
assert result[0].p.is_uri is True
|
||||
assert result[0].o.value == "object1"
|
||||
assert result[0].o.is_uri is False
|
||||
|
||||
# Check second triple
|
||||
assert result[1].s.value == "subject2"
|
||||
assert result[1].s.is_uri is False
|
||||
|
||||
def test_to_subgraph_with_empty_list(self):
|
||||
"""Test to_subgraph function with empty input"""
|
||||
input_data = []
|
||||
|
||||
result = to_subgraph(input_data)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_serialize_value_with_uri(self):
|
||||
"""Test serialize_value function with URI value"""
|
||||
value = Value(value="http://example.com/test", is_uri=True)
|
||||
|
||||
result = serialize_value(value)
|
||||
|
||||
assert result == {"v": "http://example.com/test", "e": True}
|
||||
|
||||
def test_serialize_value_with_literal(self):
|
||||
"""Test serialize_value function with literal value"""
|
||||
value = Value(value="test literal", is_uri=False)
|
||||
|
||||
result = serialize_value(value)
|
||||
|
||||
assert result == {"v": "test literal", "e": False}
|
||||
55
tests/unit/test_gateway/test_endpoint_constant.py
Normal file
55
tests/unit/test_gateway/test_endpoint_constant.py
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
"""
|
||||
Tests for Gateway Constant Endpoint
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
from aiohttp import web
|
||||
|
||||
from trustgraph.gateway.endpoint.constant_endpoint import ConstantEndpoint
|
||||
|
||||
|
||||
class TestConstantEndpoint:
|
||||
"""Test cases for ConstantEndpoint class"""
|
||||
|
||||
def test_constant_endpoint_initialization(self):
|
||||
"""Test ConstantEndpoint initialization"""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
|
||||
endpoint = ConstantEndpoint(
|
||||
endpoint_path="/api/test",
|
||||
auth=mock_auth,
|
||||
dispatcher=mock_dispatcher
|
||||
)
|
||||
|
||||
assert endpoint.path == "/api/test"
|
||||
assert endpoint.auth == mock_auth
|
||||
assert endpoint.dispatcher == mock_dispatcher
|
||||
assert endpoint.operation == "service"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_constant_endpoint_start_method(self):
|
||||
"""Test ConstantEndpoint start method (should be no-op)"""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
|
||||
endpoint = ConstantEndpoint("/api/test", mock_auth, mock_dispatcher)
|
||||
|
||||
# start() should complete without error
|
||||
await endpoint.start()
|
||||
|
||||
def test_add_routes_registers_post_handler(self):
|
||||
"""Test add_routes method registers POST route"""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
mock_app = MagicMock()
|
||||
|
||||
endpoint = ConstantEndpoint("/api/test", mock_auth, mock_dispatcher)
|
||||
endpoint.add_routes(mock_app)
|
||||
|
||||
# Verify add_routes was called with POST route
|
||||
mock_app.add_routes.assert_called_once()
|
||||
# The call should include web.post with the path and handler
|
||||
call_args = mock_app.add_routes.call_args[0][0]
|
||||
assert len(call_args) == 1 # One route added
|
||||
89
tests/unit/test_gateway/test_endpoint_manager.py
Normal file
89
tests/unit/test_gateway/test_endpoint_manager.py
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
"""
|
||||
Tests for Gateway Endpoint Manager
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from trustgraph.gateway.endpoint.manager import EndpointManager
|
||||
|
||||
|
||||
class TestEndpointManager:
|
||||
"""Test cases for EndpointManager class"""
|
||||
|
||||
def test_endpoint_manager_initialization(self):
|
||||
"""Test EndpointManager initialization creates all endpoints"""
|
||||
mock_dispatcher_manager = MagicMock()
|
||||
mock_auth = MagicMock()
|
||||
|
||||
# Mock dispatcher methods
|
||||
mock_dispatcher_manager.dispatch_global_service.return_value = MagicMock()
|
||||
mock_dispatcher_manager.dispatch_socket.return_value = MagicMock()
|
||||
mock_dispatcher_manager.dispatch_flow_service.return_value = MagicMock()
|
||||
mock_dispatcher_manager.dispatch_flow_import.return_value = MagicMock()
|
||||
mock_dispatcher_manager.dispatch_flow_export.return_value = MagicMock()
|
||||
mock_dispatcher_manager.dispatch_core_import.return_value = MagicMock()
|
||||
mock_dispatcher_manager.dispatch_core_export.return_value = MagicMock()
|
||||
|
||||
manager = EndpointManager(
|
||||
dispatcher_manager=mock_dispatcher_manager,
|
||||
auth=mock_auth,
|
||||
prometheus_url="http://prometheus:9090",
|
||||
timeout=300
|
||||
)
|
||||
|
||||
assert manager.dispatcher_manager == mock_dispatcher_manager
|
||||
assert manager.timeout == 300
|
||||
assert manager.services == {}
|
||||
assert len(manager.endpoints) > 0 # Should have multiple endpoints
|
||||
|
||||
def test_endpoint_manager_with_default_timeout(self):
|
||||
"""Test EndpointManager with default timeout value"""
|
||||
mock_dispatcher_manager = MagicMock()
|
||||
mock_auth = MagicMock()
|
||||
|
||||
# Mock dispatcher methods
|
||||
mock_dispatcher_manager.dispatch_global_service.return_value = MagicMock()
|
||||
mock_dispatcher_manager.dispatch_socket.return_value = MagicMock()
|
||||
mock_dispatcher_manager.dispatch_flow_service.return_value = MagicMock()
|
||||
mock_dispatcher_manager.dispatch_flow_import.return_value = MagicMock()
|
||||
mock_dispatcher_manager.dispatch_flow_export.return_value = MagicMock()
|
||||
mock_dispatcher_manager.dispatch_core_import.return_value = MagicMock()
|
||||
mock_dispatcher_manager.dispatch_core_export.return_value = MagicMock()
|
||||
|
||||
manager = EndpointManager(
|
||||
dispatcher_manager=mock_dispatcher_manager,
|
||||
auth=mock_auth,
|
||||
prometheus_url="http://prometheus:9090"
|
||||
)
|
||||
|
||||
assert manager.timeout == 600 # Default value
|
||||
|
||||
def test_endpoint_manager_dispatcher_calls(self):
|
||||
"""Test EndpointManager calls all required dispatcher methods"""
|
||||
mock_dispatcher_manager = MagicMock()
|
||||
mock_auth = MagicMock()
|
||||
|
||||
# Mock dispatcher methods that are actually called
|
||||
mock_dispatcher_manager.dispatch_global_service.return_value = MagicMock()
|
||||
mock_dispatcher_manager.dispatch_socket.return_value = MagicMock()
|
||||
mock_dispatcher_manager.dispatch_flow_service.return_value = MagicMock()
|
||||
mock_dispatcher_manager.dispatch_flow_import.return_value = MagicMock()
|
||||
mock_dispatcher_manager.dispatch_flow_export.return_value = MagicMock()
|
||||
mock_dispatcher_manager.dispatch_core_import.return_value = MagicMock()
|
||||
mock_dispatcher_manager.dispatch_core_export.return_value = MagicMock()
|
||||
|
||||
EndpointManager(
|
||||
dispatcher_manager=mock_dispatcher_manager,
|
||||
auth=mock_auth,
|
||||
prometheus_url="http://test:9090"
|
||||
)
|
||||
|
||||
# Verify all dispatcher methods were called during initialization
|
||||
mock_dispatcher_manager.dispatch_global_service.assert_called_once()
|
||||
mock_dispatcher_manager.dispatch_socket.assert_called() # Called twice
|
||||
mock_dispatcher_manager.dispatch_flow_service.assert_called_once()
|
||||
mock_dispatcher_manager.dispatch_flow_import.assert_called_once()
|
||||
mock_dispatcher_manager.dispatch_flow_export.assert_called_once()
|
||||
mock_dispatcher_manager.dispatch_core_import.assert_called_once()
|
||||
mock_dispatcher_manager.dispatch_core_export.assert_called_once()
|
||||
60
tests/unit/test_gateway/test_endpoint_metrics.py
Normal file
60
tests/unit/test_gateway/test_endpoint_metrics.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
"""
|
||||
Tests for Gateway Metrics Endpoint
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from trustgraph.gateway.endpoint.metrics import MetricsEndpoint
|
||||
|
||||
|
||||
class TestMetricsEndpoint:
|
||||
"""Test cases for MetricsEndpoint class"""
|
||||
|
||||
def test_metrics_endpoint_initialization(self):
|
||||
"""Test MetricsEndpoint initialization"""
|
||||
mock_auth = MagicMock()
|
||||
|
||||
endpoint = MetricsEndpoint(
|
||||
prometheus_url="http://prometheus:9090",
|
||||
endpoint_path="/metrics",
|
||||
auth=mock_auth
|
||||
)
|
||||
|
||||
assert endpoint.prometheus_url == "http://prometheus:9090"
|
||||
assert endpoint.path == "/metrics"
|
||||
assert endpoint.auth == mock_auth
|
||||
assert endpoint.operation == "service"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_metrics_endpoint_start_method(self):
|
||||
"""Test MetricsEndpoint start method (should be no-op)"""
|
||||
mock_auth = MagicMock()
|
||||
|
||||
endpoint = MetricsEndpoint(
|
||||
prometheus_url="http://localhost:9090",
|
||||
endpoint_path="/metrics",
|
||||
auth=mock_auth
|
||||
)
|
||||
|
||||
# start() should complete without error
|
||||
await endpoint.start()
|
||||
|
||||
def test_add_routes_registers_get_handler(self):
|
||||
"""Test add_routes method registers GET route with wildcard path"""
|
||||
mock_auth = MagicMock()
|
||||
mock_app = MagicMock()
|
||||
|
||||
endpoint = MetricsEndpoint(
|
||||
prometheus_url="http://prometheus:9090",
|
||||
endpoint_path="/metrics",
|
||||
auth=mock_auth
|
||||
)
|
||||
|
||||
endpoint.add_routes(mock_app)
|
||||
|
||||
# Verify add_routes was called with GET route
|
||||
mock_app.add_routes.assert_called_once()
|
||||
# The call should include web.get with wildcard path pattern
|
||||
call_args = mock_app.add_routes.call_args[0][0]
|
||||
assert len(call_args) == 1 # One route added
|
||||
133
tests/unit/test_gateway/test_endpoint_socket.py
Normal file
133
tests/unit/test_gateway/test_endpoint_socket.py
Normal file
|
|
@ -0,0 +1,133 @@
|
|||
"""
|
||||
Tests for Gateway Socket Endpoint
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
from aiohttp import WSMsgType
|
||||
|
||||
from trustgraph.gateway.endpoint.socket import SocketEndpoint
|
||||
|
||||
|
||||
class TestSocketEndpoint:
|
||||
"""Test cases for SocketEndpoint class"""
|
||||
|
||||
def test_socket_endpoint_initialization(self):
|
||||
"""Test SocketEndpoint initialization"""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
|
||||
endpoint = SocketEndpoint(
|
||||
endpoint_path="/api/socket",
|
||||
auth=mock_auth,
|
||||
dispatcher=mock_dispatcher
|
||||
)
|
||||
|
||||
assert endpoint.path == "/api/socket"
|
||||
assert endpoint.auth == mock_auth
|
||||
assert endpoint.dispatcher == mock_dispatcher
|
||||
assert endpoint.operation == "socket"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_worker_method(self):
|
||||
"""Test SocketEndpoint worker method"""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = AsyncMock()
|
||||
|
||||
endpoint = SocketEndpoint("/api/socket", mock_auth, mock_dispatcher)
|
||||
|
||||
mock_ws = MagicMock()
|
||||
mock_running = MagicMock()
|
||||
|
||||
# Call worker method
|
||||
await endpoint.worker(mock_ws, mock_dispatcher, mock_running)
|
||||
|
||||
# Verify dispatcher.run was called
|
||||
mock_dispatcher.run.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_listener_method_with_text_message(self):
|
||||
"""Test SocketEndpoint listener method with text message"""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = AsyncMock()
|
||||
|
||||
endpoint = SocketEndpoint("/api/socket", mock_auth, mock_dispatcher)
|
||||
|
||||
# Mock websocket with text message
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.type = WSMsgType.TEXT
|
||||
|
||||
# Create async iterator for websocket
|
||||
async def async_iter():
|
||||
yield mock_msg
|
||||
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.__aiter__ = lambda self: async_iter()
|
||||
mock_running = MagicMock()
|
||||
|
||||
# Call listener method
|
||||
await endpoint.listener(mock_ws, mock_dispatcher, mock_running)
|
||||
|
||||
# Verify dispatcher.receive was called with the message
|
||||
mock_dispatcher.receive.assert_called_once_with(mock_msg)
|
||||
# Verify cleanup methods were called
|
||||
mock_running.stop.assert_called_once()
|
||||
mock_ws.close.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_listener_method_with_binary_message(self):
|
||||
"""Test SocketEndpoint listener method with binary message"""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = AsyncMock()
|
||||
|
||||
endpoint = SocketEndpoint("/api/socket", mock_auth, mock_dispatcher)
|
||||
|
||||
# Mock websocket with binary message
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.type = WSMsgType.BINARY
|
||||
|
||||
# Create async iterator for websocket
|
||||
async def async_iter():
|
||||
yield mock_msg
|
||||
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.__aiter__ = lambda self: async_iter()
|
||||
mock_running = MagicMock()
|
||||
|
||||
# Call listener method
|
||||
await endpoint.listener(mock_ws, mock_dispatcher, mock_running)
|
||||
|
||||
# Verify dispatcher.receive was called with the message
|
||||
mock_dispatcher.receive.assert_called_once_with(mock_msg)
|
||||
# Verify cleanup methods were called
|
||||
mock_running.stop.assert_called_once()
|
||||
mock_ws.close.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_listener_method_with_close_message(self):
|
||||
"""Test SocketEndpoint listener method with close message"""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = AsyncMock()
|
||||
|
||||
endpoint = SocketEndpoint("/api/socket", mock_auth, mock_dispatcher)
|
||||
|
||||
# Mock websocket with close message
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.type = WSMsgType.CLOSE
|
||||
|
||||
# Create async iterator for websocket
|
||||
async def async_iter():
|
||||
yield mock_msg
|
||||
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.__aiter__ = lambda self: async_iter()
|
||||
mock_running = MagicMock()
|
||||
|
||||
# Call listener method
|
||||
await endpoint.listener(mock_ws, mock_dispatcher, mock_running)
|
||||
|
||||
# Verify dispatcher.receive was NOT called for close message
|
||||
mock_dispatcher.receive.assert_not_called()
|
||||
# Verify cleanup methods were called after break
|
||||
mock_running.stop.assert_called_once()
|
||||
mock_ws.close.assert_called_once()
|
||||
124
tests/unit/test_gateway/test_endpoint_stream.py
Normal file
124
tests/unit/test_gateway/test_endpoint_stream.py
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
"""
|
||||
Tests for Gateway Stream Endpoint
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from trustgraph.gateway.endpoint.stream_endpoint import StreamEndpoint
|
||||
|
||||
|
||||
class TestStreamEndpoint:
|
||||
"""Test cases for StreamEndpoint class"""
|
||||
|
||||
def test_stream_endpoint_initialization_with_post(self):
|
||||
"""Test StreamEndpoint initialization with POST method"""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
|
||||
endpoint = StreamEndpoint(
|
||||
endpoint_path="/api/stream",
|
||||
auth=mock_auth,
|
||||
dispatcher=mock_dispatcher,
|
||||
method="POST"
|
||||
)
|
||||
|
||||
assert endpoint.path == "/api/stream"
|
||||
assert endpoint.auth == mock_auth
|
||||
assert endpoint.dispatcher == mock_dispatcher
|
||||
assert endpoint.operation == "service"
|
||||
assert endpoint.method == "POST"
|
||||
|
||||
def test_stream_endpoint_initialization_with_get(self):
|
||||
"""Test StreamEndpoint initialization with GET method"""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
|
||||
endpoint = StreamEndpoint(
|
||||
endpoint_path="/api/stream",
|
||||
auth=mock_auth,
|
||||
dispatcher=mock_dispatcher,
|
||||
method="GET"
|
||||
)
|
||||
|
||||
assert endpoint.method == "GET"
|
||||
|
||||
def test_stream_endpoint_initialization_default_method(self):
|
||||
"""Test StreamEndpoint initialization with default POST method"""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
|
||||
endpoint = StreamEndpoint(
|
||||
endpoint_path="/api/stream",
|
||||
auth=mock_auth,
|
||||
dispatcher=mock_dispatcher
|
||||
)
|
||||
|
||||
assert endpoint.method == "POST" # Default value
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_endpoint_start_method(self):
|
||||
"""Test StreamEndpoint start method (should be no-op)"""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
|
||||
endpoint = StreamEndpoint("/api/stream", mock_auth, mock_dispatcher)
|
||||
|
||||
# start() should complete without error
|
||||
await endpoint.start()
|
||||
|
||||
def test_add_routes_with_post_method(self):
|
||||
"""Test add_routes method with POST method"""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
mock_app = MagicMock()
|
||||
|
||||
endpoint = StreamEndpoint(
|
||||
endpoint_path="/api/stream",
|
||||
auth=mock_auth,
|
||||
dispatcher=mock_dispatcher,
|
||||
method="POST"
|
||||
)
|
||||
|
||||
endpoint.add_routes(mock_app)
|
||||
|
||||
# Verify add_routes was called with POST route
|
||||
mock_app.add_routes.assert_called_once()
|
||||
call_args = mock_app.add_routes.call_args[0][0]
|
||||
assert len(call_args) == 1 # One route added
|
||||
|
||||
def test_add_routes_with_get_method(self):
|
||||
"""Test add_routes method with GET method"""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
mock_app = MagicMock()
|
||||
|
||||
endpoint = StreamEndpoint(
|
||||
endpoint_path="/api/stream",
|
||||
auth=mock_auth,
|
||||
dispatcher=mock_dispatcher,
|
||||
method="GET"
|
||||
)
|
||||
|
||||
endpoint.add_routes(mock_app)
|
||||
|
||||
# Verify add_routes was called with GET route
|
||||
mock_app.add_routes.assert_called_once()
|
||||
call_args = mock_app.add_routes.call_args[0][0]
|
||||
assert len(call_args) == 1 # One route added
|
||||
|
||||
def test_add_routes_with_invalid_method_raises_error(self):
|
||||
"""Test add_routes method with invalid method raises RuntimeError"""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
mock_app = MagicMock()
|
||||
|
||||
endpoint = StreamEndpoint(
|
||||
endpoint_path="/api/stream",
|
||||
auth=mock_auth,
|
||||
dispatcher=mock_dispatcher,
|
||||
method="INVALID"
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Bad method"):
|
||||
endpoint.add_routes(mock_app)
|
||||
53
tests/unit/test_gateway/test_endpoint_variable.py
Normal file
53
tests/unit/test_gateway/test_endpoint_variable.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
"""
|
||||
Tests for Gateway Variable Endpoint
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from trustgraph.gateway.endpoint.variable_endpoint import VariableEndpoint
|
||||
|
||||
|
||||
class TestVariableEndpoint:
|
||||
"""Test cases for VariableEndpoint class"""
|
||||
|
||||
def test_variable_endpoint_initialization(self):
|
||||
"""Test VariableEndpoint initialization"""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
|
||||
endpoint = VariableEndpoint(
|
||||
endpoint_path="/api/variable",
|
||||
auth=mock_auth,
|
||||
dispatcher=mock_dispatcher
|
||||
)
|
||||
|
||||
assert endpoint.path == "/api/variable"
|
||||
assert endpoint.auth == mock_auth
|
||||
assert endpoint.dispatcher == mock_dispatcher
|
||||
assert endpoint.operation == "service"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_variable_endpoint_start_method(self):
|
||||
"""Test VariableEndpoint start method (should be no-op)"""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
|
||||
endpoint = VariableEndpoint("/api/var", mock_auth, mock_dispatcher)
|
||||
|
||||
# start() should complete without error
|
||||
await endpoint.start()
|
||||
|
||||
def test_add_routes_registers_post_handler(self):
|
||||
"""Test add_routes method registers POST route"""
|
||||
mock_auth = MagicMock()
|
||||
mock_dispatcher = MagicMock()
|
||||
mock_app = MagicMock()
|
||||
|
||||
endpoint = VariableEndpoint("/api/variable", mock_auth, mock_dispatcher)
|
||||
endpoint.add_routes(mock_app)
|
||||
|
||||
# Verify add_routes was called with POST route
|
||||
mock_app.add_routes.assert_called_once()
|
||||
call_args = mock_app.add_routes.call_args[0][0]
|
||||
assert len(call_args) == 1 # One route added
|
||||
90
tests/unit/test_gateway/test_running.py
Normal file
90
tests/unit/test_gateway/test_running.py
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
"""
|
||||
Tests for Gateway Running utility class
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from trustgraph.gateway.running import Running
|
||||
|
||||
|
||||
class TestRunning:
|
||||
"""Test cases for Running class"""
|
||||
|
||||
def test_running_initialization(self):
|
||||
"""Test Running class initialization"""
|
||||
running = Running()
|
||||
|
||||
# Should start with running = True
|
||||
assert running.running is True
|
||||
|
||||
def test_running_get_method(self):
|
||||
"""Test Running.get() method returns current state"""
|
||||
running = Running()
|
||||
|
||||
# Should return True initially
|
||||
assert running.get() is True
|
||||
|
||||
# Should return False after stopping
|
||||
running.stop()
|
||||
assert running.get() is False
|
||||
|
||||
def test_running_stop_method(self):
|
||||
"""Test Running.stop() method sets running to False"""
|
||||
running = Running()
|
||||
|
||||
# Initially should be True
|
||||
assert running.running is True
|
||||
|
||||
# After calling stop(), should be False
|
||||
running.stop()
|
||||
assert running.running is False
|
||||
|
||||
def test_running_stop_is_idempotent(self):
|
||||
"""Test that calling stop() multiple times is safe"""
|
||||
running = Running()
|
||||
|
||||
# Stop multiple times
|
||||
running.stop()
|
||||
assert running.running is False
|
||||
|
||||
running.stop()
|
||||
assert running.running is False
|
||||
|
||||
# get() should still return False
|
||||
assert running.get() is False
|
||||
|
||||
def test_running_state_transitions(self):
|
||||
"""Test the complete state transition from running to stopped"""
|
||||
running = Running()
|
||||
|
||||
# Initial state: running
|
||||
assert running.get() is True
|
||||
assert running.running is True
|
||||
|
||||
# Transition to stopped
|
||||
running.stop()
|
||||
assert running.get() is False
|
||||
assert running.running is False
|
||||
|
||||
def test_running_multiple_instances_independent(self):
|
||||
"""Test that multiple Running instances are independent"""
|
||||
running1 = Running()
|
||||
running2 = Running()
|
||||
|
||||
# Both should start as running
|
||||
assert running1.get() is True
|
||||
assert running2.get() is True
|
||||
|
||||
# Stop only one
|
||||
running1.stop()
|
||||
|
||||
# States should be independent
|
||||
assert running1.get() is False
|
||||
assert running2.get() is True
|
||||
|
||||
# Stop the other
|
||||
running2.stop()
|
||||
|
||||
# Both should now be stopped
|
||||
assert running1.get() is False
|
||||
assert running2.get() is False
|
||||
360
tests/unit/test_gateway/test_service.py
Normal file
360
tests/unit/test_gateway/test_service.py
Normal file
|
|
@ -0,0 +1,360 @@
|
|||
"""
|
||||
Tests for Gateway Service API
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import Mock, patch, MagicMock, AsyncMock
|
||||
from aiohttp import web
|
||||
import pulsar
|
||||
|
||||
from trustgraph.gateway.service import Api, run, default_pulsar_host, default_prometheus_url, default_timeout, default_port, default_api_token
|
||||
|
||||
# Tests for Gateway Service API
|
||||
|
||||
|
||||
class TestApi:
|
||||
"""Test cases for Api class"""
|
||||
|
||||
|
||||
def test_api_initialization_with_defaults(self):
|
||||
"""Test Api initialization with default values"""
|
||||
with patch('pulsar.Client') as mock_client:
|
||||
mock_client.return_value = Mock()
|
||||
|
||||
api = Api()
|
||||
|
||||
assert api.port == default_port
|
||||
assert api.timeout == default_timeout
|
||||
assert api.pulsar_host == default_pulsar_host
|
||||
assert api.pulsar_api_key is None
|
||||
assert api.prometheus_url == default_prometheus_url + "/"
|
||||
assert api.auth.allow_all is True
|
||||
|
||||
# Verify Pulsar client was created without API key
|
||||
mock_client.assert_called_once_with(
|
||||
default_pulsar_host,
|
||||
listener_name=None
|
||||
)
|
||||
|
||||
def test_api_initialization_with_custom_config(self):
|
||||
"""Test Api initialization with custom configuration"""
|
||||
config = {
|
||||
"port": 9000,
|
||||
"timeout": 300,
|
||||
"pulsar_host": "pulsar://custom-host:6650",
|
||||
"pulsar_api_key": "test-api-key",
|
||||
"pulsar_listener": "custom-listener",
|
||||
"prometheus_url": "http://custom-prometheus:9090",
|
||||
"api_token": "secret-token"
|
||||
}
|
||||
|
||||
with patch('pulsar.Client') as mock_client, \
|
||||
patch('pulsar.AuthenticationToken') as mock_auth:
|
||||
mock_client.return_value = Mock()
|
||||
mock_auth.return_value = Mock()
|
||||
|
||||
api = Api(**config)
|
||||
|
||||
assert api.port == 9000
|
||||
assert api.timeout == 300
|
||||
assert api.pulsar_host == "pulsar://custom-host:6650"
|
||||
assert api.pulsar_api_key == "test-api-key"
|
||||
assert api.prometheus_url == "http://custom-prometheus:9090/"
|
||||
assert api.auth.token == "secret-token"
|
||||
assert api.auth.allow_all is False
|
||||
|
||||
# Verify Pulsar client was created with API key
|
||||
mock_auth.assert_called_once_with("test-api-key")
|
||||
mock_client.assert_called_once_with(
|
||||
"pulsar://custom-host:6650",
|
||||
listener_name="custom-listener",
|
||||
authentication=mock_auth.return_value
|
||||
)
|
||||
|
||||
def test_api_initialization_with_pulsar_api_key(self):
|
||||
"""Test Api initialization with Pulsar API key authentication"""
|
||||
with patch('pulsar.Client') as mock_client, \
|
||||
patch('pulsar.AuthenticationToken') as mock_auth:
|
||||
mock_client.return_value = Mock()
|
||||
mock_auth.return_value = Mock()
|
||||
|
||||
api = Api(pulsar_api_key="test-key")
|
||||
|
||||
mock_auth.assert_called_once_with("test-key")
|
||||
mock_client.assert_called_once_with(
|
||||
default_pulsar_host,
|
||||
listener_name=None,
|
||||
authentication=mock_auth.return_value
|
||||
)
|
||||
|
||||
def test_api_initialization_prometheus_url_normalization(self):
|
||||
"""Test that prometheus_url gets normalized with trailing slash"""
|
||||
with patch('pulsar.Client') as mock_client:
|
||||
mock_client.return_value = Mock()
|
||||
|
||||
# Test URL without trailing slash
|
||||
api = Api(prometheus_url="http://prometheus:9090")
|
||||
assert api.prometheus_url == "http://prometheus:9090/"
|
||||
|
||||
# Test URL with trailing slash
|
||||
api = Api(prometheus_url="http://prometheus:9090/")
|
||||
assert api.prometheus_url == "http://prometheus:9090/"
|
||||
|
||||
def test_api_initialization_empty_api_token_means_no_auth(self):
|
||||
"""Test that empty API token results in allow_all authentication"""
|
||||
with patch('pulsar.Client') as mock_client:
|
||||
mock_client.return_value = Mock()
|
||||
|
||||
api = Api(api_token="")
|
||||
assert api.auth.allow_all is True
|
||||
|
||||
def test_api_initialization_none_api_token_means_no_auth(self):
|
||||
"""Test that None API token results in allow_all authentication"""
|
||||
with patch('pulsar.Client') as mock_client:
|
||||
mock_client.return_value = Mock()
|
||||
|
||||
api = Api(api_token=None)
|
||||
assert api.auth.allow_all is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_app_factory_creates_application(self):
|
||||
"""Test that app_factory creates aiohttp application"""
|
||||
with patch('pulsar.Client') as mock_client:
|
||||
mock_client.return_value = Mock()
|
||||
|
||||
api = Api()
|
||||
|
||||
# Mock the dependencies
|
||||
api.config_receiver = Mock()
|
||||
api.config_receiver.start = AsyncMock()
|
||||
api.endpoint_manager = Mock()
|
||||
api.endpoint_manager.add_routes = Mock()
|
||||
api.endpoint_manager.start = AsyncMock()
|
||||
|
||||
app = await api.app_factory()
|
||||
|
||||
assert isinstance(app, web.Application)
|
||||
assert app._client_max_size == 256 * 1024 * 1024
|
||||
|
||||
# Verify that config receiver was started
|
||||
api.config_receiver.start.assert_called_once()
|
||||
|
||||
# Verify that endpoint manager was configured
|
||||
api.endpoint_manager.add_routes.assert_called_once_with(app)
|
||||
api.endpoint_manager.start.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_app_factory_with_custom_endpoints(self):
|
||||
"""Test app_factory with custom endpoints"""
|
||||
with patch('pulsar.Client') as mock_client:
|
||||
mock_client.return_value = Mock()
|
||||
|
||||
api = Api()
|
||||
|
||||
# Mock custom endpoints
|
||||
mock_endpoint1 = Mock()
|
||||
mock_endpoint1.add_routes = Mock()
|
||||
mock_endpoint1.start = AsyncMock()
|
||||
|
||||
mock_endpoint2 = Mock()
|
||||
mock_endpoint2.add_routes = Mock()
|
||||
mock_endpoint2.start = AsyncMock()
|
||||
|
||||
api.endpoints = [mock_endpoint1, mock_endpoint2]
|
||||
|
||||
# Mock the dependencies
|
||||
api.config_receiver = Mock()
|
||||
api.config_receiver.start = AsyncMock()
|
||||
api.endpoint_manager = Mock()
|
||||
api.endpoint_manager.add_routes = Mock()
|
||||
api.endpoint_manager.start = AsyncMock()
|
||||
|
||||
app = await api.app_factory()
|
||||
|
||||
# Verify custom endpoints were configured
|
||||
mock_endpoint1.add_routes.assert_called_once_with(app)
|
||||
mock_endpoint1.start.assert_called_once()
|
||||
mock_endpoint2.add_routes.assert_called_once_with(app)
|
||||
mock_endpoint2.start.assert_called_once()
|
||||
|
||||
def test_run_method_calls_web_run_app(self):
|
||||
"""Test that run method calls web.run_app"""
|
||||
with patch('pulsar.Client') as mock_client, \
|
||||
patch('aiohttp.web.run_app') as mock_run_app:
|
||||
mock_client.return_value = Mock()
|
||||
|
||||
api = Api(port=8080)
|
||||
api.run()
|
||||
|
||||
# Verify run_app was called once with the correct port
|
||||
mock_run_app.assert_called_once()
|
||||
args, kwargs = mock_run_app.call_args
|
||||
assert len(args) == 1 # Should have one positional arg (the coroutine)
|
||||
assert kwargs == {'port': 8080} # Should have port keyword arg
|
||||
|
||||
def test_api_components_initialization(self):
|
||||
"""Test that all API components are properly initialized"""
|
||||
with patch('pulsar.Client') as mock_client:
|
||||
mock_client.return_value = Mock()
|
||||
|
||||
api = Api()
|
||||
|
||||
# Verify all components are initialized
|
||||
assert api.config_receiver is not None
|
||||
assert api.dispatcher_manager is not None
|
||||
assert api.endpoint_manager is not None
|
||||
assert api.endpoints == []
|
||||
|
||||
# Verify component relationships
|
||||
assert api.dispatcher_manager.pulsar_client == api.pulsar_client
|
||||
assert api.dispatcher_manager.config_receiver == api.config_receiver
|
||||
assert api.endpoint_manager.dispatcher_manager == api.dispatcher_manager
|
||||
# EndpointManager doesn't store auth directly, it passes it to individual endpoints
|
||||
|
||||
|
||||
class TestRunFunction:
|
||||
"""Test cases for the run() function"""
|
||||
|
||||
def test_run_function_with_metrics_enabled(self):
|
||||
"""Test run function with metrics enabled"""
|
||||
import warnings
|
||||
# Suppress the specific async warning with a broader pattern
|
||||
warnings.filterwarnings("ignore", message=".*Api.app_factory.*was never awaited", category=RuntimeWarning)
|
||||
|
||||
with patch('argparse.ArgumentParser.parse_args') as mock_parse_args, \
|
||||
patch('trustgraph.gateway.service.start_http_server') as mock_start_http_server:
|
||||
|
||||
# Mock command line arguments
|
||||
mock_args = Mock()
|
||||
mock_args.metrics = True
|
||||
mock_args.metrics_port = 8000
|
||||
mock_parse_args.return_value = mock_args
|
||||
|
||||
# Create a simple mock instance without any async methods
|
||||
mock_api_instance = Mock()
|
||||
mock_api_instance.run = Mock()
|
||||
|
||||
# Create a mock Api class without importing the real one
|
||||
mock_api = Mock(return_value=mock_api_instance)
|
||||
|
||||
# Patch using context manager to avoid importing the real Api class
|
||||
with patch('trustgraph.gateway.service.Api', mock_api):
|
||||
# Mock vars() to return a dict
|
||||
with patch('builtins.vars') as mock_vars:
|
||||
mock_vars.return_value = {
|
||||
'metrics': True,
|
||||
'metrics_port': 8000,
|
||||
'pulsar_host': default_pulsar_host,
|
||||
'timeout': default_timeout
|
||||
}
|
||||
|
||||
run()
|
||||
|
||||
# Verify metrics server was started
|
||||
mock_start_http_server.assert_called_once_with(8000)
|
||||
|
||||
# Verify Api was created and run was called
|
||||
mock_api.assert_called_once()
|
||||
mock_api_instance.run.assert_called_once()
|
||||
|
||||
@patch('trustgraph.gateway.service.start_http_server')
|
||||
@patch('argparse.ArgumentParser.parse_args')
|
||||
def test_run_function_with_metrics_disabled(self, mock_parse_args, mock_start_http_server):
|
||||
"""Test run function with metrics disabled"""
|
||||
# Mock command line arguments
|
||||
mock_args = Mock()
|
||||
mock_args.metrics = False
|
||||
mock_parse_args.return_value = mock_args
|
||||
|
||||
# Create a simple mock instance without any async methods
|
||||
mock_api_instance = Mock()
|
||||
mock_api_instance.run = Mock()
|
||||
|
||||
# Patch the Api class inside the test without using decorators
|
||||
with patch('trustgraph.gateway.service.Api') as mock_api:
|
||||
mock_api.return_value = mock_api_instance
|
||||
|
||||
# Mock vars() to return a dict
|
||||
with patch('builtins.vars') as mock_vars:
|
||||
mock_vars.return_value = {
|
||||
'metrics': False,
|
||||
'metrics_port': 8000,
|
||||
'pulsar_host': default_pulsar_host,
|
||||
'timeout': default_timeout
|
||||
}
|
||||
|
||||
run()
|
||||
|
||||
# Verify metrics server was NOT started
|
||||
mock_start_http_server.assert_not_called()
|
||||
|
||||
# Verify Api was created and run was called
|
||||
mock_api.assert_called_once()
|
||||
mock_api_instance.run.assert_called_once()
|
||||
|
||||
@patch('argparse.ArgumentParser.parse_args')
|
||||
def test_run_function_argument_parsing(self, mock_parse_args):
|
||||
"""Test that run function properly parses command line arguments"""
|
||||
# Mock command line arguments
|
||||
mock_args = Mock()
|
||||
mock_args.metrics = False
|
||||
mock_parse_args.return_value = mock_args
|
||||
|
||||
# Create a simple mock instance without any async methods
|
||||
mock_api_instance = Mock()
|
||||
mock_api_instance.run = Mock()
|
||||
|
||||
# Mock vars() to return a dict with all expected arguments
|
||||
expected_args = {
|
||||
'pulsar_host': 'pulsar://test:6650',
|
||||
'pulsar_api_key': 'test-key',
|
||||
'pulsar_listener': 'test-listener',
|
||||
'prometheus_url': 'http://test-prometheus:9090',
|
||||
'port': 9000,
|
||||
'timeout': 300,
|
||||
'api_token': 'secret',
|
||||
'log_level': 'INFO',
|
||||
'metrics': False,
|
||||
'metrics_port': 8001
|
||||
}
|
||||
|
||||
# Patch the Api class inside the test without using decorators
|
||||
with patch('trustgraph.gateway.service.Api') as mock_api:
|
||||
mock_api.return_value = mock_api_instance
|
||||
|
||||
with patch('builtins.vars') as mock_vars:
|
||||
mock_vars.return_value = expected_args
|
||||
|
||||
run()
|
||||
|
||||
# Verify Api was created with the parsed arguments
|
||||
mock_api.assert_called_once_with(**expected_args)
|
||||
mock_api_instance.run.assert_called_once()
|
||||
|
||||
def test_run_function_creates_argument_parser(self):
|
||||
"""Test that run function creates argument parser with correct arguments"""
|
||||
with patch('argparse.ArgumentParser') as mock_parser_class:
|
||||
mock_parser = Mock()
|
||||
mock_parser_class.return_value = mock_parser
|
||||
mock_parser.parse_args.return_value = Mock(metrics=False)
|
||||
|
||||
with patch('trustgraph.gateway.service.Api') as mock_api, \
|
||||
patch('builtins.vars') as mock_vars:
|
||||
mock_vars.return_value = {'metrics': False}
|
||||
mock_api.return_value = Mock()
|
||||
|
||||
run()
|
||||
|
||||
# Verify ArgumentParser was created
|
||||
mock_parser_class.assert_called_once()
|
||||
|
||||
# Verify add_argument was called for each expected argument
|
||||
expected_arguments = [
|
||||
'pulsar-host', 'pulsar-api-key', 'pulsar-listener',
|
||||
'prometheus-url', 'port', 'timeout', 'api-token',
|
||||
'log-level', 'metrics', 'metrics-port'
|
||||
]
|
||||
|
||||
# Check that add_argument was called multiple times (once for each arg)
|
||||
assert mock_parser.add_argument.call_count >= len(expected_arguments)
|
||||
148
tests/unit/test_query/conftest.py
Normal file
148
tests/unit/test_query/conftest.py
Normal file
|
|
@ -0,0 +1,148 @@
|
|||
"""
|
||||
Shared fixtures for query tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_query_config():
|
||||
"""Base configuration for query processors"""
|
||||
return {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-query-processor'
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def qdrant_query_config(base_query_config):
|
||||
"""Configuration for Qdrant query processors"""
|
||||
return base_query_config | {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key'
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_qdrant_client():
|
||||
"""Mock Qdrant client"""
|
||||
mock_client = MagicMock()
|
||||
mock_client.query_points.return_value = []
|
||||
return mock_client
|
||||
|
||||
|
||||
# Graph embeddings query fixtures
|
||||
@pytest.fixture
|
||||
def mock_graph_embeddings_request():
|
||||
"""Mock graph embeddings request message"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'test_user'
|
||||
mock_message.collection = 'test_collection'
|
||||
return mock_message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_graph_embeddings_multiple_vectors():
|
||||
"""Mock graph embeddings request with multiple vectors"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4]]
|
||||
mock_message.limit = 3
|
||||
mock_message.user = 'multi_user'
|
||||
mock_message.collection = 'multi_collection'
|
||||
return mock_message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_graph_embeddings_query_response():
|
||||
"""Mock graph embeddings query response from Qdrant"""
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'entity': 'entity1'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'entity': 'entity2'}
|
||||
return [mock_point1, mock_point2]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_graph_embeddings_uri_response():
|
||||
"""Mock graph embeddings query response with URIs"""
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'entity': 'http://example.com/entity1'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'entity': 'https://secure.example.com/entity2'}
|
||||
mock_point3 = MagicMock()
|
||||
mock_point3.payload = {'entity': 'regular entity'}
|
||||
return [mock_point1, mock_point2, mock_point3]
|
||||
|
||||
|
||||
# Document embeddings query fixtures
|
||||
@pytest.fixture
|
||||
def mock_document_embeddings_request():
|
||||
"""Mock document embeddings request message"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'test_user'
|
||||
mock_message.collection = 'test_collection'
|
||||
return mock_message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document_embeddings_multiple_vectors():
|
||||
"""Mock document embeddings request with multiple vectors"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4]]
|
||||
mock_message.limit = 3
|
||||
mock_message.user = 'multi_user'
|
||||
mock_message.collection = 'multi_collection'
|
||||
return mock_message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document_embeddings_query_response():
|
||||
"""Mock document embeddings query response from Qdrant"""
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'doc': 'first document chunk'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'doc': 'second document chunk'}
|
||||
return [mock_point1, mock_point2]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document_embeddings_utf8_response():
|
||||
"""Mock document embeddings query response with UTF-8 content"""
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'doc': 'Document with UTF-8: café, naïve, résumé'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'doc': 'Chinese text: 你好世界'}
|
||||
return [mock_point1, mock_point2]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_empty_query_response():
|
||||
"""Mock empty query response"""
|
||||
return []
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_large_query_response():
|
||||
"""Mock large query response with many results"""
|
||||
mock_points = []
|
||||
for i in range(10):
|
||||
mock_point = MagicMock()
|
||||
mock_point.payload = {'doc': f'document chunk {i}'}
|
||||
mock_points.append(mock_point)
|
||||
return mock_points
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_mixed_dimension_vectors():
|
||||
"""Mock request with vectors of different dimensions"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4, 0.5]] # 2D and 3D
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'dim_user'
|
||||
mock_message.collection = 'dim_collection'
|
||||
return mock_message
|
||||
542
tests/unit/test_query/test_doc_embeddings_qdrant_query.py
Normal file
542
tests/unit/test_query/test_doc_embeddings_qdrant_query.py
Normal file
|
|
@ -0,0 +1,542 @@
|
|||
"""
|
||||
Unit tests for trustgraph.query.doc_embeddings.qdrant.service
|
||||
Testing document embeddings query functionality
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.query.doc_embeddings.qdrant.service import Processor
|
||||
|
||||
|
||||
class TestQdrantDocEmbeddingsQuery(IsolatedAsyncioTestCase):
|
||||
"""Test Qdrant document embeddings query functionality"""
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test basic Qdrant processor initialization"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-doc-query-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify base class initialization was called
|
||||
mock_base_init.assert_called_once()
|
||||
|
||||
# Verify QdrantClient was created with correct parameters
|
||||
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key='test-api-key')
|
||||
|
||||
# Verify processor attributes
|
||||
assert hasattr(processor, 'qdrant')
|
||||
assert processor.qdrant == mock_qdrant_instance
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_processor_initialization_with_defaults(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test processor initialization with default values"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-doc-query-processor'
|
||||
# No store_uri or api_key provided - should use defaults
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify QdrantClient was created with default URI and None API key
|
||||
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key=None)
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_query_document_embeddings_single_vector(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying document embeddings with single vector"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query response
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'doc': 'first document chunk'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'doc': 'second document chunk'}
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = [mock_point1, mock_point2]
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'test_user'
|
||||
mock_message.collection = 'test_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called with correct parameters
|
||||
expected_collection = 'd_test_user_test_collection_3'
|
||||
mock_qdrant_instance.query_points.assert_called_once_with(
|
||||
collection_name=expected_collection,
|
||||
query=[0.1, 0.2, 0.3],
|
||||
limit=5, # Direct limit, no multiplication
|
||||
with_payload=True
|
||||
)
|
||||
|
||||
# Verify result contains expected documents
|
||||
assert len(result) == 2
|
||||
# Results should be strings (document chunks)
|
||||
assert isinstance(result[0], str)
|
||||
assert isinstance(result[1], str)
|
||||
# Verify content
|
||||
assert result[0] == 'first document chunk'
|
||||
assert result[1] == 'second document chunk'
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_query_document_embeddings_multiple_vectors(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying document embeddings with multiple vectors"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query responses for different vectors
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'doc': 'document from vector 1'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'doc': 'document from vector 2'}
|
||||
mock_point3 = MagicMock()
|
||||
mock_point3.payload = {'doc': 'another document from vector 2'}
|
||||
|
||||
mock_response1 = MagicMock()
|
||||
mock_response1.points = [mock_point1]
|
||||
mock_response2 = MagicMock()
|
||||
mock_response2.points = [mock_point2, mock_point3]
|
||||
mock_qdrant_instance.query_points.side_effect = [mock_response1, mock_response2]
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with multiple vectors
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4]]
|
||||
mock_message.limit = 3
|
||||
mock_message.user = 'multi_user'
|
||||
mock_message.collection = 'multi_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called twice
|
||||
assert mock_qdrant_instance.query_points.call_count == 2
|
||||
|
||||
# Verify both collections were queried
|
||||
expected_collection = 'd_multi_user_multi_collection_2'
|
||||
calls = mock_qdrant_instance.query_points.call_args_list
|
||||
assert calls[0][1]['collection_name'] == expected_collection
|
||||
assert calls[1][1]['collection_name'] == expected_collection
|
||||
assert calls[0][1]['query'] == [0.1, 0.2]
|
||||
assert calls[1][1]['query'] == [0.3, 0.4]
|
||||
|
||||
# Verify results from both vectors are combined
|
||||
assert len(result) == 3
|
||||
assert 'document from vector 1' in result
|
||||
assert 'document from vector 2' in result
|
||||
assert 'another document from vector 2' in result
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_query_document_embeddings_with_limit(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying document embeddings respects limit parameter"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query response with many results
|
||||
mock_points = []
|
||||
for i in range(10):
|
||||
mock_point = MagicMock()
|
||||
mock_point.payload = {'doc': f'document chunk {i}'}
|
||||
mock_points.append(mock_point)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = mock_points
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with limit
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_message.limit = 3 # Should only return 3 results
|
||||
mock_message.user = 'limit_user'
|
||||
mock_message.collection = 'limit_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called with exact limit (no multiplication)
|
||||
mock_qdrant_instance.query_points.assert_called_once()
|
||||
call_args = mock_qdrant_instance.query_points.call_args
|
||||
assert call_args[1]['limit'] == 3 # Direct limit
|
||||
|
||||
# Verify result contains all returned documents (limit applied by Qdrant)
|
||||
assert len(result) == 10 # All results returned by mock
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_query_document_embeddings_empty_results(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying document embeddings with empty results"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock empty query response
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = []
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'empty_user'
|
||||
mock_message.collection = 'empty_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_query_document_embeddings_different_dimensions(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying document embeddings with different vector dimensions"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query responses
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'doc': 'document from 2D vector'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'doc': 'document from 3D vector'}
|
||||
|
||||
mock_response1 = MagicMock()
|
||||
mock_response1.points = [mock_point1]
|
||||
mock_response2 = MagicMock()
|
||||
mock_response2.points = [mock_point2]
|
||||
mock_qdrant_instance.query_points.side_effect = [mock_response1, mock_response2]
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with different dimension vectors
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4, 0.5]] # 2D and 3D
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'dim_user'
|
||||
mock_message.collection = 'dim_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called twice with different collections
|
||||
assert mock_qdrant_instance.query_points.call_count == 2
|
||||
calls = mock_qdrant_instance.query_points.call_args_list
|
||||
|
||||
# First call should use 2D collection
|
||||
assert calls[0][1]['collection_name'] == 'd_dim_user_dim_collection_2'
|
||||
assert calls[0][1]['query'] == [0.1, 0.2]
|
||||
|
||||
# Second call should use 3D collection
|
||||
assert calls[1][1]['collection_name'] == 'd_dim_user_dim_collection_3'
|
||||
assert calls[1][1]['query'] == [0.3, 0.4, 0.5]
|
||||
|
||||
# Verify results
|
||||
assert len(result) == 2
|
||||
assert 'document from 2D vector' in result
|
||||
assert 'document from 3D vector' in result
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_query_document_embeddings_utf8_encoding(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying document embeddings with UTF-8 content"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query response with UTF-8 content
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'doc': 'Document with UTF-8: café, naïve, résumé'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'doc': 'Chinese text: 你好世界'}
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = [mock_point1, mock_point2]
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'utf8_user'
|
||||
mock_message.collection = 'utf8_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
|
||||
# Verify UTF-8 content works correctly
|
||||
assert 'Document with UTF-8: café, naïve, résumé' in result
|
||||
assert 'Chinese text: 你好世界' in result
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_query_document_embeddings_qdrant_error(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying document embeddings handles Qdrant errors"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock Qdrant error
|
||||
mock_qdrant_instance.query_points.side_effect = Exception("Qdrant connection failed")
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'error_user'
|
||||
mock_message.collection = 'error_collection'
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Qdrant connection failed"):
|
||||
await processor.query_document_embeddings(mock_message)
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_query_document_embeddings_zero_limit(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying document embeddings with zero limit"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query response
|
||||
mock_point = MagicMock()
|
||||
mock_point.payload = {'doc': 'document chunk'}
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = [mock_point]
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with zero limit
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.limit = 0
|
||||
mock_message.user = 'zero_user'
|
||||
mock_message.collection = 'zero_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Should still query (with limit 0)
|
||||
mock_qdrant_instance.query_points.assert_called_once()
|
||||
call_args = mock_qdrant_instance.query_points.call_args
|
||||
assert call_args[1]['limit'] == 0
|
||||
|
||||
# Result should contain all returned documents
|
||||
assert len(result) == 1
|
||||
assert result[0] == 'document chunk'
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_query_document_embeddings_large_limit(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying document embeddings with large limit"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query response with fewer results than limit
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'doc': 'document 1'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'doc': 'document 2'}
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = [mock_point1, mock_point2]
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with large limit
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.limit = 1000 # Large limit
|
||||
mock_message.user = 'large_user'
|
||||
mock_message.collection = 'large_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Should query with full limit
|
||||
mock_qdrant_instance.query_points.assert_called_once()
|
||||
call_args = mock_qdrant_instance.query_points.call_args
|
||||
assert call_args[1]['limit'] == 1000
|
||||
|
||||
# Result should contain all available documents
|
||||
assert len(result) == 2
|
||||
assert 'document 1' in result
|
||||
assert 'document 2' in result
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_query_document_embeddings_missing_payload(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying document embeddings with missing payload data"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query response with missing 'doc' key
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'doc': 'valid document'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {} # Missing 'doc' key
|
||||
mock_point3 = MagicMock()
|
||||
mock_point3.payload = {'other_key': 'invalid'} # Wrong key
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = [mock_point1, mock_point2, mock_point3]
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'payload_user'
|
||||
mock_message.collection = 'payload_collection'
|
||||
|
||||
# Act & Assert
|
||||
# This should raise a KeyError when trying to access payload['doc']
|
||||
with pytest.raises(KeyError):
|
||||
await processor.query_document_embeddings(mock_message)
|
||||
|
||||
@patch('trustgraph.query.doc_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsQueryService.__init__')
|
||||
async def test_add_args_calls_parent(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test that add_args() calls parent add_args method"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_client.return_value = MagicMock()
|
||||
mock_parser = MagicMock()
|
||||
|
||||
# Act
|
||||
with patch('trustgraph.base.DocumentEmbeddingsQueryService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(mock_parser)
|
||||
|
||||
# Assert
|
||||
mock_parent_add_args.assert_called_once_with(mock_parser)
|
||||
|
||||
# Verify processor-specific arguments were added
|
||||
assert mock_parser.add_argument.call_count >= 2 # At least store-uri and api-key
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
537
tests/unit/test_query/test_graph_embeddings_qdrant_query.py
Normal file
537
tests/unit/test_query/test_graph_embeddings_qdrant_query.py
Normal file
|
|
@ -0,0 +1,537 @@
|
|||
"""
|
||||
Unit tests for trustgraph.query.graph_embeddings.qdrant.service
|
||||
Testing graph embeddings query functionality
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.query.graph_embeddings.qdrant.service import Processor
|
||||
|
||||
|
||||
class TestQdrantGraphEmbeddingsQuery(IsolatedAsyncioTestCase):
|
||||
"""Test Qdrant graph embeddings query functionality"""
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test basic Qdrant processor initialization"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-graph-query-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify base class initialization was called
|
||||
mock_base_init.assert_called_once()
|
||||
|
||||
# Verify QdrantClient was created with correct parameters
|
||||
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key='test-api-key')
|
||||
|
||||
# Verify processor attributes
|
||||
assert hasattr(processor, 'qdrant')
|
||||
assert processor.qdrant == mock_qdrant_instance
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_processor_initialization_with_defaults(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test processor initialization with default values"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-graph-query-processor'
|
||||
# No store_uri or api_key provided - should use defaults
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify QdrantClient was created with default URI and None API key
|
||||
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key=None)
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_create_value_http_uri(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test create_value with HTTP URI"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_client.return_value = MagicMock()
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
value = processor.create_value('http://example.com/entity')
|
||||
|
||||
# Assert
|
||||
assert hasattr(value, 'value')
|
||||
assert value.value == 'http://example.com/entity'
|
||||
assert hasattr(value, 'is_uri')
|
||||
assert value.is_uri == True
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_create_value_https_uri(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test create_value with HTTPS URI"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_client.return_value = MagicMock()
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
value = processor.create_value('https://secure.example.com/entity')
|
||||
|
||||
# Assert
|
||||
assert hasattr(value, 'value')
|
||||
assert value.value == 'https://secure.example.com/entity'
|
||||
assert hasattr(value, 'is_uri')
|
||||
assert value.is_uri == True
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_create_value_regular_string(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test create_value with regular string (non-URI)"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_client.return_value = MagicMock()
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
value = processor.create_value('regular entity name')
|
||||
|
||||
# Assert
|
||||
assert hasattr(value, 'value')
|
||||
assert value.value == 'regular entity name'
|
||||
assert hasattr(value, 'is_uri')
|
||||
assert value.is_uri == False
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_query_graph_embeddings_single_vector(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying graph embeddings with single vector"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query response
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'entity': 'entity1'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'entity': 'entity2'}
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = [mock_point1, mock_point2]
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'test_user'
|
||||
mock_message.collection = 'test_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called with correct parameters
|
||||
expected_collection = 't_test_user_test_collection_3'
|
||||
mock_qdrant_instance.query_points.assert_called_once_with(
|
||||
collection_name=expected_collection,
|
||||
query=[0.1, 0.2, 0.3],
|
||||
limit=10, # limit * 2 for deduplication
|
||||
with_payload=True
|
||||
)
|
||||
|
||||
# Verify result contains expected entities
|
||||
assert len(result) == 2
|
||||
assert all(hasattr(entity, 'value') for entity in result)
|
||||
entity_values = [entity.value for entity in result]
|
||||
assert 'entity1' in entity_values
|
||||
assert 'entity2' in entity_values
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_query_graph_embeddings_multiple_vectors(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying graph embeddings with multiple vectors"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query responses for different vectors
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'entity': 'entity1'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'entity': 'entity2'}
|
||||
mock_point3 = MagicMock()
|
||||
mock_point3.payload = {'entity': 'entity3'}
|
||||
|
||||
mock_response1 = MagicMock()
|
||||
mock_response1.points = [mock_point1, mock_point2]
|
||||
mock_response2 = MagicMock()
|
||||
mock_response2.points = [mock_point2, mock_point3]
|
||||
mock_qdrant_instance.query_points.side_effect = [mock_response1, mock_response2]
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with multiple vectors
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4]]
|
||||
mock_message.limit = 3
|
||||
mock_message.user = 'multi_user'
|
||||
mock_message.collection = 'multi_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called twice
|
||||
assert mock_qdrant_instance.query_points.call_count == 2
|
||||
|
||||
# Verify both collections were queried
|
||||
expected_collection = 't_multi_user_multi_collection_2'
|
||||
calls = mock_qdrant_instance.query_points.call_args_list
|
||||
assert calls[0][1]['collection_name'] == expected_collection
|
||||
assert calls[1][1]['collection_name'] == expected_collection
|
||||
assert calls[0][1]['query'] == [0.1, 0.2]
|
||||
assert calls[1][1]['query'] == [0.3, 0.4]
|
||||
|
||||
# Verify deduplication - entity2 appears in both results but should only appear once
|
||||
entity_values = [entity.value for entity in result]
|
||||
assert len(set(entity_values)) == len(entity_values) # All unique
|
||||
assert 'entity1' in entity_values
|
||||
assert 'entity2' in entity_values
|
||||
assert 'entity3' in entity_values
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_query_graph_embeddings_with_limit(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying graph embeddings respects limit parameter"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query response with more results than limit
|
||||
mock_points = []
|
||||
for i in range(10):
|
||||
mock_point = MagicMock()
|
||||
mock_point.payload = {'entity': f'entity{i}'}
|
||||
mock_points.append(mock_point)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = mock_points
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with limit
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_message.limit = 3 # Should only return 3 results
|
||||
mock_message.user = 'limit_user'
|
||||
mock_message.collection = 'limit_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called with limit * 2
|
||||
mock_qdrant_instance.query_points.assert_called_once()
|
||||
call_args = mock_qdrant_instance.query_points.call_args
|
||||
assert call_args[1]['limit'] == 6 # 3 * 2
|
||||
|
||||
# Verify result is limited to requested limit
|
||||
assert len(result) == 3
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_query_graph_embeddings_empty_results(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying graph embeddings with empty results"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock empty query response
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = []
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'empty_user'
|
||||
mock_message.collection = 'empty_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_query_graph_embeddings_different_dimensions(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying graph embeddings with different vector dimensions"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query responses
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'entity': 'entity2d'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'entity': 'entity3d'}
|
||||
|
||||
mock_response1 = MagicMock()
|
||||
mock_response1.points = [mock_point1]
|
||||
mock_response2 = MagicMock()
|
||||
mock_response2.points = [mock_point2]
|
||||
mock_qdrant_instance.query_points.side_effect = [mock_response1, mock_response2]
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with different dimension vectors
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2], [0.3, 0.4, 0.5]] # 2D and 3D
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'dim_user'
|
||||
mock_message.collection = 'dim_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify query was called twice with different collections
|
||||
assert mock_qdrant_instance.query_points.call_count == 2
|
||||
calls = mock_qdrant_instance.query_points.call_args_list
|
||||
|
||||
# First call should use 2D collection
|
||||
assert calls[0][1]['collection_name'] == 't_dim_user_dim_collection_2'
|
||||
assert calls[0][1]['query'] == [0.1, 0.2]
|
||||
|
||||
# Second call should use 3D collection
|
||||
assert calls[1][1]['collection_name'] == 't_dim_user_dim_collection_3'
|
||||
assert calls[1][1]['query'] == [0.3, 0.4, 0.5]
|
||||
|
||||
# Verify results
|
||||
entity_values = [entity.value for entity in result]
|
||||
assert 'entity2d' in entity_values
|
||||
assert 'entity3d' in entity_values
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_query_graph_embeddings_uri_detection(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying graph embeddings with URI detection"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query response with URIs and regular strings
|
||||
mock_point1 = MagicMock()
|
||||
mock_point1.payload = {'entity': 'http://example.com/entity1'}
|
||||
mock_point2 = MagicMock()
|
||||
mock_point2.payload = {'entity': 'https://secure.example.com/entity2'}
|
||||
mock_point3 = MagicMock()
|
||||
mock_point3.payload = {'entity': 'regular entity'}
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = [mock_point1, mock_point2, mock_point3]
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'uri_user'
|
||||
mock_message.collection = 'uri_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 3
|
||||
|
||||
# Check URI entities
|
||||
uri_entities = [entity for entity in result if hasattr(entity, 'is_uri') and entity.is_uri]
|
||||
assert len(uri_entities) == 2
|
||||
uri_values = [entity.value for entity in uri_entities]
|
||||
assert 'http://example.com/entity1' in uri_values
|
||||
assert 'https://secure.example.com/entity2' in uri_values
|
||||
|
||||
# Check regular entities
|
||||
regular_entities = [entity for entity in result if hasattr(entity, 'is_uri') and not entity.is_uri]
|
||||
assert len(regular_entities) == 1
|
||||
assert regular_entities[0].value == 'regular entity'
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_query_graph_embeddings_qdrant_error(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying graph embeddings handles Qdrant errors"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock Qdrant error
|
||||
mock_qdrant_instance.query_points.side_effect = Exception("Qdrant connection failed")
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.limit = 5
|
||||
mock_message.user = 'error_user'
|
||||
mock_message.collection = 'error_collection'
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Qdrant connection failed"):
|
||||
await processor.query_graph_embeddings(mock_message)
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_query_graph_embeddings_zero_limit(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test querying graph embeddings with zero limit"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
# Mock query response - even with zero limit, Qdrant might return results
|
||||
mock_point = MagicMock()
|
||||
mock_point.payload = {'entity': 'entity1'}
|
||||
mock_response = MagicMock()
|
||||
mock_response.points = [mock_point]
|
||||
mock_qdrant_instance.query_points.return_value = mock_response
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with zero limit
|
||||
mock_message = MagicMock()
|
||||
mock_message.vectors = [[0.1, 0.2]]
|
||||
mock_message.limit = 0
|
||||
mock_message.user = 'zero_user'
|
||||
mock_message.collection = 'zero_collection'
|
||||
|
||||
# Act
|
||||
result = await processor.query_graph_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Should still query (with limit 0)
|
||||
mock_qdrant_instance.query_points.assert_called_once()
|
||||
call_args = mock_qdrant_instance.query_points.call_args
|
||||
assert call_args[1]['limit'] == 0 # 0 * 2 = 0
|
||||
|
||||
# With zero limit, the logic still adds one entity before checking the limit
|
||||
# So it returns one result (current behavior, not ideal but actual)
|
||||
assert len(result) == 1
|
||||
assert result[0].value == 'entity1'
|
||||
|
||||
@patch('trustgraph.query.graph_embeddings.qdrant.service.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsQueryService.__init__')
|
||||
async def test_add_args_calls_parent(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test that add_args() calls parent add_args method"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_client.return_value = MagicMock()
|
||||
mock_parser = MagicMock()
|
||||
|
||||
# Act
|
||||
with patch('trustgraph.base.GraphEmbeddingsQueryService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(mock_parser)
|
||||
|
||||
# Assert
|
||||
mock_parent_add_args.assert_called_once_with(mock_parser)
|
||||
|
||||
# Verify processor-specific arguments were added
|
||||
assert mock_parser.add_argument.call_count >= 2 # At least store-uri and api-key
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
539
tests/unit/test_query/test_triples_cassandra_query.py
Normal file
539
tests/unit/test_query/test_triples_cassandra_query.py
Normal file
|
|
@ -0,0 +1,539 @@
|
|||
"""
|
||||
Tests for Cassandra triples query service
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from trustgraph.query.triples.cassandra.service import Processor
|
||||
from trustgraph.schema import Value
|
||||
|
||||
|
||||
class TestCassandraQueryProcessor:
|
||||
"""Test cases for Cassandra query processor"""
|
||||
|
||||
@pytest.fixture
|
||||
def processor(self):
|
||||
"""Create a processor instance for testing"""
|
||||
return Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id='test-cassandra-query',
|
||||
graph_host='localhost'
|
||||
)
|
||||
|
||||
def test_create_value_with_http_uri(self, processor):
|
||||
"""Test create_value with HTTP URI"""
|
||||
result = processor.create_value("http://example.com/resource")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "http://example.com/resource"
|
||||
assert result.is_uri is True
|
||||
|
||||
def test_create_value_with_https_uri(self, processor):
|
||||
"""Test create_value with HTTPS URI"""
|
||||
result = processor.create_value("https://example.com/resource")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "https://example.com/resource"
|
||||
assert result.is_uri is True
|
||||
|
||||
def test_create_value_with_literal(self, processor):
|
||||
"""Test create_value with literal value"""
|
||||
result = processor.create_value("just a literal string")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "just a literal string"
|
||||
assert result.is_uri is False
|
||||
|
||||
def test_create_value_with_empty_string(self, processor):
|
||||
"""Test create_value with empty string"""
|
||||
result = processor.create_value("")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == ""
|
||||
assert result.is_uri is False
|
||||
|
||||
def test_create_value_with_partial_uri(self, processor):
|
||||
"""Test create_value with string that looks like URI but isn't complete"""
|
||||
result = processor.create_value("http")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "http"
|
||||
assert result.is_uri is False
|
||||
|
||||
def test_create_value_with_ftp_uri(self, processor):
|
||||
"""Test create_value with FTP URI (should not be detected as URI)"""
|
||||
result = processor.create_value("ftp://example.com/file")
|
||||
|
||||
assert isinstance(result, Value)
|
||||
assert result.value == "ftp://example.com/file"
|
||||
assert result.is_uri is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
|
||||
async def test_query_triples_spo_query(self, mock_trustgraph):
|
||||
"""Test querying triples with subject, predicate, and object specified"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
# Setup mock TrustGraph
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
mock_tg_instance.get_spo.return_value = None # SPO query returns None if found
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id='test-cassandra-query',
|
||||
graph_host='localhost'
|
||||
)
|
||||
|
||||
# Create query request with all SPO values
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
p=Value(value='test_predicate', is_uri=False),
|
||||
o=Value(value='test_object', is_uri=False),
|
||||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
# Verify TrustGraph was created with correct parameters
|
||||
mock_trustgraph.assert_called_once_with(
|
||||
hosts=['localhost'],
|
||||
keyspace='test_user',
|
||||
table='test_collection'
|
||||
)
|
||||
|
||||
# Verify get_spo was called with correct parameters
|
||||
mock_tg_instance.get_spo.assert_called_once_with(
|
||||
'test_subject', 'test_predicate', 'test_object', limit=100
|
||||
)
|
||||
|
||||
# Verify result contains the queried triple
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'test_subject'
|
||||
assert result[0].p.value == 'test_predicate'
|
||||
assert result[0].o.value == 'test_object'
|
||||
|
||||
def test_processor_initialization_with_defaults(self):
|
||||
"""Test processor initialization with default parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
assert processor.graph_host == ['localhost']
|
||||
assert processor.username is None
|
||||
assert processor.password is None
|
||||
assert processor.table is None
|
||||
|
||||
def test_processor_initialization_with_custom_params(self):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
graph_host='cassandra.example.com',
|
||||
graph_username='queryuser',
|
||||
graph_password='querypass'
|
||||
)
|
||||
|
||||
assert processor.graph_host == ['cassandra.example.com']
|
||||
assert processor.username == 'queryuser'
|
||||
assert processor.password == 'querypass'
|
||||
assert processor.table is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
|
||||
async def test_query_triples_sp_pattern(self, mock_trustgraph):
|
||||
"""Test SP query pattern (subject and predicate, no object)"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
# Setup mock TrustGraph and response
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.o = 'result_object'
|
||||
mock_tg_instance.get_sp.return_value = [mock_result]
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
p=Value(value='test_predicate', is_uri=False),
|
||||
o=None,
|
||||
limit=50
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
mock_tg_instance.get_sp.assert_called_once_with('test_subject', 'test_predicate', limit=50)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'test_subject'
|
||||
assert result[0].p.value == 'test_predicate'
|
||||
assert result[0].o.value == 'result_object'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
|
||||
async def test_query_triples_s_pattern(self, mock_trustgraph):
|
||||
"""Test S query pattern (subject only)"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.p = 'result_predicate'
|
||||
mock_result.o = 'result_object'
|
||||
mock_tg_instance.get_s.return_value = [mock_result]
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
p=None,
|
||||
o=None,
|
||||
limit=25
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
mock_tg_instance.get_s.assert_called_once_with('test_subject', limit=25)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'test_subject'
|
||||
assert result[0].p.value == 'result_predicate'
|
||||
assert result[0].o.value == 'result_object'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
|
||||
async def test_query_triples_p_pattern(self, mock_trustgraph):
|
||||
"""Test P query pattern (predicate only)"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.s = 'result_subject'
|
||||
mock_result.o = 'result_object'
|
||||
mock_tg_instance.get_p.return_value = [mock_result]
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=Value(value='test_predicate', is_uri=False),
|
||||
o=None,
|
||||
limit=10
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
mock_tg_instance.get_p.assert_called_once_with('test_predicate', limit=10)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'result_subject'
|
||||
assert result[0].p.value == 'test_predicate'
|
||||
assert result[0].o.value == 'result_object'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
|
||||
async def test_query_triples_o_pattern(self, mock_trustgraph):
|
||||
"""Test O query pattern (object only)"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.s = 'result_subject'
|
||||
mock_result.p = 'result_predicate'
|
||||
mock_tg_instance.get_o.return_value = [mock_result]
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=None,
|
||||
o=Value(value='test_object', is_uri=False),
|
||||
limit=75
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
mock_tg_instance.get_o.assert_called_once_with('test_object', limit=75)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'result_subject'
|
||||
assert result[0].p.value == 'result_predicate'
|
||||
assert result[0].o.value == 'test_object'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
|
||||
async def test_query_triples_get_all_pattern(self, mock_trustgraph):
|
||||
"""Test query pattern with no constraints (get all)"""
|
||||
from trustgraph.schema import TriplesQueryRequest
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.s = 'all_subject'
|
||||
mock_result.p = 'all_predicate'
|
||||
mock_result.o = 'all_object'
|
||||
mock_tg_instance.get_all.return_value = [mock_result]
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=None,
|
||||
p=None,
|
||||
o=None,
|
||||
limit=1000
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
mock_tg_instance.get_all.assert_called_once_with(limit=1000)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.value == 'all_subject'
|
||||
assert result[0].p.value == 'all_predicate'
|
||||
assert result[0].o.value == 'all_object'
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
from argparse import ArgumentParser
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
# Mock the parent class add_args method
|
||||
with patch('trustgraph.query.triples.cassandra.service.TriplesQueryService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Verify parent add_args was called
|
||||
mock_parent_add_args.assert_called_once_with(parser)
|
||||
|
||||
# Verify our specific arguments were added
|
||||
args = parser.parse_args([])
|
||||
|
||||
assert hasattr(args, 'graph_host')
|
||||
assert args.graph_host == 'localhost'
|
||||
assert hasattr(args, 'graph_username')
|
||||
assert args.graph_username is None
|
||||
assert hasattr(args, 'graph_password')
|
||||
assert args.graph_password is None
|
||||
|
||||
def test_add_args_with_custom_values(self):
|
||||
"""Test add_args with custom command line values"""
|
||||
from argparse import ArgumentParser
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.query.triples.cassandra.service.TriplesQueryService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with custom values
|
||||
args = parser.parse_args([
|
||||
'--graph-host', 'query.cassandra.com',
|
||||
'--graph-username', 'queryuser',
|
||||
'--graph-password', 'querypass'
|
||||
])
|
||||
|
||||
assert args.graph_host == 'query.cassandra.com'
|
||||
assert args.graph_username == 'queryuser'
|
||||
assert args.graph_password == 'querypass'
|
||||
|
||||
def test_add_args_short_form(self):
|
||||
"""Test add_args with short form arguments"""
|
||||
from argparse import ArgumentParser
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.query.triples.cassandra.service.TriplesQueryService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with short form
|
||||
args = parser.parse_args(['-g', 'short.query.com'])
|
||||
|
||||
assert args.graph_host == 'short.query.com'
|
||||
|
||||
@patch('trustgraph.query.triples.cassandra.service.Processor.launch')
|
||||
def test_run_function(self, mock_launch):
|
||||
"""Test the run function calls Processor.launch with correct parameters"""
|
||||
from trustgraph.query.triples.cassandra.service import run, default_ident
|
||||
|
||||
run()
|
||||
|
||||
mock_launch.assert_called_once_with(default_ident, '\nTriples query service. Input is a (s, p, o) triple, some values may be\nnull. Output is a list of triples.\n')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
|
||||
async def test_query_triples_with_authentication(self, mock_trustgraph):
|
||||
"""Test querying with username and password authentication"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
mock_tg_instance.get_spo.return_value = None
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
graph_username='authuser',
|
||||
graph_password='authpass'
|
||||
)
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
p=Value(value='test_predicate', is_uri=False),
|
||||
o=Value(value='test_object', is_uri=False),
|
||||
limit=100
|
||||
)
|
||||
|
||||
await processor.query_triples(query)
|
||||
|
||||
# Verify TrustGraph was created with authentication
|
||||
mock_trustgraph.assert_called_once_with(
|
||||
hosts=['localhost'],
|
||||
keyspace='test_user',
|
||||
table='test_collection',
|
||||
username='authuser',
|
||||
password='authpass'
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
|
||||
async def test_query_triples_table_reuse(self, mock_trustgraph):
|
||||
"""Test that TrustGraph is reused for same table"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
mock_tg_instance.get_spo.return_value = None
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
p=Value(value='test_predicate', is_uri=False),
|
||||
o=Value(value='test_object', is_uri=False),
|
||||
limit=100
|
||||
)
|
||||
|
||||
# First query should create TrustGraph
|
||||
await processor.query_triples(query)
|
||||
assert mock_trustgraph.call_count == 1
|
||||
|
||||
# Second query with same table should reuse TrustGraph
|
||||
await processor.query_triples(query)
|
||||
assert mock_trustgraph.call_count == 1 # Should not increase
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
|
||||
async def test_query_triples_table_switching(self, mock_trustgraph):
|
||||
"""Test table switching creates new TrustGraph"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
mock_tg_instance1 = MagicMock()
|
||||
mock_tg_instance2 = MagicMock()
|
||||
mock_trustgraph.side_effect = [mock_tg_instance1, mock_tg_instance2]
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
# First query
|
||||
query1 = TriplesQueryRequest(
|
||||
user='user1',
|
||||
collection='collection1',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
p=None,
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
||||
await processor.query_triples(query1)
|
||||
assert processor.table == ('user1', 'collection1')
|
||||
|
||||
# Second query with different table
|
||||
query2 = TriplesQueryRequest(
|
||||
user='user2',
|
||||
collection='collection2',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
p=None,
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
||||
await processor.query_triples(query2)
|
||||
assert processor.table == ('user2', 'collection2')
|
||||
|
||||
# Verify TrustGraph was created twice
|
||||
assert mock_trustgraph.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
|
||||
async def test_query_triples_exception_handling(self, mock_trustgraph):
|
||||
"""Test exception handling during query execution"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
mock_tg_instance.get_spo.side_effect = Exception("Query failed")
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
p=Value(value='test_predicate', is_uri=False),
|
||||
o=Value(value='test_object', is_uri=False),
|
||||
limit=100
|
||||
)
|
||||
|
||||
with pytest.raises(Exception, match="Query failed"):
|
||||
await processor.query_triples(query)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.TrustGraph')
|
||||
async def test_query_triples_multiple_results(self, mock_trustgraph):
|
||||
"""Test query returning multiple results"""
|
||||
from trustgraph.schema import TriplesQueryRequest, Value
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
# Mock multiple results
|
||||
mock_result1 = MagicMock()
|
||||
mock_result1.o = 'object1'
|
||||
mock_result2 = MagicMock()
|
||||
mock_result2.o = 'object2'
|
||||
mock_tg_instance.get_sp.return_value = [mock_result1, mock_result2]
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
query = TriplesQueryRequest(
|
||||
user='test_user',
|
||||
collection='test_collection',
|
||||
s=Value(value='test_subject', is_uri=False),
|
||||
p=Value(value='test_predicate', is_uri=False),
|
||||
o=None,
|
||||
limit=100
|
||||
)
|
||||
|
||||
result = await processor.query_triples(query)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0].o.value == 'object1'
|
||||
assert result[1].o.value == 'object2'
|
||||
475
tests/unit/test_retrieval/test_document_rag.py
Normal file
475
tests/unit/test_retrieval/test_document_rag.py
Normal file
|
|
@ -0,0 +1,475 @@
|
|||
"""
|
||||
Tests for DocumentRAG retrieval implementation
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
|
||||
from trustgraph.retrieval.document_rag.document_rag import DocumentRag, Query
|
||||
|
||||
|
||||
class TestDocumentRag:
|
||||
"""Test cases for DocumentRag class"""
|
||||
|
||||
def test_document_rag_initialization_with_defaults(self):
|
||||
"""Test DocumentRag initialization with default verbose setting"""
|
||||
# Create mock clients
|
||||
mock_prompt_client = MagicMock()
|
||||
mock_embeddings_client = MagicMock()
|
||||
mock_doc_embeddings_client = MagicMock()
|
||||
|
||||
# Initialize DocumentRag
|
||||
document_rag = DocumentRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
embeddings_client=mock_embeddings_client,
|
||||
doc_embeddings_client=mock_doc_embeddings_client
|
||||
)
|
||||
|
||||
# Verify initialization
|
||||
assert document_rag.prompt_client == mock_prompt_client
|
||||
assert document_rag.embeddings_client == mock_embeddings_client
|
||||
assert document_rag.doc_embeddings_client == mock_doc_embeddings_client
|
||||
assert document_rag.verbose is False # Default value
|
||||
|
||||
def test_document_rag_initialization_with_verbose(self):
|
||||
"""Test DocumentRag initialization with verbose enabled"""
|
||||
# Create mock clients
|
||||
mock_prompt_client = MagicMock()
|
||||
mock_embeddings_client = MagicMock()
|
||||
mock_doc_embeddings_client = MagicMock()
|
||||
|
||||
# Initialize DocumentRag with verbose=True
|
||||
document_rag = DocumentRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
embeddings_client=mock_embeddings_client,
|
||||
doc_embeddings_client=mock_doc_embeddings_client,
|
||||
verbose=True
|
||||
)
|
||||
|
||||
# Verify initialization
|
||||
assert document_rag.prompt_client == mock_prompt_client
|
||||
assert document_rag.embeddings_client == mock_embeddings_client
|
||||
assert document_rag.doc_embeddings_client == mock_doc_embeddings_client
|
||||
assert document_rag.verbose is True
|
||||
|
||||
|
||||
class TestQuery:
|
||||
"""Test cases for Query class"""
|
||||
|
||||
def test_query_initialization_with_defaults(self):
|
||||
"""Test Query initialization with default parameters"""
|
||||
# Create mock DocumentRag
|
||||
mock_rag = MagicMock()
|
||||
|
||||
# Initialize Query with defaults
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Verify initialization
|
||||
assert query.rag == mock_rag
|
||||
assert query.user == "test_user"
|
||||
assert query.collection == "test_collection"
|
||||
assert query.verbose is False
|
||||
assert query.doc_limit == 20 # Default value
|
||||
|
||||
def test_query_initialization_with_custom_doc_limit(self):
|
||||
"""Test Query initialization with custom doc_limit"""
|
||||
# Create mock DocumentRag
|
||||
mock_rag = MagicMock()
|
||||
|
||||
# Initialize Query with custom doc_limit
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="custom_user",
|
||||
collection="custom_collection",
|
||||
verbose=True,
|
||||
doc_limit=50
|
||||
)
|
||||
|
||||
# Verify initialization
|
||||
assert query.rag == mock_rag
|
||||
assert query.user == "custom_user"
|
||||
assert query.collection == "custom_collection"
|
||||
assert query.verbose is True
|
||||
assert query.doc_limit == 50
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_vector_method(self):
|
||||
"""Test Query.get_vector method calls embeddings client correctly"""
|
||||
# Create mock DocumentRag with embeddings client
|
||||
mock_rag = MagicMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_rag.embeddings_client = mock_embeddings_client
|
||||
|
||||
# Mock the embed method to return test vectors
|
||||
expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
mock_embeddings_client.embed.return_value = expected_vectors
|
||||
|
||||
# Initialize Query
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Call get_vector
|
||||
test_query = "What documents are relevant?"
|
||||
result = await query.get_vector(test_query)
|
||||
|
||||
# Verify embeddings client was called correctly
|
||||
mock_embeddings_client.embed.assert_called_once_with(test_query)
|
||||
|
||||
# Verify result matches expected vectors
|
||||
assert result == expected_vectors
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_docs_method(self):
|
||||
"""Test Query.get_docs method retrieves documents correctly"""
|
||||
# Create mock DocumentRag with clients
|
||||
mock_rag = MagicMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_doc_embeddings_client = AsyncMock()
|
||||
mock_rag.embeddings_client = mock_embeddings_client
|
||||
mock_rag.doc_embeddings_client = mock_doc_embeddings_client
|
||||
|
||||
# Mock the embedding and document query responses
|
||||
test_vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_embeddings_client.embed.return_value = test_vectors
|
||||
|
||||
# Mock document results
|
||||
test_docs = ["Document 1 content", "Document 2 content"]
|
||||
mock_doc_embeddings_client.query.return_value = test_docs
|
||||
|
||||
# Initialize Query
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
doc_limit=15
|
||||
)
|
||||
|
||||
# Call get_docs
|
||||
test_query = "Find relevant documents"
|
||||
result = await query.get_docs(test_query)
|
||||
|
||||
# Verify embeddings client was called
|
||||
mock_embeddings_client.embed.assert_called_once_with(test_query)
|
||||
|
||||
# Verify doc embeddings client was called correctly
|
||||
mock_doc_embeddings_client.query.assert_called_once_with(
|
||||
test_vectors,
|
||||
limit=15,
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
# Verify result is list of documents
|
||||
assert result == test_docs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_query_method(self):
|
||||
"""Test DocumentRag.query method orchestrates full document RAG pipeline"""
|
||||
# Create mock clients
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_doc_embeddings_client = AsyncMock()
|
||||
|
||||
# Mock embeddings and document responses
|
||||
test_vectors = [[0.1, 0.2, 0.3]]
|
||||
test_docs = ["Relevant document content", "Another document"]
|
||||
expected_response = "This is the document RAG response"
|
||||
|
||||
mock_embeddings_client.embed.return_value = test_vectors
|
||||
mock_doc_embeddings_client.query.return_value = test_docs
|
||||
mock_prompt_client.document_prompt.return_value = expected_response
|
||||
|
||||
# Initialize DocumentRag
|
||||
document_rag = DocumentRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
embeddings_client=mock_embeddings_client,
|
||||
doc_embeddings_client=mock_doc_embeddings_client,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Call DocumentRag.query
|
||||
result = await document_rag.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
doc_limit=10
|
||||
)
|
||||
|
||||
# Verify embeddings client was called
|
||||
mock_embeddings_client.embed.assert_called_once_with("test query")
|
||||
|
||||
# Verify doc embeddings client was called
|
||||
mock_doc_embeddings_client.query.assert_called_once_with(
|
||||
test_vectors,
|
||||
limit=10,
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
# Verify prompt client was called with documents and query
|
||||
mock_prompt_client.document_prompt.assert_called_once_with(
|
||||
query="test query",
|
||||
documents=test_docs
|
||||
)
|
||||
|
||||
# Verify result
|
||||
assert result == expected_response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_query_with_defaults(self):
|
||||
"""Test DocumentRag.query method with default parameters"""
|
||||
# Create mock clients
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_doc_embeddings_client = AsyncMock()
|
||||
|
||||
# Mock responses
|
||||
mock_embeddings_client.embed.return_value = [[0.1, 0.2]]
|
||||
mock_doc_embeddings_client.query.return_value = ["Default doc"]
|
||||
mock_prompt_client.document_prompt.return_value = "Default response"
|
||||
|
||||
# Initialize DocumentRag
|
||||
document_rag = DocumentRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
embeddings_client=mock_embeddings_client,
|
||||
doc_embeddings_client=mock_doc_embeddings_client
|
||||
)
|
||||
|
||||
# Call DocumentRag.query with minimal parameters
|
||||
result = await document_rag.query("simple query")
|
||||
|
||||
# Verify default parameters were used
|
||||
mock_doc_embeddings_client.query.assert_called_once_with(
|
||||
[[0.1, 0.2]],
|
||||
limit=20, # Default doc_limit
|
||||
user="trustgraph", # Default user
|
||||
collection="default" # Default collection
|
||||
)
|
||||
|
||||
assert result == "Default response"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_docs_with_verbose_output(self):
|
||||
"""Test Query.get_docs method with verbose logging"""
|
||||
# Create mock DocumentRag with clients
|
||||
mock_rag = MagicMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_doc_embeddings_client = AsyncMock()
|
||||
mock_rag.embeddings_client = mock_embeddings_client
|
||||
mock_rag.doc_embeddings_client = mock_doc_embeddings_client
|
||||
|
||||
# Mock responses
|
||||
mock_embeddings_client.embed.return_value = [[0.7, 0.8]]
|
||||
mock_doc_embeddings_client.query.return_value = ["Verbose test doc"]
|
||||
|
||||
# Initialize Query with verbose=True
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=True,
|
||||
doc_limit=5
|
||||
)
|
||||
|
||||
# Call get_docs
|
||||
result = await query.get_docs("verbose test")
|
||||
|
||||
# Verify calls were made
|
||||
mock_embeddings_client.embed.assert_called_once_with("verbose test")
|
||||
mock_doc_embeddings_client.query.assert_called_once()
|
||||
|
||||
# Verify result
|
||||
assert result == ["Verbose test doc"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_query_with_verbose(self):
|
||||
"""Test DocumentRag.query method with verbose logging enabled"""
|
||||
# Create mock clients
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_doc_embeddings_client = AsyncMock()
|
||||
|
||||
# Mock responses
|
||||
mock_embeddings_client.embed.return_value = [[0.3, 0.4]]
|
||||
mock_doc_embeddings_client.query.return_value = ["Verbose doc content"]
|
||||
mock_prompt_client.document_prompt.return_value = "Verbose RAG response"
|
||||
|
||||
# Initialize DocumentRag with verbose=True
|
||||
document_rag = DocumentRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
embeddings_client=mock_embeddings_client,
|
||||
doc_embeddings_client=mock_doc_embeddings_client,
|
||||
verbose=True
|
||||
)
|
||||
|
||||
# Call DocumentRag.query
|
||||
result = await document_rag.query("verbose query test")
|
||||
|
||||
# Verify all clients were called
|
||||
mock_embeddings_client.embed.assert_called_once_with("verbose query test")
|
||||
mock_doc_embeddings_client.query.assert_called_once()
|
||||
mock_prompt_client.document_prompt.assert_called_once_with(
|
||||
query="verbose query test",
|
||||
documents=["Verbose doc content"]
|
||||
)
|
||||
|
||||
assert result == "Verbose RAG response"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_docs_with_empty_results(self):
|
||||
"""Test Query.get_docs method when no documents are found"""
|
||||
# Create mock DocumentRag with clients
|
||||
mock_rag = MagicMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_doc_embeddings_client = AsyncMock()
|
||||
mock_rag.embeddings_client = mock_embeddings_client
|
||||
mock_rag.doc_embeddings_client = mock_doc_embeddings_client
|
||||
|
||||
# Mock responses - empty document list
|
||||
mock_embeddings_client.embed.return_value = [[0.1, 0.2]]
|
||||
mock_doc_embeddings_client.query.return_value = [] # No documents found
|
||||
|
||||
# Initialize Query
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Call get_docs
|
||||
result = await query.get_docs("query with no results")
|
||||
|
||||
# Verify calls were made
|
||||
mock_embeddings_client.embed.assert_called_once_with("query with no results")
|
||||
mock_doc_embeddings_client.query.assert_called_once()
|
||||
|
||||
# Verify empty result is returned
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_query_with_empty_documents(self):
|
||||
"""Test DocumentRag.query method when no documents are retrieved"""
|
||||
# Create mock clients
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_doc_embeddings_client = AsyncMock()
|
||||
|
||||
# Mock responses - no documents found
|
||||
mock_embeddings_client.embed.return_value = [[0.5, 0.6]]
|
||||
mock_doc_embeddings_client.query.return_value = [] # Empty document list
|
||||
mock_prompt_client.document_prompt.return_value = "No documents found response"
|
||||
|
||||
# Initialize DocumentRag
|
||||
document_rag = DocumentRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
embeddings_client=mock_embeddings_client,
|
||||
doc_embeddings_client=mock_doc_embeddings_client,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Call DocumentRag.query
|
||||
result = await document_rag.query("query with no matching docs")
|
||||
|
||||
# Verify prompt client was called with empty document list
|
||||
mock_prompt_client.document_prompt.assert_called_once_with(
|
||||
query="query with no matching docs",
|
||||
documents=[]
|
||||
)
|
||||
|
||||
assert result == "No documents found response"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_vector_with_verbose(self):
|
||||
"""Test Query.get_vector method with verbose logging"""
|
||||
# Create mock DocumentRag with embeddings client
|
||||
mock_rag = MagicMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_rag.embeddings_client = mock_embeddings_client
|
||||
|
||||
# Mock the embed method
|
||||
expected_vectors = [[0.9, 1.0, 1.1]]
|
||||
mock_embeddings_client.embed.return_value = expected_vectors
|
||||
|
||||
# Initialize Query with verbose=True
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=True
|
||||
)
|
||||
|
||||
# Call get_vector
|
||||
result = await query.get_vector("verbose vector test")
|
||||
|
||||
# Verify embeddings client was called
|
||||
mock_embeddings_client.embed.assert_called_once_with("verbose vector test")
|
||||
|
||||
# Verify result
|
||||
assert result == expected_vectors
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_integration_flow(self):
|
||||
"""Test complete DocumentRag integration with realistic data flow"""
|
||||
# Create mock clients
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_doc_embeddings_client = AsyncMock()
|
||||
|
||||
# Mock realistic responses
|
||||
query_text = "What is machine learning?"
|
||||
query_vectors = [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||
retrieved_docs = [
|
||||
"Machine learning is a subset of artificial intelligence...",
|
||||
"ML algorithms learn patterns from data to make predictions...",
|
||||
"Common ML techniques include supervised and unsupervised learning..."
|
||||
]
|
||||
final_response = "Machine learning is a field of AI that enables computers to learn and improve from experience without being explicitly programmed."
|
||||
|
||||
mock_embeddings_client.embed.return_value = query_vectors
|
||||
mock_doc_embeddings_client.query.return_value = retrieved_docs
|
||||
mock_prompt_client.document_prompt.return_value = final_response
|
||||
|
||||
# Initialize DocumentRag
|
||||
document_rag = DocumentRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
embeddings_client=mock_embeddings_client,
|
||||
doc_embeddings_client=mock_doc_embeddings_client,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Execute full pipeline
|
||||
result = await document_rag.query(
|
||||
query=query_text,
|
||||
user="research_user",
|
||||
collection="ml_knowledge",
|
||||
doc_limit=25
|
||||
)
|
||||
|
||||
# Verify complete pipeline execution
|
||||
mock_embeddings_client.embed.assert_called_once_with(query_text)
|
||||
|
||||
mock_doc_embeddings_client.query.assert_called_once_with(
|
||||
query_vectors,
|
||||
limit=25,
|
||||
user="research_user",
|
||||
collection="ml_knowledge"
|
||||
)
|
||||
|
||||
mock_prompt_client.document_prompt.assert_called_once_with(
|
||||
query=query_text,
|
||||
documents=retrieved_docs
|
||||
)
|
||||
|
||||
# Verify final result
|
||||
assert result == final_response
|
||||
595
tests/unit/test_retrieval/test_graph_rag.py
Normal file
595
tests/unit/test_retrieval/test_graph_rag.py
Normal file
|
|
@ -0,0 +1,595 @@
|
|||
"""
|
||||
Tests for GraphRAG retrieval implementation
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import unittest.mock
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag, Query
|
||||
|
||||
|
||||
class TestGraphRag:
|
||||
"""Test cases for GraphRag class"""
|
||||
|
||||
def test_graph_rag_initialization_with_defaults(self):
|
||||
"""Test GraphRag initialization with default verbose setting"""
|
||||
# Create mock clients
|
||||
mock_prompt_client = MagicMock()
|
||||
mock_embeddings_client = MagicMock()
|
||||
mock_graph_embeddings_client = MagicMock()
|
||||
mock_triples_client = MagicMock()
|
||||
|
||||
# Initialize GraphRag
|
||||
graph_rag = GraphRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
embeddings_client=mock_embeddings_client,
|
||||
graph_embeddings_client=mock_graph_embeddings_client,
|
||||
triples_client=mock_triples_client
|
||||
)
|
||||
|
||||
# Verify initialization
|
||||
assert graph_rag.prompt_client == mock_prompt_client
|
||||
assert graph_rag.embeddings_client == mock_embeddings_client
|
||||
assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client
|
||||
assert graph_rag.triples_client == mock_triples_client
|
||||
assert graph_rag.verbose is False # Default value
|
||||
assert graph_rag.label_cache == {} # Empty cache initially
|
||||
|
||||
def test_graph_rag_initialization_with_verbose(self):
|
||||
"""Test GraphRag initialization with verbose enabled"""
|
||||
# Create mock clients
|
||||
mock_prompt_client = MagicMock()
|
||||
mock_embeddings_client = MagicMock()
|
||||
mock_graph_embeddings_client = MagicMock()
|
||||
mock_triples_client = MagicMock()
|
||||
|
||||
# Initialize GraphRag with verbose=True
|
||||
graph_rag = GraphRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
embeddings_client=mock_embeddings_client,
|
||||
graph_embeddings_client=mock_graph_embeddings_client,
|
||||
triples_client=mock_triples_client,
|
||||
verbose=True
|
||||
)
|
||||
|
||||
# Verify initialization
|
||||
assert graph_rag.prompt_client == mock_prompt_client
|
||||
assert graph_rag.embeddings_client == mock_embeddings_client
|
||||
assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client
|
||||
assert graph_rag.triples_client == mock_triples_client
|
||||
assert graph_rag.verbose is True
|
||||
assert graph_rag.label_cache == {} # Empty cache initially
|
||||
|
||||
|
||||
class TestQuery:
|
||||
"""Test cases for Query class"""
|
||||
|
||||
def test_query_initialization_with_defaults(self):
|
||||
"""Test Query initialization with default parameters"""
|
||||
# Create mock GraphRag
|
||||
mock_rag = MagicMock()
|
||||
|
||||
# Initialize Query with defaults
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Verify initialization
|
||||
assert query.rag == mock_rag
|
||||
assert query.user == "test_user"
|
||||
assert query.collection == "test_collection"
|
||||
assert query.verbose is False
|
||||
assert query.entity_limit == 50 # Default value
|
||||
assert query.triple_limit == 30 # Default value
|
||||
assert query.max_subgraph_size == 1000 # Default value
|
||||
assert query.max_path_length == 2 # Default value
|
||||
|
||||
def test_query_initialization_with_custom_params(self):
|
||||
"""Test Query initialization with custom parameters"""
|
||||
# Create mock GraphRag
|
||||
mock_rag = MagicMock()
|
||||
|
||||
# Initialize Query with custom parameters
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="custom_user",
|
||||
collection="custom_collection",
|
||||
verbose=True,
|
||||
entity_limit=100,
|
||||
triple_limit=60,
|
||||
max_subgraph_size=2000,
|
||||
max_path_length=3
|
||||
)
|
||||
|
||||
# Verify initialization
|
||||
assert query.rag == mock_rag
|
||||
assert query.user == "custom_user"
|
||||
assert query.collection == "custom_collection"
|
||||
assert query.verbose is True
|
||||
assert query.entity_limit == 100
|
||||
assert query.triple_limit == 60
|
||||
assert query.max_subgraph_size == 2000
|
||||
assert query.max_path_length == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_vector_method(self):
|
||||
"""Test Query.get_vector method calls embeddings client correctly"""
|
||||
# Create mock GraphRag with embeddings client
|
||||
mock_rag = MagicMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_rag.embeddings_client = mock_embeddings_client
|
||||
|
||||
# Mock the embed method to return test vectors
|
||||
expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
mock_embeddings_client.embed.return_value = expected_vectors
|
||||
|
||||
# Initialize Query
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Call get_vector
|
||||
test_query = "What is the capital of France?"
|
||||
result = await query.get_vector(test_query)
|
||||
|
||||
# Verify embeddings client was called correctly
|
||||
mock_embeddings_client.embed.assert_called_once_with(test_query)
|
||||
|
||||
# Verify result matches expected vectors
|
||||
assert result == expected_vectors
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_vector_method_with_verbose(self):
|
||||
"""Test Query.get_vector method with verbose output"""
|
||||
# Create mock GraphRag with embeddings client
|
||||
mock_rag = MagicMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_rag.embeddings_client = mock_embeddings_client
|
||||
|
||||
# Mock the embed method
|
||||
expected_vectors = [[0.7, 0.8, 0.9]]
|
||||
mock_embeddings_client.embed.return_value = expected_vectors
|
||||
|
||||
# Initialize Query with verbose=True
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=True
|
||||
)
|
||||
|
||||
# Call get_vector
|
||||
test_query = "Test query for embeddings"
|
||||
result = await query.get_vector(test_query)
|
||||
|
||||
# Verify embeddings client was called correctly
|
||||
mock_embeddings_client.embed.assert_called_once_with(test_query)
|
||||
|
||||
# Verify result matches expected vectors
|
||||
assert result == expected_vectors
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_entities_method(self):
|
||||
"""Test Query.get_entities method retrieves entities correctly"""
|
||||
# Create mock GraphRag with clients
|
||||
mock_rag = MagicMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_graph_embeddings_client = AsyncMock()
|
||||
mock_rag.embeddings_client = mock_embeddings_client
|
||||
mock_rag.graph_embeddings_client = mock_graph_embeddings_client
|
||||
|
||||
# Mock the embedding and entity query responses
|
||||
test_vectors = [[0.1, 0.2, 0.3]]
|
||||
mock_embeddings_client.embed.return_value = test_vectors
|
||||
|
||||
# Mock entity objects that have string representation
|
||||
mock_entity1 = MagicMock()
|
||||
mock_entity1.__str__ = MagicMock(return_value="entity1")
|
||||
mock_entity2 = MagicMock()
|
||||
mock_entity2.__str__ = MagicMock(return_value="entity2")
|
||||
mock_graph_embeddings_client.query.return_value = [mock_entity1, mock_entity2]
|
||||
|
||||
# Initialize Query
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
entity_limit=25
|
||||
)
|
||||
|
||||
# Call get_entities
|
||||
test_query = "Find related entities"
|
||||
result = await query.get_entities(test_query)
|
||||
|
||||
# Verify embeddings client was called
|
||||
mock_embeddings_client.embed.assert_called_once_with(test_query)
|
||||
|
||||
# Verify graph embeddings client was called correctly
|
||||
mock_graph_embeddings_client.query.assert_called_once_with(
|
||||
vectors=test_vectors,
|
||||
limit=25,
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
# Verify result is list of entity strings
|
||||
assert result == ["entity1", "entity2"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_maybe_label_with_cached_label(self):
|
||||
"""Test Query.maybe_label method with cached label"""
|
||||
# Create mock GraphRag with label cache
|
||||
mock_rag = MagicMock()
|
||||
mock_rag.label_cache = {"entity1": "Entity One Label"}
|
||||
|
||||
# Initialize Query
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Call maybe_label with cached entity
|
||||
result = await query.maybe_label("entity1")
|
||||
|
||||
# Verify cached label is returned
|
||||
assert result == "Entity One Label"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_maybe_label_with_label_lookup(self):
|
||||
"""Test Query.maybe_label method with database label lookup"""
|
||||
# Create mock GraphRag with triples client
|
||||
mock_rag = MagicMock()
|
||||
mock_rag.label_cache = {} # Empty cache
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_rag.triples_client = mock_triples_client
|
||||
|
||||
# Mock triple result with label
|
||||
mock_triple = MagicMock()
|
||||
mock_triple.o = "Human Readable Label"
|
||||
mock_triples_client.query.return_value = [mock_triple]
|
||||
|
||||
# Initialize Query
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Call maybe_label
|
||||
result = await query.maybe_label("http://example.com/entity")
|
||||
|
||||
# Verify triples client was called correctly
|
||||
mock_triples_client.query.assert_called_once_with(
|
||||
s="http://example.com/entity",
|
||||
p="http://www.w3.org/2000/01/rdf-schema#label",
|
||||
o=None,
|
||||
limit=1,
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
# Verify result and cache update
|
||||
assert result == "Human Readable Label"
|
||||
assert mock_rag.label_cache["http://example.com/entity"] == "Human Readable Label"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_maybe_label_with_no_label_found(self):
|
||||
"""Test Query.maybe_label method when no label is found"""
|
||||
# Create mock GraphRag with triples client
|
||||
mock_rag = MagicMock()
|
||||
mock_rag.label_cache = {} # Empty cache
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_rag.triples_client = mock_triples_client
|
||||
|
||||
# Mock empty result (no label found)
|
||||
mock_triples_client.query.return_value = []
|
||||
|
||||
# Initialize Query
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Call maybe_label
|
||||
result = await query.maybe_label("unlabeled_entity")
|
||||
|
||||
# Verify triples client was called
|
||||
mock_triples_client.query.assert_called_once_with(
|
||||
s="unlabeled_entity",
|
||||
p="http://www.w3.org/2000/01/rdf-schema#label",
|
||||
o=None,
|
||||
limit=1,
|
||||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
# Verify result is entity itself and cache is updated
|
||||
assert result == "unlabeled_entity"
|
||||
assert mock_rag.label_cache["unlabeled_entity"] == "unlabeled_entity"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_follow_edges_basic_functionality(self):
|
||||
"""Test Query.follow_edges method basic triple discovery"""
|
||||
# Create mock GraphRag with triples client
|
||||
mock_rag = MagicMock()
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_rag.triples_client = mock_triples_client
|
||||
|
||||
# Mock triple results for different query patterns
|
||||
mock_triple1 = MagicMock()
|
||||
mock_triple1.s, mock_triple1.p, mock_triple1.o = "entity1", "predicate1", "object1"
|
||||
|
||||
mock_triple2 = MagicMock()
|
||||
mock_triple2.s, mock_triple2.p, mock_triple2.o = "subject2", "entity1", "object2"
|
||||
|
||||
mock_triple3 = MagicMock()
|
||||
mock_triple3.s, mock_triple3.p, mock_triple3.o = "subject3", "predicate3", "entity1"
|
||||
|
||||
# Setup query responses for s=ent, p=ent, o=ent patterns
|
||||
mock_triples_client.query.side_effect = [
|
||||
[mock_triple1], # s=ent, p=None, o=None
|
||||
[mock_triple2], # s=None, p=ent, o=None
|
||||
[mock_triple3], # s=None, p=None, o=ent
|
||||
]
|
||||
|
||||
# Initialize Query
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
triple_limit=10
|
||||
)
|
||||
|
||||
# Call follow_edges
|
||||
subgraph = set()
|
||||
await query.follow_edges("entity1", subgraph, path_length=1)
|
||||
|
||||
# Verify all three query patterns were called
|
||||
assert mock_triples_client.query.call_count == 3
|
||||
|
||||
# Verify query calls
|
||||
mock_triples_client.query.assert_any_call(
|
||||
s="entity1", p=None, o=None, limit=10,
|
||||
user="test_user", collection="test_collection"
|
||||
)
|
||||
mock_triples_client.query.assert_any_call(
|
||||
s=None, p="entity1", o=None, limit=10,
|
||||
user="test_user", collection="test_collection"
|
||||
)
|
||||
mock_triples_client.query.assert_any_call(
|
||||
s=None, p=None, o="entity1", limit=10,
|
||||
user="test_user", collection="test_collection"
|
||||
)
|
||||
|
||||
# Verify subgraph contains discovered triples
|
||||
expected_subgraph = {
|
||||
("entity1", "predicate1", "object1"),
|
||||
("subject2", "entity1", "object2"),
|
||||
("subject3", "predicate3", "entity1")
|
||||
}
|
||||
assert subgraph == expected_subgraph
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_follow_edges_with_path_length_zero(self):
|
||||
"""Test Query.follow_edges method with path_length=0"""
|
||||
# Create mock GraphRag
|
||||
mock_rag = MagicMock()
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_rag.triples_client = mock_triples_client
|
||||
|
||||
# Initialize Query
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Call follow_edges with path_length=0
|
||||
subgraph = set()
|
||||
await query.follow_edges("entity1", subgraph, path_length=0)
|
||||
|
||||
# Verify no queries were made
|
||||
mock_triples_client.query.assert_not_called()
|
||||
|
||||
# Verify subgraph remains empty
|
||||
assert subgraph == set()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_follow_edges_with_max_subgraph_size_limit(self):
|
||||
"""Test Query.follow_edges method respects max_subgraph_size"""
|
||||
# Create mock GraphRag
|
||||
mock_rag = MagicMock()
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_rag.triples_client = mock_triples_client
|
||||
|
||||
# Initialize Query with small max_subgraph_size
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
max_subgraph_size=2
|
||||
)
|
||||
|
||||
# Pre-populate subgraph to exceed limit
|
||||
subgraph = {("s1", "p1", "o1"), ("s2", "p2", "o2"), ("s3", "p3", "o3")}
|
||||
|
||||
# Call follow_edges
|
||||
await query.follow_edges("entity1", subgraph, path_length=1)
|
||||
|
||||
# Verify no queries were made due to size limit
|
||||
mock_triples_client.query.assert_not_called()
|
||||
|
||||
# Verify subgraph unchanged
|
||||
assert len(subgraph) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_subgraph_method(self):
|
||||
"""Test Query.get_subgraph method orchestrates entity and edge discovery"""
|
||||
# Create mock Query that patches get_entities and follow_edges
|
||||
mock_rag = MagicMock()
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
max_path_length=1
|
||||
)
|
||||
|
||||
# Mock get_entities to return test entities
|
||||
query.get_entities = AsyncMock(return_value=["entity1", "entity2"])
|
||||
|
||||
# Mock follow_edges to add triples to subgraph
|
||||
async def mock_follow_edges(ent, subgraph, path_length):
|
||||
subgraph.add((ent, "predicate", "object"))
|
||||
|
||||
query.follow_edges = AsyncMock(side_effect=mock_follow_edges)
|
||||
|
||||
# Call get_subgraph
|
||||
result = await query.get_subgraph("test query")
|
||||
|
||||
# Verify get_entities was called
|
||||
query.get_entities.assert_called_once_with("test query")
|
||||
|
||||
# Verify follow_edges was called for each entity
|
||||
assert query.follow_edges.call_count == 2
|
||||
query.follow_edges.assert_any_call("entity1", unittest.mock.ANY, 1)
|
||||
query.follow_edges.assert_any_call("entity2", unittest.mock.ANY, 1)
|
||||
|
||||
# Verify result is list format
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_labelgraph_method(self):
|
||||
"""Test Query.get_labelgraph method converts entities to labels"""
|
||||
# Create mock Query
|
||||
mock_rag = MagicMock()
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
max_subgraph_size=100
|
||||
)
|
||||
|
||||
# Mock get_subgraph to return test triples
|
||||
test_subgraph = [
|
||||
("entity1", "predicate1", "object1"),
|
||||
("subject2", "http://www.w3.org/2000/01/rdf-schema#label", "Label Value"), # Should be filtered
|
||||
("entity3", "predicate3", "object3")
|
||||
]
|
||||
query.get_subgraph = AsyncMock(return_value=test_subgraph)
|
||||
|
||||
# Mock maybe_label to return human-readable labels
|
||||
async def mock_maybe_label(entity):
|
||||
label_map = {
|
||||
"entity1": "Human Entity One",
|
||||
"predicate1": "Human Predicate One",
|
||||
"object1": "Human Object One",
|
||||
"entity3": "Human Entity Three",
|
||||
"predicate3": "Human Predicate Three",
|
||||
"object3": "Human Object Three"
|
||||
}
|
||||
return label_map.get(entity, entity)
|
||||
|
||||
query.maybe_label = AsyncMock(side_effect=mock_maybe_label)
|
||||
|
||||
# Call get_labelgraph
|
||||
result = await query.get_labelgraph("test query")
|
||||
|
||||
# Verify get_subgraph was called
|
||||
query.get_subgraph.assert_called_once_with("test query")
|
||||
|
||||
# Verify label triples are filtered out
|
||||
assert len(result) == 2 # Label triple should be excluded
|
||||
|
||||
# Verify maybe_label was called for non-label triples
|
||||
expected_calls = [
|
||||
(("entity1",), {}), (("predicate1",), {}), (("object1",), {}),
|
||||
(("entity3",), {}), (("predicate3",), {}), (("object3",), {})
|
||||
]
|
||||
assert query.maybe_label.call_count == 6
|
||||
|
||||
# Verify result contains human-readable labels
|
||||
expected_result = [
|
||||
("Human Entity One", "Human Predicate One", "Human Object One"),
|
||||
("Human Entity Three", "Human Predicate Three", "Human Object Three")
|
||||
]
|
||||
assert result == expected_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_rag_query_method(self):
|
||||
"""Test GraphRag.query method orchestrates full RAG pipeline"""
|
||||
# Create mock clients
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_graph_embeddings_client = AsyncMock()
|
||||
mock_triples_client = AsyncMock()
|
||||
|
||||
# Mock prompt client response
|
||||
expected_response = "This is the RAG response"
|
||||
mock_prompt_client.kg_prompt.return_value = expected_response
|
||||
|
||||
# Initialize GraphRag
|
||||
graph_rag = GraphRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
embeddings_client=mock_embeddings_client,
|
||||
graph_embeddings_client=mock_graph_embeddings_client,
|
||||
triples_client=mock_triples_client,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Mock the Query class behavior by patching get_labelgraph
|
||||
test_labelgraph = [("Subject", "Predicate", "Object")]
|
||||
|
||||
# We need to patch the Query class's get_labelgraph method
|
||||
original_query_init = Query.__init__
|
||||
original_get_labelgraph = Query.get_labelgraph
|
||||
|
||||
def mock_query_init(self, *args, **kwargs):
|
||||
original_query_init(self, *args, **kwargs)
|
||||
|
||||
async def mock_get_labelgraph(self, query_text):
|
||||
return test_labelgraph
|
||||
|
||||
Query.__init__ = mock_query_init
|
||||
Query.get_labelgraph = mock_get_labelgraph
|
||||
|
||||
try:
|
||||
# Call GraphRag.query
|
||||
result = await graph_rag.query(
|
||||
query="test query",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
entity_limit=25,
|
||||
triple_limit=15
|
||||
)
|
||||
|
||||
# Verify prompt client was called with knowledge graph and query
|
||||
mock_prompt_client.kg_prompt.assert_called_once_with("test query", test_labelgraph)
|
||||
|
||||
# Verify result
|
||||
assert result == expected_response
|
||||
|
||||
finally:
|
||||
# Restore original methods
|
||||
Query.__init__ = original_query_init
|
||||
Query.get_labelgraph = original_get_labelgraph
|
||||
277
tests/unit/test_rev_gateway/test_dispatcher.py
Normal file
277
tests/unit/test_rev_gateway/test_dispatcher.py
Normal file
|
|
@ -0,0 +1,277 @@
|
|||
"""
|
||||
Tests for Reverse Gateway Dispatcher
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
from trustgraph.rev_gateway.dispatcher import WebSocketResponder, MessageDispatcher
|
||||
|
||||
|
||||
class TestWebSocketResponder:
|
||||
"""Test cases for WebSocketResponder class"""
|
||||
|
||||
def test_websocket_responder_initialization(self):
|
||||
"""Test WebSocketResponder initialization"""
|
||||
responder = WebSocketResponder()
|
||||
|
||||
assert responder.response is None
|
||||
assert responder.completed is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_responder_send_method(self):
|
||||
"""Test WebSocketResponder send method"""
|
||||
responder = WebSocketResponder()
|
||||
|
||||
test_response = {"data": "test response"}
|
||||
|
||||
# Call send method
|
||||
await responder.send(test_response)
|
||||
|
||||
# Verify response was stored
|
||||
assert responder.response == test_response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_responder_call_method(self):
|
||||
"""Test WebSocketResponder __call__ method"""
|
||||
responder = WebSocketResponder()
|
||||
|
||||
test_response = {"result": "success"}
|
||||
test_completed = True
|
||||
|
||||
# Call the responder
|
||||
await responder(test_response, test_completed)
|
||||
|
||||
# Verify response and completed status were set
|
||||
assert responder.response == test_response
|
||||
assert responder.completed == test_completed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_responder_call_method_with_false_completion(self):
|
||||
"""Test WebSocketResponder __call__ method with incomplete response"""
|
||||
responder = WebSocketResponder()
|
||||
|
||||
test_response = {"partial": "data"}
|
||||
test_completed = False
|
||||
|
||||
# Call the responder
|
||||
await responder(test_response, test_completed)
|
||||
|
||||
# Verify response was set and completed is True (since send() always sets completed=True)
|
||||
assert responder.response == test_response
|
||||
assert responder.completed is True
|
||||
|
||||
|
||||
class TestMessageDispatcher:
|
||||
"""Test cases for MessageDispatcher class"""
|
||||
|
||||
def test_message_dispatcher_initialization_with_defaults(self):
|
||||
"""Test MessageDispatcher initialization with default parameters"""
|
||||
dispatcher = MessageDispatcher()
|
||||
|
||||
assert dispatcher.max_workers == 10
|
||||
assert dispatcher.semaphore._value == 10
|
||||
assert dispatcher.active_tasks == set()
|
||||
assert dispatcher.pulsar_client is None
|
||||
assert dispatcher.dispatcher_manager is None
|
||||
assert len(dispatcher.service_mapping) > 0
|
||||
|
||||
def test_message_dispatcher_initialization_with_custom_workers(self):
|
||||
"""Test MessageDispatcher initialization with custom max_workers"""
|
||||
dispatcher = MessageDispatcher(max_workers=5)
|
||||
|
||||
assert dispatcher.max_workers == 5
|
||||
assert dispatcher.semaphore._value == 5
|
||||
|
||||
@patch('trustgraph.rev_gateway.dispatcher.DispatcherManager')
|
||||
def test_message_dispatcher_initialization_with_pulsar_client(self, mock_dispatcher_manager):
|
||||
"""Test MessageDispatcher initialization with pulsar_client and config_receiver"""
|
||||
mock_pulsar_client = MagicMock()
|
||||
mock_config_receiver = MagicMock()
|
||||
mock_dispatcher_instance = MagicMock()
|
||||
mock_dispatcher_manager.return_value = mock_dispatcher_instance
|
||||
|
||||
dispatcher = MessageDispatcher(
|
||||
max_workers=8,
|
||||
config_receiver=mock_config_receiver,
|
||||
pulsar_client=mock_pulsar_client
|
||||
)
|
||||
|
||||
assert dispatcher.max_workers == 8
|
||||
assert dispatcher.pulsar_client == mock_pulsar_client
|
||||
assert dispatcher.dispatcher_manager == mock_dispatcher_instance
|
||||
mock_dispatcher_manager.assert_called_once_with(
|
||||
mock_pulsar_client, mock_config_receiver, prefix="rev-gateway"
|
||||
)
|
||||
|
||||
def test_message_dispatcher_service_mapping(self):
|
||||
"""Test MessageDispatcher service mapping contains expected services"""
|
||||
dispatcher = MessageDispatcher()
|
||||
|
||||
expected_services = [
|
||||
"text-completion", "graph-rag", "agent", "embeddings",
|
||||
"graph-embeddings", "triples", "document-load", "text-load",
|
||||
"flow", "knowledge", "config", "librarian", "document-rag"
|
||||
]
|
||||
|
||||
for service in expected_services:
|
||||
assert service in dispatcher.service_mapping
|
||||
|
||||
# Test specific mappings
|
||||
assert dispatcher.service_mapping["text-completion"] == "text-completion"
|
||||
assert dispatcher.service_mapping["document-load"] == "document"
|
||||
assert dispatcher.service_mapping["text-load"] == "text-document"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_dispatcher_handle_message_without_dispatcher_manager(self):
|
||||
"""Test MessageDispatcher handle_message without dispatcher manager"""
|
||||
dispatcher = MessageDispatcher()
|
||||
|
||||
test_message = {
|
||||
"id": "test-123",
|
||||
"service": "test-service",
|
||||
"request": {"data": "test"}
|
||||
}
|
||||
|
||||
result = await dispatcher.handle_message(test_message)
|
||||
|
||||
assert result["id"] == "test-123"
|
||||
assert "error" in result["response"]
|
||||
assert "DispatcherManager not available" in result["response"]["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_dispatcher_handle_message_with_exception(self):
|
||||
"""Test MessageDispatcher handle_message with exception during processing"""
|
||||
mock_dispatcher_manager = MagicMock()
|
||||
mock_dispatcher_manager.invoke_global_service = AsyncMock(side_effect=Exception("Test error"))
|
||||
|
||||
dispatcher = MessageDispatcher()
|
||||
dispatcher.dispatcher_manager = mock_dispatcher_manager
|
||||
|
||||
test_message = {
|
||||
"id": "test-456",
|
||||
"service": "text-completion",
|
||||
"request": {"prompt": "test"}
|
||||
}
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {"text-completion": True}):
|
||||
result = await dispatcher.handle_message(test_message)
|
||||
|
||||
assert result["id"] == "test-456"
|
||||
assert "error" in result["response"]
|
||||
assert "Test error" in result["response"]["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_dispatcher_handle_message_global_service(self):
|
||||
"""Test MessageDispatcher handle_message with global service"""
|
||||
mock_dispatcher_manager = MagicMock()
|
||||
mock_dispatcher_manager.invoke_global_service = AsyncMock()
|
||||
mock_responder = MagicMock()
|
||||
mock_responder.completed = True
|
||||
mock_responder.response = {"result": "success"}
|
||||
|
||||
dispatcher = MessageDispatcher()
|
||||
dispatcher.dispatcher_manager = mock_dispatcher_manager
|
||||
|
||||
test_message = {
|
||||
"id": "test-789",
|
||||
"service": "text-completion",
|
||||
"request": {"prompt": "hello"}
|
||||
}
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {"text-completion": True}):
|
||||
with patch('trustgraph.rev_gateway.dispatcher.WebSocketResponder', return_value=mock_responder):
|
||||
result = await dispatcher.handle_message(test_message)
|
||||
|
||||
assert result["id"] == "test-789"
|
||||
assert result["response"] == {"result": "success"}
|
||||
mock_dispatcher_manager.invoke_global_service.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_dispatcher_handle_message_flow_service(self):
|
||||
"""Test MessageDispatcher handle_message with flow service"""
|
||||
mock_dispatcher_manager = MagicMock()
|
||||
mock_dispatcher_manager.invoke_flow_service = AsyncMock()
|
||||
mock_responder = MagicMock()
|
||||
mock_responder.completed = True
|
||||
mock_responder.response = {"data": "flow_result"}
|
||||
|
||||
dispatcher = MessageDispatcher()
|
||||
dispatcher.dispatcher_manager = mock_dispatcher_manager
|
||||
|
||||
test_message = {
|
||||
"id": "test-flow-123",
|
||||
"service": "document-rag",
|
||||
"request": {"query": "test"},
|
||||
"flow": "custom-flow"
|
||||
}
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {}):
|
||||
with patch('trustgraph.rev_gateway.dispatcher.WebSocketResponder', return_value=mock_responder):
|
||||
result = await dispatcher.handle_message(test_message)
|
||||
|
||||
assert result["id"] == "test-flow-123"
|
||||
assert result["response"] == {"data": "flow_result"}
|
||||
mock_dispatcher_manager.invoke_flow_service.assert_called_once_with(
|
||||
{"query": "test"}, mock_responder, "custom-flow", "document-rag"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_dispatcher_handle_message_incomplete_response(self):
|
||||
"""Test MessageDispatcher handle_message with incomplete response"""
|
||||
mock_dispatcher_manager = MagicMock()
|
||||
mock_dispatcher_manager.invoke_flow_service = AsyncMock()
|
||||
mock_responder = MagicMock()
|
||||
mock_responder.completed = False
|
||||
mock_responder.response = None
|
||||
|
||||
dispatcher = MessageDispatcher()
|
||||
dispatcher.dispatcher_manager = mock_dispatcher_manager
|
||||
|
||||
test_message = {
|
||||
"id": "test-incomplete",
|
||||
"service": "agent",
|
||||
"request": {"input": "test"}
|
||||
}
|
||||
|
||||
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {}):
|
||||
with patch('trustgraph.rev_gateway.dispatcher.WebSocketResponder', return_value=mock_responder):
|
||||
result = await dispatcher.handle_message(test_message)
|
||||
|
||||
assert result["id"] == "test-incomplete"
|
||||
assert result["response"] == {"error": "No response received"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_dispatcher_shutdown(self):
|
||||
"""Test MessageDispatcher shutdown method"""
|
||||
import asyncio
|
||||
|
||||
dispatcher = MessageDispatcher()
|
||||
|
||||
# Create actual async tasks
|
||||
async def dummy_task():
|
||||
await asyncio.sleep(0.01)
|
||||
return "done"
|
||||
|
||||
task1 = asyncio.create_task(dummy_task())
|
||||
task2 = asyncio.create_task(dummy_task())
|
||||
dispatcher.active_tasks = {task1, task2}
|
||||
|
||||
# Call shutdown
|
||||
await dispatcher.shutdown()
|
||||
|
||||
# Verify tasks were completed
|
||||
assert task1.done()
|
||||
assert task2.done()
|
||||
assert len(dispatcher.active_tasks) == 2 # Tasks remain in set but are completed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_dispatcher_shutdown_with_no_tasks(self):
|
||||
"""Test MessageDispatcher shutdown with no active tasks"""
|
||||
dispatcher = MessageDispatcher()
|
||||
|
||||
# Call shutdown with no active tasks
|
||||
await dispatcher.shutdown()
|
||||
|
||||
# Should complete without error
|
||||
assert dispatcher.active_tasks == set()
|
||||
545
tests/unit/test_rev_gateway/test_rev_gateway_service.py
Normal file
545
tests/unit/test_rev_gateway/test_rev_gateway_service.py
Normal file
|
|
@ -0,0 +1,545 @@
|
|||
"""
|
||||
Tests for Reverse Gateway Service
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock, AsyncMock, patch, Mock
|
||||
from aiohttp import WSMsgType, ClientWebSocketResponse
|
||||
import json
|
||||
|
||||
from trustgraph.rev_gateway.service import ReverseGateway, parse_args, run
|
||||
|
||||
|
||||
class TestReverseGateway:
|
||||
"""Test cases for ReverseGateway class"""
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
def test_reverse_gateway_initialization_defaults(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway initialization with default parameters"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
|
||||
gateway = ReverseGateway()
|
||||
|
||||
assert gateway.websocket_uri == "ws://localhost:7650/out"
|
||||
assert gateway.host == "localhost"
|
||||
assert gateway.port == 7650
|
||||
assert gateway.scheme == "ws"
|
||||
assert gateway.path == "/out"
|
||||
assert gateway.url == "ws://localhost:7650/out"
|
||||
assert gateway.max_workers == 10
|
||||
assert gateway.running is False
|
||||
assert gateway.reconnect_delay == 3.0
|
||||
assert gateway.pulsar_host == "pulsar://pulsar:6650"
|
||||
assert gateway.pulsar_api_key is None
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
def test_reverse_gateway_initialization_custom_params(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway initialization with custom parameters"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
|
||||
gateway = ReverseGateway(
|
||||
websocket_uri="wss://example.com:8080/websocket",
|
||||
max_workers=20,
|
||||
pulsar_host="pulsar://custom:6650",
|
||||
pulsar_api_key="test-key",
|
||||
pulsar_listener="test-listener"
|
||||
)
|
||||
|
||||
assert gateway.websocket_uri == "wss://example.com:8080/websocket"
|
||||
assert gateway.host == "example.com"
|
||||
assert gateway.port == 8080
|
||||
assert gateway.scheme == "wss"
|
||||
assert gateway.path == "/websocket"
|
||||
assert gateway.url == "wss://example.com:8080/websocket"
|
||||
assert gateway.max_workers == 20
|
||||
assert gateway.pulsar_host == "pulsar://custom:6650"
|
||||
assert gateway.pulsar_api_key == "test-key"
|
||||
assert gateway.pulsar_listener == "test-listener"
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
def test_reverse_gateway_initialization_with_missing_path(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway initialization with WebSocket URI missing path"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
|
||||
gateway = ReverseGateway(websocket_uri="ws://example.com")
|
||||
|
||||
assert gateway.path == "/ws"
|
||||
assert gateway.url == "ws://example.com/ws"
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
def test_reverse_gateway_initialization_invalid_scheme(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway initialization with invalid WebSocket scheme"""
|
||||
with pytest.raises(ValueError, match="WebSocket URI must use ws:// or wss:// scheme"):
|
||||
ReverseGateway(websocket_uri="http://example.com")
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
def test_reverse_gateway_initialization_missing_hostname(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway initialization with missing hostname"""
|
||||
with pytest.raises(ValueError, match="WebSocket URI must include hostname"):
|
||||
ReverseGateway(websocket_uri="ws://")
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
def test_reverse_gateway_pulsar_client_with_auth(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway creates Pulsar client with authentication"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
|
||||
with patch('pulsar.AuthenticationToken') as mock_auth:
|
||||
mock_auth_instance = MagicMock()
|
||||
mock_auth.return_value = mock_auth_instance
|
||||
|
||||
gateway = ReverseGateway(
|
||||
pulsar_api_key="test-key",
|
||||
pulsar_listener="test-listener"
|
||||
)
|
||||
|
||||
mock_auth.assert_called_once_with("test-key")
|
||||
mock_pulsar_client.assert_called_once_with(
|
||||
"pulsar://pulsar:6650",
|
||||
listener_name="test-listener",
|
||||
authentication=mock_auth_instance
|
||||
)
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
@patch('trustgraph.rev_gateway.service.ClientSession')
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_connect_success(self, mock_session_class, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway successful connection"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_ws = AsyncMock()
|
||||
mock_session.ws_connect.return_value = mock_ws
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
gateway = ReverseGateway()
|
||||
|
||||
result = await gateway.connect()
|
||||
|
||||
assert result is True
|
||||
assert gateway.session == mock_session
|
||||
assert gateway.ws == mock_ws
|
||||
mock_session.ws_connect.assert_called_once_with(gateway.url)
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
@patch('trustgraph.rev_gateway.service.ClientSession')
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_connect_failure(self, mock_session_class, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway connection failure"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.ws_connect.side_effect = Exception("Connection failed")
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
gateway = ReverseGateway()
|
||||
|
||||
result = await gateway.connect()
|
||||
|
||||
assert result is False
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_disconnect(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway disconnect"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
|
||||
gateway = ReverseGateway()
|
||||
|
||||
# Mock websocket and session
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.closed = False
|
||||
mock_session = AsyncMock()
|
||||
mock_session.closed = False
|
||||
|
||||
gateway.ws = mock_ws
|
||||
gateway.session = mock_session
|
||||
|
||||
await gateway.disconnect()
|
||||
|
||||
mock_ws.close.assert_called_once()
|
||||
mock_session.close.assert_called_once()
|
||||
assert gateway.ws is None
|
||||
assert gateway.session is None
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_send_message(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway send message"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
|
||||
gateway = ReverseGateway()
|
||||
|
||||
# Mock websocket
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.closed = False
|
||||
gateway.ws = mock_ws
|
||||
|
||||
test_message = {"id": "test", "data": "hello"}
|
||||
|
||||
await gateway.send_message(test_message)
|
||||
|
||||
mock_ws.send_str.assert_called_once_with(json.dumps(test_message))
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_send_message_closed_connection(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway send message with closed connection"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
|
||||
gateway = ReverseGateway()
|
||||
|
||||
# Mock closed websocket
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.closed = True
|
||||
gateway.ws = mock_ws
|
||||
|
||||
test_message = {"id": "test", "data": "hello"}
|
||||
|
||||
await gateway.send_message(test_message)
|
||||
|
||||
# Should not call send_str on closed connection
|
||||
mock_ws.send_str.assert_not_called()
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_handle_message(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway handle message"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
|
||||
mock_dispatcher_instance = AsyncMock()
|
||||
mock_dispatcher_instance.handle_message.return_value = {"response": "success"}
|
||||
mock_dispatcher.return_value = mock_dispatcher_instance
|
||||
|
||||
gateway = ReverseGateway()
|
||||
|
||||
# Mock send_message
|
||||
gateway.send_message = AsyncMock()
|
||||
|
||||
test_message = '{"id": "test", "service": "test-service", "request": {"data": "test"}}'
|
||||
|
||||
await gateway.handle_message(test_message)
|
||||
|
||||
mock_dispatcher_instance.handle_message.assert_called_once_with({
|
||||
"id": "test",
|
||||
"service": "test-service",
|
||||
"request": {"data": "test"}
|
||||
})
|
||||
gateway.send_message.assert_called_once_with({"response": "success"})
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_handle_message_invalid_json(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway handle message with invalid JSON"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
|
||||
gateway = ReverseGateway()
|
||||
|
||||
# Mock send_message
|
||||
gateway.send_message = AsyncMock()
|
||||
|
||||
test_message = 'invalid json'
|
||||
|
||||
# Should not raise exception
|
||||
await gateway.handle_message(test_message)
|
||||
|
||||
# Should not call send_message due to error
|
||||
gateway.send_message.assert_not_called()
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_listen_text_message(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway listen with text message"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
|
||||
gateway = ReverseGateway()
|
||||
gateway.running = True
|
||||
|
||||
# Mock websocket
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.closed = False
|
||||
gateway.ws = mock_ws
|
||||
|
||||
# Mock handle_message
|
||||
gateway.handle_message = AsyncMock()
|
||||
|
||||
# Mock message
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.type = WSMsgType.TEXT
|
||||
mock_msg.data = '{"test": "message"}'
|
||||
|
||||
# Mock receive to return message once, then raise exception to stop loop
|
||||
mock_ws.receive.side_effect = [mock_msg, Exception("Test stop")]
|
||||
|
||||
# listen() catches exceptions and breaks, so no exception should be raised
|
||||
await gateway.listen()
|
||||
|
||||
gateway.handle_message.assert_called_once_with('{"test": "message"}')
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_listen_binary_message(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway listen with binary message"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
|
||||
gateway = ReverseGateway()
|
||||
gateway.running = True
|
||||
|
||||
# Mock websocket
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.closed = False
|
||||
gateway.ws = mock_ws
|
||||
|
||||
# Mock handle_message
|
||||
gateway.handle_message = AsyncMock()
|
||||
|
||||
# Mock message
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.type = WSMsgType.BINARY
|
||||
mock_msg.data = b'{"test": "binary"}'
|
||||
|
||||
# Mock receive to return message once, then raise exception to stop loop
|
||||
mock_ws.receive.side_effect = [mock_msg, Exception("Test stop")]
|
||||
|
||||
# listen() catches exceptions and breaks, so no exception should be raised
|
||||
await gateway.listen()
|
||||
|
||||
gateway.handle_message.assert_called_once_with('{"test": "binary"}')
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_listen_close_message(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway listen with close message"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
|
||||
gateway = ReverseGateway()
|
||||
gateway.running = True
|
||||
|
||||
# Mock websocket
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.closed = False
|
||||
gateway.ws = mock_ws
|
||||
|
||||
# Mock handle_message
|
||||
gateway.handle_message = AsyncMock()
|
||||
|
||||
# Mock message
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.type = WSMsgType.CLOSE
|
||||
|
||||
# Mock receive to return close message
|
||||
mock_ws.receive.return_value = mock_msg
|
||||
|
||||
await gateway.listen()
|
||||
|
||||
# Should not call handle_message for close message
|
||||
gateway.handle_message.assert_not_called()
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_shutdown(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway shutdown"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
|
||||
mock_dispatcher_instance = AsyncMock()
|
||||
mock_dispatcher.return_value = mock_dispatcher_instance
|
||||
|
||||
gateway = ReverseGateway()
|
||||
gateway.running = True
|
||||
|
||||
# Mock disconnect
|
||||
gateway.disconnect = AsyncMock()
|
||||
|
||||
await gateway.shutdown()
|
||||
|
||||
assert gateway.running is False
|
||||
mock_dispatcher_instance.shutdown.assert_called_once()
|
||||
gateway.disconnect.assert_called_once()
|
||||
mock_client_instance.close.assert_called_once()
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
def test_reverse_gateway_stop(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway stop"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
|
||||
gateway = ReverseGateway()
|
||||
gateway.running = True
|
||||
|
||||
gateway.stop()
|
||||
|
||||
assert gateway.running is False
|
||||
|
||||
|
||||
class TestReverseGatewayRun:
|
||||
"""Test cases for ReverseGateway run method"""
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ConfigReceiver')
|
||||
@patch('trustgraph.rev_gateway.service.MessageDispatcher')
|
||||
@patch('pulsar.Client')
|
||||
@pytest.mark.asyncio
|
||||
async def test_reverse_gateway_run_successful_cycle(self, mock_pulsar_client, mock_dispatcher, mock_config_receiver):
|
||||
"""Test ReverseGateway run method with successful connect/listen cycle"""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_pulsar_client.return_value = mock_client_instance
|
||||
|
||||
mock_config_receiver_instance = AsyncMock()
|
||||
mock_config_receiver.return_value = mock_config_receiver_instance
|
||||
|
||||
gateway = ReverseGateway()
|
||||
|
||||
# Mock methods
|
||||
gateway.connect = AsyncMock(return_value=True)
|
||||
gateway.listen = AsyncMock()
|
||||
gateway.disconnect = AsyncMock()
|
||||
gateway.shutdown = AsyncMock()
|
||||
|
||||
# Stop after one iteration
|
||||
call_count = 0
|
||||
async def mock_connect():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return True
|
||||
else:
|
||||
gateway.running = False
|
||||
return False
|
||||
|
||||
gateway.connect = mock_connect
|
||||
|
||||
await gateway.run()
|
||||
|
||||
mock_config_receiver_instance.start.assert_called_once()
|
||||
gateway.listen.assert_called_once()
|
||||
# disconnect is called twice: once in the main loop, once in shutdown
|
||||
assert gateway.disconnect.call_count == 2
|
||||
gateway.shutdown.assert_called_once()
|
||||
|
||||
|
||||
class TestReverseGatewayArgs:
|
||||
"""Test cases for argument parsing and run function"""
|
||||
|
||||
def test_parse_args_defaults(self):
|
||||
"""Test parse_args with default values"""
|
||||
import sys
|
||||
|
||||
# Mock sys.argv
|
||||
original_argv = sys.argv
|
||||
sys.argv = ['reverse-gateway']
|
||||
|
||||
try:
|
||||
args = parse_args()
|
||||
|
||||
assert args.websocket_uri is None
|
||||
assert args.max_workers == 10
|
||||
assert args.pulsar_host is None
|
||||
assert args.pulsar_api_key is None
|
||||
assert args.pulsar_listener is None
|
||||
finally:
|
||||
sys.argv = original_argv
|
||||
|
||||
def test_parse_args_custom_values(self):
|
||||
"""Test parse_args with custom values"""
|
||||
import sys
|
||||
|
||||
# Mock sys.argv
|
||||
original_argv = sys.argv
|
||||
sys.argv = [
|
||||
'reverse-gateway',
|
||||
'--websocket-uri', 'ws://custom:8080/ws',
|
||||
'--max-workers', '20',
|
||||
'--pulsar-host', 'pulsar://custom:6650',
|
||||
'--pulsar-api-key', 'test-key',
|
||||
'--pulsar-listener', 'test-listener'
|
||||
]
|
||||
|
||||
try:
|
||||
args = parse_args()
|
||||
|
||||
assert args.websocket_uri == 'ws://custom:8080/ws'
|
||||
assert args.max_workers == 20
|
||||
assert args.pulsar_host == 'pulsar://custom:6650'
|
||||
assert args.pulsar_api_key == 'test-key'
|
||||
assert args.pulsar_listener == 'test-listener'
|
||||
finally:
|
||||
sys.argv = original_argv
|
||||
|
||||
@patch('trustgraph.rev_gateway.service.ReverseGateway')
|
||||
@patch('asyncio.run')
|
||||
def test_run_function(self, mock_asyncio_run, mock_gateway_class):
|
||||
"""Test run function"""
|
||||
import sys
|
||||
|
||||
# Mock sys.argv
|
||||
original_argv = sys.argv
|
||||
sys.argv = ['reverse-gateway', '--max-workers', '15']
|
||||
|
||||
try:
|
||||
mock_gateway_instance = MagicMock()
|
||||
mock_gateway_instance.url = "ws://localhost:7650/out"
|
||||
mock_gateway_instance.pulsar_host = "pulsar://pulsar:6650"
|
||||
mock_gateway_class.return_value = mock_gateway_instance
|
||||
|
||||
run()
|
||||
|
||||
mock_gateway_class.assert_called_once_with(
|
||||
websocket_uri=None,
|
||||
max_workers=15,
|
||||
pulsar_host=None,
|
||||
pulsar_api_key=None,
|
||||
pulsar_listener=None
|
||||
)
|
||||
mock_asyncio_run.assert_called_once_with(mock_gateway_instance.run())
|
||||
finally:
|
||||
sys.argv = original_argv
|
||||
162
tests/unit/test_storage/conftest.py
Normal file
162
tests/unit/test_storage/conftest.py
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
"""
|
||||
Shared fixtures for storage tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_storage_config():
|
||||
"""Base configuration for storage processors"""
|
||||
return {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-storage-processor'
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def qdrant_storage_config(base_storage_config):
|
||||
"""Configuration for Qdrant storage processors"""
|
||||
return base_storage_config | {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key'
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_qdrant_client():
|
||||
"""Mock Qdrant client"""
|
||||
mock_client = MagicMock()
|
||||
mock_client.collection_exists.return_value = True
|
||||
mock_client.create_collection.return_value = None
|
||||
mock_client.upsert.return_value = None
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_uuid():
|
||||
"""Mock UUID generation"""
|
||||
mock_uuid = MagicMock()
|
||||
mock_uuid.uuid4.return_value = MagicMock()
|
||||
mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid-123')
|
||||
return mock_uuid
|
||||
|
||||
|
||||
# Document embeddings fixtures
|
||||
@pytest.fixture
|
||||
def mock_document_embeddings_message():
|
||||
"""Mock document embeddings message"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'test_user'
|
||||
mock_message.metadata.collection = 'test_collection'
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.chunk.decode.return_value = 'test document chunk'
|
||||
mock_chunk.vectors = [[0.1, 0.2, 0.3]]
|
||||
|
||||
mock_message.chunks = [mock_chunk]
|
||||
return mock_message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document_embeddings_multiple_chunks():
|
||||
"""Mock document embeddings message with multiple chunks"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'multi_user'
|
||||
mock_message.metadata.collection = 'multi_collection'
|
||||
|
||||
mock_chunk1 = MagicMock()
|
||||
mock_chunk1.chunk.decode.return_value = 'first document chunk'
|
||||
mock_chunk1.vectors = [[0.1, 0.2]]
|
||||
|
||||
mock_chunk2 = MagicMock()
|
||||
mock_chunk2.chunk.decode.return_value = 'second document chunk'
|
||||
mock_chunk2.vectors = [[0.3, 0.4]]
|
||||
|
||||
mock_message.chunks = [mock_chunk1, mock_chunk2]
|
||||
return mock_message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document_embeddings_multiple_vectors():
|
||||
"""Mock document embeddings message with multiple vectors per chunk"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'vector_user'
|
||||
mock_message.metadata.collection = 'vector_collection'
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.chunk.decode.return_value = 'multi-vector document chunk'
|
||||
mock_chunk.vectors = [
|
||||
[0.1, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.6],
|
||||
[0.7, 0.8, 0.9]
|
||||
]
|
||||
|
||||
mock_message.chunks = [mock_chunk]
|
||||
return mock_message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document_embeddings_empty_chunk():
|
||||
"""Mock document embeddings message with empty chunk"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'empty_user'
|
||||
mock_message.metadata.collection = 'empty_collection'
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.chunk.decode.return_value = "" # Empty string
|
||||
mock_chunk.vectors = [[0.1, 0.2]]
|
||||
|
||||
mock_message.chunks = [mock_chunk]
|
||||
return mock_message
|
||||
|
||||
|
||||
# Graph embeddings fixtures
|
||||
@pytest.fixture
|
||||
def mock_graph_embeddings_message():
|
||||
"""Mock graph embeddings message"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'test_user'
|
||||
mock_message.metadata.collection = 'test_collection'
|
||||
|
||||
mock_entity = MagicMock()
|
||||
mock_entity.entity.value = 'test_entity'
|
||||
mock_entity.vectors = [[0.1, 0.2, 0.3]]
|
||||
|
||||
mock_message.entities = [mock_entity]
|
||||
return mock_message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_graph_embeddings_multiple_entities():
|
||||
"""Mock graph embeddings message with multiple entities"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'multi_user'
|
||||
mock_message.metadata.collection = 'multi_collection'
|
||||
|
||||
mock_entity1 = MagicMock()
|
||||
mock_entity1.entity.value = 'entity_one'
|
||||
mock_entity1.vectors = [[0.1, 0.2]]
|
||||
|
||||
mock_entity2 = MagicMock()
|
||||
mock_entity2.entity.value = 'entity_two'
|
||||
mock_entity2.vectors = [[0.3, 0.4]]
|
||||
|
||||
mock_message.entities = [mock_entity1, mock_entity2]
|
||||
return mock_message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_graph_embeddings_empty_entity():
|
||||
"""Mock graph embeddings message with empty entity"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'empty_user'
|
||||
mock_message.metadata.collection = 'empty_collection'
|
||||
|
||||
mock_entity = MagicMock()
|
||||
mock_entity.entity.value = "" # Empty string
|
||||
mock_entity.vectors = [[0.1, 0.2]]
|
||||
|
||||
mock_message.entities = [mock_entity]
|
||||
return mock_message
|
||||
569
tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py
Normal file
569
tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py
Normal file
|
|
@ -0,0 +1,569 @@
|
|||
"""
|
||||
Unit tests for trustgraph.storage.doc_embeddings.qdrant.write
|
||||
Testing document embeddings storage functionality
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.storage.doc_embeddings.qdrant.write import Processor
|
||||
|
||||
|
||||
class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||
"""Test Qdrant document embeddings storage functionality"""
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test basic Qdrant processor initialization"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-doc-qdrant-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify base class initialization was called
|
||||
mock_base_init.assert_called_once()
|
||||
|
||||
# Verify QdrantClient was created with correct parameters
|
||||
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key='test-api-key')
|
||||
|
||||
# Verify processor attributes
|
||||
assert hasattr(processor, 'qdrant')
|
||||
assert processor.qdrant == mock_qdrant_instance
|
||||
assert hasattr(processor, 'last_collection')
|
||||
assert processor.last_collection is None
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_processor_initialization_with_defaults(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test processor initialization with default values"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-doc-qdrant-processor'
|
||||
# No store_uri or api_key provided - should use defaults
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify QdrantClient was created with default URI and None API key
|
||||
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key=None)
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_store_document_embeddings_basic(self, mock_base_init, mock_uuid, mock_qdrant_client):
|
||||
"""Test storing document embeddings with basic message"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True # Collection already exists
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
mock_uuid.uuid4.return_value = MagicMock()
|
||||
mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid-123')
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-doc-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with chunks and vectors
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'test_user'
|
||||
mock_message.metadata.collection = 'test_collection'
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.chunk.decode.return_value = 'test document chunk'
|
||||
mock_chunk.vectors = [[0.1, 0.2, 0.3]] # Single vector with 3 dimensions
|
||||
|
||||
mock_message.chunks = [mock_chunk]
|
||||
|
||||
# Act
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify collection existence was checked
|
||||
expected_collection = 'd_test_user_test_collection_3'
|
||||
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection)
|
||||
|
||||
# Verify upsert was called
|
||||
mock_qdrant_instance.upsert.assert_called_once()
|
||||
|
||||
# Verify upsert parameters
|
||||
upsert_call_args = mock_qdrant_instance.upsert.call_args
|
||||
assert upsert_call_args[1]['collection_name'] == expected_collection
|
||||
assert len(upsert_call_args[1]['points']) == 1
|
||||
|
||||
point = upsert_call_args[1]['points'][0]
|
||||
assert point.vector == [0.1, 0.2, 0.3]
|
||||
assert point.payload['doc'] == 'test document chunk'
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_store_document_embeddings_multiple_chunks(self, mock_base_init, mock_uuid, mock_qdrant_client):
|
||||
"""Test storing document embeddings with multiple chunks"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
mock_uuid.uuid4.return_value = MagicMock()
|
||||
mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid')
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-doc-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with multiple chunks
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'multi_user'
|
||||
mock_message.metadata.collection = 'multi_collection'
|
||||
|
||||
mock_chunk1 = MagicMock()
|
||||
mock_chunk1.chunk.decode.return_value = 'first document chunk'
|
||||
mock_chunk1.vectors = [[0.1, 0.2]]
|
||||
|
||||
mock_chunk2 = MagicMock()
|
||||
mock_chunk2.chunk.decode.return_value = 'second document chunk'
|
||||
mock_chunk2.vectors = [[0.3, 0.4]]
|
||||
|
||||
mock_message.chunks = [mock_chunk1, mock_chunk2]
|
||||
|
||||
# Act
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Should be called twice (once per chunk)
|
||||
assert mock_qdrant_instance.upsert.call_count == 2
|
||||
|
||||
# Verify both chunks were processed
|
||||
upsert_calls = mock_qdrant_instance.upsert.call_args_list
|
||||
|
||||
# First chunk
|
||||
first_call = upsert_calls[0]
|
||||
first_point = first_call[1]['points'][0]
|
||||
assert first_point.vector == [0.1, 0.2]
|
||||
assert first_point.payload['doc'] == 'first document chunk'
|
||||
|
||||
# Second chunk
|
||||
second_call = upsert_calls[1]
|
||||
second_point = second_call[1]['points'][0]
|
||||
assert second_point.vector == [0.3, 0.4]
|
||||
assert second_point.payload['doc'] == 'second document chunk'
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_store_document_embeddings_multiple_vectors_per_chunk(self, mock_base_init, mock_uuid, mock_qdrant_client):
|
||||
"""Test storing document embeddings with multiple vectors per chunk"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
mock_uuid.uuid4.return_value = MagicMock()
|
||||
mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid')
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-doc-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with chunk having multiple vectors
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'vector_user'
|
||||
mock_message.metadata.collection = 'vector_collection'
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.chunk.decode.return_value = 'multi-vector document chunk'
|
||||
mock_chunk.vectors = [
|
||||
[0.1, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.6],
|
||||
[0.7, 0.8, 0.9]
|
||||
]
|
||||
|
||||
mock_message.chunks = [mock_chunk]
|
||||
|
||||
# Act
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Should be called 3 times (once per vector)
|
||||
assert mock_qdrant_instance.upsert.call_count == 3
|
||||
|
||||
# Verify all vectors were processed
|
||||
upsert_calls = mock_qdrant_instance.upsert.call_args_list
|
||||
|
||||
expected_vectors = [
|
||||
[0.1, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.6],
|
||||
[0.7, 0.8, 0.9]
|
||||
]
|
||||
|
||||
for i, call in enumerate(upsert_calls):
|
||||
point = call[1]['points'][0]
|
||||
assert point.vector == expected_vectors[i]
|
||||
assert point.payload['doc'] == 'multi-vector document chunk'
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_store_document_embeddings_empty_chunk(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test storing document embeddings skips empty chunks"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-doc-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with empty chunk
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'empty_user'
|
||||
mock_message.metadata.collection = 'empty_collection'
|
||||
|
||||
mock_chunk_empty = MagicMock()
|
||||
mock_chunk_empty.chunk.decode.return_value = "" # Empty string
|
||||
mock_chunk_empty.vectors = [[0.1, 0.2]]
|
||||
|
||||
mock_message.chunks = [mock_chunk_empty]
|
||||
|
||||
# Act
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Should not call upsert for empty chunks
|
||||
mock_qdrant_instance.upsert.assert_not_called()
|
||||
mock_qdrant_instance.collection_exists.assert_not_called()
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_collection_creation_when_not_exists(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test collection creation when it doesn't exist"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = False # Collection doesn't exist
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-doc-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'new_user'
|
||||
mock_message.metadata.collection = 'new_collection'
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.chunk.decode.return_value = 'test chunk'
|
||||
mock_chunk.vectors = [[0.1, 0.2, 0.3, 0.4, 0.5]] # 5 dimensions
|
||||
|
||||
mock_message.chunks = [mock_chunk]
|
||||
|
||||
# Act
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
expected_collection = 'd_new_user_new_collection_5'
|
||||
|
||||
# Verify collection existence check and creation
|
||||
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection)
|
||||
mock_qdrant_instance.create_collection.assert_called_once()
|
||||
|
||||
# Verify create_collection was called with correct parameters
|
||||
create_call_args = mock_qdrant_instance.create_collection.call_args
|
||||
assert create_call_args[1]['collection_name'] == expected_collection
|
||||
|
||||
# Verify upsert was still called after collection creation
|
||||
mock_qdrant_instance.upsert.assert_called_once()
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_collection_creation_exception(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test collection creation handles exceptions"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = False
|
||||
mock_qdrant_instance.create_collection.side_effect = Exception("Qdrant connection failed")
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-doc-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'error_user'
|
||||
mock_message.metadata.collection = 'error_collection'
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.chunk.decode.return_value = 'test chunk'
|
||||
mock_chunk.vectors = [[0.1, 0.2]]
|
||||
|
||||
mock_message.chunks = [mock_chunk]
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Qdrant connection failed"):
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_collection_caching_behavior(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test collection caching with last_collection"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-doc-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create first mock message
|
||||
mock_message1 = MagicMock()
|
||||
mock_message1.metadata.user = 'cache_user'
|
||||
mock_message1.metadata.collection = 'cache_collection'
|
||||
|
||||
mock_chunk1 = MagicMock()
|
||||
mock_chunk1.chunk.decode.return_value = 'first chunk'
|
||||
mock_chunk1.vectors = [[0.1, 0.2, 0.3]]
|
||||
|
||||
mock_message1.chunks = [mock_chunk1]
|
||||
|
||||
# First call
|
||||
await processor.store_document_embeddings(mock_message1)
|
||||
|
||||
# Reset mock to track second call
|
||||
mock_qdrant_instance.reset_mock()
|
||||
|
||||
# Create second mock message with same dimensions
|
||||
mock_message2 = MagicMock()
|
||||
mock_message2.metadata.user = 'cache_user'
|
||||
mock_message2.metadata.collection = 'cache_collection'
|
||||
|
||||
mock_chunk2 = MagicMock()
|
||||
mock_chunk2.chunk.decode.return_value = 'second chunk'
|
||||
mock_chunk2.vectors = [[0.4, 0.5, 0.6]] # Same dimension (3)
|
||||
|
||||
mock_message2.chunks = [mock_chunk2]
|
||||
|
||||
# Act - Second call with same collection
|
||||
await processor.store_document_embeddings(mock_message2)
|
||||
|
||||
# Assert
|
||||
expected_collection = 'd_cache_user_cache_collection_3'
|
||||
assert processor.last_collection == expected_collection
|
||||
|
||||
# Verify second call skipped existence check (cached)
|
||||
mock_qdrant_instance.collection_exists.assert_not_called()
|
||||
mock_qdrant_instance.create_collection.assert_not_called()
|
||||
|
||||
# But upsert should still be called
|
||||
mock_qdrant_instance.upsert.assert_called_once()
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_different_dimensions_different_collections(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test that different vector dimensions create different collections"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-doc-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with different dimension vectors
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'dim_user'
|
||||
mock_message.metadata.collection = 'dim_collection'
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.chunk.decode.return_value = 'dimension test chunk'
|
||||
mock_chunk.vectors = [
|
||||
[0.1, 0.2], # 2 dimensions
|
||||
[0.3, 0.4, 0.5] # 3 dimensions
|
||||
]
|
||||
|
||||
mock_message.chunks = [mock_chunk]
|
||||
|
||||
# Act
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Should check existence of both collections
|
||||
expected_collections = ['d_dim_user_dim_collection_2', 'd_dim_user_dim_collection_3']
|
||||
actual_calls = [call.args[0] for call in mock_qdrant_instance.collection_exists.call_args_list]
|
||||
assert actual_calls == expected_collections
|
||||
|
||||
# Should upsert to both collections
|
||||
assert mock_qdrant_instance.upsert.call_count == 2
|
||||
|
||||
upsert_calls = mock_qdrant_instance.upsert.call_args_list
|
||||
assert upsert_calls[0][1]['collection_name'] == 'd_dim_user_dim_collection_2'
|
||||
assert upsert_calls[1][1]['collection_name'] == 'd_dim_user_dim_collection_3'
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_add_args_calls_parent(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test that add_args() calls parent add_args method"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_client.return_value = MagicMock()
|
||||
mock_parser = MagicMock()
|
||||
|
||||
# Act
|
||||
with patch('trustgraph.base.DocumentEmbeddingsStoreService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(mock_parser)
|
||||
|
||||
# Assert
|
||||
mock_parent_add_args.assert_called_once_with(mock_parser)
|
||||
|
||||
# Verify processor-specific arguments were added
|
||||
assert mock_parser.add_argument.call_count >= 2 # At least store-uri and api-key
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_utf8_decoding_handling(self, mock_base_init, mock_uuid, mock_qdrant_client):
|
||||
"""Test proper UTF-8 decoding of chunk text"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
mock_uuid.uuid4.return_value = MagicMock()
|
||||
mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid')
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-doc-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with UTF-8 encoded text
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'utf8_user'
|
||||
mock_message.metadata.collection = 'utf8_collection'
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.chunk.decode.return_value = 'UTF-8 text with special chars: café, naïve, résumé'
|
||||
mock_chunk.vectors = [[0.1, 0.2]]
|
||||
|
||||
mock_message.chunks = [mock_chunk]
|
||||
|
||||
# Act
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify chunk.decode was called with 'utf-8'
|
||||
mock_chunk.chunk.decode.assert_called_with('utf-8')
|
||||
|
||||
# Verify the decoded text was stored in payload
|
||||
upsert_call_args = mock_qdrant_instance.upsert.call_args
|
||||
point = upsert_call_args[1]['points'][0]
|
||||
assert point.payload['doc'] == 'UTF-8 text with special chars: café, naïve, résumé'
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_chunk_decode_exception_handling(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test handling of chunk decode exceptions"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-doc-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with decode error
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'decode_user'
|
||||
mock_message.metadata.collection = 'decode_collection'
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.chunk.decode.side_effect = UnicodeDecodeError('utf-8', b'', 0, 1, 'invalid start byte')
|
||||
mock_chunk.vectors = [[0.1, 0.2]]
|
||||
|
||||
mock_message.chunks = [mock_chunk]
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(UnicodeDecodeError):
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
428
tests/unit/test_storage/test_graph_embeddings_qdrant_storage.py
Normal file
428
tests/unit/test_storage/test_graph_embeddings_qdrant_storage.py
Normal file
|
|
@ -0,0 +1,428 @@
|
|||
"""
|
||||
Unit tests for trustgraph.storage.graph_embeddings.qdrant.write
|
||||
Starting small with a single test to verify basic functionality
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.storage.graph_embeddings.qdrant.write import Processor
|
||||
|
||||
|
||||
class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
||||
"""Test Qdrant graph embeddings storage functionality"""
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test basic Qdrant processor initialization"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-qdrant-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify base class initialization was called
|
||||
mock_base_init.assert_called_once()
|
||||
|
||||
# Verify QdrantClient was created with correct parameters
|
||||
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key='test-api-key')
|
||||
|
||||
# Verify processor attributes
|
||||
assert hasattr(processor, 'qdrant')
|
||||
assert processor.qdrant == mock_qdrant_instance
|
||||
assert hasattr(processor, 'last_collection')
|
||||
assert processor.last_collection is None
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
async def test_get_collection_creates_new_collection(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test get_collection creates a new collection when it doesn't exist"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = False
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
collection_name = processor.get_collection(dim=512, user='test_user', collection='test_collection')
|
||||
|
||||
# Assert
|
||||
expected_name = 't_test_user_test_collection_512'
|
||||
assert collection_name == expected_name
|
||||
assert processor.last_collection == expected_name
|
||||
|
||||
# Verify collection existence check and creation
|
||||
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_name)
|
||||
mock_qdrant_instance.create_collection.assert_called_once()
|
||||
|
||||
# Verify create_collection was called with correct parameters
|
||||
create_call_args = mock_qdrant_instance.create_collection.call_args
|
||||
assert create_call_args[1]['collection_name'] == expected_name
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
async def test_store_graph_embeddings_basic(self, mock_base_init, mock_uuid, mock_qdrant_client):
|
||||
"""Test storing graph embeddings with basic message"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True # Collection already exists
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
mock_uuid.uuid4.return_value.return_value = 'test-uuid-123'
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with entities and vectors
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'test_user'
|
||||
mock_message.metadata.collection = 'test_collection'
|
||||
|
||||
mock_entity = MagicMock()
|
||||
mock_entity.entity.value = 'test_entity'
|
||||
mock_entity.vectors = [[0.1, 0.2, 0.3]] # Single vector with 3 dimensions
|
||||
|
||||
mock_message.entities = [mock_entity]
|
||||
|
||||
# Act
|
||||
await processor.store_graph_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Verify collection existence was checked
|
||||
expected_collection = 't_test_user_test_collection_3'
|
||||
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection)
|
||||
|
||||
# Verify upsert was called
|
||||
mock_qdrant_instance.upsert.assert_called_once()
|
||||
|
||||
# Verify upsert parameters
|
||||
upsert_call_args = mock_qdrant_instance.upsert.call_args
|
||||
assert upsert_call_args[1]['collection_name'] == expected_collection
|
||||
assert len(upsert_call_args[1]['points']) == 1
|
||||
|
||||
point = upsert_call_args[1]['points'][0]
|
||||
assert point.vector == [0.1, 0.2, 0.3]
|
||||
assert point.payload['entity'] == 'test_entity'
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
async def test_get_collection_uses_existing_collection(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test get_collection uses existing collection without creating new one"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True # Collection exists
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
collection_name = processor.get_collection(dim=256, user='existing_user', collection='existing_collection')
|
||||
|
||||
# Assert
|
||||
expected_name = 't_existing_user_existing_collection_256'
|
||||
assert collection_name == expected_name
|
||||
assert processor.last_collection == expected_name
|
||||
|
||||
# Verify collection existence check was performed
|
||||
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_name)
|
||||
# Verify create_collection was NOT called
|
||||
mock_qdrant_instance.create_collection.assert_not_called()
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
async def test_get_collection_caches_last_collection(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test get_collection skips checks when using same collection"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# First call
|
||||
collection_name1 = processor.get_collection(dim=128, user='cache_user', collection='cache_collection')
|
||||
|
||||
# Reset mock to track second call
|
||||
mock_qdrant_instance.reset_mock()
|
||||
|
||||
# Act - Second call with same parameters
|
||||
collection_name2 = processor.get_collection(dim=128, user='cache_user', collection='cache_collection')
|
||||
|
||||
# Assert
|
||||
expected_name = 't_cache_user_cache_collection_128'
|
||||
assert collection_name1 == expected_name
|
||||
assert collection_name2 == expected_name
|
||||
|
||||
# Verify second call skipped existence check (cached)
|
||||
mock_qdrant_instance.collection_exists.assert_not_called()
|
||||
mock_qdrant_instance.create_collection.assert_not_called()
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
async def test_get_collection_creation_exception(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test get_collection handles collection creation exceptions"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = False
|
||||
mock_qdrant_instance.create_collection.side_effect = Exception("Qdrant connection failed")
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Qdrant connection failed"):
|
||||
processor.get_collection(dim=512, user='error_user', collection='error_collection')
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
async def test_store_graph_embeddings_multiple_entities(self, mock_base_init, mock_uuid, mock_qdrant_client):
|
||||
"""Test storing graph embeddings with multiple entities"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
mock_uuid.uuid4.return_value.return_value = 'test-uuid'
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with multiple entities
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'multi_user'
|
||||
mock_message.metadata.collection = 'multi_collection'
|
||||
|
||||
mock_entity1 = MagicMock()
|
||||
mock_entity1.entity.value = 'entity_one'
|
||||
mock_entity1.vectors = [[0.1, 0.2]]
|
||||
|
||||
mock_entity2 = MagicMock()
|
||||
mock_entity2.entity.value = 'entity_two'
|
||||
mock_entity2.vectors = [[0.3, 0.4]]
|
||||
|
||||
mock_message.entities = [mock_entity1, mock_entity2]
|
||||
|
||||
# Act
|
||||
await processor.store_graph_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Should be called twice (once per entity)
|
||||
assert mock_qdrant_instance.upsert.call_count == 2
|
||||
|
||||
# Verify both entities were processed
|
||||
upsert_calls = mock_qdrant_instance.upsert.call_args_list
|
||||
|
||||
# First entity
|
||||
first_call = upsert_calls[0]
|
||||
first_point = first_call[1]['points'][0]
|
||||
assert first_point.vector == [0.1, 0.2]
|
||||
assert first_point.payload['entity'] == 'entity_one'
|
||||
|
||||
# Second entity
|
||||
second_call = upsert_calls[1]
|
||||
second_point = second_call[1]['points'][0]
|
||||
assert second_point.vector == [0.3, 0.4]
|
||||
assert second_point.payload['entity'] == 'entity_two'
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
async def test_store_graph_embeddings_multiple_vectors_per_entity(self, mock_base_init, mock_uuid, mock_qdrant_client):
|
||||
"""Test storing graph embeddings with multiple vectors per entity"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
mock_uuid.uuid4.return_value.return_value = 'test-uuid'
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with entity having multiple vectors
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'vector_user'
|
||||
mock_message.metadata.collection = 'vector_collection'
|
||||
|
||||
mock_entity = MagicMock()
|
||||
mock_entity.entity.value = 'multi_vector_entity'
|
||||
mock_entity.vectors = [
|
||||
[0.1, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.6],
|
||||
[0.7, 0.8, 0.9]
|
||||
]
|
||||
|
||||
mock_message.entities = [mock_entity]
|
||||
|
||||
# Act
|
||||
await processor.store_graph_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Should be called 3 times (once per vector)
|
||||
assert mock_qdrant_instance.upsert.call_count == 3
|
||||
|
||||
# Verify all vectors were processed
|
||||
upsert_calls = mock_qdrant_instance.upsert.call_args_list
|
||||
|
||||
expected_vectors = [
|
||||
[0.1, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.6],
|
||||
[0.7, 0.8, 0.9]
|
||||
]
|
||||
|
||||
for i, call in enumerate(upsert_calls):
|
||||
point = call[1]['points'][0]
|
||||
assert point.vector == expected_vectors[i]
|
||||
assert point.payload['entity'] == 'multi_vector_entity'
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
async def test_store_graph_embeddings_empty_entity_value(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test storing graph embeddings skips empty entity values"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-qdrant-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Create mock message with empty entity value
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'empty_user'
|
||||
mock_message.metadata.collection = 'empty_collection'
|
||||
|
||||
mock_entity_empty = MagicMock()
|
||||
mock_entity_empty.entity.value = "" # Empty string
|
||||
mock_entity_empty.vectors = [[0.1, 0.2]]
|
||||
|
||||
mock_entity_none = MagicMock()
|
||||
mock_entity_none.entity.value = None # None value
|
||||
mock_entity_none.vectors = [[0.3, 0.4]]
|
||||
|
||||
mock_message.entities = [mock_entity_empty, mock_entity_none]
|
||||
|
||||
# Act
|
||||
await processor.store_graph_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Should not call upsert for empty entities
|
||||
mock_qdrant_instance.upsert.assert_not_called()
|
||||
mock_qdrant_instance.collection_exists.assert_not_called()
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
async def test_processor_initialization_with_defaults(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test processor initialization with default values"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
config = {
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-qdrant-processor'
|
||||
# No store_uri or api_key provided - should use defaults
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify QdrantClient was created with default URI and None API key
|
||||
mock_qdrant_client.assert_called_once_with(url='http://localhost:6333', api_key=None)
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
async def test_add_args_calls_parent(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test that add_args() calls parent add_args method"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_client.return_value = MagicMock()
|
||||
mock_parser = MagicMock()
|
||||
|
||||
# Act
|
||||
with patch('trustgraph.base.GraphEmbeddingsStoreService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(mock_parser)
|
||||
|
||||
# Assert
|
||||
mock_parent_add_args.assert_called_once_with(mock_parser)
|
||||
|
||||
# Verify processor-specific arguments were added
|
||||
assert mock_parser.add_argument.call_count >= 2 # At least store-uri and api-key
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
373
tests/unit/test_storage/test_triples_cassandra_storage.py
Normal file
373
tests/unit/test_storage/test_triples_cassandra_storage.py
Normal file
|
|
@ -0,0 +1,373 @@
|
|||
"""
|
||||
Tests for Cassandra triples storage service
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
|
||||
from trustgraph.storage.triples.cassandra.write import Processor
|
||||
from trustgraph.schema import Value, Triple
|
||||
|
||||
|
||||
class TestCassandraStorageProcessor:
|
||||
"""Test cases for Cassandra storage processor"""
|
||||
|
||||
def test_processor_initialization_with_defaults(self):
|
||||
"""Test processor initialization with default parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
assert processor.graph_host == ['localhost']
|
||||
assert processor.username is None
|
||||
assert processor.password is None
|
||||
assert processor.table is None
|
||||
|
||||
def test_processor_initialization_with_custom_params(self):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
id='custom-storage',
|
||||
graph_host='cassandra.example.com',
|
||||
graph_username='testuser',
|
||||
graph_password='testpass'
|
||||
)
|
||||
|
||||
assert processor.graph_host == ['cassandra.example.com']
|
||||
assert processor.username == 'testuser'
|
||||
assert processor.password == 'testpass'
|
||||
assert processor.table is None
|
||||
|
||||
def test_processor_initialization_with_partial_auth(self):
|
||||
"""Test processor initialization with only username (no password)"""
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
graph_username='testuser'
|
||||
)
|
||||
|
||||
assert processor.username == 'testuser'
|
||||
assert processor.password is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.TrustGraph')
|
||||
async def test_table_switching_with_auth(self, mock_trustgraph):
|
||||
"""Test table switching logic when authentication is provided"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=taskgroup_mock,
|
||||
graph_username='testuser',
|
||||
graph_password='testpass'
|
||||
)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'user1'
|
||||
mock_message.metadata.collection = 'collection1'
|
||||
mock_message.triples = []
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify TrustGraph was called with auth parameters
|
||||
mock_trustgraph.assert_called_once_with(
|
||||
hosts=['localhost'],
|
||||
keyspace='user1',
|
||||
table='collection1',
|
||||
username='testuser',
|
||||
password='testpass'
|
||||
)
|
||||
assert processor.table == ('user1', 'collection1')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.TrustGraph')
|
||||
async def test_table_switching_without_auth(self, mock_trustgraph):
|
||||
"""Test table switching logic when no authentication is provided"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'user2'
|
||||
mock_message.metadata.collection = 'collection2'
|
||||
mock_message.triples = []
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify TrustGraph was called without auth parameters
|
||||
mock_trustgraph.assert_called_once_with(
|
||||
hosts=['localhost'],
|
||||
keyspace='user2',
|
||||
table='collection2'
|
||||
)
|
||||
assert processor.table == ('user2', 'collection2')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.TrustGraph')
|
||||
async def test_table_reuse_when_same(self, mock_trustgraph):
|
||||
"""Test that TrustGraph is not recreated when table hasn't changed"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'user1'
|
||||
mock_message.metadata.collection = 'collection1'
|
||||
mock_message.triples = []
|
||||
|
||||
# First call should create TrustGraph
|
||||
await processor.store_triples(mock_message)
|
||||
assert mock_trustgraph.call_count == 1
|
||||
|
||||
# Second call with same table should reuse TrustGraph
|
||||
await processor.store_triples(mock_message)
|
||||
assert mock_trustgraph.call_count == 1 # Should not increase
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.TrustGraph')
|
||||
async def test_triple_insertion(self, mock_trustgraph):
|
||||
"""Test that triples are properly inserted into Cassandra"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create mock triples
|
||||
triple1 = MagicMock()
|
||||
triple1.s.value = 'subject1'
|
||||
triple1.p.value = 'predicate1'
|
||||
triple1.o.value = 'object1'
|
||||
|
||||
triple2 = MagicMock()
|
||||
triple2.s.value = 'subject2'
|
||||
triple2.p.value = 'predicate2'
|
||||
triple2.o.value = 'object2'
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'user1'
|
||||
mock_message.metadata.collection = 'collection1'
|
||||
mock_message.triples = [triple1, triple2]
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify both triples were inserted
|
||||
assert mock_tg_instance.insert.call_count == 2
|
||||
mock_tg_instance.insert.assert_any_call('subject1', 'predicate1', 'object1')
|
||||
mock_tg_instance.insert.assert_any_call('subject2', 'predicate2', 'object2')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.TrustGraph')
|
||||
async def test_triple_insertion_with_empty_list(self, mock_trustgraph):
|
||||
"""Test behavior when message has no triples"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create mock message with empty triples
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'user1'
|
||||
mock_message.metadata.collection = 'collection1'
|
||||
mock_message.triples = []
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify no triples were inserted
|
||||
mock_tg_instance.insert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.TrustGraph')
|
||||
@patch('trustgraph.storage.triples.cassandra.write.time.sleep')
|
||||
async def test_exception_handling_with_retry(self, mock_sleep, mock_trustgraph):
|
||||
"""Test exception handling during TrustGraph creation"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_trustgraph.side_effect = Exception("Connection failed")
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'user1'
|
||||
mock_message.metadata.collection = 'collection1'
|
||||
mock_message.triples = []
|
||||
|
||||
with pytest.raises(Exception, match="Connection failed"):
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify sleep was called before re-raising
|
||||
mock_sleep.assert_called_once_with(1)
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
# Mock the parent class add_args method
|
||||
with patch('trustgraph.storage.triples.cassandra.write.TriplesStoreService.add_args') as mock_parent_add_args:
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Verify parent add_args was called
|
||||
mock_parent_add_args.assert_called_once_with(parser)
|
||||
|
||||
# Verify our specific arguments were added
|
||||
# Parse empty args to check defaults
|
||||
args = parser.parse_args([])
|
||||
|
||||
assert hasattr(args, 'graph_host')
|
||||
assert args.graph_host == 'localhost'
|
||||
assert hasattr(args, 'graph_username')
|
||||
assert args.graph_username is None
|
||||
assert hasattr(args, 'graph_password')
|
||||
assert args.graph_password is None
|
||||
|
||||
def test_add_args_with_custom_values(self):
|
||||
"""Test add_args with custom command line values"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.storage.triples.cassandra.write.TriplesStoreService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with custom values
|
||||
args = parser.parse_args([
|
||||
'--graph-host', 'cassandra.example.com',
|
||||
'--graph-username', 'testuser',
|
||||
'--graph-password', 'testpass'
|
||||
])
|
||||
|
||||
assert args.graph_host == 'cassandra.example.com'
|
||||
assert args.graph_username == 'testuser'
|
||||
assert args.graph_password == 'testpass'
|
||||
|
||||
def test_add_args_short_form(self):
|
||||
"""Test add_args with short form arguments"""
|
||||
from argparse import ArgumentParser
|
||||
from unittest.mock import patch
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
with patch('trustgraph.storage.triples.cassandra.write.TriplesStoreService.add_args'):
|
||||
Processor.add_args(parser)
|
||||
|
||||
# Test parsing with short form
|
||||
args = parser.parse_args(['-g', 'short.example.com'])
|
||||
|
||||
assert args.graph_host == 'short.example.com'
|
||||
|
||||
@patch('trustgraph.storage.triples.cassandra.write.Processor.launch')
|
||||
def test_run_function(self, mock_launch):
|
||||
"""Test the run function calls Processor.launch with correct parameters"""
|
||||
from trustgraph.storage.triples.cassandra.write import run, default_ident
|
||||
|
||||
run()
|
||||
|
||||
mock_launch.assert_called_once_with(default_ident, '\nGraph writer. Input is graph edge. Writes edges to Cassandra graph.\n')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.TrustGraph')
|
||||
async def test_store_triples_table_switching_between_different_tables(self, mock_trustgraph):
|
||||
"""Test table switching when different tables are used in sequence"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance1 = MagicMock()
|
||||
mock_tg_instance2 = MagicMock()
|
||||
mock_trustgraph.side_effect = [mock_tg_instance1, mock_tg_instance2]
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# First message with table1
|
||||
mock_message1 = MagicMock()
|
||||
mock_message1.metadata.user = 'user1'
|
||||
mock_message1.metadata.collection = 'collection1'
|
||||
mock_message1.triples = []
|
||||
|
||||
await processor.store_triples(mock_message1)
|
||||
assert processor.table == ('user1', 'collection1')
|
||||
assert processor.tg == mock_tg_instance1
|
||||
|
||||
# Second message with different table
|
||||
mock_message2 = MagicMock()
|
||||
mock_message2.metadata.user = 'user2'
|
||||
mock_message2.metadata.collection = 'collection2'
|
||||
mock_message2.triples = []
|
||||
|
||||
await processor.store_triples(mock_message2)
|
||||
assert processor.table == ('user2', 'collection2')
|
||||
assert processor.tg == mock_tg_instance2
|
||||
|
||||
# Verify TrustGraph was created twice for different tables
|
||||
assert mock_trustgraph.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.TrustGraph')
|
||||
async def test_store_triples_with_special_characters_in_values(self, mock_trustgraph):
|
||||
"""Test storing triples with special characters and unicode"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_trustgraph.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create triple with special characters
|
||||
triple = MagicMock()
|
||||
triple.s.value = 'subject with spaces & symbols'
|
||||
triple.p.value = 'predicate:with/colons'
|
||||
triple.o.value = 'object with "quotes" and unicode: ñáéíóú'
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'test_user'
|
||||
mock_message.metadata.collection = 'test_collection'
|
||||
mock_message.triples = [triple]
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify the triple was inserted with special characters preserved
|
||||
mock_tg_instance.insert.assert_called_once_with(
|
||||
'subject with spaces & symbols',
|
||||
'predicate:with/colons',
|
||||
'object with "quotes" and unicode: ñáéíóú'
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.TrustGraph')
|
||||
async def test_store_triples_preserves_old_table_on_exception(self, mock_trustgraph):
|
||||
"""Test that table remains unchanged when TrustGraph creation fails"""
|
||||
taskgroup_mock = MagicMock()
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Set an initial table
|
||||
processor.table = ('old_user', 'old_collection')
|
||||
|
||||
# Mock TrustGraph to raise exception
|
||||
mock_trustgraph.side_effect = Exception("Connection failed")
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'new_user'
|
||||
mock_message.metadata.collection = 'new_collection'
|
||||
mock_message.triples = []
|
||||
|
||||
with pytest.raises(Exception, match="Connection failed"):
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Table should remain unchanged since self.table = table happens after try/except
|
||||
assert processor.table == ('old_user', 'old_collection')
|
||||
# TrustGraph should be set to None though
|
||||
assert processor.tg is None
|
||||
3
tests/unit/test_text_completion/__init__.py
Normal file
3
tests/unit/test_text_completion/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
"""
|
||||
Unit tests for text completion services
|
||||
"""
|
||||
3
tests/unit/test_text_completion/common/__init__.py
Normal file
3
tests/unit/test_text_completion/common/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
"""
|
||||
Common utilities for text completion tests
|
||||
"""
|
||||
69
tests/unit/test_text_completion/common/base_test_cases.py
Normal file
69
tests/unit/test_text_completion/common/base_test_cases.py
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
"""
|
||||
Base test patterns that can be reused across different text completion models
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
|
||||
class BaseTextCompletionTestCase(IsolatedAsyncioTestCase, ABC):
|
||||
"""
|
||||
Base test class for text completion processors
|
||||
Provides common test patterns that can be reused
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_processor_class(self):
|
||||
"""Return the processor class to test"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_base_config(self):
|
||||
"""Return base configuration for the processor"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_mock_patches(self):
|
||||
"""Return list of patch decorators for mocking dependencies"""
|
||||
pass
|
||||
|
||||
def create_base_config(self, **overrides):
|
||||
"""Create base config with optional overrides"""
|
||||
config = self.get_base_config()
|
||||
config.update(overrides)
|
||||
return config
|
||||
|
||||
def create_mock_llm_result(self, text="Test response", in_token=10, out_token=5):
|
||||
"""Create a mock LLM result"""
|
||||
from trustgraph.base import LlmResult
|
||||
return LlmResult(text=text, in_token=in_token, out_token=out_token)
|
||||
|
||||
|
||||
class CommonTestPatterns:
|
||||
"""
|
||||
Common test patterns that can be used across different models
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def basic_initialization_test_pattern(test_instance):
|
||||
"""
|
||||
Test pattern for basic processor initialization
|
||||
test_instance should be a BaseTextCompletionTestCase
|
||||
"""
|
||||
# This would contain the common pattern for initialization testing
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def successful_generation_test_pattern(test_instance):
|
||||
"""
|
||||
Test pattern for successful content generation
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def error_handling_test_pattern(test_instance):
|
||||
"""
|
||||
Test pattern for error handling
|
||||
"""
|
||||
pass
|
||||
53
tests/unit/test_text_completion/common/mock_helpers.py
Normal file
53
tests/unit/test_text_completion/common/mock_helpers.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
"""
|
||||
Common mocking utilities for text completion tests
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
|
||||
class CommonMocks:
|
||||
"""Common mock objects used across text completion tests"""
|
||||
|
||||
@staticmethod
|
||||
def create_mock_async_processor_init():
|
||||
"""Create mock for AsyncProcessor.__init__"""
|
||||
mock = MagicMock()
|
||||
mock.return_value = None
|
||||
return mock
|
||||
|
||||
@staticmethod
|
||||
def create_mock_llm_service_init():
|
||||
"""Create mock for LlmService.__init__"""
|
||||
mock = MagicMock()
|
||||
mock.return_value = None
|
||||
return mock
|
||||
|
||||
@staticmethod
|
||||
def create_mock_response(text="Test response", prompt_tokens=10, completion_tokens=5):
|
||||
"""Create a mock response object"""
|
||||
response = MagicMock()
|
||||
response.text = text
|
||||
response.usage_metadata.prompt_token_count = prompt_tokens
|
||||
response.usage_metadata.candidates_token_count = completion_tokens
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def create_basic_config():
|
||||
"""Create basic config with required fields"""
|
||||
return {
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
|
||||
class MockPatches:
|
||||
"""Common patch decorators for different services"""
|
||||
|
||||
@staticmethod
|
||||
def get_base_patches():
|
||||
"""Get patches that are common to all processors"""
|
||||
return [
|
||||
'trustgraph.base.async_processor.AsyncProcessor.__init__',
|
||||
'trustgraph.base.llm_service.LlmService.__init__'
|
||||
]
|
||||
499
tests/unit/test_text_completion/conftest.py
Normal file
499
tests/unit/test_text_completion/conftest.py
Normal file
|
|
@ -0,0 +1,499 @@
|
|||
"""
|
||||
Pytest configuration and fixtures for text completion tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
from trustgraph.base import LlmResult
|
||||
|
||||
|
||||
# === Common Fixtures for All Text Completion Models ===
|
||||
|
||||
@pytest.fixture
|
||||
def base_processor_config():
|
||||
"""Base configuration required by all processors"""
|
||||
return {
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_llm_result():
|
||||
"""Sample LlmResult for testing"""
|
||||
return LlmResult(
|
||||
text="Test response",
|
||||
in_token=10,
|
||||
out_token=5
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_async_processor_init():
|
||||
"""Mock AsyncProcessor.__init__ to avoid infrastructure requirements"""
|
||||
mock = MagicMock()
|
||||
mock.return_value = None
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_service_init():
|
||||
"""Mock LlmService.__init__ to avoid infrastructure requirements"""
|
||||
mock = MagicMock()
|
||||
mock.return_value = None
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_prometheus_metrics():
|
||||
"""Mock Prometheus metrics"""
|
||||
mock_metric = MagicMock()
|
||||
mock_metric.labels.return_value.time.return_value = MagicMock()
|
||||
return mock_metric
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pulsar_consumer():
|
||||
"""Mock Pulsar consumer for integration testing"""
|
||||
return AsyncMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pulsar_producer():
|
||||
"""Mock Pulsar producer for integration testing"""
|
||||
return AsyncMock()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_env_vars(monkeypatch):
|
||||
"""Mock environment variables for testing"""
|
||||
monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "test-project")
|
||||
monkeypatch.setenv("GOOGLE_APPLICATION_CREDENTIALS", "/path/to/test-credentials.json")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_async_context_manager():
|
||||
"""Mock async context manager for testing"""
|
||||
class MockAsyncContextManager:
|
||||
def __init__(self, return_value):
|
||||
self.return_value = return_value
|
||||
|
||||
async def __aenter__(self):
|
||||
return self.return_value
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
pass
|
||||
|
||||
return MockAsyncContextManager
|
||||
|
||||
|
||||
# === VertexAI Specific Fixtures ===
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vertexai_credentials():
|
||||
"""Mock Google Cloud service account credentials"""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vertexai_model():
|
||||
"""Mock VertexAI GenerativeModel"""
|
||||
mock_model = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Test response"
|
||||
mock_response.usage_metadata.prompt_token_count = 10
|
||||
mock_response.usage_metadata.candidates_token_count = 5
|
||||
mock_model.generate_content.return_value = mock_response
|
||||
return mock_model
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vertexai_processor_config(base_processor_config):
|
||||
"""Default configuration for VertexAI processor"""
|
||||
config = base_processor_config.copy()
|
||||
config.update({
|
||||
'region': 'us-central1',
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'private_key': 'private.json'
|
||||
})
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_safety_settings():
|
||||
"""Mock safety settings for VertexAI"""
|
||||
safety_settings = []
|
||||
for i in range(4): # 4 safety categories
|
||||
setting = MagicMock()
|
||||
setting.category = f"HARM_CATEGORY_{i}"
|
||||
setting.threshold = "BLOCK_MEDIUM_AND_ABOVE"
|
||||
safety_settings.append(setting)
|
||||
|
||||
return safety_settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_generation_config():
|
||||
"""Mock generation configuration for VertexAI"""
|
||||
config = MagicMock()
|
||||
config.temperature = 0.0
|
||||
config.max_output_tokens = 8192
|
||||
config.top_p = 1.0
|
||||
config.top_k = 10
|
||||
config.candidate_count = 1
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vertexai_exception():
|
||||
"""Mock VertexAI exceptions"""
|
||||
from google.api_core.exceptions import ResourceExhausted
|
||||
return ResourceExhausted("Test resource exhausted error")
|
||||
|
||||
|
||||
# === Ollama Specific Fixtures ===
|
||||
|
||||
@pytest.fixture
|
||||
def ollama_processor_config(base_processor_config):
|
||||
"""Default configuration for Ollama processor"""
|
||||
config = base_processor_config.copy()
|
||||
config.update({
|
||||
'model': 'llama2',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'host': 'localhost',
|
||||
'port': 11434
|
||||
})
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ollama_client():
|
||||
"""Mock Ollama client"""
|
||||
mock_client = MagicMock()
|
||||
mock_response = {
|
||||
'response': 'Test response from Ollama',
|
||||
'done': True,
|
||||
'eval_count': 5,
|
||||
'prompt_eval_count': 10
|
||||
}
|
||||
mock_client.generate.return_value = mock_response
|
||||
return mock_client
|
||||
|
||||
|
||||
# === OpenAI Specific Fixtures ===
|
||||
|
||||
@pytest.fixture
|
||||
def openai_processor_config(base_processor_config):
|
||||
"""Default configuration for OpenAI processor"""
|
||||
config = base_processor_config.copy()
|
||||
config.update({
|
||||
'model': 'gpt-3.5-turbo',
|
||||
'api_key': 'test-api-key',
|
||||
'url': 'https://api.openai.com/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096
|
||||
})
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_client():
|
||||
"""Mock OpenAI client"""
|
||||
mock_client = MagicMock()
|
||||
|
||||
# Mock the response structure
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Test response from OpenAI"
|
||||
mock_response.usage.prompt_tokens = 15
|
||||
mock_response.usage.completion_tokens = 8
|
||||
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_rate_limit_error():
|
||||
"""Mock OpenAI rate limit error"""
|
||||
from openai import RateLimitError
|
||||
return RateLimitError("Rate limit exceeded", response=MagicMock(), body=None)
|
||||
|
||||
|
||||
# === Azure OpenAI Specific Fixtures ===
|
||||
|
||||
@pytest.fixture
|
||||
def azure_openai_processor_config(base_processor_config):
|
||||
"""Default configuration for Azure OpenAI processor"""
|
||||
config = base_processor_config.copy()
|
||||
config.update({
|
||||
'model': 'gpt-4',
|
||||
'endpoint': 'https://test.openai.azure.com/',
|
||||
'token': 'test-token',
|
||||
'api_version': '2024-12-01-preview',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192
|
||||
})
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_azure_openai_client():
|
||||
"""Mock Azure OpenAI client"""
|
||||
mock_client = MagicMock()
|
||||
|
||||
# Mock the response structure
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Test response from Azure OpenAI"
|
||||
mock_response.usage.prompt_tokens = 20
|
||||
mock_response.usage.completion_tokens = 10
|
||||
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_azure_openai_rate_limit_error():
|
||||
"""Mock Azure OpenAI rate limit error"""
|
||||
from openai import RateLimitError
|
||||
return RateLimitError("Rate limit exceeded", response=MagicMock(), body=None)
|
||||
|
||||
|
||||
# === Azure Specific Fixtures ===
|
||||
|
||||
@pytest.fixture
|
||||
def azure_processor_config(base_processor_config):
|
||||
"""Default configuration for Azure processor"""
|
||||
config = base_processor_config.copy()
|
||||
config.update({
|
||||
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
|
||||
'token': 'test-token',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192
|
||||
})
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_azure_requests():
|
||||
"""Mock requests for Azure processor"""
|
||||
mock_requests = MagicMock()
|
||||
|
||||
# Mock successful response
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
'choices': [{
|
||||
'message': {
|
||||
'content': 'Test response from Azure'
|
||||
}
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 18,
|
||||
'completion_tokens': 9
|
||||
}
|
||||
}
|
||||
mock_requests.post.return_value = mock_response
|
||||
return mock_requests
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_azure_rate_limit_response():
|
||||
"""Mock Azure rate limit response"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 429
|
||||
return mock_response
|
||||
|
||||
|
||||
# === Claude Specific Fixtures ===
|
||||
|
||||
@pytest.fixture
|
||||
def claude_processor_config(base_processor_config):
|
||||
"""Default configuration for Claude processor"""
|
||||
config = base_processor_config.copy()
|
||||
config.update({
|
||||
'model': 'claude-3-5-sonnet-20240620',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192
|
||||
})
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_claude_client():
|
||||
"""Mock Claude (Anthropic) client"""
|
||||
mock_client = MagicMock()
|
||||
|
||||
# Mock the response structure
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock()]
|
||||
mock_response.content[0].text = "Test response from Claude"
|
||||
mock_response.usage.input_tokens = 22
|
||||
mock_response.usage.output_tokens = 12
|
||||
|
||||
mock_client.messages.create.return_value = mock_response
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_claude_rate_limit_error():
|
||||
"""Mock Claude rate limit error"""
|
||||
import anthropic
|
||||
return anthropic.RateLimitError("Rate limit exceeded", response=MagicMock(), body=None)
|
||||
|
||||
|
||||
# === vLLM Specific Fixtures ===
|
||||
|
||||
@pytest.fixture
|
||||
def vllm_processor_config(base_processor_config):
|
||||
"""Default configuration for vLLM processor"""
|
||||
config = base_processor_config.copy()
|
||||
config.update({
|
||||
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
|
||||
'url': 'http://vllm-service:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048
|
||||
})
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vllm_session():
|
||||
"""Mock aiohttp ClientSession for vLLM"""
|
||||
mock_session = MagicMock()
|
||||
|
||||
# Mock successful response
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'choices': [{
|
||||
'text': 'Test response from vLLM'
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 16,
|
||||
'completion_tokens': 8
|
||||
}
|
||||
})
|
||||
|
||||
# Mock the async context manager
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session.post.return_value.__aexit__.return_value = None
|
||||
|
||||
return mock_session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vllm_error_response():
|
||||
"""Mock vLLM error response"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 500
|
||||
return mock_response
|
||||
|
||||
|
||||
# === Cohere Specific Fixtures ===
|
||||
|
||||
@pytest.fixture
|
||||
def cohere_processor_config(base_processor_config):
|
||||
"""Default configuration for Cohere processor"""
|
||||
config = base_processor_config.copy()
|
||||
config.update({
|
||||
'model': 'c4ai-aya-23-8b',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0
|
||||
})
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_cohere_client():
|
||||
"""Mock Cohere client"""
|
||||
mock_client = MagicMock()
|
||||
|
||||
# Mock the response structure
|
||||
mock_output = MagicMock()
|
||||
mock_output.text = "Test response from Cohere"
|
||||
mock_output.meta.billed_units.input_tokens = 18
|
||||
mock_output.meta.billed_units.output_tokens = 10
|
||||
|
||||
mock_client.chat.return_value = mock_output
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_cohere_rate_limit_error():
|
||||
"""Mock Cohere rate limit error"""
|
||||
import cohere
|
||||
return cohere.TooManyRequestsError("Rate limit exceeded")
|
||||
|
||||
|
||||
# === Google AI Studio Specific Fixtures ===
|
||||
|
||||
@pytest.fixture
|
||||
def googleaistudio_processor_config(base_processor_config):
|
||||
"""Default configuration for Google AI Studio processor"""
|
||||
config = base_processor_config.copy()
|
||||
config.update({
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192
|
||||
})
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_googleaistudio_client():
|
||||
"""Mock Google AI Studio client"""
|
||||
mock_client = MagicMock()
|
||||
|
||||
# Mock the response structure
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Test response from Google AI Studio"
|
||||
mock_response.usage_metadata.prompt_token_count = 20
|
||||
mock_response.usage_metadata.candidates_token_count = 12
|
||||
|
||||
mock_client.models.generate_content.return_value = mock_response
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_googleaistudio_rate_limit_error():
|
||||
"""Mock Google AI Studio rate limit error"""
|
||||
from google.api_core.exceptions import ResourceExhausted
|
||||
return ResourceExhausted("Rate limit exceeded")
|
||||
|
||||
|
||||
# === LlamaFile Specific Fixtures ===
|
||||
|
||||
@pytest.fixture
|
||||
def llamafile_processor_config(base_processor_config):
|
||||
"""Default configuration for LlamaFile processor"""
|
||||
config = base_processor_config.copy()
|
||||
config.update({
|
||||
'model': 'LLaMA_CPP',
|
||||
'llamafile': 'http://localhost:8080/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096
|
||||
})
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llamafile_client():
|
||||
"""Mock OpenAI client for LlamaFile"""
|
||||
mock_client = MagicMock()
|
||||
|
||||
# Mock the response structure
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Test response from LlamaFile"
|
||||
mock_response.usage.prompt_tokens = 14
|
||||
mock_response.usage.completion_tokens = 8
|
||||
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
return mock_client
|
||||
407
tests/unit/test_text_completion/test_azure_openai_processor.py
Normal file
407
tests/unit/test_text_completion/test_azure_openai_processor.py
Normal file
|
|
@ -0,0 +1,407 @@
|
|||
"""
|
||||
Unit tests for trustgraph.model.text_completion.azure_openai
|
||||
Following the same successful pattern as previous tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.model.text_completion.azure_openai.llm import Processor
|
||||
from trustgraph.base import LlmResult
|
||||
from trustgraph.exceptions import TooManyRequests
|
||||
|
||||
|
||||
class TestAzureOpenAIProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Test Azure OpenAI processor functionality"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
|
||||
"""Test basic processor initialization"""
|
||||
# Arrange
|
||||
mock_azure_client = MagicMock()
|
||||
mock_azure_openai_class.return_value = mock_azure_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-4',
|
||||
'endpoint': 'https://test.openai.azure.com/',
|
||||
'token': 'test-token',
|
||||
'api_version': '2024-12-01-preview',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gpt-4'
|
||||
assert processor.temperature == 0.0
|
||||
assert processor.max_output == 4192
|
||||
assert hasattr(processor, 'openai')
|
||||
mock_azure_openai_class.assert_called_once_with(
|
||||
api_key='test-token',
|
||||
api_version='2024-12-01-preview',
|
||||
azure_endpoint='https://test.openai.azure.com/'
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
|
||||
"""Test successful content generation"""
|
||||
# Arrange
|
||||
mock_azure_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Generated response from Azure OpenAI"
|
||||
mock_response.usage.prompt_tokens = 25
|
||||
mock_response.usage.completion_tokens = 15
|
||||
|
||||
mock_azure_client.chat.completions.create.return_value = mock_response
|
||||
mock_azure_openai_class.return_value = mock_azure_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-4',
|
||||
'endpoint': 'https://test.openai.azure.com/',
|
||||
'token': 'test-token',
|
||||
'api_version': '2024-12-01-preview',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Generated response from Azure OpenAI"
|
||||
assert result.in_token == 25
|
||||
assert result.out_token == 15
|
||||
assert result.model == 'gpt-4'
|
||||
|
||||
# Verify the Azure OpenAI API call
|
||||
mock_azure_client.chat.completions.create.assert_called_once_with(
|
||||
model='gpt-4',
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": "System prompt\n\nUser prompt"
|
||||
}]
|
||||
}],
|
||||
temperature=0.0,
|
||||
max_tokens=4192,
|
||||
top_p=1
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_rate_limit_error(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
|
||||
"""Test rate limit error handling"""
|
||||
# Arrange
|
||||
from openai import RateLimitError
|
||||
|
||||
mock_azure_client = MagicMock()
|
||||
mock_azure_client.chat.completions.create.side_effect = RateLimitError("Rate limit exceeded", response=MagicMock(), body=None)
|
||||
mock_azure_openai_class.return_value = mock_azure_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-4',
|
||||
'endpoint': 'https://test.openai.azure.com/',
|
||||
'token': 'test-token',
|
||||
'api_version': '2024-12-01-preview',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(TooManyRequests):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
|
||||
"""Test handling of generic exceptions"""
|
||||
# Arrange
|
||||
mock_azure_client = MagicMock()
|
||||
mock_azure_client.chat.completions.create.side_effect = Exception("Azure API connection error")
|
||||
mock_azure_openai_class.return_value = mock_azure_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-4',
|
||||
'endpoint': 'https://test.openai.azure.com/',
|
||||
'token': 'test-token',
|
||||
'api_version': '2024-12-01-preview',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Azure API connection error"):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_without_endpoint(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
|
||||
"""Test processor initialization without endpoint (should fail)"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-4',
|
||||
'endpoint': None, # No endpoint provided
|
||||
'token': 'test-token',
|
||||
'api_version': '2024-12-01-preview',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError, match="Azure endpoint not specified"):
|
||||
processor = Processor(**config)
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_without_token(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
|
||||
"""Test processor initialization without token (should fail)"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-4',
|
||||
'endpoint': 'https://test.openai.azure.com/',
|
||||
'token': None, # No token provided
|
||||
'api_version': '2024-12-01-preview',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError, match="Azure token not specified"):
|
||||
processor = Processor(**config)
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
# Arrange
|
||||
mock_azure_client = MagicMock()
|
||||
mock_azure_openai_class.return_value = mock_azure_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-35-turbo',
|
||||
'endpoint': 'https://custom.openai.azure.com/',
|
||||
'token': 'custom-token',
|
||||
'api_version': '2023-05-15',
|
||||
'temperature': 0.7,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gpt-35-turbo'
|
||||
assert processor.temperature == 0.7
|
||||
assert processor.max_output == 2048
|
||||
mock_azure_openai_class.assert_called_once_with(
|
||||
api_key='custom-token',
|
||||
api_version='2023-05-15',
|
||||
azure_endpoint='https://custom.openai.azure.com/'
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_defaults(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
|
||||
"""Test processor initialization with default values"""
|
||||
# Arrange
|
||||
mock_azure_client = MagicMock()
|
||||
mock_azure_openai_class.return_value = mock_azure_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
# Only provide required fields, should use defaults
|
||||
config = {
|
||||
'endpoint': 'https://test.openai.azure.com/',
|
||||
'token': 'test-token',
|
||||
'model': 'gpt-4', # Required for Azure
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gpt-4'
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
assert processor.max_output == 4192 # default_max_output
|
||||
mock_azure_openai_class.assert_called_once_with(
|
||||
api_key='test-token',
|
||||
api_version='2024-12-01-preview', # default_api
|
||||
azure_endpoint='https://test.openai.azure.com/'
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
|
||||
"""Test content generation with empty prompts"""
|
||||
# Arrange
|
||||
mock_azure_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Default response"
|
||||
mock_response.usage.prompt_tokens = 2
|
||||
mock_response.usage.completion_tokens = 3
|
||||
|
||||
mock_azure_client.chat.completions.create.return_value = mock_response
|
||||
mock_azure_openai_class.return_value = mock_azure_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-4',
|
||||
'endpoint': 'https://test.openai.azure.com/',
|
||||
'token': 'test-token',
|
||||
'api_version': '2024-12-01-preview',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("", "")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Default response"
|
||||
assert result.in_token == 2
|
||||
assert result.out_token == 3
|
||||
assert result.model == 'gpt-4'
|
||||
|
||||
# Verify the combined prompt is sent correctly
|
||||
call_args = mock_azure_client.chat.completions.create.call_args
|
||||
expected_prompt = "\n\n" # Empty system + "\n\n" + empty user
|
||||
assert call_args[1]['messages'][0]['content'][0]['text'] == expected_prompt
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_message_structure(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
|
||||
"""Test that Azure OpenAI messages are structured correctly"""
|
||||
# Arrange
|
||||
mock_azure_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Response with proper structure"
|
||||
mock_response.usage.prompt_tokens = 30
|
||||
mock_response.usage.completion_tokens = 20
|
||||
|
||||
mock_azure_client.chat.completions.create.return_value = mock_response
|
||||
mock_azure_openai_class.return_value = mock_azure_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-4',
|
||||
'endpoint': 'https://test.openai.azure.com/',
|
||||
'token': 'test-token',
|
||||
'api_version': '2024-12-01-preview',
|
||||
'temperature': 0.5,
|
||||
'max_output': 1024,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("You are a helpful assistant", "What is AI?")
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with proper structure"
|
||||
assert result.in_token == 30
|
||||
assert result.out_token == 20
|
||||
|
||||
# Verify the message structure matches Azure OpenAI Chat API format
|
||||
call_args = mock_azure_client.chat.completions.create.call_args
|
||||
messages = call_args[1]['messages']
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0]['role'] == 'user'
|
||||
assert messages[0]['content'][0]['type'] == 'text'
|
||||
assert messages[0]['content'][0]['text'] == "You are a helpful assistant\n\nWhat is AI?"
|
||||
|
||||
# Verify other parameters
|
||||
assert call_args[1]['model'] == 'gpt-4'
|
||||
assert call_args[1]['temperature'] == 0.5
|
||||
assert call_args[1]['max_tokens'] == 1024
|
||||
assert call_args[1]['top_p'] == 1
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
463
tests/unit/test_text_completion/test_azure_processor.py
Normal file
463
tests/unit/test_text_completion/test_azure_processor.py
Normal file
|
|
@ -0,0 +1,463 @@
|
|||
"""
|
||||
Unit tests for trustgraph.model.text_completion.azure
|
||||
Following the same successful pattern as previous tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.model.text_completion.azure.llm import Processor
|
||||
from trustgraph.base import LlmResult
|
||||
from trustgraph.exceptions import TooManyRequests
|
||||
|
||||
|
||||
class TestAzureProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Test Azure processor functionality"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_requests):
|
||||
"""Test basic processor initialization"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
|
||||
'token': 'test-token',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.endpoint == 'https://test.inference.ai.azure.com/v1/chat/completions'
|
||||
assert processor.token == 'test-token'
|
||||
assert processor.temperature == 0.0
|
||||
assert processor.max_output == 4192
|
||||
assert processor.model == 'AzureAI'
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_requests):
|
||||
"""Test successful content generation"""
|
||||
# Arrange
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
'choices': [{
|
||||
'message': {
|
||||
'content': 'Generated response from Azure'
|
||||
}
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 20,
|
||||
'completion_tokens': 12
|
||||
}
|
||||
}
|
||||
mock_requests.post.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
|
||||
'token': 'test-token',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Generated response from Azure"
|
||||
assert result.in_token == 20
|
||||
assert result.out_token == 12
|
||||
assert result.model == 'AzureAI'
|
||||
|
||||
# Verify the API call was made correctly
|
||||
mock_requests.post.assert_called_once()
|
||||
call_args = mock_requests.post.call_args
|
||||
|
||||
# Check URL
|
||||
assert call_args[0][0] == 'https://test.inference.ai.azure.com/v1/chat/completions'
|
||||
|
||||
# Check headers
|
||||
headers = call_args[1]['headers']
|
||||
assert headers['Content-Type'] == 'application/json'
|
||||
assert headers['Authorization'] == 'Bearer test-token'
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_rate_limit_error(self, mock_llm_init, mock_async_init, mock_requests):
|
||||
"""Test rate limit error handling"""
|
||||
# Arrange
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 429
|
||||
mock_requests.post.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
|
||||
'token': 'test-token',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(TooManyRequests):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_http_error(self, mock_llm_init, mock_async_init, mock_requests):
|
||||
"""Test HTTP error handling"""
|
||||
# Arrange
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
mock_requests.post.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
|
||||
'token': 'test-token',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError, match="LLM failure"):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_requests):
|
||||
"""Test handling of generic exceptions"""
|
||||
# Arrange
|
||||
mock_requests.post.side_effect = Exception("Connection error")
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
|
||||
'token': 'test-token',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Connection error"):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_without_endpoint(self, mock_llm_init, mock_async_init, mock_requests):
|
||||
"""Test processor initialization without endpoint (should fail)"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'endpoint': None, # No endpoint provided
|
||||
'token': 'test-token',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError, match="Azure endpoint not specified"):
|
||||
processor = Processor(**config)
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_without_token(self, mock_llm_init, mock_async_init, mock_requests):
|
||||
"""Test processor initialization without token (should fail)"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
|
||||
'token': None, # No token provided
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError, match="Azure token not specified"):
|
||||
processor = Processor(**config)
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_requests):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'endpoint': 'https://custom.inference.ai.azure.com/v1/chat/completions',
|
||||
'token': 'custom-token',
|
||||
'temperature': 0.7,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.endpoint == 'https://custom.inference.ai.azure.com/v1/chat/completions'
|
||||
assert processor.token == 'custom-token'
|
||||
assert processor.temperature == 0.7
|
||||
assert processor.max_output == 2048
|
||||
assert processor.model == 'AzureAI'
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_defaults(self, mock_llm_init, mock_async_init, mock_requests):
|
||||
"""Test processor initialization with default values"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
# Only provide required fields, should use defaults
|
||||
config = {
|
||||
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
|
||||
'token': 'test-token',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.endpoint == 'https://test.inference.ai.azure.com/v1/chat/completions'
|
||||
assert processor.token == 'test-token'
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
assert processor.max_output == 4192 # default_max_output
|
||||
assert processor.model == 'AzureAI' # default_model
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_requests):
|
||||
"""Test content generation with empty prompts"""
|
||||
# Arrange
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
'choices': [{
|
||||
'message': {
|
||||
'content': 'Default response'
|
||||
}
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 2,
|
||||
'completion_tokens': 3
|
||||
}
|
||||
}
|
||||
mock_requests.post.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
|
||||
'token': 'test-token',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("", "")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Default response"
|
||||
assert result.in_token == 2
|
||||
assert result.out_token == 3
|
||||
assert result.model == 'AzureAI'
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_build_prompt_structure(self, mock_llm_init, mock_async_init, mock_requests):
|
||||
"""Test that build_prompt creates correct message structure"""
|
||||
# Arrange
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
'choices': [{
|
||||
'message': {
|
||||
'content': 'Response with proper structure'
|
||||
}
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 25,
|
||||
'completion_tokens': 15
|
||||
}
|
||||
}
|
||||
mock_requests.post.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
|
||||
'token': 'test-token',
|
||||
'temperature': 0.5,
|
||||
'max_output': 1024,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("You are a helpful assistant", "What is AI?")
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with proper structure"
|
||||
assert result.in_token == 25
|
||||
assert result.out_token == 15
|
||||
|
||||
# Verify the request structure
|
||||
mock_requests.post.assert_called_once()
|
||||
call_args = mock_requests.post.call_args
|
||||
|
||||
# Parse the request body
|
||||
import json
|
||||
request_body = json.loads(call_args[1]['data'])
|
||||
|
||||
# Verify message structure
|
||||
assert 'messages' in request_body
|
||||
assert len(request_body['messages']) == 2
|
||||
|
||||
# Check system message
|
||||
assert request_body['messages'][0]['role'] == 'system'
|
||||
assert request_body['messages'][0]['content'] == 'You are a helpful assistant'
|
||||
|
||||
# Check user message
|
||||
assert request_body['messages'][1]['role'] == 'user'
|
||||
assert request_body['messages'][1]['content'] == 'What is AI?'
|
||||
|
||||
# Check parameters
|
||||
assert request_body['temperature'] == 0.5
|
||||
assert request_body['max_tokens'] == 1024
|
||||
assert request_body['top_p'] == 1
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_call_llm_method(self, mock_llm_init, mock_async_init, mock_requests):
|
||||
"""Test the call_llm method directly"""
|
||||
# Arrange
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
'choices': [{
|
||||
'message': {
|
||||
'content': 'Test response'
|
||||
}
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 10,
|
||||
'completion_tokens': 5
|
||||
}
|
||||
}
|
||||
mock_requests.post.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
|
||||
'token': 'test-token',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = processor.call_llm('{"test": "body"}')
|
||||
|
||||
# Assert
|
||||
assert result == mock_response.json.return_value
|
||||
|
||||
# Verify the request was made correctly
|
||||
mock_requests.post.assert_called_once_with(
|
||||
'https://test.inference.ai.azure.com/v1/chat/completions',
|
||||
data='{"test": "body"}',
|
||||
headers={
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': 'Bearer test-token'
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
440
tests/unit/test_text_completion/test_claude_processor.py
Normal file
440
tests/unit/test_text_completion/test_claude_processor.py
Normal file
|
|
@ -0,0 +1,440 @@
|
|||
"""
|
||||
Unit tests for trustgraph.model.text_completion.claude
|
||||
Following the same successful pattern as previous tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.model.text_completion.claude.llm import Processor
|
||||
from trustgraph.base import LlmResult
|
||||
from trustgraph.exceptions import TooManyRequests
|
||||
|
||||
|
||||
class TestClaudeProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Test Claude processor functionality"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_anthropic_class):
|
||||
"""Test basic processor initialization"""
|
||||
# Arrange
|
||||
mock_claude_client = MagicMock()
|
||||
mock_anthropic_class.return_value = mock_claude_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'claude-3-5-sonnet-20240620',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'claude-3-5-sonnet-20240620'
|
||||
assert processor.temperature == 0.0
|
||||
assert processor.max_output == 8192
|
||||
assert hasattr(processor, 'claude')
|
||||
mock_anthropic_class.assert_called_once_with(api_key='test-api-key')
|
||||
|
||||
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_anthropic_class):
|
||||
"""Test successful content generation"""
|
||||
# Arrange
|
||||
mock_claude_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock()]
|
||||
mock_response.content[0].text = "Generated response from Claude"
|
||||
mock_response.usage.input_tokens = 25
|
||||
mock_response.usage.output_tokens = 15
|
||||
|
||||
mock_claude_client.messages.create.return_value = mock_response
|
||||
mock_anthropic_class.return_value = mock_claude_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'claude-3-5-sonnet-20240620',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Generated response from Claude"
|
||||
assert result.in_token == 25
|
||||
assert result.out_token == 15
|
||||
assert result.model == 'claude-3-5-sonnet-20240620'
|
||||
|
||||
# Verify the Claude API call
|
||||
mock_claude_client.messages.create.assert_called_once_with(
|
||||
model='claude-3-5-sonnet-20240620',
|
||||
max_tokens=8192,
|
||||
temperature=0.0,
|
||||
system="System prompt",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": "User prompt"
|
||||
}]
|
||||
}]
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_rate_limit_error(self, mock_llm_init, mock_async_init, mock_anthropic_class):
|
||||
"""Test rate limit error handling"""
|
||||
# Arrange
|
||||
import anthropic
|
||||
|
||||
mock_claude_client = MagicMock()
|
||||
mock_claude_client.messages.create.side_effect = anthropic.RateLimitError(
|
||||
"Rate limit exceeded",
|
||||
response=MagicMock(),
|
||||
body=None
|
||||
)
|
||||
mock_anthropic_class.return_value = mock_claude_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'claude-3-5-sonnet-20240620',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(TooManyRequests):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_anthropic_class):
|
||||
"""Test handling of generic exceptions"""
|
||||
# Arrange
|
||||
mock_claude_client = MagicMock()
|
||||
mock_claude_client.messages.create.side_effect = Exception("API connection error")
|
||||
mock_anthropic_class.return_value = mock_claude_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'claude-3-5-sonnet-20240620',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="API connection error"):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_without_api_key(self, mock_llm_init, mock_async_init, mock_anthropic_class):
|
||||
"""Test processor initialization without API key (should fail)"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'claude-3-5-sonnet-20240620',
|
||||
'api_key': None, # No API key provided
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError, match="Claude API key not specified"):
|
||||
processor = Processor(**config)
|
||||
|
||||
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_anthropic_class):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
# Arrange
|
||||
mock_claude_client = MagicMock()
|
||||
mock_anthropic_class.return_value = mock_claude_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'claude-3-haiku-20240307',
|
||||
'api_key': 'custom-api-key',
|
||||
'temperature': 0.7,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'claude-3-haiku-20240307'
|
||||
assert processor.temperature == 0.7
|
||||
assert processor.max_output == 4096
|
||||
mock_anthropic_class.assert_called_once_with(api_key='custom-api-key')
|
||||
|
||||
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_defaults(self, mock_llm_init, mock_async_init, mock_anthropic_class):
|
||||
"""Test processor initialization with default values"""
|
||||
# Arrange
|
||||
mock_claude_client = MagicMock()
|
||||
mock_anthropic_class.return_value = mock_claude_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
# Only provide required fields, should use defaults
|
||||
config = {
|
||||
'api_key': 'test-api-key',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'claude-3-5-sonnet-20240620' # default_model
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
assert processor.max_output == 8192 # default_max_output
|
||||
mock_anthropic_class.assert_called_once_with(api_key='test-api-key')
|
||||
|
||||
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_anthropic_class):
|
||||
"""Test content generation with empty prompts"""
|
||||
# Arrange
|
||||
mock_claude_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock()]
|
||||
mock_response.content[0].text = "Default response"
|
||||
mock_response.usage.input_tokens = 2
|
||||
mock_response.usage.output_tokens = 3
|
||||
|
||||
mock_claude_client.messages.create.return_value = mock_response
|
||||
mock_anthropic_class.return_value = mock_claude_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'claude-3-5-sonnet-20240620',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("", "")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Default response"
|
||||
assert result.in_token == 2
|
||||
assert result.out_token == 3
|
||||
assert result.model == 'claude-3-5-sonnet-20240620'
|
||||
|
||||
# Verify the system prompt and user content are handled correctly
|
||||
call_args = mock_claude_client.messages.create.call_args
|
||||
assert call_args[1]['system'] == ""
|
||||
assert call_args[1]['messages'][0]['content'][0]['text'] == ""
|
||||
|
||||
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_message_structure(self, mock_llm_init, mock_async_init, mock_anthropic_class):
|
||||
"""Test that Claude messages are structured correctly"""
|
||||
# Arrange
|
||||
mock_claude_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock()]
|
||||
mock_response.content[0].text = "Response with proper structure"
|
||||
mock_response.usage.input_tokens = 30
|
||||
mock_response.usage.output_tokens = 20
|
||||
|
||||
mock_claude_client.messages.create.return_value = mock_response
|
||||
mock_anthropic_class.return_value = mock_claude_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'claude-3-5-sonnet-20240620',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.5,
|
||||
'max_output': 1024,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("You are a helpful assistant", "What is AI?")
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with proper structure"
|
||||
assert result.in_token == 30
|
||||
assert result.out_token == 20
|
||||
|
||||
# Verify the message structure matches Claude API format
|
||||
call_args = mock_claude_client.messages.create.call_args
|
||||
|
||||
# Check system prompt
|
||||
assert call_args[1]['system'] == "You are a helpful assistant"
|
||||
|
||||
# Check user message structure
|
||||
messages = call_args[1]['messages']
|
||||
assert len(messages) == 1
|
||||
assert messages[0]['role'] == 'user'
|
||||
assert messages[0]['content'][0]['type'] == 'text'
|
||||
assert messages[0]['content'][0]['text'] == "What is AI?"
|
||||
|
||||
# Verify other parameters
|
||||
assert call_args[1]['model'] == 'claude-3-5-sonnet-20240620'
|
||||
assert call_args[1]['temperature'] == 0.5
|
||||
assert call_args[1]['max_tokens'] == 1024
|
||||
|
||||
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_multiple_content_blocks(self, mock_llm_init, mock_async_init, mock_anthropic_class):
|
||||
"""Test handling of multiple content blocks in response"""
|
||||
# Arrange
|
||||
mock_claude_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
|
||||
# Mock multiple content blocks (Claude can return multiple)
|
||||
mock_content_1 = MagicMock()
|
||||
mock_content_1.text = "First part of response"
|
||||
mock_content_2 = MagicMock()
|
||||
mock_content_2.text = "Second part of response"
|
||||
mock_response.content = [mock_content_1, mock_content_2]
|
||||
|
||||
mock_response.usage.input_tokens = 40
|
||||
mock_response.usage.output_tokens = 30
|
||||
|
||||
mock_claude_client.messages.create.return_value = mock_response
|
||||
mock_anthropic_class.return_value = mock_claude_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'claude-3-5-sonnet-20240620',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
# Should take the first content block
|
||||
assert result.text == "First part of response"
|
||||
assert result.in_token == 40
|
||||
assert result.out_token == 30
|
||||
assert result.model == 'claude-3-5-sonnet-20240620'
|
||||
|
||||
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_claude_client_initialization(self, mock_llm_init, mock_async_init, mock_anthropic_class):
|
||||
"""Test that Claude client is initialized correctly"""
|
||||
# Arrange
|
||||
mock_claude_client = MagicMock()
|
||||
mock_anthropic_class.return_value = mock_claude_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'claude-3-opus-20240229',
|
||||
'api_key': 'sk-ant-test-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify Anthropic client was called with correct API key
|
||||
mock_anthropic_class.assert_called_once_with(api_key='sk-ant-test-key')
|
||||
|
||||
# Verify processor has the client
|
||||
assert processor.claude == mock_claude_client
|
||||
assert processor.model == 'claude-3-opus-20240229'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
447
tests/unit/test_text_completion/test_cohere_processor.py
Normal file
447
tests/unit/test_text_completion/test_cohere_processor.py
Normal file
|
|
@ -0,0 +1,447 @@
|
|||
"""
|
||||
Unit tests for trustgraph.model.text_completion.cohere
|
||||
Following the same successful pattern as previous tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.model.text_completion.cohere.llm import Processor
|
||||
from trustgraph.base import LlmResult
|
||||
from trustgraph.exceptions import TooManyRequests
|
||||
|
||||
|
||||
class TestCohereProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Test Cohere processor functionality"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_cohere_class):
|
||||
"""Test basic processor initialization"""
|
||||
# Arrange
|
||||
mock_cohere_client = MagicMock()
|
||||
mock_cohere_class.return_value = mock_cohere_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'c4ai-aya-23-8b',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'c4ai-aya-23-8b'
|
||||
assert processor.temperature == 0.0
|
||||
assert hasattr(processor, 'cohere')
|
||||
mock_cohere_class.assert_called_once_with(api_key='test-api-key')
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_cohere_class):
|
||||
"""Test successful content generation"""
|
||||
# Arrange
|
||||
mock_cohere_client = MagicMock()
|
||||
mock_output = MagicMock()
|
||||
mock_output.text = "Generated response from Cohere"
|
||||
mock_output.meta.billed_units.input_tokens = 25
|
||||
mock_output.meta.billed_units.output_tokens = 15
|
||||
|
||||
mock_cohere_client.chat.return_value = mock_output
|
||||
mock_cohere_class.return_value = mock_cohere_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'c4ai-aya-23-8b',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Generated response from Cohere"
|
||||
assert result.in_token == 25
|
||||
assert result.out_token == 15
|
||||
assert result.model == 'c4ai-aya-23-8b'
|
||||
|
||||
# Verify the Cohere API call
|
||||
mock_cohere_client.chat.assert_called_once_with(
|
||||
model='c4ai-aya-23-8b',
|
||||
message="User prompt",
|
||||
preamble="System prompt",
|
||||
temperature=0.0,
|
||||
chat_history=[],
|
||||
prompt_truncation='auto',
|
||||
connectors=[]
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_rate_limit_error(self, mock_llm_init, mock_async_init, mock_cohere_class):
|
||||
"""Test rate limit error handling"""
|
||||
# Arrange
|
||||
import cohere
|
||||
|
||||
mock_cohere_client = MagicMock()
|
||||
mock_cohere_client.chat.side_effect = cohere.TooManyRequestsError("Rate limit exceeded")
|
||||
mock_cohere_class.return_value = mock_cohere_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'c4ai-aya-23-8b',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(TooManyRequests):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_cohere_class):
|
||||
"""Test handling of generic exceptions"""
|
||||
# Arrange
|
||||
mock_cohere_client = MagicMock()
|
||||
mock_cohere_client.chat.side_effect = Exception("API connection error")
|
||||
mock_cohere_class.return_value = mock_cohere_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'c4ai-aya-23-8b',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="API connection error"):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_without_api_key(self, mock_llm_init, mock_async_init, mock_cohere_class):
|
||||
"""Test processor initialization without API key (should fail)"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'c4ai-aya-23-8b',
|
||||
'api_key': None, # No API key provided
|
||||
'temperature': 0.0,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError, match="Cohere API key not specified"):
|
||||
processor = Processor(**config)
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_cohere_class):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
# Arrange
|
||||
mock_cohere_client = MagicMock()
|
||||
mock_cohere_class.return_value = mock_cohere_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'command-light',
|
||||
'api_key': 'custom-api-key',
|
||||
'temperature': 0.7,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'command-light'
|
||||
assert processor.temperature == 0.7
|
||||
mock_cohere_class.assert_called_once_with(api_key='custom-api-key')
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_defaults(self, mock_llm_init, mock_async_init, mock_cohere_class):
|
||||
"""Test processor initialization with default values"""
|
||||
# Arrange
|
||||
mock_cohere_client = MagicMock()
|
||||
mock_cohere_class.return_value = mock_cohere_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
# Only provide required fields, should use defaults
|
||||
config = {
|
||||
'api_key': 'test-api-key',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'c4ai-aya-23-8b' # default_model
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
mock_cohere_class.assert_called_once_with(api_key='test-api-key')
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_cohere_class):
|
||||
"""Test content generation with empty prompts"""
|
||||
# Arrange
|
||||
mock_cohere_client = MagicMock()
|
||||
mock_output = MagicMock()
|
||||
mock_output.text = "Default response"
|
||||
mock_output.meta.billed_units.input_tokens = 2
|
||||
mock_output.meta.billed_units.output_tokens = 3
|
||||
|
||||
mock_cohere_client.chat.return_value = mock_output
|
||||
mock_cohere_class.return_value = mock_cohere_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'c4ai-aya-23-8b',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("", "")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Default response"
|
||||
assert result.in_token == 2
|
||||
assert result.out_token == 3
|
||||
assert result.model == 'c4ai-aya-23-8b'
|
||||
|
||||
# Verify the preamble and message are handled correctly
|
||||
call_args = mock_cohere_client.chat.call_args
|
||||
assert call_args[1]['preamble'] == ""
|
||||
assert call_args[1]['message'] == ""
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_chat_structure(self, mock_llm_init, mock_async_init, mock_cohere_class):
|
||||
"""Test that Cohere chat is structured correctly"""
|
||||
# Arrange
|
||||
mock_cohere_client = MagicMock()
|
||||
mock_output = MagicMock()
|
||||
mock_output.text = "Response with proper structure"
|
||||
mock_output.meta.billed_units.input_tokens = 30
|
||||
mock_output.meta.billed_units.output_tokens = 20
|
||||
|
||||
mock_cohere_client.chat.return_value = mock_output
|
||||
mock_cohere_class.return_value = mock_cohere_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'c4ai-aya-23-8b',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.5,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("You are a helpful assistant", "What is AI?")
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with proper structure"
|
||||
assert result.in_token == 30
|
||||
assert result.out_token == 20
|
||||
|
||||
# Verify the chat structure matches Cohere API format
|
||||
call_args = mock_cohere_client.chat.call_args
|
||||
|
||||
# Check parameters
|
||||
assert call_args[1]['model'] == 'c4ai-aya-23-8b'
|
||||
assert call_args[1]['message'] == "What is AI?"
|
||||
assert call_args[1]['preamble'] == "You are a helpful assistant"
|
||||
assert call_args[1]['temperature'] == 0.5
|
||||
assert call_args[1]['chat_history'] == []
|
||||
assert call_args[1]['prompt_truncation'] == 'auto'
|
||||
assert call_args[1]['connectors'] == []
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_token_parsing(self, mock_llm_init, mock_async_init, mock_cohere_class):
|
||||
"""Test token parsing from Cohere response"""
|
||||
# Arrange
|
||||
mock_cohere_client = MagicMock()
|
||||
mock_output = MagicMock()
|
||||
mock_output.text = "Token parsing test"
|
||||
mock_output.meta.billed_units.input_tokens = 50
|
||||
mock_output.meta.billed_units.output_tokens = 25
|
||||
|
||||
mock_cohere_client.chat.return_value = mock_output
|
||||
mock_cohere_class.return_value = mock_cohere_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'c4ai-aya-23-8b',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System", "User query")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Token parsing test"
|
||||
assert result.in_token == 50
|
||||
assert result.out_token == 25
|
||||
assert result.model == 'c4ai-aya-23-8b'
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_cohere_client_initialization(self, mock_llm_init, mock_async_init, mock_cohere_class):
|
||||
"""Test that Cohere client is initialized correctly"""
|
||||
# Arrange
|
||||
mock_cohere_client = MagicMock()
|
||||
mock_cohere_class.return_value = mock_cohere_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'command-r',
|
||||
'api_key': 'co-test-key',
|
||||
'temperature': 0.0,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify Cohere client was called with correct API key
|
||||
mock_cohere_class.assert_called_once_with(api_key='co-test-key')
|
||||
|
||||
# Verify processor has the client
|
||||
assert processor.cohere == mock_cohere_client
|
||||
assert processor.model == 'command-r'
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_chat_parameters(self, mock_llm_init, mock_async_init, mock_cohere_class):
|
||||
"""Test that all chat parameters are passed correctly"""
|
||||
# Arrange
|
||||
mock_cohere_client = MagicMock()
|
||||
mock_output = MagicMock()
|
||||
mock_output.text = "Chat parameter test"
|
||||
mock_output.meta.billed_units.input_tokens = 20
|
||||
mock_output.meta.billed_units.output_tokens = 10
|
||||
|
||||
mock_cohere_client.chat.return_value = mock_output
|
||||
mock_cohere_class.return_value = mock_cohere_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'c4ai-aya-23-8b',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.3,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System instructions", "User question")
|
||||
|
||||
# Assert
|
||||
assert result.text == "Chat parameter test"
|
||||
|
||||
# Verify all parameters are passed correctly
|
||||
call_args = mock_cohere_client.chat.call_args
|
||||
assert call_args[1]['model'] == 'c4ai-aya-23-8b'
|
||||
assert call_args[1]['message'] == "User question"
|
||||
assert call_args[1]['preamble'] == "System instructions"
|
||||
assert call_args[1]['temperature'] == 0.3
|
||||
assert call_args[1]['chat_history'] == []
|
||||
assert call_args[1]['prompt_truncation'] == 'auto'
|
||||
assert call_args[1]['connectors'] == []
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
482
tests/unit/test_text_completion/test_googleaistudio_processor.py
Normal file
482
tests/unit/test_text_completion/test_googleaistudio_processor.py
Normal file
|
|
@ -0,0 +1,482 @@
|
|||
"""
|
||||
Unit tests for trustgraph.model.text_completion.googleaistudio
|
||||
Following the same successful pattern as previous tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.model.text_completion.googleaistudio.llm import Processor
|
||||
from trustgraph.base import LlmResult
|
||||
from trustgraph.exceptions import TooManyRequests
|
||||
|
||||
|
||||
class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Test Google AI Studio processor functionality"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test basic processor initialization"""
|
||||
# Arrange
|
||||
mock_genai_client = MagicMock()
|
||||
mock_genai_class.return_value = mock_genai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gemini-2.0-flash-001'
|
||||
assert processor.temperature == 0.0
|
||||
assert processor.max_output == 8192
|
||||
assert hasattr(processor, 'client')
|
||||
assert hasattr(processor, 'safety_settings')
|
||||
assert len(processor.safety_settings) == 4 # 4 safety categories
|
||||
mock_genai_class.assert_called_once_with(api_key='test-api-key')
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test successful content generation"""
|
||||
# Arrange
|
||||
mock_genai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Generated response from Google AI Studio"
|
||||
mock_response.usage_metadata.prompt_token_count = 25
|
||||
mock_response.usage_metadata.candidates_token_count = 15
|
||||
|
||||
mock_genai_client.models.generate_content.return_value = mock_response
|
||||
mock_genai_class.return_value = mock_genai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Generated response from Google AI Studio"
|
||||
assert result.in_token == 25
|
||||
assert result.out_token == 15
|
||||
assert result.model == 'gemini-2.0-flash-001'
|
||||
|
||||
# Verify the Google AI Studio API call
|
||||
mock_genai_client.models.generate_content.assert_called_once()
|
||||
call_args = mock_genai_client.models.generate_content.call_args
|
||||
assert call_args[1]['model'] == 'gemini-2.0-flash-001'
|
||||
assert call_args[1]['contents'] == "User prompt"
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_rate_limit_error(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test rate limit error handling"""
|
||||
# Arrange
|
||||
from google.api_core.exceptions import ResourceExhausted
|
||||
|
||||
mock_genai_client = MagicMock()
|
||||
mock_genai_client.models.generate_content.side_effect = ResourceExhausted("Rate limit exceeded")
|
||||
mock_genai_class.return_value = mock_genai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(TooManyRequests):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test handling of generic exceptions"""
|
||||
# Arrange
|
||||
mock_genai_client = MagicMock()
|
||||
mock_genai_client.models.generate_content.side_effect = Exception("API connection error")
|
||||
mock_genai_class.return_value = mock_genai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="API connection error"):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_without_api_key(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test processor initialization without API key (should fail)"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'api_key': None, # No API key provided
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError, match="Google AI Studio API key not specified"):
|
||||
processor = Processor(**config)
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
# Arrange
|
||||
mock_genai_client = MagicMock()
|
||||
mock_genai_class.return_value = mock_genai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-1.5-pro',
|
||||
'api_key': 'custom-api-key',
|
||||
'temperature': 0.7,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gemini-1.5-pro'
|
||||
assert processor.temperature == 0.7
|
||||
assert processor.max_output == 4096
|
||||
mock_genai_class.assert_called_once_with(api_key='custom-api-key')
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_defaults(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test processor initialization with default values"""
|
||||
# Arrange
|
||||
mock_genai_client = MagicMock()
|
||||
mock_genai_class.return_value = mock_genai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
# Only provide required fields, should use defaults
|
||||
config = {
|
||||
'api_key': 'test-api-key',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gemini-2.0-flash-001' # default_model
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
assert processor.max_output == 8192 # default_max_output
|
||||
mock_genai_class.assert_called_once_with(api_key='test-api-key')
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test content generation with empty prompts"""
|
||||
# Arrange
|
||||
mock_genai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Default response"
|
||||
mock_response.usage_metadata.prompt_token_count = 2
|
||||
mock_response.usage_metadata.candidates_token_count = 3
|
||||
|
||||
mock_genai_client.models.generate_content.return_value = mock_response
|
||||
mock_genai_class.return_value = mock_genai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("", "")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Default response"
|
||||
assert result.in_token == 2
|
||||
assert result.out_token == 3
|
||||
assert result.model == 'gemini-2.0-flash-001'
|
||||
|
||||
# Verify the system instruction and content are handled correctly
|
||||
call_args = mock_genai_client.models.generate_content.call_args
|
||||
assert call_args[1]['contents'] == ""
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_configuration_structure(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test that generation configuration is structured correctly"""
|
||||
# Arrange
|
||||
mock_genai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Response with proper structure"
|
||||
mock_response.usage_metadata.prompt_token_count = 30
|
||||
mock_response.usage_metadata.candidates_token_count = 20
|
||||
|
||||
mock_genai_client.models.generate_content.return_value = mock_response
|
||||
mock_genai_class.return_value = mock_genai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.5,
|
||||
'max_output': 1024,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("You are a helpful assistant", "What is AI?")
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with proper structure"
|
||||
assert result.in_token == 30
|
||||
assert result.out_token == 20
|
||||
|
||||
# Verify the generation configuration
|
||||
call_args = mock_genai_client.models.generate_content.call_args
|
||||
config_arg = call_args[1]['config']
|
||||
|
||||
# Check that the configuration has the right structure
|
||||
assert call_args[1]['model'] == 'gemini-2.0-flash-001'
|
||||
assert call_args[1]['contents'] == "What is AI?"
|
||||
# Config should be a GenerateContentConfig object
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_safety_settings_initialization(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test that safety settings are initialized correctly"""
|
||||
# Arrange
|
||||
mock_genai_client = MagicMock()
|
||||
mock_genai_class.return_value = mock_genai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert hasattr(processor, 'safety_settings')
|
||||
assert len(processor.safety_settings) == 4
|
||||
# Should have 4 safety categories: hate speech, harassment, sexually explicit, dangerous content
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_token_parsing(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test token parsing from Google AI Studio response"""
|
||||
# Arrange
|
||||
mock_genai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Token parsing test"
|
||||
mock_response.usage_metadata.prompt_token_count = 50
|
||||
mock_response.usage_metadata.candidates_token_count = 25
|
||||
|
||||
mock_genai_client.models.generate_content.return_value = mock_response
|
||||
mock_genai_class.return_value = mock_genai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System", "User query")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Token parsing test"
|
||||
assert result.in_token == 50
|
||||
assert result.out_token == 25
|
||||
assert result.model == 'gemini-2.0-flash-001'
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_genai_client_initialization(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test that Google AI Studio client is initialized correctly"""
|
||||
# Arrange
|
||||
mock_genai_client = MagicMock()
|
||||
mock_genai_class.return_value = mock_genai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-1.5-flash',
|
||||
'api_key': 'gai-test-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify Google AI Studio client was called with correct API key
|
||||
mock_genai_class.assert_called_once_with(api_key='gai-test-key')
|
||||
|
||||
# Verify processor has the client
|
||||
assert processor.client == mock_genai_client
|
||||
assert processor.model == 'gemini-1.5-flash'
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_system_instruction(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test that system instruction is handled correctly"""
|
||||
# Arrange
|
||||
mock_genai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "System instruction test"
|
||||
mock_response.usage_metadata.prompt_token_count = 35
|
||||
mock_response.usage_metadata.candidates_token_count = 25
|
||||
|
||||
mock_genai_client.models.generate_content.return_value = mock_response
|
||||
mock_genai_class.return_value = mock_genai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("Be helpful and concise", "Explain quantum computing")
|
||||
|
||||
# Assert
|
||||
assert result.text == "System instruction test"
|
||||
assert result.in_token == 35
|
||||
assert result.out_token == 25
|
||||
|
||||
# Verify the system instruction is passed in the config
|
||||
call_args = mock_genai_client.models.generate_content.call_args
|
||||
config_arg = call_args[1]['config']
|
||||
# The system instruction should be in the config object
|
||||
assert call_args[1]['contents'] == "Explain quantum computing"
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
454
tests/unit/test_text_completion/test_llamafile_processor.py
Normal file
454
tests/unit/test_text_completion/test_llamafile_processor.py
Normal file
|
|
@ -0,0 +1,454 @@
|
|||
"""
|
||||
Unit tests for trustgraph.model.text_completion.llamafile
|
||||
Following the same successful pattern as previous tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.model.text_completion.llamafile.llm import Processor
|
||||
from trustgraph.base import LlmResult
|
||||
from trustgraph.exceptions import TooManyRequests
|
||||
|
||||
|
||||
class TestLlamaFileProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Test LlamaFile processor functionality"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test basic processor initialization"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'LLaMA_CPP',
|
||||
'llamafile': 'http://localhost:8080/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'LLaMA_CPP'
|
||||
assert processor.llamafile == 'http://localhost:8080/v1'
|
||||
assert processor.temperature == 0.0
|
||||
assert processor.max_output == 4096
|
||||
assert hasattr(processor, 'openai')
|
||||
mock_openai_class.assert_called_once_with(
|
||||
base_url='http://localhost:8080/v1',
|
||||
api_key='sk-no-key-required'
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test successful content generation"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Generated response from LlamaFile"
|
||||
mock_response.usage.prompt_tokens = 20
|
||||
mock_response.usage.completion_tokens = 12
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'LLaMA_CPP',
|
||||
'llamafile': 'http://localhost:8080/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Generated response from LlamaFile"
|
||||
assert result.in_token == 20
|
||||
assert result.out_token == 12
|
||||
assert result.model == 'llama.cpp' # Note: model in result is hardcoded to 'llama.cpp'
|
||||
|
||||
# Verify the OpenAI API call structure
|
||||
mock_openai_client.chat.completions.create.assert_called_once_with(
|
||||
model='LLaMA_CPP',
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "System prompt\n\nUser prompt"
|
||||
}]
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test handling of generic exceptions"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_openai_client.chat.completions.create.side_effect = Exception("Connection error")
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'LLaMA_CPP',
|
||||
'llamafile': 'http://localhost:8080/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Connection error"):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'custom-llama',
|
||||
'llamafile': 'http://custom-host:8080/v1',
|
||||
'temperature': 0.7,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'custom-llama'
|
||||
assert processor.llamafile == 'http://custom-host:8080/v1'
|
||||
assert processor.temperature == 0.7
|
||||
assert processor.max_output == 2048
|
||||
mock_openai_class.assert_called_once_with(
|
||||
base_url='http://custom-host:8080/v1',
|
||||
api_key='sk-no-key-required'
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_defaults(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test processor initialization with default values"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
# Only provide required fields, should use defaults
|
||||
config = {
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'LLaMA_CPP' # default_model
|
||||
assert processor.llamafile == 'http://localhost:8080/v1' # default_llamafile
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
assert processor.max_output == 4096 # default_max_output
|
||||
mock_openai_class.assert_called_once_with(
|
||||
base_url='http://localhost:8080/v1',
|
||||
api_key='sk-no-key-required'
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test content generation with empty prompts"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Default response"
|
||||
mock_response.usage.prompt_tokens = 2
|
||||
mock_response.usage.completion_tokens = 3
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'LLaMA_CPP',
|
||||
'llamafile': 'http://localhost:8080/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("", "")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Default response"
|
||||
assert result.in_token == 2
|
||||
assert result.out_token == 3
|
||||
assert result.model == 'llama.cpp'
|
||||
|
||||
# Verify the combined prompt is sent correctly
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
expected_prompt = "\n\n" # Empty system + "\n\n" + empty user
|
||||
assert call_args[1]['messages'][0]['content'] == expected_prompt
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_message_structure(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test that LlamaFile messages are structured correctly"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Response with proper structure"
|
||||
mock_response.usage.prompt_tokens = 25
|
||||
mock_response.usage.completion_tokens = 15
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'LLaMA_CPP',
|
||||
'llamafile': 'http://localhost:8080/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("You are a helpful assistant", "What is AI?")
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with proper structure"
|
||||
assert result.in_token == 25
|
||||
assert result.out_token == 15
|
||||
|
||||
# Verify the message structure
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
messages = call_args[1]['messages']
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0]['role'] == 'user'
|
||||
assert messages[0]['content'] == "You are a helpful assistant\n\nWhat is AI?"
|
||||
|
||||
# Verify model parameter
|
||||
assert call_args[1]['model'] == 'LLaMA_CPP'
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_openai_client_initialization(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test that OpenAI client is initialized correctly for LlamaFile"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'llama-custom',
|
||||
'llamafile': 'http://llamafile-server:8080/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify OpenAI client was called with correct parameters
|
||||
mock_openai_class.assert_called_once_with(
|
||||
base_url='http://llamafile-server:8080/v1',
|
||||
api_key='sk-no-key-required'
|
||||
)
|
||||
|
||||
# Verify processor has the client
|
||||
assert processor.openai == mock_openai_client
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_prompt_construction(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test prompt construction with system and user prompts"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Response with system instructions"
|
||||
mock_response.usage.prompt_tokens = 30
|
||||
mock_response.usage.completion_tokens = 20
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'LLaMA_CPP',
|
||||
'llamafile': 'http://localhost:8080/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("You are a helpful assistant", "What is machine learning?")
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with system instructions"
|
||||
assert result.in_token == 30
|
||||
assert result.out_token == 20
|
||||
|
||||
# Verify the combined prompt
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
expected_prompt = "You are a helpful assistant\n\nWhat is machine learning?"
|
||||
assert call_args[1]['messages'][0]['content'] == expected_prompt
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_hardcoded_model_response(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test that response model is hardcoded to 'llama.cpp'"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Test response"
|
||||
mock_response.usage.prompt_tokens = 15
|
||||
mock_response.usage.completion_tokens = 10
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'custom-model-name', # This should be ignored in response
|
||||
'llamafile': 'http://localhost:8080/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System", "User")
|
||||
|
||||
# Assert
|
||||
assert result.model == 'llama.cpp' # Should always be 'llama.cpp', not 'custom-model-name'
|
||||
assert processor.model == 'custom-model-name' # But processor.model should still be custom
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_no_rate_limiting(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test that no rate limiting is implemented (SLM assumption)"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "No rate limiting test"
|
||||
mock_response.usage.prompt_tokens = 10
|
||||
mock_response.usage.completion_tokens = 5
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'LLaMA_CPP',
|
||||
'llamafile': 'http://localhost:8080/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System", "User")
|
||||
|
||||
# Assert
|
||||
assert result.text == "No rate limiting test"
|
||||
# No specific rate limit error handling tested since SLM presumably has no rate limits
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
317
tests/unit/test_text_completion/test_ollama_processor.py
Normal file
317
tests/unit/test_text_completion/test_ollama_processor.py
Normal file
|
|
@ -0,0 +1,317 @@
|
|||
"""
|
||||
Unit tests for trustgraph.model.text_completion.ollama
|
||||
Following the same successful pattern as VertexAI tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.model.text_completion.ollama.llm import Processor
|
||||
from trustgraph.base import LlmResult
|
||||
|
||||
|
||||
class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Test Ollama processor functionality"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test basic processor initialization"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
# Mock the parent class initialization
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'llama2',
|
||||
'ollama': 'http://localhost:11434',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'llama2'
|
||||
assert hasattr(processor, 'llm')
|
||||
mock_client_class.assert_called_once_with(host='http://localhost:11434')
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test successful content generation"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_response = {
|
||||
'response': 'Generated response from Ollama',
|
||||
'prompt_eval_count': 15,
|
||||
'eval_count': 8
|
||||
}
|
||||
mock_client.generate.return_value = mock_response
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'llama2',
|
||||
'ollama': 'http://localhost:11434',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Generated response from Ollama"
|
||||
assert result.in_token == 15
|
||||
assert result.out_token == 8
|
||||
assert result.model == 'llama2'
|
||||
mock_client.generate.assert_called_once_with('llama2', "System prompt\n\nUser prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test handling of generic exceptions"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_client.generate.side_effect = Exception("Connection error")
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'llama2',
|
||||
'ollama': 'http://localhost:11434',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Connection error"):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'mistral',
|
||||
'ollama': 'http://192.168.1.100:11434',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'mistral'
|
||||
mock_client_class.assert_called_once_with(host='http://192.168.1.100:11434')
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_defaults(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test processor initialization with default values"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
# Don't provide model or ollama - should use defaults
|
||||
config = {
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gemma2:9b' # default_model
|
||||
# Should use default_ollama (http://localhost:11434 or from OLLAMA_HOST env)
|
||||
mock_client_class.assert_called_once()
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test content generation with empty prompts"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_response = {
|
||||
'response': 'Default response',
|
||||
'prompt_eval_count': 2,
|
||||
'eval_count': 3
|
||||
}
|
||||
mock_client.generate.return_value = mock_response
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'llama2',
|
||||
'ollama': 'http://localhost:11434',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("", "")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Default response"
|
||||
assert result.in_token == 2
|
||||
assert result.out_token == 3
|
||||
assert result.model == 'llama2'
|
||||
|
||||
# The prompt should be "" + "\n\n" + "" = "\n\n"
|
||||
mock_client.generate.assert_called_once_with('llama2', "\n\n")
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_token_counting(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test token counting from Ollama response"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_response = {
|
||||
'response': 'Test response',
|
||||
'prompt_eval_count': 50,
|
||||
'eval_count': 25
|
||||
}
|
||||
mock_client.generate.return_value = mock_response
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'llama2',
|
||||
'ollama': 'http://localhost:11434',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Test response"
|
||||
assert result.in_token == 50
|
||||
assert result.out_token == 25
|
||||
assert result.model == 'llama2'
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_ollama_client_initialization(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test that Ollama client is initialized correctly"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'codellama',
|
||||
'ollama': 'http://ollama-server:11434',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify Client was called with correct host
|
||||
mock_client_class.assert_called_once_with(host='http://ollama-server:11434')
|
||||
|
||||
# Verify processor has the client
|
||||
assert processor.llm == mock_client
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_prompt_construction(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test prompt construction with system and user prompts"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_response = {
|
||||
'response': 'Response with system instructions',
|
||||
'prompt_eval_count': 25,
|
||||
'eval_count': 15
|
||||
}
|
||||
mock_client.generate.return_value = mock_response
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'llama2',
|
||||
'ollama': 'http://localhost:11434',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("You are a helpful assistant", "What is AI?")
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with system instructions"
|
||||
assert result.in_token == 25
|
||||
assert result.out_token == 15
|
||||
|
||||
# Verify the combined prompt
|
||||
mock_client.generate.assert_called_once_with('llama2', "You are a helpful assistant\n\nWhat is AI?")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
395
tests/unit/test_text_completion/test_openai_processor.py
Normal file
395
tests/unit/test_text_completion/test_openai_processor.py
Normal file
|
|
@ -0,0 +1,395 @@
|
|||
"""
|
||||
Unit tests for trustgraph.model.text_completion.openai
|
||||
Following the same successful pattern as VertexAI and Ollama tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.model.text_completion.openai.llm import Processor
|
||||
from trustgraph.base import LlmResult
|
||||
from trustgraph.exceptions import TooManyRequests
|
||||
|
||||
|
||||
class TestOpenAIProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Test OpenAI processor functionality"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test basic processor initialization"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-3.5-turbo',
|
||||
'api_key': 'test-api-key',
|
||||
'url': 'https://api.openai.com/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gpt-3.5-turbo'
|
||||
assert processor.temperature == 0.0
|
||||
assert processor.max_output == 4096
|
||||
assert hasattr(processor, 'openai')
|
||||
mock_openai_class.assert_called_once_with(base_url='https://api.openai.com/v1', api_key='test-api-key')
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test successful content generation"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Generated response from OpenAI"
|
||||
mock_response.usage.prompt_tokens = 20
|
||||
mock_response.usage.completion_tokens = 12
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-3.5-turbo',
|
||||
'api_key': 'test-api-key',
|
||||
'url': 'https://api.openai.com/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Generated response from OpenAI"
|
||||
assert result.in_token == 20
|
||||
assert result.out_token == 12
|
||||
assert result.model == 'gpt-3.5-turbo'
|
||||
|
||||
# Verify the OpenAI API call
|
||||
mock_openai_client.chat.completions.create.assert_called_once_with(
|
||||
model='gpt-3.5-turbo',
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": "System prompt\n\nUser prompt"
|
||||
}]
|
||||
}],
|
||||
temperature=0.0,
|
||||
max_tokens=4096,
|
||||
top_p=1,
|
||||
frequency_penalty=0,
|
||||
presence_penalty=0,
|
||||
response_format={"type": "text"}
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_rate_limit_error(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test rate limit error handling"""
|
||||
# Arrange
|
||||
from openai import RateLimitError
|
||||
|
||||
mock_openai_client = MagicMock()
|
||||
mock_openai_client.chat.completions.create.side_effect = RateLimitError("Rate limit exceeded", response=MagicMock(), body=None)
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-3.5-turbo',
|
||||
'api_key': 'test-api-key',
|
||||
'url': 'https://api.openai.com/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(TooManyRequests):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test handling of generic exceptions"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_openai_client.chat.completions.create.side_effect = Exception("API connection error")
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-3.5-turbo',
|
||||
'api_key': 'test-api-key',
|
||||
'url': 'https://api.openai.com/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="API connection error"):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_without_api_key(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test processor initialization without API key (should fail)"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-3.5-turbo',
|
||||
'api_key': None, # No API key provided
|
||||
'url': 'https://api.openai.com/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError, match="OpenAI API key not specified"):
|
||||
processor = Processor(**config)
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-4',
|
||||
'api_key': 'custom-api-key',
|
||||
'url': 'https://custom-openai-url.com/v1',
|
||||
'temperature': 0.7,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gpt-4'
|
||||
assert processor.temperature == 0.7
|
||||
assert processor.max_output == 2048
|
||||
mock_openai_class.assert_called_once_with(base_url='https://custom-openai-url.com/v1', api_key='custom-api-key')
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_defaults(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test processor initialization with default values"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
# Only provide required fields, should use defaults
|
||||
config = {
|
||||
'api_key': 'test-api-key',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gpt-3.5-turbo' # default_model
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
assert processor.max_output == 4096 # default_max_output
|
||||
mock_openai_class.assert_called_once_with(base_url='https://api.openai.com/v1', api_key='test-api-key')
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test content generation with empty prompts"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Default response"
|
||||
mock_response.usage.prompt_tokens = 2
|
||||
mock_response.usage.completion_tokens = 3
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-3.5-turbo',
|
||||
'api_key': 'test-api-key',
|
||||
'url': 'https://api.openai.com/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("", "")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Default response"
|
||||
assert result.in_token == 2
|
||||
assert result.out_token == 3
|
||||
assert result.model == 'gpt-3.5-turbo'
|
||||
|
||||
# Verify the combined prompt is sent correctly
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
expected_prompt = "\n\n" # Empty system + "\n\n" + empty user
|
||||
assert call_args[1]['messages'][0]['content'][0]['text'] == expected_prompt
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_openai_client_initialization_without_base_url(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test OpenAI client initialization without base_url"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-3.5-turbo',
|
||||
'api_key': 'test-api-key',
|
||||
'url': None, # No base URL
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert - should be called without base_url when it's None
|
||||
mock_openai_class.assert_called_once_with(api_key='test-api-key')
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_message_structure(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test that OpenAI messages are structured correctly"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Response with proper structure"
|
||||
mock_response.usage.prompt_tokens = 25
|
||||
mock_response.usage.completion_tokens = 15
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-3.5-turbo',
|
||||
'api_key': 'test-api-key',
|
||||
'url': 'https://api.openai.com/v1',
|
||||
'temperature': 0.5,
|
||||
'max_output': 1024,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("You are a helpful assistant", "What is AI?")
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with proper structure"
|
||||
assert result.in_token == 25
|
||||
assert result.out_token == 15
|
||||
|
||||
# Verify the message structure matches OpenAI Chat API format
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
messages = call_args[1]['messages']
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0]['role'] == 'user'
|
||||
assert messages[0]['content'][0]['type'] == 'text'
|
||||
assert messages[0]['content'][0]['text'] == "You are a helpful assistant\n\nWhat is AI?"
|
||||
|
||||
# Verify other parameters
|
||||
assert call_args[1]['model'] == 'gpt-3.5-turbo'
|
||||
assert call_args[1]['temperature'] == 0.5
|
||||
assert call_args[1]['max_tokens'] == 1024
|
||||
assert call_args[1]['top_p'] == 1
|
||||
assert call_args[1]['frequency_penalty'] == 0
|
||||
assert call_args[1]['presence_penalty'] == 0
|
||||
assert call_args[1]['response_format'] == {"type": "text"}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
397
tests/unit/test_text_completion/test_vertexai_processor.py
Normal file
397
tests/unit/test_text_completion/test_vertexai_processor.py
Normal file
|
|
@ -0,0 +1,397 @@
|
|||
"""
|
||||
Unit tests for trustgraph.model.text_completion.vertexai
|
||||
Starting simple with one test to get the basics working
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.model.text_completion.vertexai.llm import Processor
|
||||
from trustgraph.base import LlmResult
|
||||
|
||||
|
||||
class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Simple test for processor initialization"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
"""Test basic processor initialization with mocked dependencies"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
mock_model = MagicMock()
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
# Mock the parent class initialization to avoid taskgroup requirement
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'region': 'us-central1',
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'private_key': 'private.json',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(), # Required by AsyncProcessor
|
||||
'id': 'test-processor' # Required by AsyncProcessor
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gemini-2.0-flash-001' # It's stored as 'model', not 'model_name'
|
||||
assert hasattr(processor, 'generation_config')
|
||||
assert hasattr(processor, 'safety_settings')
|
||||
assert hasattr(processor, 'llm')
|
||||
mock_service_account.Credentials.from_service_account_file.assert_called_once_with('private.json')
|
||||
mock_vertexai.init.assert_called_once()
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
"""Test successful content generation"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Generated response from Gemini"
|
||||
mock_response.usage_metadata.prompt_token_count = 15
|
||||
mock_response.usage_metadata.candidates_token_count = 8
|
||||
mock_model.generate_content.return_value = mock_response
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'region': 'us-central1',
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'private_key': 'private.json',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Generated response from Gemini"
|
||||
assert result.in_token == 15
|
||||
assert result.out_token == 8
|
||||
assert result.model == 'gemini-2.0-flash-001'
|
||||
# Check that the method was called (actual prompt format may vary)
|
||||
mock_model.generate_content.assert_called_once()
|
||||
# Verify the call was made with the expected parameters
|
||||
call_args = mock_model.generate_content.call_args
|
||||
assert call_args[1]['generation_config'] == processor.generation_config
|
||||
assert call_args[1]['safety_settings'] == processor.safety_settings
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_rate_limit_error(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
"""Test rate limit error handling"""
|
||||
# Arrange
|
||||
from google.api_core.exceptions import ResourceExhausted
|
||||
from trustgraph.exceptions import TooManyRequests
|
||||
|
||||
mock_credentials = MagicMock()
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_model.generate_content.side_effect = ResourceExhausted("Rate limit exceeded")
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'region': 'us-central1',
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'private_key': 'private.json',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(TooManyRequests):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_blocked_response(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
"""Test handling of blocked content (safety filters)"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = None # Blocked content returns None
|
||||
mock_response.usage_metadata.prompt_token_count = 10
|
||||
mock_response.usage_metadata.candidates_token_count = 0
|
||||
mock_model.generate_content.return_value = mock_response
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'region': 'us-central1',
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'private_key': 'private.json',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "Blocked content")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text is None # Should preserve None for blocked content
|
||||
assert result.in_token == 10
|
||||
assert result.out_token == 0
|
||||
assert result.model == 'gemini-2.0-flash-001'
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_without_private_key(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
"""Test processor initialization without private key (should fail)"""
|
||||
# Arrange
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'region': 'us-central1',
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'private_key': None, # No private key provided
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError, match="Private key file not specified"):
|
||||
processor = Processor(**config)
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
"""Test handling of generic exceptions"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_model.generate_content.side_effect = Exception("Network error")
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'region': 'us-central1',
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'private_key': 'private.json',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Network error"):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
mock_model = MagicMock()
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'region': 'us-west1',
|
||||
'model': 'gemini-1.5-pro',
|
||||
'temperature': 0.7,
|
||||
'max_output': 4096,
|
||||
'private_key': 'custom-key.json',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gemini-1.5-pro'
|
||||
|
||||
# Verify that generation_config object exists (can't easily check internal values)
|
||||
assert hasattr(processor, 'generation_config')
|
||||
assert processor.generation_config is not None
|
||||
|
||||
# Verify that safety settings are configured
|
||||
assert len(processor.safety_settings) == 4
|
||||
|
||||
# Verify service account was called with custom key
|
||||
mock_service_account.Credentials.from_service_account_file.assert_called_once_with('custom-key.json')
|
||||
|
||||
# Verify that parameters dict has the correct values (this is accessible)
|
||||
assert processor.parameters["temperature"] == 0.7
|
||||
assert processor.parameters["max_output_tokens"] == 4096
|
||||
assert processor.parameters["top_p"] == 1.0
|
||||
assert processor.parameters["top_k"] == 32
|
||||
assert processor.parameters["candidate_count"] == 1
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_vertexai_initialization_with_credentials(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
"""Test that VertexAI is initialized correctly with credentials"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_credentials.project_id = "test-project-123"
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
mock_model = MagicMock()
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'region': 'europe-west1',
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'private_key': 'service-account.json',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify VertexAI init was called with correct parameters
|
||||
mock_vertexai.init.assert_called_once_with(
|
||||
location='europe-west1',
|
||||
credentials=mock_credentials,
|
||||
project='test-project-123'
|
||||
)
|
||||
|
||||
# Verify GenerativeModel was created with the right model name
|
||||
mock_generative_model.assert_called_once_with('gemini-2.0-flash-001')
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
"""Test content generation with empty prompts"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Default response"
|
||||
mock_response.usage_metadata.prompt_token_count = 2
|
||||
mock_response.usage_metadata.candidates_token_count = 3
|
||||
mock_model.generate_content.return_value = mock_response
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'region': 'us-central1',
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'temperature': 0.0,
|
||||
'max_output': 8192,
|
||||
'private_key': 'private.json',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("", "")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Default response"
|
||||
assert result.in_token == 2
|
||||
assert result.out_token == 3
|
||||
assert result.model == 'gemini-2.0-flash-001'
|
||||
|
||||
# Verify the model was called with the combined empty prompts
|
||||
mock_model.generate_content.assert_called_once()
|
||||
call_args = mock_model.generate_content.call_args
|
||||
# The prompt should be "" + "\n\n" + "" = "\n\n"
|
||||
assert call_args[0][0] == "\n\n"
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
489
tests/unit/test_text_completion/test_vllm_processor.py
Normal file
489
tests/unit/test_text_completion/test_vllm_processor.py
Normal file
|
|
@ -0,0 +1,489 @@
|
|||
"""
|
||||
Unit tests for trustgraph.model.text_completion.vllm
|
||||
Following the same successful pattern as previous tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.model.text_completion.vllm.llm import Processor
|
||||
from trustgraph.base import LlmResult
|
||||
from trustgraph.exceptions import TooManyRequests
|
||||
|
||||
|
||||
class TestVLLMProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Test vLLM processor functionality"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test basic processor initialization"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
|
||||
'url': 'http://vllm-service:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'TheBloke/Mistral-7B-v0.1-AWQ'
|
||||
assert processor.base_url == 'http://vllm-service:8899/v1'
|
||||
assert processor.temperature == 0.0
|
||||
assert processor.max_output == 2048
|
||||
assert hasattr(processor, 'session')
|
||||
mock_session_class.assert_called_once()
|
||||
|
||||
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test successful content generation"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'choices': [{
|
||||
'text': 'Generated response from vLLM'
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 20,
|
||||
'completion_tokens': 12
|
||||
}
|
||||
})
|
||||
|
||||
# Mock the async context manager
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session.post.return_value.__aexit__.return_value = None
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
|
||||
'url': 'http://vllm-service:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Generated response from vLLM"
|
||||
assert result.in_token == 20
|
||||
assert result.out_token == 12
|
||||
assert result.model == 'TheBloke/Mistral-7B-v0.1-AWQ'
|
||||
|
||||
# Verify the vLLM API call
|
||||
mock_session.post.assert_called_once_with(
|
||||
'http://vllm-service:8899/v1/completions',
|
||||
headers={'Content-Type': 'application/json'},
|
||||
json={
|
||||
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
|
||||
'prompt': 'System prompt\n\nUser prompt',
|
||||
'max_tokens': 2048,
|
||||
'temperature': 0.0
|
||||
}
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_http_error(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test HTTP error handling"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 500
|
||||
|
||||
# Mock the async context manager
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session.post.return_value.__aexit__.return_value = None
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
|
||||
'url': 'http://vllm-service:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError, match="Bad status: 500"):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test handling of generic exceptions"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_session.post.side_effect = Exception("Connection error")
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
|
||||
'url': 'http://vllm-service:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Connection error"):
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'custom-model',
|
||||
'url': 'http://custom-vllm:8080/v1',
|
||||
'temperature': 0.7,
|
||||
'max_output': 1024,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'custom-model'
|
||||
assert processor.base_url == 'http://custom-vllm:8080/v1'
|
||||
assert processor.temperature == 0.7
|
||||
assert processor.max_output == 1024
|
||||
mock_session_class.assert_called_once()
|
||||
|
||||
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_with_defaults(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test processor initialization with default values"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
# Only provide required fields, should use defaults
|
||||
config = {
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'TheBloke/Mistral-7B-v0.1-AWQ' # default_model
|
||||
assert processor.base_url == 'http://vllm-service:8899/v1' # default_base_url
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
assert processor.max_output == 2048 # default_max_output
|
||||
mock_session_class.assert_called_once()
|
||||
|
||||
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test content generation with empty prompts"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'choices': [{
|
||||
'text': 'Default response'
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 2,
|
||||
'completion_tokens': 3
|
||||
}
|
||||
})
|
||||
|
||||
# Mock the async context manager
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session.post.return_value.__aexit__.return_value = None
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
|
||||
'url': 'http://vllm-service:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("", "")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Default response"
|
||||
assert result.in_token == 2
|
||||
assert result.out_token == 3
|
||||
assert result.model == 'TheBloke/Mistral-7B-v0.1-AWQ'
|
||||
|
||||
# Verify the combined prompt is sent correctly
|
||||
call_args = mock_session.post.call_args
|
||||
expected_prompt = "\n\n" # Empty system + "\n\n" + empty user
|
||||
assert call_args[1]['json']['prompt'] == expected_prompt
|
||||
|
||||
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_request_structure(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test that vLLM request is structured correctly"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'choices': [{
|
||||
'text': 'Response with proper structure'
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 25,
|
||||
'completion_tokens': 15
|
||||
}
|
||||
})
|
||||
|
||||
# Mock the async context manager
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session.post.return_value.__aexit__.return_value = None
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
|
||||
'url': 'http://vllm-service:8899/v1',
|
||||
'temperature': 0.5,
|
||||
'max_output': 1024,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("You are a helpful assistant", "What is AI?")
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with proper structure"
|
||||
assert result.in_token == 25
|
||||
assert result.out_token == 15
|
||||
|
||||
# Verify the request structure
|
||||
call_args = mock_session.post.call_args
|
||||
|
||||
# Check URL
|
||||
assert call_args[0][0] == 'http://vllm-service:8899/v1/completions'
|
||||
|
||||
# Check headers
|
||||
assert call_args[1]['headers']['Content-Type'] == 'application/json'
|
||||
|
||||
# Check request body
|
||||
request_data = call_args[1]['json']
|
||||
assert request_data['model'] == 'TheBloke/Mistral-7B-v0.1-AWQ'
|
||||
assert request_data['prompt'] == "You are a helpful assistant\n\nWhat is AI?"
|
||||
assert request_data['temperature'] == 0.5
|
||||
assert request_data['max_tokens'] == 1024
|
||||
|
||||
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_vllm_session_initialization(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test that aiohttp session is initialized correctly"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'test-model',
|
||||
'url': 'http://test-vllm:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
# Verify ClientSession was created
|
||||
mock_session_class.assert_called_once()
|
||||
|
||||
# Verify processor has the session
|
||||
assert processor.session == mock_session
|
||||
|
||||
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_response_parsing(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test response parsing from vLLM API"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'choices': [{
|
||||
'text': 'Parsed response text'
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 35,
|
||||
'completion_tokens': 25
|
||||
}
|
||||
})
|
||||
|
||||
# Mock the async context manager
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session.post.return_value.__aexit__.return_value = None
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
|
||||
'url': 'http://vllm-service:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System", "User query")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Parsed response text"
|
||||
assert result.in_token == 35
|
||||
assert result.out_token == 25
|
||||
assert result.model == 'TheBloke/Mistral-7B-v0.1-AWQ'
|
||||
|
||||
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_prompt_construction(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test prompt construction with system and user prompts"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'choices': [{
|
||||
'text': 'Response with system instructions'
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 40,
|
||||
'completion_tokens': 30
|
||||
}
|
||||
})
|
||||
|
||||
# Mock the async context manager
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session.post.return_value.__aexit__.return_value = None
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
|
||||
'url': 'http://vllm-service:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("You are a helpful assistant", "What is machine learning?")
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with system instructions"
|
||||
assert result.in_token == 40
|
||||
assert result.out_token == 30
|
||||
|
||||
# Verify the combined prompt
|
||||
call_args = mock_session.post.call_args
|
||||
expected_prompt = "You are a helpful assistant\n\nWhat is machine learning?"
|
||||
assert call_args[1]['json']['prompt'] == expected_prompt
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
Loading…
Add table
Add a link
Reference in a new issue