mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-06-06 13:25:13 +02:00
Updates to Google AI: (#394)
- Changed GoogleAIStudio LLM code to match latest documentation - Very minor tweak to vertexai LLM code - just matching what's in SDK docs no actual change to implementation. - Tweaked VertexAI container build to speed up in dev - Comments in LLM code to mention which docs it was built from. Google SDKs are confusing ATM.
This commit is contained in:
parent
25abf802e9
commit
448819ed47
4 changed files with 83 additions and 52 deletions
|
|
@ -4,10 +4,18 @@ Simple LLM service, performs text prompt completion using GoogleAIStudio.
|
|||
Input is prompt, output is response.
|
||||
"""
|
||||
|
||||
import google.generativeai as genai
|
||||
from google.generativeai.types import HarmCategory, HarmBlockThreshold
|
||||
#
|
||||
# Using this SDK:
|
||||
# https://googleapis.github.io/python-genai/genai.html#module-genai.client
|
||||
#
|
||||
# Seems to have simpler dependencies on the 'VertexAI' service, which
|
||||
# TrustGraph implements in the trustgraph-vertexai package.
|
||||
#
|
||||
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
from google.genai.types import HarmCategory, HarmBlockThreshold
|
||||
from google.api_core.exceptions import ResourceExhausted
|
||||
from prometheus_client import Histogram
|
||||
import os
|
||||
|
||||
from .... exceptions import TooManyRequests
|
||||
|
|
@ -15,7 +23,7 @@ from .... base import LlmService, LlmResult
|
|||
|
||||
default_ident = "text-completion"
|
||||
|
||||
default_model = 'gemini-1.5-flash-002'
|
||||
default_model = 'gemini-2.0-flash-001'
|
||||
default_temperature = 0.0
|
||||
default_max_output = 8192
|
||||
default_api_key = os.getenv("GOOGLE_AI_STUDIO_KEY")
|
||||
|
|
@ -40,58 +48,56 @@ class Processor(LlmService):
|
|||
}
|
||||
)
|
||||
|
||||
genai.configure(api_key=api_key)
|
||||
self.client = genai.Client(api_key=api_key)
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.max_output = max_output
|
||||
|
||||
self.generation_config = {
|
||||
"temperature": temperature,
|
||||
"top_p": 1,
|
||||
"top_k": 40,
|
||||
"max_output_tokens": max_output,
|
||||
"response_mime_type": "text/plain",
|
||||
}
|
||||
|
||||
block_level = HarmBlockThreshold.BLOCK_ONLY_HIGH
|
||||
|
||||
self.safety_settings={
|
||||
HarmCategory.HARM_CATEGORY_HATE_SPEECH: block_level,
|
||||
HarmCategory.HARM_CATEGORY_HARASSMENT: block_level,
|
||||
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: block_level,
|
||||
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: block_level,
|
||||
self.safety_settings = [
|
||||
types.SafetySetting(
|
||||
category = HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
||||
threshold = block_level,
|
||||
),
|
||||
types.SafetySetting(
|
||||
category = HarmCategory.HARM_CATEGORY_HARASSMENT,
|
||||
threshold = block_level,
|
||||
),
|
||||
types.SafetySetting(
|
||||
category = HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
||||
threshold = block_level,
|
||||
),
|
||||
types.SafetySetting(
|
||||
category = HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
||||
threshold = block_level,
|
||||
),
|
||||
# There is a documentation conflict on whether or not
|
||||
# CIVIC_INTEGRITY is a valid category
|
||||
# HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY: block_level,
|
||||
}
|
||||
|
||||
self.llm = genai.GenerativeModel(
|
||||
model_name=model,
|
||||
generation_config=self.generation_config,
|
||||
safety_settings=self.safety_settings,
|
||||
system_instruction="You are a helpful AI assistant.",
|
||||
)
|
||||
]
|
||||
|
||||
print("Initialised", flush=True)
|
||||
|
||||
async def generate_content(self, system, prompt):
|
||||
|
||||
# FIXME: There's a system prompt above. Maybe if system changes,
|
||||
# then reset self.llm? It shouldn't do, because system prompt
|
||||
# is set system wide?
|
||||
|
||||
# Or... could keep different LLM structures for different system
|
||||
# prompts?
|
||||
|
||||
prompt = system + "\n\n" + prompt
|
||||
generation_config = types.GenerateContentConfig(
|
||||
temperature = self.temperature,
|
||||
top_p = 1,
|
||||
top_k = 40,
|
||||
max_output_tokens = self.max_output,
|
||||
response_mime_type = "text/plain",
|
||||
system_instruction = system,
|
||||
safety_settings = self.safety_settings,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
chat_session = self.llm.start_chat(
|
||||
history=[
|
||||
]
|
||||
response = self.client.models.generate_content(
|
||||
model=self.model,
|
||||
config=generation_config,
|
||||
contents=prompt,
|
||||
)
|
||||
response = chat_session.send_message(prompt)
|
||||
|
||||
resp = response.text
|
||||
inputtokens = int(response.usage_metadata.prompt_token_count)
|
||||
|
|
@ -158,3 +164,4 @@ class Processor(LlmService):
|
|||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue