diff --git a/tests/integration/test_dynamic_llm_parameters.py b/tests/integration/test_dynamic_llm_parameters.py new file mode 100644 index 00000000..bb7f999c --- /dev/null +++ b/tests/integration/test_dynamic_llm_parameters.py @@ -0,0 +1,276 @@ +""" +Integration tests for Dynamic LLM Parameters +Testing end-to-end flow of runtime parameter changes in LLM processors +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from openai.types.chat import ChatCompletion, ChatCompletionMessage +from openai.types.chat.chat_completion import Choice +from openai.types.completion_usage import CompletionUsage + +from trustgraph.model.text_completion.openai.llm import Processor as OpenAIProcessor +from trustgraph.base import LlmResult + + +@pytest.mark.integration +class TestDynamicLlmParameters: + """Integration tests for dynamic parameter configuration""" + + @pytest.fixture + def mock_openai_client(self): + """Mock OpenAI client that returns realistic responses""" + client = MagicMock() + + # Default mock response + usage = CompletionUsage(prompt_tokens=25, completion_tokens=15, total_tokens=40) + message = ChatCompletionMessage(role="assistant", content="Dynamic parameter test response") + choice = Choice(index=0, message=message, finish_reason="stop") + + completion = ChatCompletion( + id="chatcmpl-test-dynamic", + choices=[choice], + created=1234567890, + model="gpt-4", # Will be overridden based on test + object="chat.completion", + usage=usage + ) + + client.chat.completions.create.return_value = completion + return client + + @pytest.fixture + def base_processor_config(self): + """Base configuration for test processors""" + return { + "api_key": "test-api-key", + "url": "https://api.openai.com/v1", + "temperature": 0.0, # Default temperature + "max_output": 1024, + } + + @patch('trustgraph.model.text_completion.openai.llm.OpenAI') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_runtime_temperature_override(self, mock_llm_init, mock_async_init, + mock_openai_class, mock_openai_client, base_processor_config): + """Test that temperature can be overridden at runtime""" + # Arrange + mock_openai_class.return_value = mock_openai_client + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = base_processor_config | { + "model": "gpt-3.5-turbo", + "concurrency": 1, + "taskgroup": AsyncMock(), + "id": "test-processor" + } + + processor = OpenAIProcessor(**config) + + # Act - Call with different temperature than configured default (0.0) + result = await processor.generate_content( + "System prompt", + "User prompt", + model=None, # Use default model + temperature=0.9 # Override temperature + ) + + # Assert + assert isinstance(result, LlmResult) + assert result.text == "Dynamic parameter test response" + + # Verify the OpenAI API was called with the overridden temperature + mock_openai_client.chat.completions.create.assert_called_once() + call_args = mock_openai_client.chat.completions.create.call_args + + assert call_args.kwargs['temperature'] == 0.9 # Should use runtime parameter + assert call_args.kwargs['model'] == "gpt-3.5-turbo" # Should use processor default + + @patch('trustgraph.model.text_completion.openai.llm.OpenAI') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_runtime_model_override(self, mock_llm_init, mock_async_init, + mock_openai_class, mock_openai_client, base_processor_config): + """Test that model can be overridden at runtime""" + # Arrange + mock_openai_class.return_value = mock_openai_client + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = base_processor_config | { + "model": "gpt-3.5-turbo", # Default model + "concurrency": 1, + "taskgroup": AsyncMock(), + "id": "test-processor" + } + + processor = OpenAIProcessor(**config) + + # Act - Call with different model than configured default + result = await processor.generate_content( + "System prompt", + "User prompt", + model="gpt-4", # Override model + temperature=None # Use default temperature + ) + + # Assert + assert isinstance(result, LlmResult) + + # Verify the OpenAI API was called with the overridden model + mock_openai_client.chat.completions.create.assert_called_once() + call_args = mock_openai_client.chat.completions.create.call_args + + assert call_args.kwargs['model'] == "gpt-4" # Should use runtime parameter + assert call_args.kwargs['temperature'] == 0.0 # Should use processor default + + @patch('trustgraph.model.text_completion.openai.llm.OpenAI') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_both_parameters_override(self, mock_llm_init, mock_async_init, + mock_openai_class, mock_openai_client, base_processor_config): + """Test that both model and temperature can be overridden simultaneously""" + # Arrange + mock_openai_class.return_value = mock_openai_client + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = base_processor_config | { + "model": "gpt-3.5-turbo", # Default model + "concurrency": 1, + "taskgroup": AsyncMock(), + "id": "test-processor" + } + + processor = OpenAIProcessor(**config) + + # Act - Override both parameters + result = await processor.generate_content( + "System prompt", + "User prompt", + model="gpt-4", # Override model + temperature=0.5 # Override temperature + ) + + # Assert + assert isinstance(result, LlmResult) + + # Verify both parameters were overridden + mock_openai_client.chat.completions.create.assert_called_once() + call_args = mock_openai_client.chat.completions.create.call_args + + assert call_args.kwargs['model'] == "gpt-4" # Should use runtime parameter + assert call_args.kwargs['temperature'] == 0.5 # Should use runtime parameter + + @patch('trustgraph.model.text_completion.openai.llm.OpenAI') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_fallback_to_defaults_when_no_override(self, mock_llm_init, mock_async_init, + mock_openai_class, mock_openai_client, base_processor_config): + """Test that processor falls back to configured defaults when no parameters are provided""" + # Arrange + mock_openai_class.return_value = mock_openai_client + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = base_processor_config | { + "model": "gpt-3.5-turbo", # Default model + "temperature": 0.2, # Default temperature + "concurrency": 1, + "taskgroup": AsyncMock(), + "id": "test-processor" + } + + processor = OpenAIProcessor(**config) + + # Act - Call with no parameter overrides + result = await processor.generate_content( + "System prompt", + "User prompt", + model=None, # Use default + temperature=None # Use default + ) + + # Assert + assert isinstance(result, LlmResult) + + # Verify defaults were used + mock_openai_client.chat.completions.create.assert_called_once() + call_args = mock_openai_client.chat.completions.create.call_args + + assert call_args.kwargs['model'] == "gpt-3.5-turbo" # Should use processor default + assert call_args.kwargs['temperature'] == 0.2 # Should use processor default + + @patch('trustgraph.model.text_completion.openai.llm.OpenAI') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_multiple_concurrent_calls_different_parameters(self, mock_llm_init, mock_async_init, + mock_openai_class, mock_openai_client, base_processor_config): + """Test multiple concurrent calls with different parameters don't interfere""" + # Arrange + mock_openai_class.return_value = mock_openai_client + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = base_processor_config | { + "model": "gpt-3.5-turbo", + "concurrency": 1, + "taskgroup": AsyncMock(), + "id": "test-processor" + } + + processor = OpenAIProcessor(**config) + + # Reset the mock to track multiple calls + mock_openai_client.reset_mock() + + # Act - Make multiple calls with different parameters concurrently + import asyncio + tasks = [ + processor.generate_content("System 1", "Prompt 1", model="gpt-3.5-turbo", temperature=0.1), + processor.generate_content("System 2", "Prompt 2", model="gpt-4", temperature=0.8), + processor.generate_content("System 3", "Prompt 3", model="gpt-3.5-turbo", temperature=0.5) + ] + + results = await asyncio.gather(*tasks) + + # Assert + assert len(results) == 3 + for result in results: + assert isinstance(result, LlmResult) + + # Verify all calls were made with correct parameters + assert mock_openai_client.chat.completions.create.call_count == 3 + + # Get all call arguments + call_args_list = mock_openai_client.chat.completions.create.call_args_list + + # Verify each call had the expected parameters + expected_params = [ + ("gpt-3.5-turbo", 0.1), + ("gpt-4", 0.8), + ("gpt-3.5-turbo", 0.5) + ] + + for i, (expected_model, expected_temp) in enumerate(expected_params): + call_kwargs = call_args_list[i].kwargs + assert call_kwargs['model'] == expected_model + assert call_kwargs['temperature'] == expected_temp + + async def test_parameter_boundary_values(self, mock_openai_client, base_processor_config): + """Test parameter boundary values (edge cases)""" + # This would test extreme values like temperature=0.0, temperature=2.0, etc. + # Implementation depends on specific validation requirements + pass + + async def test_invalid_parameter_types_handling(self, mock_openai_client, base_processor_config): + """Test handling of invalid parameter types""" + # This would test what happens with invalid temperature values, non-existent models, etc. + # Implementation depends on error handling requirements + pass + + +if __name__ == '__main__': + pytest.main([__file__]) \ No newline at end of file diff --git a/tests/unit/test_base/test_flow_parameter_specs.py b/tests/unit/test_base/test_flow_parameter_specs.py new file mode 100644 index 00000000..c813d66c --- /dev/null +++ b/tests/unit/test_base/test_flow_parameter_specs.py @@ -0,0 +1,238 @@ +""" +Unit tests for Flow Parameter Specification functionality +Testing parameter specification registration and handling in flow processors +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from unittest import IsolatedAsyncioTestCase + +from trustgraph.base.flow_processor import FlowProcessor +from trustgraph.base import ParameterSpec, ConsumerSpec, ProducerSpec + + +class MockAsyncProcessor: + def __init__(self, **params): + self.config_handlers = [] + self.id = params.get('id', 'test-service') + self.specifications = [] + + +class TestFlowParameterSpecs(IsolatedAsyncioTestCase): + """Test flow processor parameter specification functionality""" + + @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) + def test_parameter_spec_registration(self): + """Test that parameter specs can be registered with flow processors""" + # Arrange + config = { + 'id': 'test-flow-processor', + 'concurrency': 1, + 'taskgroup': AsyncMock() + } + + processor = FlowProcessor(**config) + + # Create test parameter specs + model_spec = ParameterSpec(name="model") + temperature_spec = ParameterSpec(name="temperature") + + # Act + processor.register_specification(model_spec) + processor.register_specification(temperature_spec) + + # Assert + assert len(processor.specifications) >= 2 + + param_specs = [spec for spec in processor.specifications + if isinstance(spec, ParameterSpec)] + assert len(param_specs) >= 2 + + param_names = [spec.name for spec in param_specs] + assert "model" in param_names + assert "temperature" in param_names + + @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) + def test_mixed_specification_types(self): + """Test registration of mixed specification types (parameters, consumers, producers)""" + # Arrange + config = { + 'id': 'test-flow-processor', + 'concurrency': 1, + 'taskgroup': AsyncMock() + } + + processor = FlowProcessor(**config) + + # Create different spec types + param_spec = ParameterSpec(name="model") + consumer_spec = ConsumerSpec(name="input", schema=MagicMock(), handler=MagicMock()) + producer_spec = ProducerSpec(name="output", schema=MagicMock()) + + # Act + processor.register_specification(param_spec) + processor.register_specification(consumer_spec) + processor.register_specification(producer_spec) + + # Assert + assert len(processor.specifications) == 3 + + # Count each type + param_specs = [s for s in processor.specifications if isinstance(s, ParameterSpec)] + consumer_specs = [s for s in processor.specifications if isinstance(s, ConsumerSpec)] + producer_specs = [s for s in processor.specifications if isinstance(s, ProducerSpec)] + + assert len(param_specs) == 1 + assert len(consumer_specs) == 1 + assert len(producer_specs) == 1 + + @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) + def test_parameter_spec_metadata(self): + """Test parameter specification metadata handling""" + # Arrange + config = { + 'id': 'test-flow-processor', + 'concurrency': 1, + 'taskgroup': AsyncMock() + } + + processor = FlowProcessor(**config) + + # Create parameter specs with metadata (if supported) + model_spec = ParameterSpec(name="model") + temperature_spec = ParameterSpec(name="temperature") + + # Act + processor.register_specification(model_spec) + processor.register_specification(temperature_spec) + + # Assert + param_specs = [spec for spec in processor.specifications + if isinstance(spec, ParameterSpec)] + + model_spec_registered = next((s for s in param_specs if s.name == "model"), None) + temperature_spec_registered = next((s for s in param_specs if s.name == "temperature"), None) + + assert model_spec_registered is not None + assert temperature_spec_registered is not None + assert model_spec_registered.name == "model" + assert temperature_spec_registered.name == "temperature" + + @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) + def test_duplicate_parameter_spec_handling(self): + """Test handling of duplicate parameter spec registration""" + # Arrange + config = { + 'id': 'test-flow-processor', + 'concurrency': 1, + 'taskgroup': AsyncMock() + } + + processor = FlowProcessor(**config) + + # Create duplicate parameter specs + model_spec1 = ParameterSpec(name="model") + model_spec2 = ParameterSpec(name="model") + + # Act + processor.register_specification(model_spec1) + processor.register_specification(model_spec2) + + # Assert - Should allow duplicates (or handle appropriately) + param_specs = [spec for spec in processor.specifications + if isinstance(spec, ParameterSpec) and spec.name == "model"] + + # Either should have 2 duplicates or the system should handle deduplication + assert len(param_specs) >= 1 # At least one should be registered + + @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) + @patch('trustgraph.base.flow_processor.Flow') + async def test_parameter_specs_available_to_flows(self, mock_flow_class): + """Test that parameter specs are available when flows are created""" + # Arrange + config = { + 'id': 'test-flow-processor', + 'concurrency': 1, + 'taskgroup': AsyncMock() + } + + processor = FlowProcessor(**config) + processor.id = 'test-processor' + + # Register parameter specs + model_spec = ParameterSpec(name="model") + temperature_spec = ParameterSpec(name="temperature") + processor.register_specification(model_spec) + processor.register_specification(temperature_spec) + + mock_flow = AsyncMock() + mock_flow_class.return_value = mock_flow + + flow_name = 'test-flow' + flow_defn = {'config': 'test-config'} + + # Act + await processor.start_flow(flow_name, flow_defn) + + # Assert - Flow should be created with access to processor specifications + mock_flow_class.assert_called_once_with('test-processor', flow_name, processor, flow_defn) + + # The flow should have access to the processor's specifications + # (The exact mechanism depends on Flow implementation) + assert len(processor.specifications) >= 2 + + +class TestParameterSpecValidation(IsolatedAsyncioTestCase): + """Test parameter specification validation functionality""" + + @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) + def test_parameter_spec_name_validation(self): + """Test parameter spec name validation""" + # Arrange + config = { + 'id': 'test-flow-processor', + 'concurrency': 1, + 'taskgroup': AsyncMock() + } + + processor = FlowProcessor(**config) + + # Act & Assert - Valid parameter names + valid_specs = [ + ParameterSpec(name="model"), + ParameterSpec(name="temperature"), + ParameterSpec(name="max_tokens"), + ParameterSpec(name="api_key") + ] + + for spec in valid_specs: + # Should not raise any exceptions + processor.register_specification(spec) + + assert len([s for s in processor.specifications if isinstance(s, ParameterSpec)]) >= 4 + + def test_parameter_spec_creation_validation(self): + """Test parameter spec creation with various inputs""" + # Test valid parameter spec creation + valid_specs = [ + ParameterSpec(name="model"), + ParameterSpec(name="temperature"), + ParameterSpec(name="max_output"), + ] + + for spec in valid_specs: + assert spec.name is not None + assert isinstance(spec.name, str) + + # Test edge cases (if parameter specs have validation) + # This depends on the actual ParameterSpec implementation + try: + empty_name_spec = ParameterSpec(name="") + # May or may not be valid depending on implementation + except Exception: + # If validation exists, it should catch invalid names + pass + + +if __name__ == '__main__': + pytest.main([__file__]) \ No newline at end of file diff --git a/tests/unit/test_base/test_llm_service_parameters.py b/tests/unit/test_base/test_llm_service_parameters.py new file mode 100644 index 00000000..65cdf9a5 --- /dev/null +++ b/tests/unit/test_base/test_llm_service_parameters.py @@ -0,0 +1,264 @@ +""" +Unit tests for LLM Service Parameter Specifications +Testing the new parameter-aware functionality added to the LLM base service +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from unittest import IsolatedAsyncioTestCase + +from trustgraph.base.llm_service import LlmService, LlmResult +from trustgraph.base import ParameterSpec, ConsumerSpec, ProducerSpec +from trustgraph.schema import TextCompletionRequest, TextCompletionResponse + + +class MockAsyncProcessor: + def __init__(self, **params): + self.config_handlers = [] + self.id = params.get('id', 'test-service') + self.specifications = [] + + + + +class TestLlmServiceParameters(IsolatedAsyncioTestCase): + """Test LLM service parameter specification functionality""" + + @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) + def test_parameter_specs_registration(self): + """Test that LLM service registers model and temperature parameter specs""" + # Arrange + config = { + 'id': 'test-llm-service', + 'concurrency': 1, + 'taskgroup': AsyncMock() # Add required taskgroup + } + + # Act + service = LlmService(**config) + + # Assert + param_specs = {spec.name: spec for spec in service.specifications + if isinstance(spec, ParameterSpec)} + + assert "model" in param_specs + assert "temperature" in param_specs + assert len(param_specs) >= 2 # May have other parameter specs + + @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) + def test_model_parameter_spec_properties(self): + """Test that model parameter spec has correct properties""" + # Arrange + config = { + 'id': 'test-llm-service', + 'concurrency': 1, + 'taskgroup': AsyncMock() + } + + # Act + service = LlmService(**config) + + # Assert + model_spec = None + for spec in service.specifications: + if isinstance(spec, ParameterSpec) and spec.name == "model": + model_spec = spec + break + + assert model_spec is not None + assert model_spec.name == "model" + + @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) + def test_temperature_parameter_spec_properties(self): + """Test that temperature parameter spec has correct properties""" + # Arrange + config = { + 'id': 'test-llm-service', + 'concurrency': 1, + 'taskgroup': AsyncMock() + } + + # Act + service = LlmService(**config) + + # Assert + temperature_spec = None + for spec in service.specifications: + if isinstance(spec, ParameterSpec) and spec.name == "temperature": + temperature_spec = spec + break + + assert temperature_spec is not None + assert temperature_spec.name == "temperature" + + @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) + async def test_on_request_extracts_parameters_from_flow(self): + """Test that on_request method extracts model and temperature from flow""" + # Arrange + config = { + 'id': 'test-llm-service', + 'concurrency': 1, + 'taskgroup': AsyncMock() + } + + service = LlmService(**config) + + # Mock the metrics + service.text_completion_model_metric = MagicMock() + service.text_completion_model_metric.labels.return_value.info = AsyncMock() + + # Mock the generate_content method to capture parameters + service.generate_content = AsyncMock(return_value=LlmResult( + text="test response", + in_token=10, + out_token=5, + model="gpt-4" + )) + + # Mock message and flow + mock_message = MagicMock() + mock_message.value.return_value = MagicMock() + mock_message.value.return_value.system = "system prompt" + mock_message.value.return_value.prompt = "user prompt" + mock_message.properties.return_value = {"id": "test-id"} + + mock_consumer = MagicMock() + mock_consumer.name = "request" + + mock_flow = MagicMock() + mock_flow.name = "test-flow" + mock_flow.return_value = "test-model" # flow("model") returns this + mock_flow.side_effect = lambda param: { + "model": "gpt-4", + "temperature": 0.7 + }.get(param, f"mock-{param}") + + mock_producer = AsyncMock() + mock_flow.producer = {"response": mock_producer} + + # Act + await service.on_request(mock_message, mock_consumer, mock_flow) + + # Assert + # Verify that generate_content was called with parameters from flow + service.generate_content.assert_called_once() + call_args = service.generate_content.call_args + + assert call_args[0][0] == "system prompt" # system + assert call_args[0][1] == "user prompt" # prompt + assert call_args[0][2] == "gpt-4" # model + assert call_args[0][3] == 0.7 # temperature + + # Verify flow was queried for both parameters + mock_flow.assert_any_call("model") + mock_flow.assert_any_call("temperature") + + @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) + async def test_on_request_handles_missing_parameters_gracefully(self): + """Test that on_request handles missing parameters gracefully""" + # Arrange + config = { + 'id': 'test-llm-service', + 'concurrency': 1, + 'taskgroup': AsyncMock() + } + + service = LlmService(**config) + + # Mock the metrics + service.text_completion_model_metric = MagicMock() + service.text_completion_model_metric.labels.return_value.info = AsyncMock() + + # Mock the generate_content method + service.generate_content = AsyncMock(return_value=LlmResult( + text="test response", + in_token=10, + out_token=5, + model="default-model" + )) + + # Mock message and flow where flow returns None for parameters + mock_message = MagicMock() + mock_message.value.return_value = MagicMock() + mock_message.value.return_value.system = "system prompt" + mock_message.value.return_value.prompt = "user prompt" + mock_message.properties.return_value = {"id": "test-id"} + + mock_consumer = MagicMock() + mock_consumer.name = "request" + + mock_flow = MagicMock() + mock_flow.name = "test-flow" + mock_flow.return_value = None # Both parameters return None + + mock_producer = AsyncMock() + mock_flow.producer = {"response": mock_producer} + + # Act + await service.on_request(mock_message, mock_consumer, mock_flow) + + # Assert + # Should still call generate_content, with None values that will use processor defaults + service.generate_content.assert_called_once() + call_args = service.generate_content.call_args + + assert call_args[0][0] == "system prompt" # system + assert call_args[0][1] == "user prompt" # prompt + assert call_args[0][2] is None # model (will use processor default) + assert call_args[0][3] is None # temperature (will use processor default) + + @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) + async def test_on_request_error_handling_preserves_behavior(self): + """Test that parameter extraction doesn't break existing error handling""" + # Arrange + config = { + 'id': 'test-llm-service', + 'concurrency': 1, + 'taskgroup': AsyncMock() + } + + service = LlmService(**config) + + # Mock the metrics + service.text_completion_model_metric = MagicMock() + service.text_completion_model_metric.labels.return_value.info = AsyncMock() + + # Mock generate_content to raise an exception + service.generate_content = AsyncMock(side_effect=Exception("Test error")) + + # Mock message and flow + mock_message = MagicMock() + mock_message.value.return_value = MagicMock() + mock_message.value.return_value.system = "system prompt" + mock_message.value.return_value.prompt = "user prompt" + mock_message.properties.return_value = {"id": "test-id"} + + mock_consumer = MagicMock() + mock_consumer.name = "request" + + mock_flow = MagicMock() + mock_flow.name = "test-flow" + mock_flow.side_effect = lambda param: { + "model": "gpt-4", + "temperature": 0.7 + }.get(param, f"mock-{param}") + + mock_producer = AsyncMock() + mock_flow.producer = {"response": mock_producer} + + # Act + await service.on_request(mock_message, mock_consumer, mock_flow) + + # Assert + # Should have sent error response + mock_producer.send.assert_called_once() + error_response = mock_producer.send.call_args[0][0] + + assert error_response.error is not None + assert error_response.error.type == "llm-error" + assert "Test error" in error_response.error.message + assert error_response.response is None + + +if __name__ == '__main__': + pytest.main([__file__]) \ No newline at end of file diff --git a/tests/unit/test_text_completion/test_azure_openai_processor.py b/tests/unit/test_text_completion/test_azure_openai_processor.py index 967a3893..02c85d54 100644 --- a/tests/unit/test_text_completion/test_azure_openai_processor.py +++ b/tests/unit/test_text_completion/test_azure_openai_processor.py @@ -402,6 +402,156 @@ class TestAzureOpenAIProcessorSimple(IsolatedAsyncioTestCase): assert call_args[1]['max_tokens'] == 1024 assert call_args[1]['top_p'] == 1 + @patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_azure_openai_class): + """Test temperature parameter override functionality""" + # Arrange + mock_azure_client = MagicMock() + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = 'Response with custom temperature' + mock_response.usage.prompt_tokens = 20 + mock_response.usage.completion_tokens = 12 + + mock_azure_client.chat.completions.create.return_value = mock_response + mock_azure_openai_class.return_value = mock_azure_client + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'model': 'gpt-4', + 'endpoint': 'https://test.openai.azure.com/', + 'token': 'test-token', + 'api_version': '2024-12-01-preview', + 'temperature': 0.0, # Default temperature + 'max_output': 4192, + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + # Act - Override temperature at runtime + result = await processor.generate_content( + "System prompt", + "User prompt", + model=None, # Use default model + temperature=0.8 # Override temperature + ) + + # Assert + assert isinstance(result, LlmResult) + assert result.text == "Response with custom temperature" + + # Verify Azure OpenAI API was called with overridden temperature + call_args = mock_azure_client.chat.completions.create.call_args + assert call_args[1]['temperature'] == 0.8 # Should use runtime override + assert call_args[1]['model'] == 'gpt-4' + + @patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_azure_openai_class): + """Test model parameter override functionality""" + # Arrange + mock_azure_client = MagicMock() + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = 'Response with custom model' + mock_response.usage.prompt_tokens = 18 + mock_response.usage.completion_tokens = 14 + + mock_azure_client.chat.completions.create.return_value = mock_response + mock_azure_openai_class.return_value = mock_azure_client + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'model': 'gpt-4', # Default model + 'endpoint': 'https://test.openai.azure.com/', + 'token': 'test-token', + 'api_version': '2024-12-01-preview', + 'temperature': 0.1, # Default temperature + 'max_output': 4192, + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + # Act - Override model at runtime + result = await processor.generate_content( + "System prompt", + "User prompt", + model="gpt-4o", # Override model + temperature=None # Use default temperature + ) + + # Assert + assert isinstance(result, LlmResult) + assert result.text == "Response with custom model" + + # Verify Azure OpenAI API was called with overridden model + call_args = mock_azure_client.chat.completions.create.call_args + assert call_args[1]['model'] == 'gpt-4o' # Should use runtime override + assert call_args[1]['temperature'] == 0.1 # Should use processor default + + @patch('trustgraph.model.text_completion.azure_openai.llm.AzureOpenAI') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_azure_openai_class): + """Test overriding both model and temperature parameters simultaneously""" + # Arrange + mock_azure_client = MagicMock() + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = 'Response with both overrides' + mock_response.usage.prompt_tokens = 22 + mock_response.usage.completion_tokens = 16 + + mock_azure_client.chat.completions.create.return_value = mock_response + mock_azure_openai_class.return_value = mock_azure_client + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'model': 'gpt-4', # Default model + 'endpoint': 'https://test.openai.azure.com/', + 'token': 'test-token', + 'api_version': '2024-12-01-preview', + 'temperature': 0.0, # Default temperature + 'max_output': 4192, + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + # Act - Override both parameters at runtime + result = await processor.generate_content( + "System prompt", + "User prompt", + model="gpt-4o-mini", # Override model + temperature=0.9 # Override temperature + ) + + # Assert + assert isinstance(result, LlmResult) + assert result.text == "Response with both overrides" + + # Verify Azure OpenAI API was called with both overrides + call_args = mock_azure_client.chat.completions.create.call_args + assert call_args[1]['model'] == 'gpt-4o-mini' # Should use runtime override + assert call_args[1]['temperature'] == 0.9 # Should use runtime override + if __name__ == '__main__': pytest.main([__file__]) \ No newline at end of file diff --git a/tests/unit/test_text_completion/test_azure_processor.py b/tests/unit/test_text_completion/test_azure_processor.py index a1e2ba75..529a12ab 100644 --- a/tests/unit/test_text_completion/test_azure_processor.py +++ b/tests/unit/test_text_completion/test_azure_processor.py @@ -459,5 +459,150 @@ class TestAzureProcessorSimple(IsolatedAsyncioTestCase): ) + @patch('trustgraph.model.text_completion.azure.llm.requests') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_generate_content_with_model_override(self, mock_llm_init, mock_async_init, mock_requests): + """Test generate_content with model parameter override""" + # Arrange + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + 'choices': [{ + 'message': { + 'content': 'Response with model override' + } + }], + 'usage': { + 'prompt_tokens': 15, + 'completion_tokens': 10 + } + } + mock_requests.post.return_value = mock_response + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions', + 'token': 'test-token', + 'temperature': 0.0, + 'max_output': 4192, + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + # Act - Override model + result = await processor.generate_content("System", "Prompt", model="custom-azure-model") + + # Assert + assert result.model == "custom-azure-model" # Should use overridden model + assert result.text == "Response with model override" + + @patch('trustgraph.model.text_completion.azure.llm.requests') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_generate_content_with_temperature_override(self, mock_llm_init, mock_async_init, mock_requests): + """Test generate_content with temperature parameter override""" + # Arrange + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + 'choices': [{ + 'message': { + 'content': 'Response with temperature override' + } + }], + 'usage': { + 'prompt_tokens': 15, + 'completion_tokens': 10 + } + } + mock_requests.post.return_value = mock_response + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions', + 'token': 'test-token', + 'temperature': 0.0, # Default temperature + 'max_output': 4192, + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + # Act - Override temperature + result = await processor.generate_content("System", "Prompt", temperature=0.8) + + # Assert + assert result.text == "Response with temperature override" + + # Verify the request was made with the overridden temperature + mock_requests.post.assert_called_once() + call_args = mock_requests.post.call_args + + import json + request_body = json.loads(call_args[1]['data']) + assert request_body['temperature'] == 0.8 + + @patch('trustgraph.model.text_completion.azure.llm.requests') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_generate_content_with_both_parameters_override(self, mock_llm_init, mock_async_init, mock_requests): + """Test generate_content with both model and temperature overrides""" + # Arrange + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + 'choices': [{ + 'message': { + 'content': 'Response with both parameters override' + } + }], + 'usage': { + 'prompt_tokens': 18, + 'completion_tokens': 12 + } + } + mock_requests.post.return_value = mock_response + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'endpoint': 'https://test.inference.ai.azure.com/v1/chat/completions', + 'token': 'test-token', + 'temperature': 0.0, + 'max_output': 4192, + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + # Act - Override both parameters + result = await processor.generate_content("System", "Prompt", model="override-model", temperature=0.9) + + # Assert + assert result.model == "override-model" + assert result.text == "Response with both parameters override" + + # Verify the request was made with overridden temperature + mock_requests.post.assert_called_once() + call_args = mock_requests.post.call_args + + import json + request_body = json.loads(call_args[1]['data']) + assert request_body['temperature'] == 0.9 + + if __name__ == '__main__': pytest.main([__file__]) \ No newline at end of file diff --git a/tests/unit/test_text_completion/test_bedrock_processor.py b/tests/unit/test_text_completion/test_bedrock_processor.py new file mode 100644 index 00000000..b8ead9a6 --- /dev/null +++ b/tests/unit/test_text_completion/test_bedrock_processor.py @@ -0,0 +1,280 @@ +""" +Unit tests for trustgraph.model.text_completion.bedrock +Following the same successful pattern as other processor tests +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from unittest import IsolatedAsyncioTestCase +import json + +# Import the service under test +from trustgraph.model.text_completion.bedrock.llm import Processor, Mistral, Anthropic +from trustgraph.base import LlmResult + + +class TestBedrockProcessorSimple(IsolatedAsyncioTestCase): + """Test Bedrock processor functionality""" + + @patch('trustgraph.model.text_completion.bedrock.llm.boto3.Session') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_session_class): + """Test basic processor initialization""" + # Arrange + mock_session = MagicMock() + mock_bedrock = MagicMock() + mock_session.client.return_value = mock_bedrock + mock_session_class.return_value = mock_session + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'model': 'mistral.mistral-large-2407-v1:0', + 'temperature': 0.1, + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + # Act + processor = Processor(**config) + + # Assert + assert processor.default_model == 'mistral.mistral-large-2407-v1:0' + assert processor.temperature == 0.1 + assert hasattr(processor, 'bedrock') + mock_session_class.assert_called_once() + + @patch('trustgraph.model.text_completion.bedrock.llm.boto3.Session') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_generate_content_success_mistral(self, mock_llm_init, mock_async_init, mock_session_class): + """Test successful content generation with Mistral model""" + # Arrange + mock_session = MagicMock() + mock_bedrock = MagicMock() + mock_session.client.return_value = mock_bedrock + mock_session_class.return_value = mock_session + + mock_response = { + 'body': MagicMock(), + 'ResponseMetadata': { + 'HTTPHeaders': { + 'x-amzn-bedrock-input-token-count': '15', + 'x-amzn-bedrock-output-token-count': '8' + } + } + } + mock_response['body'].read.return_value = json.dumps({ + 'outputs': [{'text': 'Generated response from Bedrock'}] + }) + mock_bedrock.invoke_model.return_value = mock_response + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'model': 'mistral.mistral-large-2407-v1:0', + 'temperature': 0.0, + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + # Act + result = await processor.generate_content("System prompt", "User prompt") + + # Assert + assert isinstance(result, LlmResult) + assert result.text == "Generated response from Bedrock" + assert result.in_token == 15 + assert result.out_token == 8 + assert result.model == 'mistral.mistral-large-2407-v1:0' + mock_bedrock.invoke_model.assert_called_once() + + @patch('trustgraph.model.text_completion.bedrock.llm.boto3.Session') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_session_class): + """Test temperature parameter override functionality""" + # Arrange + mock_session = MagicMock() + mock_bedrock = MagicMock() + mock_session.client.return_value = mock_bedrock + mock_session_class.return_value = mock_session + + mock_response = { + 'body': MagicMock(), + 'ResponseMetadata': { + 'HTTPHeaders': { + 'x-amzn-bedrock-input-token-count': '20', + 'x-amzn-bedrock-output-token-count': '12' + } + } + } + mock_response['body'].read.return_value = json.dumps({ + 'outputs': [{'text': 'Response with custom temperature'}] + }) + mock_bedrock.invoke_model.return_value = mock_response + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'model': 'mistral.mistral-large-2407-v1:0', + 'temperature': 0.0, # Default temperature + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + # Act - Override temperature at runtime + result = await processor.generate_content( + "System prompt", + "User prompt", + model=None, # Use default model + temperature=0.8 # Override temperature + ) + + # Assert + assert isinstance(result, LlmResult) + assert result.text == "Response with custom temperature" + + # Verify the model variant was created with overridden temperature + # The cache key should include the temperature + cache_key = f"mistral.mistral-large-2407-v1:0:0.8" + assert cache_key in processor.model_variants + variant = processor.model_variants[cache_key] + assert variant.temperature == 0.8 + + @patch('trustgraph.model.text_completion.bedrock.llm.boto3.Session') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_session_class): + """Test model parameter override functionality""" + # Arrange + mock_session = MagicMock() + mock_bedrock = MagicMock() + mock_session.client.return_value = mock_bedrock + mock_session_class.return_value = mock_session + + mock_response = { + 'body': MagicMock(), + 'ResponseMetadata': { + 'HTTPHeaders': { + 'x-amzn-bedrock-input-token-count': '18', + 'x-amzn-bedrock-output-token-count': '14' + } + } + } + mock_response['body'].read.return_value = json.dumps({ + 'content': [{'text': 'Response with custom model'}] + }) + mock_bedrock.invoke_model.return_value = mock_response + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'model': 'mistral.mistral-large-2407-v1:0', # Default model + 'temperature': 0.1, # Default temperature + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + # Act - Override model at runtime + result = await processor.generate_content( + "System prompt", + "User prompt", + model="anthropic.claude-3-sonnet-20240229-v1:0", # Override model + temperature=None # Use default temperature + ) + + # Assert + assert isinstance(result, LlmResult) + assert result.text == "Response with custom model" + + # Verify Bedrock API was called with overridden model + mock_bedrock.invoke_model.assert_called_once() + call_args = mock_bedrock.invoke_model.call_args + assert call_args[1]['modelId'] == "anthropic.claude-3-sonnet-20240229-v1:0" + + # Verify the correct model variant (Anthropic) was used + cache_key = f"anthropic.claude-3-sonnet-20240229-v1:0:0.1" + assert cache_key in processor.model_variants + variant = processor.model_variants[cache_key] + assert isinstance(variant, Anthropic) + + @patch('trustgraph.model.text_completion.bedrock.llm.boto3.Session') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_session_class): + """Test overriding both model and temperature parameters simultaneously""" + # Arrange + mock_session = MagicMock() + mock_bedrock = MagicMock() + mock_session.client.return_value = mock_bedrock + mock_session_class.return_value = mock_session + + mock_response = { + 'body': MagicMock(), + 'ResponseMetadata': { + 'HTTPHeaders': { + 'x-amzn-bedrock-input-token-count': '22', + 'x-amzn-bedrock-output-token-count': '16' + } + } + } + mock_response['body'].read.return_value = json.dumps({ + 'generation': 'Response with both overrides' + }) + mock_bedrock.invoke_model.return_value = mock_response + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'model': 'mistral.mistral-large-2407-v1:0', # Default model + 'temperature': 0.0, # Default temperature + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + # Act - Override both parameters at runtime + result = await processor.generate_content( + "System prompt", + "User prompt", + model="meta.llama3-70b-instruct-v1:0", # Override model (Meta/Llama) + temperature=0.9 # Override temperature + ) + + # Assert + assert isinstance(result, LlmResult) + assert result.text == "Response with both overrides" + + # Verify Bedrock API was called with both overrides + mock_bedrock.invoke_model.assert_called_once() + call_args = mock_bedrock.invoke_model.call_args + assert call_args[1]['modelId'] == "meta.llama3-70b-instruct-v1:0" + + # Verify the correct model variant (Meta) was used with correct temperature + cache_key = f"meta.llama3-70b-instruct-v1:0:0.9" + assert cache_key in processor.model_variants + variant = processor.model_variants[cache_key] + assert variant.temperature == 0.9 + + +if __name__ == '__main__': + pytest.main([__file__]) \ No newline at end of file diff --git a/tests/unit/test_text_completion/test_claude_processor.py b/tests/unit/test_text_completion/test_claude_processor.py index 12951019..dd9d9b6a 100644 --- a/tests/unit/test_text_completion/test_claude_processor.py +++ b/tests/unit/test_text_completion/test_claude_processor.py @@ -435,6 +435,156 @@ class TestClaudeProcessorSimple(IsolatedAsyncioTestCase): assert processor.claude == mock_claude_client assert processor.default_model == 'claude-3-opus-20240229' + @patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_anthropic_class): + """Test temperature parameter override functionality""" + # Arrange + mock_claude_client = MagicMock() + mock_response = MagicMock() + mock_response.content = [MagicMock()] + mock_response.content[0].text = "Response with custom temperature" + mock_response.usage.input_tokens = 20 + mock_response.usage.output_tokens = 12 + + mock_claude_client.messages.create.return_value = mock_response + mock_anthropic_class.return_value = mock_claude_client + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'model': 'claude-3-5-sonnet-20240620', + 'api_key': 'test-api-key', + 'temperature': 0.0, # Default temperature + 'max_output': 8192, + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + # Act - Override temperature at runtime + result = await processor.generate_content( + "System prompt", + "User prompt", + model=None, # Use default model + temperature=0.9 # Override temperature + ) + + # Assert + assert isinstance(result, LlmResult) + assert result.text == "Response with custom temperature" + + # Verify Claude API was called with overridden temperature + mock_claude_client.messages.create.assert_called_once() + call_kwargs = mock_claude_client.messages.create.call_args.kwargs + + assert call_kwargs['temperature'] == 0.9 # Should use runtime override + assert call_kwargs['model'] == 'claude-3-5-sonnet-20240620' # Should use processor default + + @patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_anthropic_class): + """Test model parameter override functionality""" + # Arrange + mock_claude_client = MagicMock() + mock_response = MagicMock() + mock_response.content = [MagicMock()] + mock_response.content[0].text = "Response with custom model" + mock_response.usage.input_tokens = 18 + mock_response.usage.output_tokens = 14 + + mock_claude_client.messages.create.return_value = mock_response + mock_anthropic_class.return_value = mock_claude_client + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'model': 'claude-3-5-sonnet-20240620', # Default model + 'api_key': 'test-api-key', + 'temperature': 0.2, # Default temperature + 'max_output': 8192, + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + # Act - Override model at runtime + result = await processor.generate_content( + "System prompt", + "User prompt", + model="claude-3-haiku-20240307", # Override model + temperature=None # Use default temperature + ) + + # Assert + assert isinstance(result, LlmResult) + assert result.text == "Response with custom model" + + # Verify Claude API was called with overridden model + mock_claude_client.messages.create.assert_called_once() + call_kwargs = mock_claude_client.messages.create.call_args.kwargs + + assert call_kwargs['model'] == 'claude-3-haiku-20240307' # Should use runtime override + assert call_kwargs['temperature'] == 0.2 # Should use processor default + + @patch('trustgraph.model.text_completion.claude.llm.anthropic.Anthropic') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_anthropic_class): + """Test overriding both model and temperature parameters simultaneously""" + # Arrange + mock_claude_client = MagicMock() + mock_response = MagicMock() + mock_response.content = [MagicMock()] + mock_response.content[0].text = "Response with both overrides" + mock_response.usage.input_tokens = 22 + mock_response.usage.output_tokens = 16 + + mock_claude_client.messages.create.return_value = mock_response + mock_anthropic_class.return_value = mock_claude_client + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'model': 'claude-3-5-sonnet-20240620', # Default model + 'api_key': 'test-api-key', + 'temperature': 0.0, # Default temperature + 'max_output': 8192, + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + # Act - Override both parameters at runtime + result = await processor.generate_content( + "System prompt", + "User prompt", + model="claude-3-opus-20240229", # Override model + temperature=0.8 # Override temperature + ) + + # Assert + assert isinstance(result, LlmResult) + assert result.text == "Response with both overrides" + + # Verify Claude API was called with both overrides + mock_claude_client.messages.create.assert_called_once() + call_kwargs = mock_claude_client.messages.create.call_args.kwargs + + assert call_kwargs['model'] == 'claude-3-opus-20240229' # Should use runtime override + assert call_kwargs['temperature'] == 0.8 # Should use runtime override + if __name__ == '__main__': pytest.main([__file__]) \ No newline at end of file diff --git a/tests/unit/test_text_completion/test_cohere_processor.py b/tests/unit/test_text_completion/test_cohere_processor.py index 9e4397bc..6201f95c 100644 --- a/tests/unit/test_text_completion/test_cohere_processor.py +++ b/tests/unit/test_text_completion/test_cohere_processor.py @@ -442,6 +442,162 @@ class TestCohereProcessorSimple(IsolatedAsyncioTestCase): assert call_args[1]['prompt_truncation'] == 'auto' assert call_args[1]['connectors'] == [] + @patch('trustgraph.model.text_completion.cohere.llm.cohere.Client') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_cohere_class): + """Test temperature parameter override functionality""" + # Arrange + mock_cohere_client = MagicMock() + mock_output = MagicMock() + mock_output.text = 'Response with custom temperature' + mock_output.meta.billed_units.input_tokens = 20 + mock_output.meta.billed_units.output_tokens = 12 + + mock_cohere_client.chat.return_value = mock_output + mock_cohere_class.return_value = mock_cohere_client + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'model': 'c4ai-aya-23-8b', + 'api_key': 'test-api-key', + 'temperature': 0.0, # Default temperature + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + # Act - Override temperature at runtime + result = await processor.generate_content( + "System prompt", + "User prompt", + model=None, # Use default model + temperature=0.8 # Override temperature + ) + + # Assert + assert isinstance(result, LlmResult) + assert result.text == "Response with custom temperature" + + # Verify Cohere API was called with overridden temperature + mock_cohere_client.chat.assert_called_once_with( + model='c4ai-aya-23-8b', + message='User prompt', + preamble='System prompt', + temperature=0.8, # Should use runtime override + chat_history=[], + prompt_truncation='auto', + connectors=[] + ) + + @patch('trustgraph.model.text_completion.cohere.llm.cohere.Client') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_cohere_class): + """Test model parameter override functionality""" + # Arrange + mock_cohere_client = MagicMock() + mock_output = MagicMock() + mock_output.text = 'Response with custom model' + mock_output.meta.billed_units.input_tokens = 18 + mock_output.meta.billed_units.output_tokens = 14 + + mock_cohere_client.chat.return_value = mock_output + mock_cohere_class.return_value = mock_cohere_client + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'model': 'c4ai-aya-23-8b', # Default model + 'api_key': 'test-api-key', + 'temperature': 0.1, # Default temperature + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + # Act - Override model at runtime + result = await processor.generate_content( + "System prompt", + "User prompt", + model="command-r-plus", # Override model + temperature=None # Use default temperature + ) + + # Assert + assert isinstance(result, LlmResult) + assert result.text == "Response with custom model" + + # Verify Cohere API was called with overridden model + mock_cohere_client.chat.assert_called_once_with( + model='command-r-plus', # Should use runtime override + message='User prompt', + preamble='System prompt', + temperature=0.1, # Should use processor default + chat_history=[], + prompt_truncation='auto', + connectors=[] + ) + + @patch('trustgraph.model.text_completion.cohere.llm.cohere.Client') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_cohere_class): + """Test overriding both model and temperature parameters simultaneously""" + # Arrange + mock_cohere_client = MagicMock() + mock_output = MagicMock() + mock_output.text = 'Response with both overrides' + mock_output.meta.billed_units.input_tokens = 22 + mock_output.meta.billed_units.output_tokens = 16 + + mock_cohere_client.chat.return_value = mock_output + mock_cohere_class.return_value = mock_cohere_client + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'model': 'c4ai-aya-23-8b', # Default model + 'api_key': 'test-api-key', + 'temperature': 0.0, # Default temperature + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + # Act - Override both parameters at runtime + result = await processor.generate_content( + "System prompt", + "User prompt", + model="command-r", # Override model + temperature=0.9 # Override temperature + ) + + # Assert + assert isinstance(result, LlmResult) + assert result.text == "Response with both overrides" + + # Verify Cohere API was called with both overrides + mock_cohere_client.chat.assert_called_once_with( + model='command-r', # Should use runtime override + message='User prompt', + preamble='System prompt', + temperature=0.9, # Should use runtime override + chat_history=[], + prompt_truncation='auto', + connectors=[] + ) + if __name__ == '__main__': pytest.main([__file__]) \ No newline at end of file diff --git a/tests/unit/test_text_completion/test_googleaistudio_processor.py b/tests/unit/test_text_completion/test_googleaistudio_processor.py index f31715d2..c54b3928 100644 --- a/tests/unit/test_text_completion/test_googleaistudio_processor.py +++ b/tests/unit/test_text_completion/test_googleaistudio_processor.py @@ -477,6 +477,156 @@ class TestGoogleAIStudioProcessorSimple(IsolatedAsyncioTestCase): # The system instruction should be in the config object assert call_args[1]['contents'] == "Explain quantum computing" + @patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_genai_class): + """Test temperature parameter override functionality""" + # Arrange + mock_genai_client = MagicMock() + mock_response = MagicMock() + mock_response.text = 'Response with custom temperature' + mock_response.usage_metadata.prompt_token_count = 20 + mock_response.usage_metadata.candidates_token_count = 12 + + mock_genai_client.models.generate_content.return_value = mock_response + mock_genai_class.return_value = mock_genai_client + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'model': 'gemini-2.0-flash-001', + 'api_key': 'test-api-key', + 'temperature': 0.0, # Default temperature + 'max_output': 8192, + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + # Act - Override temperature at runtime + result = await processor.generate_content( + "System prompt", + "User prompt", + model=None, # Use default model + temperature=0.8 # Override temperature + ) + + # Assert + assert isinstance(result, LlmResult) + assert result.text == "Response with custom temperature" + + # Verify the generation config was created with overridden temperature + cache_key = f"gemini-2.0-flash-001:0.8" + assert cache_key in processor.generation_configs + config_obj = processor.generation_configs[cache_key] + assert config_obj.temperature == 0.8 + + @patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_genai_class): + """Test model parameter override functionality""" + # Arrange + mock_genai_client = MagicMock() + mock_response = MagicMock() + mock_response.text = 'Response with custom model' + mock_response.usage_metadata.prompt_token_count = 18 + mock_response.usage_metadata.candidates_token_count = 14 + + mock_genai_client.models.generate_content.return_value = mock_response + mock_genai_class.return_value = mock_genai_client + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'model': 'gemini-2.0-flash-001', # Default model + 'api_key': 'test-api-key', + 'temperature': 0.1, # Default temperature + 'max_output': 8192, + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + # Act - Override model at runtime + result = await processor.generate_content( + "System prompt", + "User prompt", + model="gemini-1.5-pro", # Override model + temperature=None # Use default temperature + ) + + # Assert + assert isinstance(result, LlmResult) + assert result.text == "Response with custom model" + + # Verify Google AI Studio API was called with overridden model + call_args = mock_genai_client.models.generate_content.call_args + assert call_args[1]['model'] == 'gemini-1.5-pro' # Should use runtime override + + # Verify the generation config was created for the correct model + cache_key = f"gemini-1.5-pro:0.1" + assert cache_key in processor.generation_configs + + @patch('trustgraph.model.text_completion.googleaistudio.llm.genai.Client') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_genai_class): + """Test overriding both model and temperature parameters simultaneously""" + # Arrange + mock_genai_client = MagicMock() + mock_response = MagicMock() + mock_response.text = 'Response with both overrides' + mock_response.usage_metadata.prompt_token_count = 22 + mock_response.usage_metadata.candidates_token_count = 16 + + mock_genai_client.models.generate_content.return_value = mock_response + mock_genai_class.return_value = mock_genai_client + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'model': 'gemini-2.0-flash-001', # Default model + 'api_key': 'test-api-key', + 'temperature': 0.0, # Default temperature + 'max_output': 8192, + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + # Act - Override both parameters at runtime + result = await processor.generate_content( + "System prompt", + "User prompt", + model="gemini-1.5-flash", # Override model + temperature=0.9 # Override temperature + ) + + # Assert + assert isinstance(result, LlmResult) + assert result.text == "Response with both overrides" + + # Verify Google AI Studio API was called with both overrides + call_args = mock_genai_client.models.generate_content.call_args + assert call_args[1]['model'] == 'gemini-1.5-flash' # Should use runtime override + + # Verify the generation config was created with both overrides + cache_key = f"gemini-1.5-flash:0.9" + assert cache_key in processor.generation_configs + config_obj = processor.generation_configs[cache_key] + assert config_obj.temperature == 0.9 + if __name__ == '__main__': pytest.main([__file__]) \ No newline at end of file diff --git a/tests/unit/test_text_completion/test_llamafile_processor.py b/tests/unit/test_text_completion/test_llamafile_processor.py index 425cfce8..410b3ff2 100644 --- a/tests/unit/test_text_completion/test_llamafile_processor.py +++ b/tests/unit/test_text_completion/test_llamafile_processor.py @@ -458,5 +458,132 @@ class TestLlamaFileProcessorSimple(IsolatedAsyncioTestCase): # No specific rate limit error handling tested since SLM presumably has no rate limits + @patch('trustgraph.model.text_completion.llamafile.llm.OpenAI') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_generate_content_with_model_override(self, mock_llm_init, mock_async_init, mock_openai_class): + """Test generate_content with model parameter override""" + # Arrange + mock_openai_client = MagicMock() + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Response from overridden model" + mock_response.usage.prompt_tokens = 15 + mock_response.usage.completion_tokens = 10 + + mock_openai_client.chat.completions.create.return_value = mock_response + mock_openai_class.return_value = mock_openai_client + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'model': 'LLaMA_CPP', + 'llamafile': 'http://localhost:8080/v1', + 'temperature': 0.0, + 'max_output': 4096, + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + # Act - Override model + result = await processor.generate_content("System", "Prompt", model="custom-llamafile-model") + + # Assert + assert result.model == "custom-llamafile-model" # Should use overridden model + assert result.text == "Response from overridden model" + + # Verify the API call was made with overridden model + call_args = mock_openai_client.chat.completions.create.call_args + assert call_args[1]['model'] == "custom-llamafile-model" + + @patch('trustgraph.model.text_completion.llamafile.llm.OpenAI') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_generate_content_with_temperature_override(self, mock_llm_init, mock_async_init, mock_openai_class): + """Test generate_content with temperature parameter override""" + # Arrange + mock_openai_client = MagicMock() + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Response with temperature override" + mock_response.usage.prompt_tokens = 18 + mock_response.usage.completion_tokens = 12 + + mock_openai_client.chat.completions.create.return_value = mock_response + mock_openai_class.return_value = mock_openai_client + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'model': 'LLaMA_CPP', + 'llamafile': 'http://localhost:8080/v1', + 'temperature': 0.0, # Default temperature + 'max_output': 4096, + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + # Act - Override temperature + result = await processor.generate_content("System", "Prompt", temperature=0.7) + + # Assert + assert result.text == "Response with temperature override" + + # Verify the API call was made with overridden temperature + call_args = mock_openai_client.chat.completions.create.call_args + assert call_args[1]['temperature'] == 0.7 + + @patch('trustgraph.model.text_completion.llamafile.llm.OpenAI') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_generate_content_with_both_parameters_override(self, mock_llm_init, mock_async_init, mock_openai_class): + """Test generate_content with both model and temperature overrides""" + # Arrange + mock_openai_client = MagicMock() + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Response with both parameters override" + mock_response.usage.prompt_tokens = 20 + mock_response.usage.completion_tokens = 15 + + mock_openai_client.chat.completions.create.return_value = mock_response + mock_openai_class.return_value = mock_openai_client + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'model': 'LLaMA_CPP', + 'llamafile': 'http://localhost:8080/v1', + 'temperature': 0.0, + 'max_output': 4096, + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + # Act - Override both parameters + result = await processor.generate_content("System", "Prompt", model="override-model", temperature=0.8) + + # Assert + assert result.model == "override-model" + assert result.text == "Response with both parameters override" + + # Verify the API call was made with overridden parameters + call_args = mock_openai_client.chat.completions.create.call_args + assert call_args[1]['model'] == "override-model" + assert call_args[1]['temperature'] == 0.8 + + if __name__ == '__main__': pytest.main([__file__]) \ No newline at end of file diff --git a/tests/unit/test_text_completion/test_lmstudio_processor.py b/tests/unit/test_text_completion/test_lmstudio_processor.py new file mode 100644 index 00000000..4864151f --- /dev/null +++ b/tests/unit/test_text_completion/test_lmstudio_processor.py @@ -0,0 +1,229 @@ +""" +Unit tests for trustgraph.model.text_completion.lmstudio +Following the same successful pattern as previous tests +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from unittest import IsolatedAsyncioTestCase + +# Import the service under test +from trustgraph.model.text_completion.lmstudio.llm import Processor +from trustgraph.base import LlmResult +from trustgraph.exceptions import TooManyRequests + + +class TestLMStudioProcessorSimple(IsolatedAsyncioTestCase): + """Test LMStudio processor functionality""" + + @patch('trustgraph.model.text_completion.lmstudio.llm.OpenAI') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_openai_class): + """Test basic processor initialization""" + # Arrange + mock_openai = MagicMock() + mock_openai_class.return_value = mock_openai + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'model': 'gemma3:9b', + 'url': 'http://localhost:1234/', + 'temperature': 0.0, + 'max_output': 4096, + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + # Act + processor = Processor(**config) + + # Assert + assert processor.default_model == 'gemma3:9b' + assert processor.url == 'http://localhost:1234/v1/' + assert processor.temperature == 0.0 + assert processor.max_output == 4096 + assert hasattr(processor, 'openai') + mock_openai_class.assert_called_once_with( + base_url='http://localhost:1234/v1/', + api_key='sk-no-key-required' + ) + + @patch('trustgraph.model.text_completion.lmstudio.llm.OpenAI') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_openai_class): + """Test successful content generation""" + # Arrange + mock_openai = MagicMock() + mock_response = MagicMock() + mock_response.choices[0].message.content = 'Generated response from LMStudio' + mock_response.usage.prompt_tokens = 20 + mock_response.usage.completion_tokens = 12 + + mock_openai.chat.completions.create.return_value = mock_response + mock_openai_class.return_value = mock_openai + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'model': 'gemma3:9b', + 'url': 'http://localhost:1234/', + 'temperature': 0.0, + 'max_output': 4096, + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + # Act + result = await processor.generate_content("System prompt", "User prompt") + + # Assert + assert isinstance(result, LlmResult) + assert result.text == "Generated response from LMStudio" + assert result.in_token == 20 + assert result.out_token == 12 + assert result.model == 'gemma3:9b' + + # Verify the API call was made correctly + mock_openai.chat.completions.create.assert_called_once() + call_args = mock_openai.chat.completions.create.call_args + + # Check model and temperature + assert call_args[1]['model'] == 'gemma3:9b' + assert call_args[1]['temperature'] == 0.0 + assert call_args[1]['max_tokens'] == 4096 + + @patch('trustgraph.model.text_completion.lmstudio.llm.OpenAI') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_generate_content_with_model_override(self, mock_llm_init, mock_async_init, mock_openai_class): + """Test generate_content with model parameter override""" + # Arrange + mock_openai = MagicMock() + mock_response = MagicMock() + mock_response.choices[0].message.content = 'Response from overridden model' + mock_response.usage.prompt_tokens = 15 + mock_response.usage.completion_tokens = 10 + + mock_openai.chat.completions.create.return_value = mock_response + mock_openai_class.return_value = mock_openai + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'model': 'gemma3:9b', + 'url': 'http://localhost:1234/', + 'temperature': 0.0, + 'max_output': 4096, + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + # Act - Override model + result = await processor.generate_content("System", "Prompt", model="custom-lmstudio-model") + + # Assert + assert result.model == "custom-lmstudio-model" # Should use overridden model + assert result.text == "Response from overridden model" + + # Verify the API call was made with overridden model + call_args = mock_openai.chat.completions.create.call_args + assert call_args[1]['model'] == "custom-lmstudio-model" + + @patch('trustgraph.model.text_completion.lmstudio.llm.OpenAI') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_generate_content_with_temperature_override(self, mock_llm_init, mock_async_init, mock_openai_class): + """Test generate_content with temperature parameter override""" + # Arrange + mock_openai = MagicMock() + mock_response = MagicMock() + mock_response.choices[0].message.content = 'Response with temperature override' + mock_response.usage.prompt_tokens = 18 + mock_response.usage.completion_tokens = 12 + + mock_openai.chat.completions.create.return_value = mock_response + mock_openai_class.return_value = mock_openai + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'model': 'gemma3:9b', + 'url': 'http://localhost:1234/', + 'temperature': 0.0, # Default temperature + 'max_output': 4096, + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + # Act - Override temperature + result = await processor.generate_content("System", "Prompt", temperature=0.7) + + # Assert + assert result.text == "Response with temperature override" + + # Verify the API call was made with overridden temperature + call_args = mock_openai.chat.completions.create.call_args + assert call_args[1]['temperature'] == 0.7 + + @patch('trustgraph.model.text_completion.lmstudio.llm.OpenAI') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_generate_content_with_both_parameters_override(self, mock_llm_init, mock_async_init, mock_openai_class): + """Test generate_content with both model and temperature overrides""" + # Arrange + mock_openai = MagicMock() + mock_response = MagicMock() + mock_response.choices[0].message.content = 'Response with both parameters override' + mock_response.usage.prompt_tokens = 20 + mock_response.usage.completion_tokens = 15 + + mock_openai.chat.completions.create.return_value = mock_response + mock_openai_class.return_value = mock_openai + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'model': 'gemma3:9b', + 'url': 'http://localhost:1234/', + 'temperature': 0.0, + 'max_output': 4096, + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + # Act - Override both parameters + result = await processor.generate_content("System", "Prompt", model="override-model", temperature=0.8) + + # Assert + assert result.model == "override-model" + assert result.text == "Response with both parameters override" + + # Verify the API call was made with overridden parameters + call_args = mock_openai.chat.completions.create.call_args + assert call_args[1]['model'] == "override-model" + assert call_args[1]['temperature'] == 0.8 + + +if __name__ == '__main__': + pytest.main([__file__]) \ No newline at end of file diff --git a/tests/unit/test_text_completion/test_mistral_processor.py b/tests/unit/test_text_completion/test_mistral_processor.py new file mode 100644 index 00000000..a40cca70 --- /dev/null +++ b/tests/unit/test_text_completion/test_mistral_processor.py @@ -0,0 +1,275 @@ +""" +Unit tests for trustgraph.model.text_completion.mistral +Following the same successful pattern as other processor tests +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from unittest import IsolatedAsyncioTestCase + +# Import the service under test +from trustgraph.model.text_completion.mistral.llm import Processor +from trustgraph.base import LlmResult + + +class TestMistralProcessorSimple(IsolatedAsyncioTestCase): + """Test Mistral processor functionality""" + + @patch('trustgraph.model.text_completion.mistral.llm.Mistral') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_mistral_class): + """Test basic processor initialization""" + # Arrange + mock_mistral_client = MagicMock() + mock_mistral_class.return_value = mock_mistral_client + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'model': 'ministral-8b-latest', + 'api_key': 'test-api-key', + 'temperature': 0.1, + 'max_output': 2048, + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + # Act + processor = Processor(**config) + + # Assert + assert processor.default_model == 'ministral-8b-latest' + assert processor.temperature == 0.1 + assert processor.max_output == 2048 + assert hasattr(processor, 'mistral') + mock_mistral_class.assert_called_once_with(api_key='test-api-key') + + @patch('trustgraph.model.text_completion.mistral.llm.Mistral') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_mistral_class): + """Test successful content generation""" + # Arrange + mock_mistral_client = MagicMock() + mock_response = MagicMock() + mock_response.choices[0].message.content = 'Generated response from Mistral' + mock_response.usage.prompt_tokens = 15 + mock_response.usage.completion_tokens = 8 + mock_mistral_client.chat.complete.return_value = mock_response + mock_mistral_class.return_value = mock_mistral_client + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'model': 'ministral-8b-latest', + 'api_key': 'test-api-key', + 'temperature': 0.0, + 'max_output': 4096, + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + # Act + result = await processor.generate_content("System prompt", "User prompt") + + # Assert + assert isinstance(result, LlmResult) + assert result.text == "Generated response from Mistral" + assert result.in_token == 15 + assert result.out_token == 8 + assert result.model == 'ministral-8b-latest' + mock_mistral_client.chat.complete.assert_called_once() + + @patch('trustgraph.model.text_completion.mistral.llm.Mistral') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_mistral_class): + """Test temperature parameter override functionality""" + # Arrange + mock_mistral_client = MagicMock() + mock_response = MagicMock() + mock_response.choices[0].message.content = 'Response with custom temperature' + mock_response.usage.prompt_tokens = 20 + mock_response.usage.completion_tokens = 12 + mock_mistral_client.chat.complete.return_value = mock_response + mock_mistral_class.return_value = mock_mistral_client + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'model': 'ministral-8b-latest', + 'api_key': 'test-api-key', + 'temperature': 0.0, # Default temperature + 'max_output': 4096, + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + # Act - Override temperature at runtime + result = await processor.generate_content( + "System prompt", + "User prompt", + model=None, # Use default model + temperature=0.8 # Override temperature + ) + + # Assert + assert isinstance(result, LlmResult) + assert result.text == "Response with custom temperature" + + # Verify Mistral API was called with overridden temperature + call_args = mock_mistral_client.chat.complete.call_args + assert call_args[1]['temperature'] == 0.8 # Should use runtime override + assert call_args[1]['model'] == 'ministral-8b-latest' + + @patch('trustgraph.model.text_completion.mistral.llm.Mistral') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_mistral_class): + """Test model parameter override functionality""" + # Arrange + mock_mistral_client = MagicMock() + mock_response = MagicMock() + mock_response.choices[0].message.content = 'Response with custom model' + mock_response.usage.prompt_tokens = 18 + mock_response.usage.completion_tokens = 14 + mock_mistral_client.chat.complete.return_value = mock_response + mock_mistral_class.return_value = mock_mistral_client + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'model': 'ministral-8b-latest', # Default model + 'api_key': 'test-api-key', + 'temperature': 0.1, # Default temperature + 'max_output': 4096, + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + # Act - Override model at runtime + result = await processor.generate_content( + "System prompt", + "User prompt", + model="mistral-large-latest", # Override model + temperature=None # Use default temperature + ) + + # Assert + assert isinstance(result, LlmResult) + assert result.text == "Response with custom model" + + # Verify Mistral API was called with overridden model + call_args = mock_mistral_client.chat.complete.call_args + assert call_args[1]['model'] == 'mistral-large-latest' # Should use runtime override + assert call_args[1]['temperature'] == 0.1 # Should use processor default + + @patch('trustgraph.model.text_completion.mistral.llm.Mistral') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_mistral_class): + """Test overriding both model and temperature parameters simultaneously""" + # Arrange + mock_mistral_client = MagicMock() + mock_response = MagicMock() + mock_response.choices[0].message.content = 'Response with both overrides' + mock_response.usage.prompt_tokens = 22 + mock_response.usage.completion_tokens = 16 + mock_mistral_client.chat.complete.return_value = mock_response + mock_mistral_class.return_value = mock_mistral_client + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'model': 'ministral-8b-latest', # Default model + 'api_key': 'test-api-key', + 'temperature': 0.0, # Default temperature + 'max_output': 4096, + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + # Act - Override both parameters at runtime + result = await processor.generate_content( + "System prompt", + "User prompt", + model="mistral-large-latest", # Override model + temperature=0.9 # Override temperature + ) + + # Assert + assert isinstance(result, LlmResult) + assert result.text == "Response with both overrides" + + # Verify Mistral API was called with both overrides + call_args = mock_mistral_client.chat.complete.call_args + assert call_args[1]['model'] == 'mistral-large-latest' # Should use runtime override + assert call_args[1]['temperature'] == 0.9 # Should use runtime override + + @patch('trustgraph.model.text_completion.mistral.llm.Mistral') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_generate_content_prompt_construction(self, mock_llm_init, mock_async_init, mock_mistral_class): + """Test prompt construction with system and user prompts""" + # Arrange + mock_mistral_client = MagicMock() + mock_response = MagicMock() + mock_response.choices[0].message.content = 'Response with system instructions' + mock_response.usage.prompt_tokens = 25 + mock_response.usage.completion_tokens = 15 + mock_mistral_client.chat.complete.return_value = mock_response + mock_mistral_class.return_value = mock_mistral_client + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'model': 'ministral-8b-latest', + 'api_key': 'test-api-key', + 'temperature': 0.0, + 'max_output': 4096, + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + # Act + result = await processor.generate_content("You are a helpful assistant", "What is AI?") + + # Assert + assert result.text == "Response with system instructions" + assert result.in_token == 25 + assert result.out_token == 15 + + # Verify the combined prompt structure + call_args = mock_mistral_client.chat.complete.call_args + messages = call_args[1]['messages'] + assert len(messages) == 1 + assert messages[0]['role'] == 'user' + assert messages[0]['content'][0]['type'] == 'text' + assert messages[0]['content'][0]['text'] == "You are a helpful assistant\n\nWhat is AI?" + + +if __name__ == '__main__': + pytest.main([__file__]) \ No newline at end of file diff --git a/tests/unit/test_text_completion/test_ollama_processor.py b/tests/unit/test_text_completion/test_ollama_processor.py index 138a8598..0bf5e0ab 100644 --- a/tests/unit/test_text_completion/test_ollama_processor.py +++ b/tests/unit/test_text_completion/test_ollama_processor.py @@ -312,6 +312,150 @@ class TestOllamaProcessorSimple(IsolatedAsyncioTestCase): # Verify the combined prompt mock_client.generate.assert_called_once_with('llama2', "You are a helpful assistant\n\nWhat is AI?", options={'temperature': 0.0}) + @patch('trustgraph.model.text_completion.ollama.llm.Client') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_client_class): + """Test temperature parameter override functionality""" + # Arrange + mock_client = MagicMock() + mock_response = { + 'response': 'Response with custom temperature', + 'prompt_eval_count': 20, + 'eval_count': 12 + } + mock_client.generate.return_value = mock_response + mock_client_class.return_value = mock_client + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'model': 'llama2', + 'ollama': 'http://localhost:11434', + 'temperature': 0.0, # Default temperature + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + # Act - Override temperature at runtime + result = await processor.generate_content( + "System prompt", + "User prompt", + model=None, # Use default model + temperature=0.8 # Override temperature + ) + + # Assert + assert isinstance(result, LlmResult) + assert result.text == "Response with custom temperature" + + # Verify Ollama API was called with overridden temperature + mock_client.generate.assert_called_once_with( + 'llama2', + "System prompt\n\nUser prompt", + options={'temperature': 0.8} # Should use runtime override + ) + + @patch('trustgraph.model.text_completion.ollama.llm.Client') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_client_class): + """Test model parameter override functionality""" + # Arrange + mock_client = MagicMock() + mock_response = { + 'response': 'Response with custom model', + 'prompt_eval_count': 18, + 'eval_count': 14 + } + mock_client.generate.return_value = mock_response + mock_client_class.return_value = mock_client + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'model': 'llama2', # Default model + 'ollama': 'http://localhost:11434', + 'temperature': 0.1, # Default temperature + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + # Act - Override model at runtime + result = await processor.generate_content( + "System prompt", + "User prompt", + model="mistral", # Override model + temperature=None # Use default temperature + ) + + # Assert + assert isinstance(result, LlmResult) + assert result.text == "Response with custom model" + + # Verify Ollama API was called with overridden model + mock_client.generate.assert_called_once_with( + 'mistral', # Should use runtime override + "System prompt\n\nUser prompt", + options={'temperature': 0.1} # Should use processor default + ) + + @patch('trustgraph.model.text_completion.ollama.llm.Client') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_client_class): + """Test overriding both model and temperature parameters simultaneously""" + # Arrange + mock_client = MagicMock() + mock_response = { + 'response': 'Response with both overrides', + 'prompt_eval_count': 22, + 'eval_count': 16 + } + mock_client.generate.return_value = mock_response + mock_client_class.return_value = mock_client + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'model': 'llama2', # Default model + 'ollama': 'http://localhost:11434', + 'temperature': 0.0, # Default temperature + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + # Act - Override both parameters at runtime + result = await processor.generate_content( + "System prompt", + "User prompt", + model="codellama", # Override model + temperature=0.9 # Override temperature + ) + + # Assert + assert isinstance(result, LlmResult) + assert result.text == "Response with both overrides" + + # Verify Ollama API was called with both overrides + mock_client.generate.assert_called_once_with( + 'codellama', # Should use runtime override + "System prompt\n\nUser prompt", + options={'temperature': 0.9} # Should use runtime override + ) + if __name__ == '__main__': pytest.main([__file__]) \ No newline at end of file diff --git a/tests/unit/test_text_completion/test_openai_processor.py b/tests/unit/test_text_completion/test_openai_processor.py index 56425285..a9a43b37 100644 --- a/tests/unit/test_text_completion/test_openai_processor.py +++ b/tests/unit/test_text_completion/test_openai_processor.py @@ -391,5 +391,210 @@ class TestOpenAIProcessorSimple(IsolatedAsyncioTestCase): assert call_args[1]['response_format'] == {"type": "text"} + @patch('trustgraph.model.text_completion.openai.llm.OpenAI') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_openai_class): + """Test temperature parameter override functionality""" + # Arrange + mock_openai_client = MagicMock() + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Response with custom temperature" + mock_response.usage.prompt_tokens = 15 + mock_response.usage.completion_tokens = 10 + + mock_openai_client.chat.completions.create.return_value = mock_response + mock_openai_class.return_value = mock_openai_client + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'model': 'gpt-3.5-turbo', + 'api_key': 'test-api-key', + 'url': 'https://api.openai.com/v1', + 'temperature': 0.0, # Default temperature + 'max_output': 4096, + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + # Act - Override temperature at runtime + result = await processor.generate_content( + "System prompt", + "User prompt", + model=None, # Use default model + temperature=0.9 # Override temperature + ) + + # Assert + assert isinstance(result, LlmResult) + assert result.text == "Response with custom temperature" + + # Verify the OpenAI API was called with overridden temperature + mock_openai_client.chat.completions.create.assert_called_once() + call_kwargs = mock_openai_client.chat.completions.create.call_args.kwargs + + assert call_kwargs['temperature'] == 0.9 # Should use runtime override + assert call_kwargs['model'] == 'gpt-3.5-turbo' # Should use processor default + + @patch('trustgraph.model.text_completion.openai.llm.OpenAI') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_openai_class): + """Test model parameter override functionality""" + # Arrange + mock_openai_client = MagicMock() + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Response with custom model" + mock_response.usage.prompt_tokens = 15 + mock_response.usage.completion_tokens = 10 + + mock_openai_client.chat.completions.create.return_value = mock_response + mock_openai_class.return_value = mock_openai_client + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'model': 'gpt-3.5-turbo', # Default model + 'api_key': 'test-api-key', + 'url': 'https://api.openai.com/v1', + 'temperature': 0.2, + 'max_output': 4096, + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + # Act - Override model at runtime + result = await processor.generate_content( + "System prompt", + "User prompt", + model="gpt-4", # Override model + temperature=None # Use default temperature + ) + + # Assert + assert isinstance(result, LlmResult) + assert result.text == "Response with custom model" + + # Verify the OpenAI API was called with overridden model + mock_openai_client.chat.completions.create.assert_called_once() + call_kwargs = mock_openai_client.chat.completions.create.call_args.kwargs + + assert call_kwargs['model'] == 'gpt-4' # Should use runtime override + assert call_kwargs['temperature'] == 0.2 # Should use processor default + + @patch('trustgraph.model.text_completion.openai.llm.OpenAI') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_openai_class): + """Test overriding both model and temperature parameters simultaneously""" + # Arrange + mock_openai_client = MagicMock() + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Response with both overrides" + mock_response.usage.prompt_tokens = 15 + mock_response.usage.completion_tokens = 10 + + mock_openai_client.chat.completions.create.return_value = mock_response + mock_openai_class.return_value = mock_openai_client + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'model': 'gpt-3.5-turbo', # Default model + 'api_key': 'test-api-key', + 'url': 'https://api.openai.com/v1', + 'temperature': 0.0, # Default temperature + 'max_output': 4096, + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + # Act - Override both parameters at runtime + result = await processor.generate_content( + "System prompt", + "User prompt", + model="gpt-4", # Override model + temperature=0.7 # Override temperature + ) + + # Assert + assert isinstance(result, LlmResult) + assert result.text == "Response with both overrides" + + # Verify the OpenAI API was called with both overrides + mock_openai_client.chat.completions.create.assert_called_once() + call_kwargs = mock_openai_client.chat.completions.create.call_args.kwargs + + assert call_kwargs['model'] == 'gpt-4' # Should use runtime override + assert call_kwargs['temperature'] == 0.7 # Should use runtime override + + @patch('trustgraph.model.text_completion.openai.llm.OpenAI') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_generate_content_no_override_uses_defaults(self, mock_llm_init, mock_async_init, mock_openai_class): + """Test that when no parameters are overridden, processor defaults are used""" + # Arrange + mock_openai_client = MagicMock() + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Response with defaults" + mock_response.usage.prompt_tokens = 15 + mock_response.usage.completion_tokens = 10 + + mock_openai_client.chat.completions.create.return_value = mock_response + mock_openai_class.return_value = mock_openai_client + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'model': 'gpt-4', # Default model + 'api_key': 'test-api-key', + 'url': 'https://api.openai.com/v1', + 'temperature': 0.5, # Default temperature + 'max_output': 4096, + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + # Act - Don't override any parameters (pass None) + result = await processor.generate_content( + "System prompt", + "User prompt", + model=None, # Use default model + temperature=None # Use default temperature + ) + + # Assert + assert isinstance(result, LlmResult) + assert result.text == "Response with defaults" + + # Verify the OpenAI API was called with processor defaults + mock_openai_client.chat.completions.create.assert_called_once() + call_kwargs = mock_openai_client.chat.completions.create.call_args.kwargs + + assert call_kwargs['model'] == 'gpt-4' # Should use processor default + assert call_kwargs['temperature'] == 0.5 # Should use processor default + + if __name__ == '__main__': pytest.main([__file__]) \ No newline at end of file diff --git a/tests/unit/test_text_completion/test_parameter_caching.py b/tests/unit/test_text_completion/test_parameter_caching.py new file mode 100644 index 00000000..39e58863 --- /dev/null +++ b/tests/unit/test_text_completion/test_parameter_caching.py @@ -0,0 +1,186 @@ +""" +Unit tests for Parameter-Based Caching in LLM Processors +Testing processors that cache based on temperature parameters (Bedrock, GoogleAIStudio) +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from unittest import IsolatedAsyncioTestCase + +from trustgraph.model.text_completion.googleaistudio.llm import Processor as GoogleAIProcessor +from trustgraph.base import LlmResult + + +class TestParameterCaching(IsolatedAsyncioTestCase): + """Test parameter-based caching functionality""" + + @patch('trustgraph.model.text_completion.googleaistudio.llm.genai') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_googleai_temperature_cache_keys(self, mock_llm_init, mock_async_init, mock_genai): + """Test that GoogleAI processor creates separate cache entries for different temperatures""" + # Arrange + mock_client = MagicMock() + mock_genai.Client.return_value = mock_client + + mock_response = MagicMock() + mock_response.text = "Generated response" + mock_response.usage_metadata.prompt_token_count = 10 + mock_response.usage_metadata.candidates_token_count = 5 + mock_client.models.generate_content.return_value = mock_response + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'model': 'gemini-2.0-flash-001', + 'api_key': 'test-api-key', + 'temperature': 0.0, # Default temperature + 'max_output': 1024, + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = GoogleAIProcessor(**config) + + # Act - Call with different temperatures + await processor.generate_content("System", "Prompt 1", model="gemini-2.0-flash-001", temperature=0.0) + await processor.generate_content("System", "Prompt 2", model="gemini-2.0-flash-001", temperature=0.5) + await processor.generate_content("System", "Prompt 3", model="gemini-2.0-flash-001", temperature=1.0) + + # Assert - Should have 3 different cache entries + cache_keys = list(processor.generation_configs.keys()) + + assert len(cache_keys) == 3 + assert "gemini-2.0-flash-001:0.0" in cache_keys + assert "gemini-2.0-flash-001:0.5" in cache_keys + assert "gemini-2.0-flash-001:1.0" in cache_keys + + # Verify each cached config has the correct temperature + assert processor.generation_configs["gemini-2.0-flash-001:0.0"].temperature == 0.0 + assert processor.generation_configs["gemini-2.0-flash-001:0.5"].temperature == 0.5 + assert processor.generation_configs["gemini-2.0-flash-001:1.0"].temperature == 1.0 + + @patch('trustgraph.model.text_completion.googleaistudio.llm.genai') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_googleai_cache_reuse_same_parameters(self, mock_llm_init, mock_async_init, mock_genai): + """Test that GoogleAI processor reuses cache for identical model+temperature combinations""" + # Arrange + mock_client = MagicMock() + mock_genai.Client.return_value = mock_client + + mock_response = MagicMock() + mock_response.text = "Generated response" + mock_response.usage_metadata.prompt_token_count = 10 + mock_response.usage_metadata.candidates_token_count = 5 + mock_client.models.generate_content.return_value = mock_response + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'model': 'gemini-2.0-flash-001', + 'api_key': 'test-api-key', + 'temperature': 0.0, + 'max_output': 1024, + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = GoogleAIProcessor(**config) + + # Act - Call multiple times with same parameters + await processor.generate_content("System", "Prompt 1", model="gemini-2.0-flash-001", temperature=0.7) + await processor.generate_content("System", "Prompt 2", model="gemini-2.0-flash-001", temperature=0.7) + await processor.generate_content("System", "Prompt 3", model="gemini-2.0-flash-001", temperature=0.7) + + # Assert - Should have only 1 cache entry for the repeated parameters + cache_keys = list(processor.generation_configs.keys()) + assert len(cache_keys) == 1 + assert "gemini-2.0-flash-001:0.7" in cache_keys + + # The same config object should be reused + config_obj = processor.generation_configs["gemini-2.0-flash-001:0.7"] + assert config_obj.temperature == 0.7 + + @patch('trustgraph.model.text_completion.googleaistudio.llm.genai') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_googleai_different_models_separate_caches(self, mock_llm_init, mock_async_init, mock_genai): + """Test that different models create separate cache entries even with same temperature""" + # Arrange + mock_client = MagicMock() + mock_genai.Client.return_value = mock_client + + mock_response = MagicMock() + mock_response.text = "Generated response" + mock_response.usage_metadata.prompt_token_count = 10 + mock_response.usage_metadata.candidates_token_count = 5 + mock_client.models.generate_content.return_value = mock_response + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'model': 'gemini-2.0-flash-001', + 'api_key': 'test-api-key', + 'temperature': 0.0, + 'max_output': 1024, + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = GoogleAIProcessor(**config) + + # Act - Call with different models, same temperature + await processor.generate_content("System", "Prompt 1", model="gemini-2.0-flash-001", temperature=0.5) + await processor.generate_content("System", "Prompt 2", model="gemini-1.5-flash-001", temperature=0.5) + + # Assert - Should have separate cache entries for different models + cache_keys = list(processor.generation_configs.keys()) + assert len(cache_keys) == 2 + assert "gemini-2.0-flash-001:0.5" in cache_keys + assert "gemini-1.5-flash-001:0.5" in cache_keys + + # Note: Bedrock tests would be similar but testing the Bedrock processor's caching behavior + # The Bedrock processor caches model variants with temperature in the cache key + + async def test_bedrock_temperature_cache_keys(self): + """Test Bedrock processor temperature-aware caching""" + # This would test the Bedrock processor's _get_or_create_variant method + # with different temperature values to ensure proper cache key generation + + # Implementation would follow similar pattern to GoogleAI tests above + # but using the Bedrock processor and testing model_variants cache + pass + + async def test_bedrock_cache_isolation_different_temperatures(self): + """Test that Bedrock processor isolates cache entries by temperature""" + pass + + async def test_cache_memory_efficiency(self): + """Test that caches don't grow unbounded with many different parameter combinations""" + # This could test cache size limits or cleanup behavior if implemented + pass + + +class TestCachePerformance(IsolatedAsyncioTestCase): + """Test caching performance characteristics""" + + async def test_cache_hit_performance(self): + """Test that cache hits are faster than cache misses""" + # This would measure timing differences between cache hits and misses + pass + + async def test_concurrent_cache_access(self): + """Test concurrent access to cached configurations""" + # This would test thread-safety of cache access + pass + + +if __name__ == '__main__': + pytest.main([__file__]) \ No newline at end of file diff --git a/tests/unit/test_text_completion/test_tgi_processor.py b/tests/unit/test_text_completion/test_tgi_processor.py new file mode 100644 index 00000000..ca897023 --- /dev/null +++ b/tests/unit/test_text_completion/test_tgi_processor.py @@ -0,0 +1,271 @@ +""" +Unit tests for trustgraph.model.text_completion.tgi +Following the same successful pattern as previous tests +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from unittest import IsolatedAsyncioTestCase + +# Import the service under test +from trustgraph.model.text_completion.tgi.llm import Processor +from trustgraph.base import LlmResult +from trustgraph.exceptions import TooManyRequests + + +class TestTGIProcessorSimple(IsolatedAsyncioTestCase): + """Test TGI processor functionality""" + + @patch('trustgraph.model.text_completion.tgi.llm.aiohttp.ClientSession') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_session_class): + """Test basic processor initialization""" + # Arrange + mock_session = MagicMock() + mock_session_class.return_value = mock_session + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'model': 'tgi', + 'url': 'http://tgi-service:8899/v1', + 'temperature': 0.0, + 'max_output': 2048, + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + # Act + processor = Processor(**config) + + # Assert + assert processor.default_model == 'tgi' + assert processor.base_url == 'http://tgi-service:8899/v1' + assert processor.temperature == 0.0 + assert processor.max_output == 2048 + assert hasattr(processor, 'session') + mock_session_class.assert_called_once() + + @patch('trustgraph.model.text_completion.tgi.llm.aiohttp.ClientSession') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_session_class): + """Test successful content generation""" + # Arrange + mock_session = MagicMock() + mock_response = MagicMock() + mock_response.status = 200 + mock_response.json = AsyncMock(return_value={ + 'choices': [{ + 'message': { + 'content': 'Generated response from TGI' + } + }], + 'usage': { + 'prompt_tokens': 20, + 'completion_tokens': 12 + } + }) + + # Mock the async context manager + mock_session.post.return_value.__aenter__.return_value = mock_response + mock_session.post.return_value.__aexit__.return_value = None + mock_session_class.return_value = mock_session + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'model': 'tgi', + 'url': 'http://tgi-service:8899/v1', + 'temperature': 0.0, + 'max_output': 2048, + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + # Act + result = await processor.generate_content("System prompt", "User prompt") + + # Assert + assert isinstance(result, LlmResult) + assert result.text == "Generated response from TGI" + assert result.in_token == 20 + assert result.out_token == 12 + assert result.model == 'tgi' + + # Verify the API call was made correctly + mock_session.post.assert_called_once() + call_args = mock_session.post.call_args + + # Check URL + assert call_args[0][0] == 'http://tgi-service:8899/v1/chat/completions' + + # Check request structure + request_body = call_args[1]['json'] + assert request_body['model'] == 'tgi' + assert request_body['temperature'] == 0.0 + assert request_body['max_tokens'] == 2048 + + @patch('trustgraph.model.text_completion.tgi.llm.aiohttp.ClientSession') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_generate_content_with_model_override(self, mock_llm_init, mock_async_init, mock_session_class): + """Test generate_content with model parameter override""" + # Arrange + mock_session = MagicMock() + mock_response = MagicMock() + mock_response.status = 200 + mock_response.json = AsyncMock(return_value={ + 'choices': [{ + 'message': { + 'content': 'Response from overridden model' + } + }], + 'usage': { + 'prompt_tokens': 15, + 'completion_tokens': 10 + } + }) + + mock_session.post.return_value.__aenter__.return_value = mock_response + mock_session.post.return_value.__aexit__.return_value = None + mock_session_class.return_value = mock_session + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'model': 'tgi', + 'url': 'http://tgi-service:8899/v1', + 'temperature': 0.0, + 'max_output': 2048, + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + # Act - Override model + result = await processor.generate_content("System", "Prompt", model="custom-tgi-model") + + # Assert + assert result.model == "custom-tgi-model" # Should use overridden model + assert result.text == "Response from overridden model" + + # Verify the API call was made with overridden model + call_args = mock_session.post.call_args + assert call_args[1]['json']['model'] == "custom-tgi-model" + + @patch('trustgraph.model.text_completion.tgi.llm.aiohttp.ClientSession') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_generate_content_with_temperature_override(self, mock_llm_init, mock_async_init, mock_session_class): + """Test generate_content with temperature parameter override""" + # Arrange + mock_session = MagicMock() + mock_response = MagicMock() + mock_response.status = 200 + mock_response.json = AsyncMock(return_value={ + 'choices': [{ + 'message': { + 'content': 'Response with temperature override' + } + }], + 'usage': { + 'prompt_tokens': 18, + 'completion_tokens': 12 + } + }) + + mock_session.post.return_value.__aenter__.return_value = mock_response + mock_session.post.return_value.__aexit__.return_value = None + mock_session_class.return_value = mock_session + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'model': 'tgi', + 'url': 'http://tgi-service:8899/v1', + 'temperature': 0.0, # Default temperature + 'max_output': 2048, + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + # Act - Override temperature + result = await processor.generate_content("System", "Prompt", temperature=0.7) + + # Assert + assert result.text == "Response with temperature override" + + # Verify the API call was made with overridden temperature + call_args = mock_session.post.call_args + assert call_args[1]['json']['temperature'] == 0.7 + + @patch('trustgraph.model.text_completion.tgi.llm.aiohttp.ClientSession') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_generate_content_with_both_parameters_override(self, mock_llm_init, mock_async_init, mock_session_class): + """Test generate_content with both model and temperature overrides""" + # Arrange + mock_session = MagicMock() + mock_response = MagicMock() + mock_response.status = 200 + mock_response.json = AsyncMock(return_value={ + 'choices': [{ + 'message': { + 'content': 'Response with both parameters override' + } + }], + 'usage': { + 'prompt_tokens': 20, + 'completion_tokens': 15 + } + }) + + mock_session.post.return_value.__aenter__.return_value = mock_response + mock_session.post.return_value.__aexit__.return_value = None + mock_session_class.return_value = mock_session + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'model': 'tgi', + 'url': 'http://tgi-service:8899/v1', + 'temperature': 0.0, + 'max_output': 2048, + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + # Act - Override both parameters + result = await processor.generate_content("System", "Prompt", model="override-model", temperature=0.8) + + # Assert + assert result.model == "override-model" + assert result.text == "Response with both parameters override" + + # Verify the API call was made with overridden parameters + call_args = mock_session.post.call_args + assert call_args[1]['json']['model'] == "override-model" + assert call_args[1]['json']['temperature'] == 0.8 + + +if __name__ == '__main__': + pytest.main([__file__]) \ No newline at end of file diff --git a/tests/unit/test_text_completion/test_vertexai_processor.py b/tests/unit/test_text_completion/test_vertexai_processor.py index 48c70d15..60d61acd 100644 --- a/tests/unit/test_text_completion/test_vertexai_processor.py +++ b/tests/unit/test_text_completion/test_vertexai_processor.py @@ -460,6 +460,180 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase): assert processor.api_params["top_p"] == 1.0 assert processor.api_params["top_k"] == 32 + @patch('trustgraph.model.text_completion.vertexai.llm.service_account') + @patch('trustgraph.model.text_completion.vertexai.llm.vertexai') + @patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account): + """Test temperature parameter override functionality""" + # Arrange + mock_credentials = MagicMock() + mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials + + mock_model = MagicMock() + mock_response = MagicMock() + mock_response.text = "Response with custom temperature" + mock_response.usage_metadata.prompt_token_count = 20 + mock_response.usage_metadata.candidates_token_count = 12 + mock_model.generate_content.return_value = mock_response + mock_generative_model.return_value = mock_model + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'region': 'us-central1', + 'model': 'gemini-2.0-flash-001', + 'temperature': 0.0, # Default temperature + 'max_output': 8192, + 'private_key': 'private.json', + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + # Act - Override temperature at runtime + result = await processor.generate_content( + "System prompt", + "User prompt", + model=None, # Use default model + temperature=0.8 # Override temperature + ) + + # Assert + assert isinstance(result, LlmResult) + assert result.text == "Response with custom temperature" + + # Verify Gemini API was called with overridden temperature + mock_model.generate_content.assert_called_once() + call_args = mock_model.generate_content.call_args + + # Check that generation_config was created (we can't directly access temperature from mock) + generation_config = call_args.kwargs['generation_config'] + assert generation_config is not None # Should use overridden temperature configuration + + @patch('trustgraph.model.text_completion.vertexai.llm.service_account') + @patch('trustgraph.model.text_completion.vertexai.llm.vertexai') + @patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account): + """Test model parameter override functionality""" + # Arrange + mock_credentials = MagicMock() + mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials + + # Mock different models + mock_model_default = MagicMock() + mock_model_override = MagicMock() + mock_response = MagicMock() + mock_response.text = "Response with custom model" + mock_response.usage_metadata.prompt_token_count = 18 + mock_response.usage_metadata.candidates_token_count = 14 + mock_model_override.generate_content.return_value = mock_response + + # GenerativeModel should return different models based on input + def model_factory(model_name): + if model_name == 'gemini-1.5-pro': + return mock_model_override + return mock_model_default + + mock_generative_model.side_effect = model_factory + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'region': 'us-central1', + 'model': 'gemini-2.0-flash-001', # Default model + 'temperature': 0.2, # Default temperature + 'max_output': 8192, + 'private_key': 'private.json', + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + # Act - Override model at runtime + result = await processor.generate_content( + "System prompt", + "User prompt", + model="gemini-1.5-pro", # Override model + temperature=None # Use default temperature + ) + + # Assert + assert isinstance(result, LlmResult) + assert result.text == "Response with custom model" + + # Verify the overridden model was used + mock_model_override.generate_content.assert_called_once() + # Verify GenerativeModel was called with the override model + mock_generative_model.assert_called_with('gemini-1.5-pro') + + @patch('trustgraph.model.text_completion.vertexai.llm.service_account') + @patch('trustgraph.model.text_completion.vertexai.llm.vertexai') + @patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account): + """Test overriding both model and temperature parameters simultaneously""" + # Arrange + mock_credentials = MagicMock() + mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials + + mock_model = MagicMock() + mock_response = MagicMock() + mock_response.text = "Response with both overrides" + mock_response.usage_metadata.prompt_token_count = 22 + mock_response.usage_metadata.candidates_token_count = 16 + mock_model.generate_content.return_value = mock_response + mock_generative_model.return_value = mock_model + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'region': 'us-central1', + 'model': 'gemini-2.0-flash-001', # Default model + 'temperature': 0.0, # Default temperature + 'max_output': 8192, + 'private_key': 'private.json', + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + # Act - Override both parameters at runtime + result = await processor.generate_content( + "System prompt", + "User prompt", + model="gemini-1.5-flash-001", # Override model + temperature=0.9 # Override temperature + ) + + # Assert + assert isinstance(result, LlmResult) + assert result.text == "Response with both overrides" + + # Verify both overrides were used + mock_model.generate_content.assert_called_once() + call_args = mock_model.generate_content.call_args + + # Verify model override + mock_generative_model.assert_called_with('gemini-1.5-flash-001') # Should use runtime override + + # Verify temperature override (we can't directly access temperature from mock) + generation_config = call_args.kwargs['generation_config'] + assert generation_config is not None # Should use overridden temperature configuration + if __name__ == '__main__': pytest.main([__file__]) \ No newline at end of file diff --git a/tests/unit/test_text_completion/test_vllm_processor.py b/tests/unit/test_text_completion/test_vllm_processor.py index 7124c229..64da8ff9 100644 --- a/tests/unit/test_text_completion/test_vllm_processor.py +++ b/tests/unit/test_text_completion/test_vllm_processor.py @@ -485,5 +485,148 @@ class TestVLLMProcessorSimple(IsolatedAsyncioTestCase): assert call_args[1]['json']['prompt'] == expected_prompt + @patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_generate_content_with_model_override(self, mock_llm_init, mock_async_init, mock_session_class): + """Test generate_content with model parameter override""" + # Arrange + mock_session = MagicMock() + mock_response = MagicMock() + mock_response.status = 200 + mock_response.json = AsyncMock(return_value={ + 'choices': [{ + 'text': 'Response from overridden model' + }], + 'usage': { + 'prompt_tokens': 12, + 'completion_tokens': 8 + } + }) + + mock_session.post.return_value.__aenter__.return_value = mock_response + mock_session.post.return_value.__aexit__.return_value = None + mock_session_class.return_value = mock_session + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'model': 'TheBloke/Mistral-7B-v0.1-AWQ', + 'url': 'http://vllm-service:8899/v1', + 'temperature': 0.0, + 'max_output': 2048, + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + # Act - Override model + result = await processor.generate_content("System", "Prompt", model="custom-vllm-model") + + # Assert + assert result.model == "custom-vllm-model" # Should use overridden model + assert result.text == "Response from overridden model" + + @patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_generate_content_with_temperature_override(self, mock_llm_init, mock_async_init, mock_session_class): + """Test generate_content with temperature parameter override""" + # Arrange + mock_session = MagicMock() + mock_response = MagicMock() + mock_response.status = 200 + mock_response.json = AsyncMock(return_value={ + 'choices': [{ + 'text': 'Response with temperature override' + }], + 'usage': { + 'prompt_tokens': 15, + 'completion_tokens': 10 + } + }) + + mock_session.post.return_value.__aenter__.return_value = mock_response + mock_session.post.return_value.__aexit__.return_value = None + mock_session_class.return_value = mock_session + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'model': 'TheBloke/Mistral-7B-v0.1-AWQ', + 'url': 'http://vllm-service:8899/v1', + 'temperature': 0.0, # Default temperature + 'max_output': 2048, + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + # Act - Override temperature + result = await processor.generate_content("System", "Prompt", temperature=0.7) + + # Assert + assert result.text == "Response with temperature override" + + # Verify the request was made with overridden temperature + call_args = mock_session.post.call_args + assert call_args[1]['json']['temperature'] == 0.7 + + @patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_generate_content_with_both_parameters_override(self, mock_llm_init, mock_async_init, mock_session_class): + """Test generate_content with both model and temperature overrides""" + # Arrange + mock_session = MagicMock() + mock_response = MagicMock() + mock_response.status = 200 + mock_response.json = AsyncMock(return_value={ + 'choices': [{ + 'text': 'Response with both parameters override' + }], + 'usage': { + 'prompt_tokens': 18, + 'completion_tokens': 12 + } + }) + + mock_session.post.return_value.__aenter__.return_value = mock_response + mock_session.post.return_value.__aexit__.return_value = None + mock_session_class.return_value = mock_session + + mock_async_init.return_value = None + mock_llm_init.return_value = None + + config = { + 'model': 'TheBloke/Mistral-7B-v0.1-AWQ', + 'url': 'http://vllm-service:8899/v1', + 'temperature': 0.0, + 'max_output': 2048, + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-processor' + } + + processor = Processor(**config) + + # Act - Override both parameters + result = await processor.generate_content("System", "Prompt", model="override-model", temperature=0.8) + + # Assert + assert result.model == "override-model" + assert result.text == "Response with both parameters override" + + # Verify the request was made with overridden temperature + call_args = mock_session.post.call_args + assert call_args[1]['json']['temperature'] == 0.8 + + if __name__ == '__main__': pytest.main([__file__]) \ No newline at end of file