diff --git a/examples/agent_creator.py b/examples/agent_creator.py index fe883bdf4..bd58840ce 100644 --- a/examples/agent_creator.py +++ b/examples/agent_creator.py @@ -61,7 +61,7 @@ class AgentCreator(Role): def __init__(self, **kwargs): super().__init__(**kwargs) - self.add_actions([CreateAgent]) + self.set_actions([CreateAgent]) async def _act(self) -> Message: logger.info(f"{self._setting}: to do {self.rc.todo}({self.rc.todo.name})") diff --git a/examples/build_customized_agent.py b/examples/build_customized_agent.py index a0c8ddfb3..cfe264b47 100644 --- a/examples/build_customized_agent.py +++ b/examples/build_customized_agent.py @@ -57,7 +57,7 @@ class SimpleCoder(Role): def __init__(self, **kwargs): super().__init__(**kwargs) - self.add_actions([SimpleWriteCode]) + self.set_actions([SimpleWriteCode]) async def _act(self) -> Message: logger.info(f"{self._setting}: to do {self.rc.todo}({self.rc.todo.name})") @@ -76,7 +76,7 @@ class RunnableCoder(Role): def __init__(self, **kwargs): super().__init__(**kwargs) - self.add_actions([SimpleWriteCode, SimpleRunCode]) + self.set_actions([SimpleWriteCode, SimpleRunCode]) self._set_react_mode(react_mode=RoleReactMode.BY_ORDER.value) async def _act(self) -> Message: diff --git a/examples/build_customized_multi_agents.py b/examples/build_customized_multi_agents.py index aceb3f2ab..296323cea 100644 --- a/examples/build_customized_multi_agents.py +++ b/examples/build_customized_multi_agents.py @@ -46,7 +46,7 @@ class SimpleCoder(Role): def __init__(self, **kwargs): super().__init__(**kwargs) self._watch([UserRequirement]) - self.add_actions([SimpleWriteCode]) + self.set_actions([SimpleWriteCode]) class SimpleWriteTest(Action): @@ -75,7 +75,7 @@ class SimpleTester(Role): def __init__(self, **kwargs): super().__init__(**kwargs) - self.add_actions([SimpleWriteTest]) + self.set_actions([SimpleWriteTest]) # self._watch([SimpleWriteCode]) self._watch([SimpleWriteCode, SimpleWriteReview]) # feel free to try this too @@ -114,7 +114,7 @@ class SimpleReviewer(Role): def __init__(self, **kwargs): super().__init__(**kwargs) - self.add_actions([SimpleWriteReview]) + self.set_actions([SimpleWriteReview]) self._watch([SimpleWriteTest]) diff --git a/examples/debate.py b/examples/debate.py index b47eba3cd..72ab8796d 100644 --- a/examples/debate.py +++ b/examples/debate.py @@ -49,7 +49,7 @@ class Debator(Role): def __init__(self, **data: Any): super().__init__(**data) - self.add_actions([SpeakAloud]) + self.set_actions([SpeakAloud]) self._watch([UserRequirement, SpeakAloud]) async def _observe(self) -> int: diff --git a/metagpt/roles/architect.py b/metagpt/roles/architect.py index a22a1c926..166f8cfd0 100644 --- a/metagpt/roles/architect.py +++ b/metagpt/roles/architect.py @@ -33,7 +33,7 @@ class Architect(Role): def __init__(self, **kwargs) -> None: super().__init__(**kwargs) # Initialize actions specific to the Architect role - self.add_actions([WriteDesign]) + self.set_actions([WriteDesign]) # Set events or actions the Architect should watch or be aware of self._watch({WritePRD}) diff --git a/metagpt/roles/engineer.py b/metagpt/roles/engineer.py index 0d277813e..bc56ca813 100644 --- a/metagpt/roles/engineer.py +++ b/metagpt/roles/engineer.py @@ -84,7 +84,7 @@ class Engineer(Role): def __init__(self, **kwargs) -> None: super().__init__(**kwargs) - self.add_actions([WriteCode]) + self.set_actions([WriteCode]) self._watch([WriteTasks, SummarizeCode, WriteCode, WriteCodeReview, FixBug]) self.code_todos = [] self.summarize_todos = [] diff --git a/metagpt/roles/invoice_ocr_assistant.py b/metagpt/roles/invoice_ocr_assistant.py index de7d3f8a3..a39a48b97 100644 --- a/metagpt/roles/invoice_ocr_assistant.py +++ b/metagpt/roles/invoice_ocr_assistant.py @@ -60,7 +60,7 @@ class InvoiceOCRAssistant(Role): def __init__(self, **kwargs): super().__init__(**kwargs) - self.add_actions([InvoiceOCR]) + self.set_actions([InvoiceOCR]) self._set_react_mode(react_mode=RoleReactMode.BY_ORDER.value) async def _act(self) -> Message: @@ -82,10 +82,10 @@ class InvoiceOCRAssistant(Role): resp = await todo.run(file_path) if len(resp) == 1: # Single file support for questioning based on OCR recognition results - self.add_actions([GenerateTable, ReplyQuestion]) + self.set_actions([GenerateTable, ReplyQuestion]) self.orc_data = resp[0] else: - self.add_actions([GenerateTable]) + self.set_actions([GenerateTable]) self.set_todo(None) content = INVOICE_OCR_SUCCESS diff --git a/metagpt/roles/product_manager.py b/metagpt/roles/product_manager.py index a35dcb3a0..ec80d7bb0 100644 --- a/metagpt/roles/product_manager.py +++ b/metagpt/roles/product_manager.py @@ -33,7 +33,7 @@ class ProductManager(Role): def __init__(self, **kwargs) -> None: super().__init__(**kwargs) - self.add_actions([PrepareDocuments, WritePRD]) + self.set_actions([PrepareDocuments, WritePRD]) self._watch([UserRequirement, PrepareDocuments]) self.todo_action = any_to_name(PrepareDocuments) diff --git a/metagpt/roles/project_manager.py b/metagpt/roles/project_manager.py index 7fa16b1e5..422d2889b 100644 --- a/metagpt/roles/project_manager.py +++ b/metagpt/roles/project_manager.py @@ -33,5 +33,5 @@ class ProjectManager(Role): def __init__(self, **kwargs) -> None: super().__init__(**kwargs) - self.add_actions([WriteTasks]) + self.set_actions([WriteTasks]) self._watch([WriteDesign]) diff --git a/metagpt/roles/qa_engineer.py b/metagpt/roles/qa_engineer.py index 9483ea260..783fde9b6 100644 --- a/metagpt/roles/qa_engineer.py +++ b/metagpt/roles/qa_engineer.py @@ -44,7 +44,7 @@ class QaEngineer(Role): # FIXME: a bit hack here, only init one action to circumvent _think() logic, # will overwrite _think() in future updates - self.add_actions([WriteTest]) + self.set_actions([WriteTest]) self._watch([SummarizeCode, WriteTest, RunCode, DebugError]) self.test_round = 0 diff --git a/metagpt/roles/researcher.py b/metagpt/roles/researcher.py index e877778f6..137cfdb4c 100644 --- a/metagpt/roles/researcher.py +++ b/metagpt/roles/researcher.py @@ -34,7 +34,7 @@ class Researcher(Role): def __init__(self, **kwargs): super().__init__(**kwargs) - self.add_actions( + self.set_actions( [CollectLinks(name=self.name), WebBrowseAndSummarize(name=self.name), ConductResearch(name=self.name)] ) self._set_react_mode(react_mode=RoleReactMode.BY_ORDER.value) diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 72ee1175b..e467ef83e 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -222,16 +222,17 @@ class Role(SerializationMixin, ContextMixin, BaseModel): def _init_action_system_message(self, action: Action): action.set_prefix(self._get_prefix()) - def add_action(self, action: Action): + def set_action(self, action: Action): """Add action to the role.""" - self.add_actions([action]) + self.set_actions([action]) - def add_actions(self, actions: list[Union[Action, Type[Action]]]): + def set_actions(self, actions: list[Union[Action, Type[Action]]]): """Add actions to the role. Args: actions: list of Action classes or instances """ + self._reset() for action in actions: if not isinstance(action, Action): i = action(name="", llm=self.llm) diff --git a/metagpt/roles/sales.py b/metagpt/roles/sales.py index 8da930888..7929ce7fe 100644 --- a/metagpt/roles/sales.py +++ b/metagpt/roles/sales.py @@ -38,5 +38,5 @@ class Sales(Role): action = SearchAndSummarize(name="", engine=SearchEngineType.CUSTOM_ENGINE, search_func=store.asearch) else: action = SearchAndSummarize() - self.add_actions([action]) + self.set_actions([action]) self._watch([UserRequirement]) diff --git a/metagpt/roles/searcher.py b/metagpt/roles/searcher.py index f37bd4704..e0d2dbb65 100644 --- a/metagpt/roles/searcher.py +++ b/metagpt/roles/searcher.py @@ -48,12 +48,12 @@ class Searcher(Role): engine (SearchEngineType): The type of search engine to use. """ super().__init__(**kwargs) - self.add_actions([SearchAndSummarize(engine=self.engine)]) + self.set_actions([SearchAndSummarize(engine=self.engine)]) def set_search_func(self, search_func): """Sets a custom search function for the searcher.""" action = SearchAndSummarize(name="", engine=SearchEngineType.CUSTOM_ENGINE, search_func=search_func) - self.add_actions([action]) + self.set_actions([action]) async def _act_sp(self) -> Message: """Performs the search action in a single process.""" diff --git a/metagpt/roles/sk_agent.py b/metagpt/roles/sk_agent.py index 200ed5051..71df55fcc 100644 --- a/metagpt/roles/sk_agent.py +++ b/metagpt/roles/sk_agent.py @@ -49,7 +49,7 @@ class SkAgent(Role): def __init__(self, **data: Any) -> None: """Initializes the Engineer role with given attributes.""" super().__init__(**data) - self.add_actions([ExecuteTask()]) + self.set_actions([ExecuteTask()]) self._watch([UserRequirement]) self.kernel = make_sk_kernel() diff --git a/metagpt/roles/teacher.py b/metagpt/roles/teacher.py index 9206d5f80..d47f4af5b 100644 --- a/metagpt/roles/teacher.py +++ b/metagpt/roles/teacher.py @@ -47,7 +47,7 @@ class Teacher(Role): for topic in TeachingPlanBlock.TOPICS: act = WriteTeachingPlanPart(i_context=self.rc.news[0].content, topic=topic, llm=self.llm) actions.append(act) - self.add_actions(actions) + self.set_actions(actions) if self.rc.todo is None: self._set_state(0) diff --git a/metagpt/roles/tutorial_assistant.py b/metagpt/roles/tutorial_assistant.py index d296c7b3f..6cf3a6469 100644 --- a/metagpt/roles/tutorial_assistant.py +++ b/metagpt/roles/tutorial_assistant.py @@ -40,7 +40,7 @@ class TutorialAssistant(Role): def __init__(self, **kwargs): super().__init__(**kwargs) - self.add_actions([WriteDirectory(language=self.language)]) + self.set_actions([WriteDirectory(language=self.language)]) self._set_react_mode(react_mode=RoleReactMode.BY_ORDER.value) async def _handle_directory(self, titles: Dict) -> Message: @@ -63,7 +63,7 @@ class TutorialAssistant(Role): directory += f"- {key}\n" for second_dir in first_dir[key]: directory += f" - {second_dir}\n" - self.add_actions(actions) + self.set_actions(actions) async def _act(self) -> Message: """Perform an action as determined by the role. diff --git a/tests/metagpt/serialize_deserialize/test_serdeser_base.py b/tests/metagpt/serialize_deserialize/test_serdeser_base.py index c97cea597..62ab26d72 100644 --- a/tests/metagpt/serialize_deserialize/test_serdeser_base.py +++ b/tests/metagpt/serialize_deserialize/test_serdeser_base.py @@ -67,7 +67,7 @@ class RoleA(Role): def __init__(self, **kwargs): super(RoleA, self).__init__(**kwargs) - self.add_actions([ActionPass]) + self.set_actions([ActionPass]) self._watch([UserRequirement]) @@ -79,7 +79,7 @@ class RoleB(Role): def __init__(self, **kwargs): super(RoleB, self).__init__(**kwargs) - self.add_actions([ActionOK, ActionRaise]) + self.set_actions([ActionOK, ActionRaise]) self._watch([ActionPass]) self.rc.react_mode = RoleReactMode.BY_ORDER @@ -92,7 +92,7 @@ class RoleC(Role): def __init__(self, **kwargs): super(RoleC, self).__init__(**kwargs) - self.add_actions([ActionOK, ActionRaise]) + self.set_actions([ActionOK, ActionRaise]) self._watch([UserRequirement]) self.rc.react_mode = RoleReactMode.BY_ORDER self.rc.memory.ignore_id = True diff --git a/tests/metagpt/test_role.py b/tests/metagpt/test_role.py index c67a8ad8a..351ba9051 100644 --- a/tests/metagpt/test_role.py +++ b/tests/metagpt/test_role.py @@ -33,7 +33,7 @@ class MockAction(Action): class MockRole(Role): def __init__(self, name="", profile="", goal="", constraints="", desc=""): super().__init__(name=name, profile=profile, goal=goal, constraints=constraints, desc=desc) - self.add_actions([MockAction()]) + self.set_actions([MockAction()]) def test_basic(): @@ -111,7 +111,7 @@ async def test_send_to(): def test_init_action(): role = Role() - role.add_actions([MockAction, MockAction]) + role.set_actions([MockAction, MockAction]) assert len(role.actions) == 2 @@ -127,7 +127,7 @@ async def test_recover(): role.publish_message(None) role.llm = mock_llm - role.add_actions([MockAction, MockAction]) + role.set_actions([MockAction, MockAction]) role.recovered = True role.latest_observed_msg = Message(content="recover_test") role.rc.state = 0 @@ -144,7 +144,7 @@ async def test_think_act(): mock_llm.aask.side_effect = ["ok"] role = Role() - role.add_actions([MockAction]) + role.set_actions([MockAction]) await role.think() role.rc.memory.add(Message("run")) assert len(role.get_memories()) == 1