mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-24 14:15:17 +02:00
Added support for Command R/R+
This commit is contained in:
parent
4d59dd69cc
commit
61141e97b9
1 changed files with 44 additions and 4 deletions
|
|
@ -57,15 +57,52 @@ class AnthropicProvider(BaseBedrockProvider):
|
|||
|
||||
|
||||
class CohereProvider(BaseBedrockProvider):
|
||||
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command.html
|
||||
# For more information, see
|
||||
# (Command) https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command.html
|
||||
# (Command R/R+) https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command-r-plus.html
|
||||
|
||||
def __init__(self, model_name: str) -> None:
|
||||
self.model_name = model_name
|
||||
|
||||
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
|
||||
return rsp_dict["generations"][0]["text"]
|
||||
|
||||
def messages_to_prompt(self, messages: list[dict]) -> str:
|
||||
if "command-r" in self.model_name:
|
||||
role_map = {
|
||||
"user": "USER",
|
||||
"assistant": "CHATBOT",
|
||||
"system": "USER"
|
||||
}
|
||||
messages = list(
|
||||
map(
|
||||
lambda message: {
|
||||
"role": role_map[message["role"]],
|
||||
"message": message["content"]
|
||||
},
|
||||
messages
|
||||
)
|
||||
)
|
||||
return messages
|
||||
else:
|
||||
"""[{"role": "user", "content": msg}] to user: <msg> etc."""
|
||||
return "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||
|
||||
def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs):
|
||||
body = json.dumps(
|
||||
{"prompt": self.messages_to_prompt(messages), "stream": kwargs.get("stream", False), **generate_kwargs}
|
||||
)
|
||||
prompt = self.messages_to_prompt(messages)
|
||||
if "command-r" in self.model_name:
|
||||
chat_history, message = prompt[:-1], prompt[-1]["message"]
|
||||
body = json.dumps({
|
||||
"message": message,
|
||||
"chat_history": chat_history,
|
||||
**generate_kwargs
|
||||
})
|
||||
else:
|
||||
body = json.dumps({
|
||||
"prompt": prompt,
|
||||
"stream": kwargs.get("stream", False),
|
||||
**generate_kwargs
|
||||
})
|
||||
return body
|
||||
|
||||
def get_choice_text_from_stream(self, event) -> str:
|
||||
|
|
@ -166,4 +203,7 @@ def get_provider(model_id: str):
|
|||
elif provider == "ai21":
|
||||
# distinguish between j2 and jamba
|
||||
return PROVIDERS[provider](model_name.split("-")[0])
|
||||
elif provider == "cohere":
|
||||
# distinguish between R/R+ and older models
|
||||
return PROVIDERS[provider](model_name)
|
||||
return PROVIDERS[provider]()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue