diff --git a/metagpt/provider/bedrock/bedrock_provider.py b/metagpt/provider/bedrock/bedrock_provider.py index a4b90a82f..29cdf38a9 100644 --- a/metagpt/provider/bedrock/bedrock_provider.py +++ b/metagpt/provider/bedrock/bedrock_provider.py @@ -74,6 +74,11 @@ class AmazonProvider(BaseBedrockProvider): def _get_completion_from_dict(self, rsp_dict: dict) -> str: return rsp_dict['results'][0]['outputText'] + + def get_choice_text_from_stream(self, event) -> str: + rsp_dict = json.loads(event["chunk"]["bytes"]) + completions = rsp_dict["outputText"] + return completions PROVIDERS = { diff --git a/tests/metagpt/provider/test_amazon_bedrock_api.py b/tests/metagpt/provider/test_amazon_bedrock_api.py index 46d44e81b..c10b74d0a 100644 --- a/tests/metagpt/provider/test_amazon_bedrock_api.py +++ b/tests/metagpt/provider/test_amazon_bedrock_api.py @@ -17,12 +17,17 @@ def mock_bedrock_provider_response(self, *args, **kwargs) -> dict: def mock_bedrock_provider_stream_response(self, *args, **kwargs) -> dict: # use json object to mock EventStream + def dict2bytes(x): + return json.dumps(x).encode("utf-8") provider = self.config.model.split(".")[0] - response_body_bytes = json.dumps( - BEDROCK_PROVIDER_RESPONSE_BODY[provider]).encode("utf-8") - # decoded bytes share the same format as non-stream response_body - response_body_stream = { - "body": [{'chunk': {'bytes': response_body_bytes}},]} + response_body_bytes = dict2bytes(BEDROCK_PROVIDER_RESPONSE_BODY[provider]) + # decoded bytes share the same format as non-stream response_body except for titan + if provider == "amazon": + response_body_stream = { + "body": [{'chunk': {'bytes': dict2bytes({"outputText": "Hello World"})}}]} + else: + response_body_stream = { + "body": [{'chunk': {'bytes': response_body_bytes}}]} return response_body_stream