From 61141e97b9c0177c88d208cdc0e8dff8e1f2e095 Mon Sep 17 00:00:00 2001 From: JGalego Date: Sat, 17 Aug 2024 03:57:38 +0100 Subject: [PATCH] Added support for Command R/R+ --- metagpt/provider/bedrock/bedrock_provider.py | 48 ++++++++++++++++++-- 1 file changed, 44 insertions(+), 4 deletions(-) diff --git a/metagpt/provider/bedrock/bedrock_provider.py b/metagpt/provider/bedrock/bedrock_provider.py index 3222d2c76..90475bf41 100644 --- a/metagpt/provider/bedrock/bedrock_provider.py +++ b/metagpt/provider/bedrock/bedrock_provider.py @@ -57,15 +57,52 @@ class AnthropicProvider(BaseBedrockProvider): class CohereProvider(BaseBedrockProvider): - # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command.html + # For more information, see + # (Command) https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command.html + # (Command R/R+) https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command-r-plus.html + + def __init__(self, model_name: str) -> None: + self.model_name = model_name def _get_completion_from_dict(self, rsp_dict: dict) -> str: return rsp_dict["generations"][0]["text"] + def messages_to_prompt(self, messages: list[dict]) -> str: + if "command-r" in self.model_name: + role_map = { + "user": "USER", + "assistant": "CHATBOT", + "system": "USER" + } + messages = list( + map( + lambda message: { + "role": role_map[message["role"]], + "message": message["content"] + }, + messages + ) + ) + return messages + else: + """[{"role": "user", "content": msg}] to user: etc.""" + return "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) + def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs): - body = json.dumps( - {"prompt": self.messages_to_prompt(messages), "stream": kwargs.get("stream", False), **generate_kwargs} - ) + prompt = self.messages_to_prompt(messages) + if "command-r" in self.model_name: + chat_history, message = prompt[:-1], prompt[-1]["message"] + body = json.dumps({ + "message": message, + "chat_history": chat_history, + **generate_kwargs + }) + else: + body = json.dumps({ + "prompt": prompt, + "stream": kwargs.get("stream", False), + **generate_kwargs + }) return body def get_choice_text_from_stream(self, event) -> str: @@ -166,4 +203,7 @@ def get_provider(model_id: str): elif provider == "ai21": # distinguish between j2 and jamba return PROVIDERS[provider](model_name.split("-")[0]) + elif provider == "cohere": + # distinguish between R/R+ and older models + return PROVIDERS[provider](model_name) return PROVIDERS[provider]()