From 937bd12a63733d818338f7d3ad8c2d0907fe5c4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Wed, 23 Aug 2023 13:02:23 +0800 Subject: [PATCH] feat: memory + tags --- metagpt/memory/memory.py | 8 ++++++++ metagpt/roles/role.py | 9 +++++++-- metagpt/schema.py | 7 +++++++ tests/metagpt/roles/test_teacher.py | 2 +- 4 files changed, 23 insertions(+), 3 deletions(-) diff --git a/metagpt/memory/memory.py b/metagpt/memory/memory.py index 625d98675..1a8003fba 100644 --- a/metagpt/memory/memory.py +++ b/metagpt/memory/memory.py @@ -91,3 +91,11 @@ class Memory: key = class_names[type(action).__name__] rsp += self.index[key] return rsp + + def get_by_tags(self, tags: list) -> list[Message]: + """Return messages with specified tags""" + result = [] + for m in self.storage: + if m.is_contain_tags(tags): + result.append(m) + return result diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 00f8ed45f..217272b54 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -17,7 +17,7 @@ from metagpt.provider.openai_api import OpenAIGPTAPI as LLM from metagpt.actions import Action, ActionOutput from metagpt.logs import logger from metagpt.memory import Memory, LongTermMemory -from metagpt.schema import Message +from metagpt.schema import Message, MessageTag PREFIX_TEMPLATE = """You are a {profile}, named {name}, your goal is {goal}, and the constraint is {constraints}. """ @@ -90,6 +90,11 @@ class RoleContext(BaseModel): def history(self) -> list[Message]: return self.memory.get() + @property + def prerequisite(self): + """Retrieve information with `prerequisite` tag""" + return self.memory.get_by_tags([MessageTag.Prerequisite.value]) + class Role: """Role/Proxy""" @@ -209,7 +214,7 @@ class Role: # history=self.history) logger.info(f"{self._setting}: ready to {self._rc.todo}") - requirement = self._rc.important_memory + requirement = self._rc.important_memory or self._rc.prerequisite response = await self._rc.todo.run(requirement, **self._options) # logger.info(response) if isinstance(response, ActionOutput): diff --git a/metagpt/schema.py b/metagpt/schema.py index 56e9ad95c..4c577fd7b 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -60,6 +60,13 @@ class Message: return self.tags.remove(tag) + def is_contain_tags(self, tags: list) -> bool: + """Determine whether the message contains tags.""" + if not tags or not self.tags: + return False + intersection = set(tags) & self.tags + return len(intersection) > 0 + @dataclass class UserMessage(Message): diff --git a/tests/metagpt/roles/test_teacher.py b/tests/metagpt/roles/test_teacher.py index 11c268edb..8f673d6e0 100644 --- a/tests/metagpt/roles/test_teacher.py +++ b/tests/metagpt/roles/test_teacher.py @@ -60,7 +60,7 @@ def test_init(): for i in inputs: seed = Inputs(**i) options = Config().runtime_options - cost_manager = CostManager(options=options) + cost_manager = CostManager(**options) teacher = Teacher(options=options, cost_manager=cost_manager, name=seed.name, profile=seed.profile, goal=seed.goal, constraints=seed.constraints, desc=seed.desc, **seed.kwargs)