mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-26 08:56:21 +02:00
release/v1.4 -> master (#548)
This commit is contained in:
parent
3ec2cd54f9
commit
2bd68ed7f4
94 changed files with 8571 additions and 1740 deletions
|
|
@ -44,7 +44,7 @@ class TestAzureOpenAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gpt-4'
|
||||
assert processor.default_model == 'gpt-4'
|
||||
assert processor.temperature == 0.0
|
||||
assert processor.max_output == 4192
|
||||
assert hasattr(processor, 'openai')
|
||||
|
|
@ -254,7 +254,7 @@ class TestAzureOpenAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gpt-35-turbo'
|
||||
assert processor.default_model == 'gpt-35-turbo'
|
||||
assert processor.temperature == 0.7
|
||||
assert processor.max_output == 2048
|
||||
mock_azure_openai_class.assert_called_once_with(
|
||||
|
|
@ -289,7 +289,7 @@ class TestAzureOpenAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gpt-4'
|
||||
assert processor.default_model == 'gpt-4'
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
assert processor.max_output == 4192 # default_max_output
|
||||
mock_azure_openai_class.assert_called_once_with(
|
||||
|
|
@ -402,6 +402,156 @@ class TestAzureOpenAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert call_args[1]['max_tokens'] == 1024
|
||||
assert call_args[1]['top_p'] == 1
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
|
||||
"""Test temperature parameter override functionality"""
|
||||
# Arrange
|
||||
mock_azure_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = 'Response with custom temperature'
|
||||
mock_response.usage.prompt_tokens = 20
|
||||
mock_response.usage.completion_tokens = 12
|
||||
|
||||
mock_azure_client.chat.completions.create.return_value = mock_response
|
||||
mock_azure_openai_class.return_value = mock_azure_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-4',
|
||||
'endpoint': 'https://test.openai.azure.com/',
|
||||
'token': 'test-token',
|
||||
'api_version': '2024-12-01-preview',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model=None, # Use default model
|
||||
temperature=0.8 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom temperature"
|
||||
|
||||
# Verify Azure OpenAI API was called with overridden temperature
|
||||
call_args = mock_azure_client.chat.completions.create.call_args
|
||||
assert call_args[1]['temperature'] == 0.8 # Should use runtime override
|
||||
assert call_args[1]['model'] == 'gpt-4'
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
|
||||
"""Test model parameter override functionality"""
|
||||
# Arrange
|
||||
mock_azure_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = 'Response with custom model'
|
||||
mock_response.usage.prompt_tokens = 18
|
||||
mock_response.usage.completion_tokens = 14
|
||||
|
||||
mock_azure_client.chat.completions.create.return_value = mock_response
|
||||
mock_azure_openai_class.return_value = mock_azure_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-4', # Default model
|
||||
'endpoint': 'https://test.openai.azure.com/',
|
||||
'token': 'test-token',
|
||||
'api_version': '2024-12-01-preview',
|
||||
'temperature': 0.1, # Default temperature
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="gpt-4o", # Override model
|
||||
temperature=None # Use default temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom model"
|
||||
|
||||
# Verify Azure OpenAI API was called with overridden model
|
||||
call_args = mock_azure_client.chat.completions.create.call_args
|
||||
assert call_args[1]['model'] == 'gpt-4o' # Should use runtime override
|
||||
assert call_args[1]['temperature'] == 0.1 # Should use processor default
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_azure_openai_class):
|
||||
"""Test overriding both model and temperature parameters simultaneously"""
|
||||
# Arrange
|
||||
mock_azure_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = 'Response with both overrides'
|
||||
mock_response.usage.prompt_tokens = 22
|
||||
mock_response.usage.completion_tokens = 16
|
||||
|
||||
mock_azure_client.chat.completions.create.return_value = mock_response
|
||||
mock_azure_openai_class.return_value = mock_azure_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-4', # Default model
|
||||
'endpoint': 'https://test.openai.azure.com/',
|
||||
'token': 'test-token',
|
||||
'api_version': '2024-12-01-preview',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="gpt-4o-mini", # Override model
|
||||
temperature=0.9 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with both overrides"
|
||||
|
||||
# Verify Azure OpenAI API was called with both overrides
|
||||
call_args = mock_azure_client.chat.completions.create.call_args
|
||||
assert call_args[1]['model'] == 'gpt-4o-mini' # Should use runtime override
|
||||
assert call_args[1]['temperature'] == 0.9 # Should use runtime override
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
@ -43,7 +43,7 @@ class TestAzureProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert processor.token == 'test-token'
|
||||
assert processor.temperature == 0.0
|
||||
assert processor.max_output == 4192
|
||||
assert processor.model == 'AzureAI'
|
||||
assert processor.default_model == 'AzureAI'
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
|
|
@ -261,7 +261,7 @@ class TestAzureProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert processor.token == 'custom-token'
|
||||
assert processor.temperature == 0.7
|
||||
assert processor.max_output == 2048
|
||||
assert processor.model == 'AzureAI'
|
||||
assert processor.default_model == 'AzureAI'
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
|
|
@ -289,7 +289,7 @@ class TestAzureProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert processor.token == 'test-token'
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
assert processor.max_output == 4192 # default_max_output
|
||||
assert processor.model == 'AzureAI' # default_model
|
||||
assert processor.default_model == 'AzureAI' # default_model
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
|
|
@ -459,5 +459,150 @@ class TestAzureProcessorSimple(IsolatedAsyncioTestCase):
|
|||
)
|
||||
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_model_override(self, mock_llm_init, mock_async_init, mock_requests):
|
||||
"""Test generate_content with model parameter override"""
|
||||
# Arrange
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
'choices': [{
|
||||
'message': {
|
||||
'content': 'Response with model override'
|
||||
}
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 15,
|
||||
'completion_tokens': 10
|
||||
}
|
||||
}
|
||||
mock_requests.post.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
|
||||
'token': 'test-token',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model
|
||||
result = await processor.generate_content("System", "Prompt", model="custom-azure-model")
|
||||
|
||||
# Assert
|
||||
assert result.model == "custom-azure-model" # Should use overridden model
|
||||
assert result.text == "Response with model override"
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_temperature_override(self, mock_llm_init, mock_async_init, mock_requests):
|
||||
"""Test generate_content with temperature parameter override"""
|
||||
# Arrange
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
'choices': [{
|
||||
'message': {
|
||||
'content': 'Response with temperature override'
|
||||
}
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 15,
|
||||
'completion_tokens': 10
|
||||
}
|
||||
}
|
||||
mock_requests.post.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
|
||||
'token': 'test-token',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature
|
||||
result = await processor.generate_content("System", "Prompt", temperature=0.8)
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with temperature override"
|
||||
|
||||
# Verify the request was made with the overridden temperature
|
||||
mock_requests.post.assert_called_once()
|
||||
call_args = mock_requests.post.call_args
|
||||
|
||||
import json
|
||||
request_body = json.loads(call_args[1]['data'])
|
||||
assert request_body['temperature'] == 0.8
|
||||
|
||||
@patch('trustgraph.model.text_completion.azure.llm.requests')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_both_parameters_override(self, mock_llm_init, mock_async_init, mock_requests):
|
||||
"""Test generate_content with both model and temperature overrides"""
|
||||
# Arrange
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
'choices': [{
|
||||
'message': {
|
||||
'content': 'Response with both parameters override'
|
||||
}
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 18,
|
||||
'completion_tokens': 12
|
||||
}
|
||||
}
|
||||
mock_requests.post.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions',
|
||||
'token': 'test-token',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters
|
||||
result = await processor.generate_content("System", "Prompt", model="override-model", temperature=0.9)
|
||||
|
||||
# Assert
|
||||
assert result.model == "override-model"
|
||||
assert result.text == "Response with both parameters override"
|
||||
|
||||
# Verify the request was made with overridden temperature
|
||||
mock_requests.post.assert_called_once()
|
||||
call_args = mock_requests.post.call_args
|
||||
|
||||
import json
|
||||
request_body = json.loads(call_args[1]['data'])
|
||||
assert request_body['temperature'] == 0.9
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
280
tests/unit/test_text_completion/test_bedrock_processor.py
Normal file
280
tests/unit/test_text_completion/test_bedrock_processor.py
Normal file
|
|
@ -0,0 +1,280 @@
|
|||
"""
|
||||
Unit tests for trustgraph.model.text_completion.bedrock
|
||||
Following the same successful pattern as other processor tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
import json
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.model.text_completion.bedrock.llm import Processor, Mistral, Anthropic
|
||||
from trustgraph.base import LlmResult
|
||||
|
||||
|
||||
class TestBedrockProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Test Bedrock processor functionality"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.bedrock.llm.boto3.Session')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test basic processor initialization"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_bedrock = MagicMock()
|
||||
mock_session.client.return_value = mock_bedrock
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'mistral.mistral-large-2407-v1:0',
|
||||
'temperature': 0.1,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.default_model == 'mistral.mistral-large-2407-v1:0'
|
||||
assert processor.temperature == 0.1
|
||||
assert hasattr(processor, 'bedrock')
|
||||
mock_session_class.assert_called_once()
|
||||
|
||||
@patch('trustgraph.model.text_completion.bedrock.llm.boto3.Session')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_success_mistral(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test successful content generation with Mistral model"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_bedrock = MagicMock()
|
||||
mock_session.client.return_value = mock_bedrock
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_response = {
|
||||
'body': MagicMock(),
|
||||
'ResponseMetadata': {
|
||||
'HTTPHeaders': {
|
||||
'x-amzn-bedrock-input-token-count': '15',
|
||||
'x-amzn-bedrock-output-token-count': '8'
|
||||
}
|
||||
}
|
||||
}
|
||||
mock_response['body'].read.return_value = json.dumps({
|
||||
'outputs': [{'text': 'Generated response from Bedrock'}]
|
||||
})
|
||||
mock_bedrock.invoke_model.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'mistral.mistral-large-2407-v1:0',
|
||||
'temperature': 0.0,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Generated response from Bedrock"
|
||||
assert result.in_token == 15
|
||||
assert result.out_token == 8
|
||||
assert result.model == 'mistral.mistral-large-2407-v1:0'
|
||||
mock_bedrock.invoke_model.assert_called_once()
|
||||
|
||||
@patch('trustgraph.model.text_completion.bedrock.llm.boto3.Session')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test temperature parameter override functionality"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_bedrock = MagicMock()
|
||||
mock_session.client.return_value = mock_bedrock
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_response = {
|
||||
'body': MagicMock(),
|
||||
'ResponseMetadata': {
|
||||
'HTTPHeaders': {
|
||||
'x-amzn-bedrock-input-token-count': '20',
|
||||
'x-amzn-bedrock-output-token-count': '12'
|
||||
}
|
||||
}
|
||||
}
|
||||
mock_response['body'].read.return_value = json.dumps({
|
||||
'outputs': [{'text': 'Response with custom temperature'}]
|
||||
})
|
||||
mock_bedrock.invoke_model.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'mistral.mistral-large-2407-v1:0',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model=None, # Use default model
|
||||
temperature=0.8 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom temperature"
|
||||
|
||||
# Verify the model variant was created with overridden temperature
|
||||
# The cache key should include the temperature
|
||||
cache_key = f"mistral.mistral-large-2407-v1:0:0.8"
|
||||
assert cache_key in processor.model_variants
|
||||
variant = processor.model_variants[cache_key]
|
||||
assert variant.temperature == 0.8
|
||||
|
||||
@patch('trustgraph.model.text_completion.bedrock.llm.boto3.Session')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test model parameter override functionality"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_bedrock = MagicMock()
|
||||
mock_session.client.return_value = mock_bedrock
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_response = {
|
||||
'body': MagicMock(),
|
||||
'ResponseMetadata': {
|
||||
'HTTPHeaders': {
|
||||
'x-amzn-bedrock-input-token-count': '18',
|
||||
'x-amzn-bedrock-output-token-count': '14'
|
||||
}
|
||||
}
|
||||
}
|
||||
mock_response['body'].read.return_value = json.dumps({
|
||||
'content': [{'text': 'Response with custom model'}]
|
||||
})
|
||||
mock_bedrock.invoke_model.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'mistral.mistral-large-2407-v1:0', # Default model
|
||||
'temperature': 0.1, # Default temperature
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="anthropic.claude-3-sonnet-20240229-v1:0", # Override model
|
||||
temperature=None # Use default temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom model"
|
||||
|
||||
# Verify Bedrock API was called with overridden model
|
||||
mock_bedrock.invoke_model.assert_called_once()
|
||||
call_args = mock_bedrock.invoke_model.call_args
|
||||
assert call_args[1]['modelId'] == "anthropic.claude-3-sonnet-20240229-v1:0"
|
||||
|
||||
# Verify the correct model variant (Anthropic) was used
|
||||
cache_key = f"anthropic.claude-3-sonnet-20240229-v1:0:0.1"
|
||||
assert cache_key in processor.model_variants
|
||||
variant = processor.model_variants[cache_key]
|
||||
assert isinstance(variant, Anthropic)
|
||||
|
||||
@patch('trustgraph.model.text_completion.bedrock.llm.boto3.Session')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test overriding both model and temperature parameters simultaneously"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_bedrock = MagicMock()
|
||||
mock_session.client.return_value = mock_bedrock
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_response = {
|
||||
'body': MagicMock(),
|
||||
'ResponseMetadata': {
|
||||
'HTTPHeaders': {
|
||||
'x-amzn-bedrock-input-token-count': '22',
|
||||
'x-amzn-bedrock-output-token-count': '16'
|
||||
}
|
||||
}
|
||||
}
|
||||
mock_response['body'].read.return_value = json.dumps({
|
||||
'generation': 'Response with both overrides'
|
||||
})
|
||||
mock_bedrock.invoke_model.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'mistral.mistral-large-2407-v1:0', # Default model
|
||||
'temperature': 0.0, # Default temperature
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="meta.llama3-70b-instruct-v1:0", # Override model (Meta/Llama)
|
||||
temperature=0.9 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with both overrides"
|
||||
|
||||
# Verify Bedrock API was called with both overrides
|
||||
mock_bedrock.invoke_model.assert_called_once()
|
||||
call_args = mock_bedrock.invoke_model.call_args
|
||||
assert call_args[1]['modelId'] == "meta.llama3-70b-instruct-v1:0"
|
||||
|
||||
# Verify the correct model variant (Meta) was used with correct temperature
|
||||
cache_key = f"meta.llama3-70b-instruct-v1:0:0.9"
|
||||
assert cache_key in processor.model_variants
|
||||
variant = processor.model_variants[cache_key]
|
||||
assert variant.temperature == 0.9
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
@ -42,7 +42,7 @@ class TestClaudeProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'claude-3-5-sonnet-20240620'
|
||||
assert processor.default_model == 'claude-3-5-sonnet-20240620'
|
||||
assert processor.temperature == 0.0
|
||||
assert processor.max_output == 8192
|
||||
assert hasattr(processor, 'claude')
|
||||
|
|
@ -217,7 +217,7 @@ class TestClaudeProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'claude-3-haiku-20240307'
|
||||
assert processor.default_model == 'claude-3-haiku-20240307'
|
||||
assert processor.temperature == 0.7
|
||||
assert processor.max_output == 4096
|
||||
mock_anthropic_class.assert_called_once_with(api_key='custom-api-key')
|
||||
|
|
@ -246,7 +246,7 @@ class TestClaudeProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'claude-3-5-sonnet-20240620' # default_model
|
||||
assert processor.default_model == 'claude-3-5-sonnet-20240620' # default_model
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
assert processor.max_output == 8192 # default_max_output
|
||||
mock_anthropic_class.assert_called_once_with(api_key='test-api-key')
|
||||
|
|
@ -433,7 +433,157 @@ class TestClaudeProcessorSimple(IsolatedAsyncioTestCase):
|
|||
|
||||
# Verify processor has the client
|
||||
assert processor.claude == mock_claude_client
|
||||
assert processor.model == 'claude-3-opus-20240229'
|
||||
assert processor.default_model == 'claude-3-opus-20240229'
|
||||
|
||||
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_anthropic_class):
|
||||
"""Test temperature parameter override functionality"""
|
||||
# Arrange
|
||||
mock_claude_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock()]
|
||||
mock_response.content[0].text = "Response with custom temperature"
|
||||
mock_response.usage.input_tokens = 20
|
||||
mock_response.usage.output_tokens = 12
|
||||
|
||||
mock_claude_client.messages.create.return_value = mock_response
|
||||
mock_anthropic_class.return_value = mock_claude_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'claude-3-5-sonnet-20240620',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model=None, # Use default model
|
||||
temperature=0.9 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom temperature"
|
||||
|
||||
# Verify Claude API was called with overridden temperature
|
||||
mock_claude_client.messages.create.assert_called_once()
|
||||
call_kwargs = mock_claude_client.messages.create.call_args.kwargs
|
||||
|
||||
assert call_kwargs['temperature'] == 0.9 # Should use runtime override
|
||||
assert call_kwargs['model'] == 'claude-3-5-sonnet-20240620' # Should use processor default
|
||||
|
||||
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_anthropic_class):
|
||||
"""Test model parameter override functionality"""
|
||||
# Arrange
|
||||
mock_claude_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock()]
|
||||
mock_response.content[0].text = "Response with custom model"
|
||||
mock_response.usage.input_tokens = 18
|
||||
mock_response.usage.output_tokens = 14
|
||||
|
||||
mock_claude_client.messages.create.return_value = mock_response
|
||||
mock_anthropic_class.return_value = mock_claude_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'claude-3-5-sonnet-20240620', # Default model
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.2, # Default temperature
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="claude-3-haiku-20240307", # Override model
|
||||
temperature=None # Use default temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom model"
|
||||
|
||||
# Verify Claude API was called with overridden model
|
||||
mock_claude_client.messages.create.assert_called_once()
|
||||
call_kwargs = mock_claude_client.messages.create.call_args.kwargs
|
||||
|
||||
assert call_kwargs['model'] == 'claude-3-haiku-20240307' # Should use runtime override
|
||||
assert call_kwargs['temperature'] == 0.2 # Should use processor default
|
||||
|
||||
@patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_anthropic_class):
|
||||
"""Test overriding both model and temperature parameters simultaneously"""
|
||||
# Arrange
|
||||
mock_claude_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock()]
|
||||
mock_response.content[0].text = "Response with both overrides"
|
||||
mock_response.usage.input_tokens = 22
|
||||
mock_response.usage.output_tokens = 16
|
||||
|
||||
mock_claude_client.messages.create.return_value = mock_response
|
||||
mock_anthropic_class.return_value = mock_claude_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'claude-3-5-sonnet-20240620', # Default model
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="claude-3-opus-20240229", # Override model
|
||||
temperature=0.8 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with both overrides"
|
||||
|
||||
# Verify Claude API was called with both overrides
|
||||
mock_claude_client.messages.create.assert_called_once()
|
||||
call_kwargs = mock_claude_client.messages.create.call_args.kwargs
|
||||
|
||||
assert call_kwargs['model'] == 'claude-3-opus-20240229' # Should use runtime override
|
||||
assert call_kwargs['temperature'] == 0.8 # Should use runtime override
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ class TestCohereProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'c4ai-aya-23-8b'
|
||||
assert processor.default_model == 'c4ai-aya-23-8b'
|
||||
assert processor.temperature == 0.0
|
||||
assert hasattr(processor, 'cohere')
|
||||
mock_cohere_class.assert_called_once_with(api_key='test-api-key')
|
||||
|
|
@ -201,7 +201,7 @@ class TestCohereProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'command-light'
|
||||
assert processor.default_model == 'command-light'
|
||||
assert processor.temperature == 0.7
|
||||
mock_cohere_class.assert_called_once_with(api_key='custom-api-key')
|
||||
|
||||
|
|
@ -229,7 +229,7 @@ class TestCohereProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'c4ai-aya-23-8b' # default_model
|
||||
assert processor.default_model == 'c4ai-aya-23-8b' # default_model
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
mock_cohere_class.assert_called_once_with(api_key='test-api-key')
|
||||
|
||||
|
|
@ -395,7 +395,7 @@ class TestCohereProcessorSimple(IsolatedAsyncioTestCase):
|
|||
|
||||
# Verify processor has the client
|
||||
assert processor.cohere == mock_cohere_client
|
||||
assert processor.model == 'command-r'
|
||||
assert processor.default_model == 'command-r'
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
|
|
@ -442,6 +442,162 @@ class TestCohereProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert call_args[1]['prompt_truncation'] == 'auto'
|
||||
assert call_args[1]['connectors'] == []
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_cohere_class):
|
||||
"""Test temperature parameter override functionality"""
|
||||
# Arrange
|
||||
mock_cohere_client = MagicMock()
|
||||
mock_output = MagicMock()
|
||||
mock_output.text = 'Response with custom temperature'
|
||||
mock_output.meta.billed_units.input_tokens = 20
|
||||
mock_output.meta.billed_units.output_tokens = 12
|
||||
|
||||
mock_cohere_client.chat.return_value = mock_output
|
||||
mock_cohere_class.return_value = mock_cohere_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'c4ai-aya-23-8b',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model=None, # Use default model
|
||||
temperature=0.8 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom temperature"
|
||||
|
||||
# Verify Cohere API was called with overridden temperature
|
||||
mock_cohere_client.chat.assert_called_once_with(
|
||||
model='c4ai-aya-23-8b',
|
||||
message='User prompt',
|
||||
preamble='System prompt',
|
||||
temperature=0.8, # Should use runtime override
|
||||
chat_history=[],
|
||||
prompt_truncation='auto',
|
||||
connectors=[]
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_cohere_class):
|
||||
"""Test model parameter override functionality"""
|
||||
# Arrange
|
||||
mock_cohere_client = MagicMock()
|
||||
mock_output = MagicMock()
|
||||
mock_output.text = 'Response with custom model'
|
||||
mock_output.meta.billed_units.input_tokens = 18
|
||||
mock_output.meta.billed_units.output_tokens = 14
|
||||
|
||||
mock_cohere_client.chat.return_value = mock_output
|
||||
mock_cohere_class.return_value = mock_cohere_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'c4ai-aya-23-8b', # Default model
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.1, # Default temperature
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="command-r-plus", # Override model
|
||||
temperature=None # Use default temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom model"
|
||||
|
||||
# Verify Cohere API was called with overridden model
|
||||
mock_cohere_client.chat.assert_called_once_with(
|
||||
model='command-r-plus', # Should use runtime override
|
||||
message='User prompt',
|
||||
preamble='System prompt',
|
||||
temperature=0.1, # Should use processor default
|
||||
chat_history=[],
|
||||
prompt_truncation='auto',
|
||||
connectors=[]
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_cohere_class):
|
||||
"""Test overriding both model and temperature parameters simultaneously"""
|
||||
# Arrange
|
||||
mock_cohere_client = MagicMock()
|
||||
mock_output = MagicMock()
|
||||
mock_output.text = 'Response with both overrides'
|
||||
mock_output.meta.billed_units.input_tokens = 22
|
||||
mock_output.meta.billed_units.output_tokens = 16
|
||||
|
||||
mock_cohere_client.chat.return_value = mock_output
|
||||
mock_cohere_class.return_value = mock_cohere_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'c4ai-aya-23-8b', # Default model
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="command-r", # Override model
|
||||
temperature=0.9 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with both overrides"
|
||||
|
||||
# Verify Cohere API was called with both overrides
|
||||
mock_cohere_client.chat.assert_called_once_with(
|
||||
model='command-r', # Should use runtime override
|
||||
message='User prompt',
|
||||
preamble='System prompt',
|
||||
temperature=0.9, # Should use runtime override
|
||||
chat_history=[],
|
||||
prompt_truncation='auto',
|
||||
connectors=[]
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
@ -42,7 +42,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gemini-2.0-flash-001'
|
||||
assert processor.default_model == 'gemini-2.0-flash-001'
|
||||
assert processor.temperature == 0.0
|
||||
assert processor.max_output == 8192
|
||||
assert hasattr(processor, 'client')
|
||||
|
|
@ -205,7 +205,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gemini-1.5-pro'
|
||||
assert processor.default_model == 'gemini-1.5-pro'
|
||||
assert processor.temperature == 0.7
|
||||
assert processor.max_output == 4096
|
||||
mock_genai_class.assert_called_once_with(api_key='custom-api-key')
|
||||
|
|
@ -234,7 +234,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gemini-2.0-flash-001' # default_model
|
||||
assert processor.default_model == 'gemini-2.0-flash-001' # default_model
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
assert processor.max_output == 8192 # default_max_output
|
||||
mock_genai_class.assert_called_once_with(api_key='test-api-key')
|
||||
|
|
@ -431,7 +431,7 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
|
|||
|
||||
# Verify processor has the client
|
||||
assert processor.client == mock_genai_client
|
||||
assert processor.model == 'gemini-1.5-flash'
|
||||
assert processor.default_model == 'gemini-1.5-flash'
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
|
|
@ -477,6 +477,156 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase):
|
|||
# The system instruction should be in the config object
|
||||
assert call_args[1]['contents'] == "Explain quantum computing"
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test temperature parameter override functionality"""
|
||||
# Arrange
|
||||
mock_genai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = 'Response with custom temperature'
|
||||
mock_response.usage_metadata.prompt_token_count = 20
|
||||
mock_response.usage_metadata.candidates_token_count = 12
|
||||
|
||||
mock_genai_client.models.generate_content.return_value = mock_response
|
||||
mock_genai_class.return_value = mock_genai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model=None, # Use default model
|
||||
temperature=0.8 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom temperature"
|
||||
|
||||
# Verify the generation config was created with overridden temperature
|
||||
cache_key = f"gemini-2.0-flash-001:0.8"
|
||||
assert cache_key in processor.generation_configs
|
||||
config_obj = processor.generation_configs[cache_key]
|
||||
assert config_obj.temperature == 0.8
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test model parameter override functionality"""
|
||||
# Arrange
|
||||
mock_genai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = 'Response with custom model'
|
||||
mock_response.usage_metadata.prompt_token_count = 18
|
||||
mock_response.usage_metadata.candidates_token_count = 14
|
||||
|
||||
mock_genai_client.models.generate_content.return_value = mock_response
|
||||
mock_genai_class.return_value = mock_genai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001', # Default model
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.1, # Default temperature
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="gemini-1.5-pro", # Override model
|
||||
temperature=None # Use default temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom model"
|
||||
|
||||
# Verify Google AI Studio API was called with overridden model
|
||||
call_args = mock_genai_client.models.generate_content.call_args
|
||||
assert call_args[1]['model'] == 'gemini-1.5-pro' # Should use runtime override
|
||||
|
||||
# Verify the generation config was created for the correct model
|
||||
cache_key = f"gemini-1.5-pro:0.1"
|
||||
assert cache_key in processor.generation_configs
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_genai_class):
|
||||
"""Test overriding both model and temperature parameters simultaneously"""
|
||||
# Arrange
|
||||
mock_genai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = 'Response with both overrides'
|
||||
mock_response.usage_metadata.prompt_token_count = 22
|
||||
mock_response.usage_metadata.candidates_token_count = 16
|
||||
|
||||
mock_genai_client.models.generate_content.return_value = mock_response
|
||||
mock_genai_class.return_value = mock_genai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001', # Default model
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 8192,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="gemini-1.5-flash", # Override model
|
||||
temperature=0.9 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with both overrides"
|
||||
|
||||
# Verify Google AI Studio API was called with both overrides
|
||||
call_args = mock_genai_client.models.generate_content.call_args
|
||||
assert call_args[1]['model'] == 'gemini-1.5-flash' # Should use runtime override
|
||||
|
||||
# Verify the generation config was created with both overrides
|
||||
cache_key = f"gemini-1.5-flash:0.9"
|
||||
assert cache_key in processor.generation_configs
|
||||
config_obj = processor.generation_configs[cache_key]
|
||||
assert config_obj.temperature == 0.9
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
@ -42,7 +42,7 @@ class TestLlamaFileProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'LLaMA_CPP'
|
||||
assert processor.default_model == 'LLaMA_CPP'
|
||||
assert processor.llamafile == 'http://localhost:8080/v1'
|
||||
assert processor.temperature == 0.0
|
||||
assert processor.max_output == 4096
|
||||
|
|
@ -91,7 +91,7 @@ class TestLlamaFileProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert result.text == "Generated response from LlamaFile"
|
||||
assert result.in_token == 20
|
||||
assert result.out_token == 12
|
||||
assert result.model == 'llama.cpp' # Note: model in result is hardcoded to 'llama.cpp'
|
||||
assert result.model == 'LLaMA_CPP' # Uses the default model name
|
||||
|
||||
# Verify the OpenAI API call structure
|
||||
mock_openai_client.chat.completions.create.assert_called_once_with(
|
||||
|
|
@ -99,7 +99,15 @@ class TestLlamaFileProcessorSimple(IsolatedAsyncioTestCase):
|
|||
messages=[{
|
||||
"role": "user",
|
||||
"content": "System prompt\n\nUser prompt"
|
||||
}]
|
||||
}],
|
||||
temperature=0.0,
|
||||
max_tokens=4096,
|
||||
top_p=1,
|
||||
frequency_penalty=0,
|
||||
presence_penalty=0,
|
||||
response_format={
|
||||
"type": "text"
|
||||
}
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
|
|
@ -157,7 +165,7 @@ class TestLlamaFileProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'custom-llama'
|
||||
assert processor.default_model == 'custom-llama'
|
||||
assert processor.llamafile == 'http://custom-host:8080/v1'
|
||||
assert processor.temperature == 0.7
|
||||
assert processor.max_output == 2048
|
||||
|
|
@ -189,7 +197,7 @@ class TestLlamaFileProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'LLaMA_CPP' # default_model
|
||||
assert processor.default_model == 'LLaMA_CPP' # default_model
|
||||
assert processor.llamafile == 'http://localhost:8080/v1' # default_llamafile
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
assert processor.max_output == 4096 # default_max_output
|
||||
|
|
@ -237,7 +245,7 @@ class TestLlamaFileProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert result.text == "Default response"
|
||||
assert result.in_token == 2
|
||||
assert result.out_token == 3
|
||||
assert result.model == 'llama.cpp'
|
||||
assert result.model == 'LLaMA_CPP'
|
||||
|
||||
# Verify the combined prompt is sent correctly
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
|
|
@ -408,8 +416,8 @@ class TestLlamaFileProcessorSimple(IsolatedAsyncioTestCase):
|
|||
result = await processor.generate_content("System", "User")
|
||||
|
||||
# Assert
|
||||
assert result.model == 'llama.cpp' # Should always be 'llama.cpp', not 'custom-model-name'
|
||||
assert processor.model == 'custom-model-name' # But processor.model should still be custom
|
||||
assert result.model == 'custom-model-name' # Uses the actual model name passed to generate_content
|
||||
assert processor.default_model == 'custom-model-name' # But processor.model should still be custom
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
|
|
@ -450,5 +458,132 @@ class TestLlamaFileProcessorSimple(IsolatedAsyncioTestCase):
|
|||
# No specific rate limit error handling tested since SLM presumably has no rate limits
|
||||
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_model_override(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test generate_content with model parameter override"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Response from overridden model"
|
||||
mock_response.usage.prompt_tokens = 15
|
||||
mock_response.usage.completion_tokens = 10
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'LLaMA_CPP',
|
||||
'llamafile': 'http://localhost:8080/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model
|
||||
result = await processor.generate_content("System", "Prompt", model="custom-llamafile-model")
|
||||
|
||||
# Assert
|
||||
assert result.model == "custom-llamafile-model" # Should use overridden model
|
||||
assert result.text == "Response from overridden model"
|
||||
|
||||
# Verify the API call was made with overridden model
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
assert call_args[1]['model'] == "custom-llamafile-model"
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_temperature_override(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test generate_content with temperature parameter override"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Response with temperature override"
|
||||
mock_response.usage.prompt_tokens = 18
|
||||
mock_response.usage.completion_tokens = 12
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'LLaMA_CPP',
|
||||
'llamafile': 'http://localhost:8080/v1',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature
|
||||
result = await processor.generate_content("System", "Prompt", temperature=0.7)
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with temperature override"
|
||||
|
||||
# Verify the API call was made with overridden temperature
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
assert call_args[1]['temperature'] == 0.7
|
||||
|
||||
@patch('trustgraph.model.text_completion.llamafile.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_both_parameters_override(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test generate_content with both model and temperature overrides"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Response with both parameters override"
|
||||
mock_response.usage.prompt_tokens = 20
|
||||
mock_response.usage.completion_tokens = 15
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'LLaMA_CPP',
|
||||
'llamafile': 'http://localhost:8080/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters
|
||||
result = await processor.generate_content("System", "Prompt", model="override-model", temperature=0.8)
|
||||
|
||||
# Assert
|
||||
assert result.model == "override-model"
|
||||
assert result.text == "Response with both parameters override"
|
||||
|
||||
# Verify the API call was made with overridden parameters
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
assert call_args[1]['model'] == "override-model"
|
||||
assert call_args[1]['temperature'] == 0.8
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
229
tests/unit/test_text_completion/test_lmstudio_processor.py
Normal file
229
tests/unit/test_text_completion/test_lmstudio_processor.py
Normal file
|
|
@ -0,0 +1,229 @@
|
|||
"""
|
||||
Unit tests for trustgraph.model.text_completion.lmstudio
|
||||
Following the same successful pattern as previous tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.model.text_completion.lmstudio.llm import Processor
|
||||
from trustgraph.base import LlmResult
|
||||
from trustgraph.exceptions import TooManyRequests
|
||||
|
||||
|
||||
class TestLMStudioProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Test LMStudio processor functionality"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.lmstudio.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test basic processor initialization"""
|
||||
# Arrange
|
||||
mock_openai = MagicMock()
|
||||
mock_openai_class.return_value = mock_openai
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemma3:9b',
|
||||
'url': 'http://localhost:1234/',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.default_model == 'gemma3:9b'
|
||||
assert processor.url == 'http://localhost:1234/v1/'
|
||||
assert processor.temperature == 0.0
|
||||
assert processor.max_output == 4096
|
||||
assert hasattr(processor, 'openai')
|
||||
mock_openai_class.assert_called_once_with(
|
||||
base_url='http://localhost:1234/v1/',
|
||||
api_key='sk-no-key-required'
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.lmstudio.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test successful content generation"""
|
||||
# Arrange
|
||||
mock_openai = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices[0].message.content = 'Generated response from LMStudio'
|
||||
mock_response.usage.prompt_tokens = 20
|
||||
mock_response.usage.completion_tokens = 12
|
||||
|
||||
mock_openai.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemma3:9b',
|
||||
'url': 'http://localhost:1234/',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Generated response from LMStudio"
|
||||
assert result.in_token == 20
|
||||
assert result.out_token == 12
|
||||
assert result.model == 'gemma3:9b'
|
||||
|
||||
# Verify the API call was made correctly
|
||||
mock_openai.chat.completions.create.assert_called_once()
|
||||
call_args = mock_openai.chat.completions.create.call_args
|
||||
|
||||
# Check model and temperature
|
||||
assert call_args[1]['model'] == 'gemma3:9b'
|
||||
assert call_args[1]['temperature'] == 0.0
|
||||
assert call_args[1]['max_tokens'] == 4096
|
||||
|
||||
@patch('trustgraph.model.text_completion.lmstudio.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_model_override(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test generate_content with model parameter override"""
|
||||
# Arrange
|
||||
mock_openai = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices[0].message.content = 'Response from overridden model'
|
||||
mock_response.usage.prompt_tokens = 15
|
||||
mock_response.usage.completion_tokens = 10
|
||||
|
||||
mock_openai.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemma3:9b',
|
||||
'url': 'http://localhost:1234/',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model
|
||||
result = await processor.generate_content("System", "Prompt", model="custom-lmstudio-model")
|
||||
|
||||
# Assert
|
||||
assert result.model == "custom-lmstudio-model" # Should use overridden model
|
||||
assert result.text == "Response from overridden model"
|
||||
|
||||
# Verify the API call was made with overridden model
|
||||
call_args = mock_openai.chat.completions.create.call_args
|
||||
assert call_args[1]['model'] == "custom-lmstudio-model"
|
||||
|
||||
@patch('trustgraph.model.text_completion.lmstudio.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_temperature_override(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test generate_content with temperature parameter override"""
|
||||
# Arrange
|
||||
mock_openai = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices[0].message.content = 'Response with temperature override'
|
||||
mock_response.usage.prompt_tokens = 18
|
||||
mock_response.usage.completion_tokens = 12
|
||||
|
||||
mock_openai.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemma3:9b',
|
||||
'url': 'http://localhost:1234/',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature
|
||||
result = await processor.generate_content("System", "Prompt", temperature=0.7)
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with temperature override"
|
||||
|
||||
# Verify the API call was made with overridden temperature
|
||||
call_args = mock_openai.chat.completions.create.call_args
|
||||
assert call_args[1]['temperature'] == 0.7
|
||||
|
||||
@patch('trustgraph.model.text_completion.lmstudio.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_both_parameters_override(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test generate_content with both model and temperature overrides"""
|
||||
# Arrange
|
||||
mock_openai = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices[0].message.content = 'Response with both parameters override'
|
||||
mock_response.usage.prompt_tokens = 20
|
||||
mock_response.usage.completion_tokens = 15
|
||||
|
||||
mock_openai.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemma3:9b',
|
||||
'url': 'http://localhost:1234/',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters
|
||||
result = await processor.generate_content("System", "Prompt", model="override-model", temperature=0.8)
|
||||
|
||||
# Assert
|
||||
assert result.model == "override-model"
|
||||
assert result.text == "Response with both parameters override"
|
||||
|
||||
# Verify the API call was made with overridden parameters
|
||||
call_args = mock_openai.chat.completions.create.call_args
|
||||
assert call_args[1]['model'] == "override-model"
|
||||
assert call_args[1]['temperature'] == 0.8
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
275
tests/unit/test_text_completion/test_mistral_processor.py
Normal file
275
tests/unit/test_text_completion/test_mistral_processor.py
Normal file
|
|
@ -0,0 +1,275 @@
|
|||
"""
|
||||
Unit tests for trustgraph.model.text_completion.mistral
|
||||
Following the same successful pattern as other processor tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.model.text_completion.mistral.llm import Processor
|
||||
from trustgraph.base import LlmResult
|
||||
|
||||
|
||||
class TestMistralProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Test Mistral processor functionality"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.mistral.llm.Mistral')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_mistral_class):
|
||||
"""Test basic processor initialization"""
|
||||
# Arrange
|
||||
mock_mistral_client = MagicMock()
|
||||
mock_mistral_class.return_value = mock_mistral_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'ministral-8b-latest',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.1,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.default_model == 'ministral-8b-latest'
|
||||
assert processor.temperature == 0.1
|
||||
assert processor.max_output == 2048
|
||||
assert hasattr(processor, 'mistral')
|
||||
mock_mistral_class.assert_called_once_with(api_key='test-api-key')
|
||||
|
||||
@patch('trustgraph.model.text_completion.mistral.llm.Mistral')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_mistral_class):
|
||||
"""Test successful content generation"""
|
||||
# Arrange
|
||||
mock_mistral_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices[0].message.content = 'Generated response from Mistral'
|
||||
mock_response.usage.prompt_tokens = 15
|
||||
mock_response.usage.completion_tokens = 8
|
||||
mock_mistral_client.chat.complete.return_value = mock_response
|
||||
mock_mistral_class.return_value = mock_mistral_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'ministral-8b-latest',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Generated response from Mistral"
|
||||
assert result.in_token == 15
|
||||
assert result.out_token == 8
|
||||
assert result.model == 'ministral-8b-latest'
|
||||
mock_mistral_client.chat.complete.assert_called_once()
|
||||
|
||||
@patch('trustgraph.model.text_completion.mistral.llm.Mistral')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_mistral_class):
|
||||
"""Test temperature parameter override functionality"""
|
||||
# Arrange
|
||||
mock_mistral_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices[0].message.content = 'Response with custom temperature'
|
||||
mock_response.usage.prompt_tokens = 20
|
||||
mock_response.usage.completion_tokens = 12
|
||||
mock_mistral_client.chat.complete.return_value = mock_response
|
||||
mock_mistral_class.return_value = mock_mistral_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'ministral-8b-latest',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model=None, # Use default model
|
||||
temperature=0.8 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom temperature"
|
||||
|
||||
# Verify Mistral API was called with overridden temperature
|
||||
call_args = mock_mistral_client.chat.complete.call_args
|
||||
assert call_args[1]['temperature'] == 0.8 # Should use runtime override
|
||||
assert call_args[1]['model'] == 'ministral-8b-latest'
|
||||
|
||||
@patch('trustgraph.model.text_completion.mistral.llm.Mistral')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_mistral_class):
|
||||
"""Test model parameter override functionality"""
|
||||
# Arrange
|
||||
mock_mistral_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices[0].message.content = 'Response with custom model'
|
||||
mock_response.usage.prompt_tokens = 18
|
||||
mock_response.usage.completion_tokens = 14
|
||||
mock_mistral_client.chat.complete.return_value = mock_response
|
||||
mock_mistral_class.return_value = mock_mistral_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'ministral-8b-latest', # Default model
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.1, # Default temperature
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="mistral-large-latest", # Override model
|
||||
temperature=None # Use default temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom model"
|
||||
|
||||
# Verify Mistral API was called with overridden model
|
||||
call_args = mock_mistral_client.chat.complete.call_args
|
||||
assert call_args[1]['model'] == 'mistral-large-latest' # Should use runtime override
|
||||
assert call_args[1]['temperature'] == 0.1 # Should use processor default
|
||||
|
||||
@patch('trustgraph.model.text_completion.mistral.llm.Mistral')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_mistral_class):
|
||||
"""Test overriding both model and temperature parameters simultaneously"""
|
||||
# Arrange
|
||||
mock_mistral_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices[0].message.content = 'Response with both overrides'
|
||||
mock_response.usage.prompt_tokens = 22
|
||||
mock_response.usage.completion_tokens = 16
|
||||
mock_mistral_client.chat.complete.return_value = mock_response
|
||||
mock_mistral_class.return_value = mock_mistral_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'ministral-8b-latest', # Default model
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="mistral-large-latest", # Override model
|
||||
temperature=0.9 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with both overrides"
|
||||
|
||||
# Verify Mistral API was called with both overrides
|
||||
call_args = mock_mistral_client.chat.complete.call_args
|
||||
assert call_args[1]['model'] == 'mistral-large-latest' # Should use runtime override
|
||||
assert call_args[1]['temperature'] == 0.9 # Should use runtime override
|
||||
|
||||
@patch('trustgraph.model.text_completion.mistral.llm.Mistral')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_prompt_construction(self, mock_llm_init, mock_async_init, mock_mistral_class):
|
||||
"""Test prompt construction with system and user prompts"""
|
||||
# Arrange
|
||||
mock_mistral_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices[0].message.content = 'Response with system instructions'
|
||||
mock_response.usage.prompt_tokens = 25
|
||||
mock_response.usage.completion_tokens = 15
|
||||
mock_mistral_client.chat.complete.return_value = mock_response
|
||||
mock_mistral_class.return_value = mock_mistral_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'ministral-8b-latest',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("You are a helpful assistant", "What is AI?")
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with system instructions"
|
||||
assert result.in_token == 25
|
||||
assert result.out_token == 15
|
||||
|
||||
# Verify the combined prompt structure
|
||||
call_args = mock_mistral_client.chat.complete.call_args
|
||||
messages = call_args[1]['messages']
|
||||
assert len(messages) == 1
|
||||
assert messages[0]['role'] == 'user'
|
||||
assert messages[0]['content'][0]['type'] == 'text'
|
||||
assert messages[0]['content'][0]['text'] == "You are a helpful assistant\n\nWhat is AI?"
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
@ -40,7 +40,7 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'llama2'
|
||||
assert processor.default_model == 'llama2'
|
||||
assert hasattr(processor, 'llm')
|
||||
mock_client_class.assert_called_once_with(host='http://localhost:11434')
|
||||
|
||||
|
|
@ -81,7 +81,7 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert result.in_token == 15
|
||||
assert result.out_token == 8
|
||||
assert result.model == 'llama2'
|
||||
mock_client.generate.assert_called_once_with('llama2', "System prompt\n\nUser prompt")
|
||||
mock_client.generate.assert_called_once_with('llama2', "System prompt\n\nUser prompt", options={'temperature': 0.0})
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
|
|
@ -134,7 +134,7 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'mistral'
|
||||
assert processor.default_model == 'mistral'
|
||||
mock_client_class.assert_called_once_with(host='http://192.168.1.100:11434')
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
|
|
@ -160,7 +160,7 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gemma2:9b' # default_model
|
||||
assert processor.default_model == 'gemma2:9b' # default_model
|
||||
# Should use default_ollama (http://localhost:11434 or from OLLAMA_HOST env)
|
||||
mock_client_class.assert_called_once()
|
||||
|
||||
|
|
@ -203,7 +203,7 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert result.model == 'llama2'
|
||||
|
||||
# The prompt should be "" + "\n\n" + "" = "\n\n"
|
||||
mock_client.generate.assert_called_once_with('llama2', "\n\n")
|
||||
mock_client.generate.assert_called_once_with('llama2', "\n\n", options={'temperature': 0.0})
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
|
|
@ -310,7 +310,151 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert result.out_token == 15
|
||||
|
||||
# Verify the combined prompt
|
||||
mock_client.generate.assert_called_once_with('llama2', "You are a helpful assistant\n\nWhat is AI?")
|
||||
mock_client.generate.assert_called_once_with('llama2', "You are a helpful assistant\n\nWhat is AI?", options={'temperature': 0.0})
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test temperature parameter override functionality"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_response = {
|
||||
'response': 'Response with custom temperature',
|
||||
'prompt_eval_count': 20,
|
||||
'eval_count': 12
|
||||
}
|
||||
mock_client.generate.return_value = mock_response
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'llama2',
|
||||
'ollama': 'http://localhost:11434',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model=None, # Use default model
|
||||
temperature=0.8 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom temperature"
|
||||
|
||||
# Verify Ollama API was called with overridden temperature
|
||||
mock_client.generate.assert_called_once_with(
|
||||
'llama2',
|
||||
"System prompt\n\nUser prompt",
|
||||
options={'temperature': 0.8} # Should use runtime override
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test model parameter override functionality"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_response = {
|
||||
'response': 'Response with custom model',
|
||||
'prompt_eval_count': 18,
|
||||
'eval_count': 14
|
||||
}
|
||||
mock_client.generate.return_value = mock_response
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'llama2', # Default model
|
||||
'ollama': 'http://localhost:11434',
|
||||
'temperature': 0.1, # Default temperature
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="mistral", # Override model
|
||||
temperature=None # Use default temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom model"
|
||||
|
||||
# Verify Ollama API was called with overridden model
|
||||
mock_client.generate.assert_called_once_with(
|
||||
'mistral', # Should use runtime override
|
||||
"System prompt\n\nUser prompt",
|
||||
options={'temperature': 0.1} # Should use processor default
|
||||
)
|
||||
|
||||
@patch('trustgraph.model.text_completion.ollama.llm.Client')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_client_class):
|
||||
"""Test overriding both model and temperature parameters simultaneously"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_response = {
|
||||
'response': 'Response with both overrides',
|
||||
'prompt_eval_count': 22,
|
||||
'eval_count': 16
|
||||
}
|
||||
mock_client.generate.return_value = mock_response
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'llama2', # Default model
|
||||
'ollama': 'http://localhost:11434',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="codellama", # Override model
|
||||
temperature=0.9 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with both overrides"
|
||||
|
||||
# Verify Ollama API was called with both overrides
|
||||
mock_client.generate.assert_called_once_with(
|
||||
'codellama', # Should use runtime override
|
||||
"System prompt\n\nUser prompt",
|
||||
options={'temperature': 0.9} # Should use runtime override
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ class TestOpenAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gpt-3.5-turbo'
|
||||
assert processor.default_model == 'gpt-3.5-turbo'
|
||||
assert processor.temperature == 0.0
|
||||
assert processor.max_output == 4096
|
||||
assert hasattr(processor, 'openai')
|
||||
|
|
@ -222,7 +222,7 @@ class TestOpenAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gpt-4'
|
||||
assert processor.default_model == 'gpt-4'
|
||||
assert processor.temperature == 0.7
|
||||
assert processor.max_output == 2048
|
||||
mock_openai_class.assert_called_once_with(base_url='https://custom-openai-url.com/v1', api_key='custom-api-key')
|
||||
|
|
@ -251,7 +251,7 @@ class TestOpenAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gpt-3.5-turbo' # default_model
|
||||
assert processor.default_model == 'gpt-3.5-turbo' # default_model
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
assert processor.max_output == 4096 # default_max_output
|
||||
mock_openai_class.assert_called_once_with(base_url='https://api.openai.com/v1', api_key='test-api-key')
|
||||
|
|
@ -391,5 +391,210 @@ class TestOpenAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert call_args[1]['response_format'] == {"type": "text"}
|
||||
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test temperature parameter override functionality"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Response with custom temperature"
|
||||
mock_response.usage.prompt_tokens = 15
|
||||
mock_response.usage.completion_tokens = 10
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-3.5-turbo',
|
||||
'api_key': 'test-api-key',
|
||||
'url': 'https://api.openai.com/v1',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model=None, # Use default model
|
||||
temperature=0.9 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom temperature"
|
||||
|
||||
# Verify the OpenAI API was called with overridden temperature
|
||||
mock_openai_client.chat.completions.create.assert_called_once()
|
||||
call_kwargs = mock_openai_client.chat.completions.create.call_args.kwargs
|
||||
|
||||
assert call_kwargs['temperature'] == 0.9 # Should use runtime override
|
||||
assert call_kwargs['model'] == 'gpt-3.5-turbo' # Should use processor default
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test model parameter override functionality"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Response with custom model"
|
||||
mock_response.usage.prompt_tokens = 15
|
||||
mock_response.usage.completion_tokens = 10
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-3.5-turbo', # Default model
|
||||
'api_key': 'test-api-key',
|
||||
'url': 'https://api.openai.com/v1',
|
||||
'temperature': 0.2,
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="gpt-4", # Override model
|
||||
temperature=None # Use default temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom model"
|
||||
|
||||
# Verify the OpenAI API was called with overridden model
|
||||
mock_openai_client.chat.completions.create.assert_called_once()
|
||||
call_kwargs = mock_openai_client.chat.completions.create.call_args.kwargs
|
||||
|
||||
assert call_kwargs['model'] == 'gpt-4' # Should use runtime override
|
||||
assert call_kwargs['temperature'] == 0.2 # Should use processor default
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test overriding both model and temperature parameters simultaneously"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Response with both overrides"
|
||||
mock_response.usage.prompt_tokens = 15
|
||||
mock_response.usage.completion_tokens = 10
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-3.5-turbo', # Default model
|
||||
'api_key': 'test-api-key',
|
||||
'url': 'https://api.openai.com/v1',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="gpt-4", # Override model
|
||||
temperature=0.7 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with both overrides"
|
||||
|
||||
# Verify the OpenAI API was called with both overrides
|
||||
mock_openai_client.chat.completions.create.assert_called_once()
|
||||
call_kwargs = mock_openai_client.chat.completions.create.call_args.kwargs
|
||||
|
||||
assert call_kwargs['model'] == 'gpt-4' # Should use runtime override
|
||||
assert call_kwargs['temperature'] == 0.7 # Should use runtime override
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_no_override_uses_defaults(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||
"""Test that when no parameters are overridden, processor defaults are used"""
|
||||
# Arrange
|
||||
mock_openai_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Response with defaults"
|
||||
mock_response.usage.prompt_tokens = 15
|
||||
mock_response.usage.completion_tokens = 10
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gpt-4', # Default model
|
||||
'api_key': 'test-api-key',
|
||||
'url': 'https://api.openai.com/v1',
|
||||
'temperature': 0.5, # Default temperature
|
||||
'max_output': 4096,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Don't override any parameters (pass None)
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model=None, # Use default model
|
||||
temperature=None # Use default temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with defaults"
|
||||
|
||||
# Verify the OpenAI API was called with processor defaults
|
||||
mock_openai_client.chat.completions.create.assert_called_once()
|
||||
call_kwargs = mock_openai_client.chat.completions.create.call_args.kwargs
|
||||
|
||||
assert call_kwargs['model'] == 'gpt-4' # Should use processor default
|
||||
assert call_kwargs['temperature'] == 0.5 # Should use processor default
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
186
tests/unit/test_text_completion/test_parameter_caching.py
Normal file
186
tests/unit/test_text_completion/test_parameter_caching.py
Normal file
|
|
@ -0,0 +1,186 @@
|
|||
"""
|
||||
Unit tests for Parameter-Based Caching in LLM Processors
|
||||
Testing processors that cache based on temperature parameters (Bedrock, GoogleAIStudio)
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
from trustgraph.model.text_completion.googleaistudio.llm import Processor as GoogleAIProcessor
|
||||
from trustgraph.base import LlmResult
|
||||
|
||||
|
||||
class TestParameterCaching(IsolatedAsyncioTestCase):
|
||||
"""Test parameter-based caching functionality"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_googleai_temperature_cache_keys(self, mock_llm_init, mock_async_init, mock_genai):
|
||||
"""Test that GoogleAI processor creates separate cache entries for different temperatures"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_genai.Client.return_value = mock_client
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Generated response"
|
||||
mock_response.usage_metadata.prompt_token_count = 10
|
||||
mock_response.usage_metadata.candidates_token_count = 5
|
||||
mock_client.models.generate_content.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 1024,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = GoogleAIProcessor(**config)
|
||||
|
||||
# Act - Call with different temperatures
|
||||
await processor.generate_content("System", "Prompt 1", model="gemini-2.0-flash-001", temperature=0.0)
|
||||
await processor.generate_content("System", "Prompt 2", model="gemini-2.0-flash-001", temperature=0.5)
|
||||
await processor.generate_content("System", "Prompt 3", model="gemini-2.0-flash-001", temperature=1.0)
|
||||
|
||||
# Assert - Should have 3 different cache entries
|
||||
cache_keys = list(processor.generation_configs.keys())
|
||||
|
||||
assert len(cache_keys) == 3
|
||||
assert "gemini-2.0-flash-001:0.0" in cache_keys
|
||||
assert "gemini-2.0-flash-001:0.5" in cache_keys
|
||||
assert "gemini-2.0-flash-001:1.0" in cache_keys
|
||||
|
||||
# Verify each cached config has the correct temperature
|
||||
assert processor.generation_configs["gemini-2.0-flash-001:0.0"].temperature == 0.0
|
||||
assert processor.generation_configs["gemini-2.0-flash-001:0.5"].temperature == 0.5
|
||||
assert processor.generation_configs["gemini-2.0-flash-001:1.0"].temperature == 1.0
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_googleai_cache_reuse_same_parameters(self, mock_llm_init, mock_async_init, mock_genai):
|
||||
"""Test that GoogleAI processor reuses cache for identical model+temperature combinations"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_genai.Client.return_value = mock_client
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Generated response"
|
||||
mock_response.usage_metadata.prompt_token_count = 10
|
||||
mock_response.usage_metadata.candidates_token_count = 5
|
||||
mock_client.models.generate_content.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 1024,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = GoogleAIProcessor(**config)
|
||||
|
||||
# Act - Call multiple times with same parameters
|
||||
await processor.generate_content("System", "Prompt 1", model="gemini-2.0-flash-001", temperature=0.7)
|
||||
await processor.generate_content("System", "Prompt 2", model="gemini-2.0-flash-001", temperature=0.7)
|
||||
await processor.generate_content("System", "Prompt 3", model="gemini-2.0-flash-001", temperature=0.7)
|
||||
|
||||
# Assert - Should have only 1 cache entry for the repeated parameters
|
||||
cache_keys = list(processor.generation_configs.keys())
|
||||
assert len(cache_keys) == 1
|
||||
assert "gemini-2.0-flash-001:0.7" in cache_keys
|
||||
|
||||
# The same config object should be reused
|
||||
config_obj = processor.generation_configs["gemini-2.0-flash-001:0.7"]
|
||||
assert config_obj.temperature == 0.7
|
||||
|
||||
@patch('trustgraph.model.text_completion.googleaistudio.llm.genai')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_googleai_different_models_separate_caches(self, mock_llm_init, mock_async_init, mock_genai):
|
||||
"""Test that different models create separate cache entries even with same temperature"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_genai.Client.return_value = mock_client
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Generated response"
|
||||
mock_response.usage_metadata.prompt_token_count = 10
|
||||
mock_response.usage_metadata.candidates_token_count = 5
|
||||
mock_client.models.generate_content.return_value = mock_response
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'api_key': 'test-api-key',
|
||||
'temperature': 0.0,
|
||||
'max_output': 1024,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = GoogleAIProcessor(**config)
|
||||
|
||||
# Act - Call with different models, same temperature
|
||||
await processor.generate_content("System", "Prompt 1", model="gemini-2.0-flash-001", temperature=0.5)
|
||||
await processor.generate_content("System", "Prompt 2", model="gemini-1.5-flash-001", temperature=0.5)
|
||||
|
||||
# Assert - Should have separate cache entries for different models
|
||||
cache_keys = list(processor.generation_configs.keys())
|
||||
assert len(cache_keys) == 2
|
||||
assert "gemini-2.0-flash-001:0.5" in cache_keys
|
||||
assert "gemini-1.5-flash-001:0.5" in cache_keys
|
||||
|
||||
# Note: Bedrock tests would be similar but testing the Bedrock processor's caching behavior
|
||||
# The Bedrock processor caches model variants with temperature in the cache key
|
||||
|
||||
async def test_bedrock_temperature_cache_keys(self):
|
||||
"""Test Bedrock processor temperature-aware caching"""
|
||||
# This would test the Bedrock processor's _get_or_create_variant method
|
||||
# with different temperature values to ensure proper cache key generation
|
||||
|
||||
# Implementation would follow similar pattern to GoogleAI tests above
|
||||
# but using the Bedrock processor and testing model_variants cache
|
||||
pass
|
||||
|
||||
async def test_bedrock_cache_isolation_different_temperatures(self):
|
||||
"""Test that Bedrock processor isolates cache entries by temperature"""
|
||||
pass
|
||||
|
||||
async def test_cache_memory_efficiency(self):
|
||||
"""Test that caches don't grow unbounded with many different parameter combinations"""
|
||||
# This could test cache size limits or cleanup behavior if implemented
|
||||
pass
|
||||
|
||||
|
||||
class TestCachePerformance(IsolatedAsyncioTestCase):
|
||||
"""Test caching performance characteristics"""
|
||||
|
||||
async def test_cache_hit_performance(self):
|
||||
"""Test that cache hits are faster than cache misses"""
|
||||
# This would measure timing differences between cache hits and misses
|
||||
pass
|
||||
|
||||
async def test_concurrent_cache_access(self):
|
||||
"""Test concurrent access to cached configurations"""
|
||||
# This would test thread-safety of cache access
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
271
tests/unit/test_text_completion/test_tgi_processor.py
Normal file
271
tests/unit/test_text_completion/test_tgi_processor.py
Normal file
|
|
@ -0,0 +1,271 @@
|
|||
"""
|
||||
Unit tests for trustgraph.model.text_completion.tgi
|
||||
Following the same successful pattern as previous tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
# Import the service under test
|
||||
from trustgraph.model.text_completion.tgi.llm import Processor
|
||||
from trustgraph.base import LlmResult
|
||||
from trustgraph.exceptions import TooManyRequests
|
||||
|
||||
|
||||
class TestTGIProcessorSimple(IsolatedAsyncioTestCase):
|
||||
"""Test TGI processor functionality"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.tgi.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test basic processor initialization"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'tgi',
|
||||
'url': 'http://tgi-service:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
# Act
|
||||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.default_model == 'tgi'
|
||||
assert processor.base_url == 'http://tgi-service:8899/v1'
|
||||
assert processor.temperature == 0.0
|
||||
assert processor.max_output == 2048
|
||||
assert hasattr(processor, 'session')
|
||||
mock_session_class.assert_called_once()
|
||||
|
||||
@patch('trustgraph.model.text_completion.tgi.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test successful content generation"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'choices': [{
|
||||
'message': {
|
||||
'content': 'Generated response from TGI'
|
||||
}
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 20,
|
||||
'completion_tokens': 12
|
||||
}
|
||||
})
|
||||
|
||||
# Mock the async context manager
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session.post.return_value.__aexit__.return_value = None
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'tgi',
|
||||
'url': 'http://tgi-service:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Generated response from TGI"
|
||||
assert result.in_token == 20
|
||||
assert result.out_token == 12
|
||||
assert result.model == 'tgi'
|
||||
|
||||
# Verify the API call was made correctly
|
||||
mock_session.post.assert_called_once()
|
||||
call_args = mock_session.post.call_args
|
||||
|
||||
# Check URL
|
||||
assert call_args[0][0] == 'http://tgi-service:8899/v1/chat/completions'
|
||||
|
||||
# Check request structure
|
||||
request_body = call_args[1]['json']
|
||||
assert request_body['model'] == 'tgi'
|
||||
assert request_body['temperature'] == 0.0
|
||||
assert request_body['max_tokens'] == 2048
|
||||
|
||||
@patch('trustgraph.model.text_completion.tgi.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_model_override(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test generate_content with model parameter override"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'choices': [{
|
||||
'message': {
|
||||
'content': 'Response from overridden model'
|
||||
}
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 15,
|
||||
'completion_tokens': 10
|
||||
}
|
||||
})
|
||||
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session.post.return_value.__aexit__.return_value = None
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'tgi',
|
||||
'url': 'http://tgi-service:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model
|
||||
result = await processor.generate_content("System", "Prompt", model="custom-tgi-model")
|
||||
|
||||
# Assert
|
||||
assert result.model == "custom-tgi-model" # Should use overridden model
|
||||
assert result.text == "Response from overridden model"
|
||||
|
||||
# Verify the API call was made with overridden model
|
||||
call_args = mock_session.post.call_args
|
||||
assert call_args[1]['json']['model'] == "custom-tgi-model"
|
||||
|
||||
@patch('trustgraph.model.text_completion.tgi.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_temperature_override(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test generate_content with temperature parameter override"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'choices': [{
|
||||
'message': {
|
||||
'content': 'Response with temperature override'
|
||||
}
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 18,
|
||||
'completion_tokens': 12
|
||||
}
|
||||
})
|
||||
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session.post.return_value.__aexit__.return_value = None
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'tgi',
|
||||
'url': 'http://tgi-service:8899/v1',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature
|
||||
result = await processor.generate_content("System", "Prompt", temperature=0.7)
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with temperature override"
|
||||
|
||||
# Verify the API call was made with overridden temperature
|
||||
call_args = mock_session.post.call_args
|
||||
assert call_args[1]['json']['temperature'] == 0.7
|
||||
|
||||
@patch('trustgraph.model.text_completion.tgi.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_both_parameters_override(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test generate_content with both model and temperature overrides"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'choices': [{
|
||||
'message': {
|
||||
'content': 'Response with both parameters override'
|
||||
}
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 20,
|
||||
'completion_tokens': 15
|
||||
}
|
||||
})
|
||||
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session.post.return_value.__aexit__.return_value = None
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'tgi',
|
||||
'url': 'http://tgi-service:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters
|
||||
result = await processor.generate_content("System", "Prompt", model="override-model", temperature=0.8)
|
||||
|
||||
# Assert
|
||||
assert result.model == "override-model"
|
||||
assert result.text == "Response with both parameters override"
|
||||
|
||||
# Verify the API call was made with overridden parameters
|
||||
call_args = mock_session.post.call_args
|
||||
assert call_args[1]['json']['model'] == "override-model"
|
||||
assert call_args[1]['json']['temperature'] == 0.8
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
@ -47,10 +47,10 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gemini-2.0-flash-001' # It's stored as 'model', not 'model_name'
|
||||
assert hasattr(processor, 'generation_config')
|
||||
assert processor.default_model == 'gemini-2.0-flash-001' # It's stored as 'model', not 'model_name'
|
||||
assert hasattr(processor, 'generation_configs') # Now a cache dictionary
|
||||
assert hasattr(processor, 'safety_settings')
|
||||
assert hasattr(processor, 'llm')
|
||||
assert hasattr(processor, 'model_clients') # LLM clients are now cached
|
||||
mock_service_account.Credentials.from_service_account_file.assert_called_once_with('private.json')
|
||||
mock_vertexai.init.assert_called_once()
|
||||
|
||||
|
|
@ -102,7 +102,8 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
mock_model.generate_content.assert_called_once()
|
||||
# Verify the call was made with the expected parameters
|
||||
call_args = mock_model.generate_content.call_args
|
||||
assert call_args[1]['generation_config'] == processor.generation_config
|
||||
# Generation config is now created dynamically per model
|
||||
assert 'generation_config' in call_args[1]
|
||||
assert call_args[1]['safety_settings'] == processor.safety_settings
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
|
|
@ -223,7 +224,7 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gemini-2.0-flash-001'
|
||||
assert processor.default_model == 'gemini-2.0-flash-001'
|
||||
mock_auth_default.assert_called_once()
|
||||
mock_vertexai.init.assert_called_once_with(
|
||||
location='us-central1',
|
||||
|
|
@ -296,11 +297,11 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'gemini-1.5-pro'
|
||||
assert processor.default_model == 'gemini-1.5-pro'
|
||||
|
||||
# Verify that generation_config object exists (can't easily check internal values)
|
||||
assert hasattr(processor, 'generation_config')
|
||||
assert processor.generation_config is not None
|
||||
assert hasattr(processor, 'generation_configs') # Now a cache dictionary
|
||||
assert processor.generation_configs == {} # Empty cache initially
|
||||
|
||||
# Verify that safety settings are configured
|
||||
assert len(processor.safety_settings) == 4
|
||||
|
|
@ -353,8 +354,8 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
project='test-project-123'
|
||||
)
|
||||
|
||||
# Verify GenerativeModel was created with the right model name
|
||||
mock_generative_model.assert_called_once_with('gemini-2.0-flash-001')
|
||||
# GenerativeModel is now created lazily on first use, not at initialization
|
||||
mock_generative_model.assert_not_called()
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
|
|
@ -440,8 +441,8 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'claude-3-sonnet@20240229'
|
||||
assert processor.is_anthropic == True
|
||||
assert processor.default_model == 'claude-3-sonnet@20240229'
|
||||
# is_anthropic logic is now determined dynamically per request
|
||||
|
||||
# Verify service account was called with private key
|
||||
mock_service_account.Credentials.from_service_account_file.assert_called_once_with('anthropic-key.json')
|
||||
|
|
@ -459,6 +460,180 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert processor.api_params["top_p"] == 1.0
|
||||
assert processor.api_params["top_k"] == 32
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
"""Test temperature parameter override functionality"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Response with custom temperature"
|
||||
mock_response.usage_metadata.prompt_token_count = 20
|
||||
mock_response.usage_metadata.candidates_token_count = 12
|
||||
mock_model.generate_content.return_value = mock_response
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'region': 'us-central1',
|
||||
'model': 'gemini-2.0-flash-001',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 8192,
|
||||
'private_key': 'private.json',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model=None, # Use default model
|
||||
temperature=0.8 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom temperature"
|
||||
|
||||
# Verify Gemini API was called with overridden temperature
|
||||
mock_model.generate_content.assert_called_once()
|
||||
call_args = mock_model.generate_content.call_args
|
||||
|
||||
# Check that generation_config was created (we can't directly access temperature from mock)
|
||||
generation_config = call_args.kwargs['generation_config']
|
||||
assert generation_config is not None # Should use overridden temperature configuration
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
"""Test model parameter override functionality"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
|
||||
# Mock different models
|
||||
mock_model_default = MagicMock()
|
||||
mock_model_override = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Response with custom model"
|
||||
mock_response.usage_metadata.prompt_token_count = 18
|
||||
mock_response.usage_metadata.candidates_token_count = 14
|
||||
mock_model_override.generate_content.return_value = mock_response
|
||||
|
||||
# GenerativeModel should return different models based on input
|
||||
def model_factory(model_name):
|
||||
if model_name == 'gemini-1.5-pro':
|
||||
return mock_model_override
|
||||
return mock_model_default
|
||||
|
||||
mock_generative_model.side_effect = model_factory
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'region': 'us-central1',
|
||||
'model': 'gemini-2.0-flash-001', # Default model
|
||||
'temperature': 0.2, # Default temperature
|
||||
'max_output': 8192,
|
||||
'private_key': 'private.json',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="gemini-1.5-pro", # Override model
|
||||
temperature=None # Use default temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with custom model"
|
||||
|
||||
# Verify the overridden model was used
|
||||
mock_model_override.generate_content.assert_called_once()
|
||||
# Verify GenerativeModel was called with the override model
|
||||
mock_generative_model.assert_called_with('gemini-1.5-pro')
|
||||
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
|
||||
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
|
||||
"""Test overriding both model and temperature parameters simultaneously"""
|
||||
# Arrange
|
||||
mock_credentials = MagicMock()
|
||||
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Response with both overrides"
|
||||
mock_response.usage_metadata.prompt_token_count = 22
|
||||
mock_response.usage_metadata.candidates_token_count = 16
|
||||
mock_model.generate_content.return_value = mock_response
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'region': 'us-central1',
|
||||
'model': 'gemini-2.0-flash-001', # Default model
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 8192,
|
||||
'private_key': 'private.json',
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters at runtime
|
||||
result = await processor.generate_content(
|
||||
"System prompt",
|
||||
"User prompt",
|
||||
model="gemini-1.5-flash-001", # Override model
|
||||
temperature=0.9 # Override temperature
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, LlmResult)
|
||||
assert result.text == "Response with both overrides"
|
||||
|
||||
# Verify both overrides were used
|
||||
mock_model.generate_content.assert_called_once()
|
||||
call_args = mock_model.generate_content.call_args
|
||||
|
||||
# Verify model override
|
||||
mock_generative_model.assert_called_with('gemini-1.5-flash-001') # Should use runtime override
|
||||
|
||||
# Verify temperature override (we can't directly access temperature from mock)
|
||||
generation_config = call_args.kwargs['generation_config']
|
||||
assert generation_config is not None # Should use overridden temperature configuration
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
@ -42,7 +42,7 @@ class TestVLLMProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'TheBloke/Mistral-7B-v0.1-AWQ'
|
||||
assert processor.default_model == 'TheBloke/Mistral-7B-v0.1-AWQ'
|
||||
assert processor.base_url == 'http://vllm-service:8899/v1'
|
||||
assert processor.temperature == 0.0
|
||||
assert processor.max_output == 2048
|
||||
|
|
@ -199,7 +199,7 @@ class TestVLLMProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'custom-model'
|
||||
assert processor.default_model == 'custom-model'
|
||||
assert processor.base_url == 'http://custom-vllm:8080/v1'
|
||||
assert processor.temperature == 0.7
|
||||
assert processor.max_output == 1024
|
||||
|
|
@ -228,7 +228,7 @@ class TestVLLMProcessorSimple(IsolatedAsyncioTestCase):
|
|||
processor = Processor(**config)
|
||||
|
||||
# Assert
|
||||
assert processor.model == 'TheBloke/Mistral-7B-v0.1-AWQ' # default_model
|
||||
assert processor.default_model == 'TheBloke/Mistral-7B-v0.1-AWQ' # default_model
|
||||
assert processor.base_url == 'http://vllm-service:8899/v1' # default_base_url
|
||||
assert processor.temperature == 0.0 # default_temperature
|
||||
assert processor.max_output == 2048 # default_max_output
|
||||
|
|
@ -485,5 +485,148 @@ class TestVLLMProcessorSimple(IsolatedAsyncioTestCase):
|
|||
assert call_args[1]['json']['prompt'] == expected_prompt
|
||||
|
||||
|
||||
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_model_override(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test generate_content with model parameter override"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'choices': [{
|
||||
'text': 'Response from overridden model'
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 12,
|
||||
'completion_tokens': 8
|
||||
}
|
||||
})
|
||||
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session.post.return_value.__aexit__.return_value = None
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
|
||||
'url': 'http://vllm-service:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override model
|
||||
result = await processor.generate_content("System", "Prompt", model="custom-vllm-model")
|
||||
|
||||
# Assert
|
||||
assert result.model == "custom-vllm-model" # Should use overridden model
|
||||
assert result.text == "Response from overridden model"
|
||||
|
||||
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_temperature_override(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test generate_content with temperature parameter override"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'choices': [{
|
||||
'text': 'Response with temperature override'
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 15,
|
||||
'completion_tokens': 10
|
||||
}
|
||||
})
|
||||
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session.post.return_value.__aexit__.return_value = None
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
|
||||
'url': 'http://vllm-service:8899/v1',
|
||||
'temperature': 0.0, # Default temperature
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override temperature
|
||||
result = await processor.generate_content("System", "Prompt", temperature=0.7)
|
||||
|
||||
# Assert
|
||||
assert result.text == "Response with temperature override"
|
||||
|
||||
# Verify the request was made with overridden temperature
|
||||
call_args = mock_session.post.call_args
|
||||
assert call_args[1]['json']['temperature'] == 0.7
|
||||
|
||||
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||
async def test_generate_content_with_both_parameters_override(self, mock_llm_init, mock_async_init, mock_session_class):
|
||||
"""Test generate_content with both model and temperature overrides"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
'choices': [{
|
||||
'text': 'Response with both parameters override'
|
||||
}],
|
||||
'usage': {
|
||||
'prompt_tokens': 18,
|
||||
'completion_tokens': 12
|
||||
}
|
||||
})
|
||||
|
||||
mock_session.post.return_value.__aenter__.return_value = mock_response
|
||||
mock_session.post.return_value.__aexit__.return_value = None
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
mock_async_init.return_value = None
|
||||
mock_llm_init.return_value = None
|
||||
|
||||
config = {
|
||||
'model': 'TheBloke/Mistral-7B-v0.1-AWQ',
|
||||
'url': 'http://vllm-service:8899/v1',
|
||||
'temperature': 0.0,
|
||||
'max_output': 2048,
|
||||
'concurrency': 1,
|
||||
'taskgroup': AsyncMock(),
|
||||
'id': 'test-processor'
|
||||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Act - Override both parameters
|
||||
result = await processor.generate_content("System", "Prompt", model="override-model", temperature=0.8)
|
||||
|
||||
# Assert
|
||||
assert result.model == "override-model"
|
||||
assert result.text == "Response with both parameters override"
|
||||
|
||||
# Verify the request was made with overridden temperature
|
||||
call_args = mock_session.post.call_args
|
||||
assert call_args[1]['json']['temperature'] == 0.8
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
Loading…
Add table
Add a link
Reference in a new issue