fix tests

This commit is contained in:
geekan 2024-01-08 17:49:09 +08:00
parent a07f955124
commit e372430d9b
5 changed files with 23 additions and 13 deletions

View file

@ -28,13 +28,13 @@ class AzureOpenAILLM(OpenAILLM):
kwargs = self._make_client_kwargs()
# https://learn.microsoft.com/zh-cn/azure/ai-services/openai/how-to/migration?tabs=python-new%2Cdalle-fix
self.aclient = AsyncAzureOpenAI(**kwargs)
self.model = self.config.DEPLOYMENT_NAME # Used in _calc_usage & _cons_kwargs
self.model = self.config.model # Used in _calc_usage & _cons_kwargs
def _make_client_kwargs(self) -> dict:
kwargs = dict(
api_key=self.config.OPENAI_API_KEY,
api_version=self.config.OPENAI_API_VERSION,
azure_endpoint=self.config.OPENAI_BASE_URL,
api_key=self.config.api_key,
api_version=self.config.api_version,
azure_endpoint=self.config.base_url,
)
# to use proxy, openai v1 needs http_client

View file

@ -21,3 +21,12 @@ mock_llm_config_proxy = LLMConfig(
base_url="mock_base_url",
proxy="http://localhost:8080",
)
mock_llm_config_azure = LLMConfig(
llm_type="azure",
api_version="2023-09-01-preview",
api_key="mock_api_key",
base_url="mock_base_url",
proxy="http://localhost:8080",
)

View file

@ -2,9 +2,11 @@
# -*- coding: utf-8 -*-
# @Desc :
from metagpt.context import context
from metagpt.provider import AzureOpenAILLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config_azure
def test_azure_openai_api():
_ = context.llm("azure")
def test_azure_llm():
llm = AzureOpenAILLM(mock_llm_config_azure)
kwargs = llm._make_client_kwargs()
assert kwargs["azure_endpoint"] == mock_llm_config_azure.base_url

View file

@ -6,6 +6,7 @@ import pytest
from metagpt.config2 import config
from metagpt.provider.spark_api import GetMessageFromWeb, SparkLLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config
prompt_msg = "who are you"
resp_content = "I'm Spark"
@ -37,7 +38,7 @@ def mock_spark_get_msg_from_web_run(self) -> str:
async def test_spark_acompletion(mocker):
mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run)
spark_gpt = SparkLLM()
spark_gpt = SparkLLM(mock_llm_config)
resp = await spark_gpt.acompletion([])
assert resp == resp_content

View file

@ -5,10 +5,8 @@
import pytest
from zhipuai.utils.sse_client import Event
from metagpt.config import CONFIG
from metagpt.provider.zhipuai_api import ZhiPuAILLM
CONFIG.zhipuai_api_key = "xxx.xxx"
from tests.metagpt.provider.mock_llm_config import mock_llm_config
prompt_msg = "who are you"
messages = [{"role": "user", "content": prompt_msg}]
@ -65,7 +63,7 @@ async def test_zhipuai_acompletion(mocker):
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.ainvoke", mock_zhipuai_ainvoke)
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.asse_invoke", mock_zhipuai_asse_invoke)
zhipu_gpt = ZhiPuAILLM()
zhipu_gpt = ZhiPuAILLM(mock_llm_config)
resp = await zhipu_gpt.acompletion(messages)
assert resp["data"]["choices"][0]["content"] == resp_content