feat: memory + tags

This commit is contained in:
莘权 马 2023-08-23 13:02:23 +08:00
parent 01c487fb6a
commit 937bd12a63
4 changed files with 23 additions and 3 deletions

View file

@ -91,3 +91,11 @@ class Memory:
key = class_names[type(action).__name__]
rsp += self.index[key]
return rsp
def get_by_tags(self, tags: list) -> list[Message]:
"""Return messages with specified tags"""
result = []
for m in self.storage:
if m.is_contain_tags(tags):
result.append(m)
return result

View file

@ -17,7 +17,7 @@ from metagpt.provider.openai_api import OpenAIGPTAPI as LLM
from metagpt.actions import Action, ActionOutput
from metagpt.logs import logger
from metagpt.memory import Memory, LongTermMemory
from metagpt.schema import Message
from metagpt.schema import Message, MessageTag
PREFIX_TEMPLATE = """You are a {profile}, named {name}, your goal is {goal}, and the constraint is {constraints}. """
@ -90,6 +90,11 @@ class RoleContext(BaseModel):
def history(self) -> list[Message]:
return self.memory.get()
@property
def prerequisite(self):
"""Retrieve information with `prerequisite` tag"""
return self.memory.get_by_tags([MessageTag.Prerequisite.value])
class Role:
"""Role/Proxy"""
@ -209,7 +214,7 @@ class Role:
# history=self.history)
logger.info(f"{self._setting}: ready to {self._rc.todo}")
requirement = self._rc.important_memory
requirement = self._rc.important_memory or self._rc.prerequisite
response = await self._rc.todo.run(requirement, **self._options)
# logger.info(response)
if isinstance(response, ActionOutput):

View file

@ -60,6 +60,13 @@ class Message:
return
self.tags.remove(tag)
def is_contain_tags(self, tags: list) -> bool:
"""Determine whether the message contains tags."""
if not tags or not self.tags:
return False
intersection = set(tags) & self.tags
return len(intersection) > 0
@dataclass
class UserMessage(Message):

View file

@ -60,7 +60,7 @@ def test_init():
for i in inputs:
seed = Inputs(**i)
options = Config().runtime_options
cost_manager = CostManager(options=options)
cost_manager = CostManager(**options)
teacher = Teacher(options=options, cost_manager=cost_manager, name=seed.name, profile=seed.profile,
goal=seed.goal, constraints=seed.constraints,
desc=seed.desc, **seed.kwargs)