From 784c6265ee89f8e26271559009358bd1ffed6d25 Mon Sep 17 00:00:00 2001 From: usamimeri_renko <1710269958@qq.com> Date: Thu, 25 Apr 2024 23:41:21 +0800 Subject: [PATCH] update max tokens and support max_tokens field --- .../provider/bedrock/amazon_bedrock_api.py | 18 ++++- metagpt/provider/bedrock/base_provider.py | 1 + metagpt/provider/bedrock/bedrock_provider.py | 3 + metagpt/provider/bedrock/utils.py | 75 ++++++++----------- 4 files changed, 53 insertions(+), 44 deletions(-) diff --git a/metagpt/provider/bedrock/amazon_bedrock_api.py b/metagpt/provider/bedrock/amazon_bedrock_api.py index 7262a11be..e4382f7bd 100644 --- a/metagpt/provider/bedrock/amazon_bedrock_api.py +++ b/metagpt/provider/bedrock/amazon_bedrock_api.py @@ -5,7 +5,7 @@ from metagpt.configs.llm_config import LLMConfig, LLMType from metagpt.provider.base_llm import BaseLLM from metagpt.logs import log_llm_stream, logger from metagpt.provider.bedrock.bedrock_provider import get_provider -from metagpt.provider.bedrock.utils import NOT_SUUPORT_STREAM_MODELS, SUPPORT_STREAM_MODELS +from metagpt.provider.bedrock.utils import NOT_SUUPORT_STREAM_MODELS, get_max_tokens import boto3 @@ -15,6 +15,7 @@ class AmazonBedrockLLM(BaseLLM): self.config = config self.__client = self.__init_client("bedrock-runtime") self.__provider = get_provider(self.config.model) + logger.warning("Amazon bedrock doesn't support async now") def __init_client(self, service_name: Literal["bedrock-runtime", "bedrock"]): # access key from https://us-east-1.console.aws.amazon.com/iam @@ -38,7 +39,13 @@ class AmazonBedrockLLM(BaseLLM): @property def _generate_kwargs(self) -> dict: # for now only use temperature due to the difference of request body + model_max_tokens = get_max_tokens(self.config.model) + if self.config.max_token > model_max_tokens: + max_tokens = model_max_tokens + else: + max_tokens = self.config.max_token return { + self.__provider.max_tokens_field_name: max_tokens, "temperature": self.config.temperature } @@ -59,6 +66,7 @@ class AmazonBedrockLLM(BaseLLM): request_body = self.__provider.get_request_body( messages, **self._generate_kwargs) + response = self.__client.invoke_model_with_response_stream( modelId=self.config.model, body=request_body ) @@ -75,7 +83,13 @@ class AmazonBedrockLLM(BaseLLM): async def acompletion(self, messages: list[dict]): # Amazon bedrock doesn't support async now - return self._achat_completion(messages) + return await self._achat_completion(messages) + + async def acompletion_text(self, messages: list[dict], stream: bool = False, + timeout: int = USE_CONFIG_TIMEOUT) -> str: + if stream: + return await self._achat_completion_stream(messages) + return await self._achat_completion(messages) async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT): return self.completion(messages) diff --git a/metagpt/provider/bedrock/base_provider.py b/metagpt/provider/bedrock/base_provider.py index 724cbb669..c24556645 100644 --- a/metagpt/provider/bedrock/base_provider.py +++ b/metagpt/provider/bedrock/base_provider.py @@ -4,6 +4,7 @@ from abc import ABC, abstractmethod class BaseBedrockProvider(ABC): # to handle different generation kwargs + max_tokens_field_name = "max_tokens" @abstractmethod def _get_completion_from_dict(self, rsp_dict: dict) -> str: diff --git a/metagpt/provider/bedrock/bedrock_provider.py b/metagpt/provider/bedrock/bedrock_provider.py index 47348c083..2aa90c7ee 100644 --- a/metagpt/provider/bedrock/bedrock_provider.py +++ b/metagpt/provider/bedrock/bedrock_provider.py @@ -34,6 +34,7 @@ class CohereProvider(BaseBedrockProvider): class MetaProvider(BaseBedrockProvider): # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html + max_tokens_field_name = "max_gen_len" def messages_to_prompt(self, messages: list[dict]): return messages_to_prompt_llama(messages) @@ -44,6 +45,7 @@ class MetaProvider(BaseBedrockProvider): class Ai21Provider(BaseBedrockProvider): # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-jurassic2.html + max_tokens_field_name = "maxTokens" def _get_completion_from_dict(self, rsp_dict: dict) -> str: return rsp_dict['completions'][0]["data"]["text"] @@ -51,6 +53,7 @@ class Ai21Provider(BaseBedrockProvider): class AmazonProvider(BaseBedrockProvider): # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-text.html + max_tokens_field_name = "maxTokenCount" def get_request_body(self, messages: list[dict], **generate_kwargs): body = json.dumps({ diff --git a/metagpt/provider/bedrock/utils.py b/metagpt/provider/bedrock/utils.py index 58236157d..2df0bf163 100644 --- a/metagpt/provider/bedrock/utils.py +++ b/metagpt/provider/bedrock/utils.py @@ -1,5 +1,35 @@ from metagpt.logs import logger +NOT_SUUPORT_STREAM_MODELS = { + "ai21.j2-grande-instruct": 8000, + "ai21.j2-jumbo-instruct": 8000, + "ai21.j2-mid": 8000, + "ai21.j2-mid-v1": 8000, + "ai21.j2-ultra": 8000, + "ai21.j2-ultra-v1": 8000, +} + +SUPPORT_STREAM_MODELS = { + "amazon.titan-tg1-large": 8000, + "amazon.titan-text-express-v1": 8000, + "anthropic.claude-instant-v1": 100000, + "anthropic.claude-v1": 100000, + "anthropic.claude-v2": 100000, + "anthropic.claude-v2:1": 200000, + "anthropic.claude-3-sonnet-20240229-v1:0": 200000, + "anthropic.claude-3-haiku-20240307-v1:0": 200000, + "anthropic.claude-3-opus-20240229-v1:0": 200000, + "cohere.command-text-v14": 4096, + "cohere.command-light-text-v14": 4096, + "meta.llama2-70b-v1": 4096, + "meta.llama3-8b-instruct-v1:0": 2000, + "meta.llama3-70b-instruct-v1:0": 2000, + "mistral.mistral-7b-instruct-v0:2": 32000, + "mistral.mixtral-8x7b-instruct-v0:1": 32000, + "mistral.mistral-large-2402-v1:0": 32000, +} + + def messages_to_prompt_llama(messages: list[dict]): BOS, EOS = "", "" B_INST, E_INST = "[INST]", "[/INST]" @@ -23,46 +53,7 @@ def messages_to_prompt_llama(messages: list[dict]): return prompt -NOT_SUUPORT_STREAM_MODELS = { - "ai21.j2-grande-instruct", - "ai21.j2-jumbo-instruct", - "ai21.j2-mid", - "ai21.j2-mid-v1", - "ai21.j2-ultra", - "ai21.j2-ultra-v1", -} +def get_max_tokens(model_id) -> int: + return (NOT_SUUPORT_STREAM_MODELS | SUPPORT_STREAM_MODELS)[model_id] + -SUPPORT_STREAM_MODELS = { - "amazon.titan-tg1-large", - "amazon.titan-text-lite-v1:0:4k", - "amazon.titan-text-lite-v1", - "amazon.titan-text-express-v1:0:8k", - "amazon.titan-text-express-v1", - "anthropic.claude-instant-v1:2:100k", - "anthropic.claude-instant-v1", - "anthropic.claude-v2:0:18k", - "anthropic.claude-v2:0:100k", - "anthropic.claude-v2:1:18k", - "anthropic.claude-v2:1:200k", - "anthropic.claude-v2:1", - "anthropic.claude-v2:2:18k", - "anthropic.claude-v2:2:200k", - "anthropic.claude-v2:2", - "anthropic.claude-v2", - "anthropic.claude-3-sonnet-20240229-v1:0:28k", - "anthropic.claude-3-sonnet-20240229-v1:0:200k", - "anthropic.claude-3-sonnet-20240229-v1:0", - "anthropic.claude-3-haiku-20240307-v1:0:48k", - "anthropic.claude-3-haiku-20240307-v1:0:200k", - "anthropic.claude-3-haiku-20240307-v1:0", - "cohere.command-text-v14:7:4k", - "cohere.command-text-v14", - "cohere.command-light-text-v14:7:4k", - "cohere.command-light-text-v14", - "meta.llama2-70b-v1", - "meta.llama3-8b-instruct-v1:0", - "meta.llama3-70b-instruct-v1:0", - "mistral.mistral-7b-instruct-v0:2", - "mistral.mixtral-8x7b-instruct-v0:1", - "mistral.mistral-large-2402-v1:0", -}