Merge pull request #1342 from Wei-Jianan/feat/bedrock_async

[feat] support async in bedrockLLM by loop.run_in_executor
This commit is contained in:
Alexander Wu 2024-07-15 15:56:05 +08:00 committed by GitHub
commit 1b1606a5bd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 30 additions and 15 deletions

View file

@ -27,6 +27,7 @@ SUPPORT_STREAM_MODELS = {
"anthropic.claude-3-sonnet-20240229-v1:0:28k": 28000,
"anthropic.claude-3-sonnet-20240229-v1:0:200k": 200000,
"anthropic.claude-3-haiku-20240307-v1:0": 200000,
"anthropic.claude-3-5-sonnet-20240620-v1:0": 200000,
"anthropic.claude-3-haiku-20240307-v1:0:48k": 48000,
"anthropic.claude-3-haiku-20240307-v1:0:200k": 200000,
# currently (2024-4-29) only available at US West (Oregon) AWS Region.

View file

@ -1,5 +1,7 @@
import asyncio
import json
from typing import Literal
from functools import partial
from typing import List, Literal
import boto3
from botocore.eventstream import EventStream
@ -22,7 +24,6 @@ class BedrockLLM(BaseLLM):
self.__client = self.__init_client("bedrock-runtime")
self.__provider = get_provider(self.config.model)
self.cost_manager = CostManager(token_costs=BEDROCK_TOKEN_COSTS)
logger.warning("Amazon bedrock doesn't support asynchronous now")
if self.config.model in NOT_SUUPORT_STREAM_MODELS:
logger.warning(f"model {self.config.model} doesn't support streaming output!")
@ -64,15 +65,21 @@ class BedrockLLM(BaseLLM):
]
logger.info("\n" + "\n".join(summaries))
def invoke_model(self, request_body: str) -> dict:
response = self.__client.invoke_model(modelId=self.config.model, body=request_body)
async def invoke_model(self, request_body: str) -> dict:
loop = asyncio.get_running_loop()
response = await loop.run_in_executor(
None, partial(self.client.invoke_model, modelId=self.config.model, body=request_body)
)
usage = self._get_usage(response)
self._update_costs(usage, self.config.model)
response_body = self._get_response_body(response)
return response_body
def invoke_model_with_response_stream(self, request_body: str) -> EventStream:
response = self.__client.invoke_model_with_response_stream(modelId=self.config.model, body=request_body)
async def invoke_model_with_response_stream(self, request_body: str) -> EventStream:
loop = asyncio.get_running_loop()
response = await loop.run_in_executor(
None, partial(self.client.invoke_model_with_response_stream, modelId=self.config.model, body=request_body)
)
usage = self._get_usage(response)
self._update_costs(usage, self.config.model)
return response
@ -97,7 +104,7 @@ class BedrockLLM(BaseLLM):
async def acompletion(self, messages: list[dict]) -> dict:
request_body = self.__provider.get_request_body(messages, self._const_kwargs)
response_body = self.invoke_model(request_body)
response_body = await self.invoke_model(request_body)
return response_body
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict:
@ -111,14 +118,8 @@ class BedrockLLM(BaseLLM):
return full_text
request_body = self.__provider.get_request_body(messages, self._const_kwargs, stream=True)
response = self.invoke_model_with_response_stream(request_body)
collected_content = []
for event in response["body"]:
chunk_text = self.__provider.get_choice_text_from_stream(event)
collected_content.append(chunk_text)
log_llm_stream(chunk_text)
stream_response = await self.invoke_model_with_response_stream(request_body)
collected_content = await self._get_stream_response_body(stream_response)
log_llm_stream("\n")
full_text = ("".join(collected_content)).lstrip()
return full_text
@ -127,6 +128,18 @@ class BedrockLLM(BaseLLM):
response_body = json.loads(response["body"].read())
return response_body
async def _get_stream_response_body(self, stream_response) -> List[str]:
def collect_content() -> str:
collected_content = []
for event in stream_response["body"]:
chunk_text = self.__provider.get_choice_text_from_stream(event)
collected_content.append(chunk_text)
log_llm_stream(chunk_text)
return collected_content
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, collect_content)
def _get_usage(self, response) -> dict[str, int]:
headers = response.get("ResponseMetadata", {}).get("HTTPHeaders", {})
prompt_tokens = int(headers.get("x-amzn-bedrock-input-token-count", 0))

View file

@ -285,6 +285,7 @@ BEDROCK_TOKEN_COSTS = {
"anthropic.claude-3-sonnet-20240229-v1:0": {"prompt": 0.003, "completion": 0.015},
"anthropic.claude-3-sonnet-20240229-v1:0:28k": {"prompt": 0.003, "completion": 0.015},
"anthropic.claude-3-sonnet-20240229-v1:0:200k": {"prompt": 0.003, "completion": 0.015},
"anthropic.claude-3-5-sonnet-20240620-v1:0": {"prompt": 0.003, "completion": 0.015},
"anthropic.claude-3-haiku-20240307-v1:0": {"prompt": 0.00025, "completion": 0.00125},
"anthropic.claude-3-haiku-20240307-v1:0:48k": {"prompt": 0.00025, "completion": 0.00125},
"anthropic.claude-3-haiku-20240307-v1:0:200k": {"prompt": 0.00025, "completion": 0.00125},