stream test

This commit is contained in:
usamimeri_renko 2024-04-26 17:03:06 +08:00
parent 4c394a1cac
commit 8fafa2eb4e
4 changed files with 27 additions and 32 deletions

View file

@ -9,7 +9,7 @@ from metagpt.provider.bedrock.bedrock_provider import get_provider
from metagpt.provider.bedrock.utils import NOT_SUUPORT_STREAM_MODELS, get_max_tokens
try:
import boto3
from botocore.response import StreamingBody
from botocore.eventstream import EventStream
except ImportError:
raise ImportError(
"boto3 not found! please install it by `pip install boto3` ")
@ -70,7 +70,7 @@ class AmazonBedrockLLM(BaseLLM):
response_body = self._get_response_body(response)
return response_body
def invoke_model_with_response_stream(self, request_body) -> StreamingBody:
def invoke_model_with_response_stream(self, request_body) -> EventStream:
response = self.__client.invoke_model_with_response_stream(
modelId=self.config.model, body=request_body
)
@ -106,7 +106,6 @@ class AmazonBedrockLLM(BaseLLM):
messages, **self._generate_kwargs)
response = self.invoke_model_with_response_stream(request_body)
collected_content = []
for event in response["body"]:
chunk_text = self.__provider.get_choice_text_from_stream(event)

View file

@ -73,12 +73,7 @@ class AmazonProvider(BaseBedrockProvider):
return body
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
return rsp_dict['results'][0]['outputText'].strip()
def get_choice_text_from_stream(self, event) -> str:
rsp_dict = json.loads(event["chunk"]["bytes"])
completions = rsp_dict["outputText"]
return completions
return rsp_dict['results'][0]['outputText']
PROVIDERS = {

View file

@ -193,6 +193,8 @@ async def llm_general_chat_funcs_test(llm: BaseLLM, prompt: str, messages: list[
# For Amazon Bedrock
# Check the API documentation of each model
# https://docs.aws.amazon.com/bedrock/latest/userguide
BEDROCK_PROVIDER_REQUEST_BODY = {
"mistral": {"prompt": "", "max_tokens": 0, "stop": [], "temperature": 0.0, "top_p": 0.0, "top_k": 0},
"meta": {"prompt": "", "temperature": 0.0, "top_p": 0.0, "max_gen_len": 0},
@ -236,6 +238,4 @@ BEDROCK_PROVIDER_RESPONSE_BODY = {
},
"amazon": {'inputTextTokenCount': 0, 'results': [{'tokenCount': 0, 'outputText': 'Hello World', 'completionReason': ""}]}
}
BEDROCK_PROVIDER_STREAM_RESPONSE = {}
}

View file

@ -3,14 +3,10 @@ import json
from metagpt.provider.bedrock.amazon_bedrock_api import AmazonBedrockLLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config_bedrock
from metagpt.provider.bedrock.utils import get_max_tokens, SUPPORT_STREAM_MODELS, NOT_SUUPORT_STREAM_MODELS
from tests.metagpt.provider.req_resp_const import (
BEDROCK_PROVIDER_REQUEST_BODY,
BEDROCK_PROVIDER_RESPONSE_BODY,
BEDROCK_PROVIDER_STREAM_RESPONSE)
from botocore.response import StreamingBody
from tests.metagpt.provider.req_resp_const import BEDROCK_PROVIDER_REQUEST_BODY, BEDROCK_PROVIDER_RESPONSE_BODY
# all available model from bedrock
models = SUPPORT_STREAM_MODELS | NOT_SUUPORT_STREAM_MODELS
models = (SUPPORT_STREAM_MODELS | NOT_SUUPORT_STREAM_MODELS)
messages = [{"role": "user", "content": "Hi!"}]
@ -19,10 +15,15 @@ def mock_bedrock_provider_response(self, *args, **kwargs) -> dict:
return BEDROCK_PROVIDER_RESPONSE_BODY[provider]
def mock_bedrock_provider_stream_response(self, *args, **kwargs) -> StreamingBody:
def mock_bedrock_provider_stream_response(self, *args, **kwargs) -> dict:
# use json object to mock EventStream
provider = self.config.model.split(".")[0]
response_json = BEDROCK_PROVIDER_STREAM_RESPONSE[provider]
return
response_body_bytes = json.dumps(
BEDROCK_PROVIDER_RESPONSE_BODY[provider]).encode("utf-8")
# decoded bytes share the same format as non-stream response_body
response_body_stream = {
"body": [{'chunk': {'bytes': response_body_bytes}},]}
return response_body_stream
def get_bedrock_request_body(model_id) -> dict:
@ -51,7 +52,7 @@ def is_subset(subset, superset):
return True
@ pytest.fixture(scope="class", params=models)
@pytest.fixture(scope="class", params=models)
def bedrock_api(request) -> AmazonBedrockLLM:
model_id = request.param
mock_llm_config_bedrock.model = model_id
@ -66,6 +67,7 @@ class TestAPI:
bedrock_api.config.model)
def test_get_request_body(self, bedrock_api: AmazonBedrockLLM):
"""Ensure request body has correct format"""
provider = bedrock_api._get_provider()
request_body = json.loads(provider.get_request_body(
messages, **bedrock_api._generate_kwargs))
@ -77,15 +79,14 @@ class TestAPI:
bedrock_api.config.model))
def test_completion(self, bedrock_api: AmazonBedrockLLM, mocker):
mocker.patch(
"metagpt.provider.bedrock.amazon_bedrock_api.AmazonBedrockLLM.invoke_model", mock_bedrock_provider_response)
mocker.patch("metagpt.provider.bedrock.amazon_bedrock_api.AmazonBedrockLLM.invoke_model",
mock_bedrock_provider_response)
assert bedrock_api.completion(messages) == "Hello World"
# def test_stream_completion(self, bedrock_api: AmazonBedrockLLM, mocker):
# mocker.patch(
# "metagpt.provider.bedrock.amazon_bedrock_api.AmazonBedrockLLM.invoke_model_with_response_stream", mock_bedrock_provider_response)
# assert bedrock_api._chat_completion_stream(messages) == "Hello World"
if __name__ == '__main__':
print(get_bedrock_request_body("amazon.titan-tg1-large"))
def test_stream_completion(self, bedrock_api: AmazonBedrockLLM, mocker):
mocker.patch("metagpt.provider.bedrock.amazon_bedrock_api.AmazonBedrockLLM.invoke_model",
mock_bedrock_provider_response)
mocker.patch("metagpt.provider.bedrock.amazon_bedrock_api.AmazonBedrockLLM.invoke_model_with_response_stream",
mock_bedrock_provider_stream_response)
assert bedrock_api._chat_completion_stream(
messages) == "Hello World"