From e372430d9b02a6c9a60c17534ddd2ab9f3187d1d Mon Sep 17 00:00:00 2001 From: geekan Date: Mon, 8 Jan 2024 17:49:09 +0800 Subject: [PATCH] fix tests --- metagpt/provider/azure_openai_api.py | 8 ++++---- tests/metagpt/provider/mock_llm_config.py | 9 +++++++++ tests/metagpt/provider/test_azure_llm.py | 10 ++++++---- tests/metagpt/provider/test_spark_api.py | 3 ++- tests/metagpt/provider/test_zhipuai_api.py | 6 ++---- 5 files changed, 23 insertions(+), 13 deletions(-) diff --git a/metagpt/provider/azure_openai_api.py b/metagpt/provider/azure_openai_api.py index bd965f2cf..0b46b1fa7 100644 --- a/metagpt/provider/azure_openai_api.py +++ b/metagpt/provider/azure_openai_api.py @@ -28,13 +28,13 @@ class AzureOpenAILLM(OpenAILLM): kwargs = self._make_client_kwargs() # https://learn.microsoft.com/zh-cn/azure/ai-services/openai/how-to/migration?tabs=python-new%2Cdalle-fix self.aclient = AsyncAzureOpenAI(**kwargs) - self.model = self.config.DEPLOYMENT_NAME # Used in _calc_usage & _cons_kwargs + self.model = self.config.model # Used in _calc_usage & _cons_kwargs def _make_client_kwargs(self) -> dict: kwargs = dict( - api_key=self.config.OPENAI_API_KEY, - api_version=self.config.OPENAI_API_VERSION, - azure_endpoint=self.config.OPENAI_BASE_URL, + api_key=self.config.api_key, + api_version=self.config.api_version, + azure_endpoint=self.config.base_url, ) # to use proxy, openai v1 needs http_client diff --git a/tests/metagpt/provider/mock_llm_config.py b/tests/metagpt/provider/mock_llm_config.py index 6b1b52335..57f17e427 100644 --- a/tests/metagpt/provider/mock_llm_config.py +++ b/tests/metagpt/provider/mock_llm_config.py @@ -21,3 +21,12 @@ mock_llm_config_proxy = LLMConfig( base_url="mock_base_url", proxy="http://localhost:8080", ) + + +mock_llm_config_azure = LLMConfig( + llm_type="azure", + api_version="2023-09-01-preview", + api_key="mock_api_key", + base_url="mock_base_url", + proxy="http://localhost:8080", +) diff --git a/tests/metagpt/provider/test_azure_llm.py b/tests/metagpt/provider/test_azure_llm.py index 4437eec3b..51e051145 100644 --- a/tests/metagpt/provider/test_azure_llm.py +++ b/tests/metagpt/provider/test_azure_llm.py @@ -2,9 +2,11 @@ # -*- coding: utf-8 -*- # @Desc : - -from metagpt.context import context +from metagpt.provider import AzureOpenAILLM +from tests.metagpt.provider.mock_llm_config import mock_llm_config_azure -def test_azure_openai_api(): - _ = context.llm("azure") +def test_azure_llm(): + llm = AzureOpenAILLM(mock_llm_config_azure) + kwargs = llm._make_client_kwargs() + assert kwargs["azure_endpoint"] == mock_llm_config_azure.base_url diff --git a/tests/metagpt/provider/test_spark_api.py b/tests/metagpt/provider/test_spark_api.py index 8c6218ac4..aded1d9f0 100644 --- a/tests/metagpt/provider/test_spark_api.py +++ b/tests/metagpt/provider/test_spark_api.py @@ -6,6 +6,7 @@ import pytest from metagpt.config2 import config from metagpt.provider.spark_api import GetMessageFromWeb, SparkLLM +from tests.metagpt.provider.mock_llm_config import mock_llm_config prompt_msg = "who are you" resp_content = "I'm Spark" @@ -37,7 +38,7 @@ def mock_spark_get_msg_from_web_run(self) -> str: async def test_spark_acompletion(mocker): mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run) - spark_gpt = SparkLLM() + spark_gpt = SparkLLM(mock_llm_config) resp = await spark_gpt.acompletion([]) assert resp == resp_content diff --git a/tests/metagpt/provider/test_zhipuai_api.py b/tests/metagpt/provider/test_zhipuai_api.py index ab240260c..fd5067715 100644 --- a/tests/metagpt/provider/test_zhipuai_api.py +++ b/tests/metagpt/provider/test_zhipuai_api.py @@ -5,10 +5,8 @@ import pytest from zhipuai.utils.sse_client import Event -from metagpt.config import CONFIG from metagpt.provider.zhipuai_api import ZhiPuAILLM - -CONFIG.zhipuai_api_key = "xxx.xxx" +from tests.metagpt.provider.mock_llm_config import mock_llm_config prompt_msg = "who are you" messages = [{"role": "user", "content": prompt_msg}] @@ -65,7 +63,7 @@ async def test_zhipuai_acompletion(mocker): mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.ainvoke", mock_zhipuai_ainvoke) mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.asse_invoke", mock_zhipuai_asse_invoke) - zhipu_gpt = ZhiPuAILLM() + zhipu_gpt = ZhiPuAILLM(mock_llm_config) resp = await zhipu_gpt.acompletion(messages) assert resp["data"]["choices"][0]["content"] == resp_content