From 6561c7aa7e03756b57ed809b8f97e44e32ffb0d9 Mon Sep 17 00:00:00 2001 From: usamimeri_renko <1710269958@qq.com> Date: Fri, 26 Apr 2024 01:29:43 +0800 Subject: [PATCH] add docs --- .../provider/bedrock/amazon_bedrock_api.py | 27 ++++++++++++++++--- metagpt/provider/bedrock/utils.py | 7 +++-- 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/metagpt/provider/bedrock/amazon_bedrock_api.py b/metagpt/provider/bedrock/amazon_bedrock_api.py index a5cacec8c..640be3534 100644 --- a/metagpt/provider/bedrock/amazon_bedrock_api.py +++ b/metagpt/provider/bedrock/amazon_bedrock_api.py @@ -11,14 +11,21 @@ import boto3 @register_provider([LLMType.AMAZON_BEDROCK]) class AmazonBedrockLLM(BaseLLM): + """ + check out: + https://docs.aws.amazon.com/code-library/latest/ug/python_3_bedrock-runtime_code_examples.html + """ + def __init__(self, config: LLMConfig): self.config = config self.__client = self.__init_client("bedrock-runtime") self.__provider = get_provider(self.config.model) - logger.warning("Amazon bedrock doesn't support async now") + logger.warning( + "Amazon bedrock doesn't support asynchronous calls now") def __init_client(self, service_name: Literal["bedrock-runtime", "bedrock"]): - # access key from https://us-east-1.console.aws.amazon.com/iam + """initialize boto3 client""" + # access key and secret key from https://us-east-1.console.aws.amazon.com/iam self.__credentital_kwards = { "aws_secret_access_key": self.config.secret_key, "aws_access_key_id": self.config.access_key, @@ -29,6 +36,14 @@ class AmazonBedrockLLM(BaseLLM): return client def list_models(self): + """list all available text-generation models + + ```shell + ai21.j2-ultra-v1 Support Streaming:False + meta.llama3-70b-instruct-v1:0 Support Streaming:True + …… + ``` + """ client = self.__init_client("bedrock") # only output text-generation models response = client.list_foundation_models(byOutputModality='TEXT') @@ -38,12 +53,12 @@ class AmazonBedrockLLM(BaseLLM): @property def _generate_kwargs(self) -> dict: - # for now only use temperature due to the difference of request body model_max_tokens = get_max_tokens(self.config.model) if self.config.max_token > model_max_tokens: max_tokens = model_max_tokens else: max_tokens = self.config.max_token + return { self.__provider.max_tokens_field_name: max_tokens, "temperature": self.config.temperature @@ -81,8 +96,12 @@ class AmazonBedrockLLM(BaseLLM): full_text = ("".join(collected_content)).lstrip() return full_text + # boto3 don't support support asynchronous calls. + # for asynchronous version of boto3,check out: + # https://aioboto3.readthedocs.io/en/latest/usage.html + # However,aioboto3 doesn't support invoke model + async def acompletion(self, messages: list[dict]): - # Amazon bedrock doesn't support async now return await self._achat_completion(messages) async def acompletion_text(self, messages: list[dict], stream: bool = False, diff --git a/metagpt/provider/bedrock/utils.py b/metagpt/provider/bedrock/utils.py index 47a23caeb..80b7b82bd 100644 --- a/metagpt/provider/bedrock/utils.py +++ b/metagpt/provider/bedrock/utils.py @@ -1,5 +1,6 @@ from metagpt.logs import logger +# max_tokens for each model NOT_SUUPORT_STREAM_MODELS = { "ai21.j2-grande-instruct": 8000, "ai21.j2-jumbo-instruct": 8000, @@ -29,7 +30,7 @@ SUPPORT_STREAM_MODELS = { "mistral.mistral-large-2402-v1:0": 32000, } -# TODO:use a general function for constructing chat templates. +# TODO:use a more general function for constructing chat templates. def messages_to_prompt_llama2(messages: list[dict]): @@ -64,6 +65,7 @@ def messages_to_prompt_llama3(messages: list[dict]): role = message["role"] content = message["content"] prompt += GENERAL_TEMPLATE.format(role=role, content=content) + if role != "assistant": prompt += f"<|start_header_id|>assistant<|end_header_id|>" @@ -77,11 +79,12 @@ def messages_to_prompt_claude(messages: list[dict]): role = message["role"] content = message["content"] prompt += GENERAL_TEMPLATE.format(role=role, content=content) + if role != "assistant": prompt += f"\n\nAssistant:" + return prompt def get_max_tokens(model_id) -> int: return (NOT_SUUPORT_STREAM_MODELS | SUPPORT_STREAM_MODELS)[model_id] -