mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 08:26:21 +02:00
parent
b0a3716b0e
commit
43cfcb18a0
18 changed files with 3563 additions and 0 deletions
|
|
@ -402,6 +402,156 @@ class TestAzureOpenAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert call_args[1]['max_tokens'] == 1024
|
||||
assert call_args[1]['top_p'] == 1
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
|
||||
"""Test temperature parameter override functionality"""
|
||||
# Arrange
|
||||
mock_azure_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = 'Response with custom temperature'
|
||||
mock_response.usage.prompt_tokens = 20
|
||||
mock_response.usage.completion_tokens = 12
|
||||
|
||||
mock_azure_client.chat.completions.create.return_value = mock_response
|
||||
mock_azure_openai_class.return_value = mock_azure_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-4',
|
||||
'endpoint': 'https://test.openai.azure.com/',
|
||||
'token': 'test-token',
|
||||
'api_version': '2024-12-01-preview',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model=None, # Use default model
|
||||
temperature=0.8 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom temperature"
|
||||
|
||||
# Verify Azure OpenAI API was called with overridden temperature
|
||||
call_args = mock_azure_client.chat.completions.create.call_args
|
||||
assert call_args[1]['temperature'] == 0.8 # Should use runtime override
|
||||
assert call_args[1]['model'] == 'gpt-4'
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
|
||||
"""Test model parameter override functionality"""
|
||||
# Arrange
|
||||
mock_azure_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = 'Response with custom model'
|
||||
mock_response.usage.prompt_tokens = 18
|
||||
mock_response.usage.completion_tokens = 14
|
||||
|
||||
mock_azure_client.chat.completions.create.return_value = mock_response
|
||||
mock_azure_openai_class.return_value = mock_azure_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-4', # Default model
|
||||
'endpoint': 'https://test.openai.azure.com/',
|
||||
'token': 'test-token',
|
||||
'api_version': '2024-12-01-preview',
|
||||
'temperature': 0.1, # Default temperature
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="gpt-4o", # Override model
|
||||
temperature=None # Use default temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom model"
|
||||
|
||||
# Verify Azure OpenAI API was called with overridden model
|
||||
call_args = mock_azure_client.chat.completions.create.call_args
|
||||
assert call_args[1]['model'] == 'gpt-4o' # Should use runtime override
|
||||
assert call_args[1]['temperature'] == 0.1 # Should use processor default
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
|
||||
"""Test overriding both model and temperature parameters simultaneously"""
|
||||
# Arrange
|
||||
mock_azure_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = 'Response with both overrides'
|
||||
mock_response.usage.prompt_tokens = 22
|
||||
mock_response.usage.completion_tokens = 16
|
||||
|
||||
mock_azure_client.chat.completions.create.return_value = mock_response
|
||||
mock_azure_openai_class.return_value = mock_azure_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-4', # Default model
|
||||
'endpoint': 'https://test.openai.azure.com/',
|
||||
'token': 'test-token',
|
||||
'api_version': '2024-12-01-preview',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="gpt-4o-mini", # Override model
|
||||
temperature=0.9 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with both overrides"
|
||||
|
||||
# Verify Azure OpenAI API was called with both overrides
|
||||
call_args = mock_azure_client.chat.completions.create.call_args
|
||||
assert call_args[1]['model'] == 'gpt-4o-mini' # Should use runtime override
|
||||
assert call_args[1]['temperature'] == 0.9 # Should use runtime override
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
@ -459,5 +459,150 @@ class TestAzureProcessorSimple(IsolatedAsyncioTestCase):
|
|||
)
|
||||
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_model_override(self, mock_llm_init, mock_async_init, mock_requests):
|
||||
"""Test generate_content with model parameter override"""
|
||||
# Arrange
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
'choices': [{
|
||||
'message': {
|
||||
'content': 'Response with model override'
|
||||
}
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 15,
|
||||
'completion_tokens': 10
|
||||
}
|
||||
}
|
||||
mock_requests.post.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
|
||||
'token': 'test-token',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model
|
||||
result = await processor.generate_content("System", "Prompt", model="custom-azure-model")
|
||||
|
||||
# Assert
|
||||
assert result.model == "custom-azure-model" # Should use overridden model
|
||||
assert result.text == "Response with model override"
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_temperature_override(self, mock_llm_init, mock_async_init, mock_requests):
|
||||
"""Test generate_content with temperature parameter override"""
|
||||
# Arrange
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
'choices': [{
|
||||
'message': {
|
||||
'content': 'Response with temperature override'
|
||||
}
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 15,
|
||||
'completion_tokens': 10
|
||||
}
|
||||
}
|
||||
mock_requests.post.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
|
||||
'token': 'test-token',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature
|
||||
result = await processor.generate_content("System", "Prompt", temperature=0.8)
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with temperature override"
|
||||
|
||||
# Verify the request was made with the overridden temperature
|
||||
mock_requests.post.assert_called_once()
|
||||
call_args = mock_requests.post.call_args
|
||||
|
||||
import json
|
||||
request_body = json.loads(call_args[1]['data'])
|
||||
assert request_body['temperature'] == 0.8
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_both_parameters_override(self, mock_llm_init, mock_async_init, mock_requests):
|
||||
"""Test generate_content with both model and temperature overrides"""
|
||||
# Arrange
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
'choices': [{
|
||||
'message': {
|
||||
'content': 'Response with both parameters override'
|
||||
}
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 18,
|
||||
'completion_tokens': 12
|
||||
}
|
||||
}
|
||||
mock_requests.post.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
|
||||
'token': 'test-token',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters
|
||||
result = await processor.generate_content("System", "Prompt", model="override-model", temperature=0.9)
|
||||
|
||||
# Assert
|
||||
assert result.model == "override-model"
|
||||
assert result.text == "Response with both parameters override"
|
||||
|
||||
# Verify the request was made with overridden temperature
|
||||
mock_requests.post.assert_called_once()
|
||||
call_args = mock_requests.post.call_args
|
||||
|
||||
import json
|
||||
request_body = json.loads(call_args[1]['data'])
|
||||
assert request_body['temperature'] == 0.9
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
280
tests/unit/test_text_completion/test_bedrock_processor.py
Normal file
280
tests/unit/test_text_completion/test_bedrock_processor.py
Normal file
|
|
@ -0,0 +1,280 @@
|
|||
"""
|
||||
Unit tests for trustgraph.model.text_completion.bedrock
|
||||
Following the same successful pattern as other processor tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
import json
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.model.text_completion.bedrock.llm import Processor, Mistral, Anthropic
|
||||
from trustgraph.base import LlmResult
|
||||
|
||||
|
||||
class TestBedrockProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Test Bedrock processor functionality"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.bedrock.llm.boto3.Session')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test basic processor initialization"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_bedrock = MagicMock()
|
||||
mock_session.client.return_value = mock_bedrock
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'mistral.mistral-large-2407-v1:0',
|
||||
'temperature': 0.1,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.default_model == 'mistral.mistral-large-2407-v1:0'
|
||||
assert processor.temperature == 0.1
|
||||
assert hasattr(processor, 'bedrock')
|
||||
mock_session_class.assert_called_once()
|
||||
|
||||
@patch('trustgraph.model.text_completion.bedrock.llm.boto3.Session')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_success_mistral(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test successful content generation with Mistral model"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_bedrock = MagicMock()
|
||||
mock_session.client.return_value = mock_bedrock
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_response = {
|
||||
'body': MagicMock(),
|
||||
'ResponseMetadata': {
|
||||
'HTTPHeaders': {
|
||||
'x-amzn-bedrock-input-token-count': '15',
|
||||
'x-amzn-bedrock-output-token-count': '8'
|
||||
}
|
||||
}
|
||||
}
|
||||
mock_response['body'].read.return_value = json.dumps({
|
||||
'outputs': [{'text': 'Generated response from Bedrock'}]
|
||||
})
|
||||
mock_bedrock.invoke_model.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'mistral.mistral-large-2407-v1:0',
|
||||
'temperature': 0.0,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Generated response from Bedrock"
|
||||
assert result.in_token == 15
|
||||
assert result.out_token == 8
|
||||
assert result.model == 'mistral.mistral-large-2407-v1:0'
|
||||
mock_bedrock.invoke_model.assert_called_once()
|
||||
|
||||
@patch('trustgraph.model.text_completion.bedrock.llm.boto3.Session')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test temperature parameter override functionality"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_bedrock = MagicMock()
|
||||
mock_session.client.return_value = mock_bedrock
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_response = {
|
||||
'body': MagicMock(),
|
||||
'ResponseMetadata': {
|
||||
'HTTPHeaders': {
|
||||
'x-amzn-bedrock-input-token-count': '20',
|
||||
'x-amzn-bedrock-output-token-count': '12'
|
||||
}
|
||||
}
|
||||
}
|
||||
mock_response['body'].read.return_value = json.dumps({
|
||||
'outputs': [{'text': 'Response with custom temperature'}]
|
||||
})
|
||||
mock_bedrock.invoke_model.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'mistral.mistral-large-2407-v1:0',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model=None, # Use default model
|
||||
temperature=0.8 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom temperature"
|
||||
|
||||
# Verify the model variant was created with overridden temperature
|
||||
# The cache key should include the temperature
|
||||
cache_key = f"mistral.mistral-large-2407-v1:0:0.8"
|
||||
assert cache_key in processor.model_variants
|
||||
variant = processor.model_variants[cache_key]
|
||||
assert variant.temperature == 0.8
|
||||
|
||||
@patch('trustgraph.model.text_completion.bedrock.llm.boto3.Session')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test model parameter override functionality"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_bedrock = MagicMock()
|
||||
mock_session.client.return_value = mock_bedrock
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_response = {
|
||||
'body': MagicMock(),
|
||||
'ResponseMetadata': {
|
||||
'HTTPHeaders': {
|
||||
'x-amzn-bedrock-input-token-count': '18',
|
||||
'x-amzn-bedrock-output-token-count': '14'
|
||||
}
|
||||
}
|
||||
}
|
||||
mock_response['body'].read.return_value = json.dumps({
|
||||
'content': [{'text': 'Response with custom model'}]
|
||||
})
|
||||
mock_bedrock.invoke_model.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'mistral.mistral-large-2407-v1:0', # Default model
|
||||
'temperature': 0.1, # Default temperature
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="anthropic.claude-3-sonnet-20240229-v1:0", # Override model
|
||||
temperature=None # Use default temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom model"
|
||||
|
||||
# Verify Bedrock API was called with overridden model
|
||||
mock_bedrock.invoke_model.assert_called_once()
|
||||
call_args = mock_bedrock.invoke_model.call_args
|
||||
assert call_args[1]['modelId'] == "anthropic.claude-3-sonnet-20240229-v1:0"
|
||||
|
||||
# Verify the correct model variant (Anthropic) was used
|
||||
cache_key = f"anthropic.claude-3-sonnet-20240229-v1:0:0.1"
|
||||
assert cache_key in processor.model_variants
|
||||
variant = processor.model_variants[cache_key]
|
||||
assert isinstance(variant, Anthropic)
|
||||
|
||||
@patch('trustgraph.model.text_completion.bedrock.llm.boto3.Session')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test overriding both model and temperature parameters simultaneously"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_bedrock = MagicMock()
|
||||
mock_session.client.return_value = mock_bedrock
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_response = {
|
||||
'body': MagicMock(),
|
||||
'ResponseMetadata': {
|
||||
'HTTPHeaders': {
|
||||
'x-amzn-bedrock-input-token-count': '22',
|
||||
'x-amzn-bedrock-output-token-count': '16'
|
||||
}
|
||||
}
|
||||
}
|
||||
mock_response['body'].read.return_value = json.dumps({
|
||||
'generation': 'Response with both overrides'
|
||||
})
|
||||
mock_bedrock.invoke_model.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'mistral.mistral-large-2407-v1:0', # Default model
|
||||
'temperature': 0.0, # Default temperature
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="meta.llama3-70b-instruct-v1:0", # Override model (Meta/Llama)
|
||||
temperature=0.9 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with both overrides"
|
||||
|
||||
# Verify Bedrock API was called with both overrides
|
||||
mock_bedrock.invoke_model.assert_called_once()
|
||||
call_args = mock_bedrock.invoke_model.call_args
|
||||
assert call_args[1]['modelId'] == "meta.llama3-70b-instruct-v1:0"
|
||||
|
||||
# Verify the correct model variant (Meta) was used with correct temperature
|
||||
cache_key = f"meta.llama3-70b-instruct-v1:0:0.9"
|
||||
assert cache_key in processor.model_variants
|
||||
variant = processor.model_variants[cache_key]
|
||||
assert variant.temperature == 0.9
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
@ -435,6 +435,156 @@ class TestClaudeProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert processor.claude == mock_claude_client
|
||||
assert processor.default_model == 'claude-3-opus-20240229'
|
||||
|
||||
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_anthropic_class):
|
||||
"""Test temperature parameter override functionality"""
|
||||
# Arrange
|
||||
mock_claude_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock()]
|
||||
mock_response.content[0].text = "Response with custom temperature"
|
||||
mock_response.usage.input_tokens = 20
|
||||
mock_response.usage.output_tokens = 12
|
||||
|
||||
mock_claude_client.messages.create.return_value = mock_response
|
||||
mock_anthropic_class.return_value = mock_claude_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'claude-3-5-sonnet-20240620',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model=None, # Use default model
|
||||
temperature=0.9 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom temperature"
|
||||
|
||||
# Verify Claude API was called with overridden temperature
|
||||
mock_claude_client.messages.create.assert_called_once()
|
||||
call_kwargs = mock_claude_client.messages.create.call_args.kwargs
|
||||
|
||||
assert call_kwargs['temperature'] == 0.9 # Should use runtime override
|
||||
assert call_kwargs['model'] == 'claude-3-5-sonnet-20240620' # Should use processor default
|
||||
|
||||
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_anthropic_class):
|
||||
"""Test model parameter override functionality"""
|
||||
# Arrange
|
||||
mock_claude_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock()]
|
||||
mock_response.content[0].text = "Response with custom model"
|
||||
mock_response.usage.input_tokens = 18
|
||||
mock_response.usage.output_tokens = 14
|
||||
|
||||
mock_claude_client.messages.create.return_value = mock_response
|
||||
mock_anthropic_class.return_value = mock_claude_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'claude-3-5-sonnet-20240620', # Default model
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.2, # Default temperature
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="claude-3-haiku-20240307", # Override model
|
||||
temperature=None # Use default temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom model"
|
||||
|
||||
# Verify Claude API was called with overridden model
|
||||
mock_claude_client.messages.create.assert_called_once()
|
||||
call_kwargs = mock_claude_client.messages.create.call_args.kwargs
|
||||
|
||||
assert call_kwargs['model'] == 'claude-3-haiku-20240307' # Should use runtime override
|
||||
assert call_kwargs['temperature'] == 0.2 # Should use processor default
|
||||
|
||||
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_anthropic_class):
|
||||
"""Test overriding both model and temperature parameters simultaneously"""
|
||||
# Arrange
|
||||
mock_claude_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock()]
|
||||
mock_response.content[0].text = "Response with both overrides"
|
||||
mock_response.usage.input_tokens = 22
|
||||
mock_response.usage.output_tokens = 16
|
||||
|
||||
mock_claude_client.messages.create.return_value = mock_response
|
||||
mock_anthropic_class.return_value = mock_claude_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'claude-3-5-sonnet-20240620', # Default model
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="claude-3-opus-20240229", # Override model
|
||||
temperature=0.8 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with both overrides"
|
||||
|
||||
# Verify Claude API was called with both overrides
|
||||
mock_claude_client.messages.create.assert_called_once()
|
||||
call_kwargs = mock_claude_client.messages.create.call_args.kwargs
|
||||
|
||||
assert call_kwargs['model'] == 'claude-3-opus-20240229' # Should use runtime override
|
||||
assert call_kwargs['temperature'] == 0.8 # Should use runtime override
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
@ -442,6 +442,162 @@ class TestCohereProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert call_args[1]['prompt_truncation'] == 'auto'
|
||||
assert call_args[1]['connectors'] == []
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_cohere_class):
|
||||
"""Test temperature parameter override functionality"""
|
||||
# Arrange
|
||||
mock_cohere_client = MagicMock()
|
||||
mock_output = MagicMock()
|
||||
mock_output.text = 'Response with custom temperature'
|
||||
mock_output.meta.billed_units.input_tokens = 20
|
||||
mock_output.meta.billed_units.output_tokens = 12
|
||||
|
||||
mock_cohere_client.chat.return_value = mock_output
|
||||
mock_cohere_class.return_value = mock_cohere_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'c4ai-aya-23-8b',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model=None, # Use default model
|
||||
temperature=0.8 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom temperature"
|
||||
|
||||
# Verify Cohere API was called with overridden temperature
|
||||
mock_cohere_client.chat.assert_called_once_with(
|
||||
model='c4ai-aya-23-8b',
|
||||
message='User prompt',
|
||||
preamble='System prompt',
|
||||
temperature=0.8, # Should use runtime override
|
||||
chat_history=[],
|
||||
prompt_truncation='auto',
|
||||
connectors=[]
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_cohere_class):
|
||||
"""Test model parameter override functionality"""
|
||||
# Arrange
|
||||
mock_cohere_client = MagicMock()
|
||||
mock_output = MagicMock()
|
||||
mock_output.text = 'Response with custom model'
|
||||
mock_output.meta.billed_units.input_tokens = 18
|
||||
mock_output.meta.billed_units.output_tokens = 14
|
||||
|
||||
mock_cohere_client.chat.return_value = mock_output
|
||||
mock_cohere_class.return_value = mock_cohere_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'c4ai-aya-23-8b', # Default model
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.1, # Default temperature
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="command-r-plus", # Override model
|
||||
temperature=None # Use default temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom model"
|
||||
|
||||
# Verify Cohere API was called with overridden model
|
||||
mock_cohere_client.chat.assert_called_once_with(
|
||||
model='command-r-plus', # Should use runtime override
|
||||
message='User prompt',
|
||||
preamble='System prompt',
|
||||
temperature=0.1, # Should use processor default
|
||||
chat_history=[],
|
||||
prompt_truncation='auto',
|
||||
connectors=[]
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_cohere_class):
|
||||
"""Test overriding both model and temperature parameters simultaneously"""
|
||||
# Arrange
|
||||
mock_cohere_client = MagicMock()
|
||||
mock_output = MagicMock()
|
||||
mock_output.text = 'Response with both overrides'
|
||||
mock_output.meta.billed_units.input_tokens = 22
|
||||
mock_output.meta.billed_units.output_tokens = 16
|
||||
|
||||
mock_cohere_client.chat.return_value = mock_output
|
||||
mock_cohere_class.return_value = mock_cohere_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'c4ai-aya-23-8b', # Default model
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="command-r", # Override model
|
||||
temperature=0.9 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with both overrides"
|
||||
|
||||
# Verify Cohere API was called with both overrides
|
||||
mock_cohere_client.chat.assert_called_once_with(
|
||||
model='command-r', # Should use runtime override
|
||||
message='User prompt',
|
||||
preamble='System prompt',
|
||||
temperature=0.9, # Should use runtime override
|
||||
chat_history=[],
|
||||
prompt_truncation='auto',
|
||||
connectors=[]
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
@ -477,6 +477,156 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
|
|||
# The system instruction should be in the config object
|
||||
assert call_args[1]['contents'] == "Explain quantum computing"
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test temperature parameter override functionality"""
|
||||
# Arrange
|
||||
mock_genai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = 'Response with custom temperature'
|
||||
mock_response.usage_metadata.prompt_token_count = 20
|
||||
mock_response.usage_metadata.candidates_token_count = 12
|
||||
|
||||
mock_genai_client.models.generate_content.return_value = mock_response
|
||||
mock_genai_class.return_value = mock_genai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model=None, # Use default model
|
||||
temperature=0.8 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom temperature"
|
||||
|
||||
# Verify the generation config was created with overridden temperature
|
||||
cache_key = f"gemini-2.0-flash-001:0.8"
|
||||
assert cache_key in processor.generation_configs
|
||||
config_obj = processor.generation_configs[cache_key]
|
||||
assert config_obj.temperature == 0.8
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test model parameter override functionality"""
|
||||
# Arrange
|
||||
mock_genai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = 'Response with custom model'
|
||||
mock_response.usage_metadata.prompt_token_count = 18
|
||||
mock_response.usage_metadata.candidates_token_count = 14
|
||||
|
||||
mock_genai_client.models.generate_content.return_value = mock_response
|
||||
mock_genai_class.return_value = mock_genai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001', # Default model
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.1, # Default temperature
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="gemini-1.5-pro", # Override model
|
||||
temperature=None # Use default temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom model"
|
||||
|
||||
# Verify Google AI Studio API was called with overridden model
|
||||
call_args = mock_genai_client.models.generate_content.call_args
|
||||
assert call_args[1]['model'] == 'gemini-1.5-pro' # Should use runtime override
|
||||
|
||||
# Verify the generation config was created for the correct model
|
||||
cache_key = f"gemini-1.5-pro:0.1"
|
||||
assert cache_key in processor.generation_configs
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test overriding both model and temperature parameters simultaneously"""
|
||||
# Arrange
|
||||
mock_genai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = 'Response with both overrides'
|
||||
mock_response.usage_metadata.prompt_token_count = 22
|
||||
mock_response.usage_metadata.candidates_token_count = 16
|
||||
|
||||
mock_genai_client.models.generate_content.return_value = mock_response
|
||||
mock_genai_class.return_value = mock_genai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001', # Default model
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="gemini-1.5-flash", # Override model
|
||||
temperature=0.9 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with both overrides"
|
||||
|
||||
# Verify Google AI Studio API was called with both overrides
|
||||
call_args = mock_genai_client.models.generate_content.call_args
|
||||
assert call_args[1]['model'] == 'gemini-1.5-flash' # Should use runtime override
|
||||
|
||||
# Verify the generation config was created with both overrides
|
||||
cache_key = f"gemini-1.5-flash:0.9"
|
||||
assert cache_key in processor.generation_configs
|
||||
config_obj = processor.generation_configs[cache_key]
|
||||
assert config_obj.temperature == 0.9
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
@ -458,5 +458,132 @@ class TestLlamaFileProcessorSimple(IsolatedAsyncioTestCase):
|
|||
# No specific rate limit error handling tested since SLM presumably has no rate limits
|
||||
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_model_override(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test generate_content with model parameter override"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Response from overridden model"
|
||||
mock_response.usage.prompt_tokens = 15
|
||||
mock_response.usage.completion_tokens = 10
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'LLaMA_CPP',
|
||||
'llamafile': 'http://localhost:8080/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model
|
||||
result = await processor.generate_content("System", "Prompt", model="custom-llamafile-model")
|
||||
|
||||
# Assert
|
||||
assert result.model == "custom-llamafile-model" # Should use overridden model
|
||||
assert result.text == "Response from overridden model"
|
||||
|
||||
# Verify the API call was made with overridden model
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
assert call_args[1]['model'] == "custom-llamafile-model"
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_temperature_override(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test generate_content with temperature parameter override"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Response with temperature override"
|
||||
mock_response.usage.prompt_tokens = 18
|
||||
mock_response.usage.completion_tokens = 12
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'LLaMA_CPP',
|
||||
'llamafile': 'http://localhost:8080/v1',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature
|
||||
result = await processor.generate_content("System", "Prompt", temperature=0.7)
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with temperature override"
|
||||
|
||||
# Verify the API call was made with overridden temperature
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
assert call_args[1]['temperature'] == 0.7
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_both_parameters_override(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test generate_content with both model and temperature overrides"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Response with both parameters override"
|
||||
mock_response.usage.prompt_tokens = 20
|
||||
mock_response.usage.completion_tokens = 15
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'LLaMA_CPP',
|
||||
'llamafile': 'http://localhost:8080/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters
|
||||
result = await processor.generate_content("System", "Prompt", model="override-model", temperature=0.8)
|
||||
|
||||
# Assert
|
||||
assert result.model == "override-model"
|
||||
assert result.text == "Response with both parameters override"
|
||||
|
||||
# Verify the API call was made with overridden parameters
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
assert call_args[1]['model'] == "override-model"
|
||||
assert call_args[1]['temperature'] == 0.8
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
229
tests/unit/test_text_completion/test_lmstudio_processor.py
Normal file
229
tests/unit/test_text_completion/test_lmstudio_processor.py
Normal file
|
|
@ -0,0 +1,229 @@
|
|||
"""
|
||||
Unit tests for trustgraph.model.text_completion.lmstudio
|
||||
Following the same successful pattern as previous tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.model.text_completion.lmstudio.llm import Processor
|
||||
from trustgraph.base import LlmResult
|
||||
from trustgraph.exceptions import TooManyRequests
|
||||
|
||||
|
||||
class TestLMStudioProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Test LMStudio processor functionality"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.lmstudio.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test basic processor initialization"""
|
||||
# Arrange
|
||||
mock_openai = MagicMock()
|
||||
mock_openai_class.return_value = mock_openai
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemma3:9b',
|
||||
'url': 'http://localhost:1234/',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.default_model == 'gemma3:9b'
|
||||
assert processor.url == 'http://localhost:1234/v1/'
|
||||
assert processor.temperature == 0.0
|
||||
assert processor.max_output == 4096
|
||||
assert hasattr(processor, 'openai')
|
||||
mock_openai_class.assert_called_once_with(
|
||||
base_url='http://localhost:1234/v1/',
|
||||
api_key='sk-no-key-required'
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.lmstudio.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test successful content generation"""
|
||||
# Arrange
|
||||
mock_openai = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices[0].message.content = 'Generated response from LMStudio'
|
||||
mock_response.usage.prompt_tokens = 20
|
||||
mock_response.usage.completion_tokens = 12
|
||||
|
||||
mock_openai.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemma3:9b',
|
||||
'url': 'http://localhost:1234/',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Generated response from LMStudio"
|
||||
assert result.in_token == 20
|
||||
assert result.out_token == 12
|
||||
assert result.model == 'gemma3:9b'
|
||||
|
||||
# Verify the API call was made correctly
|
||||
mock_openai.chat.completions.create.assert_called_once()
|
||||
call_args = mock_openai.chat.completions.create.call_args
|
||||
|
||||
# Check model and temperature
|
||||
assert call_args[1]['model'] == 'gemma3:9b'
|
||||
assert call_args[1]['temperature'] == 0.0
|
||||
assert call_args[1]['max_tokens'] == 4096
|
||||
|
||||
@patch('trustgraph.model.text_completion.lmstudio.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_model_override(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test generate_content with model parameter override"""
|
||||
# Arrange
|
||||
mock_openai = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices[0].message.content = 'Response from overridden model'
|
||||
mock_response.usage.prompt_tokens = 15
|
||||
mock_response.usage.completion_tokens = 10
|
||||
|
||||
mock_openai.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemma3:9b',
|
||||
'url': 'http://localhost:1234/',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model
|
||||
result = await processor.generate_content("System", "Prompt", model="custom-lmstudio-model")
|
||||
|
||||
# Assert
|
||||
assert result.model == "custom-lmstudio-model" # Should use overridden model
|
||||
assert result.text == "Response from overridden model"
|
||||
|
||||
# Verify the API call was made with overridden model
|
||||
call_args = mock_openai.chat.completions.create.call_args
|
||||
assert call_args[1]['model'] == "custom-lmstudio-model"
|
||||
|
||||
@patch('trustgraph.model.text_completion.lmstudio.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_temperature_override(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test generate_content with temperature parameter override"""
|
||||
# Arrange
|
||||
mock_openai = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices[0].message.content = 'Response with temperature override'
|
||||
mock_response.usage.prompt_tokens = 18
|
||||
mock_response.usage.completion_tokens = 12
|
||||
|
||||
mock_openai.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemma3:9b',
|
||||
'url': 'http://localhost:1234/',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature
|
||||
result = await processor.generate_content("System", "Prompt", temperature=0.7)
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with temperature override"
|
||||
|
||||
# Verify the API call was made with overridden temperature
|
||||
call_args = mock_openai.chat.completions.create.call_args
|
||||
assert call_args[1]['temperature'] == 0.7
|
||||
|
||||
@patch('trustgraph.model.text_completion.lmstudio.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_both_parameters_override(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test generate_content with both model and temperature overrides"""
|
||||
# Arrange
|
||||
mock_openai = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices[0].message.content = 'Response with both parameters override'
|
||||
mock_response.usage.prompt_tokens = 20
|
||||
mock_response.usage.completion_tokens = 15
|
||||
|
||||
mock_openai.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemma3:9b',
|
||||
'url': 'http://localhost:1234/',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters
|
||||
result = await processor.generate_content("System", "Prompt", model="override-model", temperature=0.8)
|
||||
|
||||
# Assert
|
||||
assert result.model == "override-model"
|
||||
assert result.text == "Response with both parameters override"
|
||||
|
||||
# Verify the API call was made with overridden parameters
|
||||
call_args = mock_openai.chat.completions.create.call_args
|
||||
assert call_args[1]['model'] == "override-model"
|
||||
assert call_args[1]['temperature'] == 0.8
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
275
tests/unit/test_text_completion/test_mistral_processor.py
Normal file
275
tests/unit/test_text_completion/test_mistral_processor.py
Normal file
|
|
@ -0,0 +1,275 @@
|
|||
"""
|
||||
Unit tests for trustgraph.model.text_completion.mistral
|
||||
Following the same successful pattern as other processor tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.model.text_completion.mistral.llm import Processor
|
||||
from trustgraph.base import LlmResult
|
||||
|
||||
|
||||
class TestMistralProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Test Mistral processor functionality"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.mistral.llm.Mistral')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_mistral_class):
|
||||
"""Test basic processor initialization"""
|
||||
# Arrange
|
||||
mock_mistral_client = MagicMock()
|
||||
mock_mistral_class.return_value = mock_mistral_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'ministral-8b-latest',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.1,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.default_model == 'ministral-8b-latest'
|
||||
assert processor.temperature == 0.1
|
||||
assert processor.max_output == 2048
|
||||
assert hasattr(processor, 'mistral')
|
||||
mock_mistral_class.assert_called_once_with(api_key='test-api-key')
|
||||
|
||||
@patch('trustgraph.model.text_completion.mistral.llm.Mistral')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_mistral_class):
|
||||
"""Test successful content generation"""
|
||||
# Arrange
|
||||
mock_mistral_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices[0].message.content = 'Generated response from Mistral'
|
||||
mock_response.usage.prompt_tokens = 15
|
||||
mock_response.usage.completion_tokens = 8
|
||||
mock_mistral_client.chat.complete.return_value = mock_response
|
||||
mock_mistral_class.return_value = mock_mistral_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'ministral-8b-latest',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Generated response from Mistral"
|
||||
assert result.in_token == 15
|
||||
assert result.out_token == 8
|
||||
assert result.model == 'ministral-8b-latest'
|
||||
mock_mistral_client.chat.complete.assert_called_once()
|
||||
|
||||
@patch('trustgraph.model.text_completion.mistral.llm.Mistral')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_mistral_class):
|
||||
"""Test temperature parameter override functionality"""
|
||||
# Arrange
|
||||
mock_mistral_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices[0].message.content = 'Response with custom temperature'
|
||||
mock_response.usage.prompt_tokens = 20
|
||||
mock_response.usage.completion_tokens = 12
|
||||
mock_mistral_client.chat.complete.return_value = mock_response
|
||||
mock_mistral_class.return_value = mock_mistral_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'ministral-8b-latest',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model=None, # Use default model
|
||||
temperature=0.8 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom temperature"
|
||||
|
||||
# Verify Mistral API was called with overridden temperature
|
||||
call_args = mock_mistral_client.chat.complete.call_args
|
||||
assert call_args[1]['temperature'] == 0.8 # Should use runtime override
|
||||
assert call_args[1]['model'] == 'ministral-8b-latest'
|
||||
|
||||
@patch('trustgraph.model.text_completion.mistral.llm.Mistral')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_mistral_class):
|
||||
"""Test model parameter override functionality"""
|
||||
# Arrange
|
||||
mock_mistral_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices[0].message.content = 'Response with custom model'
|
||||
mock_response.usage.prompt_tokens = 18
|
||||
mock_response.usage.completion_tokens = 14
|
||||
mock_mistral_client.chat.complete.return_value = mock_response
|
||||
mock_mistral_class.return_value = mock_mistral_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'ministral-8b-latest', # Default model
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.1, # Default temperature
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="mistral-large-latest", # Override model
|
||||
temperature=None # Use default temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom model"
|
||||
|
||||
# Verify Mistral API was called with overridden model
|
||||
call_args = mock_mistral_client.chat.complete.call_args
|
||||
assert call_args[1]['model'] == 'mistral-large-latest' # Should use runtime override
|
||||
assert call_args[1]['temperature'] == 0.1 # Should use processor default
|
||||
|
||||
@patch('trustgraph.model.text_completion.mistral.llm.Mistral')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_mistral_class):
|
||||
"""Test overriding both model and temperature parameters simultaneously"""
|
||||
# Arrange
|
||||
mock_mistral_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices[0].message.content = 'Response with both overrides'
|
||||
mock_response.usage.prompt_tokens = 22
|
||||
mock_response.usage.completion_tokens = 16
|
||||
mock_mistral_client.chat.complete.return_value = mock_response
|
||||
mock_mistral_class.return_value = mock_mistral_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'ministral-8b-latest', # Default model
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="mistral-large-latest", # Override model
|
||||
temperature=0.9 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with both overrides"
|
||||
|
||||
# Verify Mistral API was called with both overrides
|
||||
call_args = mock_mistral_client.chat.complete.call_args
|
||||
assert call_args[1]['model'] == 'mistral-large-latest' # Should use runtime override
|
||||
assert call_args[1]['temperature'] == 0.9 # Should use runtime override
|
||||
|
||||
@patch('trustgraph.model.text_completion.mistral.llm.Mistral')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_prompt_construction(self, mock_llm_init, mock_async_init, mock_mistral_class):
|
||||
"""Test prompt construction with system and user prompts"""
|
||||
# Arrange
|
||||
mock_mistral_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices[0].message.content = 'Response with system instructions'
|
||||
mock_response.usage.prompt_tokens = 25
|
||||
mock_response.usage.completion_tokens = 15
|
||||
mock_mistral_client.chat.complete.return_value = mock_response
|
||||
mock_mistral_class.return_value = mock_mistral_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'ministral-8b-latest',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("You are a helpful assistant", "What is AI?")
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with system instructions"
|
||||
assert result.in_token == 25
|
||||
assert result.out_token == 15
|
||||
|
||||
# Verify the combined prompt structure
|
||||
call_args = mock_mistral_client.chat.complete.call_args
|
||||
messages = call_args[1]['messages']
|
||||
assert len(messages) == 1
|
||||
assert messages[0]['role'] == 'user'
|
||||
assert messages[0]['content'][0]['type'] == 'text'
|
||||
assert messages[0]['content'][0]['text'] == "You are a helpful assistant\n\nWhat is AI?"
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
@ -312,6 +312,150 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
|
|||
# Verify the combined prompt
|
||||
mock_client.generate.assert_called_once_with('llama2', "You are a helpful assistant\n\nWhat is AI?", options={'temperature': 0.0})
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test temperature parameter override functionality"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_response = {
|
||||
'response': 'Response with custom temperature',
|
||||
'prompt_eval_count': 20,
|
||||
'eval_count': 12
|
||||
}
|
||||
mock_client.generate.return_value = mock_response
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'llama2',
|
||||
'ollama': 'http://localhost:11434',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model=None, # Use default model
|
||||
temperature=0.8 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom temperature"
|
||||
|
||||
# Verify Ollama API was called with overridden temperature
|
||||
mock_client.generate.assert_called_once_with(
|
||||
'llama2',
|
||||
"System prompt\n\nUser prompt",
|
||||
options={'temperature': 0.8} # Should use runtime override
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test model parameter override functionality"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_response = {
|
||||
'response': 'Response with custom model',
|
||||
'prompt_eval_count': 18,
|
||||
'eval_count': 14
|
||||
}
|
||||
mock_client.generate.return_value = mock_response
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'llama2', # Default model
|
||||
'ollama': 'http://localhost:11434',
|
||||
'temperature': 0.1, # Default temperature
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="mistral", # Override model
|
||||
temperature=None # Use default temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom model"
|
||||
|
||||
# Verify Ollama API was called with overridden model
|
||||
mock_client.generate.assert_called_once_with(
|
||||
'mistral', # Should use runtime override
|
||||
"System prompt\n\nUser prompt",
|
||||
options={'temperature': 0.1} # Should use processor default
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test overriding both model and temperature parameters simultaneously"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_response = {
|
||||
'response': 'Response with both overrides',
|
||||
'prompt_eval_count': 22,
|
||||
'eval_count': 16
|
||||
}
|
||||
mock_client.generate.return_value = mock_response
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'llama2', # Default model
|
||||
'ollama': 'http://localhost:11434',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="codellama", # Override model
|
||||
temperature=0.9 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with both overrides"
|
||||
|
||||
# Verify Ollama API was called with both overrides
|
||||
mock_client.generate.assert_called_once_with(
|
||||
'codellama', # Should use runtime override
|
||||
"System prompt\n\nUser prompt",
|
||||
options={'temperature': 0.9} # Should use runtime override
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
@ -391,5 +391,210 @@ class TestOpenAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert call_args[1]['response_format'] == {"type": "text"}
|
||||
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test temperature parameter override functionality"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Response with custom temperature"
|
||||
mock_response.usage.prompt_tokens = 15
|
||||
mock_response.usage.completion_tokens = 10
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-3.5-turbo',
|
||||
'api_key': 'test-api-key',
|
||||
'url': 'https://api.openai.com/v1',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model=None, # Use default model
|
||||
temperature=0.9 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom temperature"
|
||||
|
||||
# Verify the OpenAI API was called with overridden temperature
|
||||
mock_openai_client.chat.completions.create.assert_called_once()
|
||||
call_kwargs = mock_openai_client.chat.completions.create.call_args.kwargs
|
||||
|
||||
assert call_kwargs['temperature'] == 0.9 # Should use runtime override
|
||||
assert call_kwargs['model'] == 'gpt-3.5-turbo' # Should use processor default
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test model parameter override functionality"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Response with custom model"
|
||||
mock_response.usage.prompt_tokens = 15
|
||||
mock_response.usage.completion_tokens = 10
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-3.5-turbo', # Default model
|
||||
'api_key': 'test-api-key',
|
||||
'url': 'https://api.openai.com/v1',
|
||||
'temperature': 0.2,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="gpt-4", # Override model
|
||||
temperature=None # Use default temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom model"
|
||||
|
||||
# Verify the OpenAI API was called with overridden model
|
||||
mock_openai_client.chat.completions.create.assert_called_once()
|
||||
call_kwargs = mock_openai_client.chat.completions.create.call_args.kwargs
|
||||
|
||||
assert call_kwargs['model'] == 'gpt-4' # Should use runtime override
|
||||
assert call_kwargs['temperature'] == 0.2 # Should use processor default
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test overriding both model and temperature parameters simultaneously"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Response with both overrides"
|
||||
mock_response.usage.prompt_tokens = 15
|
||||
mock_response.usage.completion_tokens = 10
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-3.5-turbo', # Default model
|
||||
'api_key': 'test-api-key',
|
||||
'url': 'https://api.openai.com/v1',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="gpt-4", # Override model
|
||||
temperature=0.7 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with both overrides"
|
||||
|
||||
# Verify the OpenAI API was called with both overrides
|
||||
mock_openai_client.chat.completions.create.assert_called_once()
|
||||
call_kwargs = mock_openai_client.chat.completions.create.call_args.kwargs
|
||||
|
||||
assert call_kwargs['model'] == 'gpt-4' # Should use runtime override
|
||||
assert call_kwargs['temperature'] == 0.7 # Should use runtime override
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_no_override_uses_defaults(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test that when no parameters are overridden, processor defaults are used"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Response with defaults"
|
||||
mock_response.usage.prompt_tokens = 15
|
||||
mock_response.usage.completion_tokens = 10
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-4', # Default model
|
||||
'api_key': 'test-api-key',
|
||||
'url': 'https://api.openai.com/v1',
|
||||
'temperature': 0.5, # Default temperature
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Don't override any parameters (pass None)
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model=None, # Use default model
|
||||
temperature=None # Use default temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with defaults"
|
||||
|
||||
# Verify the OpenAI API was called with processor defaults
|
||||
mock_openai_client.chat.completions.create.assert_called_once()
|
||||
call_kwargs = mock_openai_client.chat.completions.create.call_args.kwargs
|
||||
|
||||
assert call_kwargs['model'] == 'gpt-4' # Should use processor default
|
||||
assert call_kwargs['temperature'] == 0.5 # Should use processor default
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
186
tests/unit/test_text_completion/test_parameter_caching.py
Normal file
186
tests/unit/test_text_completion/test_parameter_caching.py
Normal file
|
|
@ -0,0 +1,186 @@
|
|||
"""
|
||||
Unit tests for Parameter-Based Caching in LLM Processors
|
||||
Testing processors that cache based on temperature parameters (Bedrock, GoogleAIStudio)
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
from trustgraph.model.text_completion.googleaistudio.llm import Processor as GoogleAIProcessor
|
||||
from trustgraph.base import LlmResult
|
||||
|
||||
|
||||
class TestParameterCaching(IsolatedAsyncioTestCase):
|
||||
"""Test parameter-based caching functionality"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_googleai_temperature_cache_keys(self, mock_llm_init, mock_async_init, mock_genai):
|
||||
"""Test that GoogleAI processor creates separate cache entries for different temperatures"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_genai.Client.return_value = mock_client
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Generated response"
|
||||
mock_response.usage_metadata.prompt_token_count = 10
|
||||
mock_response.usage_metadata.candidates_token_count = 5
|
||||
mock_client.models.generate_content.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 1024,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = GoogleAIProcessor(**config)
|
||||
|
||||
# Act - Call with different temperatures
|
||||
await processor.generate_content("System", "Prompt 1", model="gemini-2.0-flash-001", temperature=0.0)
|
||||
await processor.generate_content("System", "Prompt 2", model="gemini-2.0-flash-001", temperature=0.5)
|
||||
await processor.generate_content("System", "Prompt 3", model="gemini-2.0-flash-001", temperature=1.0)
|
||||
|
||||
# Assert - Should have 3 different cache entries
|
||||
cache_keys = list(processor.generation_configs.keys())
|
||||
|
||||
assert len(cache_keys) == 3
|
||||
assert "gemini-2.0-flash-001:0.0" in cache_keys
|
||||
assert "gemini-2.0-flash-001:0.5" in cache_keys
|
||||
assert "gemini-2.0-flash-001:1.0" in cache_keys
|
||||
|
||||
# Verify each cached config has the correct temperature
|
||||
assert processor.generation_configs["gemini-2.0-flash-001:0.0"].temperature == 0.0
|
||||
assert processor.generation_configs["gemini-2.0-flash-001:0.5"].temperature == 0.5
|
||||
assert processor.generation_configs["gemini-2.0-flash-001:1.0"].temperature == 1.0
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_googleai_cache_reuse_same_parameters(self, mock_llm_init, mock_async_init, mock_genai):
|
||||
"""Test that GoogleAI processor reuses cache for identical model+temperature combinations"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_genai.Client.return_value = mock_client
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Generated response"
|
||||
mock_response.usage_metadata.prompt_token_count = 10
|
||||
mock_response.usage_metadata.candidates_token_count = 5
|
||||
mock_client.models.generate_content.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 1024,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = GoogleAIProcessor(**config)
|
||||
|
||||
# Act - Call multiple times with same parameters
|
||||
await processor.generate_content("System", "Prompt 1", model="gemini-2.0-flash-001", temperature=0.7)
|
||||
await processor.generate_content("System", "Prompt 2", model="gemini-2.0-flash-001", temperature=0.7)
|
||||
await processor.generate_content("System", "Prompt 3", model="gemini-2.0-flash-001", temperature=0.7)
|
||||
|
||||
# Assert - Should have only 1 cache entry for the repeated parameters
|
||||
cache_keys = list(processor.generation_configs.keys())
|
||||
assert len(cache_keys) == 1
|
||||
assert "gemini-2.0-flash-001:0.7" in cache_keys
|
||||
|
||||
# The same config object should be reused
|
||||
config_obj = processor.generation_configs["gemini-2.0-flash-001:0.7"]
|
||||
assert config_obj.temperature == 0.7
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_googleai_different_models_separate_caches(self, mock_llm_init, mock_async_init, mock_genai):
|
||||
"""Test that different models create separate cache entries even with same temperature"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_genai.Client.return_value = mock_client
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Generated response"
|
||||
mock_response.usage_metadata.prompt_token_count = 10
|
||||
mock_response.usage_metadata.candidates_token_count = 5
|
||||
mock_client.models.generate_content.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 1024,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = GoogleAIProcessor(**config)
|
||||
|
||||
# Act - Call with different models, same temperature
|
||||
await processor.generate_content("System", "Prompt 1", model="gemini-2.0-flash-001", temperature=0.5)
|
||||
await processor.generate_content("System", "Prompt 2", model="gemini-1.5-flash-001", temperature=0.5)
|
||||
|
||||
# Assert - Should have separate cache entries for different models
|
||||
cache_keys = list(processor.generation_configs.keys())
|
||||
assert len(cache_keys) == 2
|
||||
assert "gemini-2.0-flash-001:0.5" in cache_keys
|
||||
assert "gemini-1.5-flash-001:0.5" in cache_keys
|
||||
|
||||
# Note: Bedrock tests would be similar but testing the Bedrock processor's caching behavior
|
||||
# The Bedrock processor caches model variants with temperature in the cache key
|
||||
|
||||
async def test_bedrock_temperature_cache_keys(self):
|
||||
"""Test Bedrock processor temperature-aware caching"""
|
||||
# This would test the Bedrock processor's _get_or_create_variant method
|
||||
# with different temperature values to ensure proper cache key generation
|
||||
|
||||
# Implementation would follow similar pattern to GoogleAI tests above
|
||||
# but using the Bedrock processor and testing model_variants cache
|
||||
pass
|
||||
|
||||
async def test_bedrock_cache_isolation_different_temperatures(self):
|
||||
"""Test that Bedrock processor isolates cache entries by temperature"""
|
||||
pass
|
||||
|
||||
async def test_cache_memory_efficiency(self):
|
||||
"""Test that caches don't grow unbounded with many different parameter combinations"""
|
||||
# This could test cache size limits or cleanup behavior if implemented
|
||||
pass
|
||||
|
||||
|
||||
class TestCachePerformance(IsolatedAsyncioTestCase):
|
||||
"""Test caching performance characteristics"""
|
||||
|
||||
async def test_cache_hit_performance(self):
|
||||
"""Test that cache hits are faster than cache misses"""
|
||||
# This would measure timing differences between cache hits and misses
|
||||
pass
|
||||
|
||||
async def test_concurrent_cache_access(self):
|
||||
"""Test concurrent access to cached configurations"""
|
||||
# This would test thread-safety of cache access
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
271
tests/unit/test_text_completion/test_tgi_processor.py
Normal file
271
tests/unit/test_text_completion/test_tgi_processor.py
Normal file
|
|
@ -0,0 +1,271 @@
|
|||
"""
|
||||
Unit tests for trustgraph.model.text_completion.tgi
|
||||
Following the same successful pattern as previous tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.model.text_completion.tgi.llm import Processor
|
||||
from trustgraph.base import LlmResult
|
||||
from trustgraph.exceptions import TooManyRequests
|
||||
|
||||
|
||||
class TestTGIProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Test TGI processor functionality"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.tgi.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test basic processor initialization"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'tgi',
|
||||
'url': 'http://tgi-service:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.default_model == 'tgi'
|
||||
assert processor.base_url == 'http://tgi-service:8899/v1'
|
||||
assert processor.temperature == 0.0
|
||||
assert processor.max_output == 2048
|
||||
assert hasattr(processor, 'session')
|
||||
mock_session_class.assert_called_once()
|
||||
|
||||
@patch('trustgraph.model.text_completion.tgi.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test successful content generation"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'choices': [{
|
||||
'message': {
|
||||
'content': 'Generated response from TGI'
|
||||
}
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 20,
|
||||
'completion_tokens': 12
|
||||
}
|
||||
})
|
||||
|
||||
# Mock the async context manager
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session.post.return_value.__aexit__.return_value = None
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'tgi',
|
||||
'url': 'http://tgi-service:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Generated response from TGI"
|
||||
assert result.in_token == 20
|
||||
assert result.out_token == 12
|
||||
assert result.model == 'tgi'
|
||||
|
||||
# Verify the API call was made correctly
|
||||
mock_session.post.assert_called_once()
|
||||
call_args = mock_session.post.call_args
|
||||
|
||||
# Check URL
|
||||
assert call_args[0][0] == 'http://tgi-service:8899/v1/chat/completions'
|
||||
|
||||
# Check request structure
|
||||
request_body = call_args[1]['json']
|
||||
assert request_body['model'] == 'tgi'
|
||||
assert request_body['temperature'] == 0.0
|
||||
assert request_body['max_tokens'] == 2048
|
||||
|
||||
@patch('trustgraph.model.text_completion.tgi.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_model_override(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test generate_content with model parameter override"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'choices': [{
|
||||
'message': {
|
||||
'content': 'Response from overridden model'
|
||||
}
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 15,
|
||||
'completion_tokens': 10
|
||||
}
|
||||
})
|
||||
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session.post.return_value.__aexit__.return_value = None
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'tgi',
|
||||
'url': 'http://tgi-service:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model
|
||||
result = await processor.generate_content("System", "Prompt", model="custom-tgi-model")
|
||||
|
||||
# Assert
|
||||
assert result.model == "custom-tgi-model" # Should use overridden model
|
||||
assert result.text == "Response from overridden model"
|
||||
|
||||
# Verify the API call was made with overridden model
|
||||
call_args = mock_session.post.call_args
|
||||
assert call_args[1]['json']['model'] == "custom-tgi-model"
|
||||
|
||||
@patch('trustgraph.model.text_completion.tgi.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_temperature_override(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test generate_content with temperature parameter override"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'choices': [{
|
||||
'message': {
|
||||
'content': 'Response with temperature override'
|
||||
}
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 18,
|
||||
'completion_tokens': 12
|
||||
}
|
||||
})
|
||||
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session.post.return_value.__aexit__.return_value = None
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'tgi',
|
||||
'url': 'http://tgi-service:8899/v1',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature
|
||||
result = await processor.generate_content("System", "Prompt", temperature=0.7)
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with temperature override"
|
||||
|
||||
# Verify the API call was made with overridden temperature
|
||||
call_args = mock_session.post.call_args
|
||||
assert call_args[1]['json']['temperature'] == 0.7
|
||||
|
||||
@patch('trustgraph.model.text_completion.tgi.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_both_parameters_override(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test generate_content with both model and temperature overrides"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'choices': [{
|
||||
'message': {
|
||||
'content': 'Response with both parameters override'
|
||||
}
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 20,
|
||||
'completion_tokens': 15
|
||||
}
|
||||
})
|
||||
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session.post.return_value.__aexit__.return_value = None
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'tgi',
|
||||
'url': 'http://tgi-service:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters
|
||||
result = await processor.generate_content("System", "Prompt", model="override-model", temperature=0.8)
|
||||
|
||||
# Assert
|
||||
assert result.model == "override-model"
|
||||
assert result.text == "Response with both parameters override"
|
||||
|
||||
# Verify the API call was made with overridden parameters
|
||||
call_args = mock_session.post.call_args
|
||||
assert call_args[1]['json']['model'] == "override-model"
|
||||
assert call_args[1]['json']['temperature'] == 0.8
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
@ -460,6 +460,180 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert processor.api_params["top_p"] == 1.0
|
||||
assert processor.api_params["top_k"] == 32
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
"""Test temperature parameter override functionality"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Response with custom temperature"
|
||||
mock_response.usage_metadata.prompt_token_count = 20
|
||||
mock_response.usage_metadata.candidates_token_count = 12
|
||||
mock_model.generate_content.return_value = mock_response
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'region': 'us-central1',
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 8192,
|
||||
'private_key': 'private.json',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model=None, # Use default model
|
||||
temperature=0.8 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom temperature"
|
||||
|
||||
# Verify Gemini API was called with overridden temperature
|
||||
mock_model.generate_content.assert_called_once()
|
||||
call_args = mock_model.generate_content.call_args
|
||||
|
||||
# Check that generation_config was created (we can't directly access temperature from mock)
|
||||
generation_config = call_args.kwargs['generation_config']
|
||||
assert generation_config is not None # Should use overridden temperature configuration
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
"""Test model parameter override functionality"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
|
||||
# Mock different models
|
||||
mock_model_default = MagicMock()
|
||||
mock_model_override = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Response with custom model"
|
||||
mock_response.usage_metadata.prompt_token_count = 18
|
||||
mock_response.usage_metadata.candidates_token_count = 14
|
||||
mock_model_override.generate_content.return_value = mock_response
|
||||
|
||||
# GenerativeModel should return different models based on input
|
||||
def model_factory(model_name):
|
||||
if model_name == 'gemini-1.5-pro':
|
||||
return mock_model_override
|
||||
return mock_model_default
|
||||
|
||||
mock_generative_model.side_effect = model_factory
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'region': 'us-central1',
|
||||
'model': 'gemini-2.0-flash-001', # Default model
|
||||
'temperature': 0.2, # Default temperature
|
||||
'max_output': 8192,
|
||||
'private_key': 'private.json',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="gemini-1.5-pro", # Override model
|
||||
temperature=None # Use default temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom model"
|
||||
|
||||
# Verify the overridden model was used
|
||||
mock_model_override.generate_content.assert_called_once()
|
||||
# Verify GenerativeModel was called with the override model
|
||||
mock_generative_model.assert_called_with('gemini-1.5-pro')
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
"""Test overriding both model and temperature parameters simultaneously"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Response with both overrides"
|
||||
mock_response.usage_metadata.prompt_token_count = 22
|
||||
mock_response.usage_metadata.candidates_token_count = 16
|
||||
mock_model.generate_content.return_value = mock_response
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'region': 'us-central1',
|
||||
'model': 'gemini-2.0-flash-001', # Default model
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 8192,
|
||||
'private_key': 'private.json',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="gemini-1.5-flash-001", # Override model
|
||||
temperature=0.9 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with both overrides"
|
||||
|
||||
# Verify both overrides were used
|
||||
mock_model.generate_content.assert_called_once()
|
||||
call_args = mock_model.generate_content.call_args
|
||||
|
||||
# Verify model override
|
||||
mock_generative_model.assert_called_with('gemini-1.5-flash-001') # Should use runtime override
|
||||
|
||||
# Verify temperature override (we can't directly access temperature from mock)
|
||||
generation_config = call_args.kwargs['generation_config']
|
||||
assert generation_config is not None # Should use overridden temperature configuration
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
@ -485,5 +485,148 @@ class TestVLLMProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert call_args[1]['json']['prompt'] == expected_prompt
|
||||
|
||||
|
||||
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_model_override(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test generate_content with model parameter override"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'choices': [{
|
||||
'text': 'Response from overridden model'
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 12,
|
||||
'completion_tokens': 8
|
||||
}
|
||||
})
|
||||
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session.post.return_value.__aexit__.return_value = None
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
|
||||
'url': 'http://vllm-service:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model
|
||||
result = await processor.generate_content("System", "Prompt", model="custom-vllm-model")
|
||||
|
||||
# Assert
|
||||
assert result.model == "custom-vllm-model" # Should use overridden model
|
||||
assert result.text == "Response from overridden model"
|
||||
|
||||
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_temperature_override(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test generate_content with temperature parameter override"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'choices': [{
|
||||
'text': 'Response with temperature override'
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 15,
|
||||
'completion_tokens': 10
|
||||
}
|
||||
})
|
||||
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session.post.return_value.__aexit__.return_value = None
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
|
||||
'url': 'http://vllm-service:8899/v1',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature
|
||||
result = await processor.generate_content("System", "Prompt", temperature=0.7)
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with temperature override"
|
||||
|
||||
# Verify the request was made with overridden temperature
|
||||
call_args = mock_session.post.call_args
|
||||
assert call_args[1]['json']['temperature'] == 0.7
|
||||
|
||||
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_both_parameters_override(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test generate_content with both model and temperature overrides"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'choices': [{
|
||||
'text': 'Response with both parameters override'
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 18,
|
||||
'completion_tokens': 12
|
||||
}
|
||||
})
|
||||
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session.post.return_value.__aexit__.return_value = None
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
|
||||
'url': 'http://vllm-service:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters
|
||||
result = await processor.generate_content("System", "Prompt", model="override-model", temperature=0.8)
|
||||
|
||||
# Assert
|
||||
assert result.model == "override-model"
|
||||
assert result.text == "Response with both parameters override"
|
||||
|
||||
# Verify the request was made with overridden temperature
|
||||
call_args = mock_session.post.call_args
|
||||
assert call_args[1]['json']['temperature'] == 0.8
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
Loading…
Add table
Add a link
Reference in a new issue