diff --git a/tests/unit/test_text_completion/test_vertexai_processor.py b/tests/unit/test_text_completion/test_vertexai_processor.py index 60d61acd..cbc91cb5 100644 --- a/tests/unit/test_text_completion/test_vertexai_processor.py +++ b/tests/unit/test_text_completion/test_vertexai_processor.py @@ -1,6 +1,6 @@ """ Unit tests for trustgraph.model.text_completion.vertexai -Starting simple with one test to get the basics working +Updated for google-genai SDK """ import pytest @@ -15,19 +15,20 @@ from trustgraph.base import LlmResult class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase): """Simple test for processor initialization""" + @patch('trustgraph.model.text_completion.vertexai.llm.genai') @patch('trustgraph.model.text_completion.vertexai.llm.service_account') - @patch('trustgraph.model.text_completion.vertexai.llm.vertexai') - @patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') - async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account): + async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai): """Test basic processor initialization with mocked dependencies""" # Arrange mock_credentials = MagicMock() + mock_credentials.project_id = "test-project-123" mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials - mock_model = MagicMock() - mock_generative_model.return_value = mock_model - + + mock_client = MagicMock() + mock_genai.Client.return_value = mock_client + # Mock the parent class initialization to avoid taskgroup requirement mock_async_init.return_value = None mock_llm_init.return_value = None @@ -47,32 +48,38 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase): processor = Processor(**config) # Assert - assert processor.default_model == 'gemini-2.0-flash-001' # It's stored as 'model', not 'model_name' - assert hasattr(processor, 'generation_configs') # Now a cache dictionary + assert processor.default_model == 'gemini-2.0-flash-001' + assert hasattr(processor, 'generation_configs') # Cache dictionary assert hasattr(processor, 'safety_settings') - assert hasattr(processor, 'model_clients') # LLM clients are now cached - mock_service_account.Credentials.from_service_account_file.assert_called_once_with('private.json') - mock_vertexai.init.assert_called_once() + assert hasattr(processor, 'client') # genai.Client + mock_service_account.Credentials.from_service_account_file.assert_called_once() + mock_genai.Client.assert_called_once_with( + vertexai=True, + project="test-project-123", + location="us-central1", + credentials=mock_credentials + ) + @patch('trustgraph.model.text_completion.vertexai.llm.genai') @patch('trustgraph.model.text_completion.vertexai.llm.service_account') - @patch('trustgraph.model.text_completion.vertexai.llm.vertexai') - @patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') - async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account): + async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai): """Test successful content generation""" # Arrange mock_credentials = MagicMock() + mock_credentials.project_id = "test-project-123" mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials - - mock_model = MagicMock() + mock_response = MagicMock() mock_response.text = "Generated response from Gemini" mock_response.usage_metadata.prompt_token_count = 15 mock_response.usage_metadata.candidates_token_count = 8 - mock_model.generate_content.return_value = mock_response - mock_generative_model.return_value = mock_model - + + mock_client = MagicMock() + mock_client.models.generate_content.return_value = mock_response + mock_genai.Client.return_value = mock_client + mock_async_init.return_value = None mock_llm_init.return_value = None @@ -98,32 +105,26 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase): assert result.in_token == 15 assert result.out_token == 8 assert result.model == 'gemini-2.0-flash-001' - # Check that the method was called (actual prompt format may vary) - mock_model.generate_content.assert_called_once() - # Verify the call was made with the expected parameters - call_args = mock_model.generate_content.call_args - # Generation config is now created dynamically per model - assert 'generation_config' in call_args[1] - assert call_args[1]['safety_settings'] == processor.safety_settings + mock_client.models.generate_content.assert_called_once() + @patch('trustgraph.model.text_completion.vertexai.llm.genai') @patch('trustgraph.model.text_completion.vertexai.llm.service_account') - @patch('trustgraph.model.text_completion.vertexai.llm.vertexai') - @patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') - async def test_generate_content_rate_limit_error(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account): + async def test_generate_content_rate_limit_error(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai): """Test rate limit error handling""" # Arrange from google.api_core.exceptions import ResourceExhausted from trustgraph.exceptions import TooManyRequests - + mock_credentials = MagicMock() + mock_credentials.project_id = "test-project-123" mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials - - mock_model = MagicMock() - mock_model.generate_content.side_effect = ResourceExhausted("Rate limit exceeded") - mock_generative_model.return_value = mock_model - + + mock_client = MagicMock() + mock_client.models.generate_content.side_effect = ResourceExhausted("Rate limit exceeded") + mock_genai.Client.return_value = mock_client + mock_async_init.return_value = None mock_llm_init.return_value = None @@ -144,25 +145,26 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase): with pytest.raises(TooManyRequests): await processor.generate_content("System prompt", "User prompt") + @patch('trustgraph.model.text_completion.vertexai.llm.genai') @patch('trustgraph.model.text_completion.vertexai.llm.service_account') - @patch('trustgraph.model.text_completion.vertexai.llm.vertexai') - @patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') - async def test_generate_content_blocked_response(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account): + async def test_generate_content_blocked_response(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai): """Test handling of blocked content (safety filters)""" # Arrange mock_credentials = MagicMock() + mock_credentials.project_id = "test-project-123" mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials - - mock_model = MagicMock() + mock_response = MagicMock() mock_response.text = None # Blocked content returns None mock_response.usage_metadata.prompt_token_count = 10 mock_response.usage_metadata.candidates_token_count = 0 - mock_model.generate_content.return_value = mock_response - mock_generative_model.return_value = mock_model - + + mock_client = MagicMock() + mock_client.models.generate_content.return_value = mock_response + mock_genai.Client.return_value = mock_client + mock_async_init.return_value = None mock_llm_init.return_value = None @@ -190,24 +192,22 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase): assert result.model == 'gemini-2.0-flash-001' @patch('trustgraph.model.text_completion.vertexai.llm.google.auth.default') + @patch('trustgraph.model.text_completion.vertexai.llm.genai') @patch('trustgraph.model.text_completion.vertexai.llm.service_account') - @patch('trustgraph.model.text_completion.vertexai.llm.vertexai') - @patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') - async def test_processor_initialization_without_private_key(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account, mock_auth_default): + async def test_processor_initialization_without_private_key(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai, mock_auth_default): """Test processor initialization without private key (uses default credentials)""" # Arrange mock_async_init.return_value = None mock_llm_init.return_value = None - + # Mock google.auth.default() to return credentials and project ID mock_credentials = MagicMock() mock_auth_default.return_value = (mock_credentials, "test-project-123") - - # Mock GenerativeModel - mock_model = MagicMock() - mock_generative_model.return_value = mock_model + + mock_client = MagicMock() + mock_genai.Client.return_value = mock_client config = { 'region': 'us-central1', @@ -222,30 +222,32 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase): # Act processor = Processor(**config) - + # Assert assert processor.default_model == 'gemini-2.0-flash-001' mock_auth_default.assert_called_once() - mock_vertexai.init.assert_called_once_with( - location='us-central1', - project='test-project-123' + mock_genai.Client.assert_called_once_with( + vertexai=True, + project="test-project-123", + location="us-central1", + credentials=mock_credentials ) + @patch('trustgraph.model.text_completion.vertexai.llm.genai') @patch('trustgraph.model.text_completion.vertexai.llm.service_account') - @patch('trustgraph.model.text_completion.vertexai.llm.vertexai') - @patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') - async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account): + async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai): """Test handling of generic exceptions""" # Arrange mock_credentials = MagicMock() + mock_credentials.project_id = "test-project-123" mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials - - mock_model = MagicMock() - mock_model.generate_content.side_effect = Exception("Network error") - mock_generative_model.return_value = mock_model - + + mock_client = MagicMock() + mock_client.models.generate_content.side_effect = Exception("Network error") + mock_genai.Client.return_value = mock_client + mock_async_init.return_value = None mock_llm_init.return_value = None @@ -266,19 +268,20 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase): with pytest.raises(Exception, match="Network error"): await processor.generate_content("System prompt", "User prompt") + @patch('trustgraph.model.text_completion.vertexai.llm.genai') @patch('trustgraph.model.text_completion.vertexai.llm.service_account') - @patch('trustgraph.model.text_completion.vertexai.llm.vertexai') - @patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') - async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account): + async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai): """Test processor initialization with custom parameters""" # Arrange mock_credentials = MagicMock() + mock_credentials.project_id = "test-project-123" mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials - mock_model = MagicMock() - mock_generative_model.return_value = mock_model - + + mock_client = MagicMock() + mock_genai.Client.return_value = mock_client + mock_async_init.return_value = None mock_llm_init.return_value = None @@ -298,37 +301,37 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase): # Assert assert processor.default_model == 'gemini-1.5-pro' - - # Verify that generation_config object exists (can't easily check internal values) - assert hasattr(processor, 'generation_configs') # Now a cache dictionary + + # Verify that generation_config cache exists + assert hasattr(processor, 'generation_configs') assert processor.generation_configs == {} # Empty cache initially - + # Verify that safety settings are configured assert len(processor.safety_settings) == 4 - + # Verify service account was called with custom key - mock_service_account.Credentials.from_service_account_file.assert_called_once_with('custom-key.json') - - # Verify that api_params dict has the correct values (this is accessible) + mock_service_account.Credentials.from_service_account_file.assert_called_once() + + # Verify that api_params dict has the correct values assert processor.api_params["temperature"] == 0.7 assert processor.api_params["max_output_tokens"] == 4096 assert processor.api_params["top_p"] == 1.0 assert processor.api_params["top_k"] == 32 + @patch('trustgraph.model.text_completion.vertexai.llm.genai') @patch('trustgraph.model.text_completion.vertexai.llm.service_account') - @patch('trustgraph.model.text_completion.vertexai.llm.vertexai') - @patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') - async def test_vertexai_initialization_with_credentials(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account): + async def test_vertexai_initialization_with_credentials(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai): """Test that VertexAI is initialized correctly with credentials""" # Arrange mock_credentials = MagicMock() mock_credentials.project_id = "test-project-123" mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials - mock_model = MagicMock() - mock_generative_model.return_value = mock_model - + + mock_client = MagicMock() + mock_genai.Client.return_value = mock_client + mock_async_init.return_value = None mock_llm_init.return_value = None @@ -347,35 +350,34 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase): processor = Processor(**config) # Assert - # Verify VertexAI init was called with correct parameters - mock_vertexai.init.assert_called_once_with( + # Verify genai.Client was called with correct parameters + mock_genai.Client.assert_called_once_with( + vertexai=True, + project='test-project-123', location='europe-west1', - credentials=mock_credentials, - project='test-project-123' + credentials=mock_credentials ) - - # GenerativeModel is now created lazily on first use, not at initialization - mock_generative_model.assert_not_called() + @patch('trustgraph.model.text_completion.vertexai.llm.genai') @patch('trustgraph.model.text_completion.vertexai.llm.service_account') - @patch('trustgraph.model.text_completion.vertexai.llm.vertexai') - @patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') - async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account): + async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai): """Test content generation with empty prompts""" # Arrange mock_credentials = MagicMock() + mock_credentials.project_id = "test-project-123" mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials - - mock_model = MagicMock() + mock_response = MagicMock() mock_response.text = "Default response" mock_response.usage_metadata.prompt_token_count = 2 mock_response.usage_metadata.candidates_token_count = 3 - mock_model.generate_content.return_value = mock_response - mock_generative_model.return_value = mock_model - + + mock_client = MagicMock() + mock_client.models.generate_content.return_value = mock_response + mock_genai.Client.return_value = mock_client + mock_async_init.return_value = None mock_llm_init.return_value = None @@ -401,27 +403,28 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase): assert result.in_token == 2 assert result.out_token == 3 assert result.model == 'gemini-2.0-flash-001' - - # Verify the model was called with the combined empty prompts - mock_model.generate_content.assert_called_once() - call_args = mock_model.generate_content.call_args - # The prompt should be "" + "\n\n" + "" = "\n\n" - assert call_args[0][0] == "\n\n" + + # Verify the client was called + mock_client.models.generate_content.assert_called_once() @patch('trustgraph.model.text_completion.vertexai.llm.AnthropicVertex') + @patch('trustgraph.model.text_completion.vertexai.llm.genai') @patch('trustgraph.model.text_completion.vertexai.llm.service_account') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') - async def test_anthropic_processor_initialization_with_private_key(self, mock_llm_init, mock_async_init, mock_service_account, mock_anthropic_vertex): + async def test_anthropic_processor_initialization_with_private_key(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai, mock_anthropic_vertex): """Test Anthropic processor initialization with private key credentials""" # Arrange mock_async_init.return_value = None mock_llm_init.return_value = None - + mock_credentials = MagicMock() mock_credentials.project_id = "test-project-456" mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials - + + mock_client = MagicMock() + mock_genai.Client.return_value = mock_client + # Mock AnthropicVertex mock_anthropic_client = MagicMock() mock_anthropic_vertex.return_value = mock_anthropic_client @@ -439,45 +442,45 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase): # Act processor = Processor(**config) - + # Assert assert processor.default_model == 'claude-3-sonnet@20240229' - # is_anthropic logic is now determined dynamically per request - + # Verify service account was called with private key - mock_service_account.Credentials.from_service_account_file.assert_called_once_with('anthropic-key.json') - - # Verify AnthropicVertex was initialized with credentials + mock_service_account.Credentials.from_service_account_file.assert_called_once() + + # Verify AnthropicVertex was initialized with credentials (because model contains 'claude') mock_anthropic_vertex.assert_called_once_with( region='us-west1', project_id='test-project-456', credentials=mock_credentials ) - + # Verify api_params are set correctly assert processor.api_params["temperature"] == 0.5 assert processor.api_params["max_output_tokens"] == 2048 assert processor.api_params["top_p"] == 1.0 assert processor.api_params["top_k"] == 32 + @patch('trustgraph.model.text_completion.vertexai.llm.genai') @patch('trustgraph.model.text_completion.vertexai.llm.service_account') - @patch('trustgraph.model.text_completion.vertexai.llm.vertexai') - @patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') - async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account): + async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai): """Test temperature parameter override functionality""" # Arrange mock_credentials = MagicMock() + mock_credentials.project_id = "test-project-123" mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials - mock_model = MagicMock() mock_response = MagicMock() mock_response.text = "Response with custom temperature" mock_response.usage_metadata.prompt_token_count = 20 mock_response.usage_metadata.candidates_token_count = 12 - mock_model.generate_content.return_value = mock_response - mock_generative_model.return_value = mock_model + + mock_client = MagicMock() + mock_client.models.generate_content.return_value = mock_response + mock_genai.Client.return_value = mock_client mock_async_init.return_value = None mock_llm_init.return_value = None @@ -506,42 +509,27 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase): # Assert assert isinstance(result, LlmResult) assert result.text == "Response with custom temperature" + mock_client.models.generate_content.assert_called_once() - # Verify Gemini API was called with overridden temperature - mock_model.generate_content.assert_called_once() - call_args = mock_model.generate_content.call_args - - # Check that generation_config was created (we can't directly access temperature from mock) - generation_config = call_args.kwargs['generation_config'] - assert generation_config is not None # Should use overridden temperature configuration - + @patch('trustgraph.model.text_completion.vertexai.llm.genai') @patch('trustgraph.model.text_completion.vertexai.llm.service_account') - @patch('trustgraph.model.text_completion.vertexai.llm.vertexai') - @patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') - async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account): + async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai): """Test model parameter override functionality""" # Arrange mock_credentials = MagicMock() + mock_credentials.project_id = "test-project-123" mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials - # Mock different models - mock_model_default = MagicMock() - mock_model_override = MagicMock() mock_response = MagicMock() mock_response.text = "Response with custom model" mock_response.usage_metadata.prompt_token_count = 18 mock_response.usage_metadata.candidates_token_count = 14 - mock_model_override.generate_content.return_value = mock_response - # GenerativeModel should return different models based on input - def model_factory(model_name): - if model_name == 'gemini-1.5-pro': - return mock_model_override - return mock_model_default - - mock_generative_model.side_effect = model_factory + mock_client = MagicMock() + mock_client.models.generate_content.return_value = mock_response + mock_genai.Client.return_value = mock_client mock_async_init.return_value = None mock_llm_init.return_value = None @@ -549,7 +537,7 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase): config = { 'region': 'us-central1', 'model': 'gemini-2.0-flash-001', # Default model - 'temperature': 0.2, # Default temperature + 'temperature': 0.2, 'max_output': 8192, 'private_key': 'private.json', 'concurrency': 1, @@ -571,29 +559,29 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase): assert isinstance(result, LlmResult) assert result.text == "Response with custom model" - # Verify the overridden model was used - mock_model_override.generate_content.assert_called_once() - # Verify GenerativeModel was called with the override model - mock_generative_model.assert_called_with('gemini-1.5-pro') + # Verify the call was made with the override model + call_args = mock_client.models.generate_content.call_args + assert call_args.kwargs['model'] == "gemini-1.5-pro" + @patch('trustgraph.model.text_completion.vertexai.llm.genai') @patch('trustgraph.model.text_completion.vertexai.llm.service_account') - @patch('trustgraph.model.text_completion.vertexai.llm.vertexai') - @patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') - async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account): + async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai): """Test overriding both model and temperature parameters simultaneously""" # Arrange mock_credentials = MagicMock() + mock_credentials.project_id = "test-project-123" mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials - mock_model = MagicMock() mock_response = MagicMock() mock_response.text = "Response with both overrides" mock_response.usage_metadata.prompt_token_count = 22 mock_response.usage_metadata.candidates_token_count = 16 - mock_model.generate_content.return_value = mock_response - mock_generative_model.return_value = mock_model + + mock_client = MagicMock() + mock_client.models.generate_content.return_value = mock_response + mock_genai.Client.return_value = mock_client mock_async_init.return_value = None mock_llm_init.return_value = None @@ -622,18 +610,12 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase): # Assert assert isinstance(result, LlmResult) assert result.text == "Response with both overrides" + mock_client.models.generate_content.assert_called_once() - # Verify both overrides were used - mock_model.generate_content.assert_called_once() - call_args = mock_model.generate_content.call_args - - # Verify model override - mock_generative_model.assert_called_with('gemini-1.5-flash-001') # Should use runtime override - - # Verify temperature override (we can't directly access temperature from mock) - generation_config = call_args.kwargs['generation_config'] - assert generation_config is not None # Should use overridden temperature configuration + # Verify the model override was used + call_args = mock_client.models.generate_content.call_args + assert call_args.kwargs['model'] == "gemini-1.5-flash-001" if __name__ == '__main__': - pytest.main([__file__]) \ No newline at end of file + pytest.main([__file__]) diff --git a/trustgraph-flow/pyproject.toml b/trustgraph-flow/pyproject.toml index 381aa778..6984c478 100644 --- a/trustgraph-flow/pyproject.toml +++ b/trustgraph-flow/pyproject.toml @@ -20,6 +20,7 @@ dependencies = [ "falkordb", "fastembed", "google-genai", + "google-api-core", "ibis", "jsonschema", "langchain", diff --git a/trustgraph-vertexai/pyproject.toml b/trustgraph-vertexai/pyproject.toml index e5166a19..e8054b56 100644 --- a/trustgraph-vertexai/pyproject.toml +++ b/trustgraph-vertexai/pyproject.toml @@ -12,7 +12,8 @@ requires-python = ">=3.8" dependencies = [ "trustgraph-base>=2.0,<2.1", "pulsar-client", - "google-cloud-aiplatform", + "google-genai", + "google-api-core", "prometheus-client", "anthropic", ] diff --git a/trustgraph-vertexai/trustgraph/model/text_completion/vertexai/llm.py b/trustgraph-vertexai/trustgraph/model/text_completion/vertexai/llm.py index 5cf17b4d..59aa5bfe 100755 --- a/trustgraph-vertexai/trustgraph/model/text_completion/vertexai/llm.py +++ b/trustgraph-vertexai/trustgraph/model/text_completion/vertexai/llm.py @@ -4,29 +4,19 @@ Google Cloud. Input is prompt, output is response. Supports both Google's Gemini models and Anthropic's Claude models. """ -# -# Somewhat perplexed by the Google Cloud SDK choices. We're going off this -# one, which uses the google-cloud-aiplatform library: -# https://cloud.google.com/python/docs/reference/vertexai/1.94.0 -# It seems it is possible to invoke VertexAI from the google-genai -# SDK too: -# https://googleapis.github.io/python-genai/genai.html#module-genai.client -# That would make this code look very much like the GoogleAIStudio -# code. And maybe not reliant on the google-cloud-aiplatform library? # -# This module's imports bring in a lot of libraries. +# Uses the google-genai SDK for Gemini models on Vertex AI: +# https://googleapis.github.io/python-genai/genai.html#module-genai.client +# from google.oauth2 import service_account import google.auth -import google.api_core.exceptions -import vertexai import logging -# Why is preview here? -from vertexai.generative_models import ( - Content, FunctionDeclaration, GenerativeModel, GenerationConfig, - HarmCategory, HarmBlockThreshold, Part, Tool, SafetySetting, -) +from google import genai +from google.genai import types +from google.genai.types import HarmCategory, HarmBlockThreshold +from google.api_core.exceptions import ResourceExhausted # Added for Anthropic model support from anthropic import AnthropicVertex, RateLimitError @@ -67,12 +57,10 @@ class Processor(LlmService): self.max_output = max_output self.private_key = private_key - # Model client caches - self.model_clients = {} # Cache for model instances - self.generation_configs = {} # Cache for generation configs (Gemini only) - self.anthropic_client = None # Single Anthropic client (handles multiple models) + # Anthropic client (handles Claude models) + self.anthropic_client = None - # Shared parameters for both model types + # Shared parameters for Anthropic models self.api_params = { "temperature": temperature, "top_p": 1.0, @@ -84,10 +72,10 @@ class Processor(LlmService): # Unified credential and project ID loading if private_key: - credentials = ( - service_account.Credentials.from_service_account_file( - private_key - ) + scopes = ["https://www.googleapis.com/auth/cloud-platform"] + credentials = service_account.Credentials.from_service_account_file( + private_key, + scopes=scopes ) project_id = credentials.project_id else: @@ -103,12 +91,13 @@ class Processor(LlmService): self.credentials = credentials self.project_id = project_id - # Initialize Vertex AI SDK for Gemini models - init_kwargs = {'location': region, 'project': project_id} - if credentials and private_key: # Pass credentials only if from a file - init_kwargs['credentials'] = credentials - - vertexai.init(**init_kwargs) + # Initialize Google GenAI client for Gemini models + self.client = genai.Client( + vertexai=True, + project=project_id, + location=region, + credentials=credentials + ) # Pre-initialize Anthropic client if needed (single client handles all Claude models) if 'claude' in self.default_model.lower(): @@ -117,24 +106,27 @@ class Processor(LlmService): # Safety settings for Gemini models block_level = HarmBlockThreshold.BLOCK_ONLY_HIGH self.safety_settings = [ - SafetySetting( - category = HarmCategory.HARM_CATEGORY_HARASSMENT, - threshold = block_level, + types.SafetySetting( + category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, + threshold=block_level, ), - SafetySetting( - category = HarmCategory.HARM_CATEGORY_HATE_SPEECH, - threshold = block_level, + types.SafetySetting( + category=HarmCategory.HARM_CATEGORY_HARASSMENT, + threshold=block_level, ), - SafetySetting( - category = HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - threshold = block_level, + types.SafetySetting( + category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold=block_level, ), - SafetySetting( - category = HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - threshold = block_level, + types.SafetySetting( + category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=block_level, ), ] + # Cache for generation configs + self.generation_configs = {} + logger.info("VertexAI initialization complete") def _get_anthropic_client(self): @@ -152,25 +144,26 @@ class Processor(LlmService): return self.anthropic_client - def _get_gemini_model(self, model_name, temperature=None): - """Get or create a Gemini model instance""" - if model_name not in self.model_clients: - logger.info(f"Creating GenerativeModel instance for '{model_name}'") - self.model_clients[model_name] = GenerativeModel(model_name) - + def _get_or_create_config(self, model_name, temperature=None): + """Get or create generation config with dynamic temperature""" # Use provided temperature or fall back to default effective_temperature = temperature if temperature is not None else self.temperature - # Create generation config with the effective temperature - generation_config = GenerationConfig( - temperature=effective_temperature, - top_p=1.0, - top_k=10, - candidate_count=1, - max_output_tokens=self.max_output, - ) + # Create cache key that includes temperature to avoid conflicts + cache_key = f"{model_name}:{effective_temperature}" - return self.model_clients[model_name], generation_config + if cache_key not in self.generation_configs: + logger.info(f"Creating generation config for '{model_name}' with temperature {effective_temperature}") + self.generation_configs[cache_key] = types.GenerateContentConfig( + temperature=effective_temperature, + top_p=1.0, + top_k=40, + max_output_tokens=self.max_output, + response_mime_type="text/plain", + safety_settings=self.safety_settings, + ) + + return self.generation_configs[cache_key] async def generate_content(self, system, prompt, model=None, temperature=None): @@ -205,22 +198,24 @@ class Processor(LlmService): model=model_name ) else: - # Gemini API combines system and user prompts + # Gemini API using google-genai SDK logger.debug(f"Sending request to Gemini model '{model_name}'...") - full_prompt = system + "\n\n" + prompt - llm, generation_config = self._get_gemini_model(model_name, effective_temperature) + generation_config = self._get_or_create_config(model_name, effective_temperature) + # Set system instruction per request (can't be cached) + generation_config.system_instruction = system - response = llm.generate_content( - full_prompt, generation_config = generation_config, - safety_settings = self.safety_settings, + response = self.client.models.generate_content( + model=model_name, + config=generation_config, + contents=prompt, ) resp = LlmResult( - text = response.text, - in_token = response.usage_metadata.prompt_token_count, - out_token = response.usage_metadata.candidates_token_count, - model = model_name + text=response.text, + in_token=int(response.usage_metadata.prompt_token_count), + out_token=int(response.usage_metadata.candidates_token_count), + model=model_name ) logger.info(f"Input Tokens: {resp.in_token}") @@ -229,7 +224,7 @@ class Processor(LlmService): return resp - except (google.api_core.exceptions.ResourceExhausted, RateLimitError) as e: + except (ResourceExhausted, RateLimitError) as e: logger.warning(f"Hit rate limit: {e}") # Leave rate limit retries to the base handler raise TooManyRequests() @@ -302,17 +297,16 @@ class Processor(LlmService): logger.info(f"Output Tokens: {total_out_tokens}") else: - # Gemini streaming + # Gemini streaming using google-genai SDK logger.debug(f"Streaming request to Gemini model '{model_name}'...") - full_prompt = system + "\n\n" + prompt - llm, generation_config = self._get_gemini_model(model_name, effective_temperature) + generation_config = self._get_or_create_config(model_name, effective_temperature) + generation_config.system_instruction = system - response = llm.generate_content( - full_prompt, - generation_config=generation_config, - safety_settings=self.safety_settings, - stream=True # Enable streaming + response = self.client.models.generate_content_stream( + model=model_name, + config=generation_config, + contents=prompt, ) total_in_tokens = 0 @@ -348,7 +342,7 @@ class Processor(LlmService): logger.info(f"Input Tokens: {total_in_tokens}") logger.info(f"Output Tokens: {total_out_tokens}") - except (google.api_core.exceptions.ResourceExhausted, RateLimitError) as e: + except (ResourceExhausted, RateLimitError) as e: logger.warning(f"Hit rate limit during streaming: {e}") raise TooManyRequests()