Merge branch 'feat-memory-opt' into 'mgx_ops'

use long-term memory in rolezero

See merge request pub/MetaGPT!369
This commit is contained in:
林义章 2024-09-25 07:13:41 +00:00
commit e0ec17e6aa
13 changed files with 418 additions and 6 deletions

1
.gitignore vendored
View file

@ -163,6 +163,7 @@ examples/image__vector_store.json
examples/index_store.json
.chroma
.chroma_exp_data
.role_memory_data
*~$*
workspace/*
tmp

View file

@ -83,6 +83,8 @@ exp_pool:
use_llm_ranker: true # Default is `true`, it will use LLM Reranker to get better result.
collection_name: experience_pool # When `retrieval_type` is `chroma`, `collection_name` is the collection name in chromadb.
role_zero:
enable_longterm_memory: false # Whether to use long-term memory. Default is `false`.
azure_tts_subscription_key: "YOUR_SUBSCRIPTION_KEY"
azure_tts_region: "eastus"

View file

@ -19,6 +19,7 @@ from metagpt.configs.mermaid_config import MermaidConfig
from metagpt.configs.omniparse_config import OmniParseConfig
from metagpt.configs.redis_config import RedisConfig
from metagpt.configs.role_custom_config import RoleCustomConfig
from metagpt.configs.role_zero_config import RoleZeroConfig
from metagpt.configs.s3_config import S3Config
from metagpt.configs.search_config import SearchConfig
from metagpt.configs.workspace_config import WorkspaceConfig
@ -89,6 +90,9 @@ class Config(CLIParams, YamlModel):
# Role's custom configuration
roles: Optional[List[RoleCustomConfig]] = None
# RoleZero's configuration
role_zero: RoleZeroConfig = Field(default_factory=RoleZeroConfig)
omniparse: Optional[OmniParseConfig] = None
@classmethod

View file

@ -0,0 +1,7 @@
from pydantic import Field
from metagpt.utils.yaml_model import YamlModel
class RoleZeroConfig(YamlModel):
enable_longterm_memory: bool = Field(default=False, description="Whether to use long-term memory.")

View file

@ -157,3 +157,6 @@ SWE_SETUP_PATH = get_metagpt_package_root() / "metagpt/tools/swe_agent_commands/
# experience pool
EXPERIENCE_MASK = "<experience>"
# TeamLeader's name
TEAMLEADER_NAME = "Mike"

View file

@ -8,7 +8,7 @@ from metagpt.actions import (
WriteTest,
)
from metagpt.actions.summarize_code import SummarizeCode
from metagpt.const import AGENT, IMAGES
from metagpt.const import AGENT, IMAGES, TEAMLEADER_NAME
from metagpt.environment.base_env import Environment
from metagpt.logs import get_human_input
from metagpt.roles import Architect, ProductManager, ProjectManager, Role
@ -31,7 +31,7 @@ class MGXEnv(Environment, SerializationMixin):
"""let the team leader take over message publishing"""
message = self.attach_images(message) # for multi-modal message
tl = self.get_role("Mike") # TeamLeader's name is Mike
tl = self.get_role(TEAMLEADER_NAME) # TeamLeader's name is Mike
if user_defined_recipient:
# human user's direct chat message to a certain role

View file

@ -7,13 +7,14 @@
@Modified By: mashenquan, 2023-11-1. According to RFC 116: Updated the type of index key.
"""
from collections import defaultdict
from typing import DefaultDict, Iterable, Set
from typing import DefaultDict, Iterable, Optional, Set
from pydantic import BaseModel, Field, SerializeAsAny
from metagpt.const import IGNORED_MESSAGE_ID
from metagpt.schema import Message
from metagpt.utils.common import any_to_str, any_to_str_set
from metagpt.utils.exceptions import handle_exception
class Memory(BaseModel):
@ -104,3 +105,8 @@ class Memory(BaseModel):
continue
rsp += self.index[action]
return rsp
@handle_exception
def get_by_position(self, position: int) -> Optional[Message]:
"""Returns the message at the given position if valid, otherwise returns None"""
return self.storage[position]

View file

@ -0,0 +1,189 @@
"""
This module implements a memory system combining short-term and long-term storage for AI role memory management.
It utilizes a RAG (Retrieval-Augmented Generation) engine for long-term memory storage and retrieval.
"""
from typing import TYPE_CHECKING, Any, Optional
from pydantic import Field
from metagpt.actions import UserRequirement
from metagpt.const import TEAMLEADER_NAME
from metagpt.logs import logger
from metagpt.memory import Memory
from metagpt.schema import LongTermMemoryItem, Message
from metagpt.utils.common import any_to_str
from metagpt.utils.exceptions import handle_exception
if TYPE_CHECKING:
from llama_index.core.schema import NodeWithScore
from metagpt.rag.engines import SimpleEngine
class RoleZeroLongTermMemory(Memory):
"""
Implements a memory system combining short-term and long-term storage using a RAG engine.
Transfers old memories to long-term storage when short-term capacity is reached.
Retrieves combined short-term and long-term memories as needed.
"""
persist_path: str = Field(default=".role_memory_data", description="The directory to save data.")
collection_name: str = Field(default="role_zero", description="The name of the collection, such as the role name.")
memory_k: int = Field(default=200, description="The capacity of short-term memory.")
_rag_engine: Any = None
@property
def rag_engine(self) -> "SimpleEngine":
if self._rag_engine is None:
self._rag_engine = self._resolve_rag_engine()
return self._rag_engine
def _resolve_rag_engine(self) -> "SimpleEngine":
"""Lazy loading of the RAG engine components, ensuring they are only loaded when needed.
It uses `Chroma` for retrieval and `LLMRanker` for ranking.
"""
try:
from metagpt.rag.engines import SimpleEngine
from metagpt.rag.schema import ChromaRetrieverConfig, LLMRankerConfig
except ImportError:
raise ImportError("To use the RoleZeroMemory, you need to install the rag module.")
retriever_configs = [
ChromaRetrieverConfig(persist_path=self.persist_path, collection_name=self.collection_name)
]
ranker_configs = [LLMRankerConfig()]
rag_engine = SimpleEngine.from_objs(retriever_configs=retriever_configs, ranker_configs=ranker_configs)
return rag_engine
def add(self, message: Message):
"""Add a new message and potentially transfer it to long-term memory."""
super().add(message)
if not self._should_use_longterm_memory_for_add():
return
self._transfer_to_longterm_memory()
def get(self, k=0) -> list[Message]:
"""Return recent memories and optionally combines them with related long-term memories."""
memories = super().get(k)
if not self._should_use_longterm_memory_for_get(k=k):
return memories
query = self._build_longterm_memory_query()
related_memories = self._fetch_longterm_memories(query)
logger.info(f"Fetched {len(related_memories)} long-term memories.")
final_memories = related_memories + memories
return final_memories
def _should_use_longterm_memory_for_add(self) -> bool:
"""Determines if long-term memory should be used for add."""
return self.count() > self.memory_k
def _should_use_longterm_memory_for_get(self, k: int) -> bool:
"""Determines if long-term memory should be used for get.
Long-term memory is used if:
- k is not 0.
- The last message is from user requirement.
- The count of recent memories is greater than self.memory_k.
"""
conds = [
k != 0,
self._is_last_message_from_user_requirement(),
self.count() > self.memory_k,
]
return all(conds)
def _transfer_to_longterm_memory(self):
item = self._get_longterm_memory_item()
self._add_to_longterm_memory(item)
@handle_exception
def _get_longterm_memory_item(self) -> Optional[LongTermMemoryItem]:
"""Retrieves the most recent message before the last k messages."""
index = -(self.memory_k + 1)
message = self.get_by_position(index)
return LongTermMemoryItem(message=message)
def _add_to_longterm_memory(self, item: LongTermMemoryItem):
"""Adds a long-term memory item to the RAG engine."""
if not item:
return
self.rag_engine.add_objs([item])
def _build_longterm_memory_query(self) -> str:
"""Build the content used to query related long-term memory.
Default is to get the most recent user message, or an empty string if none is found.
"""
message = self._get_the_last_message()
return message.content if message else ""
def _get_the_last_message(self) -> Optional[Message]:
if not self.count():
return None
return self.get_by_position(-1)
def _is_last_message_from_user_requirement(self) -> bool:
"""Checks if the last message is from a user requirement or sent by the team leader."""
message = self._get_the_last_message()
if not message:
return False
is_user_message = message.is_user_message()
cause_by_user_requirement = message.cause_by == any_to_str(UserRequirement)
sent_from_team_leader = message.sent_from == TEAMLEADER_NAME
return is_user_message and (cause_by_user_requirement or sent_from_team_leader)
def _fetch_longterm_memories(self, query: str) -> list[Message]:
"""Fetches long-term memories based on a query.
Args:
query (str): The query string to search for relevant memories.
Returns:
list[Message]: A list of user and AI messages related to the query.
"""
if not query:
return []
nodes = self.rag_engine.retrieve(query)
items = self._get_items_from_nodes(nodes)
memories = [item.message for item in items]
return memories
def _get_items_from_nodes(self, nodes: list["NodeWithScore"]) -> list[LongTermMemoryItem]:
"""Get items from nodes and arrange them in order of their `created_at`."""
items: list[LongTermMemoryItem] = [node.metadata["obj"] for node in nodes]
items.sort(key=lambda item: item.created_at)
return items

View file

@ -23,6 +23,7 @@ Doc: 9, Relevance: 7
- Evaluate the relevance between the question and the documents.
- The relevance score is a number from 1-10 based on how relevant you think the document is to the question.
- Do not include any documents that are not relevant to the question.
- If none of the documents provided contain information that directly answers the question, simply respond with "no relevant documents".
## Constraint
Format: Just print the result in format like **Format Example**.

View file

@ -17,6 +17,7 @@ from metagpt.exp_pool import exp_cache
from metagpt.exp_pool.context_builders import RoleZeroContextBuilder
from metagpt.exp_pool.serializers import RoleZeroSerializer
from metagpt.logs import logger
from metagpt.memory.role_zero_memory import RoleZeroLongTermMemory
from metagpt.prompts.di.role_zero import (
ASK_HUMAN_COMMAND,
CMD_PROMPT,
@ -166,6 +167,24 @@ class RoleZero(Role):
self._update_tool_execution()
return self
@model_validator(mode="after")
def set_longterm_memory(self) -> "RoleZero":
"""Set up long-term memory for the role if enabled in the configuration.
If `enable_longterm_memory` is True, set up long-term memory.
The role name will be used as the collection name.
"""
if self.config.role_zero.enable_longterm_memory:
self.rc.memory = RoleZeroLongTermMemory(
**self.rc.memory.model_dump(),
collection_name=self.name.replace(" ", ""),
memory_k=self.memory_k,
)
logger.info(f"Long-term memory set for role '{self.name}'")
return self
def _update_tool_execution(self):
pass
@ -289,12 +308,12 @@ class RoleZero(Role):
self.rc.memory.add(AIMessage(content=self.command_rsp))
if not ok:
error_msg = commands
self.rc.memory.add(UserMessage(content=error_msg))
self.rc.memory.add(UserMessage(content=error_msg, cause_by=RunCommand))
return error_msg
logger.info(f"Commands: \n{commands}")
outputs = await self._run_commands(commands)
logger.info(f"Commands outputs: \n{outputs}")
self.rc.memory.add(UserMessage(content=outputs))
self.rc.memory.add(UserMessage(content=outputs, cause_by=RunCommand))
return AIMessage(
content=f"I have finished the task, please mark my task as finished. Outputs: {outputs}",

View file

@ -5,6 +5,7 @@ from typing import Annotated
from pydantic import Field
from metagpt.actions.di.run_command import RunCommand
from metagpt.const import TEAMLEADER_NAME
from metagpt.prompts.di.team_leader import (
FINISH_CURRENT_TASK_CMD,
TL_INFO,
@ -19,7 +20,7 @@ from metagpt.tools.tool_registry import register_tool
@register_tool(include_functions=["publish_team_message"])
class TeamLeader(RoleZero):
name: str = "Mike"
name: str = TEAMLEADER_NAME
profile: str = "Team Leader"
goal: str = "Manage a team to assist users"
thought_guidance: str = TL_THOUGHT_GUIDANCE

View file

@ -18,6 +18,7 @@ from __future__ import annotations
import asyncio
import json
import os.path
import time
import uuid
from abc import ABC
from asyncio import Queue, QueueEmpty, wait_for
@ -408,6 +409,12 @@ class Message(BaseModel):
dynamic_class = create_model(class_name, **{key: (value.__class__, ...) for key, value in kvs.items()})
return dynamic_class.model_validate(kvs)
def is_user_message(self) -> bool:
return self.role == "user"
def is_ai_message(self) -> bool:
return self.role == "assistant"
class UserMessage(Message):
"""便于支持OpenAI的消息
@ -955,3 +962,11 @@ class BaseEnum(Enum):
obj._value_ = value
obj.desc = desc
return obj
class LongTermMemoryItem(BaseModel):
message: Message
created_at: Optional[float] = Field(default_factory=time.time)
def rag_key(self) -> str:
return self.message.content

View file

@ -0,0 +1,164 @@
from datetime import datetime, timedelta
import pytest
from metagpt.actions import UserRequirement
from metagpt.const import TEAMLEADER_NAME
from metagpt.memory.role_zero_memory import RoleZeroLongTermMemory
from metagpt.schema import AIMessage, LongTermMemoryItem, Message, UserMessage
class TestRoleZeroLongTermMemory:
@pytest.fixture
def mock_memory(self, mocker) -> RoleZeroLongTermMemory:
memory = RoleZeroLongTermMemory()
memory._resolve_rag_engine = mocker.Mock()
return memory
def test_add(self, mocker, mock_memory: RoleZeroLongTermMemory):
mock_memory._should_use_longterm_memory_for_add = mocker.Mock(return_value=True)
mock_memory._transfer_to_longterm_memory = mocker.Mock()
message = UserMessage(content="test")
mock_memory.add(message)
assert mock_memory.storage[-1] == message
mock_memory._transfer_to_longterm_memory.assert_called_once()
def test_get(self, mocker, mock_memory: RoleZeroLongTermMemory):
mock_memory._should_use_longterm_memory_for_get = mocker.Mock(return_value=True)
mock_memory._build_longterm_memory_query = mocker.Mock(return_value="query")
mock_memory._fetch_longterm_memories = mocker.Mock(return_value=[Message(content="long-term")])
mock_memory.storage = [Message(content="short-term")]
result = mock_memory.get()
assert len(result) == 2
assert result[0].content == "long-term"
assert result[1].content == "short-term"
def test_should_use_longterm_memory_for_add(self, mocker, mock_memory: RoleZeroLongTermMemory):
mocker.patch.object(mock_memory, "storage", [None] * 201)
mock_memory.memory_k = 200
assert mock_memory._should_use_longterm_memory_for_add() == True
mocker.patch.object(mock_memory, "storage", [None] * 199)
assert mock_memory._should_use_longterm_memory_for_add() == False
@pytest.mark.parametrize(
"k,is_last_from_user,count,expected",
[
(0, True, 201, False),
(1, False, 201, False),
(1, True, 199, False),
(1, True, 201, True),
],
)
def test_should_use_longterm_memory_for_get(
self, mocker, mock_memory: RoleZeroLongTermMemory, k, is_last_from_user, count, expected
):
mock_memory._is_last_message_from_user_requirement = mocker.Mock(return_value=is_last_from_user)
mocker.patch.object(mock_memory, "storage", [None] * count)
mock_memory.memory_k = 200
assert mock_memory._should_use_longterm_memory_for_get(k) == expected
def test_transfer_to_longterm_memory(self, mocker, mock_memory: RoleZeroLongTermMemory):
mock_item = mocker.Mock()
mock_memory._get_longterm_memory_item = mocker.Mock(return_value=mock_item)
mock_memory._add_to_longterm_memory = mocker.Mock()
mock_memory._transfer_to_longterm_memory()
mock_memory._add_to_longterm_memory.assert_called_once_with(mock_item)
def test_get_longterm_memory_item(self, mocker, mock_memory: RoleZeroLongTermMemory):
mock_message = Message(content="test")
mock_memory.storage = [mock_message, mock_message]
mock_memory.memory_k = 1
result = mock_memory._get_longterm_memory_item()
assert isinstance(result, LongTermMemoryItem)
assert result.message == mock_message
def test_add_to_longterm_memory(self, mock_memory: RoleZeroLongTermMemory):
item = LongTermMemoryItem(message=Message(content="test"))
mock_memory._add_to_longterm_memory(item)
mock_memory.rag_engine.add_objs.assert_called_once_with([item])
def test_build_longterm_memory_query(self, mocker, mock_memory: RoleZeroLongTermMemory):
mock_message = Message(content="query")
mock_memory._get_the_last_message = mocker.Mock(return_value=mock_message)
result = mock_memory._build_longterm_memory_query()
assert result == "query"
def test_get_the_last_message(self, mock_memory: RoleZeroLongTermMemory):
mock_memory.storage = [Message(content="1"), Message(content="2")]
result = mock_memory._get_the_last_message()
assert result.content == "2"
@pytest.mark.parametrize(
"message,expected",
[
(UserMessage(content="test", cause_by=UserRequirement), True),
(UserMessage(content="test", sent_from=TEAMLEADER_NAME), True),
(UserMessage(content="test"), True),
(AIMessage(content="test"), False),
(None, False),
],
)
def test_is_last_message_from_user_requirement(
self, mocker, mock_memory: RoleZeroLongTermMemory, message, expected
):
mock_memory._get_the_last_message = mocker.Mock(return_value=message)
assert mock_memory._is_last_message_from_user_requirement() == expected
def test_fetch_longterm_memories(self, mocker, mock_memory: RoleZeroLongTermMemory):
mock_nodes = [mocker.Mock(), mocker.Mock()]
mock_memory.rag_engine.retrieve = mocker.Mock(return_value=mock_nodes)
mock_items = [
LongTermMemoryItem(message=UserMessage(content="user1")),
LongTermMemoryItem(message=AIMessage(content="ai1")),
]
mock_memory._get_items_from_nodes = mocker.Mock(return_value=mock_items)
result = mock_memory._fetch_longterm_memories("query")
assert len(result) == 2
assert result[0].content == "user1"
assert result[1].content == "ai1"
def test_get_items_from_nodes(self, mocker, mock_memory: RoleZeroLongTermMemory):
now = datetime.now()
mock_nodes = [
mocker.Mock(
metadata={
"obj": LongTermMemoryItem(
message=Message(content="2"), created_at=(now - timedelta(minutes=1)).timestamp()
)
}
),
mocker.Mock(
metadata={
"obj": LongTermMemoryItem(
message=Message(content="1"), created_at=(now - timedelta(minutes=2)).timestamp()
)
}
),
mocker.Mock(metadata={"obj": LongTermMemoryItem(message=Message(content="3"), created_at=now.timestamp())}),
]
result = mock_memory._get_items_from_nodes(mock_nodes)
assert len(result) == 3
assert [item.message.content for item in result] == ["1", "2", "3"]