From c8e351f3c863950d9d23b2f556d028227b53c2b1 Mon Sep 17 00:00:00 2001 From: better629 Date: Fri, 29 Dec 2023 02:45:54 +0800 Subject: [PATCH] format --- .../test_base_postprocess_plugin.py | 29 ++++++-------- .../test_llm_output_postprocess.py | 9 +++-- .../metagpt/provider/test_azure_openai_api.py | 5 +-- .../metagpt/provider/test_general_api_base.py | 39 +++++++++++++------ tests/metagpt/provider/test_open_llm_api.py | 24 ++++++------ .../provider/zhipuai/test_async_sse_client.py | 2 +- .../provider/zhipuai/test_zhipu_model_api.py | 7 ++-- 7 files changed, 63 insertions(+), 52 deletions(-) diff --git a/tests/metagpt/provider/postprocess/test_base_postprocess_plugin.py b/tests/metagpt/provider/postprocess/test_base_postprocess_plugin.py index e63e4ecfe..824bb88f3 100644 --- a/tests/metagpt/provider/postprocess/test_base_postprocess_plugin.py +++ b/tests/metagpt/provider/postprocess/test_base_postprocess_plugin.py @@ -1,8 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -# @Desc : +# @Desc : -import pytest from metagpt.provider.postprocess.base_postprocess_plugin import BasePostProcessPlugin @@ -14,25 +13,19 @@ raw_output = """ [/CONTENT] """ raw_schema = { - "title":"prd", - "type":"object", - "properties":{ - "Original Requirements":{ - "title":"Original Requirements", - "type":"string" - }, - }, - "required":[ - "Original Requirements", - ] - } + "title": "prd", + "type": "object", + "properties": { + "Original Requirements": {"title": "Original Requirements", "type": "string"}, + }, + "required": [ + "Original Requirements", + ], +} def test_llm_post_process_plugin(): post_process_plugin = BasePostProcessPlugin() - output = post_process_plugin.run( - output=raw_output, - schema=raw_schema - ) + output = post_process_plugin.run(output=raw_output, schema=raw_schema) assert "Original Requirements" in output diff --git a/tests/metagpt/provider/postprocess/test_llm_output_postprocess.py b/tests/metagpt/provider/postprocess/test_llm_output_postprocess.py index 3cb627216..40457b186 100644 --- a/tests/metagpt/provider/postprocess/test_llm_output_postprocess.py +++ b/tests/metagpt/provider/postprocess/test_llm_output_postprocess.py @@ -1,12 +1,13 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -# @Desc : +# @Desc : -import pytest from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess - -from tests.metagpt.provider.postprocess.test_base_postprocess_plugin import raw_output, raw_schema +from tests.metagpt.provider.postprocess.test_base_postprocess_plugin import ( + raw_output, + raw_schema, +) def test_llm_output_postprocess(): diff --git a/tests/metagpt/provider/test_azure_openai_api.py b/tests/metagpt/provider/test_azure_openai_api.py index 208e3104a..f36740e65 100644 --- a/tests/metagpt/provider/test_azure_openai_api.py +++ b/tests/metagpt/provider/test_azure_openai_api.py @@ -1,11 +1,10 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -# @Desc : +# @Desc : -import pytest -from metagpt.provider.azure_openai_api import AzureOpenAILLM from metagpt.config import CONFIG +from metagpt.provider.azure_openai_api import AzureOpenAILLM CONFIG.OPENAI_API_VERSION = "xx" CONFIG.openai_proxy = "http://127.0.0.1:80" # fake value diff --git a/tests/metagpt/provider/test_general_api_base.py b/tests/metagpt/provider/test_general_api_base.py index 52ba32f01..ae768ce95 100644 --- a/tests/metagpt/provider/test_general_api_base.py +++ b/tests/metagpt/provider/test_general_api_base.py @@ -1,16 +1,27 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -# @Desc : +# @Desc : -import pytest import os -import requests -import aiohttp -from typing import Iterator, Tuple, Union, Generator, AsyncGenerator +from typing import AsyncGenerator, Generator, Iterator, Tuple, Union +import aiohttp +import pytest +import requests from openai import OpenAIError -from metagpt.provider.general_api_base import ApiType, log_debug, log_info, log_warn, OpenAIResponse, \ - _requests_proxies_arg, _aiohttp_proxies_arg, _make_session, parse_stream_helper, parse_stream, APIRequestor + +from metagpt.provider.general_api_base import ( + APIRequestor, + ApiType, + OpenAIResponse, + _make_session, + _requests_proxies_arg, + log_debug, + log_info, + log_warn, + parse_stream, + parse_stream_helper, +) def test_basic(): @@ -60,13 +71,15 @@ def test_parse_stream(): api_requestor = APIRequestor(base_url="http://www.baidu.com") -def mock_interpret_response(self, result: requests.Response, stream: bool - ) -> Tuple[Union[bytes, Iterator[Generator]], bytes]: +def mock_interpret_response( + self, result: requests.Response, stream: bool +) -> Tuple[Union[bytes, Iterator[Generator]], bytes]: return b"baidu", False -async def mock_interpret_async_response(self, result: aiohttp.ClientResponse, stream: bool - ) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool]: +async def mock_interpret_async_response( + self, result: aiohttp.ClientResponse, stream: bool +) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool]: return b"baidu", True @@ -79,6 +92,8 @@ def test_api_requestor(mocker): @pytest.mark.asyncio async def test_async_api_requestor(mocker): - mocker.patch("metagpt.provider.general_api_base.APIRequestor._interpret_async_response", mock_interpret_async_response) + mocker.patch( + "metagpt.provider.general_api_base.APIRequestor._interpret_async_response", mock_interpret_async_response + ) resp, _, _ = await api_requestor.arequest(method="get", url="/s?wd=baidu") resp, _, _ = await api_requestor.arequest(method="post", url="/s?wd=baidu") diff --git a/tests/metagpt/provider/test_open_llm_api.py b/tests/metagpt/provider/test_open_llm_api.py index bf094d54a..85069c5e1 100644 --- a/tests/metagpt/provider/test_open_llm_api.py +++ b/tests/metagpt/provider/test_open_llm_api.py @@ -1,15 +1,16 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -# @Desc : +# @Desc : import pytest - from openai.types.chat.chat_completion import ( ChatCompletion, ChatCompletionMessage, Choice, ) -from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, ChoiceDelta, Choice as AChoice +from openai.types.chat.chat_completion_chunk import ChatCompletionChunk +from openai.types.chat.chat_completion_chunk import Choice as AChoice +from openai.types.chat.chat_completion_chunk import ChoiceDelta from openai.types.completion_usage import CompletionUsage from metagpt.config import CONFIG @@ -32,7 +33,7 @@ default_resp = ChatCompletion( message=ChatCompletionMessage(role="assistant", content=resp_content), logprobs=None, ) - ] + ], ) default_resp_chunk = ChatCompletionChunk( @@ -42,26 +43,25 @@ default_resp_chunk = ChatCompletionChunk( created=default_resp.created, choices=[ AChoice( - delta=ChoiceDelta( - content=resp_content, - role="assistant" - ), + delta=ChoiceDelta(content=resp_content, role="assistant"), finish_reason="stop", index=0, logprobs=None, ) - ] + ], ) prompt_msg = "who are you" messages = [{"role": "user", "content": prompt_msg}] -async def mock_openai_acompletions_create(self, stream: bool=False, **kwargs) -> ChatCompletionChunk: +async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs) -> ChatCompletionChunk: if stream: + class Iterator(object): async def __aiter__(self): yield default_resp_chunk + return Iterator() else: return default_resp @@ -75,7 +75,9 @@ async def test_openllm_acompletion(mocker): openllm_gpt.model = "llama-v2-13b-chat" openllm_gpt._update_costs(usage=CompletionUsage(prompt_tokens=100, completion_tokens=100, total_tokens=200)) - assert openllm_gpt.get_costs() == Costs(total_prompt_tokens=100, total_completion_tokens=100, total_cost=0, total_budget=0) + assert openllm_gpt.get_costs() == Costs( + total_prompt_tokens=100, total_completion_tokens=100, total_cost=0, total_budget=0 + ) resp = await openllm_gpt.acompletion(messages) assert resp.choices[0].message.content in resp_content diff --git a/tests/metagpt/provider/zhipuai/test_async_sse_client.py b/tests/metagpt/provider/zhipuai/test_async_sse_client.py index af75e40df..9e5bd5f2e 100644 --- a/tests/metagpt/provider/zhipuai/test_async_sse_client.py +++ b/tests/metagpt/provider/zhipuai/test_async_sse_client.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -# @Desc : +# @Desc : import pytest diff --git a/tests/metagpt/provider/zhipuai/test_zhipu_model_api.py b/tests/metagpt/provider/zhipuai/test_zhipu_model_api.py index b3838e813..83ae2de60 100644 --- a/tests/metagpt/provider/zhipuai/test_zhipu_model_api.py +++ b/tests/metagpt/provider/zhipuai/test_zhipu_model_api.py @@ -3,15 +3,14 @@ # @Desc : from typing import Any, Tuple -import pytest +import pytest import zhipuai from zhipuai.model_api.api import InvokeType from zhipuai.utils.http_client import headers as zhipuai_default_headers from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI - api_key = "xxx.xxx" zhipuai.api_key = api_key @@ -36,5 +35,7 @@ async def test_zhipu_model_api(mocker): assert url_suffix == "/paas/v3/model-api/chatglm_turbo/invoke" mocker.patch("metagpt.provider.general_api_requestor.GeneralAPIRequestor.arequest", mock_requestor_arequest) - result = await ZhiPuModelAPI.arequest(InvokeType.SYNC, stream=False, method="get", headers={}, kwargs={"model": "chatglm_turbo"}) + result = await ZhiPuModelAPI.arequest( + InvokeType.SYNC, stream=False, method="get", headers={}, kwargs={"model": "chatglm_turbo"} + ) assert result == default_resp