[fix] serval bugs in bedrock LLM

This commit is contained in:
Wei-Jianan 2024-05-17 10:21:16 +00:00
parent 44ea12fe66
commit 1911e27219
3 changed files with 17 additions and 8 deletions

View file

@ -25,4 +25,4 @@ class BaseBedrockProvider(ABC):
def messages_to_prompt(self, messages: list[dict]) -> str:
"""[{"role": "user", "content": msg}] to user: <msg> etc."""
return "\n".join([f"{i['role']}: {i['content']}" for i in messages])
return "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])

View file

@ -1,5 +1,5 @@
import json
from typing import Literal
from typing import Literal, Tuple
from metagpt.provider.bedrock.base_provider import BaseBedrockProvider
from metagpt.provider.bedrock.utils import (
@ -21,8 +21,19 @@ class MistralProvider(BaseBedrockProvider):
class AnthropicProvider(BaseBedrockProvider):
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs):
body = json.dumps({"messages": messages, "anthropic_version": "bedrock-2023-05-31", **generate_kwargs})
def _split_system_user_messages(self, messages: list[dict]) -> Tuple[str, list[dict]]:
system_messages = []
user_messages = []
for message in messages:
if message["role"] == "system":
system_messages.append(message)
else:
user_messages.append(message)
return self.messages_to_prompt(system_messages), user_messages
def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs) -> str:
system_message, user_messages = self._split_system_user_messages(messages)
body = json.dumps({"messages": user_messages, "anthropic_version": "bedrock-2023-05-31", "system": system_message, **generate_kwargs})
return body
def _get_completion_from_dict(self, rsp_dict: dict) -> str:

View file

@ -131,10 +131,8 @@ class BedrockLLM(BaseLLM):
headers = response.get("ResponseMetadata", {}).get("HTTPHeaders", {})
prompt_tokens = int(headers.get("x-amzn-bedrock-input-token-count", 0))
completion_tokens = int(headers.get("x-amzn-bedrock-output-token-count", 0))
usage = (
{
usage = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
},
)
}
return usage