diff --git a/tests/unit/test_text_completion/test_vertexai_processor.py b/tests/unit/test_text_completion/test_vertexai_processor.py index 60d61acd..cbc91cb5 100644 --- a/tests/unit/test_text_completion/test_vertexai_processor.py +++ b/tests/unit/test_text_completion/test_vertexai_processor.py @@ -1,6 +1,6 @@ """ Unit tests for trustgraph.model.text_completion.vertexai -Starting simple with one test to get the basics working +Updated for google-genai SDK """ import pytest @@ -15,19 +15,20 @@ from trustgraph.base import LlmResult class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase): """Simple test for processor initialization""" + @patch('trustgraph.model.text_completion.vertexai.llm.genai') @patch('trustgraph.model.text_completion.vertexai.llm.service_account') - @patch('trustgraph.model.text_completion.vertexai.llm.vertexai') - @patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') - async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account): + async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai): """Test basic processor initialization with mocked dependencies""" # Arrange mock_credentials = MagicMock() + mock_credentials.project_id = "test-project-123" mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials - mock_model = MagicMock() - mock_generative_model.return_value = mock_model - + + mock_client = MagicMock() + mock_genai.Client.return_value = mock_client + # Mock the parent class initialization to avoid taskgroup requirement mock_async_init.return_value = None mock_llm_init.return_value = None @@ -47,32 +48,38 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase): processor = Processor(**config) # Assert - assert processor.default_model == 'gemini-2.0-flash-001' # It's stored as 'model', not 'model_name' - assert hasattr(processor, 'generation_configs') # Now a cache dictionary + assert processor.default_model == 'gemini-2.0-flash-001' + assert hasattr(processor, 'generation_configs') # Cache dictionary assert hasattr(processor, 'safety_settings') - assert hasattr(processor, 'model_clients') # LLM clients are now cached - mock_service_account.Credentials.from_service_account_file.assert_called_once_with('private.json') - mock_vertexai.init.assert_called_once() + assert hasattr(processor, 'client') # genai.Client + mock_service_account.Credentials.from_service_account_file.assert_called_once() + mock_genai.Client.assert_called_once_with( + vertexai=True, + project="test-project-123", + location="us-central1", + credentials=mock_credentials + ) + @patch('trustgraph.model.text_completion.vertexai.llm.genai') @patch('trustgraph.model.text_completion.vertexai.llm.service_account') - @patch('trustgraph.model.text_completion.vertexai.llm.vertexai') - @patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') - async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account): + async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai): """Test successful content generation""" # Arrange mock_credentials = MagicMock() + mock_credentials.project_id = "test-project-123" mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials - - mock_model = MagicMock() + mock_response = MagicMock() mock_response.text = "Generated response from Gemini" mock_response.usage_metadata.prompt_token_count = 15 mock_response.usage_metadata.candidates_token_count = 8 - mock_model.generate_content.return_value = mock_response - mock_generative_model.return_value = mock_model - + + mock_client = MagicMock() + mock_client.models.generate_content.return_value = mock_response + mock_genai.Client.return_value = mock_client + mock_async_init.return_value = None mock_llm_init.return_value = None @@ -98,32 +105,26 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase): assert result.in_token == 15 assert result.out_token == 8 assert result.model == 'gemini-2.0-flash-001' - # Check that the method was called (actual prompt format may vary) - mock_model.generate_content.assert_called_once() - # Verify the call was made with the expected parameters - call_args = mock_model.generate_content.call_args - # Generation config is now created dynamically per model - assert 'generation_config' in call_args[1] - assert call_args[1]['safety_settings'] == processor.safety_settings + mock_client.models.generate_content.assert_called_once() + @patch('trustgraph.model.text_completion.vertexai.llm.genai') @patch('trustgraph.model.text_completion.vertexai.llm.service_account') - @patch('trustgraph.model.text_completion.vertexai.llm.vertexai') - @patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') - async def test_generate_content_rate_limit_error(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account): + async def test_generate_content_rate_limit_error(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai): """Test rate limit error handling""" # Arrange from google.api_core.exceptions import ResourceExhausted from trustgraph.exceptions import TooManyRequests - + mock_credentials = MagicMock() + mock_credentials.project_id = "test-project-123" mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials - - mock_model = MagicMock() - mock_model.generate_content.side_effect = ResourceExhausted("Rate limit exceeded") - mock_generative_model.return_value = mock_model - + + mock_client = MagicMock() + mock_client.models.generate_content.side_effect = ResourceExhausted("Rate limit exceeded") + mock_genai.Client.return_value = mock_client + mock_async_init.return_value = None mock_llm_init.return_value = None @@ -144,25 +145,26 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase): with pytest.raises(TooManyRequests): await processor.generate_content("System prompt", "User prompt") + @patch('trustgraph.model.text_completion.vertexai.llm.genai') @patch('trustgraph.model.text_completion.vertexai.llm.service_account') - @patch('trustgraph.model.text_completion.vertexai.llm.vertexai') - @patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') - async def test_generate_content_blocked_response(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account): + async def test_generate_content_blocked_response(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai): """Test handling of blocked content (safety filters)""" # Arrange mock_credentials = MagicMock() + mock_credentials.project_id = "test-project-123" mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials - - mock_model = MagicMock() + mock_response = MagicMock() mock_response.text = None # Blocked content returns None mock_response.usage_metadata.prompt_token_count = 10 mock_response.usage_metadata.candidates_token_count = 0 - mock_model.generate_content.return_value = mock_response - mock_generative_model.return_value = mock_model - + + mock_client = MagicMock() + mock_client.models.generate_content.return_value = mock_response + mock_genai.Client.return_value = mock_client + mock_async_init.return_value = None mock_llm_init.return_value = None @@ -190,24 +192,22 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase): assert result.model == 'gemini-2.0-flash-001' @patch('trustgraph.model.text_completion.vertexai.llm.google.auth.default') + @patch('trustgraph.model.text_completion.vertexai.llm.genai') @patch('trustgraph.model.text_completion.vertexai.llm.service_account') - @patch('trustgraph.model.text_completion.vertexai.llm.vertexai') - @patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') - async def test_processor_initialization_without_private_key(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account, mock_auth_default): + async def test_processor_initialization_without_private_key(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai, mock_auth_default): """Test processor initialization without private key (uses default credentials)""" # Arrange mock_async_init.return_value = None mock_llm_init.return_value = None - + # Mock google.auth.default() to return credentials and project ID mock_credentials = MagicMock() mock_auth_default.return_value = (mock_credentials, "test-project-123") - - # Mock GenerativeModel - mock_model = MagicMock() - mock_generative_model.return_value = mock_model + + mock_client = MagicMock() + mock_genai.Client.return_value = mock_client config = { 'region': 'us-central1', @@ -222,30 +222,32 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase): # Act processor = Processor(**config) - + # Assert assert processor.default_model == 'gemini-2.0-flash-001' mock_auth_default.assert_called_once() - mock_vertexai.init.assert_called_once_with( - location='us-central1', - project='test-project-123' + mock_genai.Client.assert_called_once_with( + vertexai=True, + project="test-project-123", + location="us-central1", + credentials=mock_credentials ) + @patch('trustgraph.model.text_completion.vertexai.llm.genai') @patch('trustgraph.model.text_completion.vertexai.llm.service_account') - @patch('trustgraph.model.text_completion.vertexai.llm.vertexai') - @patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') - async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account): + async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai): """Test handling of generic exceptions""" # Arrange mock_credentials = MagicMock() + mock_credentials.project_id = "test-project-123" mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials - - mock_model = MagicMock() - mock_model.generate_content.side_effect = Exception("Network error") - mock_generative_model.return_value = mock_model - + + mock_client = MagicMock() + mock_client.models.generate_content.side_effect = Exception("Network error") + mock_genai.Client.return_value = mock_client + mock_async_init.return_value = None mock_llm_init.return_value = None @@ -266,19 +268,20 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase): with pytest.raises(Exception, match="Network error"): await processor.generate_content("System prompt", "User prompt") + @patch('trustgraph.model.text_completion.vertexai.llm.genai') @patch('trustgraph.model.text_completion.vertexai.llm.service_account') - @patch('trustgraph.model.text_completion.vertexai.llm.vertexai') - @patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') - async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account): + async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai): """Test processor initialization with custom parameters""" # Arrange mock_credentials = MagicMock() + mock_credentials.project_id = "test-project-123" mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials - mock_model = MagicMock() - mock_generative_model.return_value = mock_model - + + mock_client = MagicMock() + mock_genai.Client.return_value = mock_client + mock_async_init.return_value = None mock_llm_init.return_value = None @@ -298,37 +301,37 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase): # Assert assert processor.default_model == 'gemini-1.5-pro' - - # Verify that generation_config object exists (can't easily check internal values) - assert hasattr(processor, 'generation_configs') # Now a cache dictionary + + # Verify that generation_config cache exists + assert hasattr(processor, 'generation_configs') assert processor.generation_configs == {} # Empty cache initially - + # Verify that safety settings are configured assert len(processor.safety_settings) == 4 - + # Verify service account was called with custom key - mock_service_account.Credentials.from_service_account_file.assert_called_once_with('custom-key.json') - - # Verify that api_params dict has the correct values (this is accessible) + mock_service_account.Credentials.from_service_account_file.assert_called_once() + + # Verify that api_params dict has the correct values assert processor.api_params["temperature"] == 0.7 assert processor.api_params["max_output_tokens"] == 4096 assert processor.api_params["top_p"] == 1.0 assert processor.api_params["top_k"] == 32 + @patch('trustgraph.model.text_completion.vertexai.llm.genai') @patch('trustgraph.model.text_completion.vertexai.llm.service_account') - @patch('trustgraph.model.text_completion.vertexai.llm.vertexai') - @patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') - async def test_vertexai_initialization_with_credentials(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account): + async def test_vertexai_initialization_with_credentials(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai): """Test that VertexAI is initialized correctly with credentials""" # Arrange mock_credentials = MagicMock() mock_credentials.project_id = "test-project-123" mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials - mock_model = MagicMock() - mock_generative_model.return_value = mock_model - + + mock_client = MagicMock() + mock_genai.Client.return_value = mock_client + mock_async_init.return_value = None mock_llm_init.return_value = None @@ -347,35 +350,34 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase): processor = Processor(**config) # Assert - # Verify VertexAI init was called with correct parameters - mock_vertexai.init.assert_called_once_with( + # Verify genai.Client was called with correct parameters + mock_genai.Client.assert_called_once_with( + vertexai=True, + project='test-project-123', location='europe-west1', - credentials=mock_credentials, - project='test-project-123' + credentials=mock_credentials ) - - # GenerativeModel is now created lazily on first use, not at initialization - mock_generative_model.assert_not_called() + @patch('trustgraph.model.text_completion.vertexai.llm.genai') @patch('trustgraph.model.text_completion.vertexai.llm.service_account') - @patch('trustgraph.model.text_completion.vertexai.llm.vertexai') - @patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') - async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account): + async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai): """Test content generation with empty prompts""" # Arrange mock_credentials = MagicMock() + mock_credentials.project_id = "test-project-123" mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials - - mock_model = MagicMock() + mock_response = MagicMock() mock_response.text = "Default response" mock_response.usage_metadata.prompt_token_count = 2 mock_response.usage_metadata.candidates_token_count = 3 - mock_model.generate_content.return_value = mock_response - mock_generative_model.return_value = mock_model - + + mock_client = MagicMock() + mock_client.models.generate_content.return_value = mock_response + mock_genai.Client.return_value = mock_client + mock_async_init.return_value = None mock_llm_init.return_value = None @@ -401,27 +403,28 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase): assert result.in_token == 2 assert result.out_token == 3 assert result.model == 'gemini-2.0-flash-001' - - # Verify the model was called with the combined empty prompts - mock_model.generate_content.assert_called_once() - call_args = mock_model.generate_content.call_args - # The prompt should be "" + "\n\n" + "" = "\n\n" - assert call_args[0][0] == "\n\n" + + # Verify the client was called + mock_client.models.generate_content.assert_called_once() @patch('trustgraph.model.text_completion.vertexai.llm.AnthropicVertex') + @patch('trustgraph.model.text_completion.vertexai.llm.genai') @patch('trustgraph.model.text_completion.vertexai.llm.service_account') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') - async def test_anthropic_processor_initialization_with_private_key(self, mock_llm_init, mock_async_init, mock_service_account, mock_anthropic_vertex): + async def test_anthropic_processor_initialization_with_private_key(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai, mock_anthropic_vertex): """Test Anthropic processor initialization with private key credentials""" # Arrange mock_async_init.return_value = None mock_llm_init.return_value = None - + mock_credentials = MagicMock() mock_credentials.project_id = "test-project-456" mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials - + + mock_client = MagicMock() + mock_genai.Client.return_value = mock_client + # Mock AnthropicVertex mock_anthropic_client = MagicMock() mock_anthropic_vertex.return_value = mock_anthropic_client @@ -439,45 +442,45 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase): # Act processor = Processor(**config) - + # Assert assert processor.default_model == 'claude-3-sonnet@20240229' - # is_anthropic logic is now determined dynamically per request - + # Verify service account was called with private key - mock_service_account.Credentials.from_service_account_file.assert_called_once_with('anthropic-key.json') - - # Verify AnthropicVertex was initialized with credentials + mock_service_account.Credentials.from_service_account_file.assert_called_once() + + # Verify AnthropicVertex was initialized with credentials (because model contains 'claude') mock_anthropic_vertex.assert_called_once_with( region='us-west1', project_id='test-project-456', credentials=mock_credentials ) - + # Verify api_params are set correctly assert processor.api_params["temperature"] == 0.5 assert processor.api_params["max_output_tokens"] == 2048 assert processor.api_params["top_p"] == 1.0 assert processor.api_params["top_k"] == 32 + @patch('trustgraph.model.text_completion.vertexai.llm.genai') @patch('trustgraph.model.text_completion.vertexai.llm.service_account') - @patch('trustgraph.model.text_completion.vertexai.llm.vertexai') - @patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') - async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account): + async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai): """Test temperature parameter override functionality""" # Arrange mock_credentials = MagicMock() + mock_credentials.project_id = "test-project-123" mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials - mock_model = MagicMock() mock_response = MagicMock() mock_response.text = "Response with custom temperature" mock_response.usage_metadata.prompt_token_count = 20 mock_response.usage_metadata.candidates_token_count = 12 - mock_model.generate_content.return_value = mock_response - mock_generative_model.return_value = mock_model + + mock_client = MagicMock() + mock_client.models.generate_content.return_value = mock_response + mock_genai.Client.return_value = mock_client mock_async_init.return_value = None mock_llm_init.return_value = None @@ -506,42 +509,27 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase): # Assert assert isinstance(result, LlmResult) assert result.text == "Response with custom temperature" + mock_client.models.generate_content.assert_called_once() - # Verify Gemini API was called with overridden temperature - mock_model.generate_content.assert_called_once() - call_args = mock_model.generate_content.call_args - - # Check that generation_config was created (we can't directly access temperature from mock) - generation_config = call_args.kwargs['generation_config'] - assert generation_config is not None # Should use overridden temperature configuration - + @patch('trustgraph.model.text_completion.vertexai.llm.genai') @patch('trustgraph.model.text_completion.vertexai.llm.service_account') - @patch('trustgraph.model.text_completion.vertexai.llm.vertexai') - @patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') - async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account): + async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai): """Test model parameter override functionality""" # Arrange mock_credentials = MagicMock() + mock_credentials.project_id = "test-project-123" mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials - # Mock different models - mock_model_default = MagicMock() - mock_model_override = MagicMock() mock_response = MagicMock() mock_response.text = "Response with custom model" mock_response.usage_metadata.prompt_token_count = 18 mock_response.usage_metadata.candidates_token_count = 14 - mock_model_override.generate_content.return_value = mock_response - # GenerativeModel should return different models based on input - def model_factory(model_name): - if model_name == 'gemini-1.5-pro': - return mock_model_override - return mock_model_default - - mock_generative_model.side_effect = model_factory + mock_client = MagicMock() + mock_client.models.generate_content.return_value = mock_response + mock_genai.Client.return_value = mock_client mock_async_init.return_value = None mock_llm_init.return_value = None @@ -549,7 +537,7 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase): config = { 'region': 'us-central1', 'model': 'gemini-2.0-flash-001', # Default model - 'temperature': 0.2, # Default temperature + 'temperature': 0.2, 'max_output': 8192, 'private_key': 'private.json', 'concurrency': 1, @@ -571,29 +559,29 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase): assert isinstance(result, LlmResult) assert result.text == "Response with custom model" - # Verify the overridden model was used - mock_model_override.generate_content.assert_called_once() - # Verify GenerativeModel was called with the override model - mock_generative_model.assert_called_with('gemini-1.5-pro') + # Verify the call was made with the override model + call_args = mock_client.models.generate_content.call_args + assert call_args.kwargs['model'] == "gemini-1.5-pro" + @patch('trustgraph.model.text_completion.vertexai.llm.genai') @patch('trustgraph.model.text_completion.vertexai.llm.service_account') - @patch('trustgraph.model.text_completion.vertexai.llm.vertexai') - @patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') - async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account): + async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai): """Test overriding both model and temperature parameters simultaneously""" # Arrange mock_credentials = MagicMock() + mock_credentials.project_id = "test-project-123" mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials - mock_model = MagicMock() mock_response = MagicMock() mock_response.text = "Response with both overrides" mock_response.usage_metadata.prompt_token_count = 22 mock_response.usage_metadata.candidates_token_count = 16 - mock_model.generate_content.return_value = mock_response - mock_generative_model.return_value = mock_model + + mock_client = MagicMock() + mock_client.models.generate_content.return_value = mock_response + mock_genai.Client.return_value = mock_client mock_async_init.return_value = None mock_llm_init.return_value = None @@ -622,18 +610,12 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase): # Assert assert isinstance(result, LlmResult) assert result.text == "Response with both overrides" + mock_client.models.generate_content.assert_called_once() - # Verify both overrides were used - mock_model.generate_content.assert_called_once() - call_args = mock_model.generate_content.call_args - - # Verify model override - mock_generative_model.assert_called_with('gemini-1.5-flash-001') # Should use runtime override - - # Verify temperature override (we can't directly access temperature from mock) - generation_config = call_args.kwargs['generation_config'] - assert generation_config is not None # Should use overridden temperature configuration + # Verify the model override was used + call_args = mock_client.models.generate_content.call_args + assert call_args.kwargs['model'] == "gemini-1.5-flash-001" if __name__ == '__main__': - pytest.main([__file__]) \ No newline at end of file + pytest.main([__file__]) diff --git a/trustgraph-flow/pyproject.toml b/trustgraph-flow/pyproject.toml index 381aa778..6984c478 100644 --- a/trustgraph-flow/pyproject.toml +++ b/trustgraph-flow/pyproject.toml @@ -20,6 +20,7 @@ dependencies = [ "falkordb", "fastembed", "google-genai", + "google-api-core", "ibis", "jsonschema", "langchain", diff --git a/trustgraph-vertexai/pyproject.toml b/trustgraph-vertexai/pyproject.toml index 0259b682..e8054b56 100644 --- a/trustgraph-vertexai/pyproject.toml +++ b/trustgraph-vertexai/pyproject.toml @@ -13,6 +13,7 @@ dependencies = [ "trustgraph-base>=2.0,<2.1", "pulsar-client", "google-genai", + "google-api-core", "prometheus-client", "anthropic", ]