diff --git a/tests/unit/test_text_completion/test_vertexai_processor.py b/tests/unit/test_text_completion/test_vertexai_processor.py index f7fcab73..3910a30c 100644 --- a/tests/unit/test_text_completion/test_vertexai_processor.py +++ b/tests/unit/test_text_completion/test_vertexai_processor.py @@ -188,16 +188,25 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase): assert result.out_token == 0 assert result.model == 'gemini-2.0-flash-001' + @patch('trustgraph.model.text_completion.vertexai.llm.google.auth.default') @patch('trustgraph.model.text_completion.vertexai.llm.service_account') @patch('trustgraph.model.text_completion.vertexai.llm.vertexai') @patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__') - async def test_processor_initialization_without_private_key(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account): - """Test processor initialization without private key (should fail)""" + async def test_processor_initialization_without_private_key(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account, mock_auth_default): + """Test processor initialization without private key (uses default credentials)""" # Arrange mock_async_init.return_value = None mock_llm_init.return_value = None + + # Mock google.auth.default() to return credentials and project ID + mock_credentials = MagicMock() + mock_auth_default.return_value = (mock_credentials, "test-project-123") + + # Mock GenerativeModel + mock_model = MagicMock() + mock_generative_model.return_value = mock_model config = { 'region': 'us-central1', @@ -210,9 +219,16 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase): 'id': 'test-processor' } - # Act & Assert - with pytest.raises(RuntimeError, match="Private key file not specified"): - processor = Processor(**config) + # Act + processor = Processor(**config) + + # Assert + assert processor.model == 'gemini-2.0-flash-001' + mock_auth_default.assert_called_once() + mock_vertexai.init.assert_called_once_with( + location='us-central1', + project='test-project-123' + ) @patch('trustgraph.model.text_completion.vertexai.llm.service_account') @patch('trustgraph.model.text_completion.vertexai.llm.vertexai') @@ -292,12 +308,11 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase): # Verify service account was called with custom key mock_service_account.Credentials.from_service_account_file.assert_called_once_with('custom-key.json') - # Verify that parameters dict has the correct values (this is accessible) - assert processor.parameters["temperature"] == 0.7 - assert processor.parameters["max_output_tokens"] == 4096 - assert processor.parameters["top_p"] == 1.0 - assert processor.parameters["top_k"] == 32 - assert processor.parameters["candidate_count"] == 1 + # Verify that api_params dict has the correct values (this is accessible) + assert processor.api_params["temperature"] == 0.7 + assert processor.api_params["max_output_tokens"] == 4096 + assert processor.api_params["top_p"] == 1.0 + assert processor.api_params["top_k"] == 32 @patch('trustgraph.model.text_completion.vertexai.llm.service_account') @patch('trustgraph.model.text_completion.vertexai.llm.vertexai') @@ -392,6 +407,58 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase): # The prompt should be "" + "\n\n" + "" = "\n\n" assert call_args[0][0] == "\n\n" + @patch('trustgraph.model.text_completion.vertexai.llm.AnthropicVertex') + @patch('trustgraph.model.text_completion.vertexai.llm.service_account') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__') + @patch('trustgraph.base.llm_service.LlmService.__init__') + async def test_anthropic_processor_initialization_with_private_key(self, mock_llm_init, mock_async_init, mock_service_account, mock_anthropic_vertex): + """Test Anthropic processor initialization with private key credentials""" + # Arrange + mock_async_init.return_value = None + mock_llm_init.return_value = None + + mock_credentials = MagicMock() + mock_credentials.project_id = "test-project-456" + mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials + + # Mock AnthropicVertex + mock_anthropic_client = MagicMock() + mock_anthropic_vertex.return_value = mock_anthropic_client + + config = { + 'region': 'us-west1', + 'model': 'claude-3-sonnet@20240229', # Anthropic model + 'temperature': 0.5, + 'max_output': 2048, + 'private_key': 'anthropic-key.json', + 'concurrency': 1, + 'taskgroup': AsyncMock(), + 'id': 'test-anthropic-processor' + } + + # Act + processor = Processor(**config) + + # Assert + assert processor.model == 'claude-3-sonnet@20240229' + assert processor.is_anthropic == True + + # Verify service account was called with private key + mock_service_account.Credentials.from_service_account_file.assert_called_once_with('anthropic-key.json') + + # Verify AnthropicVertex was initialized with credentials + mock_anthropic_vertex.assert_called_once_with( + region='us-west1', + project_id='test-project-456', + credentials=mock_credentials + ) + + # Verify api_params are set correctly + assert processor.api_params["temperature"] == 0.5 + assert processor.api_params["max_output_tokens"] == 2048 + assert processor.api_params["top_p"] == 1.0 + assert processor.api_params["top_k"] == 32 + if __name__ == '__main__': pytest.main([__file__]) \ No newline at end of file diff --git a/trustgraph-vertexai/trustgraph/model/text_completion/vertexai/llm.py b/trustgraph-vertexai/trustgraph/model/text_completion/vertexai/llm.py index 24cc576c..a1ab4717 100755 --- a/trustgraph-vertexai/trustgraph/model/text_completion/vertexai/llm.py +++ b/trustgraph-vertexai/trustgraph/model/text_completion/vertexai/llm.py @@ -1,7 +1,7 @@ - """ Simple LLM service, performs text prompt completion using VertexAI on Google Cloud. Input is prompt, output is response. +Supports both Google's Gemini models and Anthropic's Claude models. """ # @@ -17,7 +17,7 @@ Google Cloud. Input is prompt, output is response. # This module's imports bring in a lot of libraries. from google.oauth2 import service_account -import google +import google.auth import vertexai import logging @@ -27,6 +27,9 @@ from vertexai.generative_models import ( HarmCategory, HarmBlockThreshold, Part, Tool, SafetySetting, ) +# Added for Anthropic model support +from anthropic import AnthropicVertex, RateLimitError + from .... exceptions import TooManyRequests from .... base import LlmService, LlmResult @@ -35,7 +38,7 @@ logger = logging.getLogger(__name__) default_ident = "text-completion" -default_model = 'gemini-2.0-flash-001' +default_model = 'gemini-1.5-flash-001' default_region = 'us-central1' default_temperature = 0.0 default_max_output = 8192 @@ -52,111 +55,148 @@ class Processor(LlmService): max_output = params.get("max_output", default_max_output) if private_key is None: - raise RuntimeError("Private key file not specified") + logger.warning("Private key file not specified, using Application Default Credentials") super(Processor, self).__init__(**params) - self.parameters = { + self.model = model + self.is_anthropic = 'claude' in self.model.lower() + + # Shared parameters for both model types + self.api_params = { "temperature": temperature, "top_p": 1.0, "top_k": 32, - "candidate_count": 1, "max_output_tokens": max_output, } - self.generation_config = GenerationConfig( - temperature=temperature, - top_p=1.0, - top_k=10, - candidate_count=1, - max_output_tokens=max_output, - ) - - # Block none doesn't seem to work - block_level = HarmBlockThreshold.BLOCK_ONLY_HIGH - # block_level = HarmBlockThreshold.BLOCK_NONE - - self.safety_settings = [ - SafetySetting( - category = HarmCategory.HARM_CATEGORY_HARASSMENT, - threshold = block_level, - ), - SafetySetting( - category = HarmCategory.HARM_CATEGORY_HATE_SPEECH, - threshold = block_level, - ), - SafetySetting( - category = HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - threshold = block_level, - ), - SafetySetting( - category = HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - threshold = block_level, - ), - ] - logger.info("Initializing VertexAI...") + # Unified credential and project ID loading if private_key: credentials = ( service_account.Credentials.from_service_account_file( private_key ) ) + project_id = credentials.project_id else: - credentials = None + credentials, project_id = google.auth.default() - if credentials: - vertexai.init( - location=region, - credentials=credentials, - project=credentials.project_id, - ) - else: - vertexai.init( - location=region + if not project_id: + raise RuntimeError( + "Could not determine Google Cloud project ID. " + "Ensure it's set in your environment or service account." ) - logger.info(f"Initializing model {model}") - self.llm = GenerativeModel(model) - self.model = model + # Initialize the appropriate client based on the model type + if self.is_anthropic: + logger.info(f"Initializing Anthropic model '{model}' via AnthropicVertex SDK") + # Initialize AnthropicVertex with credentials if provided, otherwise use ADC + anthropic_kwargs = {'region': region, 'project_id': project_id} + if credentials and private_key: # Pass credentials only if from a file + anthropic_kwargs['credentials'] = credentials + logger.debug(f"Using service account credentials for Anthropic model") + else: + logger.debug(f"Using Application Default Credentials for Anthropic model") + + self.llm = AnthropicVertex(**anthropic_kwargs) + else: + # For Gemini models, initialize the Vertex AI SDK + logger.info(f"Initializing Google model '{model}' via Vertex AI SDK") + init_kwargs = {'location': region, 'project': project_id} + if credentials and private_key: # Pass credentials only if from a file + init_kwargs['credentials'] = credentials + + vertexai.init(**init_kwargs) + + self.llm = GenerativeModel(model) + + self.generation_config = GenerationConfig( + temperature=temperature, + top_p=1.0, + top_k=10, + candidate_count=1, + max_output_tokens=max_output, + ) + + # Block none doesn't seem to work + block_level = HarmBlockThreshold.BLOCK_ONLY_HIGH + # block_level = HarmBlockThreshold.BLOCK_NONE + + self.safety_settings = [ + SafetySetting( + category = HarmCategory.HARM_CATEGORY_HARASSMENT, + threshold = block_level, + ), + SafetySetting( + category = HarmCategory.HARM_CATEGORY_HATE_SPEECH, + threshold = block_level, + ), + SafetySetting( + category = HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold = block_level, + ), + SafetySetting( + category = HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold = block_level, + ), + ] + logger.info("VertexAI initialization complete") async def generate_content(self, system, prompt): try: + if self.is_anthropic: + # Anthropic API uses a dedicated system prompt + logger.debug("Sending request to Anthropic model...") + response = self.llm.messages.create( + model=self.model, + system=system, + messages=[{"role": "user", "content": prompt}], + max_tokens=self.api_params['max_output_tokens'], + temperature=self.api_params['temperature'], + top_p=self.api_params['top_p'], + top_k=self.api_params['top_k'], + ) - prompt = system + "\n\n" + prompt + resp = LlmResult( + text=response.content[0].text, + in_token=response.usage.input_tokens, + out_token=response.usage.output_tokens, + model=self.model + ) + else: + # Gemini API combines system and user prompts + logger.debug("Sending request to Gemini model...") + full_prompt = system + "\n\n" + prompt - response = self.llm.generate_content( - prompt, generation_config = self.generation_config, - safety_settings = self.safety_settings, - ) + response = self.llm.generate_content( + full_prompt, generation_config = self.generation_config, + safety_settings = self.safety_settings, + ) - resp = LlmResult( - text = response.text, - in_token = response.usage_metadata.prompt_token_count, - out_token = response.usage_metadata.candidates_token_count, - model = self.model - ) + resp = LlmResult( + text = response.text, + in_token = response.usage_metadata.prompt_token_count, + out_token = response.usage_metadata.candidates_token_count, + model = self.model + ) logger.info(f"Input Tokens: {resp.in_token}") logger.info(f"Output Tokens: {resp.out_token}") - logger.debug("Send response...") return resp - except google.api_core.exceptions.ResourceExhausted as e: - + except (google.api_core.exceptions.ResourceExhausted, RateLimitError) as e: logger.warning(f"Hit rate limit: {e}") - # Leave rate limit retries to the base handler raise TooManyRequests() except Exception as e: - # Apart from rate limits, treat all exceptions as unrecoverable logger.error(f"VertexAI LLM exception: {e}", exc_info=True) raise e @@ -169,12 +209,12 @@ class Processor(LlmService): parser.add_argument( '-m', '--model', default=default_model, - help=f'LLM model (default: {default_model})' + help=f'LLM model (e.g., gemini-1.5-flash-001, claude-3-sonnet@20240229) (default: {default_model})' ) parser.add_argument( '-k', '--private-key', - help=f'Google Cloud private JSON file' + help=f'Google Cloud private JSON file (optional, uses ADC if not provided)' ) parser.add_argument( @@ -198,5 +238,4 @@ class Processor(LlmService): ) def run(): - Processor.launch(default_ident, __doc__) - + Processor.launch(default_ident, __doc__) \ No newline at end of file