use long-term memory in rolezero

This commit is contained in:
seehi 2024-09-06 10:23:07 +08:00
parent 4e82d86166
commit f20d3ca2dd
8 changed files with 440 additions and 15 deletions

View file

@ -1,8 +1,12 @@
from __future__ import annotations
from abc import abstractmethod
from typing import Optional, Union
from typing import TYPE_CHECKING, Optional, Union
from metagpt.base.base_serialization import BaseSerialization
from metagpt.schema import Message
if TYPE_CHECKING:
from metagpt.schema import Message
class BaseRole(BaseSerialization):

View file

@ -104,3 +104,7 @@ class Memory(BaseModel):
continue
rsp += self.index[action]
return rsp
def get_by_position(self, position: int) -> Message:
"""Return the message by its position"""
return self.storage[position]

View file

@ -0,0 +1,60 @@
from typing import TYPE_CHECKING, Any
from pydantic import BaseModel, Field
from metagpt.schema import LongTermMemoryItem, Message
if TYPE_CHECKING:
from llama_index.core.schema import NodeWithScore
from metagpt.rag.engines import SimpleEngine
class RoleZeroLongTermMemory(BaseModel):
persist_path: str = Field(default=".role_memory_data", description="The directory to save data.")
collection_name: str = Field(default="role_zero", description="The name of the collection, such as the role name.")
_rag_engine: Any = None
@property
def rag_engine(self) -> "SimpleEngine":
if self._rag_engine is None:
self._rag_engine = self._resolve_rag_engine()
return self._rag_engine
def _resolve_rag_engine(self) -> "SimpleEngine":
try:
from metagpt.rag.engines import SimpleEngine
from metagpt.rag.schema import ChromaRetrieverConfig, LLMRankerConfig
except ImportError:
raise ImportError("To use the RoleZeroMemory, you need to install the rag module.")
retriever_configs = [
ChromaRetrieverConfig(persist_path=self.persist_path, collection_name=self.collection_name)
]
ranker_configs = [LLMRankerConfig()]
rag_engine = SimpleEngine.from_objs(retriever_configs=retriever_configs, ranker_configs=ranker_configs)
return rag_engine
def fetch(self, query: str) -> list[Message]:
if not query:
return []
nodes: list[NodeWithScore] = self.rag_engine.retrieve(query)
memories = []
for node in nodes:
item: LongTermMemoryItem = node.metadata["obj"]
memories.append(item.user_message)
memories.append(item.ai_message)
return memories
def add(self, item: LongTermMemoryItem):
if not item:
return
self.rag_engine.add_objs([item])

View file

@ -18,6 +18,7 @@ from metagpt.exp_pool import exp_cache
from metagpt.exp_pool.context_builders import RoleZeroContextBuilder
from metagpt.exp_pool.serializers import RoleZeroSerializer
from metagpt.logs import logger
from metagpt.memory.role_zero_memory import RoleZeroLongTermMemory
from metagpt.prompts.di.role_zero import (
ASK_HUMAN_COMMAND,
CMD_PROMPT,
@ -34,7 +35,7 @@ from metagpt.prompts.di.role_zero import (
THOUGHT_GUIDANCE,
)
from metagpt.roles import Role
from metagpt.schema import AIMessage, Message, UserMessage
from metagpt.schema import AIMessage, LongTermMemoryItem, Message, UserMessage
from metagpt.strategy.experience_retriever import DummyExpRetriever, ExpRetriever
from metagpt.strategy.planner import Planner
from metagpt.tools.libs.browser import Browser
@ -42,6 +43,7 @@ from metagpt.tools.libs.editor import Editor
from metagpt.tools.tool_recommend import BM25ToolRecommender, ToolRecommender
from metagpt.tools.tool_registry import register_tool
from metagpt.utils.common import CodeParser, any_to_str, extract_and_encode_images
from metagpt.utils.exceptions import handle_exception
from metagpt.utils.repair_llm_raw_output import (
RepairType,
repair_escape_error,
@ -86,6 +88,8 @@ class RoleZero(Role):
command_rsp: str = "" # the raw string containing the commands
commands: list[dict] = [] # commands to be executed
memory_k: int = 20 # number of memories (messages) to use as historical context
enable_longterm_memory: bool = True # whether to use longterm memory
longterm_memory: RoleZeroLongTermMemory = None
use_fixed_sop: bool = False
requirements_constraints: str = "" # the constraints in user requirements
use_summary: bool = True # whether to summarize at the end
@ -140,6 +144,19 @@ class RoleZero(Role):
self._update_tool_execution()
return self
@model_validator(mode="after")
def set_longterm_memory(self) -> "RoleZero":
"""Set longterm memory.
If enable_longterm_memory is True and longterm_memory is not set, set it.
The role name will be used as the collection name.
"""
if self.enable_longterm_memory and not self.longterm_memory:
self.longterm_memory = RoleZeroLongTermMemory(collection_name=self.name.replace(" ", ""))
return self
def _update_tool_execution(self):
pass
@ -154,7 +171,7 @@ class RoleZero(Role):
return False
if not self.planner.plan.goal:
self.planner.plan.goal = self.get_memories()[-1].content
self.planner.plan.goal = self._get_all_memories()[-1].content
self.requirements_constraints = await AnalyzeRequirementsRestrictions().run(self.planner.plan.goal)
### 1. Experience ###
@ -186,7 +203,7 @@ class RoleZero(Role):
)
### Recent Observation ###
memory = self.rc.memory.get(self.memory_k)
memory = self._fetch_memories()
memory = await self.parse_browser_actions(memory)
memory = self.parse_images(memory)
@ -202,7 +219,7 @@ class RoleZero(Role):
self.command_rsp = await self._check_duplicates(req, self.command_rsp)
self.rc.memory.add(AIMessage(content=self.command_rsp))
self._add_memory(AIMessage(content=self.command_rsp))
return True
@exp_cache(context_builder=RoleZeroContextBuilder(), serializer=RoleZeroSerializer())
@ -245,12 +262,12 @@ class RoleZero(Role):
commands, ok = await self._parse_commands(self.command_rsp)
if not ok:
error_msg = commands
self.rc.memory.add(UserMessage(content=error_msg))
self._add_memory(UserMessage(content=error_msg))
return error_msg
logger.info(f"Commands: \n{commands}")
outputs = await self._run_commands(commands)
logger.info(f"Commands outputs: \n{outputs}")
self.rc.memory.add(UserMessage(content=outputs))
self._add_memory(UserMessage(content=outputs))
return AIMessage(
content=f"I have finished the task, please mark my task as finished. Outputs: {outputs}",
@ -303,7 +320,7 @@ class RoleZero(Role):
return rsp_msg, ""
# routing
memory = self.get_memories(k=self.memory_k)
memory = self._fetch_memories()
context = self.llm.format_msg(memory + [UserMessage(content=QUICK_THINK_PROMPT)])
async with ThoughtReporter() as reporter:
await reporter.async_report({"type": "classify"})
@ -328,7 +345,7 @@ class RoleZero(Role):
answer = await SearchEnhancedQA().run(query)
if answer:
self.rc.memory.add(AIMessage(content=answer, cause_by=RunCommand))
self._add_memory(AIMessage(content=answer, cause_by=RunCommand))
await self.reply_to_human(content=answer)
rsp_msg = AIMessage(
content="Complete run",
@ -339,7 +356,7 @@ class RoleZero(Role):
return rsp_msg, intent_result
async def _check_duplicates(self, req: list[dict], command_rsp: str):
past_rsp = [mem.content for mem in self.rc.memory.get(self.memory_k)]
past_rsp = [mem.content for mem in self._fetch_memories()]
if command_rsp in past_rsp:
# Normal response with thought contents are highly unlikely to reproduce
# If an identical response is detected, it is a bad response, mostly due to LLM repeating generated content
@ -479,7 +496,7 @@ class RoleZero(Role):
def _retrieve_experience(self) -> str:
"""Default implementation of experience retrieval. Can be overwritten in subclasses."""
context = [str(msg) for msg in self.rc.memory.get(self.memory_k)]
context = [str(msg) for msg in self._fetch_memories()]
context = "\n\n".join(context)
example = self.experience_retriever.retrieve(context=context)
return example
@ -504,9 +521,9 @@ class RoleZero(Role):
async def _end(self):
self._set_state(-1)
memory = self.rc.memory.get(self.memory_k)
memory = self._fetch_memories()
# Ensure reply to the human before the "end" command is executed. Hard code k=5 for checking.
if not any(["reply_to_human" in memory.content for memory in self.get_memories(k=5)]):
if not any(["reply_to_human" in memory.content for memory in self._fetch_memories(k=5)]):
logger.info("manually reply to human")
pattern = r"\[Language Restrictions\](.*?)\n"
match = re.search(pattern, self.requirements_constraints, re.DOTALL)
@ -515,10 +532,95 @@ class RoleZero(Role):
await reporter.async_report({"type": "quick"})
reply_content = await self.llm.aask(self.llm.format_msg(memory + [UserMessage(reply_to_human_prompt)]))
await self.reply_to_human(content=reply_content)
self.rc.memory.add(AIMessage(content=reply_content, cause_by=RunCommand))
self._add_memory(AIMessage(content=reply_content, cause_by=RunCommand))
outputs = ""
# Summary of the Completed Task and Deliverables
if self.use_summary:
logger.info("end current run and summarize")
outputs = await self.llm.aask(self.llm.format_msg(memory + [UserMessage(SUMMARY_PROMPT)]))
return outputs
def _get_all_memories(self) -> list[Message]:
return self._fetch_memories(k=0)
def _fetch_memories(self, k: Optional[int] = None) -> list[Message]:
"""Fetches recent memories and optionally combines them with related long-term memories.
If long-term memory is not enabled or the last message is not from the user,
it returns the recent memories without fetching from long-term memory.
Args:
k (Optional[int]): The number of recent memories to fetch. If None, defaults to self.memory_k.
Returns:
List[Message]: A list of messages representing the combined memories.
"""
if k is None:
k = self.memory_k
memories = self.rc.memory.get(k)
if not self._should_use_longterm_memory(k=k, k_memories=memories):
return memories
related_memories = self.longterm_memory.fetch(memories[-1].content)
logger.info(f"Fetched {len(related_memories)} long-term memories.")
if related_memories and self._is_first_message_from_ai(memories):
memories = memories[1:]
final_memories = related_memories + memories
return final_memories
def _add_memory(self, message: Message):
self.rc.memory.add(message)
if not self._should_use_longterm_memory():
return
self._transfer_to_longterm_memory()
def _should_use_longterm_memory(self, k: int = None, k_memories: list[Message] = None) -> bool:
"""Determines if long-term memory should be used.
Long-term memory is used if:
- k is not 0.
- k_memories is None or k_memories is not empty, and the last message is a user message.
- Long-term memory usage is enabled.
- The count of recent memories is greater than self.memory_k.
"""
conds = [
k != 0,
k_memories is None or self._is_last_message_from_user(k_memories),
self.enable_longterm_memory,
self.rc.memory.count() > self.memory_k,
]
return all(conds)
def _transfer_to_longterm_memory(self):
item = self._get_longterm_memory_item()
self.longterm_memory.add(item)
@handle_exception
def _get_longterm_memory_item(self) -> Optional[LongTermMemoryItem]:
"""Retrieves the most recent pair of user and AI messages before the last k messages."""
index = -(self.memory_k + 1)
message = self.rc.memory.get_by_position(index)
if not message.is_ai_message():
return None
index = -(self.memory_k + 2)
user_message = self.rc.memory.get_by_position(index)
return LongTermMemoryItem(user_message=user_message, ai_message=message)
def _is_last_message_from_user(self, memories: list[Message]) -> bool:
return bool(memories and memories[-1].is_user_message())
def _is_first_message_from_ai(self, memories: list[Message]) -> bool:
return bool(memories and memories[0].is_ai_message())

View file

@ -408,6 +408,12 @@ class Message(BaseModel):
dynamic_class = create_model(class_name, **{key: (value.__class__, ...) for key, value in kvs.items()})
return dynamic_class.model_validate(kvs)
def is_user_message(self):
return self.role == "user"
def is_ai_message(self):
return self.role == "assistant"
class UserMessage(Message):
"""便于支持OpenAI的消息
@ -955,3 +961,11 @@ class BaseEnum(Enum):
obj._value_ = value
obj.desc = desc
return obj
class LongTermMemoryItem(BaseModel):
user_message: Message
ai_message: Message
def rag_key(self) -> str:
return self.user_message.content