From 1911e272190d87d089ab36e4f8ba9e888c1caa9f Mon Sep 17 00:00:00 2001 From: Wei-Jianan Date: Fri, 17 May 2024 10:21:16 +0000 Subject: [PATCH 1/3] [fix] serval bugs in bedrock LLM --- metagpt/provider/bedrock/base_provider.py | 2 +- metagpt/provider/bedrock/bedrock_provider.py | 17 ++++++++++++++--- metagpt/provider/bedrock_api.py | 6 ++---- 3 files changed, 17 insertions(+), 8 deletions(-) diff --git a/metagpt/provider/bedrock/base_provider.py b/metagpt/provider/bedrock/base_provider.py index 0d13ae938..ebc55483b 100644 --- a/metagpt/provider/bedrock/base_provider.py +++ b/metagpt/provider/bedrock/base_provider.py @@ -25,4 +25,4 @@ class BaseBedrockProvider(ABC): def messages_to_prompt(self, messages: list[dict]) -> str: """[{"role": "user", "content": msg}] to user: etc.""" - return "\n".join([f"{i['role']}: {i['content']}" for i in messages]) + return "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) diff --git a/metagpt/provider/bedrock/bedrock_provider.py b/metagpt/provider/bedrock/bedrock_provider.py index ff1d88a47..06a0029bd 100644 --- a/metagpt/provider/bedrock/bedrock_provider.py +++ b/metagpt/provider/bedrock/bedrock_provider.py @@ -1,5 +1,5 @@ import json -from typing import Literal +from typing import Literal, Tuple from metagpt.provider.bedrock.base_provider import BaseBedrockProvider from metagpt.provider.bedrock.utils import ( @@ -21,8 +21,19 @@ class MistralProvider(BaseBedrockProvider): class AnthropicProvider(BaseBedrockProvider): # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html - def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs): - body = json.dumps({"messages": messages, "anthropic_version": "bedrock-2023-05-31", **generate_kwargs}) + def _split_system_user_messages(self, messages: list[dict]) -> Tuple[str, list[dict]]: + system_messages = [] + user_messages = [] + for message in messages: + if message["role"] == "system": + system_messages.append(message) + else: + user_messages.append(message) + return self.messages_to_prompt(system_messages), user_messages + + def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs) -> str: + system_message, user_messages = self._split_system_user_messages(messages) + body = json.dumps({"messages": user_messages, "anthropic_version": "bedrock-2023-05-31", "system": system_message, **generate_kwargs}) return body def _get_completion_from_dict(self, rsp_dict: dict) -> str: diff --git a/metagpt/provider/bedrock_api.py b/metagpt/provider/bedrock_api.py index d192a5478..b4a5f267f 100644 --- a/metagpt/provider/bedrock_api.py +++ b/metagpt/provider/bedrock_api.py @@ -131,10 +131,8 @@ class BedrockLLM(BaseLLM): headers = response.get("ResponseMetadata", {}).get("HTTPHeaders", {}) prompt_tokens = int(headers.get("x-amzn-bedrock-input-token-count", 0)) completion_tokens = int(headers.get("x-amzn-bedrock-output-token-count", 0)) - usage = ( - { + usage = { "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, - }, - ) + } return usage From 3338c304b5b7e09dc68740dab080560a9233d0e0 Mon Sep 17 00:00:00 2001 From: Wei-Jianan Date: Fri, 17 May 2024 10:34:10 +0000 Subject: [PATCH 2/3] [fix] max token problem in llama2 --- metagpt/provider/bedrock/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/metagpt/provider/bedrock/utils.py b/metagpt/provider/bedrock/utils.py index ee31da1b9..4f3be47ae 100644 --- a/metagpt/provider/bedrock/utils.py +++ b/metagpt/provider/bedrock/utils.py @@ -39,8 +39,8 @@ SUPPORT_STREAM_MODELS = { "meta.llama2-13b-chat-v1": 2000, "meta.llama2-70b-v1": 4000, "meta.llama2-70b-v1:0:4k": 4000, - "meta.llama2-70b-chat-v1": 4000, - "meta.llama2-70b-chat-v1:0:4k": 4000, + "meta.llama2-70b-chat-v1": 2000, + "meta.llama2-70b-chat-v1:0:4k": 2000, "meta.llama3-8b-instruct-v1:0": 2000, "meta.llama3-70b-instruct-v1:0": 2000, "mistral.mistral-7b-instruct-v0:2": 32000, From e7c7b44a0b45e40fd78c644576bbe23689ba3605 Mon Sep 17 00:00:00 2001 From: Wei-Jianan Date: Fri, 17 May 2024 18:44:05 +0800 Subject: [PATCH 3/3] [format] --- metagpt/provider/bedrock/bedrock_provider.py | 9 ++++++++- metagpt/provider/bedrock_api.py | 6 +++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/metagpt/provider/bedrock/bedrock_provider.py b/metagpt/provider/bedrock/bedrock_provider.py index 06a0029bd..1236bf56b 100644 --- a/metagpt/provider/bedrock/bedrock_provider.py +++ b/metagpt/provider/bedrock/bedrock_provider.py @@ -33,7 +33,14 @@ class AnthropicProvider(BaseBedrockProvider): def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs) -> str: system_message, user_messages = self._split_system_user_messages(messages) - body = json.dumps({"messages": user_messages, "anthropic_version": "bedrock-2023-05-31", "system": system_message, **generate_kwargs}) + body = json.dumps( + { + "messages": user_messages, + "anthropic_version": "bedrock-2023-05-31", + "system": system_message, + **generate_kwargs, + } + ) return body def _get_completion_from_dict(self, rsp_dict: dict) -> str: diff --git a/metagpt/provider/bedrock_api.py b/metagpt/provider/bedrock_api.py index b4a5f267f..f30d4701e 100644 --- a/metagpt/provider/bedrock_api.py +++ b/metagpt/provider/bedrock_api.py @@ -132,7 +132,7 @@ class BedrockLLM(BaseLLM): prompt_tokens = int(headers.get("x-amzn-bedrock-input-token-count", 0)) completion_tokens = int(headers.get("x-amzn-bedrock-output-token-count", 0)) usage = { - "prompt_tokens": prompt_tokens, - "completion_tokens": completion_tokens, - } + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + } return usage