diff --git a/metagpt/provider/qianfan_api.py b/metagpt/provider/qianfan_api.py index fbbff7085..6f94b9cea 100644 --- a/metagpt/provider/qianfan_api.py +++ b/metagpt/provider/qianfan_api.py @@ -5,6 +5,7 @@ import copy import os import qianfan +from qianfan import ChatCompletion from qianfan.resources.typing import JsonBody from tenacity import ( after_log, @@ -78,7 +79,7 @@ class QianFanLLM(BaseLLM): # self deployed model on the cloud not to calculate usage, it charges resource pool rental fee self.calc_usage = self.config.calc_usage and self.config.endpoint is None - self.aclient = qianfan.ChatCompletion() + self.aclient: ChatCompletion = qianfan.ChatCompletion() def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict: kwargs = { diff --git a/tests/metagpt/provider/mock_llm_config.py b/tests/metagpt/provider/mock_llm_config.py index 21780f914..e61e32e8b 100644 --- a/tests/metagpt/provider/mock_llm_config.py +++ b/tests/metagpt/provider/mock_llm_config.py @@ -52,3 +52,10 @@ mock_llm_config_spark = LLMConfig( domain="generalv2", base_url="wss://spark-api.xf-yun.com/v3.1/chat", ) + +mock_llm_config_qianfan = LLMConfig( + api_type="qianfan", + access_key="xxx", + secret_key="xxx", + model="ERNIE-Bot-turbo" +) diff --git a/tests/metagpt/provider/req_resp_const.py b/tests/metagpt/provider/req_resp_const.py index a3a7a363c..20d8e0914 100644 --- a/tests/metagpt/provider/req_resp_const.py +++ b/tests/metagpt/provider/req_resp_const.py @@ -2,6 +2,7 @@ # -*- coding: utf-8 -*- # @Desc : default request & response data for provider unittest +from typing import Dict from openai.types.chat.chat_completion import ( ChatCompletion, ChatCompletionMessage, @@ -11,6 +12,9 @@ from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from openai.types.chat.chat_completion_chunk import Choice as AChoice from openai.types.chat.chat_completion_chunk import ChoiceDelta from openai.types.completion_usage import CompletionUsage +from qianfan.resources.typing import QfResponse, default_field + +from metagpt.provider.base_llm import BaseLLM prompt = "who are you?" messages = [{"role": "user", "content": prompt}] @@ -20,14 +24,14 @@ default_resp_cont = resp_cont_tmpl.format(name="GPT") # part of whole ChatCompletion of openai like structure -def get_part_chat_completion(llm_name: str) -> dict: +def get_part_chat_completion(name: str) -> dict: part_chat_completion = { "choices": [ { "index": 0, "message": { "role": "assistant", - "content": resp_cont_tmpl.format(name=llm_name), + "content": resp_cont_tmpl.format(name=name), }, "finish_reason": "stop", } @@ -37,7 +41,7 @@ def get_part_chat_completion(llm_name: str) -> dict: return part_chat_completion -def get_openai_chat_completion(llm_name: str) -> ChatCompletion: +def get_openai_chat_completion(name: str) -> ChatCompletion: openai_chat_completion = ChatCompletion( id="cmpl-a6652c1bb181caae8dd19ad8", model="xx/xxx", @@ -47,7 +51,7 @@ def get_openai_chat_completion(llm_name: str) -> ChatCompletion: Choice( finish_reason="stop", index=0, - message=ChatCompletionMessage(role="assistant", content=resp_cont_tmpl.format(name=llm_name)), + message=ChatCompletionMessage(role="assistant", content=resp_cont_tmpl.format(name=name)), logprobs=None, ) ], @@ -56,7 +60,7 @@ def get_openai_chat_completion(llm_name: str) -> ChatCompletion: return openai_chat_completion -def get_openai_chat_completion_chunk(llm_name: str, usage_as_dict: bool = False) -> ChatCompletionChunk: +def get_openai_chat_completion_chunk(name: str, usage_as_dict: bool = False) -> ChatCompletionChunk: usage = CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202) usage = usage if not usage_as_dict else usage.model_dump() openai_chat_completion_chunk = ChatCompletionChunk( @@ -66,7 +70,7 @@ def get_openai_chat_completion_chunk(llm_name: str, usage_as_dict: bool = False) created=1703300855, choices=[ AChoice( - delta=ChoiceDelta(role="assistant", content=resp_cont_tmpl.format(name=llm_name)), + delta=ChoiceDelta(role="assistant", content=resp_cont_tmpl.format(name=name)), finish_reason="stop", index=0, logprobs=None, @@ -76,5 +80,44 @@ def get_openai_chat_completion_chunk(llm_name: str, usage_as_dict: bool = False) ) return openai_chat_completion_chunk - +# For gemini gemini_messages = [{"role": "user", "parts": prompt}] + + +# For QianFan +qf_jsonbody_dict = { + "id": "as-4v1h587fyv", + "object": "chat.completion", + "created": 1695021339, + "result": "", + "is_truncated": False, + "need_clear_history": False, + "usage": { + "prompt_tokens": 7, + "completion_tokens": 15, + "total_tokens": 22 + } +} + + +def get_qianfan_response(name: str) -> QfResponse: + qf_jsonbody_dict["result"] = resp_cont_tmpl.format(name=name) + return QfResponse( + code=200, + body=qf_jsonbody_dict + ) + + +# For llm general chat functions call +async def llm_general_chat_funcs_test(llm: BaseLLM, prompt: str, messages: list[dict], resp_cont: str): + resp = await llm.aask(prompt, stream=False) + assert resp == resp_cont + + resp = await llm.aask(prompt) + assert resp == resp_cont + + resp = await llm.acompletion_text(messages, stream=False) + assert resp == resp_cont + + resp = await llm.acompletion_text(messages, stream=True) + assert resp == resp_cont diff --git a/tests/metagpt/provider/test_base_llm.py b/tests/metagpt/provider/test_base_llm.py index 0babd6d5f..cf44343bc 100644 --- a/tests/metagpt/provider/test_base_llm.py +++ b/tests/metagpt/provider/test_base_llm.py @@ -17,7 +17,7 @@ from tests.metagpt.provider.req_resp_const import ( prompt, ) -llm_name = "GPT" +name = "GPT" class MockBaseLLM(BaseLLM): @@ -25,10 +25,10 @@ class MockBaseLLM(BaseLLM): pass def completion(self, messages: list[dict], timeout=3): - return get_part_chat_completion(llm_name) + return get_part_chat_completion(name) async def acompletion(self, messages: list[dict], timeout=3): - return get_part_chat_completion(llm_name) + return get_part_chat_completion(name) async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str: return default_resp_cont diff --git a/tests/metagpt/provider/test_fireworks_llm.py b/tests/metagpt/provider/test_fireworks_llm.py index 834f6305f..e28f7500b 100644 --- a/tests/metagpt/provider/test_fireworks_llm.py +++ b/tests/metagpt/provider/test_fireworks_llm.py @@ -21,10 +21,10 @@ from tests.metagpt.provider.req_resp_const import ( resp_cont_tmpl, ) -llm_name = "fireworks" -resp_cont = resp_cont_tmpl.format(name=llm_name) -default_resp = get_openai_chat_completion(llm_name) -default_resp_chunk = get_openai_chat_completion_chunk(llm_name, usage_as_dict=True) +name = "fireworks" +resp_cont = resp_cont_tmpl.format(name=name) +default_resp = get_openai_chat_completion(name) +default_resp_chunk = get_openai_chat_completion_chunk(name, usage_as_dict=True) def test_fireworks_costmanager(): @@ -57,27 +57,27 @@ async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs) async def test_fireworks_acompletion(mocker): mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create) - fireworks_gpt = FireworksLLM(mock_llm_config) - fireworks_gpt.model = "llama-v2-13b-chat" + fireworks_llm = FireworksLLM(mock_llm_config) + fireworks_llm.model = "llama-v2-13b-chat" - fireworks_gpt._update_costs( + fireworks_llm._update_costs( usage=CompletionUsage(prompt_tokens=500000, completion_tokens=500000, total_tokens=1000000) ) - assert fireworks_gpt.get_costs() == Costs( + assert fireworks_llm.get_costs() == Costs( total_prompt_tokens=500000, total_completion_tokens=500000, total_cost=0.5, total_budget=0 ) - resp = await fireworks_gpt.acompletion(messages) + resp = await fireworks_llm.acompletion(messages) assert resp.choices[0].message.content in resp_cont - resp = await fireworks_gpt.aask(prompt, stream=False) + resp = await fireworks_llm.aask(prompt, stream=False) assert resp == resp_cont - resp = await fireworks_gpt.acompletion_text(messages, stream=False) + resp = await fireworks_llm.acompletion_text(messages, stream=False) assert resp == resp_cont - resp = await fireworks_gpt.acompletion_text(messages, stream=True) + resp = await fireworks_llm.acompletion_text(messages, stream=True) assert resp == resp_cont - resp = await fireworks_gpt.aask(prompt) + resp = await fireworks_llm.aask(prompt) assert resp == resp_cont diff --git a/tests/metagpt/provider/test_google_gemini_api.py b/tests/metagpt/provider/test_google_gemini_api.py index ad0c7bbfe..dae9d123b 100644 --- a/tests/metagpt/provider/test_google_gemini_api.py +++ b/tests/metagpt/provider/test_google_gemini_api.py @@ -63,28 +63,28 @@ async def test_gemini_acompletion(mocker): mock_gemini_generate_content_async, ) - gemini_gpt = GeminiLLM(mock_llm_config) + gemini_llm = GeminiLLM(mock_llm_config) - assert gemini_gpt._user_msg(prompt) == {"role": "user", "parts": [prompt]} - assert gemini_gpt._assistant_msg(prompt) == {"role": "model", "parts": [prompt]} + assert gemini_llm._user_msg(prompt) == {"role": "user", "parts": [prompt]} + assert gemini_llm._assistant_msg(prompt) == {"role": "model", "parts": [prompt]} - usage = gemini_gpt.get_usage(gemini_messages, resp_cont) + usage = gemini_llm.get_usage(gemini_messages, resp_cont) assert usage == {"prompt_tokens": 20, "completion_tokens": 20} - resp = gemini_gpt.completion(gemini_messages) + resp = gemini_llm.completion(gemini_messages) assert resp == default_resp - resp = await gemini_gpt.acompletion(gemini_messages) + resp = await gemini_llm.acompletion(gemini_messages) assert resp.text == default_resp.text - resp = await gemini_gpt.aask(prompt, stream=False) + resp = await gemini_llm.aask(prompt, stream=False) assert resp == resp_cont - resp = await gemini_gpt.acompletion_text(gemini_messages, stream=False) + resp = await gemini_llm.acompletion_text(gemini_messages, stream=False) assert resp == resp_cont - resp = await gemini_gpt.acompletion_text(gemini_messages, stream=True) + resp = await gemini_llm.acompletion_text(gemini_messages, stream=True) assert resp == resp_cont - resp = await gemini_gpt.aask(prompt) + resp = await gemini_llm.aask(prompt) assert resp == resp_cont diff --git a/tests/metagpt/provider/test_ollama_api.py b/tests/metagpt/provider/test_ollama_api.py index 8e2625e35..01d53251c 100644 --- a/tests/metagpt/provider/test_ollama_api.py +++ b/tests/metagpt/provider/test_ollama_api.py @@ -39,19 +39,19 @@ async def mock_ollama_arequest(self, stream: bool = False, **kwargs) -> Tuple[An async def test_gemini_acompletion(mocker): mocker.patch("metagpt.provider.general_api_requestor.GeneralAPIRequestor.arequest", mock_ollama_arequest) - ollama_gpt = OllamaLLM(mock_llm_config) + ollama_llm = OllamaLLM(mock_llm_config) - resp = await ollama_gpt.acompletion(messages) + resp = await ollama_llm.acompletion(messages) assert resp["message"]["content"] == default_resp["message"]["content"] - resp = await ollama_gpt.aask(prompt, stream=False) + resp = await ollama_llm.aask(prompt, stream=False) assert resp == resp_cont - resp = await ollama_gpt.acompletion_text(messages, stream=False) + resp = await ollama_llm.acompletion_text(messages, stream=False) assert resp == resp_cont - resp = await ollama_gpt.acompletion_text(messages, stream=True) + resp = await ollama_llm.acompletion_text(messages, stream=True) assert resp == resp_cont - resp = await ollama_gpt.aask(prompt) + resp = await ollama_llm.aask(prompt) assert resp == resp_cont diff --git a/tests/metagpt/provider/test_open_llm_api.py b/tests/metagpt/provider/test_open_llm_api.py index 5b8a506e9..b2e759d06 100644 --- a/tests/metagpt/provider/test_open_llm_api.py +++ b/tests/metagpt/provider/test_open_llm_api.py @@ -17,11 +17,11 @@ from tests.metagpt.provider.req_resp_const import ( resp_cont_tmpl, ) -llm_name = "llama2-7b" -resp_cont = resp_cont_tmpl.format(name=llm_name) -default_resp = get_openai_chat_completion(llm_name) +name = "llama2-7b" +resp_cont = resp_cont_tmpl.format(name=name) +default_resp = get_openai_chat_completion(name) -default_resp_chunk = get_openai_chat_completion_chunk(llm_name) +default_resp_chunk = get_openai_chat_completion_chunk(name) async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs) -> ChatCompletionChunk: @@ -40,26 +40,26 @@ async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs) async def test_openllm_acompletion(mocker): mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create) - openllm_gpt = OpenLLM(mock_llm_config) - openllm_gpt.model = "llama-v2-13b-chat" + openllm_llm = OpenLLM(mock_llm_config) + openllm_llm.model = "llama-v2-13b-chat" - openllm_gpt.cost_manager = CostManager() - openllm_gpt._update_costs(usage=CompletionUsage(prompt_tokens=100, completion_tokens=100, total_tokens=200)) - assert openllm_gpt.get_costs() == Costs( + openllm_llm.cost_manager = CostManager() + openllm_llm._update_costs(usage=CompletionUsage(prompt_tokens=100, completion_tokens=100, total_tokens=200)) + assert openllm_llm.get_costs() == Costs( total_prompt_tokens=100, total_completion_tokens=100, total_cost=0, total_budget=0 ) - resp = await openllm_gpt.acompletion(messages) + resp = await openllm_llm.acompletion(messages) assert resp.choices[0].message.content in resp_cont - resp = await openllm_gpt.aask(prompt, stream=False) + resp = await openllm_llm.aask(prompt, stream=False) assert resp == resp_cont - resp = await openllm_gpt.acompletion_text(messages, stream=False) + resp = await openllm_llm.acompletion_text(messages, stream=False) assert resp == resp_cont - resp = await openllm_gpt.acompletion_text(messages, stream=True) + resp = await openllm_llm.acompletion_text(messages, stream=True) assert resp == resp_cont - resp = await openllm_gpt.aask(prompt) + resp = await openllm_llm.aask(prompt) assert resp == resp_cont diff --git a/tests/metagpt/provider/test_qianfan_api.py b/tests/metagpt/provider/test_qianfan_api.py index 76271b1e8..30ac06911 100644 --- a/tests/metagpt/provider/test_qianfan_api.py +++ b/tests/metagpt/provider/test_qianfan_api.py @@ -2,14 +2,45 @@ # -*- coding: utf-8 -*- # @Desc : the unittest of qianfan api +from typing import Dict, Union, AsyncIterator import pytest +from qianfan.resources.typing import JsonBody, QfResponse + from metagpt.provider.qianfan_api import QianFanLLM -from tests.metagpt.provider.req_resp_const import prompt, messages, resp_cont_tmpl +from tests.metagpt.provider.mock_llm_config import mock_llm_config_qianfan +from tests.metagpt.provider.req_resp_const import resp_cont_tmpl, prompt, messages, llm_general_chat_funcs_test, get_qianfan_response + +name = "ERNIE-Bot-turbo" +resp_cont = resp_cont_tmpl.format(name=name) -resp_cont = resp_cont_tmpl.format(name="ERNIE-Bot-turbo") +def mock_qianfan_do(self, messages: list[dict], model: str, stream: bool = False, system: str = None) -> QfResponse: + return get_qianfan_response(name=name) -def test_qianfan_acompletion(mocker): - assert True, True +async def mock_qianfan_ado(self, messages: list[dict], model: str, stream: bool = True, system: str = None) -> Union[QfResponse, AsyncIterator[QfResponse]]: + resps = [get_qianfan_response(name=name)] + if stream: + async def aresp_iterator(resps: list[JsonBody]): + for resp in resps: + yield resp + return aresp_iterator(resps) + else: + return resps[0] + + +@pytest.mark.asyncio +async def test_qianfan_acompletion(mocker): + mocker.patch("qianfan.resources.llm.chat_completion.ChatCompletion.do", mock_qianfan_do) + mocker.patch("qianfan.resources.llm.chat_completion.ChatCompletion.ado", mock_qianfan_ado) + + qianfan_llm = QianFanLLM(mock_llm_config_qianfan) + + resp = qianfan_llm.completion(messages) + assert resp.get("result") == resp_cont + + resp = await qianfan_llm.acompletion(messages) + assert resp.get("result") == resp_cont + + await llm_general_chat_funcs_test(qianfan_llm, prompt, messages, resp_cont) diff --git a/tests/metagpt/provider/test_spark_api.py b/tests/metagpt/provider/test_spark_api.py index 32a839393..8aa8bc7a8 100644 --- a/tests/metagpt/provider/test_spark_api.py +++ b/tests/metagpt/provider/test_spark_api.py @@ -50,19 +50,19 @@ async def test_spark_aask(mocker): async def test_spark_acompletion(mocker): mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run) - spark_gpt = SparkLLM(mock_llm_config) + spark_llm = SparkLLM(mock_llm_config) - resp = await spark_gpt.acompletion([]) + resp = await spark_llm.acompletion([]) assert resp == resp_cont - resp = await spark_gpt.aask(prompt, stream=False) + resp = await spark_llm.aask(prompt, stream=False) assert resp == resp_cont - resp = await spark_gpt.acompletion_text([], stream=False) + resp = await spark_llm.acompletion_text([], stream=False) assert resp == resp_cont - resp = await spark_gpt.acompletion_text([], stream=True) + resp = await spark_llm.acompletion_text([], stream=True) assert resp == resp_cont - resp = await spark_gpt.aask(prompt) + resp = await spark_llm.aask(prompt) assert resp == resp_cont diff --git a/tests/metagpt/provider/test_zhipuai_api.py b/tests/metagpt/provider/test_zhipuai_api.py index 064562bff..3dada367c 100644 --- a/tests/metagpt/provider/test_zhipuai_api.py +++ b/tests/metagpt/provider/test_zhipuai_api.py @@ -11,11 +11,12 @@ from tests.metagpt.provider.req_resp_const import ( messages, prompt, resp_cont_tmpl, + llm_general_chat_funcs_test ) -llm_name = "ChatGLM-4" -resp_cont = resp_cont_tmpl.format(name=llm_name) -default_resp = get_part_chat_completion(llm_name) +name = "ChatGLM-4" +resp_cont = resp_cont_tmpl.format(name=name) +default_resp = get_part_chat_completion(name) async def mock_zhipuai_acreate_stream(**kwargs): @@ -47,22 +48,12 @@ async def test_zhipuai_acompletion(mocker): mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.acreate", mock_zhipuai_acreate) mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.acreate_stream", mock_zhipuai_acreate_stream) - zhipu_gpt = ZhiPuAILLM(mock_llm_config_zhipu) + zhipu_llm = ZhiPuAILLM(mock_llm_config_zhipu) - resp = await zhipu_gpt.acompletion(messages) + resp = await zhipu_llm.acompletion(messages) assert resp["choices"][0]["message"]["content"] == resp_cont - resp = await zhipu_gpt.aask(prompt, stream=False) - assert resp == resp_cont - - resp = await zhipu_gpt.acompletion_text(messages, stream=False) - assert resp == resp_cont - - resp = await zhipu_gpt.acompletion_text(messages, stream=True) - assert resp == resp_cont - - resp = await zhipu_gpt.aask(prompt) - assert resp == resp_cont + await llm_general_chat_funcs_test(zhipu_llm, prompt, messages, resp_cont) def test_zhipuai_proxy():