diff --git a/metagpt/provider/bedrock/amazon_bedrock_api.py b/metagpt/provider/bedrock/amazon_bedrock_api.py index 07687c682..3d1b08f47 100644 --- a/metagpt/provider/bedrock/amazon_bedrock_api.py +++ b/metagpt/provider/bedrock/amazon_bedrock_api.py @@ -9,7 +9,7 @@ from metagpt.provider.bedrock.bedrock_provider import get_provider from metagpt.provider.bedrock.utils import NOT_SUUPORT_STREAM_MODELS, get_max_tokens try: import boto3 - from botocore.response import StreamingBody + from botocore.eventstream import EventStream except ImportError: raise ImportError( "boto3 not found! please install it by `pip install boto3` ") @@ -70,7 +70,7 @@ class AmazonBedrockLLM(BaseLLM): response_body = self._get_response_body(response) return response_body - def invoke_model_with_response_stream(self, request_body) -> StreamingBody: + def invoke_model_with_response_stream(self, request_body) -> EventStream: response = self.__client.invoke_model_with_response_stream( modelId=self.config.model, body=request_body ) @@ -106,7 +106,6 @@ class AmazonBedrockLLM(BaseLLM): messages, **self._generate_kwargs) response = self.invoke_model_with_response_stream(request_body) - collected_content = [] for event in response["body"]: chunk_text = self.__provider.get_choice_text_from_stream(event) diff --git a/metagpt/provider/bedrock/bedrock_provider.py b/metagpt/provider/bedrock/bedrock_provider.py index e01184705..a4b90a82f 100644 --- a/metagpt/provider/bedrock/bedrock_provider.py +++ b/metagpt/provider/bedrock/bedrock_provider.py @@ -73,12 +73,7 @@ class AmazonProvider(BaseBedrockProvider): return body def _get_completion_from_dict(self, rsp_dict: dict) -> str: - return rsp_dict['results'][0]['outputText'].strip() - - def get_choice_text_from_stream(self, event) -> str: - rsp_dict = json.loads(event["chunk"]["bytes"]) - completions = rsp_dict["outputText"] - return completions + return rsp_dict['results'][0]['outputText'] PROVIDERS = { diff --git a/tests/metagpt/provider/req_resp_const.py b/tests/metagpt/provider/req_resp_const.py index 6a244cbe4..893c33704 100644 --- a/tests/metagpt/provider/req_resp_const.py +++ b/tests/metagpt/provider/req_resp_const.py @@ -193,6 +193,8 @@ async def llm_general_chat_funcs_test(llm: BaseLLM, prompt: str, messages: list[ # For Amazon Bedrock +# Check the API documentation of each model +# https://docs.aws.amazon.com/bedrock/latest/userguide BEDROCK_PROVIDER_REQUEST_BODY = { "mistral": {"prompt": "", "max_tokens": 0, "stop": [], "temperature": 0.0, "top_p": 0.0, "top_k": 0}, "meta": {"prompt": "", "temperature": 0.0, "top_p": 0.0, "max_gen_len": 0}, @@ -236,6 +238,4 @@ BEDROCK_PROVIDER_RESPONSE_BODY = { }, "amazon": {'inputTextTokenCount': 0, 'results': [{'tokenCount': 0, 'outputText': 'Hello World', 'completionReason': ""}]} -} - -BEDROCK_PROVIDER_STREAM_RESPONSE = {} +} \ No newline at end of file diff --git a/tests/metagpt/provider/test_amazon_bedrock_api.py b/tests/metagpt/provider/test_amazon_bedrock_api.py index 0b8206463..bf8e467e7 100644 --- a/tests/metagpt/provider/test_amazon_bedrock_api.py +++ b/tests/metagpt/provider/test_amazon_bedrock_api.py @@ -3,14 +3,10 @@ import json from metagpt.provider.bedrock.amazon_bedrock_api import AmazonBedrockLLM from tests.metagpt.provider.mock_llm_config import mock_llm_config_bedrock from metagpt.provider.bedrock.utils import get_max_tokens, SUPPORT_STREAM_MODELS, NOT_SUUPORT_STREAM_MODELS -from tests.metagpt.provider.req_resp_const import ( - BEDROCK_PROVIDER_REQUEST_BODY, - BEDROCK_PROVIDER_RESPONSE_BODY, - BEDROCK_PROVIDER_STREAM_RESPONSE) -from botocore.response import StreamingBody +from tests.metagpt.provider.req_resp_const import BEDROCK_PROVIDER_REQUEST_BODY, BEDROCK_PROVIDER_RESPONSE_BODY # all available model from bedrock -models = SUPPORT_STREAM_MODELS | NOT_SUUPORT_STREAM_MODELS +models = (SUPPORT_STREAM_MODELS | NOT_SUUPORT_STREAM_MODELS) messages = [{"role": "user", "content": "Hi!"}] @@ -19,10 +15,15 @@ def mock_bedrock_provider_response(self, *args, **kwargs) -> dict: return BEDROCK_PROVIDER_RESPONSE_BODY[provider] -def mock_bedrock_provider_stream_response(self, *args, **kwargs) -> StreamingBody: +def mock_bedrock_provider_stream_response(self, *args, **kwargs) -> dict: + # use json object to mock EventStream provider = self.config.model.split(".")[0] - response_json = BEDROCK_PROVIDER_STREAM_RESPONSE[provider] - return + response_body_bytes = json.dumps( + BEDROCK_PROVIDER_RESPONSE_BODY[provider]).encode("utf-8") + # decoded bytes share the same format as non-stream response_body + response_body_stream = { + "body": [{'chunk': {'bytes': response_body_bytes}},]} + return response_body_stream def get_bedrock_request_body(model_id) -> dict: @@ -51,7 +52,7 @@ def is_subset(subset, superset): return True -@ pytest.fixture(scope="class", params=models) +@pytest.fixture(scope="class", params=models) def bedrock_api(request) -> AmazonBedrockLLM: model_id = request.param mock_llm_config_bedrock.model = model_id @@ -66,6 +67,7 @@ class TestAPI: bedrock_api.config.model) def test_get_request_body(self, bedrock_api: AmazonBedrockLLM): + """Ensure request body has correct format""" provider = bedrock_api._get_provider() request_body = json.loads(provider.get_request_body( messages, **bedrock_api._generate_kwargs)) @@ -77,15 +79,14 @@ class TestAPI: bedrock_api.config.model)) def test_completion(self, bedrock_api: AmazonBedrockLLM, mocker): - mocker.patch( - "metagpt.provider.bedrock.amazon_bedrock_api.AmazonBedrockLLM.invoke_model", mock_bedrock_provider_response) + mocker.patch("metagpt.provider.bedrock.amazon_bedrock_api.AmazonBedrockLLM.invoke_model", + mock_bedrock_provider_response) assert bedrock_api.completion(messages) == "Hello World" - # def test_stream_completion(self, bedrock_api: AmazonBedrockLLM, mocker): - # mocker.patch( - # "metagpt.provider.bedrock.amazon_bedrock_api.AmazonBedrockLLM.invoke_model_with_response_stream", mock_bedrock_provider_response) - # assert bedrock_api._chat_completion_stream(messages) == "Hello World" - - -if __name__ == '__main__': - print(get_bedrock_request_body("amazon.titan-tg1-large")) + def test_stream_completion(self, bedrock_api: AmazonBedrockLLM, mocker): + mocker.patch("metagpt.provider.bedrock.amazon_bedrock_api.AmazonBedrockLLM.invoke_model", + mock_bedrock_provider_response) + mocker.patch("metagpt.provider.bedrock.amazon_bedrock_api.AmazonBedrockLLM.invoke_model_with_response_stream", + mock_bedrock_provider_stream_response) + assert bedrock_api._chat_completion_stream( + messages) == "Hello World"