simplify provider ut code

This commit is contained in:
better629 2024-02-07 17:40:27 +08:00
parent 15a9c5e941
commit dc240a2efd
14 changed files with 235 additions and 201 deletions

View file

@ -54,7 +54,6 @@ jobs:
export ALLOW_OPENAI_API_CALL=0
echo "${{ secrets.METAGPT_KEY_YAML }}" | base64 -d > config/key.yaml
mkdir -p ~/.metagpt && echo "${{ secrets.METAGPT_CONFIG2_YAML }}" | base64 -d > ~/.metagpt/config2.yaml
echo "${{ secrets.SPARK_YAML }}" | base64 -d > ~/.metagpt/spark.yaml
pytest tests/ --doctest-modules --cov=./metagpt/ --cov-report=xml:cov.xml --cov-report=html:htmlcov --durations=20 | tee unittest.txt
- name: Show coverage report
run: |

View file

@ -31,7 +31,7 @@ jobs:
- name: Test with pytest
run: |
export ALLOW_OPENAI_API_CALL=0
mkdir -p ~/.metagpt && cp tests/config2.yaml ~/.metagpt/config2.yaml && cp tests/spark.yaml ~/.metagpt/spark.yaml
mkdir -p ~/.metagpt && cp tests/config2.yaml ~/.metagpt/config2.yaml
pytest tests/ --doctest-modules --cov=./metagpt/ --cov-report=xml:cov.xml --cov-report=html:htmlcov --durations=20 | tee unittest.txt
- name: Show coverage report
run: |

View file

@ -42,3 +42,13 @@ mock_llm_config_zhipu = LLMConfig(
model="mock_zhipu_model",
proxy="http://localhost:8080",
)
mock_llm_config_spark = LLMConfig(
api_type="spark",
app_id="xxx",
api_key="xxx",
api_secret="xxx",
domain="generalv2",
base_url="wss://spark-api.xf-yun.com/v3.1/chat",
)

View file

@ -0,0 +1,80 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : default request & response data for provider unittest
from openai.types.chat.chat_completion import (
ChatCompletion,
ChatCompletionMessage,
Choice,
)
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai.types.chat.chat_completion_chunk import Choice as AChoice
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from openai.types.completion_usage import CompletionUsage
prompt = "who are you?"
messages = [{"role": "user", "content": prompt}]
resp_cont_tmpl = "I'm {name}"
default_resp_cont = resp_cont_tmpl.format(name="GPT")
# part of whole ChatCompletion of openai like structure
def get_part_chat_completion(llm_name: str) -> dict:
part_chat_completion = {
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": resp_cont_tmpl.format(name=llm_name),
},
"finish_reason": "stop",
}
],
"usage": {"completion_tokens": 22, "prompt_tokens": 19, "total_tokens": 41},
}
return part_chat_completion
def get_openai_chat_completion(llm_name: str) -> ChatCompletion:
openai_chat_completion = ChatCompletion(
id="cmpl-a6652c1bb181caae8dd19ad8",
model="xx/xxx",
object="chat.completion",
created=1703300855,
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(role="assistant", content=resp_cont_tmpl.format(name=llm_name)),
logprobs=None,
)
],
usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202),
)
return openai_chat_completion
def get_openai_chat_completion_chunk(llm_name: str, usage_as_dict: bool = False) -> ChatCompletionChunk:
usage = CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202)
usage = usage if not usage_as_dict else usage.model_dump()
openai_chat_completion_chunk = ChatCompletionChunk(
id="cmpl-a6652c1bb181caae8dd19ad8",
model="xx/xxx",
object="chat.completion.chunk",
created=1703300855,
choices=[
AChoice(
delta=ChoiceDelta(role="assistant", content=resp_cont_tmpl.format(name=llm_name)),
finish_reason="stop",
index=0,
logprobs=None,
)
],
usage=usage,
)
return openai_chat_completion_chunk
gemini_messages = [{"role": "user", "parts": prompt}]

View file

@ -8,25 +8,25 @@ from anthropic.resources.completions import Completion
from metagpt.provider.anthropic_api import Claude2
from tests.metagpt.provider.mock_llm_config import mock_llm_config
from tests.metagpt.provider.req_resp_const import prompt, resp_cont_tmpl
prompt = "who are you"
resp = "I'am Claude2"
resp_cont = resp_cont_tmpl.format(name="Claude")
def mock_anthropic_completions_create(self, model: str, prompt: str, max_tokens_to_sample: int) -> Completion:
return Completion(id="xx", completion=resp, model="claude-2", stop_reason="stop_sequence", type="completion")
return Completion(id="xx", completion=resp_cont, model="claude-2", stop_reason="stop_sequence", type="completion")
async def mock_anthropic_acompletions_create(self, model: str, prompt: str, max_tokens_to_sample: int) -> Completion:
return Completion(id="xx", completion=resp, model="claude-2", stop_reason="stop_sequence", type="completion")
return Completion(id="xx", completion=resp_cont, model="claude-2", stop_reason="stop_sequence", type="completion")
def test_claude2_ask(mocker):
mocker.patch("anthropic.resources.completions.Completions.create", mock_anthropic_completions_create)
assert resp == Claude2(mock_llm_config).ask(prompt)
assert resp_cont == Claude2(mock_llm_config).ask(prompt)
@pytest.mark.asyncio
async def test_claude2_aask(mocker):
mocker.patch("anthropic.resources.completions.AsyncCompletions.create", mock_anthropic_acompletions_create)
assert resp == await Claude2(mock_llm_config).aask(prompt)
assert resp_cont == await Claude2(mock_llm_config).aask(prompt)

View file

@ -11,21 +11,13 @@ import pytest
from metagpt.configs.llm_config import LLMConfig
from metagpt.provider.base_llm import BaseLLM
from metagpt.schema import Message
from tests.metagpt.provider.req_resp_const import (
default_resp_cont,
get_part_chat_completion,
prompt,
)
default_chat_resp = {
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "I'am GPT",
},
"finish_reason": "stop",
}
]
}
prompt_msg = "who are you"
resp_content = default_chat_resp["choices"][0]["message"]["content"]
llm_name = "GPT"
class MockBaseLLM(BaseLLM):
@ -33,16 +25,13 @@ class MockBaseLLM(BaseLLM):
pass
def completion(self, messages: list[dict], timeout=3):
return default_chat_resp
return get_part_chat_completion(llm_name)
async def acompletion(self, messages: list[dict], timeout=3):
return default_chat_resp
return get_part_chat_completion(llm_name)
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
return resp_content
async def close(self):
return default_chat_resp
return default_resp_cont
def test_base_llm():
@ -86,25 +75,25 @@ def test_base_llm():
choice_text = base_llm.get_choice_text(openai_funccall_resp)
assert choice_text == openai_funccall_resp["choices"][0]["message"]["content"]
# resp = base_llm.ask(prompt_msg)
# assert resp == resp_content
# resp = base_llm.ask(prompt)
# assert resp == default_resp_cont
# resp = base_llm.ask_batch([prompt_msg])
# assert resp == resp_content
# resp = base_llm.ask_batch([prompt])
# assert resp == default_resp_cont
# resp = base_llm.ask_code([prompt_msg])
# assert resp == resp_content
# resp = base_llm.ask_code([prompt])
# assert resp == default_resp_cont
@pytest.mark.asyncio
async def test_async_base_llm():
base_llm = MockBaseLLM()
resp = await base_llm.aask(prompt_msg)
assert resp == resp_content
resp = await base_llm.aask(prompt)
assert resp == default_resp_cont
resp = await base_llm.aask_batch([prompt_msg])
assert resp == resp_content
resp = await base_llm.aask_batch([prompt])
assert resp == default_resp_cont
# resp = await base_llm.aask_code([prompt_msg])
# assert resp == resp_content
# resp = await base_llm.aask_code([prompt])
# assert resp == default_resp_cont

View file

@ -3,14 +3,7 @@
# @Desc : the unittest of fireworks api
import pytest
from openai.types.chat.chat_completion import (
ChatCompletion,
ChatCompletionMessage,
Choice,
)
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai.types.chat.chat_completion_chunk import Choice as AChoice
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from openai.types.completion_usage import CompletionUsage
from metagpt.provider.fireworks_api import (
@ -20,42 +13,18 @@ from metagpt.provider.fireworks_api import (
)
from metagpt.utils.cost_manager import Costs
from tests.metagpt.provider.mock_llm_config import mock_llm_config
resp_content = "I'm fireworks"
default_resp = ChatCompletion(
id="cmpl-a6652c1bb181caae8dd19ad8",
model="accounts/fireworks/models/llama-v2-13b-chat",
object="chat.completion",
created=1703300855,
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(role="assistant", content=resp_content),
logprobs=None,
)
],
usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202),
from tests.metagpt.provider.req_resp_const import (
get_openai_chat_completion,
get_openai_chat_completion_chunk,
messages,
prompt,
resp_cont_tmpl,
)
default_resp_chunk = ChatCompletionChunk(
id=default_resp.id,
model=default_resp.model,
object="chat.completion.chunk",
created=default_resp.created,
choices=[
AChoice(
delta=ChoiceDelta(content=resp_content, role="assistant"),
finish_reason="stop",
index=0,
logprobs=None,
)
],
usage=dict(default_resp.usage),
)
prompt_msg = "who are you"
messages = [{"role": "user", "content": prompt_msg}]
llm_name = "fireworks"
resp_cont = resp_cont_tmpl.format(name=llm_name)
default_resp = get_openai_chat_completion(llm_name)
default_resp_chunk = get_openai_chat_completion_chunk(llm_name, usage_as_dict=True)
def test_fireworks_costmanager():
@ -99,16 +68,16 @@ async def test_fireworks_acompletion(mocker):
)
resp = await fireworks_gpt.acompletion(messages)
assert resp.choices[0].message.content in resp_content
assert resp.choices[0].message.content in resp_cont
resp = await fireworks_gpt.aask(prompt_msg, stream=False)
assert resp == resp_content
resp = await fireworks_gpt.aask(prompt, stream=False)
assert resp == resp_cont
resp = await fireworks_gpt.acompletion_text(messages, stream=False)
assert resp == resp_content
assert resp == resp_cont
resp = await fireworks_gpt.acompletion_text(messages, stream=True)
assert resp == resp_content
assert resp == resp_cont
resp = await fireworks_gpt.aask(prompt_msg)
assert resp == resp_content
resp = await fireworks_gpt.aask(prompt)
assert resp == resp_cont

View file

@ -11,6 +11,11 @@ from google.generativeai.types import content_types
from metagpt.provider.google_gemini_api import GeminiLLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config
from tests.metagpt.provider.req_resp_const import (
gemini_messages,
prompt,
resp_cont_tmpl,
)
@dataclass
@ -18,10 +23,8 @@ class MockGeminiResponse(ABC):
text: str
prompt_msg = "who are you"
messages = [{"role": "user", "parts": prompt_msg}]
resp_content = "I'm gemini from google"
default_resp = MockGeminiResponse(text=resp_content)
resp_cont = resp_cont_tmpl.format(name="gemini")
default_resp = MockGeminiResponse(text=resp_cont)
def mock_gemini_count_tokens(self, contents: content_types.ContentsType) -> glm.CountTokensResponse:
@ -62,26 +65,26 @@ async def test_gemini_acompletion(mocker):
gemini_gpt = GeminiLLM(mock_llm_config)
assert gemini_gpt._user_msg(prompt_msg) == {"role": "user", "parts": [prompt_msg]}
assert gemini_gpt._assistant_msg(prompt_msg) == {"role": "model", "parts": [prompt_msg]}
assert gemini_gpt._user_msg(prompt) == {"role": "user", "parts": [prompt]}
assert gemini_gpt._assistant_msg(prompt) == {"role": "model", "parts": [prompt]}
usage = gemini_gpt.get_usage(messages, resp_content)
usage = gemini_gpt.get_usage(gemini_messages, resp_cont)
assert usage == {"prompt_tokens": 20, "completion_tokens": 20}
resp = gemini_gpt.completion(messages)
resp = gemini_gpt.completion(gemini_messages)
assert resp == default_resp
resp = await gemini_gpt.acompletion(messages)
resp = await gemini_gpt.acompletion(gemini_messages)
assert resp.text == default_resp.text
resp = await gemini_gpt.aask(prompt_msg, stream=False)
assert resp == resp_content
resp = await gemini_gpt.aask(prompt, stream=False)
assert resp == resp_cont
resp = await gemini_gpt.acompletion_text(messages, stream=False)
assert resp == resp_content
resp = await gemini_gpt.acompletion_text(gemini_messages, stream=False)
assert resp == resp_cont
resp = await gemini_gpt.acompletion_text(messages, stream=True)
assert resp == resp_content
resp = await gemini_gpt.acompletion_text(gemini_messages, stream=True)
assert resp == resp_cont
resp = await gemini_gpt.aask(prompt_msg)
assert resp == resp_content
resp = await gemini_gpt.aask(prompt)
assert resp == resp_cont

View file

@ -9,12 +9,10 @@ import pytest
from metagpt.provider.ollama_api import OllamaLLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config
from tests.metagpt.provider.req_resp_const import messages, prompt, resp_cont_tmpl
prompt_msg = "who are you"
messages = [{"role": "user", "content": prompt_msg}]
resp_content = "I'm ollama"
default_resp = {"message": {"role": "assistant", "content": resp_content}}
resp_cont = resp_cont_tmpl.format(name="ollama")
default_resp = {"message": {"role": "assistant", "content": resp_cont}}
async def mock_ollama_arequest(self, stream: bool = False, **kwargs) -> Tuple[Any, Any, bool]:
@ -46,14 +44,14 @@ async def test_gemini_acompletion(mocker):
resp = await ollama_gpt.acompletion(messages)
assert resp["message"]["content"] == default_resp["message"]["content"]
resp = await ollama_gpt.aask(prompt_msg, stream=False)
assert resp == resp_content
resp = await ollama_gpt.aask(prompt, stream=False)
assert resp == resp_cont
resp = await ollama_gpt.acompletion_text(messages, stream=False)
assert resp == resp_content
assert resp == resp_cont
resp = await ollama_gpt.acompletion_text(messages, stream=True)
assert resp == resp_content
assert resp == resp_cont
resp = await ollama_gpt.aask(prompt_msg)
assert resp == resp_content
resp = await ollama_gpt.aask(prompt)
assert resp == resp_cont

View file

@ -3,53 +3,25 @@
# @Desc :
import pytest
from openai.types.chat.chat_completion import (
ChatCompletion,
ChatCompletionMessage,
Choice,
)
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai.types.chat.chat_completion_chunk import Choice as AChoice
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from openai.types.completion_usage import CompletionUsage
from metagpt.provider.open_llm_api import OpenLLM
from metagpt.utils.cost_manager import Costs
from metagpt.utils.cost_manager import CostManager, Costs
from tests.metagpt.provider.mock_llm_config import mock_llm_config
resp_content = "I'm llama2"
default_resp = ChatCompletion(
id="cmpl-a6652c1bb181caae8dd19ad8",
model="llama-v2-13b-chat",
object="chat.completion",
created=1703302755,
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(role="assistant", content=resp_content),
logprobs=None,
)
],
from tests.metagpt.provider.req_resp_const import (
get_openai_chat_completion,
get_openai_chat_completion_chunk,
messages,
prompt,
resp_cont_tmpl,
)
default_resp_chunk = ChatCompletionChunk(
id=default_resp.id,
model=default_resp.model,
object="chat.completion.chunk",
created=default_resp.created,
choices=[
AChoice(
delta=ChoiceDelta(content=resp_content, role="assistant"),
finish_reason="stop",
index=0,
logprobs=None,
)
],
)
llm_name = "llama2-7b"
resp_cont = resp_cont_tmpl.format(name=llm_name)
default_resp = get_openai_chat_completion(llm_name)
prompt_msg = "who are you"
messages = [{"role": "user", "content": prompt_msg}]
default_resp_chunk = get_openai_chat_completion_chunk(llm_name)
async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs) -> ChatCompletionChunk:
@ -71,22 +43,23 @@ async def test_openllm_acompletion(mocker):
openllm_gpt = OpenLLM(mock_llm_config)
openllm_gpt.model = "llama-v2-13b-chat"
openllm_gpt.cost_manager = CostManager()
openllm_gpt._update_costs(usage=CompletionUsage(prompt_tokens=100, completion_tokens=100, total_tokens=200))
assert openllm_gpt.get_costs() == Costs(
total_prompt_tokens=100, total_completion_tokens=100, total_cost=0, total_budget=0
)
resp = await openllm_gpt.acompletion(messages)
assert resp.choices[0].message.content in resp_content
assert resp.choices[0].message.content in resp_cont
resp = await openllm_gpt.aask(prompt_msg, stream=False)
assert resp == resp_content
resp = await openllm_gpt.aask(prompt, stream=False)
assert resp == resp_cont
resp = await openllm_gpt.acompletion_text(messages, stream=False)
assert resp == resp_content
assert resp == resp_cont
resp = await openllm_gpt.acompletion_text(messages, stream=True)
assert resp == resp_content
assert resp == resp_cont
resp = await openllm_gpt.aask(prompt_msg)
assert resp == resp_content
resp = await openllm_gpt.aask(prompt)
assert resp == resp_cont

View file

@ -0,0 +1,15 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : the unittest of qianfan api
import pytest
from metagpt.provider.qianfan_api import QianFanLLM
from tests.metagpt.provider.req_resp_const import prompt, messages, resp_cont_tmpl
resp_cont = resp_cont_tmpl.format(name="ERNIE-Bot-turbo")
def test_qianfan_acompletion(mocker):
assert True, True

View file

@ -4,12 +4,14 @@
import pytest
from metagpt.config2 import Config
from metagpt.provider.spark_api import GetMessageFromWeb, SparkLLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config
from tests.metagpt.provider.mock_llm_config import (
mock_llm_config,
mock_llm_config_spark,
)
from tests.metagpt.provider.req_resp_const import prompt, resp_cont_tmpl
prompt_msg = "who are you"
resp_content = "I'm Spark"
resp_cont = resp_cont_tmpl.format(name="Spark")
class MockWebSocketApp(object):
@ -23,7 +25,7 @@ class MockWebSocketApp(object):
def test_get_msg_from_web(mocker):
mocker.patch("websocket.WebSocketApp", MockWebSocketApp)
get_msg_from_web = GetMessageFromWeb(prompt_msg, mock_llm_config)
get_msg_from_web = GetMessageFromWeb(prompt, mock_llm_config)
assert get_msg_from_web.gen_params()["parameter"]["chat"]["domain"] == "mock_domain"
ret = get_msg_from_web.run()
@ -31,15 +33,17 @@ def test_get_msg_from_web(mocker):
def mock_spark_get_msg_from_web_run(self) -> str:
return resp_content
return resp_cont
@pytest.mark.asyncio
async def test_spark_aask():
llm = SparkLLM(Config.from_home("spark.yaml").llm)
async def test_spark_aask(mocker):
mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run)
llm = SparkLLM(mock_llm_config_spark)
resp = await llm.aask("Hello!")
print(resp)
assert resp == resp_cont
@pytest.mark.asyncio
@ -49,16 +53,16 @@ async def test_spark_acompletion(mocker):
spark_gpt = SparkLLM(mock_llm_config)
resp = await spark_gpt.acompletion([])
assert resp == resp_content
assert resp == resp_cont
resp = await spark_gpt.aask(prompt_msg, stream=False)
assert resp == resp_content
resp = await spark_gpt.aask(prompt, stream=False)
assert resp == resp_cont
resp = await spark_gpt.acompletion_text([], stream=False)
assert resp == resp_content
assert resp == resp_cont
resp = await spark_gpt.acompletion_text([], stream=True)
assert resp == resp_content
assert resp == resp_cont
resp = await spark_gpt.aask(prompt_msg)
assert resp == resp_content
resp = await spark_gpt.aask(prompt)
assert resp == resp_cont

View file

@ -6,22 +6,23 @@ import pytest
from metagpt.provider.zhipuai_api import ZhiPuAILLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config_zhipu
from tests.metagpt.provider.req_resp_const import (
get_part_chat_completion,
messages,
prompt,
resp_cont_tmpl,
)
prompt_msg = "who are you"
messages = [{"role": "user", "content": prompt_msg}]
resp_content = "I'm chatglm-turbo"
default_resp = {
"choices": [{"finish_reason": "stop", "index": 0, "message": {"content": resp_content, "role": "assistant"}}],
"usage": {"completion_tokens": 22, "prompt_tokens": 19, "total_tokens": 41},
}
llm_name = "ChatGLM-4"
resp_cont = resp_cont_tmpl.format(name=llm_name)
default_resp = get_part_chat_completion(llm_name)
async def mock_zhipuai_acreate_stream(**kwargs):
class MockResponse(object):
async def _aread(self):
class Iterator(object):
events = [{"choices": [{"index": 0, "delta": {"content": resp_content, "role": "assistant"}}]}]
events = [{"choices": [{"index": 0, "delta": {"content": resp_cont, "role": "assistant"}}]}]
async def __aiter__(self):
for event in self.events:
@ -49,19 +50,19 @@ async def test_zhipuai_acompletion(mocker):
zhipu_gpt = ZhiPuAILLM(mock_llm_config_zhipu)
resp = await zhipu_gpt.acompletion(messages)
assert resp["choices"][0]["message"]["content"] == resp_content
assert resp["choices"][0]["message"]["content"] == resp_cont
resp = await zhipu_gpt.aask(prompt_msg, stream=False)
assert resp == resp_content
resp = await zhipu_gpt.aask(prompt, stream=False)
assert resp == resp_cont
resp = await zhipu_gpt.acompletion_text(messages, stream=False)
assert resp == resp_content
assert resp == resp_cont
resp = await zhipu_gpt.acompletion_text(messages, stream=True)
assert resp == resp_content
assert resp == resp_cont
resp = await zhipu_gpt.aask(prompt_msg)
assert resp == resp_content
resp = await zhipu_gpt.aask(prompt)
assert resp == resp_cont
def test_zhipuai_proxy():

View file

@ -1,7 +0,0 @@
llm:
api_type: "spark"
app_id: "xxx"
api_key: "xxx"
api_secret: "xxx"
domain: "generalv2"
base_url: "wss://spark-api.xf-yun.com/v3.1/chat"