Change the way of taking over memory

This commit is contained in:
seehi 2024-09-11 20:25:29 +08:00
parent d077cd0b2f
commit c0d5c031f8
8 changed files with 270 additions and 407 deletions

View file

@ -158,5 +158,5 @@ SWE_SETUP_PATH = get_metagpt_package_root() / "metagpt/tools/swe_agent_commands/
# experience pool
EXPERIENCE_MASK = "<experience>"
# Used to identify user requirements in the memory index.
USER_REQUIREMENT = "metagpt.actions.add_requirement.UserRequirement"
# TeamLeader's name
TEAMLEADER_NAME = "Mike"

View file

@ -8,7 +8,7 @@ from metagpt.actions import (
WriteTest,
)
from metagpt.actions.summarize_code import SummarizeCode
from metagpt.const import AGENT, IMAGES
from metagpt.const import AGENT, IMAGES, TEAMLEADER_NAME
from metagpt.environment.base_env import Environment
from metagpt.logs import get_human_input
from metagpt.roles import Architect, ProductManager, ProjectManager, Role
@ -31,7 +31,7 @@ class MGXEnv(Environment, SerializationMixin):
"""let the team leader take over message publishing"""
message = self.attach_images(message) # for multi-modal message
tl = self.get_role("Mike") # TeamLeader's name is Mike
tl = self.get_role(TEAMLEADER_NAME) # TeamLeader's name is Mike
if user_defined_recipient:
# human user's direct chat message to a certain role

View file

@ -1,8 +1,14 @@
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Optional
from pydantic import BaseModel, Field
from pydantic import Field
from metagpt.actions import UserRequirement
from metagpt.const import TEAMLEADER_NAME
from metagpt.logs import logger
from metagpt.memory import Memory
from metagpt.schema import LongTermMemoryItem, Message
from metagpt.utils.common import any_to_str
from metagpt.utils.exceptions import handle_exception
if TYPE_CHECKING:
from llama_index.core.schema import NodeWithScore
@ -10,9 +16,10 @@ if TYPE_CHECKING:
from metagpt.rag.engines import SimpleEngine
class RoleZeroLongTermMemory(BaseModel):
class RoleZeroLongTermMemory(Memory):
persist_path: str = Field(default=".role_memory_data", description="The directory to save data.")
collection_name: str = Field(default="role_zero", description="The name of the collection, such as the role name.")
memory_k: int = Field(default=200, description="The capacity of short-term memory.")
_rag_engine: Any = None
@ -44,7 +51,104 @@ class RoleZeroLongTermMemory(BaseModel):
return rag_engine
def fetch(self, query: str) -> list[Message]:
def add(self, message: Message):
"""Add a new message and potentially transfer it to long-term memory."""
super().add(message)
if not self._should_use_longterm_memory_for_add():
return
self._transfer_to_longterm_memory()
def get(self, k=0) -> list[Message]:
"""Return recent memories and optionally combines them with related long-term memories."""
memories = super().get(k)
if not self._should_use_longterm_memory_for_get(k=k):
return memories
query = self._build_longterm_memory_query()
related_memories = self._fetch_longterm_memories(query)
logger.info(f"Fetched {len(related_memories)} long-term memories.")
final_memories = related_memories + memories
return final_memories
def _should_use_longterm_memory_for_add(self) -> bool:
"""Determines if long-term memory should be used for add."""
return self.count() > self.memory_k
def _should_use_longterm_memory_for_get(self, k: int) -> bool:
"""Determines if long-term memory should be used for get.
Long-term memory is used if:
- k is not 0.
- The last message is from user requirement.
- The count of recent memories is greater than self.memory_k.
"""
conds = [
k != 0,
self._is_last_message_from_user_requirement(),
self.count() > self.memory_k,
]
return all(conds)
def _transfer_to_longterm_memory(self):
item = self._get_longterm_memory_item()
self._add_to_longterm_memory(item)
@handle_exception
def _get_longterm_memory_item(self) -> Optional[LongTermMemoryItem]:
"""Retrieves the most recent message before the last k messages."""
index = -(self.memory_k + 1)
message = self.get_by_position(index)
return LongTermMemoryItem(message=message)
def _add_to_longterm_memory(self, item: LongTermMemoryItem):
"""Adds a long-term memory item to the RAG engine."""
if not item:
return
self.rag_engine.add_objs([item])
def _build_longterm_memory_query(self) -> str:
"""Build the content used to query related long-term memory.
Default is to get the most recent user message, or an empty string if none is found.
"""
message = self._get_the_last_message()
return message.content if message else ""
def _get_the_last_message(self) -> Optional[Message]:
if not self.count():
return None
return self.get_by_position(-1)
def _is_last_message_from_user_requirement(self) -> bool:
message = self._get_the_last_message()
if not message:
return False
is_user_message = message.is_user_message()
cause_by_user_requirement = message.cause_by == any_to_str(UserRequirement)
sent_from_team_leader = message.sent_from == TEAMLEADER_NAME
return is_user_message and (cause_by_user_requirement or sent_from_team_leader)
def _fetch_longterm_memories(self, query: str) -> list[Message]:
"""Fetches long-term memories based on a query.
Args:
@ -59,26 +163,10 @@ class RoleZeroLongTermMemory(BaseModel):
nodes = self.rag_engine.retrieve(query)
items = self._get_items_from_nodes(nodes)
memories = []
for item in items:
memories.append(item.user_message)
memories.append(item.ai_message)
memories = [item.message for item in items]
return memories
def add(self, item: LongTermMemoryItem):
"""Adds a long-term memory item to the RAG engine.
Args:
item (LongTermMemoryItem): The memory item containing user and AI messages.
"""
if not item:
return
self.rag_engine.add_objs([item])
def _get_items_from_nodes(self, nodes: list["NodeWithScore"]) -> list[LongTermMemoryItem]:
"""Get items from nodes and arrange them in order of their `created_at`."""

View file

@ -13,7 +13,7 @@ from metagpt.actions import Action, UserRequirement
from metagpt.actions.analyze_requirements import AnalyzeRequirementsRestrictions
from metagpt.actions.di.run_command import RunCommand
from metagpt.actions.search_enhanced_qa import SearchEnhancedQA
from metagpt.const import IMAGES, USER_REQUIREMENT
from metagpt.const import IMAGES
from metagpt.exp_pool import exp_cache
from metagpt.exp_pool.context_builders import RoleZeroContextBuilder
from metagpt.exp_pool.serializers import RoleZeroSerializer
@ -34,7 +34,7 @@ from metagpt.prompts.di.role_zero import (
SYSTEM_PROMPT,
)
from metagpt.roles import Role
from metagpt.schema import AIMessage, LongTermMemoryItem, Message, UserMessage
from metagpt.schema import AIMessage, Message, UserMessage
from metagpt.strategy.experience_retriever import DummyExpRetriever, ExpRetriever
from metagpt.strategy.planner import Planner
from metagpt.tools.libs.browser import Browser
@ -42,7 +42,6 @@ from metagpt.tools.libs.editor import Editor
from metagpt.tools.tool_recommend import BM25ToolRecommender, ToolRecommender
from metagpt.tools.tool_registry import register_tool
from metagpt.utils.common import CodeParser, any_to_str, extract_and_encode_images
from metagpt.utils.exceptions import handle_exception
from metagpt.utils.repair_llm_raw_output import (
RepairType,
repair_escape_error,
@ -94,7 +93,6 @@ class RoleZero(Role):
commands: list[dict] = [] # commands to be executed
memory_k: int = 200 # number of memories (messages) to use as historical context
enable_longterm_memory: bool = True # whether to use longterm memory
longterm_memory: RoleZeroLongTermMemory = None
use_fixed_sop: bool = False
requirements_constraints: str = "" # the constraints in user requirements
use_summary: bool = True # whether to summarize at the end
@ -176,8 +174,8 @@ class RoleZero(Role):
The role name will be used as the collection name.
"""
if self.enable_longterm_memory and not self.longterm_memory:
self.longterm_memory = RoleZeroLongTermMemory(collection_name=self.name.replace(" ", ""))
if self.enable_longterm_memory:
self.rc.memory = RoleZeroLongTermMemory(collection_name=self.name.replace(" ", ""), memory_k=self.memory_k)
return self
@ -195,7 +193,7 @@ class RoleZero(Role):
return False
if not self.planner.plan.goal:
self.planner.plan.goal = self._get_all_memories()[-1].content
self.planner.plan.goal = self.get_memories()[-1].content
self.requirements_constraints = await AnalyzeRequirementsRestrictions().run(self.planner.plan.goal)
### 1. Experience ###
@ -227,7 +225,7 @@ class RoleZero(Role):
)
### Recent Observation ###
memory = self._fetch_memories()
memory = self.rc.memory.get(self.memory_k)
memory = await self.parse_browser_actions(memory)
memory = self.parse_images(memory)
@ -282,15 +280,15 @@ class RoleZero(Role):
return await super()._act()
commands, ok, self.command_rsp = await self._parse_commands(self.command_rsp)
self._add_memory(AIMessage(content=self.command_rsp))
self.rc.memory.add(AIMessage(content=self.command_rsp))
if not ok:
error_msg = commands
self._add_memory(UserMessage(content=error_msg, cause_by=RunCommand))
self.rc.memory.add(UserMessage(content=error_msg, cause_by=RunCommand))
return error_msg
logger.info(f"Commands: \n{commands}")
outputs = await self._run_commands(commands)
logger.info(f"Commands outputs: \n{outputs}")
self._add_memory(UserMessage(content=outputs, cause_by=RunCommand))
self.rc.memory.add(UserMessage(content=outputs, cause_by=RunCommand))
return AIMessage(
content=f"I have finished the task, please mark my task as finished. Outputs: {outputs}",
@ -343,7 +341,7 @@ class RoleZero(Role):
return rsp_msg, ""
# routing
memory = self._fetch_memories()
memory = self.get_memories(k=self.memory_k)
context = self.llm.format_msg(memory + [UserMessage(content=QUICK_THINK_PROMPT)])
async with ThoughtReporter() as reporter:
await reporter.async_report({"type": "classify"})
@ -368,7 +366,7 @@ class RoleZero(Role):
answer = await SearchEnhancedQA().run(query)
if answer:
self._add_memory(AIMessage(content=answer, cause_by=RunCommand))
self.rc.memory.add(AIMessage(content=answer, cause_by=RunCommand))
await self.reply_to_human(content=answer)
rsp_msg = AIMessage(
content="Complete run",
@ -379,7 +377,7 @@ class RoleZero(Role):
return rsp_msg, intent_result
async def _check_duplicates(self, req: list[dict], command_rsp: str):
past_rsp = [mem.content for mem in self._fetch_memories()]
past_rsp = [mem.content for mem in self.rc.memory.get(self.memory_k)]
if command_rsp in past_rsp:
# Normal response with thought contents are highly unlikely to reproduce
# If an identical response is detected, it is a bad response, mostly due to LLM repeating generated content
@ -537,7 +535,7 @@ class RoleZero(Role):
def _retrieve_experience(self) -> str:
"""Default implementation of experience retrieval. Can be overwritten in subclasses."""
context = [str(msg) for msg in self._fetch_memories()]
context = [str(msg) for msg in self.rc.memory.get(self.memory_k)]
context = "\n\n".join(context)
example = self.experience_retriever.retrieve(context=context)
return example
@ -562,9 +560,9 @@ class RoleZero(Role):
async def _end(self, **kwarg):
self._set_state(-1)
memory = self._fetch_memories()
memory = self.rc.memory.get(self.memory_k)
# Ensure reply to the human before the "end" command is executed. Hard code k=5 for checking.
if not any(["reply_to_human" in memory.content for memory in self._fetch_memories(k=5)]):
if not any(["reply_to_human" in memory.content for memory in self.get_memories(k=5)]):
logger.info("manually reply to human")
pattern = r"\[Language Restrictions\](.*?)\n"
match = re.search(pattern, self.requirements_constraints, re.DOTALL)
@ -573,106 +571,10 @@ class RoleZero(Role):
await reporter.async_report({"type": "quick"})
reply_content = await self.llm.aask(self.llm.format_msg(memory + [UserMessage(reply_to_human_prompt)]))
await self.reply_to_human(content=reply_content)
self._add_memory(AIMessage(content=reply_content, cause_by=RunCommand))
self.rc.memory.add(AIMessage(content=reply_content, cause_by=RunCommand))
outputs = ""
# Summary of the Completed Task and Deliverables
if self.use_summary:
logger.info("end current run and summarize")
outputs = await self.llm.aask(self.llm.format_msg(memory + [UserMessage(SUMMARY_PROMPT)]))
return outputs
def _get_all_memories(self) -> list[Message]:
return self._fetch_memories(k=0)
def _fetch_memories(self, k: Optional[int] = None) -> list[Message]:
"""Fetches recent memories and optionally combines them with related long-term memories.
If long-term memory is not enabled or the last message is not from the user,
it returns the recent memories without fetching from long-term memory.
Args:
k (Optional[int]): The number of recent memories to fetch. If None, defaults to self.memory_k.
Returns:
List[Message]: A list of messages representing the combined memories.
"""
if k is None:
k = self.memory_k
memories = self.rc.memory.get(k)
if not self._should_use_longterm_memory(k=k):
return memories
query = self._build_longterm_memory_query()
related_memories = self.longterm_memory.fetch(query)
logger.info(f"Fetched {len(related_memories)} long-term memories.")
# Keep user and AI messages are paired.
if self._is_first_message_from_ai(memories):
memories.insert(0, self.rc.memory.get_by_position(-(k + 1)))
final_memories = related_memories + memories
return final_memories
def _add_memory(self, message: Message):
self.rc.memory.add(message)
if not self._should_use_longterm_memory():
return
self._transfer_to_longterm_memory()
def _should_use_longterm_memory(self, k: int = None) -> bool:
"""Determines if long-term memory should be used.
Long-term memory is used if:
- k is not 0.
- Long-term memory usage is enabled.
- The count of recent memories is greater than self.memory_k.
"""
conds = [
k != 0,
self.enable_longterm_memory,
self.rc.memory.count() > self.memory_k,
]
return all(conds)
def _transfer_to_longterm_memory(self):
item = self._get_longterm_memory_item()
self.longterm_memory.add(item)
@handle_exception
def _get_longterm_memory_item(self) -> Optional[LongTermMemoryItem]:
"""Retrieves the most recent pair of user and AI messages before the last k messages."""
index = -(self.memory_k + 1)
message = self.rc.memory.get_by_position(index)
if not message.is_ai_message():
return None
index = -(self.memory_k + 2)
user_message = self.rc.memory.get_by_position(index)
return LongTermMemoryItem(user_message=user_message, ai_message=message)
def _is_first_message_from_ai(self, memories: list[Message]) -> bool:
return bool(memories and memories[0].is_ai_message())
def _build_longterm_memory_query(self) -> str:
"""Build the content used to query related long-term memory.
Default is to get the most recent user message, or an empty string if none is found.
"""
message = self._get_the_last_user_message()
return message.content if message else ""
def _get_the_last_user_message(self) -> Message:
values = self.rc.memory.index.get(USER_REQUIREMENT, [])
return values[-1] if values else None

View file

@ -5,6 +5,7 @@ from typing import Annotated
from pydantic import Field
from metagpt.actions.di.run_command import RunCommand
from metagpt.const import TEAMLEADER_NAME
from metagpt.prompts.di.team_leader import (
FINISH_CURRENT_TASK_CMD,
TL_INFO,
@ -19,7 +20,7 @@ from metagpt.tools.tool_registry import register_tool
@register_tool(include_functions=["publish_team_message"])
class TeamLeader(RoleZero):
name: str = "Mike"
name: str = TEAMLEADER_NAME
profile: str = "Team Leader"
goal: str = "Manage a team to assist users"
thought_guidance: str = TL_THOUGHT_GUIDANCE
@ -84,4 +85,4 @@ class TeamLeader(RoleZero):
def finish_current_task(self):
self.planner.plan.finish_current_task()
self._add_memory(AIMessage(content=FINISH_CURRENT_TASK_CMD))
self.rc.memory.add(AIMessage(content=FINISH_CURRENT_TASK_CMD))

View file

@ -965,9 +965,8 @@ class BaseEnum(Enum):
class LongTermMemoryItem(BaseModel):
user_message: Message
ai_message: Message
message: Message
created_at: Optional[float] = Field(default_factory=time.time)
def rag_key(self) -> str:
return self.user_message.content
return self.message.content

View file

@ -2,8 +2,10 @@ from datetime import datetime, timedelta
import pytest
from metagpt.actions import UserRequirement
from metagpt.const import TEAMLEADER_NAME
from metagpt.memory.role_zero_memory import RoleZeroLongTermMemory
from metagpt.schema import AIMessage, LongTermMemoryItem, UserMessage
from metagpt.schema import AIMessage, LongTermMemoryItem, Message, UserMessage
class TestRoleZeroLongTermMemory:
@ -13,71 +15,150 @@ class TestRoleZeroLongTermMemory:
memory._resolve_rag_engine = mocker.Mock()
return memory
def test_fetch_empty_query(self, mock_memory: RoleZeroLongTermMemory):
assert mock_memory.fetch("") == []
def test_add(self, mocker, mock_memory: RoleZeroLongTermMemory):
mock_memory._should_use_longterm_memory_for_add = mocker.Mock(return_value=True)
mock_memory._transfer_to_longterm_memory = mocker.Mock()
def test_fetch(self, mocker, mock_memory: RoleZeroLongTermMemory):
mock_node1 = mocker.Mock()
mock_node2 = mocker.Mock()
mock_node1.metadata = {
"obj": LongTermMemoryItem(user_message=UserMessage(content="user1"), ai_message=AIMessage(content="ai1"))
}
mock_node2.metadata = {
"obj": LongTermMemoryItem(user_message=UserMessage(content="user2"), ai_message=AIMessage(content="ai2"))
}
message = UserMessage(content="test")
mock_memory.add(message)
mock_memory.rag_engine.retrieve.return_value = [mock_node1, mock_node2]
assert mock_memory.storage[-1] == message
mock_memory._transfer_to_longterm_memory.assert_called_once()
result = mock_memory.fetch("test query")
def test_get(self, mocker, mock_memory: RoleZeroLongTermMemory):
mock_memory._should_use_longterm_memory_for_get = mocker.Mock(return_value=True)
mock_memory._build_longterm_memory_query = mocker.Mock(return_value="query")
mock_memory._fetch_longterm_memories = mocker.Mock(return_value=[Message(content="long-term")])
assert len(result) == 4
assert isinstance(result[0], UserMessage)
assert isinstance(result[1], AIMessage)
assert result[0].content == "user1"
assert result[1].content == "ai1"
assert result[2].content == "user2"
assert result[3].content == "ai2"
mock_memory.storage = [Message(content="short-term")]
mock_memory.rag_engine.retrieve.assert_called_once_with("test query")
result = mock_memory.get()
def test_add_empty_item(self, mock_memory: RoleZeroLongTermMemory):
mock_memory.add(None)
mock_memory.rag_engine.add_objs.assert_not_called()
assert len(result) == 2
assert result[0].content == "long-term"
assert result[1].content == "short-term"
def test_should_use_longterm_memory_for_add(self, mocker, mock_memory: RoleZeroLongTermMemory):
mocker.patch.object(mock_memory, "storage", [None] * 201)
mock_memory.memory_k = 200
assert mock_memory._should_use_longterm_memory_for_add() == True
mocker.patch.object(mock_memory, "storage", [None] * 199)
assert mock_memory._should_use_longterm_memory_for_add() == False
@pytest.mark.parametrize(
"k,is_last_from_user,count,expected",
[
(0, True, 201, False),
(1, False, 201, False),
(1, True, 199, False),
(1, True, 201, True),
],
)
def test_should_use_longterm_memory_for_get(
self, mocker, mock_memory: RoleZeroLongTermMemory, k, is_last_from_user, count, expected
):
mock_memory._is_last_message_from_user_requirement = mocker.Mock(return_value=is_last_from_user)
mocker.patch.object(mock_memory, "storage", [None] * count)
mock_memory.memory_k = 200
assert mock_memory._should_use_longterm_memory_for_get(k) == expected
def test_transfer_to_longterm_memory(self, mocker, mock_memory: RoleZeroLongTermMemory):
mock_item = mocker.Mock()
mock_memory._get_longterm_memory_item = mocker.Mock(return_value=mock_item)
mock_memory._add_to_longterm_memory = mocker.Mock()
mock_memory._transfer_to_longterm_memory()
mock_memory._add_to_longterm_memory.assert_called_once_with(mock_item)
def test_get_longterm_memory_item(self, mocker, mock_memory: RoleZeroLongTermMemory):
mock_message = Message(content="test")
mock_memory.storage = [mock_message, mock_message]
mock_memory.memory_k = 1
result = mock_memory._get_longterm_memory_item()
assert isinstance(result, LongTermMemoryItem)
assert result.message == mock_message
def test_add_to_longterm_memory(self, mock_memory: RoleZeroLongTermMemory):
item = LongTermMemoryItem(message=Message(content="test"))
mock_memory._add_to_longterm_memory(item)
def test_add_item(self, mock_memory: RoleZeroLongTermMemory):
item = LongTermMemoryItem(user_message=UserMessage(content="user"), ai_message=AIMessage(content="ai"))
mock_memory.add(item)
mock_memory.rag_engine.add_objs.assert_called_once_with([item])
def test_build_longterm_memory_query(self, mocker, mock_memory: RoleZeroLongTermMemory):
mock_message = Message(content="query")
mock_memory._get_the_last_message = mocker.Mock(return_value=mock_message)
result = mock_memory._build_longterm_memory_query()
assert result == "query"
def test_get_the_last_message(self, mock_memory: RoleZeroLongTermMemory):
mock_memory.storage = [Message(content="1"), Message(content="2")]
result = mock_memory._get_the_last_message()
assert result.content == "2"
@pytest.mark.parametrize(
"message,expected",
[
(UserMessage(content="test", cause_by=UserRequirement), True),
(UserMessage(content="test", sent_from=TEAMLEADER_NAME), True),
(UserMessage(content="test"), True),
(AIMessage(content="test"), False),
(None, False),
],
)
def test_is_last_message_from_user_requirement(
self, mocker, mock_memory: RoleZeroLongTermMemory, message, expected
):
mock_memory._get_the_last_message = mocker.Mock(return_value=message)
assert mock_memory._is_last_message_from_user_requirement() == expected
def test_fetch_longterm_memories(self, mocker, mock_memory: RoleZeroLongTermMemory):
mock_nodes = [mocker.Mock(), mocker.Mock()]
mock_memory.rag_engine.retrieve = mocker.Mock(return_value=mock_nodes)
mock_items = [
LongTermMemoryItem(message=UserMessage(content="user1")),
LongTermMemoryItem(message=AIMessage(content="ai1")),
]
mock_memory._get_items_from_nodes = mocker.Mock(return_value=mock_items)
result = mock_memory._fetch_longterm_memories("query")
assert len(result) == 2
assert result[0].content == "user1"
assert result[1].content == "ai1"
def test_get_items_from_nodes(self, mocker, mock_memory: RoleZeroLongTermMemory):
mock_node1 = mocker.Mock()
mock_node2 = mocker.Mock()
mock_node3 = mocker.Mock()
now = datetime.now()
item1 = LongTermMemoryItem(
user_message=UserMessage(content="user1"), ai_message=AIMessage(content="ai1"), created_at=now.timestamp()
)
item2 = LongTermMemoryItem(
user_message=UserMessage(content="user2"),
ai_message=AIMessage(content="ai2"),
created_at=(now - timedelta(minutes=5)).timestamp(),
)
item3 = LongTermMemoryItem(
user_message=UserMessage(content="user3"),
ai_message=AIMessage(content="ai3"),
created_at=(now + timedelta(minutes=5)).timestamp(),
)
mock_nodes = [
mocker.Mock(
metadata={
"obj": LongTermMemoryItem(
message=Message(content="2"), created_at=(now - timedelta(minutes=1)).timestamp()
)
}
),
mocker.Mock(
metadata={
"obj": LongTermMemoryItem(
message=Message(content="1"), created_at=(now - timedelta(minutes=2)).timestamp()
)
}
),
mocker.Mock(metadata={"obj": LongTermMemoryItem(message=Message(content="3"), created_at=now.timestamp())}),
]
mock_node1.metadata = {"obj": item1}
mock_node2.metadata = {"obj": item2}
mock_node3.metadata = {"obj": item3}
result = mock_memory._get_items_from_nodes([mock_node1, mock_node2, mock_node3])
result = mock_memory._get_items_from_nodes(mock_nodes)
assert len(result) == 3
assert result[0] == item2
assert result[1] == item1
assert result[2] == item3
assert [item.user_message.content for item in result] == ["user2", "user1", "user3"]
assert [item.ai_message.content for item in result] == ["ai2", "ai1", "ai3"]
assert [item.message.content for item in result] == ["1", "2", "3"]

View file

@ -1,208 +0,0 @@
import pytest
from metagpt.roles.di.role_zero import RoleZero
from metagpt.schema import AIMessage, LongTermMemoryItem, Message, UserMessage
class TestRoleZero:
@pytest.fixture
def mock_role_zero(self, mocker):
role_zero = RoleZero()
role_zero.rc.memory = mocker.Mock()
role_zero.longterm_memory = mocker.Mock()
return role_zero
def test_get_all_memories(self, mocker, mock_role_zero: RoleZero):
mock_memories = [Message(content="test1"), Message(content="test2")]
mock_role_zero._fetch_memories = mocker.Mock(return_value=mock_memories)
result = mock_role_zero._get_all_memories()
assert result == mock_memories
mock_role_zero._fetch_memories.assert_called_once_with(k=0)
@pytest.mark.parametrize(
"k,should_use_ltm,memories,related_memories,is_first_from_ai,expected",
[
(
None,
False,
[UserMessage(content="user"), AIMessage(content="ai")],
[],
False,
[UserMessage(content="user"), AIMessage(content="ai")],
),
(
1,
True,
[UserMessage(content="user")],
[Message(content="related")],
False,
[Message(content="related"), UserMessage(content="user")],
),
(
None,
True,
[AIMessage(content="ai1"), UserMessage(content="user"), AIMessage(content="ai2")],
[Message(content="related")],
True,
[
Message(content="related"),
UserMessage(content="user"),
AIMessage(content="ai1"),
UserMessage(content="user"),
AIMessage(content="ai2"),
],
),
(
None,
True,
[UserMessage(content="user"), AIMessage(content="ai")],
[Message(content="related")],
False,
[Message(content="related"), UserMessage(content="user"), AIMessage(content="ai")],
),
(
0,
False,
[UserMessage(content="user"), AIMessage(content="ai")],
[],
False,
[UserMessage(content="user"), AIMessage(content="ai")],
),
],
)
def test_fetch_memories(
self,
mocker,
mock_role_zero: RoleZero,
k,
should_use_ltm,
memories,
related_memories,
is_first_from_ai,
expected,
):
mock_role_zero.memory_k = 2
mock_role_zero.rc.memory.get = mocker.Mock(return_value=memories)
mock_role_zero.rc.memory.get_by_position = mocker.Mock(return_value=UserMessage(content="user"))
mock_role_zero._should_use_longterm_memory = mocker.Mock(return_value=should_use_ltm)
mock_role_zero.longterm_memory.fetch = mocker.Mock(return_value=related_memories)
mock_role_zero._is_first_message_from_ai = mocker.Mock(return_value=is_first_from_ai)
result = mock_role_zero._fetch_memories(k)
assert len(result) == len(expected)
for actual, expected_msg in zip(result, expected):
assert actual.role == expected_msg.role
assert actual.content == expected_msg.content
really_k = k if k is not None else mock_role_zero.memory_k
mock_role_zero.rc.memory.get.assert_called_once_with(really_k)
if k != 0:
mock_role_zero._should_use_longterm_memory.assert_called_once_with(k=really_k)
if should_use_ltm:
mock_role_zero.longterm_memory.fetch.assert_called_once_with("user")
mock_role_zero._is_first_message_from_ai.assert_called_once_with(memories)
def test_add_memory(self, mocker, mock_role_zero: RoleZero):
message = AIMessage(content="ai")
mock_role_zero.rc.memory.add = mocker.Mock()
mock_role_zero._should_use_longterm_memory = mocker.Mock(return_value=True)
mock_role_zero._transfer_to_longterm_memory = mocker.Mock()
mock_role_zero._add_memory(message)
mock_role_zero.rc.memory.add.assert_called_once_with(message)
mock_role_zero._transfer_to_longterm_memory.assert_called_once()
@pytest.mark.parametrize(
"k,enable_longterm_memory,memory_count,k_memories,expected",
[
(0, True, 30, None, False), # k is 0
(None, False, 30, None, False), # Long-term memory usage is disabled
(None, True, 10, None, False), # Memory count is less than or equal to memory_k
(None, True, 30, [], False), # k_memories is empty
(None, True, 30, [AIMessage(content="ai")], False), # Last message in k_memories is not a user message
(None, True, 30, [AIMessage(content="ai"), UserMessage(content="user")], True), # All conditions are met
],
)
def test_should_use_longterm_memory(
self, mocker, mock_role_zero: RoleZero, k, enable_longterm_memory, memory_count, k_memories, expected
):
mock_role_zero.enable_longterm_memory = enable_longterm_memory
mock_role_zero.rc.memory.count = mocker.Mock(return_value=memory_count)
mock_role_zero.memory_k = 20
result = mock_role_zero._should_use_longterm_memory(k, k_memories)
assert result == expected
def test_transfer_to_longterm_memory(self, mocker, mock_role_zero: RoleZero):
mock_item = LongTermMemoryItem(user_message=UserMessage(content="user"), ai_message=AIMessage(content="ai"))
mock_role_zero._get_longterm_memory_item = mocker.Mock(return_value=mock_item)
mock_role_zero.longterm_memory = mocker.Mock()
mock_role_zero._transfer_to_longterm_memory()
mock_role_zero.longterm_memory.add.assert_called_once_with(mock_item)
def test_get_longterm_memory_item(self, mocker, mock_role_zero: RoleZero):
mock_role_zero.memory_k = 2
mock_messages = [
UserMessage(content="user1"),
AIMessage(content="ai1"),
UserMessage(content="user2"),
AIMessage(content="ai2"),
UserMessage(content="user3"), # memory_k + 2
AIMessage(content="ai3"), # memory_k + 1
UserMessage(content="recent1"),
AIMessage(content="recent2"),
]
mock_role_zero.rc.memory.get_by_position = mocker.Mock(side_effect=lambda i: mock_messages[i])
mock_role_zero.rc.memory.count = mocker.Mock(return_value=len(mock_messages))
result = mock_role_zero._get_longterm_memory_item()
assert isinstance(result, LongTermMemoryItem)
assert result.user_message.content == "user3"
assert result.ai_message.content == "ai3"
mock_role_zero.rc.memory.get_by_position.assert_any_call(-(mock_role_zero.memory_k + 1))
mock_role_zero.rc.memory.get_by_position.assert_any_call(-(mock_role_zero.memory_k + 2))
@pytest.mark.parametrize(
"memories,expected",
[
([AIMessage(content="ai")], True),
([UserMessage(content="user")], False),
([], False),
],
)
def test_is_first_message_from_ai(self, mock_role_zero: RoleZero, memories, expected):
result = mock_role_zero._is_first_message_from_ai(memories)
assert result == expected
@pytest.mark.parametrize(
"memories,expected",
[
([UserMessage(content="user1"), AIMessage(content="ai"), UserMessage(content="user2")], "user2"),
(
[
UserMessage(content="user1", cause_by="test"),
AIMessage(content="ai"),
UserMessage(content="user2", cause_by="test"),
],
"",
),
([AIMessage(content="ai1"), AIMessage(content="ai2")], ""),
([UserMessage(content="user")], "user"),
([], ""),
],
)
def test_build_longterm_memory_query(self, mock_role_zero: RoleZero, memories, expected):
result = mock_role_zero._build_longterm_memory_query(memories)
assert result == expected