mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-15 11:02:36 +02:00
Merge pull request #1342 from Wei-Jianan/feat/bedrock_async
[feat] support async in bedrockLLM by loop.run_in_executor
This commit is contained in:
commit
1b1606a5bd
3 changed files with 30 additions and 15 deletions
|
|
@ -27,6 +27,7 @@ SUPPORT_STREAM_MODELS = {
|
|||
"anthropic.claude-3-sonnet-20240229-v1:0:28k": 28000,
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0:200k": 200000,
|
||||
"anthropic.claude-3-haiku-20240307-v1:0": 200000,
|
||||
"anthropic.claude-3-5-sonnet-20240620-v1:0": 200000,
|
||||
"anthropic.claude-3-haiku-20240307-v1:0:48k": 48000,
|
||||
"anthropic.claude-3-haiku-20240307-v1:0:200k": 200000,
|
||||
# currently (2024-4-29) only available at US West (Oregon) AWS Region.
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
import asyncio
|
||||
import json
|
||||
from typing import Literal
|
||||
from functools import partial
|
||||
from typing import List, Literal
|
||||
|
||||
import boto3
|
||||
from botocore.eventstream import EventStream
|
||||
|
|
@ -22,7 +24,6 @@ class BedrockLLM(BaseLLM):
|
|||
self.__client = self.__init_client("bedrock-runtime")
|
||||
self.__provider = get_provider(self.config.model)
|
||||
self.cost_manager = CostManager(token_costs=BEDROCK_TOKEN_COSTS)
|
||||
logger.warning("Amazon bedrock doesn't support asynchronous now")
|
||||
if self.config.model in NOT_SUUPORT_STREAM_MODELS:
|
||||
logger.warning(f"model {self.config.model} doesn't support streaming output!")
|
||||
|
||||
|
|
@ -64,15 +65,21 @@ class BedrockLLM(BaseLLM):
|
|||
]
|
||||
logger.info("\n" + "\n".join(summaries))
|
||||
|
||||
def invoke_model(self, request_body: str) -> dict:
|
||||
response = self.__client.invoke_model(modelId=self.config.model, body=request_body)
|
||||
async def invoke_model(self, request_body: str) -> dict:
|
||||
loop = asyncio.get_running_loop()
|
||||
response = await loop.run_in_executor(
|
||||
None, partial(self.client.invoke_model, modelId=self.config.model, body=request_body)
|
||||
)
|
||||
usage = self._get_usage(response)
|
||||
self._update_costs(usage, self.config.model)
|
||||
response_body = self._get_response_body(response)
|
||||
return response_body
|
||||
|
||||
def invoke_model_with_response_stream(self, request_body: str) -> EventStream:
|
||||
response = self.__client.invoke_model_with_response_stream(modelId=self.config.model, body=request_body)
|
||||
async def invoke_model_with_response_stream(self, request_body: str) -> EventStream:
|
||||
loop = asyncio.get_running_loop()
|
||||
response = await loop.run_in_executor(
|
||||
None, partial(self.client.invoke_model_with_response_stream, modelId=self.config.model, body=request_body)
|
||||
)
|
||||
usage = self._get_usage(response)
|
||||
self._update_costs(usage, self.config.model)
|
||||
return response
|
||||
|
|
@ -97,7 +104,7 @@ class BedrockLLM(BaseLLM):
|
|||
|
||||
async def acompletion(self, messages: list[dict]) -> dict:
|
||||
request_body = self.__provider.get_request_body(messages, self._const_kwargs)
|
||||
response_body = self.invoke_model(request_body)
|
||||
response_body = await self.invoke_model(request_body)
|
||||
return response_body
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict:
|
||||
|
|
@ -111,14 +118,8 @@ class BedrockLLM(BaseLLM):
|
|||
return full_text
|
||||
|
||||
request_body = self.__provider.get_request_body(messages, self._const_kwargs, stream=True)
|
||||
|
||||
response = self.invoke_model_with_response_stream(request_body)
|
||||
collected_content = []
|
||||
for event in response["body"]:
|
||||
chunk_text = self.__provider.get_choice_text_from_stream(event)
|
||||
collected_content.append(chunk_text)
|
||||
log_llm_stream(chunk_text)
|
||||
|
||||
stream_response = await self.invoke_model_with_response_stream(request_body)
|
||||
collected_content = await self._get_stream_response_body(stream_response)
|
||||
log_llm_stream("\n")
|
||||
full_text = ("".join(collected_content)).lstrip()
|
||||
return full_text
|
||||
|
|
@ -127,6 +128,18 @@ class BedrockLLM(BaseLLM):
|
|||
response_body = json.loads(response["body"].read())
|
||||
return response_body
|
||||
|
||||
async def _get_stream_response_body(self, stream_response) -> List[str]:
|
||||
def collect_content() -> str:
|
||||
collected_content = []
|
||||
for event in stream_response["body"]:
|
||||
chunk_text = self.__provider.get_choice_text_from_stream(event)
|
||||
collected_content.append(chunk_text)
|
||||
log_llm_stream(chunk_text)
|
||||
return collected_content
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(None, collect_content)
|
||||
|
||||
def _get_usage(self, response) -> dict[str, int]:
|
||||
headers = response.get("ResponseMetadata", {}).get("HTTPHeaders", {})
|
||||
prompt_tokens = int(headers.get("x-amzn-bedrock-input-token-count", 0))
|
||||
|
|
|
|||
|
|
@ -285,6 +285,7 @@ BEDROCK_TOKEN_COSTS = {
|
|||
"anthropic.claude-3-sonnet-20240229-v1:0": {"prompt": 0.003, "completion": 0.015},
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0:28k": {"prompt": 0.003, "completion": 0.015},
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0:200k": {"prompt": 0.003, "completion": 0.015},
|
||||
"anthropic.claude-3-5-sonnet-20240620-v1:0": {"prompt": 0.003, "completion": 0.015},
|
||||
"anthropic.claude-3-haiku-20240307-v1:0": {"prompt": 0.00025, "completion": 0.00125},
|
||||
"anthropic.claude-3-haiku-20240307-v1:0:48k": {"prompt": 0.00025, "completion": 0.00125},
|
||||
"anthropic.claude-3-haiku-20240307-v1:0:200k": {"prompt": 0.00025, "completion": 0.00125},
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue