mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-18 13:55:17 +02:00
Merge pull request #940 from better629/feat_new_qianfan
Feat add qianfan and dashscope
This commit is contained in:
commit
d3961a630e
16 changed files with 845 additions and 8 deletions
|
|
@ -42,3 +42,17 @@ mock_llm_config_zhipu = LLMConfig(
|
|||
model="mock_zhipu_model",
|
||||
proxy="http://localhost:8080",
|
||||
)
|
||||
|
||||
|
||||
mock_llm_config_spark = LLMConfig(
|
||||
api_type="spark",
|
||||
app_id="xxx",
|
||||
api_key="xxx",
|
||||
api_secret="xxx",
|
||||
domain="generalv2",
|
||||
base_url="wss://spark-api.xf-yun.com/v3.1/chat",
|
||||
)
|
||||
|
||||
mock_llm_config_qianfan = LLMConfig(api_type="qianfan", access_key="xxx", secret_key="xxx", model="ERNIE-Bot-turbo")
|
||||
|
||||
mock_llm_config_dashscope = LLMConfig(api_type="dashscope", api_key="xxx", model="qwen-max")
|
||||
|
|
|
|||
145
tests/metagpt/provider/req_resp_const.py
Normal file
145
tests/metagpt/provider/req_resp_const.py
Normal file
|
|
@ -0,0 +1,145 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : default request & response data for provider unittest
|
||||
|
||||
|
||||
from dashscope.api_entities.dashscope_response import (
|
||||
DashScopeAPIResponse,
|
||||
GenerationOutput,
|
||||
GenerationResponse,
|
||||
GenerationUsage,
|
||||
)
|
||||
from openai.types.chat.chat_completion import (
|
||||
ChatCompletion,
|
||||
ChatCompletionMessage,
|
||||
Choice,
|
||||
)
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion_chunk import Choice as AChoice
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
from qianfan.resources.typing import QfResponse
|
||||
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
|
||||
prompt = "who are you?"
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
resp_cont_tmpl = "I'm {name}"
|
||||
default_resp_cont = resp_cont_tmpl.format(name="GPT")
|
||||
|
||||
|
||||
# part of whole ChatCompletion of openai like structure
|
||||
def get_part_chat_completion(name: str) -> dict:
|
||||
part_chat_completion = {
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": resp_cont_tmpl.format(name=name),
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {"completion_tokens": 22, "prompt_tokens": 19, "total_tokens": 41},
|
||||
}
|
||||
return part_chat_completion
|
||||
|
||||
|
||||
def get_openai_chat_completion(name: str) -> ChatCompletion:
|
||||
openai_chat_completion = ChatCompletion(
|
||||
id="cmpl-a6652c1bb181caae8dd19ad8",
|
||||
model="xx/xxx",
|
||||
object="chat.completion",
|
||||
created=1703300855,
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(role="assistant", content=resp_cont_tmpl.format(name=name)),
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202),
|
||||
)
|
||||
return openai_chat_completion
|
||||
|
||||
|
||||
def get_openai_chat_completion_chunk(name: str, usage_as_dict: bool = False) -> ChatCompletionChunk:
|
||||
usage = CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202)
|
||||
usage = usage if not usage_as_dict else usage.model_dump()
|
||||
openai_chat_completion_chunk = ChatCompletionChunk(
|
||||
id="cmpl-a6652c1bb181caae8dd19ad8",
|
||||
model="xx/xxx",
|
||||
object="chat.completion.chunk",
|
||||
created=1703300855,
|
||||
choices=[
|
||||
AChoice(
|
||||
delta=ChoiceDelta(role="assistant", content=resp_cont_tmpl.format(name=name)),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
usage=usage,
|
||||
)
|
||||
return openai_chat_completion_chunk
|
||||
|
||||
|
||||
# For gemini
|
||||
gemini_messages = [{"role": "user", "parts": prompt}]
|
||||
|
||||
|
||||
# For QianFan
|
||||
qf_jsonbody_dict = {
|
||||
"id": "as-4v1h587fyv",
|
||||
"object": "chat.completion",
|
||||
"created": 1695021339,
|
||||
"result": "",
|
||||
"is_truncated": False,
|
||||
"need_clear_history": False,
|
||||
"usage": {"prompt_tokens": 7, "completion_tokens": 15, "total_tokens": 22},
|
||||
}
|
||||
|
||||
|
||||
def get_qianfan_response(name: str) -> QfResponse:
|
||||
qf_jsonbody_dict["result"] = resp_cont_tmpl.format(name=name)
|
||||
return QfResponse(code=200, body=qf_jsonbody_dict)
|
||||
|
||||
|
||||
# For DashScope
|
||||
def get_dashscope_response(name: str) -> GenerationResponse:
|
||||
return GenerationResponse.from_api_response(
|
||||
DashScopeAPIResponse(
|
||||
status_code=200,
|
||||
output=GenerationOutput(
|
||||
**{
|
||||
"text": "",
|
||||
"finish_reason": "",
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "stop",
|
||||
"message": {"role": "assistant", "content": resp_cont_tmpl.format(name=name)},
|
||||
}
|
||||
],
|
||||
}
|
||||
),
|
||||
usage=GenerationUsage(**{"input_tokens": 12, "output_tokens": 98, "total_tokens": 110}),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# For llm general chat functions call
|
||||
async def llm_general_chat_funcs_test(llm: BaseLLM, prompt: str, messages: list[dict], resp_cont: str):
|
||||
resp = await llm.aask(prompt, stream=False)
|
||||
assert resp == resp_cont
|
||||
|
||||
resp = await llm.aask(prompt)
|
||||
assert resp == resp_cont
|
||||
|
||||
resp = await llm.acompletion_text(messages, stream=False)
|
||||
assert resp == resp_cont
|
||||
|
||||
resp = await llm.acompletion_text(messages, stream=True)
|
||||
assert resp == resp_cont
|
||||
73
tests/metagpt/provider/test_dashscope_api.py
Normal file
73
tests/metagpt/provider/test_dashscope_api.py
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of DashScopeLLM
|
||||
|
||||
from typing import AsyncGenerator, Union
|
||||
|
||||
import pytest
|
||||
from dashscope.api_entities.dashscope_response import GenerationResponse
|
||||
|
||||
from metagpt.provider.dashscope_api import DashScopeLLM
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config_dashscope
|
||||
from tests.metagpt.provider.req_resp_const import (
|
||||
get_dashscope_response,
|
||||
llm_general_chat_funcs_test,
|
||||
messages,
|
||||
prompt,
|
||||
resp_cont_tmpl,
|
||||
)
|
||||
|
||||
name = "qwen-max"
|
||||
resp_cont = resp_cont_tmpl.format(name=name)
|
||||
|
||||
|
||||
@classmethod
|
||||
def mock_dashscope_call(
|
||||
cls,
|
||||
messages: list[dict],
|
||||
model: str,
|
||||
api_key: str,
|
||||
result_format: str,
|
||||
incremental_output: bool = True,
|
||||
stream: bool = False,
|
||||
) -> GenerationResponse:
|
||||
return get_dashscope_response(name)
|
||||
|
||||
|
||||
@classmethod
|
||||
async def mock_dashscope_acall(
|
||||
cls,
|
||||
messages: list[dict],
|
||||
model: str,
|
||||
api_key: str,
|
||||
result_format: str,
|
||||
incremental_output: bool = True,
|
||||
stream: bool = False,
|
||||
) -> Union[AsyncGenerator[GenerationResponse, None], GenerationResponse]:
|
||||
resps = [get_dashscope_response(name)]
|
||||
|
||||
if stream:
|
||||
|
||||
async def aresp_iterator(resps: list[GenerationResponse]):
|
||||
for resp in resps:
|
||||
yield resp
|
||||
|
||||
return aresp_iterator(resps)
|
||||
else:
|
||||
return resps[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dashscope_acompletion(mocker):
|
||||
mocker.patch("dashscope.aigc.generation.Generation.call", mock_dashscope_call)
|
||||
mocker.patch("metagpt.provider.dashscope_api.AGeneration.acall", mock_dashscope_acall)
|
||||
|
||||
dashscope_llm = DashScopeLLM(mock_llm_config_dashscope)
|
||||
|
||||
resp = dashscope_llm.completion(messages)
|
||||
assert resp.choices[0]["message"]["content"] == resp_cont
|
||||
|
||||
resp = await dashscope_llm.acompletion(messages)
|
||||
assert resp.choices[0]["message"]["content"] == resp_cont
|
||||
|
||||
await llm_general_chat_funcs_test(dashscope_llm, prompt, messages, resp_cont)
|
||||
56
tests/metagpt/provider/test_qianfan_api.py
Normal file
56
tests/metagpt/provider/test_qianfan_api.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of qianfan api
|
||||
|
||||
from typing import AsyncIterator, Union
|
||||
|
||||
import pytest
|
||||
from qianfan.resources.typing import JsonBody, QfResponse
|
||||
|
||||
from metagpt.provider.qianfan_api import QianFanLLM
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config_qianfan
|
||||
from tests.metagpt.provider.req_resp_const import (
|
||||
get_qianfan_response,
|
||||
llm_general_chat_funcs_test,
|
||||
messages,
|
||||
prompt,
|
||||
resp_cont_tmpl,
|
||||
)
|
||||
|
||||
name = "ERNIE-Bot-turbo"
|
||||
resp_cont = resp_cont_tmpl.format(name=name)
|
||||
|
||||
|
||||
def mock_qianfan_do(self, messages: list[dict], model: str, stream: bool = False, system: str = None) -> QfResponse:
|
||||
return get_qianfan_response(name=name)
|
||||
|
||||
|
||||
async def mock_qianfan_ado(
|
||||
self, messages: list[dict], model: str, stream: bool = True, system: str = None
|
||||
) -> Union[QfResponse, AsyncIterator[QfResponse]]:
|
||||
resps = [get_qianfan_response(name=name)]
|
||||
if stream:
|
||||
|
||||
async def aresp_iterator(resps: list[JsonBody]):
|
||||
for resp in resps:
|
||||
yield resp
|
||||
|
||||
return aresp_iterator(resps)
|
||||
else:
|
||||
return resps[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_qianfan_acompletion(mocker):
|
||||
mocker.patch("qianfan.resources.llm.chat_completion.ChatCompletion.do", mock_qianfan_do)
|
||||
mocker.patch("qianfan.resources.llm.chat_completion.ChatCompletion.ado", mock_qianfan_ado)
|
||||
|
||||
qianfan_llm = QianFanLLM(mock_llm_config_qianfan)
|
||||
|
||||
resp = qianfan_llm.completion(messages)
|
||||
assert resp.get("result") == resp_cont
|
||||
|
||||
resp = await qianfan_llm.acompletion(messages)
|
||||
assert resp.get("result") == resp_cont
|
||||
|
||||
await llm_general_chat_funcs_test(qianfan_llm, prompt, messages, resp_cont)
|
||||
|
|
@ -211,6 +211,11 @@ value
|
|||
output = repair_invalid_json(output, "Expecting ',' delimiter: line 4 column 1")
|
||||
assert output == target_output
|
||||
|
||||
raw_output = '{"key": "url "http" \\"https\\" "}'
|
||||
target_output = '{"key": "url \\"http\\" \\"https\\" "}'
|
||||
output = repair_invalid_json(raw_output, "Expecting ',' delimiter: line 1 column 15 (char 14)")
|
||||
assert output == target_output
|
||||
|
||||
|
||||
def test_retry_parse_json_text():
|
||||
from metagpt.utils.repair_llm_raw_output import retry_parse_json_text
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue