mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-24 14:15:17 +02:00
Merge branch 'dev' into dev
This commit is contained in:
commit
539e1c7dce
81 changed files with 1402 additions and 649 deletions
|
|
@ -12,6 +12,7 @@ import pytest
|
|||
from metagpt.actions import Action
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.environment import Environment
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.roles import Role
|
||||
from metagpt.schema import Message
|
||||
from metagpt.team import Team
|
||||
|
|
@ -76,18 +77,24 @@ async def test_action_node_one_layer():
|
|||
assert "key-a" in markdown_template
|
||||
|
||||
assert node_dict["key-a"] == "instruction-b"
|
||||
assert "key-a" in repr(node)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_action_node_two_layer():
|
||||
node_a = ActionNode(key="key-a", expected_type=str, instruction="i-a", example="e-a")
|
||||
node_b = ActionNode(key="key-b", expected_type=str, instruction="i-b", example="e-b")
|
||||
node_a = ActionNode(key="reasoning", expected_type=str, instruction="reasoning step by step", example="")
|
||||
node_b = ActionNode(key="answer", expected_type=str, instruction="the final answer", example="")
|
||||
|
||||
root = ActionNode.from_children(key="", nodes=[node_a, node_b])
|
||||
assert "key-a" in root.children
|
||||
root = ActionNode.from_children(key="detail answer", nodes=[node_a, node_b])
|
||||
assert "reasoning" in root.children
|
||||
assert node_b in root.children.values()
|
||||
json_template = root.compile(context="123", schema="json", mode="auto")
|
||||
assert "i-a" in json_template
|
||||
|
||||
# FIXME: ADD MARKDOWN SUPPORT. NEED TO TUNE MARKDOWN SYMBOL FIRST.
|
||||
answer1 = await root.fill(context="what's the answer to 123+456?", schema="json", strgy="simple", llm=LLM())
|
||||
assert "579" in answer1.content
|
||||
|
||||
answer2 = await root.fill(context="what's the answer to 123+456?", schema="json", strgy="complex", llm=LLM())
|
||||
assert "579" in answer2.content
|
||||
|
||||
|
||||
t_dict = {
|
||||
|
|
@ -116,11 +123,28 @@ WRITE_TASKS_OUTPUT_MAPPING = {
|
|||
"Anything UNCLEAR": (str, ...),
|
||||
}
|
||||
|
||||
WRITE_TASKS_OUTPUT_MAPPING_MISSING = {
|
||||
"Required Python third-party packages": (str, ...),
|
||||
}
|
||||
|
||||
|
||||
def test_create_model_class():
|
||||
test_class = ActionNode.create_model_class("test_class", WRITE_TASKS_OUTPUT_MAPPING)
|
||||
assert test_class.__name__ == "test_class"
|
||||
|
||||
output = test_class(**t_dict)
|
||||
print(output.schema())
|
||||
assert output.schema()["title"] == "test_class"
|
||||
assert output.schema()["type"] == "object"
|
||||
assert output.schema()["properties"]["Full API spec"]
|
||||
|
||||
|
||||
def test_create_model_class_missing():
|
||||
test_class = ActionNode.create_model_class("test_class", WRITE_TASKS_OUTPUT_MAPPING_MISSING)
|
||||
assert test_class.__name__ == "test_class"
|
||||
|
||||
_ = test_class(**t_dict) # 这里应该要挂掉
|
||||
|
||||
|
||||
def test_create_model_class_with_mapping():
|
||||
t = ActionNode.create_model_class("test_class_1", WRITE_TASKS_OUTPUT_MAPPING)
|
||||
|
|
|
|||
|
|
@ -1,16 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/7/1 22:50
|
||||
@Author : alexanderwu
|
||||
@File : test_azure_tts.py
|
||||
"""
|
||||
from metagpt.tools.azure_tts import AzureTTS
|
||||
|
||||
|
||||
def test_azure_tts():
|
||||
azure_tts = AzureTTS()
|
||||
azure_tts.synthesize_speech("zh-CN", "zh-CN-YunxiNeural", "Boy", "你好,我是卡卡", "output.wav")
|
||||
|
||||
# 运行需要先配置 SUBSCRIPTION_KEY
|
||||
# TODO: 这里如果要检验,还要额外加上对应的asr,才能确保前后生成是接近一致的,但现在还没有
|
||||
|
|
@ -1,101 +0,0 @@
|
|||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.clone_function import (
|
||||
CloneFunction,
|
||||
run_function_code,
|
||||
run_function_script,
|
||||
)
|
||||
|
||||
source_code = """
|
||||
import pandas as pd
|
||||
import ta
|
||||
|
||||
def user_indicator():
|
||||
# 读取股票数据
|
||||
stock_data = pd.read_csv('./tests/data/baba_stock.csv')
|
||||
stock_data.head()
|
||||
# 计算简单移动平均线
|
||||
stock_data['SMA'] = ta.trend.sma_indicator(stock_data['Close'], window=6)
|
||||
stock_data[['Date', 'Close', 'SMA']].head()
|
||||
# 计算布林带
|
||||
stock_data['bb_upper'], stock_data['bb_middle'], stock_data['bb_lower'] = ta.volatility.bollinger_hband_indicator(stock_data['Close'], window=20), ta.volatility.bollinger_mavg(stock_data['Close'], window=20), ta.volatility.bollinger_lband_indicator(stock_data['Close'], window=20)
|
||||
stock_data[['Date', 'Close', 'bb_upper', 'bb_middle', 'bb_lower']].head()
|
||||
"""
|
||||
|
||||
template_code = """
|
||||
def stock_indicator(stock_path: str, indicators=['Simple Moving Average', 'BollingerBands', 'MACD]) -> pd.DataFrame:
|
||||
import pandas as pd
|
||||
# here is your code.
|
||||
"""
|
||||
|
||||
|
||||
def get_expected_res():
|
||||
import pandas as pd
|
||||
import ta
|
||||
|
||||
# 读取股票数据
|
||||
stock_data = pd.read_csv("./tests/data/baba_stock.csv")
|
||||
stock_data.head()
|
||||
# 计算简单移动平均线
|
||||
stock_data["SMA"] = ta.trend.sma_indicator(stock_data["Close"], window=6)
|
||||
stock_data[["Date", "Close", "SMA"]].head()
|
||||
# 计算布林带
|
||||
stock_data["bb_upper"], stock_data["bb_middle"], stock_data["bb_lower"] = (
|
||||
ta.volatility.bollinger_hband_indicator(stock_data["Close"], window=20),
|
||||
ta.volatility.bollinger_mavg(stock_data["Close"], window=20),
|
||||
ta.volatility.bollinger_lband_indicator(stock_data["Close"], window=20),
|
||||
)
|
||||
stock_data[["Date", "Close", "bb_upper", "bb_middle", "bb_lower"]].head()
|
||||
return stock_data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clone_function():
|
||||
clone = CloneFunction()
|
||||
code = await clone.run(template_code, source_code)
|
||||
assert "def " in code
|
||||
stock_path = "./tests/data/baba_stock.csv"
|
||||
df, msg = run_function_code(code, "stock_indicator", stock_path)
|
||||
assert not msg
|
||||
expected_df = get_expected_res()
|
||||
assert df.equals(expected_df)
|
||||
|
||||
|
||||
def test_run_function_script():
|
||||
# 创建一个临时文件并写入脚本内容
|
||||
script_content = """def valid_function(arg1, arg2):\n return arg1 + arg2\n"""
|
||||
with tempfile.NamedTemporaryFile(mode="w+", suffix=".py", delete=False) as temp_file:
|
||||
temp_file.write(script_content)
|
||||
temp_file_path = temp_file.name
|
||||
|
||||
invalid_script_content = """def valid_function(arg1, arg2)\n return arg1 + arg2\n"""
|
||||
with tempfile.NamedTemporaryFile(mode="w+", suffix=".py", delete=False) as error_temp_file:
|
||||
error_temp_file.write(invalid_script_content)
|
||||
error_temp_file_path = error_temp_file.name
|
||||
|
||||
try:
|
||||
# 正常情况下运行脚本
|
||||
result, _ = run_function_script(temp_file_path, "valid_function", 1, arg2=2)
|
||||
assert result == 3
|
||||
|
||||
# 不存在的脚本路径
|
||||
with pytest.raises(FileNotFoundError):
|
||||
run_function_script("nonexistent/path/script.py", "valid_function", 1, arg2=2)
|
||||
|
||||
# 无效的脚本内容
|
||||
result, traceback = run_function_script(error_temp_file_path, "invalid_function", 1, arg2=2)
|
||||
assert not result
|
||||
assert "SyntaxError" in traceback
|
||||
|
||||
# 函数调用失败的情况
|
||||
result, traceback = run_function_script(temp_file_path, "function_that_raises_exception", 1, arg2=2)
|
||||
assert not result
|
||||
assert "KeyError" in traceback
|
||||
|
||||
finally:
|
||||
# 删除临时文件
|
||||
if os.path.exists(temp_file_path):
|
||||
os.remove(temp_file_path)
|
||||
|
|
@ -6,27 +6,26 @@
|
|||
@Author : Stitch-z
|
||||
@File : test_invoice_ocr.py
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.invoice_ocr import GenerateTable, InvoiceOCR, ReplyQuestion
|
||||
from metagpt.const import TEST_DATA_PATH
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"invoice_path",
|
||||
[
|
||||
"../../data/invoices/invoice-3.jpg",
|
||||
"../../data/invoices/invoice-4.zip",
|
||||
Path("invoices/invoice-3.jpg"),
|
||||
Path("invoices/invoice-4.zip"),
|
||||
],
|
||||
)
|
||||
async def test_invoice_ocr(invoice_path: str):
|
||||
invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
|
||||
filename = os.path.basename(invoice_path)
|
||||
resp = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename)
|
||||
async def test_invoice_ocr(invoice_path: Path):
|
||||
invoice_path = TEST_DATA_PATH / invoice_path
|
||||
resp = await InvoiceOCR().run(file_path=Path(invoice_path))
|
||||
assert isinstance(resp, list)
|
||||
|
||||
|
||||
|
|
@ -34,25 +33,29 @@ async def test_invoice_ocr(invoice_path: str):
|
|||
@pytest.mark.parametrize(
|
||||
("invoice_path", "expected_result"),
|
||||
[
|
||||
("../../data/invoices/invoice-1.pdf", [{"收款人": "小明", "城市": "深圳市", "总费用/元": "412.00", "开票日期": "2023年02月03日"}]),
|
||||
(Path("invoices/invoice-1.pdf"), {"收款人": "小明", "城市": "深圳", "总费用/元": 412.00, "开票日期": "2023年02月03日"}),
|
||||
],
|
||||
)
|
||||
async def test_generate_table(invoice_path: str, expected_result: list[dict]):
|
||||
invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
|
||||
filename = os.path.basename(invoice_path)
|
||||
ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename)
|
||||
async def test_generate_table(invoice_path: Path, expected_result: dict):
|
||||
invoice_path = TEST_DATA_PATH / invoice_path
|
||||
filename = invoice_path.name
|
||||
ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path))
|
||||
table_data = await GenerateTable().run(ocr_results=ocr_result, filename=filename)
|
||||
assert json.dumps(table_data) == json.dumps(expected_result)
|
||||
assert isinstance(table_data, list)
|
||||
table_data = table_data[0]
|
||||
assert expected_result["收款人"] == table_data["收款人"]
|
||||
assert expected_result["城市"] in table_data["城市"]
|
||||
assert float(expected_result["总费用/元"]) == float(table_data["总费用/元"])
|
||||
assert expected_result["开票日期"] == table_data["开票日期"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("invoice_path", "query", "expected_result"),
|
||||
[("../../data/invoices/invoice-1.pdf", "Invoicing date", "2023年02月03日")],
|
||||
[(Path("invoices/invoice-1.pdf"), "Invoicing date", "2023年02月03日")],
|
||||
)
|
||||
async def test_reply_question(invoice_path: str, query: dict, expected_result: str):
|
||||
invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
|
||||
filename = os.path.basename(invoice_path)
|
||||
ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename)
|
||||
async def test_reply_question(invoice_path: Path, query: dict, expected_result: str):
|
||||
invoice_path = TEST_DATA_PATH / invoice_path
|
||||
ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path))
|
||||
result = await ReplyQuestion().run(query=query, ocr_result=ocr_result)
|
||||
assert expected_result in result
|
||||
|
|
|
|||
124
tests/metagpt/actions/test_research.py
Normal file
124
tests/metagpt/actions/test_research.py
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/28
|
||||
@Author : mashenquan
|
||||
@File : test_research.py
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions import CollectLinks, research
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_action():
|
||||
action = CollectLinks()
|
||||
result = await action.run(topic="baidu")
|
||||
assert result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_collect_links(mocker):
|
||||
async def mock_llm_ask(self, prompt: str, system_msgs):
|
||||
if "Please provide up to 2 necessary keywords" in prompt:
|
||||
return '["metagpt", "llm"]'
|
||||
|
||||
elif "Provide up to 4 queries related to your research topic" in prompt:
|
||||
return (
|
||||
'["MetaGPT use cases", "The roadmap of MetaGPT", '
|
||||
'"The function of MetaGPT", "What llm MetaGPT support"]'
|
||||
)
|
||||
elif "sort the remaining search results" in prompt:
|
||||
return "[1,2]"
|
||||
|
||||
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_llm_ask)
|
||||
resp = await research.CollectLinks().run("The application of MetaGPT")
|
||||
for i in ["MetaGPT use cases", "The roadmap of MetaGPT", "The function of MetaGPT", "What llm MetaGPT support"]:
|
||||
assert i in resp
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_collect_links_with_rank_func(mocker):
|
||||
rank_before = []
|
||||
rank_after = []
|
||||
url_per_query = 4
|
||||
|
||||
def rank_func(results):
|
||||
results = results[:url_per_query]
|
||||
rank_before.append(results)
|
||||
results = results[::-1]
|
||||
rank_after.append(results)
|
||||
return results
|
||||
|
||||
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_collect_links_llm_ask)
|
||||
resp = await research.CollectLinks(rank_func=rank_func).run("The application of MetaGPT")
|
||||
for x, y, z in zip(rank_before, rank_after, resp.values()):
|
||||
assert x[::-1] == y
|
||||
assert [i["link"] for i in y] == z
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_web_browse_and_summarize(mocker):
|
||||
async def mock_llm_ask(*args, **kwargs):
|
||||
return "metagpt"
|
||||
|
||||
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_llm_ask)
|
||||
url = "https://github.com/geekan/MetaGPT"
|
||||
url2 = "https://github.com/trending"
|
||||
query = "What's new in metagpt"
|
||||
resp = await research.WebBrowseAndSummarize().run(url, query=query)
|
||||
|
||||
assert len(resp) == 1
|
||||
assert url in resp
|
||||
assert resp[url] == "metagpt"
|
||||
|
||||
resp = await research.WebBrowseAndSummarize().run(url, url2, query=query)
|
||||
assert len(resp) == 2
|
||||
|
||||
async def mock_llm_ask(*args, **kwargs):
|
||||
return "Not relevant."
|
||||
|
||||
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_llm_ask)
|
||||
resp = await research.WebBrowseAndSummarize().run(url, query=query)
|
||||
|
||||
assert len(resp) == 1
|
||||
assert url in resp
|
||||
assert resp[url] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conduct_research(mocker):
|
||||
data = None
|
||||
|
||||
async def mock_llm_ask(*args, **kwargs):
|
||||
nonlocal data
|
||||
data = f"# Research Report\n## Introduction\n{args} {kwargs}"
|
||||
return data
|
||||
|
||||
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_llm_ask)
|
||||
content = (
|
||||
"MetaGPT takes a one line requirement as input and "
|
||||
"outputs user stories / competitive analysis / requirements / data structures / APIs / documents, etc."
|
||||
)
|
||||
|
||||
resp = await research.ConductResearch().run("The application of MetaGPT", content)
|
||||
assert resp == data
|
||||
|
||||
|
||||
async def mock_collect_links_llm_ask(self, prompt: str, system_msgs):
|
||||
if "Please provide up to 2 necessary keywords" in prompt:
|
||||
return '["metagpt", "llm"]'
|
||||
|
||||
elif "Provide up to 4 queries related to your research topic" in prompt:
|
||||
return (
|
||||
'["MetaGPT use cases", "The roadmap of MetaGPT", ' '"The function of MetaGPT", "What llm MetaGPT support"]'
|
||||
)
|
||||
elif "sort the remaining search results" in prompt:
|
||||
return "[1,2]"
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
@ -14,13 +14,13 @@ from metagpt.schema import RunCodeContext
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_text():
|
||||
result, errs = await RunCode.run_text("result = 1 + 1")
|
||||
assert result == 2
|
||||
assert errs == ""
|
||||
out, err = await RunCode.run_text("result = 1 + 1")
|
||||
assert out == 2
|
||||
assert err == ""
|
||||
|
||||
result, errs = await RunCode.run_text("result = 1 / 0")
|
||||
assert result == ""
|
||||
assert "ZeroDivisionError" in errs
|
||||
out, err = await RunCode.run_text("result = 1 / 0")
|
||||
assert out == ""
|
||||
assert "division by zero" in err
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
|
|
@ -58,7 +58,29 @@ class TestSkillAction:
|
|||
action = SkillAction(skill=self.skill, args=parser_action.args)
|
||||
rsp = await action.run()
|
||||
assert rsp
|
||||
assert "image/png;base64," in rsp.content
|
||||
assert "image/png;base64," in rsp.content or "http" in rsp.content
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("skill_name", "txt", "want"),
|
||||
[
|
||||
("skill1", 'skill1(a="1", b="2")', {"a": "1", "b": "2"}),
|
||||
("skill1", '(a="1", b="2")', None),
|
||||
("skill1", 'skill1(a="1", b="2"', None),
|
||||
],
|
||||
)
|
||||
def test_parse_arguments(self, skill_name, txt, want):
|
||||
args = ArgumentsParingAction.parse_arguments(skill_name, txt)
|
||||
assert args == want
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_and_call_function_error(self):
|
||||
with pytest.raises(ValueError):
|
||||
await SkillAction.find_and_call_function("dummy_call", {"a": 1})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skill_action_error(self):
|
||||
action = SkillAction(skill=self.skill, args={})
|
||||
await action.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
51
tests/metagpt/actions/test_talk_action.py
Normal file
51
tests/metagpt/actions/test_talk_action.py
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/28
|
||||
@Author : mashenquan
|
||||
@File : test_talk_action.py
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.talk_action import TalkAction
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("agent_description", "language", "context", "knowledge", "history_summary"),
|
||||
[
|
||||
(
|
||||
"mathematician",
|
||||
"English",
|
||||
"How old is Susie?",
|
||||
"Susie is a girl born in 2011/11/14. Today is 2023/12/3",
|
||||
"balabala... (useless words)",
|
||||
),
|
||||
(
|
||||
"mathematician",
|
||||
"Chinese",
|
||||
"Does Susie have an apple?",
|
||||
"Susie is a girl born in 2011/11/14. Today is 2023/12/3",
|
||||
"Susie had an apple, and she ate it right now",
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_prompt(agent_description, language, context, knowledge, history_summary):
|
||||
# Prerequisites
|
||||
CONFIG.agent_description = agent_description
|
||||
CONFIG.language = language
|
||||
|
||||
action = TalkAction(context=context, knowledge=knowledge, history_summary=history_summary)
|
||||
assert "{" not in action.prompt
|
||||
assert "{" not in action.prompt_gpt4
|
||||
|
||||
rsp = await action.run()
|
||||
assert rsp
|
||||
assert isinstance(rsp, Message)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
@ -6,12 +6,24 @@
|
|||
@File : test_write_code.py
|
||||
@Modifiled By: mashenquan, 2023-12-6. According to RFC 135
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.write_code import WriteCode
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import (
|
||||
CODE_SUMMARIES_FILE_REPO,
|
||||
SYSTEM_DESIGN_FILE_REPO,
|
||||
TASK_FILE_REPO,
|
||||
TEST_OUTPUTS_FILE_REPO,
|
||||
)
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.openai_api import OpenAILLM as LLM
|
||||
from metagpt.schema import CodingContext, Document
|
||||
from metagpt.utils.common import aread
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
from tests.metagpt.actions.mock_markdown import TASKS_2, WRITE_CODE_PROMPT_SAMPLE
|
||||
|
||||
|
||||
|
|
@ -37,3 +49,47 @@ async def test_write_code_directly():
|
|||
llm = LLM()
|
||||
rsp = await llm.aask(prompt)
|
||||
logger.info(rsp)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_deps():
|
||||
# Prerequisites
|
||||
CONFIG.src_workspace = CONFIG.git_repo.workdir / "snake1/snake1"
|
||||
demo_path = Path(__file__).parent / "../../data/demo_project"
|
||||
await FileRepository.save_file(
|
||||
filename="test_game.py.json",
|
||||
content=await aread(str(demo_path / "test_game.py.json")),
|
||||
relative_path=TEST_OUTPUTS_FILE_REPO,
|
||||
)
|
||||
await FileRepository.save_file(
|
||||
filename="20231221155954.json",
|
||||
content=await aread(str(demo_path / "code_summaries.json")),
|
||||
relative_path=CODE_SUMMARIES_FILE_REPO,
|
||||
)
|
||||
await FileRepository.save_file(
|
||||
filename="20231221155954.json",
|
||||
content=await aread(str(demo_path / "system_design.json")),
|
||||
relative_path=SYSTEM_DESIGN_FILE_REPO,
|
||||
)
|
||||
await FileRepository.save_file(
|
||||
filename="20231221155954.json", content=await aread(str(demo_path / "tasks.json")), relative_path=TASK_FILE_REPO
|
||||
)
|
||||
await FileRepository.save_file(
|
||||
filename="main.py", content='if __name__ == "__main__":\nmain()', relative_path=CONFIG.src_workspace
|
||||
)
|
||||
context = CodingContext(
|
||||
filename="game.py",
|
||||
design_doc=await FileRepository.get_file(filename="20231221155954.json", relative_path=SYSTEM_DESIGN_FILE_REPO),
|
||||
task_doc=await FileRepository.get_file(filename="20231221155954.json", relative_path=TASK_FILE_REPO),
|
||||
code_doc=Document(filename="game.py", content="", root_path="snake1"),
|
||||
)
|
||||
coding_doc = Document(root_path="snake1", filename="game.py", content=context.json())
|
||||
|
||||
action = WriteCode(context=coding_doc)
|
||||
rsp = await action.run()
|
||||
assert rsp
|
||||
assert rsp.code_doc.content
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -30,3 +30,13 @@ class Person:
|
|||
async def test_write_docstring(style: str, part: str):
|
||||
ret = await WriteDocstring().run(code, style=style)
|
||||
assert part in ret
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write():
|
||||
code = await WriteDocstring.write_docstring(__file__)
|
||||
assert code
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -23,10 +23,14 @@ async def test_write_prd_review():
|
|||
Timeline: The feature should be ready for testing in 1.5 months.
|
||||
"""
|
||||
|
||||
write_prd_review = WritePRDReview("write_prd_review")
|
||||
write_prd_review = WritePRDReview(name="write_prd_review")
|
||||
|
||||
prd_review = await write_prd_review.run(prd)
|
||||
|
||||
# We cannot exactly predict the generated PRD review, but we can check if it is a string and if it is not empty
|
||||
assert isinstance(prd_review, str)
|
||||
assert len(prd_review) > 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -6,53 +6,21 @@
|
|||
@File : test_write_teaching_plan.py
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
from langchain.llms.base import LLM
|
||||
from pydantic import BaseModel
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.write_teaching_plan import WriteTeachingPlanPart
|
||||
from metagpt.config import Config
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
class MockWriteTeachingPlanPart(WriteTeachingPlanPart):
|
||||
def __init__(self, options, name: str = "", context=None, llm: LLM = None, topic="", language="Chinese"):
|
||||
super().__init__(options, name, context, llm, topic, language)
|
||||
|
||||
async def _aask(self, prompt: str, system_msgs: Optional[list[str]] = None) -> str:
|
||||
return f"{WriteTeachingPlanPart.DATA_BEGIN_TAG}\nprompt\n{WriteTeachingPlanPart.DATA_END_TAG}"
|
||||
|
||||
|
||||
async def mock_write_teaching_plan_part():
|
||||
class Inputs(BaseModel):
|
||||
input: str
|
||||
name: str
|
||||
topic: str
|
||||
language: str
|
||||
|
||||
inputs = [
|
||||
{"input": "AABBCC", "name": "A", "topic": WriteTeachingPlanPart.COURSE_TITLE, "language": "C"},
|
||||
{"input": "DDEEFFF", "name": "A1", "topic": "B1", "language": "C1"},
|
||||
]
|
||||
|
||||
for i in inputs:
|
||||
seed = Inputs(**i)
|
||||
options = Config().runtime_options
|
||||
act = MockWriteTeachingPlanPart(options=options, name=seed.name, topic=seed.topic, language=seed.language)
|
||||
await act.run([Message(content="")])
|
||||
assert act.topic == seed.topic
|
||||
assert str(act) == seed.topic
|
||||
assert act.name == seed.name
|
||||
assert act.rsp == "# prompt" if seed.topic == WriteTeachingPlanPart.COURSE_TITLE else "prompt"
|
||||
|
||||
|
||||
def test_suite():
|
||||
loop = asyncio.get_event_loop()
|
||||
task = loop.create_task(mock_write_teaching_plan_part())
|
||||
loop.run_until_complete(task)
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("topic", "context"),
|
||||
[("Title", "Lesson 1: Learn to draw an apple."), ("Teaching Content", "Lesson 1: Learn to draw an apple.")],
|
||||
)
|
||||
async def test_write_teaching_plan_part(topic, context):
|
||||
action = WriteTeachingPlanPart(topic=topic, context=context)
|
||||
rsp = await action.run()
|
||||
assert rsp
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_suite()
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -1,36 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/6/11 21:08
|
||||
@Author : alexanderwu
|
||||
@File : test_milvus_store.py
|
||||
"""
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
from metagpt.document_store.milvus_store import MilvusConnection, MilvusStore
|
||||
from metagpt.logs import logger
|
||||
|
||||
book_columns = {"idx": int, "name": str, "desc": str, "emb": np.ndarray, "price": float}
|
||||
book_data = [
|
||||
[i for i in range(10)],
|
||||
[f"book-{i}" for i in range(10)],
|
||||
[f"book-desc-{i}" for i in range(10000, 10010)],
|
||||
[[random.random() for _ in range(2)] for _ in range(10)],
|
||||
[random.random() for _ in range(10)],
|
||||
]
|
||||
|
||||
|
||||
def test_milvus_store():
|
||||
milvus_connection = MilvusConnection(alias="default", host="192.168.50.161", port="30530")
|
||||
milvus_store = MilvusStore(milvus_connection)
|
||||
milvus_store.drop("Book")
|
||||
milvus_store.create_collection("Book", book_columns)
|
||||
milvus_store.add(book_data)
|
||||
milvus_store.build_index("emb")
|
||||
milvus_store.load_collection()
|
||||
|
||||
results = milvus_store.search([[1.0, 1.0]], field="emb")
|
||||
logger.info(results)
|
||||
assert results
|
||||
|
|
@ -29,7 +29,7 @@ points = [
|
|||
]
|
||||
|
||||
|
||||
def test_milvus_store():
|
||||
def test_qdrant_store():
|
||||
qdrant_connection = QdrantConnection(memory=True)
|
||||
vectors_config = VectorParams(size=2, distance=Distance.COSINE)
|
||||
qdrant_store = QdrantStore(qdrant_connection)
|
||||
|
|
@ -43,13 +43,13 @@ def test_milvus_store():
|
|||
results = qdrant_store.search("Book", query=[1.0, 1.0])
|
||||
assert results[0]["id"] == 2
|
||||
assert results[0]["score"] == 0.999106722578389
|
||||
assert results[1]["score"] == 7
|
||||
assert results[1]["id"] == 7
|
||||
assert results[1]["score"] == 0.9961650411397226
|
||||
results = qdrant_store.search("Book", query=[1.0, 1.0], return_vector=True)
|
||||
assert results[0]["id"] == 2
|
||||
assert results[0]["score"] == 0.999106722578389
|
||||
assert results[0]["vector"] == [0.7363563179969788, 0.6765939593315125]
|
||||
assert results[1]["score"] == 7
|
||||
assert results[1]["id"] == 7
|
||||
assert results[1]["score"] == 0.9961650411397226
|
||||
assert results[1]["vector"] == [0.7662628889083862, 0.6425272226333618]
|
||||
results = qdrant_store.search(
|
||||
|
|
|
|||
|
|
@ -7,35 +7,26 @@
|
|||
@Desc : Unit tests.
|
||||
"""
|
||||
|
||||
import base64
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.learn.text_to_image import text_to_image
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test():
|
||||
class Input(BaseModel):
|
||||
input: str
|
||||
size_type: str
|
||||
# Prerequisites
|
||||
assert CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL
|
||||
assert CONFIG.OPENAI_API_KEY
|
||||
|
||||
inputs = [{"input": "Panda emoji", "size_type": "512x512"}]
|
||||
|
||||
for i in inputs:
|
||||
seed = Input(**i)
|
||||
base64_data = await text_to_image(seed.input)
|
||||
assert base64_data != ""
|
||||
print(f"{seed.input} -> {base64_data}")
|
||||
flags = ";base64,"
|
||||
assert flags in base64_data
|
||||
ix = base64_data.find(flags) + len(flags)
|
||||
declaration = base64_data[0:ix]
|
||||
assert declaration
|
||||
data = base64_data[ix:]
|
||||
assert data
|
||||
assert base64.b64decode(data, validate=True)
|
||||
data = await text_to_image("Panda emoji", size_type="512x512")
|
||||
assert "base64" in data or "http" in data
|
||||
key = CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL
|
||||
CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL = None
|
||||
data = await text_to_image("Panda emoji", size_type="512x512")
|
||||
assert "base64" in data or "http" in data
|
||||
CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL = key
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -6,40 +6,33 @@
|
|||
@File : test_text_to_speech.py
|
||||
@Desc : Unit tests.
|
||||
"""
|
||||
import asyncio
|
||||
import base64
|
||||
|
||||
from pydantic import BaseModel
|
||||
import pytest
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.learn.text_to_speech import text_to_speech
|
||||
|
||||
|
||||
async def mock_text_to_speech():
|
||||
class Input(BaseModel):
|
||||
input: str
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_to_speech():
|
||||
# Prerequisites
|
||||
assert CONFIG.IFLYTEK_APP_ID
|
||||
assert CONFIG.IFLYTEK_API_KEY
|
||||
assert CONFIG.IFLYTEK_API_SECRET
|
||||
assert CONFIG.AZURE_TTS_SUBSCRIPTION_KEY and CONFIG.AZURE_TTS_SUBSCRIPTION_KEY != "YOUR_API_KEY"
|
||||
assert CONFIG.AZURE_TTS_REGION
|
||||
|
||||
inputs = [{"input": "Panda emoji"}]
|
||||
# test azure
|
||||
data = await text_to_speech("panda emoji")
|
||||
assert "base64" in data or "http" in data
|
||||
|
||||
for i in inputs:
|
||||
seed = Input(**i)
|
||||
base64_data = await text_to_speech(seed.input)
|
||||
assert base64_data != ""
|
||||
print(f"{seed.input} -> {base64_data}")
|
||||
flags = ";base64,"
|
||||
assert flags in base64_data
|
||||
ix = base64_data.find(flags) + len(flags)
|
||||
declaration = base64_data[0:ix]
|
||||
assert declaration
|
||||
data = base64_data[ix:]
|
||||
assert data
|
||||
assert base64.b64decode(data, validate=True)
|
||||
|
||||
|
||||
def test_suite():
|
||||
loop = asyncio.get_event_loop()
|
||||
task = loop.create_task(mock_text_to_speech())
|
||||
loop.run_until_complete(task)
|
||||
# test iflytek
|
||||
key = CONFIG.AZURE_TTS_SUBSCRIPTION_KEY
|
||||
CONFIG.AZURE_TTS_SUBSCRIPTION_KEY = ""
|
||||
data = await text_to_speech("panda emoji")
|
||||
assert "base64" in data or "http" in data
|
||||
CONFIG.AZURE_TTS_SUBSCRIPTION_KEY = key
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_suite()
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -14,9 +14,9 @@ def test_skill_manager():
|
|||
manager = SkillManager()
|
||||
logger.info(manager._store)
|
||||
|
||||
write_prd = WritePRD("WritePRD")
|
||||
write_prd = WritePRD(name="WritePRD")
|
||||
write_prd.desc = "基于老板或其他人的需求进行PRD的撰写,包括用户故事、需求分解等"
|
||||
write_test = WriteTest("WriteTest")
|
||||
write_test = WriteTest(name="WriteTest")
|
||||
write_test.desc = "进行测试用例的撰写"
|
||||
manager.add_skill(write_prd)
|
||||
manager.add_skill(write_test)
|
||||
|
|
@ -24,7 +24,7 @@ def test_skill_manager():
|
|||
skill = manager.get_skill("WriteTest")
|
||||
logger.info(skill)
|
||||
|
||||
rsp = manager.retrieve_skill("写PRD")
|
||||
rsp = manager.retrieve_skill("WritePRD")
|
||||
logger.info(rsp)
|
||||
assert rsp[0] == "WritePRD"
|
||||
|
||||
|
|
|
|||
|
|
@ -5,47 +5,64 @@
|
|||
@Author : mashenquan
|
||||
@File : test_brain_memory.py
|
||||
"""
|
||||
# import json
|
||||
# from typing import List
|
||||
#
|
||||
# import pydantic
|
||||
#
|
||||
# from metagpt.memory.brain_memory import BrainMemory
|
||||
# from metagpt.schema import Message
|
||||
#
|
||||
#
|
||||
# def test_json():
|
||||
# class Input(pydantic.BaseModel):
|
||||
# history: List[str]
|
||||
# solution: List[str]
|
||||
# knowledge: List[str]
|
||||
# stack: List[str]
|
||||
#
|
||||
# inputs = [{"history": ["a", "b"], "solution": ["c"], "knowledge": ["d", "e"], "stack": ["f"]}]
|
||||
#
|
||||
# for i in inputs:
|
||||
# v = Input(**i)
|
||||
# bm = BrainMemory()
|
||||
# for h in v.history:
|
||||
# msg = Message(content=h)
|
||||
# bm.history.append(msg.model_dump())
|
||||
# for h in v.solution:
|
||||
# msg = Message(content=h)
|
||||
# bm.solution.append(msg.model_dump())
|
||||
# for h in v.knowledge:
|
||||
# msg = Message(content=h)
|
||||
# bm.knowledge.append(msg.model_dump())
|
||||
# for h in v.stack:
|
||||
# msg = Message(content=h)
|
||||
# bm.stack.append(msg.model_dump())
|
||||
# s = bm.json()
|
||||
# m = json.loads(s)
|
||||
# bm = BrainMemory(**m)
|
||||
# assert bm
|
||||
# for v in bm.history:
|
||||
# msg = Message(**v)
|
||||
# assert msg
|
||||
#
|
||||
#
|
||||
# if __name__ == "__main__":
|
||||
# test_json()
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.config import LLMProviderEnum
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.memory.brain_memory import BrainMemory
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory():
|
||||
memory = BrainMemory()
|
||||
memory.add_talk(Message(content="talk"))
|
||||
assert memory.history[0].role == "user"
|
||||
memory.add_answer(Message(content="answer"))
|
||||
assert memory.history[1].role == "assistant"
|
||||
redis_key = BrainMemory.to_redis_key("none", "user_id", "chat_id")
|
||||
await memory.dumps(redis_key=redis_key)
|
||||
assert memory.exists("talk")
|
||||
assert 1 == memory.to_int("1", 0)
|
||||
memory.last_talk = "AAA"
|
||||
assert memory.pop_last_talk() == "AAA"
|
||||
assert memory.last_talk is None
|
||||
assert memory.is_history_available
|
||||
assert memory.history_text
|
||||
|
||||
memory = await BrainMemory.loads(redis_key=redis_key)
|
||||
assert memory
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("input", "tag", "val"),
|
||||
[("[TALK]:Hello", "TALK", "Hello"), ("Hello", None, "Hello"), ("[TALK]Hello", None, "[TALK]Hello")],
|
||||
)
|
||||
def test_extract_info(input, tag, val):
|
||||
t, v = BrainMemory.extract_info(input)
|
||||
assert tag == t
|
||||
assert val == v
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("llm", [LLM(provider=LLMProviderEnum.OPENAI), LLM(provider=LLMProviderEnum.METAGPT)])
|
||||
async def test_memory_llm(llm):
|
||||
memory = BrainMemory()
|
||||
for i in range(500):
|
||||
memory.add_talk(Message(content="Lily is a girl.\n"))
|
||||
|
||||
res = await memory.is_related("apple", "moon", llm)
|
||||
assert not res
|
||||
|
||||
res = await memory.rewrite(sentence="apple Lily eating", context="", llm=llm)
|
||||
assert "Lily" in res
|
||||
|
||||
res = await memory.get_title(llm=llm)
|
||||
assert res
|
||||
assert "Lily" in res
|
||||
assert memory.history or memory.historical_summary
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
"""
|
||||
@Time : 2023/5/7 17:40
|
||||
@Author : alexanderwu
|
||||
@File : test_base_gpt_api.py
|
||||
@File : test_base_llm.py
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
|
@ -27,7 +27,7 @@ prompt_msg = "who are you"
|
|||
resp_content = default_chat_resp["choices"][0]["message"]["content"]
|
||||
|
||||
|
||||
class MockBaseGPTAPI(BaseLLM):
|
||||
class MockBaseLLM(BaseLLM):
|
||||
def completion(self, messages: list[dict], timeout=3):
|
||||
return default_chat_resp
|
||||
|
||||
|
|
@ -41,12 +41,12 @@ class MockBaseGPTAPI(BaseLLM):
|
|||
return default_chat_resp
|
||||
|
||||
|
||||
def test_base_gpt_api():
|
||||
def test_base_llm():
|
||||
message = Message(role="user", content="hello")
|
||||
assert "role" in message.to_dict()
|
||||
assert "user" in str(message)
|
||||
|
||||
base_gpt_api = MockBaseGPTAPI()
|
||||
base_llm = MockBaseLLM()
|
||||
|
||||
openai_funccall_resp = {
|
||||
"choices": [
|
||||
|
|
@ -70,37 +70,37 @@ def test_base_gpt_api():
|
|||
}
|
||||
]
|
||||
}
|
||||
func: dict = base_gpt_api.get_choice_function(openai_funccall_resp)
|
||||
func: dict = base_llm.get_choice_function(openai_funccall_resp)
|
||||
assert func == {
|
||||
"name": "execute",
|
||||
"arguments": '{\n "language": "python",\n "code": "print(\'Hello, World!\')"\n}',
|
||||
}
|
||||
|
||||
func_args: dict = base_gpt_api.get_choice_function_arguments(openai_funccall_resp)
|
||||
func_args: dict = base_llm.get_choice_function_arguments(openai_funccall_resp)
|
||||
assert func_args == {"language": "python", "code": "print('Hello, World!')"}
|
||||
|
||||
choice_text = base_gpt_api.get_choice_text(openai_funccall_resp)
|
||||
choice_text = base_llm.get_choice_text(openai_funccall_resp)
|
||||
assert choice_text == openai_funccall_resp["choices"][0]["message"]["content"]
|
||||
|
||||
# resp = base_gpt_api.ask(prompt_msg)
|
||||
# resp = base_llm.ask(prompt_msg)
|
||||
# assert resp == resp_content
|
||||
|
||||
# resp = base_gpt_api.ask_batch([prompt_msg])
|
||||
# resp = base_llm.ask_batch([prompt_msg])
|
||||
# assert resp == resp_content
|
||||
|
||||
# resp = base_gpt_api.ask_code([prompt_msg])
|
||||
# resp = base_llm.ask_code([prompt_msg])
|
||||
# assert resp == resp_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_base_gpt_api():
|
||||
base_gpt_api = MockBaseGPTAPI()
|
||||
async def test_async_base_llm():
|
||||
base_llm = MockBaseLLM()
|
||||
|
||||
resp = await base_gpt_api.aask(prompt_msg)
|
||||
resp = await base_llm.aask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await base_gpt_api.aask_batch([prompt_msg])
|
||||
resp = await base_llm.aask_batch([prompt_msg])
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await base_gpt_api.aask_code([prompt_msg])
|
||||
resp = await base_llm.aask_code([prompt_msg])
|
||||
assert resp == resp_content
|
||||
|
|
|
|||
14
tests/metagpt/provider/test_metagpt_api.py
Normal file
14
tests/metagpt/provider/test_metagpt_api.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/28
|
||||
@Author : mashenquan
|
||||
@File : test_metagpt_api.py
|
||||
"""
|
||||
from metagpt.config import LLMProviderEnum
|
||||
from metagpt.llm import LLM
|
||||
|
||||
|
||||
def test_llm():
|
||||
llm = LLM(provider=LLMProviderEnum.METAGPT)
|
||||
assert llm
|
||||
|
|
@ -16,6 +16,7 @@ from tests.metagpt.roles.mock import MockMessages
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_architect():
|
||||
# FIXME: make git as env? Or should we support
|
||||
role = Architect()
|
||||
role.put_message(MockMessages.req)
|
||||
rsp = await role.run(MockMessages.prd)
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ async def test_run():
|
|||
{
|
||||
"content": "who is tulin",
|
||||
"role": "user",
|
||||
"id": 1,
|
||||
"id": "1",
|
||||
},
|
||||
{"content": "The one who eaten a poison apple.", "role": "assistant"},
|
||||
],
|
||||
|
|
@ -53,7 +53,7 @@ async def test_run():
|
|||
{
|
||||
"content": "can you draw me an picture?",
|
||||
"role": "user",
|
||||
"id": 1,
|
||||
"id": "1",
|
||||
},
|
||||
{"content": "Yes, of course. What do you want me to draw", "role": "assistant"},
|
||||
],
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from pathlib import Path
|
|||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from metagpt.const import DATA_PATH, TEST_DATA_PATH
|
||||
from metagpt.roles.invoice_ocr_assistant import InvoiceOCRAssistant, InvoicePath
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
|
@ -22,29 +23,29 @@ from metagpt.schema import Message
|
|||
[
|
||||
(
|
||||
"Invoicing date",
|
||||
Path("../../data/invoices/invoice-1.pdf"),
|
||||
Path("../../../data/invoice_table/invoice-1.xlsx"),
|
||||
Path("invoices/invoice-1.pdf"),
|
||||
Path("invoice_table/invoice-1.xlsx"),
|
||||
{"收款人": "小明", "城市": "深圳", "总费用/元": 412.00, "开票日期": "2023年02月03日"},
|
||||
),
|
||||
(
|
||||
"Invoicing date",
|
||||
Path("../../data/invoices/invoice-2.png"),
|
||||
Path("../../../data/invoice_table/invoice-2.xlsx"),
|
||||
Path("invoices/invoice-2.png"),
|
||||
Path("invoice_table/invoice-2.xlsx"),
|
||||
{"收款人": "铁头", "城市": "广州", "总费用/元": 898.00, "开票日期": "2023年03月17日"},
|
||||
),
|
||||
(
|
||||
"Invoicing date",
|
||||
Path("../../data/invoices/invoice-3.jpg"),
|
||||
Path("../../../data/invoice_table/invoice-3.xlsx"),
|
||||
Path("invoices/invoice-3.jpg"),
|
||||
Path("invoice_table/invoice-3.xlsx"),
|
||||
{"收款人": "夏天", "城市": "福州", "总费用/元": 2462.00, "开票日期": "2023年08月26日"},
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_invoice_ocr_assistant(query: str, invoice_path: Path, invoice_table_path: Path, expected_result: dict):
|
||||
invoice_path = Path.cwd() / invoice_path
|
||||
invoice_path = TEST_DATA_PATH / invoice_path
|
||||
role = InvoiceOCRAssistant()
|
||||
await role.run(Message(content=query, instruct_content=InvoicePath(file_path=invoice_path)))
|
||||
invoice_table_path = Path.cwd() / invoice_table_path
|
||||
invoice_table_path = DATA_PATH / invoice_table_path
|
||||
df = pd.read_excel(invoice_table_path)
|
||||
resp = df.to_dict(orient="records")
|
||||
assert isinstance(resp, list)
|
||||
|
|
@ -52,5 +53,5 @@ async def test_invoice_ocr_assistant(query: str, invoice_path: Path, invoice_tab
|
|||
resp = resp[0]
|
||||
assert expected_result["收款人"] == resp["收款人"]
|
||||
assert expected_result["城市"] in resp["城市"]
|
||||
assert int(expected_result["总费用/元"]) == int(resp["总费用/元"])
|
||||
assert float(expected_result["总费用/元"]) == float(resp["总费用/元"])
|
||||
assert expected_result["开票日期"] == resp["开票日期"]
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from tests.metagpt.roles.mock import MockMessages
|
|||
@pytest.mark.asyncio
|
||||
async def test_product_manager():
|
||||
product_manager = ProductManager()
|
||||
rsp = await product_manager.handle(MockMessages.req)
|
||||
rsp = await product_manager.run(MockMessages.req)
|
||||
logger.info(rsp)
|
||||
assert len(rsp.content) > 0
|
||||
assert "Product Goals" in rsp.content
|
||||
|
|
|
|||
|
|
@ -15,5 +15,5 @@ from tests.metagpt.roles.mock import MockMessages
|
|||
@pytest.mark.asyncio
|
||||
async def test_project_manager():
|
||||
project_manager = ProjectManager()
|
||||
rsp = await project_manager.handle(MockMessages.system_design)
|
||||
rsp = await project_manager.run(MockMessages.system_design)
|
||||
logger.info(rsp)
|
||||
|
|
|
|||
|
|
@ -28,7 +28,23 @@ async def mock_llm_ask(self, prompt: str, system_msgs):
|
|||
async def test_researcher(mocker):
|
||||
with TemporaryDirectory() as dirname:
|
||||
topic = "dataiku vs. datarobot"
|
||||
mocker.patch("metagpt.provider.base_gpt_api.BaseGPTAPI.aask", mock_llm_ask)
|
||||
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_llm_ask)
|
||||
researcher.RESEARCH_PATH = Path(dirname)
|
||||
await researcher.Researcher().run(topic)
|
||||
assert (researcher.RESEARCH_PATH / f"{topic}.md").read_text().startswith("# Research Report")
|
||||
|
||||
|
||||
def test_write_report(mocker):
|
||||
with TemporaryDirectory() as dirname:
|
||||
for i, topic in enumerate(
|
||||
[
|
||||
("1./metagpt"),
|
||||
('2.:"metagpt'),
|
||||
("3.*?<>|metagpt"),
|
||||
("4. metagpt\n"),
|
||||
]
|
||||
):
|
||||
researcher.RESEARCH_PATH = Path(dirname)
|
||||
content = "# Research Report"
|
||||
researcher.Researcher().write_report(topic, content)
|
||||
assert (researcher.RESEARCH_PATH / f"{i+1}. metagpt.md").read_text().startswith("# Research Report")
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
@Author : Stitch-z
|
||||
@File : test_tutorial_assistant.py
|
||||
"""
|
||||
import shutil
|
||||
|
||||
import aiofiles
|
||||
import pytest
|
||||
|
|
@ -17,8 +16,6 @@ from metagpt.roles.tutorial_assistant import TutorialAssistant
|
|||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(("language", "topic"), [("Chinese", "Write a tutorial about pip")])
|
||||
async def test_tutorial_assistant(language: str, topic: str):
|
||||
shutil.rmtree(path=TUTORIAL_PATH, ignore_errors=True)
|
||||
|
||||
role = TutorialAssistant(language=language)
|
||||
msg = await role.run(topic)
|
||||
assert TUTORIAL_PATH.exists()
|
||||
|
|
|
|||
|
|
@ -9,23 +9,25 @@ import pytest
|
|||
from typer.testing import CliRunner
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.startup import app
|
||||
from metagpt.team import Team
|
||||
|
||||
runner = CliRunner()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_team():
|
||||
async def test_empty_team():
|
||||
# FIXME: we're now using "metagpt" cli, so the entrance should be replaced instead.
|
||||
company = Team()
|
||||
company.run_project("做一个基础搜索引擎,可以支持知识库")
|
||||
history = await company.run(n_round=5)
|
||||
history = await company.run(idea="Build a simple search system. I will upload my files later.")
|
||||
logger.info(history)
|
||||
|
||||
|
||||
# def test_startup():
|
||||
# args = ["Make a 2048 game"]
|
||||
# result = runner.invoke(app, args)
|
||||
def test_startup():
|
||||
args = ["Make a cli snake game"]
|
||||
result = runner.invoke(app, args)
|
||||
logger.info(result)
|
||||
logger.info(result.output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -10,4 +10,4 @@ def test_team():
|
|||
company = Team()
|
||||
company.hire([ProjectManager()])
|
||||
|
||||
assert len(company.environment.roles) == 1
|
||||
assert len(company.env.roles) == 1
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@
|
|||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
|
|
@ -25,7 +27,7 @@ class MockSearchEnine:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("search_engine_typpe", "run_func", "max_results", "as_string"),
|
||||
("search_engine_type", "run_func", "max_results", "as_string"),
|
||||
[
|
||||
(SearchEngineType.SERPAPI_GOOGLE, None, 8, True),
|
||||
(SearchEngineType.SERPAPI_GOOGLE, None, 4, False),
|
||||
|
|
@ -39,23 +41,18 @@ class MockSearchEnine:
|
|||
(SearchEngineType.CUSTOM_ENGINE, MockSearchEnine().run, 6, False),
|
||||
],
|
||||
)
|
||||
async def test_search_engine(
|
||||
search_engine_typpe,
|
||||
run_func,
|
||||
max_results,
|
||||
as_string,
|
||||
):
|
||||
async def test_search_engine(search_engine_type, run_func: Callable, max_results: int, as_string: bool):
|
||||
# Prerequisites
|
||||
if search_engine_typpe is SearchEngineType.SERPAPI_GOOGLE:
|
||||
if search_engine_type is SearchEngineType.SERPAPI_GOOGLE:
|
||||
assert CONFIG.SERPAPI_API_KEY and CONFIG.SERPAPI_API_KEY != "YOUR_API_KEY"
|
||||
elif search_engine_typpe is SearchEngineType.DIRECT_GOOGLE:
|
||||
elif search_engine_type is SearchEngineType.DIRECT_GOOGLE:
|
||||
assert CONFIG.GOOGLE_API_KEY and CONFIG.GOOGLE_API_KEY != "YOUR_API_KEY"
|
||||
assert CONFIG.GOOGLE_CSE_ID and CONFIG.GOOGLE_CSE_ID != "YOUR_CSE_ID"
|
||||
elif search_engine_typpe is SearchEngineType.SERPER_GOOGLE:
|
||||
elif search_engine_type is SearchEngineType.SERPER_GOOGLE:
|
||||
assert CONFIG.SERPER_API_KEY and CONFIG.SERPER_API_KEY != "YOUR_API_KEY"
|
||||
|
||||
search_engine = SearchEngine(search_engine_typpe, run_func)
|
||||
rsp = await search_engine.run("metagpt", max_results=max_results, as_string=as_string)
|
||||
search_engine = SearchEngine(search_engine_type, run_func)
|
||||
rsp = await search_engine.run("metagpt", max_results, as_string)
|
||||
logger.info(rsp)
|
||||
if as_string:
|
||||
assert isinstance(rsp, str)
|
||||
|
|
|
|||
|
|
@ -111,27 +111,27 @@ class TestCodeParser:
|
|||
def test_parse_blocks(self, parser, text):
|
||||
result = parser.parse_blocks(text)
|
||||
print(result)
|
||||
assert result == {"title": "content", "title2": "content2"}
|
||||
assert "game.py" in result["Task list"]
|
||||
|
||||
def test_parse_block(self, parser, text):
|
||||
result = parser.parse_block("title", text)
|
||||
result = parser.parse_block("Task list", text)
|
||||
print(result)
|
||||
assert result == "content"
|
||||
assert "game.py" in result
|
||||
|
||||
def test_parse_code(self, parser, text):
|
||||
result = parser.parse_code("title", text, "python")
|
||||
result = parser.parse_code("Task list", text, "python")
|
||||
print(result)
|
||||
assert result == "print('hello world')"
|
||||
assert "game.py" in result
|
||||
|
||||
def test_parse_str(self, parser, text):
|
||||
result = parser.parse_str("title", text, "python")
|
||||
result = parser.parse_str("Anything UNCLEAR", text, "python")
|
||||
print(result)
|
||||
assert result == "hello world"
|
||||
assert "We need clarification on how the high score " in result
|
||||
|
||||
def test_parse_file_list(self, parser, text):
|
||||
result = parser.parse_file_list("Task list", text)
|
||||
print(result)
|
||||
assert result == ["task1", "task2"]
|
||||
assert "game.py" in result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -47,7 +47,8 @@ class TestGetProjectRoot:
|
|||
|
||||
def test_get_project_root(self):
|
||||
project_root = get_metagpt_root()
|
||||
assert project_root.name == "MetaGPT"
|
||||
src_path = project_root / "metagpt"
|
||||
assert src_path.exists()
|
||||
|
||||
def test_get_root_exception(self):
|
||||
self.change_etc_dir()
|
||||
|
|
|
|||
|
|
@ -21,10 +21,11 @@ def test_config_class_get_key_exception():
|
|||
|
||||
|
||||
def test_config_yaml_file_not_exists():
|
||||
config = Config("wtf.yaml")
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
config.get("OPENAI_BASE_URL")
|
||||
assert str(exc_info.value) == "Set OPENAI_API_KEY or Anthropic_API_KEY first"
|
||||
# FIXME: 由于这里是单例,所以会导致Config重新创建失效。后续要将Config改为非单例模式。
|
||||
_ = Config("wtf.yaml")
|
||||
# with pytest.raises(Exception) as exc_info:
|
||||
# config.get("OPENAI_BASE_URL")
|
||||
# assert str(exc_info.value) == "Set OPENAI_API_KEY or Anthropic_API_KEY first"
|
||||
|
||||
|
||||
def test_options():
|
||||
|
|
|
|||
|
|
@ -10,29 +10,31 @@ import pytest
|
|||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.utils.common import check_cmd_exists
|
||||
from metagpt.utils.mermaid import MMC1, MMC2, mermaid_to_file
|
||||
from metagpt.utils.mermaid import MMC1, mermaid_to_file
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("engine", ["nodejs", "playwright", "pyppeteer", "ink"])
|
||||
@pytest.mark.parametrize("engine", ["nodejs", "ink"]) # TODO: playwright and pyppeteer
|
||||
async def test_mermaid(engine):
|
||||
# Prerequisites
|
||||
# npm install -g @mermaid-js/mermaid-cli
|
||||
# nodejs prerequisites: npm install -g @mermaid-js/mermaid-cli
|
||||
# ink prerequisites: connected to internet
|
||||
# playwright prerequisites: playwright install --with-deps chromium
|
||||
assert check_cmd_exists("npm") == 0
|
||||
assert CONFIG.PYPPETEER_EXECUTABLE_PATH
|
||||
|
||||
CONFIG.mermaid_engine = engine
|
||||
save_to = CONFIG.git_repo.workdir / f"{CONFIG.mermaid_engine}/1"
|
||||
await mermaid_to_file(MMC1, save_to)
|
||||
for ext in [".pdf", ".svg", ".png"]:
|
||||
assert save_to.with_suffix(ext).exists()
|
||||
save_to.with_suffix(ext).unlink(missing_ok=True)
|
||||
|
||||
save_to = CONFIG.git_repo.workdir / f"{CONFIG.mermaid_engine}/2"
|
||||
await mermaid_to_file(MMC2, save_to)
|
||||
for ext in [".pdf", ".svg", ".png"]:
|
||||
assert save_to.with_suffix(ext).exists()
|
||||
save_to.with_suffix(ext).unlink(missing_ok=True)
|
||||
# ink does not support pdf
|
||||
if engine == "ink":
|
||||
for ext in [".svg", ".png"]:
|
||||
assert save_to.with_suffix(ext).exists()
|
||||
save_to.with_suffix(ext).unlink(missing_ok=True)
|
||||
else:
|
||||
for ext in [".pdf", ".svg", ".png"]:
|
||||
assert save_to.with_suffix(ext).exists()
|
||||
save_to.with_suffix(ext).unlink(missing_ok=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -54,13 +54,13 @@ def test_parse_file_list():
|
|||
expected_result = ["file1", "file2", "file3"]
|
||||
assert OutputParser.parse_file_list(test_text) == expected_result
|
||||
|
||||
with pytest.raises(Exception):
|
||||
OutputParser.parse_file_list("wrong_input")
|
||||
# with pytest.raises(Exception):
|
||||
# OutputParser.parse_file_list("wrong_input")
|
||||
|
||||
|
||||
def test_parse_data():
|
||||
test_data = "##block1\n```python\nprint('Hello, world!')\n```\n##block2\nfiles=['file1', 'file2', 'file3']"
|
||||
expected_result = {"block1": "print('Hello, world!')", "block2": ["file1", "file2", "file3"]}
|
||||
expected_result = {"block1": "print('Hello, world!')\n", "block2": ["file1", "file2", "file3"]}
|
||||
assert OutputParser.parse_data(test_data) == expected_result
|
||||
|
||||
|
||||
|
|
@ -94,7 +94,7 @@ def test_parse_data():
|
|||
(
|
||||
"""xxx xx""",
|
||||
list,
|
||||
None,
|
||||
[],
|
||||
[],
|
||||
),
|
||||
(
|
||||
|
|
|
|||
|
|
@ -45,9 +45,11 @@ async def test_s3():
|
|||
@pytest.mark.asyncio
|
||||
async def test_s3_no_error():
|
||||
conn = S3()
|
||||
key = conn.auth_config["aws_secret_access_key"]
|
||||
conn.auth_config["aws_secret_access_key"] = ""
|
||||
res = await conn.cache("ABC", ".bak", "script")
|
||||
assert not res
|
||||
conn.auth_config["aws_secret_access_key"] = key
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue