rename bedrock class and add more tests

This commit is contained in:
usamimeri_renko 2024-04-29 10:46:50 +08:00
parent 986e3c827e
commit 3f108abd06
4 changed files with 42 additions and 38 deletions

View file

@ -1,6 +1,6 @@
import pytest
import json
from metagpt.provider.bedrock_api import AmazonBedrockLLM
from metagpt.provider.bedrock_api import BedrockLLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config_bedrock
from metagpt.provider.bedrock.utils import get_max_tokens, SUPPORT_STREAM_MODELS, NOT_SUUPORT_STREAM_MODELS
from tests.metagpt.provider.req_resp_const import BEDROCK_PROVIDER_REQUEST_BODY, BEDROCK_PROVIDER_RESPONSE_BODY
@ -65,36 +65,50 @@ def is_subset(subset, superset) -> bool:
@pytest.fixture(scope="class", params=models)
def bedrock_api(request) -> AmazonBedrockLLM:
def bedrock_api(request) -> BedrockLLM:
model_id = request.param
mock_llm_config_bedrock.model = model_id
api = AmazonBedrockLLM(mock_llm_config_bedrock)
api = BedrockLLM(mock_llm_config_bedrock)
return api
class TestAPI:
def test_generate_kwargs(self, bedrock_api: AmazonBedrockLLM):
class TestBedrockAPI:
def _patch_invoke_model(self, mocker):
mocker.patch("metagpt.provider.bedrock_api.BedrockLLM.invoke_model", mock_bedrock_provider_response)
def _patch_invoke_model_stream(self, mocker):
mocker.patch("metagpt.provider.bedrock_api.BedrockLLM.invoke_model_with_response_stream",
mock_bedrock_provider_stream_response)
def test_const_kwargs(self, bedrock_api: BedrockLLM):
provider = bedrock_api.provider
assert bedrock_api._generate_kwargs[provider.max_tokens_field_name] <= get_max_tokens(
assert bedrock_api._const_kwargs[provider.max_tokens_field_name] <= get_max_tokens(
bedrock_api.config.model)
def test_get_request_body(self, bedrock_api: AmazonBedrockLLM):
def test_get_request_body(self, bedrock_api: BedrockLLM):
"""Ensure request body has correct format"""
provider = bedrock_api.provider
request_body = json.loads(provider.get_request_body(
messages, bedrock_api._generate_kwargs))
messages, bedrock_api._const_kwargs))
assert is_subset(request_body, get_bedrock_request_body(
bedrock_api.config.model))
assert is_subset(request_body, get_bedrock_request_body(bedrock_api.config.model))
def test_completion(self, bedrock_api: AmazonBedrockLLM, mocker):
mocker.patch("metagpt.provider.bedrock_api.AmazonBedrockLLM.invoke_model",
mock_bedrock_provider_response)
def test_completion(self, bedrock_api: BedrockLLM, mocker):
self._patch_invoke_model(mocker)
assert bedrock_api.completion(messages) == "Hello World"
def test_stream_completion(self, bedrock_api: AmazonBedrockLLM, mocker):
mocker.patch("metagpt.provider.bedrock_api.AmazonBedrockLLM.invoke_model",
mock_bedrock_provider_response)
mocker.patch("metagpt.provider.bedrock_api.AmazonBedrockLLM.invoke_model_with_response_stream",
mock_bedrock_provider_stream_response)
def test_chat_completion_stream(self, bedrock_api: BedrockLLM, mocker):
self._patch_invoke_model(mocker)
self._patch_invoke_model_stream(mocker)
assert bedrock_api._chat_completion_stream(messages) == "Hello World"
@pytest.mark.asyncio
async def test_achat_completion_stream(self, bedrock_api: BedrockLLM, mocker):
self._patch_invoke_model_stream(mocker)
self._patch_invoke_model(mocker)
assert await bedrock_api._achat_completion_stream(messages) == "Hello World"
@pytest.mark.asyncio
async def test_acompletion(self, bedrock_api: BedrockLLM, mocker):
self._patch_invoke_model(mocker)
assert await bedrock_api.acompletion(messages) == "Hello World"