add qianfan ut code and update xx_llm from xx_gpt

This commit is contained in:
better629 2024-02-07 18:32:32 +08:00
parent dc240a2efd
commit d3f6e38e8a
11 changed files with 153 additions and 80 deletions

View file

@ -52,3 +52,10 @@ mock_llm_config_spark = LLMConfig(
domain="generalv2",
base_url="wss://spark-api.xf-yun.com/v3.1/chat",
)
mock_llm_config_qianfan = LLMConfig(
api_type="qianfan",
access_key="xxx",
secret_key="xxx",
model="ERNIE-Bot-turbo"
)

View file

@ -2,6 +2,7 @@
# -*- coding: utf-8 -*-
# @Desc : default request & response data for provider unittest
from typing import Dict
from openai.types.chat.chat_completion import (
ChatCompletion,
ChatCompletionMessage,
@ -11,6 +12,9 @@ from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai.types.chat.chat_completion_chunk import Choice as AChoice
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from openai.types.completion_usage import CompletionUsage
from qianfan.resources.typing import QfResponse, default_field
from metagpt.provider.base_llm import BaseLLM
prompt = "who are you?"
messages = [{"role": "user", "content": prompt}]
@ -20,14 +24,14 @@ default_resp_cont = resp_cont_tmpl.format(name="GPT")
# part of whole ChatCompletion of openai like structure
def get_part_chat_completion(llm_name: str) -> dict:
def get_part_chat_completion(name: str) -> dict:
part_chat_completion = {
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": resp_cont_tmpl.format(name=llm_name),
"content": resp_cont_tmpl.format(name=name),
},
"finish_reason": "stop",
}
@ -37,7 +41,7 @@ def get_part_chat_completion(llm_name: str) -> dict:
return part_chat_completion
def get_openai_chat_completion(llm_name: str) -> ChatCompletion:
def get_openai_chat_completion(name: str) -> ChatCompletion:
openai_chat_completion = ChatCompletion(
id="cmpl-a6652c1bb181caae8dd19ad8",
model="xx/xxx",
@ -47,7 +51,7 @@ def get_openai_chat_completion(llm_name: str) -> ChatCompletion:
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(role="assistant", content=resp_cont_tmpl.format(name=llm_name)),
message=ChatCompletionMessage(role="assistant", content=resp_cont_tmpl.format(name=name)),
logprobs=None,
)
],
@ -56,7 +60,7 @@ def get_openai_chat_completion(llm_name: str) -> ChatCompletion:
return openai_chat_completion
def get_openai_chat_completion_chunk(llm_name: str, usage_as_dict: bool = False) -> ChatCompletionChunk:
def get_openai_chat_completion_chunk(name: str, usage_as_dict: bool = False) -> ChatCompletionChunk:
usage = CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202)
usage = usage if not usage_as_dict else usage.model_dump()
openai_chat_completion_chunk = ChatCompletionChunk(
@ -66,7 +70,7 @@ def get_openai_chat_completion_chunk(llm_name: str, usage_as_dict: bool = False)
created=1703300855,
choices=[
AChoice(
delta=ChoiceDelta(role="assistant", content=resp_cont_tmpl.format(name=llm_name)),
delta=ChoiceDelta(role="assistant", content=resp_cont_tmpl.format(name=name)),
finish_reason="stop",
index=0,
logprobs=None,
@ -76,5 +80,44 @@ def get_openai_chat_completion_chunk(llm_name: str, usage_as_dict: bool = False)
)
return openai_chat_completion_chunk
# For gemini
gemini_messages = [{"role": "user", "parts": prompt}]
# For QianFan
qf_jsonbody_dict = {
"id": "as-4v1h587fyv",
"object": "chat.completion",
"created": 1695021339,
"result": "",
"is_truncated": False,
"need_clear_history": False,
"usage": {
"prompt_tokens": 7,
"completion_tokens": 15,
"total_tokens": 22
}
}
def get_qianfan_response(name: str) -> QfResponse:
qf_jsonbody_dict["result"] = resp_cont_tmpl.format(name=name)
return QfResponse(
code=200,
body=qf_jsonbody_dict
)
# For llm general chat functions call
async def llm_general_chat_funcs_test(llm: BaseLLM, prompt: str, messages: list[dict], resp_cont: str):
resp = await llm.aask(prompt, stream=False)
assert resp == resp_cont
resp = await llm.aask(prompt)
assert resp == resp_cont
resp = await llm.acompletion_text(messages, stream=False)
assert resp == resp_cont
resp = await llm.acompletion_text(messages, stream=True)
assert resp == resp_cont

View file

@ -17,7 +17,7 @@ from tests.metagpt.provider.req_resp_const import (
prompt,
)
llm_name = "GPT"
name = "GPT"
class MockBaseLLM(BaseLLM):
@ -25,10 +25,10 @@ class MockBaseLLM(BaseLLM):
pass
def completion(self, messages: list[dict], timeout=3):
return get_part_chat_completion(llm_name)
return get_part_chat_completion(name)
async def acompletion(self, messages: list[dict], timeout=3):
return get_part_chat_completion(llm_name)
return get_part_chat_completion(name)
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
return default_resp_cont

View file

@ -21,10 +21,10 @@ from tests.metagpt.provider.req_resp_const import (
resp_cont_tmpl,
)
llm_name = "fireworks"
resp_cont = resp_cont_tmpl.format(name=llm_name)
default_resp = get_openai_chat_completion(llm_name)
default_resp_chunk = get_openai_chat_completion_chunk(llm_name, usage_as_dict=True)
name = "fireworks"
resp_cont = resp_cont_tmpl.format(name=name)
default_resp = get_openai_chat_completion(name)
default_resp_chunk = get_openai_chat_completion_chunk(name, usage_as_dict=True)
def test_fireworks_costmanager():
@ -57,27 +57,27 @@ async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs)
async def test_fireworks_acompletion(mocker):
mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create)
fireworks_gpt = FireworksLLM(mock_llm_config)
fireworks_gpt.model = "llama-v2-13b-chat"
fireworks_llm = FireworksLLM(mock_llm_config)
fireworks_llm.model = "llama-v2-13b-chat"
fireworks_gpt._update_costs(
fireworks_llm._update_costs(
usage=CompletionUsage(prompt_tokens=500000, completion_tokens=500000, total_tokens=1000000)
)
assert fireworks_gpt.get_costs() == Costs(
assert fireworks_llm.get_costs() == Costs(
total_prompt_tokens=500000, total_completion_tokens=500000, total_cost=0.5, total_budget=0
)
resp = await fireworks_gpt.acompletion(messages)
resp = await fireworks_llm.acompletion(messages)
assert resp.choices[0].message.content in resp_cont
resp = await fireworks_gpt.aask(prompt, stream=False)
resp = await fireworks_llm.aask(prompt, stream=False)
assert resp == resp_cont
resp = await fireworks_gpt.acompletion_text(messages, stream=False)
resp = await fireworks_llm.acompletion_text(messages, stream=False)
assert resp == resp_cont
resp = await fireworks_gpt.acompletion_text(messages, stream=True)
resp = await fireworks_llm.acompletion_text(messages, stream=True)
assert resp == resp_cont
resp = await fireworks_gpt.aask(prompt)
resp = await fireworks_llm.aask(prompt)
assert resp == resp_cont

View file

@ -63,28 +63,28 @@ async def test_gemini_acompletion(mocker):
mock_gemini_generate_content_async,
)
gemini_gpt = GeminiLLM(mock_llm_config)
gemini_llm = GeminiLLM(mock_llm_config)
assert gemini_gpt._user_msg(prompt) == {"role": "user", "parts": [prompt]}
assert gemini_gpt._assistant_msg(prompt) == {"role": "model", "parts": [prompt]}
assert gemini_llm._user_msg(prompt) == {"role": "user", "parts": [prompt]}
assert gemini_llm._assistant_msg(prompt) == {"role": "model", "parts": [prompt]}
usage = gemini_gpt.get_usage(gemini_messages, resp_cont)
usage = gemini_llm.get_usage(gemini_messages, resp_cont)
assert usage == {"prompt_tokens": 20, "completion_tokens": 20}
resp = gemini_gpt.completion(gemini_messages)
resp = gemini_llm.completion(gemini_messages)
assert resp == default_resp
resp = await gemini_gpt.acompletion(gemini_messages)
resp = await gemini_llm.acompletion(gemini_messages)
assert resp.text == default_resp.text
resp = await gemini_gpt.aask(prompt, stream=False)
resp = await gemini_llm.aask(prompt, stream=False)
assert resp == resp_cont
resp = await gemini_gpt.acompletion_text(gemini_messages, stream=False)
resp = await gemini_llm.acompletion_text(gemini_messages, stream=False)
assert resp == resp_cont
resp = await gemini_gpt.acompletion_text(gemini_messages, stream=True)
resp = await gemini_llm.acompletion_text(gemini_messages, stream=True)
assert resp == resp_cont
resp = await gemini_gpt.aask(prompt)
resp = await gemini_llm.aask(prompt)
assert resp == resp_cont

View file

@ -39,19 +39,19 @@ async def mock_ollama_arequest(self, stream: bool = False, **kwargs) -> Tuple[An
async def test_gemini_acompletion(mocker):
mocker.patch("metagpt.provider.general_api_requestor.GeneralAPIRequestor.arequest", mock_ollama_arequest)
ollama_gpt = OllamaLLM(mock_llm_config)
ollama_llm = OllamaLLM(mock_llm_config)
resp = await ollama_gpt.acompletion(messages)
resp = await ollama_llm.acompletion(messages)
assert resp["message"]["content"] == default_resp["message"]["content"]
resp = await ollama_gpt.aask(prompt, stream=False)
resp = await ollama_llm.aask(prompt, stream=False)
assert resp == resp_cont
resp = await ollama_gpt.acompletion_text(messages, stream=False)
resp = await ollama_llm.acompletion_text(messages, stream=False)
assert resp == resp_cont
resp = await ollama_gpt.acompletion_text(messages, stream=True)
resp = await ollama_llm.acompletion_text(messages, stream=True)
assert resp == resp_cont
resp = await ollama_gpt.aask(prompt)
resp = await ollama_llm.aask(prompt)
assert resp == resp_cont

View file

@ -17,11 +17,11 @@ from tests.metagpt.provider.req_resp_const import (
resp_cont_tmpl,
)
llm_name = "llama2-7b"
resp_cont = resp_cont_tmpl.format(name=llm_name)
default_resp = get_openai_chat_completion(llm_name)
name = "llama2-7b"
resp_cont = resp_cont_tmpl.format(name=name)
default_resp = get_openai_chat_completion(name)
default_resp_chunk = get_openai_chat_completion_chunk(llm_name)
default_resp_chunk = get_openai_chat_completion_chunk(name)
async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs) -> ChatCompletionChunk:
@ -40,26 +40,26 @@ async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs)
async def test_openllm_acompletion(mocker):
mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create)
openllm_gpt = OpenLLM(mock_llm_config)
openllm_gpt.model = "llama-v2-13b-chat"
openllm_llm = OpenLLM(mock_llm_config)
openllm_llm.model = "llama-v2-13b-chat"
openllm_gpt.cost_manager = CostManager()
openllm_gpt._update_costs(usage=CompletionUsage(prompt_tokens=100, completion_tokens=100, total_tokens=200))
assert openllm_gpt.get_costs() == Costs(
openllm_llm.cost_manager = CostManager()
openllm_llm._update_costs(usage=CompletionUsage(prompt_tokens=100, completion_tokens=100, total_tokens=200))
assert openllm_llm.get_costs() == Costs(
total_prompt_tokens=100, total_completion_tokens=100, total_cost=0, total_budget=0
)
resp = await openllm_gpt.acompletion(messages)
resp = await openllm_llm.acompletion(messages)
assert resp.choices[0].message.content in resp_cont
resp = await openllm_gpt.aask(prompt, stream=False)
resp = await openllm_llm.aask(prompt, stream=False)
assert resp == resp_cont
resp = await openllm_gpt.acompletion_text(messages, stream=False)
resp = await openllm_llm.acompletion_text(messages, stream=False)
assert resp == resp_cont
resp = await openllm_gpt.acompletion_text(messages, stream=True)
resp = await openllm_llm.acompletion_text(messages, stream=True)
assert resp == resp_cont
resp = await openllm_gpt.aask(prompt)
resp = await openllm_llm.aask(prompt)
assert resp == resp_cont

View file

@ -2,14 +2,45 @@
# -*- coding: utf-8 -*-
# @Desc : the unittest of qianfan api
from typing import Dict, Union, AsyncIterator
import pytest
from qianfan.resources.typing import JsonBody, QfResponse
from metagpt.provider.qianfan_api import QianFanLLM
from tests.metagpt.provider.req_resp_const import prompt, messages, resp_cont_tmpl
from tests.metagpt.provider.mock_llm_config import mock_llm_config_qianfan
from tests.metagpt.provider.req_resp_const import resp_cont_tmpl, prompt, messages, llm_general_chat_funcs_test, get_qianfan_response
name = "ERNIE-Bot-turbo"
resp_cont = resp_cont_tmpl.format(name=name)
resp_cont = resp_cont_tmpl.format(name="ERNIE-Bot-turbo")
def mock_qianfan_do(self, messages: list[dict], model: str, stream: bool = False, system: str = None) -> QfResponse:
return get_qianfan_response(name=name)
def test_qianfan_acompletion(mocker):
assert True, True
async def mock_qianfan_ado(self, messages: list[dict], model: str, stream: bool = True, system: str = None) -> Union[QfResponse, AsyncIterator[QfResponse]]:
resps = [get_qianfan_response(name=name)]
if stream:
async def aresp_iterator(resps: list[JsonBody]):
for resp in resps:
yield resp
return aresp_iterator(resps)
else:
return resps[0]
@pytest.mark.asyncio
async def test_qianfan_acompletion(mocker):
mocker.patch("qianfan.resources.llm.chat_completion.ChatCompletion.do", mock_qianfan_do)
mocker.patch("qianfan.resources.llm.chat_completion.ChatCompletion.ado", mock_qianfan_ado)
qianfan_llm = QianFanLLM(mock_llm_config_qianfan)
resp = qianfan_llm.completion(messages)
assert resp.get("result") == resp_cont
resp = await qianfan_llm.acompletion(messages)
assert resp.get("result") == resp_cont
await llm_general_chat_funcs_test(qianfan_llm, prompt, messages, resp_cont)

View file

@ -50,19 +50,19 @@ async def test_spark_aask(mocker):
async def test_spark_acompletion(mocker):
mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run)
spark_gpt = SparkLLM(mock_llm_config)
spark_llm = SparkLLM(mock_llm_config)
resp = await spark_gpt.acompletion([])
resp = await spark_llm.acompletion([])
assert resp == resp_cont
resp = await spark_gpt.aask(prompt, stream=False)
resp = await spark_llm.aask(prompt, stream=False)
assert resp == resp_cont
resp = await spark_gpt.acompletion_text([], stream=False)
resp = await spark_llm.acompletion_text([], stream=False)
assert resp == resp_cont
resp = await spark_gpt.acompletion_text([], stream=True)
resp = await spark_llm.acompletion_text([], stream=True)
assert resp == resp_cont
resp = await spark_gpt.aask(prompt)
resp = await spark_llm.aask(prompt)
assert resp == resp_cont

View file

@ -11,11 +11,12 @@ from tests.metagpt.provider.req_resp_const import (
messages,
prompt,
resp_cont_tmpl,
llm_general_chat_funcs_test
)
llm_name = "ChatGLM-4"
resp_cont = resp_cont_tmpl.format(name=llm_name)
default_resp = get_part_chat_completion(llm_name)
name = "ChatGLM-4"
resp_cont = resp_cont_tmpl.format(name=name)
default_resp = get_part_chat_completion(name)
async def mock_zhipuai_acreate_stream(**kwargs):
@ -47,22 +48,12 @@ async def test_zhipuai_acompletion(mocker):
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.acreate", mock_zhipuai_acreate)
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.acreate_stream", mock_zhipuai_acreate_stream)
zhipu_gpt = ZhiPuAILLM(mock_llm_config_zhipu)
zhipu_llm = ZhiPuAILLM(mock_llm_config_zhipu)
resp = await zhipu_gpt.acompletion(messages)
resp = await zhipu_llm.acompletion(messages)
assert resp["choices"][0]["message"]["content"] == resp_cont
resp = await zhipu_gpt.aask(prompt, stream=False)
assert resp == resp_cont
resp = await zhipu_gpt.acompletion_text(messages, stream=False)
assert resp == resp_cont
resp = await zhipu_gpt.acompletion_text(messages, stream=True)
assert resp == resp_cont
resp = await zhipu_gpt.aask(prompt)
assert resp == resp_cont
await llm_general_chat_funcs_test(zhipu_llm, prompt, messages, resp_cont)
def test_zhipuai_proxy():