From 454e6164fb804bba1fcc58797140e3ee15e137ab Mon Sep 17 00:00:00 2001 From: better629 Date: Mon, 25 Dec 2023 18:00:51 +0800 Subject: [PATCH] update provider unittests --- metagpt/provider/anthropic_api.py | 10 +- metagpt/provider/base_gpt_api.py | 2 +- metagpt/provider/fireworks_api.py | 4 +- metagpt/provider/google_gemini_api.py | 7 +- metagpt/provider/ollama_api.py | 7 +- metagpt/provider/spark_api.py | 11 +- metagpt/provider/zhipuai_api.py | 5 +- tests/metagpt/provider/test_anthropic_api.py | 29 +++++ tests/metagpt/provider/test_base_gpt_api.py | 100 +++++++++++++++++- tests/metagpt/provider/test_fireworks_api.py | 67 +++++++++--- .../provider/test_general_api_requestor.py | 20 ++++ .../provider/test_google_gemini_api.py | 53 +++++++--- tests/metagpt/provider/test_human_provider.py | 38 +++++++ .../metagpt/provider/test_metagpt_llm_api.py | 4 +- tests/metagpt/provider/test_ollama_api.py | 52 ++++++--- tests/metagpt/provider/test_openai.py | 19 +++- tests/metagpt/provider/test_spark_api.py | 56 ++++++++-- tests/metagpt/provider/test_zhipuai_api.py | 54 +++++++--- 18 files changed, 460 insertions(+), 78 deletions(-) create mode 100644 tests/metagpt/provider/test_anthropic_api.py create mode 100644 tests/metagpt/provider/test_general_api_requestor.py create mode 100644 tests/metagpt/provider/test_human_provider.py diff --git a/metagpt/provider/anthropic_api.py b/metagpt/provider/anthropic_api.py index f5b06c855..b9d7d9e38 100644 --- a/metagpt/provider/anthropic_api.py +++ b/metagpt/provider/anthropic_api.py @@ -7,13 +7,13 @@ """ import anthropic -from anthropic import Anthropic +from anthropic import Anthropic, AsyncAnthropic from metagpt.config import CONFIG class Claude2: - def ask(self, prompt): + def ask(self, prompt: str) -> str: client = Anthropic(api_key=CONFIG.anthropic_api_key) res = client.completions.create( @@ -23,10 +23,10 @@ class Claude2: ) return res.completion - async def aask(self, prompt): - client = Anthropic(api_key=CONFIG.anthropic_api_key) + async def aask(self, prompt: str) -> str: + aclient = AsyncAnthropic(api_key=CONFIG.anthropic_api_key) - res = client.completions.create( + res = await aclient.completions.create( model="claude-2", prompt=f"{anthropic.HUMAN_PROMPT} {prompt} {anthropic.AI_PROMPT}", max_tokens_to_sample=1000, diff --git a/metagpt/provider/base_gpt_api.py b/metagpt/provider/base_gpt_api.py index f650305e3..a5541324f 100644 --- a/metagpt/provider/base_gpt_api.py +++ b/metagpt/provider/base_gpt_api.py @@ -162,7 +162,7 @@ class BaseGPTAPI(BaseChatbot): def messages_to_prompt(self, messages: list[dict]): """[{"role": "user", "content": msg}] to user: etc.""" - return "\n".join([f"{i['role']}: {i['content']}" for i in messages]) + return "\n".join([f"{i.role}: {i.content}" for i in messages]) def messages_to_dict(self, messages): """objects to [{"role": "user", "content": msg}] etc.""" diff --git a/metagpt/provider/fireworks_api.py b/metagpt/provider/fireworks_api.py index 96b7db453..55b1b6c28 100644 --- a/metagpt/provider/fireworks_api.py +++ b/metagpt/provider/fireworks_api.py @@ -133,7 +133,9 @@ class FireWorksGPTAPI(OpenAIGPTAPI): retry=retry_if_exception_type(APIConnectionError), retry_error_callback=log_and_reraise, ) - async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str: + async def acompletion_text( + self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3 + ) -> str: """when streaming, print each token in place.""" if stream: return await self._achat_completion_stream(messages) diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index eb91cc32b..e9d3ea70d 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -79,6 +79,9 @@ class GeminiGPTAPI(BaseGPTAPI): except Exception as e: logger.error(f"google gemini updats costs failed! exp: {e}") + def close(self): + pass + def get_choice_text(self, resp: GenerateContentResponse) -> str: return resp.text @@ -133,7 +136,9 @@ class GeminiGPTAPI(BaseGPTAPI): retry=retry_if_exception_type(ConnectionError), retry_error_callback=log_and_reraise, ) - async def acompletion_text(self, messages: list[dict], stream=False) -> str: + async def acompletion_text( + self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3 + ) -> str: """response in async with stream or non-stream mode""" if stream: return await self._achat_completion_stream(messages) diff --git a/metagpt/provider/ollama_api.py b/metagpt/provider/ollama_api.py index 05bdb5a1f..7d858e769 100644 --- a/metagpt/provider/ollama_api.py +++ b/metagpt/provider/ollama_api.py @@ -57,6 +57,9 @@ class OllamaGPTAPI(BaseGPTAPI): self.model = config.ollama_api_model + def close(self): + pass + def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict: kwargs = {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream} return kwargs @@ -144,7 +147,9 @@ class OllamaGPTAPI(BaseGPTAPI): retry=retry_if_exception_type(ConnectionError), retry_error_callback=log_and_reraise, ) - async def acompletion_text(self, messages: list[dict], stream=False) -> str: + async def acompletion_text( + self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3 + ) -> str: """response in async with stream or non-stream mode""" if stream: return await self._achat_completion_stream(messages) diff --git a/metagpt/provider/spark_api.py b/metagpt/provider/spark_api.py index 484fa7956..70076bc86 100644 --- a/metagpt/provider/spark_api.py +++ b/metagpt/provider/spark_api.py @@ -26,16 +26,19 @@ from metagpt.provider.llm_provider_registry import register_provider @register_provider(LLMProviderEnum.SPARK) -class SparkAPI(BaseGPTAPI): +class SparkGPTAPI(BaseGPTAPI): def __init__(self): logger.warning("当前方法无法支持异步运行。当你使用acompletion时,并不能并行访问。") + def close(self): + pass + def ask(self, msg: str) -> str: message = [self._default_system_msg(), self._user_msg(msg)] rsp = self.completion(message) return rsp - async def aask(self, msg: str, system_msgs: Optional[list[str]] = None) -> str: + async def aask(self, msg: str, system_msgs: Optional[list[str]] = None, stream: bool = True) -> str: if system_msgs: message = self._system_msgs(system_msgs) + [self._user_msg(msg)] else: @@ -47,7 +50,9 @@ class SparkAPI(BaseGPTAPI): def get_choice_text(self, rsp: dict) -> str: return rsp["payload"]["choices"]["text"][-1]["content"] - async def acompletion_text(self, messages: list[dict], stream=False) -> str: + async def acompletion_text( + self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3 + ) -> str: # 不支持 logger.error("该功能禁用。") w = GetMessageFromWeb(messages) diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py index 4a2cae51d..0d5663431 100644 --- a/metagpt/provider/zhipuai_api.py +++ b/metagpt/provider/zhipuai_api.py @@ -64,6 +64,9 @@ class ZhiPuAIGPTAPI(BaseGPTAPI): except Exception as e: logger.error(f"zhipuai updats costs failed! exp: {e}") + def close(self): + pass + def get_choice_text(self, resp: dict) -> str: """get the first text of choice from llm response""" assist_msg = resp.get("data", {}).get("choices", [{"role": "error"}])[-1] @@ -131,6 +134,6 @@ class ZhiPuAIGPTAPI(BaseGPTAPI): async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str: """response in async with stream or non-stream mode""" if stream: - return await self._achat_completion_stream(messages, timeout=timeout) + return await self._achat_completion_stream(messages) resp = await self._achat_completion(messages) return self.get_choice_text(resp) diff --git a/tests/metagpt/provider/test_anthropic_api.py b/tests/metagpt/provider/test_anthropic_api.py new file mode 100644 index 000000000..4d3de5320 --- /dev/null +++ b/tests/metagpt/provider/test_anthropic_api.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : the unittest of Claude2 + +import pytest + +from metagpt.provider.anthropic_api import Claude2 + +prompt = "who are you" +resp = "I'am Claude2" + + +def mock_llm_ask(self, msg: str) -> str: + return resp + + +async def mock_llm_aask(self, msg: str) -> str: + return resp + + +def test_claude2_ask(mocker): + mocker.patch("metagpt.provider.anthropic_api.Claude2.ask", mock_llm_ask) + assert resp == Claude2().ask(prompt) + + +@pytest.mark.asyncio +async def test_claude2_aask(mocker): + mocker.patch("metagpt.provider.anthropic_api.Claude2.aask", mock_llm_aask) + assert resp == await Claude2().aask(prompt) diff --git a/tests/metagpt/provider/test_base_gpt_api.py b/tests/metagpt/provider/test_base_gpt_api.py index 6cfe3b02d..aaa7b64ff 100644 --- a/tests/metagpt/provider/test_base_gpt_api.py +++ b/tests/metagpt/provider/test_base_gpt_api.py @@ -6,10 +6,106 @@ @File : test_base_gpt_api.py """ +import pytest + +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.schema import Message +default_chat_resp = { + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "I'am GPT", + }, + "finish_reason": "stop", + } + ] +} +prompt_msg = "who are you" +resp_content = default_chat_resp["choices"][0]["message"]["content"] -def test_message(): - message = Message(role="user", content="wtf") + +class MockBaseGPTAPI(BaseGPTAPI): + def completion(self, messages: list[dict], timeout=3): + return default_chat_resp + + async def acompletion(self, messages: list[dict], timeout=3): + return default_chat_resp + + async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str: + return resp_content + + async def close(self): + return default_chat_resp + + +def test_base_gpt_api(): + message = Message(role="user", content="hello") assert "role" in message.to_dict() assert "user" in str(message) + + base_gpt_api = MockBaseGPTAPI() + msg_prompt = base_gpt_api.messages_to_prompt([message]) + assert msg_prompt == "user: hello" + + msg_dict = base_gpt_api.messages_to_dict([message]) + assert msg_dict == [{"role": "user", "content": "hello"}] + + openai_funccall_resp = { + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "test", + "tool_calls": [ + { + "id": "call_Y5r6Ddr2Qc2ZrqgfwzPX5l72", + "type": "function", + "function": { + "name": "execute", + "arguments": '{\n "language": "python",\n "code": "print(\'Hello, World!\')"\n}', + }, + } + ], + }, + "finish_reason": "stop", + } + ] + } + func: dict = base_gpt_api.get_choice_function(openai_funccall_resp) + assert func == { + "name": "execute", + "arguments": '{\n "language": "python",\n "code": "print(\'Hello, World!\')"\n}', + } + + func_args: dict = base_gpt_api.get_choice_function_arguments(openai_funccall_resp) + assert func_args == {"language": "python", "code": "print('Hello, World!')"} + + choice_text = base_gpt_api.get_choice_text(openai_funccall_resp) + assert choice_text == openai_funccall_resp["choices"][0]["message"]["content"] + + resp = base_gpt_api.ask(prompt_msg) + assert resp == resp_content + + resp = base_gpt_api.ask_batch([prompt_msg]) + assert resp == resp_content + + resp = base_gpt_api.ask_code([prompt_msg]) + assert resp == resp_content + + +@pytest.mark.asyncio +async def test_async_base_gpt_api(): + base_gpt_api = MockBaseGPTAPI() + + resp = await base_gpt_api.aask(prompt_msg) + assert resp == resp_content + + resp = await base_gpt_api.aask_batch([prompt_msg]) + assert resp == resp_content + + resp = await base_gpt_api.aask_code([prompt_msg]) + assert resp == resp_content diff --git a/tests/metagpt/provider/test_fireworks_api.py b/tests/metagpt/provider/test_fireworks_api.py index 43e45adf3..caf8b9f45 100644 --- a/tests/metagpt/provider/test_fireworks_api.py +++ b/tests/metagpt/provider/test_fireworks_api.py @@ -10,41 +10,82 @@ from openai.types.chat.chat_completion import ( ) from openai.types.completion_usage import CompletionUsage -from metagpt.provider.fireworks_api import FireWorksGPTAPI +from metagpt.provider.fireworks_api import ( + MODEL_GRADE_TOKEN_COSTS, + FireworksCostManager, + FireWorksGPTAPI, +) +resp_content = "I'm fireworks" default_resp = ChatCompletion( id="cmpl-a6652c1bb181caae8dd19ad8", model="accounts/fireworks/models/llama-v2-13b-chat", object="chat.completion", created=1703300855, choices=[ - Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(role="assistant", content="I'm fireworks")) + Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(role="assistant", content=resp_content)) ], usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202), ) -messages = [{"role": "user", "content": "who are you"}] +prompt_msg = "who are you" +messages = [{"role": "user", "content": prompt_msg}] -def mock_llm_ask(self, messages: list[dict]) -> ChatCompletion: +def test_fireworks_costmanager(): + cost_manager = FireworksCostManager() + assert MODEL_GRADE_TOKEN_COSTS["-1"] == cost_manager.model_grade_token_costs("test") + assert MODEL_GRADE_TOKEN_COSTS["-1"] == cost_manager.model_grade_token_costs("xxx-81b-chat") + assert MODEL_GRADE_TOKEN_COSTS["16"] == cost_manager.model_grade_token_costs("llama-v2-13b-chat") + assert MODEL_GRADE_TOKEN_COSTS["16"] == cost_manager.model_grade_token_costs("xxx-15.5b-chat") + assert MODEL_GRADE_TOKEN_COSTS["16"] == cost_manager.model_grade_token_costs("xxx-16b-chat") + assert MODEL_GRADE_TOKEN_COSTS["80"] == cost_manager.model_grade_token_costs("xxx-80b-chat") + assert MODEL_GRADE_TOKEN_COSTS["mixtral-8x7b"] == cost_manager.model_grade_token_costs("mixtral-8x7b-chat") + + +def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> ChatCompletion: return default_resp +async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False, timeout: int = 60) -> ChatCompletion: + return default_resp + + +async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str: + return default_resp.choices[0].message.content + + def test_fireworks_completion(mocker): - mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI.completion", mock_llm_ask) + mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI.completion", mock_llm_completion) + fireworks_gpt = FireWorksGPTAPI() - resp = FireWorksGPTAPI().completion(messages) - assert "fireworks" in resp.choices[0].message.content + resp = fireworks_gpt.completion(messages) + assert resp.choices[0].message.content == resp_content - -async def mock_llm_aask(self, messgaes: list[dict], stream: bool = False) -> ChatCompletion: - return default_resp + resp = fireworks_gpt.ask(prompt_msg) + assert resp == resp_content @pytest.mark.asyncio async def test_fireworks_acompletion(mocker): - mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI.acompletion", mock_llm_aask) + mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI.acompletion", mock_llm_acompletion) + mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI._achat_completion", mock_llm_acompletion) + mocker.patch( + "metagpt.provider.fireworks_api.FireWorksGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream + ) + fireworks_gpt = FireWorksGPTAPI() - resp = await FireWorksGPTAPI().acompletion(messages, stream=False) + resp = await fireworks_gpt.acompletion(messages, stream=False) + assert resp.choices[0].message.content in resp_content - assert "fireworks" in resp.choices[0].message.content + resp = await fireworks_gpt.aask(prompt_msg, stream=False) + assert resp == resp_content + + resp = await fireworks_gpt.acompletion_text(messages, stream=False) + assert resp == resp_content + + resp = await fireworks_gpt.acompletion_text(messages, stream=True) + assert resp == resp_content + + resp = await fireworks_gpt.aask(prompt_msg) + assert resp == resp_content diff --git a/tests/metagpt/provider/test_general_api_requestor.py b/tests/metagpt/provider/test_general_api_requestor.py new file mode 100644 index 000000000..28130fa65 --- /dev/null +++ b/tests/metagpt/provider/test_general_api_requestor.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : the unittest of APIRequestor + +import pytest + +from metagpt.provider.general_api_requestor import GeneralAPIRequestor + +api_requestor = GeneralAPIRequestor(base_url="http://www.baidu.com") + + +def test_api_requestor(): + resp, _, _ = api_requestor.request(method="get", url="/s?wd=baidu") + assert b"baidu" in resp + + +@pytest.mark.asyncio +async def test_async_api_requestor(): + resp, _, _ = await api_requestor.arequest(method="get", url="/s?wd=baidu") + assert b"baidu" in resp diff --git a/tests/metagpt/provider/test_google_gemini_api.py b/tests/metagpt/provider/test_google_gemini_api.py index 9c8cf46c0..aec7b8520 100644 --- a/tests/metagpt/provider/test_google_gemini_api.py +++ b/tests/metagpt/provider/test_google_gemini_api.py @@ -9,33 +9,62 @@ import pytest from metagpt.provider.google_gemini_api import GeminiGPTAPI -messages = [{"role": "user", "parts": "who are you"}] - @dataclass class MockGeminiResponse(ABC): text: str -default_resp = MockGeminiResponse(text="I'm gemini from google") +prompt_msg = "who are you" +messages = [{"role": "user", "parts": prompt_msg}] +resp_content = "I'm gemini from google" +default_resp = MockGeminiResponse(text=resp_content) -def mock_llm_ask(self, messages: list[dict]) -> MockGeminiResponse: +def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> MockGeminiResponse: return default_resp +async def mock_llm_acompletion( + self, messgaes: list[dict], stream: bool = False, timeout: int = 60 +) -> MockGeminiResponse: + return default_resp + + +async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str: + return resp_content + + def test_gemini_completion(mocker): - mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.completion", mock_llm_ask) - resp = GeminiGPTAPI().completion(messages) - assert resp.text == default_resp.text + mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.completion", mock_llm_completion) + gemini_gpt = GeminiGPTAPI() + resp = gemini_gpt.completion(messages) + assert resp.text == resp_content - -async def mock_llm_aask(self, messgaes: list[dict]) -> MockGeminiResponse: - return default_resp + resp = gemini_gpt.ask(prompt_msg) + assert resp == resp_content @pytest.mark.asyncio async def test_gemini_acompletion(mocker): - mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.acompletion", mock_llm_aask) - resp = await GeminiGPTAPI().acompletion(messages) + mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.acompletion", mock_llm_acompletion) + mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI._achat_completion", mock_llm_acompletion) + mocker.patch( + "metagpt.provider.google_gemini_api.GeminiGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream + ) + gemini_gpt = GeminiGPTAPI() + + resp = await gemini_gpt.acompletion(messages) assert resp.text == default_resp.text + + resp = await gemini_gpt.aask(prompt_msg, stream=False) + assert resp == resp_content + + resp = await gemini_gpt.acompletion_text(messages, stream=False) + assert resp == resp_content + + resp = await gemini_gpt.acompletion_text(messages, stream=True) + assert resp == resp_content + + resp = await gemini_gpt.aask(prompt_msg) + assert resp == resp_content diff --git a/tests/metagpt/provider/test_human_provider.py b/tests/metagpt/provider/test_human_provider.py new file mode 100644 index 000000000..caab9f15f --- /dev/null +++ b/tests/metagpt/provider/test_human_provider.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : the unittest of HumanProvider + +import pytest + +from metagpt.provider.human_provider import HumanProvider + +resp_content = "test" + + +def mock_llm_ask(msg: str, timeout: int = 3) -> str: + return resp_content + + +async def mock_llm_aask(msg: str, timeout: int = 3) -> str: + return mock_llm_ask(msg) + + +def test_human_provider(mocker): + mocker.patch("metagpt.provider.human_provider.HumanProvider.ask", mock_llm_ask) + human_provider = HumanProvider() + + assert resp_content == human_provider.ask(None) + + assert not human_provider.completion(messages=[]) + + +@pytest.mark.asyncio +async def test_async_human_provider(mocker): + mocker.patch("metagpt.provider.human_provider.HumanProvider.aask", mock_llm_aask) + human_provider = HumanProvider() + + resp = await human_provider.aask(None) + assert resp_content == resp + + resp = await human_provider.acompletion([]) + assert not resp diff --git a/tests/metagpt/provider/test_metagpt_llm_api.py b/tests/metagpt/provider/test_metagpt_llm_api.py index 9c8356ca6..f454b08a7 100644 --- a/tests/metagpt/provider/test_metagpt_llm_api.py +++ b/tests/metagpt/provider/test_metagpt_llm_api.py @@ -5,11 +5,11 @@ @Author : mashenquan @File : test_metagpt_llm_api.py """ -from metagpt.provider.metagpt_llm_api import MetaGPTLLMAPI +from metagpt.provider.metagpt_api import MetaGPTAPI def test_metagpt(): - llm = MetaGPTLLMAPI() + llm = MetaGPTAPI() assert llm diff --git a/tests/metagpt/provider/test_ollama_api.py b/tests/metagpt/provider/test_ollama_api.py index 2798f5cc3..d552d9f9e 100644 --- a/tests/metagpt/provider/test_ollama_api.py +++ b/tests/metagpt/provider/test_ollama_api.py @@ -4,30 +4,58 @@ import pytest +from metagpt.config import CONFIG from metagpt.provider.ollama_api import OllamaGPTAPI -messages = [{"role": "user", "content": "who are you"}] +prompt_msg = "who are you" +messages = [{"role": "user", "content": prompt_msg}] + +resp_content = "I'm ollama" +default_resp = {"message": {"role": "assistant", "content": resp_content}} + +CONFIG.ollama_api_base = "http://xxx" -default_resp = {"message": {"role": "assisant", "content": "I'm ollama"}} - - -def mock_llm_ask(self, messages: list[dict]) -> dict: +def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> dict: return default_resp +async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False, timeout: int = 60) -> dict: + return default_resp + + +async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str: + return resp_content + + def test_gemini_completion(mocker): - mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.completion", mock_llm_ask) - resp = OllamaGPTAPI().completion(messages) + mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.completion", mock_llm_completion) + ollama_gpt = OllamaGPTAPI() + resp = ollama_gpt.completion(messages) assert resp["message"]["content"] == default_resp["message"]["content"] - -async def mock_llm_aask(self, messgaes: list[dict]) -> dict: - return default_resp + resp = ollama_gpt.ask(prompt_msg) + assert resp == resp_content @pytest.mark.asyncio async def test_gemini_acompletion(mocker): - mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.acompletion", mock_llm_aask) - resp = await OllamaGPTAPI().acompletion(messages) + mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.acompletion", mock_llm_acompletion) + mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI._achat_completion", mock_llm_acompletion) + mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream) + ollama_gpt = OllamaGPTAPI() + + resp = await ollama_gpt.acompletion(messages) assert resp["message"]["content"] == default_resp["message"]["content"] + + resp = await ollama_gpt.aask(prompt_msg, stream=False) + assert resp == resp_content + + resp = await ollama_gpt.acompletion_text(messages, stream=False) + assert resp == resp_content + + resp = await ollama_gpt.acompletion_text(messages, stream=True) + assert resp == resp_content + + resp = await ollama_gpt.aask(prompt_msg) + assert resp == resp_content diff --git a/tests/metagpt/provider/test_openai.py b/tests/metagpt/provider/test_openai.py index 332d554cf..1f25951b1 100644 --- a/tests/metagpt/provider/test_openai.py +++ b/tests/metagpt/provider/test_openai.py @@ -85,14 +85,23 @@ def test_ask_code_list_str(): class TestOpenAI: @pytest.fixture def config(self): - return Mock(openai_api_key="test_key", openai_base_url="test_url", openai_proxy=None, openai_api_type="other") + return Mock( + openai_api_key="test_key", + OPENAI_API_KEY="test_key", + openai_base_url="test_url", + OPENAI_BASE_URL="test_url", + openai_proxy=None, + openai_api_type="other", + ) @pytest.fixture def config_azure(self): return Mock( openai_api_key="test_key", + OPENAI_API_KEY="test_key", openai_api_version="test_version", openai_base_url="test_url", + OPENAI_BASE_URL="test_url", openai_proxy=None, openai_api_type="azure", ) @@ -101,7 +110,9 @@ class TestOpenAI: def config_proxy(self): return Mock( openai_api_key="test_key", + OPENAI_API_KEY="test_key", openai_base_url="test_url", + OPENAI_BASE_URL="test_url", openai_proxy="http://proxy.com", openai_api_type="other", ) @@ -110,8 +121,10 @@ class TestOpenAI: def config_azure_proxy(self): return Mock( openai_api_key="test_key", + OPENAI_API_KEY="test_key", openai_api_version="test_version", openai_base_url="test_url", + OPENAI_BASE_URL="test_url", openai_proxy="http://proxy.com", openai_api_type="azure", ) @@ -129,8 +142,8 @@ class TestOpenAI: instance = OpenAIGPTAPI() instance.config = config_azure kwargs, async_kwargs = instance._make_client_kwargs() - assert kwargs == {"api_key": "test_key", "api_version": "test_version", "azure_endpoint": "test_url"} - assert async_kwargs == {"api_key": "test_key", "api_version": "test_version", "azure_endpoint": "test_url"} + assert kwargs == {"api_key": "test_key", "base_url": "test_url"} + assert async_kwargs == {"api_key": "test_key", "base_url": "test_url"} assert "http_client" not in kwargs assert "http_client" not in async_kwargs diff --git a/tests/metagpt/provider/test_spark_api.py b/tests/metagpt/provider/test_spark_api.py index 3b3dd67f4..61ae8cbec 100644 --- a/tests/metagpt/provider/test_spark_api.py +++ b/tests/metagpt/provider/test_spark_api.py @@ -1,11 +1,51 @@ -from metagpt.logs import logger -from metagpt.provider.spark_api import SparkAPI +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : the unittest of spark api + +import pytest + +from metagpt.provider.spark_api import SparkGPTAPI + +prompt_msg = "who are you" +resp_content = "I'm Spark" -def test_message(): - llm = SparkAPI() +def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> str: + return resp_content - logger.info(llm.ask('只回答"收到了"这三个字。')) - result = llm.ask("写一篇五百字的日记") - logger.info(result) - assert len(result) > 100 + +async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False, timeout: int = 60) -> str: + return resp_content + + +def test_spark_completion(mocker): + mocker.patch("metagpt.provider.spark_api.SparkGPTAPI.completion", mock_llm_completion) + spark_gpt = SparkGPTAPI() + + resp = spark_gpt.completion([]) + assert resp == resp_content + + resp = spark_gpt.ask(prompt_msg) + assert resp == resp_content + + +@pytest.mark.asyncio +async def test_spark_acompletion(mocker): + mocker.patch("metagpt.provider.spark_api.SparkGPTAPI.acompletion", mock_llm_acompletion) + mocker.patch("metagpt.provider.spark_api.SparkGPTAPI.acompletion_text", mock_llm_acompletion) + spark_gpt = SparkGPTAPI() + + resp = await spark_gpt.acompletion([], stream=False) + assert resp == resp_content + + resp = await spark_gpt.aask(prompt_msg, stream=False) + assert resp == resp_content + + resp = await spark_gpt.acompletion_text([], stream=False) + assert resp == resp_content + + resp = await spark_gpt.acompletion_text([], stream=True) + assert resp == resp_content + + resp = await spark_gpt.aask(prompt_msg) + assert resp == resp_content diff --git a/tests/metagpt/provider/test_zhipuai_api.py b/tests/metagpt/provider/test_zhipuai_api.py index 4684e8887..ec02e1b47 100644 --- a/tests/metagpt/provider/test_zhipuai_api.py +++ b/tests/metagpt/provider/test_zhipuai_api.py @@ -4,34 +4,62 @@ import pytest +from metagpt.config import CONFIG from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI -default_resp = {"code": 200, "data": {"choices": [{"role": "assistant", "content": "I'm chatglm-turbo"}]}} +CONFIG.zhipuai_api_key = "xxx" -messages = [{"role": "user", "content": "who are you"}] +prompt_msg = "who are you" +messages = [{"role": "user", "content": prompt_msg}] + +resp_content = "I'm chatglm-turbo" +default_resp = {"code": 200, "data": {"choices": [{"role": "assistant", "content": resp_content}]}} -def mock_llm_ask(self, messages: list[dict]) -> dict: +def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> dict: return default_resp +async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False, timeout: int = 60) -> dict: + return default_resp + + +async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str: + return resp_content + + def test_zhipuai_completion(mocker): - mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.completion", mock_llm_ask) + mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.completion", mock_llm_completion) + zhipu_gpt = ZhiPuAIGPTAPI() - resp = ZhiPuAIGPTAPI().completion(messages) + resp = zhipu_gpt.completion(messages) assert resp["code"] == 200 - assert "chatglm-turbo" in resp["data"]["choices"][0]["content"] + assert resp["data"]["choices"][0]["content"] == resp_content - -async def mock_llm_aask(self, messgaes: list[dict], stream: bool = False) -> dict: - return default_resp + resp = zhipu_gpt.ask(prompt_msg) + assert resp == resp_content @pytest.mark.asyncio async def test_zhipuai_acompletion(mocker): - mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.acompletion_text", mock_llm_aask) + mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.acompletion", mock_llm_acompletion) + mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI._achat_completion", mock_llm_acompletion) + mocker.patch( + "metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream + ) + zhipu_gpt = ZhiPuAIGPTAPI() - resp = await ZhiPuAIGPTAPI().acompletion_text(messages, stream=False) + resp = await zhipu_gpt.acompletion(messages) + assert resp["data"]["choices"][0]["content"] == resp_content - assert resp["code"] == 200 - assert "chatglm-turbo" in resp["data"]["choices"][0]["content"] + resp = await zhipu_gpt.aask(prompt_msg, stream=False) + assert resp == resp_content + + resp = await zhipu_gpt.acompletion_text(messages, stream=False) + assert resp == resp_content + + resp = await zhipu_gpt.acompletion_text(messages, stream=True) + assert resp == resp_content + + resp = await zhipu_gpt.aask(prompt_msg) + assert resp == resp_content