mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-21 14:05:17 +02:00
fix zhipuai error follow https://github.com/geekan/MetaGPT/pull/770/files
This commit is contained in:
parent
1008bbb4c3
commit
57befdac30
11 changed files with 105 additions and 212 deletions
|
|
@ -3,7 +3,6 @@
|
|||
# @Desc : the unittest of ZhiPuAILLM
|
||||
|
||||
import pytest
|
||||
from zhipuai.utils.sse_client import Event
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.provider.zhipuai_api import ZhiPuAILLM
|
||||
|
|
@ -15,35 +14,16 @@ messages = [{"role": "user", "content": prompt_msg}]
|
|||
|
||||
resp_content = "I'm chatglm-turbo"
|
||||
default_resp = {
|
||||
"code": 200,
|
||||
"data": {
|
||||
"choices": [{"role": "assistant", "content": resp_content}],
|
||||
"usage": {"prompt_tokens": 20, "completion_tokens": 20},
|
||||
},
|
||||
"choices": [{"finish_reason": "stop", "index": 0, "message": {"content": resp_content, "role": "assistant"}}],
|
||||
"usage": {"completion_tokens": 22, "prompt_tokens": 19, "total_tokens": 41},
|
||||
}
|
||||
|
||||
|
||||
def mock_zhipuai_invoke(**kwargs) -> dict:
|
||||
return default_resp
|
||||
|
||||
|
||||
async def mock_zhipuai_ainvoke(**kwargs) -> dict:
|
||||
return default_resp
|
||||
|
||||
|
||||
async def mock_zhipuai_asse_invoke(**kwargs):
|
||||
async def mock_zhipuai_acreate_stream(self, **kwargs):
|
||||
class MockResponse(object):
|
||||
async def _aread(self):
|
||||
class Iterator(object):
|
||||
events = [
|
||||
Event(id="xxx", event="add", data=resp_content, retry=0),
|
||||
Event(
|
||||
id="xxx",
|
||||
event="finish",
|
||||
data="",
|
||||
meta='{"usage": {"completion_tokens": 20,"prompt_tokens": 20}}',
|
||||
),
|
||||
]
|
||||
events = [{"choices": [{"index": 0, "delta": {"content": resp_content, "role": "assistant"}}]}]
|
||||
|
||||
async def __aiter__(self):
|
||||
for event in self.events:
|
||||
|
|
@ -52,23 +32,26 @@ async def mock_zhipuai_asse_invoke(**kwargs):
|
|||
async for chunk in Iterator():
|
||||
yield chunk
|
||||
|
||||
async def async_events(self):
|
||||
async def stream(self):
|
||||
async for chunk in self._aread():
|
||||
yield chunk
|
||||
|
||||
return MockResponse()
|
||||
|
||||
|
||||
async def mock_zhipuai_acreate(self, **kwargs) -> dict:
|
||||
return default_resp
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zhipuai_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.invoke", mock_zhipuai_invoke)
|
||||
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.ainvoke", mock_zhipuai_ainvoke)
|
||||
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.asse_invoke", mock_zhipuai_asse_invoke)
|
||||
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.acreate", mock_zhipuai_acreate)
|
||||
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.acreate_stream", mock_zhipuai_acreate_stream)
|
||||
|
||||
zhipu_gpt = ZhiPuAILLM()
|
||||
|
||||
resp = await zhipu_gpt.acompletion(messages)
|
||||
assert resp["data"]["choices"][0]["content"] == resp_content
|
||||
assert resp["choices"][0]["message"]["content"] == resp_content
|
||||
|
||||
resp = await zhipu_gpt.aask(prompt_msg, stream=False)
|
||||
assert resp == resp_content
|
||||
|
|
|
|||
|
|
@ -11,16 +11,16 @@ from metagpt.provider.zhipuai.async_sse_client import AsyncSSEClient
|
|||
async def test_async_sse_client():
|
||||
class Iterator(object):
|
||||
async def __aiter__(self):
|
||||
yield b"data: test_value"
|
||||
yield b'data: {"test_key": "test_value"}'
|
||||
|
||||
async_sse_client = AsyncSSEClient(event_source=Iterator())
|
||||
async for event in async_sse_client.async_events():
|
||||
assert event.data, "test_value"
|
||||
async for chunk in async_sse_client.stream():
|
||||
assert "test_value" in chunk.values()
|
||||
|
||||
class InvalidIterator(object):
|
||||
async def __aiter__(self):
|
||||
yield b"invalid: test_value"
|
||||
|
||||
async_sse_client = AsyncSSEClient(event_source=InvalidIterator())
|
||||
async for event in async_sse_client.async_events():
|
||||
assert not event
|
||||
async for chunk in async_sse_client.stream():
|
||||
assert not chunk
|
||||
|
|
|
|||
|
|
@ -6,15 +6,13 @@ from typing import Any, Tuple
|
|||
|
||||
import pytest
|
||||
import zhipuai
|
||||
from zhipuai.model_api.api import InvokeType
|
||||
from zhipuai.utils.http_client import headers as zhipuai_default_headers
|
||||
|
||||
from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI
|
||||
|
||||
api_key = "xxx.xxx"
|
||||
zhipuai.api_key = api_key
|
||||
|
||||
default_resp = b'{"result": "test response"}'
|
||||
default_resp = b'{"choices": [{"finish_reason": "stop", "index": 0, "message": {"content": "test response", "role": "assistant"}}]}'
|
||||
|
||||
|
||||
async def mock_requestor_arequest(self, **kwargs) -> Tuple[Any, Any, str]:
|
||||
|
|
@ -23,22 +21,15 @@ async def mock_requestor_arequest(self, **kwargs) -> Tuple[Any, Any, str]:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zhipu_model_api(mocker):
|
||||
header = ZhiPuModelAPI.get_header()
|
||||
zhipuai_default_headers.update({"Authorization": api_key})
|
||||
assert header == zhipuai_default_headers
|
||||
|
||||
sse_header = ZhiPuModelAPI.get_sse_header()
|
||||
assert len(sse_header["Authorization"]) == 191
|
||||
|
||||
url_prefix, url_suffix = ZhiPuModelAPI.split_zhipu_api_url(InvokeType.SYNC, kwargs={"model": "chatglm_turbo"})
|
||||
url_prefix, url_suffix = ZhiPuModelAPI(api_key=api_key).split_zhipu_api_url()
|
||||
assert url_prefix == "https://open.bigmodel.cn/api"
|
||||
assert url_suffix == "/paas/v3/model-api/chatglm_turbo/invoke"
|
||||
assert url_suffix == "/paas/v4/chat/completions"
|
||||
|
||||
mocker.patch("metagpt.provider.general_api_requestor.GeneralAPIRequestor.arequest", mock_requestor_arequest)
|
||||
result = await ZhiPuModelAPI.arequest(
|
||||
InvokeType.SYNC, stream=False, method="get", headers={}, kwargs={"model": "chatglm_turbo"}
|
||||
result = await ZhiPuModelAPI(api_key=api_key).arequest(
|
||||
stream=False, method="get", headers={}, kwargs={"model": "glm-3-turbo"}
|
||||
)
|
||||
assert result == default_resp
|
||||
|
||||
result = await ZhiPuModelAPI.ainvoke()
|
||||
assert result["result"] == "test response"
|
||||
result = await ZhiPuModelAPI(api_key=api_key).acreate()
|
||||
assert result["choices"][0]["message"]["content"] == "test response"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue