From 187e9ef698cc1820451aed917e19a57ddb61b7e0 Mon Sep 17 00:00:00 2001 From: usamimeri_renko <1710269958@qq.com> Date: Thu, 25 Apr 2024 20:24:39 +0800 Subject: [PATCH] support anthropic --- metagpt/provider/bedrock/base_provider.py | 10 +++++++--- metagpt/provider/bedrock/bedrock_provider.py | 17 +++++++++++++++-- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/metagpt/provider/bedrock/base_provider.py b/metagpt/provider/bedrock/base_provider.py index 3ecd5789a..2a17c335b 100644 --- a/metagpt/provider/bedrock/base_provider.py +++ b/metagpt/provider/bedrock/base_provider.py @@ -4,7 +4,9 @@ import json class BaseBedrockProvider(object): # to handle different generation kwargs def get_request_body(self, messages, **generate_kwargs): - return json.dumps({"prompt": self.messages_to_prompt(messages)} | generate_kwargs) + body = json.dumps( + {"prompt": self.messages_to_prompt(messages), **generate_kwargs}) + return body def get_choice_text(self, response) -> str: response_body = self._get_response_body_json(response) @@ -12,10 +14,12 @@ class BaseBedrockProvider(object): return completions def get_choice_text_from_stream(self, event): - return json.loads(event["chunk"]["bytes"])["outputs"][0]["text"] + completions = json.loads(event["chunk"]["bytes"])["outputs"][0]["text"] + return completions def _get_response_body_json(self, response): - return json.loads(response["body"].read()) + response_body = json.loads(response["body"].read()) + return response_body def messages_to_prompt(self, messages: list[dict]): """[{"role": "user", "content": msg}] to user: etc.""" diff --git a/metagpt/provider/bedrock/bedrock_provider.py b/metagpt/provider/bedrock/bedrock_provider.py index a50f9abed..bbf3de223 100644 --- a/metagpt/provider/bedrock/bedrock_provider.py +++ b/metagpt/provider/bedrock/bedrock_provider.py @@ -11,7 +11,19 @@ class MistralProvider(BaseBedrockProvider): class AnthropicProvider(BaseBedrockProvider): # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html - pass + def get_request_body(self, messages, **generate_kwargs): + body = json.dumps( + {"messages": messages, "anthropic_version": "bedrock-2023-05-31", **generate_kwargs}) + return body + + def get_choice_text(self, response) -> str: + response_body = self._get_response_body_json(response) + completions = response_body["content"][0]['text'] + return completions + + def get_choice_text_from_stream(self, event): + completions = json.loads(event["chunk"]["bytes"])["content"][0]["text"] + return completions class CohereProvider(BaseBedrockProvider): @@ -28,7 +40,8 @@ class MetaProvider(BaseBedrockProvider): return completions def get_choice_text_from_stream(self, event): - return json.loads(event["chunk"]["bytes"])["generation"] + completions = json.loads(event["chunk"]["bytes"])["generation"] + return completions class Ai21Provider(BaseBedrockProvider):