diff --git a/metagpt/provider/bedrock/amazon_bedrock_api.py b/metagpt/provider/bedrock/amazon_bedrock_api.py index aad58f884..6230cd3f2 100644 --- a/metagpt/provider/bedrock/amazon_bedrock_api.py +++ b/metagpt/provider/bedrock/amazon_bedrock_api.py @@ -86,7 +86,7 @@ class AmazonBedrockLLM(BaseLLM): def completion(self, messages: list[dict]) -> str: request_body = self.__provider.get_request_body( - messages, **self._generate_kwargs) + messages, self._generate_kwargs) response_body = self.invoke_model(request_body) completions = self.__provider.get_choice_text(response_body) return completions @@ -98,7 +98,7 @@ class AmazonBedrockLLM(BaseLLM): return self.completion(messages) request_body = self.__provider.get_request_body( - messages, **self._generate_kwargs) + messages, self._generate_kwargs, stream=True) response = self.invoke_model_with_response_stream(request_body) collected_content = [] diff --git a/metagpt/provider/bedrock/base_provider.py b/metagpt/provider/bedrock/base_provider.py index 449e0f5c8..5aa6ae605 100644 --- a/metagpt/provider/bedrock/base_provider.py +++ b/metagpt/provider/bedrock/base_provider.py @@ -10,7 +10,7 @@ class BaseBedrockProvider(ABC): def _get_completion_from_dict(self, rsp_dict: dict) -> str: ... - def get_request_body(self, messages: list[dict], **generate_kwargs): + def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs): body = json.dumps( {"prompt": self.messages_to_prompt(messages), **generate_kwargs}) return body diff --git a/metagpt/provider/bedrock/bedrock_provider.py b/metagpt/provider/bedrock/bedrock_provider.py index 47d699313..01bcaac53 100644 --- a/metagpt/provider/bedrock/bedrock_provider.py +++ b/metagpt/provider/bedrock/bedrock_provider.py @@ -17,7 +17,7 @@ class MistralProvider(BaseBedrockProvider): class AnthropicProvider(BaseBedrockProvider): # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html - def get_request_body(self, messages: list[dict], **generate_kwargs): + def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs): body = json.dumps( {"messages": messages, "anthropic_version": "bedrock-2023-05-31", **generate_kwargs}) return body @@ -41,6 +41,16 @@ class CohereProvider(BaseBedrockProvider): def _get_completion_from_dict(self, rsp_dict: dict) -> str: return rsp_dict["generations"][0]["text"] + def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs): + body = json.dumps( + {"prompt": self.messages_to_prompt(messages), "stream": kwargs.get("stream", False), **generate_kwargs}) + return body + + def get_choice_text_from_stream(self, event) -> str: + rsp_dict = json.loads(event["chunk"]["bytes"]) + completions = rsp_dict.get("text", "") + return completions + class MetaProvider(BaseBedrockProvider): # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html @@ -74,7 +84,7 @@ class AmazonProvider(BaseBedrockProvider): max_tokens_field_name = "maxTokenCount" - def get_request_body(self, messages: list[dict], **generate_kwargs): + def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs): body = json.dumps({ "inputText": self.messages_to_prompt(messages), "textGenerationConfig": generate_kwargs diff --git a/metagpt/provider/bedrock/utils.py b/metagpt/provider/bedrock/utils.py index f69d04a7b..a87367b22 100644 --- a/metagpt/provider/bedrock/utils.py +++ b/metagpt/provider/bedrock/utils.py @@ -14,15 +14,24 @@ SUPPORT_STREAM_MODELS = { "amazon.titan-tg1-large": 8000, "amazon.titan-text-express-v1": 8000, "anthropic.claude-instant-v1": 100000, + "anthropic.claude-instant-v1:2:100k": 100000, "anthropic.claude-v1": 100000, "anthropic.claude-v2": 100000, "anthropic.claude-v2:1": 200000, "anthropic.claude-3-sonnet-20240229-v1:0": 200000, + "anthropic.claude-3-sonnet-20240229-v1:0:28k": 28000, + "anthropic.claude-3-sonnet-20240229-v1:0:200k": 200000, "anthropic.claude-3-haiku-20240307-v1:0": 200000, - "anthropic.claude-3-opus-20240229-v1:0": 200000, - "cohere.command-text-v14": 4096, - "cohere.command-light-text-v14": 4096, - "meta.llama2-70b-v1": 4096, + "anthropic.claude-3-haiku-20240307-v1:0:48k": 48000, + "anthropic.claude-3-haiku-20240307-v1:0:200k": 200000, + "cohere.command-text-v14": 4000, + "cohere.command-text-v14:7:4k": 4000, + "cohere.command-light-text-v14": 4000, + "cohere.command-light-text-v14:7:4k": 4000, + "meta.llama2-70b-v1": 4000, + "meta.llama2-13b-chat-v1:0:4k": 4000, + "meta.llama2-13b-chat-v1": 2000, + "meta.llama2-70b-v1:0:4k": 4000, "meta.llama3-8b-instruct-v1:0": 2000, "meta.llama3-70b-instruct-v1:0": 2000, "mistral.mistral-7b-instruct-v0:2": 32000, diff --git a/tests/metagpt/provider/test_amazon_bedrock_api.py b/tests/metagpt/provider/test_amazon_bedrock_api.py index c10b74d0a..cd13c0b24 100644 --- a/tests/metagpt/provider/test_amazon_bedrock_api.py +++ b/tests/metagpt/provider/test_amazon_bedrock_api.py @@ -20,14 +20,21 @@ def mock_bedrock_provider_stream_response(self, *args, **kwargs) -> dict: def dict2bytes(x): return json.dumps(x).encode("utf-8") provider = self.config.model.split(".")[0] - response_body_bytes = dict2bytes(BEDROCK_PROVIDER_RESPONSE_BODY[provider]) - # decoded bytes share the same format as non-stream response_body except for titan + if provider == "amazon": - response_body_stream = { - "body": [{'chunk': {'bytes': dict2bytes({"outputText": "Hello World"})}}]} + response_body_bytes = dict2bytes({"outputText": "Hello World"}) + elif provider == "anthropic": + response_body_bytes = dict2bytes({"type": "content_block_delta", "index": 0, + "delta": {"type": "text_delta", "text": "Hello World"}}) + elif provider == "cohere": + response_body_bytes = dict2bytes( + {"is_finished": False, "text": "Hello World"}) else: - response_body_stream = { - "body": [{'chunk': {'bytes': response_body_bytes}}]} + response_body_bytes = dict2bytes( + BEDROCK_PROVIDER_RESPONSE_BODY[provider]) + + response_body_stream = { + "body": [{'chunk': {'bytes': response_body_bytes}}]} return response_body_stream @@ -75,7 +82,7 @@ class TestAPI: """Ensure request body has correct format""" provider = bedrock_api._get_provider() request_body = json.loads(provider.get_request_body( - messages, **bedrock_api._generate_kwargs)) + messages, bedrock_api._generate_kwargs)) assert is_subset(request_body, get_bedrock_request_body( bedrock_api.config.model))