This commit is contained in:
usamimeri_renko 2024-04-26 01:29:43 +08:00
parent e9723f4955
commit 6561c7aa7e
2 changed files with 28 additions and 6 deletions

View file

@ -11,14 +11,21 @@ import boto3
@register_provider([LLMType.AMAZON_BEDROCK])
class AmazonBedrockLLM(BaseLLM):
"""
check out:
https://docs.aws.amazon.com/code-library/latest/ug/python_3_bedrock-runtime_code_examples.html
"""
def __init__(self, config: LLMConfig):
self.config = config
self.__client = self.__init_client("bedrock-runtime")
self.__provider = get_provider(self.config.model)
logger.warning("Amazon bedrock doesn't support async now")
logger.warning(
"Amazon bedrock doesn't support asynchronous calls now")
def __init_client(self, service_name: Literal["bedrock-runtime", "bedrock"]):
# access key from https://us-east-1.console.aws.amazon.com/iam
"""initialize boto3 client"""
# access key and secret key from https://us-east-1.console.aws.amazon.com/iam
self.__credentital_kwards = {
"aws_secret_access_key": self.config.secret_key,
"aws_access_key_id": self.config.access_key,
@ -29,6 +36,14 @@ class AmazonBedrockLLM(BaseLLM):
return client
def list_models(self):
"""list all available text-generation models
```shell
ai21.j2-ultra-v1 Support Streaming:False
meta.llama3-70b-instruct-v1:0 Support Streaming:True
```
"""
client = self.__init_client("bedrock")
# only output text-generation models
response = client.list_foundation_models(byOutputModality='TEXT')
@ -38,12 +53,12 @@ class AmazonBedrockLLM(BaseLLM):
@property
def _generate_kwargs(self) -> dict:
# for now only use temperature due to the difference of request body
model_max_tokens = get_max_tokens(self.config.model)
if self.config.max_token > model_max_tokens:
max_tokens = model_max_tokens
else:
max_tokens = self.config.max_token
return {
self.__provider.max_tokens_field_name: max_tokens,
"temperature": self.config.temperature
@ -81,8 +96,12 @@ class AmazonBedrockLLM(BaseLLM):
full_text = ("".join(collected_content)).lstrip()
return full_text
# boto3 don't support support asynchronous calls.
# for asynchronous version of boto3,check out:
# https://aioboto3.readthedocs.io/en/latest/usage.html
# However,aioboto3 doesn't support invoke model
async def acompletion(self, messages: list[dict]):
# Amazon bedrock doesn't support async now
return await self._achat_completion(messages)
async def acompletion_text(self, messages: list[dict], stream: bool = False,

View file

@ -1,5 +1,6 @@
from metagpt.logs import logger
# max_tokens for each model
NOT_SUUPORT_STREAM_MODELS = {
"ai21.j2-grande-instruct": 8000,
"ai21.j2-jumbo-instruct": 8000,
@ -29,7 +30,7 @@ SUPPORT_STREAM_MODELS = {
"mistral.mistral-large-2402-v1:0": 32000,
}
# TODO:use a general function for constructing chat templates.
# TODO:use a more general function for constructing chat templates.
def messages_to_prompt_llama2(messages: list[dict]):
@ -64,6 +65,7 @@ def messages_to_prompt_llama3(messages: list[dict]):
role = message["role"]
content = message["content"]
prompt += GENERAL_TEMPLATE.format(role=role, content=content)
if role != "assistant":
prompt += f"<|start_header_id|>assistant<|end_header_id|>"
@ -77,11 +79,12 @@ def messages_to_prompt_claude(messages: list[dict]):
role = message["role"]
content = message["content"]
prompt += GENERAL_TEMPLATE.format(role=role, content=content)
if role != "assistant":
prompt += f"\n\nAssistant:"
return prompt
def get_max_tokens(model_id) -> int:
return (NOT_SUUPORT_STREAM_MODELS | SUPPORT_STREAM_MODELS)[model_id]