mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-27 01:16:22 +02:00
release/v1.4 -> master (#548)
This commit is contained in:
parent
3ec2cd54f9
commit
2bd68ed7f4
94 changed files with 8571 additions and 1740 deletions
238
tests/unit/test_base/test_flow_parameter_specs.py
Normal file
238
tests/unit/test_base/test_flow_parameter_specs.py
Normal file
|
|
@ -0,0 +1,238 @@
|
|||
"""
|
||||
Unit tests for Flow Parameter Specification functionality
|
||||
Testing parameter specification registration and handling in flow processors
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
from trustgraph.base.flow_processor import FlowProcessor
|
||||
from trustgraph.base import ParameterSpec, ConsumerSpec, ProducerSpec
|
||||
|
||||
|
||||
class MockAsyncProcessor:
|
||||
def __init__(self, **params):
|
||||
self.config_handlers = []
|
||||
self.id = params.get('id', 'test-service')
|
||||
self.specifications = []
|
||||
|
||||
|
||||
class TestFlowParameterSpecs(IsolatedAsyncioTestCase):
|
||||
"""Test flow processor parameter specification functionality"""
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
def test_parameter_spec_registration(self):
|
||||
"""Test that parameter specs can be registered with flow processors"""
|
||||
# Arrange
|
||||
config = {
|
||||
'id': 'test-flow-processor',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = FlowProcessor(**config)
|
||||
|
||||
# Create test parameter specs
|
||||
model_spec = ParameterSpec(name="model")
|
||||
temperature_spec = ParameterSpec(name="temperature")
|
||||
|
||||
# Act
|
||||
processor.register_specification(model_spec)
|
||||
processor.register_specification(temperature_spec)
|
||||
|
||||
# Assert
|
||||
assert len(processor.specifications) >= 2
|
||||
|
||||
param_specs = [spec for spec in processor.specifications
|
||||
if isinstance(spec, ParameterSpec)]
|
||||
assert len(param_specs) >= 2
|
||||
|
||||
param_names = [spec.name for spec in param_specs]
|
||||
assert "model" in param_names
|
||||
assert "temperature" in param_names
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
def test_mixed_specification_types(self):
|
||||
"""Test registration of mixed specification types (parameters, consumers, producers)"""
|
||||
# Arrange
|
||||
config = {
|
||||
'id': 'test-flow-processor',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = FlowProcessor(**config)
|
||||
|
||||
# Create different spec types
|
||||
param_spec = ParameterSpec(name="model")
|
||||
consumer_spec = ConsumerSpec(name="input", schema=MagicMock(), handler=MagicMock())
|
||||
producer_spec = ProducerSpec(name="output", schema=MagicMock())
|
||||
|
||||
# Act
|
||||
processor.register_specification(param_spec)
|
||||
processor.register_specification(consumer_spec)
|
||||
processor.register_specification(producer_spec)
|
||||
|
||||
# Assert
|
||||
assert len(processor.specifications) == 3
|
||||
|
||||
# Count each type
|
||||
param_specs = [s for s in processor.specifications if isinstance(s, ParameterSpec)]
|
||||
consumer_specs = [s for s in processor.specifications if isinstance(s, ConsumerSpec)]
|
||||
producer_specs = [s for s in processor.specifications if isinstance(s, ProducerSpec)]
|
||||
|
||||
assert len(param_specs) == 1
|
||||
assert len(consumer_specs) == 1
|
||||
assert len(producer_specs) == 1
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
def test_parameter_spec_metadata(self):
|
||||
"""Test parameter specification metadata handling"""
|
||||
# Arrange
|
||||
config = {
|
||||
'id': 'test-flow-processor',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = FlowProcessor(**config)
|
||||
|
||||
# Create parameter specs with metadata (if supported)
|
||||
model_spec = ParameterSpec(name="model")
|
||||
temperature_spec = ParameterSpec(name="temperature")
|
||||
|
||||
# Act
|
||||
processor.register_specification(model_spec)
|
||||
processor.register_specification(temperature_spec)
|
||||
|
||||
# Assert
|
||||
param_specs = [spec for spec in processor.specifications
|
||||
if isinstance(spec, ParameterSpec)]
|
||||
|
||||
model_spec_registered = next((s for s in param_specs if s.name == "model"), None)
|
||||
temperature_spec_registered = next((s for s in param_specs if s.name == "temperature"), None)
|
||||
|
||||
assert model_spec_registered is not None
|
||||
assert temperature_spec_registered is not None
|
||||
assert model_spec_registered.name == "model"
|
||||
assert temperature_spec_registered.name == "temperature"
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
def test_duplicate_parameter_spec_handling(self):
|
||||
"""Test handling of duplicate parameter spec registration"""
|
||||
# Arrange
|
||||
config = {
|
||||
'id': 'test-flow-processor',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = FlowProcessor(**config)
|
||||
|
||||
# Create duplicate parameter specs
|
||||
model_spec1 = ParameterSpec(name="model")
|
||||
model_spec2 = ParameterSpec(name="model")
|
||||
|
||||
# Act
|
||||
processor.register_specification(model_spec1)
|
||||
processor.register_specification(model_spec2)
|
||||
|
||||
# Assert - Should allow duplicates (or handle appropriately)
|
||||
param_specs = [spec for spec in processor.specifications
|
||||
if isinstance(spec, ParameterSpec) and spec.name == "model"]
|
||||
|
||||
# Either should have 2 duplicates or the system should handle deduplication
|
||||
assert len(param_specs) >= 1 # At least one should be registered
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
@patch('trustgraph.base.flow_processor.Flow')
|
||||
async def test_parameter_specs_available_to_flows(self, mock_flow_class):
|
||||
"""Test that parameter specs are available when flows are created"""
|
||||
# Arrange
|
||||
config = {
|
||||
'id': 'test-flow-processor',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = FlowProcessor(**config)
|
||||
processor.id = 'test-processor'
|
||||
|
||||
# Register parameter specs
|
||||
model_spec = ParameterSpec(name="model")
|
||||
temperature_spec = ParameterSpec(name="temperature")
|
||||
processor.register_specification(model_spec)
|
||||
processor.register_specification(temperature_spec)
|
||||
|
||||
mock_flow = AsyncMock()
|
||||
mock_flow_class.return_value = mock_flow
|
||||
|
||||
flow_name = 'test-flow'
|
||||
flow_defn = {'config': 'test-config'}
|
||||
|
||||
# Act
|
||||
await processor.start_flow(flow_name, flow_defn)
|
||||
|
||||
# Assert - Flow should be created with access to processor specifications
|
||||
mock_flow_class.assert_called_once_with('test-processor', flow_name, processor, flow_defn)
|
||||
|
||||
# The flow should have access to the processor's specifications
|
||||
# (The exact mechanism depends on Flow implementation)
|
||||
assert len(processor.specifications) >= 2
|
||||
|
||||
|
||||
class TestParameterSpecValidation(IsolatedAsyncioTestCase):
|
||||
"""Test parameter specification validation functionality"""
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
def test_parameter_spec_name_validation(self):
|
||||
"""Test parameter spec name validation"""
|
||||
# Arrange
|
||||
config = {
|
||||
'id': 'test-flow-processor',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = FlowProcessor(**config)
|
||||
|
||||
# Act & Assert - Valid parameter names
|
||||
valid_specs = [
|
||||
ParameterSpec(name="model"),
|
||||
ParameterSpec(name="temperature"),
|
||||
ParameterSpec(name="max_tokens"),
|
||||
ParameterSpec(name="api_key")
|
||||
]
|
||||
|
||||
for spec in valid_specs:
|
||||
# Should not raise any exceptions
|
||||
processor.register_specification(spec)
|
||||
|
||||
assert len([s for s in processor.specifications if isinstance(s, ParameterSpec)]) >= 4
|
||||
|
||||
def test_parameter_spec_creation_validation(self):
|
||||
"""Test parameter spec creation with various inputs"""
|
||||
# Test valid parameter spec creation
|
||||
valid_specs = [
|
||||
ParameterSpec(name="model"),
|
||||
ParameterSpec(name="temperature"),
|
||||
ParameterSpec(name="max_output"),
|
||||
]
|
||||
|
||||
for spec in valid_specs:
|
||||
assert spec.name is not None
|
||||
assert isinstance(spec.name, str)
|
||||
|
||||
# Test edge cases (if parameter specs have validation)
|
||||
# This depends on the actual ParameterSpec implementation
|
||||
try:
|
||||
empty_name_spec = ParameterSpec(name="")
|
||||
# May or may not be valid depending on implementation
|
||||
except Exception:
|
||||
# If validation exists, it should catch invalid names
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
264
tests/unit/test_base/test_llm_service_parameters.py
Normal file
264
tests/unit/test_base/test_llm_service_parameters.py
Normal file
|
|
@ -0,0 +1,264 @@
|
|||
"""
|
||||
Unit tests for LLM Service Parameter Specifications
|
||||
Testing the new parameter-aware functionality added to the LLM base service
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
from trustgraph.base.llm_service import LlmService, LlmResult
|
||||
from trustgraph.base import ParameterSpec, ConsumerSpec, ProducerSpec
|
||||
from trustgraph.schema import TextCompletionRequest, TextCompletionResponse
|
||||
|
||||
|
||||
class MockAsyncProcessor:
|
||||
def __init__(self, **params):
|
||||
self.config_handlers = []
|
||||
self.id = params.get('id', 'test-service')
|
||||
self.specifications = []
|
||||
|
||||
|
||||
|
||||
|
||||
class TestLlmServiceParameters(IsolatedAsyncioTestCase):
|
||||
"""Test LLM service parameter specification functionality"""
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
def test_parameter_specs_registration(self):
|
||||
"""Test that LLM service registers model and temperature parameter specs"""
|
||||
# Arrange
|
||||
config = {
|
||||
'id': 'test-llm-service',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock() # Add required taskgroup
|
||||
}
|
||||
|
||||
# Act
|
||||
service = LlmService(**config)
|
||||
|
||||
# Assert
|
||||
param_specs = {spec.name: spec for spec in service.specifications
|
||||
if isinstance(spec, ParameterSpec)}
|
||||
|
||||
assert "model" in param_specs
|
||||
assert "temperature" in param_specs
|
||||
assert len(param_specs) >= 2 # May have other parameter specs
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
def test_model_parameter_spec_properties(self):
|
||||
"""Test that model parameter spec has correct properties"""
|
||||
# Arrange
|
||||
config = {
|
||||
'id': 'test-llm-service',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
# Act
|
||||
service = LlmService(**config)
|
||||
|
||||
# Assert
|
||||
model_spec = None
|
||||
for spec in service.specifications:
|
||||
if isinstance(spec, ParameterSpec) and spec.name == "model":
|
||||
model_spec = spec
|
||||
break
|
||||
|
||||
assert model_spec is not None
|
||||
assert model_spec.name == "model"
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
def test_temperature_parameter_spec_properties(self):
|
||||
"""Test that temperature parameter spec has correct properties"""
|
||||
# Arrange
|
||||
config = {
|
||||
'id': 'test-llm-service',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
# Act
|
||||
service = LlmService(**config)
|
||||
|
||||
# Assert
|
||||
temperature_spec = None
|
||||
for spec in service.specifications:
|
||||
if isinstance(spec, ParameterSpec) and spec.name == "temperature":
|
||||
temperature_spec = spec
|
||||
break
|
||||
|
||||
assert temperature_spec is not None
|
||||
assert temperature_spec.name == "temperature"
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_on_request_extracts_parameters_from_flow(self):
|
||||
"""Test that on_request method extracts model and temperature from flow"""
|
||||
# Arrange
|
||||
config = {
|
||||
'id': 'test-llm-service',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
service = LlmService(**config)
|
||||
|
||||
# Mock the metrics
|
||||
service.text_completion_model_metric = MagicMock()
|
||||
service.text_completion_model_metric.labels.return_value.info = AsyncMock()
|
||||
|
||||
# Mock the generate_content method to capture parameters
|
||||
service.generate_content = AsyncMock(return_value=LlmResult(
|
||||
text="test response",
|
||||
in_token=10,
|
||||
out_token=5,
|
||||
model="gpt-4"
|
||||
))
|
||||
|
||||
# Mock message and flow
|
||||
mock_message = MagicMock()
|
||||
mock_message.value.return_value = MagicMock()
|
||||
mock_message.value.return_value.system = "system prompt"
|
||||
mock_message.value.return_value.prompt = "user prompt"
|
||||
mock_message.properties.return_value = {"id": "test-id"}
|
||||
|
||||
mock_consumer = MagicMock()
|
||||
mock_consumer.name = "request"
|
||||
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.name = "test-flow"
|
||||
mock_flow.return_value = "test-model" # flow("model") returns this
|
||||
mock_flow.side_effect = lambda param: {
|
||||
"model": "gpt-4",
|
||||
"temperature": 0.7
|
||||
}.get(param, f"mock-{param}")
|
||||
|
||||
mock_producer = AsyncMock()
|
||||
mock_flow.producer = {"response": mock_producer}
|
||||
|
||||
# Act
|
||||
await service.on_request(mock_message, mock_consumer, mock_flow)
|
||||
|
||||
# Assert
|
||||
# Verify that generate_content was called with parameters from flow
|
||||
service.generate_content.assert_called_once()
|
||||
call_args = service.generate_content.call_args
|
||||
|
||||
assert call_args[0][0] == "system prompt" # system
|
||||
assert call_args[0][1] == "user prompt" # prompt
|
||||
assert call_args[0][2] == "gpt-4" # model
|
||||
assert call_args[0][3] == 0.7 # temperature
|
||||
|
||||
# Verify flow was queried for both parameters
|
||||
mock_flow.assert_any_call("model")
|
||||
mock_flow.assert_any_call("temperature")
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_on_request_handles_missing_parameters_gracefully(self):
|
||||
"""Test that on_request handles missing parameters gracefully"""
|
||||
# Arrange
|
||||
config = {
|
||||
'id': 'test-llm-service',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
service = LlmService(**config)
|
||||
|
||||
# Mock the metrics
|
||||
service.text_completion_model_metric = MagicMock()
|
||||
service.text_completion_model_metric.labels.return_value.info = AsyncMock()
|
||||
|
||||
# Mock the generate_content method
|
||||
service.generate_content = AsyncMock(return_value=LlmResult(
|
||||
text="test response",
|
||||
in_token=10,
|
||||
out_token=5,
|
||||
model="default-model"
|
||||
))
|
||||
|
||||
# Mock message and flow where flow returns None for parameters
|
||||
mock_message = MagicMock()
|
||||
mock_message.value.return_value = MagicMock()
|
||||
mock_message.value.return_value.system = "system prompt"
|
||||
mock_message.value.return_value.prompt = "user prompt"
|
||||
mock_message.properties.return_value = {"id": "test-id"}
|
||||
|
||||
mock_consumer = MagicMock()
|
||||
mock_consumer.name = "request"
|
||||
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.name = "test-flow"
|
||||
mock_flow.return_value = None # Both parameters return None
|
||||
|
||||
mock_producer = AsyncMock()
|
||||
mock_flow.producer = {"response": mock_producer}
|
||||
|
||||
# Act
|
||||
await service.on_request(mock_message, mock_consumer, mock_flow)
|
||||
|
||||
# Assert
|
||||
# Should still call generate_content, with None values that will use processor defaults
|
||||
service.generate_content.assert_called_once()
|
||||
call_args = service.generate_content.call_args
|
||||
|
||||
assert call_args[0][0] == "system prompt" # system
|
||||
assert call_args[0][1] == "user prompt" # prompt
|
||||
assert call_args[0][2] is None # model (will use processor default)
|
||||
assert call_args[0][3] is None # temperature (will use processor default)
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_on_request_error_handling_preserves_behavior(self):
|
||||
"""Test that parameter extraction doesn't break existing error handling"""
|
||||
# Arrange
|
||||
config = {
|
||||
'id': 'test-llm-service',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
service = LlmService(**config)
|
||||
|
||||
# Mock the metrics
|
||||
service.text_completion_model_metric = MagicMock()
|
||||
service.text_completion_model_metric.labels.return_value.info = AsyncMock()
|
||||
|
||||
# Mock generate_content to raise an exception
|
||||
service.generate_content = AsyncMock(side_effect=Exception("Test error"))
|
||||
|
||||
# Mock message and flow
|
||||
mock_message = MagicMock()
|
||||
mock_message.value.return_value = MagicMock()
|
||||
mock_message.value.return_value.system = "system prompt"
|
||||
mock_message.value.return_value.prompt = "user prompt"
|
||||
mock_message.properties.return_value = {"id": "test-id"}
|
||||
|
||||
mock_consumer = MagicMock()
|
||||
mock_consumer.name = "request"
|
||||
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.name = "test-flow"
|
||||
mock_flow.side_effect = lambda param: {
|
||||
"model": "gpt-4",
|
||||
"temperature": 0.7
|
||||
}.get(param, f"mock-{param}")
|
||||
|
||||
mock_producer = AsyncMock()
|
||||
mock_flow.producer = {"response": mock_producer}
|
||||
|
||||
# Act
|
||||
await service.on_request(mock_message, mock_consumer, mock_flow)
|
||||
|
||||
# Assert
|
||||
# Should have sent error response
|
||||
mock_producer.send.assert_called_once()
|
||||
error_response = mock_producer.send.call_args[0][0]
|
||||
|
||||
assert error_response.error is not None
|
||||
assert error_response.error.type == "llm-error"
|
||||
assert "Test error" in error_response.error.message
|
||||
assert error_response.response is None
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
@ -1,211 +1,236 @@
|
|||
"""
|
||||
Unit tests for trustgraph.chunking.recursive
|
||||
Testing parameter override functionality for chunk-size and chunk-overlap
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, Mock, patch, MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.chunking.recursive.chunker import Processor
|
||||
from trustgraph.schema import TextDocument, Chunk, Metadata
|
||||
from trustgraph.chunking.recursive.chunker import Processor as RecursiveChunker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_flow():
|
||||
output_mock = AsyncMock()
|
||||
flow_mock = Mock(return_value=output_mock)
|
||||
return flow_mock, output_mock
|
||||
class MockAsyncProcessor:
|
||||
def __init__(self, **params):
|
||||
self.config_handlers = []
|
||||
self.id = params.get('id', 'test-service')
|
||||
self.specifications = []
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_consumer():
|
||||
consumer = Mock()
|
||||
consumer.id = "test-consumer"
|
||||
consumer.flow = "test-flow"
|
||||
return consumer
|
||||
class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
||||
"""Test Recursive chunker functionality"""
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
def test_processor_initialization_basic(self):
|
||||
"""Test basic processor initialization"""
|
||||
# Arrange
|
||||
config = {
|
||||
'id': 'test-chunker',
|
||||
'chunk_size': 1500,
|
||||
'chunk_overlap': 150,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.default_chunk_size == 1500
|
||||
assert processor.default_chunk_overlap == 150
|
||||
assert hasattr(processor, 'text_splitter')
|
||||
|
||||
# Verify parameter specs are registered
|
||||
param_specs = [spec for spec in processor.specifications
|
||||
if hasattr(spec, 'name') and spec.name in ['chunk-size', 'chunk-overlap']]
|
||||
assert len(param_specs) == 2
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_chunk_document_with_chunk_size_override(self):
|
||||
"""Test chunk_document with chunk-size parameter override"""
|
||||
# Arrange
|
||||
config = {
|
||||
'id': 'test-chunker',
|
||||
'chunk_size': 1000, # Default chunk size
|
||||
'chunk_overlap': 100,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Mock message and flow
|
||||
mock_message = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.side_effect = lambda param: {
|
||||
"chunk-size": 2000, # Override chunk size
|
||||
"chunk-overlap": None # Use default chunk overlap
|
||||
}.get(param)
|
||||
|
||||
# Act
|
||||
chunk_size, chunk_overlap = await processor.chunk_document(
|
||||
mock_message, mock_consumer, mock_flow, 1000, 100
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert chunk_size == 2000 # Should use overridden value
|
||||
assert chunk_overlap == 100 # Should use default value
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_chunk_document_with_chunk_overlap_override(self):
|
||||
"""Test chunk_document with chunk-overlap parameter override"""
|
||||
# Arrange
|
||||
config = {
|
||||
'id': 'test-chunker',
|
||||
'chunk_size': 1000,
|
||||
'chunk_overlap': 100, # Default chunk overlap
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Mock message and flow
|
||||
mock_message = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.side_effect = lambda param: {
|
||||
"chunk-size": None, # Use default chunk size
|
||||
"chunk-overlap": 200 # Override chunk overlap
|
||||
}.get(param)
|
||||
|
||||
# Act
|
||||
chunk_size, chunk_overlap = await processor.chunk_document(
|
||||
mock_message, mock_consumer, mock_flow, 1000, 100
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert chunk_size == 1000 # Should use default value
|
||||
assert chunk_overlap == 200 # Should use overridden value
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_chunk_document_with_both_parameters_override(self):
|
||||
"""Test chunk_document with both chunk-size and chunk-overlap overrides"""
|
||||
# Arrange
|
||||
config = {
|
||||
'id': 'test-chunker',
|
||||
'chunk_size': 1000,
|
||||
'chunk_overlap': 100,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Mock message and flow
|
||||
mock_message = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.side_effect = lambda param: {
|
||||
"chunk-size": 1500, # Override chunk size
|
||||
"chunk-overlap": 150 # Override chunk overlap
|
||||
}.get(param)
|
||||
|
||||
# Act
|
||||
chunk_size, chunk_overlap = await processor.chunk_document(
|
||||
mock_message, mock_consumer, mock_flow, 1000, 100
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert chunk_size == 1500 # Should use overridden value
|
||||
assert chunk_overlap == 150 # Should use overridden value
|
||||
|
||||
@patch('trustgraph.chunking.recursive.chunker.RecursiveCharacterTextSplitter')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_on_message_uses_flow_parameters(self, mock_splitter_class):
|
||||
"""Test that on_message method uses parameters from flow"""
|
||||
# Arrange
|
||||
mock_splitter = MagicMock()
|
||||
mock_document = MagicMock()
|
||||
mock_document.page_content = "Test chunk content"
|
||||
mock_splitter.create_documents.return_value = [mock_document]
|
||||
mock_splitter_class.return_value = mock_splitter
|
||||
|
||||
config = {
|
||||
'id': 'test-chunker',
|
||||
'chunk_size': 1000,
|
||||
'chunk_overlap': 100,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Mock message with TextDocument
|
||||
mock_message = MagicMock()
|
||||
mock_text_doc = MagicMock()
|
||||
mock_text_doc.metadata = Metadata(
|
||||
id="test-doc-123",
|
||||
metadata=[],
|
||||
user="test-user",
|
||||
collection="test-collection"
|
||||
)
|
||||
mock_text_doc.text = b"This is test document content"
|
||||
mock_message.value.return_value = mock_text_doc
|
||||
|
||||
# Mock consumer and flow with parameter overrides
|
||||
mock_consumer = MagicMock()
|
||||
mock_producer = AsyncMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.side_effect = lambda param: {
|
||||
"chunk-size": 1500,
|
||||
"chunk-overlap": 150,
|
||||
"output": mock_producer
|
||||
}.get(param)
|
||||
|
||||
# Act
|
||||
await processor.on_message(mock_message, mock_consumer, mock_flow)
|
||||
|
||||
# Assert
|
||||
# Verify RecursiveCharacterTextSplitter was called with overridden parameters (last call)
|
||||
actual_last_call = mock_splitter_class.call_args_list[-1]
|
||||
assert actual_last_call.kwargs['chunk_size'] == 1500
|
||||
assert actual_last_call.kwargs['chunk_overlap'] == 150
|
||||
assert actual_last_call.kwargs['length_function'] == len
|
||||
assert actual_last_call.kwargs['is_separator_regex'] == False
|
||||
|
||||
# Verify chunk was sent to output
|
||||
mock_producer.send.assert_called_once()
|
||||
sent_chunk = mock_producer.send.call_args[0][0]
|
||||
assert isinstance(sent_chunk, Chunk)
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_chunk_document_with_no_overrides(self):
|
||||
"""Test chunk_document when no parameters are overridden (flow returns None)"""
|
||||
# Arrange
|
||||
config = {
|
||||
'id': 'test-chunker',
|
||||
'chunk_size': 1000,
|
||||
'chunk_overlap': 100,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Mock message and flow that returns None for all parameters
|
||||
mock_message = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.return_value = None # No overrides
|
||||
|
||||
# Act
|
||||
chunk_size, chunk_overlap = await processor.chunk_document(
|
||||
mock_message, mock_consumer, mock_flow, 1000, 100
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert chunk_size == 1000 # Should use default value
|
||||
assert chunk_overlap == 100 # Should use default value
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_document():
|
||||
metadata = Metadata(
|
||||
id="test-doc-1",
|
||||
metadata=[],
|
||||
user="test-user",
|
||||
collection="test-collection"
|
||||
)
|
||||
text = "This is a test document. " * 100 # Create text long enough to be chunked
|
||||
return TextDocument(
|
||||
metadata=metadata,
|
||||
text=text.encode("utf-8")
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def short_document():
|
||||
metadata = Metadata(
|
||||
id="test-doc-2",
|
||||
metadata=[],
|
||||
user="test-user",
|
||||
collection="test-collection"
|
||||
)
|
||||
text = "This is a very short document."
|
||||
return TextDocument(
|
||||
metadata=metadata,
|
||||
text=text.encode("utf-8")
|
||||
)
|
||||
|
||||
|
||||
class TestRecursiveChunker:
|
||||
|
||||
def test_init_default_params(self, mock_async_processor_init):
|
||||
processor = RecursiveChunker()
|
||||
assert processor.text_splitter._chunk_size == 2000
|
||||
assert processor.text_splitter._chunk_overlap == 100
|
||||
|
||||
def test_init_custom_params(self, mock_async_processor_init):
|
||||
processor = RecursiveChunker(chunk_size=500, chunk_overlap=50)
|
||||
assert processor.text_splitter._chunk_size == 500
|
||||
assert processor.text_splitter._chunk_overlap == 50
|
||||
|
||||
def test_init_with_id(self, mock_async_processor_init):
|
||||
processor = RecursiveChunker(id="custom-chunker")
|
||||
assert processor.id == "custom-chunker"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_single_chunk(self, mock_async_processor_init, mock_flow, mock_consumer, short_document):
|
||||
flow_mock, output_mock = mock_flow
|
||||
processor = RecursiveChunker(chunk_size=2000, chunk_overlap=100)
|
||||
|
||||
msg = Mock()
|
||||
msg.value.return_value = short_document
|
||||
|
||||
await processor.on_message(msg, mock_consumer, flow_mock)
|
||||
|
||||
# Should produce exactly one chunk for short text
|
||||
assert output_mock.send.call_count == 1
|
||||
|
||||
# Verify the chunk was created correctly
|
||||
chunk_call = output_mock.send.call_args[0][0]
|
||||
assert isinstance(chunk_call, Chunk)
|
||||
assert chunk_call.metadata == short_document.metadata
|
||||
assert chunk_call.chunk.decode("utf-8") == short_document.text.decode("utf-8")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_multiple_chunks(self, mock_async_processor_init, mock_flow, mock_consumer, sample_document):
|
||||
flow_mock, output_mock = mock_flow
|
||||
processor = RecursiveChunker(chunk_size=100, chunk_overlap=20)
|
||||
|
||||
msg = Mock()
|
||||
msg.value.return_value = sample_document
|
||||
|
||||
await processor.on_message(msg, mock_consumer, flow_mock)
|
||||
|
||||
# Should produce multiple chunks
|
||||
assert output_mock.send.call_count > 1
|
||||
|
||||
# Verify all chunks have correct metadata
|
||||
for call in output_mock.send.call_args_list:
|
||||
chunk = call[0][0]
|
||||
assert isinstance(chunk, Chunk)
|
||||
assert chunk.metadata == sample_document.metadata
|
||||
assert len(chunk.chunk) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_chunk_overlap(self, mock_async_processor_init, mock_flow, mock_consumer):
|
||||
flow_mock, output_mock = mock_flow
|
||||
processor = RecursiveChunker(chunk_size=50, chunk_overlap=10)
|
||||
|
||||
# Create a document with predictable content
|
||||
metadata = Metadata(id="test", metadata=[], user="test-user", collection="test-collection")
|
||||
text = "ABCDEFGHIJ" * 10 # 100 characters
|
||||
document = TextDocument(metadata=metadata, text=text.encode("utf-8"))
|
||||
|
||||
msg = Mock()
|
||||
msg.value.return_value = document
|
||||
|
||||
await processor.on_message(msg, mock_consumer, flow_mock)
|
||||
|
||||
# Collect all chunks
|
||||
chunks = []
|
||||
for call in output_mock.send.call_args_list:
|
||||
chunk_text = call[0][0].chunk.decode("utf-8")
|
||||
chunks.append(chunk_text)
|
||||
|
||||
# Verify chunks have expected overlap
|
||||
for i in range(len(chunks) - 1):
|
||||
# The end of chunk i should overlap with the beginning of chunk i+1
|
||||
# Check if there's some overlap (exact overlap depends on text splitter logic)
|
||||
assert len(chunks[i]) <= 50 + 10 # chunk_size + some tolerance
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_empty_document(self, mock_async_processor_init, mock_flow, mock_consumer):
|
||||
flow_mock, output_mock = mock_flow
|
||||
processor = RecursiveChunker()
|
||||
|
||||
metadata = Metadata(id="empty", metadata=[], user="test-user", collection="test-collection")
|
||||
document = TextDocument(metadata=metadata, text=b"")
|
||||
|
||||
msg = Mock()
|
||||
msg.value.return_value = document
|
||||
|
||||
await processor.on_message(msg, mock_consumer, flow_mock)
|
||||
|
||||
# Empty documents typically don't produce chunks with langchain splitters
|
||||
# This behavior is expected - no chunks should be produced
|
||||
assert output_mock.send.call_count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_unicode_handling(self, mock_async_processor_init, mock_flow, mock_consumer):
|
||||
flow_mock, output_mock = mock_flow
|
||||
processor = RecursiveChunker(chunk_size=500, chunk_overlap=20) # Fixed overlap < chunk_size
|
||||
|
||||
metadata = Metadata(id="unicode", metadata=[], user="test-user", collection="test-collection")
|
||||
text = "Hello 世界! 🌍 This is a test with émojis and spëcial characters."
|
||||
document = TextDocument(metadata=metadata, text=text.encode("utf-8"))
|
||||
|
||||
msg = Mock()
|
||||
msg.value.return_value = document
|
||||
|
||||
await processor.on_message(msg, mock_consumer, flow_mock)
|
||||
|
||||
# Verify unicode is preserved correctly
|
||||
all_chunks = []
|
||||
for call in output_mock.send.call_args_list:
|
||||
chunk_text = call[0][0].chunk.decode("utf-8")
|
||||
all_chunks.append(chunk_text)
|
||||
|
||||
# Reconstruct text (approximately, due to overlap)
|
||||
reconstructed = "".join(all_chunks)
|
||||
assert "世界" in reconstructed
|
||||
assert "🌍" in reconstructed
|
||||
assert "émojis" in reconstructed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_metrics_recorded(self, mock_async_processor_init, mock_flow, mock_consumer, sample_document):
|
||||
flow_mock, output_mock = mock_flow
|
||||
processor = RecursiveChunker(chunk_size=100)
|
||||
|
||||
msg = Mock()
|
||||
msg.value.return_value = sample_document
|
||||
|
||||
# Mock the metric
|
||||
with patch.object(RecursiveChunker.chunk_metric, 'labels') as mock_labels:
|
||||
mock_observe = Mock()
|
||||
mock_labels.return_value.observe = mock_observe
|
||||
|
||||
await processor.on_message(msg, mock_consumer, flow_mock)
|
||||
|
||||
# Verify metrics were recorded
|
||||
mock_labels.assert_called_with(id="test-consumer", flow="test-flow")
|
||||
assert mock_observe.call_count > 0
|
||||
|
||||
# Verify chunk sizes were observed
|
||||
for call in mock_observe.call_args_list:
|
||||
chunk_size = call[0][0]
|
||||
assert chunk_size > 0
|
||||
|
||||
def test_add_args(self):
|
||||
parser = Mock()
|
||||
RecursiveChunker.add_args(parser)
|
||||
|
||||
# Verify arguments were added
|
||||
calls = parser.add_argument.call_args_list
|
||||
arg_names = [call[0][0] for call in calls]
|
||||
|
||||
assert '-z' in arg_names or '--chunk-size' in arg_names
|
||||
assert '-v' in arg_names or '--chunk-overlap' in arg_names
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
@ -1,275 +1,256 @@
|
|||
"""
|
||||
Unit tests for trustgraph.chunking.token
|
||||
Testing parameter override functionality for chunk-size and chunk-overlap
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.chunking.token.chunker import Processor
|
||||
from trustgraph.schema import TextDocument, Chunk, Metadata
|
||||
from trustgraph.chunking.token.chunker import Processor as TokenChunker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_flow():
|
||||
output_mock = AsyncMock()
|
||||
flow_mock = Mock(return_value=output_mock)
|
||||
return flow_mock, output_mock
|
||||
class MockAsyncProcessor:
|
||||
def __init__(self, **params):
|
||||
self.config_handlers = []
|
||||
self.id = params.get('id', 'test-service')
|
||||
self.specifications = []
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_consumer():
|
||||
consumer = Mock()
|
||||
consumer.id = "test-consumer"
|
||||
consumer.flow = "test-flow"
|
||||
return consumer
|
||||
class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
||||
"""Test Token chunker functionality"""
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
def test_processor_initialization_basic(self):
|
||||
"""Test basic processor initialization"""
|
||||
# Arrange
|
||||
config = {
|
||||
'id': 'test-chunker',
|
||||
'chunk_size': 300,
|
||||
'chunk_overlap': 20,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.default_chunk_size == 300
|
||||
assert processor.default_chunk_overlap == 20
|
||||
assert hasattr(processor, 'text_splitter')
|
||||
|
||||
# Verify parameter specs are registered
|
||||
param_specs = [spec for spec in processor.specifications
|
||||
if hasattr(spec, 'name') and spec.name in ['chunk-size', 'chunk-overlap']]
|
||||
assert len(param_specs) == 2
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_chunk_document_with_chunk_size_override(self):
|
||||
"""Test chunk_document with chunk-size parameter override"""
|
||||
# Arrange
|
||||
config = {
|
||||
'id': 'test-chunker',
|
||||
'chunk_size': 250, # Default chunk size
|
||||
'chunk_overlap': 15,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Mock message and flow
|
||||
mock_message = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.side_effect = lambda param: {
|
||||
"chunk-size": 400, # Override chunk size
|
||||
"chunk-overlap": None # Use default chunk overlap
|
||||
}.get(param)
|
||||
|
||||
# Act
|
||||
chunk_size, chunk_overlap = await processor.chunk_document(
|
||||
mock_message, mock_consumer, mock_flow, 250, 15
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert chunk_size == 400 # Should use overridden value
|
||||
assert chunk_overlap == 15 # Should use default value
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_chunk_document_with_chunk_overlap_override(self):
|
||||
"""Test chunk_document with chunk-overlap parameter override"""
|
||||
# Arrange
|
||||
config = {
|
||||
'id': 'test-chunker',
|
||||
'chunk_size': 250,
|
||||
'chunk_overlap': 15, # Default chunk overlap
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Mock message and flow
|
||||
mock_message = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.side_effect = lambda param: {
|
||||
"chunk-size": None, # Use default chunk size
|
||||
"chunk-overlap": 25 # Override chunk overlap
|
||||
}.get(param)
|
||||
|
||||
# Act
|
||||
chunk_size, chunk_overlap = await processor.chunk_document(
|
||||
mock_message, mock_consumer, mock_flow, 250, 15
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert chunk_size == 250 # Should use default value
|
||||
assert chunk_overlap == 25 # Should use overridden value
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_chunk_document_with_both_parameters_override(self):
|
||||
"""Test chunk_document with both chunk-size and chunk-overlap overrides"""
|
||||
# Arrange
|
||||
config = {
|
||||
'id': 'test-chunker',
|
||||
'chunk_size': 250,
|
||||
'chunk_overlap': 15,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Mock message and flow
|
||||
mock_message = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.side_effect = lambda param: {
|
||||
"chunk-size": 350, # Override chunk size
|
||||
"chunk-overlap": 30 # Override chunk overlap
|
||||
}.get(param)
|
||||
|
||||
# Act
|
||||
chunk_size, chunk_overlap = await processor.chunk_document(
|
||||
mock_message, mock_consumer, mock_flow, 250, 15
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert chunk_size == 350 # Should use overridden value
|
||||
assert chunk_overlap == 30 # Should use overridden value
|
||||
|
||||
@patch('trustgraph.chunking.token.chunker.TokenTextSplitter')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_on_message_uses_flow_parameters(self, mock_splitter_class):
|
||||
"""Test that on_message method uses parameters from flow"""
|
||||
# Arrange
|
||||
mock_splitter = MagicMock()
|
||||
mock_document = MagicMock()
|
||||
mock_document.page_content = "Test token chunk content"
|
||||
mock_splitter.create_documents.return_value = [mock_document]
|
||||
mock_splitter_class.return_value = mock_splitter
|
||||
|
||||
config = {
|
||||
'id': 'test-chunker',
|
||||
'chunk_size': 250,
|
||||
'chunk_overlap': 15,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Mock message with TextDocument
|
||||
mock_message = MagicMock()
|
||||
mock_text_doc = MagicMock()
|
||||
mock_text_doc.metadata = Metadata(
|
||||
id="test-doc-456",
|
||||
metadata=[],
|
||||
user="test-user",
|
||||
collection="test-collection"
|
||||
)
|
||||
mock_text_doc.text = b"This is test document content for token chunking"
|
||||
mock_message.value.return_value = mock_text_doc
|
||||
|
||||
# Mock consumer and flow with parameter overrides
|
||||
mock_consumer = MagicMock()
|
||||
mock_producer = AsyncMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.side_effect = lambda param: {
|
||||
"chunk-size": 400,
|
||||
"chunk-overlap": 40,
|
||||
"output": mock_producer
|
||||
}.get(param)
|
||||
|
||||
# Act
|
||||
await processor.on_message(mock_message, mock_consumer, mock_flow)
|
||||
|
||||
# Assert
|
||||
# Verify TokenTextSplitter was called with overridden parameters (last call)
|
||||
expected_call = [
|
||||
('encoding_name', 'cl100k_base'),
|
||||
('chunk_size', 400),
|
||||
('chunk_overlap', 40)
|
||||
]
|
||||
actual_last_call = mock_splitter_class.call_args_list[-1]
|
||||
assert actual_last_call.kwargs['encoding_name'] == "cl100k_base"
|
||||
assert actual_last_call.kwargs['chunk_size'] == 400
|
||||
assert actual_last_call.kwargs['chunk_overlap'] == 40
|
||||
|
||||
# Verify chunk was sent to output
|
||||
mock_producer.send.assert_called_once()
|
||||
sent_chunk = mock_producer.send.call_args[0][0]
|
||||
assert isinstance(sent_chunk, Chunk)
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
async def test_chunk_document_with_no_overrides(self):
|
||||
"""Test chunk_document when no parameters are overridden (flow returns None)"""
|
||||
# Arrange
|
||||
config = {
|
||||
'id': 'test-chunker',
|
||||
'chunk_size': 250,
|
||||
'chunk_overlap': 15,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Mock message and flow that returns None for all parameters
|
||||
mock_message = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.return_value = None # No overrides
|
||||
|
||||
# Act
|
||||
chunk_size, chunk_overlap = await processor.chunk_document(
|
||||
mock_message, mock_consumer, mock_flow, 250, 15
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert chunk_size == 250 # Should use default value
|
||||
assert chunk_overlap == 15 # Should use default value
|
||||
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||
def test_token_chunker_uses_different_defaults(self):
|
||||
"""Test that token chunker has different defaults than recursive chunker"""
|
||||
# Arrange & Act
|
||||
config = {
|
||||
'id': 'test-chunker',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock()
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert - Token chunker should have different defaults
|
||||
assert processor.default_chunk_size == 250 # Token chunker default
|
||||
assert processor.default_chunk_overlap == 15 # Token chunker default
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_document():
|
||||
metadata = Metadata(
|
||||
id="test-doc-1",
|
||||
metadata=[],
|
||||
user="test-user",
|
||||
collection="test-collection"
|
||||
)
|
||||
# Create text that will result in multiple token chunks
|
||||
text = "The quick brown fox jumps over the lazy dog. " * 50
|
||||
return TextDocument(
|
||||
metadata=metadata,
|
||||
text=text.encode("utf-8")
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def short_document():
|
||||
metadata = Metadata(
|
||||
id="test-doc-2",
|
||||
metadata=[],
|
||||
user="test-user",
|
||||
collection="test-collection"
|
||||
)
|
||||
text = "Short text."
|
||||
return TextDocument(
|
||||
metadata=metadata,
|
||||
text=text.encode("utf-8")
|
||||
)
|
||||
|
||||
|
||||
class TestTokenChunker:
|
||||
|
||||
def test_init_default_params(self, mock_async_processor_init):
|
||||
processor = TokenChunker()
|
||||
assert processor.text_splitter._chunk_size == 250
|
||||
assert processor.text_splitter._chunk_overlap == 15
|
||||
# Just verify the text splitter was created (encoding verification is complex)
|
||||
assert processor.text_splitter is not None
|
||||
assert hasattr(processor.text_splitter, 'split_text')
|
||||
|
||||
def test_init_custom_params(self, mock_async_processor_init):
|
||||
processor = TokenChunker(chunk_size=100, chunk_overlap=10)
|
||||
assert processor.text_splitter._chunk_size == 100
|
||||
assert processor.text_splitter._chunk_overlap == 10
|
||||
|
||||
def test_init_with_id(self, mock_async_processor_init):
|
||||
processor = TokenChunker(id="custom-token-chunker")
|
||||
assert processor.id == "custom-token-chunker"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_single_chunk(self, mock_async_processor_init, mock_flow, mock_consumer, short_document):
|
||||
flow_mock, output_mock = mock_flow
|
||||
processor = TokenChunker(chunk_size=250, chunk_overlap=15)
|
||||
|
||||
msg = Mock()
|
||||
msg.value.return_value = short_document
|
||||
|
||||
await processor.on_message(msg, mock_consumer, flow_mock)
|
||||
|
||||
# Short text should produce exactly one chunk
|
||||
assert output_mock.send.call_count == 1
|
||||
|
||||
# Verify the chunk was created correctly
|
||||
chunk_call = output_mock.send.call_args[0][0]
|
||||
assert isinstance(chunk_call, Chunk)
|
||||
assert chunk_call.metadata == short_document.metadata
|
||||
assert chunk_call.chunk.decode("utf-8") == short_document.text.decode("utf-8")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_multiple_chunks(self, mock_async_processor_init, mock_flow, mock_consumer, sample_document):
|
||||
flow_mock, output_mock = mock_flow
|
||||
processor = TokenChunker(chunk_size=50, chunk_overlap=5)
|
||||
|
||||
msg = Mock()
|
||||
msg.value.return_value = sample_document
|
||||
|
||||
await processor.on_message(msg, mock_consumer, flow_mock)
|
||||
|
||||
# Should produce multiple chunks
|
||||
assert output_mock.send.call_count > 1
|
||||
|
||||
# Verify all chunks have correct metadata
|
||||
for call in output_mock.send.call_args_list:
|
||||
chunk = call[0][0]
|
||||
assert isinstance(chunk, Chunk)
|
||||
assert chunk.metadata == sample_document.metadata
|
||||
assert len(chunk.chunk) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_token_overlap(self, mock_async_processor_init, mock_flow, mock_consumer):
|
||||
flow_mock, output_mock = mock_flow
|
||||
processor = TokenChunker(chunk_size=20, chunk_overlap=5)
|
||||
|
||||
# Create a document with repeated pattern
|
||||
metadata = Metadata(id="test", metadata=[], user="test-user", collection="test-collection")
|
||||
text = "one two three four five six seven eight nine ten " * 5
|
||||
document = TextDocument(metadata=metadata, text=text.encode("utf-8"))
|
||||
|
||||
msg = Mock()
|
||||
msg.value.return_value = document
|
||||
|
||||
await processor.on_message(msg, mock_consumer, flow_mock)
|
||||
|
||||
# Collect all chunks
|
||||
chunks = []
|
||||
for call in output_mock.send.call_args_list:
|
||||
chunk_text = call[0][0].chunk.decode("utf-8")
|
||||
chunks.append(chunk_text)
|
||||
|
||||
# Should have multiple chunks
|
||||
assert len(chunks) > 1
|
||||
|
||||
# Verify chunks are not empty
|
||||
for chunk in chunks:
|
||||
assert len(chunk) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_empty_document(self, mock_async_processor_init, mock_flow, mock_consumer):
|
||||
flow_mock, output_mock = mock_flow
|
||||
processor = TokenChunker()
|
||||
|
||||
metadata = Metadata(id="empty", metadata=[], user="test-user", collection="test-collection")
|
||||
document = TextDocument(metadata=metadata, text=b"")
|
||||
|
||||
msg = Mock()
|
||||
msg.value.return_value = document
|
||||
|
||||
await processor.on_message(msg, mock_consumer, flow_mock)
|
||||
|
||||
# Empty documents typically don't produce chunks with langchain splitters
|
||||
# This behavior is expected - no chunks should be produced
|
||||
assert output_mock.send.call_count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_unicode_handling(self, mock_async_processor_init, mock_flow, mock_consumer):
|
||||
flow_mock, output_mock = mock_flow
|
||||
processor = TokenChunker(chunk_size=50)
|
||||
|
||||
metadata = Metadata(id="unicode", metadata=[], user="test-user", collection="test-collection")
|
||||
# Test with various unicode characters
|
||||
text = "Hello 世界! 🌍 Test émojis café naïve résumé. Greek: αβγδε Hebrew: אבגדה"
|
||||
document = TextDocument(metadata=metadata, text=text.encode("utf-8"))
|
||||
|
||||
msg = Mock()
|
||||
msg.value.return_value = document
|
||||
|
||||
await processor.on_message(msg, mock_consumer, flow_mock)
|
||||
|
||||
# Verify unicode is preserved correctly
|
||||
all_chunks = []
|
||||
for call in output_mock.send.call_args_list:
|
||||
chunk_text = call[0][0].chunk.decode("utf-8")
|
||||
all_chunks.append(chunk_text)
|
||||
|
||||
# Reconstruct text
|
||||
reconstructed = "".join(all_chunks)
|
||||
assert "世界" in reconstructed
|
||||
assert "🌍" in reconstructed
|
||||
assert "émojis" in reconstructed
|
||||
assert "αβγδε" in reconstructed
|
||||
assert "אבגדה" in reconstructed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_token_boundary_preservation(self, mock_async_processor_init, mock_flow, mock_consumer):
|
||||
flow_mock, output_mock = mock_flow
|
||||
processor = TokenChunker(chunk_size=10, chunk_overlap=2)
|
||||
|
||||
metadata = Metadata(id="boundary", metadata=[], user="test-user", collection="test-collection")
|
||||
# Text with clear word boundaries
|
||||
text = "This is a test of token boundaries and proper splitting."
|
||||
document = TextDocument(metadata=metadata, text=text.encode("utf-8"))
|
||||
|
||||
msg = Mock()
|
||||
msg.value.return_value = document
|
||||
|
||||
await processor.on_message(msg, mock_consumer, flow_mock)
|
||||
|
||||
# Collect all chunks
|
||||
chunks = []
|
||||
for call in output_mock.send.call_args_list:
|
||||
chunk_text = call[0][0].chunk.decode("utf-8")
|
||||
chunks.append(chunk_text)
|
||||
|
||||
# Token chunker should respect token boundaries
|
||||
for chunk in chunks:
|
||||
# Chunks should not start or end with partial words (in most cases)
|
||||
assert len(chunk.strip()) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_metrics_recorded(self, mock_async_processor_init, mock_flow, mock_consumer, sample_document):
|
||||
flow_mock, output_mock = mock_flow
|
||||
processor = TokenChunker(chunk_size=50)
|
||||
|
||||
msg = Mock()
|
||||
msg.value.return_value = sample_document
|
||||
|
||||
# Mock the metric
|
||||
with patch.object(TokenChunker.chunk_metric, 'labels') as mock_labels:
|
||||
mock_observe = Mock()
|
||||
mock_labels.return_value.observe = mock_observe
|
||||
|
||||
await processor.on_message(msg, mock_consumer, flow_mock)
|
||||
|
||||
# Verify metrics were recorded
|
||||
mock_labels.assert_called_with(id="test-consumer", flow="test-flow")
|
||||
assert mock_observe.call_count > 0
|
||||
|
||||
# Verify chunk sizes were observed
|
||||
for call in mock_observe.call_args_list:
|
||||
chunk_size = call[0][0]
|
||||
assert chunk_size > 0
|
||||
|
||||
def test_add_args(self):
|
||||
parser = Mock()
|
||||
TokenChunker.add_args(parser)
|
||||
|
||||
# Verify arguments were added
|
||||
calls = parser.add_argument.call_args_list
|
||||
arg_names = [call[0][0] for call in calls]
|
||||
|
||||
assert '-z' in arg_names or '--chunk-size' in arg_names
|
||||
assert '-v' in arg_names or '--chunk-overlap' in arg_names
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_encoding_specific_behavior(self, mock_async_processor_init, mock_flow, mock_consumer):
|
||||
flow_mock, output_mock = mock_flow
|
||||
processor = TokenChunker(chunk_size=10, chunk_overlap=0)
|
||||
|
||||
metadata = Metadata(id="encoding", metadata=[], user="test-user", collection="test-collection")
|
||||
# Test text that might tokenize differently with cl100k_base encoding
|
||||
text = "GPT-4 is an AI model. It uses tokens."
|
||||
document = TextDocument(metadata=metadata, text=text.encode("utf-8"))
|
||||
|
||||
msg = Mock()
|
||||
msg.value.return_value = document
|
||||
|
||||
await processor.on_message(msg, mock_consumer, flow_mock)
|
||||
|
||||
# Verify chunking happened
|
||||
assert output_mock.send.call_count >= 1
|
||||
|
||||
# Collect all chunks
|
||||
chunks = []
|
||||
for call in output_mock.send.call_args_list:
|
||||
chunk_text = call[0][0].chunk.decode("utf-8")
|
||||
chunks.append(chunk_text)
|
||||
|
||||
# Verify all text is preserved (allowing for overlap)
|
||||
all_text = " ".join(chunks)
|
||||
assert "GPT-4" in all_text
|
||||
assert "AI model" in all_text
|
||||
assert "tokens" in all_text
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
@ -34,7 +34,9 @@ class TestGraphRag:
|
|||
assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client
|
||||
assert graph_rag.triples_client == mock_triples_client
|
||||
assert graph_rag.verbose is False # Default value
|
||||
assert graph_rag.label_cache == {} # Empty cache initially
|
||||
# Verify label_cache is an LRUCacheWithTTL instance
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL
|
||||
assert isinstance(graph_rag.label_cache, LRUCacheWithTTL)
|
||||
|
||||
def test_graph_rag_initialization_with_verbose(self):
|
||||
"""Test GraphRag initialization with verbose enabled"""
|
||||
|
|
@ -59,7 +61,9 @@ class TestGraphRag:
|
|||
assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client
|
||||
assert graph_rag.triples_client == mock_triples_client
|
||||
assert graph_rag.verbose is True
|
||||
assert graph_rag.label_cache == {} # Empty cache initially
|
||||
# Verify label_cache is an LRUCacheWithTTL instance
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL
|
||||
assert isinstance(graph_rag.label_cache, LRUCacheWithTTL)
|
||||
|
||||
|
||||
class TestQuery:
|
||||
|
|
@ -228,8 +232,11 @@ class TestQuery:
|
|||
"""Test Query.maybe_label method with cached label"""
|
||||
# Create mock GraphRag with label cache
|
||||
mock_rag = MagicMock()
|
||||
mock_rag.label_cache = {"entity1": "Entity One Label"}
|
||||
|
||||
# Create mock LRUCacheWithTTL
|
||||
mock_cache = MagicMock()
|
||||
mock_cache.get.return_value = "Entity One Label"
|
||||
mock_rag.label_cache = mock_cache
|
||||
|
||||
# Initialize Query
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
|
|
@ -237,27 +244,32 @@ class TestQuery:
|
|||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
||||
|
||||
# Call maybe_label with cached entity
|
||||
result = await query.maybe_label("entity1")
|
||||
|
||||
|
||||
# Verify cached label is returned
|
||||
assert result == "Entity One Label"
|
||||
# Verify cache was checked with proper key format (user:collection:entity)
|
||||
mock_cache.get.assert_called_once_with("test_user:test_collection:entity1")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_maybe_label_with_label_lookup(self):
|
||||
"""Test Query.maybe_label method with database label lookup"""
|
||||
# Create mock GraphRag with triples client
|
||||
mock_rag = MagicMock()
|
||||
mock_rag.label_cache = {} # Empty cache
|
||||
# Create mock LRUCacheWithTTL that returns None (cache miss)
|
||||
mock_cache = MagicMock()
|
||||
mock_cache.get.return_value = None
|
||||
mock_rag.label_cache = mock_cache
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_rag.triples_client = mock_triples_client
|
||||
|
||||
|
||||
# Mock triple result with label
|
||||
mock_triple = MagicMock()
|
||||
mock_triple.o = "Human Readable Label"
|
||||
mock_triples_client.query.return_value = [mock_triple]
|
||||
|
||||
|
||||
# Initialize Query
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
|
|
@ -265,10 +277,10 @@ class TestQuery:
|
|||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
||||
|
||||
# Call maybe_label
|
||||
result = await query.maybe_label("http://example.com/entity")
|
||||
|
||||
|
||||
# Verify triples client was called correctly
|
||||
mock_triples_client.query.assert_called_once_with(
|
||||
s="http://example.com/entity",
|
||||
|
|
@ -278,17 +290,21 @@ class TestQuery:
|
|||
user="test_user",
|
||||
collection="test_collection"
|
||||
)
|
||||
|
||||
# Verify result and cache update
|
||||
|
||||
# Verify result and cache update with proper key
|
||||
assert result == "Human Readable Label"
|
||||
assert mock_rag.label_cache["http://example.com/entity"] == "Human Readable Label"
|
||||
cache_key = "test_user:test_collection:http://example.com/entity"
|
||||
mock_cache.put.assert_called_once_with(cache_key, "Human Readable Label")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_maybe_label_with_no_label_found(self):
|
||||
"""Test Query.maybe_label method when no label is found"""
|
||||
# Create mock GraphRag with triples client
|
||||
mock_rag = MagicMock()
|
||||
mock_rag.label_cache = {} # Empty cache
|
||||
# Create mock LRUCacheWithTTL that returns None (cache miss)
|
||||
mock_cache = MagicMock()
|
||||
mock_cache.get.return_value = None
|
||||
mock_rag.label_cache = mock_cache
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_rag.triples_client = mock_triples_client
|
||||
|
||||
|
|
@ -318,7 +334,8 @@ class TestQuery:
|
|||
|
||||
# Verify result is entity itself and cache is updated
|
||||
assert result == "unlabeled_entity"
|
||||
assert mock_rag.label_cache["unlabeled_entity"] == "unlabeled_entity"
|
||||
cache_key = "test_user:test_collection:unlabeled_entity"
|
||||
mock_cache.put.assert_called_once_with(cache_key, "unlabeled_entity")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_follow_edges_basic_functionality(self):
|
||||
|
|
@ -441,40 +458,40 @@ class TestQuery:
|
|||
@pytest.mark.asyncio
|
||||
async def test_get_subgraph_method(self):
|
||||
"""Test Query.get_subgraph method orchestrates entity and edge discovery"""
|
||||
# Create mock Query that patches get_entities and follow_edges
|
||||
# Create mock Query that patches get_entities and follow_edges_batch
|
||||
mock_rag = MagicMock()
|
||||
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
user="test_user",
|
||||
user="test_user",
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
max_path_length=1
|
||||
)
|
||||
|
||||
|
||||
# Mock get_entities to return test entities
|
||||
query.get_entities = AsyncMock(return_value=["entity1", "entity2"])
|
||||
|
||||
# Mock follow_edges to add triples to subgraph
|
||||
async def mock_follow_edges(ent, subgraph, path_length):
|
||||
subgraph.add((ent, "predicate", "object"))
|
||||
|
||||
query.follow_edges = AsyncMock(side_effect=mock_follow_edges)
|
||||
|
||||
|
||||
# Mock follow_edges_batch to return test triples
|
||||
query.follow_edges_batch = AsyncMock(return_value={
|
||||
("entity1", "predicate1", "object1"),
|
||||
("entity2", "predicate2", "object2")
|
||||
})
|
||||
|
||||
# Call get_subgraph
|
||||
result = await query.get_subgraph("test query")
|
||||
|
||||
|
||||
# Verify get_entities was called
|
||||
query.get_entities.assert_called_once_with("test query")
|
||||
|
||||
# Verify follow_edges was called for each entity
|
||||
assert query.follow_edges.call_count == 2
|
||||
query.follow_edges.assert_any_call("entity1", unittest.mock.ANY, 1)
|
||||
query.follow_edges.assert_any_call("entity2", unittest.mock.ANY, 1)
|
||||
|
||||
# Verify result is list format
|
||||
|
||||
# Verify follow_edges_batch was called with entities and max_path_length
|
||||
query.follow_edges_batch.assert_called_once_with(["entity1", "entity2"], 1)
|
||||
|
||||
# Verify result is list format and contains expected triples
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 2
|
||||
assert ("entity1", "predicate1", "object1") in result
|
||||
assert ("entity2", "predicate2", "object2") in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_labelgraph_method(self):
|
||||
|
|
|
|||
|
|
@ -178,37 +178,24 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
assert calls[2][1]['vectors'][0]['metadata']['doc'] == "This is the second document chunk"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_index_creation(self, processor):
|
||||
"""Test automatic index creation when index doesn't exist"""
|
||||
async def test_store_document_embeddings_index_validation(self, processor):
|
||||
"""Test that writing to non-existent index raises ValueError"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk=b"Test document content",
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
# Mock index doesn't exist initially
|
||||
|
||||
# Mock index doesn't exist
|
||||
processor.pinecone.has_index.return_value = False
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
# Mock index creation
|
||||
processor.pinecone.describe_index.return_value.status = {"ready": True}
|
||||
|
||||
with patch('uuid.uuid4', return_value='test-id'):
|
||||
|
||||
with pytest.raises(ValueError, match="Collection .* does not exist"):
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
# Verify index creation was called
|
||||
expected_index_name = "d-test_user-test_collection"
|
||||
processor.pinecone.create_index.assert_called_once()
|
||||
create_call = processor.pinecone.create_index.call_args
|
||||
assert create_call[1]['name'] == expected_index_name
|
||||
assert create_call[1]['dimension'] == 3
|
||||
assert create_call[1]['metric'] == "cosine"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_empty_chunk(self, processor):
|
||||
|
|
@ -357,47 +344,44 @@ class TestPineconeDocEmbeddingsStorageProcessor:
|
|||
mock_index.upsert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_index_creation_failure(self, processor):
|
||||
"""Test handling of index creation failure"""
|
||||
async def test_store_document_embeddings_validation_before_creation(self, processor):
|
||||
"""Test that validation error occurs before creation attempts"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk=b"Test document content",
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
# Mock index doesn't exist and creation fails
|
||||
|
||||
# Mock index doesn't exist
|
||||
processor.pinecone.has_index.return_value = False
|
||||
processor.pinecone.create_index.side_effect = Exception("Index creation failed")
|
||||
|
||||
with pytest.raises(Exception, match="Index creation failed"):
|
||||
|
||||
with pytest.raises(ValueError, match="Collection .* does not exist"):
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_index_creation_timeout(self, processor):
|
||||
"""Test handling of index creation timeout"""
|
||||
async def test_store_document_embeddings_validates_before_timeout(self, processor):
|
||||
"""Test that validation error occurs before timeout checks"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
|
||||
chunk = ChunkEmbeddings(
|
||||
chunk=b"Test document content",
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.chunks = [chunk]
|
||||
|
||||
# Mock index doesn't exist and never becomes ready
|
||||
|
||||
# Mock index doesn't exist
|
||||
processor.pinecone.has_index.return_value = False
|
||||
processor.pinecone.describe_index.return_value.status = {"ready": False}
|
||||
|
||||
with patch('time.sleep'): # Speed up the test
|
||||
with pytest.raises(RuntimeError, match="Gave up waiting for index creation"):
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
with pytest.raises(ValueError, match="Collection .* does not exist"):
|
||||
await processor.store_document_embeddings(message)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_document_embeddings_unicode_content(self, processor):
|
||||
|
|
|
|||
|
|
@ -43,8 +43,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
# Verify processor attributes
|
||||
assert hasattr(processor, 'qdrant')
|
||||
assert processor.qdrant == mock_qdrant_instance
|
||||
assert hasattr(processor, 'last_collection')
|
||||
assert processor.last_collection is None
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
|
|
@ -245,8 +243,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True # Collection exists
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
|
|
@ -255,36 +254,37 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
|
||||
# Create mock message with empty chunk
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'empty_user'
|
||||
mock_message.metadata.collection = 'empty_collection'
|
||||
|
||||
|
||||
mock_chunk_empty = MagicMock()
|
||||
mock_chunk_empty.chunk.decode.return_value = "" # Empty string
|
||||
mock_chunk_empty.vectors = [[0.1, 0.2]]
|
||||
|
||||
|
||||
mock_message.chunks = [mock_chunk_empty]
|
||||
|
||||
|
||||
# Act
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
# Should not call upsert for empty chunks
|
||||
mock_qdrant_instance.upsert.assert_not_called()
|
||||
mock_qdrant_instance.collection_exists.assert_not_called()
|
||||
# But collection_exists should be called for validation
|
||||
mock_qdrant_instance.collection_exists.assert_called_once()
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_collection_creation_when_not_exists(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test collection creation when it doesn't exist"""
|
||||
"""Test that writing to non-existent collection raises ValueError"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = False # Collection doesn't exist
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
|
|
@ -293,46 +293,32 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'new_user'
|
||||
mock_message.metadata.collection = 'new_collection'
|
||||
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.chunk.decode.return_value = 'test chunk'
|
||||
mock_chunk.vectors = [[0.1, 0.2, 0.3, 0.4, 0.5]] # 5 dimensions
|
||||
|
||||
mock_message.chunks = [mock_chunk]
|
||||
|
||||
# Act
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
|
||||
# Assert
|
||||
expected_collection = 'd_new_user_new_collection'
|
||||
|
||||
# Verify collection existence check and creation
|
||||
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection)
|
||||
mock_qdrant_instance.create_collection.assert_called_once()
|
||||
|
||||
# Verify create_collection was called with correct parameters
|
||||
create_call_args = mock_qdrant_instance.create_collection.call_args
|
||||
assert create_call_args[1]['collection_name'] == expected_collection
|
||||
|
||||
# Verify upsert was still called after collection creation
|
||||
mock_qdrant_instance.upsert.assert_called_once()
|
||||
mock_message.chunks = [mock_chunk]
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Collection .* does not exist"):
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_collection_creation_exception(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test collection creation handles exceptions"""
|
||||
"""Test that validation error occurs before connection errors"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = False
|
||||
mock_qdrant_instance.create_collection.side_effect = Exception("Qdrant connection failed")
|
||||
mock_qdrant_instance.collection_exists.return_value = False # Collection doesn't exist
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
|
|
@ -341,32 +327,35 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.user = 'error_user'
|
||||
mock_message.metadata.collection = 'error_collection'
|
||||
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.chunk.decode.return_value = 'test chunk'
|
||||
mock_chunk.vectors = [[0.1, 0.2]]
|
||||
|
||||
|
||||
mock_message.chunks = [mock_chunk]
|
||||
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Qdrant connection failed"):
|
||||
with pytest.raises(ValueError, match="Collection .* does not exist"):
|
||||
await processor.store_document_embeddings(mock_message)
|
||||
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
|
||||
async def test_collection_caching_behavior(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test collection caching with last_collection"""
|
||||
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
|
||||
async def test_collection_validation_on_write(self, mock_uuid, mock_base_init, mock_qdrant_client):
|
||||
"""Test collection validation checks collection exists before writing"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
mock_uuid.uuid4.return_value = MagicMock()
|
||||
mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid')
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
|
|
@ -375,46 +364,45 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
|
||||
# Create first mock message
|
||||
mock_message1 = MagicMock()
|
||||
mock_message1.metadata.user = 'cache_user'
|
||||
mock_message1.metadata.collection = 'cache_collection'
|
||||
|
||||
|
||||
mock_chunk1 = MagicMock()
|
||||
mock_chunk1.chunk.decode.return_value = 'first chunk'
|
||||
mock_chunk1.vectors = [[0.1, 0.2, 0.3]]
|
||||
|
||||
|
||||
mock_message1.chunks = [mock_chunk1]
|
||||
|
||||
|
||||
# First call
|
||||
await processor.store_document_embeddings(mock_message1)
|
||||
|
||||
|
||||
# Reset mock to track second call
|
||||
mock_qdrant_instance.reset_mock()
|
||||
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
|
||||
# Create second mock message with same dimensions
|
||||
mock_message2 = MagicMock()
|
||||
mock_message2.metadata.user = 'cache_user'
|
||||
mock_message2.metadata.collection = 'cache_collection'
|
||||
|
||||
|
||||
mock_chunk2 = MagicMock()
|
||||
mock_chunk2.chunk.decode.return_value = 'second chunk'
|
||||
mock_chunk2.vectors = [[0.4, 0.5, 0.6]] # Same dimension (3)
|
||||
|
||||
|
||||
mock_message2.chunks = [mock_chunk2]
|
||||
|
||||
|
||||
# Act - Second call with same collection
|
||||
await processor.store_document_embeddings(mock_message2)
|
||||
|
||||
# Assert
|
||||
expected_collection = 'd_cache_user_cache_collection'
|
||||
assert processor.last_collection == expected_collection
|
||||
|
||||
# Verify second call skipped existence check (cached)
|
||||
mock_qdrant_instance.collection_exists.assert_not_called()
|
||||
mock_qdrant_instance.create_collection.assert_not_called()
|
||||
|
||||
|
||||
# Verify collection existence is checked on each write
|
||||
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection)
|
||||
|
||||
# But upsert should still be called
|
||||
mock_qdrant_instance.upsert.assert_called_once()
|
||||
|
||||
|
|
|
|||
|
|
@ -178,37 +178,24 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
assert calls[2][1]['vectors'][0]['metadata']['entity'] == "entity2"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_index_creation(self, processor):
|
||||
"""Test automatic index creation when index doesn't exist"""
|
||||
async def test_store_graph_embeddings_index_validation(self, processor):
|
||||
"""Test that writing to non-existent index raises ValueError"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Value(value="test_entity", is_uri=False),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
||||
# Mock index doesn't exist initially
|
||||
|
||||
# Mock index doesn't exist
|
||||
processor.pinecone.has_index.return_value = False
|
||||
mock_index = MagicMock()
|
||||
processor.pinecone.Index.return_value = mock_index
|
||||
|
||||
# Mock index creation
|
||||
processor.pinecone.describe_index.return_value.status = {"ready": True}
|
||||
|
||||
with patch('uuid.uuid4', return_value='test-id'):
|
||||
|
||||
with pytest.raises(ValueError, match="Collection .* does not exist"):
|
||||
await processor.store_graph_embeddings(message)
|
||||
|
||||
# Verify index creation was called
|
||||
expected_index_name = "t-test_user-test_collection"
|
||||
processor.pinecone.create_index.assert_called_once()
|
||||
create_call = processor.pinecone.create_index.call_args
|
||||
assert create_call[1]['name'] == expected_index_name
|
||||
assert create_call[1]['dimension'] == 3
|
||||
assert create_call[1]['metric'] == "cosine"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_empty_entity_value(self, processor):
|
||||
|
|
@ -328,47 +315,44 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
|
|||
mock_index.upsert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_index_creation_failure(self, processor):
|
||||
"""Test handling of index creation failure"""
|
||||
async def test_store_graph_embeddings_validation_before_creation(self, processor):
|
||||
"""Test that validation error occurs before any creation attempts"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Value(value="test_entity", is_uri=False),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
||||
# Mock index doesn't exist and creation fails
|
||||
|
||||
# Mock index doesn't exist
|
||||
processor.pinecone.has_index.return_value = False
|
||||
processor.pinecone.create_index.side_effect = Exception("Index creation failed")
|
||||
|
||||
with pytest.raises(Exception, match="Index creation failed"):
|
||||
|
||||
with pytest.raises(ValueError, match="Collection .* does not exist"):
|
||||
await processor.store_graph_embeddings(message)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_graph_embeddings_index_creation_timeout(self, processor):
|
||||
"""Test handling of index creation timeout"""
|
||||
async def test_store_graph_embeddings_validates_before_timeout(self, processor):
|
||||
"""Test that validation error occurs before timeout checks"""
|
||||
message = MagicMock()
|
||||
message.metadata = MagicMock()
|
||||
message.metadata.user = 'test_user'
|
||||
message.metadata.collection = 'test_collection'
|
||||
|
||||
|
||||
entity = EntityEmbeddings(
|
||||
entity=Value(value="test_entity", is_uri=False),
|
||||
vectors=[[0.1, 0.2, 0.3]]
|
||||
)
|
||||
message.entities = [entity]
|
||||
|
||||
# Mock index doesn't exist and never becomes ready
|
||||
|
||||
# Mock index doesn't exist
|
||||
processor.pinecone.has_index.return_value = False
|
||||
processor.pinecone.describe_index.return_value.status = {"ready": False}
|
||||
|
||||
with patch('time.sleep'): # Speed up the test
|
||||
with pytest.raises(RuntimeError, match="Gave up waiting for index creation"):
|
||||
await processor.store_graph_embeddings(message)
|
||||
|
||||
with pytest.raises(ValueError, match="Collection .* does not exist"):
|
||||
await processor.store_graph_embeddings(message)
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
|
|
|
|||
|
|
@ -43,19 +43,17 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
# Verify processor attributes
|
||||
assert hasattr(processor, 'qdrant')
|
||||
assert processor.qdrant == mock_qdrant_instance
|
||||
assert hasattr(processor, 'last_collection')
|
||||
assert processor.last_collection is None
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
async def test_get_collection_creates_new_collection(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test get_collection creates a new collection when it doesn't exist"""
|
||||
async def test_get_collection_validates_existence(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test get_collection validates that collection exists"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = False
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
|
|
@ -64,22 +62,10 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
collection_name = processor.get_collection(dim=512, user='test_user', collection='test_collection')
|
||||
|
||||
# Assert
|
||||
expected_name = 't_test_user_test_collection'
|
||||
assert collection_name == expected_name
|
||||
assert processor.last_collection == expected_name
|
||||
|
||||
# Verify collection existence check and creation
|
||||
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_name)
|
||||
mock_qdrant_instance.create_collection.assert_called_once()
|
||||
|
||||
# Verify create_collection was called with correct parameters
|
||||
create_call_args = mock_qdrant_instance.create_collection.call_args
|
||||
assert create_call_args[1]['collection_name'] == expected_name
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Collection .* does not exist"):
|
||||
processor.get_collection(user='test_user', collection='test_collection')
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid')
|
||||
|
|
@ -142,7 +128,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True # Collection exists
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
|
|
@ -151,15 +137,14 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
|
||||
# Act
|
||||
collection_name = processor.get_collection(dim=256, user='existing_user', collection='existing_collection')
|
||||
collection_name = processor.get_collection(user='existing_user', collection='existing_collection')
|
||||
|
||||
# Assert
|
||||
expected_name = 't_existing_user_existing_collection'
|
||||
assert collection_name == expected_name
|
||||
assert processor.last_collection == expected_name
|
||||
|
||||
|
||||
# Verify collection existence check was performed
|
||||
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_name)
|
||||
# Verify create_collection was NOT called
|
||||
|
|
@ -167,14 +152,14 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
async def test_get_collection_caches_last_collection(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test get_collection skips checks when using same collection"""
|
||||
async def test_get_collection_validates_on_each_call(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test get_collection validates collection existence on each call"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
|
|
@ -183,36 +168,36 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
|
||||
# First call
|
||||
collection_name1 = processor.get_collection(dim=128, user='cache_user', collection='cache_collection')
|
||||
|
||||
collection_name1 = processor.get_collection(user='cache_user', collection='cache_collection')
|
||||
|
||||
# Reset mock to track second call
|
||||
mock_qdrant_instance.reset_mock()
|
||||
|
||||
mock_qdrant_instance.collection_exists.return_value = True
|
||||
|
||||
# Act - Second call with same parameters
|
||||
collection_name2 = processor.get_collection(dim=128, user='cache_user', collection='cache_collection')
|
||||
collection_name2 = processor.get_collection(user='cache_user', collection='cache_collection')
|
||||
|
||||
# Assert
|
||||
expected_name = 't_cache_user_cache_collection'
|
||||
assert collection_name1 == expected_name
|
||||
assert collection_name2 == expected_name
|
||||
|
||||
# Verify second call skipped existence check (cached)
|
||||
mock_qdrant_instance.collection_exists.assert_not_called()
|
||||
|
||||
# Verify collection existence check happens on each call
|
||||
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_name)
|
||||
mock_qdrant_instance.create_collection.assert_not_called()
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
|
||||
async def test_get_collection_creation_exception(self, mock_base_init, mock_qdrant_client):
|
||||
"""Test get_collection handles collection creation exceptions"""
|
||||
"""Test get_collection raises ValueError when collection doesn't exist"""
|
||||
# Arrange
|
||||
mock_base_init.return_value = None
|
||||
mock_qdrant_instance = MagicMock()
|
||||
mock_qdrant_instance.collection_exists.return_value = False
|
||||
mock_qdrant_instance.create_collection.side_effect = Exception("Qdrant connection failed")
|
||||
mock_qdrant_client.return_value = mock_qdrant_instance
|
||||
|
||||
|
||||
config = {
|
||||
'store_uri': 'http://localhost:6333',
|
||||
'api_key': 'test-api-key',
|
||||
|
|
@ -221,10 +206,10 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Qdrant connection failed"):
|
||||
processor.get_collection(dim=512, user='error_user', collection='error_collection')
|
||||
with pytest.raises(ValueError, match="Collection .* does not exist"):
|
||||
processor.get_collection(user='error_user', collection='error_collection')
|
||||
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
|
||||
@patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid')
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ class TestMemgraphUserCollectionIsolation:
|
|||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
|
||||
# Mock execute_query response
|
||||
mock_result = MagicMock()
|
||||
mock_summary = MagicMock()
|
||||
|
|
@ -55,28 +55,30 @@ class TestMemgraphUserCollectionIsolation:
|
|||
mock_summary.result_available_after = 10
|
||||
mock_result.summary = mock_summary
|
||||
mock_driver.execute_query.return_value = mock_result
|
||||
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
||||
# Create mock triple with URI object
|
||||
triple = MagicMock()
|
||||
triple.s.value = "http://example.com/subject"
|
||||
triple.p.value = "http://example.com/predicate"
|
||||
triple.o.value = "http://example.com/object"
|
||||
triple.o.is_uri = True
|
||||
|
||||
|
||||
# Create mock message with metadata
|
||||
mock_message = MagicMock()
|
||||
mock_message.triples = [triple]
|
||||
mock_message.metadata.user = "test_user"
|
||||
mock_message.metadata.collection = "test_collection"
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify user/collection parameters were passed to all operations
|
||||
# Should have: create_node (subject), create_node (object), relate_node = 3 calls
|
||||
assert mock_driver.execute_query.call_count == 3
|
||||
|
||||
|
||||
# Check that user and collection were included in all calls
|
||||
for call in mock_driver.execute_query.call_args_list:
|
||||
call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1]
|
||||
|
|
@ -93,7 +95,7 @@ class TestMemgraphUserCollectionIsolation:
|
|||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
|
||||
# Mock execute_query response
|
||||
mock_result = MagicMock()
|
||||
mock_summary = MagicMock()
|
||||
|
|
@ -101,24 +103,26 @@ class TestMemgraphUserCollectionIsolation:
|
|||
mock_summary.result_available_after = 10
|
||||
mock_result.summary = mock_summary
|
||||
mock_driver.execute_query.return_value = mock_result
|
||||
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
||||
# Create mock triple
|
||||
triple = MagicMock()
|
||||
triple.s.value = "http://example.com/subject"
|
||||
triple.p.value = "http://example.com/predicate"
|
||||
triple.o.value = "literal_value"
|
||||
triple.o.is_uri = False
|
||||
|
||||
|
||||
# Create mock message without user/collection metadata
|
||||
mock_message = MagicMock()
|
||||
mock_message.triples = [triple]
|
||||
mock_message.metadata.user = None
|
||||
mock_message.metadata.collection = None
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify defaults were used
|
||||
for call in mock_driver.execute_query.call_args_list:
|
||||
call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1]
|
||||
|
|
@ -295,7 +299,7 @@ class TestMemgraphUserCollectionRegression:
|
|||
mock_graph_db.driver.return_value = mock_driver
|
||||
mock_session = MagicMock()
|
||||
mock_driver.session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
|
||||
# Mock execute_query response
|
||||
mock_result = MagicMock()
|
||||
mock_summary = MagicMock()
|
||||
|
|
@ -303,23 +307,25 @@ class TestMemgraphUserCollectionRegression:
|
|||
mock_summary.result_available_after = 10
|
||||
mock_result.summary = mock_summary
|
||||
mock_driver.execute_query.return_value = mock_result
|
||||
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
||||
# Store data for user1
|
||||
triple = MagicMock()
|
||||
triple.s.value = "http://example.com/subject"
|
||||
triple.p.value = "http://example.com/predicate"
|
||||
triple.o.value = "user1_data"
|
||||
triple.o.is_uri = False
|
||||
|
||||
|
||||
message_user1 = MagicMock()
|
||||
message_user1.triples = [triple]
|
||||
message_user1.metadata.user = "user1"
|
||||
message_user1.metadata.collection = "collection1"
|
||||
|
||||
await processor.store_triples(message_user1)
|
||||
|
||||
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
await processor.store_triples(message_user1)
|
||||
|
||||
# Verify that all storage operations included user1/collection1 parameters
|
||||
for call in mock_driver.execute_query.call_args_list:
|
||||
call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1]
|
||||
|
|
|
|||
|
|
@ -75,8 +75,10 @@ class TestNeo4jUserCollectionIsolation:
|
|||
mock_summary.counters.nodes_created = 1
|
||||
mock_summary.result_available_after = 10
|
||||
mock_driver.execute_query.return_value.summary = mock_summary
|
||||
|
||||
await processor.store_triples(message)
|
||||
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
await processor.store_triples(message)
|
||||
|
||||
# Verify nodes and relationships were created with user/collection properties
|
||||
expected_calls = [
|
||||
|
|
@ -141,8 +143,10 @@ class TestNeo4jUserCollectionIsolation:
|
|||
mock_summary.counters.nodes_created = 1
|
||||
mock_summary.result_available_after = 10
|
||||
mock_driver.execute_query.return_value.summary = mock_summary
|
||||
|
||||
await processor.store_triples(message)
|
||||
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
await processor.store_triples(message)
|
||||
|
||||
# Verify defaults were used
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
|
|
@ -273,10 +277,12 @@ class TestNeo4jUserCollectionIsolation:
|
|||
mock_summary.counters.nodes_created = 1
|
||||
mock_summary.result_available_after = 10
|
||||
mock_driver.execute_query.return_value.summary = mock_summary
|
||||
|
||||
# Store data for both users
|
||||
await processor.store_triples(message_user1)
|
||||
await processor.store_triples(message_user2)
|
||||
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
# Store data for both users
|
||||
await processor.store_triples(message_user1)
|
||||
await processor.store_triples(message_user2)
|
||||
|
||||
# Verify user1 data was stored with user1/coll1
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
|
|
@ -446,9 +452,11 @@ class TestNeo4jUserCollectionRegression:
|
|||
mock_summary.counters.nodes_created = 1
|
||||
mock_summary.result_available_after = 10
|
||||
mock_driver.execute_query.return_value.summary = mock_summary
|
||||
|
||||
await processor.store_triples(message_user1)
|
||||
await processor.store_triples(message_user2)
|
||||
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
await processor.store_triples(message_user1)
|
||||
await processor.store_triples(message_user2)
|
||||
|
||||
# Verify two separate nodes were created with same URI but different user/collection
|
||||
user1_node_call = call(
|
||||
|
|
|
|||
|
|
@ -251,6 +251,8 @@ class TestObjectsCassandraStorageLogic:
|
|||
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
|
||||
processor.session = MagicMock()
|
||||
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
||||
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
|
||||
processor.known_tables = {"test_user": set()} # Pre-populate
|
||||
|
||||
# Create test object
|
||||
test_obj = ExtractedObject(
|
||||
|
|
@ -291,18 +293,19 @@ class TestObjectsCassandraStorageLogic:
|
|||
"""Test that secondary indexes are created for indexed fields"""
|
||||
processor = MagicMock()
|
||||
processor.schemas = {}
|
||||
processor.known_keyspaces = set()
|
||||
processor.known_tables = {}
|
||||
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
|
||||
processor.known_tables = {"test_user": set()} # Pre-populate
|
||||
processor.session = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
||||
processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor)
|
||||
def mock_ensure_keyspace(keyspace):
|
||||
processor.known_keyspaces.add(keyspace)
|
||||
processor.known_tables[keyspace] = set()
|
||||
if keyspace not in processor.known_tables:
|
||||
processor.known_tables[keyspace] = set()
|
||||
processor.ensure_keyspace = mock_ensure_keyspace
|
||||
processor.ensure_table = Processor.ensure_table.__get__(processor, Processor)
|
||||
|
||||
|
||||
# Create schema with indexed field
|
||||
schema = RowSchema(
|
||||
name="products",
|
||||
|
|
@ -313,10 +316,10 @@ class TestObjectsCassandraStorageLogic:
|
|||
Field(name="price", type="float", size=8, indexed=True)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# Call ensure_table
|
||||
processor.ensure_table("test_user", "products", schema)
|
||||
|
||||
|
||||
# Should have 3 calls: create table + 2 indexes
|
||||
assert processor.session.execute.call_count == 3
|
||||
|
||||
|
|
@ -346,9 +349,10 @@ class TestObjectsCassandraStorageBatchLogic:
|
|||
]
|
||||
)
|
||||
}
|
||||
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
|
||||
processor.ensure_table = MagicMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
||||
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
|
||||
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
|
||||
processor.session = MagicMock()
|
||||
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
||||
|
|
@ -415,6 +419,8 @@ class TestObjectsCassandraStorageBatchLogic:
|
|||
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
|
||||
processor.session = MagicMock()
|
||||
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
||||
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
|
||||
processor.known_tables = {"test_user": set()} # Pre-populate
|
||||
|
||||
# Create empty batch object
|
||||
empty_batch_obj = ExtractedObject(
|
||||
|
|
@ -461,6 +467,8 @@ class TestObjectsCassandraStorageBatchLogic:
|
|||
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
|
||||
processor.session = MagicMock()
|
||||
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
||||
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
|
||||
processor.known_tables = {"test_user": set()} # Pre-populate
|
||||
|
||||
# Create single-item batch object (backward compatibility case)
|
||||
single_batch_obj = ExtractedObject(
|
||||
|
|
|
|||
|
|
@ -194,7 +194,13 @@ class TestFalkorDBStorageProcessor:
|
|||
mock_result.run_time_ms = 10
|
||||
processor.io.query.return_value = mock_result
|
||||
|
||||
await processor.store_triples(message)
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
|
||||
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
|
||||
|
||||
await processor.store_triples(message)
|
||||
|
||||
# Verify queries were called in the correct order
|
||||
expected_calls = [
|
||||
|
|
@ -225,7 +231,13 @@ class TestFalkorDBStorageProcessor:
|
|||
mock_result.run_time_ms = 10
|
||||
processor.io.query.return_value = mock_result
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
|
||||
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify queries were called in the correct order
|
||||
expected_calls = [
|
||||
|
|
@ -273,7 +285,13 @@ class TestFalkorDBStorageProcessor:
|
|||
mock_result.run_time_ms = 10
|
||||
processor.io.query.return_value = mock_result
|
||||
|
||||
await processor.store_triples(message)
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
|
||||
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
|
||||
|
||||
await processor.store_triples(message)
|
||||
|
||||
# Verify total number of queries (3 per triple)
|
||||
assert processor.io.query.call_count == 6
|
||||
|
|
@ -299,7 +317,13 @@ class TestFalkorDBStorageProcessor:
|
|||
message.metadata.collection = 'test_collection'
|
||||
message.triples = []
|
||||
|
||||
await processor.store_triples(message)
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
|
||||
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
|
||||
|
||||
await processor.store_triples(message)
|
||||
|
||||
# Verify no queries were made
|
||||
processor.io.query.assert_not_called()
|
||||
|
|
@ -329,7 +353,13 @@ class TestFalkorDBStorageProcessor:
|
|||
mock_result.run_time_ms = 10
|
||||
processor.io.query.return_value = mock_result
|
||||
|
||||
await processor.store_triples(message)
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
|
||||
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
|
||||
|
||||
await processor.store_triples(message)
|
||||
|
||||
# Verify total number of queries (3 per triple)
|
||||
assert processor.io.query.call_count == 6
|
||||
|
|
|
|||
|
|
@ -308,7 +308,13 @@ class TestMemgraphStorageProcessor:
|
|||
# Reset the mock to clear initialization calls
|
||||
processor.io.execute_query.reset_mock()
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
|
||||
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify execute_query was called for create_node, create_literal, and relate_literal
|
||||
# (since mock_message has a literal object)
|
||||
|
|
@ -352,7 +358,13 @@ class TestMemgraphStorageProcessor:
|
|||
)
|
||||
message.triples = [triple1, triple2]
|
||||
|
||||
await processor.store_triples(message)
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
|
||||
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
|
||||
|
||||
await processor.store_triples(message)
|
||||
|
||||
# Verify execute_query was called:
|
||||
# Triple1: create_node(s) + create_literal(o) + relate_literal = 3 calls
|
||||
|
|
@ -381,7 +393,13 @@ class TestMemgraphStorageProcessor:
|
|||
message.metadata.collection = 'test_collection'
|
||||
message.triples = []
|
||||
|
||||
await processor.store_triples(message)
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
|
||||
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
|
||||
|
||||
await processor.store_triples(message)
|
||||
|
||||
# Verify no session calls were made (no triples to process)
|
||||
processor.io.session.assert_not_called()
|
||||
|
|
|
|||
|
|
@ -268,7 +268,9 @@ class TestNeo4jStorageProcessor:
|
|||
mock_message.metadata.user = "test_user"
|
||||
mock_message.metadata.collection = "test_collection"
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify create_node was called for subject and object
|
||||
# Verify relate_node was called
|
||||
|
|
@ -336,7 +338,9 @@ class TestNeo4jStorageProcessor:
|
|||
mock_message.metadata.user = "test_user"
|
||||
mock_message.metadata.collection = "test_collection"
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify create_node was called for subject
|
||||
# Verify create_literal was called for object
|
||||
|
|
@ -411,7 +415,9 @@ class TestNeo4jStorageProcessor:
|
|||
mock_message.metadata.user = "test_user"
|
||||
mock_message.metadata.collection = "test_collection"
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Should have processed both triples
|
||||
# Triple1: 2 nodes + 1 relationship = 3 calls
|
||||
|
|
@ -437,7 +443,9 @@ class TestNeo4jStorageProcessor:
|
|||
mock_message.metadata.user = "test_user"
|
||||
mock_message.metadata.collection = "test_collection"
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Should not have made any execute_query calls beyond index creation
|
||||
# Only index creation calls should have been made during initialization
|
||||
|
|
@ -552,7 +560,9 @@ class TestNeo4jStorageProcessor:
|
|||
mock_message.metadata.user = "test_user"
|
||||
mock_message.metadata.collection = "test_collection"
|
||||
|
||||
await processor.store_triples(mock_message)
|
||||
# Mock collection_exists to bypass validation in unit tests
|
||||
with patch.object(processor, 'collection_exists', return_value=True):
|
||||
await processor.store_triples(mock_message)
|
||||
|
||||
# Verify the triple was processed with special characters preserved
|
||||
mock_driver.execute_query.assert_any_call(
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ class TestAzureOpenAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gpt-4'
|
||||
assert processor.default_model == 'gpt-4'
|
||||
assert processor.temperature == 0.0
|
||||
assert processor.max_output == 4192
|
||||
assert hasattr(processor, 'openai')
|
||||
|
|
@ -254,7 +254,7 @@ class TestAzureOpenAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gpt-35-turbo'
|
||||
assert processor.default_model == 'gpt-35-turbo'
|
||||
assert processor.temperature == 0.7
|
||||
assert processor.max_output == 2048
|
||||
mock_azure_openai_class.assert_called_once_with(
|
||||
|
|
@ -289,7 +289,7 @@ class TestAzureOpenAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gpt-4'
|
||||
assert processor.default_model == 'gpt-4'
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
assert processor.max_output == 4192 # default_max_output
|
||||
mock_azure_openai_class.assert_called_once_with(
|
||||
|
|
@ -402,6 +402,156 @@ class TestAzureOpenAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert call_args[1]['max_tokens'] == 1024
|
||||
assert call_args[1]['top_p'] == 1
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
|
||||
"""Test temperature parameter override functionality"""
|
||||
# Arrange
|
||||
mock_azure_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = 'Response with custom temperature'
|
||||
mock_response.usage.prompt_tokens = 20
|
||||
mock_response.usage.completion_tokens = 12
|
||||
|
||||
mock_azure_client.chat.completions.create.return_value = mock_response
|
||||
mock_azure_openai_class.return_value = mock_azure_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-4',
|
||||
'endpoint': 'https://test.openai.azure.com/',
|
||||
'token': 'test-token',
|
||||
'api_version': '2024-12-01-preview',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model=None, # Use default model
|
||||
temperature=0.8 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom temperature"
|
||||
|
||||
# Verify Azure OpenAI API was called with overridden temperature
|
||||
call_args = mock_azure_client.chat.completions.create.call_args
|
||||
assert call_args[1]['temperature'] == 0.8 # Should use runtime override
|
||||
assert call_args[1]['model'] == 'gpt-4'
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
|
||||
"""Test model parameter override functionality"""
|
||||
# Arrange
|
||||
mock_azure_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = 'Response with custom model'
|
||||
mock_response.usage.prompt_tokens = 18
|
||||
mock_response.usage.completion_tokens = 14
|
||||
|
||||
mock_azure_client.chat.completions.create.return_value = mock_response
|
||||
mock_azure_openai_class.return_value = mock_azure_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-4', # Default model
|
||||
'endpoint': 'https://test.openai.azure.com/',
|
||||
'token': 'test-token',
|
||||
'api_version': '2024-12-01-preview',
|
||||
'temperature': 0.1, # Default temperature
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="gpt-4o", # Override model
|
||||
temperature=None # Use default temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom model"
|
||||
|
||||
# Verify Azure OpenAI API was called with overridden model
|
||||
call_args = mock_azure_client.chat.completions.create.call_args
|
||||
assert call_args[1]['model'] == 'gpt-4o' # Should use runtime override
|
||||
assert call_args[1]['temperature'] == 0.1 # Should use processor default
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
|
||||
"""Test overriding both model and temperature parameters simultaneously"""
|
||||
# Arrange
|
||||
mock_azure_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = 'Response with both overrides'
|
||||
mock_response.usage.prompt_tokens = 22
|
||||
mock_response.usage.completion_tokens = 16
|
||||
|
||||
mock_azure_client.chat.completions.create.return_value = mock_response
|
||||
mock_azure_openai_class.return_value = mock_azure_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-4', # Default model
|
||||
'endpoint': 'https://test.openai.azure.com/',
|
||||
'token': 'test-token',
|
||||
'api_version': '2024-12-01-preview',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="gpt-4o-mini", # Override model
|
||||
temperature=0.9 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with both overrides"
|
||||
|
||||
# Verify Azure OpenAI API was called with both overrides
|
||||
call_args = mock_azure_client.chat.completions.create.call_args
|
||||
assert call_args[1]['model'] == 'gpt-4o-mini' # Should use runtime override
|
||||
assert call_args[1]['temperature'] == 0.9 # Should use runtime override
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
@ -43,7 +43,7 @@ class TestAzureProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert processor.token == 'test-token'
|
||||
assert processor.temperature == 0.0
|
||||
assert processor.max_output == 4192
|
||||
assert processor.model == 'AzureAI'
|
||||
assert processor.default_model == 'AzureAI'
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
|
|
@ -261,7 +261,7 @@ class TestAzureProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert processor.token == 'custom-token'
|
||||
assert processor.temperature == 0.7
|
||||
assert processor.max_output == 2048
|
||||
assert processor.model == 'AzureAI'
|
||||
assert processor.default_model == 'AzureAI'
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
|
|
@ -289,7 +289,7 @@ class TestAzureProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert processor.token == 'test-token'
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
assert processor.max_output == 4192 # default_max_output
|
||||
assert processor.model == 'AzureAI' # default_model
|
||||
assert processor.default_model == 'AzureAI' # default_model
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
|
|
@ -459,5 +459,150 @@ class TestAzureProcessorSimple(IsolatedAsyncioTestCase):
|
|||
)
|
||||
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_model_override(self, mock_llm_init, mock_async_init, mock_requests):
|
||||
"""Test generate_content with model parameter override"""
|
||||
# Arrange
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
'choices': [{
|
||||
'message': {
|
||||
'content': 'Response with model override'
|
||||
}
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 15,
|
||||
'completion_tokens': 10
|
||||
}
|
||||
}
|
||||
mock_requests.post.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
|
||||
'token': 'test-token',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model
|
||||
result = await processor.generate_content("System", "Prompt", model="custom-azure-model")
|
||||
|
||||
# Assert
|
||||
assert result.model == "custom-azure-model" # Should use overridden model
|
||||
assert result.text == "Response with model override"
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_temperature_override(self, mock_llm_init, mock_async_init, mock_requests):
|
||||
"""Test generate_content with temperature parameter override"""
|
||||
# Arrange
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
'choices': [{
|
||||
'message': {
|
||||
'content': 'Response with temperature override'
|
||||
}
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 15,
|
||||
'completion_tokens': 10
|
||||
}
|
||||
}
|
||||
mock_requests.post.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
|
||||
'token': 'test-token',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature
|
||||
result = await processor.generate_content("System", "Prompt", temperature=0.8)
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with temperature override"
|
||||
|
||||
# Verify the request was made with the overridden temperature
|
||||
mock_requests.post.assert_called_once()
|
||||
call_args = mock_requests.post.call_args
|
||||
|
||||
import json
|
||||
request_body = json.loads(call_args[1]['data'])
|
||||
assert request_body['temperature'] == 0.8
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_both_parameters_override(self, mock_llm_init, mock_async_init, mock_requests):
|
||||
"""Test generate_content with both model and temperature overrides"""
|
||||
# Arrange
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
'choices': [{
|
||||
'message': {
|
||||
'content': 'Response with both parameters override'
|
||||
}
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 18,
|
||||
'completion_tokens': 12
|
||||
}
|
||||
}
|
||||
mock_requests.post.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
|
||||
'token': 'test-token',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters
|
||||
result = await processor.generate_content("System", "Prompt", model="override-model", temperature=0.9)
|
||||
|
||||
# Assert
|
||||
assert result.model == "override-model"
|
||||
assert result.text == "Response with both parameters override"
|
||||
|
||||
# Verify the request was made with overridden temperature
|
||||
mock_requests.post.assert_called_once()
|
||||
call_args = mock_requests.post.call_args
|
||||
|
||||
import json
|
||||
request_body = json.loads(call_args[1]['data'])
|
||||
assert request_body['temperature'] == 0.9
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
280
tests/unit/test_text_completion/test_bedrock_processor.py
Normal file
280
tests/unit/test_text_completion/test_bedrock_processor.py
Normal file
|
|
@ -0,0 +1,280 @@
|
|||
"""
|
||||
Unit tests for trustgraph.model.text_completion.bedrock
|
||||
Following the same successful pattern as other processor tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
import json
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.model.text_completion.bedrock.llm import Processor, Mistral, Anthropic
|
||||
from trustgraph.base import LlmResult
|
||||
|
||||
|
||||
class TestBedrockProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Test Bedrock processor functionality"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.bedrock.llm.boto3.Session')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test basic processor initialization"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_bedrock = MagicMock()
|
||||
mock_session.client.return_value = mock_bedrock
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'mistral.mistral-large-2407-v1:0',
|
||||
'temperature': 0.1,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.default_model == 'mistral.mistral-large-2407-v1:0'
|
||||
assert processor.temperature == 0.1
|
||||
assert hasattr(processor, 'bedrock')
|
||||
mock_session_class.assert_called_once()
|
||||
|
||||
@patch('trustgraph.model.text_completion.bedrock.llm.boto3.Session')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_success_mistral(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test successful content generation with Mistral model"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_bedrock = MagicMock()
|
||||
mock_session.client.return_value = mock_bedrock
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_response = {
|
||||
'body': MagicMock(),
|
||||
'ResponseMetadata': {
|
||||
'HTTPHeaders': {
|
||||
'x-amzn-bedrock-input-token-count': '15',
|
||||
'x-amzn-bedrock-output-token-count': '8'
|
||||
}
|
||||
}
|
||||
}
|
||||
mock_response['body'].read.return_value = json.dumps({
|
||||
'outputs': [{'text': 'Generated response from Bedrock'}]
|
||||
})
|
||||
mock_bedrock.invoke_model.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'mistral.mistral-large-2407-v1:0',
|
||||
'temperature': 0.0,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Generated response from Bedrock"
|
||||
assert result.in_token == 15
|
||||
assert result.out_token == 8
|
||||
assert result.model == 'mistral.mistral-large-2407-v1:0'
|
||||
mock_bedrock.invoke_model.assert_called_once()
|
||||
|
||||
@patch('trustgraph.model.text_completion.bedrock.llm.boto3.Session')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test temperature parameter override functionality"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_bedrock = MagicMock()
|
||||
mock_session.client.return_value = mock_bedrock
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_response = {
|
||||
'body': MagicMock(),
|
||||
'ResponseMetadata': {
|
||||
'HTTPHeaders': {
|
||||
'x-amzn-bedrock-input-token-count': '20',
|
||||
'x-amzn-bedrock-output-token-count': '12'
|
||||
}
|
||||
}
|
||||
}
|
||||
mock_response['body'].read.return_value = json.dumps({
|
||||
'outputs': [{'text': 'Response with custom temperature'}]
|
||||
})
|
||||
mock_bedrock.invoke_model.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'mistral.mistral-large-2407-v1:0',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model=None, # Use default model
|
||||
temperature=0.8 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom temperature"
|
||||
|
||||
# Verify the model variant was created with overridden temperature
|
||||
# The cache key should include the temperature
|
||||
cache_key = f"mistral.mistral-large-2407-v1:0:0.8"
|
||||
assert cache_key in processor.model_variants
|
||||
variant = processor.model_variants[cache_key]
|
||||
assert variant.temperature == 0.8
|
||||
|
||||
@patch('trustgraph.model.text_completion.bedrock.llm.boto3.Session')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test model parameter override functionality"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_bedrock = MagicMock()
|
||||
mock_session.client.return_value = mock_bedrock
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_response = {
|
||||
'body': MagicMock(),
|
||||
'ResponseMetadata': {
|
||||
'HTTPHeaders': {
|
||||
'x-amzn-bedrock-input-token-count': '18',
|
||||
'x-amzn-bedrock-output-token-count': '14'
|
||||
}
|
||||
}
|
||||
}
|
||||
mock_response['body'].read.return_value = json.dumps({
|
||||
'content': [{'text': 'Response with custom model'}]
|
||||
})
|
||||
mock_bedrock.invoke_model.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'mistral.mistral-large-2407-v1:0', # Default model
|
||||
'temperature': 0.1, # Default temperature
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="anthropic.claude-3-sonnet-20240229-v1:0", # Override model
|
||||
temperature=None # Use default temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom model"
|
||||
|
||||
# Verify Bedrock API was called with overridden model
|
||||
mock_bedrock.invoke_model.assert_called_once()
|
||||
call_args = mock_bedrock.invoke_model.call_args
|
||||
assert call_args[1]['modelId'] == "anthropic.claude-3-sonnet-20240229-v1:0"
|
||||
|
||||
# Verify the correct model variant (Anthropic) was used
|
||||
cache_key = f"anthropic.claude-3-sonnet-20240229-v1:0:0.1"
|
||||
assert cache_key in processor.model_variants
|
||||
variant = processor.model_variants[cache_key]
|
||||
assert isinstance(variant, Anthropic)
|
||||
|
||||
@patch('trustgraph.model.text_completion.bedrock.llm.boto3.Session')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test overriding both model and temperature parameters simultaneously"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_bedrock = MagicMock()
|
||||
mock_session.client.return_value = mock_bedrock
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_response = {
|
||||
'body': MagicMock(),
|
||||
'ResponseMetadata': {
|
||||
'HTTPHeaders': {
|
||||
'x-amzn-bedrock-input-token-count': '22',
|
||||
'x-amzn-bedrock-output-token-count': '16'
|
||||
}
|
||||
}
|
||||
}
|
||||
mock_response['body'].read.return_value = json.dumps({
|
||||
'generation': 'Response with both overrides'
|
||||
})
|
||||
mock_bedrock.invoke_model.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'mistral.mistral-large-2407-v1:0', # Default model
|
||||
'temperature': 0.0, # Default temperature
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="meta.llama3-70b-instruct-v1:0", # Override model (Meta/Llama)
|
||||
temperature=0.9 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with both overrides"
|
||||
|
||||
# Verify Bedrock API was called with both overrides
|
||||
mock_bedrock.invoke_model.assert_called_once()
|
||||
call_args = mock_bedrock.invoke_model.call_args
|
||||
assert call_args[1]['modelId'] == "meta.llama3-70b-instruct-v1:0"
|
||||
|
||||
# Verify the correct model variant (Meta) was used with correct temperature
|
||||
cache_key = f"meta.llama3-70b-instruct-v1:0:0.9"
|
||||
assert cache_key in processor.model_variants
|
||||
variant = processor.model_variants[cache_key]
|
||||
assert variant.temperature == 0.9
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
@ -42,7 +42,7 @@ class TestClaudeProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'claude-3-5-sonnet-20240620'
|
||||
assert processor.default_model == 'claude-3-5-sonnet-20240620'
|
||||
assert processor.temperature == 0.0
|
||||
assert processor.max_output == 8192
|
||||
assert hasattr(processor, 'claude')
|
||||
|
|
@ -217,7 +217,7 @@ class TestClaudeProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'claude-3-haiku-20240307'
|
||||
assert processor.default_model == 'claude-3-haiku-20240307'
|
||||
assert processor.temperature == 0.7
|
||||
assert processor.max_output == 4096
|
||||
mock_anthropic_class.assert_called_once_with(api_key='custom-api-key')
|
||||
|
|
@ -246,7 +246,7 @@ class TestClaudeProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'claude-3-5-sonnet-20240620' # default_model
|
||||
assert processor.default_model == 'claude-3-5-sonnet-20240620' # default_model
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
assert processor.max_output == 8192 # default_max_output
|
||||
mock_anthropic_class.assert_called_once_with(api_key='test-api-key')
|
||||
|
|
@ -433,7 +433,157 @@ class TestClaudeProcessorSimple(IsolatedAsyncioTestCase):
|
|||
|
||||
# Verify processor has the client
|
||||
assert processor.claude == mock_claude_client
|
||||
assert processor.model == 'claude-3-opus-20240229'
|
||||
assert processor.default_model == 'claude-3-opus-20240229'
|
||||
|
||||
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_anthropic_class):
|
||||
"""Test temperature parameter override functionality"""
|
||||
# Arrange
|
||||
mock_claude_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock()]
|
||||
mock_response.content[0].text = "Response with custom temperature"
|
||||
mock_response.usage.input_tokens = 20
|
||||
mock_response.usage.output_tokens = 12
|
||||
|
||||
mock_claude_client.messages.create.return_value = mock_response
|
||||
mock_anthropic_class.return_value = mock_claude_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'claude-3-5-sonnet-20240620',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model=None, # Use default model
|
||||
temperature=0.9 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom temperature"
|
||||
|
||||
# Verify Claude API was called with overridden temperature
|
||||
mock_claude_client.messages.create.assert_called_once()
|
||||
call_kwargs = mock_claude_client.messages.create.call_args.kwargs
|
||||
|
||||
assert call_kwargs['temperature'] == 0.9 # Should use runtime override
|
||||
assert call_kwargs['model'] == 'claude-3-5-sonnet-20240620' # Should use processor default
|
||||
|
||||
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_anthropic_class):
|
||||
"""Test model parameter override functionality"""
|
||||
# Arrange
|
||||
mock_claude_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock()]
|
||||
mock_response.content[0].text = "Response with custom model"
|
||||
mock_response.usage.input_tokens = 18
|
||||
mock_response.usage.output_tokens = 14
|
||||
|
||||
mock_claude_client.messages.create.return_value = mock_response
|
||||
mock_anthropic_class.return_value = mock_claude_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'claude-3-5-sonnet-20240620', # Default model
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.2, # Default temperature
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="claude-3-haiku-20240307", # Override model
|
||||
temperature=None # Use default temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom model"
|
||||
|
||||
# Verify Claude API was called with overridden model
|
||||
mock_claude_client.messages.create.assert_called_once()
|
||||
call_kwargs = mock_claude_client.messages.create.call_args.kwargs
|
||||
|
||||
assert call_kwargs['model'] == 'claude-3-haiku-20240307' # Should use runtime override
|
||||
assert call_kwargs['temperature'] == 0.2 # Should use processor default
|
||||
|
||||
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_anthropic_class):
|
||||
"""Test overriding both model and temperature parameters simultaneously"""
|
||||
# Arrange
|
||||
mock_claude_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock()]
|
||||
mock_response.content[0].text = "Response with both overrides"
|
||||
mock_response.usage.input_tokens = 22
|
||||
mock_response.usage.output_tokens = 16
|
||||
|
||||
mock_claude_client.messages.create.return_value = mock_response
|
||||
mock_anthropic_class.return_value = mock_claude_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'claude-3-5-sonnet-20240620', # Default model
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="claude-3-opus-20240229", # Override model
|
||||
temperature=0.8 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with both overrides"
|
||||
|
||||
# Verify Claude API was called with both overrides
|
||||
mock_claude_client.messages.create.assert_called_once()
|
||||
call_kwargs = mock_claude_client.messages.create.call_args.kwargs
|
||||
|
||||
assert call_kwargs['model'] == 'claude-3-opus-20240229' # Should use runtime override
|
||||
assert call_kwargs['temperature'] == 0.8 # Should use runtime override
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ class TestCohereProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'c4ai-aya-23-8b'
|
||||
assert processor.default_model == 'c4ai-aya-23-8b'
|
||||
assert processor.temperature == 0.0
|
||||
assert hasattr(processor, 'cohere')
|
||||
mock_cohere_class.assert_called_once_with(api_key='test-api-key')
|
||||
|
|
@ -201,7 +201,7 @@ class TestCohereProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'command-light'
|
||||
assert processor.default_model == 'command-light'
|
||||
assert processor.temperature == 0.7
|
||||
mock_cohere_class.assert_called_once_with(api_key='custom-api-key')
|
||||
|
||||
|
|
@ -229,7 +229,7 @@ class TestCohereProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'c4ai-aya-23-8b' # default_model
|
||||
assert processor.default_model == 'c4ai-aya-23-8b' # default_model
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
mock_cohere_class.assert_called_once_with(api_key='test-api-key')
|
||||
|
||||
|
|
@ -395,7 +395,7 @@ class TestCohereProcessorSimple(IsolatedAsyncioTestCase):
|
|||
|
||||
# Verify processor has the client
|
||||
assert processor.cohere == mock_cohere_client
|
||||
assert processor.model == 'command-r'
|
||||
assert processor.default_model == 'command-r'
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
|
|
@ -442,6 +442,162 @@ class TestCohereProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert call_args[1]['prompt_truncation'] == 'auto'
|
||||
assert call_args[1]['connectors'] == []
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_cohere_class):
|
||||
"""Test temperature parameter override functionality"""
|
||||
# Arrange
|
||||
mock_cohere_client = MagicMock()
|
||||
mock_output = MagicMock()
|
||||
mock_output.text = 'Response with custom temperature'
|
||||
mock_output.meta.billed_units.input_tokens = 20
|
||||
mock_output.meta.billed_units.output_tokens = 12
|
||||
|
||||
mock_cohere_client.chat.return_value = mock_output
|
||||
mock_cohere_class.return_value = mock_cohere_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'c4ai-aya-23-8b',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model=None, # Use default model
|
||||
temperature=0.8 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom temperature"
|
||||
|
||||
# Verify Cohere API was called with overridden temperature
|
||||
mock_cohere_client.chat.assert_called_once_with(
|
||||
model='c4ai-aya-23-8b',
|
||||
message='User prompt',
|
||||
preamble='System prompt',
|
||||
temperature=0.8, # Should use runtime override
|
||||
chat_history=[],
|
||||
prompt_truncation='auto',
|
||||
connectors=[]
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_cohere_class):
|
||||
"""Test model parameter override functionality"""
|
||||
# Arrange
|
||||
mock_cohere_client = MagicMock()
|
||||
mock_output = MagicMock()
|
||||
mock_output.text = 'Response with custom model'
|
||||
mock_output.meta.billed_units.input_tokens = 18
|
||||
mock_output.meta.billed_units.output_tokens = 14
|
||||
|
||||
mock_cohere_client.chat.return_value = mock_output
|
||||
mock_cohere_class.return_value = mock_cohere_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'c4ai-aya-23-8b', # Default model
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.1, # Default temperature
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="command-r-plus", # Override model
|
||||
temperature=None # Use default temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom model"
|
||||
|
||||
# Verify Cohere API was called with overridden model
|
||||
mock_cohere_client.chat.assert_called_once_with(
|
||||
model='command-r-plus', # Should use runtime override
|
||||
message='User prompt',
|
||||
preamble='System prompt',
|
||||
temperature=0.1, # Should use processor default
|
||||
chat_history=[],
|
||||
prompt_truncation='auto',
|
||||
connectors=[]
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_cohere_class):
|
||||
"""Test overriding both model and temperature parameters simultaneously"""
|
||||
# Arrange
|
||||
mock_cohere_client = MagicMock()
|
||||
mock_output = MagicMock()
|
||||
mock_output.text = 'Response with both overrides'
|
||||
mock_output.meta.billed_units.input_tokens = 22
|
||||
mock_output.meta.billed_units.output_tokens = 16
|
||||
|
||||
mock_cohere_client.chat.return_value = mock_output
|
||||
mock_cohere_class.return_value = mock_cohere_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'c4ai-aya-23-8b', # Default model
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="command-r", # Override model
|
||||
temperature=0.9 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with both overrides"
|
||||
|
||||
# Verify Cohere API was called with both overrides
|
||||
mock_cohere_client.chat.assert_called_once_with(
|
||||
model='command-r', # Should use runtime override
|
||||
message='User prompt',
|
||||
preamble='System prompt',
|
||||
temperature=0.9, # Should use runtime override
|
||||
chat_history=[],
|
||||
prompt_truncation='auto',
|
||||
connectors=[]
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
@ -42,7 +42,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gemini-2.0-flash-001'
|
||||
assert processor.default_model == 'gemini-2.0-flash-001'
|
||||
assert processor.temperature == 0.0
|
||||
assert processor.max_output == 8192
|
||||
assert hasattr(processor, 'client')
|
||||
|
|
@ -205,7 +205,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gemini-1.5-pro'
|
||||
assert processor.default_model == 'gemini-1.5-pro'
|
||||
assert processor.temperature == 0.7
|
||||
assert processor.max_output == 4096
|
||||
mock_genai_class.assert_called_once_with(api_key='custom-api-key')
|
||||
|
|
@ -234,7 +234,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gemini-2.0-flash-001' # default_model
|
||||
assert processor.default_model == 'gemini-2.0-flash-001' # default_model
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
assert processor.max_output == 8192 # default_max_output
|
||||
mock_genai_class.assert_called_once_with(api_key='test-api-key')
|
||||
|
|
@ -431,7 +431,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
|
|||
|
||||
# Verify processor has the client
|
||||
assert processor.client == mock_genai_client
|
||||
assert processor.model == 'gemini-1.5-flash'
|
||||
assert processor.default_model == 'gemini-1.5-flash'
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
|
|
@ -477,6 +477,156 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
|
|||
# The system instruction should be in the config object
|
||||
assert call_args[1]['contents'] == "Explain quantum computing"
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test temperature parameter override functionality"""
|
||||
# Arrange
|
||||
mock_genai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = 'Response with custom temperature'
|
||||
mock_response.usage_metadata.prompt_token_count = 20
|
||||
mock_response.usage_metadata.candidates_token_count = 12
|
||||
|
||||
mock_genai_client.models.generate_content.return_value = mock_response
|
||||
mock_genai_class.return_value = mock_genai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model=None, # Use default model
|
||||
temperature=0.8 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom temperature"
|
||||
|
||||
# Verify the generation config was created with overridden temperature
|
||||
cache_key = f"gemini-2.0-flash-001:0.8"
|
||||
assert cache_key in processor.generation_configs
|
||||
config_obj = processor.generation_configs[cache_key]
|
||||
assert config_obj.temperature == 0.8
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test model parameter override functionality"""
|
||||
# Arrange
|
||||
mock_genai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = 'Response with custom model'
|
||||
mock_response.usage_metadata.prompt_token_count = 18
|
||||
mock_response.usage_metadata.candidates_token_count = 14
|
||||
|
||||
mock_genai_client.models.generate_content.return_value = mock_response
|
||||
mock_genai_class.return_value = mock_genai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001', # Default model
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.1, # Default temperature
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="gemini-1.5-pro", # Override model
|
||||
temperature=None # Use default temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom model"
|
||||
|
||||
# Verify Google AI Studio API was called with overridden model
|
||||
call_args = mock_genai_client.models.generate_content.call_args
|
||||
assert call_args[1]['model'] == 'gemini-1.5-pro' # Should use runtime override
|
||||
|
||||
# Verify the generation config was created for the correct model
|
||||
cache_key = f"gemini-1.5-pro:0.1"
|
||||
assert cache_key in processor.generation_configs
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test overriding both model and temperature parameters simultaneously"""
|
||||
# Arrange
|
||||
mock_genai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = 'Response with both overrides'
|
||||
mock_response.usage_metadata.prompt_token_count = 22
|
||||
mock_response.usage_metadata.candidates_token_count = 16
|
||||
|
||||
mock_genai_client.models.generate_content.return_value = mock_response
|
||||
mock_genai_class.return_value = mock_genai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001', # Default model
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="gemini-1.5-flash", # Override model
|
||||
temperature=0.9 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with both overrides"
|
||||
|
||||
# Verify Google AI Studio API was called with both overrides
|
||||
call_args = mock_genai_client.models.generate_content.call_args
|
||||
assert call_args[1]['model'] == 'gemini-1.5-flash' # Should use runtime override
|
||||
|
||||
# Verify the generation config was created with both overrides
|
||||
cache_key = f"gemini-1.5-flash:0.9"
|
||||
assert cache_key in processor.generation_configs
|
||||
config_obj = processor.generation_configs[cache_key]
|
||||
assert config_obj.temperature == 0.9
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
@ -42,7 +42,7 @@ class TestLlamaFileProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'LLaMA_CPP'
|
||||
assert processor.default_model == 'LLaMA_CPP'
|
||||
assert processor.llamafile == 'http://localhost:8080/v1'
|
||||
assert processor.temperature == 0.0
|
||||
assert processor.max_output == 4096
|
||||
|
|
@ -91,7 +91,7 @@ class TestLlamaFileProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert result.text == "Generated response from LlamaFile"
|
||||
assert result.in_token == 20
|
||||
assert result.out_token == 12
|
||||
assert result.model == 'llama.cpp' # Note: model in result is hardcoded to 'llama.cpp'
|
||||
assert result.model == 'LLaMA_CPP' # Uses the default model name
|
||||
|
||||
# Verify the OpenAI API call structure
|
||||
mock_openai_client.chat.completions.create.assert_called_once_with(
|
||||
|
|
@ -99,7 +99,15 @@ class TestLlamaFileProcessorSimple(IsolatedAsyncioTestCase):
|
|||
messages=[{
|
||||
"role": "user",
|
||||
"content": "System prompt\n\nUser prompt"
|
||||
}]
|
||||
}],
|
||||
temperature=0.0,
|
||||
max_tokens=4096,
|
||||
top_p=1,
|
||||
frequency_penalty=0,
|
||||
presence_penalty=0,
|
||||
response_format={
|
||||
"type": "text"
|
||||
}
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
|
|
@ -157,7 +165,7 @@ class TestLlamaFileProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'custom-llama'
|
||||
assert processor.default_model == 'custom-llama'
|
||||
assert processor.llamafile == 'http://custom-host:8080/v1'
|
||||
assert processor.temperature == 0.7
|
||||
assert processor.max_output == 2048
|
||||
|
|
@ -189,7 +197,7 @@ class TestLlamaFileProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'LLaMA_CPP' # default_model
|
||||
assert processor.default_model == 'LLaMA_CPP' # default_model
|
||||
assert processor.llamafile == 'http://localhost:8080/v1' # default_llamafile
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
assert processor.max_output == 4096 # default_max_output
|
||||
|
|
@ -237,7 +245,7 @@ class TestLlamaFileProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert result.text == "Default response"
|
||||
assert result.in_token == 2
|
||||
assert result.out_token == 3
|
||||
assert result.model == 'llama.cpp'
|
||||
assert result.model == 'LLaMA_CPP'
|
||||
|
||||
# Verify the combined prompt is sent correctly
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
|
|
@ -408,8 +416,8 @@ class TestLlamaFileProcessorSimple(IsolatedAsyncioTestCase):
|
|||
result = await processor.generate_content("System", "User")
|
||||
|
||||
# Assert
|
||||
assert result.model == 'llama.cpp' # Should always be 'llama.cpp', not 'custom-model-name'
|
||||
assert processor.model == 'custom-model-name' # But processor.model should still be custom
|
||||
assert result.model == 'custom-model-name' # Uses the actual model name passed to generate_content
|
||||
assert processor.default_model == 'custom-model-name' # But processor.model should still be custom
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
|
|
@ -450,5 +458,132 @@ class TestLlamaFileProcessorSimple(IsolatedAsyncioTestCase):
|
|||
# No specific rate limit error handling tested since SLM presumably has no rate limits
|
||||
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_model_override(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test generate_content with model parameter override"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Response from overridden model"
|
||||
mock_response.usage.prompt_tokens = 15
|
||||
mock_response.usage.completion_tokens = 10
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'LLaMA_CPP',
|
||||
'llamafile': 'http://localhost:8080/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model
|
||||
result = await processor.generate_content("System", "Prompt", model="custom-llamafile-model")
|
||||
|
||||
# Assert
|
||||
assert result.model == "custom-llamafile-model" # Should use overridden model
|
||||
assert result.text == "Response from overridden model"
|
||||
|
||||
# Verify the API call was made with overridden model
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
assert call_args[1]['model'] == "custom-llamafile-model"
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_temperature_override(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test generate_content with temperature parameter override"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Response with temperature override"
|
||||
mock_response.usage.prompt_tokens = 18
|
||||
mock_response.usage.completion_tokens = 12
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'LLaMA_CPP',
|
||||
'llamafile': 'http://localhost:8080/v1',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature
|
||||
result = await processor.generate_content("System", "Prompt", temperature=0.7)
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with temperature override"
|
||||
|
||||
# Verify the API call was made with overridden temperature
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
assert call_args[1]['temperature'] == 0.7
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_both_parameters_override(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test generate_content with both model and temperature overrides"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Response with both parameters override"
|
||||
mock_response.usage.prompt_tokens = 20
|
||||
mock_response.usage.completion_tokens = 15
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'LLaMA_CPP',
|
||||
'llamafile': 'http://localhost:8080/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters
|
||||
result = await processor.generate_content("System", "Prompt", model="override-model", temperature=0.8)
|
||||
|
||||
# Assert
|
||||
assert result.model == "override-model"
|
||||
assert result.text == "Response with both parameters override"
|
||||
|
||||
# Verify the API call was made with overridden parameters
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
assert call_args[1]['model'] == "override-model"
|
||||
assert call_args[1]['temperature'] == 0.8
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
229
tests/unit/test_text_completion/test_lmstudio_processor.py
Normal file
229
tests/unit/test_text_completion/test_lmstudio_processor.py
Normal file
|
|
@ -0,0 +1,229 @@
|
|||
"""
|
||||
Unit tests for trustgraph.model.text_completion.lmstudio
|
||||
Following the same successful pattern as previous tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.model.text_completion.lmstudio.llm import Processor
|
||||
from trustgraph.base import LlmResult
|
||||
from trustgraph.exceptions import TooManyRequests
|
||||
|
||||
|
||||
class TestLMStudioProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Test LMStudio processor functionality"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.lmstudio.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test basic processor initialization"""
|
||||
# Arrange
|
||||
mock_openai = MagicMock()
|
||||
mock_openai_class.return_value = mock_openai
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemma3:9b',
|
||||
'url': 'http://localhost:1234/',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.default_model == 'gemma3:9b'
|
||||
assert processor.url == 'http://localhost:1234/v1/'
|
||||
assert processor.temperature == 0.0
|
||||
assert processor.max_output == 4096
|
||||
assert hasattr(processor, 'openai')
|
||||
mock_openai_class.assert_called_once_with(
|
||||
base_url='http://localhost:1234/v1/',
|
||||
api_key='sk-no-key-required'
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.lmstudio.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test successful content generation"""
|
||||
# Arrange
|
||||
mock_openai = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices[0].message.content = 'Generated response from LMStudio'
|
||||
mock_response.usage.prompt_tokens = 20
|
||||
mock_response.usage.completion_tokens = 12
|
||||
|
||||
mock_openai.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemma3:9b',
|
||||
'url': 'http://localhost:1234/',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Generated response from LMStudio"
|
||||
assert result.in_token == 20
|
||||
assert result.out_token == 12
|
||||
assert result.model == 'gemma3:9b'
|
||||
|
||||
# Verify the API call was made correctly
|
||||
mock_openai.chat.completions.create.assert_called_once()
|
||||
call_args = mock_openai.chat.completions.create.call_args
|
||||
|
||||
# Check model and temperature
|
||||
assert call_args[1]['model'] == 'gemma3:9b'
|
||||
assert call_args[1]['temperature'] == 0.0
|
||||
assert call_args[1]['max_tokens'] == 4096
|
||||
|
||||
@patch('trustgraph.model.text_completion.lmstudio.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_model_override(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test generate_content with model parameter override"""
|
||||
# Arrange
|
||||
mock_openai = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices[0].message.content = 'Response from overridden model'
|
||||
mock_response.usage.prompt_tokens = 15
|
||||
mock_response.usage.completion_tokens = 10
|
||||
|
||||
mock_openai.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemma3:9b',
|
||||
'url': 'http://localhost:1234/',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model
|
||||
result = await processor.generate_content("System", "Prompt", model="custom-lmstudio-model")
|
||||
|
||||
# Assert
|
||||
assert result.model == "custom-lmstudio-model" # Should use overridden model
|
||||
assert result.text == "Response from overridden model"
|
||||
|
||||
# Verify the API call was made with overridden model
|
||||
call_args = mock_openai.chat.completions.create.call_args
|
||||
assert call_args[1]['model'] == "custom-lmstudio-model"
|
||||
|
||||
@patch('trustgraph.model.text_completion.lmstudio.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_temperature_override(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test generate_content with temperature parameter override"""
|
||||
# Arrange
|
||||
mock_openai = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices[0].message.content = 'Response with temperature override'
|
||||
mock_response.usage.prompt_tokens = 18
|
||||
mock_response.usage.completion_tokens = 12
|
||||
|
||||
mock_openai.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemma3:9b',
|
||||
'url': 'http://localhost:1234/',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature
|
||||
result = await processor.generate_content("System", "Prompt", temperature=0.7)
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with temperature override"
|
||||
|
||||
# Verify the API call was made with overridden temperature
|
||||
call_args = mock_openai.chat.completions.create.call_args
|
||||
assert call_args[1]['temperature'] == 0.7
|
||||
|
||||
@patch('trustgraph.model.text_completion.lmstudio.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_both_parameters_override(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test generate_content with both model and temperature overrides"""
|
||||
# Arrange
|
||||
mock_openai = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices[0].message.content = 'Response with both parameters override'
|
||||
mock_response.usage.prompt_tokens = 20
|
||||
mock_response.usage.completion_tokens = 15
|
||||
|
||||
mock_openai.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemma3:9b',
|
||||
'url': 'http://localhost:1234/',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters
|
||||
result = await processor.generate_content("System", "Prompt", model="override-model", temperature=0.8)
|
||||
|
||||
# Assert
|
||||
assert result.model == "override-model"
|
||||
assert result.text == "Response with both parameters override"
|
||||
|
||||
# Verify the API call was made with overridden parameters
|
||||
call_args = mock_openai.chat.completions.create.call_args
|
||||
assert call_args[1]['model'] == "override-model"
|
||||
assert call_args[1]['temperature'] == 0.8
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
275
tests/unit/test_text_completion/test_mistral_processor.py
Normal file
275
tests/unit/test_text_completion/test_mistral_processor.py
Normal file
|
|
@ -0,0 +1,275 @@
|
|||
"""
|
||||
Unit tests for trustgraph.model.text_completion.mistral
|
||||
Following the same successful pattern as other processor tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.model.text_completion.mistral.llm import Processor
|
||||
from trustgraph.base import LlmResult
|
||||
|
||||
|
||||
class TestMistralProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Test Mistral processor functionality"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.mistral.llm.Mistral')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_mistral_class):
|
||||
"""Test basic processor initialization"""
|
||||
# Arrange
|
||||
mock_mistral_client = MagicMock()
|
||||
mock_mistral_class.return_value = mock_mistral_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'ministral-8b-latest',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.1,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.default_model == 'ministral-8b-latest'
|
||||
assert processor.temperature == 0.1
|
||||
assert processor.max_output == 2048
|
||||
assert hasattr(processor, 'mistral')
|
||||
mock_mistral_class.assert_called_once_with(api_key='test-api-key')
|
||||
|
||||
@patch('trustgraph.model.text_completion.mistral.llm.Mistral')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_mistral_class):
|
||||
"""Test successful content generation"""
|
||||
# Arrange
|
||||
mock_mistral_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices[0].message.content = 'Generated response from Mistral'
|
||||
mock_response.usage.prompt_tokens = 15
|
||||
mock_response.usage.completion_tokens = 8
|
||||
mock_mistral_client.chat.complete.return_value = mock_response
|
||||
mock_mistral_class.return_value = mock_mistral_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'ministral-8b-latest',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Generated response from Mistral"
|
||||
assert result.in_token == 15
|
||||
assert result.out_token == 8
|
||||
assert result.model == 'ministral-8b-latest'
|
||||
mock_mistral_client.chat.complete.assert_called_once()
|
||||
|
||||
@patch('trustgraph.model.text_completion.mistral.llm.Mistral')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_mistral_class):
|
||||
"""Test temperature parameter override functionality"""
|
||||
# Arrange
|
||||
mock_mistral_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices[0].message.content = 'Response with custom temperature'
|
||||
mock_response.usage.prompt_tokens = 20
|
||||
mock_response.usage.completion_tokens = 12
|
||||
mock_mistral_client.chat.complete.return_value = mock_response
|
||||
mock_mistral_class.return_value = mock_mistral_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'ministral-8b-latest',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model=None, # Use default model
|
||||
temperature=0.8 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom temperature"
|
||||
|
||||
# Verify Mistral API was called with overridden temperature
|
||||
call_args = mock_mistral_client.chat.complete.call_args
|
||||
assert call_args[1]['temperature'] == 0.8 # Should use runtime override
|
||||
assert call_args[1]['model'] == 'ministral-8b-latest'
|
||||
|
||||
@patch('trustgraph.model.text_completion.mistral.llm.Mistral')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_mistral_class):
|
||||
"""Test model parameter override functionality"""
|
||||
# Arrange
|
||||
mock_mistral_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices[0].message.content = 'Response with custom model'
|
||||
mock_response.usage.prompt_tokens = 18
|
||||
mock_response.usage.completion_tokens = 14
|
||||
mock_mistral_client.chat.complete.return_value = mock_response
|
||||
mock_mistral_class.return_value = mock_mistral_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'ministral-8b-latest', # Default model
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.1, # Default temperature
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="mistral-large-latest", # Override model
|
||||
temperature=None # Use default temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom model"
|
||||
|
||||
# Verify Mistral API was called with overridden model
|
||||
call_args = mock_mistral_client.chat.complete.call_args
|
||||
assert call_args[1]['model'] == 'mistral-large-latest' # Should use runtime override
|
||||
assert call_args[1]['temperature'] == 0.1 # Should use processor default
|
||||
|
||||
@patch('trustgraph.model.text_completion.mistral.llm.Mistral')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_mistral_class):
|
||||
"""Test overriding both model and temperature parameters simultaneously"""
|
||||
# Arrange
|
||||
mock_mistral_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices[0].message.content = 'Response with both overrides'
|
||||
mock_response.usage.prompt_tokens = 22
|
||||
mock_response.usage.completion_tokens = 16
|
||||
mock_mistral_client.chat.complete.return_value = mock_response
|
||||
mock_mistral_class.return_value = mock_mistral_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'ministral-8b-latest', # Default model
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="mistral-large-latest", # Override model
|
||||
temperature=0.9 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with both overrides"
|
||||
|
||||
# Verify Mistral API was called with both overrides
|
||||
call_args = mock_mistral_client.chat.complete.call_args
|
||||
assert call_args[1]['model'] == 'mistral-large-latest' # Should use runtime override
|
||||
assert call_args[1]['temperature'] == 0.9 # Should use runtime override
|
||||
|
||||
@patch('trustgraph.model.text_completion.mistral.llm.Mistral')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_prompt_construction(self, mock_llm_init, mock_async_init, mock_mistral_class):
|
||||
"""Test prompt construction with system and user prompts"""
|
||||
# Arrange
|
||||
mock_mistral_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices[0].message.content = 'Response with system instructions'
|
||||
mock_response.usage.prompt_tokens = 25
|
||||
mock_response.usage.completion_tokens = 15
|
||||
mock_mistral_client.chat.complete.return_value = mock_response
|
||||
mock_mistral_class.return_value = mock_mistral_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'ministral-8b-latest',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("You are a helpful assistant", "What is AI?")
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with system instructions"
|
||||
assert result.in_token == 25
|
||||
assert result.out_token == 15
|
||||
|
||||
# Verify the combined prompt structure
|
||||
call_args = mock_mistral_client.chat.complete.call_args
|
||||
messages = call_args[1]['messages']
|
||||
assert len(messages) == 1
|
||||
assert messages[0]['role'] == 'user'
|
||||
assert messages[0]['content'][0]['type'] == 'text'
|
||||
assert messages[0]['content'][0]['text'] == "You are a helpful assistant\n\nWhat is AI?"
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
@ -40,7 +40,7 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'llama2'
|
||||
assert processor.default_model == 'llama2'
|
||||
assert hasattr(processor, 'llm')
|
||||
mock_client_class.assert_called_once_with(host='http://localhost:11434')
|
||||
|
||||
|
|
@ -81,7 +81,7 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert result.in_token == 15
|
||||
assert result.out_token == 8
|
||||
assert result.model == 'llama2'
|
||||
mock_client.generate.assert_called_once_with('llama2', "System prompt\n\nUser prompt")
|
||||
mock_client.generate.assert_called_once_with('llama2', "System prompt\n\nUser prompt", options={'temperature': 0.0})
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
|
|
@ -134,7 +134,7 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'mistral'
|
||||
assert processor.default_model == 'mistral'
|
||||
mock_client_class.assert_called_once_with(host='http://192.168.1.100:11434')
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
|
|
@ -160,7 +160,7 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gemma2:9b' # default_model
|
||||
assert processor.default_model == 'gemma2:9b' # default_model
|
||||
# Should use default_ollama (http://localhost:11434 or from OLLAMA_HOST env)
|
||||
mock_client_class.assert_called_once()
|
||||
|
||||
|
|
@ -203,7 +203,7 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert result.model == 'llama2'
|
||||
|
||||
# The prompt should be "" + "\n\n" + "" = "\n\n"
|
||||
mock_client.generate.assert_called_once_with('llama2', "\n\n")
|
||||
mock_client.generate.assert_called_once_with('llama2', "\n\n", options={'temperature': 0.0})
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
|
|
@ -310,7 +310,151 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert result.out_token == 15
|
||||
|
||||
# Verify the combined prompt
|
||||
mock_client.generate.assert_called_once_with('llama2', "You are a helpful assistant\n\nWhat is AI?")
|
||||
mock_client.generate.assert_called_once_with('llama2', "You are a helpful assistant\n\nWhat is AI?", options={'temperature': 0.0})
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test temperature parameter override functionality"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_response = {
|
||||
'response': 'Response with custom temperature',
|
||||
'prompt_eval_count': 20,
|
||||
'eval_count': 12
|
||||
}
|
||||
mock_client.generate.return_value = mock_response
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'llama2',
|
||||
'ollama': 'http://localhost:11434',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model=None, # Use default model
|
||||
temperature=0.8 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom temperature"
|
||||
|
||||
# Verify Ollama API was called with overridden temperature
|
||||
mock_client.generate.assert_called_once_with(
|
||||
'llama2',
|
||||
"System prompt\n\nUser prompt",
|
||||
options={'temperature': 0.8} # Should use runtime override
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test model parameter override functionality"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_response = {
|
||||
'response': 'Response with custom model',
|
||||
'prompt_eval_count': 18,
|
||||
'eval_count': 14
|
||||
}
|
||||
mock_client.generate.return_value = mock_response
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'llama2', # Default model
|
||||
'ollama': 'http://localhost:11434',
|
||||
'temperature': 0.1, # Default temperature
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="mistral", # Override model
|
||||
temperature=None # Use default temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom model"
|
||||
|
||||
# Verify Ollama API was called with overridden model
|
||||
mock_client.generate.assert_called_once_with(
|
||||
'mistral', # Should use runtime override
|
||||
"System prompt\n\nUser prompt",
|
||||
options={'temperature': 0.1} # Should use processor default
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test overriding both model and temperature parameters simultaneously"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_response = {
|
||||
'response': 'Response with both overrides',
|
||||
'prompt_eval_count': 22,
|
||||
'eval_count': 16
|
||||
}
|
||||
mock_client.generate.return_value = mock_response
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'llama2', # Default model
|
||||
'ollama': 'http://localhost:11434',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="codellama", # Override model
|
||||
temperature=0.9 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with both overrides"
|
||||
|
||||
# Verify Ollama API was called with both overrides
|
||||
mock_client.generate.assert_called_once_with(
|
||||
'codellama', # Should use runtime override
|
||||
"System prompt\n\nUser prompt",
|
||||
options={'temperature': 0.9} # Should use runtime override
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ class TestOpenAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gpt-3.5-turbo'
|
||||
assert processor.default_model == 'gpt-3.5-turbo'
|
||||
assert processor.temperature == 0.0
|
||||
assert processor.max_output == 4096
|
||||
assert hasattr(processor, 'openai')
|
||||
|
|
@ -222,7 +222,7 @@ class TestOpenAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gpt-4'
|
||||
assert processor.default_model == 'gpt-4'
|
||||
assert processor.temperature == 0.7
|
||||
assert processor.max_output == 2048
|
||||
mock_openai_class.assert_called_once_with(base_url='https://custom-openai-url.com/v1', api_key='custom-api-key')
|
||||
|
|
@ -251,7 +251,7 @@ class TestOpenAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gpt-3.5-turbo' # default_model
|
||||
assert processor.default_model == 'gpt-3.5-turbo' # default_model
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
assert processor.max_output == 4096 # default_max_output
|
||||
mock_openai_class.assert_called_once_with(base_url='https://api.openai.com/v1', api_key='test-api-key')
|
||||
|
|
@ -391,5 +391,210 @@ class TestOpenAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert call_args[1]['response_format'] == {"type": "text"}
|
||||
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test temperature parameter override functionality"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Response with custom temperature"
|
||||
mock_response.usage.prompt_tokens = 15
|
||||
mock_response.usage.completion_tokens = 10
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-3.5-turbo',
|
||||
'api_key': 'test-api-key',
|
||||
'url': 'https://api.openai.com/v1',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model=None, # Use default model
|
||||
temperature=0.9 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom temperature"
|
||||
|
||||
# Verify the OpenAI API was called with overridden temperature
|
||||
mock_openai_client.chat.completions.create.assert_called_once()
|
||||
call_kwargs = mock_openai_client.chat.completions.create.call_args.kwargs
|
||||
|
||||
assert call_kwargs['temperature'] == 0.9 # Should use runtime override
|
||||
assert call_kwargs['model'] == 'gpt-3.5-turbo' # Should use processor default
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test model parameter override functionality"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Response with custom model"
|
||||
mock_response.usage.prompt_tokens = 15
|
||||
mock_response.usage.completion_tokens = 10
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-3.5-turbo', # Default model
|
||||
'api_key': 'test-api-key',
|
||||
'url': 'https://api.openai.com/v1',
|
||||
'temperature': 0.2,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="gpt-4", # Override model
|
||||
temperature=None # Use default temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom model"
|
||||
|
||||
# Verify the OpenAI API was called with overridden model
|
||||
mock_openai_client.chat.completions.create.assert_called_once()
|
||||
call_kwargs = mock_openai_client.chat.completions.create.call_args.kwargs
|
||||
|
||||
assert call_kwargs['model'] == 'gpt-4' # Should use runtime override
|
||||
assert call_kwargs['temperature'] == 0.2 # Should use processor default
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test overriding both model and temperature parameters simultaneously"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Response with both overrides"
|
||||
mock_response.usage.prompt_tokens = 15
|
||||
mock_response.usage.completion_tokens = 10
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-3.5-turbo', # Default model
|
||||
'api_key': 'test-api-key',
|
||||
'url': 'https://api.openai.com/v1',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="gpt-4", # Override model
|
||||
temperature=0.7 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with both overrides"
|
||||
|
||||
# Verify the OpenAI API was called with both overrides
|
||||
mock_openai_client.chat.completions.create.assert_called_once()
|
||||
call_kwargs = mock_openai_client.chat.completions.create.call_args.kwargs
|
||||
|
||||
assert call_kwargs['model'] == 'gpt-4' # Should use runtime override
|
||||
assert call_kwargs['temperature'] == 0.7 # Should use runtime override
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_no_override_uses_defaults(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test that when no parameters are overridden, processor defaults are used"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Response with defaults"
|
||||
mock_response.usage.prompt_tokens = 15
|
||||
mock_response.usage.completion_tokens = 10
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-4', # Default model
|
||||
'api_key': 'test-api-key',
|
||||
'url': 'https://api.openai.com/v1',
|
||||
'temperature': 0.5, # Default temperature
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Don't override any parameters (pass None)
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model=None, # Use default model
|
||||
temperature=None # Use default temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with defaults"
|
||||
|
||||
# Verify the OpenAI API was called with processor defaults
|
||||
mock_openai_client.chat.completions.create.assert_called_once()
|
||||
call_kwargs = mock_openai_client.chat.completions.create.call_args.kwargs
|
||||
|
||||
assert call_kwargs['model'] == 'gpt-4' # Should use processor default
|
||||
assert call_kwargs['temperature'] == 0.5 # Should use processor default
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
186
tests/unit/test_text_completion/test_parameter_caching.py
Normal file
186
tests/unit/test_text_completion/test_parameter_caching.py
Normal file
|
|
@ -0,0 +1,186 @@
|
|||
"""
|
||||
Unit tests for Parameter-Based Caching in LLM Processors
|
||||
Testing processors that cache based on temperature parameters (Bedrock, GoogleAIStudio)
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
from trustgraph.model.text_completion.googleaistudio.llm import Processor as GoogleAIProcessor
|
||||
from trustgraph.base import LlmResult
|
||||
|
||||
|
||||
class TestParameterCaching(IsolatedAsyncioTestCase):
|
||||
"""Test parameter-based caching functionality"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_googleai_temperature_cache_keys(self, mock_llm_init, mock_async_init, mock_genai):
|
||||
"""Test that GoogleAI processor creates separate cache entries for different temperatures"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_genai.Client.return_value = mock_client
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Generated response"
|
||||
mock_response.usage_metadata.prompt_token_count = 10
|
||||
mock_response.usage_metadata.candidates_token_count = 5
|
||||
mock_client.models.generate_content.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 1024,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = GoogleAIProcessor(**config)
|
||||
|
||||
# Act - Call with different temperatures
|
||||
await processor.generate_content("System", "Prompt 1", model="gemini-2.0-flash-001", temperature=0.0)
|
||||
await processor.generate_content("System", "Prompt 2", model="gemini-2.0-flash-001", temperature=0.5)
|
||||
await processor.generate_content("System", "Prompt 3", model="gemini-2.0-flash-001", temperature=1.0)
|
||||
|
||||
# Assert - Should have 3 different cache entries
|
||||
cache_keys = list(processor.generation_configs.keys())
|
||||
|
||||
assert len(cache_keys) == 3
|
||||
assert "gemini-2.0-flash-001:0.0" in cache_keys
|
||||
assert "gemini-2.0-flash-001:0.5" in cache_keys
|
||||
assert "gemini-2.0-flash-001:1.0" in cache_keys
|
||||
|
||||
# Verify each cached config has the correct temperature
|
||||
assert processor.generation_configs["gemini-2.0-flash-001:0.0"].temperature == 0.0
|
||||
assert processor.generation_configs["gemini-2.0-flash-001:0.5"].temperature == 0.5
|
||||
assert processor.generation_configs["gemini-2.0-flash-001:1.0"].temperature == 1.0
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_googleai_cache_reuse_same_parameters(self, mock_llm_init, mock_async_init, mock_genai):
|
||||
"""Test that GoogleAI processor reuses cache for identical model+temperature combinations"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_genai.Client.return_value = mock_client
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Generated response"
|
||||
mock_response.usage_metadata.prompt_token_count = 10
|
||||
mock_response.usage_metadata.candidates_token_count = 5
|
||||
mock_client.models.generate_content.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 1024,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = GoogleAIProcessor(**config)
|
||||
|
||||
# Act - Call multiple times with same parameters
|
||||
await processor.generate_content("System", "Prompt 1", model="gemini-2.0-flash-001", temperature=0.7)
|
||||
await processor.generate_content("System", "Prompt 2", model="gemini-2.0-flash-001", temperature=0.7)
|
||||
await processor.generate_content("System", "Prompt 3", model="gemini-2.0-flash-001", temperature=0.7)
|
||||
|
||||
# Assert - Should have only 1 cache entry for the repeated parameters
|
||||
cache_keys = list(processor.generation_configs.keys())
|
||||
assert len(cache_keys) == 1
|
||||
assert "gemini-2.0-flash-001:0.7" in cache_keys
|
||||
|
||||
# The same config object should be reused
|
||||
config_obj = processor.generation_configs["gemini-2.0-flash-001:0.7"]
|
||||
assert config_obj.temperature == 0.7
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_googleai_different_models_separate_caches(self, mock_llm_init, mock_async_init, mock_genai):
|
||||
"""Test that different models create separate cache entries even with same temperature"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_genai.Client.return_value = mock_client
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Generated response"
|
||||
mock_response.usage_metadata.prompt_token_count = 10
|
||||
mock_response.usage_metadata.candidates_token_count = 5
|
||||
mock_client.models.generate_content.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 1024,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = GoogleAIProcessor(**config)
|
||||
|
||||
# Act - Call with different models, same temperature
|
||||
await processor.generate_content("System", "Prompt 1", model="gemini-2.0-flash-001", temperature=0.5)
|
||||
await processor.generate_content("System", "Prompt 2", model="gemini-1.5-flash-001", temperature=0.5)
|
||||
|
||||
# Assert - Should have separate cache entries for different models
|
||||
cache_keys = list(processor.generation_configs.keys())
|
||||
assert len(cache_keys) == 2
|
||||
assert "gemini-2.0-flash-001:0.5" in cache_keys
|
||||
assert "gemini-1.5-flash-001:0.5" in cache_keys
|
||||
|
||||
# Note: Bedrock tests would be similar but testing the Bedrock processor's caching behavior
|
||||
# The Bedrock processor caches model variants with temperature in the cache key
|
||||
|
||||
async def test_bedrock_temperature_cache_keys(self):
|
||||
"""Test Bedrock processor temperature-aware caching"""
|
||||
# This would test the Bedrock processor's _get_or_create_variant method
|
||||
# with different temperature values to ensure proper cache key generation
|
||||
|
||||
# Implementation would follow similar pattern to GoogleAI tests above
|
||||
# but using the Bedrock processor and testing model_variants cache
|
||||
pass
|
||||
|
||||
async def test_bedrock_cache_isolation_different_temperatures(self):
|
||||
"""Test that Bedrock processor isolates cache entries by temperature"""
|
||||
pass
|
||||
|
||||
async def test_cache_memory_efficiency(self):
|
||||
"""Test that caches don't grow unbounded with many different parameter combinations"""
|
||||
# This could test cache size limits or cleanup behavior if implemented
|
||||
pass
|
||||
|
||||
|
||||
class TestCachePerformance(IsolatedAsyncioTestCase):
|
||||
"""Test caching performance characteristics"""
|
||||
|
||||
async def test_cache_hit_performance(self):
|
||||
"""Test that cache hits are faster than cache misses"""
|
||||
# This would measure timing differences between cache hits and misses
|
||||
pass
|
||||
|
||||
async def test_concurrent_cache_access(self):
|
||||
"""Test concurrent access to cached configurations"""
|
||||
# This would test thread-safety of cache access
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
271
tests/unit/test_text_completion/test_tgi_processor.py
Normal file
271
tests/unit/test_text_completion/test_tgi_processor.py
Normal file
|
|
@ -0,0 +1,271 @@
|
|||
"""
|
||||
Unit tests for trustgraph.model.text_completion.tgi
|
||||
Following the same successful pattern as previous tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.model.text_completion.tgi.llm import Processor
|
||||
from trustgraph.base import LlmResult
|
||||
from trustgraph.exceptions import TooManyRequests
|
||||
|
||||
|
||||
class TestTGIProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Test TGI processor functionality"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.tgi.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test basic processor initialization"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'tgi',
|
||||
'url': 'http://tgi-service:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.default_model == 'tgi'
|
||||
assert processor.base_url == 'http://tgi-service:8899/v1'
|
||||
assert processor.temperature == 0.0
|
||||
assert processor.max_output == 2048
|
||||
assert hasattr(processor, 'session')
|
||||
mock_session_class.assert_called_once()
|
||||
|
||||
@patch('trustgraph.model.text_completion.tgi.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test successful content generation"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'choices': [{
|
||||
'message': {
|
||||
'content': 'Generated response from TGI'
|
||||
}
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 20,
|
||||
'completion_tokens': 12
|
||||
}
|
||||
})
|
||||
|
||||
# Mock the async context manager
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session.post.return_value.__aexit__.return_value = None
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'tgi',
|
||||
'url': 'http://tgi-service:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Generated response from TGI"
|
||||
assert result.in_token == 20
|
||||
assert result.out_token == 12
|
||||
assert result.model == 'tgi'
|
||||
|
||||
# Verify the API call was made correctly
|
||||
mock_session.post.assert_called_once()
|
||||
call_args = mock_session.post.call_args
|
||||
|
||||
# Check URL
|
||||
assert call_args[0][0] == 'http://tgi-service:8899/v1/chat/completions'
|
||||
|
||||
# Check request structure
|
||||
request_body = call_args[1]['json']
|
||||
assert request_body['model'] == 'tgi'
|
||||
assert request_body['temperature'] == 0.0
|
||||
assert request_body['max_tokens'] == 2048
|
||||
|
||||
@patch('trustgraph.model.text_completion.tgi.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_model_override(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test generate_content with model parameter override"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'choices': [{
|
||||
'message': {
|
||||
'content': 'Response from overridden model'
|
||||
}
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 15,
|
||||
'completion_tokens': 10
|
||||
}
|
||||
})
|
||||
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session.post.return_value.__aexit__.return_value = None
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'tgi',
|
||||
'url': 'http://tgi-service:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model
|
||||
result = await processor.generate_content("System", "Prompt", model="custom-tgi-model")
|
||||
|
||||
# Assert
|
||||
assert result.model == "custom-tgi-model" # Should use overridden model
|
||||
assert result.text == "Response from overridden model"
|
||||
|
||||
# Verify the API call was made with overridden model
|
||||
call_args = mock_session.post.call_args
|
||||
assert call_args[1]['json']['model'] == "custom-tgi-model"
|
||||
|
||||
@patch('trustgraph.model.text_completion.tgi.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_temperature_override(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test generate_content with temperature parameter override"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'choices': [{
|
||||
'message': {
|
||||
'content': 'Response with temperature override'
|
||||
}
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 18,
|
||||
'completion_tokens': 12
|
||||
}
|
||||
})
|
||||
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session.post.return_value.__aexit__.return_value = None
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'tgi',
|
||||
'url': 'http://tgi-service:8899/v1',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature
|
||||
result = await processor.generate_content("System", "Prompt", temperature=0.7)
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with temperature override"
|
||||
|
||||
# Verify the API call was made with overridden temperature
|
||||
call_args = mock_session.post.call_args
|
||||
assert call_args[1]['json']['temperature'] == 0.7
|
||||
|
||||
@patch('trustgraph.model.text_completion.tgi.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_both_parameters_override(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test generate_content with both model and temperature overrides"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'choices': [{
|
||||
'message': {
|
||||
'content': 'Response with both parameters override'
|
||||
}
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 20,
|
||||
'completion_tokens': 15
|
||||
}
|
||||
})
|
||||
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session.post.return_value.__aexit__.return_value = None
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'tgi',
|
||||
'url': 'http://tgi-service:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters
|
||||
result = await processor.generate_content("System", "Prompt", model="override-model", temperature=0.8)
|
||||
|
||||
# Assert
|
||||
assert result.model == "override-model"
|
||||
assert result.text == "Response with both parameters override"
|
||||
|
||||
# Verify the API call was made with overridden parameters
|
||||
call_args = mock_session.post.call_args
|
||||
assert call_args[1]['json']['model'] == "override-model"
|
||||
assert call_args[1]['json']['temperature'] == 0.8
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
@ -47,10 +47,10 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gemini-2.0-flash-001' # It's stored as 'model', not 'model_name'
|
||||
assert hasattr(processor, 'generation_config')
|
||||
assert processor.default_model == 'gemini-2.0-flash-001' # It's stored as 'model', not 'model_name'
|
||||
assert hasattr(processor, 'generation_configs') # Now a cache dictionary
|
||||
assert hasattr(processor, 'safety_settings')
|
||||
assert hasattr(processor, 'llm')
|
||||
assert hasattr(processor, 'model_clients') # LLM clients are now cached
|
||||
mock_service_account.Credentials.from_service_account_file.assert_called_once_with('private.json')
|
||||
mock_vertexai.init.assert_called_once()
|
||||
|
||||
|
|
@ -102,7 +102,8 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
mock_model.generate_content.assert_called_once()
|
||||
# Verify the call was made with the expected parameters
|
||||
call_args = mock_model.generate_content.call_args
|
||||
assert call_args[1]['generation_config'] == processor.generation_config
|
||||
# Generation config is now created dynamically per model
|
||||
assert 'generation_config' in call_args[1]
|
||||
assert call_args[1]['safety_settings'] == processor.safety_settings
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
|
|
@ -223,7 +224,7 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gemini-2.0-flash-001'
|
||||
assert processor.default_model == 'gemini-2.0-flash-001'
|
||||
mock_auth_default.assert_called_once()
|
||||
mock_vertexai.init.assert_called_once_with(
|
||||
location='us-central1',
|
||||
|
|
@ -296,11 +297,11 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gemini-1.5-pro'
|
||||
assert processor.default_model == 'gemini-1.5-pro'
|
||||
|
||||
# Verify that generation_config object exists (can't easily check internal values)
|
||||
assert hasattr(processor, 'generation_config')
|
||||
assert processor.generation_config is not None
|
||||
assert hasattr(processor, 'generation_configs') # Now a cache dictionary
|
||||
assert processor.generation_configs == {} # Empty cache initially
|
||||
|
||||
# Verify that safety settings are configured
|
||||
assert len(processor.safety_settings) == 4
|
||||
|
|
@ -353,8 +354,8 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
project='test-project-123'
|
||||
)
|
||||
|
||||
# Verify GenerativeModel was created with the right model name
|
||||
mock_generative_model.assert_called_once_with('gemini-2.0-flash-001')
|
||||
# GenerativeModel is now created lazily on first use, not at initialization
|
||||
mock_generative_model.assert_not_called()
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
|
|
@ -440,8 +441,8 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'claude-3-sonnet@20240229'
|
||||
assert processor.is_anthropic == True
|
||||
assert processor.default_model == 'claude-3-sonnet@20240229'
|
||||
# is_anthropic logic is now determined dynamically per request
|
||||
|
||||
# Verify service account was called with private key
|
||||
mock_service_account.Credentials.from_service_account_file.assert_called_once_with('anthropic-key.json')
|
||||
|
|
@ -459,6 +460,180 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert processor.api_params["top_p"] == 1.0
|
||||
assert processor.api_params["top_k"] == 32
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
"""Test temperature parameter override functionality"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Response with custom temperature"
|
||||
mock_response.usage_metadata.prompt_token_count = 20
|
||||
mock_response.usage_metadata.candidates_token_count = 12
|
||||
mock_model.generate_content.return_value = mock_response
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'region': 'us-central1',
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 8192,
|
||||
'private_key': 'private.json',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model=None, # Use default model
|
||||
temperature=0.8 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom temperature"
|
||||
|
||||
# Verify Gemini API was called with overridden temperature
|
||||
mock_model.generate_content.assert_called_once()
|
||||
call_args = mock_model.generate_content.call_args
|
||||
|
||||
# Check that generation_config was created (we can't directly access temperature from mock)
|
||||
generation_config = call_args.kwargs['generation_config']
|
||||
assert generation_config is not None # Should use overridden temperature configuration
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
"""Test model parameter override functionality"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
|
||||
# Mock different models
|
||||
mock_model_default = MagicMock()
|
||||
mock_model_override = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Response with custom model"
|
||||
mock_response.usage_metadata.prompt_token_count = 18
|
||||
mock_response.usage_metadata.candidates_token_count = 14
|
||||
mock_model_override.generate_content.return_value = mock_response
|
||||
|
||||
# GenerativeModel should return different models based on input
|
||||
def model_factory(model_name):
|
||||
if model_name == 'gemini-1.5-pro':
|
||||
return mock_model_override
|
||||
return mock_model_default
|
||||
|
||||
mock_generative_model.side_effect = model_factory
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'region': 'us-central1',
|
||||
'model': 'gemini-2.0-flash-001', # Default model
|
||||
'temperature': 0.2, # Default temperature
|
||||
'max_output': 8192,
|
||||
'private_key': 'private.json',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="gemini-1.5-pro", # Override model
|
||||
temperature=None # Use default temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom model"
|
||||
|
||||
# Verify the overridden model was used
|
||||
mock_model_override.generate_content.assert_called_once()
|
||||
# Verify GenerativeModel was called with the override model
|
||||
mock_generative_model.assert_called_with('gemini-1.5-pro')
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
"""Test overriding both model and temperature parameters simultaneously"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Response with both overrides"
|
||||
mock_response.usage_metadata.prompt_token_count = 22
|
||||
mock_response.usage_metadata.candidates_token_count = 16
|
||||
mock_model.generate_content.return_value = mock_response
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'region': 'us-central1',
|
||||
'model': 'gemini-2.0-flash-001', # Default model
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 8192,
|
||||
'private_key': 'private.json',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="gemini-1.5-flash-001", # Override model
|
||||
temperature=0.9 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with both overrides"
|
||||
|
||||
# Verify both overrides were used
|
||||
mock_model.generate_content.assert_called_once()
|
||||
call_args = mock_model.generate_content.call_args
|
||||
|
||||
# Verify model override
|
||||
mock_generative_model.assert_called_with('gemini-1.5-flash-001') # Should use runtime override
|
||||
|
||||
# Verify temperature override (we can't directly access temperature from mock)
|
||||
generation_config = call_args.kwargs['generation_config']
|
||||
assert generation_config is not None # Should use overridden temperature configuration
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
@ -42,7 +42,7 @@ class TestVLLMProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'TheBloke/Mistral-7B-v0.1-AWQ'
|
||||
assert processor.default_model == 'TheBloke/Mistral-7B-v0.1-AWQ'
|
||||
assert processor.base_url == 'http://vllm-service:8899/v1'
|
||||
assert processor.temperature == 0.0
|
||||
assert processor.max_output == 2048
|
||||
|
|
@ -199,7 +199,7 @@ class TestVLLMProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'custom-model'
|
||||
assert processor.default_model == 'custom-model'
|
||||
assert processor.base_url == 'http://custom-vllm:8080/v1'
|
||||
assert processor.temperature == 0.7
|
||||
assert processor.max_output == 1024
|
||||
|
|
@ -228,7 +228,7 @@ class TestVLLMProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'TheBloke/Mistral-7B-v0.1-AWQ' # default_model
|
||||
assert processor.default_model == 'TheBloke/Mistral-7B-v0.1-AWQ' # default_model
|
||||
assert processor.base_url == 'http://vllm-service:8899/v1' # default_base_url
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
assert processor.max_output == 2048 # default_max_output
|
||||
|
|
@ -485,5 +485,148 @@ class TestVLLMProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert call_args[1]['json']['prompt'] == expected_prompt
|
||||
|
||||
|
||||
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_model_override(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test generate_content with model parameter override"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'choices': [{
|
||||
'text': 'Response from overridden model'
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 12,
|
||||
'completion_tokens': 8
|
||||
}
|
||||
})
|
||||
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session.post.return_value.__aexit__.return_value = None
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
|
||||
'url': 'http://vllm-service:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model
|
||||
result = await processor.generate_content("System", "Prompt", model="custom-vllm-model")
|
||||
|
||||
# Assert
|
||||
assert result.model == "custom-vllm-model" # Should use overridden model
|
||||
assert result.text == "Response from overridden model"
|
||||
|
||||
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_temperature_override(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test generate_content with temperature parameter override"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'choices': [{
|
||||
'text': 'Response with temperature override'
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 15,
|
||||
'completion_tokens': 10
|
||||
}
|
||||
})
|
||||
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session.post.return_value.__aexit__.return_value = None
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
|
||||
'url': 'http://vllm-service:8899/v1',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature
|
||||
result = await processor.generate_content("System", "Prompt", temperature=0.7)
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with temperature override"
|
||||
|
||||
# Verify the request was made with overridden temperature
|
||||
call_args = mock_session.post.call_args
|
||||
assert call_args[1]['json']['temperature'] == 0.7
|
||||
|
||||
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_both_parameters_override(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test generate_content with both model and temperature overrides"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'choices': [{
|
||||
'text': 'Response with both parameters override'
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 18,
|
||||
'completion_tokens': 12
|
||||
}
|
||||
})
|
||||
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session.post.return_value.__aexit__.return_value = None
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
|
||||
'url': 'http://vllm-service:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters
|
||||
result = await processor.generate_content("System", "Prompt", model="override-model", temperature=0.8)
|
||||
|
||||
# Assert
|
||||
assert result.model == "override-model"
|
||||
assert result.text == "Response with both parameters override"
|
||||
|
||||
# Verify the request was made with overridden temperature
|
||||
call_args = mock_session.post.call_args
|
||||
assert call_args[1]['json']['temperature'] == 0.8
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
Loading…
Add table
Add a link
Reference in a new issue