mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-03 04:12:37 +02:00
Jamba and Cohere support for Bedrock (#34)
This commit is contained in:
parent
d69de52b04
commit
980a4c5b93
1 changed files with 37 additions and 1 deletions
|
|
@ -117,7 +117,29 @@ class Processor(ConsumerProducer):
|
||||||
]
|
]
|
||||||
})
|
})
|
||||||
|
|
||||||
# Use Mistral format as defualt
|
# Jamba Input Format
|
||||||
|
elif self.model.startswith("ai21"):
|
||||||
|
promptbody = json.dumps({
|
||||||
|
"max_tokens": self.max_output,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"top_p": 0.9,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": prompt
|
||||||
|
}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
# Cohere Input Format
|
||||||
|
elif self.model.startswith("cohere"):
|
||||||
|
promptbody = json.dumps({
|
||||||
|
"max_tokens": self.max_output,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"message": prompt
|
||||||
|
})
|
||||||
|
|
||||||
|
# Use Mistral format as defualt
|
||||||
else:
|
else:
|
||||||
promptbody = json.dumps({
|
promptbody = json.dumps({
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
|
|
@ -151,6 +173,20 @@ class Processor(ConsumerProducer):
|
||||||
model_response = json.loads(response["body"].read())
|
model_response = json.loads(response["body"].read())
|
||||||
outputtext = model_response["generation"]
|
outputtext = model_response["generation"]
|
||||||
|
|
||||||
|
# Jamba Response Structure
|
||||||
|
elif self.model.startswith("ai21"):
|
||||||
|
content = response['body'].read()
|
||||||
|
content_str = content.decode('utf-8')
|
||||||
|
content_json = json.loads(content_str)
|
||||||
|
outputtext = content_json['choices'][0]['message']['content']
|
||||||
|
|
||||||
|
# Cohere Input Format
|
||||||
|
elif self.model.startswith("cohere"):
|
||||||
|
content = response['body'].read()
|
||||||
|
content_str = content.decode('utf-8')
|
||||||
|
content_json = json.loads(content_str)
|
||||||
|
outputtext = content_json['text']
|
||||||
|
|
||||||
# Use Mistral as default
|
# Use Mistral as default
|
||||||
else:
|
else:
|
||||||
response_body = json.loads(response.get("body").read())
|
response_body = json.loads(response.get("body").read())
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue