diff --git a/metagpt/provider/bedrock/utils.py b/metagpt/provider/bedrock/utils.py index 4f3be47ae..46520d1d5 100644 --- a/metagpt/provider/bedrock/utils.py +++ b/metagpt/provider/bedrock/utils.py @@ -27,6 +27,7 @@ SUPPORT_STREAM_MODELS = { "anthropic.claude-3-sonnet-20240229-v1:0:28k": 28000, "anthropic.claude-3-sonnet-20240229-v1:0:200k": 200000, "anthropic.claude-3-haiku-20240307-v1:0": 200000, + "anthropic.claude-3-5-sonnet-20240620-v1:0": 200000, "anthropic.claude-3-haiku-20240307-v1:0:48k": 48000, "anthropic.claude-3-haiku-20240307-v1:0:200k": 200000, # currently (2024-4-29) only available at US West (Oregon) AWS Region. diff --git a/metagpt/provider/bedrock_api.py b/metagpt/provider/bedrock_api.py index f30d4701e..03954e5c2 100644 --- a/metagpt/provider/bedrock_api.py +++ b/metagpt/provider/bedrock_api.py @@ -1,5 +1,7 @@ +import asyncio import json -from typing import Literal +from functools import partial +from typing import List, Literal import boto3 from botocore.eventstream import EventStream @@ -22,7 +24,6 @@ class BedrockLLM(BaseLLM): self.__client = self.__init_client("bedrock-runtime") self.__provider = get_provider(self.config.model) self.cost_manager = CostManager(token_costs=BEDROCK_TOKEN_COSTS) - logger.warning("Amazon bedrock doesn't support asynchronous now") if self.config.model in NOT_SUUPORT_STREAM_MODELS: logger.warning(f"model {self.config.model} doesn't support streaming output!") @@ -64,15 +65,21 @@ class BedrockLLM(BaseLLM): ] logger.info("\n" + "\n".join(summaries)) - def invoke_model(self, request_body: str) -> dict: - response = self.__client.invoke_model(modelId=self.config.model, body=request_body) + async def invoke_model(self, request_body: str) -> dict: + loop = asyncio.get_running_loop() + response = await loop.run_in_executor( + None, partial(self.client.invoke_model, modelId=self.config.model, body=request_body) + ) usage = self._get_usage(response) self._update_costs(usage, self.config.model) response_body = self._get_response_body(response) return response_body - def invoke_model_with_response_stream(self, request_body: str) -> EventStream: - response = self.__client.invoke_model_with_response_stream(modelId=self.config.model, body=request_body) + async def invoke_model_with_response_stream(self, request_body: str) -> EventStream: + loop = asyncio.get_running_loop() + response = await loop.run_in_executor( + None, partial(self.client.invoke_model_with_response_stream, modelId=self.config.model, body=request_body) + ) usage = self._get_usage(response) self._update_costs(usage, self.config.model) return response @@ -97,7 +104,7 @@ class BedrockLLM(BaseLLM): async def acompletion(self, messages: list[dict]) -> dict: request_body = self.__provider.get_request_body(messages, self._const_kwargs) - response_body = self.invoke_model(request_body) + response_body = await self.invoke_model(request_body) return response_body async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict: @@ -111,14 +118,8 @@ class BedrockLLM(BaseLLM): return full_text request_body = self.__provider.get_request_body(messages, self._const_kwargs, stream=True) - - response = self.invoke_model_with_response_stream(request_body) - collected_content = [] - for event in response["body"]: - chunk_text = self.__provider.get_choice_text_from_stream(event) - collected_content.append(chunk_text) - log_llm_stream(chunk_text) - + stream_response = await self.invoke_model_with_response_stream(request_body) + collected_content = await self._get_stream_response_body(stream_response) log_llm_stream("\n") full_text = ("".join(collected_content)).lstrip() return full_text @@ -127,6 +128,18 @@ class BedrockLLM(BaseLLM): response_body = json.loads(response["body"].read()) return response_body + async def _get_stream_response_body(self, stream_response) -> List[str]: + def collect_content() -> str: + collected_content = [] + for event in stream_response["body"]: + chunk_text = self.__provider.get_choice_text_from_stream(event) + collected_content.append(chunk_text) + log_llm_stream(chunk_text) + return collected_content + + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, collect_content) + def _get_usage(self, response) -> dict[str, int]: headers = response.get("ResponseMetadata", {}).get("HTTPHeaders", {}) prompt_tokens = int(headers.get("x-amzn-bedrock-input-token-count", 0)) diff --git a/metagpt/utils/token_counter.py b/metagpt/utils/token_counter.py index e2b2c230d..0d69fca10 100644 --- a/metagpt/utils/token_counter.py +++ b/metagpt/utils/token_counter.py @@ -285,6 +285,7 @@ BEDROCK_TOKEN_COSTS = { "anthropic.claude-3-sonnet-20240229-v1:0": {"prompt": 0.003, "completion": 0.015}, "anthropic.claude-3-sonnet-20240229-v1:0:28k": {"prompt": 0.003, "completion": 0.015}, "anthropic.claude-3-sonnet-20240229-v1:0:200k": {"prompt": 0.003, "completion": 0.015}, + "anthropic.claude-3-5-sonnet-20240620-v1:0": {"prompt": 0.003, "completion": 0.015}, "anthropic.claude-3-haiku-20240307-v1:0": {"prompt": 0.00025, "completion": 0.00125}, "anthropic.claude-3-haiku-20240307-v1:0:48k": {"prompt": 0.00025, "completion": 0.00125}, "anthropic.claude-3-haiku-20240307-v1:0:200k": {"prompt": 0.00025, "completion": 0.00125},