diff --git a/metagpt/provider/bedrock/amazon_bedrock_api.py b/metagpt/provider/bedrock/amazon_bedrock_api.py index e4382f7bd..a5cacec8c 100644 --- a/metagpt/provider/bedrock/amazon_bedrock_api.py +++ b/metagpt/provider/bedrock/amazon_bedrock_api.py @@ -96,4 +96,3 @@ class AmazonBedrockLLM(BaseLLM): async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT): return self._chat_completion_stream(messages) - diff --git a/metagpt/provider/bedrock/bedrock_provider.py b/metagpt/provider/bedrock/bedrock_provider.py index 2aa90c7ee..729697ad7 100644 --- a/metagpt/provider/bedrock/bedrock_provider.py +++ b/metagpt/provider/bedrock/bedrock_provider.py @@ -1,13 +1,14 @@ import json +from typing import Literal from metagpt.provider.bedrock.base_provider import BaseBedrockProvider -from metagpt.provider.bedrock.utils import messages_to_prompt_llama +from metagpt.provider.bedrock.utils import messages_to_prompt_llama2, messages_to_prompt_llama3 class MistralProvider(BaseBedrockProvider): # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral.html def messages_to_prompt(self, messages: list[dict]): - return messages_to_prompt_llama(messages) + return messages_to_prompt_llama2(messages) def _get_completion_from_dict(self, rsp_dict: dict) -> str: return rsp_dict["outputs"][0]["text"] @@ -36,8 +37,14 @@ class MetaProvider(BaseBedrockProvider): # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html max_tokens_field_name = "max_gen_len" + def __init__(self, llama_version: Literal["llama2", "llama3"]) -> None: + self.llama_version = llama_version + def messages_to_prompt(self, messages: list[dict]): - return messages_to_prompt_llama(messages) + if self.llama_version == "llama2": + return messages_to_prompt_llama2(messages) + else: + return messages_to_prompt_llama3(messages) def _get_completion_from_dict(self, rsp_dict: dict) -> str: return rsp_dict["generation"] @@ -72,17 +79,20 @@ class AmazonProvider(BaseBedrockProvider): PROVIDERS = { - "mistral": MistralProvider(), - "meta": MetaProvider(), - "ai21": Ai21Provider(), - "cohere": CohereProvider(), - "anthropic": AnthropicProvider(), - "amazon": AmazonProvider() + "mistral": MistralProvider, + "meta": MetaProvider, + "ai21": Ai21Provider, + "cohere": CohereProvider, + "anthropic": AnthropicProvider, + "amazon": AmazonProvider } def get_provider(model_id: str): - model_name = model_id.split(".")[0] # meta、mistral…… - if model_name not in PROVIDERS: - raise KeyError(f"{model_name} is not supported!") - return PROVIDERS[model_name] + provider, model_name = model_id.split(".")[0:2] # meta、mistral…… + if provider not in PROVIDERS: + raise KeyError(f"{provider} is not supported!") + if provider == "meta": + # distinguish llama2 and llama3 + return PROVIDERS[provider](model_name[:6]) + return PROVIDERS[provider]() diff --git a/metagpt/provider/bedrock/utils.py b/metagpt/provider/bedrock/utils.py index 2df0bf163..61778e8e8 100644 --- a/metagpt/provider/bedrock/utils.py +++ b/metagpt/provider/bedrock/utils.py @@ -29,8 +29,10 @@ SUPPORT_STREAM_MODELS = { "mistral.mistral-large-2402-v1:0": 32000, } +# TODO:use a general function for constructing chat templates. -def messages_to_prompt_llama(messages: list[dict]): + +def messages_to_prompt_llama2(messages: list[dict]): BOS, EOS = "", "" B_INST, E_INST = "[INST]", "[/INST]" B_SYS, E_SYS = "<>\n", "\n<>\n\n" @@ -53,7 +55,20 @@ def messages_to_prompt_llama(messages: list[dict]): return prompt +def messages_to_prompt_llama3(messages: list[dict]): + BOS, EOS = "<|begin_of_text|>", "<|eot_id|>" + GENERAL_TEMPLATE = "<|start_header_id|>{role}<|end_header_id|>\n\n{content}<|eot_id|>" + + prompt = f"{BOS}" + for message in messages: + role = message["role"] + content = message["content"] + prompt += GENERAL_TEMPLATE.format(role=role, content=content) + if role != "assistant": + prompt += f"<|start_header_id|>assistant<|end_header_id|>" + + return prompt + + def get_max_tokens(model_id) -> int: return (NOT_SUUPORT_STREAM_MODELS | SUPPORT_STREAM_MODELS)[model_id] - -