This commit is contained in:
better629 2023-12-29 02:45:54 +08:00
parent 0f047e5693
commit c8e351f3c8
7 changed files with 63 additions and 52 deletions

View file

@ -1,8 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
# @Desc :
import pytest
from metagpt.provider.postprocess.base_postprocess_plugin import BasePostProcessPlugin
@ -14,25 +13,19 @@ raw_output = """
[/CONTENT]
"""
raw_schema = {
"title":"prd",
"type":"object",
"properties":{
"Original Requirements":{
"title":"Original Requirements",
"type":"string"
},
},
"required":[
"Original Requirements",
]
}
"title": "prd",
"type": "object",
"properties": {
"Original Requirements": {"title": "Original Requirements", "type": "string"},
},
"required": [
"Original Requirements",
],
}
def test_llm_post_process_plugin():
post_process_plugin = BasePostProcessPlugin()
output = post_process_plugin.run(
output=raw_output,
schema=raw_schema
)
output = post_process_plugin.run(output=raw_output, schema=raw_schema)
assert "Original Requirements" in output

View file

@ -1,12 +1,13 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
# @Desc :
import pytest
from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess
from tests.metagpt.provider.postprocess.test_base_postprocess_plugin import raw_output, raw_schema
from tests.metagpt.provider.postprocess.test_base_postprocess_plugin import (
raw_output,
raw_schema,
)
def test_llm_output_postprocess():

View file

@ -1,11 +1,10 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
# @Desc :
import pytest
from metagpt.provider.azure_openai_api import AzureOpenAILLM
from metagpt.config import CONFIG
from metagpt.provider.azure_openai_api import AzureOpenAILLM
CONFIG.OPENAI_API_VERSION = "xx"
CONFIG.openai_proxy = "http://127.0.0.1:80" # fake value

View file

@ -1,16 +1,27 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
# @Desc :
import pytest
import os
import requests
import aiohttp
from typing import Iterator, Tuple, Union, Generator, AsyncGenerator
from typing import AsyncGenerator, Generator, Iterator, Tuple, Union
import aiohttp
import pytest
import requests
from openai import OpenAIError
from metagpt.provider.general_api_base import ApiType, log_debug, log_info, log_warn, OpenAIResponse, \
_requests_proxies_arg, _aiohttp_proxies_arg, _make_session, parse_stream_helper, parse_stream, APIRequestor
from metagpt.provider.general_api_base import (
APIRequestor,
ApiType,
OpenAIResponse,
_make_session,
_requests_proxies_arg,
log_debug,
log_info,
log_warn,
parse_stream,
parse_stream_helper,
)
def test_basic():
@ -60,13 +71,15 @@ def test_parse_stream():
api_requestor = APIRequestor(base_url="http://www.baidu.com")
def mock_interpret_response(self, result: requests.Response, stream: bool
) -> Tuple[Union[bytes, Iterator[Generator]], bytes]:
def mock_interpret_response(
self, result: requests.Response, stream: bool
) -> Tuple[Union[bytes, Iterator[Generator]], bytes]:
return b"baidu", False
async def mock_interpret_async_response(self, result: aiohttp.ClientResponse, stream: bool
) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool]:
async def mock_interpret_async_response(
self, result: aiohttp.ClientResponse, stream: bool
) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool]:
return b"baidu", True
@ -79,6 +92,8 @@ def test_api_requestor(mocker):
@pytest.mark.asyncio
async def test_async_api_requestor(mocker):
mocker.patch("metagpt.provider.general_api_base.APIRequestor._interpret_async_response", mock_interpret_async_response)
mocker.patch(
"metagpt.provider.general_api_base.APIRequestor._interpret_async_response", mock_interpret_async_response
)
resp, _, _ = await api_requestor.arequest(method="get", url="/s?wd=baidu")
resp, _, _ = await api_requestor.arequest(method="post", url="/s?wd=baidu")

View file

@ -1,15 +1,16 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
# @Desc :
import pytest
from openai.types.chat.chat_completion import (
ChatCompletion,
ChatCompletionMessage,
Choice,
)
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, ChoiceDelta, Choice as AChoice
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai.types.chat.chat_completion_chunk import Choice as AChoice
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from openai.types.completion_usage import CompletionUsage
from metagpt.config import CONFIG
@ -32,7 +33,7 @@ default_resp = ChatCompletion(
message=ChatCompletionMessage(role="assistant", content=resp_content),
logprobs=None,
)
]
],
)
default_resp_chunk = ChatCompletionChunk(
@ -42,26 +43,25 @@ default_resp_chunk = ChatCompletionChunk(
created=default_resp.created,
choices=[
AChoice(
delta=ChoiceDelta(
content=resp_content,
role="assistant"
),
delta=ChoiceDelta(content=resp_content, role="assistant"),
finish_reason="stop",
index=0,
logprobs=None,
)
]
],
)
prompt_msg = "who are you"
messages = [{"role": "user", "content": prompt_msg}]
async def mock_openai_acompletions_create(self, stream: bool=False, **kwargs) -> ChatCompletionChunk:
async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs) -> ChatCompletionChunk:
if stream:
class Iterator(object):
async def __aiter__(self):
yield default_resp_chunk
return Iterator()
else:
return default_resp
@ -75,7 +75,9 @@ async def test_openllm_acompletion(mocker):
openllm_gpt.model = "llama-v2-13b-chat"
openllm_gpt._update_costs(usage=CompletionUsage(prompt_tokens=100, completion_tokens=100, total_tokens=200))
assert openllm_gpt.get_costs() == Costs(total_prompt_tokens=100, total_completion_tokens=100, total_cost=0, total_budget=0)
assert openllm_gpt.get_costs() == Costs(
total_prompt_tokens=100, total_completion_tokens=100, total_cost=0, total_budget=0
)
resp = await openllm_gpt.acompletion(messages)
assert resp.choices[0].message.content in resp_content

View file

@ -1,6 +1,6 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
# @Desc :
import pytest

View file

@ -3,15 +3,14 @@
# @Desc :
from typing import Any, Tuple
import pytest
import pytest
import zhipuai
from zhipuai.model_api.api import InvokeType
from zhipuai.utils.http_client import headers as zhipuai_default_headers
from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI
api_key = "xxx.xxx"
zhipuai.api_key = api_key
@ -36,5 +35,7 @@ async def test_zhipu_model_api(mocker):
assert url_suffix == "/paas/v3/model-api/chatglm_turbo/invoke"
mocker.patch("metagpt.provider.general_api_requestor.GeneralAPIRequestor.arequest", mock_requestor_arequest)
result = await ZhiPuModelAPI.arequest(InvokeType.SYNC, stream=False, method="get", headers={}, kwargs={"model": "chatglm_turbo"})
result = await ZhiPuModelAPI.arequest(
InvokeType.SYNC, stream=False, method="get", headers={}, kwargs={"model": "chatglm_turbo"}
)
assert result == default_resp