From 980a4c5b9320c4107bc5756e41d7d0cfedda2d3a Mon Sep 17 00:00:00 2001 From: Jack Colquitt <126733989+JackColquitt@users.noreply.github.com> Date: Sun, 25 Aug 2024 12:44:00 -0700 Subject: [PATCH] Jamba and Cohere support for Bedrock (#34) --- .../model/text_completion/bedrock/llm.py | 38 ++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/trustgraph/model/text_completion/bedrock/llm.py b/trustgraph/model/text_completion/bedrock/llm.py index 4da0ef14..61e52ae0 100755 --- a/trustgraph/model/text_completion/bedrock/llm.py +++ b/trustgraph/model/text_completion/bedrock/llm.py @@ -117,7 +117,29 @@ class Processor(ConsumerProducer): ] }) - # Use Mistral format as defualt + # Jamba Input Format + elif self.model.startswith("ai21"): + promptbody = json.dumps({ + "max_tokens": self.max_output, + "temperature": self.temperature, + "top_p": 0.9, + "messages": [ + { + "role": "user", + "content": prompt + } + ] + }) + + # Cohere Input Format + elif self.model.startswith("cohere"): + promptbody = json.dumps({ + "max_tokens": self.max_output, + "temperature": self.temperature, + "message": prompt + }) + + # Use Mistral format as defualt else: promptbody = json.dumps({ "prompt": prompt, @@ -151,6 +173,20 @@ class Processor(ConsumerProducer): model_response = json.loads(response["body"].read()) outputtext = model_response["generation"] + # Jamba Response Structure + elif self.model.startswith("ai21"): + content = response['body'].read() + content_str = content.decode('utf-8') + content_json = json.loads(content_str) + outputtext = content_json['choices'][0]['message']['content'] + + # Cohere Input Format + elif self.model.startswith("cohere"): + content = response['body'].read() + content_str = content.decode('utf-8') + content_json = json.loads(content_str) + outputtext = content_json['text'] + # Use Mistral as default else: response_body = json.loads(response.get("body").read())