diff --git a/trustgraph-vertexai/pyproject.toml b/trustgraph-vertexai/pyproject.toml index e5166a19..0259b682 100644 --- a/trustgraph-vertexai/pyproject.toml +++ b/trustgraph-vertexai/pyproject.toml @@ -12,7 +12,7 @@ requires-python = ">=3.8" dependencies = [ "trustgraph-base>=2.0,<2.1", "pulsar-client", - "google-cloud-aiplatform", + "google-genai", "prometheus-client", "anthropic", ] diff --git a/trustgraph-vertexai/trustgraph/model/text_completion/vertexai/llm.py b/trustgraph-vertexai/trustgraph/model/text_completion/vertexai/llm.py index 5cf17b4d..59aa5bfe 100755 --- a/trustgraph-vertexai/trustgraph/model/text_completion/vertexai/llm.py +++ b/trustgraph-vertexai/trustgraph/model/text_completion/vertexai/llm.py @@ -4,29 +4,19 @@ Google Cloud. Input is prompt, output is response. Supports both Google's Gemini models and Anthropic's Claude models. """ -# -# Somewhat perplexed by the Google Cloud SDK choices. We're going off this -# one, which uses the google-cloud-aiplatform library: -# https://cloud.google.com/python/docs/reference/vertexai/1.94.0 -# It seems it is possible to invoke VertexAI from the google-genai -# SDK too: -# https://googleapis.github.io/python-genai/genai.html#module-genai.client -# That would make this code look very much like the GoogleAIStudio -# code. And maybe not reliant on the google-cloud-aiplatform library? # -# This module's imports bring in a lot of libraries. +# Uses the google-genai SDK for Gemini models on Vertex AI: +# https://googleapis.github.io/python-genai/genai.html#module-genai.client +# from google.oauth2 import service_account import google.auth -import google.api_core.exceptions -import vertexai import logging -# Why is preview here? -from vertexai.generative_models import ( - Content, FunctionDeclaration, GenerativeModel, GenerationConfig, - HarmCategory, HarmBlockThreshold, Part, Tool, SafetySetting, -) +from google import genai +from google.genai import types +from google.genai.types import HarmCategory, HarmBlockThreshold +from google.api_core.exceptions import ResourceExhausted # Added for Anthropic model support from anthropic import AnthropicVertex, RateLimitError @@ -67,12 +57,10 @@ class Processor(LlmService): self.max_output = max_output self.private_key = private_key - # Model client caches - self.model_clients = {} # Cache for model instances - self.generation_configs = {} # Cache for generation configs (Gemini only) - self.anthropic_client = None # Single Anthropic client (handles multiple models) + # Anthropic client (handles Claude models) + self.anthropic_client = None - # Shared parameters for both model types + # Shared parameters for Anthropic models self.api_params = { "temperature": temperature, "top_p": 1.0, @@ -84,10 +72,10 @@ class Processor(LlmService): # Unified credential and project ID loading if private_key: - credentials = ( - service_account.Credentials.from_service_account_file( - private_key - ) + scopes = ["https://www.googleapis.com/auth/cloud-platform"] + credentials = service_account.Credentials.from_service_account_file( + private_key, + scopes=scopes ) project_id = credentials.project_id else: @@ -103,12 +91,13 @@ class Processor(LlmService): self.credentials = credentials self.project_id = project_id - # Initialize Vertex AI SDK for Gemini models - init_kwargs = {'location': region, 'project': project_id} - if credentials and private_key: # Pass credentials only if from a file - init_kwargs['credentials'] = credentials - - vertexai.init(**init_kwargs) + # Initialize Google GenAI client for Gemini models + self.client = genai.Client( + vertexai=True, + project=project_id, + location=region, + credentials=credentials + ) # Pre-initialize Anthropic client if needed (single client handles all Claude models) if 'claude' in self.default_model.lower(): @@ -117,24 +106,27 @@ class Processor(LlmService): # Safety settings for Gemini models block_level = HarmBlockThreshold.BLOCK_ONLY_HIGH self.safety_settings = [ - SafetySetting( - category = HarmCategory.HARM_CATEGORY_HARASSMENT, - threshold = block_level, + types.SafetySetting( + category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, + threshold=block_level, ), - SafetySetting( - category = HarmCategory.HARM_CATEGORY_HATE_SPEECH, - threshold = block_level, + types.SafetySetting( + category=HarmCategory.HARM_CATEGORY_HARASSMENT, + threshold=block_level, ), - SafetySetting( - category = HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - threshold = block_level, + types.SafetySetting( + category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold=block_level, ), - SafetySetting( - category = HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - threshold = block_level, + types.SafetySetting( + category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=block_level, ), ] + # Cache for generation configs + self.generation_configs = {} + logger.info("VertexAI initialization complete") def _get_anthropic_client(self): @@ -152,25 +144,26 @@ class Processor(LlmService): return self.anthropic_client - def _get_gemini_model(self, model_name, temperature=None): - """Get or create a Gemini model instance""" - if model_name not in self.model_clients: - logger.info(f"Creating GenerativeModel instance for '{model_name}'") - self.model_clients[model_name] = GenerativeModel(model_name) - + def _get_or_create_config(self, model_name, temperature=None): + """Get or create generation config with dynamic temperature""" # Use provided temperature or fall back to default effective_temperature = temperature if temperature is not None else self.temperature - # Create generation config with the effective temperature - generation_config = GenerationConfig( - temperature=effective_temperature, - top_p=1.0, - top_k=10, - candidate_count=1, - max_output_tokens=self.max_output, - ) + # Create cache key that includes temperature to avoid conflicts + cache_key = f"{model_name}:{effective_temperature}" - return self.model_clients[model_name], generation_config + if cache_key not in self.generation_configs: + logger.info(f"Creating generation config for '{model_name}' with temperature {effective_temperature}") + self.generation_configs[cache_key] = types.GenerateContentConfig( + temperature=effective_temperature, + top_p=1.0, + top_k=40, + max_output_tokens=self.max_output, + response_mime_type="text/plain", + safety_settings=self.safety_settings, + ) + + return self.generation_configs[cache_key] async def generate_content(self, system, prompt, model=None, temperature=None): @@ -205,22 +198,24 @@ class Processor(LlmService): model=model_name ) else: - # Gemini API combines system and user prompts + # Gemini API using google-genai SDK logger.debug(f"Sending request to Gemini model '{model_name}'...") - full_prompt = system + "\n\n" + prompt - llm, generation_config = self._get_gemini_model(model_name, effective_temperature) + generation_config = self._get_or_create_config(model_name, effective_temperature) + # Set system instruction per request (can't be cached) + generation_config.system_instruction = system - response = llm.generate_content( - full_prompt, generation_config = generation_config, - safety_settings = self.safety_settings, + response = self.client.models.generate_content( + model=model_name, + config=generation_config, + contents=prompt, ) resp = LlmResult( - text = response.text, - in_token = response.usage_metadata.prompt_token_count, - out_token = response.usage_metadata.candidates_token_count, - model = model_name + text=response.text, + in_token=int(response.usage_metadata.prompt_token_count), + out_token=int(response.usage_metadata.candidates_token_count), + model=model_name ) logger.info(f"Input Tokens: {resp.in_token}") @@ -229,7 +224,7 @@ class Processor(LlmService): return resp - except (google.api_core.exceptions.ResourceExhausted, RateLimitError) as e: + except (ResourceExhausted, RateLimitError) as e: logger.warning(f"Hit rate limit: {e}") # Leave rate limit retries to the base handler raise TooManyRequests() @@ -302,17 +297,16 @@ class Processor(LlmService): logger.info(f"Output Tokens: {total_out_tokens}") else: - # Gemini streaming + # Gemini streaming using google-genai SDK logger.debug(f"Streaming request to Gemini model '{model_name}'...") - full_prompt = system + "\n\n" + prompt - llm, generation_config = self._get_gemini_model(model_name, effective_temperature) + generation_config = self._get_or_create_config(model_name, effective_temperature) + generation_config.system_instruction = system - response = llm.generate_content( - full_prompt, - generation_config=generation_config, - safety_settings=self.safety_settings, - stream=True # Enable streaming + response = self.client.models.generate_content_stream( + model=model_name, + config=generation_config, + contents=prompt, ) total_in_tokens = 0 @@ -348,7 +342,7 @@ class Processor(LlmService): logger.info(f"Input Tokens: {total_in_tokens}") logger.info(f"Output Tokens: {total_out_tokens}") - except (google.api_core.exceptions.ResourceExhausted, RateLimitError) as e: + except (ResourceExhausted, RateLimitError) as e: logger.warning(f"Hit rate limit during streaming: {e}") raise TooManyRequests()