From 30de3b4d6498cd8ebf2d9efdeb9e6f0a5d861a5a Mon Sep 17 00:00:00 2001 From: yzlin Date: Wed, 31 Jan 2024 15:08:40 +0800 Subject: [PATCH] fix message init bug --- metagpt/schema.py | 2 +- tests/metagpt/test_schema.py | 274 +++++++++++++++++------------------ 2 files changed, 138 insertions(+), 138 deletions(-) diff --git a/metagpt/schema.py b/metagpt/schema.py index e6a447fba..08f97be94 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -327,7 +327,7 @@ class AIMessage(Message): """ def __init__(self, content: str): - super().__init__(content, "assistant") + super().__init__(content=content, role="assistant") class Task(BaseModel): diff --git a/tests/metagpt/test_schema.py b/tests/metagpt/test_schema.py index 17d2bb22c..a8fa27151 100644 --- a/tests/metagpt/test_schema.py +++ b/tests/metagpt/test_schema.py @@ -46,6 +46,143 @@ def test_messages(): assert all([i in text for i in roles]) +def test_message(): + Message("a", role="v1") + + m = Message(content="a", role="v1") + v = m.dump() + d = json.loads(v) + assert d + assert d.get("content") == "a" + assert d.get("role") == "v1" + m.role = "v2" + v = m.dump() + assert v + m = Message.load(v) + assert m.content == "a" + assert m.role == "v2" + + m = Message(content="a", role="b", cause_by="c", x="d", send_to="c") + assert m.content == "a" + assert m.role == "b" + assert m.send_to == {"c"} + assert m.cause_by == "c" + m.sent_from = "e" + assert m.sent_from == "e" + + m.cause_by = "Message" + assert m.cause_by == "Message" + m.cause_by = Action + assert m.cause_by == any_to_str(Action) + m.cause_by = Action() + assert m.cause_by == any_to_str(Action) + m.content = "b" + assert m.content == "b" + + +def test_routes(): + m = Message(content="a", role="b", cause_by="c", x="d", send_to="c") + m.send_to = "b" + assert m.send_to == {"b"} + m.send_to = {"e", Action} + assert m.send_to == {"e", any_to_str(Action)} + + +def test_message_serdeser(): + out_mapping = {"field3": (str, ...), "field4": (list[str], ...)} + out_data = {"field3": "field3 value3", "field4": ["field4 value1", "field4 value2"]} + ic_obj = ActionNode.create_model_class("code", out_mapping) + + message = Message(content="code", instruct_content=ic_obj(**out_data), role="engineer", cause_by=WriteCode) + message_dict = message.model_dump() + assert message_dict["cause_by"] == "metagpt.actions.write_code.WriteCode" + assert message_dict["instruct_content"] == { + "class": "code", + "mapping": {"field3": "(, Ellipsis)", "field4": "(list[str], Ellipsis)"}, + "value": {"field3": "field3 value3", "field4": ["field4 value1", "field4 value2"]}, + } + new_message = Message.model_validate(message_dict) + assert new_message.content == message.content + assert new_message.instruct_content.model_dump() == message.instruct_content.model_dump() + assert new_message.instruct_content == message.instruct_content # TODO + assert new_message.cause_by == message.cause_by + assert new_message.instruct_content.field3 == out_data["field3"] + + message = Message(content="code") + message_dict = message.model_dump() + new_message = Message(**message_dict) + assert new_message.instruct_content is None + assert new_message.cause_by == "metagpt.actions.add_requirement.UserRequirement" + assert not Message.load("{") + + +def test_document(): + doc = Document(root_path="a", filename="b", content="c") + meta_doc = doc.get_meta() + assert doc.root_path == meta_doc.root_path + assert doc.filename == meta_doc.filename + assert meta_doc.content == "" + + +@pytest.mark.asyncio +async def test_message_queue(): + mq = MessageQueue() + val = await mq.dump() + assert val == "[]" + mq.push(Message(content="1")) + mq.push(Message(content="2中文测试aaa")) + msg = mq.pop() + assert msg.content == "1" + + val = await mq.dump() + assert val + new_mq = MessageQueue.load(val) + assert new_mq.pop_all() == mq.pop_all() + + +@pytest.mark.parametrize( + ("file_list", "want"), + [ + ( + [f"{SYSTEM_DESIGN_FILE_REPO}/a.txt", f"{TASK_FILE_REPO}/b.txt"], + CodeSummarizeContext( + design_filename=f"{SYSTEM_DESIGN_FILE_REPO}/a.txt", task_filename=f"{TASK_FILE_REPO}/b.txt" + ), + ) + ], +) +def test_CodeSummarizeContext(file_list, want): + ctx = CodeSummarizeContext.loads(file_list) + assert ctx == want + m = {ctx: ctx} + assert want in m + + +def test_class_view(): + attr_a = ClassAttribute(name="a", value_type="int", default_value="0", visibility="+", abstraction=True) + assert attr_a.get_mermaid(align=1) == "\t+int a=0*" + attr_b = ClassAttribute(name="b", value_type="str", default_value="0", visibility="#", static=True) + assert attr_b.get_mermaid(align=0) == '#str b="0"$' + class_view = ClassView(name="A") + class_view.attributes = [attr_a, attr_b] + + method_a = ClassMethod(name="run", visibility="+", abstraction=True) + assert method_a.get_mermaid(align=1) == "\t+run()*" + method_b = ClassMethod( + name="_test", + visibility="#", + static=True, + args=[ClassAttribute(name="a", value_type="str"), ClassAttribute(name="b", value_type="int")], + return_type="str", + ) + assert method_b.get_mermaid(align=0) == "#_test(str a,int b):str$" + class_view.methods = [method_a, method_b] + assert ( + class_view.get_mermaid(align=0) + == 'class A{\n\t+int a=0*\n\t#str b="0"$\n\t+run()*\n\t#_test(str a,int b):str$\n}\n' + ) + + class TestPlan: def test_add_tasks_ordering(self): plan = Plan(goal="") @@ -214,142 +351,5 @@ class TestPlan: assert plan.current_task_id == "2" -def test_message(): - Message("a", role="v1") - - m = Message(content="a", role="v1") - v = m.dump() - d = json.loads(v) - assert d - assert d.get("content") == "a" - assert d.get("role") == "v1" - m.role = "v2" - v = m.dump() - assert v - m = Message.load(v) - assert m.content == "a" - assert m.role == "v2" - - m = Message(content="a", role="b", cause_by="c", x="d", send_to="c") - assert m.content == "a" - assert m.role == "b" - assert m.send_to == {"c"} - assert m.cause_by == "c" - m.sent_from = "e" - assert m.sent_from == "e" - - m.cause_by = "Message" - assert m.cause_by == "Message" - m.cause_by = Action - assert m.cause_by == any_to_str(Action) - m.cause_by = Action() - assert m.cause_by == any_to_str(Action) - m.content = "b" - assert m.content == "b" - - -def test_routes(): - m = Message(content="a", role="b", cause_by="c", x="d", send_to="c") - m.send_to = "b" - assert m.send_to == {"b"} - m.send_to = {"e", Action} - assert m.send_to == {"e", any_to_str(Action)} - - -def test_message_serdeser(): - out_mapping = {"field3": (str, ...), "field4": (list[str], ...)} - out_data = {"field3": "field3 value3", "field4": ["field4 value1", "field4 value2"]} - ic_obj = ActionNode.create_model_class("code", out_mapping) - - message = Message(content="code", instruct_content=ic_obj(**out_data), role="engineer", cause_by=WriteCode) - message_dict = message.model_dump() - assert message_dict["cause_by"] == "metagpt.actions.write_code.WriteCode" - assert message_dict["instruct_content"] == { - "class": "code", - "mapping": {"field3": "(, Ellipsis)", "field4": "(list[str], Ellipsis)"}, - "value": {"field3": "field3 value3", "field4": ["field4 value1", "field4 value2"]}, - } - new_message = Message.model_validate(message_dict) - assert new_message.content == message.content - assert new_message.instruct_content.model_dump() == message.instruct_content.model_dump() - assert new_message.instruct_content == message.instruct_content # TODO - assert new_message.cause_by == message.cause_by - assert new_message.instruct_content.field3 == out_data["field3"] - - message = Message(content="code") - message_dict = message.model_dump() - new_message = Message(**message_dict) - assert new_message.instruct_content is None - assert new_message.cause_by == "metagpt.actions.add_requirement.UserRequirement" - assert not Message.load("{") - - -def test_document(): - doc = Document(root_path="a", filename="b", content="c") - meta_doc = doc.get_meta() - assert doc.root_path == meta_doc.root_path - assert doc.filename == meta_doc.filename - assert meta_doc.content == "" - - -@pytest.mark.asyncio -async def test_message_queue(): - mq = MessageQueue() - val = await mq.dump() - assert val == "[]" - mq.push(Message(content="1")) - mq.push(Message(content="2中文测试aaa")) - msg = mq.pop() - assert msg.content == "1" - - val = await mq.dump() - assert val - new_mq = MessageQueue.load(val) - assert new_mq.pop_all() == mq.pop_all() - - -@pytest.mark.parametrize( - ("file_list", "want"), - [ - ( - [f"{SYSTEM_DESIGN_FILE_REPO}/a.txt", f"{TASK_FILE_REPO}/b.txt"], - CodeSummarizeContext( - design_filename=f"{SYSTEM_DESIGN_FILE_REPO}/a.txt", task_filename=f"{TASK_FILE_REPO}/b.txt" - ), - ) - ], -) -def test_CodeSummarizeContext(file_list, want): - ctx = CodeSummarizeContext.loads(file_list) - assert ctx == want - m = {ctx: ctx} - assert want in m - - -def test_class_view(): - attr_a = ClassAttribute(name="a", value_type="int", default_value="0", visibility="+", abstraction=True) - assert attr_a.get_mermaid(align=1) == "\t+int a=0*" - attr_b = ClassAttribute(name="b", value_type="str", default_value="0", visibility="#", static=True) - assert attr_b.get_mermaid(align=0) == '#str b="0"$' - class_view = ClassView(name="A") - class_view.attributes = [attr_a, attr_b] - - method_a = ClassMethod(name="run", visibility="+", abstraction=True) - assert method_a.get_mermaid(align=1) == "\t+run()*" - method_b = ClassMethod( - name="_test", - visibility="#", - static=True, - args=[ClassAttribute(name="a", value_type="str"), ClassAttribute(name="b", value_type="int")], - return_type="str", - ) - assert method_b.get_mermaid(align=0) == "#_test(str a,int b):str$" - class_view.methods = [method_a, method_b] - assert ( - class_view.get_mermaid(align=0) - == 'class A{\n\t+int a=0*\n\t#str b="0"$\n\t+run()*\n\t#_test(str a,int b):str$\n}\n' - ) - - if __name__ == "__main__": pytest.main([__file__, "-s"])