add dashscope including QWEN and its ut code

This commit is contained in:
better629 2024-02-07 22:50:30 +08:00
parent 997e25e97d
commit c7ee54ace1
9 changed files with 371 additions and 1 deletions

View file

@ -21,7 +21,7 @@ async def main():
logger.info(
await llm.aask(
"who are you", system_msgs=["act as a robot, answer 'I'am robot' if the question is 'who are you'"]
"who are you", system_msgs=["act as a robot, just answer 'I'am robot' if the question is 'who are you'"]
)
)

View file

@ -25,6 +25,7 @@ class LLMType(Enum):
AZURE = "azure"
OLLAMA = "ollama"
QIANFAN = "qianfan" # Baidu BCE
DASHSCOPE = "dashscope" # Aliyun LingJi DashScope
def __missing__(self, key):
return self.OPENAI

View file

@ -17,6 +17,7 @@ from metagpt.provider.metagpt_api import MetaGPTLLM
from metagpt.provider.human_provider import HumanProvider
from metagpt.provider.spark_api import SparkLLM
from metagpt.provider.qianfan_api import QianFanLLM
from metagpt.provider.dashscope_api import DashScopeLLM
__all__ = [
"FireworksLLM",
@ -30,4 +31,5 @@ __all__ = [
"HumanProvider",
"SparkLLM",
"QianFanLLM",
"DashScopeLLM",
]

View file

@ -0,0 +1,246 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
import json
from http import HTTPStatus
from typing import Any, AsyncGenerator, Dict, List, Union
import dashscope
from dashscope.aigc.generation import Generation
from dashscope.api_entities.aiohttp_request import AioHttpRequest
from dashscope.api_entities.api_request_data import ApiRequestData
from dashscope.api_entities.api_request_factory import _get_protocol_params
from dashscope.api_entities.dashscope_response import (
GenerationOutput,
GenerationResponse,
Message,
)
from dashscope.client.base_api import BaseAioApi
from dashscope.common.constants import SERVICE_API_PATH, ApiProtocol
from dashscope.common.error import (
InputDataRequired,
InputRequired,
ModelRequired,
UnsupportedApiProtocol,
)
from tenacity import (
after_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_random_exponential,
)
from metagpt.logs import log_llm_stream, logger
from metagpt.provider.base_llm import BaseLLM, LLMConfig
from metagpt.provider.llm_provider_registry import LLMType, register_provider
from metagpt.provider.openai_api import log_and_reraise
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.token_counter import DashScore_TOKEN_COSTS
def build_api_arequest(
model: str, input: object, task_group: str, task: str, function: str, api_key: str, is_service=True, **kwargs
):
(
api_protocol,
ws_stream_mode,
is_binary_input,
http_method,
stream,
async_request,
query,
headers,
request_timeout,
form,
resources,
) = _get_protocol_params(kwargs)
task_id = kwargs.pop("task_id", None)
if api_protocol in [ApiProtocol.HTTP, ApiProtocol.HTTPS]:
if not dashscope.base_http_api_url.endswith("/"):
http_url = dashscope.base_http_api_url + "/"
else:
http_url = dashscope.base_http_api_url
if is_service:
http_url = http_url + SERVICE_API_PATH + "/"
if task_group:
http_url += "%s/" % task_group
if task:
http_url += "%s/" % task
if function:
http_url += function
request = AioHttpRequest(
url=http_url,
api_key=api_key,
http_method=http_method,
stream=stream,
async_request=async_request,
query=query,
timeout=request_timeout,
task_id=task_id,
)
else:
raise UnsupportedApiProtocol("Unsupported protocol: %s, support [http, https, websocket]" % api_protocol)
if headers is not None:
request.add_headers(headers=headers)
if input is None and form is None:
raise InputDataRequired("There is no input data and form data")
request_data = ApiRequestData(
model,
task_group=task_group,
task=task,
function=function,
input=input,
form=form,
is_binary_input=is_binary_input,
api_protocol=api_protocol,
)
request_data.add_resources(resources)
request_data.add_parameters(**kwargs)
request.data = request_data
return request
class AGeneration(Generation, BaseAioApi):
@classmethod
async def acall(
cls,
model: str,
prompt: Any = None,
history: list = None,
api_key: str = None,
messages: List[Message] = None,
plugins: Union[str, Dict[str, Any]] = None,
**kwargs,
) -> Union[GenerationResponse, AsyncGenerator[GenerationResponse, None]]:
if (prompt is None or not prompt) and (messages is None or not messages):
raise InputRequired("prompt or messages is required!")
if model is None or not model:
raise ModelRequired("Model is required!")
task_group, function = "aigc", "generation" # fixed value
if plugins is not None:
headers = kwargs.pop("headers", {})
if isinstance(plugins, str):
headers["X-DashScope-Plugin"] = plugins
else:
headers["X-DashScope-Plugin"] = json.dumps(plugins)
kwargs["headers"] = headers
input, parameters = cls._build_input_parameters(model, prompt, history, messages, **kwargs)
api_key, model = BaseAioApi._validate_params(api_key, model)
request = build_api_arequest(
model=model,
input=input,
task_group=task_group,
task=Generation.task,
function=function,
api_key=api_key,
**kwargs,
)
response = await request.aio_call()
is_stream = kwargs.get("stream", False)
if is_stream:
async def aresp_iterator(response):
async for resp in response:
yield GenerationResponse.from_api_response(resp)
return aresp_iterator(response)
else:
return GenerationResponse.from_api_response(response)
@register_provider(LLMType.DASHSCOPE)
class DashScopeLLM(BaseLLM):
def __init__(self, llm_config: LLMConfig):
self.config = llm_config
self.use_system_prompt = False # only some models support system_prompt
self.__init_dashscope()
self.cost_manager = CostManager(token_costs=self.token_costs)
def __init_dashscope(self):
self.model = self.config.model
self.api_key = self.config.api_key
self.token_costs = DashScore_TOKEN_COSTS
self.aclient: AGeneration = AGeneration
# check support system_message models
support_system_models = [
"qwen-", # all support
"llama2-", # all support
"baichuan2-7b-chat-v1",
"chatglm3-6b",
]
for support_model in support_system_models:
if support_model in self.model:
self.use_system_prompt = True
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
kwargs = {
"api_key": self.api_key,
"model": self.model,
"messages": messages,
"stream": stream,
"result_format": "message",
}
if self.config.temperature > 0:
# different model has default temperature. only set when it"s specified.
kwargs["temperature"] = self.config.temperature
return kwargs
def _check_response(self, resp: GenerationResponse):
if resp.status_code != HTTPStatus.OK:
raise RuntimeError(f"code: {resp.code}, request_id: {resp.request_id}, message: {resp.message}")
def get_choice_text(self, output: GenerationOutput) -> str:
return output.get("choices", [{}])[0].get("message", {}).get("content", "")
def completion(self, messages: list[dict]) -> GenerationOutput:
resp: GenerationResponse = self.aclient.call(**self._const_kwargs(messages, stream=False))
self._check_response(resp)
self._update_costs(dict(resp.usage))
return resp.output
async def _achat_completion(self, messages: list[dict]) -> GenerationOutput:
resp: GenerationResponse = await self.aclient.acall(**self._const_kwargs(messages, stream=False))
self._check_response(resp)
self._update_costs(dict(resp.usage))
return resp.output
async def acompletion(self, messages: list[dict], timeout=3) -> GenerationOutput:
return await self._achat_completion(messages)
async def _achat_completion_stream(self, messages: list[dict]) -> str:
resp = await self.aclient.acall(**self._const_kwargs(messages, stream=True))
collected_content = []
usage = {}
async for chunk in resp:
self._check_response(chunk)
content = chunk.output.choices[0]["message"]["content"]
usage = dict(chunk.usage) # each chunk has usage
log_llm_stream(content)
collected_content.append(content)
log_llm_stream("\n")
self._update_costs(usage)
full_content = "".join(collected_content)
return full_content
@retry(
stop=stop_after_attempt(3),
wait=wait_random_exponential(min=1, max=60),
after=after_log(logger, logger.level("WARNING").name),
retry=retry_if_exception_type(ConnectionError),
retry_error_callback=log_and_reraise,
)
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
if stream:
return await self._achat_completion_stream(messages)
resp = await self._achat_completion(messages)
return self.get_choice_text(resp)

View file

@ -90,6 +90,35 @@ QianFan_EndPoint_TOKEN_COSTS = {
"yi_34b_chat": QianFan_MODEL_TOKEN_COSTS["Yi-34B-Chat"],
}
"""
DashScore Token price https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing
Different model has different detail page. Attention, some model are free for a limited time.
"""
DashScore_TOKEN_COSTS = {
"qwen-turbo": {"prompt": 0.0011, "completion": 0.0011},
"qwen-plus": {"prompt": 0.0028, "completion": 0.0028},
"qwen-max": {"prompt": 0.0, "completion": 0.0},
"qwen-max-1201": {"prompt": 0.0, "completion": 0.0},
"qwen-max-longcontext": {"prompt": 0.0, "completion": 0.0},
"llama2-7b-chat-v2": {"prompt": 0.0, "completion": 0.0},
"llama2-13b-chat-v2": {"prompt": 0.0, "completion": 0.0},
"qwen-72b-chat": {"prompt": 0.0, "completion": 0.0},
"qwen-14b-chat": {"prompt": 0.0011, "completion": 0.0011},
"qwen-7b-chat": {"prompt": 0.00084, "completion": 0.00084},
"qwen-1.8b-chat": {"prompt": 0.0, "completion": 0.0},
"baichuan2-13b-chat-v1": {"prompt": 0.0011, "completion": 0.0011},
"baichuan2-7b-chat-v1": {"prompt": 0.00084, "completion": 0.00084},
"baichuan-7b-v1": {"prompt": 0.0, "completion": 0.0},
"chatglm-6b-v2": {"prompt": 0.0011, "completion": 0.0011},
"chatglm3-6b": {"prompt": 0.0, "completion": 0.0},
"ziya-llama-13b-v1": {"prompt": 0.0, "completion": 0.0}, # no price page, judge it as free
"dolly-12b-v2": {"prompt": 0.0, "completion": 0.0},
"belle-llama-13b-2m-v1": {"prompt": 0.0, "completion": 0.0},
"moss-moon-003-sft-v1": {"prompt": 0.0, "completion": 0.0},
"chatyuan-large-v2": {"prompt": 0.0, "completion": 0.0},
"billa-7b-sft-v1": {"prompt": 0.0, "completion": 0.0},
}
TOKEN_MAX = {
"gpt-3.5-turbo": 4096,

View file

@ -68,3 +68,4 @@ anytree
ipywidgets==8.1.1
Pillow
qianfan==0.3.1
dashscope==1.14.1

View file

@ -54,3 +54,5 @@ mock_llm_config_spark = LLMConfig(
)
mock_llm_config_qianfan = LLMConfig(api_type="qianfan", access_key="xxx", secret_key="xxx", model="ERNIE-Bot-turbo")
mock_llm_config_dashscope = LLMConfig(api_type="dashscore", api_key="xxx", model="qwen-max")

View file

@ -3,6 +3,12 @@
# @Desc : default request & response data for provider unittest
from dashscope.api_entities.dashscope_response import (
DashScopeAPIResponse,
GenerationOutput,
GenerationResponse,
GenerationUsage,
)
from openai.types.chat.chat_completion import (
ChatCompletion,
ChatCompletionMessage,
@ -102,6 +108,28 @@ def get_qianfan_response(name: str) -> QfResponse:
return QfResponse(code=200, body=qf_jsonbody_dict)
# For DashScope
def get_dashscope_response(name: str) -> GenerationResponse:
return GenerationResponse.from_api_response(
DashScopeAPIResponse(
status_code=200,
output=GenerationOutput(
**{
"text": "",
"finish_reason": "",
"choices": [
{
"finish_reason": "stop",
"message": {"role": "assistant", "content": resp_cont_tmpl.format(name=name)},
}
],
}
),
usage=GenerationUsage(**{"input_tokens": 12, "output_tokens": 98, "total_tokens": 110}),
)
)
# For llm general chat functions call
async def llm_general_chat_funcs_test(llm: BaseLLM, prompt: str, messages: list[dict], resp_cont: str):
resp = await llm.aask(prompt, stream=False)

View file

@ -0,0 +1,61 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : the unittest of DashScopeLLM
from typing import AsyncGenerator, Union
import pytest
from dashscope.api_entities.dashscope_response import GenerationResponse
from metagpt.provider.dashscope_api import DashScopeLLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config_dashscope
from tests.metagpt.provider.req_resp_const import (
get_dashscope_response,
llm_general_chat_funcs_test,
messages,
prompt,
resp_cont_tmpl,
)
name = "qwen-max"
resp_cont = resp_cont_tmpl.format(name=name)
@classmethod
def mock_dashscope_call(
cls, messages: list[dict], model: str, api_key: str, result_format: str, stream: bool = False
) -> GenerationResponse:
return get_dashscope_response(name)
@classmethod
async def mock_dashscope_acall(
cls, messages: list[dict], model: str, api_key: str, result_format: str, stream: bool = False
) -> Union[AsyncGenerator[GenerationResponse, None], GenerationResponse]:
resps = [get_dashscope_response(name)]
if stream:
async def aresp_iterator(resps: list[GenerationResponse]):
for resp in resps:
yield resp
return aresp_iterator(resps)
else:
return resps[0]
@pytest.mark.asyncio
async def test_dashscope_acompletion(mocker):
mocker.patch("dashscope.aigc.generation.Generation.call", mock_dashscope_call)
mocker.patch("metagpt.provider.dashscope_api.AGeneration.acall", mock_dashscope_acall)
dashscore_llm = DashScopeLLM(mock_llm_config_dashscope)
resp = dashscore_llm.completion(messages)
assert resp.choices[0]["message"]["content"] == resp_cont
resp = await dashscore_llm.acompletion(messages)
assert resp.choices[0]["message"]["content"] == resp_cont
await llm_general_chat_funcs_test(dashscore_llm, prompt, messages, resp_cont)