Migrate to VertexAI to google-genai SDK from deprecated library (#632)

* Migrate to VertexAI to google-genai SDK from deprecated library

* Fix tests, mock the correct API
This commit is contained in:
cybermaggedon 2026-02-09 20:43:33 +00:00 committed by GitHub
parent 2781c7d87c
commit f24f1ebd80
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 223 additions and 245 deletions

View file

@ -12,7 +12,8 @@ requires-python = ">=3.8"
dependencies = [
"trustgraph-base>=2.0,<2.1",
"pulsar-client",
"google-cloud-aiplatform",
"google-genai",
"google-api-core",
"prometheus-client",
"anthropic",
]

View file

@ -4,29 +4,19 @@ Google Cloud. Input is prompt, output is response.
Supports both Google's Gemini models and Anthropic's Claude models.
"""
#
# Somewhat perplexed by the Google Cloud SDK choices. We're going off this
# one, which uses the google-cloud-aiplatform library:
# https://cloud.google.com/python/docs/reference/vertexai/1.94.0
# It seems it is possible to invoke VertexAI from the google-genai
# SDK too:
# https://googleapis.github.io/python-genai/genai.html#module-genai.client
# That would make this code look very much like the GoogleAIStudio
# code. And maybe not reliant on the google-cloud-aiplatform library?
#
# This module's imports bring in a lot of libraries.
# Uses the google-genai SDK for Gemini models on Vertex AI:
# https://googleapis.github.io/python-genai/genai.html#module-genai.client
#
from google.oauth2 import service_account
import google.auth
import google.api_core.exceptions
import vertexai
import logging
# Why is preview here?
from vertexai.generative_models import (
Content, FunctionDeclaration, GenerativeModel, GenerationConfig,
HarmCategory, HarmBlockThreshold, Part, Tool, SafetySetting,
)
from google import genai
from google.genai import types
from google.genai.types import HarmCategory, HarmBlockThreshold
from google.api_core.exceptions import ResourceExhausted
# Added for Anthropic model support
from anthropic import AnthropicVertex, RateLimitError
@ -67,12 +57,10 @@ class Processor(LlmService):
self.max_output = max_output
self.private_key = private_key
# Model client caches
self.model_clients = {} # Cache for model instances
self.generation_configs = {} # Cache for generation configs (Gemini only)
self.anthropic_client = None # Single Anthropic client (handles multiple models)
# Anthropic client (handles Claude models)
self.anthropic_client = None
# Shared parameters for both model types
# Shared parameters for Anthropic models
self.api_params = {
"temperature": temperature,
"top_p": 1.0,
@ -84,10 +72,10 @@ class Processor(LlmService):
# Unified credential and project ID loading
if private_key:
credentials = (
service_account.Credentials.from_service_account_file(
private_key
)
scopes = ["https://www.googleapis.com/auth/cloud-platform"]
credentials = service_account.Credentials.from_service_account_file(
private_key,
scopes=scopes
)
project_id = credentials.project_id
else:
@ -103,12 +91,13 @@ class Processor(LlmService):
self.credentials = credentials
self.project_id = project_id
# Initialize Vertex AI SDK for Gemini models
init_kwargs = {'location': region, 'project': project_id}
if credentials and private_key: # Pass credentials only if from a file
init_kwargs['credentials'] = credentials
vertexai.init(**init_kwargs)
# Initialize Google GenAI client for Gemini models
self.client = genai.Client(
vertexai=True,
project=project_id,
location=region,
credentials=credentials
)
# Pre-initialize Anthropic client if needed (single client handles all Claude models)
if 'claude' in self.default_model.lower():
@ -117,24 +106,27 @@ class Processor(LlmService):
# Safety settings for Gemini models
block_level = HarmBlockThreshold.BLOCK_ONLY_HIGH
self.safety_settings = [
SafetySetting(
category = HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold = block_level,
types.SafetySetting(
category=HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold=block_level,
),
SafetySetting(
category = HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold = block_level,
types.SafetySetting(
category=HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold=block_level,
),
SafetySetting(
category = HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold = block_level,
types.SafetySetting(
category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold=block_level,
),
SafetySetting(
category = HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold = block_level,
types.SafetySetting(
category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold=block_level,
),
]
# Cache for generation configs
self.generation_configs = {}
logger.info("VertexAI initialization complete")
def _get_anthropic_client(self):
@ -152,25 +144,26 @@ class Processor(LlmService):
return self.anthropic_client
def _get_gemini_model(self, model_name, temperature=None):
"""Get or create a Gemini model instance"""
if model_name not in self.model_clients:
logger.info(f"Creating GenerativeModel instance for '{model_name}'")
self.model_clients[model_name] = GenerativeModel(model_name)
def _get_or_create_config(self, model_name, temperature=None):
"""Get or create generation config with dynamic temperature"""
# Use provided temperature or fall back to default
effective_temperature = temperature if temperature is not None else self.temperature
# Create generation config with the effective temperature
generation_config = GenerationConfig(
temperature=effective_temperature,
top_p=1.0,
top_k=10,
candidate_count=1,
max_output_tokens=self.max_output,
)
# Create cache key that includes temperature to avoid conflicts
cache_key = f"{model_name}:{effective_temperature}"
return self.model_clients[model_name], generation_config
if cache_key not in self.generation_configs:
logger.info(f"Creating generation config for '{model_name}' with temperature {effective_temperature}")
self.generation_configs[cache_key] = types.GenerateContentConfig(
temperature=effective_temperature,
top_p=1.0,
top_k=40,
max_output_tokens=self.max_output,
response_mime_type="text/plain",
safety_settings=self.safety_settings,
)
return self.generation_configs[cache_key]
async def generate_content(self, system, prompt, model=None, temperature=None):
@ -205,22 +198,24 @@ class Processor(LlmService):
model=model_name
)
else:
# Gemini API combines system and user prompts
# Gemini API using google-genai SDK
logger.debug(f"Sending request to Gemini model '{model_name}'...")
full_prompt = system + "\n\n" + prompt
llm, generation_config = self._get_gemini_model(model_name, effective_temperature)
generation_config = self._get_or_create_config(model_name, effective_temperature)
# Set system instruction per request (can't be cached)
generation_config.system_instruction = system
response = llm.generate_content(
full_prompt, generation_config = generation_config,
safety_settings = self.safety_settings,
response = self.client.models.generate_content(
model=model_name,
config=generation_config,
contents=prompt,
)
resp = LlmResult(
text = response.text,
in_token = response.usage_metadata.prompt_token_count,
out_token = response.usage_metadata.candidates_token_count,
model = model_name
text=response.text,
in_token=int(response.usage_metadata.prompt_token_count),
out_token=int(response.usage_metadata.candidates_token_count),
model=model_name
)
logger.info(f"Input Tokens: {resp.in_token}")
@ -229,7 +224,7 @@ class Processor(LlmService):
return resp
except (google.api_core.exceptions.ResourceExhausted, RateLimitError) as e:
except (ResourceExhausted, RateLimitError) as e:
logger.warning(f"Hit rate limit: {e}")
# Leave rate limit retries to the base handler
raise TooManyRequests()
@ -302,17 +297,16 @@ class Processor(LlmService):
logger.info(f"Output Tokens: {total_out_tokens}")
else:
# Gemini streaming
# Gemini streaming using google-genai SDK
logger.debug(f"Streaming request to Gemini model '{model_name}'...")
full_prompt = system + "\n\n" + prompt
llm, generation_config = self._get_gemini_model(model_name, effective_temperature)
generation_config = self._get_or_create_config(model_name, effective_temperature)
generation_config.system_instruction = system
response = llm.generate_content(
full_prompt,
generation_config=generation_config,
safety_settings=self.safety_settings,
stream=True # Enable streaming
response = self.client.models.generate_content_stream(
model=model_name,
config=generation_config,
contents=prompt,
)
total_in_tokens = 0
@ -348,7 +342,7 @@ class Processor(LlmService):
logger.info(f"Input Tokens: {total_in_tokens}")
logger.info(f"Output Tokens: {total_out_tokens}")
except (google.api_core.exceptions.ResourceExhausted, RateLimitError) as e:
except (ResourceExhausted, RateLimitError) as e:
logger.warning(f"Hit rate limit during streaming: {e}")
raise TooManyRequests()