diff --git a/trustgraph/model/text_completion/bedrock/llm.py b/trustgraph/model/text_completion/bedrock/llm.py index 4da0ef14..61e52ae0 100755 --- a/trustgraph/model/text_completion/bedrock/llm.py +++ b/trustgraph/model/text_completion/bedrock/llm.py @@ -117,7 +117,29 @@ class Processor(ConsumerProducer): ] }) - # Use Mistral format as defualt + # Jamba Input Format + elif self.model.startswith("ai21"): + promptbody = json.dumps({ + "max_tokens": self.max_output, + "temperature": self.temperature, + "top_p": 0.9, + "messages": [ + { + "role": "user", + "content": prompt + } + ] + }) + + # Cohere Input Format + elif self.model.startswith("cohere"): + promptbody = json.dumps({ + "max_tokens": self.max_output, + "temperature": self.temperature, + "message": prompt + }) + + # Use Mistral format as defualt else: promptbody = json.dumps({ "prompt": prompt, @@ -151,6 +173,20 @@ class Processor(ConsumerProducer): model_response = json.loads(response["body"].read()) outputtext = model_response["generation"] + # Jamba Response Structure + elif self.model.startswith("ai21"): + content = response['body'].read() + content_str = content.decode('utf-8') + content_json = json.loads(content_str) + outputtext = content_json['choices'][0]['message']['content'] + + # Cohere Input Format + elif self.model.startswith("cohere"): + content = response['body'].read() + content_str = content.decode('utf-8') + content_json = json.loads(content_str) + outputtext = content_json['text'] + # Use Mistral as default else: response_body = json.loads(response.get("body").read())