refactor: get_by_action(s)

This commit is contained in:
莘权 马 2023-11-10 16:15:07 +08:00
parent 83a5e03b72
commit a61f3f80e9
4 changed files with 16 additions and 20 deletions

View file

@ -15,7 +15,6 @@ from metagpt.logs import logger
from metagpt.roles import Role
from metagpt.schema import Message
from metagpt.software_company import SoftwareCompany
from metagpt.utils.common import any_to_str_set
class ShoutOut(Action):
@ -66,7 +65,7 @@ class Trump(Role):
async def _act(self) -> Message:
logger.info(f"{self._setting}: ready to {self._rc.todo}")
msg_history = self._rc.memory.get_by_actions(any_to_str_set([ShoutOut]))
msg_history = self._rc.memory.get_by_actions([ShoutOut])
context = []
for m in msg_history:
context.append(str(m))
@ -108,7 +107,7 @@ class Biden(Role):
async def _act(self) -> Message:
logger.info(f"{self._setting}: ready to {self._rc.todo}")
msg_history = self._rc.memory.get_by_actions(any_to_str_set([BossRequirement, ShoutOut]))
msg_history = self._rc.memory.get_by_actions([BossRequirement, ShoutOut])
context = []
for m in msg_history:
context.append(str(m))

View file

@ -10,6 +10,7 @@ from collections import defaultdict
from typing import Iterable, Set
from metagpt.schema import Message
from metagpt.utils.common import any_to_str, any_to_str_set
class Memory:
@ -73,14 +74,16 @@ class Memory:
news.append(i)
return news
def get_by_action(self, action: str) -> list[Message]:
def get_by_action(self, action) -> list[Message]:
"""Return all messages triggered by a specified Action"""
return self.index[action]
index = any_to_str(action)
return self.index[index]
def get_by_actions(self, actions: Set) -> list[Message]:
"""Return all messages triggered by specified Actions"""
rsp = []
for action in actions:
indices = any_to_str_set(actions)
for action in indices:
if action not in self.index:
continue
rsp += self.index[action]

View file

@ -21,7 +21,7 @@ from metagpt.const import WORKSPACE_ROOT
from metagpt.logs import logger
from metagpt.roles import Role
from metagpt.schema import Message
from metagpt.utils.common import CodeParser, any_to_str, any_to_str_set
from metagpt.utils.common import CodeParser, any_to_str
from metagpt.utils.special_tokens import FILENAME_CODE_SEP, MSG_SEP
@ -102,7 +102,7 @@ class Engineer(Role):
return CodeParser.parse_str(block="Python package name", text=system_design_msg.content)
def get_workspace(self) -> Path:
msg = self._rc.memory.get_by_action(any_to_str(WriteDesign))[-1]
msg = self._rc.memory.get_by_action(WriteDesign)[-1]
if not msg:
return WORKSPACE_ROOT / "src"
workspace = self.parse_workspace(msg)
@ -130,7 +130,7 @@ class Engineer(Role):
todo_coros = []
for todo in self.todos:
todo_coro = WriteCode().run(
context=self._rc.memory.get_by_actions(any_to_str_set([WriteTasks, WriteDesign])),
context=self._rc.memory.get_by_actions([WriteTasks, WriteDesign]),
filename=todo,
)
todo_coros.append(todo_coro)
@ -185,7 +185,7 @@ class Engineer(Role):
TODO: The goal is not to need it. After clear task decomposition, based on the design idea, you should be able to write a single file without needing other codes. If you can't, it means you need a clearer definition. This is the key to writing longer code.
"""
context = []
msg = self._rc.memory.get_by_actions(any_to_str_set([WriteDesign, WriteTasks, WriteCode]))
msg = self._rc.memory.get_by_actions([WriteDesign, WriteTasks, WriteCode])
for m in msg:
context.append(m.content)
context_str = "\n".join(context)
@ -240,8 +240,7 @@ class Engineer(Role):
async def _think(self) -> None:
# In asynchronous scenarios, first check if the required messages are ready.
filters = {any_to_str(WriteTasks)}
msgs = self._rc.memory.get_by_actions(filters)
msgs = self._rc.memory.get_by_actions({WriteTasks})
if not msgs:
self._rc.todo = None
return

View file

@ -22,12 +22,7 @@ from metagpt.const import WORKSPACE_ROOT
from metagpt.logs import logger
from metagpt.roles import Role
from metagpt.schema import Message
from metagpt.utils.common import (
CodeParser,
any_to_str_set,
get_class_name,
parse_recipient,
)
from metagpt.utils.common import CodeParser, any_to_str, any_to_str_set, parse_recipient
from metagpt.utils.special_tokens import FILENAME_CODE_SEP, MSG_SEP
@ -55,7 +50,7 @@ class QaEngineer(Role):
return CodeParser.parse_str(block="Python package name", text=system_design_msg.content)
def get_workspace(self, return_proj_dir=True) -> Path:
msg = self._rc.memory.get_by_action(get_class_name(WriteDesign))[-1]
msg = self._rc.memory.get_by_action(WriteDesign)[-1]
if not msg:
return WORKSPACE_ROOT / "src"
workspace = self.parse_workspace(msg)
@ -104,7 +99,7 @@ class QaEngineer(Role):
msg = Message(
content=str(file_info),
role=self.profile,
cause_by=WriteTest,
cause_by=any_to_str(WriteTest),
sent_from=self.profile,
send_to=self.profile,
)