mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
feat: memory + tags
This commit is contained in:
parent
01c487fb6a
commit
937bd12a63
4 changed files with 23 additions and 3 deletions
|
|
@ -91,3 +91,11 @@ class Memory:
|
|||
key = class_names[type(action).__name__]
|
||||
rsp += self.index[key]
|
||||
return rsp
|
||||
|
||||
def get_by_tags(self, tags: list) -> list[Message]:
|
||||
"""Return messages with specified tags"""
|
||||
result = []
|
||||
for m in self.storage:
|
||||
if m.is_contain_tags(tags):
|
||||
result.append(m)
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from metagpt.provider.openai_api import OpenAIGPTAPI as LLM
|
|||
from metagpt.actions import Action, ActionOutput
|
||||
from metagpt.logs import logger
|
||||
from metagpt.memory import Memory, LongTermMemory
|
||||
from metagpt.schema import Message
|
||||
from metagpt.schema import Message, MessageTag
|
||||
|
||||
PREFIX_TEMPLATE = """You are a {profile}, named {name}, your goal is {goal}, and the constraint is {constraints}. """
|
||||
|
||||
|
|
@ -90,6 +90,11 @@ class RoleContext(BaseModel):
|
|||
def history(self) -> list[Message]:
|
||||
return self.memory.get()
|
||||
|
||||
@property
|
||||
def prerequisite(self):
|
||||
"""Retrieve information with `prerequisite` tag"""
|
||||
return self.memory.get_by_tags([MessageTag.Prerequisite.value])
|
||||
|
||||
|
||||
class Role:
|
||||
"""Role/Proxy"""
|
||||
|
|
@ -209,7 +214,7 @@ class Role:
|
|||
# history=self.history)
|
||||
|
||||
logger.info(f"{self._setting}: ready to {self._rc.todo}")
|
||||
requirement = self._rc.important_memory
|
||||
requirement = self._rc.important_memory or self._rc.prerequisite
|
||||
response = await self._rc.todo.run(requirement, **self._options)
|
||||
# logger.info(response)
|
||||
if isinstance(response, ActionOutput):
|
||||
|
|
|
|||
|
|
@ -60,6 +60,13 @@ class Message:
|
|||
return
|
||||
self.tags.remove(tag)
|
||||
|
||||
def is_contain_tags(self, tags: list) -> bool:
|
||||
"""Determine whether the message contains tags."""
|
||||
if not tags or not self.tags:
|
||||
return False
|
||||
intersection = set(tags) & self.tags
|
||||
return len(intersection) > 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserMessage(Message):
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ def test_init():
|
|||
for i in inputs:
|
||||
seed = Inputs(**i)
|
||||
options = Config().runtime_options
|
||||
cost_manager = CostManager(options=options)
|
||||
cost_manager = CostManager(**options)
|
||||
teacher = Teacher(options=options, cost_manager=cost_manager, name=seed.name, profile=seed.profile,
|
||||
goal=seed.goal, constraints=seed.constraints,
|
||||
desc=seed.desc, **seed.kwargs)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue