diff --git a/metagpt/configs/llm_config.py b/metagpt/configs/llm_config.py index 7388063aa..b7bbd2bf7 100644 --- a/metagpt/configs/llm_config.py +++ b/metagpt/configs/llm_config.py @@ -57,6 +57,7 @@ class LLMConfig(YamlModel): # For Cloud Service Provider like Baidu/ Alibaba access_key: Optional[str] = None secret_key: Optional[str] = None + session_token: Optional[str] = None endpoint: Optional[str] = None # for self-deployed model on the cloud # For Spark(Xunfei), maybe remove later diff --git a/metagpt/provider/bedrock/bedrock_provider.py b/metagpt/provider/bedrock/bedrock_provider.py index 1236bf56b..90475bf41 100644 --- a/metagpt/provider/bedrock/bedrock_provider.py +++ b/metagpt/provider/bedrock/bedrock_provider.py @@ -57,15 +57,52 @@ class AnthropicProvider(BaseBedrockProvider): class CohereProvider(BaseBedrockProvider): - # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command.html + # For more information, see + # (Command) https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command.html + # (Command R/R+) https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command-r-plus.html + + def __init__(self, model_name: str) -> None: + self.model_name = model_name def _get_completion_from_dict(self, rsp_dict: dict) -> str: return rsp_dict["generations"][0]["text"] + def messages_to_prompt(self, messages: list[dict]) -> str: + if "command-r" in self.model_name: + role_map = { + "user": "USER", + "assistant": "CHATBOT", + "system": "USER" + } + messages = list( + map( + lambda message: { + "role": role_map[message["role"]], + "message": message["content"] + }, + messages + ) + ) + return messages + else: + """[{"role": "user", "content": msg}] to user: etc.""" + return "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) + def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs): - body = json.dumps( - {"prompt": self.messages_to_prompt(messages), "stream": kwargs.get("stream", False), **generate_kwargs} - ) + prompt = self.messages_to_prompt(messages) + if "command-r" in self.model_name: + chat_history, message = prompt[:-1], prompt[-1]["message"] + body = json.dumps({ + "message": message, + "chat_history": chat_history, + **generate_kwargs + }) + else: + body = json.dumps({ + "prompt": prompt, + "stream": kwargs.get("stream", False), + **generate_kwargs + }) return body def get_choice_text_from_stream(self, event) -> str: @@ -95,10 +132,37 @@ class MetaProvider(BaseBedrockProvider): class Ai21Provider(BaseBedrockProvider): # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-jurassic2.html - max_tokens_field_name = "maxTokens" + def __init__(self, model_type: Literal["j2", "jamba"]) -> None: + self.model_type = model_type + if self.model_type == "j2": + self.max_tokens_field_name = "maxTokens" + else: + self.max_tokens_field_name = "max_tokens" + + def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs) -> str: + if self.model_type == "j2": + body = super().get_request_body(messages, generate_kwargs, *args, **kwargs) + else: + body = json.dumps( + { + "messages": messages, + **generate_kwargs, + } + ) + return body + + def get_choice_text_from_stream(self, event) -> str: + rsp_dict = json.loads(event["chunk"]["bytes"]) + completions = rsp_dict.get("choices", [{}])[0].get("delta", {}).get("content", "") + return completions def _get_completion_from_dict(self, rsp_dict: dict) -> str: - return rsp_dict["completions"][0]["data"]["text"] + if self.model_type == "j2": + # See https://docs.ai21.com/reference/j2-complete-ref + return rsp_dict["completions"][0]["data"]["text"] + else: + # See https://docs.ai21.com/reference/jamba-instruct-api + return rsp_dict["choices"][0]["message"]["content"] class AmazonProvider(BaseBedrockProvider): @@ -136,4 +200,10 @@ def get_provider(model_id: str): if provider == "meta": # distinguish llama2 and llama3 return PROVIDERS[provider](model_name[:6]) + elif provider == "ai21": + # distinguish between j2 and jamba + return PROVIDERS[provider](model_name.split("-")[0]) + elif provider == "cohere": + # distinguish between R/R+ and older models + return PROVIDERS[provider](model_name) return PROVIDERS[provider]() diff --git a/metagpt/provider/bedrock/utils.py b/metagpt/provider/bedrock/utils.py index 46520d1d5..66e98c759 100644 --- a/metagpt/provider/bedrock/utils.py +++ b/metagpt/provider/bedrock/utils.py @@ -1,52 +1,97 @@ from metagpt.logs import logger # max_tokens for each model -NOT_SUUPORT_STREAM_MODELS = { - "ai21.j2-grande-instruct": 8000, - "ai21.j2-jumbo-instruct": 8000, - "ai21.j2-mid": 8000, - "ai21.j2-mid-v1": 8000, - "ai21.j2-ultra": 8000, - "ai21.j2-ultra-v1": 8000, +NOT_SUPPORT_STREAM_MODELS = { + # Jurassic-2 Mid-v1 and Ultra-v1 + # + Legacy date: 2024-04-30 (us-west-2/Oregon) + # + EOL date: 2024-08-31 (us-west-2/Oregon) + "ai21.j2-mid-v1": 8191, + "ai21.j2-ultra-v1": 8191, } SUPPORT_STREAM_MODELS = { - "amazon.titan-tg1-large": 8000, - "amazon.titan-text-express-v1": 8000, - "amazon.titan-text-express-v1:0:8k": 8000, - "amazon.titan-text-lite-v1:0:4k": 4000, - "amazon.titan-text-lite-v1": 4000, - "anthropic.claude-instant-v1": 100000, - "anthropic.claude-instant-v1:2:100k": 100000, - "anthropic.claude-v1": 100000, - "anthropic.claude-v2": 100000, - "anthropic.claude-v2:1": 200000, - "anthropic.claude-v2:0:18k": 18000, - "anthropic.claude-v2:1:200k": 200000, - "anthropic.claude-3-sonnet-20240229-v1:0": 200000, - "anthropic.claude-3-sonnet-20240229-v1:0:28k": 28000, - "anthropic.claude-3-sonnet-20240229-v1:0:200k": 200000, - "anthropic.claude-3-haiku-20240307-v1:0": 200000, - "anthropic.claude-3-5-sonnet-20240620-v1:0": 200000, - "anthropic.claude-3-haiku-20240307-v1:0:48k": 48000, - "anthropic.claude-3-haiku-20240307-v1:0:200k": 200000, - # currently (2024-4-29) only available at US West (Oregon) AWS Region. - "anthropic.claude-3-opus-20240229-v1:0": 200000, - "cohere.command-text-v14": 4000, - "cohere.command-text-v14:7:4k": 4000, - "cohere.command-light-text-v14": 4000, - "cohere.command-light-text-v14:7:4k": 4000, - "meta.llama2-13b-chat-v1:0:4k": 4000, - "meta.llama2-13b-chat-v1": 2000, - "meta.llama2-70b-v1": 4000, - "meta.llama2-70b-v1:0:4k": 4000, - "meta.llama2-70b-chat-v1": 2000, - "meta.llama2-70b-chat-v1:0:4k": 2000, - "meta.llama3-8b-instruct-v1:0": 2000, - "meta.llama3-70b-instruct-v1:0": 2000, - "mistral.mistral-7b-instruct-v0:2": 32000, - "mistral.mixtral-8x7b-instruct-v0:1": 32000, - "mistral.mistral-large-2402-v1:0": 32000, + # Jamba-Instruct + # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-jamba.html + "ai21.jamba-instruct-v1:0": 4096, + # Titan Text G1 - Lite + # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-text.html + "amazon.titan-text-lite-v1:0:4k": 4096, + "amazon.titan-text-lite-v1": 4096, + # Titan Text G1 - Express + "amazon.titan-text-express-v1": 8192, + "amazon.titan-text-express-v1:0:8k": 8192, + # Titan Text Premier + "amazon.titan-text-premier-v1:0": 3072, + "amazon.titan-text-premier-v1:0:32k": 3072, + # Claude Instant v1 + # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-text-completion.html + # https://docs.anthropic.com/en/docs/about-claude/models#model-comparison + "anthropic.claude-instant-v1": 4096, + "anthropic.claude-instant-v1:2:100k": 4096, + # Claude v2 + "anthropic.claude-v2": 4096, + "anthropic.claude-v2:0:18k": 4096, + "anthropic.claude-v2:0:100k": 4096, + # Claude v2.1 + "anthropic.claude-v2:1": 4096, + "anthropic.claude-v2:1:18k": 4096, + "anthropic.claude-v2:1:200k": 4096, + # Claude 3 Sonnet + "anthropic.claude-3-sonnet-20240229-v1:0": 4096, + "anthropic.claude-3-sonnet-20240229-v1:0:28k": 4096, + "anthropic.claude-3-sonnet-20240229-v1:0:200k": 4096, + # Claude 3 Haiku + "anthropic.claude-3-haiku-20240307-v1:0": 4096, + "anthropic.claude-3-haiku-20240307-v1:0:48k": 4096, + "anthropic.claude-3-haiku-20240307-v1:0:200k": 4096, + # Claude 3 Opus + "anthropic.claude-3-opus-20240229-v1:0": 4096, + # Claude 3.5 Sonnet + "anthropic.claude-3-5-sonnet-20240620-v1:0": 8192, + # Command Text + # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command.html + "cohere.command-text-v14": 4096, + "cohere.command-text-v14:7:4k": 4096, + # Command Light Text + "cohere.command-light-text-v14": 4096, + "cohere.command-light-text-v14:7:4k": 4096, + # Command R + # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command-r-plus.html + "cohere.command-r-v1:0": 4096, + # Command R+ + "cohere.command-r-plus-v1:0": 4096, + # Llama 2 (--> Llama 3/3.1/3.2) !!! + # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html + # + Legacy: 2024-05-12 + # + EOL: 2024-10-30 + # "meta.llama2-13b-chat-v1": 2048, + # "meta.llama2-13b-chat-v1:0:4k": 2048, + # "meta.llama2-70b-v1": 2048, + # "meta.llama2-70b-v1:0:4k": 2048, + # "meta.llama2-70b-chat-v1": 2048, + # "meta.llama2-70b-chat-v1:0:4k": 2048, + # Llama 3 Instruct + # "meta.llama3-8b-instruct-v1:0": 2048, + "meta.llama3-70b-instruct-v1:0": 2048, + # Llama 3.1 Instruct + # "meta.llama3-1-8b-instruct-v1:0": 2048, + "meta.llama3-1-70b-instruct-v1:0": 2048, + "meta.llama3-1-405b-instruct-v1:0": 2048, + # Llama 3.2 Instruct + # "meta.llama3-2-3b-instruct-v1:0": 2048, + # "meta.llama3-2-11b-instruct-v1:0": 2048, + "meta.llama3-2-90b-instruct-v1:0": 2048, + # Mistral 7B Instruct + # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral-text-completion.html + # "mistral.mistral-7b-instruct-v0:2": 8192, + # Mixtral 8x7B Instruct + "mistral.mixtral-8x7b-instruct-v0:1": 4096, + # Mistral Small + "mistral.mistral-small-2402-v1:0": 8192, + # Mistral Large (24.02) + "mistral.mistral-large-2402-v1:0": 8192, + # Mistral Large 2 (24.07) + "mistral.mistral-large-2407-v1:0": 8192 } # TODO:use a more general function for constructing chat templates. @@ -106,7 +151,7 @@ def messages_to_prompt_claude2(messages: list[dict]) -> str: def get_max_tokens(model_id: str) -> int: try: - max_tokens = (NOT_SUUPORT_STREAM_MODELS | SUPPORT_STREAM_MODELS)[model_id] + max_tokens = (NOT_SUPPORT_STREAM_MODELS | SUPPORT_STREAM_MODELS)[model_id] except KeyError: logger.warning(f"Couldn't find model:{model_id} , max tokens has been set to 2048") max_tokens = 2048 diff --git a/metagpt/provider/bedrock_api.py b/metagpt/provider/bedrock_api.py index 03954e5c2..4cf22f41b 100644 --- a/metagpt/provider/bedrock_api.py +++ b/metagpt/provider/bedrock_api.py @@ -1,3 +1,4 @@ +import os import asyncio import json from functools import partial @@ -11,7 +12,7 @@ from metagpt.const import USE_CONFIG_TIMEOUT from metagpt.logs import log_llm_stream, logger from metagpt.provider.base_llm import BaseLLM from metagpt.provider.bedrock.bedrock_provider import get_provider -from metagpt.provider.bedrock.utils import NOT_SUUPORT_STREAM_MODELS, get_max_tokens +from metagpt.provider.bedrock.utils import NOT_SUPPORT_STREAM_MODELS, get_max_tokens from metagpt.provider.llm_provider_registry import register_provider from metagpt.utils.cost_manager import CostManager from metagpt.utils.token_counter import BEDROCK_TOKEN_COSTS @@ -24,18 +25,19 @@ class BedrockLLM(BaseLLM): self.__client = self.__init_client("bedrock-runtime") self.__provider = get_provider(self.config.model) self.cost_manager = CostManager(token_costs=BEDROCK_TOKEN_COSTS) - if self.config.model in NOT_SUUPORT_STREAM_MODELS: + if self.config.model in NOT_SUPPORT_STREAM_MODELS: logger.warning(f"model {self.config.model} doesn't support streaming output!") def __init_client(self, service_name: Literal["bedrock-runtime", "bedrock"]): """initialize boto3 client""" # access key and secret key from https://us-east-1.console.aws.amazon.com/iam - self.__credentital_kwargs = { - "aws_secret_access_key": self.config.secret_key, - "aws_access_key_id": self.config.access_key, - "region_name": self.config.region_name, + self.__credential_kwargs = { + "aws_secret_access_key": os.environ.get("AWS_SECRET_ACCESS_KEY", self.config.secret_key), + "aws_access_key_id": os.environ.get("AWS_ACCESS_KEY_ID", self.config.access_key), + "aws_session_token": os.environ.get("AWS_SESSION_TOKEN", self.config.session_token), + "region_name": os.environ.get("AWS_DEFAULT_REGION", self.config.region_name), } - session = boto3.Session(**self.__credentital_kwargs) + session = boto3.Session(**self.__credential_kwargs) client = session.client(service_name) return client @@ -111,7 +113,7 @@ class BedrockLLM(BaseLLM): return await self.acompletion(messages) async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str: - if self.config.model in NOT_SUUPORT_STREAM_MODELS: + if self.config.model in NOT_SUPPORT_STREAM_MODELS: rsp = await self.acompletion(messages) full_text = self.get_choice_text(rsp) log_llm_stream(full_text) diff --git a/tests/metagpt/provider/test_bedrock_api.py b/tests/metagpt/provider/test_bedrock_api.py index b9c9e0f93..28d1d7008 100644 --- a/tests/metagpt/provider/test_bedrock_api.py +++ b/tests/metagpt/provider/test_bedrock_api.py @@ -3,7 +3,7 @@ import json import pytest from metagpt.provider.bedrock.utils import ( - NOT_SUUPORT_STREAM_MODELS, + NOT_SUPPORT_STREAM_MODELS, SUPPORT_STREAM_MODELS, ) from metagpt.provider.bedrock_api import BedrockLLM @@ -14,7 +14,7 @@ from tests.metagpt.provider.req_resp_const import ( ) # all available model from bedrock -models = SUPPORT_STREAM_MODELS | NOT_SUUPORT_STREAM_MODELS +models = SUPPORT_STREAM_MODELS | NOT_SUPPORT_STREAM_MODELS messages = [{"role": "user", "content": "Hi!"}] usage = { "prompt_tokens": 1000000,