From c0d5c031f8c2d8a2f3f45a40cce3294733182e7b Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Wed, 11 Sep 2024 20:25:29 +0800 Subject: [PATCH] Change the way of taking over memory --- metagpt/const.py | 4 +- metagpt/environment/mgx/mgx_env.py | 4 +- metagpt/memory/role_zero_memory.py | 130 +++++++++-- metagpt/roles/di/role_zero.py | 130 ++--------- metagpt/roles/di/team_leader.py | 5 +- metagpt/schema.py | 5 +- tests/metagpt/memory/test_role_zero_memory.py | 191 +++++++++++----- tests/metagpt/roles/di/test_role_zero.py | 208 ------------------ 8 files changed, 270 insertions(+), 407 deletions(-) delete mode 100644 tests/metagpt/roles/di/test_role_zero.py diff --git a/metagpt/const.py b/metagpt/const.py index e7a0dc31b..4fe8dca3d 100644 --- a/metagpt/const.py +++ b/metagpt/const.py @@ -158,5 +158,5 @@ SWE_SETUP_PATH = get_metagpt_package_root() / "metagpt/tools/swe_agent_commands/ # experience pool EXPERIENCE_MASK = "" -# Used to identify user requirements in the memory index. -USER_REQUIREMENT = "metagpt.actions.add_requirement.UserRequirement" +# TeamLeader's name +TEAMLEADER_NAME = "Mike" diff --git a/metagpt/environment/mgx/mgx_env.py b/metagpt/environment/mgx/mgx_env.py index 4df04d3ce..53690f7d7 100644 --- a/metagpt/environment/mgx/mgx_env.py +++ b/metagpt/environment/mgx/mgx_env.py @@ -8,7 +8,7 @@ from metagpt.actions import ( WriteTest, ) from metagpt.actions.summarize_code import SummarizeCode -from metagpt.const import AGENT, IMAGES +from metagpt.const import AGENT, IMAGES, TEAMLEADER_NAME from metagpt.environment.base_env import Environment from metagpt.logs import get_human_input from metagpt.roles import Architect, ProductManager, ProjectManager, Role @@ -31,7 +31,7 @@ class MGXEnv(Environment, SerializationMixin): """let the team leader take over message publishing""" message = self.attach_images(message) # for multi-modal message - tl = self.get_role("Mike") # TeamLeader's name is Mike + tl = self.get_role(TEAMLEADER_NAME) # TeamLeader's name is Mike if user_defined_recipient: # human user's direct chat message to a certain role diff --git a/metagpt/memory/role_zero_memory.py b/metagpt/memory/role_zero_memory.py index 7f777df56..ff403d129 100644 --- a/metagpt/memory/role_zero_memory.py +++ b/metagpt/memory/role_zero_memory.py @@ -1,8 +1,14 @@ -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Optional -from pydantic import BaseModel, Field +from pydantic import Field +from metagpt.actions import UserRequirement +from metagpt.const import TEAMLEADER_NAME +from metagpt.logs import logger +from metagpt.memory import Memory from metagpt.schema import LongTermMemoryItem, Message +from metagpt.utils.common import any_to_str +from metagpt.utils.exceptions import handle_exception if TYPE_CHECKING: from llama_index.core.schema import NodeWithScore @@ -10,9 +16,10 @@ if TYPE_CHECKING: from metagpt.rag.engines import SimpleEngine -class RoleZeroLongTermMemory(BaseModel): +class RoleZeroLongTermMemory(Memory): persist_path: str = Field(default=".role_memory_data", description="The directory to save data.") collection_name: str = Field(default="role_zero", description="The name of the collection, such as the role name.") + memory_k: int = Field(default=200, description="The capacity of short-term memory.") _rag_engine: Any = None @@ -44,7 +51,104 @@ class RoleZeroLongTermMemory(BaseModel): return rag_engine - def fetch(self, query: str) -> list[Message]: + def add(self, message: Message): + """Add a new message and potentially transfer it to long-term memory.""" + + super().add(message) + + if not self._should_use_longterm_memory_for_add(): + return + + self._transfer_to_longterm_memory() + + def get(self, k=0) -> list[Message]: + """Return recent memories and optionally combines them with related long-term memories.""" + + memories = super().get(k) + + if not self._should_use_longterm_memory_for_get(k=k): + return memories + + query = self._build_longterm_memory_query() + related_memories = self._fetch_longterm_memories(query) + logger.info(f"Fetched {len(related_memories)} long-term memories.") + + final_memories = related_memories + memories + + return final_memories + + def _should_use_longterm_memory_for_add(self) -> bool: + """Determines if long-term memory should be used for add.""" + + return self.count() > self.memory_k + + def _should_use_longterm_memory_for_get(self, k: int) -> bool: + """Determines if long-term memory should be used for get. + + Long-term memory is used if: + - k is not 0. + - The last message is from user requirement. + - The count of recent memories is greater than self.memory_k. + """ + + conds = [ + k != 0, + self._is_last_message_from_user_requirement(), + self.count() > self.memory_k, + ] + + return all(conds) + + def _transfer_to_longterm_memory(self): + item = self._get_longterm_memory_item() + self._add_to_longterm_memory(item) + + @handle_exception + def _get_longterm_memory_item(self) -> Optional[LongTermMemoryItem]: + """Retrieves the most recent message before the last k messages.""" + + index = -(self.memory_k + 1) + message = self.get_by_position(index) + + return LongTermMemoryItem(message=message) + + def _add_to_longterm_memory(self, item: LongTermMemoryItem): + """Adds a long-term memory item to the RAG engine.""" + + if not item: + return + + self.rag_engine.add_objs([item]) + + def _build_longterm_memory_query(self) -> str: + """Build the content used to query related long-term memory. + + Default is to get the most recent user message, or an empty string if none is found. + """ + + message = self._get_the_last_message() + + return message.content if message else "" + + def _get_the_last_message(self) -> Optional[Message]: + if not self.count(): + return None + + return self.get_by_position(-1) + + def _is_last_message_from_user_requirement(self) -> bool: + message = self._get_the_last_message() + + if not message: + return False + + is_user_message = message.is_user_message() + cause_by_user_requirement = message.cause_by == any_to_str(UserRequirement) + sent_from_team_leader = message.sent_from == TEAMLEADER_NAME + + return is_user_message and (cause_by_user_requirement or sent_from_team_leader) + + def _fetch_longterm_memories(self, query: str) -> list[Message]: """Fetches long-term memories based on a query. Args: @@ -59,26 +163,10 @@ class RoleZeroLongTermMemory(BaseModel): nodes = self.rag_engine.retrieve(query) items = self._get_items_from_nodes(nodes) - - memories = [] - for item in items: - memories.append(item.user_message) - memories.append(item.ai_message) + memories = [item.message for item in items] return memories - def add(self, item: LongTermMemoryItem): - """Adds a long-term memory item to the RAG engine. - - Args: - item (LongTermMemoryItem): The memory item containing user and AI messages. - """ - - if not item: - return - - self.rag_engine.add_objs([item]) - def _get_items_from_nodes(self, nodes: list["NodeWithScore"]) -> list[LongTermMemoryItem]: """Get items from nodes and arrange them in order of their `created_at`.""" diff --git a/metagpt/roles/di/role_zero.py b/metagpt/roles/di/role_zero.py index 9e64a1954..466d87f5c 100644 --- a/metagpt/roles/di/role_zero.py +++ b/metagpt/roles/di/role_zero.py @@ -13,7 +13,7 @@ from metagpt.actions import Action, UserRequirement from metagpt.actions.analyze_requirements import AnalyzeRequirementsRestrictions from metagpt.actions.di.run_command import RunCommand from metagpt.actions.search_enhanced_qa import SearchEnhancedQA -from metagpt.const import IMAGES, USER_REQUIREMENT +from metagpt.const import IMAGES from metagpt.exp_pool import exp_cache from metagpt.exp_pool.context_builders import RoleZeroContextBuilder from metagpt.exp_pool.serializers import RoleZeroSerializer @@ -34,7 +34,7 @@ from metagpt.prompts.di.role_zero import ( SYSTEM_PROMPT, ) from metagpt.roles import Role -from metagpt.schema import AIMessage, LongTermMemoryItem, Message, UserMessage +from metagpt.schema import AIMessage, Message, UserMessage from metagpt.strategy.experience_retriever import DummyExpRetriever, ExpRetriever from metagpt.strategy.planner import Planner from metagpt.tools.libs.browser import Browser @@ -42,7 +42,6 @@ from metagpt.tools.libs.editor import Editor from metagpt.tools.tool_recommend import BM25ToolRecommender, ToolRecommender from metagpt.tools.tool_registry import register_tool from metagpt.utils.common import CodeParser, any_to_str, extract_and_encode_images -from metagpt.utils.exceptions import handle_exception from metagpt.utils.repair_llm_raw_output import ( RepairType, repair_escape_error, @@ -94,7 +93,6 @@ class RoleZero(Role): commands: list[dict] = [] # commands to be executed memory_k: int = 200 # number of memories (messages) to use as historical context enable_longterm_memory: bool = True # whether to use longterm memory - longterm_memory: RoleZeroLongTermMemory = None use_fixed_sop: bool = False requirements_constraints: str = "" # the constraints in user requirements use_summary: bool = True # whether to summarize at the end @@ -176,8 +174,8 @@ class RoleZero(Role): The role name will be used as the collection name. """ - if self.enable_longterm_memory and not self.longterm_memory: - self.longterm_memory = RoleZeroLongTermMemory(collection_name=self.name.replace(" ", "")) + if self.enable_longterm_memory: + self.rc.memory = RoleZeroLongTermMemory(collection_name=self.name.replace(" ", ""), memory_k=self.memory_k) return self @@ -195,7 +193,7 @@ class RoleZero(Role): return False if not self.planner.plan.goal: - self.planner.plan.goal = self._get_all_memories()[-1].content + self.planner.plan.goal = self.get_memories()[-1].content self.requirements_constraints = await AnalyzeRequirementsRestrictions().run(self.planner.plan.goal) ### 1. Experience ### @@ -227,7 +225,7 @@ class RoleZero(Role): ) ### Recent Observation ### - memory = self._fetch_memories() + memory = self.rc.memory.get(self.memory_k) memory = await self.parse_browser_actions(memory) memory = self.parse_images(memory) @@ -282,15 +280,15 @@ class RoleZero(Role): return await super()._act() commands, ok, self.command_rsp = await self._parse_commands(self.command_rsp) - self._add_memory(AIMessage(content=self.command_rsp)) + self.rc.memory.add(AIMessage(content=self.command_rsp)) if not ok: error_msg = commands - self._add_memory(UserMessage(content=error_msg, cause_by=RunCommand)) + self.rc.memory.add(UserMessage(content=error_msg, cause_by=RunCommand)) return error_msg logger.info(f"Commands: \n{commands}") outputs = await self._run_commands(commands) logger.info(f"Commands outputs: \n{outputs}") - self._add_memory(UserMessage(content=outputs, cause_by=RunCommand)) + self.rc.memory.add(UserMessage(content=outputs, cause_by=RunCommand)) return AIMessage( content=f"I have finished the task, please mark my task as finished. Outputs: {outputs}", @@ -343,7 +341,7 @@ class RoleZero(Role): return rsp_msg, "" # routing - memory = self._fetch_memories() + memory = self.get_memories(k=self.memory_k) context = self.llm.format_msg(memory + [UserMessage(content=QUICK_THINK_PROMPT)]) async with ThoughtReporter() as reporter: await reporter.async_report({"type": "classify"}) @@ -368,7 +366,7 @@ class RoleZero(Role): answer = await SearchEnhancedQA().run(query) if answer: - self._add_memory(AIMessage(content=answer, cause_by=RunCommand)) + self.rc.memory.add(AIMessage(content=answer, cause_by=RunCommand)) await self.reply_to_human(content=answer) rsp_msg = AIMessage( content="Complete run", @@ -379,7 +377,7 @@ class RoleZero(Role): return rsp_msg, intent_result async def _check_duplicates(self, req: list[dict], command_rsp: str): - past_rsp = [mem.content for mem in self._fetch_memories()] + past_rsp = [mem.content for mem in self.rc.memory.get(self.memory_k)] if command_rsp in past_rsp: # Normal response with thought contents are highly unlikely to reproduce # If an identical response is detected, it is a bad response, mostly due to LLM repeating generated content @@ -537,7 +535,7 @@ class RoleZero(Role): def _retrieve_experience(self) -> str: """Default implementation of experience retrieval. Can be overwritten in subclasses.""" - context = [str(msg) for msg in self._fetch_memories()] + context = [str(msg) for msg in self.rc.memory.get(self.memory_k)] context = "\n\n".join(context) example = self.experience_retriever.retrieve(context=context) return example @@ -562,9 +560,9 @@ class RoleZero(Role): async def _end(self, **kwarg): self._set_state(-1) - memory = self._fetch_memories() + memory = self.rc.memory.get(self.memory_k) # Ensure reply to the human before the "end" command is executed. Hard code k=5 for checking. - if not any(["reply_to_human" in memory.content for memory in self._fetch_memories(k=5)]): + if not any(["reply_to_human" in memory.content for memory in self.get_memories(k=5)]): logger.info("manually reply to human") pattern = r"\[Language Restrictions\](.*?)\n" match = re.search(pattern, self.requirements_constraints, re.DOTALL) @@ -573,106 +571,10 @@ class RoleZero(Role): await reporter.async_report({"type": "quick"}) reply_content = await self.llm.aask(self.llm.format_msg(memory + [UserMessage(reply_to_human_prompt)])) await self.reply_to_human(content=reply_content) - self._add_memory(AIMessage(content=reply_content, cause_by=RunCommand)) + self.rc.memory.add(AIMessage(content=reply_content, cause_by=RunCommand)) outputs = "" # Summary of the Completed Task and Deliverables if self.use_summary: logger.info("end current run and summarize") outputs = await self.llm.aask(self.llm.format_msg(memory + [UserMessage(SUMMARY_PROMPT)])) return outputs - - def _get_all_memories(self) -> list[Message]: - return self._fetch_memories(k=0) - - def _fetch_memories(self, k: Optional[int] = None) -> list[Message]: - """Fetches recent memories and optionally combines them with related long-term memories. - - If long-term memory is not enabled or the last message is not from the user, - it returns the recent memories without fetching from long-term memory. - - Args: - k (Optional[int]): The number of recent memories to fetch. If None, defaults to self.memory_k. - - Returns: - List[Message]: A list of messages representing the combined memories. - """ - - if k is None: - k = self.memory_k - - memories = self.rc.memory.get(k) - - if not self._should_use_longterm_memory(k=k): - return memories - - query = self._build_longterm_memory_query() - related_memories = self.longterm_memory.fetch(query) - logger.info(f"Fetched {len(related_memories)} long-term memories.") - - # Keep user and AI messages are paired. - if self._is_first_message_from_ai(memories): - memories.insert(0, self.rc.memory.get_by_position(-(k + 1))) - - final_memories = related_memories + memories - - return final_memories - - def _add_memory(self, message: Message): - self.rc.memory.add(message) - - if not self._should_use_longterm_memory(): - return - - self._transfer_to_longterm_memory() - - def _should_use_longterm_memory(self, k: int = None) -> bool: - """Determines if long-term memory should be used. - - Long-term memory is used if: - - k is not 0. - - Long-term memory usage is enabled. - - The count of recent memories is greater than self.memory_k. - """ - - conds = [ - k != 0, - self.enable_longterm_memory, - self.rc.memory.count() > self.memory_k, - ] - - return all(conds) - - def _transfer_to_longterm_memory(self): - item = self._get_longterm_memory_item() - self.longterm_memory.add(item) - - @handle_exception - def _get_longterm_memory_item(self) -> Optional[LongTermMemoryItem]: - """Retrieves the most recent pair of user and AI messages before the last k messages.""" - - index = -(self.memory_k + 1) - message = self.rc.memory.get_by_position(index) - if not message.is_ai_message(): - return None - - index = -(self.memory_k + 2) - user_message = self.rc.memory.get_by_position(index) - - return LongTermMemoryItem(user_message=user_message, ai_message=message) - - def _is_first_message_from_ai(self, memories: list[Message]) -> bool: - return bool(memories and memories[0].is_ai_message()) - - def _build_longterm_memory_query(self) -> str: - """Build the content used to query related long-term memory. - - Default is to get the most recent user message, or an empty string if none is found. - """ - message = self._get_the_last_user_message() - - return message.content if message else "" - - def _get_the_last_user_message(self) -> Message: - values = self.rc.memory.index.get(USER_REQUIREMENT, []) - - return values[-1] if values else None diff --git a/metagpt/roles/di/team_leader.py b/metagpt/roles/di/team_leader.py index 1bf6364d9..0724ffdea 100644 --- a/metagpt/roles/di/team_leader.py +++ b/metagpt/roles/di/team_leader.py @@ -5,6 +5,7 @@ from typing import Annotated from pydantic import Field from metagpt.actions.di.run_command import RunCommand +from metagpt.const import TEAMLEADER_NAME from metagpt.prompts.di.team_leader import ( FINISH_CURRENT_TASK_CMD, TL_INFO, @@ -19,7 +20,7 @@ from metagpt.tools.tool_registry import register_tool @register_tool(include_functions=["publish_team_message"]) class TeamLeader(RoleZero): - name: str = "Mike" + name: str = TEAMLEADER_NAME profile: str = "Team Leader" goal: str = "Manage a team to assist users" thought_guidance: str = TL_THOUGHT_GUIDANCE @@ -84,4 +85,4 @@ class TeamLeader(RoleZero): def finish_current_task(self): self.planner.plan.finish_current_task() - self._add_memory(AIMessage(content=FINISH_CURRENT_TASK_CMD)) + self.rc.memory.add(AIMessage(content=FINISH_CURRENT_TASK_CMD)) diff --git a/metagpt/schema.py b/metagpt/schema.py index 9352664e2..d4c75bfb4 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -965,9 +965,8 @@ class BaseEnum(Enum): class LongTermMemoryItem(BaseModel): - user_message: Message - ai_message: Message + message: Message created_at: Optional[float] = Field(default_factory=time.time) def rag_key(self) -> str: - return self.user_message.content + return self.message.content diff --git a/tests/metagpt/memory/test_role_zero_memory.py b/tests/metagpt/memory/test_role_zero_memory.py index 1c6fb785e..80eb58e49 100644 --- a/tests/metagpt/memory/test_role_zero_memory.py +++ b/tests/metagpt/memory/test_role_zero_memory.py @@ -2,8 +2,10 @@ from datetime import datetime, timedelta import pytest +from metagpt.actions import UserRequirement +from metagpt.const import TEAMLEADER_NAME from metagpt.memory.role_zero_memory import RoleZeroLongTermMemory -from metagpt.schema import AIMessage, LongTermMemoryItem, UserMessage +from metagpt.schema import AIMessage, LongTermMemoryItem, Message, UserMessage class TestRoleZeroLongTermMemory: @@ -13,71 +15,150 @@ class TestRoleZeroLongTermMemory: memory._resolve_rag_engine = mocker.Mock() return memory - def test_fetch_empty_query(self, mock_memory: RoleZeroLongTermMemory): - assert mock_memory.fetch("") == [] + def test_add(self, mocker, mock_memory: RoleZeroLongTermMemory): + mock_memory._should_use_longterm_memory_for_add = mocker.Mock(return_value=True) + mock_memory._transfer_to_longterm_memory = mocker.Mock() - def test_fetch(self, mocker, mock_memory: RoleZeroLongTermMemory): - mock_node1 = mocker.Mock() - mock_node2 = mocker.Mock() - mock_node1.metadata = { - "obj": LongTermMemoryItem(user_message=UserMessage(content="user1"), ai_message=AIMessage(content="ai1")) - } - mock_node2.metadata = { - "obj": LongTermMemoryItem(user_message=UserMessage(content="user2"), ai_message=AIMessage(content="ai2")) - } + message = UserMessage(content="test") + mock_memory.add(message) - mock_memory.rag_engine.retrieve.return_value = [mock_node1, mock_node2] + assert mock_memory.storage[-1] == message + mock_memory._transfer_to_longterm_memory.assert_called_once() - result = mock_memory.fetch("test query") + def test_get(self, mocker, mock_memory: RoleZeroLongTermMemory): + mock_memory._should_use_longterm_memory_for_get = mocker.Mock(return_value=True) + mock_memory._build_longterm_memory_query = mocker.Mock(return_value="query") + mock_memory._fetch_longterm_memories = mocker.Mock(return_value=[Message(content="long-term")]) - assert len(result) == 4 - assert isinstance(result[0], UserMessage) - assert isinstance(result[1], AIMessage) - assert result[0].content == "user1" - assert result[1].content == "ai1" - assert result[2].content == "user2" - assert result[3].content == "ai2" + mock_memory.storage = [Message(content="short-term")] - mock_memory.rag_engine.retrieve.assert_called_once_with("test query") + result = mock_memory.get() - def test_add_empty_item(self, mock_memory: RoleZeroLongTermMemory): - mock_memory.add(None) - mock_memory.rag_engine.add_objs.assert_not_called() + assert len(result) == 2 + assert result[0].content == "long-term" + assert result[1].content == "short-term" + + def test_should_use_longterm_memory_for_add(self, mocker, mock_memory: RoleZeroLongTermMemory): + mocker.patch.object(mock_memory, "storage", [None] * 201) + + mock_memory.memory_k = 200 + + assert mock_memory._should_use_longterm_memory_for_add() == True + + mocker.patch.object(mock_memory, "storage", [None] * 199) + assert mock_memory._should_use_longterm_memory_for_add() == False + + @pytest.mark.parametrize( + "k,is_last_from_user,count,expected", + [ + (0, True, 201, False), + (1, False, 201, False), + (1, True, 199, False), + (1, True, 201, True), + ], + ) + def test_should_use_longterm_memory_for_get( + self, mocker, mock_memory: RoleZeroLongTermMemory, k, is_last_from_user, count, expected + ): + mock_memory._is_last_message_from_user_requirement = mocker.Mock(return_value=is_last_from_user) + mocker.patch.object(mock_memory, "storage", [None] * count) + mock_memory.memory_k = 200 + + assert mock_memory._should_use_longterm_memory_for_get(k) == expected + + def test_transfer_to_longterm_memory(self, mocker, mock_memory: RoleZeroLongTermMemory): + mock_item = mocker.Mock() + mock_memory._get_longterm_memory_item = mocker.Mock(return_value=mock_item) + mock_memory._add_to_longterm_memory = mocker.Mock() + + mock_memory._transfer_to_longterm_memory() + + mock_memory._add_to_longterm_memory.assert_called_once_with(mock_item) + + def test_get_longterm_memory_item(self, mocker, mock_memory: RoleZeroLongTermMemory): + mock_message = Message(content="test") + mock_memory.storage = [mock_message, mock_message] + mock_memory.memory_k = 1 + + result = mock_memory._get_longterm_memory_item() + + assert isinstance(result, LongTermMemoryItem) + assert result.message == mock_message + + def test_add_to_longterm_memory(self, mock_memory: RoleZeroLongTermMemory): + item = LongTermMemoryItem(message=Message(content="test")) + mock_memory._add_to_longterm_memory(item) - def test_add_item(self, mock_memory: RoleZeroLongTermMemory): - item = LongTermMemoryItem(user_message=UserMessage(content="user"), ai_message=AIMessage(content="ai")) - mock_memory.add(item) mock_memory.rag_engine.add_objs.assert_called_once_with([item]) + def test_build_longterm_memory_query(self, mocker, mock_memory: RoleZeroLongTermMemory): + mock_message = Message(content="query") + mock_memory._get_the_last_message = mocker.Mock(return_value=mock_message) + + result = mock_memory._build_longterm_memory_query() + + assert result == "query" + + def test_get_the_last_message(self, mock_memory: RoleZeroLongTermMemory): + mock_memory.storage = [Message(content="1"), Message(content="2")] + + result = mock_memory._get_the_last_message() + + assert result.content == "2" + + @pytest.mark.parametrize( + "message,expected", + [ + (UserMessage(content="test", cause_by=UserRequirement), True), + (UserMessage(content="test", sent_from=TEAMLEADER_NAME), True), + (UserMessage(content="test"), True), + (AIMessage(content="test"), False), + (None, False), + ], + ) + def test_is_last_message_from_user_requirement( + self, mocker, mock_memory: RoleZeroLongTermMemory, message, expected + ): + mock_memory._get_the_last_message = mocker.Mock(return_value=message) + + assert mock_memory._is_last_message_from_user_requirement() == expected + + def test_fetch_longterm_memories(self, mocker, mock_memory: RoleZeroLongTermMemory): + mock_nodes = [mocker.Mock(), mocker.Mock()] + mock_memory.rag_engine.retrieve = mocker.Mock(return_value=mock_nodes) + mock_items = [ + LongTermMemoryItem(message=UserMessage(content="user1")), + LongTermMemoryItem(message=AIMessage(content="ai1")), + ] + mock_memory._get_items_from_nodes = mocker.Mock(return_value=mock_items) + + result = mock_memory._fetch_longterm_memories("query") + + assert len(result) == 2 + assert result[0].content == "user1" + assert result[1].content == "ai1" + def test_get_items_from_nodes(self, mocker, mock_memory: RoleZeroLongTermMemory): - mock_node1 = mocker.Mock() - mock_node2 = mocker.Mock() - mock_node3 = mocker.Mock() - now = datetime.now() - item1 = LongTermMemoryItem( - user_message=UserMessage(content="user1"), ai_message=AIMessage(content="ai1"), created_at=now.timestamp() - ) - item2 = LongTermMemoryItem( - user_message=UserMessage(content="user2"), - ai_message=AIMessage(content="ai2"), - created_at=(now - timedelta(minutes=5)).timestamp(), - ) - item3 = LongTermMemoryItem( - user_message=UserMessage(content="user3"), - ai_message=AIMessage(content="ai3"), - created_at=(now + timedelta(minutes=5)).timestamp(), - ) + mock_nodes = [ + mocker.Mock( + metadata={ + "obj": LongTermMemoryItem( + message=Message(content="2"), created_at=(now - timedelta(minutes=1)).timestamp() + ) + } + ), + mocker.Mock( + metadata={ + "obj": LongTermMemoryItem( + message=Message(content="1"), created_at=(now - timedelta(minutes=2)).timestamp() + ) + } + ), + mocker.Mock(metadata={"obj": LongTermMemoryItem(message=Message(content="3"), created_at=now.timestamp())}), + ] - mock_node1.metadata = {"obj": item1} - mock_node2.metadata = {"obj": item2} - mock_node3.metadata = {"obj": item3} - - result = mock_memory._get_items_from_nodes([mock_node1, mock_node2, mock_node3]) + result = mock_memory._get_items_from_nodes(mock_nodes) assert len(result) == 3 - assert result[0] == item2 - assert result[1] == item1 - assert result[2] == item3 - assert [item.user_message.content for item in result] == ["user2", "user1", "user3"] - assert [item.ai_message.content for item in result] == ["ai2", "ai1", "ai3"] + assert [item.message.content for item in result] == ["1", "2", "3"] diff --git a/tests/metagpt/roles/di/test_role_zero.py b/tests/metagpt/roles/di/test_role_zero.py deleted file mode 100644 index 0d427ce0f..000000000 --- a/tests/metagpt/roles/di/test_role_zero.py +++ /dev/null @@ -1,208 +0,0 @@ -import pytest - -from metagpt.roles.di.role_zero import RoleZero -from metagpt.schema import AIMessage, LongTermMemoryItem, Message, UserMessage - - -class TestRoleZero: - @pytest.fixture - def mock_role_zero(self, mocker): - role_zero = RoleZero() - role_zero.rc.memory = mocker.Mock() - role_zero.longterm_memory = mocker.Mock() - return role_zero - - def test_get_all_memories(self, mocker, mock_role_zero: RoleZero): - mock_memories = [Message(content="test1"), Message(content="test2")] - mock_role_zero._fetch_memories = mocker.Mock(return_value=mock_memories) - - result = mock_role_zero._get_all_memories() - - assert result == mock_memories - mock_role_zero._fetch_memories.assert_called_once_with(k=0) - - @pytest.mark.parametrize( - "k,should_use_ltm,memories,related_memories,is_first_from_ai,expected", - [ - ( - None, - False, - [UserMessage(content="user"), AIMessage(content="ai")], - [], - False, - [UserMessage(content="user"), AIMessage(content="ai")], - ), - ( - 1, - True, - [UserMessage(content="user")], - [Message(content="related")], - False, - [Message(content="related"), UserMessage(content="user")], - ), - ( - None, - True, - [AIMessage(content="ai1"), UserMessage(content="user"), AIMessage(content="ai2")], - [Message(content="related")], - True, - [ - Message(content="related"), - UserMessage(content="user"), - AIMessage(content="ai1"), - UserMessage(content="user"), - AIMessage(content="ai2"), - ], - ), - ( - None, - True, - [UserMessage(content="user"), AIMessage(content="ai")], - [Message(content="related")], - False, - [Message(content="related"), UserMessage(content="user"), AIMessage(content="ai")], - ), - ( - 0, - False, - [UserMessage(content="user"), AIMessage(content="ai")], - [], - False, - [UserMessage(content="user"), AIMessage(content="ai")], - ), - ], - ) - def test_fetch_memories( - self, - mocker, - mock_role_zero: RoleZero, - k, - should_use_ltm, - memories, - related_memories, - is_first_from_ai, - expected, - ): - mock_role_zero.memory_k = 2 - mock_role_zero.rc.memory.get = mocker.Mock(return_value=memories) - mock_role_zero.rc.memory.get_by_position = mocker.Mock(return_value=UserMessage(content="user")) - mock_role_zero._should_use_longterm_memory = mocker.Mock(return_value=should_use_ltm) - mock_role_zero.longterm_memory.fetch = mocker.Mock(return_value=related_memories) - mock_role_zero._is_first_message_from_ai = mocker.Mock(return_value=is_first_from_ai) - - result = mock_role_zero._fetch_memories(k) - - assert len(result) == len(expected) - for actual, expected_msg in zip(result, expected): - assert actual.role == expected_msg.role - assert actual.content == expected_msg.content - - really_k = k if k is not None else mock_role_zero.memory_k - mock_role_zero.rc.memory.get.assert_called_once_with(really_k) - - if k != 0: - mock_role_zero._should_use_longterm_memory.assert_called_once_with(k=really_k) - - if should_use_ltm: - mock_role_zero.longterm_memory.fetch.assert_called_once_with("user") - mock_role_zero._is_first_message_from_ai.assert_called_once_with(memories) - - def test_add_memory(self, mocker, mock_role_zero: RoleZero): - message = AIMessage(content="ai") - mock_role_zero.rc.memory.add = mocker.Mock() - mock_role_zero._should_use_longterm_memory = mocker.Mock(return_value=True) - mock_role_zero._transfer_to_longterm_memory = mocker.Mock() - - mock_role_zero._add_memory(message) - - mock_role_zero.rc.memory.add.assert_called_once_with(message) - mock_role_zero._transfer_to_longterm_memory.assert_called_once() - - @pytest.mark.parametrize( - "k,enable_longterm_memory,memory_count,k_memories,expected", - [ - (0, True, 30, None, False), # k is 0 - (None, False, 30, None, False), # Long-term memory usage is disabled - (None, True, 10, None, False), # Memory count is less than or equal to memory_k - (None, True, 30, [], False), # k_memories is empty - (None, True, 30, [AIMessage(content="ai")], False), # Last message in k_memories is not a user message - (None, True, 30, [AIMessage(content="ai"), UserMessage(content="user")], True), # All conditions are met - ], - ) - def test_should_use_longterm_memory( - self, mocker, mock_role_zero: RoleZero, k, enable_longterm_memory, memory_count, k_memories, expected - ): - mock_role_zero.enable_longterm_memory = enable_longterm_memory - mock_role_zero.rc.memory.count = mocker.Mock(return_value=memory_count) - mock_role_zero.memory_k = 20 - - result = mock_role_zero._should_use_longterm_memory(k, k_memories) - - assert result == expected - - def test_transfer_to_longterm_memory(self, mocker, mock_role_zero: RoleZero): - mock_item = LongTermMemoryItem(user_message=UserMessage(content="user"), ai_message=AIMessage(content="ai")) - mock_role_zero._get_longterm_memory_item = mocker.Mock(return_value=mock_item) - mock_role_zero.longterm_memory = mocker.Mock() - - mock_role_zero._transfer_to_longterm_memory() - - mock_role_zero.longterm_memory.add.assert_called_once_with(mock_item) - - def test_get_longterm_memory_item(self, mocker, mock_role_zero: RoleZero): - mock_role_zero.memory_k = 2 - mock_messages = [ - UserMessage(content="user1"), - AIMessage(content="ai1"), - UserMessage(content="user2"), - AIMessage(content="ai2"), - UserMessage(content="user3"), # memory_k + 2 - AIMessage(content="ai3"), # memory_k + 1 - UserMessage(content="recent1"), - AIMessage(content="recent2"), - ] - - mock_role_zero.rc.memory.get_by_position = mocker.Mock(side_effect=lambda i: mock_messages[i]) - mock_role_zero.rc.memory.count = mocker.Mock(return_value=len(mock_messages)) - - result = mock_role_zero._get_longterm_memory_item() - - assert isinstance(result, LongTermMemoryItem) - assert result.user_message.content == "user3" - assert result.ai_message.content == "ai3" - - mock_role_zero.rc.memory.get_by_position.assert_any_call(-(mock_role_zero.memory_k + 1)) - mock_role_zero.rc.memory.get_by_position.assert_any_call(-(mock_role_zero.memory_k + 2)) - - @pytest.mark.parametrize( - "memories,expected", - [ - ([AIMessage(content="ai")], True), - ([UserMessage(content="user")], False), - ([], False), - ], - ) - def test_is_first_message_from_ai(self, mock_role_zero: RoleZero, memories, expected): - result = mock_role_zero._is_first_message_from_ai(memories) - assert result == expected - - @pytest.mark.parametrize( - "memories,expected", - [ - ([UserMessage(content="user1"), AIMessage(content="ai"), UserMessage(content="user2")], "user2"), - ( - [ - UserMessage(content="user1", cause_by="test"), - AIMessage(content="ai"), - UserMessage(content="user2", cause_by="test"), - ], - "", - ), - ([AIMessage(content="ai1"), AIMessage(content="ai2")], ""), - ([UserMessage(content="user")], "user"), - ([], ""), - ], - ) - def test_build_longterm_memory_query(self, mock_role_zero: RoleZero, memories, expected): - result = mock_role_zero._build_longterm_memory_query(memories) - assert result == expected