diff --git a/.gitignore b/.gitignore index 24dd046be..73f32f75d 100644 --- a/.gitignore +++ b/.gitignore @@ -163,6 +163,7 @@ examples/image__vector_store.json examples/index_store.json .chroma .chroma_exp_data +.role_memory_data *~$* workspace/* tmp diff --git a/metagpt/base/base_role.py b/metagpt/base/base_role.py index b500b2cd6..55b8f6de2 100644 --- a/metagpt/base/base_role.py +++ b/metagpt/base/base_role.py @@ -1,8 +1,12 @@ +from __future__ import annotations + from abc import abstractmethod -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union from metagpt.base.base_serialization import BaseSerialization -from metagpt.schema import Message + +if TYPE_CHECKING: + from metagpt.schema import Message class BaseRole(BaseSerialization): diff --git a/metagpt/memory/memory.py b/metagpt/memory/memory.py index 580361d33..d44753413 100644 --- a/metagpt/memory/memory.py +++ b/metagpt/memory/memory.py @@ -104,3 +104,7 @@ class Memory(BaseModel): continue rsp += self.index[action] return rsp + + def get_by_position(self, position: int) -> Message: + """Return the message by its position""" + return self.storage[position] diff --git a/metagpt/memory/role_zero_memory.py b/metagpt/memory/role_zero_memory.py new file mode 100644 index 000000000..d02e31e7f --- /dev/null +++ b/metagpt/memory/role_zero_memory.py @@ -0,0 +1,60 @@ +from typing import TYPE_CHECKING, Any + +from pydantic import BaseModel, Field + +from metagpt.schema import LongTermMemoryItem, Message + +if TYPE_CHECKING: + from llama_index.core.schema import NodeWithScore + + from metagpt.rag.engines import SimpleEngine + + +class RoleZeroLongTermMemory(BaseModel): + persist_path: str = Field(default=".role_memory_data", description="The directory to save data.") + collection_name: str = Field(default="role_zero", description="The name of the collection, such as the role name.") + + _rag_engine: Any = None + + @property + def rag_engine(self) -> "SimpleEngine": + if self._rag_engine is None: + self._rag_engine = self._resolve_rag_engine() + + return self._rag_engine + + def _resolve_rag_engine(self) -> "SimpleEngine": + try: + from metagpt.rag.engines import SimpleEngine + from metagpt.rag.schema import ChromaRetrieverConfig, LLMRankerConfig + except ImportError: + raise ImportError("To use the RoleZeroMemory, you need to install the rag module.") + + retriever_configs = [ + ChromaRetrieverConfig(persist_path=self.persist_path, collection_name=self.collection_name) + ] + ranker_configs = [LLMRankerConfig()] + + rag_engine = SimpleEngine.from_objs(retriever_configs=retriever_configs, ranker_configs=ranker_configs) + + return rag_engine + + def fetch(self, query: str) -> list[Message]: + if not query: + return [] + + nodes: list[NodeWithScore] = self.rag_engine.retrieve(query) + + memories = [] + for node in nodes: + item: LongTermMemoryItem = node.metadata["obj"] + memories.append(item.user_message) + memories.append(item.ai_message) + + return memories + + def add(self, item: LongTermMemoryItem): + if not item: + return + + self.rag_engine.add_objs([item]) diff --git a/metagpt/roles/di/role_zero.py b/metagpt/roles/di/role_zero.py index 5edd7f88c..738de2244 100644 --- a/metagpt/roles/di/role_zero.py +++ b/metagpt/roles/di/role_zero.py @@ -18,6 +18,7 @@ from metagpt.exp_pool import exp_cache from metagpt.exp_pool.context_builders import RoleZeroContextBuilder from metagpt.exp_pool.serializers import RoleZeroSerializer from metagpt.logs import logger +from metagpt.memory.role_zero_memory import RoleZeroLongTermMemory from metagpt.prompts.di.role_zero import ( ASK_HUMAN_COMMAND, CMD_PROMPT, @@ -34,7 +35,7 @@ from metagpt.prompts.di.role_zero import ( THOUGHT_GUIDANCE, ) from metagpt.roles import Role -from metagpt.schema import AIMessage, Message, UserMessage +from metagpt.schema import AIMessage, LongTermMemoryItem, Message, UserMessage from metagpt.strategy.experience_retriever import DummyExpRetriever, ExpRetriever from metagpt.strategy.planner import Planner from metagpt.tools.libs.browser import Browser @@ -42,6 +43,7 @@ from metagpt.tools.libs.editor import Editor from metagpt.tools.tool_recommend import BM25ToolRecommender, ToolRecommender from metagpt.tools.tool_registry import register_tool from metagpt.utils.common import CodeParser, any_to_str, extract_and_encode_images +from metagpt.utils.exceptions import handle_exception from metagpt.utils.repair_llm_raw_output import ( RepairType, repair_escape_error, @@ -86,6 +88,8 @@ class RoleZero(Role): command_rsp: str = "" # the raw string containing the commands commands: list[dict] = [] # commands to be executed memory_k: int = 20 # number of memories (messages) to use as historical context + enable_longterm_memory: bool = True # whether to use longterm memory + longterm_memory: RoleZeroLongTermMemory = None use_fixed_sop: bool = False requirements_constraints: str = "" # the constraints in user requirements use_summary: bool = True # whether to summarize at the end @@ -140,6 +144,19 @@ class RoleZero(Role): self._update_tool_execution() return self + @model_validator(mode="after") + def set_longterm_memory(self) -> "RoleZero": + """Set longterm memory. + + If enable_longterm_memory is True and longterm_memory is not set, set it. + The role name will be used as the collection name. + """ + + if self.enable_longterm_memory and not self.longterm_memory: + self.longterm_memory = RoleZeroLongTermMemory(collection_name=self.name.replace(" ", "")) + + return self + def _update_tool_execution(self): pass @@ -154,7 +171,7 @@ class RoleZero(Role): return False if not self.planner.plan.goal: - self.planner.plan.goal = self.get_memories()[-1].content + self.planner.plan.goal = self._get_all_memories()[-1].content self.requirements_constraints = await AnalyzeRequirementsRestrictions().run(self.planner.plan.goal) ### 1. Experience ### @@ -186,7 +203,7 @@ class RoleZero(Role): ) ### Recent Observation ### - memory = self.rc.memory.get(self.memory_k) + memory = self._fetch_memories() memory = await self.parse_browser_actions(memory) memory = self.parse_images(memory) @@ -202,7 +219,7 @@ class RoleZero(Role): self.command_rsp = await self._check_duplicates(req, self.command_rsp) - self.rc.memory.add(AIMessage(content=self.command_rsp)) + self._add_memory(AIMessage(content=self.command_rsp)) return True @exp_cache(context_builder=RoleZeroContextBuilder(), serializer=RoleZeroSerializer()) @@ -245,12 +262,12 @@ class RoleZero(Role): commands, ok = await self._parse_commands(self.command_rsp) if not ok: error_msg = commands - self.rc.memory.add(UserMessage(content=error_msg)) + self._add_memory(UserMessage(content=error_msg)) return error_msg logger.info(f"Commands: \n{commands}") outputs = await self._run_commands(commands) logger.info(f"Commands outputs: \n{outputs}") - self.rc.memory.add(UserMessage(content=outputs)) + self._add_memory(UserMessage(content=outputs)) return AIMessage( content=f"I have finished the task, please mark my task as finished. Outputs: {outputs}", @@ -303,7 +320,7 @@ class RoleZero(Role): return rsp_msg, "" # routing - memory = self.get_memories(k=self.memory_k) + memory = self._fetch_memories() context = self.llm.format_msg(memory + [UserMessage(content=QUICK_THINK_PROMPT)]) async with ThoughtReporter() as reporter: await reporter.async_report({"type": "classify"}) @@ -328,7 +345,7 @@ class RoleZero(Role): answer = await SearchEnhancedQA().run(query) if answer: - self.rc.memory.add(AIMessage(content=answer, cause_by=RunCommand)) + self._add_memory(AIMessage(content=answer, cause_by=RunCommand)) await self.reply_to_human(content=answer) rsp_msg = AIMessage( content="Complete run", @@ -339,7 +356,7 @@ class RoleZero(Role): return rsp_msg, intent_result async def _check_duplicates(self, req: list[dict], command_rsp: str): - past_rsp = [mem.content for mem in self.rc.memory.get(self.memory_k)] + past_rsp = [mem.content for mem in self._fetch_memories()] if command_rsp in past_rsp: # Normal response with thought contents are highly unlikely to reproduce # If an identical response is detected, it is a bad response, mostly due to LLM repeating generated content @@ -479,7 +496,7 @@ class RoleZero(Role): def _retrieve_experience(self) -> str: """Default implementation of experience retrieval. Can be overwritten in subclasses.""" - context = [str(msg) for msg in self.rc.memory.get(self.memory_k)] + context = [str(msg) for msg in self._fetch_memories()] context = "\n\n".join(context) example = self.experience_retriever.retrieve(context=context) return example @@ -504,9 +521,9 @@ class RoleZero(Role): async def _end(self): self._set_state(-1) - memory = self.rc.memory.get(self.memory_k) + memory = self._fetch_memories() # Ensure reply to the human before the "end" command is executed. Hard code k=5 for checking. - if not any(["reply_to_human" in memory.content for memory in self.get_memories(k=5)]): + if not any(["reply_to_human" in memory.content for memory in self._fetch_memories(k=5)]): logger.info("manually reply to human") pattern = r"\[Language Restrictions\](.*?)\n" match = re.search(pattern, self.requirements_constraints, re.DOTALL) @@ -515,10 +532,95 @@ class RoleZero(Role): await reporter.async_report({"type": "quick"}) reply_content = await self.llm.aask(self.llm.format_msg(memory + [UserMessage(reply_to_human_prompt)])) await self.reply_to_human(content=reply_content) - self.rc.memory.add(AIMessage(content=reply_content, cause_by=RunCommand)) + self._add_memory(AIMessage(content=reply_content, cause_by=RunCommand)) outputs = "" # Summary of the Completed Task and Deliverables if self.use_summary: logger.info("end current run and summarize") outputs = await self.llm.aask(self.llm.format_msg(memory + [UserMessage(SUMMARY_PROMPT)])) return outputs + + def _get_all_memories(self) -> list[Message]: + return self._fetch_memories(k=0) + + def _fetch_memories(self, k: Optional[int] = None) -> list[Message]: + """Fetches recent memories and optionally combines them with related long-term memories. + + If long-term memory is not enabled or the last message is not from the user, + it returns the recent memories without fetching from long-term memory. + + Args: + k (Optional[int]): The number of recent memories to fetch. If None, defaults to self.memory_k. + + Returns: + List[Message]: A list of messages representing the combined memories. + """ + + if k is None: + k = self.memory_k + + memories = self.rc.memory.get(k) + + if not self._should_use_longterm_memory(k=k, k_memories=memories): + return memories + + related_memories = self.longterm_memory.fetch(memories[-1].content) + logger.info(f"Fetched {len(related_memories)} long-term memories.") + + if related_memories and self._is_first_message_from_ai(memories): + memories = memories[1:] + + final_memories = related_memories + memories + + return final_memories + + def _add_memory(self, message: Message): + self.rc.memory.add(message) + + if not self._should_use_longterm_memory(): + return + + self._transfer_to_longterm_memory() + + def _should_use_longterm_memory(self, k: int = None, k_memories: list[Message] = None) -> bool: + """Determines if long-term memory should be used. + + Long-term memory is used if: + - k is not 0. + - k_memories is None or k_memories is not empty, and the last message is a user message. + - Long-term memory usage is enabled. + - The count of recent memories is greater than self.memory_k. + """ + + conds = [ + k != 0, + k_memories is None or self._is_last_message_from_user(k_memories), + self.enable_longterm_memory, + self.rc.memory.count() > self.memory_k, + ] + + return all(conds) + + def _transfer_to_longterm_memory(self): + item = self._get_longterm_memory_item() + self.longterm_memory.add(item) + + @handle_exception + def _get_longterm_memory_item(self) -> Optional[LongTermMemoryItem]: + """Retrieves the most recent pair of user and AI messages before the last k messages.""" + + index = -(self.memory_k + 1) + message = self.rc.memory.get_by_position(index) + if not message.is_ai_message(): + return None + + index = -(self.memory_k + 2) + user_message = self.rc.memory.get_by_position(index) + + return LongTermMemoryItem(user_message=user_message, ai_message=message) + + def _is_last_message_from_user(self, memories: list[Message]) -> bool: + return bool(memories and memories[-1].is_user_message()) + + def _is_first_message_from_ai(self, memories: list[Message]) -> bool: + return bool(memories and memories[0].is_ai_message()) diff --git a/metagpt/schema.py b/metagpt/schema.py index ce64d130a..63a80f62a 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -408,6 +408,12 @@ class Message(BaseModel): dynamic_class = create_model(class_name, **{key: (value.__class__, ...) for key, value in kvs.items()}) return dynamic_class.model_validate(kvs) + def is_user_message(self): + return self.role == "user" + + def is_ai_message(self): + return self.role == "assistant" + class UserMessage(Message): """便于支持OpenAI的消息 @@ -955,3 +961,11 @@ class BaseEnum(Enum): obj._value_ = value obj.desc = desc return obj + + +class LongTermMemoryItem(BaseModel): + user_message: Message + ai_message: Message + + def rag_key(self) -> str: + return self.user_message.content diff --git a/tests/metagpt/memory/test_role_zero_memory.py b/tests/metagpt/memory/test_role_zero_memory.py new file mode 100644 index 000000000..8e2532bfc --- /dev/null +++ b/tests/metagpt/memory/test_role_zero_memory.py @@ -0,0 +1,48 @@ +import pytest + +from metagpt.memory.role_zero_memory import RoleZeroLongTermMemory +from metagpt.schema import AIMessage, LongTermMemoryItem, UserMessage + + +class TestRoleZeroLongTermMemory: + @pytest.fixture + def mock_memory(self, mocker) -> RoleZeroLongTermMemory: + memory = RoleZeroLongTermMemory() + memory._resolve_rag_engine = mocker.Mock() + return memory + + def test_fetch_empty_query(self, mock_memory: RoleZeroLongTermMemory): + assert mock_memory.fetch("") == [] + + def test_fetch(self, mocker, mock_memory: RoleZeroLongTermMemory): + mock_node1 = mocker.Mock() + mock_node2 = mocker.Mock() + mock_node1.metadata = { + "obj": LongTermMemoryItem(user_message=UserMessage(content="user1"), ai_message=AIMessage(content="ai1")) + } + mock_node2.metadata = { + "obj": LongTermMemoryItem(user_message=UserMessage(content="user2"), ai_message=AIMessage(content="ai2")) + } + + mock_memory.rag_engine.retrieve.return_value = [mock_node1, mock_node2] + + result = mock_memory.fetch("test query") + + assert len(result) == 4 + assert isinstance(result[0], UserMessage) + assert isinstance(result[1], AIMessage) + assert result[0].content == "user1" + assert result[1].content == "ai1" + assert result[2].content == "user2" + assert result[3].content == "ai2" + + mock_memory.rag_engine.retrieve.assert_called_once_with("test query") + + def test_add_empty_item(self, mock_memory: RoleZeroLongTermMemory): + mock_memory.add(None) + mock_memory.rag_engine.add_objs.assert_not_called() + + def test_add_item(self, mock_memory: RoleZeroLongTermMemory): + item = LongTermMemoryItem(user_message=UserMessage(content="user"), ai_message=AIMessage(content="ai")) + mock_memory.add(item) + mock_memory.rag_engine.add_objs.assert_called_once_with([item]) diff --git a/tests/metagpt/roles/di/test_role_zero.py b/tests/metagpt/roles/di/test_role_zero.py new file mode 100644 index 000000000..d4d4a46da --- /dev/null +++ b/tests/metagpt/roles/di/test_role_zero.py @@ -0,0 +1,192 @@ +import pytest + +from metagpt.roles.di.role_zero import RoleZero +from metagpt.schema import AIMessage, LongTermMemoryItem, Message, UserMessage + + +class TestRoleZero: + @pytest.fixture + def mock_role_zero(self, mocker): + role_zero = RoleZero() + role_zero.rc.memory = mocker.Mock() + role_zero.longterm_memory = mocker.Mock() + return role_zero + + def test_get_all_memories(self, mocker, mock_role_zero: RoleZero): + mock_memories = [Message(content="test1"), Message(content="test2")] + mock_role_zero._fetch_memories = mocker.Mock(return_value=mock_memories) + + result = mock_role_zero._get_all_memories() + + assert result == mock_memories + mock_role_zero._fetch_memories.assert_called_once_with(k=0) + + @pytest.mark.parametrize( + "k,should_use_ltm,memories,related_memories,is_first_from_ai,expected", + [ + ( + None, + False, + [UserMessage(content="user"), AIMessage(content="ai")], + [], + False, + [UserMessage(content="user"), AIMessage(content="ai")], + ), + ( + 1, + True, + [UserMessage(content="user")], + [Message(content="related")], + False, + [Message(content="related"), UserMessage(content="user")], + ), + ( + None, + True, + [AIMessage(content="ai1"), UserMessage(content="user"), AIMessage(content="ai2")], + [Message(content="related")], + True, + [Message(content="related"), UserMessage(content="user"), AIMessage(content="ai2")], + ), + ( + None, + True, + [UserMessage(content="user"), AIMessage(content="ai")], + [Message(content="related")], + False, + [Message(content="related"), UserMessage(content="user"), AIMessage(content="ai")], + ), + ( + 0, + False, + [UserMessage(content="user"), AIMessage(content="ai")], + [], + False, + [UserMessage(content="user"), AIMessage(content="ai")], + ), + ], + ) + def test_fetch_memories( + self, + mocker, + mock_role_zero: RoleZero, + k, + should_use_ltm, + memories, + related_memories, + is_first_from_ai, + expected, + ): + mock_role_zero.memory_k = 2 + mock_role_zero.rc.memory.get = mocker.Mock(return_value=memories) + mock_role_zero._should_use_longterm_memory = mocker.Mock(return_value=should_use_ltm) + mock_role_zero.longterm_memory.fetch = mocker.Mock(return_value=related_memories) + mock_role_zero._is_first_message_from_ai = mocker.Mock(return_value=is_first_from_ai) + + result = mock_role_zero._fetch_memories(k) + + assert len(result) == len(expected) + for actual, expected_msg in zip(result, expected): + assert actual.role == expected_msg.role + assert actual.content == expected_msg.content + + really_k = k if k is not None else mock_role_zero.memory_k + mock_role_zero.rc.memory.get.assert_called_once_with(really_k) + + if k != 0: + mock_role_zero._should_use_longterm_memory.assert_called_once_with(k=really_k, k_memories=memories) + + if should_use_ltm: + mock_role_zero.longterm_memory.fetch.assert_called_once_with(memories[-1].content) + mock_role_zero._is_first_message_from_ai.assert_called_once_with(memories) + + def test_add_memory(self, mocker, mock_role_zero: RoleZero): + message = AIMessage(content="ai") + mock_role_zero.rc.memory.add = mocker.Mock() + mock_role_zero._should_use_longterm_memory = mocker.Mock(return_value=True) + mock_role_zero._transfer_to_longterm_memory = mocker.Mock() + + mock_role_zero._add_memory(message) + + mock_role_zero.rc.memory.add.assert_called_once_with(message) + mock_role_zero._transfer_to_longterm_memory.assert_called_once() + + @pytest.mark.parametrize( + "k,enable_longterm_memory,memory_count,k_memories,expected", + [ + (0, True, 30, None, False), # k is 0 + (None, False, 30, None, False), # Long-term memory usage is disabled + (None, True, 10, None, False), # Memory count is less than or equal to memory_k + (None, True, 30, [], False), # k_memories is empty + (None, True, 30, [AIMessage(content="ai")], False), # Last message in k_memories is not a user message + (None, True, 30, [AIMessage(content="ai"), UserMessage(content="user")], True), # All conditions are met + ], + ) + def test_should_use_longterm_memory( + self, mocker, mock_role_zero: RoleZero, k, enable_longterm_memory, memory_count, k_memories, expected + ): + mock_role_zero.enable_longterm_memory = enable_longterm_memory + mock_role_zero.rc.memory.count = mocker.Mock(return_value=memory_count) + mock_role_zero.memory_k = 20 + + result = mock_role_zero._should_use_longterm_memory(k, k_memories) + + assert result == expected + + def test_transfer_to_longterm_memory(self, mocker, mock_role_zero: RoleZero): + mock_item = LongTermMemoryItem(user_message=UserMessage(content="user"), ai_message=AIMessage(content="ai")) + mock_role_zero._get_longterm_memory_item = mocker.Mock(return_value=mock_item) + mock_role_zero.longterm_memory = mocker.Mock() + + mock_role_zero._transfer_to_longterm_memory() + + mock_role_zero.longterm_memory.add.assert_called_once_with(mock_item) + + def test_get_longterm_memory_item(self, mocker, mock_role_zero: RoleZero): + mock_role_zero.memory_k = 2 + mock_messages = [ + UserMessage(content="user1"), + AIMessage(content="ai1"), + UserMessage(content="user2"), + AIMessage(content="ai2"), + UserMessage(content="user3"), # memory_k + 2 + AIMessage(content="ai3"), # memory_k + 1 + UserMessage(content="recent1"), + AIMessage(content="recent2"), + ] + + mock_role_zero.rc.memory.get_by_position = mocker.Mock(side_effect=lambda i: mock_messages[i]) + mock_role_zero.rc.memory.count = mocker.Mock(return_value=len(mock_messages)) + + result = mock_role_zero._get_longterm_memory_item() + + assert isinstance(result, LongTermMemoryItem) + assert result.user_message.content == "user3" + assert result.ai_message.content == "ai3" + + mock_role_zero.rc.memory.get_by_position.assert_any_call(-(mock_role_zero.memory_k + 1)) + mock_role_zero.rc.memory.get_by_position.assert_any_call(-(mock_role_zero.memory_k + 2)) + + @pytest.mark.parametrize( + "memories,expected", + [ + ([UserMessage(content="user")], True), + ([AIMessage(content="ai")], False), + ([], False), + ], + ) + def test_is_last_message_from_user(self, mock_role_zero: RoleZero, memories, expected): + result = mock_role_zero._is_last_message_from_user(memories) + assert result == expected + + @pytest.mark.parametrize( + "memories,expected", + [ + ([AIMessage(content="ai")], True), + ([UserMessage(content="user")], False), + ([], False), + ], + ) + def test_is_first_message_from_ai(self, mock_role_zero: RoleZero, memories, expected): + result = mock_role_zero._is_first_message_from_ai(memories) + assert result == expected