From 1d35cab9d77adb2828a579d5e398b176d672e920 Mon Sep 17 00:00:00 2001 From: better629 Date: Tue, 2 Jan 2024 21:21:10 +0800 Subject: [PATCH] rm useless code and increase UT ratio --- metagpt/provider/general_api_base.py | 98 ------------------- .../metagpt/provider/test_general_api_base.py | 37 +++++++ tests/metagpt/provider/test_human_provider.py | 20 ++-- tests/metagpt/provider/test_spark_api.py | 16 ++- .../provider/zhipuai/test_async_sse_client.py | 8 ++ .../provider/zhipuai/test_zhipu_model_api.py | 5 +- 6 files changed, 75 insertions(+), 109 deletions(-) diff --git a/metagpt/provider/general_api_base.py b/metagpt/provider/general_api_base.py index bbe03774c..1b9149396 100644 --- a/metagpt/provider/general_api_base.py +++ b/metagpt/provider/general_api_base.py @@ -15,7 +15,6 @@ from enum import Enum from typing import ( AsyncGenerator, AsyncIterator, - Callable, Dict, Iterator, Optional, @@ -240,54 +239,6 @@ class APIRequestor: self.api_version = api_version or openai.api_version self.organization = organization or openai.organization - def _check_polling_response(self, response: OpenAIResponse, predicate: Callable[[OpenAIResponse], bool]): - if not predicate(response): - return - error_data = response.data["error"] - message = error_data.get("message", "Operation failed") - code = error_data.get("code") - raise openai.APIError(message=message, body=dict(code=code)) - - def _poll( - self, method, url, until, failed, params=None, headers=None, interval=None, delay=None - ) -> Tuple[Iterator[OpenAIResponse], bool, str]: - if delay: - time.sleep(delay) - - response, b, api_key = self.request(method, url, params, headers) - self._check_polling_response(response, failed) - start_time = time.time() - while not until(response): - if time.time() - start_time > TIMEOUT_SECS: - raise openai.APITimeoutError("Operation polling timed out.") - - time.sleep(interval or response.retry_after or 10) - response, b, api_key = self.request(method, url, params, headers) - self._check_polling_response(response, failed) - - response.data = response.data["result"] - return response, b, api_key - - async def _apoll( - self, method, url, until, failed, params=None, headers=None, interval=None, delay=None - ) -> Tuple[Iterator[OpenAIResponse], bool, str]: - if delay: - await asyncio.sleep(delay) - - response, b, api_key = await self.arequest(method, url, params, headers) - self._check_polling_response(response, failed) - start_time = time.time() - while not until(response): - if time.time() - start_time > TIMEOUT_SECS: - raise openai.APITimeoutError("Operation polling timed out.") - - await asyncio.sleep(interval or response.retry_after or 10) - response, b, api_key = await self.arequest(method, url, params, headers) - self._check_polling_response(response, failed) - - response.data = response.data["result"] - return response, b, api_key - @overload def request( self, @@ -469,55 +420,6 @@ class APIRequestor: await ctx.__aexit__(None, None, None) return resp, got_stream, self.api_key - def handle_error_response(self, rbody, rcode, resp, rheaders, stream_error=False): - try: - error_data = resp["error"] - except (KeyError, TypeError): - raise openai.APIError( - "Invalid response object from API: %r (HTTP response code " "was %d)" % (rbody, rcode) - ) - - if "internal_message" in error_data: - error_data["message"] += "\n\n" + error_data["internal_message"] - - log_info( - "LLM API error received", - error_code=error_data.get("code"), - error_type=error_data.get("type"), - error_message=error_data.get("message"), - error_param=error_data.get("param"), - stream_error=stream_error, - ) - - # Rate limits were previously coded as 400's with code 'rate_limit' - if rcode == 429: - return openai.RateLimitError(f"{error_data.get('message')} {rbody} {rcode} {resp} {rheaders}", body=rbody) - elif rcode in [400, 404, 415]: - return openai.BadRequestError( - message=f'{error_data.get("message")}, {error_data.get("param")}, {error_data.get("code")} {rbody} {rcode} {resp} {rheaders}', - body=rbody, - ) - elif rcode == 401: - return openai.AuthenticationError( - f"{error_data.get('message')} {rbody} {rcode} {resp} {rheaders}", body=rbody - ) - elif rcode == 403: - return openai.PermissionDeniedError( - f"{error_data.get('message')} {rbody} {rcode} {resp} {rheaders}", body=rbody - ) - elif rcode == 409: - return openai.ConflictError(f"{error_data.get('message')} {rbody} {rcode} {resp} {rheaders}", body=rbody) - elif stream_error: - # TODO: we will soon attach status codes to stream errors - parts = [error_data.get("message"), "(Error occurred while streaming.)"] - message = " ".join([p for p in parts if p is not None]) - return openai.APIError(f"{message} {rbody} {rcode} {resp} {rheaders}", body=rbody) - else: - return openai.APIError( - f"{error_data.get('message')} {rbody} {rcode} {resp} {rheaders}", - body=rbody, - ) - def request_headers(self, method: str, extra, request_id: Optional[str]) -> Dict[str, str]: user_agent = "LLM/v1 PythonBindings/%s" % (version.VERSION,) diff --git a/tests/metagpt/provider/test_general_api_base.py b/tests/metagpt/provider/test_general_api_base.py index ae768ce95..b8ab619f7 100644 --- a/tests/metagpt/provider/test_general_api_base.py +++ b/tests/metagpt/provider/test_general_api_base.py @@ -14,11 +14,14 @@ from metagpt.provider.general_api_base import ( APIRequestor, ApiType, OpenAIResponse, + _aiohttp_proxies_arg, + _build_api_url, _make_session, _requests_proxies_arg, log_debug, log_info, log_warn, + logfmt, parse_stream, parse_stream_helper, ) @@ -36,6 +39,10 @@ def test_basic(): log_warn("warn") log_info("info") + logfmt({"k1": b"v1", "k2": 1, "k3": "a b"}) + + _build_api_url(url="http://www.baidu.com/s?wd=", query="baidu") + def test_openai_response(): resp = OpenAIResponse(data=[], headers={"retry-after": 3}) @@ -53,11 +60,18 @@ def test_proxy(): assert _requests_proxies_arg(proxy=proxy) == {"http": proxy, "https": proxy} proxy_dict = {"http": proxy} assert _requests_proxies_arg(proxy=proxy_dict) == proxy_dict + assert _aiohttp_proxies_arg(proxy_dict) == proxy proxy_dict = {"https": proxy} assert _requests_proxies_arg(proxy=proxy_dict) == proxy_dict + assert _aiohttp_proxies_arg(proxy_dict) == proxy assert _make_session() is not None + assert _aiohttp_proxies_arg(None) is None + assert _aiohttp_proxies_arg("test") == "test" + with pytest.raises(ValueError): + _aiohttp_proxies_arg(-1) + def test_parse_stream(): assert parse_stream_helper(None) is None @@ -83,6 +97,29 @@ async def mock_interpret_async_response( return b"baidu", True +def test_requestor_headers(): + # validate_headers + headers = api_requestor._validate_headers(None) + assert not headers + with pytest.raises(Exception): + api_requestor._validate_headers(-1) + with pytest.raises(Exception): + api_requestor._validate_headers({1: 2}) + with pytest.raises(Exception): + api_requestor._validate_headers({"test": 1}) + supplied_headers = {"test": "test"} + assert api_requestor._validate_headers(supplied_headers) == supplied_headers + + api_requestor.organization = "test" + api_requestor.api_version = "test123" + api_requestor.api_type = ApiType.OPEN_AI + request_id = "test123" + headers = api_requestor.request_headers(method="post", extra={}, request_id=request_id) + assert headers["LLM-Organization"] == api_requestor.organization + assert headers["LLM-Version"] == api_requestor.api_version + assert headers["X-Request-Id"] == request_id + + def test_api_requestor(mocker): mocker.patch("metagpt.provider.general_api_base.APIRequestor._interpret_response", mock_interpret_response) resp, _, _ = api_requestor.request(method="get", url="/s?wd=baidu") diff --git a/tests/metagpt/provider/test_human_provider.py b/tests/metagpt/provider/test_human_provider.py index 8ba532781..3f63410c0 100644 --- a/tests/metagpt/provider/test_human_provider.py +++ b/tests/metagpt/provider/test_human_provider.py @@ -7,23 +7,25 @@ import pytest from metagpt.provider.human_provider import HumanProvider resp_content = "test" - - -def mock_llm_ask(msg: str, timeout: int = 3) -> str: - return resp_content - - -async def mock_llm_aask(msg: str, timeout: int = 3) -> str: - return mock_llm_ask(msg) +resp_exit = "exit" @pytest.mark.asyncio async def test_async_human_provider(mocker): - mocker.patch("metagpt.provider.human_provider.HumanProvider.aask", mock_llm_aask) + mocker.patch("builtins.input", lambda _: resp_content) human_provider = HumanProvider() + resp = human_provider.ask(resp_content) + assert resp == resp_content resp = await human_provider.aask(None) assert resp_content == resp + mocker.patch("builtins.input", lambda _: resp_exit) + with pytest.raises(SystemExit): + human_provider.ask(resp_exit) + resp = await human_provider.acompletion([]) assert not resp + + resp = await human_provider.acompletion_text([]) + assert resp == "" diff --git a/tests/metagpt/provider/test_spark_api.py b/tests/metagpt/provider/test_spark_api.py index 6d5a0e1f6..ee2d02c97 100644 --- a/tests/metagpt/provider/test_spark_api.py +++ b/tests/metagpt/provider/test_spark_api.py @@ -17,10 +17,23 @@ prompt_msg = "who are you" resp_content = "I'm Spark" -def test_get_msg_from_web(): +class MockWebSocketApp(object): + def __init__(self, ws_url, on_message=None, on_error=None, on_close=None, on_open=None): + pass + + def run_forever(self, sslopt=None): + pass + + +def test_get_msg_from_web(mocker): + mocker.patch("websocket.WebSocketApp", MockWebSocketApp) + get_msg_from_web = GetMessageFromWeb(text=prompt_msg) assert get_msg_from_web.gen_params()["parameter"]["chat"]["domain"] == "xxxxxx" + ret = get_msg_from_web.run() + assert ret == "" + def mock_spark_get_msg_from_web_run(self) -> str: return resp_content @@ -29,6 +42,7 @@ def mock_spark_get_msg_from_web_run(self) -> str: @pytest.mark.asyncio async def test_spark_acompletion(mocker): mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run) + spark_gpt = SparkLLM() resp = await spark_gpt.acompletion([]) diff --git a/tests/metagpt/provider/zhipuai/test_async_sse_client.py b/tests/metagpt/provider/zhipuai/test_async_sse_client.py index 9e5bd5f2e..2649f595b 100644 --- a/tests/metagpt/provider/zhipuai/test_async_sse_client.py +++ b/tests/metagpt/provider/zhipuai/test_async_sse_client.py @@ -16,3 +16,11 @@ async def test_async_sse_client(): async_sse_client = AsyncSSEClient(event_source=Iterator()) async for event in async_sse_client.async_events(): assert event.data, "test_value" + + class InvalidIterator(object): + async def __aiter__(self): + yield b"invalid: test_value" + + async_sse_client = AsyncSSEClient(event_source=InvalidIterator()) + async for event in async_sse_client.async_events(): + assert not event diff --git a/tests/metagpt/provider/zhipuai/test_zhipu_model_api.py b/tests/metagpt/provider/zhipuai/test_zhipu_model_api.py index 83ae2de60..1f0a42fa6 100644 --- a/tests/metagpt/provider/zhipuai/test_zhipu_model_api.py +++ b/tests/metagpt/provider/zhipuai/test_zhipu_model_api.py @@ -14,7 +14,7 @@ from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI api_key = "xxx.xxx" zhipuai.api_key = api_key -default_resp = {"result": "test response"} +default_resp = b'{"result": "test response"}' async def mock_requestor_arequest(self, **kwargs) -> Tuple[Any, Any, str]: @@ -39,3 +39,6 @@ async def test_zhipu_model_api(mocker): InvokeType.SYNC, stream=False, method="get", headers={}, kwargs={"model": "chatglm_turbo"} ) assert result == default_resp + + result = await ZhiPuModelAPI.ainvoke() + assert result["result"] == "test response"