From f14a1f63ef93d035b77cb285c078bc511290c094 Mon Sep 17 00:00:00 2001 From: usamimeri_renko <1710269958@qq.com> Date: Mon, 29 Apr 2024 15:04:33 +0800 Subject: [PATCH] add pre-commit --- metagpt/configs/llm_config.py | 1 - metagpt/provider/__init__.py | 2 +- metagpt/provider/bedrock/base_provider.py | 3 +- metagpt/provider/bedrock/bedrock_provider.py | 19 ++-- metagpt/provider/bedrock_api.py | 62 ++++++------- tests/metagpt/provider/req_resp_const.py | 98 +++++++++++++------- tests/metagpt/provider/test_bedrock_api.py | 44 +++++---- 7 files changed, 132 insertions(+), 97 deletions(-) diff --git a/metagpt/configs/llm_config.py b/metagpt/configs/llm_config.py index 41e04ab44..e202150f7 100644 --- a/metagpt/configs/llm_config.py +++ b/metagpt/configs/llm_config.py @@ -100,4 +100,3 @@ class LLMConfig(YamlModel): @classmethod def check_timeout(cls, v): return v or LLM_API_TIMEOUT - diff --git a/metagpt/provider/__init__.py b/metagpt/provider/__init__.py index fc91c7460..fcb5fa32a 100644 --- a/metagpt/provider/__init__.py +++ b/metagpt/provider/__init__.py @@ -31,5 +31,5 @@ __all__ = [ "QianFanLLM", "DashScopeLLM", "AnthropicLLM", - "BedrockLLM" + "BedrockLLM", ] diff --git a/metagpt/provider/bedrock/base_provider.py b/metagpt/provider/bedrock/base_provider.py index fd0508cb0..0d13ae938 100644 --- a/metagpt/provider/bedrock/base_provider.py +++ b/metagpt/provider/bedrock/base_provider.py @@ -11,8 +11,7 @@ class BaseBedrockProvider(ABC): ... def get_request_body(self, messages: list[dict], const_kwargs, *args, **kwargs) -> str: - body = json.dumps( - {"prompt": self.messages_to_prompt(messages), **const_kwargs}) + body = json.dumps({"prompt": self.messages_to_prompt(messages), **const_kwargs}) return body def get_choice_text(self, response_body: dict) -> str: diff --git a/metagpt/provider/bedrock/bedrock_provider.py b/metagpt/provider/bedrock/bedrock_provider.py index 6378939c9..ff1d88a47 100644 --- a/metagpt/provider/bedrock/bedrock_provider.py +++ b/metagpt/provider/bedrock/bedrock_provider.py @@ -1,7 +1,11 @@ import json from typing import Literal + from metagpt.provider.bedrock.base_provider import BaseBedrockProvider -from metagpt.provider.bedrock.utils import messages_to_prompt_llama2, messages_to_prompt_llama3 +from metagpt.provider.bedrock.utils import ( + messages_to_prompt_llama2, + messages_to_prompt_llama3, +) class MistralProvider(BaseBedrockProvider): @@ -18,8 +22,7 @@ class AnthropicProvider(BaseBedrockProvider): # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs): - body = json.dumps( - {"messages": messages, "anthropic_version": "bedrock-2023-05-31", **generate_kwargs}) + body = json.dumps({"messages": messages, "anthropic_version": "bedrock-2023-05-31", **generate_kwargs}) return body def _get_completion_from_dict(self, rsp_dict: dict) -> str: @@ -43,7 +46,8 @@ class CohereProvider(BaseBedrockProvider): def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs): body = json.dumps( - {"prompt": self.messages_to_prompt(messages), "stream": kwargs.get("stream", False), **generate_kwargs}) + {"prompt": self.messages_to_prompt(messages), "stream": kwargs.get("stream", False), **generate_kwargs} + ) return body def get_choice_text_from_stream(self, event) -> str: @@ -85,10 +89,7 @@ class AmazonProvider(BaseBedrockProvider): max_tokens_field_name = "maxTokenCount" def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs): - body = json.dumps({ - "inputText": self.messages_to_prompt(messages), - "textGenerationConfig": generate_kwargs - }) + body = json.dumps({"inputText": self.messages_to_prompt(messages), "textGenerationConfig": generate_kwargs}) return body def _get_completion_from_dict(self, rsp_dict: dict) -> str: @@ -106,7 +107,7 @@ PROVIDERS = { "ai21": Ai21Provider, "cohere": CohereProvider, "anthropic": AnthropicProvider, - "amazon": AmazonProvider + "amazon": AmazonProvider, } diff --git a/metagpt/provider/bedrock_api.py b/metagpt/provider/bedrock_api.py index b07520fd1..de3fbae94 100644 --- a/metagpt/provider/bedrock_api.py +++ b/metagpt/provider/bedrock_api.py @@ -1,15 +1,17 @@ -from typing import Literal import json -from metagpt.const import USE_CONFIG_TIMEOUT -from metagpt.provider.llm_provider_registry import register_provider -from metagpt.configs.llm_config import LLMConfig, LLMType -from metagpt.provider.base_llm import BaseLLM -from metagpt.logs import log_llm_stream, logger -from metagpt.provider.bedrock.bedrock_provider import get_provider -from metagpt.provider.bedrock.utils import NOT_SUUPORT_STREAM_MODELS, get_max_tokens +from typing import Literal + import boto3 from botocore.eventstream import EventStream +from metagpt.configs.llm_config import LLMConfig, LLMType +from metagpt.const import USE_CONFIG_TIMEOUT +from metagpt.logs import log_llm_stream, logger +from metagpt.provider.base_llm import BaseLLM +from metagpt.provider.bedrock.bedrock_provider import get_provider +from metagpt.provider.bedrock.utils import NOT_SUUPORT_STREAM_MODELS, get_max_tokens +from metagpt.provider.llm_provider_registry import register_provider + @register_provider([LLMType.BEDROCK]) class BedrockLLM(BaseLLM): @@ -17,8 +19,7 @@ class BedrockLLM(BaseLLM): self.config = config self.__client = self.__init_client("bedrock-runtime") self.__provider = get_provider(self.config.model) - logger.warning( - "Amazon bedrock doesn't support asynchronous now") + logger.warning("Amazon bedrock doesn't support asynchronous now") def __init_client(self, service_name: Literal["bedrock-runtime", "bedrock"]): """initialize boto3 client""" @@ -26,7 +27,7 @@ class BedrockLLM(BaseLLM): self.__credentital_kwargs = { "aws_secret_access_key": self.config.secret_key, "aws_access_key_id": self.config.access_key, - "region_name": self.config.region_name + "region_name": self.config.region_name, } session = boto3.Session(**self.__credentital_kwargs) client = session.client(service_name) @@ -52,22 +53,21 @@ class BedrockLLM(BaseLLM): client = self.__init_client("bedrock") # only output text-generation models response = client.list_foundation_models(byOutputModality="TEXT") - summaries = [f'{summary["modelId"]:50} Support Streaming:{summary["responseStreamingSupported"]}' - for summary in response["modelSummaries"]] - logger.info("\n"+"\n".join(summaries)) + summaries = [ + f'{summary["modelId"]:50} Support Streaming:{summary["responseStreamingSupported"]}' + for summary in response["modelSummaries"] + ] + logger.info("\n" + "\n".join(summaries)) def invoke_model(self, request_body: str) -> dict: - response = self.__client.invoke_model( - modelId=self.config.model, body=request_body - ) + response = self.__client.invoke_model(modelId=self.config.model, body=request_body) usage = self._get_usage(response) self._update_costs(usage) response_body = self._get_response_body(response) return response_body def invoke_model_with_response_stream(self, request_body: str) -> EventStream: - response = self.__client.invoke_model_with_response_stream( - modelId=self.config.model, body=request_body) + response = self.__client.invoke_model_with_response_stream(modelId=self.config.model, body=request_body) usage = self._get_usage(response) self._update_costs(usage) return response @@ -80,26 +80,20 @@ class BedrockLLM(BaseLLM): else: max_tokens = self.config.max_token - return { - self.__provider.max_tokens_field_name: max_tokens, - "temperature": self.config.temperature - } + return {self.__provider.max_tokens_field_name: max_tokens, "temperature": self.config.temperature} def completion(self, messages: list[dict]) -> str: - request_body = self.__provider.get_request_body( - messages, self._const_kwargs) + request_body = self.__provider.get_request_body(messages, self._const_kwargs) response_body = self.invoke_model(request_body) completions = self.__provider.get_choice_text(response_body) return completions def _chat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str: if self.config.model in NOT_SUUPORT_STREAM_MODELS: - logger.warning( - f"model {self.config.model} doesn't support streaming output!") + logger.warning(f"model {self.config.model} doesn't support streaming output!") return self.completion(messages) - request_body = self.__provider.get_request_body( - messages, self._const_kwargs, stream=True) + request_body = self.__provider.get_request_body(messages, self._const_kwargs, stream=True) response = self.invoke_model_with_response_stream(request_body) collected_content = [] @@ -134,8 +128,10 @@ class BedrockLLM(BaseLLM): headers = response.get("ResponseMetadata", {}).get("HTTPHeaders", {}) prompt_tokens = int(headers.get("x-amzn-bedrock-input-token-count", 0)) completion_tokens = int(headers.get("x-amzn-bedrock-output-token-count", 0)) - usage = { - "prompt_tokens": prompt_tokens, - "completion_tokens": completion_tokens, - }, + usage = ( + { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + }, + ) return usage diff --git a/tests/metagpt/provider/req_resp_const.py b/tests/metagpt/provider/req_resp_const.py index fb754abf7..111b57f91 100644 --- a/tests/metagpt/provider/req_resp_const.py +++ b/tests/metagpt/provider/req_resp_const.py @@ -65,20 +65,17 @@ def get_openai_chat_completion(name: str) -> ChatCompletion: Choice( finish_reason="stop", index=0, - message=ChatCompletionMessage( - role="assistant", content=resp_cont_tmpl.format(name=name)), + message=ChatCompletionMessage(role="assistant", content=resp_cont_tmpl.format(name=name)), logprobs=None, ) ], - usage=CompletionUsage(completion_tokens=110, - prompt_tokens=92, total_tokens=202), + usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202), ) return openai_chat_completion def get_openai_chat_completion_chunk(name: str, usage_as_dict: bool = False) -> ChatCompletionChunk: - usage = CompletionUsage(completion_tokens=110, - prompt_tokens=92, total_tokens=202) + usage = CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202) usage = usage if not usage_as_dict else usage.model_dump() openai_chat_completion_chunk = ChatCompletionChunk( id="cmpl-a6652c1bb181caae8dd19ad8", @@ -87,8 +84,7 @@ def get_openai_chat_completion_chunk(name: str, usage_as_dict: bool = False) -> created=1703300855, choices=[ AChoice( - delta=ChoiceDelta(role="assistant", - content=resp_cont_tmpl.format(name=name)), + delta=ChoiceDelta(role="assistant", content=resp_cont_tmpl.format(name=name)), finish_reason="stop", index=0, logprobs=None, @@ -137,8 +133,7 @@ def get_dashscope_response(name: str) -> GenerationResponse: ], } ), - usage=GenerationUsage( - **{"input_tokens": 12, "output_tokens": 98, "total_tokens": 110}), + usage=GenerationUsage(**{"input_tokens": 12, "output_tokens": 98, "total_tokens": 110}), ) ) @@ -170,8 +165,7 @@ def get_anthropic_response(name: str, stream: bool = False) -> Message: model=name, role="assistant", type="message", - content=[ContentBlock( - text=resp_cont_tmpl.format(name=name), type="text")], + content=[ContentBlock(text=resp_cont_tmpl.format(name=name), type="text")], usage=AnthropicUsage(input_tokens=10, output_tokens=10), ) @@ -198,43 +192,81 @@ BEDROCK_PROVIDER_REQUEST_BODY = { "mistral": {"prompt": "", "max_tokens": 0, "stop": [], "temperature": 0.0, "top_p": 0.0, "top_k": 0}, "meta": {"prompt": "", "temperature": 0.0, "top_p": 0.0, "max_gen_len": 0}, "ai21": { - "prompt": "", "temperature": 0.0, "topP": 0.0, "maxTokens": 0, - "stopSequences": [], "countPenalty": {"scale": 0.0}, - "presencePenalty": {"scale": 0.0}, "frequencyPenalty": {"scale": 0.0} + "prompt": "", + "temperature": 0.0, + "topP": 0.0, + "maxTokens": 0, + "stopSequences": [], + "countPenalty": {"scale": 0.0}, + "presencePenalty": {"scale": 0.0}, + "frequencyPenalty": {"scale": 0.0}, }, "cohere": { - "prompt": "", "temperature": 0.0, "p": 0.0, "k": 0.0, "max_tokens": 0, "stop_sequences": [], - "return_likelihoods": "NONE", "stream": False, "num_generations": 0, "logit_bias": {}, "truncate": "NONE" + "prompt": "", + "temperature": 0.0, + "p": 0.0, + "k": 0.0, + "max_tokens": 0, + "stop_sequences": [], + "return_likelihoods": "NONE", + "stream": False, + "num_generations": 0, + "logit_bias": {}, + "truncate": "NONE", }, "anthropic": { - "anthropic_version": "bedrock-2023-05-31", "max_tokens": 0, "system": "", - "messages": [{"role": "", "content": ""}], "temperature": 0.0, "top_p": 0.0, "top_k": 0, "stop_sequences": [] + "anthropic_version": "bedrock-2023-05-31", + "max_tokens": 0, + "system": "", + "messages": [{"role": "", "content": ""}], + "temperature": 0.0, + "top_p": 0.0, + "top_k": 0, + "stop_sequences": [], }, "amazon": { - "inputText": "", "textGenerationConfig": {"temperature": 0.0, "topP": 0.0, "maxTokenCount": 0, "stopSequences": []} - } + "inputText": "", + "textGenerationConfig": {"temperature": 0.0, "topP": 0.0, "maxTokenCount": 0, "stopSequences": []}, + }, } BEDROCK_PROVIDER_RESPONSE_BODY = { "mistral": {"outputs": [{"text": "Hello World", "stop_reason": ""}]}, - "meta": {"generation": "Hello World", "prompt_token_count": 0, "generation_token_count": 0, "stop_reason": ""}, - "ai21": { - "id": "", "prompt": {"text": "Hello World", "tokens": []}, - "completions": [{"data": {"text": "Hello World", "tokens": []}, - "finishReason": {"reason": "length", "length": 2}}] + "id": "", + "prompt": {"text": "Hello World", "tokens": []}, + "completions": [ + {"data": {"text": "Hello World", "tokens": []}, "finishReason": {"reason": "length", "length": 2}} + ], }, "cohere": { "generations": [ - {"finish_reason": "", "id": "", "text": "Hello World", "likelihood": 0.0, - "token_likelihoods": [{"token": 0.0}], "is_finished": True, "index": 0}], "id": "", "prompt": "" + { + "finish_reason": "", + "id": "", + "text": "Hello World", + "likelihood": 0.0, + "token_likelihoods": [{"token": 0.0}], + "is_finished": True, + "index": 0, + } + ], + "id": "", + "prompt": "", }, - "anthropic": { - "id": "", "model": "", "type": "message", "role": "assistant", "content": [{"type": "text", "text": "Hello World"}], - "stop_reason": "", "stop_sequence": "", "usage": {"input_tokens": 0, "output_tokens": 0} + "id": "", + "model": "", + "type": "message", + "role": "assistant", + "content": [{"type": "text", "text": "Hello World"}], + "stop_reason": "", + "stop_sequence": "", + "usage": {"input_tokens": 0, "output_tokens": 0}, + }, + "amazon": { + "inputTextTokenCount": 0, + "results": [{"tokenCount": 0, "outputText": "Hello World", "completionReason": ""}], }, - - "amazon": {"inputTextTokenCount": 0, "results": [{"tokenCount": 0, "outputText": "Hello World", "completionReason": ""}]} } diff --git a/tests/metagpt/provider/test_bedrock_api.py b/tests/metagpt/provider/test_bedrock_api.py index 7b2797da0..54ff5afa4 100644 --- a/tests/metagpt/provider/test_bedrock_api.py +++ b/tests/metagpt/provider/test_bedrock_api.py @@ -1,12 +1,21 @@ -import pytest import json + +import pytest + +from metagpt.provider.bedrock.utils import ( + NOT_SUUPORT_STREAM_MODELS, + SUPPORT_STREAM_MODELS, + get_max_tokens, +) from metagpt.provider.bedrock_api import BedrockLLM from tests.metagpt.provider.mock_llm_config import mock_llm_config_bedrock -from metagpt.provider.bedrock.utils import get_max_tokens, SUPPORT_STREAM_MODELS, NOT_SUUPORT_STREAM_MODELS -from tests.metagpt.provider.req_resp_const import BEDROCK_PROVIDER_REQUEST_BODY, BEDROCK_PROVIDER_RESPONSE_BODY +from tests.metagpt.provider.req_resp_const import ( + BEDROCK_PROVIDER_REQUEST_BODY, + BEDROCK_PROVIDER_RESPONSE_BODY, +) # all available model from bedrock -models = (SUPPORT_STREAM_MODELS | NOT_SUUPORT_STREAM_MODELS) +models = SUPPORT_STREAM_MODELS | NOT_SUUPORT_STREAM_MODELS messages = [{"role": "user", "content": "Hi!"}] @@ -19,22 +28,21 @@ def mock_bedrock_provider_stream_response(self, *args, **kwargs) -> dict: # use json object to mock EventStream def dict2bytes(x): return json.dumps(x).encode("utf-8") + provider = self.config.model.split(".")[0] if provider == "amazon": response_body_bytes = dict2bytes({"outputText": "Hello World"}) elif provider == "anthropic": - response_body_bytes = dict2bytes({"type": "content_block_delta", "index": 0, - "delta": {"type": "text_delta", "text": "Hello World"}}) + response_body_bytes = dict2bytes( + {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "Hello World"}} + ) elif provider == "cohere": - response_body_bytes = dict2bytes( - {"is_finished": False, "text": "Hello World"}) + response_body_bytes = dict2bytes({"is_finished": False, "text": "Hello World"}) else: - response_body_bytes = dict2bytes( - BEDROCK_PROVIDER_RESPONSE_BODY[provider]) + response_body_bytes = dict2bytes(BEDROCK_PROVIDER_RESPONSE_BODY[provider]) - response_body_stream = { - "body": [{"chunk": {"bytes": response_body_bytes}}]} + response_body_stream = {"body": [{"chunk": {"bytes": response_body_bytes}}]} return response_body_stream @@ -77,19 +85,19 @@ class TestBedrockAPI: mocker.patch("metagpt.provider.bedrock_api.BedrockLLM.invoke_model", mock_bedrock_provider_response) def _patch_invoke_model_stream(self, mocker): - mocker.patch("metagpt.provider.bedrock_api.BedrockLLM.invoke_model_with_response_stream", - mock_bedrock_provider_stream_response) + mocker.patch( + "metagpt.provider.bedrock_api.BedrockLLM.invoke_model_with_response_stream", + mock_bedrock_provider_stream_response, + ) def test_const_kwargs(self, bedrock_api: BedrockLLM): provider = bedrock_api.provider - assert bedrock_api._const_kwargs[provider.max_tokens_field_name] <= get_max_tokens( - bedrock_api.config.model) + assert bedrock_api._const_kwargs[provider.max_tokens_field_name] <= get_max_tokens(bedrock_api.config.model) def test_get_request_body(self, bedrock_api: BedrockLLM): """Ensure request body has correct format""" provider = bedrock_api.provider - request_body = json.loads(provider.get_request_body( - messages, bedrock_api._const_kwargs)) + request_body = json.loads(provider.get_request_body(messages, bedrock_api._const_kwargs)) assert is_subset(request_body, get_bedrock_request_body(bedrock_api.config.model))