diff --git a/metagpt/provider/bedrock/amazon_bedrock_api.py b/metagpt/provider/bedrock/amazon_bedrock_api.py
index e4382f7bd..a5cacec8c 100644
--- a/metagpt/provider/bedrock/amazon_bedrock_api.py
+++ b/metagpt/provider/bedrock/amazon_bedrock_api.py
@@ -96,4 +96,3 @@ class AmazonBedrockLLM(BaseLLM):
async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
return self._chat_completion_stream(messages)
-
diff --git a/metagpt/provider/bedrock/bedrock_provider.py b/metagpt/provider/bedrock/bedrock_provider.py
index 2aa90c7ee..729697ad7 100644
--- a/metagpt/provider/bedrock/bedrock_provider.py
+++ b/metagpt/provider/bedrock/bedrock_provider.py
@@ -1,13 +1,14 @@
import json
+from typing import Literal
from metagpt.provider.bedrock.base_provider import BaseBedrockProvider
-from metagpt.provider.bedrock.utils import messages_to_prompt_llama
+from metagpt.provider.bedrock.utils import messages_to_prompt_llama2, messages_to_prompt_llama3
class MistralProvider(BaseBedrockProvider):
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral.html
def messages_to_prompt(self, messages: list[dict]):
- return messages_to_prompt_llama(messages)
+ return messages_to_prompt_llama2(messages)
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
return rsp_dict["outputs"][0]["text"]
@@ -36,8 +37,14 @@ class MetaProvider(BaseBedrockProvider):
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html
max_tokens_field_name = "max_gen_len"
+ def __init__(self, llama_version: Literal["llama2", "llama3"]) -> None:
+ self.llama_version = llama_version
+
def messages_to_prompt(self, messages: list[dict]):
- return messages_to_prompt_llama(messages)
+ if self.llama_version == "llama2":
+ return messages_to_prompt_llama2(messages)
+ else:
+ return messages_to_prompt_llama3(messages)
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
return rsp_dict["generation"]
@@ -72,17 +79,20 @@ class AmazonProvider(BaseBedrockProvider):
PROVIDERS = {
- "mistral": MistralProvider(),
- "meta": MetaProvider(),
- "ai21": Ai21Provider(),
- "cohere": CohereProvider(),
- "anthropic": AnthropicProvider(),
- "amazon": AmazonProvider()
+ "mistral": MistralProvider,
+ "meta": MetaProvider,
+ "ai21": Ai21Provider,
+ "cohere": CohereProvider,
+ "anthropic": AnthropicProvider,
+ "amazon": AmazonProvider
}
def get_provider(model_id: str):
- model_name = model_id.split(".")[0] # meta、mistral……
- if model_name not in PROVIDERS:
- raise KeyError(f"{model_name} is not supported!")
- return PROVIDERS[model_name]
+ provider, model_name = model_id.split(".")[0:2] # meta、mistral……
+ if provider not in PROVIDERS:
+ raise KeyError(f"{provider} is not supported!")
+ if provider == "meta":
+ # distinguish llama2 and llama3
+ return PROVIDERS[provider](model_name[:6])
+ return PROVIDERS[provider]()
diff --git a/metagpt/provider/bedrock/utils.py b/metagpt/provider/bedrock/utils.py
index 2df0bf163..61778e8e8 100644
--- a/metagpt/provider/bedrock/utils.py
+++ b/metagpt/provider/bedrock/utils.py
@@ -29,8 +29,10 @@ SUPPORT_STREAM_MODELS = {
"mistral.mistral-large-2402-v1:0": 32000,
}
+# TODO:use a general function for constructing chat templates.
-def messages_to_prompt_llama(messages: list[dict]):
+
+def messages_to_prompt_llama2(messages: list[dict]):
BOS, EOS = "", ""
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<>\n", "\n<>\n\n"
@@ -53,7 +55,20 @@ def messages_to_prompt_llama(messages: list[dict]):
return prompt
+def messages_to_prompt_llama3(messages: list[dict]):
+ BOS, EOS = "<|begin_of_text|>", "<|eot_id|>"
+ GENERAL_TEMPLATE = "<|start_header_id|>{role}<|end_header_id|>\n\n{content}<|eot_id|>"
+
+ prompt = f"{BOS}"
+ for message in messages:
+ role = message["role"]
+ content = message["content"]
+ prompt += GENERAL_TEMPLATE.format(role=role, content=content)
+ if role != "assistant":
+ prompt += f"<|start_header_id|>assistant<|end_header_id|>"
+
+ return prompt
+
+
def get_max_tokens(model_id) -> int:
return (NOT_SUUPORT_STREAM_MODELS | SUPPORT_STREAM_MODELS)[model_id]
-
-