Merge pull request #1282 from Wei-Jianan/fix/bedrock_bugs

Fix/bedrock bugs
This commit is contained in:
Alexander Wu 2024-05-17 19:15:42 +08:00 committed by GitHub
commit 99405cca20
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 28 additions and 12 deletions

View file

@ -25,4 +25,4 @@ class BaseBedrockProvider(ABC):
def messages_to_prompt(self, messages: list[dict]) -> str:
"""[{"role": "user", "content": msg}] to user: <msg> etc."""
return "\n".join([f"{i['role']}: {i['content']}" for i in messages])
return "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])

View file

@ -1,5 +1,5 @@
import json
from typing import Literal
from typing import Literal, Tuple
from metagpt.provider.bedrock.base_provider import BaseBedrockProvider
from metagpt.provider.bedrock.utils import (
@ -21,8 +21,26 @@ class MistralProvider(BaseBedrockProvider):
class AnthropicProvider(BaseBedrockProvider):
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs):
body = json.dumps({"messages": messages, "anthropic_version": "bedrock-2023-05-31", **generate_kwargs})
def _split_system_user_messages(self, messages: list[dict]) -> Tuple[str, list[dict]]:
system_messages = []
user_messages = []
for message in messages:
if message["role"] == "system":
system_messages.append(message)
else:
user_messages.append(message)
return self.messages_to_prompt(system_messages), user_messages
def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs) -> str:
system_message, user_messages = self._split_system_user_messages(messages)
body = json.dumps(
{
"messages": user_messages,
"anthropic_version": "bedrock-2023-05-31",
"system": system_message,
**generate_kwargs,
}
)
return body
def _get_completion_from_dict(self, rsp_dict: dict) -> str:

View file

@ -39,8 +39,8 @@ SUPPORT_STREAM_MODELS = {
"meta.llama2-13b-chat-v1": 2000,
"meta.llama2-70b-v1": 4000,
"meta.llama2-70b-v1:0:4k": 4000,
"meta.llama2-70b-chat-v1": 4000,
"meta.llama2-70b-chat-v1:0:4k": 4000,
"meta.llama2-70b-chat-v1": 2000,
"meta.llama2-70b-chat-v1:0:4k": 2000,
"meta.llama3-8b-instruct-v1:0": 2000,
"meta.llama3-70b-instruct-v1:0": 2000,
"mistral.mistral-7b-instruct-v0:2": 32000,

View file

@ -131,10 +131,8 @@ class BedrockLLM(BaseLLM):
headers = response.get("ResponseMetadata", {}).get("HTTPHeaders", {})
prompt_tokens = int(headers.get("x-amzn-bedrock-input-token-count", 0))
completion_tokens = int(headers.get("x-amzn-bedrock-output-token-count", 0))
usage = (
{
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
},
)
usage = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
}
return usage