mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
[feat] support async in bedrockLLM by loop.run_in_executor
This commit is contained in:
parent
38cea1daf2
commit
f77e685bd1
1 changed files with 27 additions and 14 deletions
|
|
@ -1,4 +1,6 @@
|
|||
import asyncio
|
||||
import json
|
||||
from functools import partial
|
||||
from typing import Literal
|
||||
|
||||
import boto3
|
||||
|
|
@ -22,7 +24,6 @@ class BedrockLLM(BaseLLM):
|
|||
self.__client = self.__init_client("bedrock-runtime")
|
||||
self.__provider = get_provider(self.config.model)
|
||||
self.cost_manager = CostManager(token_costs=BEDROCK_TOKEN_COSTS)
|
||||
logger.warning("Amazon bedrock doesn't support asynchronous now")
|
||||
if self.config.model in NOT_SUUPORT_STREAM_MODELS:
|
||||
logger.warning(f"model {self.config.model} doesn't support streaming output!")
|
||||
|
||||
|
|
@ -64,15 +65,21 @@ class BedrockLLM(BaseLLM):
|
|||
]
|
||||
logger.info("\n" + "\n".join(summaries))
|
||||
|
||||
def invoke_model(self, request_body: str) -> dict:
|
||||
response = self.__client.invoke_model(modelId=self.config.model, body=request_body)
|
||||
async def invoke_model(self, request_body: str) -> dict:
|
||||
loop = asyncio.get_running_loop()
|
||||
response = await loop.run_in_executor(
|
||||
None, partial(self.client.invoke_model, modelId=self.config.model, body=request_body)
|
||||
)
|
||||
usage = self._get_usage(response)
|
||||
self._update_costs(usage, self.config.model)
|
||||
response_body = self._get_response_body(response)
|
||||
return response_body
|
||||
|
||||
def invoke_model_with_response_stream(self, request_body: str) -> EventStream:
|
||||
response = self.__client.invoke_model_with_response_stream(modelId=self.config.model, body=request_body)
|
||||
async def invoke_model_with_response_stream(self, request_body: str) -> EventStream:
|
||||
loop = asyncio.get_running_loop()
|
||||
response = await loop.run_in_executor(
|
||||
None, partial(self.client.invoke_model_with_response_stream, modelId=self.config.model, body=request_body)
|
||||
)
|
||||
usage = self._get_usage(response)
|
||||
self._update_costs(usage, self.config.model)
|
||||
return response
|
||||
|
|
@ -97,7 +104,7 @@ class BedrockLLM(BaseLLM):
|
|||
|
||||
async def acompletion(self, messages: list[dict]) -> dict:
|
||||
request_body = self.__provider.get_request_body(messages, self._const_kwargs)
|
||||
response_body = self.invoke_model(request_body)
|
||||
response_body = await self.invoke_model(request_body)
|
||||
return response_body
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict:
|
||||
|
|
@ -111,14 +118,8 @@ class BedrockLLM(BaseLLM):
|
|||
return full_text
|
||||
|
||||
request_body = self.__provider.get_request_body(messages, self._const_kwargs, stream=True)
|
||||
|
||||
response = self.invoke_model_with_response_stream(request_body)
|
||||
collected_content = []
|
||||
for event in response["body"]:
|
||||
chunk_text = self.__provider.get_choice_text_from_stream(event)
|
||||
collected_content.append(chunk_text)
|
||||
log_llm_stream(chunk_text)
|
||||
|
||||
stream_response = await self.invoke_model_with_response_stream(request_body)
|
||||
collected_content = await self._get_stream_response_body(stream_response)
|
||||
log_llm_stream("\n")
|
||||
full_text = ("".join(collected_content)).lstrip()
|
||||
return full_text
|
||||
|
|
@ -127,6 +128,18 @@ class BedrockLLM(BaseLLM):
|
|||
response_body = json.loads(response["body"].read())
|
||||
return response_body
|
||||
|
||||
async def _get_stream_response_body(self, stream_response) -> str:
|
||||
def collect_content() -> str:
|
||||
collected_content = []
|
||||
for event in stream_response["body"]:
|
||||
chunk_text = self.__provider.get_choice_text_from_stream(event)
|
||||
collected_content.append(chunk_text)
|
||||
log_llm_stream(chunk_text)
|
||||
return collected_content
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(None, collect_content)
|
||||
|
||||
def _get_usage(self, response) -> dict[str, int]:
|
||||
headers = response.get("ResponseMetadata", {}).get("HTTPHeaders", {})
|
||||
prompt_tokens = int(headers.get("x-amzn-bedrock-input-token-count", 0))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue