trustgraph/trustgraph-bedrock/trustgraph/model/text_completion/bedrock/llm.py
2025-09-24 16:36:25 +01:00

354 lines
11 KiB
Python
Executable file

"""
Simple LLM service, performs text prompt completion using AWS Bedrock.
Input is prompt, output is response. Mistral is default.
"""
import boto3
import json
import os
import enum
import logging
from .... exceptions import TooManyRequests
from .... base import LlmService, LlmResult
# Module logger
logger = logging.getLogger(__name__)
default_ident = "text-completion"
default_model = 'mistral.mistral-large-2407-v1:0'
default_temperature = 0.0
default_max_output = 2048
default_top_p = 0.99
default_top_k = 40
# Actually, these could all just be None, no need to get environment
# variables, as Boto3 would pick all these up if not passed in as args
default_access_key_id = os.getenv("AWS_ACCESS_KEY_ID", None)
default_secret_access_key = os.getenv("AWS_SECRET_ACCESS_KEY", None)
default_session_token = os.getenv("AWS_SESSION_TOKEN", None)
default_profile = os.getenv("AWS_PROFILE", None)
default_region = os.getenv("AWS_DEFAULT_REGION", None)
# Variant API handling depends on the model type
class ModelHandler:
def __init__(self):
self.temperature = default_temperature
self.max_output = default_max_output
self.top_p = default_top_p
self.top_k = default_top_k
def set_temperature(self, temperature):
self.temperature = temperature
def set_max_output(self, max_output):
self.max_output = max_output
def set_top_p(self, top_p):
self.top_p = top_p
def set_top_k(self, top_k):
self.top_k = top_k
def encode_request(self, system, prompt):
raise RuntimeError("format_request not implemented")
def decode_response(self, response):
raise RuntimeError("format_request not implemented")
class Mistral(ModelHandler):
def __init__(self):
self.top_p = 0.99
self.top_k = 40
def encode_request(self, system, prompt):
return json.dumps({
"prompt": f"{system}\n\n{prompt}",
"max_tokens": self.max_output,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
})
def decode_response(self, response):
response_body = json.loads(response.get("body").read())
return response_body['outputs'][0]['text']
# Llama 3
class Meta(ModelHandler):
def __init__(self):
self.top_p = 0.95
def encode_request(self, system, prompt):
return json.dumps({
"prompt": f"{system}\n\n{prompt}",
"max_gen_len": self.max_output,
"temperature": self.temperature,
"top_p": self.top_p,
})
def decode_response(self, response):
model_response = json.loads(response["body"].read())
return model_response["generation"]
class Anthropic(ModelHandler):
def __init__(self):
self.top_p = 0.999
def encode_request(self, system, prompt):
return json.dumps({
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": self.max_output,
"temperature": self.temperature,
"top_p": self.top_p,
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": f"{system}\n\n{prompt}",
}
]
}
]
})
def decode_response(self, response):
model_response = json.loads(response["body"].read())
return model_response['content'][0]['text']
class Ai21(ModelHandler):
def __init__(self):
self.top_p = 0.9
def encode_request(self, system, prompt):
return json.dumps({
"max_tokens": self.max_output,
"temperature": self.temperature,
"top_p": self.top_p,
"messages": [
{
"role": "user",
"content": f"{system}\n\n{prompt}"
}
]
})
def decode_response(self, response):
content = response['body'].read()
content_str = content.decode('utf-8')
content_json = json.loads(content_str)
return content_json['choices'][0]['message']['content']
class Cohere(ModelHandler):
def encode_request(self, system, prompt):
return json.dumps({
"max_tokens": self.max_output,
"temperature": self.temperature,
"message": f"{system}\n\n{prompt}",
})
def decode_response(self, response):
content = response['body'].read()
content_str = content.decode('utf-8')
content_json = json.loads(content_str)
return content_json['text']
Default=Mistral
class Processor(LlmService):
def __init__(self, **params):
logger.debug(f"Bedrock LLM initialized with params: {params}")
model = params.get("model", default_model)
temperature = params.get("temperature", default_temperature)
max_output = params.get("max_output", default_max_output)
aws_access_key_id = params.get(
"aws_access_key_id", default_access_key_id
)
aws_secret_access_key = params.get(
"aws_secret_access_key", default_secret_access_key
)
aws_session_token = params.get(
"aws_session_token", default_session_token
)
aws_region = params.get(
"aws_region", default_region
)
aws_profile = params.get(
"aws_profile", default_profile
)
super(Processor, self).__init__(
**params | {
"model": model,
"temperature": temperature,
"max_output": max_output,
}
)
# Store default configuration
self.default_model = model
self.temperature = temperature
self.max_output = max_output
# Cache for model variants to avoid re-initialization
self.model_variants = {}
self.session = boto3.Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
profile_name=aws_profile,
region_name=aws_region,
)
self.bedrock = self.session.client(service_name='bedrock-runtime')
logger.info("Bedrock LLM service initialized")
def determine_variant(self, model):
# FIXME: Missing, Amazon models, Deepseek
# This set of conditions deals with normal bedrock on-demand usage
if model.startswith("mistral"):
return Mistral
elif model.startswith("meta"):
return Meta
elif model.startswith("anthropic"):
return Anthropic
elif model.startswith("ai21"):
return Ai21
elif model.startswith("cohere"):
return Cohere
# The inference profiles
if model.startswith("us.meta"):
return Meta
elif model.startswith("us.anthropic"):
return Anthropic
elif model.startswith("eu.meta"):
return Meta
elif model.startswith("eu.anthropic"):
return Anthropic
return Default
def _get_or_create_variant(self, model_name):
"""Get cached model variant or create new one"""
if model_name not in self.model_variants:
logger.info(f"Creating model variant for '{model_name}'")
variant_class = self.determine_variant(model_name)
variant = variant_class()
variant.set_temperature(self.temperature)
variant.set_max_output(self.max_output)
self.model_variants[model_name] = variant
return self.model_variants[model_name]
async def generate_content(self, system, prompt, model=None):
# Use provided model or fall back to default
model_name = model or self.default_model
logger.debug(f"Using model: {model_name}")
try:
# Get the appropriate variant for this model
variant = self._get_or_create_variant(model_name)
promptbody = variant.encode_request(system, prompt)
accept = 'application/json'
contentType = 'application/json'
response = self.bedrock.invoke_model(
body=promptbody,
modelId=model_name,
accept=accept,
contentType=contentType
)
# Response structure decode
outputtext = variant.decode_response(response)
metadata = response['ResponseMetadata']['HTTPHeaders']
inputtokens = int(metadata['x-amzn-bedrock-input-token-count'])
outputtokens = int(metadata['x-amzn-bedrock-output-token-count'])
logger.debug(f"LLM output: {outputtext}")
logger.info(f"Input Tokens: {inputtokens}")
logger.info(f"Output Tokens: {outputtokens}")
resp = LlmResult(
text = outputtext,
in_token = inputtokens,
out_token = outputtokens,
model = model_name
)
return resp
except self.bedrock.exceptions.ThrottlingException as e:
logger.warning(f"Hit rate limit: {e}")
# Leave rate limit retries to the base handler
raise TooManyRequests()
except Exception as e:
# Apart from rate limits, treat all exceptions as unrecoverable
logger.error(f"Bedrock LLM exception ({type(e).__name__}): {e}", exc_info=True)
raise e
@staticmethod
def add_args(parser):
LlmService.add_args(parser)
parser.add_argument(
'-m', '--model',
default="mistral.mistral-large-2407-v1:0",
help=f'Bedrock model (default: Mistral-Large-2407)'
)
parser.add_argument(
'-z', '--aws-access-key-id',
default=default_access_key_id,
help=f'AWS access key ID'
)
parser.add_argument(
'-k', '--aws-secret-access-key',
default=default_secret_access_key,
help=f'AWS secret access key'
)
parser.add_argument(
'-r', '--aws-region',
default=default_region,
help=f'AWS region'
)
parser.add_argument(
'--aws-profile', '--profile',
default=default_profile,
help=f'AWS profile name'
)
parser.add_argument(
'-t', '--temperature',
type=float,
default=default_temperature,
help=f'LLM temperature parameter (default: {default_temperature})'
)
parser.add_argument(
'-x', '--max-output',
type=int,
default=default_max_output,
help=f'LLM max output tokens (default: {default_max_output})'
)
def run():
Processor.launch(default_ident, __doc__)