mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
rename bedrock class and add more tests
This commit is contained in:
parent
986e3c827e
commit
3f108abd06
4 changed files with 42 additions and 38 deletions
|
|
@ -17,7 +17,7 @@ from metagpt.provider.spark_api import SparkLLM
|
|||
from metagpt.provider.qianfan_api import QianFanLLM
|
||||
from metagpt.provider.dashscope_api import DashScopeLLM
|
||||
from metagpt.provider.anthropic_api import AnthropicLLM
|
||||
from metagpt.provider.bedrock_api import AmazonBedrockLLM
|
||||
from metagpt.provider.bedrock_api import BedrockLLM
|
||||
|
||||
__all__ = [
|
||||
"GeminiLLM",
|
||||
|
|
@ -31,5 +31,5 @@ __all__ = [
|
|||
"QianFanLLM",
|
||||
"DashScopeLLM",
|
||||
"AnthropicLLM",
|
||||
"AmazonBedrockLLM"
|
||||
"BedrockLLM"
|
||||
]
|
||||
|
|
|
|||
|
|
@ -10,9 +10,9 @@ class BaseBedrockProvider(ABC):
|
|||
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
|
||||
...
|
||||
|
||||
def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs):
|
||||
def get_request_body(self, messages: list[dict], const_kwargs, *args, **kwargs) -> str:
|
||||
body = json.dumps(
|
||||
{"prompt": self.messages_to_prompt(messages), **generate_kwargs})
|
||||
{"prompt": self.messages_to_prompt(messages), **const_kwargs})
|
||||
return body
|
||||
|
||||
def get_choice_text(self, response_body: dict) -> str:
|
||||
|
|
|
|||
|
|
@ -7,16 +7,12 @@ from metagpt.provider.base_llm import BaseLLM
|
|||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.provider.bedrock.bedrock_provider import get_provider
|
||||
from metagpt.provider.bedrock.utils import NOT_SUUPORT_STREAM_MODELS, get_max_tokens
|
||||
try:
|
||||
import boto3
|
||||
from botocore.eventstream import EventStream
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"boto3 not found! please install it by `pip install boto3` ")
|
||||
import boto3
|
||||
from botocore.eventstream import EventStream
|
||||
|
||||
|
||||
@register_provider([LLMType.BEDROCK])
|
||||
class AmazonBedrockLLM(BaseLLM):
|
||||
class BedrockLLM(BaseLLM):
|
||||
def __init__(self, config: LLMConfig):
|
||||
self.config = config
|
||||
self.__client = self.__init_client("bedrock-runtime")
|
||||
|
|
@ -77,7 +73,7 @@ class AmazonBedrockLLM(BaseLLM):
|
|||
return response
|
||||
|
||||
@property
|
||||
def _generate_kwargs(self) -> dict:
|
||||
def _const_kwargs(self) -> dict:
|
||||
model_max_tokens = get_max_tokens(self.config.model)
|
||||
if self.config.max_token > model_max_tokens:
|
||||
max_tokens = model_max_tokens
|
||||
|
|
@ -91,7 +87,7 @@ class AmazonBedrockLLM(BaseLLM):
|
|||
|
||||
def completion(self, messages: list[dict]) -> str:
|
||||
request_body = self.__provider.get_request_body(
|
||||
messages, self._generate_kwargs)
|
||||
messages, self._const_kwargs)
|
||||
response_body = self.invoke_model(request_body)
|
||||
completions = self.__provider.get_choice_text(response_body)
|
||||
return completions
|
||||
|
|
@ -103,7 +99,7 @@ class AmazonBedrockLLM(BaseLLM):
|
|||
return self.completion(messages)
|
||||
|
||||
request_body = self.__provider.get_request_body(
|
||||
messages, self._generate_kwargs, stream=True)
|
||||
messages, self._const_kwargs, stream=True)
|
||||
|
||||
response = self.invoke_model_with_response_stream(request_body)
|
||||
collected_content = []
|
||||
|
|
@ -124,12 +120,6 @@ class AmazonBedrockLLM(BaseLLM):
|
|||
async def acompletion(self, messages: list[dict]):
|
||||
return await self._achat_completion(messages)
|
||||
|
||||
async def acompletion_text(self, messages: list[dict], stream: bool = False,
|
||||
timeout: int = USE_CONFIG_TIMEOUT) -> str:
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages)
|
||||
return await self._achat_completion(messages)
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
|
||||
return self.completion(messages)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import pytest
|
||||
import json
|
||||
from metagpt.provider.bedrock_api import AmazonBedrockLLM
|
||||
from metagpt.provider.bedrock_api import BedrockLLM
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config_bedrock
|
||||
from metagpt.provider.bedrock.utils import get_max_tokens, SUPPORT_STREAM_MODELS, NOT_SUUPORT_STREAM_MODELS
|
||||
from tests.metagpt.provider.req_resp_const import BEDROCK_PROVIDER_REQUEST_BODY, BEDROCK_PROVIDER_RESPONSE_BODY
|
||||
|
|
@ -65,36 +65,50 @@ def is_subset(subset, superset) -> bool:
|
|||
|
||||
|
||||
@pytest.fixture(scope="class", params=models)
|
||||
def bedrock_api(request) -> AmazonBedrockLLM:
|
||||
def bedrock_api(request) -> BedrockLLM:
|
||||
model_id = request.param
|
||||
mock_llm_config_bedrock.model = model_id
|
||||
api = AmazonBedrockLLM(mock_llm_config_bedrock)
|
||||
api = BedrockLLM(mock_llm_config_bedrock)
|
||||
return api
|
||||
|
||||
|
||||
class TestAPI:
|
||||
def test_generate_kwargs(self, bedrock_api: AmazonBedrockLLM):
|
||||
class TestBedrockAPI:
|
||||
def _patch_invoke_model(self, mocker):
|
||||
mocker.patch("metagpt.provider.bedrock_api.BedrockLLM.invoke_model", mock_bedrock_provider_response)
|
||||
|
||||
def _patch_invoke_model_stream(self, mocker):
|
||||
mocker.patch("metagpt.provider.bedrock_api.BedrockLLM.invoke_model_with_response_stream",
|
||||
mock_bedrock_provider_stream_response)
|
||||
|
||||
def test_const_kwargs(self, bedrock_api: BedrockLLM):
|
||||
provider = bedrock_api.provider
|
||||
assert bedrock_api._generate_kwargs[provider.max_tokens_field_name] <= get_max_tokens(
|
||||
assert bedrock_api._const_kwargs[provider.max_tokens_field_name] <= get_max_tokens(
|
||||
bedrock_api.config.model)
|
||||
|
||||
def test_get_request_body(self, bedrock_api: AmazonBedrockLLM):
|
||||
def test_get_request_body(self, bedrock_api: BedrockLLM):
|
||||
"""Ensure request body has correct format"""
|
||||
provider = bedrock_api.provider
|
||||
request_body = json.loads(provider.get_request_body(
|
||||
messages, bedrock_api._generate_kwargs))
|
||||
messages, bedrock_api._const_kwargs))
|
||||
|
||||
assert is_subset(request_body, get_bedrock_request_body(
|
||||
bedrock_api.config.model))
|
||||
assert is_subset(request_body, get_bedrock_request_body(bedrock_api.config.model))
|
||||
|
||||
def test_completion(self, bedrock_api: AmazonBedrockLLM, mocker):
|
||||
mocker.patch("metagpt.provider.bedrock_api.AmazonBedrockLLM.invoke_model",
|
||||
mock_bedrock_provider_response)
|
||||
def test_completion(self, bedrock_api: BedrockLLM, mocker):
|
||||
self._patch_invoke_model(mocker)
|
||||
assert bedrock_api.completion(messages) == "Hello World"
|
||||
|
||||
def test_stream_completion(self, bedrock_api: AmazonBedrockLLM, mocker):
|
||||
mocker.patch("metagpt.provider.bedrock_api.AmazonBedrockLLM.invoke_model",
|
||||
mock_bedrock_provider_response)
|
||||
mocker.patch("metagpt.provider.bedrock_api.AmazonBedrockLLM.invoke_model_with_response_stream",
|
||||
mock_bedrock_provider_stream_response)
|
||||
def test_chat_completion_stream(self, bedrock_api: BedrockLLM, mocker):
|
||||
self._patch_invoke_model(mocker)
|
||||
self._patch_invoke_model_stream(mocker)
|
||||
assert bedrock_api._chat_completion_stream(messages) == "Hello World"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_achat_completion_stream(self, bedrock_api: BedrockLLM, mocker):
|
||||
self._patch_invoke_model_stream(mocker)
|
||||
self._patch_invoke_model(mocker)
|
||||
assert await bedrock_api._achat_completion_stream(messages) == "Hello World"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acompletion(self, bedrock_api: BedrockLLM, mocker):
|
||||
self._patch_invoke_model(mocker)
|
||||
assert await bedrock_api.acompletion(messages) == "Hello World"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue