From f20d3ca2ddac5908c4b4c635c63253e6954251fa Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Fri, 6 Sep 2024 10:23:07 +0800 Subject: [PATCH 01/18] use long-term memory in rolezero --- .gitignore | 1 + metagpt/base/base_role.py | 8 +- metagpt/memory/memory.py | 4 + metagpt/memory/role_zero_memory.py | 60 ++++++ metagpt/roles/di/role_zero.py | 128 ++++++++++-- metagpt/schema.py | 14 ++ tests/metagpt/memory/test_role_zero_memory.py | 48 +++++ tests/metagpt/roles/di/test_role_zero.py | 192 ++++++++++++++++++ 8 files changed, 440 insertions(+), 15 deletions(-) create mode 100644 metagpt/memory/role_zero_memory.py create mode 100644 tests/metagpt/memory/test_role_zero_memory.py create mode 100644 tests/metagpt/roles/di/test_role_zero.py diff --git a/.gitignore b/.gitignore index 24dd046be..73f32f75d 100644 --- a/.gitignore +++ b/.gitignore @@ -163,6 +163,7 @@ examples/image__vector_store.json examples/index_store.json .chroma .chroma_exp_data +.role_memory_data *~$* workspace/* tmp diff --git a/metagpt/base/base_role.py b/metagpt/base/base_role.py index b500b2cd6..55b8f6de2 100644 --- a/metagpt/base/base_role.py +++ b/metagpt/base/base_role.py @@ -1,8 +1,12 @@ +from __future__ import annotations + from abc import abstractmethod -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union from metagpt.base.base_serialization import BaseSerialization -from metagpt.schema import Message + +if TYPE_CHECKING: + from metagpt.schema import Message class BaseRole(BaseSerialization): diff --git a/metagpt/memory/memory.py b/metagpt/memory/memory.py index 580361d33..d44753413 100644 --- a/metagpt/memory/memory.py +++ b/metagpt/memory/memory.py @@ -104,3 +104,7 @@ class Memory(BaseModel): continue rsp += self.index[action] return rsp + + def get_by_position(self, position: int) -> Message: + """Return the message by its position""" + return self.storage[position] diff --git a/metagpt/memory/role_zero_memory.py b/metagpt/memory/role_zero_memory.py new file mode 100644 index 000000000..d02e31e7f --- /dev/null +++ b/metagpt/memory/role_zero_memory.py @@ -0,0 +1,60 @@ +from typing import TYPE_CHECKING, Any + +from pydantic import BaseModel, Field + +from metagpt.schema import LongTermMemoryItem, Message + +if TYPE_CHECKING: + from llama_index.core.schema import NodeWithScore + + from metagpt.rag.engines import SimpleEngine + + +class RoleZeroLongTermMemory(BaseModel): + persist_path: str = Field(default=".role_memory_data", description="The directory to save data.") + collection_name: str = Field(default="role_zero", description="The name of the collection, such as the role name.") + + _rag_engine: Any = None + + @property + def rag_engine(self) -> "SimpleEngine": + if self._rag_engine is None: + self._rag_engine = self._resolve_rag_engine() + + return self._rag_engine + + def _resolve_rag_engine(self) -> "SimpleEngine": + try: + from metagpt.rag.engines import SimpleEngine + from metagpt.rag.schema import ChromaRetrieverConfig, LLMRankerConfig + except ImportError: + raise ImportError("To use the RoleZeroMemory, you need to install the rag module.") + + retriever_configs = [ + ChromaRetrieverConfig(persist_path=self.persist_path, collection_name=self.collection_name) + ] + ranker_configs = [LLMRankerConfig()] + + rag_engine = SimpleEngine.from_objs(retriever_configs=retriever_configs, ranker_configs=ranker_configs) + + return rag_engine + + def fetch(self, query: str) -> list[Message]: + if not query: + return [] + + nodes: list[NodeWithScore] = self.rag_engine.retrieve(query) + + memories = [] + for node in nodes: + item: LongTermMemoryItem = node.metadata["obj"] + memories.append(item.user_message) + memories.append(item.ai_message) + + return memories + + def add(self, item: LongTermMemoryItem): + if not item: + return + + self.rag_engine.add_objs([item]) diff --git a/metagpt/roles/di/role_zero.py b/metagpt/roles/di/role_zero.py index 5edd7f88c..738de2244 100644 --- a/metagpt/roles/di/role_zero.py +++ b/metagpt/roles/di/role_zero.py @@ -18,6 +18,7 @@ from metagpt.exp_pool import exp_cache from metagpt.exp_pool.context_builders import RoleZeroContextBuilder from metagpt.exp_pool.serializers import RoleZeroSerializer from metagpt.logs import logger +from metagpt.memory.role_zero_memory import RoleZeroLongTermMemory from metagpt.prompts.di.role_zero import ( ASK_HUMAN_COMMAND, CMD_PROMPT, @@ -34,7 +35,7 @@ from metagpt.prompts.di.role_zero import ( THOUGHT_GUIDANCE, ) from metagpt.roles import Role -from metagpt.schema import AIMessage, Message, UserMessage +from metagpt.schema import AIMessage, LongTermMemoryItem, Message, UserMessage from metagpt.strategy.experience_retriever import DummyExpRetriever, ExpRetriever from metagpt.strategy.planner import Planner from metagpt.tools.libs.browser import Browser @@ -42,6 +43,7 @@ from metagpt.tools.libs.editor import Editor from metagpt.tools.tool_recommend import BM25ToolRecommender, ToolRecommender from metagpt.tools.tool_registry import register_tool from metagpt.utils.common import CodeParser, any_to_str, extract_and_encode_images +from metagpt.utils.exceptions import handle_exception from metagpt.utils.repair_llm_raw_output import ( RepairType, repair_escape_error, @@ -86,6 +88,8 @@ class RoleZero(Role): command_rsp: str = "" # the raw string containing the commands commands: list[dict] = [] # commands to be executed memory_k: int = 20 # number of memories (messages) to use as historical context + enable_longterm_memory: bool = True # whether to use longterm memory + longterm_memory: RoleZeroLongTermMemory = None use_fixed_sop: bool = False requirements_constraints: str = "" # the constraints in user requirements use_summary: bool = True # whether to summarize at the end @@ -140,6 +144,19 @@ class RoleZero(Role): self._update_tool_execution() return self + @model_validator(mode="after") + def set_longterm_memory(self) -> "RoleZero": + """Set longterm memory. + + If enable_longterm_memory is True and longterm_memory is not set, set it. + The role name will be used as the collection name. + """ + + if self.enable_longterm_memory and not self.longterm_memory: + self.longterm_memory = RoleZeroLongTermMemory(collection_name=self.name.replace(" ", "")) + + return self + def _update_tool_execution(self): pass @@ -154,7 +171,7 @@ class RoleZero(Role): return False if not self.planner.plan.goal: - self.planner.plan.goal = self.get_memories()[-1].content + self.planner.plan.goal = self._get_all_memories()[-1].content self.requirements_constraints = await AnalyzeRequirementsRestrictions().run(self.planner.plan.goal) ### 1. Experience ### @@ -186,7 +203,7 @@ class RoleZero(Role): ) ### Recent Observation ### - memory = self.rc.memory.get(self.memory_k) + memory = self._fetch_memories() memory = await self.parse_browser_actions(memory) memory = self.parse_images(memory) @@ -202,7 +219,7 @@ class RoleZero(Role): self.command_rsp = await self._check_duplicates(req, self.command_rsp) - self.rc.memory.add(AIMessage(content=self.command_rsp)) + self._add_memory(AIMessage(content=self.command_rsp)) return True @exp_cache(context_builder=RoleZeroContextBuilder(), serializer=RoleZeroSerializer()) @@ -245,12 +262,12 @@ class RoleZero(Role): commands, ok = await self._parse_commands(self.command_rsp) if not ok: error_msg = commands - self.rc.memory.add(UserMessage(content=error_msg)) + self._add_memory(UserMessage(content=error_msg)) return error_msg logger.info(f"Commands: \n{commands}") outputs = await self._run_commands(commands) logger.info(f"Commands outputs: \n{outputs}") - self.rc.memory.add(UserMessage(content=outputs)) + self._add_memory(UserMessage(content=outputs)) return AIMessage( content=f"I have finished the task, please mark my task as finished. Outputs: {outputs}", @@ -303,7 +320,7 @@ class RoleZero(Role): return rsp_msg, "" # routing - memory = self.get_memories(k=self.memory_k) + memory = self._fetch_memories() context = self.llm.format_msg(memory + [UserMessage(content=QUICK_THINK_PROMPT)]) async with ThoughtReporter() as reporter: await reporter.async_report({"type": "classify"}) @@ -328,7 +345,7 @@ class RoleZero(Role): answer = await SearchEnhancedQA().run(query) if answer: - self.rc.memory.add(AIMessage(content=answer, cause_by=RunCommand)) + self._add_memory(AIMessage(content=answer, cause_by=RunCommand)) await self.reply_to_human(content=answer) rsp_msg = AIMessage( content="Complete run", @@ -339,7 +356,7 @@ class RoleZero(Role): return rsp_msg, intent_result async def _check_duplicates(self, req: list[dict], command_rsp: str): - past_rsp = [mem.content for mem in self.rc.memory.get(self.memory_k)] + past_rsp = [mem.content for mem in self._fetch_memories()] if command_rsp in past_rsp: # Normal response with thought contents are highly unlikely to reproduce # If an identical response is detected, it is a bad response, mostly due to LLM repeating generated content @@ -479,7 +496,7 @@ class RoleZero(Role): def _retrieve_experience(self) -> str: """Default implementation of experience retrieval. Can be overwritten in subclasses.""" - context = [str(msg) for msg in self.rc.memory.get(self.memory_k)] + context = [str(msg) for msg in self._fetch_memories()] context = "\n\n".join(context) example = self.experience_retriever.retrieve(context=context) return example @@ -504,9 +521,9 @@ class RoleZero(Role): async def _end(self): self._set_state(-1) - memory = self.rc.memory.get(self.memory_k) + memory = self._fetch_memories() # Ensure reply to the human before the "end" command is executed. Hard code k=5 for checking. - if not any(["reply_to_human" in memory.content for memory in self.get_memories(k=5)]): + if not any(["reply_to_human" in memory.content for memory in self._fetch_memories(k=5)]): logger.info("manually reply to human") pattern = r"\[Language Restrictions\](.*?)\n" match = re.search(pattern, self.requirements_constraints, re.DOTALL) @@ -515,10 +532,95 @@ class RoleZero(Role): await reporter.async_report({"type": "quick"}) reply_content = await self.llm.aask(self.llm.format_msg(memory + [UserMessage(reply_to_human_prompt)])) await self.reply_to_human(content=reply_content) - self.rc.memory.add(AIMessage(content=reply_content, cause_by=RunCommand)) + self._add_memory(AIMessage(content=reply_content, cause_by=RunCommand)) outputs = "" # Summary of the Completed Task and Deliverables if self.use_summary: logger.info("end current run and summarize") outputs = await self.llm.aask(self.llm.format_msg(memory + [UserMessage(SUMMARY_PROMPT)])) return outputs + + def _get_all_memories(self) -> list[Message]: + return self._fetch_memories(k=0) + + def _fetch_memories(self, k: Optional[int] = None) -> list[Message]: + """Fetches recent memories and optionally combines them with related long-term memories. + + If long-term memory is not enabled or the last message is not from the user, + it returns the recent memories without fetching from long-term memory. + + Args: + k (Optional[int]): The number of recent memories to fetch. If None, defaults to self.memory_k. + + Returns: + List[Message]: A list of messages representing the combined memories. + """ + + if k is None: + k = self.memory_k + + memories = self.rc.memory.get(k) + + if not self._should_use_longterm_memory(k=k, k_memories=memories): + return memories + + related_memories = self.longterm_memory.fetch(memories[-1].content) + logger.info(f"Fetched {len(related_memories)} long-term memories.") + + if related_memories and self._is_first_message_from_ai(memories): + memories = memories[1:] + + final_memories = related_memories + memories + + return final_memories + + def _add_memory(self, message: Message): + self.rc.memory.add(message) + + if not self._should_use_longterm_memory(): + return + + self._transfer_to_longterm_memory() + + def _should_use_longterm_memory(self, k: int = None, k_memories: list[Message] = None) -> bool: + """Determines if long-term memory should be used. + + Long-term memory is used if: + - k is not 0. + - k_memories is None or k_memories is not empty, and the last message is a user message. + - Long-term memory usage is enabled. + - The count of recent memories is greater than self.memory_k. + """ + + conds = [ + k != 0, + k_memories is None or self._is_last_message_from_user(k_memories), + self.enable_longterm_memory, + self.rc.memory.count() > self.memory_k, + ] + + return all(conds) + + def _transfer_to_longterm_memory(self): + item = self._get_longterm_memory_item() + self.longterm_memory.add(item) + + @handle_exception + def _get_longterm_memory_item(self) -> Optional[LongTermMemoryItem]: + """Retrieves the most recent pair of user and AI messages before the last k messages.""" + + index = -(self.memory_k + 1) + message = self.rc.memory.get_by_position(index) + if not message.is_ai_message(): + return None + + index = -(self.memory_k + 2) + user_message = self.rc.memory.get_by_position(index) + + return LongTermMemoryItem(user_message=user_message, ai_message=message) + + def _is_last_message_from_user(self, memories: list[Message]) -> bool: + return bool(memories and memories[-1].is_user_message()) + + def _is_first_message_from_ai(self, memories: list[Message]) -> bool: + return bool(memories and memories[0].is_ai_message()) diff --git a/metagpt/schema.py b/metagpt/schema.py index ce64d130a..63a80f62a 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -408,6 +408,12 @@ class Message(BaseModel): dynamic_class = create_model(class_name, **{key: (value.__class__, ...) for key, value in kvs.items()}) return dynamic_class.model_validate(kvs) + def is_user_message(self): + return self.role == "user" + + def is_ai_message(self): + return self.role == "assistant" + class UserMessage(Message): """便于支持OpenAI的消息 @@ -955,3 +961,11 @@ class BaseEnum(Enum): obj._value_ = value obj.desc = desc return obj + + +class LongTermMemoryItem(BaseModel): + user_message: Message + ai_message: Message + + def rag_key(self) -> str: + return self.user_message.content diff --git a/tests/metagpt/memory/test_role_zero_memory.py b/tests/metagpt/memory/test_role_zero_memory.py new file mode 100644 index 000000000..8e2532bfc --- /dev/null +++ b/tests/metagpt/memory/test_role_zero_memory.py @@ -0,0 +1,48 @@ +import pytest + +from metagpt.memory.role_zero_memory import RoleZeroLongTermMemory +from metagpt.schema import AIMessage, LongTermMemoryItem, UserMessage + + +class TestRoleZeroLongTermMemory: + @pytest.fixture + def mock_memory(self, mocker) -> RoleZeroLongTermMemory: + memory = RoleZeroLongTermMemory() + memory._resolve_rag_engine = mocker.Mock() + return memory + + def test_fetch_empty_query(self, mock_memory: RoleZeroLongTermMemory): + assert mock_memory.fetch("") == [] + + def test_fetch(self, mocker, mock_memory: RoleZeroLongTermMemory): + mock_node1 = mocker.Mock() + mock_node2 = mocker.Mock() + mock_node1.metadata = { + "obj": LongTermMemoryItem(user_message=UserMessage(content="user1"), ai_message=AIMessage(content="ai1")) + } + mock_node2.metadata = { + "obj": LongTermMemoryItem(user_message=UserMessage(content="user2"), ai_message=AIMessage(content="ai2")) + } + + mock_memory.rag_engine.retrieve.return_value = [mock_node1, mock_node2] + + result = mock_memory.fetch("test query") + + assert len(result) == 4 + assert isinstance(result[0], UserMessage) + assert isinstance(result[1], AIMessage) + assert result[0].content == "user1" + assert result[1].content == "ai1" + assert result[2].content == "user2" + assert result[3].content == "ai2" + + mock_memory.rag_engine.retrieve.assert_called_once_with("test query") + + def test_add_empty_item(self, mock_memory: RoleZeroLongTermMemory): + mock_memory.add(None) + mock_memory.rag_engine.add_objs.assert_not_called() + + def test_add_item(self, mock_memory: RoleZeroLongTermMemory): + item = LongTermMemoryItem(user_message=UserMessage(content="user"), ai_message=AIMessage(content="ai")) + mock_memory.add(item) + mock_memory.rag_engine.add_objs.assert_called_once_with([item]) diff --git a/tests/metagpt/roles/di/test_role_zero.py b/tests/metagpt/roles/di/test_role_zero.py new file mode 100644 index 000000000..d4d4a46da --- /dev/null +++ b/tests/metagpt/roles/di/test_role_zero.py @@ -0,0 +1,192 @@ +import pytest + +from metagpt.roles.di.role_zero import RoleZero +from metagpt.schema import AIMessage, LongTermMemoryItem, Message, UserMessage + + +class TestRoleZero: + @pytest.fixture + def mock_role_zero(self, mocker): + role_zero = RoleZero() + role_zero.rc.memory = mocker.Mock() + role_zero.longterm_memory = mocker.Mock() + return role_zero + + def test_get_all_memories(self, mocker, mock_role_zero: RoleZero): + mock_memories = [Message(content="test1"), Message(content="test2")] + mock_role_zero._fetch_memories = mocker.Mock(return_value=mock_memories) + + result = mock_role_zero._get_all_memories() + + assert result == mock_memories + mock_role_zero._fetch_memories.assert_called_once_with(k=0) + + @pytest.mark.parametrize( + "k,should_use_ltm,memories,related_memories,is_first_from_ai,expected", + [ + ( + None, + False, + [UserMessage(content="user"), AIMessage(content="ai")], + [], + False, + [UserMessage(content="user"), AIMessage(content="ai")], + ), + ( + 1, + True, + [UserMessage(content="user")], + [Message(content="related")], + False, + [Message(content="related"), UserMessage(content="user")], + ), + ( + None, + True, + [AIMessage(content="ai1"), UserMessage(content="user"), AIMessage(content="ai2")], + [Message(content="related")], + True, + [Message(content="related"), UserMessage(content="user"), AIMessage(content="ai2")], + ), + ( + None, + True, + [UserMessage(content="user"), AIMessage(content="ai")], + [Message(content="related")], + False, + [Message(content="related"), UserMessage(content="user"), AIMessage(content="ai")], + ), + ( + 0, + False, + [UserMessage(content="user"), AIMessage(content="ai")], + [], + False, + [UserMessage(content="user"), AIMessage(content="ai")], + ), + ], + ) + def test_fetch_memories( + self, + mocker, + mock_role_zero: RoleZero, + k, + should_use_ltm, + memories, + related_memories, + is_first_from_ai, + expected, + ): + mock_role_zero.memory_k = 2 + mock_role_zero.rc.memory.get = mocker.Mock(return_value=memories) + mock_role_zero._should_use_longterm_memory = mocker.Mock(return_value=should_use_ltm) + mock_role_zero.longterm_memory.fetch = mocker.Mock(return_value=related_memories) + mock_role_zero._is_first_message_from_ai = mocker.Mock(return_value=is_first_from_ai) + + result = mock_role_zero._fetch_memories(k) + + assert len(result) == len(expected) + for actual, expected_msg in zip(result, expected): + assert actual.role == expected_msg.role + assert actual.content == expected_msg.content + + really_k = k if k is not None else mock_role_zero.memory_k + mock_role_zero.rc.memory.get.assert_called_once_with(really_k) + + if k != 0: + mock_role_zero._should_use_longterm_memory.assert_called_once_with(k=really_k, k_memories=memories) + + if should_use_ltm: + mock_role_zero.longterm_memory.fetch.assert_called_once_with(memories[-1].content) + mock_role_zero._is_first_message_from_ai.assert_called_once_with(memories) + + def test_add_memory(self, mocker, mock_role_zero: RoleZero): + message = AIMessage(content="ai") + mock_role_zero.rc.memory.add = mocker.Mock() + mock_role_zero._should_use_longterm_memory = mocker.Mock(return_value=True) + mock_role_zero._transfer_to_longterm_memory = mocker.Mock() + + mock_role_zero._add_memory(message) + + mock_role_zero.rc.memory.add.assert_called_once_with(message) + mock_role_zero._transfer_to_longterm_memory.assert_called_once() + + @pytest.mark.parametrize( + "k,enable_longterm_memory,memory_count,k_memories,expected", + [ + (0, True, 30, None, False), # k is 0 + (None, False, 30, None, False), # Long-term memory usage is disabled + (None, True, 10, None, False), # Memory count is less than or equal to memory_k + (None, True, 30, [], False), # k_memories is empty + (None, True, 30, [AIMessage(content="ai")], False), # Last message in k_memories is not a user message + (None, True, 30, [AIMessage(content="ai"), UserMessage(content="user")], True), # All conditions are met + ], + ) + def test_should_use_longterm_memory( + self, mocker, mock_role_zero: RoleZero, k, enable_longterm_memory, memory_count, k_memories, expected + ): + mock_role_zero.enable_longterm_memory = enable_longterm_memory + mock_role_zero.rc.memory.count = mocker.Mock(return_value=memory_count) + mock_role_zero.memory_k = 20 + + result = mock_role_zero._should_use_longterm_memory(k, k_memories) + + assert result == expected + + def test_transfer_to_longterm_memory(self, mocker, mock_role_zero: RoleZero): + mock_item = LongTermMemoryItem(user_message=UserMessage(content="user"), ai_message=AIMessage(content="ai")) + mock_role_zero._get_longterm_memory_item = mocker.Mock(return_value=mock_item) + mock_role_zero.longterm_memory = mocker.Mock() + + mock_role_zero._transfer_to_longterm_memory() + + mock_role_zero.longterm_memory.add.assert_called_once_with(mock_item) + + def test_get_longterm_memory_item(self, mocker, mock_role_zero: RoleZero): + mock_role_zero.memory_k = 2 + mock_messages = [ + UserMessage(content="user1"), + AIMessage(content="ai1"), + UserMessage(content="user2"), + AIMessage(content="ai2"), + UserMessage(content="user3"), # memory_k + 2 + AIMessage(content="ai3"), # memory_k + 1 + UserMessage(content="recent1"), + AIMessage(content="recent2"), + ] + + mock_role_zero.rc.memory.get_by_position = mocker.Mock(side_effect=lambda i: mock_messages[i]) + mock_role_zero.rc.memory.count = mocker.Mock(return_value=len(mock_messages)) + + result = mock_role_zero._get_longterm_memory_item() + + assert isinstance(result, LongTermMemoryItem) + assert result.user_message.content == "user3" + assert result.ai_message.content == "ai3" + + mock_role_zero.rc.memory.get_by_position.assert_any_call(-(mock_role_zero.memory_k + 1)) + mock_role_zero.rc.memory.get_by_position.assert_any_call(-(mock_role_zero.memory_k + 2)) + + @pytest.mark.parametrize( + "memories,expected", + [ + ([UserMessage(content="user")], True), + ([AIMessage(content="ai")], False), + ([], False), + ], + ) + def test_is_last_message_from_user(self, mock_role_zero: RoleZero, memories, expected): + result = mock_role_zero._is_last_message_from_user(memories) + assert result == expected + + @pytest.mark.parametrize( + "memories,expected", + [ + ([AIMessage(content="ai")], True), + ([UserMessage(content="user")], False), + ([], False), + ], + ) + def test_is_first_message_from_ai(self, mock_role_zero: RoleZero, memories, expected): + result = mock_role_zero._is_first_message_from_ai(memories) + assert result == expected From 0f7630ef66bc3ed6fa57b9ec1b62c47ecd13c3f4 Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Fri, 6 Sep 2024 10:41:43 +0800 Subject: [PATCH 02/18] fix conflict --- metagpt/base/base_role.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/metagpt/base/base_role.py b/metagpt/base/base_role.py index e77819d76..1f7f00fa2 100644 --- a/metagpt/base/base_role.py +++ b/metagpt/base/base_role.py @@ -1,13 +1,8 @@ -from __future__ import annotations - from abc import abstractmethod -from typing import TYPE_CHECKING, Optional, Union +from typing import Optional, Union from metagpt.base.base_serialization import BaseSerialization -if TYPE_CHECKING: - from metagpt.schema import Message - class BaseRole(BaseSerialization): """Abstract base class for all roles.""" From 53ef7be68c425d5d1f73ff6bc03f8b108733df9b Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Fri, 6 Sep 2024 11:30:49 +0800 Subject: [PATCH 03/18] update comment --- metagpt/memory/role_zero_memory.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/metagpt/memory/role_zero_memory.py b/metagpt/memory/role_zero_memory.py index d02e31e7f..570d0cc41 100644 --- a/metagpt/memory/role_zero_memory.py +++ b/metagpt/memory/role_zero_memory.py @@ -24,6 +24,11 @@ class RoleZeroLongTermMemory(BaseModel): return self._rag_engine def _resolve_rag_engine(self) -> "SimpleEngine": + """Lazy loading of the RAG engine components, ensuring they are only loaded when needed. + + It uses `Chroma` for retrieval and `LLMRanker` for ranking. + """ + try: from metagpt.rag.engines import SimpleEngine from metagpt.rag.schema import ChromaRetrieverConfig, LLMRankerConfig @@ -40,6 +45,15 @@ class RoleZeroLongTermMemory(BaseModel): return rag_engine def fetch(self, query: str) -> list[Message]: + """Fetches long-term memories based on a query. + + Args: + query (str): The query string to search for relevant memories. + + Returns: + list[Message]: A list of user and AI messages related to the query. + """ + if not query: return [] @@ -54,6 +68,12 @@ class RoleZeroLongTermMemory(BaseModel): return memories def add(self, item: LongTermMemoryItem): + """Adds a long-term memory item to the RAG engine. + + Args: + item (LongTermMemoryItem): The memory item containing user and AI messages. + """ + if not item: return From e2600c0a64e5545bac2dc5e1810e29404c8f8acf Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Mon, 9 Sep 2024 11:17:41 +0800 Subject: [PATCH 04/18] update comment --- metagpt/roles/di/team_leader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metagpt/roles/di/team_leader.py b/metagpt/roles/di/team_leader.py index 112ca5a84..1bf6364d9 100644 --- a/metagpt/roles/di/team_leader.py +++ b/metagpt/roles/di/team_leader.py @@ -84,4 +84,4 @@ class TeamLeader(RoleZero): def finish_current_task(self): self.planner.plan.finish_current_task() - self.rc.memory.add(AIMessage(content=FINISH_CURRENT_TASK_CMD)) + self._add_memory(AIMessage(content=FINISH_CURRENT_TASK_CMD)) From 047f1e429dcbd7d48c57cbac58f6039b55c456f3 Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Mon, 9 Sep 2024 14:02:26 +0800 Subject: [PATCH 05/18] Keep user and AI messages are paired --- metagpt/roles/di/role_zero.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/metagpt/roles/di/role_zero.py b/metagpt/roles/di/role_zero.py index 5e1d239c0..b4ee43228 100644 --- a/metagpt/roles/di/role_zero.py +++ b/metagpt/roles/di/role_zero.py @@ -608,8 +608,9 @@ class RoleZero(Role): related_memories = self.longterm_memory.fetch(memories[-1].content) logger.info(f"Fetched {len(related_memories)} long-term memories.") - if related_memories and self._is_first_message_from_ai(memories): - memories = memories[1:] + # Keep user and AI messages are paired. + if self._is_first_message_from_ai(memories): + memories.append(self.rc.memory.get_by_position(-(k + 1))) final_memories = related_memories + memories From 31dbff5474d4dbc2ef278e8fdd9c35ad95db3878 Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Mon, 9 Sep 2024 14:24:03 +0800 Subject: [PATCH 06/18] Keep user and AI messages are paired --- metagpt/roles/di/role_zero.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metagpt/roles/di/role_zero.py b/metagpt/roles/di/role_zero.py index b4ee43228..34d1812fc 100644 --- a/metagpt/roles/di/role_zero.py +++ b/metagpt/roles/di/role_zero.py @@ -610,7 +610,7 @@ class RoleZero(Role): # Keep user and AI messages are paired. if self._is_first_message_from_ai(memories): - memories.append(self.rc.memory.get_by_position(-(k + 1))) + memories.insert(0, self.rc.memory.get_by_position(-(k + 1))) final_memories = related_memories + memories From 5d21d255e4c1a567ea0ee5c00a0a5e26247d74de Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Tue, 10 Sep 2024 16:44:22 +0800 Subject: [PATCH 07/18] update --- metagpt/memory/role_zero_memory.py | 14 ++++++-- metagpt/roles/di/role_zero.py | 15 ++++++-- metagpt/schema.py | 9 +++-- tests/metagpt/memory/test_role_zero_memory.py | 35 +++++++++++++++++++ tests/metagpt/roles/di/test_role_zero.py | 32 +++++++++++++++-- 5 files changed, 95 insertions(+), 10 deletions(-) diff --git a/metagpt/memory/role_zero_memory.py b/metagpt/memory/role_zero_memory.py index 570d0cc41..7f777df56 100644 --- a/metagpt/memory/role_zero_memory.py +++ b/metagpt/memory/role_zero_memory.py @@ -57,11 +57,11 @@ class RoleZeroLongTermMemory(BaseModel): if not query: return [] - nodes: list[NodeWithScore] = self.rag_engine.retrieve(query) + nodes = self.rag_engine.retrieve(query) + items = self._get_items_from_nodes(nodes) memories = [] - for node in nodes: - item: LongTermMemoryItem = node.metadata["obj"] + for item in items: memories.append(item.user_message) memories.append(item.ai_message) @@ -78,3 +78,11 @@ class RoleZeroLongTermMemory(BaseModel): return self.rag_engine.add_objs([item]) + + def _get_items_from_nodes(self, nodes: list["NodeWithScore"]) -> list[LongTermMemoryItem]: + """Get items from nodes and arrange them in order of their `created_at`.""" + + items: list[LongTermMemoryItem] = [node.metadata["obj"] for node in nodes] + items.sort(key=lambda item: item.created_at) + + return items diff --git a/metagpt/roles/di/role_zero.py b/metagpt/roles/di/role_zero.py index 34d1812fc..5818bc1f3 100644 --- a/metagpt/roles/di/role_zero.py +++ b/metagpt/roles/di/role_zero.py @@ -285,12 +285,12 @@ class RoleZero(Role): self._add_memory(AIMessage(content=self.command_rsp)) if not ok: error_msg = commands - self._add_memory(UserMessage(content=error_msg)) + self._add_memory(UserMessage(content=error_msg, cause_by=RunCommand)) return error_msg logger.info(f"Commands: \n{commands}") outputs = await self._run_commands(commands) logger.info(f"Commands outputs: \n{outputs}") - self._add_memory(UserMessage(content=outputs)) + self._add_memory(UserMessage(content=outputs, cause_by=RunCommand)) return AIMessage( content=f"I have finished the task, please mark my task as finished. Outputs: {outputs}", @@ -605,7 +605,8 @@ class RoleZero(Role): if not self._should_use_longterm_memory(k=k, k_memories=memories): return memories - related_memories = self.longterm_memory.fetch(memories[-1].content) + query = self._build_longterm_memory_query(memories) + related_memories = self.longterm_memory.fetch(query) logger.info(f"Fetched {len(related_memories)} long-term memories.") # Keep user and AI messages are paired. @@ -666,3 +667,11 @@ class RoleZero(Role): def _is_first_message_from_ai(self, memories: list[Message]) -> bool: return bool(memories and memories[0].is_ai_message()) + + def _build_longterm_memory_query(self, memories: list[Message]) -> str: + """Build the content used to query related long-term memory. + + Default is to get the most recent user message, or an empty string if none is found. + """ + + return next((m.content for m in reversed(memories) if m.is_real_user_message()), "") diff --git a/metagpt/schema.py b/metagpt/schema.py index 63a80f62a..8481bccf3 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -18,6 +18,7 @@ from __future__ import annotations import asyncio import json import os.path +import time import uuid from abc import ABC from asyncio import Queue, QueueEmpty, wait_for @@ -408,12 +409,15 @@ class Message(BaseModel): dynamic_class = create_model(class_name, **{key: (value.__class__, ...) for key, value in kvs.items()}) return dynamic_class.model_validate(kvs) - def is_user_message(self): + def is_user_message(self) -> bool: return self.role == "user" - def is_ai_message(self): + def is_ai_message(self) -> bool: return self.role == "assistant" + def is_real_user_message(self) -> bool: + return self.is_user_message() and "UserRequirement" in self.cause_by + class UserMessage(Message): """便于支持OpenAI的消息 @@ -966,6 +970,7 @@ class BaseEnum(Enum): class LongTermMemoryItem(BaseModel): user_message: Message ai_message: Message + created_at: Optional[float] = Field(default_factory=time.time) def rag_key(self) -> str: return self.user_message.content diff --git a/tests/metagpt/memory/test_role_zero_memory.py b/tests/metagpt/memory/test_role_zero_memory.py index 8e2532bfc..1c6fb785e 100644 --- a/tests/metagpt/memory/test_role_zero_memory.py +++ b/tests/metagpt/memory/test_role_zero_memory.py @@ -1,3 +1,5 @@ +from datetime import datetime, timedelta + import pytest from metagpt.memory.role_zero_memory import RoleZeroLongTermMemory @@ -46,3 +48,36 @@ class TestRoleZeroLongTermMemory: item = LongTermMemoryItem(user_message=UserMessage(content="user"), ai_message=AIMessage(content="ai")) mock_memory.add(item) mock_memory.rag_engine.add_objs.assert_called_once_with([item]) + + def test_get_items_from_nodes(self, mocker, mock_memory: RoleZeroLongTermMemory): + mock_node1 = mocker.Mock() + mock_node2 = mocker.Mock() + mock_node3 = mocker.Mock() + + now = datetime.now() + item1 = LongTermMemoryItem( + user_message=UserMessage(content="user1"), ai_message=AIMessage(content="ai1"), created_at=now.timestamp() + ) + item2 = LongTermMemoryItem( + user_message=UserMessage(content="user2"), + ai_message=AIMessage(content="ai2"), + created_at=(now - timedelta(minutes=5)).timestamp(), + ) + item3 = LongTermMemoryItem( + user_message=UserMessage(content="user3"), + ai_message=AIMessage(content="ai3"), + created_at=(now + timedelta(minutes=5)).timestamp(), + ) + + mock_node1.metadata = {"obj": item1} + mock_node2.metadata = {"obj": item2} + mock_node3.metadata = {"obj": item3} + + result = mock_memory._get_items_from_nodes([mock_node1, mock_node2, mock_node3]) + + assert len(result) == 3 + assert result[0] == item2 + assert result[1] == item1 + assert result[2] == item3 + assert [item.user_message.content for item in result] == ["user2", "user1", "user3"] + assert [item.ai_message.content for item in result] == ["ai2", "ai1", "ai3"] diff --git a/tests/metagpt/roles/di/test_role_zero.py b/tests/metagpt/roles/di/test_role_zero.py index d4d4a46da..964d456a7 100644 --- a/tests/metagpt/roles/di/test_role_zero.py +++ b/tests/metagpt/roles/di/test_role_zero.py @@ -46,7 +46,13 @@ class TestRoleZero: [AIMessage(content="ai1"), UserMessage(content="user"), AIMessage(content="ai2")], [Message(content="related")], True, - [Message(content="related"), UserMessage(content="user"), AIMessage(content="ai2")], + [ + Message(content="related"), + UserMessage(content="user"), + AIMessage(content="ai1"), + UserMessage(content="user"), + AIMessage(content="ai2"), + ], ), ( None, @@ -79,6 +85,7 @@ class TestRoleZero: ): mock_role_zero.memory_k = 2 mock_role_zero.rc.memory.get = mocker.Mock(return_value=memories) + mock_role_zero.rc.memory.get_by_position = mocker.Mock(return_value=UserMessage(content="user")) mock_role_zero._should_use_longterm_memory = mocker.Mock(return_value=should_use_ltm) mock_role_zero.longterm_memory.fetch = mocker.Mock(return_value=related_memories) mock_role_zero._is_first_message_from_ai = mocker.Mock(return_value=is_first_from_ai) @@ -97,7 +104,7 @@ class TestRoleZero: mock_role_zero._should_use_longterm_memory.assert_called_once_with(k=really_k, k_memories=memories) if should_use_ltm: - mock_role_zero.longterm_memory.fetch.assert_called_once_with(memories[-1].content) + mock_role_zero.longterm_memory.fetch.assert_called_once_with("user") mock_role_zero._is_first_message_from_ai.assert_called_once_with(memories) def test_add_memory(self, mocker, mock_role_zero: RoleZero): @@ -190,3 +197,24 @@ class TestRoleZero: def test_is_first_message_from_ai(self, mock_role_zero: RoleZero, memories, expected): result = mock_role_zero._is_first_message_from_ai(memories) assert result == expected + + @pytest.mark.parametrize( + "memories,expected", + [ + ([UserMessage(content="user1"), AIMessage(content="ai"), UserMessage(content="user2")], "user2"), + ( + [ + UserMessage(content="user1", cause_by="test"), + AIMessage(content="ai"), + UserMessage(content="user2", cause_by="test"), + ], + "", + ), + ([AIMessage(content="ai1"), AIMessage(content="ai2")], ""), + ([UserMessage(content="user")], "user"), + ([], ""), + ], + ) + def test_build_longterm_memory_query(self, mock_role_zero: RoleZero, memories, expected): + result = mock_role_zero._build_longterm_memory_query(memories) + assert result == expected From d077cd0b2f4b47bfa436ef11a87f4c0594b6a07b Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Tue, 10 Sep 2024 20:05:56 +0800 Subject: [PATCH 08/18] use memory index to get the last user message --- metagpt/const.py | 3 +++ metagpt/roles/di/role_zero.py | 23 ++++++++++++----------- metagpt/schema.py | 3 --- tests/metagpt/roles/di/test_role_zero.py | 14 +------------- 4 files changed, 16 insertions(+), 27 deletions(-) diff --git a/metagpt/const.py b/metagpt/const.py index c53e8494a..e7a0dc31b 100644 --- a/metagpt/const.py +++ b/metagpt/const.py @@ -157,3 +157,6 @@ SWE_SETUP_PATH = get_metagpt_package_root() / "metagpt/tools/swe_agent_commands/ # experience pool EXPERIENCE_MASK = "" + +# Used to identify user requirements in the memory index. +USER_REQUIREMENT = "metagpt.actions.add_requirement.UserRequirement" diff --git a/metagpt/roles/di/role_zero.py b/metagpt/roles/di/role_zero.py index 5818bc1f3..9e64a1954 100644 --- a/metagpt/roles/di/role_zero.py +++ b/metagpt/roles/di/role_zero.py @@ -13,7 +13,7 @@ from metagpt.actions import Action, UserRequirement from metagpt.actions.analyze_requirements import AnalyzeRequirementsRestrictions from metagpt.actions.di.run_command import RunCommand from metagpt.actions.search_enhanced_qa import SearchEnhancedQA -from metagpt.const import IMAGES +from metagpt.const import IMAGES, USER_REQUIREMENT from metagpt.exp_pool import exp_cache from metagpt.exp_pool.context_builders import RoleZeroContextBuilder from metagpt.exp_pool.serializers import RoleZeroSerializer @@ -602,10 +602,10 @@ class RoleZero(Role): memories = self.rc.memory.get(k) - if not self._should_use_longterm_memory(k=k, k_memories=memories): + if not self._should_use_longterm_memory(k=k): return memories - query = self._build_longterm_memory_query(memories) + query = self._build_longterm_memory_query() related_memories = self.longterm_memory.fetch(query) logger.info(f"Fetched {len(related_memories)} long-term memories.") @@ -625,19 +625,17 @@ class RoleZero(Role): self._transfer_to_longterm_memory() - def _should_use_longterm_memory(self, k: int = None, k_memories: list[Message] = None) -> bool: + def _should_use_longterm_memory(self, k: int = None) -> bool: """Determines if long-term memory should be used. Long-term memory is used if: - k is not 0. - - k_memories is None or k_memories is not empty, and the last message is a user message. - Long-term memory usage is enabled. - The count of recent memories is greater than self.memory_k. """ conds = [ k != 0, - k_memories is None or self._is_last_message_from_user(k_memories), self.enable_longterm_memory, self.rc.memory.count() > self.memory_k, ] @@ -662,16 +660,19 @@ class RoleZero(Role): return LongTermMemoryItem(user_message=user_message, ai_message=message) - def _is_last_message_from_user(self, memories: list[Message]) -> bool: - return bool(memories and memories[-1].is_user_message()) - def _is_first_message_from_ai(self, memories: list[Message]) -> bool: return bool(memories and memories[0].is_ai_message()) - def _build_longterm_memory_query(self, memories: list[Message]) -> str: + def _build_longterm_memory_query(self) -> str: """Build the content used to query related long-term memory. Default is to get the most recent user message, or an empty string if none is found. """ + message = self._get_the_last_user_message() - return next((m.content for m in reversed(memories) if m.is_real_user_message()), "") + return message.content if message else "" + + def _get_the_last_user_message(self) -> Message: + values = self.rc.memory.index.get(USER_REQUIREMENT, []) + + return values[-1] if values else None diff --git a/metagpt/schema.py b/metagpt/schema.py index 8481bccf3..9352664e2 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -415,9 +415,6 @@ class Message(BaseModel): def is_ai_message(self) -> bool: return self.role == "assistant" - def is_real_user_message(self) -> bool: - return self.is_user_message() and "UserRequirement" in self.cause_by - class UserMessage(Message): """便于支持OpenAI的消息 diff --git a/tests/metagpt/roles/di/test_role_zero.py b/tests/metagpt/roles/di/test_role_zero.py index 964d456a7..0d427ce0f 100644 --- a/tests/metagpt/roles/di/test_role_zero.py +++ b/tests/metagpt/roles/di/test_role_zero.py @@ -101,7 +101,7 @@ class TestRoleZero: mock_role_zero.rc.memory.get.assert_called_once_with(really_k) if k != 0: - mock_role_zero._should_use_longterm_memory.assert_called_once_with(k=really_k, k_memories=memories) + mock_role_zero._should_use_longterm_memory.assert_called_once_with(k=really_k) if should_use_ltm: mock_role_zero.longterm_memory.fetch.assert_called_once_with("user") @@ -174,18 +174,6 @@ class TestRoleZero: mock_role_zero.rc.memory.get_by_position.assert_any_call(-(mock_role_zero.memory_k + 1)) mock_role_zero.rc.memory.get_by_position.assert_any_call(-(mock_role_zero.memory_k + 2)) - @pytest.mark.parametrize( - "memories,expected", - [ - ([UserMessage(content="user")], True), - ([AIMessage(content="ai")], False), - ([], False), - ], - ) - def test_is_last_message_from_user(self, mock_role_zero: RoleZero, memories, expected): - result = mock_role_zero._is_last_message_from_user(memories) - assert result == expected - @pytest.mark.parametrize( "memories,expected", [ From c0d5c031f8c2d8a2f3f45a40cce3294733182e7b Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Wed, 11 Sep 2024 20:25:29 +0800 Subject: [PATCH 09/18] Change the way of taking over memory --- metagpt/const.py | 4 +- metagpt/environment/mgx/mgx_env.py | 4 +- metagpt/memory/role_zero_memory.py | 130 +++++++++-- metagpt/roles/di/role_zero.py | 130 ++--------- metagpt/roles/di/team_leader.py | 5 +- metagpt/schema.py | 5 +- tests/metagpt/memory/test_role_zero_memory.py | 191 +++++++++++----- tests/metagpt/roles/di/test_role_zero.py | 208 ------------------ 8 files changed, 270 insertions(+), 407 deletions(-) delete mode 100644 tests/metagpt/roles/di/test_role_zero.py diff --git a/metagpt/const.py b/metagpt/const.py index e7a0dc31b..4fe8dca3d 100644 --- a/metagpt/const.py +++ b/metagpt/const.py @@ -158,5 +158,5 @@ SWE_SETUP_PATH = get_metagpt_package_root() / "metagpt/tools/swe_agent_commands/ # experience pool EXPERIENCE_MASK = "" -# Used to identify user requirements in the memory index. -USER_REQUIREMENT = "metagpt.actions.add_requirement.UserRequirement" +# TeamLeader's name +TEAMLEADER_NAME = "Mike" diff --git a/metagpt/environment/mgx/mgx_env.py b/metagpt/environment/mgx/mgx_env.py index 4df04d3ce..53690f7d7 100644 --- a/metagpt/environment/mgx/mgx_env.py +++ b/metagpt/environment/mgx/mgx_env.py @@ -8,7 +8,7 @@ from metagpt.actions import ( WriteTest, ) from metagpt.actions.summarize_code import SummarizeCode -from metagpt.const import AGENT, IMAGES +from metagpt.const import AGENT, IMAGES, TEAMLEADER_NAME from metagpt.environment.base_env import Environment from metagpt.logs import get_human_input from metagpt.roles import Architect, ProductManager, ProjectManager, Role @@ -31,7 +31,7 @@ class MGXEnv(Environment, SerializationMixin): """let the team leader take over message publishing""" message = self.attach_images(message) # for multi-modal message - tl = self.get_role("Mike") # TeamLeader's name is Mike + tl = self.get_role(TEAMLEADER_NAME) # TeamLeader's name is Mike if user_defined_recipient: # human user's direct chat message to a certain role diff --git a/metagpt/memory/role_zero_memory.py b/metagpt/memory/role_zero_memory.py index 7f777df56..ff403d129 100644 --- a/metagpt/memory/role_zero_memory.py +++ b/metagpt/memory/role_zero_memory.py @@ -1,8 +1,14 @@ -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Optional -from pydantic import BaseModel, Field +from pydantic import Field +from metagpt.actions import UserRequirement +from metagpt.const import TEAMLEADER_NAME +from metagpt.logs import logger +from metagpt.memory import Memory from metagpt.schema import LongTermMemoryItem, Message +from metagpt.utils.common import any_to_str +from metagpt.utils.exceptions import handle_exception if TYPE_CHECKING: from llama_index.core.schema import NodeWithScore @@ -10,9 +16,10 @@ if TYPE_CHECKING: from metagpt.rag.engines import SimpleEngine -class RoleZeroLongTermMemory(BaseModel): +class RoleZeroLongTermMemory(Memory): persist_path: str = Field(default=".role_memory_data", description="The directory to save data.") collection_name: str = Field(default="role_zero", description="The name of the collection, such as the role name.") + memory_k: int = Field(default=200, description="The capacity of short-term memory.") _rag_engine: Any = None @@ -44,7 +51,104 @@ class RoleZeroLongTermMemory(BaseModel): return rag_engine - def fetch(self, query: str) -> list[Message]: + def add(self, message: Message): + """Add a new message and potentially transfer it to long-term memory.""" + + super().add(message) + + if not self._should_use_longterm_memory_for_add(): + return + + self._transfer_to_longterm_memory() + + def get(self, k=0) -> list[Message]: + """Return recent memories and optionally combines them with related long-term memories.""" + + memories = super().get(k) + + if not self._should_use_longterm_memory_for_get(k=k): + return memories + + query = self._build_longterm_memory_query() + related_memories = self._fetch_longterm_memories(query) + logger.info(f"Fetched {len(related_memories)} long-term memories.") + + final_memories = related_memories + memories + + return final_memories + + def _should_use_longterm_memory_for_add(self) -> bool: + """Determines if long-term memory should be used for add.""" + + return self.count() > self.memory_k + + def _should_use_longterm_memory_for_get(self, k: int) -> bool: + """Determines if long-term memory should be used for get. + + Long-term memory is used if: + - k is not 0. + - The last message is from user requirement. + - The count of recent memories is greater than self.memory_k. + """ + + conds = [ + k != 0, + self._is_last_message_from_user_requirement(), + self.count() > self.memory_k, + ] + + return all(conds) + + def _transfer_to_longterm_memory(self): + item = self._get_longterm_memory_item() + self._add_to_longterm_memory(item) + + @handle_exception + def _get_longterm_memory_item(self) -> Optional[LongTermMemoryItem]: + """Retrieves the most recent message before the last k messages.""" + + index = -(self.memory_k + 1) + message = self.get_by_position(index) + + return LongTermMemoryItem(message=message) + + def _add_to_longterm_memory(self, item: LongTermMemoryItem): + """Adds a long-term memory item to the RAG engine.""" + + if not item: + return + + self.rag_engine.add_objs([item]) + + def _build_longterm_memory_query(self) -> str: + """Build the content used to query related long-term memory. + + Default is to get the most recent user message, or an empty string if none is found. + """ + + message = self._get_the_last_message() + + return message.content if message else "" + + def _get_the_last_message(self) -> Optional[Message]: + if not self.count(): + return None + + return self.get_by_position(-1) + + def _is_last_message_from_user_requirement(self) -> bool: + message = self._get_the_last_message() + + if not message: + return False + + is_user_message = message.is_user_message() + cause_by_user_requirement = message.cause_by == any_to_str(UserRequirement) + sent_from_team_leader = message.sent_from == TEAMLEADER_NAME + + return is_user_message and (cause_by_user_requirement or sent_from_team_leader) + + def _fetch_longterm_memories(self, query: str) -> list[Message]: """Fetches long-term memories based on a query. Args: @@ -59,26 +163,10 @@ class RoleZeroLongTermMemory(BaseModel): nodes = self.rag_engine.retrieve(query) items = self._get_items_from_nodes(nodes) - - memories = [] - for item in items: - memories.append(item.user_message) - memories.append(item.ai_message) + memories = [item.message for item in items] return memories - def add(self, item: LongTermMemoryItem): - """Adds a long-term memory item to the RAG engine. - - Args: - item (LongTermMemoryItem): The memory item containing user and AI messages. - """ - - if not item: - return - - self.rag_engine.add_objs([item]) - def _get_items_from_nodes(self, nodes: list["NodeWithScore"]) -> list[LongTermMemoryItem]: """Get items from nodes and arrange them in order of their `created_at`.""" diff --git a/metagpt/roles/di/role_zero.py b/metagpt/roles/di/role_zero.py index 9e64a1954..466d87f5c 100644 --- a/metagpt/roles/di/role_zero.py +++ b/metagpt/roles/di/role_zero.py @@ -13,7 +13,7 @@ from metagpt.actions import Action, UserRequirement from metagpt.actions.analyze_requirements import AnalyzeRequirementsRestrictions from metagpt.actions.di.run_command import RunCommand from metagpt.actions.search_enhanced_qa import SearchEnhancedQA -from metagpt.const import IMAGES, USER_REQUIREMENT +from metagpt.const import IMAGES from metagpt.exp_pool import exp_cache from metagpt.exp_pool.context_builders import RoleZeroContextBuilder from metagpt.exp_pool.serializers import RoleZeroSerializer @@ -34,7 +34,7 @@ from metagpt.prompts.di.role_zero import ( SYSTEM_PROMPT, ) from metagpt.roles import Role -from metagpt.schema import AIMessage, LongTermMemoryItem, Message, UserMessage +from metagpt.schema import AIMessage, Message, UserMessage from metagpt.strategy.experience_retriever import DummyExpRetriever, ExpRetriever from metagpt.strategy.planner import Planner from metagpt.tools.libs.browser import Browser @@ -42,7 +42,6 @@ from metagpt.tools.libs.editor import Editor from metagpt.tools.tool_recommend import BM25ToolRecommender, ToolRecommender from metagpt.tools.tool_registry import register_tool from metagpt.utils.common import CodeParser, any_to_str, extract_and_encode_images -from metagpt.utils.exceptions import handle_exception from metagpt.utils.repair_llm_raw_output import ( RepairType, repair_escape_error, @@ -94,7 +93,6 @@ class RoleZero(Role): commands: list[dict] = [] # commands to be executed memory_k: int = 200 # number of memories (messages) to use as historical context enable_longterm_memory: bool = True # whether to use longterm memory - longterm_memory: RoleZeroLongTermMemory = None use_fixed_sop: bool = False requirements_constraints: str = "" # the constraints in user requirements use_summary: bool = True # whether to summarize at the end @@ -176,8 +174,8 @@ class RoleZero(Role): The role name will be used as the collection name. """ - if self.enable_longterm_memory and not self.longterm_memory: - self.longterm_memory = RoleZeroLongTermMemory(collection_name=self.name.replace(" ", "")) + if self.enable_longterm_memory: + self.rc.memory = RoleZeroLongTermMemory(collection_name=self.name.replace(" ", ""), memory_k=self.memory_k) return self @@ -195,7 +193,7 @@ class RoleZero(Role): return False if not self.planner.plan.goal: - self.planner.plan.goal = self._get_all_memories()[-1].content + self.planner.plan.goal = self.get_memories()[-1].content self.requirements_constraints = await AnalyzeRequirementsRestrictions().run(self.planner.plan.goal) ### 1. Experience ### @@ -227,7 +225,7 @@ class RoleZero(Role): ) ### Recent Observation ### - memory = self._fetch_memories() + memory = self.rc.memory.get(self.memory_k) memory = await self.parse_browser_actions(memory) memory = self.parse_images(memory) @@ -282,15 +280,15 @@ class RoleZero(Role): return await super()._act() commands, ok, self.command_rsp = await self._parse_commands(self.command_rsp) - self._add_memory(AIMessage(content=self.command_rsp)) + self.rc.memory.add(AIMessage(content=self.command_rsp)) if not ok: error_msg = commands - self._add_memory(UserMessage(content=error_msg, cause_by=RunCommand)) + self.rc.memory.add(UserMessage(content=error_msg, cause_by=RunCommand)) return error_msg logger.info(f"Commands: \n{commands}") outputs = await self._run_commands(commands) logger.info(f"Commands outputs: \n{outputs}") - self._add_memory(UserMessage(content=outputs, cause_by=RunCommand)) + self.rc.memory.add(UserMessage(content=outputs, cause_by=RunCommand)) return AIMessage( content=f"I have finished the task, please mark my task as finished. Outputs: {outputs}", @@ -343,7 +341,7 @@ class RoleZero(Role): return rsp_msg, "" # routing - memory = self._fetch_memories() + memory = self.get_memories(k=self.memory_k) context = self.llm.format_msg(memory + [UserMessage(content=QUICK_THINK_PROMPT)]) async with ThoughtReporter() as reporter: await reporter.async_report({"type": "classify"}) @@ -368,7 +366,7 @@ class RoleZero(Role): answer = await SearchEnhancedQA().run(query) if answer: - self._add_memory(AIMessage(content=answer, cause_by=RunCommand)) + self.rc.memory.add(AIMessage(content=answer, cause_by=RunCommand)) await self.reply_to_human(content=answer) rsp_msg = AIMessage( content="Complete run", @@ -379,7 +377,7 @@ class RoleZero(Role): return rsp_msg, intent_result async def _check_duplicates(self, req: list[dict], command_rsp: str): - past_rsp = [mem.content for mem in self._fetch_memories()] + past_rsp = [mem.content for mem in self.rc.memory.get(self.memory_k)] if command_rsp in past_rsp: # Normal response with thought contents are highly unlikely to reproduce # If an identical response is detected, it is a bad response, mostly due to LLM repeating generated content @@ -537,7 +535,7 @@ class RoleZero(Role): def _retrieve_experience(self) -> str: """Default implementation of experience retrieval. Can be overwritten in subclasses.""" - context = [str(msg) for msg in self._fetch_memories()] + context = [str(msg) for msg in self.rc.memory.get(self.memory_k)] context = "\n\n".join(context) example = self.experience_retriever.retrieve(context=context) return example @@ -562,9 +560,9 @@ class RoleZero(Role): async def _end(self, **kwarg): self._set_state(-1) - memory = self._fetch_memories() + memory = self.rc.memory.get(self.memory_k) # Ensure reply to the human before the "end" command is executed. Hard code k=5 for checking. - if not any(["reply_to_human" in memory.content for memory in self._fetch_memories(k=5)]): + if not any(["reply_to_human" in memory.content for memory in self.get_memories(k=5)]): logger.info("manually reply to human") pattern = r"\[Language Restrictions\](.*?)\n" match = re.search(pattern, self.requirements_constraints, re.DOTALL) @@ -573,106 +571,10 @@ class RoleZero(Role): await reporter.async_report({"type": "quick"}) reply_content = await self.llm.aask(self.llm.format_msg(memory + [UserMessage(reply_to_human_prompt)])) await self.reply_to_human(content=reply_content) - self._add_memory(AIMessage(content=reply_content, cause_by=RunCommand)) + self.rc.memory.add(AIMessage(content=reply_content, cause_by=RunCommand)) outputs = "" # Summary of the Completed Task and Deliverables if self.use_summary: logger.info("end current run and summarize") outputs = await self.llm.aask(self.llm.format_msg(memory + [UserMessage(SUMMARY_PROMPT)])) return outputs - - def _get_all_memories(self) -> list[Message]: - return self._fetch_memories(k=0) - - def _fetch_memories(self, k: Optional[int] = None) -> list[Message]: - """Fetches recent memories and optionally combines them with related long-term memories. - - If long-term memory is not enabled or the last message is not from the user, - it returns the recent memories without fetching from long-term memory. - - Args: - k (Optional[int]): The number of recent memories to fetch. If None, defaults to self.memory_k. - - Returns: - List[Message]: A list of messages representing the combined memories. - """ - - if k is None: - k = self.memory_k - - memories = self.rc.memory.get(k) - - if not self._should_use_longterm_memory(k=k): - return memories - - query = self._build_longterm_memory_query() - related_memories = self.longterm_memory.fetch(query) - logger.info(f"Fetched {len(related_memories)} long-term memories.") - - # Keep user and AI messages are paired. - if self._is_first_message_from_ai(memories): - memories.insert(0, self.rc.memory.get_by_position(-(k + 1))) - - final_memories = related_memories + memories - - return final_memories - - def _add_memory(self, message: Message): - self.rc.memory.add(message) - - if not self._should_use_longterm_memory(): - return - - self._transfer_to_longterm_memory() - - def _should_use_longterm_memory(self, k: int = None) -> bool: - """Determines if long-term memory should be used. - - Long-term memory is used if: - - k is not 0. - - Long-term memory usage is enabled. - - The count of recent memories is greater than self.memory_k. - """ - - conds = [ - k != 0, - self.enable_longterm_memory, - self.rc.memory.count() > self.memory_k, - ] - - return all(conds) - - def _transfer_to_longterm_memory(self): - item = self._get_longterm_memory_item() - self.longterm_memory.add(item) - - @handle_exception - def _get_longterm_memory_item(self) -> Optional[LongTermMemoryItem]: - """Retrieves the most recent pair of user and AI messages before the last k messages.""" - - index = -(self.memory_k + 1) - message = self.rc.memory.get_by_position(index) - if not message.is_ai_message(): - return None - - index = -(self.memory_k + 2) - user_message = self.rc.memory.get_by_position(index) - - return LongTermMemoryItem(user_message=user_message, ai_message=message) - - def _is_first_message_from_ai(self, memories: list[Message]) -> bool: - return bool(memories and memories[0].is_ai_message()) - - def _build_longterm_memory_query(self) -> str: - """Build the content used to query related long-term memory. - - Default is to get the most recent user message, or an empty string if none is found. - """ - message = self._get_the_last_user_message() - - return message.content if message else "" - - def _get_the_last_user_message(self) -> Message: - values = self.rc.memory.index.get(USER_REQUIREMENT, []) - - return values[-1] if values else None diff --git a/metagpt/roles/di/team_leader.py b/metagpt/roles/di/team_leader.py index 1bf6364d9..0724ffdea 100644 --- a/metagpt/roles/di/team_leader.py +++ b/metagpt/roles/di/team_leader.py @@ -5,6 +5,7 @@ from typing import Annotated from pydantic import Field from metagpt.actions.di.run_command import RunCommand +from metagpt.const import TEAMLEADER_NAME from metagpt.prompts.di.team_leader import ( FINISH_CURRENT_TASK_CMD, TL_INFO, @@ -19,7 +20,7 @@ from metagpt.tools.tool_registry import register_tool @register_tool(include_functions=["publish_team_message"]) class TeamLeader(RoleZero): - name: str = "Mike" + name: str = TEAMLEADER_NAME profile: str = "Team Leader" goal: str = "Manage a team to assist users" thought_guidance: str = TL_THOUGHT_GUIDANCE @@ -84,4 +85,4 @@ class TeamLeader(RoleZero): def finish_current_task(self): self.planner.plan.finish_current_task() - self._add_memory(AIMessage(content=FINISH_CURRENT_TASK_CMD)) + self.rc.memory.add(AIMessage(content=FINISH_CURRENT_TASK_CMD)) diff --git a/metagpt/schema.py b/metagpt/schema.py index 9352664e2..d4c75bfb4 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -965,9 +965,8 @@ class BaseEnum(Enum): class LongTermMemoryItem(BaseModel): - user_message: Message - ai_message: Message + message: Message created_at: Optional[float] = Field(default_factory=time.time) def rag_key(self) -> str: - return self.user_message.content + return self.message.content diff --git a/tests/metagpt/memory/test_role_zero_memory.py b/tests/metagpt/memory/test_role_zero_memory.py index 1c6fb785e..80eb58e49 100644 --- a/tests/metagpt/memory/test_role_zero_memory.py +++ b/tests/metagpt/memory/test_role_zero_memory.py @@ -2,8 +2,10 @@ from datetime import datetime, timedelta import pytest +from metagpt.actions import UserRequirement +from metagpt.const import TEAMLEADER_NAME from metagpt.memory.role_zero_memory import RoleZeroLongTermMemory -from metagpt.schema import AIMessage, LongTermMemoryItem, UserMessage +from metagpt.schema import AIMessage, LongTermMemoryItem, Message, UserMessage class TestRoleZeroLongTermMemory: @@ -13,71 +15,150 @@ class TestRoleZeroLongTermMemory: memory._resolve_rag_engine = mocker.Mock() return memory - def test_fetch_empty_query(self, mock_memory: RoleZeroLongTermMemory): - assert mock_memory.fetch("") == [] + def test_add(self, mocker, mock_memory: RoleZeroLongTermMemory): + mock_memory._should_use_longterm_memory_for_add = mocker.Mock(return_value=True) + mock_memory._transfer_to_longterm_memory = mocker.Mock() - def test_fetch(self, mocker, mock_memory: RoleZeroLongTermMemory): - mock_node1 = mocker.Mock() - mock_node2 = mocker.Mock() - mock_node1.metadata = { - "obj": LongTermMemoryItem(user_message=UserMessage(content="user1"), ai_message=AIMessage(content="ai1")) - } - mock_node2.metadata = { - "obj": LongTermMemoryItem(user_message=UserMessage(content="user2"), ai_message=AIMessage(content="ai2")) - } + message = UserMessage(content="test") + mock_memory.add(message) - mock_memory.rag_engine.retrieve.return_value = [mock_node1, mock_node2] + assert mock_memory.storage[-1] == message + mock_memory._transfer_to_longterm_memory.assert_called_once() - result = mock_memory.fetch("test query") + def test_get(self, mocker, mock_memory: RoleZeroLongTermMemory): + mock_memory._should_use_longterm_memory_for_get = mocker.Mock(return_value=True) + mock_memory._build_longterm_memory_query = mocker.Mock(return_value="query") + mock_memory._fetch_longterm_memories = mocker.Mock(return_value=[Message(content="long-term")]) - assert len(result) == 4 - assert isinstance(result[0], UserMessage) - assert isinstance(result[1], AIMessage) - assert result[0].content == "user1" - assert result[1].content == "ai1" - assert result[2].content == "user2" - assert result[3].content == "ai2" + mock_memory.storage = [Message(content="short-term")] - mock_memory.rag_engine.retrieve.assert_called_once_with("test query") + result = mock_memory.get() - def test_add_empty_item(self, mock_memory: RoleZeroLongTermMemory): - mock_memory.add(None) - mock_memory.rag_engine.add_objs.assert_not_called() + assert len(result) == 2 + assert result[0].content == "long-term" + assert result[1].content == "short-term" + + def test_should_use_longterm_memory_for_add(self, mocker, mock_memory: RoleZeroLongTermMemory): + mocker.patch.object(mock_memory, "storage", [None] * 201) + + mock_memory.memory_k = 200 + + assert mock_memory._should_use_longterm_memory_for_add() == True + + mocker.patch.object(mock_memory, "storage", [None] * 199) + assert mock_memory._should_use_longterm_memory_for_add() == False + + @pytest.mark.parametrize( + "k,is_last_from_user,count,expected", + [ + (0, True, 201, False), + (1, False, 201, False), + (1, True, 199, False), + (1, True, 201, True), + ], + ) + def test_should_use_longterm_memory_for_get( + self, mocker, mock_memory: RoleZeroLongTermMemory, k, is_last_from_user, count, expected + ): + mock_memory._is_last_message_from_user_requirement = mocker.Mock(return_value=is_last_from_user) + mocker.patch.object(mock_memory, "storage", [None] * count) + mock_memory.memory_k = 200 + + assert mock_memory._should_use_longterm_memory_for_get(k) == expected + + def test_transfer_to_longterm_memory(self, mocker, mock_memory: RoleZeroLongTermMemory): + mock_item = mocker.Mock() + mock_memory._get_longterm_memory_item = mocker.Mock(return_value=mock_item) + mock_memory._add_to_longterm_memory = mocker.Mock() + + mock_memory._transfer_to_longterm_memory() + + mock_memory._add_to_longterm_memory.assert_called_once_with(mock_item) + + def test_get_longterm_memory_item(self, mocker, mock_memory: RoleZeroLongTermMemory): + mock_message = Message(content="test") + mock_memory.storage = [mock_message, mock_message] + mock_memory.memory_k = 1 + + result = mock_memory._get_longterm_memory_item() + + assert isinstance(result, LongTermMemoryItem) + assert result.message == mock_message + + def test_add_to_longterm_memory(self, mock_memory: RoleZeroLongTermMemory): + item = LongTermMemoryItem(message=Message(content="test")) + mock_memory._add_to_longterm_memory(item) - def test_add_item(self, mock_memory: RoleZeroLongTermMemory): - item = LongTermMemoryItem(user_message=UserMessage(content="user"), ai_message=AIMessage(content="ai")) - mock_memory.add(item) mock_memory.rag_engine.add_objs.assert_called_once_with([item]) + def test_build_longterm_memory_query(self, mocker, mock_memory: RoleZeroLongTermMemory): + mock_message = Message(content="query") + mock_memory._get_the_last_message = mocker.Mock(return_value=mock_message) + + result = mock_memory._build_longterm_memory_query() + + assert result == "query" + + def test_get_the_last_message(self, mock_memory: RoleZeroLongTermMemory): + mock_memory.storage = [Message(content="1"), Message(content="2")] + + result = mock_memory._get_the_last_message() + + assert result.content == "2" + + @pytest.mark.parametrize( + "message,expected", + [ + (UserMessage(content="test", cause_by=UserRequirement), True), + (UserMessage(content="test", sent_from=TEAMLEADER_NAME), True), + (UserMessage(content="test"), True), + (AIMessage(content="test"), False), + (None, False), + ], + ) + def test_is_last_message_from_user_requirement( + self, mocker, mock_memory: RoleZeroLongTermMemory, message, expected + ): + mock_memory._get_the_last_message = mocker.Mock(return_value=message) + + assert mock_memory._is_last_message_from_user_requirement() == expected + + def test_fetch_longterm_memories(self, mocker, mock_memory: RoleZeroLongTermMemory): + mock_nodes = [mocker.Mock(), mocker.Mock()] + mock_memory.rag_engine.retrieve = mocker.Mock(return_value=mock_nodes) + mock_items = [ + LongTermMemoryItem(message=UserMessage(content="user1")), + LongTermMemoryItem(message=AIMessage(content="ai1")), + ] + mock_memory._get_items_from_nodes = mocker.Mock(return_value=mock_items) + + result = mock_memory._fetch_longterm_memories("query") + + assert len(result) == 2 + assert result[0].content == "user1" + assert result[1].content == "ai1" + def test_get_items_from_nodes(self, mocker, mock_memory: RoleZeroLongTermMemory): - mock_node1 = mocker.Mock() - mock_node2 = mocker.Mock() - mock_node3 = mocker.Mock() - now = datetime.now() - item1 = LongTermMemoryItem( - user_message=UserMessage(content="user1"), ai_message=AIMessage(content="ai1"), created_at=now.timestamp() - ) - item2 = LongTermMemoryItem( - user_message=UserMessage(content="user2"), - ai_message=AIMessage(content="ai2"), - created_at=(now - timedelta(minutes=5)).timestamp(), - ) - item3 = LongTermMemoryItem( - user_message=UserMessage(content="user3"), - ai_message=AIMessage(content="ai3"), - created_at=(now + timedelta(minutes=5)).timestamp(), - ) + mock_nodes = [ + mocker.Mock( + metadata={ + "obj": LongTermMemoryItem( + message=Message(content="2"), created_at=(now - timedelta(minutes=1)).timestamp() + ) + } + ), + mocker.Mock( + metadata={ + "obj": LongTermMemoryItem( + message=Message(content="1"), created_at=(now - timedelta(minutes=2)).timestamp() + ) + } + ), + mocker.Mock(metadata={"obj": LongTermMemoryItem(message=Message(content="3"), created_at=now.timestamp())}), + ] - mock_node1.metadata = {"obj": item1} - mock_node2.metadata = {"obj": item2} - mock_node3.metadata = {"obj": item3} - - result = mock_memory._get_items_from_nodes([mock_node1, mock_node2, mock_node3]) + result = mock_memory._get_items_from_nodes(mock_nodes) assert len(result) == 3 - assert result[0] == item2 - assert result[1] == item1 - assert result[2] == item3 - assert [item.user_message.content for item in result] == ["user2", "user1", "user3"] - assert [item.ai_message.content for item in result] == ["ai2", "ai1", "ai3"] + assert [item.message.content for item in result] == ["1", "2", "3"] diff --git a/tests/metagpt/roles/di/test_role_zero.py b/tests/metagpt/roles/di/test_role_zero.py deleted file mode 100644 index 0d427ce0f..000000000 --- a/tests/metagpt/roles/di/test_role_zero.py +++ /dev/null @@ -1,208 +0,0 @@ -import pytest - -from metagpt.roles.di.role_zero import RoleZero -from metagpt.schema import AIMessage, LongTermMemoryItem, Message, UserMessage - - -class TestRoleZero: - @pytest.fixture - def mock_role_zero(self, mocker): - role_zero = RoleZero() - role_zero.rc.memory = mocker.Mock() - role_zero.longterm_memory = mocker.Mock() - return role_zero - - def test_get_all_memories(self, mocker, mock_role_zero: RoleZero): - mock_memories = [Message(content="test1"), Message(content="test2")] - mock_role_zero._fetch_memories = mocker.Mock(return_value=mock_memories) - - result = mock_role_zero._get_all_memories() - - assert result == mock_memories - mock_role_zero._fetch_memories.assert_called_once_with(k=0) - - @pytest.mark.parametrize( - "k,should_use_ltm,memories,related_memories,is_first_from_ai,expected", - [ - ( - None, - False, - [UserMessage(content="user"), AIMessage(content="ai")], - [], - False, - [UserMessage(content="user"), AIMessage(content="ai")], - ), - ( - 1, - True, - [UserMessage(content="user")], - [Message(content="related")], - False, - [Message(content="related"), UserMessage(content="user")], - ), - ( - None, - True, - [AIMessage(content="ai1"), UserMessage(content="user"), AIMessage(content="ai2")], - [Message(content="related")], - True, - [ - Message(content="related"), - UserMessage(content="user"), - AIMessage(content="ai1"), - UserMessage(content="user"), - AIMessage(content="ai2"), - ], - ), - ( - None, - True, - [UserMessage(content="user"), AIMessage(content="ai")], - [Message(content="related")], - False, - [Message(content="related"), UserMessage(content="user"), AIMessage(content="ai")], - ), - ( - 0, - False, - [UserMessage(content="user"), AIMessage(content="ai")], - [], - False, - [UserMessage(content="user"), AIMessage(content="ai")], - ), - ], - ) - def test_fetch_memories( - self, - mocker, - mock_role_zero: RoleZero, - k, - should_use_ltm, - memories, - related_memories, - is_first_from_ai, - expected, - ): - mock_role_zero.memory_k = 2 - mock_role_zero.rc.memory.get = mocker.Mock(return_value=memories) - mock_role_zero.rc.memory.get_by_position = mocker.Mock(return_value=UserMessage(content="user")) - mock_role_zero._should_use_longterm_memory = mocker.Mock(return_value=should_use_ltm) - mock_role_zero.longterm_memory.fetch = mocker.Mock(return_value=related_memories) - mock_role_zero._is_first_message_from_ai = mocker.Mock(return_value=is_first_from_ai) - - result = mock_role_zero._fetch_memories(k) - - assert len(result) == len(expected) - for actual, expected_msg in zip(result, expected): - assert actual.role == expected_msg.role - assert actual.content == expected_msg.content - - really_k = k if k is not None else mock_role_zero.memory_k - mock_role_zero.rc.memory.get.assert_called_once_with(really_k) - - if k != 0: - mock_role_zero._should_use_longterm_memory.assert_called_once_with(k=really_k) - - if should_use_ltm: - mock_role_zero.longterm_memory.fetch.assert_called_once_with("user") - mock_role_zero._is_first_message_from_ai.assert_called_once_with(memories) - - def test_add_memory(self, mocker, mock_role_zero: RoleZero): - message = AIMessage(content="ai") - mock_role_zero.rc.memory.add = mocker.Mock() - mock_role_zero._should_use_longterm_memory = mocker.Mock(return_value=True) - mock_role_zero._transfer_to_longterm_memory = mocker.Mock() - - mock_role_zero._add_memory(message) - - mock_role_zero.rc.memory.add.assert_called_once_with(message) - mock_role_zero._transfer_to_longterm_memory.assert_called_once() - - @pytest.mark.parametrize( - "k,enable_longterm_memory,memory_count,k_memories,expected", - [ - (0, True, 30, None, False), # k is 0 - (None, False, 30, None, False), # Long-term memory usage is disabled - (None, True, 10, None, False), # Memory count is less than or equal to memory_k - (None, True, 30, [], False), # k_memories is empty - (None, True, 30, [AIMessage(content="ai")], False), # Last message in k_memories is not a user message - (None, True, 30, [AIMessage(content="ai"), UserMessage(content="user")], True), # All conditions are met - ], - ) - def test_should_use_longterm_memory( - self, mocker, mock_role_zero: RoleZero, k, enable_longterm_memory, memory_count, k_memories, expected - ): - mock_role_zero.enable_longterm_memory = enable_longterm_memory - mock_role_zero.rc.memory.count = mocker.Mock(return_value=memory_count) - mock_role_zero.memory_k = 20 - - result = mock_role_zero._should_use_longterm_memory(k, k_memories) - - assert result == expected - - def test_transfer_to_longterm_memory(self, mocker, mock_role_zero: RoleZero): - mock_item = LongTermMemoryItem(user_message=UserMessage(content="user"), ai_message=AIMessage(content="ai")) - mock_role_zero._get_longterm_memory_item = mocker.Mock(return_value=mock_item) - mock_role_zero.longterm_memory = mocker.Mock() - - mock_role_zero._transfer_to_longterm_memory() - - mock_role_zero.longterm_memory.add.assert_called_once_with(mock_item) - - def test_get_longterm_memory_item(self, mocker, mock_role_zero: RoleZero): - mock_role_zero.memory_k = 2 - mock_messages = [ - UserMessage(content="user1"), - AIMessage(content="ai1"), - UserMessage(content="user2"), - AIMessage(content="ai2"), - UserMessage(content="user3"), # memory_k + 2 - AIMessage(content="ai3"), # memory_k + 1 - UserMessage(content="recent1"), - AIMessage(content="recent2"), - ] - - mock_role_zero.rc.memory.get_by_position = mocker.Mock(side_effect=lambda i: mock_messages[i]) - mock_role_zero.rc.memory.count = mocker.Mock(return_value=len(mock_messages)) - - result = mock_role_zero._get_longterm_memory_item() - - assert isinstance(result, LongTermMemoryItem) - assert result.user_message.content == "user3" - assert result.ai_message.content == "ai3" - - mock_role_zero.rc.memory.get_by_position.assert_any_call(-(mock_role_zero.memory_k + 1)) - mock_role_zero.rc.memory.get_by_position.assert_any_call(-(mock_role_zero.memory_k + 2)) - - @pytest.mark.parametrize( - "memories,expected", - [ - ([AIMessage(content="ai")], True), - ([UserMessage(content="user")], False), - ([], False), - ], - ) - def test_is_first_message_from_ai(self, mock_role_zero: RoleZero, memories, expected): - result = mock_role_zero._is_first_message_from_ai(memories) - assert result == expected - - @pytest.mark.parametrize( - "memories,expected", - [ - ([UserMessage(content="user1"), AIMessage(content="ai"), UserMessage(content="user2")], "user2"), - ( - [ - UserMessage(content="user1", cause_by="test"), - AIMessage(content="ai"), - UserMessage(content="user2", cause_by="test"), - ], - "", - ), - ([AIMessage(content="ai1"), AIMessage(content="ai2")], ""), - ([UserMessage(content="user")], "user"), - ([], ""), - ], - ) - def test_build_longterm_memory_query(self, mock_role_zero: RoleZero, memories, expected): - result = mock_role_zero._build_longterm_memory_query(memories) - assert result == expected From cf3b80d5f10ced58cd908826f048283bf643a37f Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Thu, 12 Sep 2024 11:20:40 +0800 Subject: [PATCH 10/18] add `enable_longterm_memory` config --- config/config2.example.yaml | 2 ++ metagpt/config2.py | 4 ++++ metagpt/configs/role_zero_config.py | 7 +++++++ metagpt/roles/di/role_zero.py | 9 +++++---- 4 files changed, 18 insertions(+), 4 deletions(-) create mode 100644 metagpt/configs/role_zero_config.py diff --git a/config/config2.example.yaml b/config/config2.example.yaml index 2a0ebcc47..1c0934567 100644 --- a/config/config2.example.yaml +++ b/config/config2.example.yaml @@ -83,6 +83,8 @@ exp_pool: use_llm_ranker: true # Default is `true`, it will use LLM Reranker to get better result. collection_name: experience_pool # When `retrieval_type` is `chroma`, `collection_name` is the collection name in chromadb. +role_zero: + enable_longterm_memory: false # Whether to use long-term memory. azure_tts_subscription_key: "YOUR_SUBSCRIPTION_KEY" azure_tts_region: "eastus" diff --git a/metagpt/config2.py b/metagpt/config2.py index 7b6ddf8c6..c328713c5 100644 --- a/metagpt/config2.py +++ b/metagpt/config2.py @@ -19,6 +19,7 @@ from metagpt.configs.mermaid_config import MermaidConfig from metagpt.configs.omniparse_config import OmniParseConfig from metagpt.configs.redis_config import RedisConfig from metagpt.configs.role_custom_config import RoleCustomConfig +from metagpt.configs.role_zero_config import RoleZeroConfig from metagpt.configs.s3_config import S3Config from metagpt.configs.search_config import SearchConfig from metagpt.configs.workspace_config import WorkspaceConfig @@ -89,6 +90,9 @@ class Config(CLIParams, YamlModel): # Role's custom configuration roles: Optional[List[RoleCustomConfig]] = None + # RoleZero's configuration + role_zero: Optional[RoleZeroConfig] = None + omniparse: Optional[OmniParseConfig] = None @classmethod diff --git a/metagpt/configs/role_zero_config.py b/metagpt/configs/role_zero_config.py new file mode 100644 index 000000000..27103ddf6 --- /dev/null +++ b/metagpt/configs/role_zero_config.py @@ -0,0 +1,7 @@ +from pydantic import Field + +from metagpt.utils.yaml_model import YamlModel + + +class RoleZeroConfig(YamlModel): + enable_longterm_memory: bool = Field(default=False, description="Whether to use long-term memory.") diff --git a/metagpt/roles/di/role_zero.py b/metagpt/roles/di/role_zero.py index 466d87f5c..586e5345f 100644 --- a/metagpt/roles/di/role_zero.py +++ b/metagpt/roles/di/role_zero.py @@ -92,7 +92,6 @@ class RoleZero(Role): command_rsp: str = "" # the raw string containing the commands commands: list[dict] = [] # commands to be executed memory_k: int = 200 # number of memories (messages) to use as historical context - enable_longterm_memory: bool = True # whether to use longterm memory use_fixed_sop: bool = False requirements_constraints: str = "" # the constraints in user requirements use_summary: bool = True # whether to summarize at the end @@ -168,14 +167,16 @@ class RoleZero(Role): @model_validator(mode="after") def set_longterm_memory(self) -> "RoleZero": - """Set longterm memory. + """Set up long-term memory for the role if enabled in the configuration. - If enable_longterm_memory is True and longterm_memory is not set, set it. + If `enable_longterm_memory` is True, set up long-term memory. The role name will be used as the collection name. """ - if self.enable_longterm_memory: + enable_longterm_memory = bool(self.config.role_zero and self.config.role_zero.enable_longterm_memory) + if enable_longterm_memory: self.rc.memory = RoleZeroLongTermMemory(collection_name=self.name.replace(" ", ""), memory_k=self.memory_k) + logger.info(f"Long-term memory set for role '{self.name}'") return self From d9ab81bbae03dd8ea166b0dda535b0f3ec655434 Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Thu, 12 Sep 2024 11:21:59 +0800 Subject: [PATCH 11/18] add `enable_longterm_memory` config --- metagpt/roles/di/role_zero.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/metagpt/roles/di/role_zero.py b/metagpt/roles/di/role_zero.py index 586e5345f..9b68378cb 100644 --- a/metagpt/roles/di/role_zero.py +++ b/metagpt/roles/di/role_zero.py @@ -173,8 +173,7 @@ class RoleZero(Role): The role name will be used as the collection name. """ - enable_longterm_memory = bool(self.config.role_zero and self.config.role_zero.enable_longterm_memory) - if enable_longterm_memory: + if self.config.role_zero and self.config.role_zero.enable_longterm_memory: self.rc.memory = RoleZeroLongTermMemory(collection_name=self.name.replace(" ", ""), memory_k=self.memory_k) logger.info(f"Long-term memory set for role '{self.name}'") From 99e5a73fbb919a769a4a91a1040d2e44f4de4e33 Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Thu, 12 Sep 2024 11:31:12 +0800 Subject: [PATCH 12/18] add `enable_longterm_memory` config --- config/config2.example.yaml | 2 +- metagpt/config2.py | 2 +- metagpt/roles/di/role_zero.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/config/config2.example.yaml b/config/config2.example.yaml index 1c0934567..e4dfff1eb 100644 --- a/config/config2.example.yaml +++ b/config/config2.example.yaml @@ -84,7 +84,7 @@ exp_pool: collection_name: experience_pool # When `retrieval_type` is `chroma`, `collection_name` is the collection name in chromadb. role_zero: - enable_longterm_memory: false # Whether to use long-term memory. + enable_longterm_memory: false # Whether to use long-term memory. Default is `false`. azure_tts_subscription_key: "YOUR_SUBSCRIPTION_KEY" azure_tts_region: "eastus" diff --git a/metagpt/config2.py b/metagpt/config2.py index c328713c5..fd0cb0948 100644 --- a/metagpt/config2.py +++ b/metagpt/config2.py @@ -91,7 +91,7 @@ class Config(CLIParams, YamlModel): roles: Optional[List[RoleCustomConfig]] = None # RoleZero's configuration - role_zero: Optional[RoleZeroConfig] = None + role_zero: RoleZeroConfig = Field(default_factory=RoleZeroConfig) omniparse: Optional[OmniParseConfig] = None diff --git a/metagpt/roles/di/role_zero.py b/metagpt/roles/di/role_zero.py index 9b68378cb..f77471a2a 100644 --- a/metagpt/roles/di/role_zero.py +++ b/metagpt/roles/di/role_zero.py @@ -173,7 +173,7 @@ class RoleZero(Role): The role name will be used as the collection name. """ - if self.config.role_zero and self.config.role_zero.enable_longterm_memory: + if self.config.role_zero.enable_longterm_memory: self.rc.memory = RoleZeroLongTermMemory(collection_name=self.name.replace(" ", ""), memory_k=self.memory_k) logger.info(f"Long-term memory set for role '{self.name}'") From ea9582cf0e8cc833613606de122419fb087b6fd8 Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Thu, 12 Sep 2024 11:35:01 +0800 Subject: [PATCH 13/18] add `enable_longterm_memory` config --- metagpt/memory/role_zero_memory.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/metagpt/memory/role_zero_memory.py b/metagpt/memory/role_zero_memory.py index ff403d129..4840b5871 100644 --- a/metagpt/memory/role_zero_memory.py +++ b/metagpt/memory/role_zero_memory.py @@ -137,6 +137,8 @@ class RoleZeroLongTermMemory(Memory): return self.get_by_position(-1) def _is_last_message_from_user_requirement(self) -> bool: + """Checks if the last message is from a user requirement or sent by the team leader.""" + message = self._get_the_last_message() if not message: From a16535fb089e0970d2fcf12e3b11fca002ddb9f8 Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Thu, 12 Sep 2024 13:52:50 +0800 Subject: [PATCH 14/18] add handle_exception --- metagpt/memory/memory.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/metagpt/memory/memory.py b/metagpt/memory/memory.py index d44753413..0707a36ea 100644 --- a/metagpt/memory/memory.py +++ b/metagpt/memory/memory.py @@ -7,13 +7,14 @@ @Modified By: mashenquan, 2023-11-1. According to RFC 116: Updated the type of index key. """ from collections import defaultdict -from typing import DefaultDict, Iterable, Set +from typing import DefaultDict, Iterable, Optional, Set from pydantic import BaseModel, Field, SerializeAsAny from metagpt.const import IGNORED_MESSAGE_ID from metagpt.schema import Message from metagpt.utils.common import any_to_str, any_to_str_set +from metagpt.utils.exceptions import handle_exception class Memory(BaseModel): @@ -105,6 +106,7 @@ class Memory(BaseModel): rsp += self.index[action] return rsp - def get_by_position(self, position: int) -> Message: - """Return the message by its position""" + @handle_exception + def get_by_position(self, position: int) -> Optional[Message]: + """Returns the message at the given position if valid, otherwise returns None""" return self.storage[position] From 7b6dc3d744cdbd083c7b72228a8e65273325ded8 Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Fri, 20 Sep 2024 18:44:32 +0800 Subject: [PATCH 15/18] update llmranker prompt --- metagpt/rag/prompts/default_prompts.py | 1 + 1 file changed, 1 insertion(+) diff --git a/metagpt/rag/prompts/default_prompts.py b/metagpt/rag/prompts/default_prompts.py index 12a5e2f06..eadcaa770 100644 --- a/metagpt/rag/prompts/default_prompts.py +++ b/metagpt/rag/prompts/default_prompts.py @@ -23,6 +23,7 @@ Doc: 9, Relevance: 7 - Evaluate the relevance between the question and the documents. - The relevance score is a number from 1-10 based on how relevant you think the document is to the question. - Do not include any documents that are not relevant to the question. +- If none of the documents provided contain information that directly answers the question, simply respond with "no relevant documents". ## Constraint Format: Just print the result in format like **Format Example**. From 57b61083180f536557336acf197d9e9bf0cafb85 Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Mon, 23 Sep 2024 13:41:26 +0800 Subject: [PATCH 16/18] opt RoleZeroLongTermMemory init --- metagpt/roles/di/role_zero.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/metagpt/roles/di/role_zero.py b/metagpt/roles/di/role_zero.py index f77471a2a..d2e132c4b 100644 --- a/metagpt/roles/di/role_zero.py +++ b/metagpt/roles/di/role_zero.py @@ -174,7 +174,13 @@ class RoleZero(Role): """ if self.config.role_zero.enable_longterm_memory: - self.rc.memory = RoleZeroLongTermMemory(collection_name=self.name.replace(" ", ""), memory_k=self.memory_k) + self.rc.memory = RoleZeroLongTermMemory( + collection_name=self.name.replace(" ", ""), + memory_k=self.memory_k, + storage=self.rc.memory.storage, + index=self.rc.memory.index, + ignore_id=self.rc.memory.ignore_id, + ) logger.info(f"Long-term memory set for role '{self.name}'") return self From 59d94f77603b189554d7344472c57b9f091aed1b Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Mon, 23 Sep 2024 15:38:08 +0800 Subject: [PATCH 17/18] opt RoleZeroLongTermMemory init --- metagpt/roles/di/role_zero.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/metagpt/roles/di/role_zero.py b/metagpt/roles/di/role_zero.py index d2e132c4b..15e09db67 100644 --- a/metagpt/roles/di/role_zero.py +++ b/metagpt/roles/di/role_zero.py @@ -175,11 +175,9 @@ class RoleZero(Role): if self.config.role_zero.enable_longterm_memory: self.rc.memory = RoleZeroLongTermMemory( + **self.rc.memory.model_dump(), collection_name=self.name.replace(" ", ""), memory_k=self.memory_k, - storage=self.rc.memory.storage, - index=self.rc.memory.index, - ignore_id=self.rc.memory.ignore_id, ) logger.info(f"Long-term memory set for role '{self.name}'") From c69752a600517507c45433dc54c5c5ef8f1350fa Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Tue, 24 Sep 2024 11:12:22 +0800 Subject: [PATCH 18/18] update comment --- metagpt/memory/role_zero_memory.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/metagpt/memory/role_zero_memory.py b/metagpt/memory/role_zero_memory.py index 4840b5871..857f2473b 100644 --- a/metagpt/memory/role_zero_memory.py +++ b/metagpt/memory/role_zero_memory.py @@ -1,3 +1,8 @@ +""" +This module implements a memory system combining short-term and long-term storage for AI role memory management. +It utilizes a RAG (Retrieval-Augmented Generation) engine for long-term memory storage and retrieval. +""" + from typing import TYPE_CHECKING, Any, Optional from pydantic import Field @@ -17,6 +22,12 @@ if TYPE_CHECKING: class RoleZeroLongTermMemory(Memory): + """ + Implements a memory system combining short-term and long-term storage using a RAG engine. + Transfers old memories to long-term storage when short-term capacity is reached. + Retrieves combined short-term and long-term memories as needed. + """ + persist_path: str = Field(default=".role_memory_data", description="The directory to save data.") collection_name: str = Field(default="role_zero", description="The name of the collection, such as the role name.") memory_k: int = Field(default=200, description="The capacity of short-term memory.")