use memory index to get the last user message

This commit is contained in:
seehi 2024-09-10 20:05:56 +08:00
parent 5d21d255e4
commit d077cd0b2f
4 changed files with 16 additions and 27 deletions

View file

@ -157,3 +157,6 @@ SWE_SETUP_PATH = get_metagpt_package_root() / "metagpt/tools/swe_agent_commands/
# experience pool
EXPERIENCE_MASK = "<experience>"
# Used to identify user requirements in the memory index.
USER_REQUIREMENT = "metagpt.actions.add_requirement.UserRequirement"

View file

@ -13,7 +13,7 @@ from metagpt.actions import Action, UserRequirement
from metagpt.actions.analyze_requirements import AnalyzeRequirementsRestrictions
from metagpt.actions.di.run_command import RunCommand
from metagpt.actions.search_enhanced_qa import SearchEnhancedQA
from metagpt.const import IMAGES
from metagpt.const import IMAGES, USER_REQUIREMENT
from metagpt.exp_pool import exp_cache
from metagpt.exp_pool.context_builders import RoleZeroContextBuilder
from metagpt.exp_pool.serializers import RoleZeroSerializer
@ -602,10 +602,10 @@ class RoleZero(Role):
memories = self.rc.memory.get(k)
if not self._should_use_longterm_memory(k=k, k_memories=memories):
if not self._should_use_longterm_memory(k=k):
return memories
query = self._build_longterm_memory_query(memories)
query = self._build_longterm_memory_query()
related_memories = self.longterm_memory.fetch(query)
logger.info(f"Fetched {len(related_memories)} long-term memories.")
@ -625,19 +625,17 @@ class RoleZero(Role):
self._transfer_to_longterm_memory()
def _should_use_longterm_memory(self, k: int = None, k_memories: list[Message] = None) -> bool:
def _should_use_longterm_memory(self, k: int = None) -> bool:
"""Determines if long-term memory should be used.
Long-term memory is used if:
- k is not 0.
- k_memories is None or k_memories is not empty, and the last message is a user message.
- Long-term memory usage is enabled.
- The count of recent memories is greater than self.memory_k.
"""
conds = [
k != 0,
k_memories is None or self._is_last_message_from_user(k_memories),
self.enable_longterm_memory,
self.rc.memory.count() > self.memory_k,
]
@ -662,16 +660,19 @@ class RoleZero(Role):
return LongTermMemoryItem(user_message=user_message, ai_message=message)
def _is_last_message_from_user(self, memories: list[Message]) -> bool:
return bool(memories and memories[-1].is_user_message())
def _is_first_message_from_ai(self, memories: list[Message]) -> bool:
return bool(memories and memories[0].is_ai_message())
def _build_longterm_memory_query(self, memories: list[Message]) -> str:
def _build_longterm_memory_query(self) -> str:
"""Build the content used to query related long-term memory.
Default is to get the most recent user message, or an empty string if none is found.
"""
message = self._get_the_last_user_message()
return next((m.content for m in reversed(memories) if m.is_real_user_message()), "")
return message.content if message else ""
def _get_the_last_user_message(self) -> Message:
values = self.rc.memory.index.get(USER_REQUIREMENT, [])
return values[-1] if values else None

View file

@ -415,9 +415,6 @@ class Message(BaseModel):
def is_ai_message(self) -> bool:
return self.role == "assistant"
def is_real_user_message(self) -> bool:
return self.is_user_message() and "UserRequirement" in self.cause_by
class UserMessage(Message):
"""便于支持OpenAI的消息

View file

@ -101,7 +101,7 @@ class TestRoleZero:
mock_role_zero.rc.memory.get.assert_called_once_with(really_k)
if k != 0:
mock_role_zero._should_use_longterm_memory.assert_called_once_with(k=really_k, k_memories=memories)
mock_role_zero._should_use_longterm_memory.assert_called_once_with(k=really_k)
if should_use_ltm:
mock_role_zero.longterm_memory.fetch.assert_called_once_with("user")
@ -174,18 +174,6 @@ class TestRoleZero:
mock_role_zero.rc.memory.get_by_position.assert_any_call(-(mock_role_zero.memory_k + 1))
mock_role_zero.rc.memory.get_by_position.assert_any_call(-(mock_role_zero.memory_k + 2))
@pytest.mark.parametrize(
"memories,expected",
[
([UserMessage(content="user")], True),
([AIMessage(content="ai")], False),
([], False),
],
)
def test_is_last_message_from_user(self, mock_role_zero: RoleZero, memories, expected):
result = mock_role_zero._is_last_message_from_user(memories)
assert result == expected
@pytest.mark.parametrize(
"memories,expected",
[