refactor: get_class_name

This commit is contained in:
莘权 马 2023-11-10 15:55:33 +08:00
parent d36b4e2088
commit 3c38c5c416
2 changed files with 9 additions and 6 deletions

View file

@ -14,7 +14,7 @@ from metagpt.actions.action import Action
from metagpt.const import WORKSPACE_ROOT
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.utils.common import CodeParser, get_class_name
from metagpt.utils.common import CodeParser, any_to_str
PROMPT_TEMPLATE = """
NOTICE
@ -58,7 +58,7 @@ class WriteCode(Action):
if self._is_invalid(filename):
return
design = [i for i in context if i.cause_by == get_class_name(WriteDesign)][0]
design = [i for i in context if i.cause_by == any_to_str(WriteDesign)][0]
ws_name = CodeParser.parse_str(block="Python package name", text=design.content)
ws_path = WORKSPACE_ROOT / ws_name

View file

@ -10,6 +10,7 @@ from collections import defaultdict
from typing import Iterable, Set
from metagpt.schema import Message
from metagpt.utils.common import any_to_str, any_to_str_set
class Memory:
@ -73,14 +74,16 @@ class Memory:
news.append(i)
return news
def get_by_action(self, action: str) -> list[Message]:
def get_by_action(self, action) -> list[Message]:
"""Return all messages triggered by a specified Action"""
return self.index[action]
idx = any_to_str(action)
return self.index[idx]
def get_by_actions(self, actions: Set[str]) -> list[Message]:
def get_by_actions(self, actions: Set) -> list[Message]:
"""Return all messages triggered by specified Actions"""
idxs = any_to_str_set(actions)
rsp = []
for action in actions:
for action in idxs:
if action not in self.index:
continue
rsp += self.index[action]