diff --git a/metagpt/provider/bedrock/amazon_bedrock_api.py b/metagpt/provider/bedrock/amazon_bedrock_api.py index d8aaed8e9..123495da5 100644 --- a/metagpt/provider/bedrock/amazon_bedrock_api.py +++ b/metagpt/provider/bedrock/amazon_bedrock_api.py @@ -51,6 +51,7 @@ class AmazonBedrockLLM(BaseLLM): modelId=self.config.model, body=request_body ) completions = self.provider.get_choice_text(response) + log_llm_stream(completions) return completions def _chat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT): @@ -59,11 +60,10 @@ class AmazonBedrockLLM(BaseLLM): response = self.__client.invoke_model_with_response_stream( modelId=self.config.model, body=request_body ) - collected_content = [] - for event in response.get("body"): - chunk_text = json.loads(event["chunk"]["bytes"])[ - "outputs"][0]["text"] + collected_content = [] + for event in response["body"]: + chunk_text = self.provider.get_choice_text_from_stream(event) collected_content.append(chunk_text) log_llm_stream(chunk_text) @@ -84,7 +84,11 @@ class AmazonBedrockLLM(BaseLLM): if __name__ == '__main__': from .config import my_config - prompt = "write an essay for living on mars in 1000 word" - messages = [{"role": "user", "content": prompt}] + messages = [ + {"role": "system", "content": "your name is Bob"}, + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "hello,my friend"}, + {"role": "user", "content": "What is your name?"}] llm = AmazonBedrockLLM(my_config) + llm.completion(messages) llm._chat_completion_stream(messages) diff --git a/metagpt/provider/bedrock/base_provider.py b/metagpt/provider/bedrock/base_provider.py index 9a6f9659c..3ecd5789a 100644 --- a/metagpt/provider/bedrock/base_provider.py +++ b/metagpt/provider/bedrock/base_provider.py @@ -7,16 +7,16 @@ class BaseBedrockProvider(object): return json.dumps({"prompt": self.messages_to_prompt(messages)} | generate_kwargs) def get_choice_text(self, response) -> str: - response_body = json.loads(response["body"].read()) + response_body = self._get_response_body_json(response) completions = response_body["outputs"][0]['text'] return completions + def get_choice_text_from_stream(self, event): + return json.loads(event["chunk"]["bytes"])["outputs"][0]["text"] + + def _get_response_body_json(self, response): + return json.loads(response["body"].read()) + def messages_to_prompt(self, messages: list[dict]): """[{"role": "user", "content": msg}] to user: etc.""" return "\n".join([f"{i['role']}: {i['content']}" for i in messages]) - - def format_prompt(self, prompt: str) -> str: - return prompt - - def format_messages(self, messages: list[dict]) -> list[dict]: - return messages diff --git a/metagpt/provider/bedrock/bedrock_provider.py b/metagpt/provider/bedrock/bedrock_provider.py index 3ae84d8c3..a50f9abed 100644 --- a/metagpt/provider/bedrock/bedrock_provider.py +++ b/metagpt/provider/bedrock/bedrock_provider.py @@ -1,18 +1,16 @@ -from metagpt.provider.bedrock.base_provider import BaseBedrockProvider import json +from metagpt.provider.bedrock.base_provider import BaseBedrockProvider +from metagpt.provider.bedrock.utils import messages_to_prompt_llama class MistralProvider(BaseBedrockProvider): - - def format_prompt(self, prompt: str) -> str: - # for mixtral and llama - return f"[INST]{prompt}[/INST]" - - def get_request_body(self, messages, **generate_kwargs): - return json.dumps({"prompt": self.format_prompt(self.messages_to_prompt(messages))} | generate_kwargs) + # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral.html + def messages_to_prompt(self, messages: list[dict]): + return messages_to_prompt_llama(messages) class AnthropicProvider(BaseBedrockProvider): + # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html pass @@ -21,7 +19,16 @@ class CohereProvider(BaseBedrockProvider): class MetaProvider(BaseBedrockProvider): - pass + def messages_to_prompt(self, messages: list[dict]): + return messages_to_prompt_llama(messages) + + def get_choice_text(self, response) -> str: + response_body = self._get_response_body_json(response) + completions = response_body['generation'] + return completions + + def get_choice_text_from_stream(self, event): + return json.loads(event["chunk"]["bytes"])["generation"] class Ai21Provider(BaseBedrockProvider): @@ -29,7 +36,8 @@ class Ai21Provider(BaseBedrockProvider): PROVIDERS = { - "mistral": MistralProvider() + "mistral": MistralProvider(), + "meta": MetaProvider(), } NOT_SUUPORT_STREAM_MODELS = { diff --git a/metagpt/provider/bedrock/utils.py b/metagpt/provider/bedrock/utils.py new file mode 100644 index 000000000..7352a15a0 --- /dev/null +++ b/metagpt/provider/bedrock/utils.py @@ -0,0 +1,25 @@ +from metagpt.logs import logger + + +def messages_to_prompt_llama(messages: list[dict]): + BOS, EOS = "", "" + B_INST, E_INST = "[INST]", "[/INST]" + B_SYS, E_SYS = "<>\n", "\n<>\n\n" + + prompt = f"{BOS}" + for message in messages: + role = message["role"] + content = message["content"] + if role == "system": + prompt += f"{B_SYS} {content} {E_SYS}" + elif role == "user": + prompt += f"{B_INST} {content} {E_INST}" + elif role == "assistant": + prompt += f"{content}" + else: + logger.warning( + f"Unknown role name {role} when formatting messages") + prompt += f"{content}" + + return prompt +