Migrate to VertexAI to google-genai SDK from deprecated library (#632)

* Migrate to VertexAI to google-genai SDK from deprecated library

* Fix tests, mock the correct API
This commit is contained in:
cybermaggedon 2026-02-09 20:43:33 +00:00 committed by GitHub
parent 2781c7d87c
commit f24f1ebd80
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 223 additions and 245 deletions

View file

@ -1,6 +1,6 @@
"""
Unit tests for trustgraph.model.text_completion.vertexai
Starting simple with one test to get the basics working
Updated for google-genai SDK
"""
import pytest
@ -15,19 +15,20 @@ from trustgraph.base import LlmResult
class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
"""Simple test for processor initialization"""
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
async def test_processor_initialization_basic(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
"""Test basic processor initialization with mocked dependencies"""
# Arrange
mock_credentials = MagicMock()
mock_credentials.project_id = "test-project-123"
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
mock_model = MagicMock()
mock_generative_model.return_value = mock_model
mock_client = MagicMock()
mock_genai.Client.return_value = mock_client
# Mock the parent class initialization to avoid taskgroup requirement
mock_async_init.return_value = None
mock_llm_init.return_value = None
@ -47,32 +48,38 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
assert processor.default_model == 'gemini-2.0-flash-001' # It's stored as 'model', not 'model_name'
assert hasattr(processor, 'generation_configs') # Now a cache dictionary
assert processor.default_model == 'gemini-2.0-flash-001'
assert hasattr(processor, 'generation_configs') # Cache dictionary
assert hasattr(processor, 'safety_settings')
assert hasattr(processor, 'model_clients') # LLM clients are now cached
mock_service_account.Credentials.from_service_account_file.assert_called_once_with('private.json')
mock_vertexai.init.assert_called_once()
assert hasattr(processor, 'client') # genai.Client
mock_service_account.Credentials.from_service_account_file.assert_called_once()
mock_genai.Client.assert_called_once_with(
vertexai=True,
project="test-project-123",
location="us-central1",
credentials=mock_credentials
)
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
async def test_generate_content_success(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
"""Test successful content generation"""
# Arrange
mock_credentials = MagicMock()
mock_credentials.project_id = "test-project-123"
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
mock_model = MagicMock()
mock_response = MagicMock()
mock_response.text = "Generated response from Gemini"
mock_response.usage_metadata.prompt_token_count = 15
mock_response.usage_metadata.candidates_token_count = 8
mock_model.generate_content.return_value = mock_response
mock_generative_model.return_value = mock_model
mock_client = MagicMock()
mock_client.models.generate_content.return_value = mock_response
mock_genai.Client.return_value = mock_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
@ -98,32 +105,26 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
assert result.in_token == 15
assert result.out_token == 8
assert result.model == 'gemini-2.0-flash-001'
# Check that the method was called (actual prompt format may vary)
mock_model.generate_content.assert_called_once()
# Verify the call was made with the expected parameters
call_args = mock_model.generate_content.call_args
# Generation config is now created dynamically per model
assert 'generation_config' in call_args[1]
assert call_args[1]['safety_settings'] == processor.safety_settings
mock_client.models.generate_content.assert_called_once()
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_rate_limit_error(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
async def test_generate_content_rate_limit_error(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
"""Test rate limit error handling"""
# Arrange
from google.api_core.exceptions import ResourceExhausted
from trustgraph.exceptions import TooManyRequests
mock_credentials = MagicMock()
mock_credentials.project_id = "test-project-123"
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
mock_model = MagicMock()
mock_model.generate_content.side_effect = ResourceExhausted("Rate limit exceeded")
mock_generative_model.return_value = mock_model
mock_client = MagicMock()
mock_client.models.generate_content.side_effect = ResourceExhausted("Rate limit exceeded")
mock_genai.Client.return_value = mock_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
@ -144,25 +145,26 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
with pytest.raises(TooManyRequests):
await processor.generate_content("System prompt", "User prompt")
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_blocked_response(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
async def test_generate_content_blocked_response(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
"""Test handling of blocked content (safety filters)"""
# Arrange
mock_credentials = MagicMock()
mock_credentials.project_id = "test-project-123"
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
mock_model = MagicMock()
mock_response = MagicMock()
mock_response.text = None # Blocked content returns None
mock_response.usage_metadata.prompt_token_count = 10
mock_response.usage_metadata.candidates_token_count = 0
mock_model.generate_content.return_value = mock_response
mock_generative_model.return_value = mock_model
mock_client = MagicMock()
mock_client.models.generate_content.return_value = mock_response
mock_genai.Client.return_value = mock_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
@ -190,24 +192,22 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
assert result.model == 'gemini-2.0-flash-001'
@patch('trustgraph.model.text_completion.vertexai.llm.google.auth.default')
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_without_private_key(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account, mock_auth_default):
async def test_processor_initialization_without_private_key(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai, mock_auth_default):
"""Test processor initialization without private key (uses default credentials)"""
# Arrange
mock_async_init.return_value = None
mock_llm_init.return_value = None
# Mock google.auth.default() to return credentials and project ID
mock_credentials = MagicMock()
mock_auth_default.return_value = (mock_credentials, "test-project-123")
# Mock GenerativeModel
mock_model = MagicMock()
mock_generative_model.return_value = mock_model
mock_client = MagicMock()
mock_genai.Client.return_value = mock_client
config = {
'region': 'us-central1',
@ -222,30 +222,32 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
# Act
processor = Processor(**config)
# Assert
assert processor.default_model == 'gemini-2.0-flash-001'
mock_auth_default.assert_called_once()
mock_vertexai.init.assert_called_once_with(
location='us-central1',
project='test-project-123'
mock_genai.Client.assert_called_once_with(
vertexai=True,
project="test-project-123",
location="us-central1",
credentials=mock_credentials
)
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
async def test_generate_content_generic_exception(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
"""Test handling of generic exceptions"""
# Arrange
mock_credentials = MagicMock()
mock_credentials.project_id = "test-project-123"
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
mock_model = MagicMock()
mock_model.generate_content.side_effect = Exception("Network error")
mock_generative_model.return_value = mock_model
mock_client = MagicMock()
mock_client.models.generate_content.side_effect = Exception("Network error")
mock_genai.Client.return_value = mock_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
@ -266,19 +268,20 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
with pytest.raises(Exception, match="Network error"):
await processor.generate_content("System prompt", "User prompt")
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
async def test_processor_initialization_with_custom_parameters(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
"""Test processor initialization with custom parameters"""
# Arrange
mock_credentials = MagicMock()
mock_credentials.project_id = "test-project-123"
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
mock_model = MagicMock()
mock_generative_model.return_value = mock_model
mock_client = MagicMock()
mock_genai.Client.return_value = mock_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
@ -298,37 +301,37 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
# Assert
assert processor.default_model == 'gemini-1.5-pro'
# Verify that generation_config object exists (can't easily check internal values)
assert hasattr(processor, 'generation_configs') # Now a cache dictionary
# Verify that generation_config cache exists
assert hasattr(processor, 'generation_configs')
assert processor.generation_configs == {} # Empty cache initially
# Verify that safety settings are configured
assert len(processor.safety_settings) == 4
# Verify service account was called with custom key
mock_service_account.Credentials.from_service_account_file.assert_called_once_with('custom-key.json')
# Verify that api_params dict has the correct values (this is accessible)
mock_service_account.Credentials.from_service_account_file.assert_called_once()
# Verify that api_params dict has the correct values
assert processor.api_params["temperature"] == 0.7
assert processor.api_params["max_output_tokens"] == 4096
assert processor.api_params["top_p"] == 1.0
assert processor.api_params["top_k"] == 32
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_vertexai_initialization_with_credentials(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
async def test_vertexai_initialization_with_credentials(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
"""Test that VertexAI is initialized correctly with credentials"""
# Arrange
mock_credentials = MagicMock()
mock_credentials.project_id = "test-project-123"
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
mock_model = MagicMock()
mock_generative_model.return_value = mock_model
mock_client = MagicMock()
mock_genai.Client.return_value = mock_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
@ -347,35 +350,34 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Assert
# Verify VertexAI init was called with correct parameters
mock_vertexai.init.assert_called_once_with(
# Verify genai.Client was called with correct parameters
mock_genai.Client.assert_called_once_with(
vertexai=True,
project='test-project-123',
location='europe-west1',
credentials=mock_credentials,
project='test-project-123'
credentials=mock_credentials
)
# GenerativeModel is now created lazily on first use, not at initialization
mock_generative_model.assert_not_called()
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
async def test_generate_content_empty_prompts(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
"""Test content generation with empty prompts"""
# Arrange
mock_credentials = MagicMock()
mock_credentials.project_id = "test-project-123"
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
mock_model = MagicMock()
mock_response = MagicMock()
mock_response.text = "Default response"
mock_response.usage_metadata.prompt_token_count = 2
mock_response.usage_metadata.candidates_token_count = 3
mock_model.generate_content.return_value = mock_response
mock_generative_model.return_value = mock_model
mock_client = MagicMock()
mock_client.models.generate_content.return_value = mock_response
mock_genai.Client.return_value = mock_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
@ -401,27 +403,28 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
assert result.in_token == 2
assert result.out_token == 3
assert result.model == 'gemini-2.0-flash-001'
# Verify the model was called with the combined empty prompts
mock_model.generate_content.assert_called_once()
call_args = mock_model.generate_content.call_args
# The prompt should be "" + "\n\n" + "" = "\n\n"
assert call_args[0][0] == "\n\n"
# Verify the client was called
mock_client.models.generate_content.assert_called_once()
@patch('trustgraph.model.text_completion.vertexai.llm.AnthropicVertex')
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_anthropic_processor_initialization_with_private_key(self, mock_llm_init, mock_async_init, mock_service_account, mock_anthropic_vertex):
async def test_anthropic_processor_initialization_with_private_key(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai, mock_anthropic_vertex):
"""Test Anthropic processor initialization with private key credentials"""
# Arrange
mock_async_init.return_value = None
mock_llm_init.return_value = None
mock_credentials = MagicMock()
mock_credentials.project_id = "test-project-456"
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
mock_client = MagicMock()
mock_genai.Client.return_value = mock_client
# Mock AnthropicVertex
mock_anthropic_client = MagicMock()
mock_anthropic_vertex.return_value = mock_anthropic_client
@ -439,45 +442,45 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
# Act
processor = Processor(**config)
# Assert
assert processor.default_model == 'claude-3-sonnet@20240229'
# is_anthropic logic is now determined dynamically per request
# Verify service account was called with private key
mock_service_account.Credentials.from_service_account_file.assert_called_once_with('anthropic-key.json')
# Verify AnthropicVertex was initialized with credentials
mock_service_account.Credentials.from_service_account_file.assert_called_once()
# Verify AnthropicVertex was initialized with credentials (because model contains 'claude')
mock_anthropic_vertex.assert_called_once_with(
region='us-west1',
project_id='test-project-456',
credentials=mock_credentials
)
# Verify api_params are set correctly
assert processor.api_params["temperature"] == 0.5
assert processor.api_params["max_output_tokens"] == 2048
assert processor.api_params["top_p"] == 1.0
assert processor.api_params["top_k"] == 32
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
async def test_generate_content_temperature_override(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
"""Test temperature parameter override functionality"""
# Arrange
mock_credentials = MagicMock()
mock_credentials.project_id = "test-project-123"
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
mock_model = MagicMock()
mock_response = MagicMock()
mock_response.text = "Response with custom temperature"
mock_response.usage_metadata.prompt_token_count = 20
mock_response.usage_metadata.candidates_token_count = 12
mock_model.generate_content.return_value = mock_response
mock_generative_model.return_value = mock_model
mock_client = MagicMock()
mock_client.models.generate_content.return_value = mock_response
mock_genai.Client.return_value = mock_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
@ -506,42 +509,27 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with custom temperature"
mock_client.models.generate_content.assert_called_once()
# Verify Gemini API was called with overridden temperature
mock_model.generate_content.assert_called_once()
call_args = mock_model.generate_content.call_args
# Check that generation_config was created (we can't directly access temperature from mock)
generation_config = call_args.kwargs['generation_config']
assert generation_config is not None # Should use overridden temperature configuration
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
async def test_generate_content_model_override(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
"""Test model parameter override functionality"""
# Arrange
mock_credentials = MagicMock()
mock_credentials.project_id = "test-project-123"
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
# Mock different models
mock_model_default = MagicMock()
mock_model_override = MagicMock()
mock_response = MagicMock()
mock_response.text = "Response with custom model"
mock_response.usage_metadata.prompt_token_count = 18
mock_response.usage_metadata.candidates_token_count = 14
mock_model_override.generate_content.return_value = mock_response
# GenerativeModel should return different models based on input
def model_factory(model_name):
if model_name == 'gemini-1.5-pro':
return mock_model_override
return mock_model_default
mock_generative_model.side_effect = model_factory
mock_client = MagicMock()
mock_client.models.generate_content.return_value = mock_response
mock_genai.Client.return_value = mock_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
@ -549,7 +537,7 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
config = {
'region': 'us-central1',
'model': 'gemini-2.0-flash-001', # Default model
'temperature': 0.2, # Default temperature
'temperature': 0.2,
'max_output': 8192,
'private_key': 'private.json',
'concurrency': 1,
@ -571,29 +559,29 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
assert isinstance(result, LlmResult)
assert result.text == "Response with custom model"
# Verify the overridden model was used
mock_model_override.generate_content.assert_called_once()
# Verify GenerativeModel was called with the override model
mock_generative_model.assert_called_with('gemini-1.5-pro')
# Verify the call was made with the override model
call_args = mock_client.models.generate_content.call_args
assert call_args.kwargs['model'] == "gemini-1.5-pro"
@patch('trustgraph.model.text_completion.vertexai.llm.genai')
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account):
async def test_generate_content_both_parameters_override(self, mock_llm_init, mock_async_init, mock_service_account, mock_genai):
"""Test overriding both model and temperature parameters simultaneously"""
# Arrange
mock_credentials = MagicMock()
mock_credentials.project_id = "test-project-123"
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
mock_model = MagicMock()
mock_response = MagicMock()
mock_response.text = "Response with both overrides"
mock_response.usage_metadata.prompt_token_count = 22
mock_response.usage_metadata.candidates_token_count = 16
mock_model.generate_content.return_value = mock_response
mock_generative_model.return_value = mock_model
mock_client = MagicMock()
mock_client.models.generate_content.return_value = mock_response
mock_genai.Client.return_value = mock_client
mock_async_init.return_value = None
mock_llm_init.return_value = None
@ -622,18 +610,12 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
# Assert
assert isinstance(result, LlmResult)
assert result.text == "Response with both overrides"
mock_client.models.generate_content.assert_called_once()
# Verify both overrides were used
mock_model.generate_content.assert_called_once()
call_args = mock_model.generate_content.call_args
# Verify model override
mock_generative_model.assert_called_with('gemini-1.5-flash-001') # Should use runtime override
# Verify temperature override (we can't directly access temperature from mock)
generation_config = call_args.kwargs['generation_config']
assert generation_config is not None # Should use overridden temperature configuration
# Verify the model override was used
call_args = mock_client.models.generate_content.call_args
assert call_args.kwargs['model'] == "gemini-1.5-flash-001"
if __name__ == '__main__':
pytest.main([__file__])
pytest.main([__file__])

View file

@ -20,6 +20,7 @@ dependencies = [
"falkordb",
"fastembed",
"google-genai",
"google-api-core",
"ibis",
"jsonschema",
"langchain",

View file

@ -12,7 +12,8 @@ requires-python = ">=3.8"
dependencies = [
"trustgraph-base>=2.0,<2.1",
"pulsar-client",
"google-cloud-aiplatform",
"google-genai",
"google-api-core",
"prometheus-client",
"anthropic",
]

View file

@ -4,29 +4,19 @@ Google Cloud. Input is prompt, output is response.
Supports both Google's Gemini models and Anthropic's Claude models.
"""
#
# Somewhat perplexed by the Google Cloud SDK choices. We're going off this
# one, which uses the google-cloud-aiplatform library:
# https://cloud.google.com/python/docs/reference/vertexai/1.94.0
# It seems it is possible to invoke VertexAI from the google-genai
# SDK too:
# https://googleapis.github.io/python-genai/genai.html#module-genai.client
# That would make this code look very much like the GoogleAIStudio
# code. And maybe not reliant on the google-cloud-aiplatform library?
#
# This module's imports bring in a lot of libraries.
# Uses the google-genai SDK for Gemini models on Vertex AI:
# https://googleapis.github.io/python-genai/genai.html#module-genai.client
#
from google.oauth2 import service_account
import google.auth
import google.api_core.exceptions
import vertexai
import logging
# Why is preview here?
from vertexai.generative_models import (
Content, FunctionDeclaration, GenerativeModel, GenerationConfig,
HarmCategory, HarmBlockThreshold, Part, Tool, SafetySetting,
)
from google import genai
from google.genai import types
from google.genai.types import HarmCategory, HarmBlockThreshold
from google.api_core.exceptions import ResourceExhausted
# Added for Anthropic model support
from anthropic import AnthropicVertex, RateLimitError
@ -67,12 +57,10 @@ class Processor(LlmService):
self.max_output = max_output
self.private_key = private_key
# Model client caches
self.model_clients = {} # Cache for model instances
self.generation_configs = {} # Cache for generation configs (Gemini only)
self.anthropic_client = None # Single Anthropic client (handles multiple models)
# Anthropic client (handles Claude models)
self.anthropic_client = None
# Shared parameters for both model types
# Shared parameters for Anthropic models
self.api_params = {
"temperature": temperature,
"top_p": 1.0,
@ -84,10 +72,10 @@ class Processor(LlmService):
# Unified credential and project ID loading
if private_key:
credentials = (
service_account.Credentials.from_service_account_file(
private_key
)
scopes = ["https://www.googleapis.com/auth/cloud-platform"]
credentials = service_account.Credentials.from_service_account_file(
private_key,
scopes=scopes
)
project_id = credentials.project_id
else:
@ -103,12 +91,13 @@ class Processor(LlmService):
self.credentials = credentials
self.project_id = project_id
# Initialize Vertex AI SDK for Gemini models
init_kwargs = {'location': region, 'project': project_id}
if credentials and private_key: # Pass credentials only if from a file
init_kwargs['credentials'] = credentials
vertexai.init(**init_kwargs)
# Initialize Google GenAI client for Gemini models
self.client = genai.Client(
vertexai=True,
project=project_id,
location=region,
credentials=credentials
)
# Pre-initialize Anthropic client if needed (single client handles all Claude models)
if 'claude' in self.default_model.lower():
@ -117,24 +106,27 @@ class Processor(LlmService):
# Safety settings for Gemini models
block_level = HarmBlockThreshold.BLOCK_ONLY_HIGH
self.safety_settings = [
SafetySetting(
category = HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold = block_level,
types.SafetySetting(
category=HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold=block_level,
),
SafetySetting(
category = HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold = block_level,
types.SafetySetting(
category=HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold=block_level,
),
SafetySetting(
category = HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold = block_level,
types.SafetySetting(
category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold=block_level,
),
SafetySetting(
category = HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold = block_level,
types.SafetySetting(
category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold=block_level,
),
]
# Cache for generation configs
self.generation_configs = {}
logger.info("VertexAI initialization complete")
def _get_anthropic_client(self):
@ -152,25 +144,26 @@ class Processor(LlmService):
return self.anthropic_client
def _get_gemini_model(self, model_name, temperature=None):
"""Get or create a Gemini model instance"""
if model_name not in self.model_clients:
logger.info(f"Creating GenerativeModel instance for '{model_name}'")
self.model_clients[model_name] = GenerativeModel(model_name)
def _get_or_create_config(self, model_name, temperature=None):
"""Get or create generation config with dynamic temperature"""
# Use provided temperature or fall back to default
effective_temperature = temperature if temperature is not None else self.temperature
# Create generation config with the effective temperature
generation_config = GenerationConfig(
temperature=effective_temperature,
top_p=1.0,
top_k=10,
candidate_count=1,
max_output_tokens=self.max_output,
)
# Create cache key that includes temperature to avoid conflicts
cache_key = f"{model_name}:{effective_temperature}"
return self.model_clients[model_name], generation_config
if cache_key not in self.generation_configs:
logger.info(f"Creating generation config for '{model_name}' with temperature {effective_temperature}")
self.generation_configs[cache_key] = types.GenerateContentConfig(
temperature=effective_temperature,
top_p=1.0,
top_k=40,
max_output_tokens=self.max_output,
response_mime_type="text/plain",
safety_settings=self.safety_settings,
)
return self.generation_configs[cache_key]
async def generate_content(self, system, prompt, model=None, temperature=None):
@ -205,22 +198,24 @@ class Processor(LlmService):
model=model_name
)
else:
# Gemini API combines system and user prompts
# Gemini API using google-genai SDK
logger.debug(f"Sending request to Gemini model '{model_name}'...")
full_prompt = system + "\n\n" + prompt
llm, generation_config = self._get_gemini_model(model_name, effective_temperature)
generation_config = self._get_or_create_config(model_name, effective_temperature)
# Set system instruction per request (can't be cached)
generation_config.system_instruction = system
response = llm.generate_content(
full_prompt, generation_config = generation_config,
safety_settings = self.safety_settings,
response = self.client.models.generate_content(
model=model_name,
config=generation_config,
contents=prompt,
)
resp = LlmResult(
text = response.text,
in_token = response.usage_metadata.prompt_token_count,
out_token = response.usage_metadata.candidates_token_count,
model = model_name
text=response.text,
in_token=int(response.usage_metadata.prompt_token_count),
out_token=int(response.usage_metadata.candidates_token_count),
model=model_name
)
logger.info(f"Input Tokens: {resp.in_token}")
@ -229,7 +224,7 @@ class Processor(LlmService):
return resp
except (google.api_core.exceptions.ResourceExhausted, RateLimitError) as e:
except (ResourceExhausted, RateLimitError) as e:
logger.warning(f"Hit rate limit: {e}")
# Leave rate limit retries to the base handler
raise TooManyRequests()
@ -302,17 +297,16 @@ class Processor(LlmService):
logger.info(f"Output Tokens: {total_out_tokens}")
else:
# Gemini streaming
# Gemini streaming using google-genai SDK
logger.debug(f"Streaming request to Gemini model '{model_name}'...")
full_prompt = system + "\n\n" + prompt
llm, generation_config = self._get_gemini_model(model_name, effective_temperature)
generation_config = self._get_or_create_config(model_name, effective_temperature)
generation_config.system_instruction = system
response = llm.generate_content(
full_prompt,
generation_config=generation_config,
safety_settings=self.safety_settings,
stream=True # Enable streaming
response = self.client.models.generate_content_stream(
model=model_name,
config=generation_config,
contents=prompt,
)
total_in_tokens = 0
@ -348,7 +342,7 @@ class Processor(LlmService):
logger.info(f"Input Tokens: {total_in_tokens}")
logger.info(f"Output Tokens: {total_out_tokens}")
except (google.api_core.exceptions.ResourceExhausted, RateLimitError) as e:
except (ResourceExhausted, RateLimitError) as e:
logger.warning(f"Hit rate limit during streaming: {e}")
raise TooManyRequests()