diff --git a/metagpt/memory/role_zero_memory.py b/metagpt/memory/role_zero_memory.py index 570d0cc41..7f777df56 100644 --- a/metagpt/memory/role_zero_memory.py +++ b/metagpt/memory/role_zero_memory.py @@ -57,11 +57,11 @@ class RoleZeroLongTermMemory(BaseModel): if not query: return [] - nodes: list[NodeWithScore] = self.rag_engine.retrieve(query) + nodes = self.rag_engine.retrieve(query) + items = self._get_items_from_nodes(nodes) memories = [] - for node in nodes: - item: LongTermMemoryItem = node.metadata["obj"] + for item in items: memories.append(item.user_message) memories.append(item.ai_message) @@ -78,3 +78,11 @@ class RoleZeroLongTermMemory(BaseModel): return self.rag_engine.add_objs([item]) + + def _get_items_from_nodes(self, nodes: list["NodeWithScore"]) -> list[LongTermMemoryItem]: + """Get items from nodes and arrange them in order of their `created_at`.""" + + items: list[LongTermMemoryItem] = [node.metadata["obj"] for node in nodes] + items.sort(key=lambda item: item.created_at) + + return items diff --git a/metagpt/roles/di/role_zero.py b/metagpt/roles/di/role_zero.py index 34d1812fc..5818bc1f3 100644 --- a/metagpt/roles/di/role_zero.py +++ b/metagpt/roles/di/role_zero.py @@ -285,12 +285,12 @@ class RoleZero(Role): self._add_memory(AIMessage(content=self.command_rsp)) if not ok: error_msg = commands - self._add_memory(UserMessage(content=error_msg)) + self._add_memory(UserMessage(content=error_msg, cause_by=RunCommand)) return error_msg logger.info(f"Commands: \n{commands}") outputs = await self._run_commands(commands) logger.info(f"Commands outputs: \n{outputs}") - self._add_memory(UserMessage(content=outputs)) + self._add_memory(UserMessage(content=outputs, cause_by=RunCommand)) return AIMessage( content=f"I have finished the task, please mark my task as finished. Outputs: {outputs}", @@ -605,7 +605,8 @@ class RoleZero(Role): if not self._should_use_longterm_memory(k=k, k_memories=memories): return memories - related_memories = self.longterm_memory.fetch(memories[-1].content) + query = self._build_longterm_memory_query(memories) + related_memories = self.longterm_memory.fetch(query) logger.info(f"Fetched {len(related_memories)} long-term memories.") # Keep user and AI messages are paired. @@ -666,3 +667,11 @@ class RoleZero(Role): def _is_first_message_from_ai(self, memories: list[Message]) -> bool: return bool(memories and memories[0].is_ai_message()) + + def _build_longterm_memory_query(self, memories: list[Message]) -> str: + """Build the content used to query related long-term memory. + + Default is to get the most recent user message, or an empty string if none is found. + """ + + return next((m.content for m in reversed(memories) if m.is_real_user_message()), "") diff --git a/metagpt/schema.py b/metagpt/schema.py index 63a80f62a..8481bccf3 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -18,6 +18,7 @@ from __future__ import annotations import asyncio import json import os.path +import time import uuid from abc import ABC from asyncio import Queue, QueueEmpty, wait_for @@ -408,12 +409,15 @@ class Message(BaseModel): dynamic_class = create_model(class_name, **{key: (value.__class__, ...) for key, value in kvs.items()}) return dynamic_class.model_validate(kvs) - def is_user_message(self): + def is_user_message(self) -> bool: return self.role == "user" - def is_ai_message(self): + def is_ai_message(self) -> bool: return self.role == "assistant" + def is_real_user_message(self) -> bool: + return self.is_user_message() and "UserRequirement" in self.cause_by + class UserMessage(Message): """便于支持OpenAI的消息 @@ -966,6 +970,7 @@ class BaseEnum(Enum): class LongTermMemoryItem(BaseModel): user_message: Message ai_message: Message + created_at: Optional[float] = Field(default_factory=time.time) def rag_key(self) -> str: return self.user_message.content diff --git a/tests/metagpt/memory/test_role_zero_memory.py b/tests/metagpt/memory/test_role_zero_memory.py index 8e2532bfc..1c6fb785e 100644 --- a/tests/metagpt/memory/test_role_zero_memory.py +++ b/tests/metagpt/memory/test_role_zero_memory.py @@ -1,3 +1,5 @@ +from datetime import datetime, timedelta + import pytest from metagpt.memory.role_zero_memory import RoleZeroLongTermMemory @@ -46,3 +48,36 @@ class TestRoleZeroLongTermMemory: item = LongTermMemoryItem(user_message=UserMessage(content="user"), ai_message=AIMessage(content="ai")) mock_memory.add(item) mock_memory.rag_engine.add_objs.assert_called_once_with([item]) + + def test_get_items_from_nodes(self, mocker, mock_memory: RoleZeroLongTermMemory): + mock_node1 = mocker.Mock() + mock_node2 = mocker.Mock() + mock_node3 = mocker.Mock() + + now = datetime.now() + item1 = LongTermMemoryItem( + user_message=UserMessage(content="user1"), ai_message=AIMessage(content="ai1"), created_at=now.timestamp() + ) + item2 = LongTermMemoryItem( + user_message=UserMessage(content="user2"), + ai_message=AIMessage(content="ai2"), + created_at=(now - timedelta(minutes=5)).timestamp(), + ) + item3 = LongTermMemoryItem( + user_message=UserMessage(content="user3"), + ai_message=AIMessage(content="ai3"), + created_at=(now + timedelta(minutes=5)).timestamp(), + ) + + mock_node1.metadata = {"obj": item1} + mock_node2.metadata = {"obj": item2} + mock_node3.metadata = {"obj": item3} + + result = mock_memory._get_items_from_nodes([mock_node1, mock_node2, mock_node3]) + + assert len(result) == 3 + assert result[0] == item2 + assert result[1] == item1 + assert result[2] == item3 + assert [item.user_message.content for item in result] == ["user2", "user1", "user3"] + assert [item.ai_message.content for item in result] == ["ai2", "ai1", "ai3"] diff --git a/tests/metagpt/roles/di/test_role_zero.py b/tests/metagpt/roles/di/test_role_zero.py index d4d4a46da..964d456a7 100644 --- a/tests/metagpt/roles/di/test_role_zero.py +++ b/tests/metagpt/roles/di/test_role_zero.py @@ -46,7 +46,13 @@ class TestRoleZero: [AIMessage(content="ai1"), UserMessage(content="user"), AIMessage(content="ai2")], [Message(content="related")], True, - [Message(content="related"), UserMessage(content="user"), AIMessage(content="ai2")], + [ + Message(content="related"), + UserMessage(content="user"), + AIMessage(content="ai1"), + UserMessage(content="user"), + AIMessage(content="ai2"), + ], ), ( None, @@ -79,6 +85,7 @@ class TestRoleZero: ): mock_role_zero.memory_k = 2 mock_role_zero.rc.memory.get = mocker.Mock(return_value=memories) + mock_role_zero.rc.memory.get_by_position = mocker.Mock(return_value=UserMessage(content="user")) mock_role_zero._should_use_longterm_memory = mocker.Mock(return_value=should_use_ltm) mock_role_zero.longterm_memory.fetch = mocker.Mock(return_value=related_memories) mock_role_zero._is_first_message_from_ai = mocker.Mock(return_value=is_first_from_ai) @@ -97,7 +104,7 @@ class TestRoleZero: mock_role_zero._should_use_longterm_memory.assert_called_once_with(k=really_k, k_memories=memories) if should_use_ltm: - mock_role_zero.longterm_memory.fetch.assert_called_once_with(memories[-1].content) + mock_role_zero.longterm_memory.fetch.assert_called_once_with("user") mock_role_zero._is_first_message_from_ai.assert_called_once_with(memories) def test_add_memory(self, mocker, mock_role_zero: RoleZero): @@ -190,3 +197,24 @@ class TestRoleZero: def test_is_first_message_from_ai(self, mock_role_zero: RoleZero, memories, expected): result = mock_role_zero._is_first_message_from_ai(memories) assert result == expected + + @pytest.mark.parametrize( + "memories,expected", + [ + ([UserMessage(content="user1"), AIMessage(content="ai"), UserMessage(content="user2")], "user2"), + ( + [ + UserMessage(content="user1", cause_by="test"), + AIMessage(content="ai"), + UserMessage(content="user2", cause_by="test"), + ], + "", + ), + ([AIMessage(content="ai1"), AIMessage(content="ai2")], ""), + ([UserMessage(content="user")], "user"), + ([], ""), + ], + ) + def test_build_longterm_memory_query(self, mock_role_zero: RoleZero, memories, expected): + result = mock_role_zero._build_longterm_memory_query(memories) + assert result == expected