diff --git a/metagpt/const.py b/metagpt/const.py index c53e8494a..e7a0dc31b 100644 --- a/metagpt/const.py +++ b/metagpt/const.py @@ -157,3 +157,6 @@ SWE_SETUP_PATH = get_metagpt_package_root() / "metagpt/tools/swe_agent_commands/ # experience pool EXPERIENCE_MASK = "" + +# Used to identify user requirements in the memory index. +USER_REQUIREMENT = "metagpt.actions.add_requirement.UserRequirement" diff --git a/metagpt/roles/di/role_zero.py b/metagpt/roles/di/role_zero.py index 5818bc1f3..9e64a1954 100644 --- a/metagpt/roles/di/role_zero.py +++ b/metagpt/roles/di/role_zero.py @@ -13,7 +13,7 @@ from metagpt.actions import Action, UserRequirement from metagpt.actions.analyze_requirements import AnalyzeRequirementsRestrictions from metagpt.actions.di.run_command import RunCommand from metagpt.actions.search_enhanced_qa import SearchEnhancedQA -from metagpt.const import IMAGES +from metagpt.const import IMAGES, USER_REQUIREMENT from metagpt.exp_pool import exp_cache from metagpt.exp_pool.context_builders import RoleZeroContextBuilder from metagpt.exp_pool.serializers import RoleZeroSerializer @@ -602,10 +602,10 @@ class RoleZero(Role): memories = self.rc.memory.get(k) - if not self._should_use_longterm_memory(k=k, k_memories=memories): + if not self._should_use_longterm_memory(k=k): return memories - query = self._build_longterm_memory_query(memories) + query = self._build_longterm_memory_query() related_memories = self.longterm_memory.fetch(query) logger.info(f"Fetched {len(related_memories)} long-term memories.") @@ -625,19 +625,17 @@ class RoleZero(Role): self._transfer_to_longterm_memory() - def _should_use_longterm_memory(self, k: int = None, k_memories: list[Message] = None) -> bool: + def _should_use_longterm_memory(self, k: int = None) -> bool: """Determines if long-term memory should be used. Long-term memory is used if: - k is not 0. - - k_memories is None or k_memories is not empty, and the last message is a user message. - Long-term memory usage is enabled. - The count of recent memories is greater than self.memory_k. """ conds = [ k != 0, - k_memories is None or self._is_last_message_from_user(k_memories), self.enable_longterm_memory, self.rc.memory.count() > self.memory_k, ] @@ -662,16 +660,19 @@ class RoleZero(Role): return LongTermMemoryItem(user_message=user_message, ai_message=message) - def _is_last_message_from_user(self, memories: list[Message]) -> bool: - return bool(memories and memories[-1].is_user_message()) - def _is_first_message_from_ai(self, memories: list[Message]) -> bool: return bool(memories and memories[0].is_ai_message()) - def _build_longterm_memory_query(self, memories: list[Message]) -> str: + def _build_longterm_memory_query(self) -> str: """Build the content used to query related long-term memory. Default is to get the most recent user message, or an empty string if none is found. """ + message = self._get_the_last_user_message() - return next((m.content for m in reversed(memories) if m.is_real_user_message()), "") + return message.content if message else "" + + def _get_the_last_user_message(self) -> Message: + values = self.rc.memory.index.get(USER_REQUIREMENT, []) + + return values[-1] if values else None diff --git a/metagpt/schema.py b/metagpt/schema.py index 8481bccf3..9352664e2 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -415,9 +415,6 @@ class Message(BaseModel): def is_ai_message(self) -> bool: return self.role == "assistant" - def is_real_user_message(self) -> bool: - return self.is_user_message() and "UserRequirement" in self.cause_by - class UserMessage(Message): """便于支持OpenAI的消息 diff --git a/tests/metagpt/roles/di/test_role_zero.py b/tests/metagpt/roles/di/test_role_zero.py index 964d456a7..0d427ce0f 100644 --- a/tests/metagpt/roles/di/test_role_zero.py +++ b/tests/metagpt/roles/di/test_role_zero.py @@ -101,7 +101,7 @@ class TestRoleZero: mock_role_zero.rc.memory.get.assert_called_once_with(really_k) if k != 0: - mock_role_zero._should_use_longterm_memory.assert_called_once_with(k=really_k, k_memories=memories) + mock_role_zero._should_use_longterm_memory.assert_called_once_with(k=really_k) if should_use_ltm: mock_role_zero.longterm_memory.fetch.assert_called_once_with("user") @@ -174,18 +174,6 @@ class TestRoleZero: mock_role_zero.rc.memory.get_by_position.assert_any_call(-(mock_role_zero.memory_k + 1)) mock_role_zero.rc.memory.get_by_position.assert_any_call(-(mock_role_zero.memory_k + 2)) - @pytest.mark.parametrize( - "memories,expected", - [ - ([UserMessage(content="user")], True), - ([AIMessage(content="ai")], False), - ([], False), - ], - ) - def test_is_last_message_from_user(self, mock_role_zero: RoleZero, memories, expected): - result = mock_role_zero._is_last_message_from_user(memories) - assert result == expected - @pytest.mark.parametrize( "memories,expected", [