mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-04 21:32:38 +02:00
Merge branch 'feat-memory-opt' into 'mgx_ops'
use long-term memory in rolezero See merge request pub/MetaGPT!369
This commit is contained in:
commit
e0ec17e6aa
13 changed files with 418 additions and 6 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -163,6 +163,7 @@ examples/image__vector_store.json
|
|||
examples/index_store.json
|
||||
.chroma
|
||||
.chroma_exp_data
|
||||
.role_memory_data
|
||||
*~$*
|
||||
workspace/*
|
||||
tmp
|
||||
|
|
|
|||
|
|
@ -83,6 +83,8 @@ exp_pool:
|
|||
use_llm_ranker: true # Default is `true`, it will use LLM Reranker to get better result.
|
||||
collection_name: experience_pool # When `retrieval_type` is `chroma`, `collection_name` is the collection name in chromadb.
|
||||
|
||||
role_zero:
|
||||
enable_longterm_memory: false # Whether to use long-term memory. Default is `false`.
|
||||
|
||||
azure_tts_subscription_key: "YOUR_SUBSCRIPTION_KEY"
|
||||
azure_tts_region: "eastus"
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from metagpt.configs.mermaid_config import MermaidConfig
|
|||
from metagpt.configs.omniparse_config import OmniParseConfig
|
||||
from metagpt.configs.redis_config import RedisConfig
|
||||
from metagpt.configs.role_custom_config import RoleCustomConfig
|
||||
from metagpt.configs.role_zero_config import RoleZeroConfig
|
||||
from metagpt.configs.s3_config import S3Config
|
||||
from metagpt.configs.search_config import SearchConfig
|
||||
from metagpt.configs.workspace_config import WorkspaceConfig
|
||||
|
|
@ -89,6 +90,9 @@ class Config(CLIParams, YamlModel):
|
|||
# Role's custom configuration
|
||||
roles: Optional[List[RoleCustomConfig]] = None
|
||||
|
||||
# RoleZero's configuration
|
||||
role_zero: RoleZeroConfig = Field(default_factory=RoleZeroConfig)
|
||||
|
||||
omniparse: Optional[OmniParseConfig] = None
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
7
metagpt/configs/role_zero_config.py
Normal file
7
metagpt/configs/role_zero_config.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
from pydantic import Field
|
||||
|
||||
from metagpt.utils.yaml_model import YamlModel
|
||||
|
||||
|
||||
class RoleZeroConfig(YamlModel):
|
||||
enable_longterm_memory: bool = Field(default=False, description="Whether to use long-term memory.")
|
||||
|
|
@ -157,3 +157,6 @@ SWE_SETUP_PATH = get_metagpt_package_root() / "metagpt/tools/swe_agent_commands/
|
|||
|
||||
# experience pool
|
||||
EXPERIENCE_MASK = "<experience>"
|
||||
|
||||
# TeamLeader's name
|
||||
TEAMLEADER_NAME = "Mike"
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from metagpt.actions import (
|
|||
WriteTest,
|
||||
)
|
||||
from metagpt.actions.summarize_code import SummarizeCode
|
||||
from metagpt.const import AGENT, IMAGES
|
||||
from metagpt.const import AGENT, IMAGES, TEAMLEADER_NAME
|
||||
from metagpt.environment.base_env import Environment
|
||||
from metagpt.logs import get_human_input
|
||||
from metagpt.roles import Architect, ProductManager, ProjectManager, Role
|
||||
|
|
@ -31,7 +31,7 @@ class MGXEnv(Environment, SerializationMixin):
|
|||
"""let the team leader take over message publishing"""
|
||||
message = self.attach_images(message) # for multi-modal message
|
||||
|
||||
tl = self.get_role("Mike") # TeamLeader's name is Mike
|
||||
tl = self.get_role(TEAMLEADER_NAME) # TeamLeader's name is Mike
|
||||
|
||||
if user_defined_recipient:
|
||||
# human user's direct chat message to a certain role
|
||||
|
|
|
|||
|
|
@ -7,13 +7,14 @@
|
|||
@Modified By: mashenquan, 2023-11-1. According to RFC 116: Updated the type of index key.
|
||||
"""
|
||||
from collections import defaultdict
|
||||
from typing import DefaultDict, Iterable, Set
|
||||
from typing import DefaultDict, Iterable, Optional, Set
|
||||
|
||||
from pydantic import BaseModel, Field, SerializeAsAny
|
||||
|
||||
from metagpt.const import IGNORED_MESSAGE_ID
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import any_to_str, any_to_str_set
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
|
||||
|
||||
class Memory(BaseModel):
|
||||
|
|
@ -104,3 +105,8 @@ class Memory(BaseModel):
|
|||
continue
|
||||
rsp += self.index[action]
|
||||
return rsp
|
||||
|
||||
@handle_exception
|
||||
def get_by_position(self, position: int) -> Optional[Message]:
|
||||
"""Returns the message at the given position if valid, otherwise returns None"""
|
||||
return self.storage[position]
|
||||
|
|
|
|||
189
metagpt/memory/role_zero_memory.py
Normal file
189
metagpt/memory/role_zero_memory.py
Normal file
|
|
@ -0,0 +1,189 @@
|
|||
"""
|
||||
This module implements a memory system combining short-term and long-term storage for AI role memory management.
|
||||
It utilizes a RAG (Retrieval-Augmented Generation) engine for long-term memory storage and retrieval.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions import UserRequirement
|
||||
from metagpt.const import TEAMLEADER_NAME
|
||||
from metagpt.logs import logger
|
||||
from metagpt.memory import Memory
|
||||
from metagpt.schema import LongTermMemoryItem, Message
|
||||
from metagpt.utils.common import any_to_str
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llama_index.core.schema import NodeWithScore
|
||||
|
||||
from metagpt.rag.engines import SimpleEngine
|
||||
|
||||
|
||||
class RoleZeroLongTermMemory(Memory):
|
||||
"""
|
||||
Implements a memory system combining short-term and long-term storage using a RAG engine.
|
||||
Transfers old memories to long-term storage when short-term capacity is reached.
|
||||
Retrieves combined short-term and long-term memories as needed.
|
||||
"""
|
||||
|
||||
persist_path: str = Field(default=".role_memory_data", description="The directory to save data.")
|
||||
collection_name: str = Field(default="role_zero", description="The name of the collection, such as the role name.")
|
||||
memory_k: int = Field(default=200, description="The capacity of short-term memory.")
|
||||
|
||||
_rag_engine: Any = None
|
||||
|
||||
@property
|
||||
def rag_engine(self) -> "SimpleEngine":
|
||||
if self._rag_engine is None:
|
||||
self._rag_engine = self._resolve_rag_engine()
|
||||
|
||||
return self._rag_engine
|
||||
|
||||
def _resolve_rag_engine(self) -> "SimpleEngine":
|
||||
"""Lazy loading of the RAG engine components, ensuring they are only loaded when needed.
|
||||
|
||||
It uses `Chroma` for retrieval and `LLMRanker` for ranking.
|
||||
"""
|
||||
|
||||
try:
|
||||
from metagpt.rag.engines import SimpleEngine
|
||||
from metagpt.rag.schema import ChromaRetrieverConfig, LLMRankerConfig
|
||||
except ImportError:
|
||||
raise ImportError("To use the RoleZeroMemory, you need to install the rag module.")
|
||||
|
||||
retriever_configs = [
|
||||
ChromaRetrieverConfig(persist_path=self.persist_path, collection_name=self.collection_name)
|
||||
]
|
||||
ranker_configs = [LLMRankerConfig()]
|
||||
|
||||
rag_engine = SimpleEngine.from_objs(retriever_configs=retriever_configs, ranker_configs=ranker_configs)
|
||||
|
||||
return rag_engine
|
||||
|
||||
def add(self, message: Message):
|
||||
"""Add a new message and potentially transfer it to long-term memory."""
|
||||
|
||||
super().add(message)
|
||||
|
||||
if not self._should_use_longterm_memory_for_add():
|
||||
return
|
||||
|
||||
self._transfer_to_longterm_memory()
|
||||
|
||||
def get(self, k=0) -> list[Message]:
|
||||
"""Return recent memories and optionally combines them with related long-term memories."""
|
||||
|
||||
memories = super().get(k)
|
||||
|
||||
if not self._should_use_longterm_memory_for_get(k=k):
|
||||
return memories
|
||||
|
||||
query = self._build_longterm_memory_query()
|
||||
related_memories = self._fetch_longterm_memories(query)
|
||||
logger.info(f"Fetched {len(related_memories)} long-term memories.")
|
||||
|
||||
final_memories = related_memories + memories
|
||||
|
||||
return final_memories
|
||||
|
||||
def _should_use_longterm_memory_for_add(self) -> bool:
|
||||
"""Determines if long-term memory should be used for add."""
|
||||
|
||||
return self.count() > self.memory_k
|
||||
|
||||
def _should_use_longterm_memory_for_get(self, k: int) -> bool:
|
||||
"""Determines if long-term memory should be used for get.
|
||||
|
||||
Long-term memory is used if:
|
||||
- k is not 0.
|
||||
- The last message is from user requirement.
|
||||
- The count of recent memories is greater than self.memory_k.
|
||||
"""
|
||||
|
||||
conds = [
|
||||
k != 0,
|
||||
self._is_last_message_from_user_requirement(),
|
||||
self.count() > self.memory_k,
|
||||
]
|
||||
|
||||
return all(conds)
|
||||
|
||||
def _transfer_to_longterm_memory(self):
|
||||
item = self._get_longterm_memory_item()
|
||||
self._add_to_longterm_memory(item)
|
||||
|
||||
@handle_exception
|
||||
def _get_longterm_memory_item(self) -> Optional[LongTermMemoryItem]:
|
||||
"""Retrieves the most recent message before the last k messages."""
|
||||
|
||||
index = -(self.memory_k + 1)
|
||||
message = self.get_by_position(index)
|
||||
|
||||
return LongTermMemoryItem(message=message)
|
||||
|
||||
def _add_to_longterm_memory(self, item: LongTermMemoryItem):
|
||||
"""Adds a long-term memory item to the RAG engine."""
|
||||
|
||||
if not item:
|
||||
return
|
||||
|
||||
self.rag_engine.add_objs([item])
|
||||
|
||||
def _build_longterm_memory_query(self) -> str:
|
||||
"""Build the content used to query related long-term memory.
|
||||
|
||||
Default is to get the most recent user message, or an empty string if none is found.
|
||||
"""
|
||||
|
||||
message = self._get_the_last_message()
|
||||
|
||||
return message.content if message else ""
|
||||
|
||||
def _get_the_last_message(self) -> Optional[Message]:
|
||||
if not self.count():
|
||||
return None
|
||||
|
||||
return self.get_by_position(-1)
|
||||
|
||||
def _is_last_message_from_user_requirement(self) -> bool:
|
||||
"""Checks if the last message is from a user requirement or sent by the team leader."""
|
||||
|
||||
message = self._get_the_last_message()
|
||||
|
||||
if not message:
|
||||
return False
|
||||
|
||||
is_user_message = message.is_user_message()
|
||||
cause_by_user_requirement = message.cause_by == any_to_str(UserRequirement)
|
||||
sent_from_team_leader = message.sent_from == TEAMLEADER_NAME
|
||||
|
||||
return is_user_message and (cause_by_user_requirement or sent_from_team_leader)
|
||||
|
||||
def _fetch_longterm_memories(self, query: str) -> list[Message]:
|
||||
"""Fetches long-term memories based on a query.
|
||||
|
||||
Args:
|
||||
query (str): The query string to search for relevant memories.
|
||||
|
||||
Returns:
|
||||
list[Message]: A list of user and AI messages related to the query.
|
||||
"""
|
||||
|
||||
if not query:
|
||||
return []
|
||||
|
||||
nodes = self.rag_engine.retrieve(query)
|
||||
items = self._get_items_from_nodes(nodes)
|
||||
memories = [item.message for item in items]
|
||||
|
||||
return memories
|
||||
|
||||
def _get_items_from_nodes(self, nodes: list["NodeWithScore"]) -> list[LongTermMemoryItem]:
|
||||
"""Get items from nodes and arrange them in order of their `created_at`."""
|
||||
|
||||
items: list[LongTermMemoryItem] = [node.metadata["obj"] for node in nodes]
|
||||
items.sort(key=lambda item: item.created_at)
|
||||
|
||||
return items
|
||||
|
|
@ -23,6 +23,7 @@ Doc: 9, Relevance: 7
|
|||
- Evaluate the relevance between the question and the documents.
|
||||
- The relevance score is a number from 1-10 based on how relevant you think the document is to the question.
|
||||
- Do not include any documents that are not relevant to the question.
|
||||
- If none of the documents provided contain information that directly answers the question, simply respond with "no relevant documents".
|
||||
|
||||
## Constraint
|
||||
Format: Just print the result in format like **Format Example**.
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from metagpt.exp_pool import exp_cache
|
|||
from metagpt.exp_pool.context_builders import RoleZeroContextBuilder
|
||||
from metagpt.exp_pool.serializers import RoleZeroSerializer
|
||||
from metagpt.logs import logger
|
||||
from metagpt.memory.role_zero_memory import RoleZeroLongTermMemory
|
||||
from metagpt.prompts.di.role_zero import (
|
||||
ASK_HUMAN_COMMAND,
|
||||
CMD_PROMPT,
|
||||
|
|
@ -166,6 +167,24 @@ class RoleZero(Role):
|
|||
self._update_tool_execution()
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_longterm_memory(self) -> "RoleZero":
|
||||
"""Set up long-term memory for the role if enabled in the configuration.
|
||||
|
||||
If `enable_longterm_memory` is True, set up long-term memory.
|
||||
The role name will be used as the collection name.
|
||||
"""
|
||||
|
||||
if self.config.role_zero.enable_longterm_memory:
|
||||
self.rc.memory = RoleZeroLongTermMemory(
|
||||
**self.rc.memory.model_dump(),
|
||||
collection_name=self.name.replace(" ", ""),
|
||||
memory_k=self.memory_k,
|
||||
)
|
||||
logger.info(f"Long-term memory set for role '{self.name}'")
|
||||
|
||||
return self
|
||||
|
||||
def _update_tool_execution(self):
|
||||
pass
|
||||
|
||||
|
|
@ -289,12 +308,12 @@ class RoleZero(Role):
|
|||
self.rc.memory.add(AIMessage(content=self.command_rsp))
|
||||
if not ok:
|
||||
error_msg = commands
|
||||
self.rc.memory.add(UserMessage(content=error_msg))
|
||||
self.rc.memory.add(UserMessage(content=error_msg, cause_by=RunCommand))
|
||||
return error_msg
|
||||
logger.info(f"Commands: \n{commands}")
|
||||
outputs = await self._run_commands(commands)
|
||||
logger.info(f"Commands outputs: \n{outputs}")
|
||||
self.rc.memory.add(UserMessage(content=outputs))
|
||||
self.rc.memory.add(UserMessage(content=outputs, cause_by=RunCommand))
|
||||
|
||||
return AIMessage(
|
||||
content=f"I have finished the task, please mark my task as finished. Outputs: {outputs}",
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from typing import Annotated
|
|||
from pydantic import Field
|
||||
|
||||
from metagpt.actions.di.run_command import RunCommand
|
||||
from metagpt.const import TEAMLEADER_NAME
|
||||
from metagpt.prompts.di.team_leader import (
|
||||
FINISH_CURRENT_TASK_CMD,
|
||||
TL_INFO,
|
||||
|
|
@ -19,7 +20,7 @@ from metagpt.tools.tool_registry import register_tool
|
|||
|
||||
@register_tool(include_functions=["publish_team_message"])
|
||||
class TeamLeader(RoleZero):
|
||||
name: str = "Mike"
|
||||
name: str = TEAMLEADER_NAME
|
||||
profile: str = "Team Leader"
|
||||
goal: str = "Manage a team to assist users"
|
||||
thought_guidance: str = TL_THOUGHT_GUIDANCE
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
import json
|
||||
import os.path
|
||||
import time
|
||||
import uuid
|
||||
from abc import ABC
|
||||
from asyncio import Queue, QueueEmpty, wait_for
|
||||
|
|
@ -408,6 +409,12 @@ class Message(BaseModel):
|
|||
dynamic_class = create_model(class_name, **{key: (value.__class__, ...) for key, value in kvs.items()})
|
||||
return dynamic_class.model_validate(kvs)
|
||||
|
||||
def is_user_message(self) -> bool:
|
||||
return self.role == "user"
|
||||
|
||||
def is_ai_message(self) -> bool:
|
||||
return self.role == "assistant"
|
||||
|
||||
|
||||
class UserMessage(Message):
|
||||
"""便于支持OpenAI的消息
|
||||
|
|
@ -955,3 +962,11 @@ class BaseEnum(Enum):
|
|||
obj._value_ = value
|
||||
obj.desc = desc
|
||||
return obj
|
||||
|
||||
|
||||
class LongTermMemoryItem(BaseModel):
|
||||
message: Message
|
||||
created_at: Optional[float] = Field(default_factory=time.time)
|
||||
|
||||
def rag_key(self) -> str:
|
||||
return self.message.content
|
||||
|
|
|
|||
164
tests/metagpt/memory/test_role_zero_memory.py
Normal file
164
tests/metagpt/memory/test_role_zero_memory.py
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
from datetime import datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions import UserRequirement
|
||||
from metagpt.const import TEAMLEADER_NAME
|
||||
from metagpt.memory.role_zero_memory import RoleZeroLongTermMemory
|
||||
from metagpt.schema import AIMessage, LongTermMemoryItem, Message, UserMessage
|
||||
|
||||
|
||||
class TestRoleZeroLongTermMemory:
|
||||
@pytest.fixture
|
||||
def mock_memory(self, mocker) -> RoleZeroLongTermMemory:
|
||||
memory = RoleZeroLongTermMemory()
|
||||
memory._resolve_rag_engine = mocker.Mock()
|
||||
return memory
|
||||
|
||||
def test_add(self, mocker, mock_memory: RoleZeroLongTermMemory):
|
||||
mock_memory._should_use_longterm_memory_for_add = mocker.Mock(return_value=True)
|
||||
mock_memory._transfer_to_longterm_memory = mocker.Mock()
|
||||
|
||||
message = UserMessage(content="test")
|
||||
mock_memory.add(message)
|
||||
|
||||
assert mock_memory.storage[-1] == message
|
||||
mock_memory._transfer_to_longterm_memory.assert_called_once()
|
||||
|
||||
def test_get(self, mocker, mock_memory: RoleZeroLongTermMemory):
|
||||
mock_memory._should_use_longterm_memory_for_get = mocker.Mock(return_value=True)
|
||||
mock_memory._build_longterm_memory_query = mocker.Mock(return_value="query")
|
||||
mock_memory._fetch_longterm_memories = mocker.Mock(return_value=[Message(content="long-term")])
|
||||
|
||||
mock_memory.storage = [Message(content="short-term")]
|
||||
|
||||
result = mock_memory.get()
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0].content == "long-term"
|
||||
assert result[1].content == "short-term"
|
||||
|
||||
def test_should_use_longterm_memory_for_add(self, mocker, mock_memory: RoleZeroLongTermMemory):
|
||||
mocker.patch.object(mock_memory, "storage", [None] * 201)
|
||||
|
||||
mock_memory.memory_k = 200
|
||||
|
||||
assert mock_memory._should_use_longterm_memory_for_add() == True
|
||||
|
||||
mocker.patch.object(mock_memory, "storage", [None] * 199)
|
||||
assert mock_memory._should_use_longterm_memory_for_add() == False
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"k,is_last_from_user,count,expected",
|
||||
[
|
||||
(0, True, 201, False),
|
||||
(1, False, 201, False),
|
||||
(1, True, 199, False),
|
||||
(1, True, 201, True),
|
||||
],
|
||||
)
|
||||
def test_should_use_longterm_memory_for_get(
|
||||
self, mocker, mock_memory: RoleZeroLongTermMemory, k, is_last_from_user, count, expected
|
||||
):
|
||||
mock_memory._is_last_message_from_user_requirement = mocker.Mock(return_value=is_last_from_user)
|
||||
mocker.patch.object(mock_memory, "storage", [None] * count)
|
||||
mock_memory.memory_k = 200
|
||||
|
||||
assert mock_memory._should_use_longterm_memory_for_get(k) == expected
|
||||
|
||||
def test_transfer_to_longterm_memory(self, mocker, mock_memory: RoleZeroLongTermMemory):
|
||||
mock_item = mocker.Mock()
|
||||
mock_memory._get_longterm_memory_item = mocker.Mock(return_value=mock_item)
|
||||
mock_memory._add_to_longterm_memory = mocker.Mock()
|
||||
|
||||
mock_memory._transfer_to_longterm_memory()
|
||||
|
||||
mock_memory._add_to_longterm_memory.assert_called_once_with(mock_item)
|
||||
|
||||
def test_get_longterm_memory_item(self, mocker, mock_memory: RoleZeroLongTermMemory):
|
||||
mock_message = Message(content="test")
|
||||
mock_memory.storage = [mock_message, mock_message]
|
||||
mock_memory.memory_k = 1
|
||||
|
||||
result = mock_memory._get_longterm_memory_item()
|
||||
|
||||
assert isinstance(result, LongTermMemoryItem)
|
||||
assert result.message == mock_message
|
||||
|
||||
def test_add_to_longterm_memory(self, mock_memory: RoleZeroLongTermMemory):
|
||||
item = LongTermMemoryItem(message=Message(content="test"))
|
||||
mock_memory._add_to_longterm_memory(item)
|
||||
|
||||
mock_memory.rag_engine.add_objs.assert_called_once_with([item])
|
||||
|
||||
def test_build_longterm_memory_query(self, mocker, mock_memory: RoleZeroLongTermMemory):
|
||||
mock_message = Message(content="query")
|
||||
mock_memory._get_the_last_message = mocker.Mock(return_value=mock_message)
|
||||
|
||||
result = mock_memory._build_longterm_memory_query()
|
||||
|
||||
assert result == "query"
|
||||
|
||||
def test_get_the_last_message(self, mock_memory: RoleZeroLongTermMemory):
|
||||
mock_memory.storage = [Message(content="1"), Message(content="2")]
|
||||
|
||||
result = mock_memory._get_the_last_message()
|
||||
|
||||
assert result.content == "2"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"message,expected",
|
||||
[
|
||||
(UserMessage(content="test", cause_by=UserRequirement), True),
|
||||
(UserMessage(content="test", sent_from=TEAMLEADER_NAME), True),
|
||||
(UserMessage(content="test"), True),
|
||||
(AIMessage(content="test"), False),
|
||||
(None, False),
|
||||
],
|
||||
)
|
||||
def test_is_last_message_from_user_requirement(
|
||||
self, mocker, mock_memory: RoleZeroLongTermMemory, message, expected
|
||||
):
|
||||
mock_memory._get_the_last_message = mocker.Mock(return_value=message)
|
||||
|
||||
assert mock_memory._is_last_message_from_user_requirement() == expected
|
||||
|
||||
def test_fetch_longterm_memories(self, mocker, mock_memory: RoleZeroLongTermMemory):
|
||||
mock_nodes = [mocker.Mock(), mocker.Mock()]
|
||||
mock_memory.rag_engine.retrieve = mocker.Mock(return_value=mock_nodes)
|
||||
mock_items = [
|
||||
LongTermMemoryItem(message=UserMessage(content="user1")),
|
||||
LongTermMemoryItem(message=AIMessage(content="ai1")),
|
||||
]
|
||||
mock_memory._get_items_from_nodes = mocker.Mock(return_value=mock_items)
|
||||
|
||||
result = mock_memory._fetch_longterm_memories("query")
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0].content == "user1"
|
||||
assert result[1].content == "ai1"
|
||||
|
||||
def test_get_items_from_nodes(self, mocker, mock_memory: RoleZeroLongTermMemory):
|
||||
now = datetime.now()
|
||||
mock_nodes = [
|
||||
mocker.Mock(
|
||||
metadata={
|
||||
"obj": LongTermMemoryItem(
|
||||
message=Message(content="2"), created_at=(now - timedelta(minutes=1)).timestamp()
|
||||
)
|
||||
}
|
||||
),
|
||||
mocker.Mock(
|
||||
metadata={
|
||||
"obj": LongTermMemoryItem(
|
||||
message=Message(content="1"), created_at=(now - timedelta(minutes=2)).timestamp()
|
||||
)
|
||||
}
|
||||
),
|
||||
mocker.Mock(metadata={"obj": LongTermMemoryItem(message=Message(content="3"), created_at=now.timestamp())}),
|
||||
]
|
||||
|
||||
result = mock_memory._get_items_from_nodes(mock_nodes)
|
||||
|
||||
assert len(result) == 3
|
||||
assert [item.message.content for item in result] == ["1", "2", "3"]
|
||||
Loading…
Add table
Add a link
Reference in a new issue