[feat] support async in bedrockLLM by loop.run_in_executor

This commit is contained in:
Wei-Jianan 2024-06-13 05:04:47 +00:00
parent 38cea1daf2
commit f77e685bd1

View file

@ -1,4 +1,6 @@
import asyncio
import json
from functools import partial
from typing import Literal
import boto3
@ -22,7 +24,6 @@ class BedrockLLM(BaseLLM):
self.__client = self.__init_client("bedrock-runtime")
self.__provider = get_provider(self.config.model)
self.cost_manager = CostManager(token_costs=BEDROCK_TOKEN_COSTS)
logger.warning("Amazon bedrock doesn't support asynchronous now")
if self.config.model in NOT_SUUPORT_STREAM_MODELS:
logger.warning(f"model {self.config.model} doesn't support streaming output!")
@ -64,15 +65,21 @@ class BedrockLLM(BaseLLM):
]
logger.info("\n" + "\n".join(summaries))
def invoke_model(self, request_body: str) -> dict:
response = self.__client.invoke_model(modelId=self.config.model, body=request_body)
async def invoke_model(self, request_body: str) -> dict:
loop = asyncio.get_running_loop()
response = await loop.run_in_executor(
None, partial(self.client.invoke_model, modelId=self.config.model, body=request_body)
)
usage = self._get_usage(response)
self._update_costs(usage, self.config.model)
response_body = self._get_response_body(response)
return response_body
def invoke_model_with_response_stream(self, request_body: str) -> EventStream:
response = self.__client.invoke_model_with_response_stream(modelId=self.config.model, body=request_body)
async def invoke_model_with_response_stream(self, request_body: str) -> EventStream:
loop = asyncio.get_running_loop()
response = await loop.run_in_executor(
None, partial(self.client.invoke_model_with_response_stream, modelId=self.config.model, body=request_body)
)
usage = self._get_usage(response)
self._update_costs(usage, self.config.model)
return response
@ -97,7 +104,7 @@ class BedrockLLM(BaseLLM):
async def acompletion(self, messages: list[dict]) -> dict:
request_body = self.__provider.get_request_body(messages, self._const_kwargs)
response_body = self.invoke_model(request_body)
response_body = await self.invoke_model(request_body)
return response_body
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict:
@ -111,14 +118,8 @@ class BedrockLLM(BaseLLM):
return full_text
request_body = self.__provider.get_request_body(messages, self._const_kwargs, stream=True)
response = self.invoke_model_with_response_stream(request_body)
collected_content = []
for event in response["body"]:
chunk_text = self.__provider.get_choice_text_from_stream(event)
collected_content.append(chunk_text)
log_llm_stream(chunk_text)
stream_response = await self.invoke_model_with_response_stream(request_body)
collected_content = await self._get_stream_response_body(stream_response)
log_llm_stream("\n")
full_text = ("".join(collected_content)).lstrip()
return full_text
@ -127,6 +128,18 @@ class BedrockLLM(BaseLLM):
response_body = json.loads(response["body"].read())
return response_body
async def _get_stream_response_body(self, stream_response) -> str:
def collect_content() -> str:
collected_content = []
for event in stream_response["body"]:
chunk_text = self.__provider.get_choice_text_from_stream(event)
collected_content.append(chunk_text)
log_llm_stream(chunk_text)
return collected_content
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, collect_content)
def _get_usage(self, response) -> dict[str, int]:
headers = response.get("ResponseMetadata", {}).get("HTTPHeaders", {})
prompt_tokens = int(headers.get("x-amzn-bedrock-input-token-count", 0))