diff --git a/metagpt/provider/bedrock/base_provider.py b/metagpt/provider/bedrock/base_provider.py index 0d13ae938..ebc55483b 100644 --- a/metagpt/provider/bedrock/base_provider.py +++ b/metagpt/provider/bedrock/base_provider.py @@ -25,4 +25,4 @@ class BaseBedrockProvider(ABC): def messages_to_prompt(self, messages: list[dict]) -> str: """[{"role": "user", "content": msg}] to user: etc.""" - return "\n".join([f"{i['role']}: {i['content']}" for i in messages]) + return "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) diff --git a/metagpt/provider/bedrock/bedrock_provider.py b/metagpt/provider/bedrock/bedrock_provider.py index ff1d88a47..06a0029bd 100644 --- a/metagpt/provider/bedrock/bedrock_provider.py +++ b/metagpt/provider/bedrock/bedrock_provider.py @@ -1,5 +1,5 @@ import json -from typing import Literal +from typing import Literal, Tuple from metagpt.provider.bedrock.base_provider import BaseBedrockProvider from metagpt.provider.bedrock.utils import ( @@ -21,8 +21,19 @@ class MistralProvider(BaseBedrockProvider): class AnthropicProvider(BaseBedrockProvider): # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html - def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs): - body = json.dumps({"messages": messages, "anthropic_version": "bedrock-2023-05-31", **generate_kwargs}) + def _split_system_user_messages(self, messages: list[dict]) -> Tuple[str, list[dict]]: + system_messages = [] + user_messages = [] + for message in messages: + if message["role"] == "system": + system_messages.append(message) + else: + user_messages.append(message) + return self.messages_to_prompt(system_messages), user_messages + + def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs) -> str: + system_message, user_messages = self._split_system_user_messages(messages) + body = json.dumps({"messages": user_messages, "anthropic_version": "bedrock-2023-05-31", "system": system_message, **generate_kwargs}) return body def _get_completion_from_dict(self, rsp_dict: dict) -> str: diff --git a/metagpt/provider/bedrock_api.py b/metagpt/provider/bedrock_api.py index d192a5478..b4a5f267f 100644 --- a/metagpt/provider/bedrock_api.py +++ b/metagpt/provider/bedrock_api.py @@ -131,10 +131,8 @@ class BedrockLLM(BaseLLM): headers = response.get("ResponseMetadata", {}).get("HTTPHeaders", {}) prompt_tokens = int(headers.get("x-amzn-bedrock-input-token-count", 0)) completion_tokens = int(headers.get("x-amzn-bedrock-output-token-count", 0)) - usage = ( - { + usage = { "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, - }, - ) + } return usage