mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-18 13:55:17 +02:00
Change the way of taking over memory
This commit is contained in:
parent
d077cd0b2f
commit
c0d5c031f8
8 changed files with 270 additions and 407 deletions
|
|
@ -158,5 +158,5 @@ SWE_SETUP_PATH = get_metagpt_package_root() / "metagpt/tools/swe_agent_commands/
|
|||
# experience pool
|
||||
EXPERIENCE_MASK = "<experience>"
|
||||
|
||||
# Used to identify user requirements in the memory index.
|
||||
USER_REQUIREMENT = "metagpt.actions.add_requirement.UserRequirement"
|
||||
# TeamLeader's name
|
||||
TEAMLEADER_NAME = "Mike"
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from metagpt.actions import (
|
|||
WriteTest,
|
||||
)
|
||||
from metagpt.actions.summarize_code import SummarizeCode
|
||||
from metagpt.const import AGENT, IMAGES
|
||||
from metagpt.const import AGENT, IMAGES, TEAMLEADER_NAME
|
||||
from metagpt.environment.base_env import Environment
|
||||
from metagpt.logs import get_human_input
|
||||
from metagpt.roles import Architect, ProductManager, ProjectManager, Role
|
||||
|
|
@ -31,7 +31,7 @@ class MGXEnv(Environment, SerializationMixin):
|
|||
"""let the team leader take over message publishing"""
|
||||
message = self.attach_images(message) # for multi-modal message
|
||||
|
||||
tl = self.get_role("Mike") # TeamLeader's name is Mike
|
||||
tl = self.get_role(TEAMLEADER_NAME) # TeamLeader's name is Mike
|
||||
|
||||
if user_defined_recipient:
|
||||
# human user's direct chat message to a certain role
|
||||
|
|
|
|||
|
|
@ -1,8 +1,14 @@
|
|||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions import UserRequirement
|
||||
from metagpt.const import TEAMLEADER_NAME
|
||||
from metagpt.logs import logger
|
||||
from metagpt.memory import Memory
|
||||
from metagpt.schema import LongTermMemoryItem, Message
|
||||
from metagpt.utils.common import any_to_str
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llama_index.core.schema import NodeWithScore
|
||||
|
|
@ -10,9 +16,10 @@ if TYPE_CHECKING:
|
|||
from metagpt.rag.engines import SimpleEngine
|
||||
|
||||
|
||||
class RoleZeroLongTermMemory(BaseModel):
|
||||
class RoleZeroLongTermMemory(Memory):
|
||||
persist_path: str = Field(default=".role_memory_data", description="The directory to save data.")
|
||||
collection_name: str = Field(default="role_zero", description="The name of the collection, such as the role name.")
|
||||
memory_k: int = Field(default=200, description="The capacity of short-term memory.")
|
||||
|
||||
_rag_engine: Any = None
|
||||
|
||||
|
|
@ -44,7 +51,104 @@ class RoleZeroLongTermMemory(BaseModel):
|
|||
|
||||
return rag_engine
|
||||
|
||||
def fetch(self, query: str) -> list[Message]:
|
||||
def add(self, message: Message):
|
||||
"""Add a new message and potentially transfer it to long-term memory."""
|
||||
|
||||
super().add(message)
|
||||
|
||||
if not self._should_use_longterm_memory_for_add():
|
||||
return
|
||||
|
||||
self._transfer_to_longterm_memory()
|
||||
|
||||
def get(self, k=0) -> list[Message]:
|
||||
"""Return recent memories and optionally combines them with related long-term memories."""
|
||||
|
||||
memories = super().get(k)
|
||||
|
||||
if not self._should_use_longterm_memory_for_get(k=k):
|
||||
return memories
|
||||
|
||||
query = self._build_longterm_memory_query()
|
||||
related_memories = self._fetch_longterm_memories(query)
|
||||
logger.info(f"Fetched {len(related_memories)} long-term memories.")
|
||||
|
||||
final_memories = related_memories + memories
|
||||
|
||||
return final_memories
|
||||
|
||||
def _should_use_longterm_memory_for_add(self) -> bool:
|
||||
"""Determines if long-term memory should be used for add."""
|
||||
|
||||
return self.count() > self.memory_k
|
||||
|
||||
def _should_use_longterm_memory_for_get(self, k: int) -> bool:
|
||||
"""Determines if long-term memory should be used for get.
|
||||
|
||||
Long-term memory is used if:
|
||||
- k is not 0.
|
||||
- The last message is from user requirement.
|
||||
- The count of recent memories is greater than self.memory_k.
|
||||
"""
|
||||
|
||||
conds = [
|
||||
k != 0,
|
||||
self._is_last_message_from_user_requirement(),
|
||||
self.count() > self.memory_k,
|
||||
]
|
||||
|
||||
return all(conds)
|
||||
|
||||
def _transfer_to_longterm_memory(self):
|
||||
item = self._get_longterm_memory_item()
|
||||
self._add_to_longterm_memory(item)
|
||||
|
||||
@handle_exception
|
||||
def _get_longterm_memory_item(self) -> Optional[LongTermMemoryItem]:
|
||||
"""Retrieves the most recent message before the last k messages."""
|
||||
|
||||
index = -(self.memory_k + 1)
|
||||
message = self.get_by_position(index)
|
||||
|
||||
return LongTermMemoryItem(message=message)
|
||||
|
||||
def _add_to_longterm_memory(self, item: LongTermMemoryItem):
|
||||
"""Adds a long-term memory item to the RAG engine."""
|
||||
|
||||
if not item:
|
||||
return
|
||||
|
||||
self.rag_engine.add_objs([item])
|
||||
|
||||
def _build_longterm_memory_query(self) -> str:
|
||||
"""Build the content used to query related long-term memory.
|
||||
|
||||
Default is to get the most recent user message, or an empty string if none is found.
|
||||
"""
|
||||
|
||||
message = self._get_the_last_message()
|
||||
|
||||
return message.content if message else ""
|
||||
|
||||
def _get_the_last_message(self) -> Optional[Message]:
|
||||
if not self.count():
|
||||
return None
|
||||
|
||||
return self.get_by_position(-1)
|
||||
|
||||
def _is_last_message_from_user_requirement(self) -> bool:
|
||||
message = self._get_the_last_message()
|
||||
|
||||
if not message:
|
||||
return False
|
||||
|
||||
is_user_message = message.is_user_message()
|
||||
cause_by_user_requirement = message.cause_by == any_to_str(UserRequirement)
|
||||
sent_from_team_leader = message.sent_from == TEAMLEADER_NAME
|
||||
|
||||
return is_user_message and (cause_by_user_requirement or sent_from_team_leader)
|
||||
|
||||
def _fetch_longterm_memories(self, query: str) -> list[Message]:
|
||||
"""Fetches long-term memories based on a query.
|
||||
|
||||
Args:
|
||||
|
|
@ -59,26 +163,10 @@ class RoleZeroLongTermMemory(BaseModel):
|
|||
|
||||
nodes = self.rag_engine.retrieve(query)
|
||||
items = self._get_items_from_nodes(nodes)
|
||||
|
||||
memories = []
|
||||
for item in items:
|
||||
memories.append(item.user_message)
|
||||
memories.append(item.ai_message)
|
||||
memories = [item.message for item in items]
|
||||
|
||||
return memories
|
||||
|
||||
def add(self, item: LongTermMemoryItem):
|
||||
"""Adds a long-term memory item to the RAG engine.
|
||||
|
||||
Args:
|
||||
item (LongTermMemoryItem): The memory item containing user and AI messages.
|
||||
"""
|
||||
|
||||
if not item:
|
||||
return
|
||||
|
||||
self.rag_engine.add_objs([item])
|
||||
|
||||
def _get_items_from_nodes(self, nodes: list["NodeWithScore"]) -> list[LongTermMemoryItem]:
|
||||
"""Get items from nodes and arrange them in order of their `created_at`."""
|
||||
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from metagpt.actions import Action, UserRequirement
|
|||
from metagpt.actions.analyze_requirements import AnalyzeRequirementsRestrictions
|
||||
from metagpt.actions.di.run_command import RunCommand
|
||||
from metagpt.actions.search_enhanced_qa import SearchEnhancedQA
|
||||
from metagpt.const import IMAGES, USER_REQUIREMENT
|
||||
from metagpt.const import IMAGES
|
||||
from metagpt.exp_pool import exp_cache
|
||||
from metagpt.exp_pool.context_builders import RoleZeroContextBuilder
|
||||
from metagpt.exp_pool.serializers import RoleZeroSerializer
|
||||
|
|
@ -34,7 +34,7 @@ from metagpt.prompts.di.role_zero import (
|
|||
SYSTEM_PROMPT,
|
||||
)
|
||||
from metagpt.roles import Role
|
||||
from metagpt.schema import AIMessage, LongTermMemoryItem, Message, UserMessage
|
||||
from metagpt.schema import AIMessage, Message, UserMessage
|
||||
from metagpt.strategy.experience_retriever import DummyExpRetriever, ExpRetriever
|
||||
from metagpt.strategy.planner import Planner
|
||||
from metagpt.tools.libs.browser import Browser
|
||||
|
|
@ -42,7 +42,6 @@ from metagpt.tools.libs.editor import Editor
|
|||
from metagpt.tools.tool_recommend import BM25ToolRecommender, ToolRecommender
|
||||
from metagpt.tools.tool_registry import register_tool
|
||||
from metagpt.utils.common import CodeParser, any_to_str, extract_and_encode_images
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
from metagpt.utils.repair_llm_raw_output import (
|
||||
RepairType,
|
||||
repair_escape_error,
|
||||
|
|
@ -94,7 +93,6 @@ class RoleZero(Role):
|
|||
commands: list[dict] = [] # commands to be executed
|
||||
memory_k: int = 200 # number of memories (messages) to use as historical context
|
||||
enable_longterm_memory: bool = True # whether to use longterm memory
|
||||
longterm_memory: RoleZeroLongTermMemory = None
|
||||
use_fixed_sop: bool = False
|
||||
requirements_constraints: str = "" # the constraints in user requirements
|
||||
use_summary: bool = True # whether to summarize at the end
|
||||
|
|
@ -176,8 +174,8 @@ class RoleZero(Role):
|
|||
The role name will be used as the collection name.
|
||||
"""
|
||||
|
||||
if self.enable_longterm_memory and not self.longterm_memory:
|
||||
self.longterm_memory = RoleZeroLongTermMemory(collection_name=self.name.replace(" ", ""))
|
||||
if self.enable_longterm_memory:
|
||||
self.rc.memory = RoleZeroLongTermMemory(collection_name=self.name.replace(" ", ""), memory_k=self.memory_k)
|
||||
|
||||
return self
|
||||
|
||||
|
|
@ -195,7 +193,7 @@ class RoleZero(Role):
|
|||
return False
|
||||
|
||||
if not self.planner.plan.goal:
|
||||
self.planner.plan.goal = self._get_all_memories()[-1].content
|
||||
self.planner.plan.goal = self.get_memories()[-1].content
|
||||
self.requirements_constraints = await AnalyzeRequirementsRestrictions().run(self.planner.plan.goal)
|
||||
|
||||
### 1. Experience ###
|
||||
|
|
@ -227,7 +225,7 @@ class RoleZero(Role):
|
|||
)
|
||||
|
||||
### Recent Observation ###
|
||||
memory = self._fetch_memories()
|
||||
memory = self.rc.memory.get(self.memory_k)
|
||||
memory = await self.parse_browser_actions(memory)
|
||||
memory = self.parse_images(memory)
|
||||
|
||||
|
|
@ -282,15 +280,15 @@ class RoleZero(Role):
|
|||
return await super()._act()
|
||||
|
||||
commands, ok, self.command_rsp = await self._parse_commands(self.command_rsp)
|
||||
self._add_memory(AIMessage(content=self.command_rsp))
|
||||
self.rc.memory.add(AIMessage(content=self.command_rsp))
|
||||
if not ok:
|
||||
error_msg = commands
|
||||
self._add_memory(UserMessage(content=error_msg, cause_by=RunCommand))
|
||||
self.rc.memory.add(UserMessage(content=error_msg, cause_by=RunCommand))
|
||||
return error_msg
|
||||
logger.info(f"Commands: \n{commands}")
|
||||
outputs = await self._run_commands(commands)
|
||||
logger.info(f"Commands outputs: \n{outputs}")
|
||||
self._add_memory(UserMessage(content=outputs, cause_by=RunCommand))
|
||||
self.rc.memory.add(UserMessage(content=outputs, cause_by=RunCommand))
|
||||
|
||||
return AIMessage(
|
||||
content=f"I have finished the task, please mark my task as finished. Outputs: {outputs}",
|
||||
|
|
@ -343,7 +341,7 @@ class RoleZero(Role):
|
|||
return rsp_msg, ""
|
||||
|
||||
# routing
|
||||
memory = self._fetch_memories()
|
||||
memory = self.get_memories(k=self.memory_k)
|
||||
context = self.llm.format_msg(memory + [UserMessage(content=QUICK_THINK_PROMPT)])
|
||||
async with ThoughtReporter() as reporter:
|
||||
await reporter.async_report({"type": "classify"})
|
||||
|
|
@ -368,7 +366,7 @@ class RoleZero(Role):
|
|||
answer = await SearchEnhancedQA().run(query)
|
||||
|
||||
if answer:
|
||||
self._add_memory(AIMessage(content=answer, cause_by=RunCommand))
|
||||
self.rc.memory.add(AIMessage(content=answer, cause_by=RunCommand))
|
||||
await self.reply_to_human(content=answer)
|
||||
rsp_msg = AIMessage(
|
||||
content="Complete run",
|
||||
|
|
@ -379,7 +377,7 @@ class RoleZero(Role):
|
|||
return rsp_msg, intent_result
|
||||
|
||||
async def _check_duplicates(self, req: list[dict], command_rsp: str):
|
||||
past_rsp = [mem.content for mem in self._fetch_memories()]
|
||||
past_rsp = [mem.content for mem in self.rc.memory.get(self.memory_k)]
|
||||
if command_rsp in past_rsp:
|
||||
# Normal response with thought contents are highly unlikely to reproduce
|
||||
# If an identical response is detected, it is a bad response, mostly due to LLM repeating generated content
|
||||
|
|
@ -537,7 +535,7 @@ class RoleZero(Role):
|
|||
|
||||
def _retrieve_experience(self) -> str:
|
||||
"""Default implementation of experience retrieval. Can be overwritten in subclasses."""
|
||||
context = [str(msg) for msg in self._fetch_memories()]
|
||||
context = [str(msg) for msg in self.rc.memory.get(self.memory_k)]
|
||||
context = "\n\n".join(context)
|
||||
example = self.experience_retriever.retrieve(context=context)
|
||||
return example
|
||||
|
|
@ -562,9 +560,9 @@ class RoleZero(Role):
|
|||
|
||||
async def _end(self, **kwarg):
|
||||
self._set_state(-1)
|
||||
memory = self._fetch_memories()
|
||||
memory = self.rc.memory.get(self.memory_k)
|
||||
# Ensure reply to the human before the "end" command is executed. Hard code k=5 for checking.
|
||||
if not any(["reply_to_human" in memory.content for memory in self._fetch_memories(k=5)]):
|
||||
if not any(["reply_to_human" in memory.content for memory in self.get_memories(k=5)]):
|
||||
logger.info("manually reply to human")
|
||||
pattern = r"\[Language Restrictions\](.*?)\n"
|
||||
match = re.search(pattern, self.requirements_constraints, re.DOTALL)
|
||||
|
|
@ -573,106 +571,10 @@ class RoleZero(Role):
|
|||
await reporter.async_report({"type": "quick"})
|
||||
reply_content = await self.llm.aask(self.llm.format_msg(memory + [UserMessage(reply_to_human_prompt)]))
|
||||
await self.reply_to_human(content=reply_content)
|
||||
self._add_memory(AIMessage(content=reply_content, cause_by=RunCommand))
|
||||
self.rc.memory.add(AIMessage(content=reply_content, cause_by=RunCommand))
|
||||
outputs = ""
|
||||
# Summary of the Completed Task and Deliverables
|
||||
if self.use_summary:
|
||||
logger.info("end current run and summarize")
|
||||
outputs = await self.llm.aask(self.llm.format_msg(memory + [UserMessage(SUMMARY_PROMPT)]))
|
||||
return outputs
|
||||
|
||||
def _get_all_memories(self) -> list[Message]:
|
||||
return self._fetch_memories(k=0)
|
||||
|
||||
def _fetch_memories(self, k: Optional[int] = None) -> list[Message]:
|
||||
"""Fetches recent memories and optionally combines them with related long-term memories.
|
||||
|
||||
If long-term memory is not enabled or the last message is not from the user,
|
||||
it returns the recent memories without fetching from long-term memory.
|
||||
|
||||
Args:
|
||||
k (Optional[int]): The number of recent memories to fetch. If None, defaults to self.memory_k.
|
||||
|
||||
Returns:
|
||||
List[Message]: A list of messages representing the combined memories.
|
||||
"""
|
||||
|
||||
if k is None:
|
||||
k = self.memory_k
|
||||
|
||||
memories = self.rc.memory.get(k)
|
||||
|
||||
if not self._should_use_longterm_memory(k=k):
|
||||
return memories
|
||||
|
||||
query = self._build_longterm_memory_query()
|
||||
related_memories = self.longterm_memory.fetch(query)
|
||||
logger.info(f"Fetched {len(related_memories)} long-term memories.")
|
||||
|
||||
# Keep user and AI messages are paired.
|
||||
if self._is_first_message_from_ai(memories):
|
||||
memories.insert(0, self.rc.memory.get_by_position(-(k + 1)))
|
||||
|
||||
final_memories = related_memories + memories
|
||||
|
||||
return final_memories
|
||||
|
||||
def _add_memory(self, message: Message):
|
||||
self.rc.memory.add(message)
|
||||
|
||||
if not self._should_use_longterm_memory():
|
||||
return
|
||||
|
||||
self._transfer_to_longterm_memory()
|
||||
|
||||
def _should_use_longterm_memory(self, k: int = None) -> bool:
|
||||
"""Determines if long-term memory should be used.
|
||||
|
||||
Long-term memory is used if:
|
||||
- k is not 0.
|
||||
- Long-term memory usage is enabled.
|
||||
- The count of recent memories is greater than self.memory_k.
|
||||
"""
|
||||
|
||||
conds = [
|
||||
k != 0,
|
||||
self.enable_longterm_memory,
|
||||
self.rc.memory.count() > self.memory_k,
|
||||
]
|
||||
|
||||
return all(conds)
|
||||
|
||||
def _transfer_to_longterm_memory(self):
|
||||
item = self._get_longterm_memory_item()
|
||||
self.longterm_memory.add(item)
|
||||
|
||||
@handle_exception
|
||||
def _get_longterm_memory_item(self) -> Optional[LongTermMemoryItem]:
|
||||
"""Retrieves the most recent pair of user and AI messages before the last k messages."""
|
||||
|
||||
index = -(self.memory_k + 1)
|
||||
message = self.rc.memory.get_by_position(index)
|
||||
if not message.is_ai_message():
|
||||
return None
|
||||
|
||||
index = -(self.memory_k + 2)
|
||||
user_message = self.rc.memory.get_by_position(index)
|
||||
|
||||
return LongTermMemoryItem(user_message=user_message, ai_message=message)
|
||||
|
||||
def _is_first_message_from_ai(self, memories: list[Message]) -> bool:
|
||||
return bool(memories and memories[0].is_ai_message())
|
||||
|
||||
def _build_longterm_memory_query(self) -> str:
|
||||
"""Build the content used to query related long-term memory.
|
||||
|
||||
Default is to get the most recent user message, or an empty string if none is found.
|
||||
"""
|
||||
message = self._get_the_last_user_message()
|
||||
|
||||
return message.content if message else ""
|
||||
|
||||
def _get_the_last_user_message(self) -> Message:
|
||||
values = self.rc.memory.index.get(USER_REQUIREMENT, [])
|
||||
|
||||
return values[-1] if values else None
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from typing import Annotated
|
|||
from pydantic import Field
|
||||
|
||||
from metagpt.actions.di.run_command import RunCommand
|
||||
from metagpt.const import TEAMLEADER_NAME
|
||||
from metagpt.prompts.di.team_leader import (
|
||||
FINISH_CURRENT_TASK_CMD,
|
||||
TL_INFO,
|
||||
|
|
@ -19,7 +20,7 @@ from metagpt.tools.tool_registry import register_tool
|
|||
|
||||
@register_tool(include_functions=["publish_team_message"])
|
||||
class TeamLeader(RoleZero):
|
||||
name: str = "Mike"
|
||||
name: str = TEAMLEADER_NAME
|
||||
profile: str = "Team Leader"
|
||||
goal: str = "Manage a team to assist users"
|
||||
thought_guidance: str = TL_THOUGHT_GUIDANCE
|
||||
|
|
@ -84,4 +85,4 @@ class TeamLeader(RoleZero):
|
|||
|
||||
def finish_current_task(self):
|
||||
self.planner.plan.finish_current_task()
|
||||
self._add_memory(AIMessage(content=FINISH_CURRENT_TASK_CMD))
|
||||
self.rc.memory.add(AIMessage(content=FINISH_CURRENT_TASK_CMD))
|
||||
|
|
|
|||
|
|
@ -965,9 +965,8 @@ class BaseEnum(Enum):
|
|||
|
||||
|
||||
class LongTermMemoryItem(BaseModel):
|
||||
user_message: Message
|
||||
ai_message: Message
|
||||
message: Message
|
||||
created_at: Optional[float] = Field(default_factory=time.time)
|
||||
|
||||
def rag_key(self) -> str:
|
||||
return self.user_message.content
|
||||
return self.message.content
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue