Support bedrock inference profiles (#314)

* Break out enums for different model types

* Add model detection for inference profiles in US and EU

* Encapsulate model handling, make it easier to manage
This commit is contained in:
cybermaggedon 2025-03-15 12:39:15 +00:00 committed by GitHub
parent 741d54cf72
commit 1db6dd5dfd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -8,6 +8,7 @@ import boto3
import json
from prometheus_client import Histogram
import os
import enum
from .... schema import TextCompletionRequest, TextCompletionResponse, Error
from .... schema import text_completion_request_queue
@ -24,6 +25,8 @@ default_subscriber = module
default_model = 'mistral.mistral-large-2407-v1:0'
default_temperature = 0.0
default_max_output = 2048
default_top_p = 0.99
default_top_k = 40
# Actually, these could all just be None, no need to get environment
# variables, as Boto3 would pick all these up if not passed in as args
@ -33,6 +36,119 @@ default_session_token = os.getenv("AWS_SESSION_TOKEN", None)
default_profile = os.getenv("AWS_PROFILE", None)
default_region = os.getenv("AWS_DEFAULT_REGION", None)
# Variant API handling depends on the model type
class ModelHandler:
def __init__(self):
self.temperature = default_temperature
self.max_output = default_max_output
self.top_p = default_top_p
self.top_k = default_top_k
def set_temperature(self, temperature):
self.temperature = temperature
def set_max_output(self, max_output):
self.max_output = max_output
def set_top_p(self, top_p):
self.top_p = top_p
def set_top_k(self, top_k):
self.top_k = top_k
def encode_request(self, system, prompt):
raise RuntimeError("format_request not implemented")
def decode_response(self, response):
raise RuntimeError("format_request not implemented")
class Mistral(ModelHandler):
def __init__(self):
self.top_p = 0.99
self.top_k = 40
def encode_request(self, system, prompt):
return json.dumps({
"prompt": f"{system}\n\n{prompt}",
"max_tokens": self.max_output,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
})
def decode_response(self, response):
response_body = json.loads(response.get("body").read())
return response_body['outputs'][0]['text']
# Llama 3
class Meta(ModelHandler):
def __init__(self):
self.top_p = 0.95
def encode_request(self, system, prompt):
return json.dumps({
"prompt": f"{system}\n\n{prompt}",
"max_gen_len": self.max_output,
"temperature": self.temperature,
"top_p": self.top_p,
})
def decode_response(self, response):
model_response = json.loads(response["body"].read())
return model_response["generation"]
class Anthropic(ModelHandler):
def __init__(self):
self.top_p = 0.999
def encode_request(self, system, prompt):
return json.dumps({
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": self.max_output,
"temperature": self.temperature,
"top_p": self.top_p,
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": f"{system}\n\n{prompt}",
}
]
}
]
})
def decode_response(self, response):
model_response = json.loads(response["body"].read())
return model_response['content'][0]['text']
class Ai21(ModelHandler):
def __init__(self):
self.top_p = 0.9
def encode_request(self, system, prompt):
return json.dumps({
"max_tokens": self.max_output,
"temperature": self.temperature,
"top_p": self.top_p,
"messages": [
{
"role": "user",
"content": f"{system}\n\n{prompt}"
}
]
})
def decode_response(self, response):
content = response['body'].read()
content_str = content.decode('utf-8')
content_json = json.loads(content_str)
return content_json['choices'][0]['message']['content']
class Cohere(ModelHandler):
def encode_request(self, system, prompt):
return json.dumps({
"max_tokens": self.max_output,
"temperature": self.temperature,
"message": f"{system}\n\n{prompt}",
})
def decode_response(self, response):
content = response['body'].read()
content_str = content.decode('utf-8')
content_json = json.loads(content_str)
return content_json['text']
Default=Mistral
class Processor(ConsumerProducer):
def __init__(self, **params):
@ -97,6 +213,10 @@ class Processor(ConsumerProducer):
self.temperature = temperature
self.max_output = max_output
self.variant = self.determine_variant(self.model)()
self.variant.set_temperature(temperature)
self.variant.set_max_output(max_output)
self.session = boto3.Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
@ -109,6 +229,34 @@ class Processor(ConsumerProducer):
print("Initialised", flush=True)
def determine_variant(self, model):
# FIXME: Missing, Amazon models, Deepseek
# This set of conditions deals with normal bedrock on-demand usage
if self.model.startswith("mistral"):
return Mistral
elif self.model.startswith("meta"):
return Meta
elif self.model.startswith("anthropic"):
return Anthropic
elif self.model.startswith("ai21"):
return Ai21
elif self.model.startswith("cohere"):
return Cohere
# The inference profiles
if self.model.startswith("us.meta"):
return Meta
elif self.model.startswith("us.anthropic"):
return Anthropic
elif self.model.startswith("eu.meta"):
return Meta
elif self.model.startswith("eu.anthropic"):
return Anthropic
return Default
async def handle(self, msg):
v = msg.value()
@ -119,127 +267,27 @@ class Processor(ConsumerProducer):
print(f"Handling prompt {id}...", flush=True)
prompt = v.system + "\n\n" + v.prompt
try:
# Mistral Input Format
if self.model.startswith("mistral"):
promptbody = json.dumps({
"prompt": prompt,
"max_tokens": self.max_output,
"temperature": self.temperature,
"top_p": 0.99,
"top_k": 40
})
# Llama 3.1 Input Format
elif self.model.startswith("meta"):
promptbody = json.dumps({
"prompt": prompt,
"max_gen_len": self.max_output,
"temperature": self.temperature,
"top_p": 0.95,
})
# Anthropic Input Format
elif self.model.startswith("anthropic"):
promptbody = json.dumps({
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": self.max_output,
"temperature": self.temperature,
"top_p": 0.999,
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt
}
]
}
]
})
# Jamba Input Format
elif self.model.startswith("ai21"):
promptbody = json.dumps({
"max_tokens": self.max_output,
"temperature": self.temperature,
"top_p": 0.9,
"messages": [
{
"role": "user",
"content": prompt
}
]
})
# Cohere Input Format
elif self.model.startswith("cohere"):
promptbody = json.dumps({
"max_tokens": self.max_output,
"temperature": self.temperature,
"message": prompt
})
# Use Mistral format as defualt
else:
promptbody = json.dumps({
"prompt": prompt,
"max_tokens": self.max_output,
"temperature": self.temperature,
"top_p": 0.99,
"top_k": 40
})
promptbody = self.variant.encode_request(v.system, v.prompt)
accept = 'application/json'
contentType = 'application/json'
with __class__.text_completion_metric.time():
response = self.bedrock.invoke_model(
body=promptbody, modelId=self.model, accept=accept,
body=promptbody,
modelId=self.model,
accept=accept,
contentType=contentType
)
# Mistral Response Structure
if self.model.startswith("mistral"):
response_body = json.loads(response.get("body").read())
outputtext = response_body['outputs'][0]['text']
# Claude Response Structure
elif self.model.startswith("anthropic"):
model_response = json.loads(response["body"].read())
outputtext = model_response['content'][0]['text']
# Llama 3.1 Response Structure
elif self.model.startswith("meta"):
model_response = json.loads(response["body"].read())
outputtext = model_response["generation"]
# Jamba Response Structure
elif self.model.startswith("ai21"):
content = response['body'].read()
content_str = content.decode('utf-8')
content_json = json.loads(content_str)
outputtext = content_json['choices'][0]['message']['content']
# Cohere Input Format
elif self.model.startswith("cohere"):
content = response['body'].read()
content_str = content.decode('utf-8')
content_json = json.loads(content_str)
outputtext = content_json['text']
# Use Mistral as default
else:
response_body = json.loads(response.get("body").read())
outputtext = response_body['outputs'][0]['text']
# Response structure decode
outputtext = self.variant.decode_response(response)
metadata = response['ResponseMetadata']['HTTPHeaders']
inputtokens = int(metadata['x-amzn-bedrock-input-token-count'])
outputtokens = int(metadata['x-amzn-bedrock-output-token-count'])
outputtokens = int(metadata['x-amzn-bedrock-output-token-count'])
print(outputtext, flush=True)
print(f"Input Tokens: {inputtokens}", flush=True)