mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-09 15:52:38 +02:00
rm useless code and increase UT ratio
This commit is contained in:
parent
a4c7156d57
commit
1d35cab9d7
6 changed files with 75 additions and 109 deletions
|
|
@ -14,11 +14,14 @@ from metagpt.provider.general_api_base import (
|
|||
APIRequestor,
|
||||
ApiType,
|
||||
OpenAIResponse,
|
||||
_aiohttp_proxies_arg,
|
||||
_build_api_url,
|
||||
_make_session,
|
||||
_requests_proxies_arg,
|
||||
log_debug,
|
||||
log_info,
|
||||
log_warn,
|
||||
logfmt,
|
||||
parse_stream,
|
||||
parse_stream_helper,
|
||||
)
|
||||
|
|
@ -36,6 +39,10 @@ def test_basic():
|
|||
log_warn("warn")
|
||||
log_info("info")
|
||||
|
||||
logfmt({"k1": b"v1", "k2": 1, "k3": "a b"})
|
||||
|
||||
_build_api_url(url="http://www.baidu.com/s?wd=", query="baidu")
|
||||
|
||||
|
||||
def test_openai_response():
|
||||
resp = OpenAIResponse(data=[], headers={"retry-after": 3})
|
||||
|
|
@ -53,11 +60,18 @@ def test_proxy():
|
|||
assert _requests_proxies_arg(proxy=proxy) == {"http": proxy, "https": proxy}
|
||||
proxy_dict = {"http": proxy}
|
||||
assert _requests_proxies_arg(proxy=proxy_dict) == proxy_dict
|
||||
assert _aiohttp_proxies_arg(proxy_dict) == proxy
|
||||
proxy_dict = {"https": proxy}
|
||||
assert _requests_proxies_arg(proxy=proxy_dict) == proxy_dict
|
||||
assert _aiohttp_proxies_arg(proxy_dict) == proxy
|
||||
|
||||
assert _make_session() is not None
|
||||
|
||||
assert _aiohttp_proxies_arg(None) is None
|
||||
assert _aiohttp_proxies_arg("test") == "test"
|
||||
with pytest.raises(ValueError):
|
||||
_aiohttp_proxies_arg(-1)
|
||||
|
||||
|
||||
def test_parse_stream():
|
||||
assert parse_stream_helper(None) is None
|
||||
|
|
@ -83,6 +97,29 @@ async def mock_interpret_async_response(
|
|||
return b"baidu", True
|
||||
|
||||
|
||||
def test_requestor_headers():
|
||||
# validate_headers
|
||||
headers = api_requestor._validate_headers(None)
|
||||
assert not headers
|
||||
with pytest.raises(Exception):
|
||||
api_requestor._validate_headers(-1)
|
||||
with pytest.raises(Exception):
|
||||
api_requestor._validate_headers({1: 2})
|
||||
with pytest.raises(Exception):
|
||||
api_requestor._validate_headers({"test": 1})
|
||||
supplied_headers = {"test": "test"}
|
||||
assert api_requestor._validate_headers(supplied_headers) == supplied_headers
|
||||
|
||||
api_requestor.organization = "test"
|
||||
api_requestor.api_version = "test123"
|
||||
api_requestor.api_type = ApiType.OPEN_AI
|
||||
request_id = "test123"
|
||||
headers = api_requestor.request_headers(method="post", extra={}, request_id=request_id)
|
||||
assert headers["LLM-Organization"] == api_requestor.organization
|
||||
assert headers["LLM-Version"] == api_requestor.api_version
|
||||
assert headers["X-Request-Id"] == request_id
|
||||
|
||||
|
||||
def test_api_requestor(mocker):
|
||||
mocker.patch("metagpt.provider.general_api_base.APIRequestor._interpret_response", mock_interpret_response)
|
||||
resp, _, _ = api_requestor.request(method="get", url="/s?wd=baidu")
|
||||
|
|
|
|||
|
|
@ -7,23 +7,25 @@ import pytest
|
|||
from metagpt.provider.human_provider import HumanProvider
|
||||
|
||||
resp_content = "test"
|
||||
|
||||
|
||||
def mock_llm_ask(msg: str, timeout: int = 3) -> str:
|
||||
return resp_content
|
||||
|
||||
|
||||
async def mock_llm_aask(msg: str, timeout: int = 3) -> str:
|
||||
return mock_llm_ask(msg)
|
||||
resp_exit = "exit"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_human_provider(mocker):
|
||||
mocker.patch("metagpt.provider.human_provider.HumanProvider.aask", mock_llm_aask)
|
||||
mocker.patch("builtins.input", lambda _: resp_content)
|
||||
human_provider = HumanProvider()
|
||||
|
||||
resp = human_provider.ask(resp_content)
|
||||
assert resp == resp_content
|
||||
resp = await human_provider.aask(None)
|
||||
assert resp_content == resp
|
||||
|
||||
mocker.patch("builtins.input", lambda _: resp_exit)
|
||||
with pytest.raises(SystemExit):
|
||||
human_provider.ask(resp_exit)
|
||||
|
||||
resp = await human_provider.acompletion([])
|
||||
assert not resp
|
||||
|
||||
resp = await human_provider.acompletion_text([])
|
||||
assert resp == ""
|
||||
|
|
|
|||
|
|
@ -17,10 +17,23 @@ prompt_msg = "who are you"
|
|||
resp_content = "I'm Spark"
|
||||
|
||||
|
||||
def test_get_msg_from_web():
|
||||
class MockWebSocketApp(object):
|
||||
def __init__(self, ws_url, on_message=None, on_error=None, on_close=None, on_open=None):
|
||||
pass
|
||||
|
||||
def run_forever(self, sslopt=None):
|
||||
pass
|
||||
|
||||
|
||||
def test_get_msg_from_web(mocker):
|
||||
mocker.patch("websocket.WebSocketApp", MockWebSocketApp)
|
||||
|
||||
get_msg_from_web = GetMessageFromWeb(text=prompt_msg)
|
||||
assert get_msg_from_web.gen_params()["parameter"]["chat"]["domain"] == "xxxxxx"
|
||||
|
||||
ret = get_msg_from_web.run()
|
||||
assert ret == ""
|
||||
|
||||
|
||||
def mock_spark_get_msg_from_web_run(self) -> str:
|
||||
return resp_content
|
||||
|
|
@ -29,6 +42,7 @@ def mock_spark_get_msg_from_web_run(self) -> str:
|
|||
@pytest.mark.asyncio
|
||||
async def test_spark_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run)
|
||||
|
||||
spark_gpt = SparkLLM()
|
||||
|
||||
resp = await spark_gpt.acompletion([])
|
||||
|
|
|
|||
|
|
@ -16,3 +16,11 @@ async def test_async_sse_client():
|
|||
async_sse_client = AsyncSSEClient(event_source=Iterator())
|
||||
async for event in async_sse_client.async_events():
|
||||
assert event.data, "test_value"
|
||||
|
||||
class InvalidIterator(object):
|
||||
async def __aiter__(self):
|
||||
yield b"invalid: test_value"
|
||||
|
||||
async_sse_client = AsyncSSEClient(event_source=InvalidIterator())
|
||||
async for event in async_sse_client.async_events():
|
||||
assert not event
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI
|
|||
api_key = "xxx.xxx"
|
||||
zhipuai.api_key = api_key
|
||||
|
||||
default_resp = {"result": "test response"}
|
||||
default_resp = b'{"result": "test response"}'
|
||||
|
||||
|
||||
async def mock_requestor_arequest(self, **kwargs) -> Tuple[Any, Any, str]:
|
||||
|
|
@ -39,3 +39,6 @@ async def test_zhipu_model_api(mocker):
|
|||
InvokeType.SYNC, stream=False, method="get", headers={}, kwargs={"model": "chatglm_turbo"}
|
||||
)
|
||||
assert result == default_resp
|
||||
|
||||
result = await ZhiPuModelAPI.ainvoke()
|
||||
assert result["result"] == "test response"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue