Added support for Command R/R+

This commit is contained in:
JGalego 2024-08-17 03:57:38 +01:00
parent 4d59dd69cc
commit 61141e97b9

View file

@ -57,15 +57,52 @@ class AnthropicProvider(BaseBedrockProvider):
class CohereProvider(BaseBedrockProvider):
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command.html
# For more information, see
# (Command) https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command.html
# (Command R/R+) https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command-r-plus.html
def __init__(self, model_name: str) -> None:
self.model_name = model_name
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
return rsp_dict["generations"][0]["text"]
def messages_to_prompt(self, messages: list[dict]) -> str:
if "command-r" in self.model_name:
role_map = {
"user": "USER",
"assistant": "CHATBOT",
"system": "USER"
}
messages = list(
map(
lambda message: {
"role": role_map[message["role"]],
"message": message["content"]
},
messages
)
)
return messages
else:
"""[{"role": "user", "content": msg}] to user: <msg> etc."""
return "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs):
body = json.dumps(
{"prompt": self.messages_to_prompt(messages), "stream": kwargs.get("stream", False), **generate_kwargs}
)
prompt = self.messages_to_prompt(messages)
if "command-r" in self.model_name:
chat_history, message = prompt[:-1], prompt[-1]["message"]
body = json.dumps({
"message": message,
"chat_history": chat_history,
**generate_kwargs
})
else:
body = json.dumps({
"prompt": prompt,
"stream": kwargs.get("stream", False),
**generate_kwargs
})
return body
def get_choice_text_from_stream(self, event) -> str:
@ -166,4 +203,7 @@ def get_provider(model_id: str):
elif provider == "ai21":
# distinguish between j2 and jamba
return PROVIDERS[provider](model_name.split("-")[0])
elif provider == "cohere":
# distinguish between R/R+ and older models
return PROVIDERS[provider](model_name)
return PROVIDERS[provider]()