mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-21 14:05:17 +02:00
rm useless code and increase UT ratio
This commit is contained in:
parent
a4c7156d57
commit
1d35cab9d7
6 changed files with 75 additions and 109 deletions
|
|
@ -15,7 +15,6 @@ from enum import Enum
|
|||
from typing import (
|
||||
AsyncGenerator,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
Optional,
|
||||
|
|
@ -240,54 +239,6 @@ class APIRequestor:
|
|||
self.api_version = api_version or openai.api_version
|
||||
self.organization = organization or openai.organization
|
||||
|
||||
def _check_polling_response(self, response: OpenAIResponse, predicate: Callable[[OpenAIResponse], bool]):
|
||||
if not predicate(response):
|
||||
return
|
||||
error_data = response.data["error"]
|
||||
message = error_data.get("message", "Operation failed")
|
||||
code = error_data.get("code")
|
||||
raise openai.APIError(message=message, body=dict(code=code))
|
||||
|
||||
def _poll(
|
||||
self, method, url, until, failed, params=None, headers=None, interval=None, delay=None
|
||||
) -> Tuple[Iterator[OpenAIResponse], bool, str]:
|
||||
if delay:
|
||||
time.sleep(delay)
|
||||
|
||||
response, b, api_key = self.request(method, url, params, headers)
|
||||
self._check_polling_response(response, failed)
|
||||
start_time = time.time()
|
||||
while not until(response):
|
||||
if time.time() - start_time > TIMEOUT_SECS:
|
||||
raise openai.APITimeoutError("Operation polling timed out.")
|
||||
|
||||
time.sleep(interval or response.retry_after or 10)
|
||||
response, b, api_key = self.request(method, url, params, headers)
|
||||
self._check_polling_response(response, failed)
|
||||
|
||||
response.data = response.data["result"]
|
||||
return response, b, api_key
|
||||
|
||||
async def _apoll(
|
||||
self, method, url, until, failed, params=None, headers=None, interval=None, delay=None
|
||||
) -> Tuple[Iterator[OpenAIResponse], bool, str]:
|
||||
if delay:
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
response, b, api_key = await self.arequest(method, url, params, headers)
|
||||
self._check_polling_response(response, failed)
|
||||
start_time = time.time()
|
||||
while not until(response):
|
||||
if time.time() - start_time > TIMEOUT_SECS:
|
||||
raise openai.APITimeoutError("Operation polling timed out.")
|
||||
|
||||
await asyncio.sleep(interval or response.retry_after or 10)
|
||||
response, b, api_key = await self.arequest(method, url, params, headers)
|
||||
self._check_polling_response(response, failed)
|
||||
|
||||
response.data = response.data["result"]
|
||||
return response, b, api_key
|
||||
|
||||
@overload
|
||||
def request(
|
||||
self,
|
||||
|
|
@ -469,55 +420,6 @@ class APIRequestor:
|
|||
await ctx.__aexit__(None, None, None)
|
||||
return resp, got_stream, self.api_key
|
||||
|
||||
def handle_error_response(self, rbody, rcode, resp, rheaders, stream_error=False):
|
||||
try:
|
||||
error_data = resp["error"]
|
||||
except (KeyError, TypeError):
|
||||
raise openai.APIError(
|
||||
"Invalid response object from API: %r (HTTP response code " "was %d)" % (rbody, rcode)
|
||||
)
|
||||
|
||||
if "internal_message" in error_data:
|
||||
error_data["message"] += "\n\n" + error_data["internal_message"]
|
||||
|
||||
log_info(
|
||||
"LLM API error received",
|
||||
error_code=error_data.get("code"),
|
||||
error_type=error_data.get("type"),
|
||||
error_message=error_data.get("message"),
|
||||
error_param=error_data.get("param"),
|
||||
stream_error=stream_error,
|
||||
)
|
||||
|
||||
# Rate limits were previously coded as 400's with code 'rate_limit'
|
||||
if rcode == 429:
|
||||
return openai.RateLimitError(f"{error_data.get('message')} {rbody} {rcode} {resp} {rheaders}", body=rbody)
|
||||
elif rcode in [400, 404, 415]:
|
||||
return openai.BadRequestError(
|
||||
message=f'{error_data.get("message")}, {error_data.get("param")}, {error_data.get("code")} {rbody} {rcode} {resp} {rheaders}',
|
||||
body=rbody,
|
||||
)
|
||||
elif rcode == 401:
|
||||
return openai.AuthenticationError(
|
||||
f"{error_data.get('message')} {rbody} {rcode} {resp} {rheaders}", body=rbody
|
||||
)
|
||||
elif rcode == 403:
|
||||
return openai.PermissionDeniedError(
|
||||
f"{error_data.get('message')} {rbody} {rcode} {resp} {rheaders}", body=rbody
|
||||
)
|
||||
elif rcode == 409:
|
||||
return openai.ConflictError(f"{error_data.get('message')} {rbody} {rcode} {resp} {rheaders}", body=rbody)
|
||||
elif stream_error:
|
||||
# TODO: we will soon attach status codes to stream errors
|
||||
parts = [error_data.get("message"), "(Error occurred while streaming.)"]
|
||||
message = " ".join([p for p in parts if p is not None])
|
||||
return openai.APIError(f"{message} {rbody} {rcode} {resp} {rheaders}", body=rbody)
|
||||
else:
|
||||
return openai.APIError(
|
||||
f"{error_data.get('message')} {rbody} {rcode} {resp} {rheaders}",
|
||||
body=rbody,
|
||||
)
|
||||
|
||||
def request_headers(self, method: str, extra, request_id: Optional[str]) -> Dict[str, str]:
|
||||
user_agent = "LLM/v1 PythonBindings/%s" % (version.VERSION,)
|
||||
|
||||
|
|
|
|||
|
|
@ -14,11 +14,14 @@ from metagpt.provider.general_api_base import (
|
|||
APIRequestor,
|
||||
ApiType,
|
||||
OpenAIResponse,
|
||||
_aiohttp_proxies_arg,
|
||||
_build_api_url,
|
||||
_make_session,
|
||||
_requests_proxies_arg,
|
||||
log_debug,
|
||||
log_info,
|
||||
log_warn,
|
||||
logfmt,
|
||||
parse_stream,
|
||||
parse_stream_helper,
|
||||
)
|
||||
|
|
@ -36,6 +39,10 @@ def test_basic():
|
|||
log_warn("warn")
|
||||
log_info("info")
|
||||
|
||||
logfmt({"k1": b"v1", "k2": 1, "k3": "a b"})
|
||||
|
||||
_build_api_url(url="http://www.baidu.com/s?wd=", query="baidu")
|
||||
|
||||
|
||||
def test_openai_response():
|
||||
resp = OpenAIResponse(data=[], headers={"retry-after": 3})
|
||||
|
|
@ -53,11 +60,18 @@ def test_proxy():
|
|||
assert _requests_proxies_arg(proxy=proxy) == {"http": proxy, "https": proxy}
|
||||
proxy_dict = {"http": proxy}
|
||||
assert _requests_proxies_arg(proxy=proxy_dict) == proxy_dict
|
||||
assert _aiohttp_proxies_arg(proxy_dict) == proxy
|
||||
proxy_dict = {"https": proxy}
|
||||
assert _requests_proxies_arg(proxy=proxy_dict) == proxy_dict
|
||||
assert _aiohttp_proxies_arg(proxy_dict) == proxy
|
||||
|
||||
assert _make_session() is not None
|
||||
|
||||
assert _aiohttp_proxies_arg(None) is None
|
||||
assert _aiohttp_proxies_arg("test") == "test"
|
||||
with pytest.raises(ValueError):
|
||||
_aiohttp_proxies_arg(-1)
|
||||
|
||||
|
||||
def test_parse_stream():
|
||||
assert parse_stream_helper(None) is None
|
||||
|
|
@ -83,6 +97,29 @@ async def mock_interpret_async_response(
|
|||
return b"baidu", True
|
||||
|
||||
|
||||
def test_requestor_headers():
|
||||
# validate_headers
|
||||
headers = api_requestor._validate_headers(None)
|
||||
assert not headers
|
||||
with pytest.raises(Exception):
|
||||
api_requestor._validate_headers(-1)
|
||||
with pytest.raises(Exception):
|
||||
api_requestor._validate_headers({1: 2})
|
||||
with pytest.raises(Exception):
|
||||
api_requestor._validate_headers({"test": 1})
|
||||
supplied_headers = {"test": "test"}
|
||||
assert api_requestor._validate_headers(supplied_headers) == supplied_headers
|
||||
|
||||
api_requestor.organization = "test"
|
||||
api_requestor.api_version = "test123"
|
||||
api_requestor.api_type = ApiType.OPEN_AI
|
||||
request_id = "test123"
|
||||
headers = api_requestor.request_headers(method="post", extra={}, request_id=request_id)
|
||||
assert headers["LLM-Organization"] == api_requestor.organization
|
||||
assert headers["LLM-Version"] == api_requestor.api_version
|
||||
assert headers["X-Request-Id"] == request_id
|
||||
|
||||
|
||||
def test_api_requestor(mocker):
|
||||
mocker.patch("metagpt.provider.general_api_base.APIRequestor._interpret_response", mock_interpret_response)
|
||||
resp, _, _ = api_requestor.request(method="get", url="/s?wd=baidu")
|
||||
|
|
|
|||
|
|
@ -7,23 +7,25 @@ import pytest
|
|||
from metagpt.provider.human_provider import HumanProvider
|
||||
|
||||
resp_content = "test"
|
||||
|
||||
|
||||
def mock_llm_ask(msg: str, timeout: int = 3) -> str:
|
||||
return resp_content
|
||||
|
||||
|
||||
async def mock_llm_aask(msg: str, timeout: int = 3) -> str:
|
||||
return mock_llm_ask(msg)
|
||||
resp_exit = "exit"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_human_provider(mocker):
|
||||
mocker.patch("metagpt.provider.human_provider.HumanProvider.aask", mock_llm_aask)
|
||||
mocker.patch("builtins.input", lambda _: resp_content)
|
||||
human_provider = HumanProvider()
|
||||
|
||||
resp = human_provider.ask(resp_content)
|
||||
assert resp == resp_content
|
||||
resp = await human_provider.aask(None)
|
||||
assert resp_content == resp
|
||||
|
||||
mocker.patch("builtins.input", lambda _: resp_exit)
|
||||
with pytest.raises(SystemExit):
|
||||
human_provider.ask(resp_exit)
|
||||
|
||||
resp = await human_provider.acompletion([])
|
||||
assert not resp
|
||||
|
||||
resp = await human_provider.acompletion_text([])
|
||||
assert resp == ""
|
||||
|
|
|
|||
|
|
@ -17,10 +17,23 @@ prompt_msg = "who are you"
|
|||
resp_content = "I'm Spark"
|
||||
|
||||
|
||||
def test_get_msg_from_web():
|
||||
class MockWebSocketApp(object):
|
||||
def __init__(self, ws_url, on_message=None, on_error=None, on_close=None, on_open=None):
|
||||
pass
|
||||
|
||||
def run_forever(self, sslopt=None):
|
||||
pass
|
||||
|
||||
|
||||
def test_get_msg_from_web(mocker):
|
||||
mocker.patch("websocket.WebSocketApp", MockWebSocketApp)
|
||||
|
||||
get_msg_from_web = GetMessageFromWeb(text=prompt_msg)
|
||||
assert get_msg_from_web.gen_params()["parameter"]["chat"]["domain"] == "xxxxxx"
|
||||
|
||||
ret = get_msg_from_web.run()
|
||||
assert ret == ""
|
||||
|
||||
|
||||
def mock_spark_get_msg_from_web_run(self) -> str:
|
||||
return resp_content
|
||||
|
|
@ -29,6 +42,7 @@ def mock_spark_get_msg_from_web_run(self) -> str:
|
|||
@pytest.mark.asyncio
|
||||
async def test_spark_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run)
|
||||
|
||||
spark_gpt = SparkLLM()
|
||||
|
||||
resp = await spark_gpt.acompletion([])
|
||||
|
|
|
|||
|
|
@ -16,3 +16,11 @@ async def test_async_sse_client():
|
|||
async_sse_client = AsyncSSEClient(event_source=Iterator())
|
||||
async for event in async_sse_client.async_events():
|
||||
assert event.data, "test_value"
|
||||
|
||||
class InvalidIterator(object):
|
||||
async def __aiter__(self):
|
||||
yield b"invalid: test_value"
|
||||
|
||||
async_sse_client = AsyncSSEClient(event_source=InvalidIterator())
|
||||
async for event in async_sse_client.async_events():
|
||||
assert not event
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI
|
|||
api_key = "xxx.xxx"
|
||||
zhipuai.api_key = api_key
|
||||
|
||||
default_resp = {"result": "test response"}
|
||||
default_resp = b'{"result": "test response"}'
|
||||
|
||||
|
||||
async def mock_requestor_arequest(self, **kwargs) -> Tuple[Any, Any, str]:
|
||||
|
|
@ -39,3 +39,6 @@ async def test_zhipu_model_api(mocker):
|
|||
InvokeType.SYNC, stream=False, method="get", headers={}, kwargs={"model": "chatglm_turbo"}
|
||||
)
|
||||
assert result == default_resp
|
||||
|
||||
result = await ZhiPuModelAPI.ainvoke()
|
||||
assert result["result"] == "test response"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue