This commit is contained in:
seehi 2024-09-10 16:44:22 +08:00
parent 31dbff5474
commit 5d21d255e4
5 changed files with 95 additions and 10 deletions

View file

@ -1,3 +1,5 @@
from datetime import datetime, timedelta
import pytest
from metagpt.memory.role_zero_memory import RoleZeroLongTermMemory
@ -46,3 +48,36 @@ class TestRoleZeroLongTermMemory:
item = LongTermMemoryItem(user_message=UserMessage(content="user"), ai_message=AIMessage(content="ai"))
mock_memory.add(item)
mock_memory.rag_engine.add_objs.assert_called_once_with([item])
def test_get_items_from_nodes(self, mocker, mock_memory: RoleZeroLongTermMemory):
mock_node1 = mocker.Mock()
mock_node2 = mocker.Mock()
mock_node3 = mocker.Mock()
now = datetime.now()
item1 = LongTermMemoryItem(
user_message=UserMessage(content="user1"), ai_message=AIMessage(content="ai1"), created_at=now.timestamp()
)
item2 = LongTermMemoryItem(
user_message=UserMessage(content="user2"),
ai_message=AIMessage(content="ai2"),
created_at=(now - timedelta(minutes=5)).timestamp(),
)
item3 = LongTermMemoryItem(
user_message=UserMessage(content="user3"),
ai_message=AIMessage(content="ai3"),
created_at=(now + timedelta(minutes=5)).timestamp(),
)
mock_node1.metadata = {"obj": item1}
mock_node2.metadata = {"obj": item2}
mock_node3.metadata = {"obj": item3}
result = mock_memory._get_items_from_nodes([mock_node1, mock_node2, mock_node3])
assert len(result) == 3
assert result[0] == item2
assert result[1] == item1
assert result[2] == item3
assert [item.user_message.content for item in result] == ["user2", "user1", "user3"]
assert [item.ai_message.content for item in result] == ["ai2", "ai1", "ai3"]

View file

@ -46,7 +46,13 @@ class TestRoleZero:
[AIMessage(content="ai1"), UserMessage(content="user"), AIMessage(content="ai2")],
[Message(content="related")],
True,
[Message(content="related"), UserMessage(content="user"), AIMessage(content="ai2")],
[
Message(content="related"),
UserMessage(content="user"),
AIMessage(content="ai1"),
UserMessage(content="user"),
AIMessage(content="ai2"),
],
),
(
None,
@ -79,6 +85,7 @@ class TestRoleZero:
):
mock_role_zero.memory_k = 2
mock_role_zero.rc.memory.get = mocker.Mock(return_value=memories)
mock_role_zero.rc.memory.get_by_position = mocker.Mock(return_value=UserMessage(content="user"))
mock_role_zero._should_use_longterm_memory = mocker.Mock(return_value=should_use_ltm)
mock_role_zero.longterm_memory.fetch = mocker.Mock(return_value=related_memories)
mock_role_zero._is_first_message_from_ai = mocker.Mock(return_value=is_first_from_ai)
@ -97,7 +104,7 @@ class TestRoleZero:
mock_role_zero._should_use_longterm_memory.assert_called_once_with(k=really_k, k_memories=memories)
if should_use_ltm:
mock_role_zero.longterm_memory.fetch.assert_called_once_with(memories[-1].content)
mock_role_zero.longterm_memory.fetch.assert_called_once_with("user")
mock_role_zero._is_first_message_from_ai.assert_called_once_with(memories)
def test_add_memory(self, mocker, mock_role_zero: RoleZero):
@ -190,3 +197,24 @@ class TestRoleZero:
def test_is_first_message_from_ai(self, mock_role_zero: RoleZero, memories, expected):
result = mock_role_zero._is_first_message_from_ai(memories)
assert result == expected
@pytest.mark.parametrize(
"memories,expected",
[
([UserMessage(content="user1"), AIMessage(content="ai"), UserMessage(content="user2")], "user2"),
(
[
UserMessage(content="user1", cause_by="test"),
AIMessage(content="ai"),
UserMessage(content="user2", cause_by="test"),
],
"",
),
([AIMessage(content="ai1"), AIMessage(content="ai2")], ""),
([UserMessage(content="user")], "user"),
([], ""),
],
)
def test_build_longterm_memory_query(self, mock_role_zero: RoleZero, memories, expected):
result = mock_role_zero._build_longterm_memory_query(memories)
assert result == expected