diff --git a/metagpt/provider/__init__.py b/metagpt/provider/__init__.py index 1311ccf61..fc91c7460 100644 --- a/metagpt/provider/__init__.py +++ b/metagpt/provider/__init__.py @@ -17,7 +17,7 @@ from metagpt.provider.spark_api import SparkLLM from metagpt.provider.qianfan_api import QianFanLLM from metagpt.provider.dashscope_api import DashScopeLLM from metagpt.provider.anthropic_api import AnthropicLLM -from metagpt.provider.bedrock_api import AmazonBedrockLLM +from metagpt.provider.bedrock_api import BedrockLLM __all__ = [ "GeminiLLM", @@ -31,5 +31,5 @@ __all__ = [ "QianFanLLM", "DashScopeLLM", "AnthropicLLM", - "AmazonBedrockLLM" + "BedrockLLM" ] diff --git a/metagpt/provider/bedrock/base_provider.py b/metagpt/provider/bedrock/base_provider.py index 5aa6ae605..fd0508cb0 100644 --- a/metagpt/provider/bedrock/base_provider.py +++ b/metagpt/provider/bedrock/base_provider.py @@ -10,9 +10,9 @@ class BaseBedrockProvider(ABC): def _get_completion_from_dict(self, rsp_dict: dict) -> str: ... - def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs): + def get_request_body(self, messages: list[dict], const_kwargs, *args, **kwargs) -> str: body = json.dumps( - {"prompt": self.messages_to_prompt(messages), **generate_kwargs}) + {"prompt": self.messages_to_prompt(messages), **const_kwargs}) return body def get_choice_text(self, response_body: dict) -> str: diff --git a/metagpt/provider/bedrock_api.py b/metagpt/provider/bedrock_api.py index 16b50d996..b07520fd1 100644 --- a/metagpt/provider/bedrock_api.py +++ b/metagpt/provider/bedrock_api.py @@ -7,16 +7,12 @@ from metagpt.provider.base_llm import BaseLLM from metagpt.logs import log_llm_stream, logger from metagpt.provider.bedrock.bedrock_provider import get_provider from metagpt.provider.bedrock.utils import NOT_SUUPORT_STREAM_MODELS, get_max_tokens -try: - import boto3 - from botocore.eventstream import EventStream -except ImportError: - raise ImportError( - "boto3 not found! please install it by `pip install boto3` ") +import boto3 +from botocore.eventstream import EventStream @register_provider([LLMType.BEDROCK]) -class AmazonBedrockLLM(BaseLLM): +class BedrockLLM(BaseLLM): def __init__(self, config: LLMConfig): self.config = config self.__client = self.__init_client("bedrock-runtime") @@ -77,7 +73,7 @@ class AmazonBedrockLLM(BaseLLM): return response @property - def _generate_kwargs(self) -> dict: + def _const_kwargs(self) -> dict: model_max_tokens = get_max_tokens(self.config.model) if self.config.max_token > model_max_tokens: max_tokens = model_max_tokens @@ -91,7 +87,7 @@ class AmazonBedrockLLM(BaseLLM): def completion(self, messages: list[dict]) -> str: request_body = self.__provider.get_request_body( - messages, self._generate_kwargs) + messages, self._const_kwargs) response_body = self.invoke_model(request_body) completions = self.__provider.get_choice_text(response_body) return completions @@ -103,7 +99,7 @@ class AmazonBedrockLLM(BaseLLM): return self.completion(messages) request_body = self.__provider.get_request_body( - messages, self._generate_kwargs, stream=True) + messages, self._const_kwargs, stream=True) response = self.invoke_model_with_response_stream(request_body) collected_content = [] @@ -124,12 +120,6 @@ class AmazonBedrockLLM(BaseLLM): async def acompletion(self, messages: list[dict]): return await self._achat_completion(messages) - async def acompletion_text(self, messages: list[dict], stream: bool = False, - timeout: int = USE_CONFIG_TIMEOUT) -> str: - if stream: - return await self._achat_completion_stream(messages) - return await self._achat_completion(messages) - async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT): return self.completion(messages) diff --git a/tests/metagpt/provider/test_bedrock_api.py b/tests/metagpt/provider/test_bedrock_api.py index 7e282db78..7b2797da0 100644 --- a/tests/metagpt/provider/test_bedrock_api.py +++ b/tests/metagpt/provider/test_bedrock_api.py @@ -1,6 +1,6 @@ import pytest import json -from metagpt.provider.bedrock_api import AmazonBedrockLLM +from metagpt.provider.bedrock_api import BedrockLLM from tests.metagpt.provider.mock_llm_config import mock_llm_config_bedrock from metagpt.provider.bedrock.utils import get_max_tokens, SUPPORT_STREAM_MODELS, NOT_SUUPORT_STREAM_MODELS from tests.metagpt.provider.req_resp_const import BEDROCK_PROVIDER_REQUEST_BODY, BEDROCK_PROVIDER_RESPONSE_BODY @@ -65,36 +65,50 @@ def is_subset(subset, superset) -> bool: @pytest.fixture(scope="class", params=models) -def bedrock_api(request) -> AmazonBedrockLLM: +def bedrock_api(request) -> BedrockLLM: model_id = request.param mock_llm_config_bedrock.model = model_id - api = AmazonBedrockLLM(mock_llm_config_bedrock) + api = BedrockLLM(mock_llm_config_bedrock) return api -class TestAPI: - def test_generate_kwargs(self, bedrock_api: AmazonBedrockLLM): +class TestBedrockAPI: + def _patch_invoke_model(self, mocker): + mocker.patch("metagpt.provider.bedrock_api.BedrockLLM.invoke_model", mock_bedrock_provider_response) + + def _patch_invoke_model_stream(self, mocker): + mocker.patch("metagpt.provider.bedrock_api.BedrockLLM.invoke_model_with_response_stream", + mock_bedrock_provider_stream_response) + + def test_const_kwargs(self, bedrock_api: BedrockLLM): provider = bedrock_api.provider - assert bedrock_api._generate_kwargs[provider.max_tokens_field_name] <= get_max_tokens( + assert bedrock_api._const_kwargs[provider.max_tokens_field_name] <= get_max_tokens( bedrock_api.config.model) - def test_get_request_body(self, bedrock_api: AmazonBedrockLLM): + def test_get_request_body(self, bedrock_api: BedrockLLM): """Ensure request body has correct format""" provider = bedrock_api.provider request_body = json.loads(provider.get_request_body( - messages, bedrock_api._generate_kwargs)) + messages, bedrock_api._const_kwargs)) - assert is_subset(request_body, get_bedrock_request_body( - bedrock_api.config.model)) + assert is_subset(request_body, get_bedrock_request_body(bedrock_api.config.model)) - def test_completion(self, bedrock_api: AmazonBedrockLLM, mocker): - mocker.patch("metagpt.provider.bedrock_api.AmazonBedrockLLM.invoke_model", - mock_bedrock_provider_response) + def test_completion(self, bedrock_api: BedrockLLM, mocker): + self._patch_invoke_model(mocker) assert bedrock_api.completion(messages) == "Hello World" - def test_stream_completion(self, bedrock_api: AmazonBedrockLLM, mocker): - mocker.patch("metagpt.provider.bedrock_api.AmazonBedrockLLM.invoke_model", - mock_bedrock_provider_response) - mocker.patch("metagpt.provider.bedrock_api.AmazonBedrockLLM.invoke_model_with_response_stream", - mock_bedrock_provider_stream_response) + def test_chat_completion_stream(self, bedrock_api: BedrockLLM, mocker): + self._patch_invoke_model(mocker) + self._patch_invoke_model_stream(mocker) assert bedrock_api._chat_completion_stream(messages) == "Hello World" + + @pytest.mark.asyncio + async def test_achat_completion_stream(self, bedrock_api: BedrockLLM, mocker): + self._patch_invoke_model_stream(mocker) + self._patch_invoke_model(mocker) + assert await bedrock_api._achat_completion_stream(messages) == "Hello World" + + @pytest.mark.asyncio + async def test_acompletion(self, bedrock_api: BedrockLLM, mocker): + self._patch_invoke_model(mocker) + assert await bedrock_api.acompletion(messages) == "Hello World"