mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-27 14:25:20 +02:00
[fix] serval bugs in bedrock LLM
This commit is contained in:
parent
44ea12fe66
commit
1911e27219
3 changed files with 17 additions and 8 deletions
|
|
@ -25,4 +25,4 @@ class BaseBedrockProvider(ABC):
|
|||
|
||||
def messages_to_prompt(self, messages: list[dict]) -> str:
|
||||
"""[{"role": "user", "content": msg}] to user: <msg> etc."""
|
||||
return "\n".join([f"{i['role']}: {i['content']}" for i in messages])
|
||||
return "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import json
|
||||
from typing import Literal
|
||||
from typing import Literal, Tuple
|
||||
|
||||
from metagpt.provider.bedrock.base_provider import BaseBedrockProvider
|
||||
from metagpt.provider.bedrock.utils import (
|
||||
|
|
@ -21,8 +21,19 @@ class MistralProvider(BaseBedrockProvider):
|
|||
class AnthropicProvider(BaseBedrockProvider):
|
||||
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
|
||||
|
||||
def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs):
|
||||
body = json.dumps({"messages": messages, "anthropic_version": "bedrock-2023-05-31", **generate_kwargs})
|
||||
def _split_system_user_messages(self, messages: list[dict]) -> Tuple[str, list[dict]]:
|
||||
system_messages = []
|
||||
user_messages = []
|
||||
for message in messages:
|
||||
if message["role"] == "system":
|
||||
system_messages.append(message)
|
||||
else:
|
||||
user_messages.append(message)
|
||||
return self.messages_to_prompt(system_messages), user_messages
|
||||
|
||||
def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs) -> str:
|
||||
system_message, user_messages = self._split_system_user_messages(messages)
|
||||
body = json.dumps({"messages": user_messages, "anthropic_version": "bedrock-2023-05-31", "system": system_message, **generate_kwargs})
|
||||
return body
|
||||
|
||||
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
|
||||
|
|
|
|||
|
|
@ -131,10 +131,8 @@ class BedrockLLM(BaseLLM):
|
|||
headers = response.get("ResponseMetadata", {}).get("HTTPHeaders", {})
|
||||
prompt_tokens = int(headers.get("x-amzn-bedrock-input-token-count", 0))
|
||||
completion_tokens = int(headers.get("x-amzn-bedrock-output-token-count", 0))
|
||||
usage = (
|
||||
{
|
||||
usage = {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
},
|
||||
)
|
||||
}
|
||||
return usage
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue