From 90fe017240475fc2782ba53bcb0b6de206d39808 Mon Sep 17 00:00:00 2001 From: Cyber MacGeddon Date: Wed, 7 Aug 2024 20:37:24 +0100 Subject: [PATCH] Bedrock LLM fix --- .../model/text_completion/bedrock/llm.py | 55 ++++++++++++++++--- 1 file changed, 48 insertions(+), 7 deletions(-) diff --git a/trustgraph/model/text_completion/bedrock/llm.py b/trustgraph/model/text_completion/bedrock/llm.py index 21ff73d8..8a158183 100755 --- a/trustgraph/model/text_completion/bedrock/llm.py +++ b/trustgraph/model/text_completion/bedrock/llm.py @@ -69,13 +69,54 @@ class Processor(ConsumerProducer): prompt = v.prompt - promptbody = json.dumps({ - "prompt": prompt, - "max_tokens": 8192, - "temperature": 0.0, - "top_p": 0.99, - "top_k": 40 - }) + # Mistral Input Format + if self.model.startswith("mistral"): + promptbody = json.dumps({ + "prompt": prompt, + "max_tokens": 8192, + "temperature": 0.0, + "top_p": 0.99, + "top_k": 40 + }) + + # Llama 3.1 Input Format + elif self.model.startswith("meta"): + promptbody = json.dumps({ + "prompt": prompt, + "max_gen_len": 2048, + "temperature": 0.0, + "top_p": 0.95, + }) + + # Anthropic Input Format + elif self.model.startswith("anthropic"): + promptbody = json.dumps({ + "anthropic_version": "bedrock-2023-05-31", + "max_tokens": 8192, + "temperature": 0, + "top_p": 0.999, + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": prompt + } + ] + } + ] + }) + + # Use Mistral format as defualt + else: + promptbody = json.dumps({ + "prompt": prompt, + "max_tokens": 8192, + "temperature": 0.0, + "top_p": 0.99, + "top_k": 40 + }) accept = 'application/json' contentType = 'application/json'