Features/vertex anthropic (#458)

* Added Anthropic support for VertexAI

* Update tests to match code

* Fixed private.json usage with Anthropic (I think).

* Fixed test

---------

Co-authored-by: Cyber MacGeddon <cybermaggedon@gmail.com>
This commit is contained in:
Jack Colquitt 2025-08-19 13:00:22 -07:00 committed by GitHub
parent e89a5b5d23
commit 244da4aec1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 187 additions and 81 deletions

View file

@ -188,16 +188,25 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
assert result.out_token == 0 assert result.out_token == 0
assert result.model == 'gemini-2.0-flash-001' assert result.model == 'gemini-2.0-flash-001'
@patch('trustgraph.model.text_completion.vertexai.llm.google.auth.default')
@patch('trustgraph.model.text_completion.vertexai.llm.service_account') @patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai') @patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel') @patch('trustgraph.model.text_completion.vertexai.llm.GenerativeModel')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_without_private_key(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account): async def test_processor_initialization_without_private_key(self, mock_llm_init, mock_async_init, mock_generative_model, mock_vertexai, mock_service_account, mock_auth_default):
"""Test processor initialization without private key (should fail)""" """Test processor initialization without private key (uses default credentials)"""
# Arrange # Arrange
mock_async_init.return_value = None mock_async_init.return_value = None
mock_llm_init.return_value = None mock_llm_init.return_value = None
# Mock google.auth.default() to return credentials and project ID
mock_credentials = MagicMock()
mock_auth_default.return_value = (mock_credentials, "test-project-123")
# Mock GenerativeModel
mock_model = MagicMock()
mock_generative_model.return_value = mock_model
config = { config = {
'region': 'us-central1', 'region': 'us-central1',
@ -210,9 +219,16 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
'id': 'test-processor' 'id': 'test-processor'
} }
# Act & Assert # Act
with pytest.raises(RuntimeError, match="Private key file not specified"): processor = Processor(**config)
processor = Processor(**config)
# Assert
assert processor.model == 'gemini-2.0-flash-001'
mock_auth_default.assert_called_once()
mock_vertexai.init.assert_called_once_with(
location='us-central1',
project='test-project-123'
)
@patch('trustgraph.model.text_completion.vertexai.llm.service_account') @patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai') @patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@ -292,12 +308,11 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
# Verify service account was called with custom key # Verify service account was called with custom key
mock_service_account.Credentials.from_service_account_file.assert_called_once_with('custom-key.json') mock_service_account.Credentials.from_service_account_file.assert_called_once_with('custom-key.json')
# Verify that parameters dict has the correct values (this is accessible) # Verify that api_params dict has the correct values (this is accessible)
assert processor.parameters["temperature"] == 0.7 assert processor.api_params["temperature"] == 0.7
assert processor.parameters["max_output_tokens"] == 4096 assert processor.api_params["max_output_tokens"] == 4096
assert processor.parameters["top_p"] == 1.0 assert processor.api_params["top_p"] == 1.0
assert processor.parameters["top_k"] == 32 assert processor.api_params["top_k"] == 32
assert processor.parameters["candidate_count"] == 1
@patch('trustgraph.model.text_completion.vertexai.llm.service_account') @patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.model.text_completion.vertexai.llm.vertexai') @patch('trustgraph.model.text_completion.vertexai.llm.vertexai')
@ -392,6 +407,58 @@ class TestVertexAIProcessorSimple(IsolatedAsyncioTestCase):
# The prompt should be "" + "\n\n" + "" = "\n\n" # The prompt should be "" + "\n\n" + "" = "\n\n"
assert call_args[0][0] == "\n\n" assert call_args[0][0] == "\n\n"
@patch('trustgraph.model.text_completion.vertexai.llm.AnthropicVertex')
@patch('trustgraph.model.text_completion.vertexai.llm.service_account')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_anthropic_processor_initialization_with_private_key(self, mock_llm_init, mock_async_init, mock_service_account, mock_anthropic_vertex):
"""Test Anthropic processor initialization with private key credentials"""
# Arrange
mock_async_init.return_value = None
mock_llm_init.return_value = None
mock_credentials = MagicMock()
mock_credentials.project_id = "test-project-456"
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
# Mock AnthropicVertex
mock_anthropic_client = MagicMock()
mock_anthropic_vertex.return_value = mock_anthropic_client
config = {
'region': 'us-west1',
'model': 'claude-3-sonnet@20240229', # Anthropic model
'temperature': 0.5,
'max_output': 2048,
'private_key': 'anthropic-key.json',
'concurrency': 1,
'taskgroup': AsyncMock(),
'id': 'test-anthropic-processor'
}
# Act
processor = Processor(**config)
# Assert
assert processor.model == 'claude-3-sonnet@20240229'
assert processor.is_anthropic == True
# Verify service account was called with private key
mock_service_account.Credentials.from_service_account_file.assert_called_once_with('anthropic-key.json')
# Verify AnthropicVertex was initialized with credentials
mock_anthropic_vertex.assert_called_once_with(
region='us-west1',
project_id='test-project-456',
credentials=mock_credentials
)
# Verify api_params are set correctly
assert processor.api_params["temperature"] == 0.5
assert processor.api_params["max_output_tokens"] == 2048
assert processor.api_params["top_p"] == 1.0
assert processor.api_params["top_k"] == 32
if __name__ == '__main__': if __name__ == '__main__':
pytest.main([__file__]) pytest.main([__file__])

View file

@ -1,7 +1,7 @@
""" """
Simple LLM service, performs text prompt completion using VertexAI on Simple LLM service, performs text prompt completion using VertexAI on
Google Cloud. Input is prompt, output is response. Google Cloud. Input is prompt, output is response.
Supports both Google's Gemini models and Anthropic's Claude models.
""" """
# #
@ -17,7 +17,7 @@ Google Cloud. Input is prompt, output is response.
# This module's imports bring in a lot of libraries. # This module's imports bring in a lot of libraries.
from google.oauth2 import service_account from google.oauth2 import service_account
import google import google.auth
import vertexai import vertexai
import logging import logging
@ -27,6 +27,9 @@ from vertexai.generative_models import (
HarmCategory, HarmBlockThreshold, Part, Tool, SafetySetting, HarmCategory, HarmBlockThreshold, Part, Tool, SafetySetting,
) )
# Added for Anthropic model support
from anthropic import AnthropicVertex, RateLimitError
from .... exceptions import TooManyRequests from .... exceptions import TooManyRequests
from .... base import LlmService, LlmResult from .... base import LlmService, LlmResult
@ -35,7 +38,7 @@ logger = logging.getLogger(__name__)
default_ident = "text-completion" default_ident = "text-completion"
default_model = 'gemini-2.0-flash-001' default_model = 'gemini-1.5-flash-001'
default_region = 'us-central1' default_region = 'us-central1'
default_temperature = 0.0 default_temperature = 0.0
default_max_output = 8192 default_max_output = 8192
@ -52,111 +55,148 @@ class Processor(LlmService):
max_output = params.get("max_output", default_max_output) max_output = params.get("max_output", default_max_output)
if private_key is None: if private_key is None:
raise RuntimeError("Private key file not specified") logger.warning("Private key file not specified, using Application Default Credentials")
super(Processor, self).__init__(**params) super(Processor, self).__init__(**params)
self.parameters = { self.model = model
self.is_anthropic = 'claude' in self.model.lower()
# Shared parameters for both model types
self.api_params = {
"temperature": temperature, "temperature": temperature,
"top_p": 1.0, "top_p": 1.0,
"top_k": 32, "top_k": 32,
"candidate_count": 1,
"max_output_tokens": max_output, "max_output_tokens": max_output,
} }
self.generation_config = GenerationConfig(
temperature=temperature,
top_p=1.0,
top_k=10,
candidate_count=1,
max_output_tokens=max_output,
)
# Block none doesn't seem to work
block_level = HarmBlockThreshold.BLOCK_ONLY_HIGH
# block_level = HarmBlockThreshold.BLOCK_NONE
self.safety_settings = [
SafetySetting(
category = HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold = block_level,
),
SafetySetting(
category = HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold = block_level,
),
SafetySetting(
category = HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold = block_level,
),
SafetySetting(
category = HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold = block_level,
),
]
logger.info("Initializing VertexAI...") logger.info("Initializing VertexAI...")
# Unified credential and project ID loading
if private_key: if private_key:
credentials = ( credentials = (
service_account.Credentials.from_service_account_file( service_account.Credentials.from_service_account_file(
private_key private_key
) )
) )
project_id = credentials.project_id
else: else:
credentials = None credentials, project_id = google.auth.default()
if credentials: if not project_id:
vertexai.init( raise RuntimeError(
location=region, "Could not determine Google Cloud project ID. "
credentials=credentials, "Ensure it's set in your environment or service account."
project=credentials.project_id,
)
else:
vertexai.init(
location=region
) )
logger.info(f"Initializing model {model}") # Initialize the appropriate client based on the model type
self.llm = GenerativeModel(model) if self.is_anthropic:
self.model = model logger.info(f"Initializing Anthropic model '{model}' via AnthropicVertex SDK")
# Initialize AnthropicVertex with credentials if provided, otherwise use ADC
anthropic_kwargs = {'region': region, 'project_id': project_id}
if credentials and private_key: # Pass credentials only if from a file
anthropic_kwargs['credentials'] = credentials
logger.debug(f"Using service account credentials for Anthropic model")
else:
logger.debug(f"Using Application Default Credentials for Anthropic model")
self.llm = AnthropicVertex(**anthropic_kwargs)
else:
# For Gemini models, initialize the Vertex AI SDK
logger.info(f"Initializing Google model '{model}' via Vertex AI SDK")
init_kwargs = {'location': region, 'project': project_id}
if credentials and private_key: # Pass credentials only if from a file
init_kwargs['credentials'] = credentials
vertexai.init(**init_kwargs)
self.llm = GenerativeModel(model)
self.generation_config = GenerationConfig(
temperature=temperature,
top_p=1.0,
top_k=10,
candidate_count=1,
max_output_tokens=max_output,
)
# Block none doesn't seem to work
block_level = HarmBlockThreshold.BLOCK_ONLY_HIGH
# block_level = HarmBlockThreshold.BLOCK_NONE
self.safety_settings = [
SafetySetting(
category = HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold = block_level,
),
SafetySetting(
category = HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold = block_level,
),
SafetySetting(
category = HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold = block_level,
),
SafetySetting(
category = HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold = block_level,
),
]
logger.info("VertexAI initialization complete") logger.info("VertexAI initialization complete")
async def generate_content(self, system, prompt): async def generate_content(self, system, prompt):
try: try:
if self.is_anthropic:
# Anthropic API uses a dedicated system prompt
logger.debug("Sending request to Anthropic model...")
response = self.llm.messages.create(
model=self.model,
system=system,
messages=[{"role": "user", "content": prompt}],
max_tokens=self.api_params['max_output_tokens'],
temperature=self.api_params['temperature'],
top_p=self.api_params['top_p'],
top_k=self.api_params['top_k'],
)
prompt = system + "\n\n" + prompt resp = LlmResult(
text=response.content[0].text,
in_token=response.usage.input_tokens,
out_token=response.usage.output_tokens,
model=self.model
)
else:
# Gemini API combines system and user prompts
logger.debug("Sending request to Gemini model...")
full_prompt = system + "\n\n" + prompt
response = self.llm.generate_content( response = self.llm.generate_content(
prompt, generation_config = self.generation_config, full_prompt, generation_config = self.generation_config,
safety_settings = self.safety_settings, safety_settings = self.safety_settings,
) )
resp = LlmResult( resp = LlmResult(
text = response.text, text = response.text,
in_token = response.usage_metadata.prompt_token_count, in_token = response.usage_metadata.prompt_token_count,
out_token = response.usage_metadata.candidates_token_count, out_token = response.usage_metadata.candidates_token_count,
model = self.model model = self.model
) )
logger.info(f"Input Tokens: {resp.in_token}") logger.info(f"Input Tokens: {resp.in_token}")
logger.info(f"Output Tokens: {resp.out_token}") logger.info(f"Output Tokens: {resp.out_token}")
logger.debug("Send response...") logger.debug("Send response...")
return resp return resp
except google.api_core.exceptions.ResourceExhausted as e: except (google.api_core.exceptions.ResourceExhausted, RateLimitError) as e:
logger.warning(f"Hit rate limit: {e}") logger.warning(f"Hit rate limit: {e}")
# Leave rate limit retries to the base handler # Leave rate limit retries to the base handler
raise TooManyRequests() raise TooManyRequests()
except Exception as e: except Exception as e:
# Apart from rate limits, treat all exceptions as unrecoverable # Apart from rate limits, treat all exceptions as unrecoverable
logger.error(f"VertexAI LLM exception: {e}", exc_info=True) logger.error(f"VertexAI LLM exception: {e}", exc_info=True)
raise e raise e
@ -169,12 +209,12 @@ class Processor(LlmService):
parser.add_argument( parser.add_argument(
'-m', '--model', '-m', '--model',
default=default_model, default=default_model,
help=f'LLM model (default: {default_model})' help=f'LLM model (e.g., gemini-1.5-flash-001, claude-3-sonnet@20240229) (default: {default_model})'
) )
parser.add_argument( parser.add_argument(
'-k', '--private-key', '-k', '--private-key',
help=f'Google Cloud private JSON file' help=f'Google Cloud private JSON file (optional, uses ADC if not provided)'
) )
parser.add_argument( parser.add_argument(
@ -198,5 +238,4 @@ class Processor(LlmService):
) )
def run(): def run():
Processor.launch(default_ident, __doc__) Processor.launch(default_ident, __doc__)