mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Migrate to VertexAI to google-genai SDK from deprecated library (#632)
* Migrate to VertexAI to google-genai SDK from deprecated library * Fix tests, mock the correct API
This commit is contained in:
parent
2781c7d87c
commit
f24f1ebd80
4 changed files with 223 additions and 245 deletions
|
|
@ -12,7 +12,8 @@ requires-python = ">=3.8"
|
|||
dependencies = [
|
||||
"trustgraph-base>=2.0,<2.1",
|
||||
"pulsar-client",
|
||||
"google-cloud-aiplatform",
|
||||
"google-genai",
|
||||
"google-api-core",
|
||||
"prometheus-client",
|
||||
"anthropic",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -4,29 +4,19 @@ Google Cloud. Input is prompt, output is response.
|
|||
Supports both Google's Gemini models and Anthropic's Claude models.
|
||||
"""
|
||||
|
||||
#
|
||||
# Somewhat perplexed by the Google Cloud SDK choices. We're going off this
|
||||
# one, which uses the google-cloud-aiplatform library:
|
||||
# https://cloud.google.com/python/docs/reference/vertexai/1.94.0
|
||||
# It seems it is possible to invoke VertexAI from the google-genai
|
||||
# SDK too:
|
||||
# https://googleapis.github.io/python-genai/genai.html#module-genai.client
|
||||
# That would make this code look very much like the GoogleAIStudio
|
||||
# code. And maybe not reliant on the google-cloud-aiplatform library?
|
||||
#
|
||||
# This module's imports bring in a lot of libraries.
|
||||
# Uses the google-genai SDK for Gemini models on Vertex AI:
|
||||
# https://googleapis.github.io/python-genai/genai.html#module-genai.client
|
||||
#
|
||||
|
||||
from google.oauth2 import service_account
|
||||
import google.auth
|
||||
import google.api_core.exceptions
|
||||
import vertexai
|
||||
import logging
|
||||
|
||||
# Why is preview here?
|
||||
from vertexai.generative_models import (
|
||||
Content, FunctionDeclaration, GenerativeModel, GenerationConfig,
|
||||
HarmCategory, HarmBlockThreshold, Part, Tool, SafetySetting,
|
||||
)
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
from google.genai.types import HarmCategory, HarmBlockThreshold
|
||||
from google.api_core.exceptions import ResourceExhausted
|
||||
|
||||
# Added for Anthropic model support
|
||||
from anthropic import AnthropicVertex, RateLimitError
|
||||
|
|
@ -67,12 +57,10 @@ class Processor(LlmService):
|
|||
self.max_output = max_output
|
||||
self.private_key = private_key
|
||||
|
||||
# Model client caches
|
||||
self.model_clients = {} # Cache for model instances
|
||||
self.generation_configs = {} # Cache for generation configs (Gemini only)
|
||||
self.anthropic_client = None # Single Anthropic client (handles multiple models)
|
||||
# Anthropic client (handles Claude models)
|
||||
self.anthropic_client = None
|
||||
|
||||
# Shared parameters for both model types
|
||||
# Shared parameters for Anthropic models
|
||||
self.api_params = {
|
||||
"temperature": temperature,
|
||||
"top_p": 1.0,
|
||||
|
|
@ -84,10 +72,10 @@ class Processor(LlmService):
|
|||
|
||||
# Unified credential and project ID loading
|
||||
if private_key:
|
||||
credentials = (
|
||||
service_account.Credentials.from_service_account_file(
|
||||
private_key
|
||||
)
|
||||
scopes = ["https://www.googleapis.com/auth/cloud-platform"]
|
||||
credentials = service_account.Credentials.from_service_account_file(
|
||||
private_key,
|
||||
scopes=scopes
|
||||
)
|
||||
project_id = credentials.project_id
|
||||
else:
|
||||
|
|
@ -103,12 +91,13 @@ class Processor(LlmService):
|
|||
self.credentials = credentials
|
||||
self.project_id = project_id
|
||||
|
||||
# Initialize Vertex AI SDK for Gemini models
|
||||
init_kwargs = {'location': region, 'project': project_id}
|
||||
if credentials and private_key: # Pass credentials only if from a file
|
||||
init_kwargs['credentials'] = credentials
|
||||
|
||||
vertexai.init(**init_kwargs)
|
||||
# Initialize Google GenAI client for Gemini models
|
||||
self.client = genai.Client(
|
||||
vertexai=True,
|
||||
project=project_id,
|
||||
location=region,
|
||||
credentials=credentials
|
||||
)
|
||||
|
||||
# Pre-initialize Anthropic client if needed (single client handles all Claude models)
|
||||
if 'claude' in self.default_model.lower():
|
||||
|
|
@ -117,24 +106,27 @@ class Processor(LlmService):
|
|||
# Safety settings for Gemini models
|
||||
block_level = HarmBlockThreshold.BLOCK_ONLY_HIGH
|
||||
self.safety_settings = [
|
||||
SafetySetting(
|
||||
category = HarmCategory.HARM_CATEGORY_HARASSMENT,
|
||||
threshold = block_level,
|
||||
types.SafetySetting(
|
||||
category=HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
||||
threshold=block_level,
|
||||
),
|
||||
SafetySetting(
|
||||
category = HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
||||
threshold = block_level,
|
||||
types.SafetySetting(
|
||||
category=HarmCategory.HARM_CATEGORY_HARASSMENT,
|
||||
threshold=block_level,
|
||||
),
|
||||
SafetySetting(
|
||||
category = HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
||||
threshold = block_level,
|
||||
types.SafetySetting(
|
||||
category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
||||
threshold=block_level,
|
||||
),
|
||||
SafetySetting(
|
||||
category = HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
||||
threshold = block_level,
|
||||
types.SafetySetting(
|
||||
category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
||||
threshold=block_level,
|
||||
),
|
||||
]
|
||||
|
||||
# Cache for generation configs
|
||||
self.generation_configs = {}
|
||||
|
||||
logger.info("VertexAI initialization complete")
|
||||
|
||||
def _get_anthropic_client(self):
|
||||
|
|
@ -152,25 +144,26 @@ class Processor(LlmService):
|
|||
|
||||
return self.anthropic_client
|
||||
|
||||
def _get_gemini_model(self, model_name, temperature=None):
|
||||
"""Get or create a Gemini model instance"""
|
||||
if model_name not in self.model_clients:
|
||||
logger.info(f"Creating GenerativeModel instance for '{model_name}'")
|
||||
self.model_clients[model_name] = GenerativeModel(model_name)
|
||||
|
||||
def _get_or_create_config(self, model_name, temperature=None):
|
||||
"""Get or create generation config with dynamic temperature"""
|
||||
# Use provided temperature or fall back to default
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
|
||||
# Create generation config with the effective temperature
|
||||
generation_config = GenerationConfig(
|
||||
temperature=effective_temperature,
|
||||
top_p=1.0,
|
||||
top_k=10,
|
||||
candidate_count=1,
|
||||
max_output_tokens=self.max_output,
|
||||
)
|
||||
# Create cache key that includes temperature to avoid conflicts
|
||||
cache_key = f"{model_name}:{effective_temperature}"
|
||||
|
||||
return self.model_clients[model_name], generation_config
|
||||
if cache_key not in self.generation_configs:
|
||||
logger.info(f"Creating generation config for '{model_name}' with temperature {effective_temperature}")
|
||||
self.generation_configs[cache_key] = types.GenerateContentConfig(
|
||||
temperature=effective_temperature,
|
||||
top_p=1.0,
|
||||
top_k=40,
|
||||
max_output_tokens=self.max_output,
|
||||
response_mime_type="text/plain",
|
||||
safety_settings=self.safety_settings,
|
||||
)
|
||||
|
||||
return self.generation_configs[cache_key]
|
||||
|
||||
async def generate_content(self, system, prompt, model=None, temperature=None):
|
||||
|
||||
|
|
@ -205,22 +198,24 @@ class Processor(LlmService):
|
|||
model=model_name
|
||||
)
|
||||
else:
|
||||
# Gemini API combines system and user prompts
|
||||
# Gemini API using google-genai SDK
|
||||
logger.debug(f"Sending request to Gemini model '{model_name}'...")
|
||||
full_prompt = system + "\n\n" + prompt
|
||||
|
||||
llm, generation_config = self._get_gemini_model(model_name, effective_temperature)
|
||||
generation_config = self._get_or_create_config(model_name, effective_temperature)
|
||||
# Set system instruction per request (can't be cached)
|
||||
generation_config.system_instruction = system
|
||||
|
||||
response = llm.generate_content(
|
||||
full_prompt, generation_config = generation_config,
|
||||
safety_settings = self.safety_settings,
|
||||
response = self.client.models.generate_content(
|
||||
model=model_name,
|
||||
config=generation_config,
|
||||
contents=prompt,
|
||||
)
|
||||
|
||||
resp = LlmResult(
|
||||
text = response.text,
|
||||
in_token = response.usage_metadata.prompt_token_count,
|
||||
out_token = response.usage_metadata.candidates_token_count,
|
||||
model = model_name
|
||||
text=response.text,
|
||||
in_token=int(response.usage_metadata.prompt_token_count),
|
||||
out_token=int(response.usage_metadata.candidates_token_count),
|
||||
model=model_name
|
||||
)
|
||||
|
||||
logger.info(f"Input Tokens: {resp.in_token}")
|
||||
|
|
@ -229,7 +224,7 @@ class Processor(LlmService):
|
|||
|
||||
return resp
|
||||
|
||||
except (google.api_core.exceptions.ResourceExhausted, RateLimitError) as e:
|
||||
except (ResourceExhausted, RateLimitError) as e:
|
||||
logger.warning(f"Hit rate limit: {e}")
|
||||
# Leave rate limit retries to the base handler
|
||||
raise TooManyRequests()
|
||||
|
|
@ -302,17 +297,16 @@ class Processor(LlmService):
|
|||
logger.info(f"Output Tokens: {total_out_tokens}")
|
||||
|
||||
else:
|
||||
# Gemini streaming
|
||||
# Gemini streaming using google-genai SDK
|
||||
logger.debug(f"Streaming request to Gemini model '{model_name}'...")
|
||||
full_prompt = system + "\n\n" + prompt
|
||||
|
||||
llm, generation_config = self._get_gemini_model(model_name, effective_temperature)
|
||||
generation_config = self._get_or_create_config(model_name, effective_temperature)
|
||||
generation_config.system_instruction = system
|
||||
|
||||
response = llm.generate_content(
|
||||
full_prompt,
|
||||
generation_config=generation_config,
|
||||
safety_settings=self.safety_settings,
|
||||
stream=True # Enable streaming
|
||||
response = self.client.models.generate_content_stream(
|
||||
model=model_name,
|
||||
config=generation_config,
|
||||
contents=prompt,
|
||||
)
|
||||
|
||||
total_in_tokens = 0
|
||||
|
|
@ -348,7 +342,7 @@ class Processor(LlmService):
|
|||
logger.info(f"Input Tokens: {total_in_tokens}")
|
||||
logger.info(f"Output Tokens: {total_out_tokens}")
|
||||
|
||||
except (google.api_core.exceptions.ResourceExhausted, RateLimitError) as e:
|
||||
except (ResourceExhausted, RateLimitError) as e:
|
||||
logger.warning(f"Hit rate limit during streaming: {e}")
|
||||
raise TooManyRequests()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue