Jamba and Cohere support for Bedrock (#34)

This commit is contained in:
Jack Colquitt 2024-08-25 12:44:00 -07:00 committed by GitHub
parent d69de52b04
commit 980a4c5b93
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -117,7 +117,29 @@ class Processor(ConsumerProducer):
]
})
# Use Mistral format as defualt
# Jamba Input Format
elif self.model.startswith("ai21"):
promptbody = json.dumps({
"max_tokens": self.max_output,
"temperature": self.temperature,
"top_p": 0.9,
"messages": [
{
"role": "user",
"content": prompt
}
]
})
# Cohere Input Format
elif self.model.startswith("cohere"):
promptbody = json.dumps({
"max_tokens": self.max_output,
"temperature": self.temperature,
"message": prompt
})
# Use Mistral format as defualt
else:
promptbody = json.dumps({
"prompt": prompt,
@ -151,6 +173,20 @@ class Processor(ConsumerProducer):
model_response = json.loads(response["body"].read())
outputtext = model_response["generation"]
# Jamba Response Structure
elif self.model.startswith("ai21"):
content = response['body'].read()
content_str = content.decode('utf-8')
content_json = json.loads(content_str)
outputtext = content_json['choices'][0]['message']['content']
# Cohere Input Format
elif self.model.startswith("cohere"):
content = response['body'].read()
content_str = content.decode('utf-8')
content_json = json.loads(content_str)
outputtext = content_json['text']
# Use Mistral as default
else:
response_body = json.loads(response.get("body").read())