mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
use long-term memory in rolezero
This commit is contained in:
parent
4e82d86166
commit
f20d3ca2dd
8 changed files with 440 additions and 15 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -163,6 +163,7 @@ examples/image__vector_store.json
|
|||
examples/index_store.json
|
||||
.chroma
|
||||
.chroma_exp_data
|
||||
.role_memory_data
|
||||
*~$*
|
||||
workspace/*
|
||||
tmp
|
||||
|
|
|
|||
|
|
@ -1,8 +1,12 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import Optional, Union
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
from metagpt.base.base_serialization import BaseSerialization
|
||||
from metagpt.schema import Message
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
class BaseRole(BaseSerialization):
|
||||
|
|
|
|||
|
|
@ -104,3 +104,7 @@ class Memory(BaseModel):
|
|||
continue
|
||||
rsp += self.index[action]
|
||||
return rsp
|
||||
|
||||
def get_by_position(self, position: int) -> Message:
|
||||
"""Return the message by its position"""
|
||||
return self.storage[position]
|
||||
|
|
|
|||
60
metagpt/memory/role_zero_memory.py
Normal file
60
metagpt/memory/role_zero_memory.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.schema import LongTermMemoryItem, Message
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llama_index.core.schema import NodeWithScore
|
||||
|
||||
from metagpt.rag.engines import SimpleEngine
|
||||
|
||||
|
||||
class RoleZeroLongTermMemory(BaseModel):
|
||||
persist_path: str = Field(default=".role_memory_data", description="The directory to save data.")
|
||||
collection_name: str = Field(default="role_zero", description="The name of the collection, such as the role name.")
|
||||
|
||||
_rag_engine: Any = None
|
||||
|
||||
@property
|
||||
def rag_engine(self) -> "SimpleEngine":
|
||||
if self._rag_engine is None:
|
||||
self._rag_engine = self._resolve_rag_engine()
|
||||
|
||||
return self._rag_engine
|
||||
|
||||
def _resolve_rag_engine(self) -> "SimpleEngine":
|
||||
try:
|
||||
from metagpt.rag.engines import SimpleEngine
|
||||
from metagpt.rag.schema import ChromaRetrieverConfig, LLMRankerConfig
|
||||
except ImportError:
|
||||
raise ImportError("To use the RoleZeroMemory, you need to install the rag module.")
|
||||
|
||||
retriever_configs = [
|
||||
ChromaRetrieverConfig(persist_path=self.persist_path, collection_name=self.collection_name)
|
||||
]
|
||||
ranker_configs = [LLMRankerConfig()]
|
||||
|
||||
rag_engine = SimpleEngine.from_objs(retriever_configs=retriever_configs, ranker_configs=ranker_configs)
|
||||
|
||||
return rag_engine
|
||||
|
||||
def fetch(self, query: str) -> list[Message]:
|
||||
if not query:
|
||||
return []
|
||||
|
||||
nodes: list[NodeWithScore] = self.rag_engine.retrieve(query)
|
||||
|
||||
memories = []
|
||||
for node in nodes:
|
||||
item: LongTermMemoryItem = node.metadata["obj"]
|
||||
memories.append(item.user_message)
|
||||
memories.append(item.ai_message)
|
||||
|
||||
return memories
|
||||
|
||||
def add(self, item: LongTermMemoryItem):
|
||||
if not item:
|
||||
return
|
||||
|
||||
self.rag_engine.add_objs([item])
|
||||
|
|
@ -18,6 +18,7 @@ from metagpt.exp_pool import exp_cache
|
|||
from metagpt.exp_pool.context_builders import RoleZeroContextBuilder
|
||||
from metagpt.exp_pool.serializers import RoleZeroSerializer
|
||||
from metagpt.logs import logger
|
||||
from metagpt.memory.role_zero_memory import RoleZeroLongTermMemory
|
||||
from metagpt.prompts.di.role_zero import (
|
||||
ASK_HUMAN_COMMAND,
|
||||
CMD_PROMPT,
|
||||
|
|
@ -34,7 +35,7 @@ from metagpt.prompts.di.role_zero import (
|
|||
THOUGHT_GUIDANCE,
|
||||
)
|
||||
from metagpt.roles import Role
|
||||
from metagpt.schema import AIMessage, Message, UserMessage
|
||||
from metagpt.schema import AIMessage, LongTermMemoryItem, Message, UserMessage
|
||||
from metagpt.strategy.experience_retriever import DummyExpRetriever, ExpRetriever
|
||||
from metagpt.strategy.planner import Planner
|
||||
from metagpt.tools.libs.browser import Browser
|
||||
|
|
@ -42,6 +43,7 @@ from metagpt.tools.libs.editor import Editor
|
|||
from metagpt.tools.tool_recommend import BM25ToolRecommender, ToolRecommender
|
||||
from metagpt.tools.tool_registry import register_tool
|
||||
from metagpt.utils.common import CodeParser, any_to_str, extract_and_encode_images
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
from metagpt.utils.repair_llm_raw_output import (
|
||||
RepairType,
|
||||
repair_escape_error,
|
||||
|
|
@ -86,6 +88,8 @@ class RoleZero(Role):
|
|||
command_rsp: str = "" # the raw string containing the commands
|
||||
commands: list[dict] = [] # commands to be executed
|
||||
memory_k: int = 20 # number of memories (messages) to use as historical context
|
||||
enable_longterm_memory: bool = True # whether to use longterm memory
|
||||
longterm_memory: RoleZeroLongTermMemory = None
|
||||
use_fixed_sop: bool = False
|
||||
requirements_constraints: str = "" # the constraints in user requirements
|
||||
use_summary: bool = True # whether to summarize at the end
|
||||
|
|
@ -140,6 +144,19 @@ class RoleZero(Role):
|
|||
self._update_tool_execution()
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_longterm_memory(self) -> "RoleZero":
|
||||
"""Set longterm memory.
|
||||
|
||||
If enable_longterm_memory is True and longterm_memory is not set, set it.
|
||||
The role name will be used as the collection name.
|
||||
"""
|
||||
|
||||
if self.enable_longterm_memory and not self.longterm_memory:
|
||||
self.longterm_memory = RoleZeroLongTermMemory(collection_name=self.name.replace(" ", ""))
|
||||
|
||||
return self
|
||||
|
||||
def _update_tool_execution(self):
|
||||
pass
|
||||
|
||||
|
|
@ -154,7 +171,7 @@ class RoleZero(Role):
|
|||
return False
|
||||
|
||||
if not self.planner.plan.goal:
|
||||
self.planner.plan.goal = self.get_memories()[-1].content
|
||||
self.planner.plan.goal = self._get_all_memories()[-1].content
|
||||
self.requirements_constraints = await AnalyzeRequirementsRestrictions().run(self.planner.plan.goal)
|
||||
|
||||
### 1. Experience ###
|
||||
|
|
@ -186,7 +203,7 @@ class RoleZero(Role):
|
|||
)
|
||||
|
||||
### Recent Observation ###
|
||||
memory = self.rc.memory.get(self.memory_k)
|
||||
memory = self._fetch_memories()
|
||||
memory = await self.parse_browser_actions(memory)
|
||||
memory = self.parse_images(memory)
|
||||
|
||||
|
|
@ -202,7 +219,7 @@ class RoleZero(Role):
|
|||
|
||||
self.command_rsp = await self._check_duplicates(req, self.command_rsp)
|
||||
|
||||
self.rc.memory.add(AIMessage(content=self.command_rsp))
|
||||
self._add_memory(AIMessage(content=self.command_rsp))
|
||||
return True
|
||||
|
||||
@exp_cache(context_builder=RoleZeroContextBuilder(), serializer=RoleZeroSerializer())
|
||||
|
|
@ -245,12 +262,12 @@ class RoleZero(Role):
|
|||
commands, ok = await self._parse_commands(self.command_rsp)
|
||||
if not ok:
|
||||
error_msg = commands
|
||||
self.rc.memory.add(UserMessage(content=error_msg))
|
||||
self._add_memory(UserMessage(content=error_msg))
|
||||
return error_msg
|
||||
logger.info(f"Commands: \n{commands}")
|
||||
outputs = await self._run_commands(commands)
|
||||
logger.info(f"Commands outputs: \n{outputs}")
|
||||
self.rc.memory.add(UserMessage(content=outputs))
|
||||
self._add_memory(UserMessage(content=outputs))
|
||||
|
||||
return AIMessage(
|
||||
content=f"I have finished the task, please mark my task as finished. Outputs: {outputs}",
|
||||
|
|
@ -303,7 +320,7 @@ class RoleZero(Role):
|
|||
return rsp_msg, ""
|
||||
|
||||
# routing
|
||||
memory = self.get_memories(k=self.memory_k)
|
||||
memory = self._fetch_memories()
|
||||
context = self.llm.format_msg(memory + [UserMessage(content=QUICK_THINK_PROMPT)])
|
||||
async with ThoughtReporter() as reporter:
|
||||
await reporter.async_report({"type": "classify"})
|
||||
|
|
@ -328,7 +345,7 @@ class RoleZero(Role):
|
|||
answer = await SearchEnhancedQA().run(query)
|
||||
|
||||
if answer:
|
||||
self.rc.memory.add(AIMessage(content=answer, cause_by=RunCommand))
|
||||
self._add_memory(AIMessage(content=answer, cause_by=RunCommand))
|
||||
await self.reply_to_human(content=answer)
|
||||
rsp_msg = AIMessage(
|
||||
content="Complete run",
|
||||
|
|
@ -339,7 +356,7 @@ class RoleZero(Role):
|
|||
return rsp_msg, intent_result
|
||||
|
||||
async def _check_duplicates(self, req: list[dict], command_rsp: str):
|
||||
past_rsp = [mem.content for mem in self.rc.memory.get(self.memory_k)]
|
||||
past_rsp = [mem.content for mem in self._fetch_memories()]
|
||||
if command_rsp in past_rsp:
|
||||
# Normal response with thought contents are highly unlikely to reproduce
|
||||
# If an identical response is detected, it is a bad response, mostly due to LLM repeating generated content
|
||||
|
|
@ -479,7 +496,7 @@ class RoleZero(Role):
|
|||
|
||||
def _retrieve_experience(self) -> str:
|
||||
"""Default implementation of experience retrieval. Can be overwritten in subclasses."""
|
||||
context = [str(msg) for msg in self.rc.memory.get(self.memory_k)]
|
||||
context = [str(msg) for msg in self._fetch_memories()]
|
||||
context = "\n\n".join(context)
|
||||
example = self.experience_retriever.retrieve(context=context)
|
||||
return example
|
||||
|
|
@ -504,9 +521,9 @@ class RoleZero(Role):
|
|||
|
||||
async def _end(self):
|
||||
self._set_state(-1)
|
||||
memory = self.rc.memory.get(self.memory_k)
|
||||
memory = self._fetch_memories()
|
||||
# Ensure reply to the human before the "end" command is executed. Hard code k=5 for checking.
|
||||
if not any(["reply_to_human" in memory.content for memory in self.get_memories(k=5)]):
|
||||
if not any(["reply_to_human" in memory.content for memory in self._fetch_memories(k=5)]):
|
||||
logger.info("manually reply to human")
|
||||
pattern = r"\[Language Restrictions\](.*?)\n"
|
||||
match = re.search(pattern, self.requirements_constraints, re.DOTALL)
|
||||
|
|
@ -515,10 +532,95 @@ class RoleZero(Role):
|
|||
await reporter.async_report({"type": "quick"})
|
||||
reply_content = await self.llm.aask(self.llm.format_msg(memory + [UserMessage(reply_to_human_prompt)]))
|
||||
await self.reply_to_human(content=reply_content)
|
||||
self.rc.memory.add(AIMessage(content=reply_content, cause_by=RunCommand))
|
||||
self._add_memory(AIMessage(content=reply_content, cause_by=RunCommand))
|
||||
outputs = ""
|
||||
# Summary of the Completed Task and Deliverables
|
||||
if self.use_summary:
|
||||
logger.info("end current run and summarize")
|
||||
outputs = await self.llm.aask(self.llm.format_msg(memory + [UserMessage(SUMMARY_PROMPT)]))
|
||||
return outputs
|
||||
|
||||
def _get_all_memories(self) -> list[Message]:
|
||||
return self._fetch_memories(k=0)
|
||||
|
||||
def _fetch_memories(self, k: Optional[int] = None) -> list[Message]:
|
||||
"""Fetches recent memories and optionally combines them with related long-term memories.
|
||||
|
||||
If long-term memory is not enabled or the last message is not from the user,
|
||||
it returns the recent memories without fetching from long-term memory.
|
||||
|
||||
Args:
|
||||
k (Optional[int]): The number of recent memories to fetch. If None, defaults to self.memory_k.
|
||||
|
||||
Returns:
|
||||
List[Message]: A list of messages representing the combined memories.
|
||||
"""
|
||||
|
||||
if k is None:
|
||||
k = self.memory_k
|
||||
|
||||
memories = self.rc.memory.get(k)
|
||||
|
||||
if not self._should_use_longterm_memory(k=k, k_memories=memories):
|
||||
return memories
|
||||
|
||||
related_memories = self.longterm_memory.fetch(memories[-1].content)
|
||||
logger.info(f"Fetched {len(related_memories)} long-term memories.")
|
||||
|
||||
if related_memories and self._is_first_message_from_ai(memories):
|
||||
memories = memories[1:]
|
||||
|
||||
final_memories = related_memories + memories
|
||||
|
||||
return final_memories
|
||||
|
||||
def _add_memory(self, message: Message):
|
||||
self.rc.memory.add(message)
|
||||
|
||||
if not self._should_use_longterm_memory():
|
||||
return
|
||||
|
||||
self._transfer_to_longterm_memory()
|
||||
|
||||
def _should_use_longterm_memory(self, k: int = None, k_memories: list[Message] = None) -> bool:
|
||||
"""Determines if long-term memory should be used.
|
||||
|
||||
Long-term memory is used if:
|
||||
- k is not 0.
|
||||
- k_memories is None or k_memories is not empty, and the last message is a user message.
|
||||
- Long-term memory usage is enabled.
|
||||
- The count of recent memories is greater than self.memory_k.
|
||||
"""
|
||||
|
||||
conds = [
|
||||
k != 0,
|
||||
k_memories is None or self._is_last_message_from_user(k_memories),
|
||||
self.enable_longterm_memory,
|
||||
self.rc.memory.count() > self.memory_k,
|
||||
]
|
||||
|
||||
return all(conds)
|
||||
|
||||
def _transfer_to_longterm_memory(self):
|
||||
item = self._get_longterm_memory_item()
|
||||
self.longterm_memory.add(item)
|
||||
|
||||
@handle_exception
|
||||
def _get_longterm_memory_item(self) -> Optional[LongTermMemoryItem]:
|
||||
"""Retrieves the most recent pair of user and AI messages before the last k messages."""
|
||||
|
||||
index = -(self.memory_k + 1)
|
||||
message = self.rc.memory.get_by_position(index)
|
||||
if not message.is_ai_message():
|
||||
return None
|
||||
|
||||
index = -(self.memory_k + 2)
|
||||
user_message = self.rc.memory.get_by_position(index)
|
||||
|
||||
return LongTermMemoryItem(user_message=user_message, ai_message=message)
|
||||
|
||||
def _is_last_message_from_user(self, memories: list[Message]) -> bool:
|
||||
return bool(memories and memories[-1].is_user_message())
|
||||
|
||||
def _is_first_message_from_ai(self, memories: list[Message]) -> bool:
|
||||
return bool(memories and memories[0].is_ai_message())
|
||||
|
|
|
|||
|
|
@ -408,6 +408,12 @@ class Message(BaseModel):
|
|||
dynamic_class = create_model(class_name, **{key: (value.__class__, ...) for key, value in kvs.items()})
|
||||
return dynamic_class.model_validate(kvs)
|
||||
|
||||
def is_user_message(self):
|
||||
return self.role == "user"
|
||||
|
||||
def is_ai_message(self):
|
||||
return self.role == "assistant"
|
||||
|
||||
|
||||
class UserMessage(Message):
|
||||
"""便于支持OpenAI的消息
|
||||
|
|
@ -955,3 +961,11 @@ class BaseEnum(Enum):
|
|||
obj._value_ = value
|
||||
obj.desc = desc
|
||||
return obj
|
||||
|
||||
|
||||
class LongTermMemoryItem(BaseModel):
|
||||
user_message: Message
|
||||
ai_message: Message
|
||||
|
||||
def rag_key(self) -> str:
|
||||
return self.user_message.content
|
||||
|
|
|
|||
48
tests/metagpt/memory/test_role_zero_memory.py
Normal file
48
tests/metagpt/memory/test_role_zero_memory.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.memory.role_zero_memory import RoleZeroLongTermMemory
|
||||
from metagpt.schema import AIMessage, LongTermMemoryItem, UserMessage
|
||||
|
||||
|
||||
class TestRoleZeroLongTermMemory:
|
||||
@pytest.fixture
|
||||
def mock_memory(self, mocker) -> RoleZeroLongTermMemory:
|
||||
memory = RoleZeroLongTermMemory()
|
||||
memory._resolve_rag_engine = mocker.Mock()
|
||||
return memory
|
||||
|
||||
def test_fetch_empty_query(self, mock_memory: RoleZeroLongTermMemory):
|
||||
assert mock_memory.fetch("") == []
|
||||
|
||||
def test_fetch(self, mocker, mock_memory: RoleZeroLongTermMemory):
|
||||
mock_node1 = mocker.Mock()
|
||||
mock_node2 = mocker.Mock()
|
||||
mock_node1.metadata = {
|
||||
"obj": LongTermMemoryItem(user_message=UserMessage(content="user1"), ai_message=AIMessage(content="ai1"))
|
||||
}
|
||||
mock_node2.metadata = {
|
||||
"obj": LongTermMemoryItem(user_message=UserMessage(content="user2"), ai_message=AIMessage(content="ai2"))
|
||||
}
|
||||
|
||||
mock_memory.rag_engine.retrieve.return_value = [mock_node1, mock_node2]
|
||||
|
||||
result = mock_memory.fetch("test query")
|
||||
|
||||
assert len(result) == 4
|
||||
assert isinstance(result[0], UserMessage)
|
||||
assert isinstance(result[1], AIMessage)
|
||||
assert result[0].content == "user1"
|
||||
assert result[1].content == "ai1"
|
||||
assert result[2].content == "user2"
|
||||
assert result[3].content == "ai2"
|
||||
|
||||
mock_memory.rag_engine.retrieve.assert_called_once_with("test query")
|
||||
|
||||
def test_add_empty_item(self, mock_memory: RoleZeroLongTermMemory):
|
||||
mock_memory.add(None)
|
||||
mock_memory.rag_engine.add_objs.assert_not_called()
|
||||
|
||||
def test_add_item(self, mock_memory: RoleZeroLongTermMemory):
|
||||
item = LongTermMemoryItem(user_message=UserMessage(content="user"), ai_message=AIMessage(content="ai"))
|
||||
mock_memory.add(item)
|
||||
mock_memory.rag_engine.add_objs.assert_called_once_with([item])
|
||||
192
tests/metagpt/roles/di/test_role_zero.py
Normal file
192
tests/metagpt/roles/di/test_role_zero.py
Normal file
|
|
@ -0,0 +1,192 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.roles.di.role_zero import RoleZero
|
||||
from metagpt.schema import AIMessage, LongTermMemoryItem, Message, UserMessage
|
||||
|
||||
|
||||
class TestRoleZero:
|
||||
@pytest.fixture
|
||||
def mock_role_zero(self, mocker):
|
||||
role_zero = RoleZero()
|
||||
role_zero.rc.memory = mocker.Mock()
|
||||
role_zero.longterm_memory = mocker.Mock()
|
||||
return role_zero
|
||||
|
||||
def test_get_all_memories(self, mocker, mock_role_zero: RoleZero):
|
||||
mock_memories = [Message(content="test1"), Message(content="test2")]
|
||||
mock_role_zero._fetch_memories = mocker.Mock(return_value=mock_memories)
|
||||
|
||||
result = mock_role_zero._get_all_memories()
|
||||
|
||||
assert result == mock_memories
|
||||
mock_role_zero._fetch_memories.assert_called_once_with(k=0)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"k,should_use_ltm,memories,related_memories,is_first_from_ai,expected",
|
||||
[
|
||||
(
|
||||
None,
|
||||
False,
|
||||
[UserMessage(content="user"), AIMessage(content="ai")],
|
||||
[],
|
||||
False,
|
||||
[UserMessage(content="user"), AIMessage(content="ai")],
|
||||
),
|
||||
(
|
||||
1,
|
||||
True,
|
||||
[UserMessage(content="user")],
|
||||
[Message(content="related")],
|
||||
False,
|
||||
[Message(content="related"), UserMessage(content="user")],
|
||||
),
|
||||
(
|
||||
None,
|
||||
True,
|
||||
[AIMessage(content="ai1"), UserMessage(content="user"), AIMessage(content="ai2")],
|
||||
[Message(content="related")],
|
||||
True,
|
||||
[Message(content="related"), UserMessage(content="user"), AIMessage(content="ai2")],
|
||||
),
|
||||
(
|
||||
None,
|
||||
True,
|
||||
[UserMessage(content="user"), AIMessage(content="ai")],
|
||||
[Message(content="related")],
|
||||
False,
|
||||
[Message(content="related"), UserMessage(content="user"), AIMessage(content="ai")],
|
||||
),
|
||||
(
|
||||
0,
|
||||
False,
|
||||
[UserMessage(content="user"), AIMessage(content="ai")],
|
||||
[],
|
||||
False,
|
||||
[UserMessage(content="user"), AIMessage(content="ai")],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_fetch_memories(
|
||||
self,
|
||||
mocker,
|
||||
mock_role_zero: RoleZero,
|
||||
k,
|
||||
should_use_ltm,
|
||||
memories,
|
||||
related_memories,
|
||||
is_first_from_ai,
|
||||
expected,
|
||||
):
|
||||
mock_role_zero.memory_k = 2
|
||||
mock_role_zero.rc.memory.get = mocker.Mock(return_value=memories)
|
||||
mock_role_zero._should_use_longterm_memory = mocker.Mock(return_value=should_use_ltm)
|
||||
mock_role_zero.longterm_memory.fetch = mocker.Mock(return_value=related_memories)
|
||||
mock_role_zero._is_first_message_from_ai = mocker.Mock(return_value=is_first_from_ai)
|
||||
|
||||
result = mock_role_zero._fetch_memories(k)
|
||||
|
||||
assert len(result) == len(expected)
|
||||
for actual, expected_msg in zip(result, expected):
|
||||
assert actual.role == expected_msg.role
|
||||
assert actual.content == expected_msg.content
|
||||
|
||||
really_k = k if k is not None else mock_role_zero.memory_k
|
||||
mock_role_zero.rc.memory.get.assert_called_once_with(really_k)
|
||||
|
||||
if k != 0:
|
||||
mock_role_zero._should_use_longterm_memory.assert_called_once_with(k=really_k, k_memories=memories)
|
||||
|
||||
if should_use_ltm:
|
||||
mock_role_zero.longterm_memory.fetch.assert_called_once_with(memories[-1].content)
|
||||
mock_role_zero._is_first_message_from_ai.assert_called_once_with(memories)
|
||||
|
||||
def test_add_memory(self, mocker, mock_role_zero: RoleZero):
|
||||
message = AIMessage(content="ai")
|
||||
mock_role_zero.rc.memory.add = mocker.Mock()
|
||||
mock_role_zero._should_use_longterm_memory = mocker.Mock(return_value=True)
|
||||
mock_role_zero._transfer_to_longterm_memory = mocker.Mock()
|
||||
|
||||
mock_role_zero._add_memory(message)
|
||||
|
||||
mock_role_zero.rc.memory.add.assert_called_once_with(message)
|
||||
mock_role_zero._transfer_to_longterm_memory.assert_called_once()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"k,enable_longterm_memory,memory_count,k_memories,expected",
|
||||
[
|
||||
(0, True, 30, None, False), # k is 0
|
||||
(None, False, 30, None, False), # Long-term memory usage is disabled
|
||||
(None, True, 10, None, False), # Memory count is less than or equal to memory_k
|
||||
(None, True, 30, [], False), # k_memories is empty
|
||||
(None, True, 30, [AIMessage(content="ai")], False), # Last message in k_memories is not a user message
|
||||
(None, True, 30, [AIMessage(content="ai"), UserMessage(content="user")], True), # All conditions are met
|
||||
],
|
||||
)
|
||||
def test_should_use_longterm_memory(
|
||||
self, mocker, mock_role_zero: RoleZero, k, enable_longterm_memory, memory_count, k_memories, expected
|
||||
):
|
||||
mock_role_zero.enable_longterm_memory = enable_longterm_memory
|
||||
mock_role_zero.rc.memory.count = mocker.Mock(return_value=memory_count)
|
||||
mock_role_zero.memory_k = 20
|
||||
|
||||
result = mock_role_zero._should_use_longterm_memory(k, k_memories)
|
||||
|
||||
assert result == expected
|
||||
|
||||
def test_transfer_to_longterm_memory(self, mocker, mock_role_zero: RoleZero):
|
||||
mock_item = LongTermMemoryItem(user_message=UserMessage(content="user"), ai_message=AIMessage(content="ai"))
|
||||
mock_role_zero._get_longterm_memory_item = mocker.Mock(return_value=mock_item)
|
||||
mock_role_zero.longterm_memory = mocker.Mock()
|
||||
|
||||
mock_role_zero._transfer_to_longterm_memory()
|
||||
|
||||
mock_role_zero.longterm_memory.add.assert_called_once_with(mock_item)
|
||||
|
||||
def test_get_longterm_memory_item(self, mocker, mock_role_zero: RoleZero):
|
||||
mock_role_zero.memory_k = 2
|
||||
mock_messages = [
|
||||
UserMessage(content="user1"),
|
||||
AIMessage(content="ai1"),
|
||||
UserMessage(content="user2"),
|
||||
AIMessage(content="ai2"),
|
||||
UserMessage(content="user3"), # memory_k + 2
|
||||
AIMessage(content="ai3"), # memory_k + 1
|
||||
UserMessage(content="recent1"),
|
||||
AIMessage(content="recent2"),
|
||||
]
|
||||
|
||||
mock_role_zero.rc.memory.get_by_position = mocker.Mock(side_effect=lambda i: mock_messages[i])
|
||||
mock_role_zero.rc.memory.count = mocker.Mock(return_value=len(mock_messages))
|
||||
|
||||
result = mock_role_zero._get_longterm_memory_item()
|
||||
|
||||
assert isinstance(result, LongTermMemoryItem)
|
||||
assert result.user_message.content == "user3"
|
||||
assert result.ai_message.content == "ai3"
|
||||
|
||||
mock_role_zero.rc.memory.get_by_position.assert_any_call(-(mock_role_zero.memory_k + 1))
|
||||
mock_role_zero.rc.memory.get_by_position.assert_any_call(-(mock_role_zero.memory_k + 2))
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"memories,expected",
|
||||
[
|
||||
([UserMessage(content="user")], True),
|
||||
([AIMessage(content="ai")], False),
|
||||
([], False),
|
||||
],
|
||||
)
|
||||
def test_is_last_message_from_user(self, mock_role_zero: RoleZero, memories, expected):
|
||||
result = mock_role_zero._is_last_message_from_user(memories)
|
||||
assert result == expected
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"memories,expected",
|
||||
[
|
||||
([AIMessage(content="ai")], True),
|
||||
([UserMessage(content="user")], False),
|
||||
([], False),
|
||||
],
|
||||
)
|
||||
def test_is_first_message_from_ai(self, mock_role_zero: RoleZero, memories, expected):
|
||||
result = mock_role_zero._is_first_message_from_ai(memories)
|
||||
assert result == expected
|
||||
Loading…
Add table
Add a link
Reference in a new issue