diff --git a/containers/Containerfile.vertexai b/containers/Containerfile.vertexai index 659cf376..9d7028c0 100644 --- a/containers/Containerfile.vertexai +++ b/containers/Containerfile.vertexai @@ -13,6 +13,7 @@ RUN dnf install -y python3.12 && \ python -m ensurepip --upgrade && \ pip3 install --no-cache-dir wheel aiohttp && \ pip3 install --no-cache-dir pulsar-client==3.7.0 && \ + pip3 install --no-cache-dir google-cloud-aiplatform && \ dnf clean all # ---------------------------------------------------------------------------- @@ -48,5 +49,3 @@ RUN \ WORKDIR / - - diff --git a/trustgraph-flow/setup.py b/trustgraph-flow/setup.py index 55786d6e..256f7bc5 100644 --- a/trustgraph-flow/setup.py +++ b/trustgraph-flow/setup.py @@ -42,7 +42,7 @@ setuptools.setup( "cryptography", "falkordb", "fastembed", - "google-generativeai", + "google-genai", "ibis", "jsonschema", "langchain", diff --git a/trustgraph-flow/trustgraph/model/text_completion/googleaistudio/llm.py b/trustgraph-flow/trustgraph/model/text_completion/googleaistudio/llm.py index 051e2fe5..ec568e61 100644 --- a/trustgraph-flow/trustgraph/model/text_completion/googleaistudio/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/googleaistudio/llm.py @@ -4,10 +4,18 @@ Simple LLM service, performs text prompt completion using GoogleAIStudio. Input is prompt, output is response. """ -import google.generativeai as genai -from google.generativeai.types import HarmCategory, HarmBlockThreshold +# +# Using this SDK: +# https://googleapis.github.io/python-genai/genai.html#module-genai.client +# +# Seems to have simpler dependencies on the 'VertexAI' service, which +# TrustGraph implements in the trustgraph-vertexai package. +# + +from google import genai +from google.genai import types +from google.genai.types import HarmCategory, HarmBlockThreshold from google.api_core.exceptions import ResourceExhausted -from prometheus_client import Histogram import os from .... exceptions import TooManyRequests @@ -15,7 +23,7 @@ from .... base import LlmService, LlmResult default_ident = "text-completion" -default_model = 'gemini-1.5-flash-002' +default_model = 'gemini-2.0-flash-001' default_temperature = 0.0 default_max_output = 8192 default_api_key = os.getenv("GOOGLE_AI_STUDIO_KEY") @@ -40,58 +48,56 @@ class Processor(LlmService): } ) - genai.configure(api_key=api_key) + self.client = genai.Client(api_key=api_key) self.model = model self.temperature = temperature self.max_output = max_output - self.generation_config = { - "temperature": temperature, - "top_p": 1, - "top_k": 40, - "max_output_tokens": max_output, - "response_mime_type": "text/plain", - } - block_level = HarmBlockThreshold.BLOCK_ONLY_HIGH - self.safety_settings={ - HarmCategory.HARM_CATEGORY_HATE_SPEECH: block_level, - HarmCategory.HARM_CATEGORY_HARASSMENT: block_level, - HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: block_level, - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: block_level, + self.safety_settings = [ + types.SafetySetting( + category = HarmCategory.HARM_CATEGORY_HATE_SPEECH, + threshold = block_level, + ), + types.SafetySetting( + category = HarmCategory.HARM_CATEGORY_HARASSMENT, + threshold = block_level, + ), + types.SafetySetting( + category = HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold = block_level, + ), + types.SafetySetting( + category = HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold = block_level, + ), # There is a documentation conflict on whether or not # CIVIC_INTEGRITY is a valid category # HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY: block_level, - } - - self.llm = genai.GenerativeModel( - model_name=model, - generation_config=self.generation_config, - safety_settings=self.safety_settings, - system_instruction="You are a helpful AI assistant.", - ) + ] print("Initialised", flush=True) async def generate_content(self, system, prompt): - # FIXME: There's a system prompt above. Maybe if system changes, - # then reset self.llm? It shouldn't do, because system prompt - # is set system wide? - - # Or... could keep different LLM structures for different system - # prompts? - - prompt = system + "\n\n" + prompt + generation_config = types.GenerateContentConfig( + temperature = self.temperature, + top_p = 1, + top_k = 40, + max_output_tokens = self.max_output, + response_mime_type = "text/plain", + system_instruction = system, + safety_settings = self.safety_settings, + ) try: - chat_session = self.llm.start_chat( - history=[ - ] + response = self.client.models.generate_content( + model=self.model, + config=generation_config, + contents=prompt, ) - response = chat_session.send_message(prompt) resp = response.text inputtokens = int(response.usage_metadata.prompt_token_count) @@ -158,3 +164,4 @@ class Processor(LlmService): def run(): Processor.launch(default_ident, __doc__) + diff --git a/trustgraph-vertexai/trustgraph/model/text_completion/vertexai/llm.py b/trustgraph-vertexai/trustgraph/model/text_completion/vertexai/llm.py index 854be961..c6d869e6 100755 --- a/trustgraph-vertexai/trustgraph/model/text_completion/vertexai/llm.py +++ b/trustgraph-vertexai/trustgraph/model/text_completion/vertexai/llm.py @@ -4,13 +4,26 @@ Simple LLM service, performs text prompt completion using VertexAI on Google Cloud. Input is prompt, output is response. """ +# +# Somewhat perplexed by the Google Cloud SDK choices. We're going off this +# one, which uses the google-cloud-aiplatform library: +# https://cloud.google.com/python/docs/reference/vertexai/1.94.0 +# It seems it is possible to invoke VertexAI from the google-genai +# SDK too: +# https://googleapis.github.io/python-genai/genai.html#module-genai.client +# That would make this code look very much like the GoogleAIStudio +# code. And maybe not reliant on the google-cloud-aiplatform library? +# +# This module's imports bring in a lot of libraries. + from google.oauth2 import service_account import google import vertexai -from vertexai.preview.generative_models import ( +# Why is preview here? +from vertexai.generative_models import ( Content, FunctionDeclaration, GenerativeModel, GenerationConfig, - HarmCategory, HarmBlockThreshold, Part, Tool, + HarmCategory, HarmBlockThreshold, Part, Tool, SafetySetting, ) from .... exceptions import TooManyRequests @@ -18,7 +31,7 @@ from .... base import LlmService, LlmResult default_ident = "text-completion" -default_model = 'gemini-1.0-pro-001' +default_model = 'gemini-2.0-flash-001' default_region = 'us-central1' default_temperature = 0.0 default_max_output = 8192 @@ -59,12 +72,24 @@ class Processor(LlmService): block_level = HarmBlockThreshold.BLOCK_ONLY_HIGH # block_level = HarmBlockThreshold.BLOCK_NONE - self.safety_settings = { - HarmCategory.HARM_CATEGORY_HARASSMENT: block_level, - HarmCategory.HARM_CATEGORY_HATE_SPEECH: block_level, - HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: block_level, - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: block_level, - } + self.safety_settings = [ + SafetySetting( + category = HarmCategory.HARM_CATEGORY_HARASSMENT, + threshold = block_level, + ), + SafetySetting( + category = HarmCategory.HARM_CATEGORY_HATE_SPEECH, + threshold = block_level, + ), + SafetySetting( + category = HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold = block_level, + ), + SafetySetting( + category = HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold = block_level, + ), + ] print("Initialise VertexAI...", flush=True) @@ -101,8 +126,8 @@ class Processor(LlmService): prompt = system + "\n\n" + prompt response = self.llm.generate_content( - prompt, generation_config=self.generation_config, - safety_settings=self.safety_settings + prompt, generation_config = self.generation_config, + safety_settings = self.safety_settings, ) resp = LlmResult(