release/v1.4 -> master (#548)

This commit is contained in:
cybermaggedon 2025-10-06 17:54:26 +01:00 committed by GitHub
parent 3ec2cd54f9
commit 2bd68ed7f4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
94 changed files with 8571 additions and 1740 deletions

View file

@ -0,0 +1,238 @@
"""
Unit tests for Flow Parameter Specification functionality
Testing parameter specification registration and handling in flow processors
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
from trustgraph.base.flow_processor import FlowProcessor
from trustgraph.base import ParameterSpec, ConsumerSpec, ProducerSpec
class MockAsyncProcessor:
def __init__(self, **params):
self.config_handlers = []
self.id = params.get('id', 'test-service')
self.specifications = []
class TestFlowParameterSpecs(IsolatedAsyncioTestCase):
"""Test flow processor parameter specification functionality"""
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
def test_parameter_spec_registration(self):
"""Test that parameter specs can be registered with flow processors"""
# Arrange
config = {
'id': 'test-flow-processor',
'concurrency': 1,
'taskgroup': AsyncMock()
}
processor = FlowProcessor(**config)
# Create test parameter specs
model_spec = ParameterSpec(name="model")
temperature_spec = ParameterSpec(name="temperature")
# Act
processor.register_specification(model_spec)
processor.register_specification(temperature_spec)
# Assert
assert len(processor.specifications) >= 2
param_specs = [spec for spec in processor.specifications
if isinstance(spec, ParameterSpec)]
assert len(param_specs) >= 2
param_names = [spec.name for spec in param_specs]
assert "model" in param_names
assert "temperature" in param_names
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
def test_mixed_specification_types(self):
"""Test registration of mixed specification types (parameters, consumers, producers)"""
# Arrange
config = {
'id': 'test-flow-processor',
'concurrency': 1,
'taskgroup': AsyncMock()
}
processor = FlowProcessor(**config)
# Create different spec types
param_spec = ParameterSpec(name="model")
consumer_spec = ConsumerSpec(name="input", schema=MagicMock(), handler=MagicMock())
producer_spec = ProducerSpec(name="output", schema=MagicMock())
# Act
processor.register_specification(param_spec)
processor.register_specification(consumer_spec)
processor.register_specification(producer_spec)
# Assert
assert len(processor.specifications) == 3
# Count each type
param_specs = [s for s in processor.specifications if isinstance(s, ParameterSpec)]
consumer_specs = [s for s in processor.specifications if isinstance(s, ConsumerSpec)]
producer_specs = [s for s in processor.specifications if isinstance(s, ProducerSpec)]
assert len(param_specs) == 1
assert len(consumer_specs) == 1
assert len(producer_specs) == 1
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
def test_parameter_spec_metadata(self):
"""Test parameter specification metadata handling"""
# Arrange
config = {
'id': 'test-flow-processor',
'concurrency': 1,
'taskgroup': AsyncMock()
}
processor = FlowProcessor(**config)
# Create parameter specs with metadata (if supported)
model_spec = ParameterSpec(name="model")
temperature_spec = ParameterSpec(name="temperature")
# Act
processor.register_specification(model_spec)
processor.register_specification(temperature_spec)
# Assert
param_specs = [spec for spec in processor.specifications
if isinstance(spec, ParameterSpec)]
model_spec_registered = next((s for s in param_specs if s.name == "model"), None)
temperature_spec_registered = next((s for s in param_specs if s.name == "temperature"), None)
assert model_spec_registered is not None
assert temperature_spec_registered is not None
assert model_spec_registered.name == "model"
assert temperature_spec_registered.name == "temperature"
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
def test_duplicate_parameter_spec_handling(self):
"""Test handling of duplicate parameter spec registration"""
# Arrange
config = {
'id': 'test-flow-processor',
'concurrency': 1,
'taskgroup': AsyncMock()
}
processor = FlowProcessor(**config)
# Create duplicate parameter specs
model_spec1 = ParameterSpec(name="model")
model_spec2 = ParameterSpec(name="model")
# Act
processor.register_specification(model_spec1)
processor.register_specification(model_spec2)
# Assert - Should allow duplicates (or handle appropriately)
param_specs = [spec for spec in processor.specifications
if isinstance(spec, ParameterSpec) and spec.name == "model"]
# Either should have 2 duplicates or the system should handle deduplication
assert len(param_specs) >= 1 # At least one should be registered
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
@patch('trustgraph.base.flow_processor.Flow')
async def test_parameter_specs_available_to_flows(self, mock_flow_class):
"""Test that parameter specs are available when flows are created"""
# Arrange
config = {
'id': 'test-flow-processor',
'concurrency': 1,
'taskgroup': AsyncMock()
}
processor = FlowProcessor(**config)
processor.id = 'test-processor'
# Register parameter specs
model_spec = ParameterSpec(name="model")
temperature_spec = ParameterSpec(name="temperature")
processor.register_specification(model_spec)
processor.register_specification(temperature_spec)
mock_flow = AsyncMock()
mock_flow_class.return_value = mock_flow
flow_name = 'test-flow'
flow_defn = {'config': 'test-config'}
# Act
await processor.start_flow(flow_name, flow_defn)
# Assert - Flow should be created with access to processor specifications
mock_flow_class.assert_called_once_with('test-processor', flow_name, processor, flow_defn)
# The flow should have access to the processor's specifications
# (The exact mechanism depends on Flow implementation)
assert len(processor.specifications) >= 2
class TestParameterSpecValidation(IsolatedAsyncioTestCase):
"""Test parameter specification validation functionality"""
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
def test_parameter_spec_name_validation(self):
"""Test parameter spec name validation"""
# Arrange
config = {
'id': 'test-flow-processor',
'concurrency': 1,
'taskgroup': AsyncMock()
}
processor = FlowProcessor(**config)
# Act & Assert - Valid parameter names
valid_specs = [
ParameterSpec(name="model"),
ParameterSpec(name="temperature"),
ParameterSpec(name="max_tokens"),
ParameterSpec(name="api_key")
]
for spec in valid_specs:
# Should not raise any exceptions
processor.register_specification(spec)
assert len([s for s in processor.specifications if isinstance(s, ParameterSpec)]) >= 4
def test_parameter_spec_creation_validation(self):
"""Test parameter spec creation with various inputs"""
# Test valid parameter spec creation
valid_specs = [
ParameterSpec(name="model"),
ParameterSpec(name="temperature"),
ParameterSpec(name="max_output"),
]
for spec in valid_specs:
assert spec.name is not None
assert isinstance(spec.name, str)
# Test edge cases (if parameter specs have validation)
# This depends on the actual ParameterSpec implementation
try:
empty_name_spec = ParameterSpec(name="")
# May or may not be valid depending on implementation
except Exception:
# If validation exists, it should catch invalid names
pass
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -0,0 +1,264 @@
"""
Unit tests for LLM Service Parameter Specifications
Testing the new parameter-aware functionality added to the LLM base service
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
from trustgraph.base.llm_service import LlmService, LlmResult
from trustgraph.base import ParameterSpec, ConsumerSpec, ProducerSpec
from trustgraph.schema import TextCompletionRequest, TextCompletionResponse
class MockAsyncProcessor:
def __init__(self, **params):
self.config_handlers = []
self.id = params.get('id', 'test-service')
self.specifications = []
class TestLlmServiceParameters(IsolatedAsyncioTestCase):
"""Test LLM service parameter specification functionality"""
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
def test_parameter_specs_registration(self):
"""Test that LLM service registers model and temperature parameter specs"""
# Arrange
config = {
'id': 'test-llm-service',
'concurrency': 1,
'taskgroup': AsyncMock() # Add required taskgroup
}
# Act
service = LlmService(**config)
# Assert
param_specs = {spec.name: spec for spec in service.specifications
if isinstance(spec, ParameterSpec)}
assert "model" in param_specs
assert "temperature" in param_specs
assert len(param_specs) >= 2 # May have other parameter specs
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
def test_model_parameter_spec_properties(self):
"""Test that model parameter spec has correct properties"""
# Arrange
config = {
'id': 'test-llm-service',
'concurrency': 1,
'taskgroup': AsyncMock()
}
# Act
service = LlmService(**config)
# Assert
model_spec = None
for spec in service.specifications:
if isinstance(spec, ParameterSpec) and spec.name == "model":
model_spec = spec
break
assert model_spec is not None
assert model_spec.name == "model"
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
def test_temperature_parameter_spec_properties(self):
"""Test that temperature parameter spec has correct properties"""
# Arrange
config = {
'id': 'test-llm-service',
'concurrency': 1,
'taskgroup': AsyncMock()
}
# Act
service = LlmService(**config)
# Assert
temperature_spec = None
for spec in service.specifications:
if isinstance(spec, ParameterSpec) and spec.name == "temperature":
temperature_spec = spec
break
assert temperature_spec is not None
assert temperature_spec.name == "temperature"
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_request_extracts_parameters_from_flow(self):
"""Test that on_request method extracts model and temperature from flow"""
# Arrange
config = {
'id': 'test-llm-service',
'concurrency': 1,
'taskgroup': AsyncMock()
}
service = LlmService(**config)
# Mock the metrics
service.text_completion_model_metric = MagicMock()
service.text_completion_model_metric.labels.return_value.info = AsyncMock()
# Mock the generate_content method to capture parameters
service.generate_content = AsyncMock(return_value=LlmResult(
text="test response",
in_token=10,
out_token=5,
model="gpt-4"
))
# Mock message and flow
mock_message = MagicMock()
mock_message.value.return_value = MagicMock()
mock_message.value.return_value.system = "system prompt"
mock_message.value.return_value.prompt = "user prompt"
mock_message.properties.return_value = {"id": "test-id"}
mock_consumer = MagicMock()
mock_consumer.name = "request"
mock_flow = MagicMock()
mock_flow.name = "test-flow"
mock_flow.return_value = "test-model" # flow("model") returns this
mock_flow.side_effect = lambda param: {
"model": "gpt-4",
"temperature": 0.7
}.get(param, f"mock-{param}")
mock_producer = AsyncMock()
mock_flow.producer = {"response": mock_producer}
# Act
await service.on_request(mock_message, mock_consumer, mock_flow)
# Assert
# Verify that generate_content was called with parameters from flow
service.generate_content.assert_called_once()
call_args = service.generate_content.call_args
assert call_args[0][0] == "system prompt" # system
assert call_args[0][1] == "user prompt" # prompt
assert call_args[0][2] == "gpt-4" # model
assert call_args[0][3] == 0.7 # temperature
# Verify flow was queried for both parameters
mock_flow.assert_any_call("model")
mock_flow.assert_any_call("temperature")
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_request_handles_missing_parameters_gracefully(self):
"""Test that on_request handles missing parameters gracefully"""
# Arrange
config = {
'id': 'test-llm-service',
'concurrency': 1,
'taskgroup': AsyncMock()
}
service = LlmService(**config)
# Mock the metrics
service.text_completion_model_metric = MagicMock()
service.text_completion_model_metric.labels.return_value.info = AsyncMock()
# Mock the generate_content method
service.generate_content = AsyncMock(return_value=LlmResult(
text="test response",
in_token=10,
out_token=5,
model="default-model"
))
# Mock message and flow where flow returns None for parameters
mock_message = MagicMock()
mock_message.value.return_value = MagicMock()
mock_message.value.return_value.system = "system prompt"
mock_message.value.return_value.prompt = "user prompt"
mock_message.properties.return_value = {"id": "test-id"}
mock_consumer = MagicMock()
mock_consumer.name = "request"
mock_flow = MagicMock()
mock_flow.name = "test-flow"
mock_flow.return_value = None # Both parameters return None
mock_producer = AsyncMock()
mock_flow.producer = {"response": mock_producer}
# Act
await service.on_request(mock_message, mock_consumer, mock_flow)
# Assert
# Should still call generate_content, with None values that will use processor defaults
service.generate_content.assert_called_once()
call_args = service.generate_content.call_args
assert call_args[0][0] == "system prompt" # system
assert call_args[0][1] == "user prompt" # prompt
assert call_args[0][2] is None # model (will use processor default)
assert call_args[0][3] is None # temperature (will use processor default)
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_request_error_handling_preserves_behavior(self):
"""Test that parameter extraction doesn't break existing error handling"""
# Arrange
config = {
'id': 'test-llm-service',
'concurrency': 1,
'taskgroup': AsyncMock()
}
service = LlmService(**config)
# Mock the metrics
service.text_completion_model_metric = MagicMock()
service.text_completion_model_metric.labels.return_value.info = AsyncMock()
# Mock generate_content to raise an exception
service.generate_content = AsyncMock(side_effect=Exception("Test error"))
# Mock message and flow
mock_message = MagicMock()
mock_message.value.return_value = MagicMock()
mock_message.value.return_value.system = "system prompt"
mock_message.value.return_value.prompt = "user prompt"
mock_message.properties.return_value = {"id": "test-id"}
mock_consumer = MagicMock()
mock_consumer.name = "request"
mock_flow = MagicMock()
mock_flow.name = "test-flow"
mock_flow.side_effect = lambda param: {
"model": "gpt-4",
"temperature": 0.7
}.get(param, f"mock-{param}")
mock_producer = AsyncMock()
mock_flow.producer = {"response": mock_producer}
# Act
await service.on_request(mock_message, mock_consumer, mock_flow)
# Assert
# Should have sent error response
mock_producer.send.assert_called_once()
error_response = mock_producer.send.call_args[0][0]
assert error_response.error is not None
assert error_response.error.type == "llm-error"
assert "Test error" in error_response.error.message
assert error_response.response is None
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -1,211 +1,236 @@
"""
Unit tests for trustgraph.chunking.recursive
Testing parameter override functionality for chunk-size and chunk-overlap
"""
import pytest
import asyncio
from unittest.mock import AsyncMock, Mock, patch, MagicMock
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
# Import the service under test
from trustgraph.chunking.recursive.chunker import Processor
from trustgraph.schema import TextDocument, Chunk, Metadata
from trustgraph.chunking.recursive.chunker import Processor as RecursiveChunker
@pytest.fixture
def mock_flow():
output_mock = AsyncMock()
flow_mock = Mock(return_value=output_mock)
return flow_mock, output_mock
class MockAsyncProcessor:
def __init__(self, **params):
self.config_handlers = []
self.id = params.get('id', 'test-service')
self.specifications = []
@pytest.fixture
def mock_consumer():
consumer = Mock()
consumer.id = "test-consumer"
consumer.flow = "test-flow"
return consumer
class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
"""Test Recursive chunker functionality"""
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
def test_processor_initialization_basic(self):
"""Test basic processor initialization"""
# Arrange
config = {
'id': 'test-chunker',
'chunk_size': 1500,
'chunk_overlap': 150,
'concurrency': 1,
'taskgroup': AsyncMock()
}
# Act
processor = Processor(**config)
# Assert
assert processor.default_chunk_size == 1500
assert processor.default_chunk_overlap == 150
assert hasattr(processor, 'text_splitter')
# Verify parameter specs are registered
param_specs = [spec for spec in processor.specifications
if hasattr(spec, 'name') and spec.name in ['chunk-size', 'chunk-overlap']]
assert len(param_specs) == 2
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_chunk_document_with_chunk_size_override(self):
"""Test chunk_document with chunk-size parameter override"""
# Arrange
config = {
'id': 'test-chunker',
'chunk_size': 1000, # Default chunk size
'chunk_overlap': 100,
'concurrency': 1,
'taskgroup': AsyncMock()
}
processor = Processor(**config)
# Mock message and flow
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
mock_flow.side_effect = lambda param: {
"chunk-size": 2000, # Override chunk size
"chunk-overlap": None # Use default chunk overlap
}.get(param)
# Act
chunk_size, chunk_overlap = await processor.chunk_document(
mock_message, mock_consumer, mock_flow, 1000, 100
)
# Assert
assert chunk_size == 2000 # Should use overridden value
assert chunk_overlap == 100 # Should use default value
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_chunk_document_with_chunk_overlap_override(self):
"""Test chunk_document with chunk-overlap parameter override"""
# Arrange
config = {
'id': 'test-chunker',
'chunk_size': 1000,
'chunk_overlap': 100, # Default chunk overlap
'concurrency': 1,
'taskgroup': AsyncMock()
}
processor = Processor(**config)
# Mock message and flow
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
mock_flow.side_effect = lambda param: {
"chunk-size": None, # Use default chunk size
"chunk-overlap": 200 # Override chunk overlap
}.get(param)
# Act
chunk_size, chunk_overlap = await processor.chunk_document(
mock_message, mock_consumer, mock_flow, 1000, 100
)
# Assert
assert chunk_size == 1000 # Should use default value
assert chunk_overlap == 200 # Should use overridden value
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_chunk_document_with_both_parameters_override(self):
"""Test chunk_document with both chunk-size and chunk-overlap overrides"""
# Arrange
config = {
'id': 'test-chunker',
'chunk_size': 1000,
'chunk_overlap': 100,
'concurrency': 1,
'taskgroup': AsyncMock()
}
processor = Processor(**config)
# Mock message and flow
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
mock_flow.side_effect = lambda param: {
"chunk-size": 1500, # Override chunk size
"chunk-overlap": 150 # Override chunk overlap
}.get(param)
# Act
chunk_size, chunk_overlap = await processor.chunk_document(
mock_message, mock_consumer, mock_flow, 1000, 100
)
# Assert
assert chunk_size == 1500 # Should use overridden value
assert chunk_overlap == 150 # Should use overridden value
@patch('trustgraph.chunking.recursive.chunker.RecursiveCharacterTextSplitter')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_uses_flow_parameters(self, mock_splitter_class):
"""Test that on_message method uses parameters from flow"""
# Arrange
mock_splitter = MagicMock()
mock_document = MagicMock()
mock_document.page_content = "Test chunk content"
mock_splitter.create_documents.return_value = [mock_document]
mock_splitter_class.return_value = mock_splitter
config = {
'id': 'test-chunker',
'chunk_size': 1000,
'chunk_overlap': 100,
'concurrency': 1,
'taskgroup': AsyncMock()
}
processor = Processor(**config)
# Mock message with TextDocument
mock_message = MagicMock()
mock_text_doc = MagicMock()
mock_text_doc.metadata = Metadata(
id="test-doc-123",
metadata=[],
user="test-user",
collection="test-collection"
)
mock_text_doc.text = b"This is test document content"
mock_message.value.return_value = mock_text_doc
# Mock consumer and flow with parameter overrides
mock_consumer = MagicMock()
mock_producer = AsyncMock()
mock_flow = MagicMock()
mock_flow.side_effect = lambda param: {
"chunk-size": 1500,
"chunk-overlap": 150,
"output": mock_producer
}.get(param)
# Act
await processor.on_message(mock_message, mock_consumer, mock_flow)
# Assert
# Verify RecursiveCharacterTextSplitter was called with overridden parameters (last call)
actual_last_call = mock_splitter_class.call_args_list[-1]
assert actual_last_call.kwargs['chunk_size'] == 1500
assert actual_last_call.kwargs['chunk_overlap'] == 150
assert actual_last_call.kwargs['length_function'] == len
assert actual_last_call.kwargs['is_separator_regex'] == False
# Verify chunk was sent to output
mock_producer.send.assert_called_once()
sent_chunk = mock_producer.send.call_args[0][0]
assert isinstance(sent_chunk, Chunk)
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_chunk_document_with_no_overrides(self):
"""Test chunk_document when no parameters are overridden (flow returns None)"""
# Arrange
config = {
'id': 'test-chunker',
'chunk_size': 1000,
'chunk_overlap': 100,
'concurrency': 1,
'taskgroup': AsyncMock()
}
processor = Processor(**config)
# Mock message and flow that returns None for all parameters
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
mock_flow.return_value = None # No overrides
# Act
chunk_size, chunk_overlap = await processor.chunk_document(
mock_message, mock_consumer, mock_flow, 1000, 100
)
# Assert
assert chunk_size == 1000 # Should use default value
assert chunk_overlap == 100 # Should use default value
@pytest.fixture
def sample_document():
metadata = Metadata(
id="test-doc-1",
metadata=[],
user="test-user",
collection="test-collection"
)
text = "This is a test document. " * 100 # Create text long enough to be chunked
return TextDocument(
metadata=metadata,
text=text.encode("utf-8")
)
@pytest.fixture
def short_document():
metadata = Metadata(
id="test-doc-2",
metadata=[],
user="test-user",
collection="test-collection"
)
text = "This is a very short document."
return TextDocument(
metadata=metadata,
text=text.encode("utf-8")
)
class TestRecursiveChunker:
def test_init_default_params(self, mock_async_processor_init):
processor = RecursiveChunker()
assert processor.text_splitter._chunk_size == 2000
assert processor.text_splitter._chunk_overlap == 100
def test_init_custom_params(self, mock_async_processor_init):
processor = RecursiveChunker(chunk_size=500, chunk_overlap=50)
assert processor.text_splitter._chunk_size == 500
assert processor.text_splitter._chunk_overlap == 50
def test_init_with_id(self, mock_async_processor_init):
processor = RecursiveChunker(id="custom-chunker")
assert processor.id == "custom-chunker"
@pytest.mark.asyncio
async def test_on_message_single_chunk(self, mock_async_processor_init, mock_flow, mock_consumer, short_document):
flow_mock, output_mock = mock_flow
processor = RecursiveChunker(chunk_size=2000, chunk_overlap=100)
msg = Mock()
msg.value.return_value = short_document
await processor.on_message(msg, mock_consumer, flow_mock)
# Should produce exactly one chunk for short text
assert output_mock.send.call_count == 1
# Verify the chunk was created correctly
chunk_call = output_mock.send.call_args[0][0]
assert isinstance(chunk_call, Chunk)
assert chunk_call.metadata == short_document.metadata
assert chunk_call.chunk.decode("utf-8") == short_document.text.decode("utf-8")
@pytest.mark.asyncio
async def test_on_message_multiple_chunks(self, mock_async_processor_init, mock_flow, mock_consumer, sample_document):
flow_mock, output_mock = mock_flow
processor = RecursiveChunker(chunk_size=100, chunk_overlap=20)
msg = Mock()
msg.value.return_value = sample_document
await processor.on_message(msg, mock_consumer, flow_mock)
# Should produce multiple chunks
assert output_mock.send.call_count > 1
# Verify all chunks have correct metadata
for call in output_mock.send.call_args_list:
chunk = call[0][0]
assert isinstance(chunk, Chunk)
assert chunk.metadata == sample_document.metadata
assert len(chunk.chunk) > 0
@pytest.mark.asyncio
async def test_on_message_chunk_overlap(self, mock_async_processor_init, mock_flow, mock_consumer):
flow_mock, output_mock = mock_flow
processor = RecursiveChunker(chunk_size=50, chunk_overlap=10)
# Create a document with predictable content
metadata = Metadata(id="test", metadata=[], user="test-user", collection="test-collection")
text = "ABCDEFGHIJ" * 10 # 100 characters
document = TextDocument(metadata=metadata, text=text.encode("utf-8"))
msg = Mock()
msg.value.return_value = document
await processor.on_message(msg, mock_consumer, flow_mock)
# Collect all chunks
chunks = []
for call in output_mock.send.call_args_list:
chunk_text = call[0][0].chunk.decode("utf-8")
chunks.append(chunk_text)
# Verify chunks have expected overlap
for i in range(len(chunks) - 1):
# The end of chunk i should overlap with the beginning of chunk i+1
# Check if there's some overlap (exact overlap depends on text splitter logic)
assert len(chunks[i]) <= 50 + 10 # chunk_size + some tolerance
@pytest.mark.asyncio
async def test_on_message_empty_document(self, mock_async_processor_init, mock_flow, mock_consumer):
flow_mock, output_mock = mock_flow
processor = RecursiveChunker()
metadata = Metadata(id="empty", metadata=[], user="test-user", collection="test-collection")
document = TextDocument(metadata=metadata, text=b"")
msg = Mock()
msg.value.return_value = document
await processor.on_message(msg, mock_consumer, flow_mock)
# Empty documents typically don't produce chunks with langchain splitters
# This behavior is expected - no chunks should be produced
assert output_mock.send.call_count == 0
@pytest.mark.asyncio
async def test_on_message_unicode_handling(self, mock_async_processor_init, mock_flow, mock_consumer):
flow_mock, output_mock = mock_flow
processor = RecursiveChunker(chunk_size=500, chunk_overlap=20) # Fixed overlap < chunk_size
metadata = Metadata(id="unicode", metadata=[], user="test-user", collection="test-collection")
text = "Hello 世界! 🌍 This is a test with émojis and spëcial characters."
document = TextDocument(metadata=metadata, text=text.encode("utf-8"))
msg = Mock()
msg.value.return_value = document
await processor.on_message(msg, mock_consumer, flow_mock)
# Verify unicode is preserved correctly
all_chunks = []
for call in output_mock.send.call_args_list:
chunk_text = call[0][0].chunk.decode("utf-8")
all_chunks.append(chunk_text)
# Reconstruct text (approximately, due to overlap)
reconstructed = "".join(all_chunks)
assert "世界" in reconstructed
assert "🌍" in reconstructed
assert "émojis" in reconstructed
@pytest.mark.asyncio
async def test_metrics_recorded(self, mock_async_processor_init, mock_flow, mock_consumer, sample_document):
flow_mock, output_mock = mock_flow
processor = RecursiveChunker(chunk_size=100)
msg = Mock()
msg.value.return_value = sample_document
# Mock the metric
with patch.object(RecursiveChunker.chunk_metric, 'labels') as mock_labels:
mock_observe = Mock()
mock_labels.return_value.observe = mock_observe
await processor.on_message(msg, mock_consumer, flow_mock)
# Verify metrics were recorded
mock_labels.assert_called_with(id="test-consumer", flow="test-flow")
assert mock_observe.call_count > 0
# Verify chunk sizes were observed
for call in mock_observe.call_args_list:
chunk_size = call[0][0]
assert chunk_size > 0
def test_add_args(self):
parser = Mock()
RecursiveChunker.add_args(parser)
# Verify arguments were added
calls = parser.add_argument.call_args_list
arg_names = [call[0][0] for call in calls]
assert '-z' in arg_names or '--chunk-size' in arg_names
assert '-v' in arg_names or '--chunk-overlap' in arg_names
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -1,275 +1,256 @@
"""
Unit tests for trustgraph.chunking.token
Testing parameter override functionality for chunk-size and chunk-overlap
"""
import pytest
import asyncio
from unittest.mock import AsyncMock, Mock, patch
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
# Import the service under test
from trustgraph.chunking.token.chunker import Processor
from trustgraph.schema import TextDocument, Chunk, Metadata
from trustgraph.chunking.token.chunker import Processor as TokenChunker
@pytest.fixture
def mock_flow():
output_mock = AsyncMock()
flow_mock = Mock(return_value=output_mock)
return flow_mock, output_mock
class MockAsyncProcessor:
def __init__(self, **params):
self.config_handlers = []
self.id = params.get('id', 'test-service')
self.specifications = []
@pytest.fixture
def mock_consumer():
consumer = Mock()
consumer.id = "test-consumer"
consumer.flow = "test-flow"
return consumer
class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
"""Test Token chunker functionality"""
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
def test_processor_initialization_basic(self):
"""Test basic processor initialization"""
# Arrange
config = {
'id': 'test-chunker',
'chunk_size': 300,
'chunk_overlap': 20,
'concurrency': 1,
'taskgroup': AsyncMock()
}
# Act
processor = Processor(**config)
# Assert
assert processor.default_chunk_size == 300
assert processor.default_chunk_overlap == 20
assert hasattr(processor, 'text_splitter')
# Verify parameter specs are registered
param_specs = [spec for spec in processor.specifications
if hasattr(spec, 'name') and spec.name in ['chunk-size', 'chunk-overlap']]
assert len(param_specs) == 2
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_chunk_document_with_chunk_size_override(self):
"""Test chunk_document with chunk-size parameter override"""
# Arrange
config = {
'id': 'test-chunker',
'chunk_size': 250, # Default chunk size
'chunk_overlap': 15,
'concurrency': 1,
'taskgroup': AsyncMock()
}
processor = Processor(**config)
# Mock message and flow
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
mock_flow.side_effect = lambda param: {
"chunk-size": 400, # Override chunk size
"chunk-overlap": None # Use default chunk overlap
}.get(param)
# Act
chunk_size, chunk_overlap = await processor.chunk_document(
mock_message, mock_consumer, mock_flow, 250, 15
)
# Assert
assert chunk_size == 400 # Should use overridden value
assert chunk_overlap == 15 # Should use default value
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_chunk_document_with_chunk_overlap_override(self):
"""Test chunk_document with chunk-overlap parameter override"""
# Arrange
config = {
'id': 'test-chunker',
'chunk_size': 250,
'chunk_overlap': 15, # Default chunk overlap
'concurrency': 1,
'taskgroup': AsyncMock()
}
processor = Processor(**config)
# Mock message and flow
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
mock_flow.side_effect = lambda param: {
"chunk-size": None, # Use default chunk size
"chunk-overlap": 25 # Override chunk overlap
}.get(param)
# Act
chunk_size, chunk_overlap = await processor.chunk_document(
mock_message, mock_consumer, mock_flow, 250, 15
)
# Assert
assert chunk_size == 250 # Should use default value
assert chunk_overlap == 25 # Should use overridden value
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_chunk_document_with_both_parameters_override(self):
"""Test chunk_document with both chunk-size and chunk-overlap overrides"""
# Arrange
config = {
'id': 'test-chunker',
'chunk_size': 250,
'chunk_overlap': 15,
'concurrency': 1,
'taskgroup': AsyncMock()
}
processor = Processor(**config)
# Mock message and flow
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
mock_flow.side_effect = lambda param: {
"chunk-size": 350, # Override chunk size
"chunk-overlap": 30 # Override chunk overlap
}.get(param)
# Act
chunk_size, chunk_overlap = await processor.chunk_document(
mock_message, mock_consumer, mock_flow, 250, 15
)
# Assert
assert chunk_size == 350 # Should use overridden value
assert chunk_overlap == 30 # Should use overridden value
@patch('trustgraph.chunking.token.chunker.TokenTextSplitter')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_uses_flow_parameters(self, mock_splitter_class):
"""Test that on_message method uses parameters from flow"""
# Arrange
mock_splitter = MagicMock()
mock_document = MagicMock()
mock_document.page_content = "Test token chunk content"
mock_splitter.create_documents.return_value = [mock_document]
mock_splitter_class.return_value = mock_splitter
config = {
'id': 'test-chunker',
'chunk_size': 250,
'chunk_overlap': 15,
'concurrency': 1,
'taskgroup': AsyncMock()
}
processor = Processor(**config)
# Mock message with TextDocument
mock_message = MagicMock()
mock_text_doc = MagicMock()
mock_text_doc.metadata = Metadata(
id="test-doc-456",
metadata=[],
user="test-user",
collection="test-collection"
)
mock_text_doc.text = b"This is test document content for token chunking"
mock_message.value.return_value = mock_text_doc
# Mock consumer and flow with parameter overrides
mock_consumer = MagicMock()
mock_producer = AsyncMock()
mock_flow = MagicMock()
mock_flow.side_effect = lambda param: {
"chunk-size": 400,
"chunk-overlap": 40,
"output": mock_producer
}.get(param)
# Act
await processor.on_message(mock_message, mock_consumer, mock_flow)
# Assert
# Verify TokenTextSplitter was called with overridden parameters (last call)
expected_call = [
('encoding_name', 'cl100k_base'),
('chunk_size', 400),
('chunk_overlap', 40)
]
actual_last_call = mock_splitter_class.call_args_list[-1]
assert actual_last_call.kwargs['encoding_name'] == "cl100k_base"
assert actual_last_call.kwargs['chunk_size'] == 400
assert actual_last_call.kwargs['chunk_overlap'] == 40
# Verify chunk was sent to output
mock_producer.send.assert_called_once()
sent_chunk = mock_producer.send.call_args[0][0]
assert isinstance(sent_chunk, Chunk)
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_chunk_document_with_no_overrides(self):
"""Test chunk_document when no parameters are overridden (flow returns None)"""
# Arrange
config = {
'id': 'test-chunker',
'chunk_size': 250,
'chunk_overlap': 15,
'concurrency': 1,
'taskgroup': AsyncMock()
}
processor = Processor(**config)
# Mock message and flow that returns None for all parameters
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
mock_flow.return_value = None # No overrides
# Act
chunk_size, chunk_overlap = await processor.chunk_document(
mock_message, mock_consumer, mock_flow, 250, 15
)
# Assert
assert chunk_size == 250 # Should use default value
assert chunk_overlap == 15 # Should use default value
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
def test_token_chunker_uses_different_defaults(self):
"""Test that token chunker has different defaults than recursive chunker"""
# Arrange & Act
config = {
'id': 'test-chunker',
'concurrency': 1,
'taskgroup': AsyncMock()
}
processor = Processor(**config)
# Assert - Token chunker should have different defaults
assert processor.default_chunk_size == 250 # Token chunker default
assert processor.default_chunk_overlap == 15 # Token chunker default
@pytest.fixture
def sample_document():
metadata = Metadata(
id="test-doc-1",
metadata=[],
user="test-user",
collection="test-collection"
)
# Create text that will result in multiple token chunks
text = "The quick brown fox jumps over the lazy dog. " * 50
return TextDocument(
metadata=metadata,
text=text.encode("utf-8")
)
@pytest.fixture
def short_document():
metadata = Metadata(
id="test-doc-2",
metadata=[],
user="test-user",
collection="test-collection"
)
text = "Short text."
return TextDocument(
metadata=metadata,
text=text.encode("utf-8")
)
class TestTokenChunker:
def test_init_default_params(self, mock_async_processor_init):
processor = TokenChunker()
assert processor.text_splitter._chunk_size == 250
assert processor.text_splitter._chunk_overlap == 15
# Just verify the text splitter was created (encoding verification is complex)
assert processor.text_splitter is not None
assert hasattr(processor.text_splitter, 'split_text')
def test_init_custom_params(self, mock_async_processor_init):
processor = TokenChunker(chunk_size=100, chunk_overlap=10)
assert processor.text_splitter._chunk_size == 100
assert processor.text_splitter._chunk_overlap == 10
def test_init_with_id(self, mock_async_processor_init):
processor = TokenChunker(id="custom-token-chunker")
assert processor.id == "custom-token-chunker"
@pytest.mark.asyncio
async def test_on_message_single_chunk(self, mock_async_processor_init, mock_flow, mock_consumer, short_document):
flow_mock, output_mock = mock_flow
processor = TokenChunker(chunk_size=250, chunk_overlap=15)
msg = Mock()
msg.value.return_value = short_document
await processor.on_message(msg, mock_consumer, flow_mock)
# Short text should produce exactly one chunk
assert output_mock.send.call_count == 1
# Verify the chunk was created correctly
chunk_call = output_mock.send.call_args[0][0]
assert isinstance(chunk_call, Chunk)
assert chunk_call.metadata == short_document.metadata
assert chunk_call.chunk.decode("utf-8") == short_document.text.decode("utf-8")
@pytest.mark.asyncio
async def test_on_message_multiple_chunks(self, mock_async_processor_init, mock_flow, mock_consumer, sample_document):
flow_mock, output_mock = mock_flow
processor = TokenChunker(chunk_size=50, chunk_overlap=5)
msg = Mock()
msg.value.return_value = sample_document
await processor.on_message(msg, mock_consumer, flow_mock)
# Should produce multiple chunks
assert output_mock.send.call_count > 1
# Verify all chunks have correct metadata
for call in output_mock.send.call_args_list:
chunk = call[0][0]
assert isinstance(chunk, Chunk)
assert chunk.metadata == sample_document.metadata
assert len(chunk.chunk) > 0
@pytest.mark.asyncio
async def test_on_message_token_overlap(self, mock_async_processor_init, mock_flow, mock_consumer):
flow_mock, output_mock = mock_flow
processor = TokenChunker(chunk_size=20, chunk_overlap=5)
# Create a document with repeated pattern
metadata = Metadata(id="test", metadata=[], user="test-user", collection="test-collection")
text = "one two three four five six seven eight nine ten " * 5
document = TextDocument(metadata=metadata, text=text.encode("utf-8"))
msg = Mock()
msg.value.return_value = document
await processor.on_message(msg, mock_consumer, flow_mock)
# Collect all chunks
chunks = []
for call in output_mock.send.call_args_list:
chunk_text = call[0][0].chunk.decode("utf-8")
chunks.append(chunk_text)
# Should have multiple chunks
assert len(chunks) > 1
# Verify chunks are not empty
for chunk in chunks:
assert len(chunk) > 0
@pytest.mark.asyncio
async def test_on_message_empty_document(self, mock_async_processor_init, mock_flow, mock_consumer):
flow_mock, output_mock = mock_flow
processor = TokenChunker()
metadata = Metadata(id="empty", metadata=[], user="test-user", collection="test-collection")
document = TextDocument(metadata=metadata, text=b"")
msg = Mock()
msg.value.return_value = document
await processor.on_message(msg, mock_consumer, flow_mock)
# Empty documents typically don't produce chunks with langchain splitters
# This behavior is expected - no chunks should be produced
assert output_mock.send.call_count == 0
@pytest.mark.asyncio
async def test_on_message_unicode_handling(self, mock_async_processor_init, mock_flow, mock_consumer):
flow_mock, output_mock = mock_flow
processor = TokenChunker(chunk_size=50)
metadata = Metadata(id="unicode", metadata=[], user="test-user", collection="test-collection")
# Test with various unicode characters
text = "Hello 世界! 🌍 Test émojis café naïve résumé. Greek: αβγδε Hebrew: אבגדה"
document = TextDocument(metadata=metadata, text=text.encode("utf-8"))
msg = Mock()
msg.value.return_value = document
await processor.on_message(msg, mock_consumer, flow_mock)
# Verify unicode is preserved correctly
all_chunks = []
for call in output_mock.send.call_args_list:
chunk_text = call[0][0].chunk.decode("utf-8")
all_chunks.append(chunk_text)
# Reconstruct text
reconstructed = "".join(all_chunks)
assert "世界" in reconstructed
assert "🌍" in reconstructed
assert "émojis" in reconstructed
assert "αβγδε" in reconstructed
assert "אבגדה" in reconstructed
@pytest.mark.asyncio
async def test_on_message_token_boundary_preservation(self, mock_async_processor_init, mock_flow, mock_consumer):
flow_mock, output_mock = mock_flow
processor = TokenChunker(chunk_size=10, chunk_overlap=2)
metadata = Metadata(id="boundary", metadata=[], user="test-user", collection="test-collection")
# Text with clear word boundaries
text = "This is a test of token boundaries and proper splitting."
document = TextDocument(metadata=metadata, text=text.encode("utf-8"))
msg = Mock()
msg.value.return_value = document
await processor.on_message(msg, mock_consumer, flow_mock)
# Collect all chunks
chunks = []
for call in output_mock.send.call_args_list:
chunk_text = call[0][0].chunk.decode("utf-8")
chunks.append(chunk_text)
# Token chunker should respect token boundaries
for chunk in chunks:
# Chunks should not start or end with partial words (in most cases)
assert len(chunk.strip()) > 0
@pytest.mark.asyncio
async def test_metrics_recorded(self, mock_async_processor_init, mock_flow, mock_consumer, sample_document):
flow_mock, output_mock = mock_flow
processor = TokenChunker(chunk_size=50)
msg = Mock()
msg.value.return_value = sample_document
# Mock the metric
with patch.object(TokenChunker.chunk_metric, 'labels') as mock_labels:
mock_observe = Mock()
mock_labels.return_value.observe = mock_observe
await processor.on_message(msg, mock_consumer, flow_mock)
# Verify metrics were recorded
mock_labels.assert_called_with(id="test-consumer", flow="test-flow")
assert mock_observe.call_count > 0
# Verify chunk sizes were observed
for call in mock_observe.call_args_list:
chunk_size = call[0][0]
assert chunk_size > 0
def test_add_args(self):
parser = Mock()
TokenChunker.add_args(parser)
# Verify arguments were added
calls = parser.add_argument.call_args_list
arg_names = [call[0][0] for call in calls]
assert '-z' in arg_names or '--chunk-size' in arg_names
assert '-v' in arg_names or '--chunk-overlap' in arg_names
@pytest.mark.asyncio
async def test_encoding_specific_behavior(self, mock_async_processor_init, mock_flow, mock_consumer):
flow_mock, output_mock = mock_flow
processor = TokenChunker(chunk_size=10, chunk_overlap=0)
metadata = Metadata(id="encoding", metadata=[], user="test-user", collection="test-collection")
# Test text that might tokenize differently with cl100k_base encoding
text = "GPT-4 is an AI model. It uses tokens."
document = TextDocument(metadata=metadata, text=text.encode("utf-8"))
msg = Mock()
msg.value.return_value = document
await processor.on_message(msg, mock_consumer, flow_mock)
# Verify chunking happened
assert output_mock.send.call_count >= 1
# Collect all chunks
chunks = []
for call in output_mock.send.call_args_list:
chunk_text = call[0][0].chunk.decode("utf-8")
chunks.append(chunk_text)
# Verify all text is preserved (allowing for overlap)
all_text = " ".join(chunks)
assert "GPT-4" in all_text
assert "AI model" in all_text
assert "tokens" in all_text
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -34,7 +34,9 @@ class TestGraphRag:
assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client
assert graph_rag.triples_client == mock_triples_client
assert graph_rag.verbose is False # Default value
assert graph_rag.label_cache == {} # Empty cache initially
# Verify label_cache is an LRUCacheWithTTL instance
from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL
assert isinstance(graph_rag.label_cache, LRUCacheWithTTL)
def test_graph_rag_initialization_with_verbose(self):
"""Test GraphRag initialization with verbose enabled"""
@ -59,7 +61,9 @@ class TestGraphRag:
assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client
assert graph_rag.triples_client == mock_triples_client
assert graph_rag.verbose is True
assert graph_rag.label_cache == {} # Empty cache initially
# Verify label_cache is an LRUCacheWithTTL instance
from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL
assert isinstance(graph_rag.label_cache, LRUCacheWithTTL)
class TestQuery:
@ -228,8 +232,11 @@ class TestQuery:
"""Test Query.maybe_label method with cached label"""
# Create mock GraphRag with label cache
mock_rag = MagicMock()
mock_rag.label_cache = {"entity1": "Entity One Label"}
# Create mock LRUCacheWithTTL
mock_cache = MagicMock()
mock_cache.get.return_value = "Entity One Label"
mock_rag.label_cache = mock_cache
# Initialize Query
query = Query(
rag=mock_rag,
@ -237,27 +244,32 @@ class TestQuery:
collection="test_collection",
verbose=False
)
# Call maybe_label with cached entity
result = await query.maybe_label("entity1")
# Verify cached label is returned
assert result == "Entity One Label"
# Verify cache was checked with proper key format (user:collection:entity)
mock_cache.get.assert_called_once_with("test_user:test_collection:entity1")
@pytest.mark.asyncio
async def test_maybe_label_with_label_lookup(self):
"""Test Query.maybe_label method with database label lookup"""
# Create mock GraphRag with triples client
mock_rag = MagicMock()
mock_rag.label_cache = {} # Empty cache
# Create mock LRUCacheWithTTL that returns None (cache miss)
mock_cache = MagicMock()
mock_cache.get.return_value = None
mock_rag.label_cache = mock_cache
mock_triples_client = AsyncMock()
mock_rag.triples_client = mock_triples_client
# Mock triple result with label
mock_triple = MagicMock()
mock_triple.o = "Human Readable Label"
mock_triples_client.query.return_value = [mock_triple]
# Initialize Query
query = Query(
rag=mock_rag,
@ -265,10 +277,10 @@ class TestQuery:
collection="test_collection",
verbose=False
)
# Call maybe_label
result = await query.maybe_label("http://example.com/entity")
# Verify triples client was called correctly
mock_triples_client.query.assert_called_once_with(
s="http://example.com/entity",
@ -278,17 +290,21 @@ class TestQuery:
user="test_user",
collection="test_collection"
)
# Verify result and cache update
# Verify result and cache update with proper key
assert result == "Human Readable Label"
assert mock_rag.label_cache["http://example.com/entity"] == "Human Readable Label"
cache_key = "test_user:test_collection:http://example.com/entity"
mock_cache.put.assert_called_once_with(cache_key, "Human Readable Label")
@pytest.mark.asyncio
async def test_maybe_label_with_no_label_found(self):
"""Test Query.maybe_label method when no label is found"""
# Create mock GraphRag with triples client
mock_rag = MagicMock()
mock_rag.label_cache = {} # Empty cache
# Create mock LRUCacheWithTTL that returns None (cache miss)
mock_cache = MagicMock()
mock_cache.get.return_value = None
mock_rag.label_cache = mock_cache
mock_triples_client = AsyncMock()
mock_rag.triples_client = mock_triples_client
@ -318,7 +334,8 @@ class TestQuery:
# Verify result is entity itself and cache is updated
assert result == "unlabeled_entity"
assert mock_rag.label_cache["unlabeled_entity"] == "unlabeled_entity"
cache_key = "test_user:test_collection:unlabeled_entity"
mock_cache.put.assert_called_once_with(cache_key, "unlabeled_entity")
@pytest.mark.asyncio
async def test_follow_edges_basic_functionality(self):
@ -441,40 +458,40 @@ class TestQuery:
@pytest.mark.asyncio
async def test_get_subgraph_method(self):
"""Test Query.get_subgraph method orchestrates entity and edge discovery"""
# Create mock Query that patches get_entities and follow_edges
# Create mock Query that patches get_entities and follow_edges_batch
mock_rag = MagicMock()
query = Query(
rag=mock_rag,
user="test_user",
user="test_user",
collection="test_collection",
verbose=False,
max_path_length=1
)
# Mock get_entities to return test entities
query.get_entities = AsyncMock(return_value=["entity1", "entity2"])
# Mock follow_edges to add triples to subgraph
async def mock_follow_edges(ent, subgraph, path_length):
subgraph.add((ent, "predicate", "object"))
query.follow_edges = AsyncMock(side_effect=mock_follow_edges)
# Mock follow_edges_batch to return test triples
query.follow_edges_batch = AsyncMock(return_value={
("entity1", "predicate1", "object1"),
("entity2", "predicate2", "object2")
})
# Call get_subgraph
result = await query.get_subgraph("test query")
# Verify get_entities was called
query.get_entities.assert_called_once_with("test query")
# Verify follow_edges was called for each entity
assert query.follow_edges.call_count == 2
query.follow_edges.assert_any_call("entity1", unittest.mock.ANY, 1)
query.follow_edges.assert_any_call("entity2", unittest.mock.ANY, 1)
# Verify result is list format
# Verify follow_edges_batch was called with entities and max_path_length
query.follow_edges_batch.assert_called_once_with(["entity1", "entity2"], 1)
# Verify result is list format and contains expected triples
assert isinstance(result, list)
assert len(result) == 2
assert ("entity1", "predicate1", "object1") in result
assert ("entity2", "predicate2", "object2") in result
@pytest.mark.asyncio
async def test_get_labelgraph_method(self):

View file

@ -178,37 +178,24 @@ class TestPineconeDocEmbeddingsStorageProcessor:
assert calls[2][1]['vectors'][0]['metadata']['doc'] == "This is the second document chunk"
@pytest.mark.asyncio
async def test_store_document_embeddings_index_creation(self, processor):
"""Test automatic index creation when index doesn't exist"""
async def test_store_document_embeddings_index_validation(self, processor):
"""Test that writing to non-existent index raises ValueError"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b"Test document content",
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
# Mock index doesn't exist initially
# Mock index doesn't exist
processor.pinecone.has_index.return_value = False
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
# Mock index creation
processor.pinecone.describe_index.return_value.status = {"ready": True}
with patch('uuid.uuid4', return_value='test-id'):
with pytest.raises(ValueError, match="Collection .* does not exist"):
await processor.store_document_embeddings(message)
# Verify index creation was called
expected_index_name = "d-test_user-test_collection"
processor.pinecone.create_index.assert_called_once()
create_call = processor.pinecone.create_index.call_args
assert create_call[1]['name'] == expected_index_name
assert create_call[1]['dimension'] == 3
assert create_call[1]['metric'] == "cosine"
@pytest.mark.asyncio
async def test_store_document_embeddings_empty_chunk(self, processor):
@ -357,47 +344,44 @@ class TestPineconeDocEmbeddingsStorageProcessor:
mock_index.upsert.assert_not_called()
@pytest.mark.asyncio
async def test_store_document_embeddings_index_creation_failure(self, processor):
"""Test handling of index creation failure"""
async def test_store_document_embeddings_validation_before_creation(self, processor):
"""Test that validation error occurs before creation attempts"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b"Test document content",
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
# Mock index doesn't exist and creation fails
# Mock index doesn't exist
processor.pinecone.has_index.return_value = False
processor.pinecone.create_index.side_effect = Exception("Index creation failed")
with pytest.raises(Exception, match="Index creation failed"):
with pytest.raises(ValueError, match="Collection .* does not exist"):
await processor.store_document_embeddings(message)
@pytest.mark.asyncio
async def test_store_document_embeddings_index_creation_timeout(self, processor):
"""Test handling of index creation timeout"""
async def test_store_document_embeddings_validates_before_timeout(self, processor):
"""Test that validation error occurs before timeout checks"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
chunk=b"Test document content",
vectors=[[0.1, 0.2, 0.3]]
)
message.chunks = [chunk]
# Mock index doesn't exist and never becomes ready
# Mock index doesn't exist
processor.pinecone.has_index.return_value = False
processor.pinecone.describe_index.return_value.status = {"ready": False}
with patch('time.sleep'): # Speed up the test
with pytest.raises(RuntimeError, match="Gave up waiting for index creation"):
await processor.store_document_embeddings(message)
with pytest.raises(ValueError, match="Collection .* does not exist"):
await processor.store_document_embeddings(message)
@pytest.mark.asyncio
async def test_store_document_embeddings_unicode_content(self, processor):

View file

@ -43,8 +43,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Verify processor attributes
assert hasattr(processor, 'qdrant')
assert processor.qdrant == mock_qdrant_instance
assert hasattr(processor, 'last_collection')
assert processor.last_collection is None
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
@ -245,8 +243,9 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True # Collection exists
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
@ -255,36 +254,37 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
}
processor = Processor(**config)
# Create mock message with empty chunk
mock_message = MagicMock()
mock_message.metadata.user = 'empty_user'
mock_message.metadata.collection = 'empty_collection'
mock_chunk_empty = MagicMock()
mock_chunk_empty.chunk.decode.return_value = "" # Empty string
mock_chunk_empty.vectors = [[0.1, 0.2]]
mock_message.chunks = [mock_chunk_empty]
# Act
await processor.store_document_embeddings(mock_message)
# Assert
# Should not call upsert for empty chunks
mock_qdrant_instance.upsert.assert_not_called()
mock_qdrant_instance.collection_exists.assert_not_called()
# But collection_exists should be called for validation
mock_qdrant_instance.collection_exists.assert_called_once()
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
async def test_collection_creation_when_not_exists(self, mock_base_init, mock_qdrant_client):
"""Test collection creation when it doesn't exist"""
"""Test that writing to non-existent collection raises ValueError"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = False # Collection doesn't exist
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
@ -293,46 +293,32 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
}
processor = Processor(**config)
# Create mock message
mock_message = MagicMock()
mock_message.metadata.user = 'new_user'
mock_message.metadata.collection = 'new_collection'
mock_chunk = MagicMock()
mock_chunk.chunk.decode.return_value = 'test chunk'
mock_chunk.vectors = [[0.1, 0.2, 0.3, 0.4, 0.5]] # 5 dimensions
mock_message.chunks = [mock_chunk]
# Act
await processor.store_document_embeddings(mock_message)
# Assert
expected_collection = 'd_new_user_new_collection'
# Verify collection existence check and creation
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection)
mock_qdrant_instance.create_collection.assert_called_once()
# Verify create_collection was called with correct parameters
create_call_args = mock_qdrant_instance.create_collection.call_args
assert create_call_args[1]['collection_name'] == expected_collection
# Verify upsert was still called after collection creation
mock_qdrant_instance.upsert.assert_called_once()
mock_message.chunks = [mock_chunk]
# Act & Assert
with pytest.raises(ValueError, match="Collection .* does not exist"):
await processor.store_document_embeddings(mock_message)
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
async def test_collection_creation_exception(self, mock_base_init, mock_qdrant_client):
"""Test collection creation handles exceptions"""
"""Test that validation error occurs before connection errors"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = False
mock_qdrant_instance.create_collection.side_effect = Exception("Qdrant connection failed")
mock_qdrant_instance.collection_exists.return_value = False # Collection doesn't exist
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
@ -341,32 +327,35 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
}
processor = Processor(**config)
# Create mock message
mock_message = MagicMock()
mock_message.metadata.user = 'error_user'
mock_message.metadata.collection = 'error_collection'
mock_chunk = MagicMock()
mock_chunk.chunk.decode.return_value = 'test chunk'
mock_chunk.vectors = [[0.1, 0.2]]
mock_message.chunks = [mock_chunk]
# Act & Assert
with pytest.raises(Exception, match="Qdrant connection failed"):
with pytest.raises(ValueError, match="Collection .* does not exist"):
await processor.store_document_embeddings(mock_message)
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.base.DocumentEmbeddingsStoreService.__init__')
async def test_collection_caching_behavior(self, mock_base_init, mock_qdrant_client):
"""Test collection caching with last_collection"""
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
async def test_collection_validation_on_write(self, mock_uuid, mock_base_init, mock_qdrant_client):
"""Test collection validation checks collection exists before writing"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True
mock_qdrant_client.return_value = mock_qdrant_instance
mock_uuid.uuid4.return_value = MagicMock()
mock_uuid.uuid4.return_value.__str__ = MagicMock(return_value='test-uuid')
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
@ -375,46 +364,45 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
}
processor = Processor(**config)
# Create first mock message
mock_message1 = MagicMock()
mock_message1.metadata.user = 'cache_user'
mock_message1.metadata.collection = 'cache_collection'
mock_chunk1 = MagicMock()
mock_chunk1.chunk.decode.return_value = 'first chunk'
mock_chunk1.vectors = [[0.1, 0.2, 0.3]]
mock_message1.chunks = [mock_chunk1]
# First call
await processor.store_document_embeddings(mock_message1)
# Reset mock to track second call
mock_qdrant_instance.reset_mock()
mock_qdrant_instance.collection_exists.return_value = True
# Create second mock message with same dimensions
mock_message2 = MagicMock()
mock_message2.metadata.user = 'cache_user'
mock_message2.metadata.collection = 'cache_collection'
mock_chunk2 = MagicMock()
mock_chunk2.chunk.decode.return_value = 'second chunk'
mock_chunk2.vectors = [[0.4, 0.5, 0.6]] # Same dimension (3)
mock_message2.chunks = [mock_chunk2]
# Act - Second call with same collection
await processor.store_document_embeddings(mock_message2)
# Assert
expected_collection = 'd_cache_user_cache_collection'
assert processor.last_collection == expected_collection
# Verify second call skipped existence check (cached)
mock_qdrant_instance.collection_exists.assert_not_called()
mock_qdrant_instance.create_collection.assert_not_called()
# Verify collection existence is checked on each write
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection)
# But upsert should still be called
mock_qdrant_instance.upsert.assert_called_once()

View file

@ -178,37 +178,24 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
assert calls[2][1]['vectors'][0]['metadata']['entity'] == "entity2"
@pytest.mark.asyncio
async def test_store_graph_embeddings_index_creation(self, processor):
"""Test automatic index creation when index doesn't exist"""
async def test_store_graph_embeddings_index_validation(self, processor):
"""Test that writing to non-existent index raises ValueError"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value="test_entity", is_uri=False),
vectors=[[0.1, 0.2, 0.3]]
)
message.entities = [entity]
# Mock index doesn't exist initially
# Mock index doesn't exist
processor.pinecone.has_index.return_value = False
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
# Mock index creation
processor.pinecone.describe_index.return_value.status = {"ready": True}
with patch('uuid.uuid4', return_value='test-id'):
with pytest.raises(ValueError, match="Collection .* does not exist"):
await processor.store_graph_embeddings(message)
# Verify index creation was called
expected_index_name = "t-test_user-test_collection"
processor.pinecone.create_index.assert_called_once()
create_call = processor.pinecone.create_index.call_args
assert create_call[1]['name'] == expected_index_name
assert create_call[1]['dimension'] == 3
assert create_call[1]['metric'] == "cosine"
@pytest.mark.asyncio
async def test_store_graph_embeddings_empty_entity_value(self, processor):
@ -328,47 +315,44 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
mock_index.upsert.assert_not_called()
@pytest.mark.asyncio
async def test_store_graph_embeddings_index_creation_failure(self, processor):
"""Test handling of index creation failure"""
async def test_store_graph_embeddings_validation_before_creation(self, processor):
"""Test that validation error occurs before any creation attempts"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value="test_entity", is_uri=False),
vectors=[[0.1, 0.2, 0.3]]
)
message.entities = [entity]
# Mock index doesn't exist and creation fails
# Mock index doesn't exist
processor.pinecone.has_index.return_value = False
processor.pinecone.create_index.side_effect = Exception("Index creation failed")
with pytest.raises(Exception, match="Index creation failed"):
with pytest.raises(ValueError, match="Collection .* does not exist"):
await processor.store_graph_embeddings(message)
@pytest.mark.asyncio
async def test_store_graph_embeddings_index_creation_timeout(self, processor):
"""Test handling of index creation timeout"""
async def test_store_graph_embeddings_validates_before_timeout(self, processor):
"""Test that validation error occurs before timeout checks"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
entity=Value(value="test_entity", is_uri=False),
vectors=[[0.1, 0.2, 0.3]]
)
message.entities = [entity]
# Mock index doesn't exist and never becomes ready
# Mock index doesn't exist
processor.pinecone.has_index.return_value = False
processor.pinecone.describe_index.return_value.status = {"ready": False}
with patch('time.sleep'): # Speed up the test
with pytest.raises(RuntimeError, match="Gave up waiting for index creation"):
await processor.store_graph_embeddings(message)
with pytest.raises(ValueError, match="Collection .* does not exist"):
await processor.store_graph_embeddings(message)
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""

View file

@ -43,19 +43,17 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
# Verify processor attributes
assert hasattr(processor, 'qdrant')
assert processor.qdrant == mock_qdrant_instance
assert hasattr(processor, 'last_collection')
assert processor.last_collection is None
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
async def test_get_collection_creates_new_collection(self, mock_base_init, mock_qdrant_client):
"""Test get_collection creates a new collection when it doesn't exist"""
async def test_get_collection_validates_existence(self, mock_base_init, mock_qdrant_client):
"""Test get_collection validates that collection exists"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = False
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
@ -64,22 +62,10 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
}
processor = Processor(**config)
# Act
collection_name = processor.get_collection(dim=512, user='test_user', collection='test_collection')
# Assert
expected_name = 't_test_user_test_collection'
assert collection_name == expected_name
assert processor.last_collection == expected_name
# Verify collection existence check and creation
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_name)
mock_qdrant_instance.create_collection.assert_called_once()
# Verify create_collection was called with correct parameters
create_call_args = mock_qdrant_instance.create_collection.call_args
assert create_call_args[1]['collection_name'] == expected_name
# Act & Assert
with pytest.raises(ValueError, match="Collection .* does not exist"):
processor.get_collection(user='test_user', collection='test_collection')
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid')
@ -142,7 +128,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True # Collection exists
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
@ -151,15 +137,14 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
}
processor = Processor(**config)
# Act
collection_name = processor.get_collection(dim=256, user='existing_user', collection='existing_collection')
collection_name = processor.get_collection(user='existing_user', collection='existing_collection')
# Assert
expected_name = 't_existing_user_existing_collection'
assert collection_name == expected_name
assert processor.last_collection == expected_name
# Verify collection existence check was performed
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_name)
# Verify create_collection was NOT called
@ -167,14 +152,14 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
async def test_get_collection_caches_last_collection(self, mock_base_init, mock_qdrant_client):
"""Test get_collection skips checks when using same collection"""
async def test_get_collection_validates_on_each_call(self, mock_base_init, mock_qdrant_client):
"""Test get_collection validates collection existence on each call"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = True
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
@ -183,36 +168,36 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
}
processor = Processor(**config)
# First call
collection_name1 = processor.get_collection(dim=128, user='cache_user', collection='cache_collection')
collection_name1 = processor.get_collection(user='cache_user', collection='cache_collection')
# Reset mock to track second call
mock_qdrant_instance.reset_mock()
mock_qdrant_instance.collection_exists.return_value = True
# Act - Second call with same parameters
collection_name2 = processor.get_collection(dim=128, user='cache_user', collection='cache_collection')
collection_name2 = processor.get_collection(user='cache_user', collection='cache_collection')
# Assert
expected_name = 't_cache_user_cache_collection'
assert collection_name1 == expected_name
assert collection_name2 == expected_name
# Verify second call skipped existence check (cached)
mock_qdrant_instance.collection_exists.assert_not_called()
# Verify collection existence check happens on each call
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_name)
mock_qdrant_instance.create_collection.assert_not_called()
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.base.GraphEmbeddingsStoreService.__init__')
async def test_get_collection_creation_exception(self, mock_base_init, mock_qdrant_client):
"""Test get_collection handles collection creation exceptions"""
"""Test get_collection raises ValueError when collection doesn't exist"""
# Arrange
mock_base_init.return_value = None
mock_qdrant_instance = MagicMock()
mock_qdrant_instance.collection_exists.return_value = False
mock_qdrant_instance.create_collection.side_effect = Exception("Qdrant connection failed")
mock_qdrant_client.return_value = mock_qdrant_instance
config = {
'store_uri': 'http://localhost:6333',
'api_key': 'test-api-key',
@ -221,10 +206,10 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
}
processor = Processor(**config)
# Act & Assert
with pytest.raises(Exception, match="Qdrant connection failed"):
processor.get_collection(dim=512, user='error_user', collection='error_collection')
with pytest.raises(ValueError, match="Collection .* does not exist"):
processor.get_collection(user='error_user', collection='error_collection')
@patch('trustgraph.storage.graph_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.graph_embeddings.qdrant.write.uuid')

View file

@ -47,7 +47,7 @@ class TestMemgraphUserCollectionIsolation:
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
@ -55,28 +55,30 @@ class TestMemgraphUserCollectionIsolation:
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
mock_driver.execute_query.return_value = mock_result
processor = Processor(taskgroup=MagicMock())
# Create mock triple with URI object
triple = MagicMock()
triple.s.value = "http://example.com/subject"
triple.p.value = "http://example.com/predicate"
triple.o.value = "http://example.com/object"
triple.o.is_uri = True
# Create mock message with metadata
mock_message = MagicMock()
mock_message.triples = [triple]
mock_message.metadata.user = "test_user"
mock_message.metadata.collection = "test_collection"
await processor.store_triples(mock_message)
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(mock_message)
# Verify user/collection parameters were passed to all operations
# Should have: create_node (subject), create_node (object), relate_node = 3 calls
assert mock_driver.execute_query.call_count == 3
# Check that user and collection were included in all calls
for call in mock_driver.execute_query.call_args_list:
call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1]
@ -93,7 +95,7 @@ class TestMemgraphUserCollectionIsolation:
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
@ -101,24 +103,26 @@ class TestMemgraphUserCollectionIsolation:
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
mock_driver.execute_query.return_value = mock_result
processor = Processor(taskgroup=MagicMock())
# Create mock triple
triple = MagicMock()
triple.s.value = "http://example.com/subject"
triple.p.value = "http://example.com/predicate"
triple.o.value = "literal_value"
triple.o.is_uri = False
# Create mock message without user/collection metadata
mock_message = MagicMock()
mock_message.triples = [triple]
mock_message.metadata.user = None
mock_message.metadata.collection = None
await processor.store_triples(mock_message)
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(mock_message)
# Verify defaults were used
for call in mock_driver.execute_query.call_args_list:
call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1]
@ -295,7 +299,7 @@ class TestMemgraphUserCollectionRegression:
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
@ -303,23 +307,25 @@ class TestMemgraphUserCollectionRegression:
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
mock_driver.execute_query.return_value = mock_result
processor = Processor(taskgroup=MagicMock())
# Store data for user1
triple = MagicMock()
triple.s.value = "http://example.com/subject"
triple.p.value = "http://example.com/predicate"
triple.o.value = "user1_data"
triple.o.is_uri = False
message_user1 = MagicMock()
message_user1.triples = [triple]
message_user1.metadata.user = "user1"
message_user1.metadata.collection = "collection1"
await processor.store_triples(message_user1)
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message_user1)
# Verify that all storage operations included user1/collection1 parameters
for call in mock_driver.execute_query.call_args_list:
call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1]

View file

@ -75,8 +75,10 @@ class TestNeo4jUserCollectionIsolation:
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_driver.execute_query.return_value.summary = mock_summary
await processor.store_triples(message)
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message)
# Verify nodes and relationships were created with user/collection properties
expected_calls = [
@ -141,8 +143,10 @@ class TestNeo4jUserCollectionIsolation:
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_driver.execute_query.return_value.summary = mock_summary
await processor.store_triples(message)
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message)
# Verify defaults were used
mock_driver.execute_query.assert_any_call(
@ -273,10 +277,12 @@ class TestNeo4jUserCollectionIsolation:
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_driver.execute_query.return_value.summary = mock_summary
# Store data for both users
await processor.store_triples(message_user1)
await processor.store_triples(message_user2)
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
# Store data for both users
await processor.store_triples(message_user1)
await processor.store_triples(message_user2)
# Verify user1 data was stored with user1/coll1
mock_driver.execute_query.assert_any_call(
@ -446,9 +452,11 @@ class TestNeo4jUserCollectionRegression:
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_driver.execute_query.return_value.summary = mock_summary
await processor.store_triples(message_user1)
await processor.store_triples(message_user2)
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message_user1)
await processor.store_triples(message_user2)
# Verify two separate nodes were created with same URI but different user/collection
user1_node_call = call(

View file

@ -251,6 +251,8 @@ class TestObjectsCassandraStorageLogic:
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
processor.session = MagicMock()
processor.on_object = Processor.on_object.__get__(processor, Processor)
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
processor.known_tables = {"test_user": set()} # Pre-populate
# Create test object
test_obj = ExtractedObject(
@ -291,18 +293,19 @@ class TestObjectsCassandraStorageLogic:
"""Test that secondary indexes are created for indexed fields"""
processor = MagicMock()
processor.schemas = {}
processor.known_keyspaces = set()
processor.known_tables = {}
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
processor.known_tables = {"test_user": set()} # Pre-populate
processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
processor.get_cassandra_type = Processor.get_cassandra_type.__get__(processor, Processor)
def mock_ensure_keyspace(keyspace):
processor.known_keyspaces.add(keyspace)
processor.known_tables[keyspace] = set()
if keyspace not in processor.known_tables:
processor.known_tables[keyspace] = set()
processor.ensure_keyspace = mock_ensure_keyspace
processor.ensure_table = Processor.ensure_table.__get__(processor, Processor)
# Create schema with indexed field
schema = RowSchema(
name="products",
@ -313,10 +316,10 @@ class TestObjectsCassandraStorageLogic:
Field(name="price", type="float", size=8, indexed=True)
]
)
# Call ensure_table
processor.ensure_table("test_user", "products", schema)
# Should have 3 calls: create table + 2 indexes
assert processor.session.execute.call_count == 3
@ -346,9 +349,10 @@ class TestObjectsCassandraStorageBatchLogic:
]
)
}
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
processor.ensure_table = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
processor.sanitize_table = Processor.sanitize_table.__get__(processor, Processor)
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
processor.session = MagicMock()
processor.on_object = Processor.on_object.__get__(processor, Processor)
@ -415,6 +419,8 @@ class TestObjectsCassandraStorageBatchLogic:
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
processor.session = MagicMock()
processor.on_object = Processor.on_object.__get__(processor, Processor)
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
processor.known_tables = {"test_user": set()} # Pre-populate
# Create empty batch object
empty_batch_obj = ExtractedObject(
@ -461,6 +467,8 @@ class TestObjectsCassandraStorageBatchLogic:
processor.convert_value = Processor.convert_value.__get__(processor, Processor)
processor.session = MagicMock()
processor.on_object = Processor.on_object.__get__(processor, Processor)
processor.known_keyspaces = {"test_user"} # Pre-populate to skip validation query
processor.known_tables = {"test_user": set()} # Pre-populate
# Create single-item batch object (backward compatibility case)
single_batch_obj = ExtractedObject(

View file

@ -194,7 +194,13 @@ class TestFalkorDBStorageProcessor:
mock_result.run_time_ms = 10
processor.io.query.return_value = mock_result
await processor.store_triples(message)
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message)
# Verify queries were called in the correct order
expected_calls = [
@ -225,7 +231,13 @@ class TestFalkorDBStorageProcessor:
mock_result.run_time_ms = 10
processor.io.query.return_value = mock_result
await processor.store_triples(mock_message)
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(mock_message)
# Verify queries were called in the correct order
expected_calls = [
@ -273,7 +285,13 @@ class TestFalkorDBStorageProcessor:
mock_result.run_time_ms = 10
processor.io.query.return_value = mock_result
await processor.store_triples(message)
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message)
# Verify total number of queries (3 per triple)
assert processor.io.query.call_count == 6
@ -299,7 +317,13 @@ class TestFalkorDBStorageProcessor:
message.metadata.collection = 'test_collection'
message.triples = []
await processor.store_triples(message)
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message)
# Verify no queries were made
processor.io.query.assert_not_called()
@ -329,7 +353,13 @@ class TestFalkorDBStorageProcessor:
mock_result.run_time_ms = 10
processor.io.query.return_value = mock_result
await processor.store_triples(message)
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message)
# Verify total number of queries (3 per triple)
assert processor.io.query.call_count == 6

View file

@ -308,7 +308,13 @@ class TestMemgraphStorageProcessor:
# Reset the mock to clear initialization calls
processor.io.execute_query.reset_mock()
await processor.store_triples(mock_message)
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(mock_message)
# Verify execute_query was called for create_node, create_literal, and relate_literal
# (since mock_message has a literal object)
@ -352,7 +358,13 @@ class TestMemgraphStorageProcessor:
)
message.triples = [triple1, triple2]
await processor.store_triples(message)
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message)
# Verify execute_query was called:
# Triple1: create_node(s) + create_literal(o) + relate_literal = 3 calls
@ -381,7 +393,13 @@ class TestMemgraphStorageProcessor:
message.metadata.collection = 'test_collection'
message.triples = []
await processor.store_triples(message)
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message)
# Verify no session calls were made (no triples to process)
processor.io.session.assert_not_called()

View file

@ -268,7 +268,9 @@ class TestNeo4jStorageProcessor:
mock_message.metadata.user = "test_user"
mock_message.metadata.collection = "test_collection"
await processor.store_triples(mock_message)
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(mock_message)
# Verify create_node was called for subject and object
# Verify relate_node was called
@ -336,7 +338,9 @@ class TestNeo4jStorageProcessor:
mock_message.metadata.user = "test_user"
mock_message.metadata.collection = "test_collection"
await processor.store_triples(mock_message)
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(mock_message)
# Verify create_node was called for subject
# Verify create_literal was called for object
@ -411,7 +415,9 @@ class TestNeo4jStorageProcessor:
mock_message.metadata.user = "test_user"
mock_message.metadata.collection = "test_collection"
await processor.store_triples(mock_message)
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(mock_message)
# Should have processed both triples
# Triple1: 2 nodes + 1 relationship = 3 calls
@ -437,7 +443,9 @@ class TestNeo4jStorageProcessor:
mock_message.metadata.user = "test_user"
mock_message.metadata.collection = "test_collection"
await processor.store_triples(mock_message)
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(mock_message)
# Should not have made any execute_query calls beyond index creation
# Only index creation calls should have been made during initialization
@ -552,7 +560,9 @@ class TestNeo4jStorageProcessor:
mock_message.metadata.user = "test_user"
mock_message.metadata.collection = "test_collection"
await processor.store_triples(mock_message)
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(mock_message)
# Verify the triple was processed with special characters preserved
mock_driver.execute_query.assert_any_call(

View file

@ -44,7 +44,7 @@ class TestAzureOpenAIProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'gpt-4'
assert processor.default_model == 'gpt-4'
assert processor.temperature == 0.0
assert processor.max_output == 4192
assert hasattr(processor, 'openai')
@ -254,7 +254,7 @@ class TestAzureOpenAIProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'gpt-35-turbo'
assert processor.default_model == 'gpt-35-turbo'
assert processor.temperature == 0.7
assert processor.max_output == 2048
mock_azure_openai_class.assert_called_once_with(
@ -289,7 +289,7 @@ class TestAzureOpenAIProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'gpt-4'
assert processor.default_model == 'gpt-4'
assert processor.temperature == 0.0 # default_temperature
assert processor.max_output == 4192 # default_max_output
mock_azure_openai_class.assert_called_once_with(
@ -402,6 +402,156 @@ class TestAzureOpenAIProcessorSimple(IsolatedAsyncioTestCase):
assert call_args[1]['max_tokens'] == 1024
assert call_args[1]['top_p'] == 1
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
"""Test temperature parameter override functionality"""
# Arrange
mock_azure_client = MagicMock()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = 'Response with custom temperature'
mock_response.usage.prompt_tokens = 20
mock_response.usage.completion_tokens = 12
mock_azure_client.chat.completions.create.return_value = mock_response
mock_azure_openai_class.return_value = mock_azure_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gpt-4',
'endpoint': 'https://test.openai.azure.com/',
'token': 'test-token',
'api_version': '2024-12-01-preview',
'temperature': 0.0, # Default temperature
'max_output': 4192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override temperature at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model=None, # Use default model
temperature=0.8 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom temperature"
# Verify Azure OpenAI API was called with overridden temperature
call_args = mock_azure_client.chat.completions.create.call_args
assert call_args[1]['temperature'] == 0.8 # Should use runtime override
assert call_args[1]['model'] == 'gpt-4'
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
"""Test model parameter override functionality"""
# Arrange
mock_azure_client = MagicMock()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = 'Response with custom model'
mock_response.usage.prompt_tokens = 18
mock_response.usage.completion_tokens = 14
mock_azure_client.chat.completions.create.return_value = mock_response
mock_azure_openai_class.return_value = mock_azure_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gpt-4', # Default model
'endpoint': 'https://test.openai.azure.com/',
'token': 'test-token',
'api_version': '2024-12-01-preview',
'temperature': 0.1, # Default temperature
'max_output': 4192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override model at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="gpt-4o", # Override model
temperature=None # Use default temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom model"
# Verify Azure OpenAI API was called with overridden model
call_args = mock_azure_client.chat.completions.create.call_args
assert call_args[1]['model'] == 'gpt-4o' # Should use runtime override
assert call_args[1]['temperature'] == 0.1 # Should use processor default
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
"""Test overriding both model and temperature parameters simultaneously"""
# Arrange
mock_azure_client = MagicMock()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = 'Response with both overrides'
mock_response.usage.prompt_tokens = 22
mock_response.usage.completion_tokens = 16
mock_azure_client.chat.completions.create.return_value = mock_response
mock_azure_openai_class.return_value = mock_azure_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gpt-4', # Default model
'endpoint': 'https://test.openai.azure.com/',
'token': 'test-token',
'api_version': '2024-12-01-preview',
'temperature': 0.0, # Default temperature
'max_output': 4192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override both parameters at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="gpt-4o-mini", # Override model
temperature=0.9 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with both overrides"
# Verify Azure OpenAI API was called with both overrides
call_args = mock_azure_client.chat.completions.create.call_args
assert call_args[1]['model'] == 'gpt-4o-mini' # Should use runtime override
assert call_args[1]['temperature'] == 0.9 # Should use runtime override
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -43,7 +43,7 @@ class TestAzureProcessorSimple(IsolatedAsyncioTestCase):
assert processor.token == 'test-token'
assert processor.temperature == 0.0
assert processor.max_output == 4192
assert processor.model == 'AzureAI'
assert processor.default_model == 'AzureAI'
@patch('trustgraph.model.text_completion.azure.llm.requests')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@ -261,7 +261,7 @@ class TestAzureProcessorSimple(IsolatedAsyncioTestCase):
assert processor.token == 'custom-token'
assert processor.temperature == 0.7
assert processor.max_output == 2048
assert processor.model == 'AzureAI'
assert processor.default_model == 'AzureAI'
@patch('trustgraph.model.text_completion.azure.llm.requests')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@ -289,7 +289,7 @@ class TestAzureProcessorSimple(IsolatedAsyncioTestCase):
assert processor.token == 'test-token'
assert processor.temperature == 0.0 # default_temperature
assert processor.max_output == 4192 # default_max_output
assert processor.model == 'AzureAI' # default_model
assert processor.default_model == 'AzureAI' # default_model
@patch('trustgraph.model.text_completion.azure.llm.requests')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@ -459,5 +459,150 @@ class TestAzureProcessorSimple(IsolatedAsyncioTestCase):
)
@patch('trustgraph.model.text_completion.azure.llm.requests')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_with_model_override(self, mock_llm_init, mock_async_init, mock_requests):
"""Test generate_content with model parameter override"""
# Arrange
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
'choices': [{
'message': {
'content': 'Response with model override'
}
}],
'usage': {
'prompt_tokens': 15,
'completion_tokens': 10
}
}
mock_requests.post.return_value = mock_response
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
'token': 'test-token',
'temperature': 0.0,
'max_output': 4192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override model
result = await processor.generate_content("System", "Prompt", model="custom-azure-model")
# Assert
assert result.model == "custom-azure-model" # Should use overridden model
assert result.text == "Response with model override"
@patch('trustgraph.model.text_completion.azure.llm.requests')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_with_temperature_override(self, mock_llm_init, mock_async_init, mock_requests):
"""Test generate_content with temperature parameter override"""
# Arrange
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
'choices': [{
'message': {
'content': 'Response with temperature override'
}
}],
'usage': {
'prompt_tokens': 15,
'completion_tokens': 10
}
}
mock_requests.post.return_value = mock_response
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
'token': 'test-token',
'temperature': 0.0, # Default temperature
'max_output': 4192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override temperature
result = await processor.generate_content("System", "Prompt", temperature=0.8)
# Assert
assert result.text == "Response with temperature override"
# Verify the request was made with the overridden temperature
mock_requests.post.assert_called_once()
call_args = mock_requests.post.call_args
import json
request_body = json.loads(call_args[1]['data'])
assert request_body['temperature'] == 0.8
@patch('trustgraph.model.text_completion.azure.llm.requests')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_with_both_parameters_override(self, mock_llm_init, mock_async_init, mock_requests):
"""Test generate_content with both model and temperature overrides"""
# Arrange
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
'choices': [{
'message': {
'content': 'Response with both parameters override'
}
}],
'usage': {
'prompt_tokens': 18,
'completion_tokens': 12
}
}
mock_requests.post.return_value = mock_response
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
'token': 'test-token',
'temperature': 0.0,
'max_output': 4192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override both parameters
result = await processor.generate_content("System", "Prompt", model="override-model", temperature=0.9)
# Assert
assert result.model == "override-model"
assert result.text == "Response with both parameters override"
# Verify the request was made with overridden temperature
mock_requests.post.assert_called_once()
call_args = mock_requests.post.call_args
import json
request_body = json.loads(call_args[1]['data'])
assert request_body['temperature'] == 0.9
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -0,0 +1,280 @@
"""
Unit tests for trustgraph.model.text_completion.bedrock
Following the same successful pattern as other processor tests
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
import json
# Import the service under test
from trustgraph.model.text_completion.bedrock.llm import Processor, Mistral, Anthropic
from trustgraph.base import LlmResult
class TestBedrockProcessorSimple(IsolatedAsyncioTestCase):
"""Test Bedrock processor functionality"""
@patch('trustgraph.model.text_completion.bedrock.llm.boto3.Session')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_session_class):
"""Test basic processor initialization"""
# Arrange
mock_session = MagicMock()
mock_bedrock = MagicMock()
mock_session.client.return_value = mock_bedrock
mock_session_class.return_value = mock_session
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'mistral.mistral-large-2407-v1:0',
'temperature': 0.1,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act
processor = Processor(**config)
# Assert
assert processor.default_model == 'mistral.mistral-large-2407-v1:0'
assert processor.temperature == 0.1
assert hasattr(processor, 'bedrock')
mock_session_class.assert_called_once()
@patch('trustgraph.model.text_completion.bedrock.llm.boto3.Session')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_success_mistral(self, mock_llm_init, mock_async_init, mock_session_class):
"""Test successful content generation with Mistral model"""
# Arrange
mock_session = MagicMock()
mock_bedrock = MagicMock()
mock_session.client.return_value = mock_bedrock
mock_session_class.return_value = mock_session
mock_response = {
'body': MagicMock(),
'ResponseMetadata': {
'HTTPHeaders': {
'x-amzn-bedrock-input-token-count': '15',
'x-amzn-bedrock-output-token-count': '8'
}
}
}
mock_response['body'].read.return_value = json.dumps({
'outputs': [{'text': 'Generated response from Bedrock'}]
})
mock_bedrock.invoke_model.return_value = mock_response
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'mistral.mistral-large-2407-v1:0',
'temperature': 0.0,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("System prompt", "User prompt")
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Generated response from Bedrock"
assert result.in_token == 15
assert result.out_token == 8
assert result.model == 'mistral.mistral-large-2407-v1:0'
mock_bedrock.invoke_model.assert_called_once()
@patch('trustgraph.model.text_completion.bedrock.llm.boto3.Session')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_session_class):
"""Test temperature parameter override functionality"""
# Arrange
mock_session = MagicMock()
mock_bedrock = MagicMock()
mock_session.client.return_value = mock_bedrock
mock_session_class.return_value = mock_session
mock_response = {
'body': MagicMock(),
'ResponseMetadata': {
'HTTPHeaders': {
'x-amzn-bedrock-input-token-count': '20',
'x-amzn-bedrock-output-token-count': '12'
}
}
}
mock_response['body'].read.return_value = json.dumps({
'outputs': [{'text': 'Response with custom temperature'}]
})
mock_bedrock.invoke_model.return_value = mock_response
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'mistral.mistral-large-2407-v1:0',
'temperature': 0.0, # Default temperature
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override temperature at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model=None, # Use default model
temperature=0.8 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom temperature"
# Verify the model variant was created with overridden temperature
# The cache key should include the temperature
cache_key = f"mistral.mistral-large-2407-v1:0:0.8"
assert cache_key in processor.model_variants
variant = processor.model_variants[cache_key]
assert variant.temperature == 0.8
@patch('trustgraph.model.text_completion.bedrock.llm.boto3.Session')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_session_class):
"""Test model parameter override functionality"""
# Arrange
mock_session = MagicMock()
mock_bedrock = MagicMock()
mock_session.client.return_value = mock_bedrock
mock_session_class.return_value = mock_session
mock_response = {
'body': MagicMock(),
'ResponseMetadata': {
'HTTPHeaders': {
'x-amzn-bedrock-input-token-count': '18',
'x-amzn-bedrock-output-token-count': '14'
}
}
}
mock_response['body'].read.return_value = json.dumps({
'content': [{'text': 'Response with custom model'}]
})
mock_bedrock.invoke_model.return_value = mock_response
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'mistral.mistral-large-2407-v1:0', # Default model
'temperature': 0.1, # Default temperature
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override model at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="anthropic.claude-3-sonnet-20240229-v1:0", # Override model
temperature=None # Use default temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom model"
# Verify Bedrock API was called with overridden model
mock_bedrock.invoke_model.assert_called_once()
call_args = mock_bedrock.invoke_model.call_args
assert call_args[1]['modelId'] == "anthropic.claude-3-sonnet-20240229-v1:0"
# Verify the correct model variant (Anthropic) was used
cache_key = f"anthropic.claude-3-sonnet-20240229-v1:0:0.1"
assert cache_key in processor.model_variants
variant = processor.model_variants[cache_key]
assert isinstance(variant, Anthropic)
@patch('trustgraph.model.text_completion.bedrock.llm.boto3.Session')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_session_class):
"""Test overriding both model and temperature parameters simultaneously"""
# Arrange
mock_session = MagicMock()
mock_bedrock = MagicMock()
mock_session.client.return_value = mock_bedrock
mock_session_class.return_value = mock_session
mock_response = {
'body': MagicMock(),
'ResponseMetadata': {
'HTTPHeaders': {
'x-amzn-bedrock-input-token-count': '22',
'x-amzn-bedrock-output-token-count': '16'
}
}
}
mock_response['body'].read.return_value = json.dumps({
'generation': 'Response with both overrides'
})
mock_bedrock.invoke_model.return_value = mock_response
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'mistral.mistral-large-2407-v1:0', # Default model
'temperature': 0.0, # Default temperature
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override both parameters at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="meta.llama3-70b-instruct-v1:0", # Override model (Meta/Llama)
temperature=0.9 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with both overrides"
# Verify Bedrock API was called with both overrides
mock_bedrock.invoke_model.assert_called_once()
call_args = mock_bedrock.invoke_model.call_args
assert call_args[1]['modelId'] == "meta.llama3-70b-instruct-v1:0"
# Verify the correct model variant (Meta) was used with correct temperature
cache_key = f"meta.llama3-70b-instruct-v1:0:0.9"
assert cache_key in processor.model_variants
variant = processor.model_variants[cache_key]
assert variant.temperature == 0.9
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -42,7 +42,7 @@ class TestClaudeProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'claude-3-5-sonnet-20240620'
assert processor.default_model == 'claude-3-5-sonnet-20240620'
assert processor.temperature == 0.0
assert processor.max_output == 8192
assert hasattr(processor, 'claude')
@ -217,7 +217,7 @@ class TestClaudeProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'claude-3-haiku-20240307'
assert processor.default_model == 'claude-3-haiku-20240307'
assert processor.temperature == 0.7
assert processor.max_output == 4096
mock_anthropic_class.assert_called_once_with(api_key='custom-api-key')
@ -246,7 +246,7 @@ class TestClaudeProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'claude-3-5-sonnet-20240620' # default_model
assert processor.default_model == 'claude-3-5-sonnet-20240620' # default_model
assert processor.temperature == 0.0 # default_temperature
assert processor.max_output == 8192 # default_max_output
mock_anthropic_class.assert_called_once_with(api_key='test-api-key')
@ -433,7 +433,157 @@ class TestClaudeProcessorSimple(IsolatedAsyncioTestCase):
# Verify processor has the client
assert processor.claude == mock_claude_client
assert processor.model == 'claude-3-opus-20240229'
assert processor.default_model == 'claude-3-opus-20240229'
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_anthropic_class):
"""Test temperature parameter override functionality"""
# Arrange
mock_claude_client = MagicMock()
mock_response = MagicMock()
mock_response.content = [MagicMock()]
mock_response.content[0].text = "Response with custom temperature"
mock_response.usage.input_tokens = 20
mock_response.usage.output_tokens = 12
mock_claude_client.messages.create.return_value = mock_response
mock_anthropic_class.return_value = mock_claude_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'claude-3-5-sonnet-20240620',
'api_key': 'test-api-key',
'temperature': 0.0, # Default temperature
'max_output': 8192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override temperature at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model=None, # Use default model
temperature=0.9 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom temperature"
# Verify Claude API was called with overridden temperature
mock_claude_client.messages.create.assert_called_once()
call_kwargs = mock_claude_client.messages.create.call_args.kwargs
assert call_kwargs['temperature'] == 0.9 # Should use runtime override
assert call_kwargs['model'] == 'claude-3-5-sonnet-20240620' # Should use processor default
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_anthropic_class):
"""Test model parameter override functionality"""
# Arrange
mock_claude_client = MagicMock()
mock_response = MagicMock()
mock_response.content = [MagicMock()]
mock_response.content[0].text = "Response with custom model"
mock_response.usage.input_tokens = 18
mock_response.usage.output_tokens = 14
mock_claude_client.messages.create.return_value = mock_response
mock_anthropic_class.return_value = mock_claude_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'claude-3-5-sonnet-20240620', # Default model
'api_key': 'test-api-key',
'temperature': 0.2, # Default temperature
'max_output': 8192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override model at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="claude-3-haiku-20240307", # Override model
temperature=None # Use default temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom model"
# Verify Claude API was called with overridden model
mock_claude_client.messages.create.assert_called_once()
call_kwargs = mock_claude_client.messages.create.call_args.kwargs
assert call_kwargs['model'] == 'claude-3-haiku-20240307' # Should use runtime override
assert call_kwargs['temperature'] == 0.2 # Should use processor default
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_anthropic_class):
"""Test overriding both model and temperature parameters simultaneously"""
# Arrange
mock_claude_client = MagicMock()
mock_response = MagicMock()
mock_response.content = [MagicMock()]
mock_response.content[0].text = "Response with both overrides"
mock_response.usage.input_tokens = 22
mock_response.usage.output_tokens = 16
mock_claude_client.messages.create.return_value = mock_response
mock_anthropic_class.return_value = mock_claude_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'claude-3-5-sonnet-20240620', # Default model
'api_key': 'test-api-key',
'temperature': 0.0, # Default temperature
'max_output': 8192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override both parameters at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="claude-3-opus-20240229", # Override model
temperature=0.8 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with both overrides"
# Verify Claude API was called with both overrides
mock_claude_client.messages.create.assert_called_once()
call_kwargs = mock_claude_client.messages.create.call_args.kwargs
assert call_kwargs['model'] == 'claude-3-opus-20240229' # Should use runtime override
assert call_kwargs['temperature'] == 0.8 # Should use runtime override
if __name__ == '__main__':

View file

@ -41,7 +41,7 @@ class TestCohereProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'c4ai-aya-23-8b'
assert processor.default_model == 'c4ai-aya-23-8b'
assert processor.temperature == 0.0
assert hasattr(processor, 'cohere')
mock_cohere_class.assert_called_once_with(api_key='test-api-key')
@ -201,7 +201,7 @@ class TestCohereProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'command-light'
assert processor.default_model == 'command-light'
assert processor.temperature == 0.7
mock_cohere_class.assert_called_once_with(api_key='custom-api-key')
@ -229,7 +229,7 @@ class TestCohereProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'c4ai-aya-23-8b' # default_model
assert processor.default_model == 'c4ai-aya-23-8b' # default_model
assert processor.temperature == 0.0 # default_temperature
mock_cohere_class.assert_called_once_with(api_key='test-api-key')
@ -395,7 +395,7 @@ class TestCohereProcessorSimple(IsolatedAsyncioTestCase):
# Verify processor has the client
assert processor.cohere == mock_cohere_client
assert processor.model == 'command-r'
assert processor.default_model == 'command-r'
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@ -442,6 +442,162 @@ class TestCohereProcessorSimple(IsolatedAsyncioTestCase):
assert call_args[1]['prompt_truncation'] == 'auto'
assert call_args[1]['connectors'] == []
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_cohere_class):
"""Test temperature parameter override functionality"""
# Arrange
mock_cohere_client = MagicMock()
mock_output = MagicMock()
mock_output.text = 'Response with custom temperature'
mock_output.meta.billed_units.input_tokens = 20
mock_output.meta.billed_units.output_tokens = 12
mock_cohere_client.chat.return_value = mock_output
mock_cohere_class.return_value = mock_cohere_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'c4ai-aya-23-8b',
'api_key': 'test-api-key',
'temperature': 0.0, # Default temperature
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override temperature at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model=None, # Use default model
temperature=0.8 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom temperature"
# Verify Cohere API was called with overridden temperature
mock_cohere_client.chat.assert_called_once_with(
model='c4ai-aya-23-8b',
message='User prompt',
preamble='System prompt',
temperature=0.8, # Should use runtime override
chat_history=[],
prompt_truncation='auto',
connectors=[]
)
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_cohere_class):
"""Test model parameter override functionality"""
# Arrange
mock_cohere_client = MagicMock()
mock_output = MagicMock()
mock_output.text = 'Response with custom model'
mock_output.meta.billed_units.input_tokens = 18
mock_output.meta.billed_units.output_tokens = 14
mock_cohere_client.chat.return_value = mock_output
mock_cohere_class.return_value = mock_cohere_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'c4ai-aya-23-8b', # Default model
'api_key': 'test-api-key',
'temperature': 0.1, # Default temperature
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override model at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="command-r-plus", # Override model
temperature=None # Use default temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom model"
# Verify Cohere API was called with overridden model
mock_cohere_client.chat.assert_called_once_with(
model='command-r-plus', # Should use runtime override
message='User prompt',
preamble='System prompt',
temperature=0.1, # Should use processor default
chat_history=[],
prompt_truncation='auto',
connectors=[]
)
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_cohere_class):
"""Test overriding both model and temperature parameters simultaneously"""
# Arrange
mock_cohere_client = MagicMock()
mock_output = MagicMock()
mock_output.text = 'Response with both overrides'
mock_output.meta.billed_units.input_tokens = 22
mock_output.meta.billed_units.output_tokens = 16
mock_cohere_client.chat.return_value = mock_output
mock_cohere_class.return_value = mock_cohere_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'c4ai-aya-23-8b', # Default model
'api_key': 'test-api-key',
'temperature': 0.0, # Default temperature
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override both parameters at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="command-r", # Override model
temperature=0.9 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with both overrides"
# Verify Cohere API was called with both overrides
mock_cohere_client.chat.assert_called_once_with(
model='command-r', # Should use runtime override
message='User prompt',
preamble='System prompt',
temperature=0.9, # Should use runtime override
chat_history=[],
prompt_truncation='auto',
connectors=[]
)
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -42,7 +42,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'gemini-2.0-flash-001'
assert processor.default_model == 'gemini-2.0-flash-001'
assert processor.temperature == 0.0
assert processor.max_output == 8192
assert hasattr(processor, 'client')
@ -205,7 +205,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'gemini-1.5-pro'
assert processor.default_model == 'gemini-1.5-pro'
assert processor.temperature == 0.7
assert processor.max_output == 4096
mock_genai_class.assert_called_once_with(api_key='custom-api-key')
@ -234,7 +234,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'gemini-2.0-flash-001' # default_model
assert processor.default_model == 'gemini-2.0-flash-001' # default_model
assert processor.temperature == 0.0 # default_temperature
assert processor.max_output == 8192 # default_max_output
mock_genai_class.assert_called_once_with(api_key='test-api-key')
@ -431,7 +431,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
# Verify processor has the client
assert processor.client == mock_genai_client
assert processor.model == 'gemini-1.5-flash'
assert processor.default_model == 'gemini-1.5-flash'
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@ -477,6 +477,156 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
# The system instruction should be in the config object
assert call_args[1]['contents'] == "Explain quantum computing"
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_genai_class):
"""Test temperature parameter override functionality"""
# Arrange
mock_genai_client = MagicMock()
mock_response = MagicMock()
mock_response.text = 'Response with custom temperature'
mock_response.usage_metadata.prompt_token_count = 20
mock_response.usage_metadata.candidates_token_count = 12
mock_genai_client.models.generate_content.return_value = mock_response
mock_genai_class.return_value = mock_genai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gemini-2.0-flash-001',
'api_key': 'test-api-key',
'temperature': 0.0, # Default temperature
'max_output': 8192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override temperature at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model=None, # Use default model
temperature=0.8 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom temperature"
# Verify the generation config was created with overridden temperature
cache_key = f"gemini-2.0-flash-001:0.8"
assert cache_key in processor.generation_configs
config_obj = processor.generation_configs[cache_key]
assert config_obj.temperature == 0.8
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_genai_class):
"""Test model parameter override functionality"""
# Arrange
mock_genai_client = MagicMock()
mock_response = MagicMock()
mock_response.text = 'Response with custom model'
mock_response.usage_metadata.prompt_token_count = 18
mock_response.usage_metadata.candidates_token_count = 14
mock_genai_client.models.generate_content.return_value = mock_response
mock_genai_class.return_value = mock_genai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gemini-2.0-flash-001', # Default model
'api_key': 'test-api-key',
'temperature': 0.1, # Default temperature
'max_output': 8192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override model at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="gemini-1.5-pro", # Override model
temperature=None # Use default temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom model"
# Verify Google AI Studio API was called with overridden model
call_args = mock_genai_client.models.generate_content.call_args
assert call_args[1]['model'] == 'gemini-1.5-pro' # Should use runtime override
# Verify the generation config was created for the correct model
cache_key = f"gemini-1.5-pro:0.1"
assert cache_key in processor.generation_configs
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_genai_class):
"""Test overriding both model and temperature parameters simultaneously"""
# Arrange
mock_genai_client = MagicMock()
mock_response = MagicMock()
mock_response.text = 'Response with both overrides'
mock_response.usage_metadata.prompt_token_count = 22
mock_response.usage_metadata.candidates_token_count = 16
mock_genai_client.models.generate_content.return_value = mock_response
mock_genai_class.return_value = mock_genai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gemini-2.0-flash-001', # Default model
'api_key': 'test-api-key',
'temperature': 0.0, # Default temperature
'max_output': 8192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override both parameters at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="gemini-1.5-flash", # Override model
temperature=0.9 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with both overrides"
# Verify Google AI Studio API was called with both overrides
call_args = mock_genai_client.models.generate_content.call_args
assert call_args[1]['model'] == 'gemini-1.5-flash' # Should use runtime override
# Verify the generation config was created with both overrides
cache_key = f"gemini-1.5-flash:0.9"
assert cache_key in processor.generation_configs
config_obj = processor.generation_configs[cache_key]
assert config_obj.temperature == 0.9
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -42,7 +42,7 @@ class TestLlamaFileProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'LLaMA_CPP'
assert processor.default_model == 'LLaMA_CPP'
assert processor.llamafile == 'http://localhost:8080/v1'
assert processor.temperature == 0.0
assert processor.max_output == 4096
@ -91,7 +91,7 @@ class TestLlamaFileProcessorSimple(IsolatedAsyncioTestCase):
assert result.text == "Generated response from LlamaFile"
assert result.in_token == 20
assert result.out_token == 12
assert result.model == 'llama.cpp' # Note: model in result is hardcoded to 'llama.cpp'
assert result.model == 'LLaMA_CPP' # Uses the default model name
# Verify the OpenAI API call structure
mock_openai_client.chat.completions.create.assert_called_once_with(
@ -99,7 +99,15 @@ class TestLlamaFileProcessorSimple(IsolatedAsyncioTestCase):
messages=[{
"role": "user",
"content": "System prompt\n\nUser prompt"
}]
}],
temperature=0.0,
max_tokens=4096,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
response_format={
"type": "text"
}
)
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
@ -157,7 +165,7 @@ class TestLlamaFileProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'custom-llama'
assert processor.default_model == 'custom-llama'
assert processor.llamafile == 'http://custom-host:8080/v1'
assert processor.temperature == 0.7
assert processor.max_output == 2048
@ -189,7 +197,7 @@ class TestLlamaFileProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'LLaMA_CPP' # default_model
assert processor.default_model == 'LLaMA_CPP' # default_model
assert processor.llamafile == 'http://localhost:8080/v1' # default_llamafile
assert processor.temperature == 0.0 # default_temperature
assert processor.max_output == 4096 # default_max_output
@ -237,7 +245,7 @@ class TestLlamaFileProcessorSimple(IsolatedAsyncioTestCase):
assert result.text == "Default response"
assert result.in_token == 2
assert result.out_token == 3
assert result.model == 'llama.cpp'
assert result.model == 'LLaMA_CPP'
# Verify the combined prompt is sent correctly
call_args = mock_openai_client.chat.completions.create.call_args
@ -408,8 +416,8 @@ class TestLlamaFileProcessorSimple(IsolatedAsyncioTestCase):
result = await processor.generate_content("System", "User")
# Assert
assert result.model == 'llama.cpp' # Should always be 'llama.cpp', not 'custom-model-name'
assert processor.model == 'custom-model-name' # But processor.model should still be custom
assert result.model == 'custom-model-name' # Uses the actual model name passed to generate_content
assert processor.default_model == 'custom-model-name' # But processor.model should still be custom
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@ -450,5 +458,132 @@ class TestLlamaFileProcessorSimple(IsolatedAsyncioTestCase):
# No specific rate limit error handling tested since SLM presumably has no rate limits
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_with_model_override(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test generate_content with model parameter override"""
# Arrange
mock_openai_client = MagicMock()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Response from overridden model"
mock_response.usage.prompt_tokens = 15
mock_response.usage.completion_tokens = 10
mock_openai_client.chat.completions.create.return_value = mock_response
mock_openai_class.return_value = mock_openai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'LLaMA_CPP',
'llamafile': 'http://localhost:8080/v1',
'temperature': 0.0,
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override model
result = await processor.generate_content("System", "Prompt", model="custom-llamafile-model")
# Assert
assert result.model == "custom-llamafile-model" # Should use overridden model
assert result.text == "Response from overridden model"
# Verify the API call was made with overridden model
call_args = mock_openai_client.chat.completions.create.call_args
assert call_args[1]['model'] == "custom-llamafile-model"
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_with_temperature_override(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test generate_content with temperature parameter override"""
# Arrange
mock_openai_client = MagicMock()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Response with temperature override"
mock_response.usage.prompt_tokens = 18
mock_response.usage.completion_tokens = 12
mock_openai_client.chat.completions.create.return_value = mock_response
mock_openai_class.return_value = mock_openai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'LLaMA_CPP',
'llamafile': 'http://localhost:8080/v1',
'temperature': 0.0, # Default temperature
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override temperature
result = await processor.generate_content("System", "Prompt", temperature=0.7)
# Assert
assert result.text == "Response with temperature override"
# Verify the API call was made with overridden temperature
call_args = mock_openai_client.chat.completions.create.call_args
assert call_args[1]['temperature'] == 0.7
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_with_both_parameters_override(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test generate_content with both model and temperature overrides"""
# Arrange
mock_openai_client = MagicMock()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Response with both parameters override"
mock_response.usage.prompt_tokens = 20
mock_response.usage.completion_tokens = 15
mock_openai_client.chat.completions.create.return_value = mock_response
mock_openai_class.return_value = mock_openai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'LLaMA_CPP',
'llamafile': 'http://localhost:8080/v1',
'temperature': 0.0,
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override both parameters
result = await processor.generate_content("System", "Prompt", model="override-model", temperature=0.8)
# Assert
assert result.model == "override-model"
assert result.text == "Response with both parameters override"
# Verify the API call was made with overridden parameters
call_args = mock_openai_client.chat.completions.create.call_args
assert call_args[1]['model'] == "override-model"
assert call_args[1]['temperature'] == 0.8
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -0,0 +1,229 @@
"""
Unit tests for trustgraph.model.text_completion.lmstudio
Following the same successful pattern as previous tests
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
# Import the service under test
from trustgraph.model.text_completion.lmstudio.llm import Processor
from trustgraph.base import LlmResult
from trustgraph.exceptions import TooManyRequests
class TestLMStudioProcessorSimple(IsolatedAsyncioTestCase):
"""Test LMStudio processor functionality"""
@patch('trustgraph.model.text_completion.lmstudio.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test basic processor initialization"""
# Arrange
mock_openai = MagicMock()
mock_openai_class.return_value = mock_openai
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gemma3:9b',
'url': 'http://localhost:1234/',
'temperature': 0.0,
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act
processor = Processor(**config)
# Assert
assert processor.default_model == 'gemma3:9b'
assert processor.url == 'http://localhost:1234/v1/'
assert processor.temperature == 0.0
assert processor.max_output == 4096
assert hasattr(processor, 'openai')
mock_openai_class.assert_called_once_with(
base_url='http://localhost:1234/v1/',
api_key='sk-no-key-required'
)
@patch('trustgraph.model.text_completion.lmstudio.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test successful content generation"""
# Arrange
mock_openai = MagicMock()
mock_response = MagicMock()
mock_response.choices[0].message.content = 'Generated response from LMStudio'
mock_response.usage.prompt_tokens = 20
mock_response.usage.completion_tokens = 12
mock_openai.chat.completions.create.return_value = mock_response
mock_openai_class.return_value = mock_openai
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gemma3:9b',
'url': 'http://localhost:1234/',
'temperature': 0.0,
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("System prompt", "User prompt")
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Generated response from LMStudio"
assert result.in_token == 20
assert result.out_token == 12
assert result.model == 'gemma3:9b'
# Verify the API call was made correctly
mock_openai.chat.completions.create.assert_called_once()
call_args = mock_openai.chat.completions.create.call_args
# Check model and temperature
assert call_args[1]['model'] == 'gemma3:9b'
assert call_args[1]['temperature'] == 0.0
assert call_args[1]['max_tokens'] == 4096
@patch('trustgraph.model.text_completion.lmstudio.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_with_model_override(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test generate_content with model parameter override"""
# Arrange
mock_openai = MagicMock()
mock_response = MagicMock()
mock_response.choices[0].message.content = 'Response from overridden model'
mock_response.usage.prompt_tokens = 15
mock_response.usage.completion_tokens = 10
mock_openai.chat.completions.create.return_value = mock_response
mock_openai_class.return_value = mock_openai
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gemma3:9b',
'url': 'http://localhost:1234/',
'temperature': 0.0,
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override model
result = await processor.generate_content("System", "Prompt", model="custom-lmstudio-model")
# Assert
assert result.model == "custom-lmstudio-model" # Should use overridden model
assert result.text == "Response from overridden model"
# Verify the API call was made with overridden model
call_args = mock_openai.chat.completions.create.call_args
assert call_args[1]['model'] == "custom-lmstudio-model"
@patch('trustgraph.model.text_completion.lmstudio.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_with_temperature_override(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test generate_content with temperature parameter override"""
# Arrange
mock_openai = MagicMock()
mock_response = MagicMock()
mock_response.choices[0].message.content = 'Response with temperature override'
mock_response.usage.prompt_tokens = 18
mock_response.usage.completion_tokens = 12
mock_openai.chat.completions.create.return_value = mock_response
mock_openai_class.return_value = mock_openai
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gemma3:9b',
'url': 'http://localhost:1234/',
'temperature': 0.0, # Default temperature
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override temperature
result = await processor.generate_content("System", "Prompt", temperature=0.7)
# Assert
assert result.text == "Response with temperature override"
# Verify the API call was made with overridden temperature
call_args = mock_openai.chat.completions.create.call_args
assert call_args[1]['temperature'] == 0.7
@patch('trustgraph.model.text_completion.lmstudio.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_with_both_parameters_override(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test generate_content with both model and temperature overrides"""
# Arrange
mock_openai = MagicMock()
mock_response = MagicMock()
mock_response.choices[0].message.content = 'Response with both parameters override'
mock_response.usage.prompt_tokens = 20
mock_response.usage.completion_tokens = 15
mock_openai.chat.completions.create.return_value = mock_response
mock_openai_class.return_value = mock_openai
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gemma3:9b',
'url': 'http://localhost:1234/',
'temperature': 0.0,
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override both parameters
result = await processor.generate_content("System", "Prompt", model="override-model", temperature=0.8)
# Assert
assert result.model == "override-model"
assert result.text == "Response with both parameters override"
# Verify the API call was made with overridden parameters
call_args = mock_openai.chat.completions.create.call_args
assert call_args[1]['model'] == "override-model"
assert call_args[1]['temperature'] == 0.8
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -0,0 +1,275 @@
"""
Unit tests for trustgraph.model.text_completion.mistral
Following the same successful pattern as other processor tests
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
# Import the service under test
from trustgraph.model.text_completion.mistral.llm import Processor
from trustgraph.base import LlmResult
class TestMistralProcessorSimple(IsolatedAsyncioTestCase):
"""Test Mistral processor functionality"""
@patch('trustgraph.model.text_completion.mistral.llm.Mistral')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_mistral_class):
"""Test basic processor initialization"""
# Arrange
mock_mistral_client = MagicMock()
mock_mistral_class.return_value = mock_mistral_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'ministral-8b-latest',
'api_key': 'test-api-key',
'temperature': 0.1,
'max_output': 2048,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act
processor = Processor(**config)
# Assert
assert processor.default_model == 'ministral-8b-latest'
assert processor.temperature == 0.1
assert processor.max_output == 2048
assert hasattr(processor, 'mistral')
mock_mistral_class.assert_called_once_with(api_key='test-api-key')
@patch('trustgraph.model.text_completion.mistral.llm.Mistral')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_mistral_class):
"""Test successful content generation"""
# Arrange
mock_mistral_client = MagicMock()
mock_response = MagicMock()
mock_response.choices[0].message.content = 'Generated response from Mistral'
mock_response.usage.prompt_tokens = 15
mock_response.usage.completion_tokens = 8
mock_mistral_client.chat.complete.return_value = mock_response
mock_mistral_class.return_value = mock_mistral_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'ministral-8b-latest',
'api_key': 'test-api-key',
'temperature': 0.0,
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("System prompt", "User prompt")
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Generated response from Mistral"
assert result.in_token == 15
assert result.out_token == 8
assert result.model == 'ministral-8b-latest'
mock_mistral_client.chat.complete.assert_called_once()
@patch('trustgraph.model.text_completion.mistral.llm.Mistral')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_mistral_class):
"""Test temperature parameter override functionality"""
# Arrange
mock_mistral_client = MagicMock()
mock_response = MagicMock()
mock_response.choices[0].message.content = 'Response with custom temperature'
mock_response.usage.prompt_tokens = 20
mock_response.usage.completion_tokens = 12
mock_mistral_client.chat.complete.return_value = mock_response
mock_mistral_class.return_value = mock_mistral_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'ministral-8b-latest',
'api_key': 'test-api-key',
'temperature': 0.0, # Default temperature
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override temperature at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model=None, # Use default model
temperature=0.8 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom temperature"
# Verify Mistral API was called with overridden temperature
call_args = mock_mistral_client.chat.complete.call_args
assert call_args[1]['temperature'] == 0.8 # Should use runtime override
assert call_args[1]['model'] == 'ministral-8b-latest'
@patch('trustgraph.model.text_completion.mistral.llm.Mistral')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_mistral_class):
"""Test model parameter override functionality"""
# Arrange
mock_mistral_client = MagicMock()
mock_response = MagicMock()
mock_response.choices[0].message.content = 'Response with custom model'
mock_response.usage.prompt_tokens = 18
mock_response.usage.completion_tokens = 14
mock_mistral_client.chat.complete.return_value = mock_response
mock_mistral_class.return_value = mock_mistral_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'ministral-8b-latest', # Default model
'api_key': 'test-api-key',
'temperature': 0.1, # Default temperature
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override model at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="mistral-large-latest", # Override model
temperature=None # Use default temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom model"
# Verify Mistral API was called with overridden model
call_args = mock_mistral_client.chat.complete.call_args
assert call_args[1]['model'] == 'mistral-large-latest' # Should use runtime override
assert call_args[1]['temperature'] == 0.1 # Should use processor default
@patch('trustgraph.model.text_completion.mistral.llm.Mistral')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_mistral_class):
"""Test overriding both model and temperature parameters simultaneously"""
# Arrange
mock_mistral_client = MagicMock()
mock_response = MagicMock()
mock_response.choices[0].message.content = 'Response with both overrides'
mock_response.usage.prompt_tokens = 22
mock_response.usage.completion_tokens = 16
mock_mistral_client.chat.complete.return_value = mock_response
mock_mistral_class.return_value = mock_mistral_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'ministral-8b-latest', # Default model
'api_key': 'test-api-key',
'temperature': 0.0, # Default temperature
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override both parameters at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="mistral-large-latest", # Override model
temperature=0.9 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with both overrides"
# Verify Mistral API was called with both overrides
call_args = mock_mistral_client.chat.complete.call_args
assert call_args[1]['model'] == 'mistral-large-latest' # Should use runtime override
assert call_args[1]['temperature'] == 0.9 # Should use runtime override
@patch('trustgraph.model.text_completion.mistral.llm.Mistral')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_prompt_construction(self, mock_llm_init, mock_async_init, mock_mistral_class):
"""Test prompt construction with system and user prompts"""
# Arrange
mock_mistral_client = MagicMock()
mock_response = MagicMock()
mock_response.choices[0].message.content = 'Response with system instructions'
mock_response.usage.prompt_tokens = 25
mock_response.usage.completion_tokens = 15
mock_mistral_client.chat.complete.return_value = mock_response
mock_mistral_class.return_value = mock_mistral_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'ministral-8b-latest',
'api_key': 'test-api-key',
'temperature': 0.0,
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("You are a helpful assistant", "What is AI?")
# Assert
assert result.text == "Response with system instructions"
assert result.in_token == 25
assert result.out_token == 15
# Verify the combined prompt structure
call_args = mock_mistral_client.chat.complete.call_args
messages = call_args[1]['messages']
assert len(messages) == 1
assert messages[0]['role'] == 'user'
assert messages[0]['content'][0]['type'] == 'text'
assert messages[0]['content'][0]['text'] == "You are a helpful assistant\n\nWhat is AI?"
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -40,7 +40,7 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'llama2'
assert processor.default_model == 'llama2'
assert hasattr(processor, 'llm')
mock_client_class.assert_called_once_with(host='http://localhost:11434')
@ -81,7 +81,7 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
assert result.in_token == 15
assert result.out_token == 8
assert result.model == 'llama2'
mock_client.generate.assert_called_once_with('llama2', "System prompt\n\nUser prompt")
mock_client.generate.assert_called_once_with('llama2', "System prompt\n\nUser prompt", options={'temperature': 0.0})
@patch('trustgraph.model.text_completion.ollama.llm.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@ -134,7 +134,7 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'mistral'
assert processor.default_model == 'mistral'
mock_client_class.assert_called_once_with(host='http://192.168.1.100:11434')
@patch('trustgraph.model.text_completion.ollama.llm.Client')
@ -160,7 +160,7 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'gemma2:9b' # default_model
assert processor.default_model == 'gemma2:9b' # default_model
# Should use default_ollama (http://localhost:11434 or from OLLAMA_HOST env)
mock_client_class.assert_called_once()
@ -203,7 +203,7 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
assert result.model == 'llama2'
# The prompt should be "" + "\n\n" + "" = "\n\n"
mock_client.generate.assert_called_once_with('llama2', "\n\n")
mock_client.generate.assert_called_once_with('llama2', "\n\n", options={'temperature': 0.0})
@patch('trustgraph.model.text_completion.ollama.llm.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@ -310,7 +310,151 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
assert result.out_token == 15
# Verify the combined prompt
mock_client.generate.assert_called_once_with('llama2', "You are a helpful assistant\n\nWhat is AI?")
mock_client.generate.assert_called_once_with('llama2', "You are a helpful assistant\n\nWhat is AI?", options={'temperature': 0.0})
@patch('trustgraph.model.text_completion.ollama.llm.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test temperature parameter override functionality"""
# Arrange
mock_client = MagicMock()
mock_response = {
'response': 'Response with custom temperature',
'prompt_eval_count': 20,
'eval_count': 12
}
mock_client.generate.return_value = mock_response
mock_client_class.return_value = mock_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'llama2',
'ollama': 'http://localhost:11434',
'temperature': 0.0, # Default temperature
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override temperature at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model=None, # Use default model
temperature=0.8 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom temperature"
# Verify Ollama API was called with overridden temperature
mock_client.generate.assert_called_once_with(
'llama2',
"System prompt\n\nUser prompt",
options={'temperature': 0.8} # Should use runtime override
)
@patch('trustgraph.model.text_completion.ollama.llm.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test model parameter override functionality"""
# Arrange
mock_client = MagicMock()
mock_response = {
'response': 'Response with custom model',
'prompt_eval_count': 18,
'eval_count': 14
}
mock_client.generate.return_value = mock_response
mock_client_class.return_value = mock_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'llama2', # Default model
'ollama': 'http://localhost:11434',
'temperature': 0.1, # Default temperature
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override model at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="mistral", # Override model
temperature=None # Use default temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom model"
# Verify Ollama API was called with overridden model
mock_client.generate.assert_called_once_with(
'mistral', # Should use runtime override
"System prompt\n\nUser prompt",
options={'temperature': 0.1} # Should use processor default
)
@patch('trustgraph.model.text_completion.ollama.llm.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test overriding both model and temperature parameters simultaneously"""
# Arrange
mock_client = MagicMock()
mock_response = {
'response': 'Response with both overrides',
'prompt_eval_count': 22,
'eval_count': 16
}
mock_client.generate.return_value = mock_response
mock_client_class.return_value = mock_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'llama2', # Default model
'ollama': 'http://localhost:11434',
'temperature': 0.0, # Default temperature
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override both parameters at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="codellama", # Override model
temperature=0.9 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with both overrides"
# Verify Ollama API was called with both overrides
mock_client.generate.assert_called_once_with(
'codellama', # Should use runtime override
"System prompt\n\nUser prompt",
options={'temperature': 0.9} # Should use runtime override
)
if __name__ == '__main__':

View file

@ -43,7 +43,7 @@ class TestOpenAIProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'gpt-3.5-turbo'
assert processor.default_model == 'gpt-3.5-turbo'
assert processor.temperature == 0.0
assert processor.max_output == 4096
assert hasattr(processor, 'openai')
@ -222,7 +222,7 @@ class TestOpenAIProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'gpt-4'
assert processor.default_model == 'gpt-4'
assert processor.temperature == 0.7
assert processor.max_output == 2048
mock_openai_class.assert_called_once_with(base_url='https://custom-openai-url.com/v1', api_key='custom-api-key')
@ -251,7 +251,7 @@ class TestOpenAIProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'gpt-3.5-turbo' # default_model
assert processor.default_model == 'gpt-3.5-turbo' # default_model
assert processor.temperature == 0.0 # default_temperature
assert processor.max_output == 4096 # default_max_output
mock_openai_class.assert_called_once_with(base_url='https://api.openai.com/v1', api_key='test-api-key')
@ -391,5 +391,210 @@ class TestOpenAIProcessorSimple(IsolatedAsyncioTestCase):
assert call_args[1]['response_format'] == {"type": "text"}
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test temperature parameter override functionality"""
# Arrange
mock_openai_client = MagicMock()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Response with custom temperature"
mock_response.usage.prompt_tokens = 15
mock_response.usage.completion_tokens = 10
mock_openai_client.chat.completions.create.return_value = mock_response
mock_openai_class.return_value = mock_openai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gpt-3.5-turbo',
'api_key': 'test-api-key',
'url': 'https://api.openai.com/v1',
'temperature': 0.0, # Default temperature
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override temperature at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model=None, # Use default model
temperature=0.9 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom temperature"
# Verify the OpenAI API was called with overridden temperature
mock_openai_client.chat.completions.create.assert_called_once()
call_kwargs = mock_openai_client.chat.completions.create.call_args.kwargs
assert call_kwargs['temperature'] == 0.9 # Should use runtime override
assert call_kwargs['model'] == 'gpt-3.5-turbo' # Should use processor default
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test model parameter override functionality"""
# Arrange
mock_openai_client = MagicMock()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Response with custom model"
mock_response.usage.prompt_tokens = 15
mock_response.usage.completion_tokens = 10
mock_openai_client.chat.completions.create.return_value = mock_response
mock_openai_class.return_value = mock_openai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gpt-3.5-turbo', # Default model
'api_key': 'test-api-key',
'url': 'https://api.openai.com/v1',
'temperature': 0.2,
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override model at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="gpt-4", # Override model
temperature=None # Use default temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom model"
# Verify the OpenAI API was called with overridden model
mock_openai_client.chat.completions.create.assert_called_once()
call_kwargs = mock_openai_client.chat.completions.create.call_args.kwargs
assert call_kwargs['model'] == 'gpt-4' # Should use runtime override
assert call_kwargs['temperature'] == 0.2 # Should use processor default
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test overriding both model and temperature parameters simultaneously"""
# Arrange
mock_openai_client = MagicMock()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Response with both overrides"
mock_response.usage.prompt_tokens = 15
mock_response.usage.completion_tokens = 10
mock_openai_client.chat.completions.create.return_value = mock_response
mock_openai_class.return_value = mock_openai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gpt-3.5-turbo', # Default model
'api_key': 'test-api-key',
'url': 'https://api.openai.com/v1',
'temperature': 0.0, # Default temperature
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override both parameters at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="gpt-4", # Override model
temperature=0.7 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with both overrides"
# Verify the OpenAI API was called with both overrides
mock_openai_client.chat.completions.create.assert_called_once()
call_kwargs = mock_openai_client.chat.completions.create.call_args.kwargs
assert call_kwargs['model'] == 'gpt-4' # Should use runtime override
assert call_kwargs['temperature'] == 0.7 # Should use runtime override
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_no_override_uses_defaults(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test that when no parameters are overridden, processor defaults are used"""
# Arrange
mock_openai_client = MagicMock()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Response with defaults"
mock_response.usage.prompt_tokens = 15
mock_response.usage.completion_tokens = 10
mock_openai_client.chat.completions.create.return_value = mock_response
mock_openai_class.return_value = mock_openai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gpt-4', # Default model
'api_key': 'test-api-key',
'url': 'https://api.openai.com/v1',
'temperature': 0.5, # Default temperature
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Don't override any parameters (pass None)
result = await processor.generate_content(
"System prompt",
"User prompt",
model=None, # Use default model
temperature=None # Use default temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with defaults"
# Verify the OpenAI API was called with processor defaults
mock_openai_client.chat.completions.create.assert_called_once()
call_kwargs = mock_openai_client.chat.completions.create.call_args.kwargs
assert call_kwargs['model'] == 'gpt-4' # Should use processor default
assert call_kwargs['temperature'] == 0.5 # Should use processor default
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -0,0 +1,186 @@
"""
Unit tests for Parameter-Based Caching in LLM Processors
Testing processors that cache based on temperature parameters (Bedrock, GoogleAIStudio)
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
from trustgraph.model.text_completion.googleaistudio.llm import Processor as GoogleAIProcessor
from trustgraph.base import LlmResult
class TestParameterCaching(IsolatedAsyncioTestCase):
"""Test parameter-based caching functionality"""
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_googleai_temperature_cache_keys(self, mock_llm_init, mock_async_init, mock_genai):
"""Test that GoogleAI processor creates separate cache entries for different temperatures"""
# Arrange
mock_client = MagicMock()
mock_genai.Client.return_value = mock_client
mock_response = MagicMock()
mock_response.text = "Generated response"
mock_response.usage_metadata.prompt_token_count = 10
mock_response.usage_metadata.candidates_token_count = 5
mock_client.models.generate_content.return_value = mock_response
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gemini-2.0-flash-001',
'api_key': 'test-api-key',
'temperature': 0.0, # Default temperature
'max_output': 1024,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = GoogleAIProcessor(**config)
# Act - Call with different temperatures
await processor.generate_content("System", "Prompt 1", model="gemini-2.0-flash-001", temperature=0.0)
await processor.generate_content("System", "Prompt 2", model="gemini-2.0-flash-001", temperature=0.5)
await processor.generate_content("System", "Prompt 3", model="gemini-2.0-flash-001", temperature=1.0)
# Assert - Should have 3 different cache entries
cache_keys = list(processor.generation_configs.keys())
assert len(cache_keys) == 3
assert "gemini-2.0-flash-001:0.0" in cache_keys
assert "gemini-2.0-flash-001:0.5" in cache_keys
assert "gemini-2.0-flash-001:1.0" in cache_keys
# Verify each cached config has the correct temperature
assert processor.generation_configs["gemini-2.0-flash-001:0.0"].temperature == 0.0
assert processor.generation_configs["gemini-2.0-flash-001:0.5"].temperature == 0.5
assert processor.generation_configs["gemini-2.0-flash-001:1.0"].temperature == 1.0
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_googleai_cache_reuse_same_parameters(self, mock_llm_init, mock_async_init, mock_genai):
"""Test that GoogleAI processor reuses cache for identical model+temperature combinations"""
# Arrange
mock_client = MagicMock()
mock_genai.Client.return_value = mock_client
mock_response = MagicMock()
mock_response.text = "Generated response"
mock_response.usage_metadata.prompt_token_count = 10
mock_response.usage_metadata.candidates_token_count = 5
mock_client.models.generate_content.return_value = mock_response
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gemini-2.0-flash-001',
'api_key': 'test-api-key',
'temperature': 0.0,
'max_output': 1024,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = GoogleAIProcessor(**config)
# Act - Call multiple times with same parameters
await processor.generate_content("System", "Prompt 1", model="gemini-2.0-flash-001", temperature=0.7)
await processor.generate_content("System", "Prompt 2", model="gemini-2.0-flash-001", temperature=0.7)
await processor.generate_content("System", "Prompt 3", model="gemini-2.0-flash-001", temperature=0.7)
# Assert - Should have only 1 cache entry for the repeated parameters
cache_keys = list(processor.generation_configs.keys())
assert len(cache_keys) == 1
assert "gemini-2.0-flash-001:0.7" in cache_keys
# The same config object should be reused
config_obj = processor.generation_configs["gemini-2.0-flash-001:0.7"]
assert config_obj.temperature == 0.7
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_googleai_different_models_separate_caches(self, mock_llm_init, mock_async_init, mock_genai):
"""Test that different models create separate cache entries even with same temperature"""
# Arrange
mock_client = MagicMock()
mock_genai.Client.return_value = mock_client
mock_response = MagicMock()
mock_response.text = "Generated response"
mock_response.usage_metadata.prompt_token_count = 10
mock_response.usage_metadata.candidates_token_count = 5
mock_client.models.generate_content.return_value = mock_response
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gemini-2.0-flash-001',
'api_key': 'test-api-key',
'temperature': 0.0,
'max_output': 1024,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = GoogleAIProcessor(**config)
# Act - Call with different models, same temperature
await processor.generate_content("System", "Prompt 1", model="gemini-2.0-flash-001", temperature=0.5)
await processor.generate_content("System", "Prompt 2", model="gemini-1.5-flash-001", temperature=0.5)
# Assert - Should have separate cache entries for different models
cache_keys = list(processor.generation_configs.keys())
assert len(cache_keys) == 2
assert "gemini-2.0-flash-001:0.5" in cache_keys
assert "gemini-1.5-flash-001:0.5" in cache_keys
# Note: Bedrock tests would be similar but testing the Bedrock processor's caching behavior
# The Bedrock processor caches model variants with temperature in the cache key
async def test_bedrock_temperature_cache_keys(self):
"""Test Bedrock processor temperature-aware caching"""
# This would test the Bedrock processor's _get_or_create_variant method
# with different temperature values to ensure proper cache key generation
# Implementation would follow similar pattern to GoogleAI tests above
# but using the Bedrock processor and testing model_variants cache
pass
async def test_bedrock_cache_isolation_different_temperatures(self):
"""Test that Bedrock processor isolates cache entries by temperature"""
pass
async def test_cache_memory_efficiency(self):
"""Test that caches don't grow unbounded with many different parameter combinations"""
# This could test cache size limits or cleanup behavior if implemented
pass
class TestCachePerformance(IsolatedAsyncioTestCase):
"""Test caching performance characteristics"""
async def test_cache_hit_performance(self):
"""Test that cache hits are faster than cache misses"""
# This would measure timing differences between cache hits and misses
pass
async def test_concurrent_cache_access(self):
"""Test concurrent access to cached configurations"""
# This would test thread-safety of cache access
pass
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -0,0 +1,271 @@
"""
Unit tests for trustgraph.model.text_completion.tgi
Following the same successful pattern as previous tests
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
# Import the service under test
from trustgraph.model.text_completion.tgi.llm import Processor
from trustgraph.base import LlmResult
from trustgraph.exceptions import TooManyRequests
class TestTGIProcessorSimple(IsolatedAsyncioTestCase):
"""Test TGI processor functionality"""
@patch('trustgraph.model.text_completion.tgi.llm.aiohttp.ClientSession')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_session_class):
"""Test basic processor initialization"""
# Arrange
mock_session = MagicMock()
mock_session_class.return_value = mock_session
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'tgi',
'url': 'http://tgi-service:8899/v1',
'temperature': 0.0,
'max_output': 2048,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act
processor = Processor(**config)
# Assert
assert processor.default_model == 'tgi'
assert processor.base_url == 'http://tgi-service:8899/v1'
assert processor.temperature == 0.0
assert processor.max_output == 2048
assert hasattr(processor, 'session')
mock_session_class.assert_called_once()
@patch('trustgraph.model.text_completion.tgi.llm.aiohttp.ClientSession')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_session_class):
"""Test successful content generation"""
# Arrange
mock_session = MagicMock()
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
'choices': [{
'message': {
'content': 'Generated response from TGI'
}
}],
'usage': {
'prompt_tokens': 20,
'completion_tokens': 12
}
})
# Mock the async context manager
mock_session.post.return_value.__aenter__.return_value = mock_response
mock_session.post.return_value.__aexit__.return_value = None
mock_session_class.return_value = mock_session
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'tgi',
'url': 'http://tgi-service:8899/v1',
'temperature': 0.0,
'max_output': 2048,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("System prompt", "User prompt")
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Generated response from TGI"
assert result.in_token == 20
assert result.out_token == 12
assert result.model == 'tgi'
# Verify the API call was made correctly
mock_session.post.assert_called_once()
call_args = mock_session.post.call_args
# Check URL
assert call_args[0][0] == 'http://tgi-service:8899/v1/chat/completions'
# Check request structure
request_body = call_args[1]['json']
assert request_body['model'] == 'tgi'
assert request_body['temperature'] == 0.0
assert request_body['max_tokens'] == 2048
@patch('trustgraph.model.text_completion.tgi.llm.aiohttp.ClientSession')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_with_model_override(self, mock_llm_init, mock_async_init, mock_session_class):
"""Test generate_content with model parameter override"""
# Arrange
mock_session = MagicMock()
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
'choices': [{
'message': {
'content': 'Response from overridden model'
}
}],
'usage': {
'prompt_tokens': 15,
'completion_tokens': 10
}
})
mock_session.post.return_value.__aenter__.return_value = mock_response
mock_session.post.return_value.__aexit__.return_value = None
mock_session_class.return_value = mock_session
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'tgi',
'url': 'http://tgi-service:8899/v1',
'temperature': 0.0,
'max_output': 2048,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override model
result = await processor.generate_content("System", "Prompt", model="custom-tgi-model")
# Assert
assert result.model == "custom-tgi-model" # Should use overridden model
assert result.text == "Response from overridden model"
# Verify the API call was made with overridden model
call_args = mock_session.post.call_args
assert call_args[1]['json']['model'] == "custom-tgi-model"
@patch('trustgraph.model.text_completion.tgi.llm.aiohttp.ClientSession')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_with_temperature_override(self, mock_llm_init, mock_async_init, mock_session_class):
"""Test generate_content with temperature parameter override"""
# Arrange
mock_session = MagicMock()
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
'choices': [{
'message': {
'content': 'Response with temperature override'
}
}],
'usage': {
'prompt_tokens': 18,
'completion_tokens': 12
}
})
mock_session.post.return_value.__aenter__.return_value = mock_response
mock_session.post.return_value.__aexit__.return_value = None
mock_session_class.return_value = mock_session
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'tgi',
'url': 'http://tgi-service:8899/v1',
'temperature': 0.0, # Default temperature
'max_output': 2048,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override temperature
result = await processor.generate_content("System", "Prompt", temperature=0.7)
# Assert
assert result.text == "Response with temperature override"
# Verify the API call was made with overridden temperature
call_args = mock_session.post.call_args
assert call_args[1]['json']['temperature'] == 0.7
@patch('trustgraph.model.text_completion.tgi.llm.aiohttp.ClientSession')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_with_both_parameters_override(self, mock_llm_init, mock_async_init, mock_session_class):
"""Test generate_content with both model and temperature overrides"""
# Arrange
mock_session = MagicMock()
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
'choices': [{
'message': {
'content': 'Response with both parameters override'
}
}],
'usage': {
'prompt_tokens': 20,
'completion_tokens': 15
}
})
mock_session.post.return_value.__aenter__.return_value = mock_response
mock_session.post.return_value.__aexit__.return_value = None
mock_session_class.return_value = mock_session
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'tgi',
'url': 'http://tgi-service:8899/v1',
'temperature': 0.0,
'max_output': 2048,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override both parameters
result = await processor.generate_content("System", "Prompt", model="override-model", temperature=0.8)
# Assert
assert result.model == "override-model"
assert result.text == "Response with both parameters override"
# Verify the API call was made with overridden parameters
call_args = mock_session.post.call_args
assert call_args[1]['json']['model'] == "override-model"
assert call_args[1]['json']['temperature'] == 0.8
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -47,10 +47,10 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'gemini-2.0-flash-001' # It's stored as 'model', not 'model_name'
assert hasattr(processor, 'generation_config')
assert processor.default_model == 'gemini-2.0-flash-001' # It's stored as 'model', not 'model_name'
assert hasattr(processor, 'generation_configs') # Now a cache dictionary
assert hasattr(processor, 'safety_settings')
assert hasattr(processor, 'llm')
assert hasattr(processor, 'model_clients') # LLM clients are now cached
mock_service_account.Credentials.from_service_account_file.assert_called_once_with('private.json')
mock_vertexai.init.assert_called_once()
@ -102,7 +102,8 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
mock_model.generate_content.assert_called_once()
# Verify the call was made with the expected parameters
call_args = mock_model.generate_content.call_args
assert call_args[1]['generation_config'] == processor.generation_config
# Generation config is now created dynamically per model
assert 'generation_config' in call_args[1]
assert call_args[1]['safety_settings'] == processor.safety_settings
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@ -223,7 +224,7 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'gemini-2.0-flash-001'
assert processor.default_model == 'gemini-2.0-flash-001'
mock_auth_default.assert_called_once()
mock_vertexai.init.assert_called_once_with(
location='us-central1',
@ -296,11 +297,11 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'gemini-1.5-pro'
assert processor.default_model == 'gemini-1.5-pro'
# Verify that generation_config object exists (can't easily check internal values)
assert hasattr(processor, 'generation_config')
assert processor.generation_config is not None
assert hasattr(processor, 'generation_configs') # Now a cache dictionary
assert processor.generation_configs == {} # Empty cache initially
# Verify that safety settings are configured
assert len(processor.safety_settings) == 4
@ -353,8 +354,8 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
project='test-project-123'
)
# Verify GenerativeModel was created with the right model name
mock_generative_model.assert_called_once_with('gemini-2.0-flash-001')
# GenerativeModel is now created lazily on first use, not at initialization
mock_generative_model.assert_not_called()
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@ -440,8 +441,8 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'claude-3-sonnet@20240229'
assert processor.is_anthropic == True
assert processor.default_model == 'claude-3-sonnet@20240229'
# is_anthropic logic is now determined dynamically per request
# Verify service account was called with private key
mock_service_account.Credentials.from_service_account_file.assert_called_once_with('anthropic-key.json')
@ -459,6 +460,180 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
assert processor.api_params["top_p"] == 1.0
assert processor.api_params["top_k"] == 32
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
"""Test temperature parameter override functionality"""
# Arrange
mock_credentials = MagicMock()
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
mock_model = MagicMock()
mock_response = MagicMock()
mock_response.text = "Response with custom temperature"
mock_response.usage_metadata.prompt_token_count = 20
mock_response.usage_metadata.candidates_token_count = 12
mock_model.generate_content.return_value = mock_response
mock_generative_model.return_value = mock_model
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'region': 'us-central1',
'model': 'gemini-2.0-flash-001',
'temperature': 0.0, # Default temperature
'max_output': 8192,
'private_key': 'private.json',
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override temperature at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model=None, # Use default model
temperature=0.8 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom temperature"
# Verify Gemini API was called with overridden temperature
mock_model.generate_content.assert_called_once()
call_args = mock_model.generate_content.call_args
# Check that generation_config was created (we can't directly access temperature from mock)
generation_config = call_args.kwargs['generation_config']
assert generation_config is not None # Should use overridden temperature configuration
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
"""Test model parameter override functionality"""
# Arrange
mock_credentials = MagicMock()
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
# Mock different models
mock_model_default = MagicMock()
mock_model_override = MagicMock()
mock_response = MagicMock()
mock_response.text = "Response with custom model"
mock_response.usage_metadata.prompt_token_count = 18
mock_response.usage_metadata.candidates_token_count = 14
mock_model_override.generate_content.return_value = mock_response
# GenerativeModel should return different models based on input
def model_factory(model_name):
if model_name == 'gemini-1.5-pro':
return mock_model_override
return mock_model_default
mock_generative_model.side_effect = model_factory
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'region': 'us-central1',
'model': 'gemini-2.0-flash-001', # Default model
'temperature': 0.2, # Default temperature
'max_output': 8192,
'private_key': 'private.json',
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override model at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="gemini-1.5-pro", # Override model
temperature=None # Use default temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom model"
# Verify the overridden model was used
mock_model_override.generate_content.assert_called_once()
# Verify GenerativeModel was called with the override model
mock_generative_model.assert_called_with('gemini-1.5-pro')
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
"""Test overriding both model and temperature parameters simultaneously"""
# Arrange
mock_credentials = MagicMock()
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
mock_model = MagicMock()
mock_response = MagicMock()
mock_response.text = "Response with both overrides"
mock_response.usage_metadata.prompt_token_count = 22
mock_response.usage_metadata.candidates_token_count = 16
mock_model.generate_content.return_value = mock_response
mock_generative_model.return_value = mock_model
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'region': 'us-central1',
'model': 'gemini-2.0-flash-001', # Default model
'temperature': 0.0, # Default temperature
'max_output': 8192,
'private_key': 'private.json',
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override both parameters at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="gemini-1.5-flash-001", # Override model
temperature=0.9 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with both overrides"
# Verify both overrides were used
mock_model.generate_content.assert_called_once()
call_args = mock_model.generate_content.call_args
# Verify model override
mock_generative_model.assert_called_with('gemini-1.5-flash-001') # Should use runtime override
# Verify temperature override (we can't directly access temperature from mock)
generation_config = call_args.kwargs['generation_config']
assert generation_config is not None # Should use overridden temperature configuration
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -42,7 +42,7 @@ class TestVLLMProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'TheBloke/Mistral-7B-v0.1-AWQ'
assert processor.default_model == 'TheBloke/Mistral-7B-v0.1-AWQ'
assert processor.base_url == 'http://vllm-service:8899/v1'
assert processor.temperature == 0.0
assert processor.max_output == 2048
@ -199,7 +199,7 @@ class TestVLLMProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'custom-model'
assert processor.default_model == 'custom-model'
assert processor.base_url == 'http://custom-vllm:8080/v1'
assert processor.temperature == 0.7
assert processor.max_output == 1024
@ -228,7 +228,7 @@ class TestVLLMProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.model == 'TheBloke/Mistral-7B-v0.1-AWQ' # default_model
assert processor.default_model == 'TheBloke/Mistral-7B-v0.1-AWQ' # default_model
assert processor.base_url == 'http://vllm-service:8899/v1' # default_base_url
assert processor.temperature == 0.0 # default_temperature
assert processor.max_output == 2048 # default_max_output
@ -485,5 +485,148 @@ class TestVLLMProcessorSimple(IsolatedAsyncioTestCase):
assert call_args[1]['json']['prompt'] == expected_prompt
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_with_model_override(self, mock_llm_init, mock_async_init, mock_session_class):
"""Test generate_content with model parameter override"""
# Arrange
mock_session = MagicMock()
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
'choices': [{
'text': 'Response from overridden model'
}],
'usage': {
'prompt_tokens': 12,
'completion_tokens': 8
}
})
mock_session.post.return_value.__aenter__.return_value = mock_response
mock_session.post.return_value.__aexit__.return_value = None
mock_session_class.return_value = mock_session
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
'url': 'http://vllm-service:8899/v1',
'temperature': 0.0,
'max_output': 2048,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override model
result = await processor.generate_content("System", "Prompt", model="custom-vllm-model")
# Assert
assert result.model == "custom-vllm-model" # Should use overridden model
assert result.text == "Response from overridden model"
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_with_temperature_override(self, mock_llm_init, mock_async_init, mock_session_class):
"""Test generate_content with temperature parameter override"""
# Arrange
mock_session = MagicMock()
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
'choices': [{
'text': 'Response with temperature override'
}],
'usage': {
'prompt_tokens': 15,
'completion_tokens': 10
}
})
mock_session.post.return_value.__aenter__.return_value = mock_response
mock_session.post.return_value.__aexit__.return_value = None
mock_session_class.return_value = mock_session
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
'url': 'http://vllm-service:8899/v1',
'temperature': 0.0, # Default temperature
'max_output': 2048,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override temperature
result = await processor.generate_content("System", "Prompt", temperature=0.7)
# Assert
assert result.text == "Response with temperature override"
# Verify the request was made with overridden temperature
call_args = mock_session.post.call_args
assert call_args[1]['json']['temperature'] == 0.7
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_with_both_parameters_override(self, mock_llm_init, mock_async_init, mock_session_class):
"""Test generate_content with both model and temperature overrides"""
# Arrange
mock_session = MagicMock()
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
'choices': [{
'text': 'Response with both parameters override'
}],
'usage': {
'prompt_tokens': 18,
'completion_tokens': 12
}
})
mock_session.post.return_value.__aenter__.return_value = mock_response
mock_session.post.return_value.__aexit__.return_value = None
mock_session_class.return_value = mock_session
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
'url': 'http://vllm-service:8899/v1',
'temperature': 0.0,
'max_output': 2048,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override both parameters
result = await processor.generate_content("System", "Prompt", model="override-model", temperature=0.8)
# Assert
assert result.model == "override-model"
assert result.text == "Response with both parameters override"
# Verify the request was made with overridden temperature
call_args = mock_session.post.call_args
assert call_args[1]['json']['temperature'] == 0.8
if __name__ == '__main__':
pytest.main([__file__])