This commit is contained in:
usamimeri_renko 2024-04-26 16:08:39 +08:00
parent cafe666bfd
commit 4c394a1cac
7 changed files with 198 additions and 29 deletions

View file

@ -1,4 +1,5 @@
from typing import Literal
import json
from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.configs.llm_config import LLMConfig, LLMType
@ -8,9 +9,10 @@ from metagpt.provider.bedrock.bedrock_provider import get_provider
from metagpt.provider.bedrock.utils import NOT_SUUPORT_STREAM_MODELS, get_max_tokens
try:
import boto3
from botocore.response import StreamingBody
except ImportError:
raise ImportError(
"boto3 not found! please install it by `pip install boto3` first ")
"boto3 not found! please install it by `pip install boto3` ")
@register_provider([LLMType.AMAZON_BEDROCK])
@ -25,7 +27,7 @@ class AmazonBedrockLLM(BaseLLM):
self.__client = self.__init_client("bedrock-runtime")
self.__provider = get_provider(self.config.model)
logger.warning(
"Amazon bedrock doesn't support asynchronous calls now")
"Amazon bedrock doesn't support asynchronous now")
def __init_client(self, service_name: Literal["bedrock-runtime", "bedrock"]):
"""initialize boto3 client"""
@ -39,6 +41,12 @@ class AmazonBedrockLLM(BaseLLM):
client = session.client(service_name)
return client
def _get_client(self):
return self.__client
def _get_provider(self):
return self.__provider
def list_models(self):
"""list all available text-generation models
@ -55,6 +63,19 @@ class AmazonBedrockLLM(BaseLLM):
for summary in response["modelSummaries"]]
logger.info("\n"+"\n".join(summaries))
def invoke_model(self, request_body) -> dict:
response = self.__client.invoke_model(
modelId=self.config.model, body=request_body
)
response_body = self._get_response_body(response)
return response_body
def invoke_model_with_response_stream(self, request_body) -> StreamingBody:
response = self.__client.invoke_model_with_response_stream(
modelId=self.config.model, body=request_body
)
return response
@property
def _generate_kwargs(self) -> dict:
model_max_tokens = get_max_tokens(self.config.model)
@ -71,10 +92,8 @@ class AmazonBedrockLLM(BaseLLM):
def completion(self, messages: list[dict]) -> str:
request_body = self.__provider.get_request_body(
messages, **self._generate_kwargs)
response = self.__client.invoke_model(
modelId=self.config.model, body=request_body
)
completions = self.__provider.get_choice_text(response)
response_body = self.invoke_model(request_body)
completions = self.__provider.get_choice_text(response_body)
return completions
def _chat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str:
@ -86,9 +105,7 @@ class AmazonBedrockLLM(BaseLLM):
request_body = self.__provider.get_request_body(
messages, **self._generate_kwargs)
response = self.__client.invoke_model_with_response_stream(
modelId=self.config.model, body=request_body
)
response = self.invoke_model_with_response_stream(request_body)
collected_content = []
for event in response["body"]:
@ -119,3 +136,7 @@ class AmazonBedrockLLM(BaseLLM):
async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
return self._chat_completion_stream(messages)
def _get_response_body(self, response) -> dict:
response_body = json.loads(response["body"].read())
return response_body

View file

@ -15,8 +15,7 @@ class BaseBedrockProvider(ABC):
{"prompt": self.messages_to_prompt(messages), **generate_kwargs})
return body
def get_choice_text(self, response) -> str:
response_body = self._get_response_body_json(response)
def get_choice_text(self, response_body: dict) -> str:
completions = self._get_completion_from_dict(response_body)
return completions
@ -25,10 +24,6 @@ class BaseBedrockProvider(ABC):
completions = self._get_completion_from_dict(rsp_dict)
return completions
def _get_response_body_json(self, response):
response_body = json.loads(response["body"].read())
return response_body
def messages_to_prompt(self, messages: list[dict]) -> str:
"""[{"role": "user", "content": msg}] to user: <msg> etc."""
return "\n".join([f"{i['role']}: {i['content']}" for i in messages])

View file

@ -1,7 +1,7 @@
import json
from typing import Literal
from metagpt.provider.bedrock.base_provider import BaseBedrockProvider
from metagpt.provider.bedrock.utils import messages_to_prompt_llama2, messages_to_prompt_llama3, messages_to_prompt_claude
from metagpt.provider.bedrock.utils import messages_to_prompt_llama2, messages_to_prompt_llama3
class MistralProvider(BaseBedrockProvider):
@ -17,9 +17,6 @@ class MistralProvider(BaseBedrockProvider):
class AnthropicProvider(BaseBedrockProvider):
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
def messages_to_prompt(self, messages: list[dict]) -> str:
return messages_to_prompt_claude(messages)
def get_request_body(self, messages: list[dict], **generate_kwargs):
body = json.dumps(
{"messages": messages, "anthropic_version": "bedrock-2023-05-31", **generate_kwargs})

View file

@ -33,7 +33,7 @@ SUPPORT_STREAM_MODELS = {
# TODO:use a more general function for constructing chat templates.
def messages_to_prompt_llama2(messages: list[dict]):
def messages_to_prompt_llama2(messages: list[dict]) -> str:
BOS, EOS = "<s>", "</s>"
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
@ -56,7 +56,7 @@ def messages_to_prompt_llama2(messages: list[dict]):
return prompt
def messages_to_prompt_llama3(messages: list[dict]):
def messages_to_prompt_llama3(messages: list[dict]) -> str:
BOS, EOS = "<|begin_of_text|>", "<|eot_id|>"
GENERAL_TEMPLATE = "<|start_header_id|>{role}<|end_header_id|>\n\n{content}<|eot_id|>"
@ -72,7 +72,7 @@ def messages_to_prompt_llama3(messages: list[dict]):
return prompt
def messages_to_prompt_claude(messages: list[dict]):
def messages_to_prompt_claude2(messages: list[dict]) -> str:
GENERAL_TEMPLATE = "\n\n{role}: {content}"
prompt = ""
for message in messages: