More LLM param test coverage (#535)

* More LLM tests

* Fixing tests
This commit is contained in:
cybermaggedon 2025-09-26 01:00:30 +01:00 committed by GitHub
parent b0a3716b0e
commit 43cfcb18a0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 3563 additions and 0 deletions

View file

@ -0,0 +1,238 @@
"""
Unit tests for Flow Parameter Specification functionality
Testing parameter specification registration and handling in flow processors
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
from trustgraph.base.flow_processor import FlowProcessor
from trustgraph.base import ParameterSpec, ConsumerSpec, ProducerSpec
class MockAsyncProcessor:
def __init__(self, **params):
self.config_handlers = []
self.id = params.get('id', 'test-service')
self.specifications = []
class TestFlowParameterSpecs(IsolatedAsyncioTestCase):
"""Test flow processor parameter specification functionality"""
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
def test_parameter_spec_registration(self):
"""Test that parameter specs can be registered with flow processors"""
# Arrange
config = {
'id': 'test-flow-processor',
'concurrency': 1,
'taskgroup': AsyncMock()
}
processor = FlowProcessor(**config)
# Create test parameter specs
model_spec = ParameterSpec(name="model")
temperature_spec = ParameterSpec(name="temperature")
# Act
processor.register_specification(model_spec)
processor.register_specification(temperature_spec)
# Assert
assert len(processor.specifications) >= 2
param_specs = [spec for spec in processor.specifications
if isinstance(spec, ParameterSpec)]
assert len(param_specs) >= 2
param_names = [spec.name for spec in param_specs]
assert "model" in param_names
assert "temperature" in param_names
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
def test_mixed_specification_types(self):
"""Test registration of mixed specification types (parameters, consumers, producers)"""
# Arrange
config = {
'id': 'test-flow-processor',
'concurrency': 1,
'taskgroup': AsyncMock()
}
processor = FlowProcessor(**config)
# Create different spec types
param_spec = ParameterSpec(name="model")
consumer_spec = ConsumerSpec(name="input", schema=MagicMock(), handler=MagicMock())
producer_spec = ProducerSpec(name="output", schema=MagicMock())
# Act
processor.register_specification(param_spec)
processor.register_specification(consumer_spec)
processor.register_specification(producer_spec)
# Assert
assert len(processor.specifications) == 3
# Count each type
param_specs = [s for s in processor.specifications if isinstance(s, ParameterSpec)]
consumer_specs = [s for s in processor.specifications if isinstance(s, ConsumerSpec)]
producer_specs = [s for s in processor.specifications if isinstance(s, ProducerSpec)]
assert len(param_specs) == 1
assert len(consumer_specs) == 1
assert len(producer_specs) == 1
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
def test_parameter_spec_metadata(self):
"""Test parameter specification metadata handling"""
# Arrange
config = {
'id': 'test-flow-processor',
'concurrency': 1,
'taskgroup': AsyncMock()
}
processor = FlowProcessor(**config)
# Create parameter specs with metadata (if supported)
model_spec = ParameterSpec(name="model")
temperature_spec = ParameterSpec(name="temperature")
# Act
processor.register_specification(model_spec)
processor.register_specification(temperature_spec)
# Assert
param_specs = [spec for spec in processor.specifications
if isinstance(spec, ParameterSpec)]
model_spec_registered = next((s for s in param_specs if s.name == "model"), None)
temperature_spec_registered = next((s for s in param_specs if s.name == "temperature"), None)
assert model_spec_registered is not None
assert temperature_spec_registered is not None
assert model_spec_registered.name == "model"
assert temperature_spec_registered.name == "temperature"
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
def test_duplicate_parameter_spec_handling(self):
"""Test handling of duplicate parameter spec registration"""
# Arrange
config = {
'id': 'test-flow-processor',
'concurrency': 1,
'taskgroup': AsyncMock()
}
processor = FlowProcessor(**config)
# Create duplicate parameter specs
model_spec1 = ParameterSpec(name="model")
model_spec2 = ParameterSpec(name="model")
# Act
processor.register_specification(model_spec1)
processor.register_specification(model_spec2)
# Assert - Should allow duplicates (or handle appropriately)
param_specs = [spec for spec in processor.specifications
if isinstance(spec, ParameterSpec) and spec.name == "model"]
# Either should have 2 duplicates or the system should handle deduplication
assert len(param_specs) >= 1 # At least one should be registered
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
@patch('trustgraph.base.flow_processor.Flow')
async def test_parameter_specs_available_to_flows(self, mock_flow_class):
"""Test that parameter specs are available when flows are created"""
# Arrange
config = {
'id': 'test-flow-processor',
'concurrency': 1,
'taskgroup': AsyncMock()
}
processor = FlowProcessor(**config)
processor.id = 'test-processor'
# Register parameter specs
model_spec = ParameterSpec(name="model")
temperature_spec = ParameterSpec(name="temperature")
processor.register_specification(model_spec)
processor.register_specification(temperature_spec)
mock_flow = AsyncMock()
mock_flow_class.return_value = mock_flow
flow_name = 'test-flow'
flow_defn = {'config': 'test-config'}
# Act
await processor.start_flow(flow_name, flow_defn)
# Assert - Flow should be created with access to processor specifications
mock_flow_class.assert_called_once_with('test-processor', flow_name, processor, flow_defn)
# The flow should have access to the processor's specifications
# (The exact mechanism depends on Flow implementation)
assert len(processor.specifications) >= 2
class TestParameterSpecValidation(IsolatedAsyncioTestCase):
"""Test parameter specification validation functionality"""
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
def test_parameter_spec_name_validation(self):
"""Test parameter spec name validation"""
# Arrange
config = {
'id': 'test-flow-processor',
'concurrency': 1,
'taskgroup': AsyncMock()
}
processor = FlowProcessor(**config)
# Act & Assert - Valid parameter names
valid_specs = [
ParameterSpec(name="model"),
ParameterSpec(name="temperature"),
ParameterSpec(name="max_tokens"),
ParameterSpec(name="api_key")
]
for spec in valid_specs:
# Should not raise any exceptions
processor.register_specification(spec)
assert len([s for s in processor.specifications if isinstance(s, ParameterSpec)]) >= 4
def test_parameter_spec_creation_validation(self):
"""Test parameter spec creation with various inputs"""
# Test valid parameter spec creation
valid_specs = [
ParameterSpec(name="model"),
ParameterSpec(name="temperature"),
ParameterSpec(name="max_output"),
]
for spec in valid_specs:
assert spec.name is not None
assert isinstance(spec.name, str)
# Test edge cases (if parameter specs have validation)
# This depends on the actual ParameterSpec implementation
try:
empty_name_spec = ParameterSpec(name="")
# May or may not be valid depending on implementation
except Exception:
# If validation exists, it should catch invalid names
pass
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -0,0 +1,264 @@
"""
Unit tests for LLM Service Parameter Specifications
Testing the new parameter-aware functionality added to the LLM base service
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
from trustgraph.base.llm_service import LlmService, LlmResult
from trustgraph.base import ParameterSpec, ConsumerSpec, ProducerSpec
from trustgraph.schema import TextCompletionRequest, TextCompletionResponse
class MockAsyncProcessor:
def __init__(self, **params):
self.config_handlers = []
self.id = params.get('id', 'test-service')
self.specifications = []
class TestLlmServiceParameters(IsolatedAsyncioTestCase):
"""Test LLM service parameter specification functionality"""
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
def test_parameter_specs_registration(self):
"""Test that LLM service registers model and temperature parameter specs"""
# Arrange
config = {
'id': 'test-llm-service',
'concurrency': 1,
'taskgroup': AsyncMock() # Add required taskgroup
}
# Act
service = LlmService(**config)
# Assert
param_specs = {spec.name: spec for spec in service.specifications
if isinstance(spec, ParameterSpec)}
assert "model" in param_specs
assert "temperature" in param_specs
assert len(param_specs) >= 2 # May have other parameter specs
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
def test_model_parameter_spec_properties(self):
"""Test that model parameter spec has correct properties"""
# Arrange
config = {
'id': 'test-llm-service',
'concurrency': 1,
'taskgroup': AsyncMock()
}
# Act
service = LlmService(**config)
# Assert
model_spec = None
for spec in service.specifications:
if isinstance(spec, ParameterSpec) and spec.name == "model":
model_spec = spec
break
assert model_spec is not None
assert model_spec.name == "model"
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
def test_temperature_parameter_spec_properties(self):
"""Test that temperature parameter spec has correct properties"""
# Arrange
config = {
'id': 'test-llm-service',
'concurrency': 1,
'taskgroup': AsyncMock()
}
# Act
service = LlmService(**config)
# Assert
temperature_spec = None
for spec in service.specifications:
if isinstance(spec, ParameterSpec) and spec.name == "temperature":
temperature_spec = spec
break
assert temperature_spec is not None
assert temperature_spec.name == "temperature"
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_request_extracts_parameters_from_flow(self):
"""Test that on_request method extracts model and temperature from flow"""
# Arrange
config = {
'id': 'test-llm-service',
'concurrency': 1,
'taskgroup': AsyncMock()
}
service = LlmService(**config)
# Mock the metrics
service.text_completion_model_metric = MagicMock()
service.text_completion_model_metric.labels.return_value.info = AsyncMock()
# Mock the generate_content method to capture parameters
service.generate_content = AsyncMock(return_value=LlmResult(
text="test response",
in_token=10,
out_token=5,
model="gpt-4"
))
# Mock message and flow
mock_message = MagicMock()
mock_message.value.return_value = MagicMock()
mock_message.value.return_value.system = "system prompt"
mock_message.value.return_value.prompt = "user prompt"
mock_message.properties.return_value = {"id": "test-id"}
mock_consumer = MagicMock()
mock_consumer.name = "request"
mock_flow = MagicMock()
mock_flow.name = "test-flow"
mock_flow.return_value = "test-model" # flow("model") returns this
mock_flow.side_effect = lambda param: {
"model": "gpt-4",
"temperature": 0.7
}.get(param, f"mock-{param}")
mock_producer = AsyncMock()
mock_flow.producer = {"response": mock_producer}
# Act
await service.on_request(mock_message, mock_consumer, mock_flow)
# Assert
# Verify that generate_content was called with parameters from flow
service.generate_content.assert_called_once()
call_args = service.generate_content.call_args
assert call_args[0][0] == "system prompt" # system
assert call_args[0][1] == "user prompt" # prompt
assert call_args[0][2] == "gpt-4" # model
assert call_args[0][3] == 0.7 # temperature
# Verify flow was queried for both parameters
mock_flow.assert_any_call("model")
mock_flow.assert_any_call("temperature")
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_request_handles_missing_parameters_gracefully(self):
"""Test that on_request handles missing parameters gracefully"""
# Arrange
config = {
'id': 'test-llm-service',
'concurrency': 1,
'taskgroup': AsyncMock()
}
service = LlmService(**config)
# Mock the metrics
service.text_completion_model_metric = MagicMock()
service.text_completion_model_metric.labels.return_value.info = AsyncMock()
# Mock the generate_content method
service.generate_content = AsyncMock(return_value=LlmResult(
text="test response",
in_token=10,
out_token=5,
model="default-model"
))
# Mock message and flow where flow returns None for parameters
mock_message = MagicMock()
mock_message.value.return_value = MagicMock()
mock_message.value.return_value.system = "system prompt"
mock_message.value.return_value.prompt = "user prompt"
mock_message.properties.return_value = {"id": "test-id"}
mock_consumer = MagicMock()
mock_consumer.name = "request"
mock_flow = MagicMock()
mock_flow.name = "test-flow"
mock_flow.return_value = None # Both parameters return None
mock_producer = AsyncMock()
mock_flow.producer = {"response": mock_producer}
# Act
await service.on_request(mock_message, mock_consumer, mock_flow)
# Assert
# Should still call generate_content, with None values that will use processor defaults
service.generate_content.assert_called_once()
call_args = service.generate_content.call_args
assert call_args[0][0] == "system prompt" # system
assert call_args[0][1] == "user prompt" # prompt
assert call_args[0][2] is None # model (will use processor default)
assert call_args[0][3] is None # temperature (will use processor default)
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_request_error_handling_preserves_behavior(self):
"""Test that parameter extraction doesn't break existing error handling"""
# Arrange
config = {
'id': 'test-llm-service',
'concurrency': 1,
'taskgroup': AsyncMock()
}
service = LlmService(**config)
# Mock the metrics
service.text_completion_model_metric = MagicMock()
service.text_completion_model_metric.labels.return_value.info = AsyncMock()
# Mock generate_content to raise an exception
service.generate_content = AsyncMock(side_effect=Exception("Test error"))
# Mock message and flow
mock_message = MagicMock()
mock_message.value.return_value = MagicMock()
mock_message.value.return_value.system = "system prompt"
mock_message.value.return_value.prompt = "user prompt"
mock_message.properties.return_value = {"id": "test-id"}
mock_consumer = MagicMock()
mock_consumer.name = "request"
mock_flow = MagicMock()
mock_flow.name = "test-flow"
mock_flow.side_effect = lambda param: {
"model": "gpt-4",
"temperature": 0.7
}.get(param, f"mock-{param}")
mock_producer = AsyncMock()
mock_flow.producer = {"response": mock_producer}
# Act
await service.on_request(mock_message, mock_consumer, mock_flow)
# Assert
# Should have sent error response
mock_producer.send.assert_called_once()
error_response = mock_producer.send.call_args[0][0]
assert error_response.error is not None
assert error_response.error.type == "llm-error"
assert "Test error" in error_response.error.message
assert error_response.response is None
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -402,6 +402,156 @@ class TestAzureOpenAIProcessorSimple(IsolatedAsyncioTestCase):
assert call_args[1]['max_tokens'] == 1024
assert call_args[1]['top_p'] == 1
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
"""Test temperature parameter override functionality"""
# Arrange
mock_azure_client = MagicMock()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = 'Response with custom temperature'
mock_response.usage.prompt_tokens = 20
mock_response.usage.completion_tokens = 12
mock_azure_client.chat.completions.create.return_value = mock_response
mock_azure_openai_class.return_value = mock_azure_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gpt-4',
'endpoint': 'https://test.openai.azure.com/',
'token': 'test-token',
'api_version': '2024-12-01-preview',
'temperature': 0.0, # Default temperature
'max_output': 4192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override temperature at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model=None, # Use default model
temperature=0.8 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom temperature"
# Verify Azure OpenAI API was called with overridden temperature
call_args = mock_azure_client.chat.completions.create.call_args
assert call_args[1]['temperature'] == 0.8 # Should use runtime override
assert call_args[1]['model'] == 'gpt-4'
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
"""Test model parameter override functionality"""
# Arrange
mock_azure_client = MagicMock()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = 'Response with custom model'
mock_response.usage.prompt_tokens = 18
mock_response.usage.completion_tokens = 14
mock_azure_client.chat.completions.create.return_value = mock_response
mock_azure_openai_class.return_value = mock_azure_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gpt-4', # Default model
'endpoint': 'https://test.openai.azure.com/',
'token': 'test-token',
'api_version': '2024-12-01-preview',
'temperature': 0.1, # Default temperature
'max_output': 4192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override model at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="gpt-4o", # Override model
temperature=None # Use default temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom model"
# Verify Azure OpenAI API was called with overridden model
call_args = mock_azure_client.chat.completions.create.call_args
assert call_args[1]['model'] == 'gpt-4o' # Should use runtime override
assert call_args[1]['temperature'] == 0.1 # Should use processor default
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
"""Test overriding both model and temperature parameters simultaneously"""
# Arrange
mock_azure_client = MagicMock()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = 'Response with both overrides'
mock_response.usage.prompt_tokens = 22
mock_response.usage.completion_tokens = 16
mock_azure_client.chat.completions.create.return_value = mock_response
mock_azure_openai_class.return_value = mock_azure_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gpt-4', # Default model
'endpoint': 'https://test.openai.azure.com/',
'token': 'test-token',
'api_version': '2024-12-01-preview',
'temperature': 0.0, # Default temperature
'max_output': 4192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override both parameters at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="gpt-4o-mini", # Override model
temperature=0.9 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with both overrides"
# Verify Azure OpenAI API was called with both overrides
call_args = mock_azure_client.chat.completions.create.call_args
assert call_args[1]['model'] == 'gpt-4o-mini' # Should use runtime override
assert call_args[1]['temperature'] == 0.9 # Should use runtime override
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -459,5 +459,150 @@ class TestAzureProcessorSimple(IsolatedAsyncioTestCase):
)
@patch('trustgraph.model.text_completion.azure.llm.requests')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_with_model_override(self, mock_llm_init, mock_async_init, mock_requests):
"""Test generate_content with model parameter override"""
# Arrange
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
'choices': [{
'message': {
'content': 'Response with model override'
}
}],
'usage': {
'prompt_tokens': 15,
'completion_tokens': 10
}
}
mock_requests.post.return_value = mock_response
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
'token': 'test-token',
'temperature': 0.0,
'max_output': 4192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override model
result = await processor.generate_content("System", "Prompt", model="custom-azure-model")
# Assert
assert result.model == "custom-azure-model" # Should use overridden model
assert result.text == "Response with model override"
@patch('trustgraph.model.text_completion.azure.llm.requests')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_with_temperature_override(self, mock_llm_init, mock_async_init, mock_requests):
"""Test generate_content with temperature parameter override"""
# Arrange
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
'choices': [{
'message': {
'content': 'Response with temperature override'
}
}],
'usage': {
'prompt_tokens': 15,
'completion_tokens': 10
}
}
mock_requests.post.return_value = mock_response
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
'token': 'test-token',
'temperature': 0.0, # Default temperature
'max_output': 4192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override temperature
result = await processor.generate_content("System", "Prompt", temperature=0.8)
# Assert
assert result.text == "Response with temperature override"
# Verify the request was made with the overridden temperature
mock_requests.post.assert_called_once()
call_args = mock_requests.post.call_args
import json
request_body = json.loads(call_args[1]['data'])
assert request_body['temperature'] == 0.8
@patch('trustgraph.model.text_completion.azure.llm.requests')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_with_both_parameters_override(self, mock_llm_init, mock_async_init, mock_requests):
"""Test generate_content with both model and temperature overrides"""
# Arrange
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
'choices': [{
'message': {
'content': 'Response with both parameters override'
}
}],
'usage': {
'prompt_tokens': 18,
'completion_tokens': 12
}
}
mock_requests.post.return_value = mock_response
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
'token': 'test-token',
'temperature': 0.0,
'max_output': 4192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override both parameters
result = await processor.generate_content("System", "Prompt", model="override-model", temperature=0.9)
# Assert
assert result.model == "override-model"
assert result.text == "Response with both parameters override"
# Verify the request was made with overridden temperature
mock_requests.post.assert_called_once()
call_args = mock_requests.post.call_args
import json
request_body = json.loads(call_args[1]['data'])
assert request_body['temperature'] == 0.9
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -0,0 +1,280 @@
"""
Unit tests for trustgraph.model.text_completion.bedrock
Following the same successful pattern as other processor tests
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
import json
# Import the service under test
from trustgraph.model.text_completion.bedrock.llm import Processor, Mistral, Anthropic
from trustgraph.base import LlmResult
class TestBedrockProcessorSimple(IsolatedAsyncioTestCase):
"""Test Bedrock processor functionality"""
@patch('trustgraph.model.text_completion.bedrock.llm.boto3.Session')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_session_class):
"""Test basic processor initialization"""
# Arrange
mock_session = MagicMock()
mock_bedrock = MagicMock()
mock_session.client.return_value = mock_bedrock
mock_session_class.return_value = mock_session
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'mistral.mistral-large-2407-v1:0',
'temperature': 0.1,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act
processor = Processor(**config)
# Assert
assert processor.default_model == 'mistral.mistral-large-2407-v1:0'
assert processor.temperature == 0.1
assert hasattr(processor, 'bedrock')
mock_session_class.assert_called_once()
@patch('trustgraph.model.text_completion.bedrock.llm.boto3.Session')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_success_mistral(self, mock_llm_init, mock_async_init, mock_session_class):
"""Test successful content generation with Mistral model"""
# Arrange
mock_session = MagicMock()
mock_bedrock = MagicMock()
mock_session.client.return_value = mock_bedrock
mock_session_class.return_value = mock_session
mock_response = {
'body': MagicMock(),
'ResponseMetadata': {
'HTTPHeaders': {
'x-amzn-bedrock-input-token-count': '15',
'x-amzn-bedrock-output-token-count': '8'
}
}
}
mock_response['body'].read.return_value = json.dumps({
'outputs': [{'text': 'Generated response from Bedrock'}]
})
mock_bedrock.invoke_model.return_value = mock_response
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'mistral.mistral-large-2407-v1:0',
'temperature': 0.0,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("System prompt", "User prompt")
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Generated response from Bedrock"
assert result.in_token == 15
assert result.out_token == 8
assert result.model == 'mistral.mistral-large-2407-v1:0'
mock_bedrock.invoke_model.assert_called_once()
@patch('trustgraph.model.text_completion.bedrock.llm.boto3.Session')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_session_class):
"""Test temperature parameter override functionality"""
# Arrange
mock_session = MagicMock()
mock_bedrock = MagicMock()
mock_session.client.return_value = mock_bedrock
mock_session_class.return_value = mock_session
mock_response = {
'body': MagicMock(),
'ResponseMetadata': {
'HTTPHeaders': {
'x-amzn-bedrock-input-token-count': '20',
'x-amzn-bedrock-output-token-count': '12'
}
}
}
mock_response['body'].read.return_value = json.dumps({
'outputs': [{'text': 'Response with custom temperature'}]
})
mock_bedrock.invoke_model.return_value = mock_response
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'mistral.mistral-large-2407-v1:0',
'temperature': 0.0, # Default temperature
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override temperature at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model=None, # Use default model
temperature=0.8 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom temperature"
# Verify the model variant was created with overridden temperature
# The cache key should include the temperature
cache_key = f"mistral.mistral-large-2407-v1:0:0.8"
assert cache_key in processor.model_variants
variant = processor.model_variants[cache_key]
assert variant.temperature == 0.8
@patch('trustgraph.model.text_completion.bedrock.llm.boto3.Session')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_session_class):
"""Test model parameter override functionality"""
# Arrange
mock_session = MagicMock()
mock_bedrock = MagicMock()
mock_session.client.return_value = mock_bedrock
mock_session_class.return_value = mock_session
mock_response = {
'body': MagicMock(),
'ResponseMetadata': {
'HTTPHeaders': {
'x-amzn-bedrock-input-token-count': '18',
'x-amzn-bedrock-output-token-count': '14'
}
}
}
mock_response['body'].read.return_value = json.dumps({
'content': [{'text': 'Response with custom model'}]
})
mock_bedrock.invoke_model.return_value = mock_response
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'mistral.mistral-large-2407-v1:0', # Default model
'temperature': 0.1, # Default temperature
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override model at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="anthropic.claude-3-sonnet-20240229-v1:0", # Override model
temperature=None # Use default temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom model"
# Verify Bedrock API was called with overridden model
mock_bedrock.invoke_model.assert_called_once()
call_args = mock_bedrock.invoke_model.call_args
assert call_args[1]['modelId'] == "anthropic.claude-3-sonnet-20240229-v1:0"
# Verify the correct model variant (Anthropic) was used
cache_key = f"anthropic.claude-3-sonnet-20240229-v1:0:0.1"
assert cache_key in processor.model_variants
variant = processor.model_variants[cache_key]
assert isinstance(variant, Anthropic)
@patch('trustgraph.model.text_completion.bedrock.llm.boto3.Session')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_session_class):
"""Test overriding both model and temperature parameters simultaneously"""
# Arrange
mock_session = MagicMock()
mock_bedrock = MagicMock()
mock_session.client.return_value = mock_bedrock
mock_session_class.return_value = mock_session
mock_response = {
'body': MagicMock(),
'ResponseMetadata': {
'HTTPHeaders': {
'x-amzn-bedrock-input-token-count': '22',
'x-amzn-bedrock-output-token-count': '16'
}
}
}
mock_response['body'].read.return_value = json.dumps({
'generation': 'Response with both overrides'
})
mock_bedrock.invoke_model.return_value = mock_response
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'mistral.mistral-large-2407-v1:0', # Default model
'temperature': 0.0, # Default temperature
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override both parameters at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="meta.llama3-70b-instruct-v1:0", # Override model (Meta/Llama)
temperature=0.9 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with both overrides"
# Verify Bedrock API was called with both overrides
mock_bedrock.invoke_model.assert_called_once()
call_args = mock_bedrock.invoke_model.call_args
assert call_args[1]['modelId'] == "meta.llama3-70b-instruct-v1:0"
# Verify the correct model variant (Meta) was used with correct temperature
cache_key = f"meta.llama3-70b-instruct-v1:0:0.9"
assert cache_key in processor.model_variants
variant = processor.model_variants[cache_key]
assert variant.temperature == 0.9
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -435,6 +435,156 @@ class TestClaudeProcessorSimple(IsolatedAsyncioTestCase):
assert processor.claude == mock_claude_client
assert processor.default_model == 'claude-3-opus-20240229'
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_anthropic_class):
"""Test temperature parameter override functionality"""
# Arrange
mock_claude_client = MagicMock()
mock_response = MagicMock()
mock_response.content = [MagicMock()]
mock_response.content[0].text = "Response with custom temperature"
mock_response.usage.input_tokens = 20
mock_response.usage.output_tokens = 12
mock_claude_client.messages.create.return_value = mock_response
mock_anthropic_class.return_value = mock_claude_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'claude-3-5-sonnet-20240620',
'api_key': 'test-api-key',
'temperature': 0.0, # Default temperature
'max_output': 8192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override temperature at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model=None, # Use default model
temperature=0.9 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom temperature"
# Verify Claude API was called with overridden temperature
mock_claude_client.messages.create.assert_called_once()
call_kwargs = mock_claude_client.messages.create.call_args.kwargs
assert call_kwargs['temperature'] == 0.9 # Should use runtime override
assert call_kwargs['model'] == 'claude-3-5-sonnet-20240620' # Should use processor default
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_anthropic_class):
"""Test model parameter override functionality"""
# Arrange
mock_claude_client = MagicMock()
mock_response = MagicMock()
mock_response.content = [MagicMock()]
mock_response.content[0].text = "Response with custom model"
mock_response.usage.input_tokens = 18
mock_response.usage.output_tokens = 14
mock_claude_client.messages.create.return_value = mock_response
mock_anthropic_class.return_value = mock_claude_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'claude-3-5-sonnet-20240620', # Default model
'api_key': 'test-api-key',
'temperature': 0.2, # Default temperature
'max_output': 8192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override model at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="claude-3-haiku-20240307", # Override model
temperature=None # Use default temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom model"
# Verify Claude API was called with overridden model
mock_claude_client.messages.create.assert_called_once()
call_kwargs = mock_claude_client.messages.create.call_args.kwargs
assert call_kwargs['model'] == 'claude-3-haiku-20240307' # Should use runtime override
assert call_kwargs['temperature'] == 0.2 # Should use processor default
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_anthropic_class):
"""Test overriding both model and temperature parameters simultaneously"""
# Arrange
mock_claude_client = MagicMock()
mock_response = MagicMock()
mock_response.content = [MagicMock()]
mock_response.content[0].text = "Response with both overrides"
mock_response.usage.input_tokens = 22
mock_response.usage.output_tokens = 16
mock_claude_client.messages.create.return_value = mock_response
mock_anthropic_class.return_value = mock_claude_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'claude-3-5-sonnet-20240620', # Default model
'api_key': 'test-api-key',
'temperature': 0.0, # Default temperature
'max_output': 8192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override both parameters at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="claude-3-opus-20240229", # Override model
temperature=0.8 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with both overrides"
# Verify Claude API was called with both overrides
mock_claude_client.messages.create.assert_called_once()
call_kwargs = mock_claude_client.messages.create.call_args.kwargs
assert call_kwargs['model'] == 'claude-3-opus-20240229' # Should use runtime override
assert call_kwargs['temperature'] == 0.8 # Should use runtime override
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -442,6 +442,162 @@ class TestCohereProcessorSimple(IsolatedAsyncioTestCase):
assert call_args[1]['prompt_truncation'] == 'auto'
assert call_args[1]['connectors'] == []
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_cohere_class):
"""Test temperature parameter override functionality"""
# Arrange
mock_cohere_client = MagicMock()
mock_output = MagicMock()
mock_output.text = 'Response with custom temperature'
mock_output.meta.billed_units.input_tokens = 20
mock_output.meta.billed_units.output_tokens = 12
mock_cohere_client.chat.return_value = mock_output
mock_cohere_class.return_value = mock_cohere_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'c4ai-aya-23-8b',
'api_key': 'test-api-key',
'temperature': 0.0, # Default temperature
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override temperature at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model=None, # Use default model
temperature=0.8 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom temperature"
# Verify Cohere API was called with overridden temperature
mock_cohere_client.chat.assert_called_once_with(
model='c4ai-aya-23-8b',
message='User prompt',
preamble='System prompt',
temperature=0.8, # Should use runtime override
chat_history=[],
prompt_truncation='auto',
connectors=[]
)
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_cohere_class):
"""Test model parameter override functionality"""
# Arrange
mock_cohere_client = MagicMock()
mock_output = MagicMock()
mock_output.text = 'Response with custom model'
mock_output.meta.billed_units.input_tokens = 18
mock_output.meta.billed_units.output_tokens = 14
mock_cohere_client.chat.return_value = mock_output
mock_cohere_class.return_value = mock_cohere_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'c4ai-aya-23-8b', # Default model
'api_key': 'test-api-key',
'temperature': 0.1, # Default temperature
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override model at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="command-r-plus", # Override model
temperature=None # Use default temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom model"
# Verify Cohere API was called with overridden model
mock_cohere_client.chat.assert_called_once_with(
model='command-r-plus', # Should use runtime override
message='User prompt',
preamble='System prompt',
temperature=0.1, # Should use processor default
chat_history=[],
prompt_truncation='auto',
connectors=[]
)
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_cohere_class):
"""Test overriding both model and temperature parameters simultaneously"""
# Arrange
mock_cohere_client = MagicMock()
mock_output = MagicMock()
mock_output.text = 'Response with both overrides'
mock_output.meta.billed_units.input_tokens = 22
mock_output.meta.billed_units.output_tokens = 16
mock_cohere_client.chat.return_value = mock_output
mock_cohere_class.return_value = mock_cohere_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'c4ai-aya-23-8b', # Default model
'api_key': 'test-api-key',
'temperature': 0.0, # Default temperature
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override both parameters at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="command-r", # Override model
temperature=0.9 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with both overrides"
# Verify Cohere API was called with both overrides
mock_cohere_client.chat.assert_called_once_with(
model='command-r', # Should use runtime override
message='User prompt',
preamble='System prompt',
temperature=0.9, # Should use runtime override
chat_history=[],
prompt_truncation='auto',
connectors=[]
)
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -477,6 +477,156 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
# The system instruction should be in the config object
assert call_args[1]['contents'] == "Explain quantum computing"
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_genai_class):
"""Test temperature parameter override functionality"""
# Arrange
mock_genai_client = MagicMock()
mock_response = MagicMock()
mock_response.text = 'Response with custom temperature'
mock_response.usage_metadata.prompt_token_count = 20
mock_response.usage_metadata.candidates_token_count = 12
mock_genai_client.models.generate_content.return_value = mock_response
mock_genai_class.return_value = mock_genai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gemini-2.0-flash-001',
'api_key': 'test-api-key',
'temperature': 0.0, # Default temperature
'max_output': 8192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override temperature at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model=None, # Use default model
temperature=0.8 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom temperature"
# Verify the generation config was created with overridden temperature
cache_key = f"gemini-2.0-flash-001:0.8"
assert cache_key in processor.generation_configs
config_obj = processor.generation_configs[cache_key]
assert config_obj.temperature == 0.8
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_genai_class):
"""Test model parameter override functionality"""
# Arrange
mock_genai_client = MagicMock()
mock_response = MagicMock()
mock_response.text = 'Response with custom model'
mock_response.usage_metadata.prompt_token_count = 18
mock_response.usage_metadata.candidates_token_count = 14
mock_genai_client.models.generate_content.return_value = mock_response
mock_genai_class.return_value = mock_genai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gemini-2.0-flash-001', # Default model
'api_key': 'test-api-key',
'temperature': 0.1, # Default temperature
'max_output': 8192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override model at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="gemini-1.5-pro", # Override model
temperature=None # Use default temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom model"
# Verify Google AI Studio API was called with overridden model
call_args = mock_genai_client.models.generate_content.call_args
assert call_args[1]['model'] == 'gemini-1.5-pro' # Should use runtime override
# Verify the generation config was created for the correct model
cache_key = f"gemini-1.5-pro:0.1"
assert cache_key in processor.generation_configs
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_genai_class):
"""Test overriding both model and temperature parameters simultaneously"""
# Arrange
mock_genai_client = MagicMock()
mock_response = MagicMock()
mock_response.text = 'Response with both overrides'
mock_response.usage_metadata.prompt_token_count = 22
mock_response.usage_metadata.candidates_token_count = 16
mock_genai_client.models.generate_content.return_value = mock_response
mock_genai_class.return_value = mock_genai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gemini-2.0-flash-001', # Default model
'api_key': 'test-api-key',
'temperature': 0.0, # Default temperature
'max_output': 8192,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override both parameters at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="gemini-1.5-flash", # Override model
temperature=0.9 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with both overrides"
# Verify Google AI Studio API was called with both overrides
call_args = mock_genai_client.models.generate_content.call_args
assert call_args[1]['model'] == 'gemini-1.5-flash' # Should use runtime override
# Verify the generation config was created with both overrides
cache_key = f"gemini-1.5-flash:0.9"
assert cache_key in processor.generation_configs
config_obj = processor.generation_configs[cache_key]
assert config_obj.temperature == 0.9
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -458,5 +458,132 @@ class TestLlamaFileProcessorSimple(IsolatedAsyncioTestCase):
# No specific rate limit error handling tested since SLM presumably has no rate limits
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_with_model_override(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test generate_content with model parameter override"""
# Arrange
mock_openai_client = MagicMock()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Response from overridden model"
mock_response.usage.prompt_tokens = 15
mock_response.usage.completion_tokens = 10
mock_openai_client.chat.completions.create.return_value = mock_response
mock_openai_class.return_value = mock_openai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'LLaMA_CPP',
'llamafile': 'http://localhost:8080/v1',
'temperature': 0.0,
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override model
result = await processor.generate_content("System", "Prompt", model="custom-llamafile-model")
# Assert
assert result.model == "custom-llamafile-model" # Should use overridden model
assert result.text == "Response from overridden model"
# Verify the API call was made with overridden model
call_args = mock_openai_client.chat.completions.create.call_args
assert call_args[1]['model'] == "custom-llamafile-model"
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_with_temperature_override(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test generate_content with temperature parameter override"""
# Arrange
mock_openai_client = MagicMock()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Response with temperature override"
mock_response.usage.prompt_tokens = 18
mock_response.usage.completion_tokens = 12
mock_openai_client.chat.completions.create.return_value = mock_response
mock_openai_class.return_value = mock_openai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'LLaMA_CPP',
'llamafile': 'http://localhost:8080/v1',
'temperature': 0.0, # Default temperature
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override temperature
result = await processor.generate_content("System", "Prompt", temperature=0.7)
# Assert
assert result.text == "Response with temperature override"
# Verify the API call was made with overridden temperature
call_args = mock_openai_client.chat.completions.create.call_args
assert call_args[1]['temperature'] == 0.7
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_with_both_parameters_override(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test generate_content with both model and temperature overrides"""
# Arrange
mock_openai_client = MagicMock()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Response with both parameters override"
mock_response.usage.prompt_tokens = 20
mock_response.usage.completion_tokens = 15
mock_openai_client.chat.completions.create.return_value = mock_response
mock_openai_class.return_value = mock_openai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'LLaMA_CPP',
'llamafile': 'http://localhost:8080/v1',
'temperature': 0.0,
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override both parameters
result = await processor.generate_content("System", "Prompt", model="override-model", temperature=0.8)
# Assert
assert result.model == "override-model"
assert result.text == "Response with both parameters override"
# Verify the API call was made with overridden parameters
call_args = mock_openai_client.chat.completions.create.call_args
assert call_args[1]['model'] == "override-model"
assert call_args[1]['temperature'] == 0.8
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -0,0 +1,229 @@
"""
Unit tests for trustgraph.model.text_completion.lmstudio
Following the same successful pattern as previous tests
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
# Import the service under test
from trustgraph.model.text_completion.lmstudio.llm import Processor
from trustgraph.base import LlmResult
from trustgraph.exceptions import TooManyRequests
class TestLMStudioProcessorSimple(IsolatedAsyncioTestCase):
"""Test LMStudio processor functionality"""
@patch('trustgraph.model.text_completion.lmstudio.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test basic processor initialization"""
# Arrange
mock_openai = MagicMock()
mock_openai_class.return_value = mock_openai
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gemma3:9b',
'url': 'http://localhost:1234/',
'temperature': 0.0,
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act
processor = Processor(**config)
# Assert
assert processor.default_model == 'gemma3:9b'
assert processor.url == 'http://localhost:1234/v1/'
assert processor.temperature == 0.0
assert processor.max_output == 4096
assert hasattr(processor, 'openai')
mock_openai_class.assert_called_once_with(
base_url='http://localhost:1234/v1/',
api_key='sk-no-key-required'
)
@patch('trustgraph.model.text_completion.lmstudio.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test successful content generation"""
# Arrange
mock_openai = MagicMock()
mock_response = MagicMock()
mock_response.choices[0].message.content = 'Generated response from LMStudio'
mock_response.usage.prompt_tokens = 20
mock_response.usage.completion_tokens = 12
mock_openai.chat.completions.create.return_value = mock_response
mock_openai_class.return_value = mock_openai
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gemma3:9b',
'url': 'http://localhost:1234/',
'temperature': 0.0,
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("System prompt", "User prompt")
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Generated response from LMStudio"
assert result.in_token == 20
assert result.out_token == 12
assert result.model == 'gemma3:9b'
# Verify the API call was made correctly
mock_openai.chat.completions.create.assert_called_once()
call_args = mock_openai.chat.completions.create.call_args
# Check model and temperature
assert call_args[1]['model'] == 'gemma3:9b'
assert call_args[1]['temperature'] == 0.0
assert call_args[1]['max_tokens'] == 4096
@patch('trustgraph.model.text_completion.lmstudio.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_with_model_override(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test generate_content with model parameter override"""
# Arrange
mock_openai = MagicMock()
mock_response = MagicMock()
mock_response.choices[0].message.content = 'Response from overridden model'
mock_response.usage.prompt_tokens = 15
mock_response.usage.completion_tokens = 10
mock_openai.chat.completions.create.return_value = mock_response
mock_openai_class.return_value = mock_openai
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gemma3:9b',
'url': 'http://localhost:1234/',
'temperature': 0.0,
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override model
result = await processor.generate_content("System", "Prompt", model="custom-lmstudio-model")
# Assert
assert result.model == "custom-lmstudio-model" # Should use overridden model
assert result.text == "Response from overridden model"
# Verify the API call was made with overridden model
call_args = mock_openai.chat.completions.create.call_args
assert call_args[1]['model'] == "custom-lmstudio-model"
@patch('trustgraph.model.text_completion.lmstudio.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_with_temperature_override(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test generate_content with temperature parameter override"""
# Arrange
mock_openai = MagicMock()
mock_response = MagicMock()
mock_response.choices[0].message.content = 'Response with temperature override'
mock_response.usage.prompt_tokens = 18
mock_response.usage.completion_tokens = 12
mock_openai.chat.completions.create.return_value = mock_response
mock_openai_class.return_value = mock_openai
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gemma3:9b',
'url': 'http://localhost:1234/',
'temperature': 0.0, # Default temperature
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override temperature
result = await processor.generate_content("System", "Prompt", temperature=0.7)
# Assert
assert result.text == "Response with temperature override"
# Verify the API call was made with overridden temperature
call_args = mock_openai.chat.completions.create.call_args
assert call_args[1]['temperature'] == 0.7
@patch('trustgraph.model.text_completion.lmstudio.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_with_both_parameters_override(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test generate_content with both model and temperature overrides"""
# Arrange
mock_openai = MagicMock()
mock_response = MagicMock()
mock_response.choices[0].message.content = 'Response with both parameters override'
mock_response.usage.prompt_tokens = 20
mock_response.usage.completion_tokens = 15
mock_openai.chat.completions.create.return_value = mock_response
mock_openai_class.return_value = mock_openai
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gemma3:9b',
'url': 'http://localhost:1234/',
'temperature': 0.0,
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override both parameters
result = await processor.generate_content("System", "Prompt", model="override-model", temperature=0.8)
# Assert
assert result.model == "override-model"
assert result.text == "Response with both parameters override"
# Verify the API call was made with overridden parameters
call_args = mock_openai.chat.completions.create.call_args
assert call_args[1]['model'] == "override-model"
assert call_args[1]['temperature'] == 0.8
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -0,0 +1,275 @@
"""
Unit tests for trustgraph.model.text_completion.mistral
Following the same successful pattern as other processor tests
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
# Import the service under test
from trustgraph.model.text_completion.mistral.llm import Processor
from trustgraph.base import LlmResult
class TestMistralProcessorSimple(IsolatedAsyncioTestCase):
"""Test Mistral processor functionality"""
@patch('trustgraph.model.text_completion.mistral.llm.Mistral')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_mistral_class):
"""Test basic processor initialization"""
# Arrange
mock_mistral_client = MagicMock()
mock_mistral_class.return_value = mock_mistral_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'ministral-8b-latest',
'api_key': 'test-api-key',
'temperature': 0.1,
'max_output': 2048,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act
processor = Processor(**config)
# Assert
assert processor.default_model == 'ministral-8b-latest'
assert processor.temperature == 0.1
assert processor.max_output == 2048
assert hasattr(processor, 'mistral')
mock_mistral_class.assert_called_once_with(api_key='test-api-key')
@patch('trustgraph.model.text_completion.mistral.llm.Mistral')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_mistral_class):
"""Test successful content generation"""
# Arrange
mock_mistral_client = MagicMock()
mock_response = MagicMock()
mock_response.choices[0].message.content = 'Generated response from Mistral'
mock_response.usage.prompt_tokens = 15
mock_response.usage.completion_tokens = 8
mock_mistral_client.chat.complete.return_value = mock_response
mock_mistral_class.return_value = mock_mistral_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'ministral-8b-latest',
'api_key': 'test-api-key',
'temperature': 0.0,
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("System prompt", "User prompt")
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Generated response from Mistral"
assert result.in_token == 15
assert result.out_token == 8
assert result.model == 'ministral-8b-latest'
mock_mistral_client.chat.complete.assert_called_once()
@patch('trustgraph.model.text_completion.mistral.llm.Mistral')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_mistral_class):
"""Test temperature parameter override functionality"""
# Arrange
mock_mistral_client = MagicMock()
mock_response = MagicMock()
mock_response.choices[0].message.content = 'Response with custom temperature'
mock_response.usage.prompt_tokens = 20
mock_response.usage.completion_tokens = 12
mock_mistral_client.chat.complete.return_value = mock_response
mock_mistral_class.return_value = mock_mistral_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'ministral-8b-latest',
'api_key': 'test-api-key',
'temperature': 0.0, # Default temperature
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override temperature at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model=None, # Use default model
temperature=0.8 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom temperature"
# Verify Mistral API was called with overridden temperature
call_args = mock_mistral_client.chat.complete.call_args
assert call_args[1]['temperature'] == 0.8 # Should use runtime override
assert call_args[1]['model'] == 'ministral-8b-latest'
@patch('trustgraph.model.text_completion.mistral.llm.Mistral')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_mistral_class):
"""Test model parameter override functionality"""
# Arrange
mock_mistral_client = MagicMock()
mock_response = MagicMock()
mock_response.choices[0].message.content = 'Response with custom model'
mock_response.usage.prompt_tokens = 18
mock_response.usage.completion_tokens = 14
mock_mistral_client.chat.complete.return_value = mock_response
mock_mistral_class.return_value = mock_mistral_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'ministral-8b-latest', # Default model
'api_key': 'test-api-key',
'temperature': 0.1, # Default temperature
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override model at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="mistral-large-latest", # Override model
temperature=None # Use default temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom model"
# Verify Mistral API was called with overridden model
call_args = mock_mistral_client.chat.complete.call_args
assert call_args[1]['model'] == 'mistral-large-latest' # Should use runtime override
assert call_args[1]['temperature'] == 0.1 # Should use processor default
@patch('trustgraph.model.text_completion.mistral.llm.Mistral')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_mistral_class):
"""Test overriding both model and temperature parameters simultaneously"""
# Arrange
mock_mistral_client = MagicMock()
mock_response = MagicMock()
mock_response.choices[0].message.content = 'Response with both overrides'
mock_response.usage.prompt_tokens = 22
mock_response.usage.completion_tokens = 16
mock_mistral_client.chat.complete.return_value = mock_response
mock_mistral_class.return_value = mock_mistral_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'ministral-8b-latest', # Default model
'api_key': 'test-api-key',
'temperature': 0.0, # Default temperature
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override both parameters at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="mistral-large-latest", # Override model
temperature=0.9 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with both overrides"
# Verify Mistral API was called with both overrides
call_args = mock_mistral_client.chat.complete.call_args
assert call_args[1]['model'] == 'mistral-large-latest' # Should use runtime override
assert call_args[1]['temperature'] == 0.9 # Should use runtime override
@patch('trustgraph.model.text_completion.mistral.llm.Mistral')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_prompt_construction(self, mock_llm_init, mock_async_init, mock_mistral_class):
"""Test prompt construction with system and user prompts"""
# Arrange
mock_mistral_client = MagicMock()
mock_response = MagicMock()
mock_response.choices[0].message.content = 'Response with system instructions'
mock_response.usage.prompt_tokens = 25
mock_response.usage.completion_tokens = 15
mock_mistral_client.chat.complete.return_value = mock_response
mock_mistral_class.return_value = mock_mistral_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'ministral-8b-latest',
'api_key': 'test-api-key',
'temperature': 0.0,
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("You are a helpful assistant", "What is AI?")
# Assert
assert result.text == "Response with system instructions"
assert result.in_token == 25
assert result.out_token == 15
# Verify the combined prompt structure
call_args = mock_mistral_client.chat.complete.call_args
messages = call_args[1]['messages']
assert len(messages) == 1
assert messages[0]['role'] == 'user'
assert messages[0]['content'][0]['type'] == 'text'
assert messages[0]['content'][0]['text'] == "You are a helpful assistant\n\nWhat is AI?"
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -312,6 +312,150 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
# Verify the combined prompt
mock_client.generate.assert_called_once_with('llama2', "You are a helpful assistant\n\nWhat is AI?", options={'temperature': 0.0})
@patch('trustgraph.model.text_completion.ollama.llm.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test temperature parameter override functionality"""
# Arrange
mock_client = MagicMock()
mock_response = {
'response': 'Response with custom temperature',
'prompt_eval_count': 20,
'eval_count': 12
}
mock_client.generate.return_value = mock_response
mock_client_class.return_value = mock_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'llama2',
'ollama': 'http://localhost:11434',
'temperature': 0.0, # Default temperature
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override temperature at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model=None, # Use default model
temperature=0.8 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom temperature"
# Verify Ollama API was called with overridden temperature
mock_client.generate.assert_called_once_with(
'llama2',
"System prompt\n\nUser prompt",
options={'temperature': 0.8} # Should use runtime override
)
@patch('trustgraph.model.text_completion.ollama.llm.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test model parameter override functionality"""
# Arrange
mock_client = MagicMock()
mock_response = {
'response': 'Response with custom model',
'prompt_eval_count': 18,
'eval_count': 14
}
mock_client.generate.return_value = mock_response
mock_client_class.return_value = mock_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'llama2', # Default model
'ollama': 'http://localhost:11434',
'temperature': 0.1, # Default temperature
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override model at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="mistral", # Override model
temperature=None # Use default temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom model"
# Verify Ollama API was called with overridden model
mock_client.generate.assert_called_once_with(
'mistral', # Should use runtime override
"System prompt\n\nUser prompt",
options={'temperature': 0.1} # Should use processor default
)
@patch('trustgraph.model.text_completion.ollama.llm.Client')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_client_class):
"""Test overriding both model and temperature parameters simultaneously"""
# Arrange
mock_client = MagicMock()
mock_response = {
'response': 'Response with both overrides',
'prompt_eval_count': 22,
'eval_count': 16
}
mock_client.generate.return_value = mock_response
mock_client_class.return_value = mock_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'llama2', # Default model
'ollama': 'http://localhost:11434',
'temperature': 0.0, # Default temperature
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override both parameters at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="codellama", # Override model
temperature=0.9 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with both overrides"
# Verify Ollama API was called with both overrides
mock_client.generate.assert_called_once_with(
'codellama', # Should use runtime override
"System prompt\n\nUser prompt",
options={'temperature': 0.9} # Should use runtime override
)
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -391,5 +391,210 @@ class TestOpenAIProcessorSimple(IsolatedAsyncioTestCase):
assert call_args[1]['response_format'] == {"type": "text"}
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test temperature parameter override functionality"""
# Arrange
mock_openai_client = MagicMock()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Response with custom temperature"
mock_response.usage.prompt_tokens = 15
mock_response.usage.completion_tokens = 10
mock_openai_client.chat.completions.create.return_value = mock_response
mock_openai_class.return_value = mock_openai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gpt-3.5-turbo',
'api_key': 'test-api-key',
'url': 'https://api.openai.com/v1',
'temperature': 0.0, # Default temperature
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override temperature at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model=None, # Use default model
temperature=0.9 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom temperature"
# Verify the OpenAI API was called with overridden temperature
mock_openai_client.chat.completions.create.assert_called_once()
call_kwargs = mock_openai_client.chat.completions.create.call_args.kwargs
assert call_kwargs['temperature'] == 0.9 # Should use runtime override
assert call_kwargs['model'] == 'gpt-3.5-turbo' # Should use processor default
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test model parameter override functionality"""
# Arrange
mock_openai_client = MagicMock()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Response with custom model"
mock_response.usage.prompt_tokens = 15
mock_response.usage.completion_tokens = 10
mock_openai_client.chat.completions.create.return_value = mock_response
mock_openai_class.return_value = mock_openai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gpt-3.5-turbo', # Default model
'api_key': 'test-api-key',
'url': 'https://api.openai.com/v1',
'temperature': 0.2,
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override model at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="gpt-4", # Override model
temperature=None # Use default temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom model"
# Verify the OpenAI API was called with overridden model
mock_openai_client.chat.completions.create.assert_called_once()
call_kwargs = mock_openai_client.chat.completions.create.call_args.kwargs
assert call_kwargs['model'] == 'gpt-4' # Should use runtime override
assert call_kwargs['temperature'] == 0.2 # Should use processor default
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test overriding both model and temperature parameters simultaneously"""
# Arrange
mock_openai_client = MagicMock()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Response with both overrides"
mock_response.usage.prompt_tokens = 15
mock_response.usage.completion_tokens = 10
mock_openai_client.chat.completions.create.return_value = mock_response
mock_openai_class.return_value = mock_openai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gpt-3.5-turbo', # Default model
'api_key': 'test-api-key',
'url': 'https://api.openai.com/v1',
'temperature': 0.0, # Default temperature
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override both parameters at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="gpt-4", # Override model
temperature=0.7 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with both overrides"
# Verify the OpenAI API was called with both overrides
mock_openai_client.chat.completions.create.assert_called_once()
call_kwargs = mock_openai_client.chat.completions.create.call_args.kwargs
assert call_kwargs['model'] == 'gpt-4' # Should use runtime override
assert call_kwargs['temperature'] == 0.7 # Should use runtime override
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_no_override_uses_defaults(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test that when no parameters are overridden, processor defaults are used"""
# Arrange
mock_openai_client = MagicMock()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Response with defaults"
mock_response.usage.prompt_tokens = 15
mock_response.usage.completion_tokens = 10
mock_openai_client.chat.completions.create.return_value = mock_response
mock_openai_class.return_value = mock_openai_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gpt-4', # Default model
'api_key': 'test-api-key',
'url': 'https://api.openai.com/v1',
'temperature': 0.5, # Default temperature
'max_output': 4096,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Don't override any parameters (pass None)
result = await processor.generate_content(
"System prompt",
"User prompt",
model=None, # Use default model
temperature=None # Use default temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with defaults"
# Verify the OpenAI API was called with processor defaults
mock_openai_client.chat.completions.create.assert_called_once()
call_kwargs = mock_openai_client.chat.completions.create.call_args.kwargs
assert call_kwargs['model'] == 'gpt-4' # Should use processor default
assert call_kwargs['temperature'] == 0.5 # Should use processor default
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -0,0 +1,186 @@
"""
Unit tests for Parameter-Based Caching in LLM Processors
Testing processors that cache based on temperature parameters (Bedrock, GoogleAIStudio)
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
from trustgraph.model.text_completion.googleaistudio.llm import Processor as GoogleAIProcessor
from trustgraph.base import LlmResult
class TestParameterCaching(IsolatedAsyncioTestCase):
"""Test parameter-based caching functionality"""
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_googleai_temperature_cache_keys(self, mock_llm_init, mock_async_init, mock_genai):
"""Test that GoogleAI processor creates separate cache entries for different temperatures"""
# Arrange
mock_client = MagicMock()
mock_genai.Client.return_value = mock_client
mock_response = MagicMock()
mock_response.text = "Generated response"
mock_response.usage_metadata.prompt_token_count = 10
mock_response.usage_metadata.candidates_token_count = 5
mock_client.models.generate_content.return_value = mock_response
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gemini-2.0-flash-001',
'api_key': 'test-api-key',
'temperature': 0.0, # Default temperature
'max_output': 1024,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = GoogleAIProcessor(**config)
# Act - Call with different temperatures
await processor.generate_content("System", "Prompt 1", model="gemini-2.0-flash-001", temperature=0.0)
await processor.generate_content("System", "Prompt 2", model="gemini-2.0-flash-001", temperature=0.5)
await processor.generate_content("System", "Prompt 3", model="gemini-2.0-flash-001", temperature=1.0)
# Assert - Should have 3 different cache entries
cache_keys = list(processor.generation_configs.keys())
assert len(cache_keys) == 3
assert "gemini-2.0-flash-001:0.0" in cache_keys
assert "gemini-2.0-flash-001:0.5" in cache_keys
assert "gemini-2.0-flash-001:1.0" in cache_keys
# Verify each cached config has the correct temperature
assert processor.generation_configs["gemini-2.0-flash-001:0.0"].temperature == 0.0
assert processor.generation_configs["gemini-2.0-flash-001:0.5"].temperature == 0.5
assert processor.generation_configs["gemini-2.0-flash-001:1.0"].temperature == 1.0
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_googleai_cache_reuse_same_parameters(self, mock_llm_init, mock_async_init, mock_genai):
"""Test that GoogleAI processor reuses cache for identical model+temperature combinations"""
# Arrange
mock_client = MagicMock()
mock_genai.Client.return_value = mock_client
mock_response = MagicMock()
mock_response.text = "Generated response"
mock_response.usage_metadata.prompt_token_count = 10
mock_response.usage_metadata.candidates_token_count = 5
mock_client.models.generate_content.return_value = mock_response
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gemini-2.0-flash-001',
'api_key': 'test-api-key',
'temperature': 0.0,
'max_output': 1024,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = GoogleAIProcessor(**config)
# Act - Call multiple times with same parameters
await processor.generate_content("System", "Prompt 1", model="gemini-2.0-flash-001", temperature=0.7)
await processor.generate_content("System", "Prompt 2", model="gemini-2.0-flash-001", temperature=0.7)
await processor.generate_content("System", "Prompt 3", model="gemini-2.0-flash-001", temperature=0.7)
# Assert - Should have only 1 cache entry for the repeated parameters
cache_keys = list(processor.generation_configs.keys())
assert len(cache_keys) == 1
assert "gemini-2.0-flash-001:0.7" in cache_keys
# The same config object should be reused
config_obj = processor.generation_configs["gemini-2.0-flash-001:0.7"]
assert config_obj.temperature == 0.7
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_googleai_different_models_separate_caches(self, mock_llm_init, mock_async_init, mock_genai):
"""Test that different models create separate cache entries even with same temperature"""
# Arrange
mock_client = MagicMock()
mock_genai.Client.return_value = mock_client
mock_response = MagicMock()
mock_response.text = "Generated response"
mock_response.usage_metadata.prompt_token_count = 10
mock_response.usage_metadata.candidates_token_count = 5
mock_client.models.generate_content.return_value = mock_response
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'gemini-2.0-flash-001',
'api_key': 'test-api-key',
'temperature': 0.0,
'max_output': 1024,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = GoogleAIProcessor(**config)
# Act - Call with different models, same temperature
await processor.generate_content("System", "Prompt 1", model="gemini-2.0-flash-001", temperature=0.5)
await processor.generate_content("System", "Prompt 2", model="gemini-1.5-flash-001", temperature=0.5)
# Assert - Should have separate cache entries for different models
cache_keys = list(processor.generation_configs.keys())
assert len(cache_keys) == 2
assert "gemini-2.0-flash-001:0.5" in cache_keys
assert "gemini-1.5-flash-001:0.5" in cache_keys
# Note: Bedrock tests would be similar but testing the Bedrock processor's caching behavior
# The Bedrock processor caches model variants with temperature in the cache key
async def test_bedrock_temperature_cache_keys(self):
"""Test Bedrock processor temperature-aware caching"""
# This would test the Bedrock processor's _get_or_create_variant method
# with different temperature values to ensure proper cache key generation
# Implementation would follow similar pattern to GoogleAI tests above
# but using the Bedrock processor and testing model_variants cache
pass
async def test_bedrock_cache_isolation_different_temperatures(self):
"""Test that Bedrock processor isolates cache entries by temperature"""
pass
async def test_cache_memory_efficiency(self):
"""Test that caches don't grow unbounded with many different parameter combinations"""
# This could test cache size limits or cleanup behavior if implemented
pass
class TestCachePerformance(IsolatedAsyncioTestCase):
"""Test caching performance characteristics"""
async def test_cache_hit_performance(self):
"""Test that cache hits are faster than cache misses"""
# This would measure timing differences between cache hits and misses
pass
async def test_concurrent_cache_access(self):
"""Test concurrent access to cached configurations"""
# This would test thread-safety of cache access
pass
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -0,0 +1,271 @@
"""
Unit tests for trustgraph.model.text_completion.tgi
Following the same successful pattern as previous tests
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
# Import the service under test
from trustgraph.model.text_completion.tgi.llm import Processor
from trustgraph.base import LlmResult
from trustgraph.exceptions import TooManyRequests
class TestTGIProcessorSimple(IsolatedAsyncioTestCase):
"""Test TGI processor functionality"""
@patch('trustgraph.model.text_completion.tgi.llm.aiohttp.ClientSession')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_session_class):
"""Test basic processor initialization"""
# Arrange
mock_session = MagicMock()
mock_session_class.return_value = mock_session
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'tgi',
'url': 'http://tgi-service:8899/v1',
'temperature': 0.0,
'max_output': 2048,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
# Act
processor = Processor(**config)
# Assert
assert processor.default_model == 'tgi'
assert processor.base_url == 'http://tgi-service:8899/v1'
assert processor.temperature == 0.0
assert processor.max_output == 2048
assert hasattr(processor, 'session')
mock_session_class.assert_called_once()
@patch('trustgraph.model.text_completion.tgi.llm.aiohttp.ClientSession')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_session_class):
"""Test successful content generation"""
# Arrange
mock_session = MagicMock()
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
'choices': [{
'message': {
'content': 'Generated response from TGI'
}
}],
'usage': {
'prompt_tokens': 20,
'completion_tokens': 12
}
})
# Mock the async context manager
mock_session.post.return_value.__aenter__.return_value = mock_response
mock_session.post.return_value.__aexit__.return_value = None
mock_session_class.return_value = mock_session
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'tgi',
'url': 'http://tgi-service:8899/v1',
'temperature': 0.0,
'max_output': 2048,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act
result = await processor.generate_content("System prompt", "User prompt")
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Generated response from TGI"
assert result.in_token == 20
assert result.out_token == 12
assert result.model == 'tgi'
# Verify the API call was made correctly
mock_session.post.assert_called_once()
call_args = mock_session.post.call_args
# Check URL
assert call_args[0][0] == 'http://tgi-service:8899/v1/chat/completions'
# Check request structure
request_body = call_args[1]['json']
assert request_body['model'] == 'tgi'
assert request_body['temperature'] == 0.0
assert request_body['max_tokens'] == 2048
@patch('trustgraph.model.text_completion.tgi.llm.aiohttp.ClientSession')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_with_model_override(self, mock_llm_init, mock_async_init, mock_session_class):
"""Test generate_content with model parameter override"""
# Arrange
mock_session = MagicMock()
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
'choices': [{
'message': {
'content': 'Response from overridden model'
}
}],
'usage': {
'prompt_tokens': 15,
'completion_tokens': 10
}
})
mock_session.post.return_value.__aenter__.return_value = mock_response
mock_session.post.return_value.__aexit__.return_value = None
mock_session_class.return_value = mock_session
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'tgi',
'url': 'http://tgi-service:8899/v1',
'temperature': 0.0,
'max_output': 2048,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override model
result = await processor.generate_content("System", "Prompt", model="custom-tgi-model")
# Assert
assert result.model == "custom-tgi-model" # Should use overridden model
assert result.text == "Response from overridden model"
# Verify the API call was made with overridden model
call_args = mock_session.post.call_args
assert call_args[1]['json']['model'] == "custom-tgi-model"
@patch('trustgraph.model.text_completion.tgi.llm.aiohttp.ClientSession')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_with_temperature_override(self, mock_llm_init, mock_async_init, mock_session_class):
"""Test generate_content with temperature parameter override"""
# Arrange
mock_session = MagicMock()
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
'choices': [{
'message': {
'content': 'Response with temperature override'
}
}],
'usage': {
'prompt_tokens': 18,
'completion_tokens': 12
}
})
mock_session.post.return_value.__aenter__.return_value = mock_response
mock_session.post.return_value.__aexit__.return_value = None
mock_session_class.return_value = mock_session
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'tgi',
'url': 'http://tgi-service:8899/v1',
'temperature': 0.0, # Default temperature
'max_output': 2048,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override temperature
result = await processor.generate_content("System", "Prompt", temperature=0.7)
# Assert
assert result.text == "Response with temperature override"
# Verify the API call was made with overridden temperature
call_args = mock_session.post.call_args
assert call_args[1]['json']['temperature'] == 0.7
@patch('trustgraph.model.text_completion.tgi.llm.aiohttp.ClientSession')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_with_both_parameters_override(self, mock_llm_init, mock_async_init, mock_session_class):
"""Test generate_content with both model and temperature overrides"""
# Arrange
mock_session = MagicMock()
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
'choices': [{
'message': {
'content': 'Response with both parameters override'
}
}],
'usage': {
'prompt_tokens': 20,
'completion_tokens': 15
}
})
mock_session.post.return_value.__aenter__.return_value = mock_response
mock_session.post.return_value.__aexit__.return_value = None
mock_session_class.return_value = mock_session
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'tgi',
'url': 'http://tgi-service:8899/v1',
'temperature': 0.0,
'max_output': 2048,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override both parameters
result = await processor.generate_content("System", "Prompt", model="override-model", temperature=0.8)
# Assert
assert result.model == "override-model"
assert result.text == "Response with both parameters override"
# Verify the API call was made with overridden parameters
call_args = mock_session.post.call_args
assert call_args[1]['json']['model'] == "override-model"
assert call_args[1]['json']['temperature'] == 0.8
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -460,6 +460,180 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
assert processor.api_params["top_p"] == 1.0
assert processor.api_params["top_k"] == 32
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
"""Test temperature parameter override functionality"""
# Arrange
mock_credentials = MagicMock()
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
mock_model = MagicMock()
mock_response = MagicMock()
mock_response.text = "Response with custom temperature"
mock_response.usage_metadata.prompt_token_count = 20
mock_response.usage_metadata.candidates_token_count = 12
mock_model.generate_content.return_value = mock_response
mock_generative_model.return_value = mock_model
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'region': 'us-central1',
'model': 'gemini-2.0-flash-001',
'temperature': 0.0, # Default temperature
'max_output': 8192,
'private_key': 'private.json',
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override temperature at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model=None, # Use default model
temperature=0.8 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom temperature"
# Verify Gemini API was called with overridden temperature
mock_model.generate_content.assert_called_once()
call_args = mock_model.generate_content.call_args
# Check that generation_config was created (we can't directly access temperature from mock)
generation_config = call_args.kwargs['generation_config']
assert generation_config is not None # Should use overridden temperature configuration
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
"""Test model parameter override functionality"""
# Arrange
mock_credentials = MagicMock()
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
# Mock different models
mock_model_default = MagicMock()
mock_model_override = MagicMock()
mock_response = MagicMock()
mock_response.text = "Response with custom model"
mock_response.usage_metadata.prompt_token_count = 18
mock_response.usage_metadata.candidates_token_count = 14
mock_model_override.generate_content.return_value = mock_response
# GenerativeModel should return different models based on input
def model_factory(model_name):
if model_name == 'gemini-1.5-pro':
return mock_model_override
return mock_model_default
mock_generative_model.side_effect = model_factory
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'region': 'us-central1',
'model': 'gemini-2.0-flash-001', # Default model
'temperature': 0.2, # Default temperature
'max_output': 8192,
'private_key': 'private.json',
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override model at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="gemini-1.5-pro", # Override model
temperature=None # Use default temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom model"
# Verify the overridden model was used
mock_model_override.generate_content.assert_called_once()
# Verify GenerativeModel was called with the override model
mock_generative_model.assert_called_with('gemini-1.5-pro')
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
"""Test overriding both model and temperature parameters simultaneously"""
# Arrange
mock_credentials = MagicMock()
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
mock_model = MagicMock()
mock_response = MagicMock()
mock_response.text = "Response with both overrides"
mock_response.usage_metadata.prompt_token_count = 22
mock_response.usage_metadata.candidates_token_count = 16
mock_model.generate_content.return_value = mock_response
mock_generative_model.return_value = mock_model
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'region': 'us-central1',
'model': 'gemini-2.0-flash-001', # Default model
'temperature': 0.0, # Default temperature
'max_output': 8192,
'private_key': 'private.json',
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override both parameters at runtime
result = await processor.generate_content(
"System prompt",
"User prompt",
model="gemini-1.5-flash-001", # Override model
temperature=0.9 # Override temperature
)
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with both overrides"
# Verify both overrides were used
mock_model.generate_content.assert_called_once()
call_args = mock_model.generate_content.call_args
# Verify model override
mock_generative_model.assert_called_with('gemini-1.5-flash-001') # Should use runtime override
# Verify temperature override (we can't directly access temperature from mock)
generation_config = call_args.kwargs['generation_config']
assert generation_config is not None # Should use overridden temperature configuration
if __name__ == '__main__':
pytest.main([__file__])

View file

@ -485,5 +485,148 @@ class TestVLLMProcessorSimple(IsolatedAsyncioTestCase):
assert call_args[1]['json']['prompt'] == expected_prompt
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_with_model_override(self, mock_llm_init, mock_async_init, mock_session_class):
"""Test generate_content with model parameter override"""
# Arrange
mock_session = MagicMock()
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
'choices': [{
'text': 'Response from overridden model'
}],
'usage': {
'prompt_tokens': 12,
'completion_tokens': 8
}
})
mock_session.post.return_value.__aenter__.return_value = mock_response
mock_session.post.return_value.__aexit__.return_value = None
mock_session_class.return_value = mock_session
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
'url': 'http://vllm-service:8899/v1',
'temperature': 0.0,
'max_output': 2048,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override model
result = await processor.generate_content("System", "Prompt", model="custom-vllm-model")
# Assert
assert result.model == "custom-vllm-model" # Should use overridden model
assert result.text == "Response from overridden model"
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_with_temperature_override(self, mock_llm_init, mock_async_init, mock_session_class):
"""Test generate_content with temperature parameter override"""
# Arrange
mock_session = MagicMock()
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
'choices': [{
'text': 'Response with temperature override'
}],
'usage': {
'prompt_tokens': 15,
'completion_tokens': 10
}
})
mock_session.post.return_value.__aenter__.return_value = mock_response
mock_session.post.return_value.__aexit__.return_value = None
mock_session_class.return_value = mock_session
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
'url': 'http://vllm-service:8899/v1',
'temperature': 0.0, # Default temperature
'max_output': 2048,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override temperature
result = await processor.generate_content("System", "Prompt", temperature=0.7)
# Assert
assert result.text == "Response with temperature override"
# Verify the request was made with overridden temperature
call_args = mock_session.post.call_args
assert call_args[1]['json']['temperature'] == 0.7
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_with_both_parameters_override(self, mock_llm_init, mock_async_init, mock_session_class):
"""Test generate_content with both model and temperature overrides"""
# Arrange
mock_session = MagicMock()
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
'choices': [{
'text': 'Response with both parameters override'
}],
'usage': {
'prompt_tokens': 18,
'completion_tokens': 12
}
})
mock_session.post.return_value.__aenter__.return_value = mock_response
mock_session.post.return_value.__aexit__.return_value = None
mock_session_class.return_value = mock_session
mock_async_init.return_value = None
mock_llm_init.return_value = None
config = {
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
'url': 'http://vllm-service:8899/v1',
'temperature': 0.0,
'max_output': 2048,
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-processor'
}
processor = Processor(**config)
# Act - Override both parameters
result = await processor.generate_content("System", "Prompt", model="override-model", temperature=0.8)
# Assert
assert result.model == "override-model"
assert result.text == "Response with both parameters override"
# Verify the request was made with overridden temperature
call_args = mock_session.post.call_args
assert call_args[1]['json']['temperature'] == 0.8
if __name__ == '__main__':
pytest.main([__file__])