trustgraph/trustgraph-vertexai/trustgraph/model/text_completion/vertexai/llm.py
cybermaggedon a9197d11ee
Feature/configure flows (#345)
- Keeps processing in different flows separate so that data can go to different stores / collections etc.
- Potentially supports different processing flows
- Tidies the processing API with common base-classes for e.g. LLMs, and automatic configuration of 'clients' to use the right queue names in a flow
2025-04-22 20:21:38 +01:00

172 lines
4.9 KiB
Python
Executable file

"""
Simple LLM service, performs text prompt completion using VertexAI on
Google Cloud. Input is prompt, output is response.
"""
from google.oauth2 import service_account
import google
import vertexai
from vertexai.preview.generative_models import (
Content, FunctionDeclaration, GenerativeModel, GenerationConfig,
HarmCategory, HarmBlockThreshold, Part, Tool,
)
from .... exceptions import TooManyRequests
from .... base import LlmService, LlmResult
default_ident = "text-completion"
default_model = 'gemini-1.0-pro-001'
default_region = 'us-central1'
default_temperature = 0.0
default_max_output = 8192
default_private_key = "private.json"
class Processor(LlmService):
def __init__(self, **params):
region = params.get("region", default_region)
model = params.get("model", default_model)
private_key = params.get("private_key", default_private_key)
temperature = params.get("temperature", default_temperature)
max_output = params.get("max_output", default_max_output)
if private_key is None:
raise RuntimeError("Private key file not specified")
super(Processor, self).__init__(**params)
self.parameters = {
"temperature": temperature,
"top_p": 1.0,
"top_k": 32,
"candidate_count": 1,
"max_output_tokens": max_output,
}
self.generation_config = GenerationConfig(
temperature=temperature,
top_p=1.0,
top_k=10,
candidate_count=1,
max_output_tokens=max_output,
)
# Block none doesn't seem to work
block_level = HarmBlockThreshold.BLOCK_ONLY_HIGH
# block_level = HarmBlockThreshold.BLOCK_NONE
self.safety_settings = {
HarmCategory.HARM_CATEGORY_HARASSMENT: block_level,
HarmCategory.HARM_CATEGORY_HATE_SPEECH: block_level,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: block_level,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: block_level,
}
print("Initialise VertexAI...", flush=True)
if private_key:
credentials = (
service_account.Credentials.from_service_account_file(
private_key
)
)
else:
credentials = None
if credentials:
vertexai.init(
location=region,
credentials=credentials,
project=credentials.project_id,
)
else:
vertexai.init(
location=region
)
print(f"Initialise model {model}", flush=True)
self.llm = GenerativeModel(model)
self.model = model
print("Initialisation complete", flush=True)
async def generate_content(self, system, prompt):
try:
prompt = system + "\n\n" + prompt
response = self.llm.generate_content(
prompt, generation_config=self.generation_config,
safety_settings=self.safety_settings
)
resp = LlmResult()
resp.text = response.text
resp.in_token = response.usage_metadata.prompt_token_count
resp.out_token = response.usage_metadata.candidates_token_count
resp.model = self.model
print(f"Input Tokens: {resp.in_token}", flush=True)
print(f"Output Tokens: {resp.out_token}", flush=True)
print("Send response...", flush=True)
return resp
except google.api_core.exceptions.ResourceExhausted as e:
print("Hit rate limit:", e, flush=True)
# Leave rate limit retries to the base handler
raise TooManyRequests()
except Exception as e:
# Apart from rate limits, treat all exceptions as unrecoverable
print(f"Exception: {e}")
raise e
@staticmethod
def add_args(parser):
LlmService.add_args(parser)
parser.add_argument(
'-m', '--model',
default=default_model,
help=f'LLM model (default: {default_model})'
)
parser.add_argument(
'-k', '--private-key',
help=f'Google Cloud private JSON file'
)
parser.add_argument(
'-r', '--region',
default=default_region,
help=f'Google Cloud region (default: {default_region})',
)
parser.add_argument(
'-t', '--temperature',
type=float,
default=default_temperature,
help=f'LLM temperature parameter (default: {default_temperature})'
)
parser.add_argument(
'-x', '--max-output',
type=int,
default=default_max_output,
help=f'LLM max output tokens (default: {default_max_output})'
)
def run():
Processor.launch(default_ident, __doc__)