simplify provider ut code

This commit is contained in:
better629 2024-02-07 18:42:22 +08:00
parent d3f6e38e8a
commit 997e25e97d
9 changed files with 40 additions and 76 deletions

View file

@ -53,9 +53,4 @@ mock_llm_config_spark = LLMConfig(
base_url="wss://spark-api.xf-yun.com/v3.1/chat",
)
mock_llm_config_qianfan = LLMConfig(
api_type="qianfan",
access_key="xxx",
secret_key="xxx",
model="ERNIE-Bot-turbo"
)
mock_llm_config_qianfan = LLMConfig(api_type="qianfan", access_key="xxx", secret_key="xxx", model="ERNIE-Bot-turbo")

View file

@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
# @Desc : default request & response data for provider unittest
from typing import Dict
from openai.types.chat.chat_completion import (
ChatCompletion,
ChatCompletionMessage,
@ -12,7 +12,7 @@ from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai.types.chat.chat_completion_chunk import Choice as AChoice
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from openai.types.completion_usage import CompletionUsage
from qianfan.resources.typing import QfResponse, default_field
from qianfan.resources.typing import QfResponse
from metagpt.provider.base_llm import BaseLLM
@ -80,6 +80,7 @@ def get_openai_chat_completion_chunk(name: str, usage_as_dict: bool = False) ->
)
return openai_chat_completion_chunk
# For gemini
gemini_messages = [{"role": "user", "parts": prompt}]
@ -92,20 +93,13 @@ qf_jsonbody_dict = {
"result": "",
"is_truncated": False,
"need_clear_history": False,
"usage": {
"prompt_tokens": 7,
"completion_tokens": 15,
"total_tokens": 22
}
"usage": {"prompt_tokens": 7, "completion_tokens": 15, "total_tokens": 22},
}
def get_qianfan_response(name: str) -> QfResponse:
qf_jsonbody_dict["result"] = resp_cont_tmpl.format(name=name)
return QfResponse(
code=200,
body=qf_jsonbody_dict
)
return QfResponse(code=200, body=qf_jsonbody_dict)
# For llm general chat functions call

View file

@ -16,6 +16,7 @@ from tests.metagpt.provider.mock_llm_config import mock_llm_config
from tests.metagpt.provider.req_resp_const import (
get_openai_chat_completion,
get_openai_chat_completion_chunk,
llm_general_chat_funcs_test,
messages,
prompt,
resp_cont_tmpl,
@ -70,14 +71,4 @@ async def test_fireworks_acompletion(mocker):
resp = await fireworks_llm.acompletion(messages)
assert resp.choices[0].message.content in resp_cont
resp = await fireworks_llm.aask(prompt, stream=False)
assert resp == resp_cont
resp = await fireworks_llm.acompletion_text(messages, stream=False)
assert resp == resp_cont
resp = await fireworks_llm.acompletion_text(messages, stream=True)
assert resp == resp_cont
resp = await fireworks_llm.aask(prompt)
assert resp == resp_cont
await llm_general_chat_funcs_test(fireworks_llm, prompt, messages, resp_cont)

View file

@ -13,6 +13,7 @@ from metagpt.provider.google_gemini_api import GeminiLLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config
from tests.metagpt.provider.req_resp_const import (
gemini_messages,
llm_general_chat_funcs_test,
prompt,
resp_cont_tmpl,
)
@ -77,14 +78,4 @@ async def test_gemini_acompletion(mocker):
resp = await gemini_llm.acompletion(gemini_messages)
assert resp.text == default_resp.text
resp = await gemini_llm.aask(prompt, stream=False)
assert resp == resp_cont
resp = await gemini_llm.acompletion_text(gemini_messages, stream=False)
assert resp == resp_cont
resp = await gemini_llm.acompletion_text(gemini_messages, stream=True)
assert resp == resp_cont
resp = await gemini_llm.aask(prompt)
assert resp == resp_cont
await llm_general_chat_funcs_test(gemini_llm, prompt, gemini_messages, resp_cont)

View file

@ -9,7 +9,12 @@ import pytest
from metagpt.provider.ollama_api import OllamaLLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config
from tests.metagpt.provider.req_resp_const import messages, prompt, resp_cont_tmpl
from tests.metagpt.provider.req_resp_const import (
llm_general_chat_funcs_test,
messages,
prompt,
resp_cont_tmpl,
)
resp_cont = resp_cont_tmpl.format(name="ollama")
default_resp = {"message": {"role": "assistant", "content": resp_cont}}
@ -47,11 +52,4 @@ async def test_gemini_acompletion(mocker):
resp = await ollama_llm.aask(prompt, stream=False)
assert resp == resp_cont
resp = await ollama_llm.acompletion_text(messages, stream=False)
assert resp == resp_cont
resp = await ollama_llm.acompletion_text(messages, stream=True)
assert resp == resp_cont
resp = await ollama_llm.aask(prompt)
assert resp == resp_cont
await llm_general_chat_funcs_test(ollama_llm, prompt, messages, resp_cont)

View file

@ -12,6 +12,7 @@ from tests.metagpt.provider.mock_llm_config import mock_llm_config
from tests.metagpt.provider.req_resp_const import (
get_openai_chat_completion,
get_openai_chat_completion_chunk,
llm_general_chat_funcs_test,
messages,
prompt,
resp_cont_tmpl,
@ -52,14 +53,4 @@ async def test_openllm_acompletion(mocker):
resp = await openllm_llm.acompletion(messages)
assert resp.choices[0].message.content in resp_cont
resp = await openllm_llm.aask(prompt, stream=False)
assert resp == resp_cont
resp = await openllm_llm.acompletion_text(messages, stream=False)
assert resp == resp_cont
resp = await openllm_llm.acompletion_text(messages, stream=True)
assert resp == resp_cont
resp = await openllm_llm.aask(prompt)
assert resp == resp_cont
await llm_general_chat_funcs_test(openllm_llm, prompt, messages, resp_cont)

View file

@ -2,14 +2,20 @@
# -*- coding: utf-8 -*-
# @Desc : the unittest of qianfan api
from typing import Dict, Union, AsyncIterator
import pytest
from typing import AsyncIterator, Union
import pytest
from qianfan.resources.typing import JsonBody, QfResponse
from metagpt.provider.qianfan_api import QianFanLLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config_qianfan
from tests.metagpt.provider.req_resp_const import resp_cont_tmpl, prompt, messages, llm_general_chat_funcs_test, get_qianfan_response
from tests.metagpt.provider.req_resp_const import (
get_qianfan_response,
llm_general_chat_funcs_test,
messages,
prompt,
resp_cont_tmpl,
)
name = "ERNIE-Bot-turbo"
resp_cont = resp_cont_tmpl.format(name=name)
@ -19,12 +25,16 @@ def mock_qianfan_do(self, messages: list[dict], model: str, stream: bool = False
return get_qianfan_response(name=name)
async def mock_qianfan_ado(self, messages: list[dict], model: str, stream: bool = True, system: str = None) -> Union[QfResponse, AsyncIterator[QfResponse]]:
async def mock_qianfan_ado(
self, messages: list[dict], model: str, stream: bool = True, system: str = None
) -> Union[QfResponse, AsyncIterator[QfResponse]]:
resps = [get_qianfan_response(name=name)]
if stream:
async def aresp_iterator(resps: list[JsonBody]):
for resp in resps:
yield resp
return aresp_iterator(resps)
else:
return resps[0]

View file

@ -9,7 +9,11 @@ from tests.metagpt.provider.mock_llm_config import (
mock_llm_config,
mock_llm_config_spark,
)
from tests.metagpt.provider.req_resp_const import prompt, resp_cont_tmpl
from tests.metagpt.provider.req_resp_const import (
llm_general_chat_funcs_test,
prompt,
resp_cont_tmpl,
)
resp_cont = resp_cont_tmpl.format(name="Spark")
@ -55,14 +59,4 @@ async def test_spark_acompletion(mocker):
resp = await spark_llm.acompletion([])
assert resp == resp_cont
resp = await spark_llm.aask(prompt, stream=False)
assert resp == resp_cont
resp = await spark_llm.acompletion_text([], stream=False)
assert resp == resp_cont
resp = await spark_llm.acompletion_text([], stream=True)
assert resp == resp_cont
resp = await spark_llm.aask(prompt)
assert resp == resp_cont
await llm_general_chat_funcs_test(spark_llm, prompt, prompt, resp_cont)

View file

@ -8,10 +8,10 @@ from metagpt.provider.zhipuai_api import ZhiPuAILLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config_zhipu
from tests.metagpt.provider.req_resp_const import (
get_part_chat_completion,
llm_general_chat_funcs_test,
messages,
prompt,
resp_cont_tmpl,
llm_general_chat_funcs_test
)
name = "ChatGLM-4"