mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-29 02:46:24 +02:00
update
This commit is contained in:
parent
31dbff5474
commit
5d21d255e4
5 changed files with 95 additions and 10 deletions
|
|
@ -57,11 +57,11 @@ class RoleZeroLongTermMemory(BaseModel):
|
|||
if not query:
|
||||
return []
|
||||
|
||||
nodes: list[NodeWithScore] = self.rag_engine.retrieve(query)
|
||||
nodes = self.rag_engine.retrieve(query)
|
||||
items = self._get_items_from_nodes(nodes)
|
||||
|
||||
memories = []
|
||||
for node in nodes:
|
||||
item: LongTermMemoryItem = node.metadata["obj"]
|
||||
for item in items:
|
||||
memories.append(item.user_message)
|
||||
memories.append(item.ai_message)
|
||||
|
||||
|
|
@ -78,3 +78,11 @@ class RoleZeroLongTermMemory(BaseModel):
|
|||
return
|
||||
|
||||
self.rag_engine.add_objs([item])
|
||||
|
||||
def _get_items_from_nodes(self, nodes: list["NodeWithScore"]) -> list[LongTermMemoryItem]:
|
||||
"""Get items from nodes and arrange them in order of their `created_at`."""
|
||||
|
||||
items: list[LongTermMemoryItem] = [node.metadata["obj"] for node in nodes]
|
||||
items.sort(key=lambda item: item.created_at)
|
||||
|
||||
return items
|
||||
|
|
|
|||
|
|
@ -285,12 +285,12 @@ class RoleZero(Role):
|
|||
self._add_memory(AIMessage(content=self.command_rsp))
|
||||
if not ok:
|
||||
error_msg = commands
|
||||
self._add_memory(UserMessage(content=error_msg))
|
||||
self._add_memory(UserMessage(content=error_msg, cause_by=RunCommand))
|
||||
return error_msg
|
||||
logger.info(f"Commands: \n{commands}")
|
||||
outputs = await self._run_commands(commands)
|
||||
logger.info(f"Commands outputs: \n{outputs}")
|
||||
self._add_memory(UserMessage(content=outputs))
|
||||
self._add_memory(UserMessage(content=outputs, cause_by=RunCommand))
|
||||
|
||||
return AIMessage(
|
||||
content=f"I have finished the task, please mark my task as finished. Outputs: {outputs}",
|
||||
|
|
@ -605,7 +605,8 @@ class RoleZero(Role):
|
|||
if not self._should_use_longterm_memory(k=k, k_memories=memories):
|
||||
return memories
|
||||
|
||||
related_memories = self.longterm_memory.fetch(memories[-1].content)
|
||||
query = self._build_longterm_memory_query(memories)
|
||||
related_memories = self.longterm_memory.fetch(query)
|
||||
logger.info(f"Fetched {len(related_memories)} long-term memories.")
|
||||
|
||||
# Keep user and AI messages are paired.
|
||||
|
|
@ -666,3 +667,11 @@ class RoleZero(Role):
|
|||
|
||||
def _is_first_message_from_ai(self, memories: list[Message]) -> bool:
|
||||
return bool(memories and memories[0].is_ai_message())
|
||||
|
||||
def _build_longterm_memory_query(self, memories: list[Message]) -> str:
|
||||
"""Build the content used to query related long-term memory.
|
||||
|
||||
Default is to get the most recent user message, or an empty string if none is found.
|
||||
"""
|
||||
|
||||
return next((m.content for m in reversed(memories) if m.is_real_user_message()), "")
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
import json
|
||||
import os.path
|
||||
import time
|
||||
import uuid
|
||||
from abc import ABC
|
||||
from asyncio import Queue, QueueEmpty, wait_for
|
||||
|
|
@ -408,12 +409,15 @@ class Message(BaseModel):
|
|||
dynamic_class = create_model(class_name, **{key: (value.__class__, ...) for key, value in kvs.items()})
|
||||
return dynamic_class.model_validate(kvs)
|
||||
|
||||
def is_user_message(self):
|
||||
def is_user_message(self) -> bool:
|
||||
return self.role == "user"
|
||||
|
||||
def is_ai_message(self):
|
||||
def is_ai_message(self) -> bool:
|
||||
return self.role == "assistant"
|
||||
|
||||
def is_real_user_message(self) -> bool:
|
||||
return self.is_user_message() and "UserRequirement" in self.cause_by
|
||||
|
||||
|
||||
class UserMessage(Message):
|
||||
"""便于支持OpenAI的消息
|
||||
|
|
@ -966,6 +970,7 @@ class BaseEnum(Enum):
|
|||
class LongTermMemoryItem(BaseModel):
|
||||
user_message: Message
|
||||
ai_message: Message
|
||||
created_at: Optional[float] = Field(default_factory=time.time)
|
||||
|
||||
def rag_key(self) -> str:
|
||||
return self.user_message.content
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from datetime import datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.memory.role_zero_memory import RoleZeroLongTermMemory
|
||||
|
|
@ -46,3 +48,36 @@ class TestRoleZeroLongTermMemory:
|
|||
item = LongTermMemoryItem(user_message=UserMessage(content="user"), ai_message=AIMessage(content="ai"))
|
||||
mock_memory.add(item)
|
||||
mock_memory.rag_engine.add_objs.assert_called_once_with([item])
|
||||
|
||||
def test_get_items_from_nodes(self, mocker, mock_memory: RoleZeroLongTermMemory):
|
||||
mock_node1 = mocker.Mock()
|
||||
mock_node2 = mocker.Mock()
|
||||
mock_node3 = mocker.Mock()
|
||||
|
||||
now = datetime.now()
|
||||
item1 = LongTermMemoryItem(
|
||||
user_message=UserMessage(content="user1"), ai_message=AIMessage(content="ai1"), created_at=now.timestamp()
|
||||
)
|
||||
item2 = LongTermMemoryItem(
|
||||
user_message=UserMessage(content="user2"),
|
||||
ai_message=AIMessage(content="ai2"),
|
||||
created_at=(now - timedelta(minutes=5)).timestamp(),
|
||||
)
|
||||
item3 = LongTermMemoryItem(
|
||||
user_message=UserMessage(content="user3"),
|
||||
ai_message=AIMessage(content="ai3"),
|
||||
created_at=(now + timedelta(minutes=5)).timestamp(),
|
||||
)
|
||||
|
||||
mock_node1.metadata = {"obj": item1}
|
||||
mock_node2.metadata = {"obj": item2}
|
||||
mock_node3.metadata = {"obj": item3}
|
||||
|
||||
result = mock_memory._get_items_from_nodes([mock_node1, mock_node2, mock_node3])
|
||||
|
||||
assert len(result) == 3
|
||||
assert result[0] == item2
|
||||
assert result[1] == item1
|
||||
assert result[2] == item3
|
||||
assert [item.user_message.content for item in result] == ["user2", "user1", "user3"]
|
||||
assert [item.ai_message.content for item in result] == ["ai2", "ai1", "ai3"]
|
||||
|
|
|
|||
|
|
@ -46,7 +46,13 @@ class TestRoleZero:
|
|||
[AIMessage(content="ai1"), UserMessage(content="user"), AIMessage(content="ai2")],
|
||||
[Message(content="related")],
|
||||
True,
|
||||
[Message(content="related"), UserMessage(content="user"), AIMessage(content="ai2")],
|
||||
[
|
||||
Message(content="related"),
|
||||
UserMessage(content="user"),
|
||||
AIMessage(content="ai1"),
|
||||
UserMessage(content="user"),
|
||||
AIMessage(content="ai2"),
|
||||
],
|
||||
),
|
||||
(
|
||||
None,
|
||||
|
|
@ -79,6 +85,7 @@ class TestRoleZero:
|
|||
):
|
||||
mock_role_zero.memory_k = 2
|
||||
mock_role_zero.rc.memory.get = mocker.Mock(return_value=memories)
|
||||
mock_role_zero.rc.memory.get_by_position = mocker.Mock(return_value=UserMessage(content="user"))
|
||||
mock_role_zero._should_use_longterm_memory = mocker.Mock(return_value=should_use_ltm)
|
||||
mock_role_zero.longterm_memory.fetch = mocker.Mock(return_value=related_memories)
|
||||
mock_role_zero._is_first_message_from_ai = mocker.Mock(return_value=is_first_from_ai)
|
||||
|
|
@ -97,7 +104,7 @@ class TestRoleZero:
|
|||
mock_role_zero._should_use_longterm_memory.assert_called_once_with(k=really_k, k_memories=memories)
|
||||
|
||||
if should_use_ltm:
|
||||
mock_role_zero.longterm_memory.fetch.assert_called_once_with(memories[-1].content)
|
||||
mock_role_zero.longterm_memory.fetch.assert_called_once_with("user")
|
||||
mock_role_zero._is_first_message_from_ai.assert_called_once_with(memories)
|
||||
|
||||
def test_add_memory(self, mocker, mock_role_zero: RoleZero):
|
||||
|
|
@ -190,3 +197,24 @@ class TestRoleZero:
|
|||
def test_is_first_message_from_ai(self, mock_role_zero: RoleZero, memories, expected):
|
||||
result = mock_role_zero._is_first_message_from_ai(memories)
|
||||
assert result == expected
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"memories,expected",
|
||||
[
|
||||
([UserMessage(content="user1"), AIMessage(content="ai"), UserMessage(content="user2")], "user2"),
|
||||
(
|
||||
[
|
||||
UserMessage(content="user1", cause_by="test"),
|
||||
AIMessage(content="ai"),
|
||||
UserMessage(content="user2", cause_by="test"),
|
||||
],
|
||||
"",
|
||||
),
|
||||
([AIMessage(content="ai1"), AIMessage(content="ai2")], ""),
|
||||
([UserMessage(content="user")], "user"),
|
||||
([], ""),
|
||||
],
|
||||
)
|
||||
def test_build_longterm_memory_query(self, mock_role_zero: RoleZero, memories, expected):
|
||||
result = mock_role_zero._build_longterm_memory_query(memories)
|
||||
assert result == expected
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue