memory & retrieve 测试

This commit is contained in:
didi 2023-10-04 17:36:52 +08:00
parent 6e0701eb7e
commit 11bef70fd7
5 changed files with 167 additions and 88 deletions

View file

@ -128,12 +128,11 @@ class AgentMemory(Memory):
"""
将MemoryBasic类存储为Nodes.json形式复现GA中的Kw Strength.json形式
这里添加一个路径即可
TODO 这里需要添加Const常量
TODO 这里在存储时候进行倒序存储之后需要验证test_memory通过
"""
memory_json = dict()
for i in range(len(self.storage)):
memory_node = self.storage[i]
memory_node = self.storage[len(self.storage)-i-1]
memory_node = memory_node.save_to_dict()
memory_json.update(memory_node)
with open(memory_saved + "/nodes.json", "w") as outfile:
@ -175,14 +174,15 @@ class AgentMemory(Memory):
poignancy = node_details["poignancy"]
keywords = set(node_details["keywords"])
filling = node_details["filling"]
# print(node_type)
if node_type == "event":
self.add_event(created, expiration, s, p, o,
description, keywords, poignancy, embedding_pair, filling)
elif node_type == "chat":
cause_by = node_details["cause_by"]
# cause_by = node_details["cause_by"]
logger.info(f"{node_id}")
self.add_chat(created, expiration, s, p, o,
description, keywords, poignancy, embedding_pair, filling, cause_by)
description, keywords, poignancy, embedding_pair, filling)
elif node_type == "thought":
self.add_thought(created, expiration, s, p, o,
description, keywords, poignancy, embedding_pair, filling)
@ -198,11 +198,11 @@ class AgentMemory(Memory):
Add a new message to storage, while updating the index
重写add方法修改原有的Message类为BasicMemory类并添加不同的记忆类型添加方式
"""
if memory_basic in self.storage:
if memory_basic.memory_id in self.storage:
return
self.storage.append(memory_basic)
if memory_basic.cause_by:
self.index[memory_basic.cause_by][0:0] = [memory_basic]
if memory_basic.memory_type == "chat":
self.chat_list[0:0] = [memory_basic]
return
if memory_basic.memory_type == "thought":
self.thought_list[0:0] = [memory_basic]
@ -213,14 +213,14 @@ class AgentMemory(Memory):
def add_chat(self, created, expiration, s, p, o,
content, keywords, poignancy,
embedding_pair, filling,
cause_by):
cause_by = ''):
"""
调用add方法初始化chat在创建的时候就需要调用embedding函数
"""
memory_count = len(self.storage) + 1
type_count = len(self.thought_list) + 1
memory_type = "chat"
memory_id = f"memory_{str(memory_count)}"
memory_id = f"node_{str(memory_count)}"
depth = 1
memory_node = BasicMemory(memory_id, memory_count, type_count, memory_type, depth,
@ -251,7 +251,7 @@ class AgentMemory(Memory):
memory_count = len(self.storage) + 1
type_count = len(self.thought_list) + 1
memory_type = "event"
memory_id = f"memory_{str(memory_count)}"
memory_id = f"node_{str(memory_count)}"
depth = 1
try:
@ -296,7 +296,7 @@ class AgentMemory(Memory):
memory_count = len(self.storage) + 1
type_count = len(self.event_list) + 1
memory_type = "event"
memory_id = f"memory_{str(memory_count)}"
memory_id = f"node_{str(memory_count)}"
depth = 0
if "(" in content:

View file

@ -9,10 +9,10 @@ from numpy.linalg import norm
from examples.st_game.memory.agent_memory import BasicMemory
from examples.st_game.utils.utils import get_embedding
from examples.st_game.roles.st_role import STRole
from metagpt.logs import logger
def agent_retrieve(curr_time: datetime.datetime, memory_forget: float, query: str, nodes: list[BasicMemory],
def agent_retrieve(agent_memory, curr_time: datetime.datetime, memory_forget: float, query: str, nodes: list[BasicMemory],
topk: int = 4, ) -> list[BasicMemory]:
"""
Retrieve需要集合Role使用,原因在于Role才具有AgentMemory,scratch
@ -28,12 +28,13 @@ def agent_retrieve(curr_time: datetime.datetime, memory_forget: float, query: st
}
"""
memories = nodes
agent_memory_embedding = agent_memory.embeddings
memories = sorted(memories, key=lambda memory_node: memory_node.last_accessed, reverse=True)
score_list = []
score_list = extract_importance(memories, score_list)
score_list = extract_recency(curr_time, memory_forget, score_list)
score_list = extract_relevance(query, score_list)
score_list = extract_relevance(agent_memory_embedding,query, score_list)
score_list = normalize_score_floats(score_list, 0, 1)
total_dict = {}
@ -43,14 +44,14 @@ def agent_retrieve(curr_time: datetime.datetime, memory_forget: float, query: st
score_list[i]['recency'] * gw[1] +
score_list[i]['relevance'] * gw[2]
)
total_dict[score_list[i]['memory']] = total_score
total_dict[score_list[i]['memory'].memory_id] = total_score
result = top_highest_x_values(total_dict, topk)
return result # 返回的是一个BasicMemory列表
def new_agent_retrieve(role: STRole, focus_points: list, n_count=30) -> dict:
def new_agent_retrieve(role, focus_points: list, n_count=30) -> dict:
"""
输入为role关注点列表,返回记忆数量
输出为字典键为focus_point值为对应的记忆列表
@ -62,12 +63,16 @@ def new_agent_retrieve(role: STRole, focus_points: list, n_count=30) -> dict:
if "idle" not in i.embedding_key]
nodes = sorted(nodes, key=lambda x: x[0])
nodes = [i for created, i in nodes]
results = agent_retrieve(role.scratch.curr_time, role.scratch.recency_decay,
results = agent_retrieve(role.memory, role.scratch.curr_time, role.scratch.recency_decay,
focal_pt, nodes, n_count)
final_result = []
for n in results:
n.last_accessed = role.scratch.curr_time
for i in role.memory.storage:
if i.memory_id == n:
i.last_accessed = role.scratch.curr_time
final_result.append(i.content)
retrieved[focal_pt] = results
retrieved[focal_pt] = final_result
return retrieved
@ -93,14 +98,15 @@ def extract_importance(memories, score_list):
return score_list
def extract_relevance(query, score_list):
def extract_relevance(agent_memory_embedding,query, score_list):
"""
抽取相关性
"""
query_embedding = get_embedding(query)
# 进行
for i in range(len(score_list)):
result = cos_sim(score_list[i]["memory"].embedding_key, query_embedding)
node_embedding = agent_memory_embedding[score_list[i]["memory"].embedding_key]
result = cos_sim(node_embedding, query_embedding)
score_list[i]['relevance'] = result
return score_list

View file

@ -0,0 +1,72 @@
import pytest
import os
import time
from datetime import datetime, timedelta
from metagpt.logs import logger
from examples.st_game.memory.agent_memory import AgentMemory
from examples.st_game.utils.const import STORAGE_PATH
from examples.st_game.memory.retrieve import agent_retrieve
"""
memory测试思路
1. Basic Memory测试
2. Agent Memory测试
2.1 Load & Save方法测试; Load方法中使用了add方法验证Load即可验证所有add
2.2 Get方法测试
"""
memory_easy_storage_path = os.path.join(STORAGE_PATH,"July1_the_ville_isabella_maria_klaus-step-3-4/personas/Isabella Rodriguez/bootstrap_memory/associative_memory")
memroy_chat_storage_path = os.path.join(STORAGE_PATH,"July1_the_ville_isabella_maria_klaus-step-3-11/personas/Isabella Rodriguez/bootstrap_memory/associative_memory")
memory_save_easy_test_path = os.path.join(STORAGE_PATH,"July1_the_ville_isabella_maria_klaus-step-3-4/personas/Isabella Rodriguez/bootstrap_memory/test_memory")
memory_save_chat_test_path = os.path.join(STORAGE_PATH,"July1_the_ville_isabella_maria_klaus-step-3-11/personas/Isabella Rodriguez/bootstrap_memory/test_memory")
class TestAgentMemory:
@pytest.fixture
def agent_memory(self):
# 创建一个AgentMemory实例并返回可以在所有测试用例中共享
test_agent_memory = AgentMemory()
test_agent_memory.set_mem_path(memroy_chat_storage_path)
return test_agent_memory
def test_load(self,agent_memory):
logger.info(f"存储路径为:{agent_memory.memory_saved}")
logger.info(f"存储记忆条数为:{len(agent_memory.storage)}")
logger.info(f"kw_strength为{agent_memory.kw_strength_event},{agent_memory.kw_strength_thought}")
logger.info(f"embeeding.json条数为{len(agent_memory.embeddings)}")
assert agent_memory.embeddings != None
def test_save(self,agent_memory):
try:
agent_memory.save(memory_save_chat_test_path)
logger.info("成功存储")
except:
pass
def test_summary_function(self, agent_memory):
logger.info(f"event长度为{len(agent_memory.event_list)}")
logger.info(f"thought长度为{len(agent_memory.thought_list)}")
result1 = agent_memory.get_summarized_latest_events(4)
logger.info(f"总结最近事件结果为:{result1}")
def test_get_last_chat_function(self,agent_memory):
result2 = agent_memory.get_last_chat("customers")
logger.info(f"上一次对话是{result2}")
def test_retrieve_function(self,agent_memory):
focus_points = ["who i love?"]
retrieved = dict()
for focal_pt in focus_points:
nodes = [[i.last_accessed, i]
for i in agent_memory.event_list + agent_memory.thought_list
if "idle" not in i.embedding_key]
nodes = sorted(nodes, key=lambda x: x[0])
nodes = [i for created, i in nodes]
results = agent_retrieve(agent_memory, datetime.now()-timedelta(days=120), 0.99,
focal_pt, nodes, 5)
final_result = []
for n in results:
for i in agent_memory.storage:
if i.memory_id == n:
i.last_accessed = datetime.now()-timedelta(days=120)
final_result.append(i.content)
retrieved[focal_pt] = final_result
logger.info(f"检索结果为{retrieved}")

View file

@ -0,0 +1,66 @@
from datetime import datetime, timedelta
from metagpt.logs import logger
from examples.st_game.memory.agent_memory import BasicMemory
import pytest
"""
memory测试思路
1. Basic Memory测试
2. Agent Memory测试
2.1 Load & Save方法测试
2.2 Add方法测试
2.3 Get方法测试
"""
# Create some sample BasicMemory instances
memory1 = BasicMemory(
memory_id="1",
memory_count=1,
type_count=1,
memory_type="event",
depth=1,
created=datetime.now(),
expiration=datetime.now() + timedelta(days=30),
subject="Subject1",
predicate="Predicate1",
object="Object1",
content="This is content 1",
embedding_key="embedding_key_1",
poignancy=1,
keywords=["keyword1", "keyword2"],
filling=["memory_id_2"]
)
memory2 = BasicMemory(
memory_id="2",
memory_count=2,
type_count=2,
memory_type="thought",
depth=2,
created=datetime.now(),
expiration=datetime.now() + timedelta(days=30),
subject="Subject2",
predicate="Predicate2",
object="Object2",
content="This is content 2",
embedding_key="embedding_key_2",
poignancy=2,
keywords=["keyword3", "keyword4"],
filling=[]
)
@pytest.fixture
def basic_mem_set():
basic_mem2 = memory2
yield basic_mem2
def test_basic_mem_function(basic_mem_set):
a, b, c = basic_mem_set.summary()
logger.info(f"{a}{b}{c}")
assert a == "Subject2"
def test_basic_mem_save(basic_mem_set):
result = basic_mem_set.save_to_dict()
logger.info(f"save结果为{result}")
if __name__ == "__main__":
pytest.main()

View file

@ -1,65 +0,0 @@
from datetime import datetime
from metagpt.logs import logger
from examples.st_game.memory.agent_memory import AgentMemory, BasicMemory
# Create some sample BasicMemory instances
memory1 = BasicMemory(
memory_id="1",
memory_count=1,
type_count=1,
memory_type="event",
depth=1,
created=datetime.now(),
expiration=datetime.now(),
subject="Subject1",
predicate="Predicate1",
object="Object1",
content="This is content 1",
embedding_key="embedding_key_1",
poignancy=1,
keywords=["keyword1", "keyword2"],
filling=["memory_id_2"]
)
memory2 = BasicMemory(
memory_id="2",
memory_count=2,
type_count=2,
memory_type="thought",
depth=2,
created=datetime.now(),
expiration=None,
subject="Subject2",
predicate="Predicate2",
object="Object2",
content="This is content 2",
embedding_key="embedding_key_2",
poignancy=2,
keywords=["keyword3", "keyword4"],
filling=[]
)
if __name__ == "__main__":
# Create an AgentMemory instance and add the created BasicMemory instances
agent_memory = AgentMemory(memory_saved="sample_memory_folder")
agent_memory.add_event(memory1)
agent_memory.add_thought(memory2)
# Save the AgentMemory to a JSON file
agent_memory.save("sample_memory_folder")
# Load the AgentMemory from the JSON file
loaded_agent_memory = AgentMemory(memory_saved="sample_memory_folder")
# Get the summarized latest events
latest_events = loaded_agent_memory.get_summarized_latest_events(retention=2)
print("Summarized Latest Events:")
for event in latest_events:
print(event)
# Get the last chat for a specific role
last_chat = loaded_agent_memory.get_last_chat(target_role_name="role1")
if last_chat:
print(f"Last chat for role1: {last_chat.content}")
else:
print("No chat found for role1")