refactor: get_class_name

This commit is contained in:
莘权 马 2023-11-10 15:58:47 +08:00
parent 3c38c5c416
commit 710bc40b0a

View file

@ -10,7 +10,6 @@ from collections import defaultdict
from typing import Iterable, Set
from metagpt.schema import Message
from metagpt.utils.common import any_to_str, any_to_str_set
class Memory:
@ -76,14 +75,12 @@ class Memory:
def get_by_action(self, action) -> list[Message]:
"""Return all messages triggered by a specified Action"""
idx = any_to_str(action)
return self.index[idx]
return self.index[action]
def get_by_actions(self, actions: Set) -> list[Message]:
"""Return all messages triggered by specified Actions"""
idxs = any_to_str_set(actions)
rsp = []
for action in idxs:
for action in actions:
if action not in self.index:
continue
rsp += self.index[action]