refactor: @cause_by.setter

This commit is contained in:
莘权 马 2023-11-04 16:46:32 +08:00
parent 1febf168e7
commit d9775037b6
21 changed files with 73 additions and 123 deletions

View file

@ -19,7 +19,7 @@ from metagpt.schema import Message
async def test_write_prd():
product_manager = ProductManager()
requirements = "开发一个基于大语言模型与私有知识库的搜索引擎,希望可以基于大语言模型进行搜索总结"
prd = await product_manager.run(Message(content=requirements, cause_by=BossRequirement.get_class_name()))
prd = await product_manager.run(Message(content=requirements, cause_by=BossRequirement))
logger.info(requirements)
logger.info(prd)

View file

@ -19,24 +19,24 @@ def test_ltm_search():
assert len(openai_api_key) > 20
role_id = "UTUserLtm(Product Manager)"
rc = RoleContext(watch=[BossRequirement.get_class_name()])
rc = RoleContext(watch=[BossRequirement])
ltm = LongTermMemory()
ltm.recover_memory(role_id, rc)
idea = "Write a cli snake game"
message = Message(role="BOSS", content=idea, cause_by=BossRequirement.get_class_name())
message = Message(role="BOSS", content=idea, cause_by=BossRequirement)
news = ltm.find_news([message])
assert len(news) == 1
ltm.add(message)
sim_idea = "Write a game of cli snake"
sim_message = Message(role="BOSS", content=sim_idea, cause_by=BossRequirement.get_class_name())
sim_message = Message(role="BOSS", content=sim_idea, cause_by=BossRequirement)
news = ltm.find_news([sim_message])
assert len(news) == 0
ltm.add(sim_message)
new_idea = "Write a 2048 web game"
new_message = Message(role="BOSS", content=new_idea, cause_by=BossRequirement.get_class_name())
new_message = Message(role="BOSS", content=new_idea, cause_by=BossRequirement)
news = ltm.find_news([new_message])
assert len(news) == 1
ltm.add(new_message)
@ -52,7 +52,7 @@ def test_ltm_search():
assert len(news) == 0
new_idea = "Write a Battle City"
new_message = Message(role="BOSS", content=new_idea, cause_by=BossRequirement.get_class_name())
new_message = Message(role="BOSS", content=new_idea, cause_by=BossRequirement)
news = ltm_new.find_news([new_message])
assert len(news) == 1

View file

@ -18,7 +18,7 @@ from metagpt.schema import Message
def test_idea_message():
idea = "Write a cli snake game"
role_id = "UTUser1(Product Manager)"
message = Message(role="BOSS", content=idea, cause_by=BossRequirement.get_class_name())
message = Message(role="BOSS", content=idea, cause_by=BossRequirement)
memory_storage: MemoryStorage = MemoryStorage()
messages = memory_storage.recover_memory(role_id)
@ -28,12 +28,12 @@ def test_idea_message():
assert memory_storage.is_initialized is True
sim_idea = "Write a game of cli snake"
sim_message = Message(role="BOSS", content=sim_idea, cause_by=BossRequirement.get_class_name())
sim_message = Message(role="BOSS", content=sim_idea, cause_by=BossRequirement)
new_messages = memory_storage.search(sim_message)
assert len(new_messages) == 0 # similar, return []
new_idea = "Write a 2048 web game"
new_message = Message(role="BOSS", content=new_idea, cause_by=BossRequirement.get_class_name())
new_message = Message(role="BOSS", content=new_idea, cause_by=BossRequirement)
new_messages = memory_storage.search(new_message)
assert new_messages[0].content == message.content
@ -49,7 +49,7 @@ def test_actionout_message():
role_id = "UTUser2(Architect)"
content = "The boss has requested the creation of a command-line interface (CLI) snake game"
message = Message(
content=content, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD.get_class_name()
content=content, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD
) # WritePRD as test action
memory_storage: MemoryStorage = MemoryStorage()
@ -60,16 +60,12 @@ def test_actionout_message():
assert memory_storage.is_initialized is True
sim_conent = "The request is command-line interface (CLI) snake game"
sim_message = Message(
content=sim_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD.get_class_name()
)
sim_message = Message(content=sim_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD)
new_messages = memory_storage.search(sim_message)
assert len(new_messages) == 0 # similar, return []
new_conent = "Incorporate basic features of a snake game such as scoring and increasing difficulty"
new_message = Message(
content=new_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD.get_class_name()
)
new_message = Message(content=new_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD)
new_messages = memory_storage.search(new_message)
assert new_messages[0].content == message.content

View file

@ -26,7 +26,7 @@ async def test_action_planner():
role.import_skill(TimeSkill(), "time")
role.import_skill(TextSkill(), "text")
task = "What is the sum of 110 and 990?"
role.put_message(Message(content=task, cause_by=BossRequirement.get_class_name()))
role.put_message(Message(content=task, cause_by=BossRequirement))
await role._observe()
await role._think() # it will choose mathskill.Add
assert "1100" == (await role._act()).content

View file

@ -29,7 +29,7 @@ async def test_basic_planner():
role.import_semantic_skill_from_directory(SKILL_DIRECTORY, "WriterSkill")
role.import_skill(TextSkill(), "TextSkill")
# using BasicPlanner
role.put_message(Message(content=task, cause_by=BossRequirement.get_class_name()))
role.put_message(Message(content=task, cause_by=BossRequirement))
await role._observe()
await role._think()
# assuming sk_agent will think he needs WriterSkill.Brainstorm and WriterSkill.Translate

View file

@ -254,7 +254,7 @@ a = 'a'
class MockMessages:
req = Message(role="Boss", content=BOSS_REQUIREMENT, cause_by=BossRequirement.get_class_name())
prd = Message(role="Product Manager", content=PRD, cause_by=WritePRD.get_class_name())
system_design = Message(role="Architect", content=SYSTEM_DESIGN, cause_by=WriteDesign.get_class_name())
tasks = Message(role="Project Manager", content=TASKS, cause_by=WriteTasks.get_class_name())
req = Message(role="Boss", content=BOSS_REQUIREMENT, cause_by=BossRequirement)
prd = Message(role="Product Manager", content=PRD, cause_by=WritePRD)
system_design = Message(role="Architect", content=SYSTEM_DESIGN, cause_by=WriteDesign)
tasks = Message(role="Project Manager", content=TASKS, cause_by=WriteTasks)

View file

@ -51,7 +51,7 @@ async def test_publish_and_process_message(env: Environment):
env.add_roles([product_manager, architect])
env.set_manager(Manager())
env.publish_message(Message(role="BOSS", content="需要一个基于LLM做总结的搜索引擎", cause_by=BossRequirement.get_class_name()))
env.publish_message(Message(role="BOSS", content="需要一个基于LLM做总结的搜索引擎", cause_by=BossRequirement))
await env.run(k=2)
logger.info(f"{env.history=}")

View file

@ -13,6 +13,7 @@ import pytest
from metagpt.actions import Action
from metagpt.schema import AIMessage, Message, Routes, SystemMessage, UserMessage
from metagpt.utils.common import get_class_name
@pytest.mark.asyncio
@ -54,9 +55,9 @@ def test_message():
m.cause_by = "Message"
assert m.cause_by == "Message"
m.cause_by = Action
assert m.cause_by == Action.get_class_name()
assert m.cause_by == get_class_name(Action)
m.cause_by = Action()
assert m.cause_by == Action.get_class_name()
assert m.cause_by == get_class_name(Action)
m.content = "b"
assert m.content == "b"

View file

@ -1,28 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023-11-1
@Author : mashenquan
@File : test_named.py
"""
import pytest
from metagpt.utils.named import Named
@pytest.mark.asyncio
async def test_suite():
class A(Named):
pass
class B(A):
pass
assert A.get_class_name() == "tests.metagpt.utils.test_named.A"
assert A().get_object_name() == "tests.metagpt.utils.test_named.A"
assert B.get_class_name() == "tests.metagpt.utils.test_named.B"
assert B().get_object_name() == "tests.metagpt.utils.test_named.B"
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -59,7 +59,7 @@ def test_serialize_and_deserialize_message():
ic_obj = ActionOutput.create_model_class("prd", out_mapping)
message = Message(
content="prd demand", instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD.get_class_name()
content="prd demand", instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD
) # WritePRD as test action
message_ser = serialize_message(message)