diff --git a/tests/metagpt/provider/mock_llm_config.py b/tests/metagpt/provider/mock_llm_config.py index e61e32e8b..e0afaa51e 100644 --- a/tests/metagpt/provider/mock_llm_config.py +++ b/tests/metagpt/provider/mock_llm_config.py @@ -53,9 +53,4 @@ mock_llm_config_spark = LLMConfig( base_url="wss://spark-api.xf-yun.com/v3.1/chat", ) -mock_llm_config_qianfan = LLMConfig( - api_type="qianfan", - access_key="xxx", - secret_key="xxx", - model="ERNIE-Bot-turbo" -) +mock_llm_config_qianfan = LLMConfig(api_type="qianfan", access_key="xxx", secret_key="xxx", model="ERNIE-Bot-turbo") diff --git a/tests/metagpt/provider/req_resp_const.py b/tests/metagpt/provider/req_resp_const.py index 20d8e0914..73939e1c6 100644 --- a/tests/metagpt/provider/req_resp_const.py +++ b/tests/metagpt/provider/req_resp_const.py @@ -2,7 +2,7 @@ # -*- coding: utf-8 -*- # @Desc : default request & response data for provider unittest -from typing import Dict + from openai.types.chat.chat_completion import ( ChatCompletion, ChatCompletionMessage, @@ -12,7 +12,7 @@ from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from openai.types.chat.chat_completion_chunk import Choice as AChoice from openai.types.chat.chat_completion_chunk import ChoiceDelta from openai.types.completion_usage import CompletionUsage -from qianfan.resources.typing import QfResponse, default_field +from qianfan.resources.typing import QfResponse from metagpt.provider.base_llm import BaseLLM @@ -80,6 +80,7 @@ def get_openai_chat_completion_chunk(name: str, usage_as_dict: bool = False) -> ) return openai_chat_completion_chunk + # For gemini gemini_messages = [{"role": "user", "parts": prompt}] @@ -92,20 +93,13 @@ qf_jsonbody_dict = { "result": "", "is_truncated": False, "need_clear_history": False, - "usage": { - "prompt_tokens": 7, - "completion_tokens": 15, - "total_tokens": 22 - } + "usage": {"prompt_tokens": 7, "completion_tokens": 15, "total_tokens": 22}, } def get_qianfan_response(name: str) -> QfResponse: qf_jsonbody_dict["result"] = resp_cont_tmpl.format(name=name) - return QfResponse( - code=200, - body=qf_jsonbody_dict - ) + return QfResponse(code=200, body=qf_jsonbody_dict) # For llm general chat functions call diff --git a/tests/metagpt/provider/test_fireworks_llm.py b/tests/metagpt/provider/test_fireworks_llm.py index e28f7500b..1c1aa9caa 100644 --- a/tests/metagpt/provider/test_fireworks_llm.py +++ b/tests/metagpt/provider/test_fireworks_llm.py @@ -16,6 +16,7 @@ from tests.metagpt.provider.mock_llm_config import mock_llm_config from tests.metagpt.provider.req_resp_const import ( get_openai_chat_completion, get_openai_chat_completion_chunk, + llm_general_chat_funcs_test, messages, prompt, resp_cont_tmpl, @@ -70,14 +71,4 @@ async def test_fireworks_acompletion(mocker): resp = await fireworks_llm.acompletion(messages) assert resp.choices[0].message.content in resp_cont - resp = await fireworks_llm.aask(prompt, stream=False) - assert resp == resp_cont - - resp = await fireworks_llm.acompletion_text(messages, stream=False) - assert resp == resp_cont - - resp = await fireworks_llm.acompletion_text(messages, stream=True) - assert resp == resp_cont - - resp = await fireworks_llm.aask(prompt) - assert resp == resp_cont + await llm_general_chat_funcs_test(fireworks_llm, prompt, messages, resp_cont) diff --git a/tests/metagpt/provider/test_google_gemini_api.py b/tests/metagpt/provider/test_google_gemini_api.py index dae9d123b..50c15ee19 100644 --- a/tests/metagpt/provider/test_google_gemini_api.py +++ b/tests/metagpt/provider/test_google_gemini_api.py @@ -13,6 +13,7 @@ from metagpt.provider.google_gemini_api import GeminiLLM from tests.metagpt.provider.mock_llm_config import mock_llm_config from tests.metagpt.provider.req_resp_const import ( gemini_messages, + llm_general_chat_funcs_test, prompt, resp_cont_tmpl, ) @@ -77,14 +78,4 @@ async def test_gemini_acompletion(mocker): resp = await gemini_llm.acompletion(gemini_messages) assert resp.text == default_resp.text - resp = await gemini_llm.aask(prompt, stream=False) - assert resp == resp_cont - - resp = await gemini_llm.acompletion_text(gemini_messages, stream=False) - assert resp == resp_cont - - resp = await gemini_llm.acompletion_text(gemini_messages, stream=True) - assert resp == resp_cont - - resp = await gemini_llm.aask(prompt) - assert resp == resp_cont + await llm_general_chat_funcs_test(gemini_llm, prompt, gemini_messages, resp_cont) diff --git a/tests/metagpt/provider/test_ollama_api.py b/tests/metagpt/provider/test_ollama_api.py index 01d53251c..af2e929e9 100644 --- a/tests/metagpt/provider/test_ollama_api.py +++ b/tests/metagpt/provider/test_ollama_api.py @@ -9,7 +9,12 @@ import pytest from metagpt.provider.ollama_api import OllamaLLM from tests.metagpt.provider.mock_llm_config import mock_llm_config -from tests.metagpt.provider.req_resp_const import messages, prompt, resp_cont_tmpl +from tests.metagpt.provider.req_resp_const import ( + llm_general_chat_funcs_test, + messages, + prompt, + resp_cont_tmpl, +) resp_cont = resp_cont_tmpl.format(name="ollama") default_resp = {"message": {"role": "assistant", "content": resp_cont}} @@ -47,11 +52,4 @@ async def test_gemini_acompletion(mocker): resp = await ollama_llm.aask(prompt, stream=False) assert resp == resp_cont - resp = await ollama_llm.acompletion_text(messages, stream=False) - assert resp == resp_cont - - resp = await ollama_llm.acompletion_text(messages, stream=True) - assert resp == resp_cont - - resp = await ollama_llm.aask(prompt) - assert resp == resp_cont + await llm_general_chat_funcs_test(ollama_llm, prompt, messages, resp_cont) diff --git a/tests/metagpt/provider/test_open_llm_api.py b/tests/metagpt/provider/test_open_llm_api.py index b2e759d06..aa38b95a6 100644 --- a/tests/metagpt/provider/test_open_llm_api.py +++ b/tests/metagpt/provider/test_open_llm_api.py @@ -12,6 +12,7 @@ from tests.metagpt.provider.mock_llm_config import mock_llm_config from tests.metagpt.provider.req_resp_const import ( get_openai_chat_completion, get_openai_chat_completion_chunk, + llm_general_chat_funcs_test, messages, prompt, resp_cont_tmpl, @@ -52,14 +53,4 @@ async def test_openllm_acompletion(mocker): resp = await openllm_llm.acompletion(messages) assert resp.choices[0].message.content in resp_cont - resp = await openllm_llm.aask(prompt, stream=False) - assert resp == resp_cont - - resp = await openllm_llm.acompletion_text(messages, stream=False) - assert resp == resp_cont - - resp = await openllm_llm.acompletion_text(messages, stream=True) - assert resp == resp_cont - - resp = await openllm_llm.aask(prompt) - assert resp == resp_cont + await llm_general_chat_funcs_test(openllm_llm, prompt, messages, resp_cont) diff --git a/tests/metagpt/provider/test_qianfan_api.py b/tests/metagpt/provider/test_qianfan_api.py index 30ac06911..28341425c 100644 --- a/tests/metagpt/provider/test_qianfan_api.py +++ b/tests/metagpt/provider/test_qianfan_api.py @@ -2,14 +2,20 @@ # -*- coding: utf-8 -*- # @Desc : the unittest of qianfan api -from typing import Dict, Union, AsyncIterator -import pytest +from typing import AsyncIterator, Union +import pytest from qianfan.resources.typing import JsonBody, QfResponse from metagpt.provider.qianfan_api import QianFanLLM from tests.metagpt.provider.mock_llm_config import mock_llm_config_qianfan -from tests.metagpt.provider.req_resp_const import resp_cont_tmpl, prompt, messages, llm_general_chat_funcs_test, get_qianfan_response +from tests.metagpt.provider.req_resp_const import ( + get_qianfan_response, + llm_general_chat_funcs_test, + messages, + prompt, + resp_cont_tmpl, +) name = "ERNIE-Bot-turbo" resp_cont = resp_cont_tmpl.format(name=name) @@ -19,12 +25,16 @@ def mock_qianfan_do(self, messages: list[dict], model: str, stream: bool = False return get_qianfan_response(name=name) -async def mock_qianfan_ado(self, messages: list[dict], model: str, stream: bool = True, system: str = None) -> Union[QfResponse, AsyncIterator[QfResponse]]: +async def mock_qianfan_ado( + self, messages: list[dict], model: str, stream: bool = True, system: str = None +) -> Union[QfResponse, AsyncIterator[QfResponse]]: resps = [get_qianfan_response(name=name)] if stream: + async def aresp_iterator(resps: list[JsonBody]): for resp in resps: yield resp + return aresp_iterator(resps) else: return resps[0] diff --git a/tests/metagpt/provider/test_spark_api.py b/tests/metagpt/provider/test_spark_api.py index 8aa8bc7a8..9c278267d 100644 --- a/tests/metagpt/provider/test_spark_api.py +++ b/tests/metagpt/provider/test_spark_api.py @@ -9,7 +9,11 @@ from tests.metagpt.provider.mock_llm_config import ( mock_llm_config, mock_llm_config_spark, ) -from tests.metagpt.provider.req_resp_const import prompt, resp_cont_tmpl +from tests.metagpt.provider.req_resp_const import ( + llm_general_chat_funcs_test, + prompt, + resp_cont_tmpl, +) resp_cont = resp_cont_tmpl.format(name="Spark") @@ -55,14 +59,4 @@ async def test_spark_acompletion(mocker): resp = await spark_llm.acompletion([]) assert resp == resp_cont - resp = await spark_llm.aask(prompt, stream=False) - assert resp == resp_cont - - resp = await spark_llm.acompletion_text([], stream=False) - assert resp == resp_cont - - resp = await spark_llm.acompletion_text([], stream=True) - assert resp == resp_cont - - resp = await spark_llm.aask(prompt) - assert resp == resp_cont + await llm_general_chat_funcs_test(spark_llm, prompt, prompt, resp_cont) diff --git a/tests/metagpt/provider/test_zhipuai_api.py b/tests/metagpt/provider/test_zhipuai_api.py index 3dada367c..8ec9ab4f9 100644 --- a/tests/metagpt/provider/test_zhipuai_api.py +++ b/tests/metagpt/provider/test_zhipuai_api.py @@ -8,10 +8,10 @@ from metagpt.provider.zhipuai_api import ZhiPuAILLM from tests.metagpt.provider.mock_llm_config import mock_llm_config_zhipu from tests.metagpt.provider.req_resp_const import ( get_part_chat_completion, + llm_general_chat_funcs_test, messages, prompt, resp_cont_tmpl, - llm_general_chat_funcs_test ) name = "ChatGLM-4"