diff --git a/.gitignore b/.gitignore index 24dd046be..73f32f75d 100644 --- a/.gitignore +++ b/.gitignore @@ -163,6 +163,7 @@ examples/image__vector_store.json examples/index_store.json .chroma .chroma_exp_data +.role_memory_data *~$* workspace/* tmp diff --git a/config/config2.example.yaml b/config/config2.example.yaml index 2a0ebcc47..e4dfff1eb 100644 --- a/config/config2.example.yaml +++ b/config/config2.example.yaml @@ -83,6 +83,8 @@ exp_pool: use_llm_ranker: true # Default is `true`, it will use LLM Reranker to get better result. collection_name: experience_pool # When `retrieval_type` is `chroma`, `collection_name` is the collection name in chromadb. +role_zero: + enable_longterm_memory: false # Whether to use long-term memory. Default is `false`. azure_tts_subscription_key: "YOUR_SUBSCRIPTION_KEY" azure_tts_region: "eastus" diff --git a/metagpt/config2.py b/metagpt/config2.py index 7b6ddf8c6..fd0cb0948 100644 --- a/metagpt/config2.py +++ b/metagpt/config2.py @@ -19,6 +19,7 @@ from metagpt.configs.mermaid_config import MermaidConfig from metagpt.configs.omniparse_config import OmniParseConfig from metagpt.configs.redis_config import RedisConfig from metagpt.configs.role_custom_config import RoleCustomConfig +from metagpt.configs.role_zero_config import RoleZeroConfig from metagpt.configs.s3_config import S3Config from metagpt.configs.search_config import SearchConfig from metagpt.configs.workspace_config import WorkspaceConfig @@ -89,6 +90,9 @@ class Config(CLIParams, YamlModel): # Role's custom configuration roles: Optional[List[RoleCustomConfig]] = None + # RoleZero's configuration + role_zero: RoleZeroConfig = Field(default_factory=RoleZeroConfig) + omniparse: Optional[OmniParseConfig] = None @classmethod diff --git a/metagpt/configs/role_zero_config.py b/metagpt/configs/role_zero_config.py new file mode 100644 index 000000000..27103ddf6 --- /dev/null +++ b/metagpt/configs/role_zero_config.py @@ -0,0 +1,7 @@ +from pydantic import Field + +from metagpt.utils.yaml_model import YamlModel + + +class RoleZeroConfig(YamlModel): + enable_longterm_memory: bool = Field(default=False, description="Whether to use long-term memory.") diff --git a/metagpt/const.py b/metagpt/const.py index c53e8494a..4fe8dca3d 100644 --- a/metagpt/const.py +++ b/metagpt/const.py @@ -157,3 +157,6 @@ SWE_SETUP_PATH = get_metagpt_package_root() / "metagpt/tools/swe_agent_commands/ # experience pool EXPERIENCE_MASK = "" + +# TeamLeader's name +TEAMLEADER_NAME = "Mike" diff --git a/metagpt/environment/mgx/mgx_env.py b/metagpt/environment/mgx/mgx_env.py index 4df04d3ce..53690f7d7 100644 --- a/metagpt/environment/mgx/mgx_env.py +++ b/metagpt/environment/mgx/mgx_env.py @@ -8,7 +8,7 @@ from metagpt.actions import ( WriteTest, ) from metagpt.actions.summarize_code import SummarizeCode -from metagpt.const import AGENT, IMAGES +from metagpt.const import AGENT, IMAGES, TEAMLEADER_NAME from metagpt.environment.base_env import Environment from metagpt.logs import get_human_input from metagpt.roles import Architect, ProductManager, ProjectManager, Role @@ -31,7 +31,7 @@ class MGXEnv(Environment, SerializationMixin): """let the team leader take over message publishing""" message = self.attach_images(message) # for multi-modal message - tl = self.get_role("Mike") # TeamLeader's name is Mike + tl = self.get_role(TEAMLEADER_NAME) # TeamLeader's name is Mike if user_defined_recipient: # human user's direct chat message to a certain role diff --git a/metagpt/memory/memory.py b/metagpt/memory/memory.py index 580361d33..0707a36ea 100644 --- a/metagpt/memory/memory.py +++ b/metagpt/memory/memory.py @@ -7,13 +7,14 @@ @Modified By: mashenquan, 2023-11-1. According to RFC 116: Updated the type of index key. """ from collections import defaultdict -from typing import DefaultDict, Iterable, Set +from typing import DefaultDict, Iterable, Optional, Set from pydantic import BaseModel, Field, SerializeAsAny from metagpt.const import IGNORED_MESSAGE_ID from metagpt.schema import Message from metagpt.utils.common import any_to_str, any_to_str_set +from metagpt.utils.exceptions import handle_exception class Memory(BaseModel): @@ -104,3 +105,8 @@ class Memory(BaseModel): continue rsp += self.index[action] return rsp + + @handle_exception + def get_by_position(self, position: int) -> Optional[Message]: + """Returns the message at the given position if valid, otherwise returns None""" + return self.storage[position] diff --git a/metagpt/memory/role_zero_memory.py b/metagpt/memory/role_zero_memory.py new file mode 100644 index 000000000..857f2473b --- /dev/null +++ b/metagpt/memory/role_zero_memory.py @@ -0,0 +1,189 @@ +""" +This module implements a memory system combining short-term and long-term storage for AI role memory management. +It utilizes a RAG (Retrieval-Augmented Generation) engine for long-term memory storage and retrieval. +""" + +from typing import TYPE_CHECKING, Any, Optional + +from pydantic import Field + +from metagpt.actions import UserRequirement +from metagpt.const import TEAMLEADER_NAME +from metagpt.logs import logger +from metagpt.memory import Memory +from metagpt.schema import LongTermMemoryItem, Message +from metagpt.utils.common import any_to_str +from metagpt.utils.exceptions import handle_exception + +if TYPE_CHECKING: + from llama_index.core.schema import NodeWithScore + + from metagpt.rag.engines import SimpleEngine + + +class RoleZeroLongTermMemory(Memory): + """ + Implements a memory system combining short-term and long-term storage using a RAG engine. + Transfers old memories to long-term storage when short-term capacity is reached. + Retrieves combined short-term and long-term memories as needed. + """ + + persist_path: str = Field(default=".role_memory_data", description="The directory to save data.") + collection_name: str = Field(default="role_zero", description="The name of the collection, such as the role name.") + memory_k: int = Field(default=200, description="The capacity of short-term memory.") + + _rag_engine: Any = None + + @property + def rag_engine(self) -> "SimpleEngine": + if self._rag_engine is None: + self._rag_engine = self._resolve_rag_engine() + + return self._rag_engine + + def _resolve_rag_engine(self) -> "SimpleEngine": + """Lazy loading of the RAG engine components, ensuring they are only loaded when needed. + + It uses `Chroma` for retrieval and `LLMRanker` for ranking. + """ + + try: + from metagpt.rag.engines import SimpleEngine + from metagpt.rag.schema import ChromaRetrieverConfig, LLMRankerConfig + except ImportError: + raise ImportError("To use the RoleZeroMemory, you need to install the rag module.") + + retriever_configs = [ + ChromaRetrieverConfig(persist_path=self.persist_path, collection_name=self.collection_name) + ] + ranker_configs = [LLMRankerConfig()] + + rag_engine = SimpleEngine.from_objs(retriever_configs=retriever_configs, ranker_configs=ranker_configs) + + return rag_engine + + def add(self, message: Message): + """Add a new message and potentially transfer it to long-term memory.""" + + super().add(message) + + if not self._should_use_longterm_memory_for_add(): + return + + self._transfer_to_longterm_memory() + + def get(self, k=0) -> list[Message]: + """Return recent memories and optionally combines them with related long-term memories.""" + + memories = super().get(k) + + if not self._should_use_longterm_memory_for_get(k=k): + return memories + + query = self._build_longterm_memory_query() + related_memories = self._fetch_longterm_memories(query) + logger.info(f"Fetched {len(related_memories)} long-term memories.") + + final_memories = related_memories + memories + + return final_memories + + def _should_use_longterm_memory_for_add(self) -> bool: + """Determines if long-term memory should be used for add.""" + + return self.count() > self.memory_k + + def _should_use_longterm_memory_for_get(self, k: int) -> bool: + """Determines if long-term memory should be used for get. + + Long-term memory is used if: + - k is not 0. + - The last message is from user requirement. + - The count of recent memories is greater than self.memory_k. + """ + + conds = [ + k != 0, + self._is_last_message_from_user_requirement(), + self.count() > self.memory_k, + ] + + return all(conds) + + def _transfer_to_longterm_memory(self): + item = self._get_longterm_memory_item() + self._add_to_longterm_memory(item) + + @handle_exception + def _get_longterm_memory_item(self) -> Optional[LongTermMemoryItem]: + """Retrieves the most recent message before the last k messages.""" + + index = -(self.memory_k + 1) + message = self.get_by_position(index) + + return LongTermMemoryItem(message=message) + + def _add_to_longterm_memory(self, item: LongTermMemoryItem): + """Adds a long-term memory item to the RAG engine.""" + + if not item: + return + + self.rag_engine.add_objs([item]) + + def _build_longterm_memory_query(self) -> str: + """Build the content used to query related long-term memory. + + Default is to get the most recent user message, or an empty string if none is found. + """ + + message = self._get_the_last_message() + + return message.content if message else "" + + def _get_the_last_message(self) -> Optional[Message]: + if not self.count(): + return None + + return self.get_by_position(-1) + + def _is_last_message_from_user_requirement(self) -> bool: + """Checks if the last message is from a user requirement or sent by the team leader.""" + + message = self._get_the_last_message() + + if not message: + return False + + is_user_message = message.is_user_message() + cause_by_user_requirement = message.cause_by == any_to_str(UserRequirement) + sent_from_team_leader = message.sent_from == TEAMLEADER_NAME + + return is_user_message and (cause_by_user_requirement or sent_from_team_leader) + + def _fetch_longterm_memories(self, query: str) -> list[Message]: + """Fetches long-term memories based on a query. + + Args: + query (str): The query string to search for relevant memories. + + Returns: + list[Message]: A list of user and AI messages related to the query. + """ + + if not query: + return [] + + nodes = self.rag_engine.retrieve(query) + items = self._get_items_from_nodes(nodes) + memories = [item.message for item in items] + + return memories + + def _get_items_from_nodes(self, nodes: list["NodeWithScore"]) -> list[LongTermMemoryItem]: + """Get items from nodes and arrange them in order of their `created_at`.""" + + items: list[LongTermMemoryItem] = [node.metadata["obj"] for node in nodes] + items.sort(key=lambda item: item.created_at) + + return items diff --git a/metagpt/rag/prompts/default_prompts.py b/metagpt/rag/prompts/default_prompts.py index 12a5e2f06..eadcaa770 100644 --- a/metagpt/rag/prompts/default_prompts.py +++ b/metagpt/rag/prompts/default_prompts.py @@ -23,6 +23,7 @@ Doc: 9, Relevance: 7 - Evaluate the relevance between the question and the documents. - The relevance score is a number from 1-10 based on how relevant you think the document is to the question. - Do not include any documents that are not relevant to the question. +- If none of the documents provided contain information that directly answers the question, simply respond with "no relevant documents". ## Constraint Format: Just print the result in format like **Format Example**. diff --git a/metagpt/roles/di/role_zero.py b/metagpt/roles/di/role_zero.py index 0e7e04969..7cbac2d04 100644 --- a/metagpt/roles/di/role_zero.py +++ b/metagpt/roles/di/role_zero.py @@ -17,6 +17,7 @@ from metagpt.exp_pool import exp_cache from metagpt.exp_pool.context_builders import RoleZeroContextBuilder from metagpt.exp_pool.serializers import RoleZeroSerializer from metagpt.logs import logger +from metagpt.memory.role_zero_memory import RoleZeroLongTermMemory from metagpt.prompts.di.role_zero import ( ASK_HUMAN_COMMAND, CMD_PROMPT, @@ -166,6 +167,24 @@ class RoleZero(Role): self._update_tool_execution() return self + @model_validator(mode="after") + def set_longterm_memory(self) -> "RoleZero": + """Set up long-term memory for the role if enabled in the configuration. + + If `enable_longterm_memory` is True, set up long-term memory. + The role name will be used as the collection name. + """ + + if self.config.role_zero.enable_longterm_memory: + self.rc.memory = RoleZeroLongTermMemory( + **self.rc.memory.model_dump(), + collection_name=self.name.replace(" ", ""), + memory_k=self.memory_k, + ) + logger.info(f"Long-term memory set for role '{self.name}'") + + return self + def _update_tool_execution(self): pass @@ -289,12 +308,12 @@ class RoleZero(Role): self.rc.memory.add(AIMessage(content=self.command_rsp)) if not ok: error_msg = commands - self.rc.memory.add(UserMessage(content=error_msg)) + self.rc.memory.add(UserMessage(content=error_msg, cause_by=RunCommand)) return error_msg logger.info(f"Commands: \n{commands}") outputs = await self._run_commands(commands) logger.info(f"Commands outputs: \n{outputs}") - self.rc.memory.add(UserMessage(content=outputs)) + self.rc.memory.add(UserMessage(content=outputs, cause_by=RunCommand)) return AIMessage( content=f"I have finished the task, please mark my task as finished. Outputs: {outputs}", diff --git a/metagpt/roles/di/team_leader.py b/metagpt/roles/di/team_leader.py index 112ca5a84..0724ffdea 100644 --- a/metagpt/roles/di/team_leader.py +++ b/metagpt/roles/di/team_leader.py @@ -5,6 +5,7 @@ from typing import Annotated from pydantic import Field from metagpt.actions.di.run_command import RunCommand +from metagpt.const import TEAMLEADER_NAME from metagpt.prompts.di.team_leader import ( FINISH_CURRENT_TASK_CMD, TL_INFO, @@ -19,7 +20,7 @@ from metagpt.tools.tool_registry import register_tool @register_tool(include_functions=["publish_team_message"]) class TeamLeader(RoleZero): - name: str = "Mike" + name: str = TEAMLEADER_NAME profile: str = "Team Leader" goal: str = "Manage a team to assist users" thought_guidance: str = TL_THOUGHT_GUIDANCE diff --git a/metagpt/schema.py b/metagpt/schema.py index 13814b1d6..5224eaf14 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -18,6 +18,7 @@ from __future__ import annotations import asyncio import json import os.path +import time import uuid from abc import ABC from asyncio import Queue, QueueEmpty, wait_for @@ -408,6 +409,12 @@ class Message(BaseModel): dynamic_class = create_model(class_name, **{key: (value.__class__, ...) for key, value in kvs.items()}) return dynamic_class.model_validate(kvs) + def is_user_message(self) -> bool: + return self.role == "user" + + def is_ai_message(self) -> bool: + return self.role == "assistant" + class UserMessage(Message): """便于支持OpenAI的消息 @@ -955,3 +962,11 @@ class BaseEnum(Enum): obj._value_ = value obj.desc = desc return obj + + +class LongTermMemoryItem(BaseModel): + message: Message + created_at: Optional[float] = Field(default_factory=time.time) + + def rag_key(self) -> str: + return self.message.content diff --git a/tests/metagpt/memory/test_role_zero_memory.py b/tests/metagpt/memory/test_role_zero_memory.py new file mode 100644 index 000000000..80eb58e49 --- /dev/null +++ b/tests/metagpt/memory/test_role_zero_memory.py @@ -0,0 +1,164 @@ +from datetime import datetime, timedelta + +import pytest + +from metagpt.actions import UserRequirement +from metagpt.const import TEAMLEADER_NAME +from metagpt.memory.role_zero_memory import RoleZeroLongTermMemory +from metagpt.schema import AIMessage, LongTermMemoryItem, Message, UserMessage + + +class TestRoleZeroLongTermMemory: + @pytest.fixture + def mock_memory(self, mocker) -> RoleZeroLongTermMemory: + memory = RoleZeroLongTermMemory() + memory._resolve_rag_engine = mocker.Mock() + return memory + + def test_add(self, mocker, mock_memory: RoleZeroLongTermMemory): + mock_memory._should_use_longterm_memory_for_add = mocker.Mock(return_value=True) + mock_memory._transfer_to_longterm_memory = mocker.Mock() + + message = UserMessage(content="test") + mock_memory.add(message) + + assert mock_memory.storage[-1] == message + mock_memory._transfer_to_longterm_memory.assert_called_once() + + def test_get(self, mocker, mock_memory: RoleZeroLongTermMemory): + mock_memory._should_use_longterm_memory_for_get = mocker.Mock(return_value=True) + mock_memory._build_longterm_memory_query = mocker.Mock(return_value="query") + mock_memory._fetch_longterm_memories = mocker.Mock(return_value=[Message(content="long-term")]) + + mock_memory.storage = [Message(content="short-term")] + + result = mock_memory.get() + + assert len(result) == 2 + assert result[0].content == "long-term" + assert result[1].content == "short-term" + + def test_should_use_longterm_memory_for_add(self, mocker, mock_memory: RoleZeroLongTermMemory): + mocker.patch.object(mock_memory, "storage", [None] * 201) + + mock_memory.memory_k = 200 + + assert mock_memory._should_use_longterm_memory_for_add() == True + + mocker.patch.object(mock_memory, "storage", [None] * 199) + assert mock_memory._should_use_longterm_memory_for_add() == False + + @pytest.mark.parametrize( + "k,is_last_from_user,count,expected", + [ + (0, True, 201, False), + (1, False, 201, False), + (1, True, 199, False), + (1, True, 201, True), + ], + ) + def test_should_use_longterm_memory_for_get( + self, mocker, mock_memory: RoleZeroLongTermMemory, k, is_last_from_user, count, expected + ): + mock_memory._is_last_message_from_user_requirement = mocker.Mock(return_value=is_last_from_user) + mocker.patch.object(mock_memory, "storage", [None] * count) + mock_memory.memory_k = 200 + + assert mock_memory._should_use_longterm_memory_for_get(k) == expected + + def test_transfer_to_longterm_memory(self, mocker, mock_memory: RoleZeroLongTermMemory): + mock_item = mocker.Mock() + mock_memory._get_longterm_memory_item = mocker.Mock(return_value=mock_item) + mock_memory._add_to_longterm_memory = mocker.Mock() + + mock_memory._transfer_to_longterm_memory() + + mock_memory._add_to_longterm_memory.assert_called_once_with(mock_item) + + def test_get_longterm_memory_item(self, mocker, mock_memory: RoleZeroLongTermMemory): + mock_message = Message(content="test") + mock_memory.storage = [mock_message, mock_message] + mock_memory.memory_k = 1 + + result = mock_memory._get_longterm_memory_item() + + assert isinstance(result, LongTermMemoryItem) + assert result.message == mock_message + + def test_add_to_longterm_memory(self, mock_memory: RoleZeroLongTermMemory): + item = LongTermMemoryItem(message=Message(content="test")) + mock_memory._add_to_longterm_memory(item) + + mock_memory.rag_engine.add_objs.assert_called_once_with([item]) + + def test_build_longterm_memory_query(self, mocker, mock_memory: RoleZeroLongTermMemory): + mock_message = Message(content="query") + mock_memory._get_the_last_message = mocker.Mock(return_value=mock_message) + + result = mock_memory._build_longterm_memory_query() + + assert result == "query" + + def test_get_the_last_message(self, mock_memory: RoleZeroLongTermMemory): + mock_memory.storage = [Message(content="1"), Message(content="2")] + + result = mock_memory._get_the_last_message() + + assert result.content == "2" + + @pytest.mark.parametrize( + "message,expected", + [ + (UserMessage(content="test", cause_by=UserRequirement), True), + (UserMessage(content="test", sent_from=TEAMLEADER_NAME), True), + (UserMessage(content="test"), True), + (AIMessage(content="test"), False), + (None, False), + ], + ) + def test_is_last_message_from_user_requirement( + self, mocker, mock_memory: RoleZeroLongTermMemory, message, expected + ): + mock_memory._get_the_last_message = mocker.Mock(return_value=message) + + assert mock_memory._is_last_message_from_user_requirement() == expected + + def test_fetch_longterm_memories(self, mocker, mock_memory: RoleZeroLongTermMemory): + mock_nodes = [mocker.Mock(), mocker.Mock()] + mock_memory.rag_engine.retrieve = mocker.Mock(return_value=mock_nodes) + mock_items = [ + LongTermMemoryItem(message=UserMessage(content="user1")), + LongTermMemoryItem(message=AIMessage(content="ai1")), + ] + mock_memory._get_items_from_nodes = mocker.Mock(return_value=mock_items) + + result = mock_memory._fetch_longterm_memories("query") + + assert len(result) == 2 + assert result[0].content == "user1" + assert result[1].content == "ai1" + + def test_get_items_from_nodes(self, mocker, mock_memory: RoleZeroLongTermMemory): + now = datetime.now() + mock_nodes = [ + mocker.Mock( + metadata={ + "obj": LongTermMemoryItem( + message=Message(content="2"), created_at=(now - timedelta(minutes=1)).timestamp() + ) + } + ), + mocker.Mock( + metadata={ + "obj": LongTermMemoryItem( + message=Message(content="1"), created_at=(now - timedelta(minutes=2)).timestamp() + ) + } + ), + mocker.Mock(metadata={"obj": LongTermMemoryItem(message=Message(content="3"), created_at=now.timestamp())}), + ] + + result = mock_memory._get_items_from_nodes(mock_nodes) + + assert len(result) == 3 + assert [item.message.content for item in result] == ["1", "2", "3"]