mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-29 19:06:23 +02:00
Change the way of taking over memory
This commit is contained in:
parent
d077cd0b2f
commit
c0d5c031f8
8 changed files with 270 additions and 407 deletions
|
|
@ -2,8 +2,10 @@ from datetime import datetime, timedelta
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions import UserRequirement
|
||||
from metagpt.const import TEAMLEADER_NAME
|
||||
from metagpt.memory.role_zero_memory import RoleZeroLongTermMemory
|
||||
from metagpt.schema import AIMessage, LongTermMemoryItem, UserMessage
|
||||
from metagpt.schema import AIMessage, LongTermMemoryItem, Message, UserMessage
|
||||
|
||||
|
||||
class TestRoleZeroLongTermMemory:
|
||||
|
|
@ -13,71 +15,150 @@ class TestRoleZeroLongTermMemory:
|
|||
memory._resolve_rag_engine = mocker.Mock()
|
||||
return memory
|
||||
|
||||
def test_fetch_empty_query(self, mock_memory: RoleZeroLongTermMemory):
|
||||
assert mock_memory.fetch("") == []
|
||||
def test_add(self, mocker, mock_memory: RoleZeroLongTermMemory):
|
||||
mock_memory._should_use_longterm_memory_for_add = mocker.Mock(return_value=True)
|
||||
mock_memory._transfer_to_longterm_memory = mocker.Mock()
|
||||
|
||||
def test_fetch(self, mocker, mock_memory: RoleZeroLongTermMemory):
|
||||
mock_node1 = mocker.Mock()
|
||||
mock_node2 = mocker.Mock()
|
||||
mock_node1.metadata = {
|
||||
"obj": LongTermMemoryItem(user_message=UserMessage(content="user1"), ai_message=AIMessage(content="ai1"))
|
||||
}
|
||||
mock_node2.metadata = {
|
||||
"obj": LongTermMemoryItem(user_message=UserMessage(content="user2"), ai_message=AIMessage(content="ai2"))
|
||||
}
|
||||
message = UserMessage(content="test")
|
||||
mock_memory.add(message)
|
||||
|
||||
mock_memory.rag_engine.retrieve.return_value = [mock_node1, mock_node2]
|
||||
assert mock_memory.storage[-1] == message
|
||||
mock_memory._transfer_to_longterm_memory.assert_called_once()
|
||||
|
||||
result = mock_memory.fetch("test query")
|
||||
def test_get(self, mocker, mock_memory: RoleZeroLongTermMemory):
|
||||
mock_memory._should_use_longterm_memory_for_get = mocker.Mock(return_value=True)
|
||||
mock_memory._build_longterm_memory_query = mocker.Mock(return_value="query")
|
||||
mock_memory._fetch_longterm_memories = mocker.Mock(return_value=[Message(content="long-term")])
|
||||
|
||||
assert len(result) == 4
|
||||
assert isinstance(result[0], UserMessage)
|
||||
assert isinstance(result[1], AIMessage)
|
||||
assert result[0].content == "user1"
|
||||
assert result[1].content == "ai1"
|
||||
assert result[2].content == "user2"
|
||||
assert result[3].content == "ai2"
|
||||
mock_memory.storage = [Message(content="short-term")]
|
||||
|
||||
mock_memory.rag_engine.retrieve.assert_called_once_with("test query")
|
||||
result = mock_memory.get()
|
||||
|
||||
def test_add_empty_item(self, mock_memory: RoleZeroLongTermMemory):
|
||||
mock_memory.add(None)
|
||||
mock_memory.rag_engine.add_objs.assert_not_called()
|
||||
assert len(result) == 2
|
||||
assert result[0].content == "long-term"
|
||||
assert result[1].content == "short-term"
|
||||
|
||||
def test_should_use_longterm_memory_for_add(self, mocker, mock_memory: RoleZeroLongTermMemory):
|
||||
mocker.patch.object(mock_memory, "storage", [None] * 201)
|
||||
|
||||
mock_memory.memory_k = 200
|
||||
|
||||
assert mock_memory._should_use_longterm_memory_for_add() == True
|
||||
|
||||
mocker.patch.object(mock_memory, "storage", [None] * 199)
|
||||
assert mock_memory._should_use_longterm_memory_for_add() == False
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"k,is_last_from_user,count,expected",
|
||||
[
|
||||
(0, True, 201, False),
|
||||
(1, False, 201, False),
|
||||
(1, True, 199, False),
|
||||
(1, True, 201, True),
|
||||
],
|
||||
)
|
||||
def test_should_use_longterm_memory_for_get(
|
||||
self, mocker, mock_memory: RoleZeroLongTermMemory, k, is_last_from_user, count, expected
|
||||
):
|
||||
mock_memory._is_last_message_from_user_requirement = mocker.Mock(return_value=is_last_from_user)
|
||||
mocker.patch.object(mock_memory, "storage", [None] * count)
|
||||
mock_memory.memory_k = 200
|
||||
|
||||
assert mock_memory._should_use_longterm_memory_for_get(k) == expected
|
||||
|
||||
def test_transfer_to_longterm_memory(self, mocker, mock_memory: RoleZeroLongTermMemory):
|
||||
mock_item = mocker.Mock()
|
||||
mock_memory._get_longterm_memory_item = mocker.Mock(return_value=mock_item)
|
||||
mock_memory._add_to_longterm_memory = mocker.Mock()
|
||||
|
||||
mock_memory._transfer_to_longterm_memory()
|
||||
|
||||
mock_memory._add_to_longterm_memory.assert_called_once_with(mock_item)
|
||||
|
||||
def test_get_longterm_memory_item(self, mocker, mock_memory: RoleZeroLongTermMemory):
|
||||
mock_message = Message(content="test")
|
||||
mock_memory.storage = [mock_message, mock_message]
|
||||
mock_memory.memory_k = 1
|
||||
|
||||
result = mock_memory._get_longterm_memory_item()
|
||||
|
||||
assert isinstance(result, LongTermMemoryItem)
|
||||
assert result.message == mock_message
|
||||
|
||||
def test_add_to_longterm_memory(self, mock_memory: RoleZeroLongTermMemory):
|
||||
item = LongTermMemoryItem(message=Message(content="test"))
|
||||
mock_memory._add_to_longterm_memory(item)
|
||||
|
||||
def test_add_item(self, mock_memory: RoleZeroLongTermMemory):
|
||||
item = LongTermMemoryItem(user_message=UserMessage(content="user"), ai_message=AIMessage(content="ai"))
|
||||
mock_memory.add(item)
|
||||
mock_memory.rag_engine.add_objs.assert_called_once_with([item])
|
||||
|
||||
def test_build_longterm_memory_query(self, mocker, mock_memory: RoleZeroLongTermMemory):
|
||||
mock_message = Message(content="query")
|
||||
mock_memory._get_the_last_message = mocker.Mock(return_value=mock_message)
|
||||
|
||||
result = mock_memory._build_longterm_memory_query()
|
||||
|
||||
assert result == "query"
|
||||
|
||||
def test_get_the_last_message(self, mock_memory: RoleZeroLongTermMemory):
|
||||
mock_memory.storage = [Message(content="1"), Message(content="2")]
|
||||
|
||||
result = mock_memory._get_the_last_message()
|
||||
|
||||
assert result.content == "2"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"message,expected",
|
||||
[
|
||||
(UserMessage(content="test", cause_by=UserRequirement), True),
|
||||
(UserMessage(content="test", sent_from=TEAMLEADER_NAME), True),
|
||||
(UserMessage(content="test"), True),
|
||||
(AIMessage(content="test"), False),
|
||||
(None, False),
|
||||
],
|
||||
)
|
||||
def test_is_last_message_from_user_requirement(
|
||||
self, mocker, mock_memory: RoleZeroLongTermMemory, message, expected
|
||||
):
|
||||
mock_memory._get_the_last_message = mocker.Mock(return_value=message)
|
||||
|
||||
assert mock_memory._is_last_message_from_user_requirement() == expected
|
||||
|
||||
def test_fetch_longterm_memories(self, mocker, mock_memory: RoleZeroLongTermMemory):
|
||||
mock_nodes = [mocker.Mock(), mocker.Mock()]
|
||||
mock_memory.rag_engine.retrieve = mocker.Mock(return_value=mock_nodes)
|
||||
mock_items = [
|
||||
LongTermMemoryItem(message=UserMessage(content="user1")),
|
||||
LongTermMemoryItem(message=AIMessage(content="ai1")),
|
||||
]
|
||||
mock_memory._get_items_from_nodes = mocker.Mock(return_value=mock_items)
|
||||
|
||||
result = mock_memory._fetch_longterm_memories("query")
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0].content == "user1"
|
||||
assert result[1].content == "ai1"
|
||||
|
||||
def test_get_items_from_nodes(self, mocker, mock_memory: RoleZeroLongTermMemory):
|
||||
mock_node1 = mocker.Mock()
|
||||
mock_node2 = mocker.Mock()
|
||||
mock_node3 = mocker.Mock()
|
||||
|
||||
now = datetime.now()
|
||||
item1 = LongTermMemoryItem(
|
||||
user_message=UserMessage(content="user1"), ai_message=AIMessage(content="ai1"), created_at=now.timestamp()
|
||||
)
|
||||
item2 = LongTermMemoryItem(
|
||||
user_message=UserMessage(content="user2"),
|
||||
ai_message=AIMessage(content="ai2"),
|
||||
created_at=(now - timedelta(minutes=5)).timestamp(),
|
||||
)
|
||||
item3 = LongTermMemoryItem(
|
||||
user_message=UserMessage(content="user3"),
|
||||
ai_message=AIMessage(content="ai3"),
|
||||
created_at=(now + timedelta(minutes=5)).timestamp(),
|
||||
)
|
||||
mock_nodes = [
|
||||
mocker.Mock(
|
||||
metadata={
|
||||
"obj": LongTermMemoryItem(
|
||||
message=Message(content="2"), created_at=(now - timedelta(minutes=1)).timestamp()
|
||||
)
|
||||
}
|
||||
),
|
||||
mocker.Mock(
|
||||
metadata={
|
||||
"obj": LongTermMemoryItem(
|
||||
message=Message(content="1"), created_at=(now - timedelta(minutes=2)).timestamp()
|
||||
)
|
||||
}
|
||||
),
|
||||
mocker.Mock(metadata={"obj": LongTermMemoryItem(message=Message(content="3"), created_at=now.timestamp())}),
|
||||
]
|
||||
|
||||
mock_node1.metadata = {"obj": item1}
|
||||
mock_node2.metadata = {"obj": item2}
|
||||
mock_node3.metadata = {"obj": item3}
|
||||
|
||||
result = mock_memory._get_items_from_nodes([mock_node1, mock_node2, mock_node3])
|
||||
result = mock_memory._get_items_from_nodes(mock_nodes)
|
||||
|
||||
assert len(result) == 3
|
||||
assert result[0] == item2
|
||||
assert result[1] == item1
|
||||
assert result[2] == item3
|
||||
assert [item.user_message.content for item in result] == ["user2", "user1", "user3"]
|
||||
assert [item.ai_message.content for item in result] == ["ai2", "ai1", "ai3"]
|
||||
assert [item.message.content for item in result] == ["1", "2", "3"]
|
||||
|
|
|
|||
|
|
@ -1,208 +0,0 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.roles.di.role_zero import RoleZero
|
||||
from metagpt.schema import AIMessage, LongTermMemoryItem, Message, UserMessage
|
||||
|
||||
|
||||
class TestRoleZero:
|
||||
@pytest.fixture
|
||||
def mock_role_zero(self, mocker):
|
||||
role_zero = RoleZero()
|
||||
role_zero.rc.memory = mocker.Mock()
|
||||
role_zero.longterm_memory = mocker.Mock()
|
||||
return role_zero
|
||||
|
||||
def test_get_all_memories(self, mocker, mock_role_zero: RoleZero):
|
||||
mock_memories = [Message(content="test1"), Message(content="test2")]
|
||||
mock_role_zero._fetch_memories = mocker.Mock(return_value=mock_memories)
|
||||
|
||||
result = mock_role_zero._get_all_memories()
|
||||
|
||||
assert result == mock_memories
|
||||
mock_role_zero._fetch_memories.assert_called_once_with(k=0)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"k,should_use_ltm,memories,related_memories,is_first_from_ai,expected",
|
||||
[
|
||||
(
|
||||
None,
|
||||
False,
|
||||
[UserMessage(content="user"), AIMessage(content="ai")],
|
||||
[],
|
||||
False,
|
||||
[UserMessage(content="user"), AIMessage(content="ai")],
|
||||
),
|
||||
(
|
||||
1,
|
||||
True,
|
||||
[UserMessage(content="user")],
|
||||
[Message(content="related")],
|
||||
False,
|
||||
[Message(content="related"), UserMessage(content="user")],
|
||||
),
|
||||
(
|
||||
None,
|
||||
True,
|
||||
[AIMessage(content="ai1"), UserMessage(content="user"), AIMessage(content="ai2")],
|
||||
[Message(content="related")],
|
||||
True,
|
||||
[
|
||||
Message(content="related"),
|
||||
UserMessage(content="user"),
|
||||
AIMessage(content="ai1"),
|
||||
UserMessage(content="user"),
|
||||
AIMessage(content="ai2"),
|
||||
],
|
||||
),
|
||||
(
|
||||
None,
|
||||
True,
|
||||
[UserMessage(content="user"), AIMessage(content="ai")],
|
||||
[Message(content="related")],
|
||||
False,
|
||||
[Message(content="related"), UserMessage(content="user"), AIMessage(content="ai")],
|
||||
),
|
||||
(
|
||||
0,
|
||||
False,
|
||||
[UserMessage(content="user"), AIMessage(content="ai")],
|
||||
[],
|
||||
False,
|
||||
[UserMessage(content="user"), AIMessage(content="ai")],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_fetch_memories(
|
||||
self,
|
||||
mocker,
|
||||
mock_role_zero: RoleZero,
|
||||
k,
|
||||
should_use_ltm,
|
||||
memories,
|
||||
related_memories,
|
||||
is_first_from_ai,
|
||||
expected,
|
||||
):
|
||||
mock_role_zero.memory_k = 2
|
||||
mock_role_zero.rc.memory.get = mocker.Mock(return_value=memories)
|
||||
mock_role_zero.rc.memory.get_by_position = mocker.Mock(return_value=UserMessage(content="user"))
|
||||
mock_role_zero._should_use_longterm_memory = mocker.Mock(return_value=should_use_ltm)
|
||||
mock_role_zero.longterm_memory.fetch = mocker.Mock(return_value=related_memories)
|
||||
mock_role_zero._is_first_message_from_ai = mocker.Mock(return_value=is_first_from_ai)
|
||||
|
||||
result = mock_role_zero._fetch_memories(k)
|
||||
|
||||
assert len(result) == len(expected)
|
||||
for actual, expected_msg in zip(result, expected):
|
||||
assert actual.role == expected_msg.role
|
||||
assert actual.content == expected_msg.content
|
||||
|
||||
really_k = k if k is not None else mock_role_zero.memory_k
|
||||
mock_role_zero.rc.memory.get.assert_called_once_with(really_k)
|
||||
|
||||
if k != 0:
|
||||
mock_role_zero._should_use_longterm_memory.assert_called_once_with(k=really_k)
|
||||
|
||||
if should_use_ltm:
|
||||
mock_role_zero.longterm_memory.fetch.assert_called_once_with("user")
|
||||
mock_role_zero._is_first_message_from_ai.assert_called_once_with(memories)
|
||||
|
||||
def test_add_memory(self, mocker, mock_role_zero: RoleZero):
|
||||
message = AIMessage(content="ai")
|
||||
mock_role_zero.rc.memory.add = mocker.Mock()
|
||||
mock_role_zero._should_use_longterm_memory = mocker.Mock(return_value=True)
|
||||
mock_role_zero._transfer_to_longterm_memory = mocker.Mock()
|
||||
|
||||
mock_role_zero._add_memory(message)
|
||||
|
||||
mock_role_zero.rc.memory.add.assert_called_once_with(message)
|
||||
mock_role_zero._transfer_to_longterm_memory.assert_called_once()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"k,enable_longterm_memory,memory_count,k_memories,expected",
|
||||
[
|
||||
(0, True, 30, None, False), # k is 0
|
||||
(None, False, 30, None, False), # Long-term memory usage is disabled
|
||||
(None, True, 10, None, False), # Memory count is less than or equal to memory_k
|
||||
(None, True, 30, [], False), # k_memories is empty
|
||||
(None, True, 30, [AIMessage(content="ai")], False), # Last message in k_memories is not a user message
|
||||
(None, True, 30, [AIMessage(content="ai"), UserMessage(content="user")], True), # All conditions are met
|
||||
],
|
||||
)
|
||||
def test_should_use_longterm_memory(
|
||||
self, mocker, mock_role_zero: RoleZero, k, enable_longterm_memory, memory_count, k_memories, expected
|
||||
):
|
||||
mock_role_zero.enable_longterm_memory = enable_longterm_memory
|
||||
mock_role_zero.rc.memory.count = mocker.Mock(return_value=memory_count)
|
||||
mock_role_zero.memory_k = 20
|
||||
|
||||
result = mock_role_zero._should_use_longterm_memory(k, k_memories)
|
||||
|
||||
assert result == expected
|
||||
|
||||
def test_transfer_to_longterm_memory(self, mocker, mock_role_zero: RoleZero):
|
||||
mock_item = LongTermMemoryItem(user_message=UserMessage(content="user"), ai_message=AIMessage(content="ai"))
|
||||
mock_role_zero._get_longterm_memory_item = mocker.Mock(return_value=mock_item)
|
||||
mock_role_zero.longterm_memory = mocker.Mock()
|
||||
|
||||
mock_role_zero._transfer_to_longterm_memory()
|
||||
|
||||
mock_role_zero.longterm_memory.add.assert_called_once_with(mock_item)
|
||||
|
||||
def test_get_longterm_memory_item(self, mocker, mock_role_zero: RoleZero):
|
||||
mock_role_zero.memory_k = 2
|
||||
mock_messages = [
|
||||
UserMessage(content="user1"),
|
||||
AIMessage(content="ai1"),
|
||||
UserMessage(content="user2"),
|
||||
AIMessage(content="ai2"),
|
||||
UserMessage(content="user3"), # memory_k + 2
|
||||
AIMessage(content="ai3"), # memory_k + 1
|
||||
UserMessage(content="recent1"),
|
||||
AIMessage(content="recent2"),
|
||||
]
|
||||
|
||||
mock_role_zero.rc.memory.get_by_position = mocker.Mock(side_effect=lambda i: mock_messages[i])
|
||||
mock_role_zero.rc.memory.count = mocker.Mock(return_value=len(mock_messages))
|
||||
|
||||
result = mock_role_zero._get_longterm_memory_item()
|
||||
|
||||
assert isinstance(result, LongTermMemoryItem)
|
||||
assert result.user_message.content == "user3"
|
||||
assert result.ai_message.content == "ai3"
|
||||
|
||||
mock_role_zero.rc.memory.get_by_position.assert_any_call(-(mock_role_zero.memory_k + 1))
|
||||
mock_role_zero.rc.memory.get_by_position.assert_any_call(-(mock_role_zero.memory_k + 2))
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"memories,expected",
|
||||
[
|
||||
([AIMessage(content="ai")], True),
|
||||
([UserMessage(content="user")], False),
|
||||
([], False),
|
||||
],
|
||||
)
|
||||
def test_is_first_message_from_ai(self, mock_role_zero: RoleZero, memories, expected):
|
||||
result = mock_role_zero._is_first_message_from_ai(memories)
|
||||
assert result == expected
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"memories,expected",
|
||||
[
|
||||
([UserMessage(content="user1"), AIMessage(content="ai"), UserMessage(content="user2")], "user2"),
|
||||
(
|
||||
[
|
||||
UserMessage(content="user1", cause_by="test"),
|
||||
AIMessage(content="ai"),
|
||||
UserMessage(content="user2", cause_by="test"),
|
||||
],
|
||||
"",
|
||||
),
|
||||
([AIMessage(content="ai1"), AIMessage(content="ai2")], ""),
|
||||
([UserMessage(content="user")], "user"),
|
||||
([], ""),
|
||||
],
|
||||
)
|
||||
def test_build_longterm_memory_query(self, mock_role_zero: RoleZero, memories, expected):
|
||||
result = mock_role_zero._build_longterm_memory_query(memories)
|
||||
assert result == expected
|
||||
Loading…
Add table
Add a link
Reference in a new issue