mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-28 18:06:21 +02:00
Jamba and Cohere support for Bedrock (#34)
This commit is contained in:
parent
d69de52b04
commit
980a4c5b93
1 changed files with 37 additions and 1 deletions
|
|
@ -117,7 +117,29 @@ class Processor(ConsumerProducer):
|
|||
]
|
||||
})
|
||||
|
||||
# Use Mistral format as defualt
|
||||
# Jamba Input Format
|
||||
elif self.model.startswith("ai21"):
|
||||
promptbody = json.dumps({
|
||||
"max_tokens": self.max_output,
|
||||
"temperature": self.temperature,
|
||||
"top_p": 0.9,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
# Cohere Input Format
|
||||
elif self.model.startswith("cohere"):
|
||||
promptbody = json.dumps({
|
||||
"max_tokens": self.max_output,
|
||||
"temperature": self.temperature,
|
||||
"message": prompt
|
||||
})
|
||||
|
||||
# Use Mistral format as defualt
|
||||
else:
|
||||
promptbody = json.dumps({
|
||||
"prompt": prompt,
|
||||
|
|
@ -151,6 +173,20 @@ class Processor(ConsumerProducer):
|
|||
model_response = json.loads(response["body"].read())
|
||||
outputtext = model_response["generation"]
|
||||
|
||||
# Jamba Response Structure
|
||||
elif self.model.startswith("ai21"):
|
||||
content = response['body'].read()
|
||||
content_str = content.decode('utf-8')
|
||||
content_json = json.loads(content_str)
|
||||
outputtext = content_json['choices'][0]['message']['content']
|
||||
|
||||
# Cohere Input Format
|
||||
elif self.model.startswith("cohere"):
|
||||
content = response['body'].read()
|
||||
content_str = content.decode('utf-8')
|
||||
content_json = json.loads(content_str)
|
||||
outputtext = content_json['text']
|
||||
|
||||
# Use Mistral as default
|
||||
else:
|
||||
response_body = json.loads(response.get("body").read())
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue