add dashscope including QWEN and its ut code

This commit is contained in:
better629 2024-02-07 22:50:30 +08:00
parent 997e25e97d
commit c7ee54ace1
9 changed files with 371 additions and 1 deletions

View file

@ -54,3 +54,5 @@ mock_llm_config_spark = LLMConfig(
)
mock_llm_config_qianfan = LLMConfig(api_type="qianfan", access_key="xxx", secret_key="xxx", model="ERNIE-Bot-turbo")
mock_llm_config_dashscope = LLMConfig(api_type="dashscore", api_key="xxx", model="qwen-max")

View file

@ -3,6 +3,12 @@
# @Desc : default request & response data for provider unittest
from dashscope.api_entities.dashscope_response import (
DashScopeAPIResponse,
GenerationOutput,
GenerationResponse,
GenerationUsage,
)
from openai.types.chat.chat_completion import (
ChatCompletion,
ChatCompletionMessage,
@ -102,6 +108,28 @@ def get_qianfan_response(name: str) -> QfResponse:
return QfResponse(code=200, body=qf_jsonbody_dict)
# For DashScope
def get_dashscope_response(name: str) -> GenerationResponse:
return GenerationResponse.from_api_response(
DashScopeAPIResponse(
status_code=200,
output=GenerationOutput(
**{
"text": "",
"finish_reason": "",
"choices": [
{
"finish_reason": "stop",
"message": {"role": "assistant", "content": resp_cont_tmpl.format(name=name)},
}
],
}
),
usage=GenerationUsage(**{"input_tokens": 12, "output_tokens": 98, "total_tokens": 110}),
)
)
# For llm general chat functions call
async def llm_general_chat_funcs_test(llm: BaseLLM, prompt: str, messages: list[dict], resp_cont: str):
resp = await llm.aask(prompt, stream=False)

View file

@ -0,0 +1,61 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : the unittest of DashScopeLLM
from typing import AsyncGenerator, Union
import pytest
from dashscope.api_entities.dashscope_response import GenerationResponse
from metagpt.provider.dashscope_api import DashScopeLLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config_dashscope
from tests.metagpt.provider.req_resp_const import (
get_dashscope_response,
llm_general_chat_funcs_test,
messages,
prompt,
resp_cont_tmpl,
)
name = "qwen-max"
resp_cont = resp_cont_tmpl.format(name=name)
@classmethod
def mock_dashscope_call(
cls, messages: list[dict], model: str, api_key: str, result_format: str, stream: bool = False
) -> GenerationResponse:
return get_dashscope_response(name)
@classmethod
async def mock_dashscope_acall(
cls, messages: list[dict], model: str, api_key: str, result_format: str, stream: bool = False
) -> Union[AsyncGenerator[GenerationResponse, None], GenerationResponse]:
resps = [get_dashscope_response(name)]
if stream:
async def aresp_iterator(resps: list[GenerationResponse]):
for resp in resps:
yield resp
return aresp_iterator(resps)
else:
return resps[0]
@pytest.mark.asyncio
async def test_dashscope_acompletion(mocker):
mocker.patch("dashscope.aigc.generation.Generation.call", mock_dashscope_call)
mocker.patch("metagpt.provider.dashscope_api.AGeneration.acall", mock_dashscope_acall)
dashscore_llm = DashScopeLLM(mock_llm_config_dashscope)
resp = dashscore_llm.completion(messages)
assert resp.choices[0]["message"]["content"] == resp_cont
resp = await dashscore_llm.acompletion(messages)
assert resp.choices[0]["message"]["content"] == resp_cont
await llm_general_chat_funcs_test(dashscore_llm, prompt, messages, resp_cont)