diff --git a/examples/debate.py b/examples/debate.py index 8f5012d66..630f78cd8 100644 --- a/examples/debate.py +++ b/examples/debate.py @@ -15,7 +15,6 @@ from metagpt.logs import logger from metagpt.roles import Role from metagpt.schema import Message from metagpt.software_company import SoftwareCompany -from metagpt.utils.common import any_to_str_set class ShoutOut(Action): @@ -66,7 +65,7 @@ class Trump(Role): async def _act(self) -> Message: logger.info(f"{self._setting}: ready to {self._rc.todo}") - msg_history = self._rc.memory.get_by_actions(any_to_str_set([ShoutOut])) + msg_history = self._rc.memory.get_by_actions([ShoutOut]) context = [] for m in msg_history: context.append(str(m)) @@ -108,7 +107,7 @@ class Biden(Role): async def _act(self) -> Message: logger.info(f"{self._setting}: ready to {self._rc.todo}") - msg_history = self._rc.memory.get_by_actions(any_to_str_set([BossRequirement, ShoutOut])) + msg_history = self._rc.memory.get_by_actions([BossRequirement, ShoutOut]) context = [] for m in msg_history: context.append(str(m)) diff --git a/metagpt/memory/memory.py b/metagpt/memory/memory.py index 71d999049..53b65fcf7 100644 --- a/metagpt/memory/memory.py +++ b/metagpt/memory/memory.py @@ -10,6 +10,7 @@ from collections import defaultdict from typing import Iterable, Set from metagpt.schema import Message +from metagpt.utils.common import any_to_str, any_to_str_set class Memory: @@ -73,14 +74,16 @@ class Memory: news.append(i) return news - def get_by_action(self, action: str) -> list[Message]: + def get_by_action(self, action) -> list[Message]: """Return all messages triggered by a specified Action""" - return self.index[action] + index = any_to_str(action) + return self.index[index] def get_by_actions(self, actions: Set) -> list[Message]: """Return all messages triggered by specified Actions""" rsp = [] - for action in actions: + indices = any_to_str_set(actions) + for action in indices: if action not in self.index: continue rsp += self.index[action] diff --git a/metagpt/roles/engineer.py b/metagpt/roles/engineer.py index 742e00cc8..960f9c0f3 100644 --- a/metagpt/roles/engineer.py +++ b/metagpt/roles/engineer.py @@ -21,7 +21,7 @@ from metagpt.const import WORKSPACE_ROOT from metagpt.logs import logger from metagpt.roles import Role from metagpt.schema import Message -from metagpt.utils.common import CodeParser, any_to_str, any_to_str_set +from metagpt.utils.common import CodeParser, any_to_str from metagpt.utils.special_tokens import FILENAME_CODE_SEP, MSG_SEP @@ -102,7 +102,7 @@ class Engineer(Role): return CodeParser.parse_str(block="Python package name", text=system_design_msg.content) def get_workspace(self) -> Path: - msg = self._rc.memory.get_by_action(any_to_str(WriteDesign))[-1] + msg = self._rc.memory.get_by_action(WriteDesign)[-1] if not msg: return WORKSPACE_ROOT / "src" workspace = self.parse_workspace(msg) @@ -130,7 +130,7 @@ class Engineer(Role): todo_coros = [] for todo in self.todos: todo_coro = WriteCode().run( - context=self._rc.memory.get_by_actions(any_to_str_set([WriteTasks, WriteDesign])), + context=self._rc.memory.get_by_actions([WriteTasks, WriteDesign]), filename=todo, ) todo_coros.append(todo_coro) @@ -185,7 +185,7 @@ class Engineer(Role): TODO: The goal is not to need it. After clear task decomposition, based on the design idea, you should be able to write a single file without needing other codes. If you can't, it means you need a clearer definition. This is the key to writing longer code. """ context = [] - msg = self._rc.memory.get_by_actions(any_to_str_set([WriteDesign, WriteTasks, WriteCode])) + msg = self._rc.memory.get_by_actions([WriteDesign, WriteTasks, WriteCode]) for m in msg: context.append(m.content) context_str = "\n".join(context) @@ -240,8 +240,7 @@ class Engineer(Role): async def _think(self) -> None: # In asynchronous scenarios, first check if the required messages are ready. - filters = {any_to_str(WriteTasks)} - msgs = self._rc.memory.get_by_actions(filters) + msgs = self._rc.memory.get_by_actions({WriteTasks}) if not msgs: self._rc.todo = None return diff --git a/metagpt/roles/qa_engineer.py b/metagpt/roles/qa_engineer.py index 38fb5a24b..9495e1a12 100644 --- a/metagpt/roles/qa_engineer.py +++ b/metagpt/roles/qa_engineer.py @@ -22,12 +22,7 @@ from metagpt.const import WORKSPACE_ROOT from metagpt.logs import logger from metagpt.roles import Role from metagpt.schema import Message -from metagpt.utils.common import ( - CodeParser, - any_to_str_set, - get_class_name, - parse_recipient, -) +from metagpt.utils.common import CodeParser, any_to_str, any_to_str_set, parse_recipient from metagpt.utils.special_tokens import FILENAME_CODE_SEP, MSG_SEP @@ -55,7 +50,7 @@ class QaEngineer(Role): return CodeParser.parse_str(block="Python package name", text=system_design_msg.content) def get_workspace(self, return_proj_dir=True) -> Path: - msg = self._rc.memory.get_by_action(get_class_name(WriteDesign))[-1] + msg = self._rc.memory.get_by_action(WriteDesign)[-1] if not msg: return WORKSPACE_ROOT / "src" workspace = self.parse_workspace(msg) @@ -104,7 +99,7 @@ class QaEngineer(Role): msg = Message( content=str(file_info), role=self.profile, - cause_by=WriteTest, + cause_by=any_to_str(WriteTest), sent_from=self.profile, send_to=self.profile, )