mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Support bedrock inference profiles (#314)
* Break out enums for different model types * Add model detection for inference profiles in US and EU * Encapsulate model handling, make it easier to manage
This commit is contained in:
parent
741d54cf72
commit
1db6dd5dfd
1 changed files with 155 additions and 107 deletions
|
|
@ -8,6 +8,7 @@ import boto3
|
|||
import json
|
||||
from prometheus_client import Histogram
|
||||
import os
|
||||
import enum
|
||||
|
||||
from .... schema import TextCompletionRequest, TextCompletionResponse, Error
|
||||
from .... schema import text_completion_request_queue
|
||||
|
|
@ -24,6 +25,8 @@ default_subscriber = module
|
|||
default_model = 'mistral.mistral-large-2407-v1:0'
|
||||
default_temperature = 0.0
|
||||
default_max_output = 2048
|
||||
default_top_p = 0.99
|
||||
default_top_k = 40
|
||||
|
||||
# Actually, these could all just be None, no need to get environment
|
||||
# variables, as Boto3 would pick all these up if not passed in as args
|
||||
|
|
@ -33,6 +36,119 @@ default_session_token = os.getenv("AWS_SESSION_TOKEN", None)
|
|||
default_profile = os.getenv("AWS_PROFILE", None)
|
||||
default_region = os.getenv("AWS_DEFAULT_REGION", None)
|
||||
|
||||
# Variant API handling depends on the model type
|
||||
|
||||
class ModelHandler:
|
||||
def __init__(self):
|
||||
self.temperature = default_temperature
|
||||
self.max_output = default_max_output
|
||||
self.top_p = default_top_p
|
||||
self.top_k = default_top_k
|
||||
def set_temperature(self, temperature):
|
||||
self.temperature = temperature
|
||||
def set_max_output(self, max_output):
|
||||
self.max_output = max_output
|
||||
def set_top_p(self, top_p):
|
||||
self.top_p = top_p
|
||||
def set_top_k(self, top_k):
|
||||
self.top_k = top_k
|
||||
def encode_request(self, system, prompt):
|
||||
raise RuntimeError("format_request not implemented")
|
||||
def decode_response(self, response):
|
||||
raise RuntimeError("format_request not implemented")
|
||||
|
||||
class Mistral(ModelHandler):
|
||||
def __init__(self):
|
||||
self.top_p = 0.99
|
||||
self.top_k = 40
|
||||
def encode_request(self, system, prompt):
|
||||
return json.dumps({
|
||||
"prompt": f"{system}\n\n{prompt}",
|
||||
"max_tokens": self.max_output,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"top_k": self.top_k,
|
||||
})
|
||||
def decode_response(self, response):
|
||||
response_body = json.loads(response.get("body").read())
|
||||
return response_body['outputs'][0]['text']
|
||||
|
||||
# Llama 3
|
||||
class Meta(ModelHandler):
|
||||
def __init__(self):
|
||||
self.top_p = 0.95
|
||||
def encode_request(self, system, prompt):
|
||||
return json.dumps({
|
||||
"prompt": f"{system}\n\n{prompt}",
|
||||
"max_gen_len": self.max_output,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
})
|
||||
def decode_response(self, response):
|
||||
model_response = json.loads(response["body"].read())
|
||||
return model_response["generation"]
|
||||
|
||||
class Anthropic(ModelHandler):
|
||||
def __init__(self):
|
||||
self.top_p = 0.999
|
||||
def encode_request(self, system, prompt):
|
||||
return json.dumps({
|
||||
"anthropic_version": "bedrock-2023-05-31",
|
||||
"max_tokens": self.max_output,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"{system}\n\n{prompt}",
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
})
|
||||
def decode_response(self, response):
|
||||
model_response = json.loads(response["body"].read())
|
||||
return model_response['content'][0]['text']
|
||||
|
||||
class Ai21(ModelHandler):
|
||||
def __init__(self):
|
||||
self.top_p = 0.9
|
||||
def encode_request(self, system, prompt):
|
||||
return json.dumps({
|
||||
"max_tokens": self.max_output,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"{system}\n\n{prompt}"
|
||||
}
|
||||
]
|
||||
})
|
||||
def decode_response(self, response):
|
||||
content = response['body'].read()
|
||||
content_str = content.decode('utf-8')
|
||||
content_json = json.loads(content_str)
|
||||
return content_json['choices'][0]['message']['content']
|
||||
|
||||
class Cohere(ModelHandler):
|
||||
def encode_request(self, system, prompt):
|
||||
return json.dumps({
|
||||
"max_tokens": self.max_output,
|
||||
"temperature": self.temperature,
|
||||
"message": f"{system}\n\n{prompt}",
|
||||
})
|
||||
def decode_response(self, response):
|
||||
content = response['body'].read()
|
||||
content_str = content.decode('utf-8')
|
||||
content_json = json.loads(content_str)
|
||||
return content_json['text']
|
||||
|
||||
Default=Mistral
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
|
@ -97,6 +213,10 @@ class Processor(ConsumerProducer):
|
|||
self.temperature = temperature
|
||||
self.max_output = max_output
|
||||
|
||||
self.variant = self.determine_variant(self.model)()
|
||||
self.variant.set_temperature(temperature)
|
||||
self.variant.set_max_output(max_output)
|
||||
|
||||
self.session = boto3.Session(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
|
|
@ -109,6 +229,34 @@ class Processor(ConsumerProducer):
|
|||
|
||||
print("Initialised", flush=True)
|
||||
|
||||
def determine_variant(self, model):
|
||||
|
||||
# FIXME: Missing, Amazon models, Deepseek
|
||||
|
||||
# This set of conditions deals with normal bedrock on-demand usage
|
||||
if self.model.startswith("mistral"):
|
||||
return Mistral
|
||||
elif self.model.startswith("meta"):
|
||||
return Meta
|
||||
elif self.model.startswith("anthropic"):
|
||||
return Anthropic
|
||||
elif self.model.startswith("ai21"):
|
||||
return Ai21
|
||||
elif self.model.startswith("cohere"):
|
||||
return Cohere
|
||||
|
||||
# The inference profiles
|
||||
if self.model.startswith("us.meta"):
|
||||
return Meta
|
||||
elif self.model.startswith("us.anthropic"):
|
||||
return Anthropic
|
||||
elif self.model.startswith("eu.meta"):
|
||||
return Meta
|
||||
elif self.model.startswith("eu.anthropic"):
|
||||
return Anthropic
|
||||
|
||||
return Default
|
||||
|
||||
async def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
|
|
@ -119,127 +267,27 @@ class Processor(ConsumerProducer):
|
|||
|
||||
print(f"Handling prompt {id}...", flush=True)
|
||||
|
||||
prompt = v.system + "\n\n" + v.prompt
|
||||
|
||||
try:
|
||||
|
||||
# Mistral Input Format
|
||||
if self.model.startswith("mistral"):
|
||||
promptbody = json.dumps({
|
||||
"prompt": prompt,
|
||||
"max_tokens": self.max_output,
|
||||
"temperature": self.temperature,
|
||||
"top_p": 0.99,
|
||||
"top_k": 40
|
||||
})
|
||||
|
||||
# Llama 3.1 Input Format
|
||||
elif self.model.startswith("meta"):
|
||||
promptbody = json.dumps({
|
||||
"prompt": prompt,
|
||||
"max_gen_len": self.max_output,
|
||||
"temperature": self.temperature,
|
||||
"top_p": 0.95,
|
||||
})
|
||||
|
||||
# Anthropic Input Format
|
||||
elif self.model.startswith("anthropic"):
|
||||
promptbody = json.dumps({
|
||||
"anthropic_version": "bedrock-2023-05-31",
|
||||
"max_tokens": self.max_output,
|
||||
"temperature": self.temperature,
|
||||
"top_p": 0.999,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
# Jamba Input Format
|
||||
elif self.model.startswith("ai21"):
|
||||
promptbody = json.dumps({
|
||||
"max_tokens": self.max_output,
|
||||
"temperature": self.temperature,
|
||||
"top_p": 0.9,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
# Cohere Input Format
|
||||
elif self.model.startswith("cohere"):
|
||||
promptbody = json.dumps({
|
||||
"max_tokens": self.max_output,
|
||||
"temperature": self.temperature,
|
||||
"message": prompt
|
||||
})
|
||||
|
||||
# Use Mistral format as defualt
|
||||
else:
|
||||
promptbody = json.dumps({
|
||||
"prompt": prompt,
|
||||
"max_tokens": self.max_output,
|
||||
"temperature": self.temperature,
|
||||
"top_p": 0.99,
|
||||
"top_k": 40
|
||||
})
|
||||
promptbody = self.variant.encode_request(v.system, v.prompt)
|
||||
|
||||
accept = 'application/json'
|
||||
contentType = 'application/json'
|
||||
|
||||
with __class__.text_completion_metric.time():
|
||||
response = self.bedrock.invoke_model(
|
||||
body=promptbody, modelId=self.model, accept=accept,
|
||||
body=promptbody,
|
||||
modelId=self.model,
|
||||
accept=accept,
|
||||
contentType=contentType
|
||||
)
|
||||
|
||||
# Mistral Response Structure
|
||||
if self.model.startswith("mistral"):
|
||||
response_body = json.loads(response.get("body").read())
|
||||
outputtext = response_body['outputs'][0]['text']
|
||||
|
||||
# Claude Response Structure
|
||||
elif self.model.startswith("anthropic"):
|
||||
model_response = json.loads(response["body"].read())
|
||||
outputtext = model_response['content'][0]['text']
|
||||
|
||||
# Llama 3.1 Response Structure
|
||||
elif self.model.startswith("meta"):
|
||||
model_response = json.loads(response["body"].read())
|
||||
outputtext = model_response["generation"]
|
||||
|
||||
# Jamba Response Structure
|
||||
elif self.model.startswith("ai21"):
|
||||
content = response['body'].read()
|
||||
content_str = content.decode('utf-8')
|
||||
content_json = json.loads(content_str)
|
||||
outputtext = content_json['choices'][0]['message']['content']
|
||||
|
||||
# Cohere Input Format
|
||||
elif self.model.startswith("cohere"):
|
||||
content = response['body'].read()
|
||||
content_str = content.decode('utf-8')
|
||||
content_json = json.loads(content_str)
|
||||
outputtext = content_json['text']
|
||||
|
||||
# Use Mistral as default
|
||||
else:
|
||||
response_body = json.loads(response.get("body").read())
|
||||
outputtext = response_body['outputs'][0]['text']
|
||||
# Response structure decode
|
||||
outputtext = self.variant.decode_response(response)
|
||||
|
||||
metadata = response['ResponseMetadata']['HTTPHeaders']
|
||||
inputtokens = int(metadata['x-amzn-bedrock-input-token-count'])
|
||||
outputtokens = int(metadata['x-amzn-bedrock-output-token-count'])
|
||||
outputtokens = int(metadata['x-amzn-bedrock-output-token-count'])
|
||||
|
||||
print(outputtext, flush=True)
|
||||
print(f"Input Tokens: {inputtokens}", flush=True)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue