feat: merge

This commit is contained in:
莘权 马 2023-12-28 18:05:33 +08:00
commit f76078dedf
95 changed files with 1629 additions and 948 deletions

View file

@ -89,6 +89,7 @@ def loguru_caplog(caplog):
@pytest.fixture(scope="session", autouse=True)
def setup_and_teardown_git_repo(request):
CONFIG.git_repo = GitRepository(local_path=DEFAULT_WORKSPACE_ROOT / "unittest")
CONFIG.git_reinit = True
# Destroy git repo at the end of the test session.
def fin():

View file

@ -12,6 +12,7 @@ import pytest
from metagpt.actions import Action
from metagpt.actions.action_node import ActionNode
from metagpt.environment import Environment
from metagpt.llm import LLM
from metagpt.roles import Role
from metagpt.schema import Message
from metagpt.team import Team
@ -76,18 +77,24 @@ async def test_action_node_one_layer():
assert "key-a" in markdown_template
assert node_dict["key-a"] == "instruction-b"
assert "key-a" in repr(node)
@pytest.mark.asyncio
async def test_action_node_two_layer():
node_a = ActionNode(key="key-a", expected_type=str, instruction="i-a", example="e-a")
node_b = ActionNode(key="key-b", expected_type=str, instruction="i-b", example="e-b")
node_a = ActionNode(key="reasoning", expected_type=str, instruction="reasoning step by step", example="")
node_b = ActionNode(key="answer", expected_type=str, instruction="the final answer", example="")
root = ActionNode.from_children(key="", nodes=[node_a, node_b])
assert "key-a" in root.children
root = ActionNode.from_children(key="detail answer", nodes=[node_a, node_b])
assert "reasoning" in root.children
assert node_b in root.children.values()
json_template = root.compile(context="123", schema="json", mode="auto")
assert "i-a" in json_template
# FIXME: ADD MARKDOWN SUPPORT. NEED TO TUNE MARKDOWN SYMBOL FIRST.
answer1 = await root.fill(context="what's the answer to 123+456?", schema="json", strgy="simple", llm=LLM())
assert "579" in answer1.content
answer2 = await root.fill(context="what's the answer to 123+456?", schema="json", strgy="complex", llm=LLM())
assert "579" in answer2.content
t_dict = {
@ -116,16 +123,33 @@ WRITE_TASKS_OUTPUT_MAPPING = {
"Anything UNCLEAR": (str, ...),
}
WRITE_TASKS_OUTPUT_MAPPING_MISSING = {
"Required Python third-party packages": (str, ...),
}
def test_create_model_class():
test_class = ActionNode.create_model_class("test_class", WRITE_TASKS_OUTPUT_MAPPING)
assert test_class.__name__ == "test_class"
output = test_class(**t_dict)
print(output.schema())
assert output.schema()["title"] == "test_class"
assert output.schema()["type"] == "object"
assert output.schema()["properties"]["Full API spec"]
def test_create_model_class_missing():
test_class = ActionNode.create_model_class("test_class", WRITE_TASKS_OUTPUT_MAPPING_MISSING)
assert test_class.__name__ == "test_class"
_ = test_class(**t_dict) # 这里应该要挂掉
def test_create_model_class_with_mapping():
t = ActionNode.create_model_class("test_class_1", WRITE_TASKS_OUTPUT_MAPPING)
t1 = t(**t_dict)
value = t1.dict()["Task list"]
value = t1.model_dump()["Task list"]
assert value == ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"]

View file

@ -1,101 +0,0 @@
import os
import tempfile
import pytest
from metagpt.actions.clone_function import (
CloneFunction,
run_function_code,
run_function_script,
)
source_code = """
import pandas as pd
import ta
def user_indicator():
# 读取股票数据
stock_data = pd.read_csv('./tests/data/baba_stock.csv')
stock_data.head()
# 计算简单移动平均线
stock_data['SMA'] = ta.trend.sma_indicator(stock_data['Close'], window=6)
stock_data[['Date', 'Close', 'SMA']].head()
# 计算布林带
stock_data['bb_upper'], stock_data['bb_middle'], stock_data['bb_lower'] = ta.volatility.bollinger_hband_indicator(stock_data['Close'], window=20), ta.volatility.bollinger_mavg(stock_data['Close'], window=20), ta.volatility.bollinger_lband_indicator(stock_data['Close'], window=20)
stock_data[['Date', 'Close', 'bb_upper', 'bb_middle', 'bb_lower']].head()
"""
template_code = """
def stock_indicator(stock_path: str, indicators=['Simple Moving Average', 'BollingerBands', 'MACD]) -> pd.DataFrame:
import pandas as pd
# here is your code.
"""
def get_expected_res():
import pandas as pd
import ta
# 读取股票数据
stock_data = pd.read_csv("./tests/data/baba_stock.csv")
stock_data.head()
# 计算简单移动平均线
stock_data["SMA"] = ta.trend.sma_indicator(stock_data["Close"], window=6)
stock_data[["Date", "Close", "SMA"]].head()
# 计算布林带
stock_data["bb_upper"], stock_data["bb_middle"], stock_data["bb_lower"] = (
ta.volatility.bollinger_hband_indicator(stock_data["Close"], window=20),
ta.volatility.bollinger_mavg(stock_data["Close"], window=20),
ta.volatility.bollinger_lband_indicator(stock_data["Close"], window=20),
)
stock_data[["Date", "Close", "bb_upper", "bb_middle", "bb_lower"]].head()
return stock_data
@pytest.mark.asyncio
async def test_clone_function():
clone = CloneFunction()
code = await clone.run(template_code, source_code)
assert "def " in code
stock_path = "./tests/data/baba_stock.csv"
df, msg = run_function_code(code, "stock_indicator", stock_path)
assert not msg
expected_df = get_expected_res()
assert df.equals(expected_df)
def test_run_function_script():
# 创建一个临时文件并写入脚本内容
script_content = """def valid_function(arg1, arg2):\n return arg1 + arg2\n"""
with tempfile.NamedTemporaryFile(mode="w+", suffix=".py", delete=False) as temp_file:
temp_file.write(script_content)
temp_file_path = temp_file.name
invalid_script_content = """def valid_function(arg1, arg2)\n return arg1 + arg2\n"""
with tempfile.NamedTemporaryFile(mode="w+", suffix=".py", delete=False) as error_temp_file:
error_temp_file.write(invalid_script_content)
error_temp_file_path = error_temp_file.name
try:
# 正常情况下运行脚本
result, _ = run_function_script(temp_file_path, "valid_function", 1, arg2=2)
assert result == 3
# 不存在的脚本路径
with pytest.raises(FileNotFoundError):
run_function_script("nonexistent/path/script.py", "valid_function", 1, arg2=2)
# 无效的脚本内容
result, traceback = run_function_script(error_temp_file_path, "invalid_function", 1, arg2=2)
assert not result
assert "SyntaxError" in traceback
# 函数调用失败的情况
result, traceback = run_function_script(temp_file_path, "function_that_raises_exception", 1, arg2=2)
assert not result
assert "KeyError" in traceback
finally:
# 删除临时文件
if os.path.exists(temp_file_path):
os.remove(temp_file_path)

View file

@ -142,7 +142,7 @@ async def test_debug_error():
"Ran 5 tests in 0.007s\n\nFAILED (failures=1)\n;\n",
)
await FileRepository.save_file(
filename=ctx.output_filename, content=output_data.json(), relative_path=TEST_OUTPUTS_FILE_REPO
filename=ctx.output_filename, content=output_data.model_dump_json(), relative_path=TEST_OUTPUTS_FILE_REPO
)
debug_error = DebugError(context=ctx)

View file

@ -20,7 +20,7 @@ from metagpt.actions.invoice_ocr import GenerateTable, InvoiceOCR, ReplyQuestion
"invoice_path",
[
"../../data/invoices/invoice-3.jpg",
"../../data/invoices/invoice-4.zip",
# "../../data/invoices/invoice-4.zip",
],
)
async def test_invoice_ocr(invoice_path: str):

View file

@ -8,7 +8,7 @@
import pytest
from metagpt.actions import CollectLinks
from metagpt.actions import CollectLinks, research
@pytest.mark.asyncio
@ -18,5 +18,107 @@ async def test_action():
assert result
@pytest.mark.asyncio
async def test_collect_links(mocker):
async def mock_llm_ask(self, prompt: str, system_msgs):
if "Please provide up to 2 necessary keywords" in prompt:
return '["metagpt", "llm"]'
elif "Provide up to 4 queries related to your research topic" in prompt:
return (
'["MetaGPT use cases", "The roadmap of MetaGPT", '
'"The function of MetaGPT", "What llm MetaGPT support"]'
)
elif "sort the remaining search results" in prompt:
return "[1,2]"
mocker.patch("metagpt.provider.base_gpt_api.BaseGPTAPI.aask", mock_llm_ask)
resp = await research.CollectLinks().run("The application of MetaGPT")
for i in ["MetaGPT use cases", "The roadmap of MetaGPT", "The function of MetaGPT", "What llm MetaGPT support"]:
assert i in resp
@pytest.mark.asyncio
async def test_collect_links_with_rank_func(mocker):
rank_before = []
rank_after = []
url_per_query = 4
def rank_func(results):
results = results[:url_per_query]
rank_before.append(results)
results = results[::-1]
rank_after.append(results)
return results
mocker.patch("metagpt.provider.base_gpt_api.BaseGPTAPI.aask", mock_collect_links_llm_ask)
resp = await research.CollectLinks(rank_func=rank_func).run("The application of MetaGPT")
for x, y, z in zip(rank_before, rank_after, resp.values()):
assert x[::-1] == y
assert [i["link"] for i in y] == z
@pytest.mark.asyncio
async def test_web_browse_and_summarize(mocker):
async def mock_llm_ask(*args, **kwargs):
return "metagpt"
mocker.patch("metagpt.provider.base_gpt_api.BaseGPTAPI.aask", mock_llm_ask)
url = "https://github.com/geekan/MetaGPT"
url2 = "https://github.com/trending"
query = "What's new in metagpt"
resp = await research.WebBrowseAndSummarize().run(url, query=query)
assert len(resp) == 1
assert url in resp
assert resp[url] == "metagpt"
resp = await research.WebBrowseAndSummarize().run(url, url2, query=query)
assert len(resp) == 2
async def mock_llm_ask(*args, **kwargs):
return "Not relevant."
mocker.patch("metagpt.provider.base_gpt_api.BaseGPTAPI.aask", mock_llm_ask)
resp = await research.WebBrowseAndSummarize().run(url, query=query)
assert len(resp) == 1
assert url in resp
assert resp[url] is None
@pytest.mark.asyncio
async def test_conduct_research(mocker):
data = None
async def mock_llm_ask(*args, **kwargs):
nonlocal data
data = f"# Research Report\n## Introduction\n{args} {kwargs}"
return data
mocker.patch("metagpt.provider.base_gpt_api.BaseGPTAPI.aask", mock_llm_ask)
content = (
"MetaGPT takes a one line requirement as input and "
"outputs user stories / competitive analysis / requirements / data structures / APIs / documents, etc."
)
resp = await research.ConductResearch().run("The application of MetaGPT", content)
assert resp == data
async def mock_collect_links_llm_ask(self, prompt: str, system_msgs):
if "Please provide up to 2 necessary keywords" in prompt:
return '["metagpt", "llm"]'
elif "Provide up to 4 queries related to your research topic" in prompt:
return (
'["MetaGPT use cases", "The roadmap of MetaGPT", ' '"The function of MetaGPT", "What llm MetaGPT support"]'
)
elif "sort the remaining search results" in prompt:
return "[1,2]"
return ""
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -14,13 +14,13 @@ from metagpt.schema import RunCodeContext
@pytest.mark.asyncio
async def test_run_text():
result, errs = await RunCode.run_text("result = 1 + 1")
assert result == 2
assert errs == ""
out, err = await RunCode.run_text("result = 1 + 1")
assert out == 2
assert err == ""
result, errs = await RunCode.run_text("result = 1 / 0")
assert result == ""
assert "ZeroDivisionError" in errs
out, err = await RunCode.run_text("result = 1 / 0")
assert out == ""
assert "division by zero" in err
@pytest.mark.asyncio

View file

@ -32,11 +32,11 @@ async def test_write_code():
context = CodingContext(
filename="task_filename.py", design_doc=Document(content="设计一个名为'add'的函数,该函数接受两个整数作为输入,并返回它们的和。")
)
doc = Document(content=context.json())
doc = Document(content=context.model_dump_json())
write_code = WriteCode(context=doc)
code = await write_code.run()
logger.info(code.json())
logger.info(code.model_dump_json())
# 我们不能精确地预测生成的代码,但我们可以检查某些关键字
assert "def add" in code.code_doc.content

View file

@ -29,7 +29,7 @@ async def test_write_test():
write_test = WriteTest(context=context)
context = await write_test.run()
logger.info(context.json())
logger.info(context.model_dump_json())
# We cannot exactly predict the generated test cases, but we can check if it is a string and if it is not empty
assert isinstance(context.test_doc.content, str)

View file

@ -1,36 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/6/11 21:08
@Author : alexanderwu
@File : test_milvus_store.py
"""
import random
import numpy as np
from metagpt.document_store.milvus_store import MilvusConnection, MilvusStore
from metagpt.logs import logger
book_columns = {"idx": int, "name": str, "desc": str, "emb": np.ndarray, "price": float}
book_data = [
[i for i in range(10)],
[f"book-{i}" for i in range(10)],
[f"book-desc-{i}" for i in range(10000, 10010)],
[[random.random() for _ in range(2)] for _ in range(10)],
[random.random() for _ in range(10)],
]
def test_milvus_store():
milvus_connection = MilvusConnection(alias="default", host="192.168.50.161", port="30530")
milvus_store = MilvusStore(milvus_connection)
milvus_store.drop("Book")
milvus_store.create_collection("Book", book_columns)
milvus_store.add(book_data)
milvus_store.build_index("emb")
milvus_store.load_collection()
results = milvus_store.search([[1.0, 1.0]], field="emb")
logger.info(results)
assert results

View file

@ -29,7 +29,7 @@ points = [
]
def test_milvus_store():
def test_qdrant_store():
qdrant_connection = QdrantConnection(memory=True)
vectors_config = VectorParams(size=2, distance=Distance.COSINE)
qdrant_store = QdrantStore(qdrant_connection)
@ -43,13 +43,13 @@ def test_milvus_store():
results = qdrant_store.search("Book", query=[1.0, 1.0])
assert results[0]["id"] == 2
assert results[0]["score"] == 0.999106722578389
assert results[1]["score"] == 7
assert results[1]["id"] == 7
assert results[1]["score"] == 0.9961650411397226
results = qdrant_store.search("Book", query=[1.0, 1.0], return_vector=True)
assert results[0]["id"] == 2
assert results[0]["score"] == 0.999106722578389
assert results[0]["vector"] == [0.7363563179969788, 0.6765939593315125]
assert results[1]["score"] == 7
assert results[1]["id"] == 7
assert results[1]["score"] == 0.9961650411397226
assert results[1]["vector"] == [0.7662628889083862, 0.6425272226333618]
results = qdrant_store.search(

View file

@ -5,6 +5,7 @@
@Author : mashenquan
@File : test_brain_memory.py
"""
import pytest
from metagpt.config import LLMProviderEnum

View file

@ -86,31 +86,25 @@ class TestOpenAI:
def test_make_client_kwargs_without_proxy(self, config):
instance = OpenAILLM()
instance.config = config
kwargs, async_kwargs = instance._make_client_kwargs()
kwargs = instance._make_client_kwargs()
assert kwargs == {"api_key": "test_key", "base_url": "test_url"}
assert async_kwargs == {"api_key": "test_key", "base_url": "test_url"}
assert "http_client" not in kwargs
assert "http_client" not in async_kwargs
def test_make_client_kwargs_without_proxy_azure(self, config_azure):
instance = OpenAILLM()
instance.config = config_azure
kwargs, async_kwargs = instance._make_client_kwargs()
kwargs = instance._make_client_kwargs()
assert kwargs == {"api_key": "test_key", "base_url": "test_url"}
assert async_kwargs == {"api_key": "test_key", "base_url": "test_url"}
assert "http_client" not in kwargs
assert "http_client" not in async_kwargs
def test_make_client_kwargs_with_proxy(self, config_proxy):
instance = OpenAILLM()
instance.config = config_proxy
kwargs, async_kwargs = instance._make_client_kwargs()
kwargs = instance._make_client_kwargs()
assert "http_client" in kwargs
assert "http_client" in async_kwargs
def test_make_client_kwargs_with_proxy_azure(self, config_azure_proxy):
instance = OpenAILLM()
instance.config = config_azure_proxy
kwargs, async_kwargs = instance._make_client_kwargs()
kwargs = instance._make_client_kwargs()
assert "http_client" in kwargs
assert "http_client" in async_kwargs

View file

@ -32,3 +32,19 @@ async def test_researcher(mocker):
researcher.RESEARCH_PATH = Path(dirname)
await researcher.Researcher().run(topic)
assert (researcher.RESEARCH_PATH / f"{topic}.md").read_text().startswith("# Research Report")
def test_write_report(mocker):
with TemporaryDirectory() as dirname:
for i, topic in enumerate(
[
("1./metagpt"),
('2.:"metagpt'),
("3.*?<>|metagpt"),
("4. metagpt\n"),
]
):
researcher.RESEARCH_PATH = Path(dirname)
content = "# Research Report"
researcher.Researcher().write_report(topic, content)
assert (researcher.RESEARCH_PATH / f"{i+1}. metagpt.md").read_text().startswith("# Research Report")

View file

@ -8,4 +8,4 @@ from metagpt.roles.role import Role
def test_role_desc():
role = Role(profile="Sales", desc="Best Seller")
assert role.profile == "Sales"
assert role._setting.desc == "Best Seller"
assert role.desc == "Best Seller"

View file

@ -10,15 +10,20 @@ from metagpt.llm import LLM
def test_action_serialize():
action = Action()
ser_action_dict = action.dict()
ser_action_dict = action.model_dump()
assert "name" in ser_action_dict
# assert "llm" not in ser_action_dict # not export
assert "llm" not in ser_action_dict # not export
assert "__module_class_name" not in ser_action_dict
action = Action(name="test")
ser_action_dict = action.model_dump()
assert "test" in ser_action_dict["name"]
@pytest.mark.asyncio
async def test_action_deserialize():
action = Action()
serialized_data = action.dict()
serialized_data = action.model_dump()
new_action = Action(**serialized_data)

View file

@ -10,19 +10,19 @@ from metagpt.roles.architect import Architect
def test_architect_serialize():
role = Architect()
ser_role_dict = role.dict(by_alias=True)
ser_role_dict = role.model_dump(by_alias=True)
assert "name" in ser_role_dict
assert "_states" in ser_role_dict
assert "_actions" in ser_role_dict
assert "states" in ser_role_dict
assert "actions" in ser_role_dict
@pytest.mark.asyncio
async def test_architect_deserialize():
role = Architect()
ser_role_dict = role.dict(by_alias=True)
ser_role_dict = role.model_dump(by_alias=True)
new_role = Architect(**ser_role_dict)
# new_role = Architect.deserialize(ser_role_dict)
assert new_role.name == "Bob"
assert len(new_role._actions) == 1
assert isinstance(new_role._actions[0], Action)
await new_role._actions[0].run(with_messages="write a cli snake game")
assert len(new_role.actions) == 1
assert isinstance(new_role.actions[0], Action)
await new_role.actions[0].run(with_messages="write a cli snake game")

View file

@ -20,14 +20,15 @@ from tests.metagpt.serialize_deserialize.test_serdeser_base import (
def test_env_serialize():
env = Environment()
ser_env_dict = env.dict()
ser_env_dict = env.model_dump()
assert "roles" in ser_env_dict
assert len(ser_env_dict["roles"]) == 0
def test_env_deserialize():
env = Environment()
env.publish_message(message=Message(content="test env serialize"))
ser_env_dict = env.dict()
ser_env_dict = env.model_dump()
new_env = Environment(**ser_env_dict)
assert len(new_env.roles) == 0
assert len(new_env.history) == 25
@ -47,16 +48,16 @@ def test_environment_serdeser():
environment.add_role(role_c)
environment.publish_message(message)
ser_data = environment.dict()
ser_data = environment.model_dump()
assert ser_data["roles"]["Role C"]["name"] == "RoleC"
new_env: Environment = Environment(**ser_data)
assert len(new_env.roles) == 1
assert list(new_env.roles.values())[0]._states == list(environment.roles.values())[0]._states
assert list(new_env.roles.values())[0]._actions == list(environment.roles.values())[0]._actions
assert isinstance(list(environment.roles.values())[0]._actions[0], ActionOK)
assert type(list(new_env.roles.values())[0]._actions[0]) == ActionOK
assert list(new_env.roles.values())[0].states == list(environment.roles.values())[0].states
assert list(new_env.roles.values())[0].actions == list(environment.roles.values())[0].actions
assert isinstance(list(environment.roles.values())[0].actions[0], ActionOK)
assert type(list(new_env.roles.values())[0].actions[0]) == ActionOK
def test_environment_serdeser_v2():
@ -64,13 +65,13 @@ def test_environment_serdeser_v2():
pm = ProjectManager()
environment.add_role(pm)
ser_data = environment.dict()
ser_data = environment.model_dump()
new_env: Environment = Environment(**ser_data)
role = new_env.get_role(pm.profile)
assert isinstance(role, ProjectManager)
assert isinstance(role._actions[0], WriteTasks)
assert isinstance(list(new_env.roles.values())[0]._actions[0], WriteTasks)
assert isinstance(role.actions[0], WriteTasks)
assert isinstance(list(new_env.roles.values())[0].actions[0], WriteTasks)
def test_environment_serdeser_save():
@ -85,4 +86,4 @@ def test_environment_serdeser_save():
new_env: Environment = Environment.deserialize(stg_path)
assert len(new_env.roles) == 1
assert type(list(new_env.roles.values())[0]._actions[0]) == ActionOK
assert type(list(new_env.roles.values())[0].actions[0]) == ActionOK

View file

@ -25,7 +25,7 @@ def test_memory_serdeser():
memory = Memory()
memory.add_batch([msg1, msg2])
ser_data = memory.dict()
ser_data = memory.model_dump()
new_memory = Memory(**ser_data)
assert new_memory.count() == 2
@ -35,6 +35,9 @@ def test_memory_serdeser():
assert new_memory.storage[-1].cause_by == any_to_str(WriteDesign)
assert new_msg2.role == "Boss"
memory = Memory(storage=[msg1, msg2], index={msg1.cause_by: [msg1], msg2.cause_by: [msg2]})
assert memory.count() == 2
def test_memory_serdeser_save():
msg1 = Message(role="User", content="write a 2048 game", cause_by=UserRequirement)

View file

@ -0,0 +1,58 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : unittest of polymorphic conditions
from pydantic import BaseModel, ConfigDict, SerializeAsAny
from metagpt.actions import Action
from tests.metagpt.serialize_deserialize.test_serdeser_base import (
ActionOKV2,
ActionPass,
)
class ActionSubClasses(BaseModel):
actions: list[SerializeAsAny[Action]] = []
class ActionSubClassesNoSAA(BaseModel):
"""without SerializeAsAny"""
model_config = ConfigDict(arbitrary_types_allowed=True)
actions: list[Action] = []
def test_serialize_as_any():
"""test subclasses of action with different fields in ser&deser"""
# ActionOKV2 with a extra field `extra_field`
action_subcls = ActionSubClasses(actions=[ActionOKV2(), ActionPass()])
action_subcls_dict = action_subcls.model_dump()
assert action_subcls_dict["actions"][0]["extra_field"] == ActionOKV2().extra_field
def test_no_serialize_as_any():
# ActionOKV2 with a extra field `extra_field`
action_subcls = ActionSubClassesNoSAA(actions=[ActionOKV2(), ActionPass()])
action_subcls_dict = action_subcls.model_dump()
# without `SerializeAsAny`, it will serialize as Action
assert "extra_field" not in action_subcls_dict["actions"][0]
def test_polymorphic():
_ = ActionOKV2(
**{"name": "ActionOKV2", "context": "", "prefix": "", "desc": "", "extra_field": "ActionOKV2 Extra Info"}
)
action_subcls = ActionSubClasses(actions=[ActionOKV2(), ActionPass()])
action_subcls_dict = action_subcls.model_dump()
assert "__module_class_name" in action_subcls_dict["actions"][0]
new_action_subcls = ActionSubClasses(**action_subcls_dict)
assert isinstance(new_action_subcls.actions[0], ActionOKV2)
assert isinstance(new_action_subcls.actions[1], ActionPass)
new_action_subcls = ActionSubClasses.model_validate(action_subcls_dict)
assert isinstance(new_action_subcls.actions[0], ActionOKV2)
assert isinstance(new_action_subcls.actions[1], ActionPass)

View file

@ -12,10 +12,10 @@ from metagpt.schema import Message
@pytest.mark.asyncio
async def test_product_manager_deserialize():
role = ProductManager()
ser_role_dict = role.dict(by_alias=True)
ser_role_dict = role.model_dump(by_alias=True)
new_role = ProductManager(**ser_role_dict)
assert new_role.name == "Alice"
assert len(new_role._actions) == 2
assert isinstance(new_role._actions[0], Action)
await new_role._actions[0].run([Message(content="write a cli snake game")])
assert len(new_role.actions) == 2
assert isinstance(new_role.actions[0], Action)
await new_role.actions[0].run([Message(content="write a cli snake game")])

View file

@ -11,20 +11,20 @@ from metagpt.roles.project_manager import ProjectManager
def test_project_manager_serialize():
role = ProjectManager()
ser_role_dict = role.dict(by_alias=True)
ser_role_dict = role.model_dump(by_alias=True)
assert "name" in ser_role_dict
assert "_states" in ser_role_dict
assert "_actions" in ser_role_dict
assert "states" in ser_role_dict
assert "actions" in ser_role_dict
@pytest.mark.asyncio
async def test_project_manager_deserialize():
role = ProjectManager()
ser_role_dict = role.dict(by_alias=True)
ser_role_dict = role.model_dump(by_alias=True)
new_role = ProjectManager(**ser_role_dict)
assert new_role.name == "Eve"
assert len(new_role._actions) == 1
assert isinstance(new_role._actions[0], Action)
assert isinstance(new_role._actions[0], WriteTasks)
# await new_role._actions[0].run(context="write a cli snake game")
assert len(new_role.actions) == 1
assert isinstance(new_role.actions[0], Action)
assert isinstance(new_role.actions[0], WriteTasks)
# await new_role.actions[0].run(context="write a cli snake game")

View file

@ -6,6 +6,7 @@
import shutil
import pytest
from pydantic import BaseModel, SerializeAsAny
from metagpt.actions import WriteCode
from metagpt.actions.add_requirement import UserRequirement
@ -17,48 +18,67 @@ from metagpt.roles.role import Role
from metagpt.schema import Message
from metagpt.utils.common import format_trackback_info
from tests.metagpt.serialize_deserialize.test_serdeser_base import (
ActionOK,
RoleA,
RoleB,
RoleC,
RoleD,
serdeser_path,
)
def test_roles():
role_a = RoleA()
assert len(role_a._rc.watch) == 1
assert len(role_a.rc.watch) == 1
role_b = RoleB()
assert len(role_a._rc.watch) == 1
assert len(role_b._rc.watch) == 1
assert len(role_a.rc.watch) == 1
assert len(role_b.rc.watch) == 1
role_d = RoleD(actions=[ActionOK()])
assert len(role_d.actions) == 1
def test_role_subclasses():
"""test subclasses of role with same fields in ser&deser"""
class RoleSubClasses(BaseModel):
roles: list[SerializeAsAny[Role]] = []
role_subcls = RoleSubClasses(roles=[RoleA(), RoleB()])
role_subcls_dict = role_subcls.model_dump()
new_role_subcls = RoleSubClasses(**role_subcls_dict)
assert isinstance(new_role_subcls.roles[0], RoleA)
assert isinstance(new_role_subcls.roles[1], RoleB)
def test_role_serialize():
role = Role()
ser_role_dict = role.dict(by_alias=True)
ser_role_dict = role.model_dump()
assert "name" in ser_role_dict
assert "_states" in ser_role_dict
assert "_actions" in ser_role_dict
assert "states" in ser_role_dict
assert "actions" in ser_role_dict
def test_engineer_serialize():
role = Engineer()
ser_role_dict = role.dict(by_alias=True)
ser_role_dict = role.model_dump()
assert "name" in ser_role_dict
assert "_states" in ser_role_dict
assert "_actions" in ser_role_dict
assert "states" in ser_role_dict
assert "actions" in ser_role_dict
@pytest.mark.asyncio
async def test_engineer_deserialize():
role = Engineer(use_code_review=True)
ser_role_dict = role.dict(by_alias=True)
ser_role_dict = role.model_dump()
new_role = Engineer(**ser_role_dict)
assert new_role.name == "Alex"
assert new_role.use_code_review is True
assert len(new_role._actions) == 1
assert isinstance(new_role._actions[0], WriteCode)
# await new_role._actions[0].run(context="write a cli snake game", filename="test_code")
assert len(new_role.actions) == 1
assert isinstance(new_role.actions[0], WriteCode)
# await new_role.actions[0].run(context="write a cli snake game", filename="test_code")
def test_role_serdeser_save():
@ -87,10 +107,10 @@ async def test_role_serdeser_interrupt():
logger.error(f"Exception in `role_a.run`, detail: {format_trackback_info()}")
role_c.serialize(stg_path)
assert role_c._rc.memory.count() == 1
assert role_c.rc.memory.count() == 1
new_role_a: Role = Role.deserialize(stg_path)
assert new_role_a._rc.state == 1
assert new_role_a.rc.state == 1
with pytest.raises(Exception):
await new_role_a.run(with_message=Message(content="demo", cause_by=UserRequirement))

View file

@ -4,9 +4,12 @@
from metagpt.actions.action_node import ActionNode
from metagpt.actions.write_code import WriteCode
from metagpt.schema import Message
from metagpt.schema import Document, Documents, Message
from metagpt.utils.common import any_to_str
from tests.metagpt.serialize_deserialize.test_serdeser_base import MockMessage
from tests.metagpt.serialize_deserialize.test_serdeser_base import (
MockICMessage,
MockMessage,
)
def test_message_serdeser():
@ -15,14 +18,24 @@ def test_message_serdeser():
ic_obj = ActionNode.create_model_class("code", out_mapping)
message = Message(content="code", instruct_content=ic_obj(**out_data), role="engineer", cause_by=WriteCode)
ser_data = message.dict()
ser_data = message.model_dump()
assert ser_data["cause_by"] == "metagpt.actions.write_code.WriteCode"
assert ser_data["instruct_content"]["class"] == "code"
new_message = Message(**ser_data)
assert new_message.cause_by == any_to_str(WriteCode)
assert new_message.cause_by in [any_to_str(WriteCode)]
assert new_message.instruct_content == ic_obj(**out_data)
assert new_message.instruct_content != ic_obj(**out_data) # TODO find why `!=`
assert new_message.instruct_content.model_dump() == ic_obj(**out_data).model_dump()
message = Message(content="test_ic", instruct_content=MockICMessage())
ser_data = message.model_dump()
new_message = Message(**ser_data)
assert new_message.instruct_content != MockICMessage() # TODO
message = Message(content="test_documents", instruct_content=Documents(docs={"doc1": Document(content="test doc")}))
ser_data = message.model_dump()
assert "class" in ser_data["instruct_content"]
def test_message_without_postprocess():
@ -31,8 +44,9 @@ def test_message_without_postprocess():
out_data = {"field1": ["field1 value1", "field1 value2"]}
ic_obj = ActionNode.create_model_class("code", out_mapping)
message = MockMessage(content="code", instruct_content=ic_obj(**out_data))
ser_data = message.dict()
assert ser_data["instruct_content"] == {"field1": ["field1 value1", "field1 value2"]}
ser_data = message.model_dump()
assert ser_data["instruct_content"] == {}
ser_data["instruct_content"] = None
new_message = MockMessage(**ser_data)
assert new_message.instruct_content != ic_obj(**out_data)

View file

@ -4,6 +4,7 @@
import asyncio
from pathlib import Path
from typing import Optional
from pydantic import BaseModel, Field
@ -15,15 +16,19 @@ from metagpt.roles.role import Role, RoleReactMode
serdeser_path = Path(__file__).absolute().parent.joinpath("..", "..", "data", "serdeser_storage")
class MockICMessage(BaseModel):
content: str = "test_ic"
class MockMessage(BaseModel):
"""to test normal dict without postprocess"""
content: str = ""
instruct_content: BaseModel = Field(default=None)
instruct_content: Optional[BaseModel] = Field(default=None)
class ActionPass(Action):
name: str = Field(default="ActionPass")
name: str = "ActionPass"
async def run(self, messages: list["Message"]) -> ActionOutput:
await asyncio.sleep(5) # sleep to make other roles can watch the executed Message
@ -35,7 +40,7 @@ class ActionPass(Action):
class ActionOK(Action):
name: str = Field(default="ActionOK")
name: str = "ActionOK"
async def run(self, messages: list["Message"]) -> str:
await asyncio.sleep(5)
@ -43,12 +48,17 @@ class ActionOK(Action):
class ActionRaise(Action):
name: str = Field(default="ActionRaise")
name: str = "ActionRaise"
async def run(self, messages: list["Message"]) -> str:
raise RuntimeError("parse error in ActionRaise")
class ActionOKV2(Action):
name: str = "ActionOKV2"
extra_field: str = "ActionOKV2 Extra Info"
class RoleA(Role):
name: str = Field(default="RoleA")
profile: str = Field(default="Role A")
@ -71,7 +81,7 @@ class RoleB(Role):
super(RoleB, self).__init__(**kwargs)
self._init_actions([ActionOK, ActionRaise])
self._watch([ActionPass])
self._rc.react_mode = RoleReactMode.BY_ORDER
self.rc.react_mode = RoleReactMode.BY_ORDER
class RoleC(Role):
@ -84,5 +94,12 @@ class RoleC(Role):
super(RoleC, self).__init__(**kwargs)
self._init_actions([ActionOK, ActionRaise])
self._watch([UserRequirement])
self._rc.react_mode = RoleReactMode.BY_ORDER
self._rc.memory.ignore_id = True
self.rc.react_mode = RoleReactMode.BY_ORDER
self.rc.memory.ignore_id = True
class RoleD(Role):
name: str = Field(default="RoleD")
profile: str = Field(default="Role D")
goal: str = "RoleD's goal"
constraints: str = "RoleD's constraints"

View file

@ -33,8 +33,8 @@ def test_team_deserialize():
]
)
assert len(company.env.get_roles()) == 3
ser_company = company.dict()
new_company = Team(**ser_company)
ser_company = company.model_dump()
new_company = Team.model_validate(ser_company)
assert len(new_company.env.get_roles()) == 3
assert new_company.env.get_role(pm.profile) is not None
@ -47,6 +47,7 @@ def test_team_deserialize():
def test_team_serdeser_save():
company = Team()
company.hire([RoleC()])
stg_path = serdeser_path.joinpath("team")
@ -71,13 +72,13 @@ async def test_team_recover():
company.run_project(idea)
await company.run(n_round=4)
ser_data = company.dict()
ser_data = company.model_dump()
new_company = Team(**ser_data)
new_role_c = new_company.env.get_role(role_c.profile)
# assert new_role_c._rc.memory == role_c._rc.memory # TODO
assert new_role_c._rc.env != role_c._rc.env # TODO
assert type(list(new_company.env.roles.values())[0]._actions[0]) == ActionOK
new_company.env.get_role(role_c.profile)
# assert new_role_c.rc.memory == role_c.rc.memory # TODO
# assert new_role_c.rc.env != role_c.rc.env # TODO
assert type(list(new_company.env.roles.values())[0].actions[0]) == ActionOK
new_company.run_project(idea)
await new_company.run(n_round=4)
@ -97,11 +98,11 @@ async def test_team_recover_save():
new_company = Team.deserialize(stg_path)
new_role_c = new_company.env.get_role(role_c.profile)
# assert new_role_c._rc.memory == role_c._rc.memory
assert new_role_c._rc.env != role_c._rc.env
# assert new_role_c.rc.memory == role_c.rc.memory
# assert new_role_c.rc.env != role_c.rc.env
assert new_role_c.recovered != role_c.recovered # here cause previous ut is `!=`
assert new_role_c._rc.todo != role_c._rc.todo # serialize exclude `_rc.todo`
assert new_role_c._rc.news != role_c._rc.news # serialize exclude `_rc.news`
assert new_role_c.rc.todo != role_c.rc.todo # serialize exclude `rc.todo`
assert new_role_c.rc.news != role_c.rc.news # serialize exclude `rc.news`
new_company.run_project(idea)
await new_company.run(n_round=4)
@ -116,10 +117,6 @@ async def test_team_recover_multi_roles_save():
role_a = RoleA()
role_b = RoleB()
assert role_a.subscription == {"tests.metagpt.serialize_deserialize.test_serdeser_base.RoleA", "RoleA"}
assert role_b.subscription == {"tests.metagpt.serialize_deserialize.test_serdeser_base.RoleB", "RoleB"}
assert role_b._rc.watch == {"tests.metagpt.serialize_deserialize.test_serdeser_base.ActionPass"}
company = Team()
company.hire([role_a, role_b])
company.run_project(idea)
@ -130,6 +127,6 @@ async def test_team_recover_multi_roles_save():
new_company = Team.deserialize(stg_path)
new_company.run_project(idea)
assert new_company.env.get_role(role_b.profile)._rc.state == 1
assert new_company.env.get_role(role_b.profile).rc.state == 1
await new_company.run(n_round=4)

View file

@ -12,9 +12,9 @@ from metagpt.schema import CodingContext, Document
def test_write_design_serialize():
action = WriteCode()
ser_action_dict = action.dict()
ser_action_dict = action.model_dump()
assert ser_action_dict["name"] == "WriteCode"
# assert "llm" in ser_action_dict # not export
assert "llm" not in ser_action_dict # not export
@pytest.mark.asyncio
@ -22,9 +22,9 @@ async def test_write_code_deserialize():
context = CodingContext(
filename="test_code.py", design_doc=Document(content="write add function to calculate two numbers")
)
doc = Document(content=context.json())
doc = Document(content=context.model_dump_json())
action = WriteCode(context=doc)
serialized_data = action.dict()
serialized_data = action.model_dump()
new_action = WriteCode(**serialized_data)
assert new_action.name == "WriteCode"

View file

@ -22,7 +22,7 @@ def div(a: int, b: int = 0):
)
action = WriteCodeReview(context=context)
serialized_data = action.dict()
serialized_data = action.model_dump()
assert serialized_data["name"] == "WriteCodeReview"
new_action = WriteCodeReview(**serialized_data)

View file

@ -10,22 +10,22 @@ from metagpt.llm import LLM
def test_write_design_serialize():
action = WriteDesign()
ser_action_dict = action.dict()
ser_action_dict = action.model_dump()
assert "name" in ser_action_dict
# assert "llm" in ser_action_dict # not export
assert "llm" not in ser_action_dict # not export
def test_write_task_serialize():
action = WriteTasks()
ser_action_dict = action.dict()
ser_action_dict = action.model_dump()
assert "name" in ser_action_dict
# assert "llm" in ser_action_dict # not export
assert "llm" not in ser_action_dict # not export
@pytest.mark.asyncio
async def test_write_design_deserialize():
action = WriteDesign()
serialized_data = action.dict()
serialized_data = action.model_dump()
new_action = WriteDesign(**serialized_data)
assert new_action.name == ""
assert new_action.llm == LLM()
@ -35,7 +35,7 @@ async def test_write_design_deserialize():
@pytest.mark.asyncio
async def test_write_task_deserialize():
action = WriteTasks()
serialized_data = action.dict()
serialized_data = action.model_dump()
new_action = WriteTasks(**serialized_data)
assert new_action.name == "CreateTasks"
assert new_action.llm == LLM()

View file

@ -12,15 +12,15 @@ from metagpt.schema import Message
def test_action_serialize():
action = WritePRD()
ser_action_dict = action.dict()
ser_action_dict = action.model_dump()
assert "name" in ser_action_dict
# assert "llm" in ser_action_dict # not export
assert "llm" not in ser_action_dict # not export
@pytest.mark.asyncio
async def test_action_deserialize():
action = WritePRD()
serialized_data = action.dict()
serialized_data = action.model_dump()
new_action = WritePRD(**serialized_data)
assert new_action.name == ""
assert new_action.llm == LLM()

View file

@ -33,6 +33,15 @@ class MockRole(Role):
self._init_actions([MockAction()])
def test_basic():
mock_role = MockRole()
assert mock_role.subscription == {"tests.metagpt.test_role.MockRole"}
assert mock_role.rc.watch == {"metagpt.actions.add_requirement.UserRequirement"}
mock_role = MockRole(name="mock_role")
assert mock_role.subscription == {"tests.metagpt.test_role.MockRole", "mock_role"}
@pytest.mark.asyncio
async def test_react():
class Input(BaseModel):
@ -60,12 +69,12 @@ async def test_react():
name=seed.name, profile=seed.profile, goal=seed.goal, constraints=seed.constraints, desc=seed.desc
)
role.subscribe({seed.subscription})
assert role._rc.watch == {any_to_str(UserRequirement)}
assert role.rc.watch == {any_to_str(UserRequirement)}
assert role.name == seed.name
assert role.profile == seed.profile
assert role._setting.goal == seed.goal
assert role._setting.constraints == seed.constraints
assert role._setting.desc == seed.desc
assert role.goal == seed.goal
assert role.constraints == seed.constraints
assert role.desc == seed.desc
assert role.is_idle
env = Environment()
env.add_role(role)

View file

@ -31,6 +31,8 @@ def test_messages():
def test_message():
Message("a", role="v1")
m = Message(content="a", role="v1")
v = m.dump()
d = json.loads(v)
@ -74,22 +76,22 @@ def test_message_serdeser():
ic_obj = ActionNode.create_model_class("code", out_mapping)
message = Message(content="code", instruct_content=ic_obj(**out_data), role="engineer", cause_by=WriteCode)
message_dict = message.dict()
message_dict = message.model_dump()
assert message_dict["cause_by"] == "metagpt.actions.write_code.WriteCode"
assert message_dict["instruct_content"] == {
"class": "code",
"mapping": {"field3": "(<class 'str'>, Ellipsis)", "field4": "(list[str], Ellipsis)"},
"value": {"field3": "field3 value3", "field4": ["field4 value1", "field4 value2"]},
}
new_message = Message(**message_dict)
new_message = Message.model_validate(message_dict)
assert new_message.content == message.content
assert new_message.instruct_content == message.instruct_content
assert new_message.instruct_content.model_dump() == message.instruct_content.model_dump()
assert new_message.instruct_content != message.instruct_content # TODO
assert new_message.cause_by == message.cause_by
assert new_message.instruct_content.field3 == out_data["field3"]
message = Message(content="code")
message_dict = message.dict()
message_dict = message.model_dump()
new_message = Message(**message_dict)
assert new_message.instruct_content is None
assert new_message.cause_by == "metagpt.actions.add_requirement.UserRequirement"

View file

@ -9,23 +9,25 @@ import pytest
from typer.testing import CliRunner
from metagpt.logs import logger
from metagpt.startup import app
from metagpt.team import Team
runner = CliRunner()
@pytest.mark.asyncio
async def test_team():
async def test_empty_team():
# FIXME: we're now using "metagpt" cli, so the entrance should be replaced instead.
company = Team()
company.run_project("做一个基础搜索引擎,可以支持知识库")
history = await company.run(n_round=5)
history = await company.run(idea="Build a simple search system. I will upload my files later.")
logger.info(history)
# def test_startup():
# args = ["Make a 2048 game"]
# result = runner.invoke(app, args)
def test_startup():
args = ["Make a cli snake game"]
result = runner.invoke(app, args)
logger.info(result)
logger.info(result.output)
if __name__ == "__main__":

View file

@ -10,4 +10,4 @@ def test_team():
company = Team()
company.hire([ProjectManager()])
assert len(company.environment.roles) == 1
assert len(company.env.roles) == 1

View file

@ -56,7 +56,7 @@ class TestGetProjectRoot:
def test_any_to_str(self):
class Input(BaseModel):
x: Any
x: Any = None
want: str
inputs = [
@ -74,7 +74,7 @@ class TestGetProjectRoot:
def test_any_to_str_set(self):
class Input(BaseModel):
x: Any
x: Any = None
want: Set
inputs = [

View file

@ -21,8 +21,8 @@ from metagpt.utils.dependency_file import DependencyFile
async def test_dependency_file():
class Input(BaseModel):
x: Union[Path, str]
deps: Optional[Set[Union[Path, str]]]
key: Optional[Union[Path, str]]
deps: Optional[Set[Union[Path, str]]] = None
key: Optional[Union[Path, str]] = None
want: Set[str]
inputs = [