diff --git a/metagpt/provider/bedrock/bedrock_provider.py b/metagpt/provider/bedrock/bedrock_provider.py index 729697ad7..850b57c4f 100644 --- a/metagpt/provider/bedrock/bedrock_provider.py +++ b/metagpt/provider/bedrock/bedrock_provider.py @@ -1,7 +1,7 @@ import json from typing import Literal from metagpt.provider.bedrock.base_provider import BaseBedrockProvider -from metagpt.provider.bedrock.utils import messages_to_prompt_llama2, messages_to_prompt_llama3 +from metagpt.provider.bedrock.utils import messages_to_prompt_llama2, messages_to_prompt_llama3, messages_to_prompt_claude class MistralProvider(BaseBedrockProvider): @@ -16,6 +16,8 @@ class MistralProvider(BaseBedrockProvider): class AnthropicProvider(BaseBedrockProvider): # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html + def messages_to_prompt(self, messages: list[dict]) -> str: + return messages_to_prompt_claude(messages) def get_request_body(self, messages: list[dict], **generate_kwargs): body = json.dumps( diff --git a/metagpt/provider/bedrock/utils.py b/metagpt/provider/bedrock/utils.py index 61778e8e8..47a23caeb 100644 --- a/metagpt/provider/bedrock/utils.py +++ b/metagpt/provider/bedrock/utils.py @@ -70,5 +70,18 @@ def messages_to_prompt_llama3(messages: list[dict]): return prompt +def messages_to_prompt_claude(messages: list[dict]): + GENERAL_TEMPLATE = "\n\n{role}: {content}" + prompt = "" + for message in messages: + role = message["role"] + content = message["content"] + prompt += GENERAL_TEMPLATE.format(role=role, content=content) + if role != "assistant": + prompt += f"\n\nAssistant:" + return prompt + + def get_max_tokens(model_id) -> int: return (NOT_SUUPORT_STREAM_MODELS | SUPPORT_STREAM_MODELS)[model_id] +