rm useless code and increase UT ratio

This commit is contained in:
better629 2024-01-02 21:21:10 +08:00
parent a4c7156d57
commit 1d35cab9d7
6 changed files with 75 additions and 109 deletions

View file

@ -14,11 +14,14 @@ from metagpt.provider.general_api_base import (
APIRequestor,
ApiType,
OpenAIResponse,
_aiohttp_proxies_arg,
_build_api_url,
_make_session,
_requests_proxies_arg,
log_debug,
log_info,
log_warn,
logfmt,
parse_stream,
parse_stream_helper,
)
@ -36,6 +39,10 @@ def test_basic():
log_warn("warn")
log_info("info")
logfmt({"k1": b"v1", "k2": 1, "k3": "a b"})
_build_api_url(url="http://www.baidu.com/s?wd=", query="baidu")
def test_openai_response():
resp = OpenAIResponse(data=[], headers={"retry-after": 3})
@ -53,11 +60,18 @@ def test_proxy():
assert _requests_proxies_arg(proxy=proxy) == {"http": proxy, "https": proxy}
proxy_dict = {"http": proxy}
assert _requests_proxies_arg(proxy=proxy_dict) == proxy_dict
assert _aiohttp_proxies_arg(proxy_dict) == proxy
proxy_dict = {"https": proxy}
assert _requests_proxies_arg(proxy=proxy_dict) == proxy_dict
assert _aiohttp_proxies_arg(proxy_dict) == proxy
assert _make_session() is not None
assert _aiohttp_proxies_arg(None) is None
assert _aiohttp_proxies_arg("test") == "test"
with pytest.raises(ValueError):
_aiohttp_proxies_arg(-1)
def test_parse_stream():
assert parse_stream_helper(None) is None
@ -83,6 +97,29 @@ async def mock_interpret_async_response(
return b"baidu", True
def test_requestor_headers():
# validate_headers
headers = api_requestor._validate_headers(None)
assert not headers
with pytest.raises(Exception):
api_requestor._validate_headers(-1)
with pytest.raises(Exception):
api_requestor._validate_headers({1: 2})
with pytest.raises(Exception):
api_requestor._validate_headers({"test": 1})
supplied_headers = {"test": "test"}
assert api_requestor._validate_headers(supplied_headers) == supplied_headers
api_requestor.organization = "test"
api_requestor.api_version = "test123"
api_requestor.api_type = ApiType.OPEN_AI
request_id = "test123"
headers = api_requestor.request_headers(method="post", extra={}, request_id=request_id)
assert headers["LLM-Organization"] == api_requestor.organization
assert headers["LLM-Version"] == api_requestor.api_version
assert headers["X-Request-Id"] == request_id
def test_api_requestor(mocker):
mocker.patch("metagpt.provider.general_api_base.APIRequestor._interpret_response", mock_interpret_response)
resp, _, _ = api_requestor.request(method="get", url="/s?wd=baidu")

View file

@ -7,23 +7,25 @@ import pytest
from metagpt.provider.human_provider import HumanProvider
resp_content = "test"
def mock_llm_ask(msg: str, timeout: int = 3) -> str:
return resp_content
async def mock_llm_aask(msg: str, timeout: int = 3) -> str:
return mock_llm_ask(msg)
resp_exit = "exit"
@pytest.mark.asyncio
async def test_async_human_provider(mocker):
mocker.patch("metagpt.provider.human_provider.HumanProvider.aask", mock_llm_aask)
mocker.patch("builtins.input", lambda _: resp_content)
human_provider = HumanProvider()
resp = human_provider.ask(resp_content)
assert resp == resp_content
resp = await human_provider.aask(None)
assert resp_content == resp
mocker.patch("builtins.input", lambda _: resp_exit)
with pytest.raises(SystemExit):
human_provider.ask(resp_exit)
resp = await human_provider.acompletion([])
assert not resp
resp = await human_provider.acompletion_text([])
assert resp == ""

View file

@ -17,10 +17,23 @@ prompt_msg = "who are you"
resp_content = "I'm Spark"
def test_get_msg_from_web():
class MockWebSocketApp(object):
def __init__(self, ws_url, on_message=None, on_error=None, on_close=None, on_open=None):
pass
def run_forever(self, sslopt=None):
pass
def test_get_msg_from_web(mocker):
mocker.patch("websocket.WebSocketApp", MockWebSocketApp)
get_msg_from_web = GetMessageFromWeb(text=prompt_msg)
assert get_msg_from_web.gen_params()["parameter"]["chat"]["domain"] == "xxxxxx"
ret = get_msg_from_web.run()
assert ret == ""
def mock_spark_get_msg_from_web_run(self) -> str:
return resp_content
@ -29,6 +42,7 @@ def mock_spark_get_msg_from_web_run(self) -> str:
@pytest.mark.asyncio
async def test_spark_acompletion(mocker):
mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run)
spark_gpt = SparkLLM()
resp = await spark_gpt.acompletion([])

View file

@ -16,3 +16,11 @@ async def test_async_sse_client():
async_sse_client = AsyncSSEClient(event_source=Iterator())
async for event in async_sse_client.async_events():
assert event.data, "test_value"
class InvalidIterator(object):
async def __aiter__(self):
yield b"invalid: test_value"
async_sse_client = AsyncSSEClient(event_source=InvalidIterator())
async for event in async_sse_client.async_events():
assert not event

View file

@ -14,7 +14,7 @@ from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI
api_key = "xxx.xxx"
zhipuai.api_key = api_key
default_resp = {"result": "test response"}
default_resp = b'{"result": "test response"}'
async def mock_requestor_arequest(self, **kwargs) -> Tuple[Any, Any, str]:
@ -39,3 +39,6 @@ async def test_zhipu_model_api(mocker):
InvokeType.SYNC, stream=False, method="get", headers={}, kwargs={"model": "chatglm_turbo"}
)
assert result == default_resp
result = await ZhiPuModelAPI.ainvoke()
assert result["result"] == "test response"