mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-17 15:35:21 +02:00
memory & retrieve 测试
This commit is contained in:
parent
6e0701eb7e
commit
11bef70fd7
5 changed files with 167 additions and 88 deletions
|
|
@ -128,12 +128,11 @@ class AgentMemory(Memory):
|
|||
"""
|
||||
将MemoryBasic类存储为Nodes.json形式。复现GA中的Kw Strength.json形式
|
||||
这里添加一个路径即可
|
||||
TODO 这里需要添加Const常量
|
||||
TODO 这里在存储时候进行倒序存储,之后需要验证(test_memory通过)
|
||||
"""
|
||||
|
||||
memory_json = dict()
|
||||
for i in range(len(self.storage)):
|
||||
memory_node = self.storage[i]
|
||||
memory_node = self.storage[len(self.storage)-i-1]
|
||||
memory_node = memory_node.save_to_dict()
|
||||
memory_json.update(memory_node)
|
||||
with open(memory_saved + "/nodes.json", "w") as outfile:
|
||||
|
|
@ -175,14 +174,15 @@ class AgentMemory(Memory):
|
|||
poignancy = node_details["poignancy"]
|
||||
keywords = set(node_details["keywords"])
|
||||
filling = node_details["filling"]
|
||||
|
||||
# print(node_type)
|
||||
if node_type == "event":
|
||||
self.add_event(created, expiration, s, p, o,
|
||||
description, keywords, poignancy, embedding_pair, filling)
|
||||
elif node_type == "chat":
|
||||
cause_by = node_details["cause_by"]
|
||||
# cause_by = node_details["cause_by"]
|
||||
logger.info(f"{node_id}")
|
||||
self.add_chat(created, expiration, s, p, o,
|
||||
description, keywords, poignancy, embedding_pair, filling, cause_by)
|
||||
description, keywords, poignancy, embedding_pair, filling)
|
||||
elif node_type == "thought":
|
||||
self.add_thought(created, expiration, s, p, o,
|
||||
description, keywords, poignancy, embedding_pair, filling)
|
||||
|
|
@ -198,11 +198,11 @@ class AgentMemory(Memory):
|
|||
Add a new message to storage, while updating the index
|
||||
重写add方法,修改原有的Message类为BasicMemory类,并添加不同的记忆类型添加方式
|
||||
"""
|
||||
if memory_basic in self.storage:
|
||||
if memory_basic.memory_id in self.storage:
|
||||
return
|
||||
self.storage.append(memory_basic)
|
||||
if memory_basic.cause_by:
|
||||
self.index[memory_basic.cause_by][0:0] = [memory_basic]
|
||||
if memory_basic.memory_type == "chat":
|
||||
self.chat_list[0:0] = [memory_basic]
|
||||
return
|
||||
if memory_basic.memory_type == "thought":
|
||||
self.thought_list[0:0] = [memory_basic]
|
||||
|
|
@ -213,14 +213,14 @@ class AgentMemory(Memory):
|
|||
def add_chat(self, created, expiration, s, p, o,
|
||||
content, keywords, poignancy,
|
||||
embedding_pair, filling,
|
||||
cause_by):
|
||||
cause_by = ''):
|
||||
"""
|
||||
调用add方法,初始化chat,在创建的时候就需要调用embedding函数
|
||||
"""
|
||||
memory_count = len(self.storage) + 1
|
||||
type_count = len(self.thought_list) + 1
|
||||
memory_type = "chat"
|
||||
memory_id = f"memory_{str(memory_count)}"
|
||||
memory_id = f"node_{str(memory_count)}"
|
||||
depth = 1
|
||||
|
||||
memory_node = BasicMemory(memory_id, memory_count, type_count, memory_type, depth,
|
||||
|
|
@ -251,7 +251,7 @@ class AgentMemory(Memory):
|
|||
memory_count = len(self.storage) + 1
|
||||
type_count = len(self.thought_list) + 1
|
||||
memory_type = "event"
|
||||
memory_id = f"memory_{str(memory_count)}"
|
||||
memory_id = f"node_{str(memory_count)}"
|
||||
depth = 1
|
||||
|
||||
try:
|
||||
|
|
@ -296,7 +296,7 @@ class AgentMemory(Memory):
|
|||
memory_count = len(self.storage) + 1
|
||||
type_count = len(self.event_list) + 1
|
||||
memory_type = "event"
|
||||
memory_id = f"memory_{str(memory_count)}"
|
||||
memory_id = f"node_{str(memory_count)}"
|
||||
depth = 0
|
||||
|
||||
if "(" in content:
|
||||
|
|
|
|||
|
|
@ -9,10 +9,10 @@ from numpy.linalg import norm
|
|||
|
||||
from examples.st_game.memory.agent_memory import BasicMemory
|
||||
from examples.st_game.utils.utils import get_embedding
|
||||
from examples.st_game.roles.st_role import STRole
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
def agent_retrieve(curr_time: datetime.datetime, memory_forget: float, query: str, nodes: list[BasicMemory],
|
||||
def agent_retrieve(agent_memory, curr_time: datetime.datetime, memory_forget: float, query: str, nodes: list[BasicMemory],
|
||||
topk: int = 4, ) -> list[BasicMemory]:
|
||||
"""
|
||||
Retrieve需要集合Role使用,原因在于Role才具有AgentMemory,scratch
|
||||
|
|
@ -28,12 +28,13 @@ def agent_retrieve(curr_time: datetime.datetime, memory_forget: float, query: st
|
|||
}
|
||||
"""
|
||||
memories = nodes
|
||||
agent_memory_embedding = agent_memory.embeddings
|
||||
memories = sorted(memories, key=lambda memory_node: memory_node.last_accessed, reverse=True)
|
||||
|
||||
score_list = []
|
||||
score_list = extract_importance(memories, score_list)
|
||||
score_list = extract_recency(curr_time, memory_forget, score_list)
|
||||
score_list = extract_relevance(query, score_list)
|
||||
score_list = extract_relevance(agent_memory_embedding,query, score_list)
|
||||
score_list = normalize_score_floats(score_list, 0, 1)
|
||||
|
||||
total_dict = {}
|
||||
|
|
@ -43,14 +44,14 @@ def agent_retrieve(curr_time: datetime.datetime, memory_forget: float, query: st
|
|||
score_list[i]['recency'] * gw[1] +
|
||||
score_list[i]['relevance'] * gw[2]
|
||||
)
|
||||
total_dict[score_list[i]['memory']] = total_score
|
||||
total_dict[score_list[i]['memory'].memory_id] = total_score
|
||||
|
||||
result = top_highest_x_values(total_dict, topk)
|
||||
|
||||
return result # 返回的是一个BasicMemory列表
|
||||
|
||||
|
||||
def new_agent_retrieve(role: STRole, focus_points: list, n_count=30) -> dict:
|
||||
def new_agent_retrieve(role, focus_points: list, n_count=30) -> dict:
|
||||
"""
|
||||
输入为role,关注点列表,返回记忆数量
|
||||
输出为字典,键为focus_point,值为对应的记忆列表
|
||||
|
|
@ -62,12 +63,16 @@ def new_agent_retrieve(role: STRole, focus_points: list, n_count=30) -> dict:
|
|||
if "idle" not in i.embedding_key]
|
||||
nodes = sorted(nodes, key=lambda x: x[0])
|
||||
nodes = [i for created, i in nodes]
|
||||
results = agent_retrieve(role.scratch.curr_time, role.scratch.recency_decay,
|
||||
results = agent_retrieve(role.memory, role.scratch.curr_time, role.scratch.recency_decay,
|
||||
focal_pt, nodes, n_count)
|
||||
final_result = []
|
||||
for n in results:
|
||||
n.last_accessed = role.scratch.curr_time
|
||||
for i in role.memory.storage:
|
||||
if i.memory_id == n:
|
||||
i.last_accessed = role.scratch.curr_time
|
||||
final_result.append(i.content)
|
||||
|
||||
retrieved[focal_pt] = results
|
||||
retrieved[focal_pt] = final_result
|
||||
|
||||
return retrieved
|
||||
|
||||
|
|
@ -93,14 +98,15 @@ def extract_importance(memories, score_list):
|
|||
return score_list
|
||||
|
||||
|
||||
def extract_relevance(query, score_list):
|
||||
def extract_relevance(agent_memory_embedding,query, score_list):
|
||||
"""
|
||||
抽取相关性
|
||||
"""
|
||||
query_embedding = get_embedding(query)
|
||||
# 进行
|
||||
for i in range(len(score_list)):
|
||||
result = cos_sim(score_list[i]["memory"].embedding_key, query_embedding)
|
||||
node_embedding = agent_memory_embedding[score_list[i]["memory"].embedding_key]
|
||||
result = cos_sim(node_embedding, query_embedding)
|
||||
score_list[i]['relevance'] = result
|
||||
|
||||
return score_list
|
||||
|
|
|
|||
72
examples/st_game/tests/test_agent_memory.py
Normal file
72
examples/st_game/tests/test_agent_memory.py
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
import pytest
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from metagpt.logs import logger
|
||||
from examples.st_game.memory.agent_memory import AgentMemory
|
||||
from examples.st_game.utils.const import STORAGE_PATH
|
||||
from examples.st_game.memory.retrieve import agent_retrieve
|
||||
|
||||
"""
|
||||
memory测试思路
|
||||
1. Basic Memory测试
|
||||
2. Agent Memory测试
|
||||
2.1 Load & Save方法测试; Load方法中使用了add方法,验证Load即可验证所有add
|
||||
2.2 Get方法测试
|
||||
"""
|
||||
memory_easy_storage_path = os.path.join(STORAGE_PATH,"July1_the_ville_isabella_maria_klaus-step-3-4/personas/Isabella Rodriguez/bootstrap_memory/associative_memory")
|
||||
memroy_chat_storage_path = os.path.join(STORAGE_PATH,"July1_the_ville_isabella_maria_klaus-step-3-11/personas/Isabella Rodriguez/bootstrap_memory/associative_memory")
|
||||
memory_save_easy_test_path = os.path.join(STORAGE_PATH,"July1_the_ville_isabella_maria_klaus-step-3-4/personas/Isabella Rodriguez/bootstrap_memory/test_memory")
|
||||
memory_save_chat_test_path = os.path.join(STORAGE_PATH,"July1_the_ville_isabella_maria_klaus-step-3-11/personas/Isabella Rodriguez/bootstrap_memory/test_memory")
|
||||
class TestAgentMemory:
|
||||
@pytest.fixture
|
||||
def agent_memory(self):
|
||||
# 创建一个AgentMemory实例并返回,可以在所有测试用例中共享
|
||||
test_agent_memory = AgentMemory()
|
||||
test_agent_memory.set_mem_path(memroy_chat_storage_path)
|
||||
return test_agent_memory
|
||||
|
||||
def test_load(self,agent_memory):
|
||||
logger.info(f"存储路径为:{agent_memory.memory_saved}")
|
||||
logger.info(f"存储记忆条数为:{len(agent_memory.storage)}")
|
||||
logger.info(f"kw_strength为{agent_memory.kw_strength_event},{agent_memory.kw_strength_thought}")
|
||||
logger.info(f"embeeding.json条数为{len(agent_memory.embeddings)}")
|
||||
|
||||
assert agent_memory.embeddings != None
|
||||
|
||||
def test_save(self,agent_memory):
|
||||
try:
|
||||
agent_memory.save(memory_save_chat_test_path)
|
||||
logger.info("成功存储")
|
||||
except:
|
||||
pass
|
||||
|
||||
def test_summary_function(self, agent_memory):
|
||||
logger.info(f"event长度为{len(agent_memory.event_list)}")
|
||||
logger.info(f"thought长度为{len(agent_memory.thought_list)}")
|
||||
result1 = agent_memory.get_summarized_latest_events(4)
|
||||
logger.info(f"总结最近事件结果为:{result1}")
|
||||
def test_get_last_chat_function(self,agent_memory):
|
||||
result2 = agent_memory.get_last_chat("customers")
|
||||
logger.info(f"上一次对话是{result2}")
|
||||
|
||||
def test_retrieve_function(self,agent_memory):
|
||||
focus_points = ["who i love?"]
|
||||
retrieved = dict()
|
||||
for focal_pt in focus_points:
|
||||
nodes = [[i.last_accessed, i]
|
||||
for i in agent_memory.event_list + agent_memory.thought_list
|
||||
if "idle" not in i.embedding_key]
|
||||
nodes = sorted(nodes, key=lambda x: x[0])
|
||||
nodes = [i for created, i in nodes]
|
||||
results = agent_retrieve(agent_memory, datetime.now()-timedelta(days=120), 0.99,
|
||||
focal_pt, nodes, 5)
|
||||
final_result = []
|
||||
for n in results:
|
||||
for i in agent_memory.storage:
|
||||
if i.memory_id == n:
|
||||
i.last_accessed = datetime.now()-timedelta(days=120)
|
||||
final_result.append(i.content)
|
||||
|
||||
retrieved[focal_pt] = final_result
|
||||
logger.info(f"检索结果为{retrieved}")
|
||||
66
examples/st_game/tests/test_basic_memory.py
Normal file
66
examples/st_game/tests/test_basic_memory.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
from datetime import datetime, timedelta
|
||||
from metagpt.logs import logger
|
||||
from examples.st_game.memory.agent_memory import BasicMemory
|
||||
import pytest
|
||||
|
||||
"""
|
||||
memory测试思路
|
||||
1. Basic Memory测试
|
||||
2. Agent Memory测试
|
||||
2.1 Load & Save方法测试
|
||||
2.2 Add方法测试
|
||||
2.3 Get方法测试
|
||||
"""
|
||||
|
||||
# Create some sample BasicMemory instances
|
||||
memory1 = BasicMemory(
|
||||
memory_id="1",
|
||||
memory_count=1,
|
||||
type_count=1,
|
||||
memory_type="event",
|
||||
depth=1,
|
||||
created=datetime.now(),
|
||||
expiration=datetime.now() + timedelta(days=30),
|
||||
subject="Subject1",
|
||||
predicate="Predicate1",
|
||||
object="Object1",
|
||||
content="This is content 1",
|
||||
embedding_key="embedding_key_1",
|
||||
poignancy=1,
|
||||
keywords=["keyword1", "keyword2"],
|
||||
filling=["memory_id_2"]
|
||||
)
|
||||
memory2 = BasicMemory(
|
||||
memory_id="2",
|
||||
memory_count=2,
|
||||
type_count=2,
|
||||
memory_type="thought",
|
||||
depth=2,
|
||||
created=datetime.now(),
|
||||
expiration=datetime.now() + timedelta(days=30),
|
||||
subject="Subject2",
|
||||
predicate="Predicate2",
|
||||
object="Object2",
|
||||
content="This is content 2",
|
||||
embedding_key="embedding_key_2",
|
||||
poignancy=2,
|
||||
keywords=["keyword3", "keyword4"],
|
||||
filling=[]
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def basic_mem_set():
|
||||
basic_mem2 = memory2
|
||||
yield basic_mem2
|
||||
|
||||
def test_basic_mem_function(basic_mem_set):
|
||||
a, b, c = basic_mem_set.summary()
|
||||
logger.info(f"{a}{b}{c}")
|
||||
assert a == "Subject2"
|
||||
|
||||
def test_basic_mem_save(basic_mem_set):
|
||||
result = basic_mem_set.save_to_dict()
|
||||
logger.info(f"save结果为{result}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main()
|
||||
|
|
@ -1,65 +0,0 @@
|
|||
from datetime import datetime
|
||||
from metagpt.logs import logger
|
||||
from examples.st_game.memory.agent_memory import AgentMemory, BasicMemory
|
||||
|
||||
# Create some sample BasicMemory instances
|
||||
memory1 = BasicMemory(
|
||||
memory_id="1",
|
||||
memory_count=1,
|
||||
type_count=1,
|
||||
memory_type="event",
|
||||
depth=1,
|
||||
created=datetime.now(),
|
||||
expiration=datetime.now(),
|
||||
subject="Subject1",
|
||||
predicate="Predicate1",
|
||||
object="Object1",
|
||||
content="This is content 1",
|
||||
embedding_key="embedding_key_1",
|
||||
poignancy=1,
|
||||
keywords=["keyword1", "keyword2"],
|
||||
filling=["memory_id_2"]
|
||||
)
|
||||
|
||||
memory2 = BasicMemory(
|
||||
memory_id="2",
|
||||
memory_count=2,
|
||||
type_count=2,
|
||||
memory_type="thought",
|
||||
depth=2,
|
||||
created=datetime.now(),
|
||||
expiration=None,
|
||||
subject="Subject2",
|
||||
predicate="Predicate2",
|
||||
object="Object2",
|
||||
content="This is content 2",
|
||||
embedding_key="embedding_key_2",
|
||||
poignancy=2,
|
||||
keywords=["keyword3", "keyword4"],
|
||||
filling=[]
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Create an AgentMemory instance and add the created BasicMemory instances
|
||||
agent_memory = AgentMemory(memory_saved="sample_memory_folder")
|
||||
agent_memory.add_event(memory1)
|
||||
agent_memory.add_thought(memory2)
|
||||
|
||||
# Save the AgentMemory to a JSON file
|
||||
agent_memory.save("sample_memory_folder")
|
||||
|
||||
# Load the AgentMemory from the JSON file
|
||||
loaded_agent_memory = AgentMemory(memory_saved="sample_memory_folder")
|
||||
|
||||
# Get the summarized latest events
|
||||
latest_events = loaded_agent_memory.get_summarized_latest_events(retention=2)
|
||||
print("Summarized Latest Events:")
|
||||
for event in latest_events:
|
||||
print(event)
|
||||
|
||||
# Get the last chat for a specific role
|
||||
last_chat = loaded_agent_memory.get_last_chat(target_role_name="role1")
|
||||
if last_chat:
|
||||
print(f"Last chat for role1: {last_chat.content}")
|
||||
else:
|
||||
print("No chat found for role1")
|
||||
Loading…
Add table
Add a link
Reference in a new issue