diff --git a/trustgraph-bedrock/trustgraph/model/text_completion/bedrock/llm.py b/trustgraph-bedrock/trustgraph/model/text_completion/bedrock/llm.py index 3bca324d..75868b56 100755 --- a/trustgraph-bedrock/trustgraph/model/text_completion/bedrock/llm.py +++ b/trustgraph-bedrock/trustgraph/model/text_completion/bedrock/llm.py @@ -8,6 +8,7 @@ import boto3 import json from prometheus_client import Histogram import os +import enum from .... schema import TextCompletionRequest, TextCompletionResponse, Error from .... schema import text_completion_request_queue @@ -24,6 +25,8 @@ default_subscriber = module default_model = 'mistral.mistral-large-2407-v1:0' default_temperature = 0.0 default_max_output = 2048 +default_top_p = 0.99 +default_top_k = 40 # Actually, these could all just be None, no need to get environment # variables, as Boto3 would pick all these up if not passed in as args @@ -33,6 +36,119 @@ default_session_token = os.getenv("AWS_SESSION_TOKEN", None) default_profile = os.getenv("AWS_PROFILE", None) default_region = os.getenv("AWS_DEFAULT_REGION", None) +# Variant API handling depends on the model type + +class ModelHandler: + def __init__(self): + self.temperature = default_temperature + self.max_output = default_max_output + self.top_p = default_top_p + self.top_k = default_top_k + def set_temperature(self, temperature): + self.temperature = temperature + def set_max_output(self, max_output): + self.max_output = max_output + def set_top_p(self, top_p): + self.top_p = top_p + def set_top_k(self, top_k): + self.top_k = top_k + def encode_request(self, system, prompt): + raise RuntimeError("format_request not implemented") + def decode_response(self, response): + raise RuntimeError("format_request not implemented") + +class Mistral(ModelHandler): + def __init__(self): + self.top_p = 0.99 + self.top_k = 40 + def encode_request(self, system, prompt): + return json.dumps({ + "prompt": f"{system}\n\n{prompt}", + "max_tokens": self.max_output, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": self.top_k, + }) + def decode_response(self, response): + response_body = json.loads(response.get("body").read()) + return response_body['outputs'][0]['text'] + +# Llama 3 +class Meta(ModelHandler): + def __init__(self): + self.top_p = 0.95 + def encode_request(self, system, prompt): + return json.dumps({ + "prompt": f"{system}\n\n{prompt}", + "max_gen_len": self.max_output, + "temperature": self.temperature, + "top_p": self.top_p, + }) + def decode_response(self, response): + model_response = json.loads(response["body"].read()) + return model_response["generation"] + +class Anthropic(ModelHandler): + def __init__(self): + self.top_p = 0.999 + def encode_request(self, system, prompt): + return json.dumps({ + "anthropic_version": "bedrock-2023-05-31", + "max_tokens": self.max_output, + "temperature": self.temperature, + "top_p": self.top_p, + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": f"{system}\n\n{prompt}", + } + ] + } + ] + }) + def decode_response(self, response): + model_response = json.loads(response["body"].read()) + return model_response['content'][0]['text'] + +class Ai21(ModelHandler): + def __init__(self): + self.top_p = 0.9 + def encode_request(self, system, prompt): + return json.dumps({ + "max_tokens": self.max_output, + "temperature": self.temperature, + "top_p": self.top_p, + "messages": [ + { + "role": "user", + "content": f"{system}\n\n{prompt}" + } + ] + }) + def decode_response(self, response): + content = response['body'].read() + content_str = content.decode('utf-8') + content_json = json.loads(content_str) + return content_json['choices'][0]['message']['content'] + +class Cohere(ModelHandler): + def encode_request(self, system, prompt): + return json.dumps({ + "max_tokens": self.max_output, + "temperature": self.temperature, + "message": f"{system}\n\n{prompt}", + }) + def decode_response(self, response): + content = response['body'].read() + content_str = content.decode('utf-8') + content_json = json.loads(content_str) + return content_json['text'] + +Default=Mistral + class Processor(ConsumerProducer): def __init__(self, **params): @@ -97,6 +213,10 @@ class Processor(ConsumerProducer): self.temperature = temperature self.max_output = max_output + self.variant = self.determine_variant(self.model)() + self.variant.set_temperature(temperature) + self.variant.set_max_output(max_output) + self.session = boto3.Session( aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, @@ -109,6 +229,34 @@ class Processor(ConsumerProducer): print("Initialised", flush=True) + def determine_variant(self, model): + + # FIXME: Missing, Amazon models, Deepseek + + # This set of conditions deals with normal bedrock on-demand usage + if self.model.startswith("mistral"): + return Mistral + elif self.model.startswith("meta"): + return Meta + elif self.model.startswith("anthropic"): + return Anthropic + elif self.model.startswith("ai21"): + return Ai21 + elif self.model.startswith("cohere"): + return Cohere + + # The inference profiles + if self.model.startswith("us.meta"): + return Meta + elif self.model.startswith("us.anthropic"): + return Anthropic + elif self.model.startswith("eu.meta"): + return Meta + elif self.model.startswith("eu.anthropic"): + return Anthropic + + return Default + async def handle(self, msg): v = msg.value() @@ -119,127 +267,27 @@ class Processor(ConsumerProducer): print(f"Handling prompt {id}...", flush=True) - prompt = v.system + "\n\n" + v.prompt - try: - # Mistral Input Format - if self.model.startswith("mistral"): - promptbody = json.dumps({ - "prompt": prompt, - "max_tokens": self.max_output, - "temperature": self.temperature, - "top_p": 0.99, - "top_k": 40 - }) - - # Llama 3.1 Input Format - elif self.model.startswith("meta"): - promptbody = json.dumps({ - "prompt": prompt, - "max_gen_len": self.max_output, - "temperature": self.temperature, - "top_p": 0.95, - }) - - # Anthropic Input Format - elif self.model.startswith("anthropic"): - promptbody = json.dumps({ - "anthropic_version": "bedrock-2023-05-31", - "max_tokens": self.max_output, - "temperature": self.temperature, - "top_p": 0.999, - "messages": [ - { - "role": "user", - "content": [ - { - "type": "text", - "text": prompt - } - ] - } - ] - }) - - # Jamba Input Format - elif self.model.startswith("ai21"): - promptbody = json.dumps({ - "max_tokens": self.max_output, - "temperature": self.temperature, - "top_p": 0.9, - "messages": [ - { - "role": "user", - "content": prompt - } - ] - }) - - # Cohere Input Format - elif self.model.startswith("cohere"): - promptbody = json.dumps({ - "max_tokens": self.max_output, - "temperature": self.temperature, - "message": prompt - }) - - # Use Mistral format as defualt - else: - promptbody = json.dumps({ - "prompt": prompt, - "max_tokens": self.max_output, - "temperature": self.temperature, - "top_p": 0.99, - "top_k": 40 - }) + promptbody = self.variant.encode_request(v.system, v.prompt) accept = 'application/json' contentType = 'application/json' with __class__.text_completion_metric.time(): response = self.bedrock.invoke_model( - body=promptbody, modelId=self.model, accept=accept, + body=promptbody, + modelId=self.model, + accept=accept, contentType=contentType ) - # Mistral Response Structure - if self.model.startswith("mistral"): - response_body = json.loads(response.get("body").read()) - outputtext = response_body['outputs'][0]['text'] - - # Claude Response Structure - elif self.model.startswith("anthropic"): - model_response = json.loads(response["body"].read()) - outputtext = model_response['content'][0]['text'] - - # Llama 3.1 Response Structure - elif self.model.startswith("meta"): - model_response = json.loads(response["body"].read()) - outputtext = model_response["generation"] - - # Jamba Response Structure - elif self.model.startswith("ai21"): - content = response['body'].read() - content_str = content.decode('utf-8') - content_json = json.loads(content_str) - outputtext = content_json['choices'][0]['message']['content'] - - # Cohere Input Format - elif self.model.startswith("cohere"): - content = response['body'].read() - content_str = content.decode('utf-8') - content_json = json.loads(content_str) - outputtext = content_json['text'] - - # Use Mistral as default - else: - response_body = json.loads(response.get("body").read()) - outputtext = response_body['outputs'][0]['text'] + # Response structure decode + outputtext = self.variant.decode_response(response) metadata = response['ResponseMetadata']['HTTPHeaders'] inputtokens = int(metadata['x-amzn-bedrock-input-token-count']) - outputtokens = int(metadata['x-amzn-bedrock-output-token-count']) + outputtokens = int(metadata['x-amzn-bedrock-output-token-count']) print(outputtext, flush=True) print(f"Input Tokens: {inputtokens}", flush=True)