From f77e685bd19420f126eafa8b1705e5300735320f Mon Sep 17 00:00:00 2001 From: Wei-Jianan Date: Thu, 13 Jun 2024 05:04:47 +0000 Subject: [PATCH] [feat] support async in bedrockLLM by loop.run_in_executor --- metagpt/provider/bedrock_api.py | 41 ++++++++++++++++++++++----------- 1 file changed, 27 insertions(+), 14 deletions(-) diff --git a/metagpt/provider/bedrock_api.py b/metagpt/provider/bedrock_api.py index f30d4701e..005488ecf 100644 --- a/metagpt/provider/bedrock_api.py +++ b/metagpt/provider/bedrock_api.py @@ -1,4 +1,6 @@ +import asyncio import json +from functools import partial from typing import Literal import boto3 @@ -22,7 +24,6 @@ class BedrockLLM(BaseLLM): self.__client = self.__init_client("bedrock-runtime") self.__provider = get_provider(self.config.model) self.cost_manager = CostManager(token_costs=BEDROCK_TOKEN_COSTS) - logger.warning("Amazon bedrock doesn't support asynchronous now") if self.config.model in NOT_SUUPORT_STREAM_MODELS: logger.warning(f"model {self.config.model} doesn't support streaming output!") @@ -64,15 +65,21 @@ class BedrockLLM(BaseLLM): ] logger.info("\n" + "\n".join(summaries)) - def invoke_model(self, request_body: str) -> dict: - response = self.__client.invoke_model(modelId=self.config.model, body=request_body) + async def invoke_model(self, request_body: str) -> dict: + loop = asyncio.get_running_loop() + response = await loop.run_in_executor( + None, partial(self.client.invoke_model, modelId=self.config.model, body=request_body) + ) usage = self._get_usage(response) self._update_costs(usage, self.config.model) response_body = self._get_response_body(response) return response_body - def invoke_model_with_response_stream(self, request_body: str) -> EventStream: - response = self.__client.invoke_model_with_response_stream(modelId=self.config.model, body=request_body) + async def invoke_model_with_response_stream(self, request_body: str) -> EventStream: + loop = asyncio.get_running_loop() + response = await loop.run_in_executor( + None, partial(self.client.invoke_model_with_response_stream, modelId=self.config.model, body=request_body) + ) usage = self._get_usage(response) self._update_costs(usage, self.config.model) return response @@ -97,7 +104,7 @@ class BedrockLLM(BaseLLM): async def acompletion(self, messages: list[dict]) -> dict: request_body = self.__provider.get_request_body(messages, self._const_kwargs) - response_body = self.invoke_model(request_body) + response_body = await self.invoke_model(request_body) return response_body async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict: @@ -111,14 +118,8 @@ class BedrockLLM(BaseLLM): return full_text request_body = self.__provider.get_request_body(messages, self._const_kwargs, stream=True) - - response = self.invoke_model_with_response_stream(request_body) - collected_content = [] - for event in response["body"]: - chunk_text = self.__provider.get_choice_text_from_stream(event) - collected_content.append(chunk_text) - log_llm_stream(chunk_text) - + stream_response = await self.invoke_model_with_response_stream(request_body) + collected_content = await self._get_stream_response_body(stream_response) log_llm_stream("\n") full_text = ("".join(collected_content)).lstrip() return full_text @@ -127,6 +128,18 @@ class BedrockLLM(BaseLLM): response_body = json.loads(response["body"].read()) return response_body + async def _get_stream_response_body(self, stream_response) -> str: + def collect_content() -> str: + collected_content = [] + for event in stream_response["body"]: + chunk_text = self.__provider.get_choice_text_from_stream(event) + collected_content.append(chunk_text) + log_llm_stream(chunk_text) + return collected_content + + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, collect_content) + def _get_usage(self, response) -> dict[str, int]: headers = response.get("ResponseMetadata", {}).get("HTTPHeaders", {}) prompt_tokens = int(headers.get("x-amzn-bedrock-input-token-count", 0))