mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-26 09:16:21 +02:00
add qianfan ut code and update xx_llm from xx_gpt
This commit is contained in:
parent
dc240a2efd
commit
d3f6e38e8a
11 changed files with 153 additions and 80 deletions
|
|
@ -52,3 +52,10 @@ mock_llm_config_spark = LLMConfig(
|
|||
domain="generalv2",
|
||||
base_url="wss://spark-api.xf-yun.com/v3.1/chat",
|
||||
)
|
||||
|
||||
mock_llm_config_qianfan = LLMConfig(
|
||||
api_type="qianfan",
|
||||
access_key="xxx",
|
||||
secret_key="xxx",
|
||||
model="ERNIE-Bot-turbo"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Desc : default request & response data for provider unittest
|
||||
|
||||
from typing import Dict
|
||||
from openai.types.chat.chat_completion import (
|
||||
ChatCompletion,
|
||||
ChatCompletionMessage,
|
||||
|
|
@ -11,6 +12,9 @@ from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
|||
from openai.types.chat.chat_completion_chunk import Choice as AChoice
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
from qianfan.resources.typing import QfResponse, default_field
|
||||
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
|
||||
prompt = "who are you?"
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
|
@ -20,14 +24,14 @@ default_resp_cont = resp_cont_tmpl.format(name="GPT")
|
|||
|
||||
|
||||
# part of whole ChatCompletion of openai like structure
|
||||
def get_part_chat_completion(llm_name: str) -> dict:
|
||||
def get_part_chat_completion(name: str) -> dict:
|
||||
part_chat_completion = {
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": resp_cont_tmpl.format(name=llm_name),
|
||||
"content": resp_cont_tmpl.format(name=name),
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
|
|
@ -37,7 +41,7 @@ def get_part_chat_completion(llm_name: str) -> dict:
|
|||
return part_chat_completion
|
||||
|
||||
|
||||
def get_openai_chat_completion(llm_name: str) -> ChatCompletion:
|
||||
def get_openai_chat_completion(name: str) -> ChatCompletion:
|
||||
openai_chat_completion = ChatCompletion(
|
||||
id="cmpl-a6652c1bb181caae8dd19ad8",
|
||||
model="xx/xxx",
|
||||
|
|
@ -47,7 +51,7 @@ def get_openai_chat_completion(llm_name: str) -> ChatCompletion:
|
|||
Choice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(role="assistant", content=resp_cont_tmpl.format(name=llm_name)),
|
||||
message=ChatCompletionMessage(role="assistant", content=resp_cont_tmpl.format(name=name)),
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
|
|
@ -56,7 +60,7 @@ def get_openai_chat_completion(llm_name: str) -> ChatCompletion:
|
|||
return openai_chat_completion
|
||||
|
||||
|
||||
def get_openai_chat_completion_chunk(llm_name: str, usage_as_dict: bool = False) -> ChatCompletionChunk:
|
||||
def get_openai_chat_completion_chunk(name: str, usage_as_dict: bool = False) -> ChatCompletionChunk:
|
||||
usage = CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202)
|
||||
usage = usage if not usage_as_dict else usage.model_dump()
|
||||
openai_chat_completion_chunk = ChatCompletionChunk(
|
||||
|
|
@ -66,7 +70,7 @@ def get_openai_chat_completion_chunk(llm_name: str, usage_as_dict: bool = False)
|
|||
created=1703300855,
|
||||
choices=[
|
||||
AChoice(
|
||||
delta=ChoiceDelta(role="assistant", content=resp_cont_tmpl.format(name=llm_name)),
|
||||
delta=ChoiceDelta(role="assistant", content=resp_cont_tmpl.format(name=name)),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
logprobs=None,
|
||||
|
|
@ -76,5 +80,44 @@ def get_openai_chat_completion_chunk(llm_name: str, usage_as_dict: bool = False)
|
|||
)
|
||||
return openai_chat_completion_chunk
|
||||
|
||||
|
||||
# For gemini
|
||||
gemini_messages = [{"role": "user", "parts": prompt}]
|
||||
|
||||
|
||||
# For QianFan
|
||||
qf_jsonbody_dict = {
|
||||
"id": "as-4v1h587fyv",
|
||||
"object": "chat.completion",
|
||||
"created": 1695021339,
|
||||
"result": "",
|
||||
"is_truncated": False,
|
||||
"need_clear_history": False,
|
||||
"usage": {
|
||||
"prompt_tokens": 7,
|
||||
"completion_tokens": 15,
|
||||
"total_tokens": 22
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def get_qianfan_response(name: str) -> QfResponse:
|
||||
qf_jsonbody_dict["result"] = resp_cont_tmpl.format(name=name)
|
||||
return QfResponse(
|
||||
code=200,
|
||||
body=qf_jsonbody_dict
|
||||
)
|
||||
|
||||
|
||||
# For llm general chat functions call
|
||||
async def llm_general_chat_funcs_test(llm: BaseLLM, prompt: str, messages: list[dict], resp_cont: str):
|
||||
resp = await llm.aask(prompt, stream=False)
|
||||
assert resp == resp_cont
|
||||
|
||||
resp = await llm.aask(prompt)
|
||||
assert resp == resp_cont
|
||||
|
||||
resp = await llm.acompletion_text(messages, stream=False)
|
||||
assert resp == resp_cont
|
||||
|
||||
resp = await llm.acompletion_text(messages, stream=True)
|
||||
assert resp == resp_cont
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from tests.metagpt.provider.req_resp_const import (
|
|||
prompt,
|
||||
)
|
||||
|
||||
llm_name = "GPT"
|
||||
name = "GPT"
|
||||
|
||||
|
||||
class MockBaseLLM(BaseLLM):
|
||||
|
|
@ -25,10 +25,10 @@ class MockBaseLLM(BaseLLM):
|
|||
pass
|
||||
|
||||
def completion(self, messages: list[dict], timeout=3):
|
||||
return get_part_chat_completion(llm_name)
|
||||
return get_part_chat_completion(name)
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=3):
|
||||
return get_part_chat_completion(llm_name)
|
||||
return get_part_chat_completion(name)
|
||||
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
|
||||
return default_resp_cont
|
||||
|
|
|
|||
|
|
@ -21,10 +21,10 @@ from tests.metagpt.provider.req_resp_const import (
|
|||
resp_cont_tmpl,
|
||||
)
|
||||
|
||||
llm_name = "fireworks"
|
||||
resp_cont = resp_cont_tmpl.format(name=llm_name)
|
||||
default_resp = get_openai_chat_completion(llm_name)
|
||||
default_resp_chunk = get_openai_chat_completion_chunk(llm_name, usage_as_dict=True)
|
||||
name = "fireworks"
|
||||
resp_cont = resp_cont_tmpl.format(name=name)
|
||||
default_resp = get_openai_chat_completion(name)
|
||||
default_resp_chunk = get_openai_chat_completion_chunk(name, usage_as_dict=True)
|
||||
|
||||
|
||||
def test_fireworks_costmanager():
|
||||
|
|
@ -57,27 +57,27 @@ async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs)
|
|||
async def test_fireworks_acompletion(mocker):
|
||||
mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create)
|
||||
|
||||
fireworks_gpt = FireworksLLM(mock_llm_config)
|
||||
fireworks_gpt.model = "llama-v2-13b-chat"
|
||||
fireworks_llm = FireworksLLM(mock_llm_config)
|
||||
fireworks_llm.model = "llama-v2-13b-chat"
|
||||
|
||||
fireworks_gpt._update_costs(
|
||||
fireworks_llm._update_costs(
|
||||
usage=CompletionUsage(prompt_tokens=500000, completion_tokens=500000, total_tokens=1000000)
|
||||
)
|
||||
assert fireworks_gpt.get_costs() == Costs(
|
||||
assert fireworks_llm.get_costs() == Costs(
|
||||
total_prompt_tokens=500000, total_completion_tokens=500000, total_cost=0.5, total_budget=0
|
||||
)
|
||||
|
||||
resp = await fireworks_gpt.acompletion(messages)
|
||||
resp = await fireworks_llm.acompletion(messages)
|
||||
assert resp.choices[0].message.content in resp_cont
|
||||
|
||||
resp = await fireworks_gpt.aask(prompt, stream=False)
|
||||
resp = await fireworks_llm.aask(prompt, stream=False)
|
||||
assert resp == resp_cont
|
||||
|
||||
resp = await fireworks_gpt.acompletion_text(messages, stream=False)
|
||||
resp = await fireworks_llm.acompletion_text(messages, stream=False)
|
||||
assert resp == resp_cont
|
||||
|
||||
resp = await fireworks_gpt.acompletion_text(messages, stream=True)
|
||||
resp = await fireworks_llm.acompletion_text(messages, stream=True)
|
||||
assert resp == resp_cont
|
||||
|
||||
resp = await fireworks_gpt.aask(prompt)
|
||||
resp = await fireworks_llm.aask(prompt)
|
||||
assert resp == resp_cont
|
||||
|
|
|
|||
|
|
@ -63,28 +63,28 @@ async def test_gemini_acompletion(mocker):
|
|||
mock_gemini_generate_content_async,
|
||||
)
|
||||
|
||||
gemini_gpt = GeminiLLM(mock_llm_config)
|
||||
gemini_llm = GeminiLLM(mock_llm_config)
|
||||
|
||||
assert gemini_gpt._user_msg(prompt) == {"role": "user", "parts": [prompt]}
|
||||
assert gemini_gpt._assistant_msg(prompt) == {"role": "model", "parts": [prompt]}
|
||||
assert gemini_llm._user_msg(prompt) == {"role": "user", "parts": [prompt]}
|
||||
assert gemini_llm._assistant_msg(prompt) == {"role": "model", "parts": [prompt]}
|
||||
|
||||
usage = gemini_gpt.get_usage(gemini_messages, resp_cont)
|
||||
usage = gemini_llm.get_usage(gemini_messages, resp_cont)
|
||||
assert usage == {"prompt_tokens": 20, "completion_tokens": 20}
|
||||
|
||||
resp = gemini_gpt.completion(gemini_messages)
|
||||
resp = gemini_llm.completion(gemini_messages)
|
||||
assert resp == default_resp
|
||||
|
||||
resp = await gemini_gpt.acompletion(gemini_messages)
|
||||
resp = await gemini_llm.acompletion(gemini_messages)
|
||||
assert resp.text == default_resp.text
|
||||
|
||||
resp = await gemini_gpt.aask(prompt, stream=False)
|
||||
resp = await gemini_llm.aask(prompt, stream=False)
|
||||
assert resp == resp_cont
|
||||
|
||||
resp = await gemini_gpt.acompletion_text(gemini_messages, stream=False)
|
||||
resp = await gemini_llm.acompletion_text(gemini_messages, stream=False)
|
||||
assert resp == resp_cont
|
||||
|
||||
resp = await gemini_gpt.acompletion_text(gemini_messages, stream=True)
|
||||
resp = await gemini_llm.acompletion_text(gemini_messages, stream=True)
|
||||
assert resp == resp_cont
|
||||
|
||||
resp = await gemini_gpt.aask(prompt)
|
||||
resp = await gemini_llm.aask(prompt)
|
||||
assert resp == resp_cont
|
||||
|
|
|
|||
|
|
@ -39,19 +39,19 @@ async def mock_ollama_arequest(self, stream: bool = False, **kwargs) -> Tuple[An
|
|||
async def test_gemini_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.general_api_requestor.GeneralAPIRequestor.arequest", mock_ollama_arequest)
|
||||
|
||||
ollama_gpt = OllamaLLM(mock_llm_config)
|
||||
ollama_llm = OllamaLLM(mock_llm_config)
|
||||
|
||||
resp = await ollama_gpt.acompletion(messages)
|
||||
resp = await ollama_llm.acompletion(messages)
|
||||
assert resp["message"]["content"] == default_resp["message"]["content"]
|
||||
|
||||
resp = await ollama_gpt.aask(prompt, stream=False)
|
||||
resp = await ollama_llm.aask(prompt, stream=False)
|
||||
assert resp == resp_cont
|
||||
|
||||
resp = await ollama_gpt.acompletion_text(messages, stream=False)
|
||||
resp = await ollama_llm.acompletion_text(messages, stream=False)
|
||||
assert resp == resp_cont
|
||||
|
||||
resp = await ollama_gpt.acompletion_text(messages, stream=True)
|
||||
resp = await ollama_llm.acompletion_text(messages, stream=True)
|
||||
assert resp == resp_cont
|
||||
|
||||
resp = await ollama_gpt.aask(prompt)
|
||||
resp = await ollama_llm.aask(prompt)
|
||||
assert resp == resp_cont
|
||||
|
|
|
|||
|
|
@ -17,11 +17,11 @@ from tests.metagpt.provider.req_resp_const import (
|
|||
resp_cont_tmpl,
|
||||
)
|
||||
|
||||
llm_name = "llama2-7b"
|
||||
resp_cont = resp_cont_tmpl.format(name=llm_name)
|
||||
default_resp = get_openai_chat_completion(llm_name)
|
||||
name = "llama2-7b"
|
||||
resp_cont = resp_cont_tmpl.format(name=name)
|
||||
default_resp = get_openai_chat_completion(name)
|
||||
|
||||
default_resp_chunk = get_openai_chat_completion_chunk(llm_name)
|
||||
default_resp_chunk = get_openai_chat_completion_chunk(name)
|
||||
|
||||
|
||||
async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs) -> ChatCompletionChunk:
|
||||
|
|
@ -40,26 +40,26 @@ async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs)
|
|||
async def test_openllm_acompletion(mocker):
|
||||
mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create)
|
||||
|
||||
openllm_gpt = OpenLLM(mock_llm_config)
|
||||
openllm_gpt.model = "llama-v2-13b-chat"
|
||||
openllm_llm = OpenLLM(mock_llm_config)
|
||||
openllm_llm.model = "llama-v2-13b-chat"
|
||||
|
||||
openllm_gpt.cost_manager = CostManager()
|
||||
openllm_gpt._update_costs(usage=CompletionUsage(prompt_tokens=100, completion_tokens=100, total_tokens=200))
|
||||
assert openllm_gpt.get_costs() == Costs(
|
||||
openllm_llm.cost_manager = CostManager()
|
||||
openllm_llm._update_costs(usage=CompletionUsage(prompt_tokens=100, completion_tokens=100, total_tokens=200))
|
||||
assert openllm_llm.get_costs() == Costs(
|
||||
total_prompt_tokens=100, total_completion_tokens=100, total_cost=0, total_budget=0
|
||||
)
|
||||
|
||||
resp = await openllm_gpt.acompletion(messages)
|
||||
resp = await openllm_llm.acompletion(messages)
|
||||
assert resp.choices[0].message.content in resp_cont
|
||||
|
||||
resp = await openllm_gpt.aask(prompt, stream=False)
|
||||
resp = await openllm_llm.aask(prompt, stream=False)
|
||||
assert resp == resp_cont
|
||||
|
||||
resp = await openllm_gpt.acompletion_text(messages, stream=False)
|
||||
resp = await openllm_llm.acompletion_text(messages, stream=False)
|
||||
assert resp == resp_cont
|
||||
|
||||
resp = await openllm_gpt.acompletion_text(messages, stream=True)
|
||||
resp = await openllm_llm.acompletion_text(messages, stream=True)
|
||||
assert resp == resp_cont
|
||||
|
||||
resp = await openllm_gpt.aask(prompt)
|
||||
resp = await openllm_llm.aask(prompt)
|
||||
assert resp == resp_cont
|
||||
|
|
|
|||
|
|
@ -2,14 +2,45 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of qianfan api
|
||||
|
||||
from typing import Dict, Union, AsyncIterator
|
||||
import pytest
|
||||
|
||||
from qianfan.resources.typing import JsonBody, QfResponse
|
||||
|
||||
from metagpt.provider.qianfan_api import QianFanLLM
|
||||
from tests.metagpt.provider.req_resp_const import prompt, messages, resp_cont_tmpl
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config_qianfan
|
||||
from tests.metagpt.provider.req_resp_const import resp_cont_tmpl, prompt, messages, llm_general_chat_funcs_test, get_qianfan_response
|
||||
|
||||
name = "ERNIE-Bot-turbo"
|
||||
resp_cont = resp_cont_tmpl.format(name=name)
|
||||
|
||||
|
||||
resp_cont = resp_cont_tmpl.format(name="ERNIE-Bot-turbo")
|
||||
def mock_qianfan_do(self, messages: list[dict], model: str, stream: bool = False, system: str = None) -> QfResponse:
|
||||
return get_qianfan_response(name=name)
|
||||
|
||||
|
||||
def test_qianfan_acompletion(mocker):
|
||||
assert True, True
|
||||
async def mock_qianfan_ado(self, messages: list[dict], model: str, stream: bool = True, system: str = None) -> Union[QfResponse, AsyncIterator[QfResponse]]:
|
||||
resps = [get_qianfan_response(name=name)]
|
||||
if stream:
|
||||
async def aresp_iterator(resps: list[JsonBody]):
|
||||
for resp in resps:
|
||||
yield resp
|
||||
return aresp_iterator(resps)
|
||||
else:
|
||||
return resps[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_qianfan_acompletion(mocker):
|
||||
mocker.patch("qianfan.resources.llm.chat_completion.ChatCompletion.do", mock_qianfan_do)
|
||||
mocker.patch("qianfan.resources.llm.chat_completion.ChatCompletion.ado", mock_qianfan_ado)
|
||||
|
||||
qianfan_llm = QianFanLLM(mock_llm_config_qianfan)
|
||||
|
||||
resp = qianfan_llm.completion(messages)
|
||||
assert resp.get("result") == resp_cont
|
||||
|
||||
resp = await qianfan_llm.acompletion(messages)
|
||||
assert resp.get("result") == resp_cont
|
||||
|
||||
await llm_general_chat_funcs_test(qianfan_llm, prompt, messages, resp_cont)
|
||||
|
|
|
|||
|
|
@ -50,19 +50,19 @@ async def test_spark_aask(mocker):
|
|||
async def test_spark_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run)
|
||||
|
||||
spark_gpt = SparkLLM(mock_llm_config)
|
||||
spark_llm = SparkLLM(mock_llm_config)
|
||||
|
||||
resp = await spark_gpt.acompletion([])
|
||||
resp = await spark_llm.acompletion([])
|
||||
assert resp == resp_cont
|
||||
|
||||
resp = await spark_gpt.aask(prompt, stream=False)
|
||||
resp = await spark_llm.aask(prompt, stream=False)
|
||||
assert resp == resp_cont
|
||||
|
||||
resp = await spark_gpt.acompletion_text([], stream=False)
|
||||
resp = await spark_llm.acompletion_text([], stream=False)
|
||||
assert resp == resp_cont
|
||||
|
||||
resp = await spark_gpt.acompletion_text([], stream=True)
|
||||
resp = await spark_llm.acompletion_text([], stream=True)
|
||||
assert resp == resp_cont
|
||||
|
||||
resp = await spark_gpt.aask(prompt)
|
||||
resp = await spark_llm.aask(prompt)
|
||||
assert resp == resp_cont
|
||||
|
|
|
|||
|
|
@ -11,11 +11,12 @@ from tests.metagpt.provider.req_resp_const import (
|
|||
messages,
|
||||
prompt,
|
||||
resp_cont_tmpl,
|
||||
llm_general_chat_funcs_test
|
||||
)
|
||||
|
||||
llm_name = "ChatGLM-4"
|
||||
resp_cont = resp_cont_tmpl.format(name=llm_name)
|
||||
default_resp = get_part_chat_completion(llm_name)
|
||||
name = "ChatGLM-4"
|
||||
resp_cont = resp_cont_tmpl.format(name=name)
|
||||
default_resp = get_part_chat_completion(name)
|
||||
|
||||
|
||||
async def mock_zhipuai_acreate_stream(**kwargs):
|
||||
|
|
@ -47,22 +48,12 @@ async def test_zhipuai_acompletion(mocker):
|
|||
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.acreate", mock_zhipuai_acreate)
|
||||
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.acreate_stream", mock_zhipuai_acreate_stream)
|
||||
|
||||
zhipu_gpt = ZhiPuAILLM(mock_llm_config_zhipu)
|
||||
zhipu_llm = ZhiPuAILLM(mock_llm_config_zhipu)
|
||||
|
||||
resp = await zhipu_gpt.acompletion(messages)
|
||||
resp = await zhipu_llm.acompletion(messages)
|
||||
assert resp["choices"][0]["message"]["content"] == resp_cont
|
||||
|
||||
resp = await zhipu_gpt.aask(prompt, stream=False)
|
||||
assert resp == resp_cont
|
||||
|
||||
resp = await zhipu_gpt.acompletion_text(messages, stream=False)
|
||||
assert resp == resp_cont
|
||||
|
||||
resp = await zhipu_gpt.acompletion_text(messages, stream=True)
|
||||
assert resp == resp_cont
|
||||
|
||||
resp = await zhipu_gpt.aask(prompt)
|
||||
assert resp == resp_cont
|
||||
await llm_general_chat_funcs_test(zhipu_llm, prompt, messages, resp_cont)
|
||||
|
||||
|
||||
def test_zhipuai_proxy():
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue