diff --git a/metagpt/provider/bedrock/utils.py b/metagpt/provider/bedrock/utils.py index a87367b22..ee31da1b9 100644 --- a/metagpt/provider/bedrock/utils.py +++ b/metagpt/provider/bedrock/utils.py @@ -13,25 +13,34 @@ NOT_SUUPORT_STREAM_MODELS = { SUPPORT_STREAM_MODELS = { "amazon.titan-tg1-large": 8000, "amazon.titan-text-express-v1": 8000, + "amazon.titan-text-express-v1:0:8k": 8000, + "amazon.titan-text-lite-v1:0:4k": 4000, + "amazon.titan-text-lite-v1": 4000, "anthropic.claude-instant-v1": 100000, "anthropic.claude-instant-v1:2:100k": 100000, "anthropic.claude-v1": 100000, "anthropic.claude-v2": 100000, "anthropic.claude-v2:1": 200000, + "anthropic.claude-v2:0:18k": 18000, + "anthropic.claude-v2:1:200k": 200000, "anthropic.claude-3-sonnet-20240229-v1:0": 200000, "anthropic.claude-3-sonnet-20240229-v1:0:28k": 28000, "anthropic.claude-3-sonnet-20240229-v1:0:200k": 200000, "anthropic.claude-3-haiku-20240307-v1:0": 200000, "anthropic.claude-3-haiku-20240307-v1:0:48k": 48000, "anthropic.claude-3-haiku-20240307-v1:0:200k": 200000, + # currently (2024-4-29) only available at US West (Oregon) AWS Region. + "anthropic.claude-3-opus-20240229-v1:0": 200000, "cohere.command-text-v14": 4000, "cohere.command-text-v14:7:4k": 4000, "cohere.command-light-text-v14": 4000, "cohere.command-light-text-v14:7:4k": 4000, - "meta.llama2-70b-v1": 4000, "meta.llama2-13b-chat-v1:0:4k": 4000, "meta.llama2-13b-chat-v1": 2000, + "meta.llama2-70b-v1": 4000, "meta.llama2-70b-v1:0:4k": 4000, + "meta.llama2-70b-chat-v1": 4000, + "meta.llama2-70b-chat-v1:0:4k": 4000, "meta.llama3-8b-instruct-v1:0": 2000, "meta.llama3-70b-instruct-v1:0": 2000, "mistral.mistral-7b-instruct-v0:2": 32000, @@ -43,14 +52,14 @@ SUPPORT_STREAM_MODELS = { def messages_to_prompt_llama2(messages: list[dict]) -> str: - BOS, EOS = "", "" + BOS = ("",) B_INST, E_INST = "[INST]", "[/INST]" B_SYS, E_SYS = "<>\n", "\n<>\n\n" prompt = f"{BOS}" for message in messages: - role = message["role"] - content = message["content"] + role = message.get("role", "") + content = message.get("content", "") if role == "system": prompt += f"{B_SYS} {content} {E_SYS}" elif role == "user": @@ -58,25 +67,24 @@ def messages_to_prompt_llama2(messages: list[dict]) -> str: elif role == "assistant": prompt += f"{content}" else: - logger.warning( - f"Unknown role name {role} when formatting messages") + logger.warning(f"Unknown role name {role} when formatting messages") prompt += f"{content}" return prompt def messages_to_prompt_llama3(messages: list[dict]) -> str: - BOS, EOS = "<|begin_of_text|>", "<|eot_id|>" + BOS = "<|begin_of_text|>" GENERAL_TEMPLATE = "<|start_header_id|>{role}<|end_header_id|>\n\n{content}<|eot_id|>" prompt = f"{BOS}" for message in messages: - role = message["role"] - content = message["content"] + role = message.get("role", "") + content = message.get("content", "") prompt += GENERAL_TEMPLATE.format(role=role, content=content) if role != "assistant": - prompt += f"<|start_header_id|>assistant<|end_header_id|>" + prompt += "<|start_header_id|>assistant<|end_header_id|>" return prompt @@ -85,15 +93,20 @@ def messages_to_prompt_claude2(messages: list[dict]) -> str: GENERAL_TEMPLATE = "\n\n{role}: {content}" prompt = "" for message in messages: - role = message["role"] - content = message["content"] + role = message.get("role", "") + content = message.get("content", "") prompt += GENERAL_TEMPLATE.format(role=role, content=content) if role != "assistant": - prompt += f"\n\nAssistant:" + prompt += "\n\nAssistant:" return prompt -def get_max_tokens(model_id) -> int: - return (NOT_SUUPORT_STREAM_MODELS | SUPPORT_STREAM_MODELS)[model_id] +def get_max_tokens(model_id: str) -> int: + try: + max_tokens = (NOT_SUUPORT_STREAM_MODELS | SUPPORT_STREAM_MODELS)[model_id] + except KeyError: + logger.warning(f"Couldn't find model:{model_id} , max tokens has been set to 2048") + max_tokens = 2048 + return max_tokens diff --git a/metagpt/provider/bedrock_api.py b/metagpt/provider/bedrock_api.py index de3fbae94..483b08f29 100644 --- a/metagpt/provider/bedrock_api.py +++ b/metagpt/provider/bedrock_api.py @@ -11,6 +11,8 @@ from metagpt.provider.base_llm import BaseLLM from metagpt.provider.bedrock.bedrock_provider import get_provider from metagpt.provider.bedrock.utils import NOT_SUUPORT_STREAM_MODELS, get_max_tokens from metagpt.provider.llm_provider_registry import register_provider +from metagpt.utils.cost_manager import CostManager +from metagpt.utils.token_counter import BEDROCK_TOKEN_COSTS @register_provider([LLMType.BEDROCK]) @@ -19,6 +21,7 @@ class BedrockLLM(BaseLLM): self.config = config self.__client = self.__init_client("bedrock-runtime") self.__provider = get_provider(self.config.model) + self.cost_manager = CostManager(token_costs=BEDROCK_TOKEN_COSTS) logger.warning("Amazon bedrock doesn't support asynchronous now") def __init_client(self, service_name: Literal["bedrock-runtime", "bedrock"]): @@ -62,14 +65,14 @@ class BedrockLLM(BaseLLM): def invoke_model(self, request_body: str) -> dict: response = self.__client.invoke_model(modelId=self.config.model, body=request_body) usage = self._get_usage(response) - self._update_costs(usage) + self._update_costs(usage, self.config.model) response_body = self._get_response_body(response) return response_body def invoke_model_with_response_stream(self, request_body: str) -> EventStream: response = self.__client.invoke_model_with_response_stream(modelId=self.config.model, body=request_body) usage = self._get_usage(response) - self._update_costs(usage) + self._update_costs(usage, self.config.model) return response @property @@ -82,16 +85,29 @@ class BedrockLLM(BaseLLM): return {self.__provider.max_tokens_field_name: max_tokens, "temperature": self.config.temperature} - def completion(self, messages: list[dict]) -> str: + # boto3 don't support support asynchronous calls. + # for asynchronous version of boto3, check out: + # https://aioboto3.readthedocs.io/en/latest/usage.html + # However,aioboto3 doesn't support invoke model + + def get_choice_text(self, rsp: dict) -> str: + return self.__provider.get_choice_text(rsp) + + async def acompletion(self, messages: list[dict]) -> dict: request_body = self.__provider.get_request_body(messages, self._const_kwargs) response_body = self.invoke_model(request_body) - completions = self.__provider.get_choice_text(response_body) - return completions + return response_body - def _chat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str: + async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict: + return await self.acompletion(messages) + + async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str: if self.config.model in NOT_SUUPORT_STREAM_MODELS: logger.warning(f"model {self.config.model} doesn't support streaming output!") - return self.completion(messages) + rsp = await self.acompletion(messages) + full_text = self.get_choice_text(rsp) + log_llm_stream(full_text) + return full_text request_body = self.__provider.get_request_body(messages, self._const_kwargs, stream=True) @@ -106,20 +122,6 @@ class BedrockLLM(BaseLLM): full_text = ("".join(collected_content)).lstrip() return full_text - # boto3 don't support support asynchronous calls. - # for asynchronous version of boto3, check out: - # https://aioboto3.readthedocs.io/en/latest/usage.html - # However,aioboto3 doesn't support invoke model - - async def acompletion(self, messages: list[dict]): - return await self._achat_completion(messages) - - async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT): - return self.completion(messages) - - async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT): - return self._chat_completion_stream(messages) - def _get_response_body(self, response) -> dict: response_body = json.loads(response["body"].read()) return response_body diff --git a/metagpt/utils/token_counter.py b/metagpt/utils/token_counter.py index 724d49afc..9249d674e 100644 --- a/metagpt/utils/token_counter.py +++ b/metagpt/utils/token_counter.py @@ -198,6 +198,53 @@ TOKEN_MAX = { "openai/gpt-4-turbo-preview": 128000, } +# For Amazon Bedrock US region +# See https://aws.amazon.com/cn/bedrock/pricing/ + +BEDROCK_TOKEN_COSTS = { + "amazon.titan-tg1-large": {"prompt": 0.0008, "completion": 0.0008}, + "amazon.titan-text-express-v1": {"prompt": 0.0008, "completion": 0.0008}, + "amazon.titan-text-express-v1:0:8k": {"prompt": 0.0008, "completion": 0.0008}, + "amazon.titan-text-lite-v1:0:4k": {"prompt": 0.0003, "completion": 0.0004}, + "amazon.titan-text-lite-v1": {"prompt": 0.0003, "completion": 0.0004}, + "anthropic.claude-instant-v1": {"prompt": 0.0008, "completion": 0.00024}, + "anthropic.claude-instant-v1:2:100k": {"prompt": 0.0008, "completion": 0.00024}, + "anthropic.claude-v1": {"prompt": 0.008, "completion": 0.0024}, + "anthropic.claude-v2": {"prompt": 0.008, "completion": 0.0024}, + "anthropic.claude-v2:1": {"prompt": 0.008, "completion": 0.0024}, + "anthropic.claude-v2:0:18k": {"prompt": 0.008, "completion": 0.0024}, + "anthropic.claude-v2:1:200k": {"prompt": 0.008, "completion": 0.0024}, + "anthropic.claude-3-sonnet-20240229-v1:0": {"prompt": 0.003, "completion": 0.015}, + "anthropic.claude-3-sonnet-20240229-v1:0:28k": {"prompt": 0.003, "completion": 0.015}, + "anthropic.claude-3-sonnet-20240229-v1:0:200k": {"prompt": 0.003, "completion": 0.015}, + "anthropic.claude-3-haiku-20240307-v1:0": {"prompt": 0.00025, "completion": 0.00125}, + "anthropic.claude-3-haiku-20240307-v1:0:48k": {"prompt": 0.00025, "completion": 0.00125}, + "anthropic.claude-3-haiku-20240307-v1:0:200k": {"prompt": 0.00025, "completion": 0.00125}, + # currently (2024-4-29) only available at US West (Oregon) AWS Region. + "anthropic.claude-3-opus-20240229-v1:0": {"prompt": 0.015, "completion": 0.075}, + "cohere.command-text-v14": {"prompt": 0.0015, "completion": 0.0015}, + "cohere.command-text-v14:7:4k": {"prompt": 0.0015, "completion": 0.0015}, + "cohere.command-light-text-v14": {"prompt": 0.0003, "completion": 0.0003}, + "cohere.command-light-text-v14:7:4k": {"prompt": 0.0003, "completion": 0.0003}, + "meta.llama2-13b-chat-v1:0:4k": {"prompt": 0.00075, "completion": 0.001}, + "meta.llama2-13b-chat-v1": {"prompt": 0.00075, "completion": 0.001}, + "meta.llama2-70b-v1": {"prompt": 0.00195, "completion": 0.00256}, + "meta.llama2-70b-v1:0:4k": {"prompt": 0.00195, "completion": 0.00256}, + "meta.llama2-70b-chat-v1": {"prompt": 0.00195, "completion": 0.00256}, + "meta.llama2-70b-chat-v1:0:4k": {"prompt": 0.00195, "completion": 0.00256}, + "meta.llama3-8b-instruct-v1:0": {"prompt": 0.0004, "completion": 0.0006}, + "meta.llama3-70b-instruct-v1:0": {"prompt": 0.00265, "completion": 0.0035}, + "mistral.mistral-7b-instruct-v0:2": {"prompt": 0.00015, "completion": 0.0002}, + "mistral.mixtral-8x7b-instruct-v0:1": {"prompt": 0.00045, "completion": 0.0007}, + "mistral.mistral-large-2402-v1:0": {"prompt": 0.008, "completion": 0.024}, + "ai21.j2-grande-instruct": {"prompt": 0.0125, "completion": 0.0125}, + "ai21.j2-jumbo-instruct": {"prompt": 0.0188, "completion": 0.0188}, + "ai21.j2-mid": {"prompt": 0.0125, "completion": 0.0125}, + "ai21.j2-mid-v1": {"prompt": 0.0125, "completion": 0.0125}, + "ai21.j2-ultra": {"prompt": 0.0188, "completion": 0.0188}, + "ai21.j2-ultra-v1": {"prompt": 0.0188, "completion": 0.0188}, +} + def count_message_tokens(messages, model="gpt-3.5-turbo-0125"): """Return the number of tokens used by a list of messages.""" diff --git a/tests/metagpt/provider/test_bedrock_api.py b/tests/metagpt/provider/test_bedrock_api.py index 54ff5afa4..4760a2db2 100644 --- a/tests/metagpt/provider/test_bedrock_api.py +++ b/tests/metagpt/provider/test_bedrock_api.py @@ -5,7 +5,6 @@ import pytest from metagpt.provider.bedrock.utils import ( NOT_SUUPORT_STREAM_MODELS, SUPPORT_STREAM_MODELS, - get_max_tokens, ) from metagpt.provider.bedrock_api import BedrockLLM from tests.metagpt.provider.mock_llm_config import mock_llm_config_bedrock @@ -17,14 +16,19 @@ from tests.metagpt.provider.req_resp_const import ( # all available model from bedrock models = SUPPORT_STREAM_MODELS | NOT_SUUPORT_STREAM_MODELS messages = [{"role": "user", "content": "Hi!"}] +usage = { + "prompt_tokens": 1000000, + "completion_tokens": 1000000, +} -def mock_bedrock_provider_response(self, *args, **kwargs) -> dict: +def mock_invoke_model(self: BedrockLLM, *args, **kwargs) -> dict: provider = self.config.model.split(".")[0] + self._update_costs(usage, self.config.model) return BEDROCK_PROVIDER_RESPONSE_BODY[provider] -def mock_bedrock_provider_stream_response(self, *args, **kwargs) -> dict: +def mock_invoke_model_stream(self: BedrockLLM, *args, **kwargs) -> dict: # use json object to mock EventStream def dict2bytes(x): return json.dumps(x).encode("utf-8") @@ -43,6 +47,7 @@ def mock_bedrock_provider_stream_response(self, *args, **kwargs) -> dict: response_body_bytes = dict2bytes(BEDROCK_PROVIDER_RESPONSE_BODY[provider]) response_body_stream = {"body": [{"chunk": {"bytes": response_body_bytes}}]} + self._update_costs(usage, self.config.model) return response_body_stream @@ -82,41 +87,23 @@ def bedrock_api(request) -> BedrockLLM: class TestBedrockAPI: def _patch_invoke_model(self, mocker): - mocker.patch("metagpt.provider.bedrock_api.BedrockLLM.invoke_model", mock_bedrock_provider_response) + mocker.patch("metagpt.provider.bedrock_api.BedrockLLM.invoke_model", mock_invoke_model) def _patch_invoke_model_stream(self, mocker): mocker.patch( "metagpt.provider.bedrock_api.BedrockLLM.invoke_model_with_response_stream", - mock_bedrock_provider_stream_response, + mock_invoke_model_stream, ) - def test_const_kwargs(self, bedrock_api: BedrockLLM): - provider = bedrock_api.provider - assert bedrock_api._const_kwargs[provider.max_tokens_field_name] <= get_max_tokens(bedrock_api.config.model) - def test_get_request_body(self, bedrock_api: BedrockLLM): """Ensure request body has correct format""" provider = bedrock_api.provider request_body = json.loads(provider.get_request_body(messages, bedrock_api._const_kwargs)) - assert is_subset(request_body, get_bedrock_request_body(bedrock_api.config.model)) - def test_completion(self, bedrock_api: BedrockLLM, mocker): - self._patch_invoke_model(mocker) - assert bedrock_api.completion(messages) == "Hello World" - - def test_chat_completion_stream(self, bedrock_api: BedrockLLM, mocker): + @pytest.mark.asyncio + async def test_aask(self, bedrock_api: BedrockLLM, mocker): self._patch_invoke_model(mocker) self._patch_invoke_model_stream(mocker) - assert bedrock_api._chat_completion_stream(messages) == "Hello World" - - @pytest.mark.asyncio - async def test_achat_completion_stream(self, bedrock_api: BedrockLLM, mocker): - self._patch_invoke_model_stream(mocker) - self._patch_invoke_model(mocker) - assert await bedrock_api._achat_completion_stream(messages) == "Hello World" - - @pytest.mark.asyncio - async def test_acompletion(self, bedrock_api: BedrockLLM, mocker): - self._patch_invoke_model(mocker) - assert await bedrock_api.acompletion(messages) == "Hello World" + assert await bedrock_api.aask(messages, stream=False) == "Hello World" + assert await bedrock_api.aask(messages, stream=True) == "Hello World"