From 4c394a1cac32fca2a526425a6600f3355e209aba Mon Sep 17 00:00:00 2001 From: usamimeri_renko <1710269958@qq.com> Date: Fri, 26 Apr 2024 16:08:39 +0800 Subject: [PATCH] add test --- .../provider/bedrock/amazon_bedrock_api.py | 39 ++++++-- metagpt/provider/bedrock/base_provider.py | 7 +- metagpt/provider/bedrock/bedrock_provider.py | 5 +- metagpt/provider/bedrock/utils.py | 6 +- tests/metagpt/provider/mock_llm_config.py | 9 ++ tests/metagpt/provider/req_resp_const.py | 70 ++++++++++++-- .../provider/test_amazon_bedrock_api.py | 91 +++++++++++++++++++ 7 files changed, 198 insertions(+), 29 deletions(-) create mode 100644 tests/metagpt/provider/test_amazon_bedrock_api.py diff --git a/metagpt/provider/bedrock/amazon_bedrock_api.py b/metagpt/provider/bedrock/amazon_bedrock_api.py index 184e934fa..07687c682 100644 --- a/metagpt/provider/bedrock/amazon_bedrock_api.py +++ b/metagpt/provider/bedrock/amazon_bedrock_api.py @@ -1,4 +1,5 @@ from typing import Literal +import json from metagpt.const import USE_CONFIG_TIMEOUT from metagpt.provider.llm_provider_registry import register_provider from metagpt.configs.llm_config import LLMConfig, LLMType @@ -8,9 +9,10 @@ from metagpt.provider.bedrock.bedrock_provider import get_provider from metagpt.provider.bedrock.utils import NOT_SUUPORT_STREAM_MODELS, get_max_tokens try: import boto3 + from botocore.response import StreamingBody except ImportError: raise ImportError( - "boto3 not found! please install it by `pip install boto3` first ") + "boto3 not found! please install it by `pip install boto3` ") @register_provider([LLMType.AMAZON_BEDROCK]) @@ -25,7 +27,7 @@ class AmazonBedrockLLM(BaseLLM): self.__client = self.__init_client("bedrock-runtime") self.__provider = get_provider(self.config.model) logger.warning( - "Amazon bedrock doesn't support asynchronous calls now") + "Amazon bedrock doesn't support asynchronous now") def __init_client(self, service_name: Literal["bedrock-runtime", "bedrock"]): """initialize boto3 client""" @@ -39,6 +41,12 @@ class AmazonBedrockLLM(BaseLLM): client = session.client(service_name) return client + def _get_client(self): + return self.__client + + def _get_provider(self): + return self.__provider + def list_models(self): """list all available text-generation models @@ -55,6 +63,19 @@ class AmazonBedrockLLM(BaseLLM): for summary in response["modelSummaries"]] logger.info("\n"+"\n".join(summaries)) + def invoke_model(self, request_body) -> dict: + response = self.__client.invoke_model( + modelId=self.config.model, body=request_body + ) + response_body = self._get_response_body(response) + return response_body + + def invoke_model_with_response_stream(self, request_body) -> StreamingBody: + response = self.__client.invoke_model_with_response_stream( + modelId=self.config.model, body=request_body + ) + return response + @property def _generate_kwargs(self) -> dict: model_max_tokens = get_max_tokens(self.config.model) @@ -71,10 +92,8 @@ class AmazonBedrockLLM(BaseLLM): def completion(self, messages: list[dict]) -> str: request_body = self.__provider.get_request_body( messages, **self._generate_kwargs) - response = self.__client.invoke_model( - modelId=self.config.model, body=request_body - ) - completions = self.__provider.get_choice_text(response) + response_body = self.invoke_model(request_body) + completions = self.__provider.get_choice_text(response_body) return completions def _chat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str: @@ -86,9 +105,7 @@ class AmazonBedrockLLM(BaseLLM): request_body = self.__provider.get_request_body( messages, **self._generate_kwargs) - response = self.__client.invoke_model_with_response_stream( - modelId=self.config.model, body=request_body - ) + response = self.invoke_model_with_response_stream(request_body) collected_content = [] for event in response["body"]: @@ -119,3 +136,7 @@ class AmazonBedrockLLM(BaseLLM): async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT): return self._chat_completion_stream(messages) + + def _get_response_body(self, response) -> dict: + response_body = json.loads(response["body"].read()) + return response_body diff --git a/metagpt/provider/bedrock/base_provider.py b/metagpt/provider/bedrock/base_provider.py index c24556645..449e0f5c8 100644 --- a/metagpt/provider/bedrock/base_provider.py +++ b/metagpt/provider/bedrock/base_provider.py @@ -15,8 +15,7 @@ class BaseBedrockProvider(ABC): {"prompt": self.messages_to_prompt(messages), **generate_kwargs}) return body - def get_choice_text(self, response) -> str: - response_body = self._get_response_body_json(response) + def get_choice_text(self, response_body: dict) -> str: completions = self._get_completion_from_dict(response_body) return completions @@ -25,10 +24,6 @@ class BaseBedrockProvider(ABC): completions = self._get_completion_from_dict(rsp_dict) return completions - def _get_response_body_json(self, response): - response_body = json.loads(response["body"].read()) - return response_body - def messages_to_prompt(self, messages: list[dict]) -> str: """[{"role": "user", "content": msg}] to user: etc.""" return "\n".join([f"{i['role']}: {i['content']}" for i in messages]) diff --git a/metagpt/provider/bedrock/bedrock_provider.py b/metagpt/provider/bedrock/bedrock_provider.py index 375657f12..e01184705 100644 --- a/metagpt/provider/bedrock/bedrock_provider.py +++ b/metagpt/provider/bedrock/bedrock_provider.py @@ -1,7 +1,7 @@ import json from typing import Literal from metagpt.provider.bedrock.base_provider import BaseBedrockProvider -from metagpt.provider.bedrock.utils import messages_to_prompt_llama2, messages_to_prompt_llama3, messages_to_prompt_claude +from metagpt.provider.bedrock.utils import messages_to_prompt_llama2, messages_to_prompt_llama3 class MistralProvider(BaseBedrockProvider): @@ -17,9 +17,6 @@ class MistralProvider(BaseBedrockProvider): class AnthropicProvider(BaseBedrockProvider): # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html - def messages_to_prompt(self, messages: list[dict]) -> str: - return messages_to_prompt_claude(messages) - def get_request_body(self, messages: list[dict], **generate_kwargs): body = json.dumps( {"messages": messages, "anthropic_version": "bedrock-2023-05-31", **generate_kwargs}) diff --git a/metagpt/provider/bedrock/utils.py b/metagpt/provider/bedrock/utils.py index 80b7b82bd..f69d04a7b 100644 --- a/metagpt/provider/bedrock/utils.py +++ b/metagpt/provider/bedrock/utils.py @@ -33,7 +33,7 @@ SUPPORT_STREAM_MODELS = { # TODO:use a more general function for constructing chat templates. -def messages_to_prompt_llama2(messages: list[dict]): +def messages_to_prompt_llama2(messages: list[dict]) -> str: BOS, EOS = "", "" B_INST, E_INST = "[INST]", "[/INST]" B_SYS, E_SYS = "<>\n", "\n<>\n\n" @@ -56,7 +56,7 @@ def messages_to_prompt_llama2(messages: list[dict]): return prompt -def messages_to_prompt_llama3(messages: list[dict]): +def messages_to_prompt_llama3(messages: list[dict]) -> str: BOS, EOS = "<|begin_of_text|>", "<|eot_id|>" GENERAL_TEMPLATE = "<|start_header_id|>{role}<|end_header_id|>\n\n{content}<|eot_id|>" @@ -72,7 +72,7 @@ def messages_to_prompt_llama3(messages: list[dict]): return prompt -def messages_to_prompt_claude(messages: list[dict]): +def messages_to_prompt_claude2(messages: list[dict]) -> str: GENERAL_TEMPLATE = "\n\n{role}: {content}" prompt = "" for message in messages: diff --git a/tests/metagpt/provider/mock_llm_config.py b/tests/metagpt/provider/mock_llm_config.py index 0c56cc8ea..8660bc24f 100644 --- a/tests/metagpt/provider/mock_llm_config.py +++ b/tests/metagpt/provider/mock_llm_config.py @@ -60,3 +60,12 @@ mock_llm_config_dashscope = LLMConfig(api_type="dashscope", api_key="xxx", model mock_llm_config_anthropic = LLMConfig( api_type="anthropic", api_key="xxx", base_url="https://api.anthropic.com", model="claude-3-opus-20240229" ) + +mock_llm_config_bedrock = LLMConfig( + api_type="amazon_bedrock", + model="gpt-100", + region_name="somewhere", + access_key="123abc", + secret_key="123abc", + max_token=10000, +) diff --git a/tests/metagpt/provider/req_resp_const.py b/tests/metagpt/provider/req_resp_const.py index 7e4c1a49c..6a244cbe4 100644 --- a/tests/metagpt/provider/req_resp_const.py +++ b/tests/metagpt/provider/req_resp_const.py @@ -65,17 +65,20 @@ def get_openai_chat_completion(name: str) -> ChatCompletion: Choice( finish_reason="stop", index=0, - message=ChatCompletionMessage(role="assistant", content=resp_cont_tmpl.format(name=name)), + message=ChatCompletionMessage( + role="assistant", content=resp_cont_tmpl.format(name=name)), logprobs=None, ) ], - usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202), + usage=CompletionUsage(completion_tokens=110, + prompt_tokens=92, total_tokens=202), ) return openai_chat_completion def get_openai_chat_completion_chunk(name: str, usage_as_dict: bool = False) -> ChatCompletionChunk: - usage = CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202) + usage = CompletionUsage(completion_tokens=110, + prompt_tokens=92, total_tokens=202) usage = usage if not usage_as_dict else usage.model_dump() openai_chat_completion_chunk = ChatCompletionChunk( id="cmpl-a6652c1bb181caae8dd19ad8", @@ -84,7 +87,8 @@ def get_openai_chat_completion_chunk(name: str, usage_as_dict: bool = False) -> created=1703300855, choices=[ AChoice( - delta=ChoiceDelta(role="assistant", content=resp_cont_tmpl.format(name=name)), + delta=ChoiceDelta(role="assistant", + content=resp_cont_tmpl.format(name=name)), finish_reason="stop", index=0, logprobs=None, @@ -133,7 +137,8 @@ def get_dashscope_response(name: str) -> GenerationResponse: ], } ), - usage=GenerationUsage(**{"input_tokens": 12, "output_tokens": 98, "total_tokens": 110}), + usage=GenerationUsage( + **{"input_tokens": 12, "output_tokens": 98, "total_tokens": 110}), ) ) @@ -155,7 +160,8 @@ def get_anthropic_response(name: str, stream: bool = False) -> Message: ), ContentBlockDeltaEvent( index=0, - delta=TextDelta(text=resp_cont_tmpl.format(name=name), type="text_delta"), + delta=TextDelta(text=resp_cont_tmpl.format( + name=name), type="text_delta"), type="content_block_delta", ), ] @@ -165,7 +171,8 @@ def get_anthropic_response(name: str, stream: bool = False) -> Message: model=name, role="assistant", type="message", - content=[ContentBlock(text=resp_cont_tmpl.format(name=name), type="text")], + content=[ContentBlock( + text=resp_cont_tmpl.format(name=name), type="text")], usage=AnthropicUsage(input_tokens=10, output_tokens=10), ) @@ -183,3 +190,52 @@ async def llm_general_chat_funcs_test(llm: BaseLLM, prompt: str, messages: list[ resp = await llm.acompletion_text(messages, stream=True) assert resp == resp_cont + + +# For Amazon Bedrock +BEDROCK_PROVIDER_REQUEST_BODY = { + "mistral": {"prompt": "", "max_tokens": 0, "stop": [], "temperature": 0.0, "top_p": 0.0, "top_k": 0}, + "meta": {"prompt": "", "temperature": 0.0, "top_p": 0.0, "max_gen_len": 0}, + "ai21": { + "prompt": "", "temperature": 0.0, "topP": 0.0, "maxTokens": 0, + "stopSequences": [], "countPenalty": {"scale": 0.0}, + "presencePenalty": {"scale": 0.0}, "frequencyPenalty": {"scale": 0.0} + }, + "cohere": { + "prompt": "", "temperature": 0.0, "p": 0.0, "k": 0.0, "max_tokens": 0, "stop_sequences": [], + "return_likelihoods": "NONE", "stream": False, "num_generations": 0, "logit_bias": {}, "truncate": "NONE" + }, + "anthropic": { + "anthropic_version": "bedrock-2023-05-31", "max_tokens": 0, "system": "", + "messages": [{"role": "", "content": ""}], "temperature": 0.0, "top_p": 0.0, "top_k": 0, "stop_sequences": [] + }, + "amazon": { + "inputText": "", "textGenerationConfig": {"temperature": 0.0, "topP": 0.0, "maxTokenCount": 0, "stopSequences": []} + } +} + +BEDROCK_PROVIDER_RESPONSE_BODY = { + "mistral": {"outputs": [{"text": "Hello World", "stop_reason": ""}]}, + + "meta": {"generation": "Hello World", "prompt_token_count": 0, "generation_token_count": 0, "stop_reason": ""}, + + "ai21": { + "id": "", "prompt": {"text": "Hello World", "tokens": []}, + "completions": [{"data": {"text": "Hello World", "tokens": []}, + "finishReason": {"reason": "length", "length": 2}}] + }, + "cohere": { + "generations": [ + {"finish_reason": "", "id": "", "text": "Hello World", "likelihood": 0.0, + "token_likelihoods": [{"token": 0.0}], "is_finished": True, "index": 0}], "id": "", "prompt": "" + }, + + "anthropic": { + "id": "", "model": "", "type": "message", "role": "assistant", "content": [{"type": "text", "text": "Hello World"}], + "stop_reason": "", "stop_sequence": "", "usage": {"input_tokens": 0, "output_tokens": 0} + }, + + "amazon": {'inputTextTokenCount': 0, 'results': [{'tokenCount': 0, 'outputText': 'Hello World', 'completionReason': ""}]} +} + +BEDROCK_PROVIDER_STREAM_RESPONSE = {} diff --git a/tests/metagpt/provider/test_amazon_bedrock_api.py b/tests/metagpt/provider/test_amazon_bedrock_api.py new file mode 100644 index 000000000..0b8206463 --- /dev/null +++ b/tests/metagpt/provider/test_amazon_bedrock_api.py @@ -0,0 +1,91 @@ +import pytest +import json +from metagpt.provider.bedrock.amazon_bedrock_api import AmazonBedrockLLM +from tests.metagpt.provider.mock_llm_config import mock_llm_config_bedrock +from metagpt.provider.bedrock.utils import get_max_tokens, SUPPORT_STREAM_MODELS, NOT_SUUPORT_STREAM_MODELS +from tests.metagpt.provider.req_resp_const import ( + BEDROCK_PROVIDER_REQUEST_BODY, + BEDROCK_PROVIDER_RESPONSE_BODY, + BEDROCK_PROVIDER_STREAM_RESPONSE) +from botocore.response import StreamingBody + +# all available model from bedrock +models = SUPPORT_STREAM_MODELS | NOT_SUUPORT_STREAM_MODELS +messages = [{"role": "user", "content": "Hi!"}] + + +def mock_bedrock_provider_response(self, *args, **kwargs) -> dict: + provider = self.config.model.split(".")[0] + return BEDROCK_PROVIDER_RESPONSE_BODY[provider] + + +def mock_bedrock_provider_stream_response(self, *args, **kwargs) -> StreamingBody: + provider = self.config.model.split(".")[0] + response_json = BEDROCK_PROVIDER_STREAM_RESPONSE[provider] + return + + +def get_bedrock_request_body(model_id) -> dict: + provider = model_id.split(".")[0] + return BEDROCK_PROVIDER_REQUEST_BODY[provider] + + +def is_subset(subset, superset): + """Ensure all fields in request body are allowed. + + ```python + subset = {"prompt": "hello","kwargs": {"temperature": 0.9,"p": 0.0}} + superset = {"prompt": "hello", "kwargs": {"temperature": 0.0, "top-p": 0.0}} + is_subset(subset, superset) + ``` + >>>False + """ + for key, value in subset.items(): + if key not in superset: + return False + if isinstance(value, dict): + if not isinstance(superset[key], dict): + return False + if not is_subset(value, superset[key]): + return False + return True + + +@ pytest.fixture(scope="class", params=models) +def bedrock_api(request) -> AmazonBedrockLLM: + model_id = request.param + mock_llm_config_bedrock.model = model_id + api = AmazonBedrockLLM(mock_llm_config_bedrock) + return api + + +class TestAPI: + def test_generate_kwargs(self, bedrock_api: AmazonBedrockLLM): + provider = bedrock_api._get_provider() + assert bedrock_api._generate_kwargs[provider.max_tokens_field_name] <= get_max_tokens( + bedrock_api.config.model) + + def test_get_request_body(self, bedrock_api: AmazonBedrockLLM): + provider = bedrock_api._get_provider() + request_body = json.loads(provider.get_request_body( + messages, **bedrock_api._generate_kwargs)) + print(get_bedrock_request_body( + bedrock_api.config.model)) + print(request_body) + + assert is_subset(request_body, get_bedrock_request_body( + bedrock_api.config.model)) + + def test_completion(self, bedrock_api: AmazonBedrockLLM, mocker): + mocker.patch( + "metagpt.provider.bedrock.amazon_bedrock_api.AmazonBedrockLLM.invoke_model", mock_bedrock_provider_response) + assert bedrock_api.completion(messages) == "Hello World" + + # def test_stream_completion(self, bedrock_api: AmazonBedrockLLM, mocker): + # mocker.patch( + # "metagpt.provider.bedrock.amazon_bedrock_api.AmazonBedrockLLM.invoke_model_with_response_stream", mock_bedrock_provider_response) + # assert bedrock_api._chat_completion_stream(messages) == "Hello World" + + +if __name__ == '__main__': + print(get_bedrock_request_body("amazon.titan-tg1-large"))