From 11bef70fd7c6579cb7ab784a6550e8e52ca143d7 Mon Sep 17 00:00:00 2001 From: didi <2020201387@ruc.edu.cn> Date: Wed, 4 Oct 2023 17:36:52 +0800 Subject: [PATCH] =?UTF-8?q?memory=20&=20retrieve=20=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/st_game/memory/agent_memory.py | 26 ++++---- examples/st_game/memory/retrieve.py | 26 +++++--- examples/st_game/tests/test_agent_memory.py | 72 +++++++++++++++++++++ examples/st_game/tests/test_basic_memory.py | 66 +++++++++++++++++++ examples/st_game/tests/test_memory.py | 65 ------------------- 5 files changed, 167 insertions(+), 88 deletions(-) create mode 100644 examples/st_game/tests/test_agent_memory.py create mode 100644 examples/st_game/tests/test_basic_memory.py delete mode 100644 examples/st_game/tests/test_memory.py diff --git a/examples/st_game/memory/agent_memory.py b/examples/st_game/memory/agent_memory.py index 1db2566f6..97644c5b8 100644 --- a/examples/st_game/memory/agent_memory.py +++ b/examples/st_game/memory/agent_memory.py @@ -128,12 +128,11 @@ class AgentMemory(Memory): """ 将MemoryBasic类存储为Nodes.json形式。复现GA中的Kw Strength.json形式 这里添加一个路径即可 - TODO 这里需要添加Const常量 + TODO 这里在存储时候进行倒序存储,之后需要验证(test_memory通过) """ - memory_json = dict() for i in range(len(self.storage)): - memory_node = self.storage[i] + memory_node = self.storage[len(self.storage)-i-1] memory_node = memory_node.save_to_dict() memory_json.update(memory_node) with open(memory_saved + "/nodes.json", "w") as outfile: @@ -175,14 +174,15 @@ class AgentMemory(Memory): poignancy = node_details["poignancy"] keywords = set(node_details["keywords"]) filling = node_details["filling"] - + # print(node_type) if node_type == "event": self.add_event(created, expiration, s, p, o, description, keywords, poignancy, embedding_pair, filling) elif node_type == "chat": - cause_by = node_details["cause_by"] + # cause_by = node_details["cause_by"] + logger.info(f"{node_id}") self.add_chat(created, expiration, s, p, o, - description, keywords, poignancy, embedding_pair, filling, cause_by) + description, keywords, poignancy, embedding_pair, filling) elif node_type == "thought": self.add_thought(created, expiration, s, p, o, description, keywords, poignancy, embedding_pair, filling) @@ -198,11 +198,11 @@ class AgentMemory(Memory): Add a new message to storage, while updating the index 重写add方法,修改原有的Message类为BasicMemory类,并添加不同的记忆类型添加方式 """ - if memory_basic in self.storage: + if memory_basic.memory_id in self.storage: return self.storage.append(memory_basic) - if memory_basic.cause_by: - self.index[memory_basic.cause_by][0:0] = [memory_basic] + if memory_basic.memory_type == "chat": + self.chat_list[0:0] = [memory_basic] return if memory_basic.memory_type == "thought": self.thought_list[0:0] = [memory_basic] @@ -213,14 +213,14 @@ class AgentMemory(Memory): def add_chat(self, created, expiration, s, p, o, content, keywords, poignancy, embedding_pair, filling, - cause_by): + cause_by = ''): """ 调用add方法,初始化chat,在创建的时候就需要调用embedding函数 """ memory_count = len(self.storage) + 1 type_count = len(self.thought_list) + 1 memory_type = "chat" - memory_id = f"memory_{str(memory_count)}" + memory_id = f"node_{str(memory_count)}" depth = 1 memory_node = BasicMemory(memory_id, memory_count, type_count, memory_type, depth, @@ -251,7 +251,7 @@ class AgentMemory(Memory): memory_count = len(self.storage) + 1 type_count = len(self.thought_list) + 1 memory_type = "event" - memory_id = f"memory_{str(memory_count)}" + memory_id = f"node_{str(memory_count)}" depth = 1 try: @@ -296,7 +296,7 @@ class AgentMemory(Memory): memory_count = len(self.storage) + 1 type_count = len(self.event_list) + 1 memory_type = "event" - memory_id = f"memory_{str(memory_count)}" + memory_id = f"node_{str(memory_count)}" depth = 0 if "(" in content: diff --git a/examples/st_game/memory/retrieve.py b/examples/st_game/memory/retrieve.py index 55e5b873f..ffba5dcbb 100644 --- a/examples/st_game/memory/retrieve.py +++ b/examples/st_game/memory/retrieve.py @@ -9,10 +9,10 @@ from numpy.linalg import norm from examples.st_game.memory.agent_memory import BasicMemory from examples.st_game.utils.utils import get_embedding -from examples.st_game.roles.st_role import STRole +from metagpt.logs import logger -def agent_retrieve(curr_time: datetime.datetime, memory_forget: float, query: str, nodes: list[BasicMemory], +def agent_retrieve(agent_memory, curr_time: datetime.datetime, memory_forget: float, query: str, nodes: list[BasicMemory], topk: int = 4, ) -> list[BasicMemory]: """ Retrieve需要集合Role使用,原因在于Role才具有AgentMemory,scratch @@ -28,12 +28,13 @@ def agent_retrieve(curr_time: datetime.datetime, memory_forget: float, query: st } """ memories = nodes + agent_memory_embedding = agent_memory.embeddings memories = sorted(memories, key=lambda memory_node: memory_node.last_accessed, reverse=True) score_list = [] score_list = extract_importance(memories, score_list) score_list = extract_recency(curr_time, memory_forget, score_list) - score_list = extract_relevance(query, score_list) + score_list = extract_relevance(agent_memory_embedding,query, score_list) score_list = normalize_score_floats(score_list, 0, 1) total_dict = {} @@ -43,14 +44,14 @@ def agent_retrieve(curr_time: datetime.datetime, memory_forget: float, query: st score_list[i]['recency'] * gw[1] + score_list[i]['relevance'] * gw[2] ) - total_dict[score_list[i]['memory']] = total_score + total_dict[score_list[i]['memory'].memory_id] = total_score result = top_highest_x_values(total_dict, topk) return result # 返回的是一个BasicMemory列表 -def new_agent_retrieve(role: STRole, focus_points: list, n_count=30) -> dict: +def new_agent_retrieve(role, focus_points: list, n_count=30) -> dict: """ 输入为role,关注点列表,返回记忆数量 输出为字典,键为focus_point,值为对应的记忆列表 @@ -62,12 +63,16 @@ def new_agent_retrieve(role: STRole, focus_points: list, n_count=30) -> dict: if "idle" not in i.embedding_key] nodes = sorted(nodes, key=lambda x: x[0]) nodes = [i for created, i in nodes] - results = agent_retrieve(role.scratch.curr_time, role.scratch.recency_decay, + results = agent_retrieve(role.memory, role.scratch.curr_time, role.scratch.recency_decay, focal_pt, nodes, n_count) + final_result = [] for n in results: - n.last_accessed = role.scratch.curr_time + for i in role.memory.storage: + if i.memory_id == n: + i.last_accessed = role.scratch.curr_time + final_result.append(i.content) - retrieved[focal_pt] = results + retrieved[focal_pt] = final_result return retrieved @@ -93,14 +98,15 @@ def extract_importance(memories, score_list): return score_list -def extract_relevance(query, score_list): +def extract_relevance(agent_memory_embedding,query, score_list): """ 抽取相关性 """ query_embedding = get_embedding(query) # 进行 for i in range(len(score_list)): - result = cos_sim(score_list[i]["memory"].embedding_key, query_embedding) + node_embedding = agent_memory_embedding[score_list[i]["memory"].embedding_key] + result = cos_sim(node_embedding, query_embedding) score_list[i]['relevance'] = result return score_list diff --git a/examples/st_game/tests/test_agent_memory.py b/examples/st_game/tests/test_agent_memory.py new file mode 100644 index 000000000..08d065afd --- /dev/null +++ b/examples/st_game/tests/test_agent_memory.py @@ -0,0 +1,72 @@ +import pytest +import os +import time +from datetime import datetime, timedelta +from metagpt.logs import logger +from examples.st_game.memory.agent_memory import AgentMemory +from examples.st_game.utils.const import STORAGE_PATH +from examples.st_game.memory.retrieve import agent_retrieve + +""" +memory测试思路 +1. Basic Memory测试 +2. Agent Memory测试 + 2.1 Load & Save方法测试; Load方法中使用了add方法,验证Load即可验证所有add + 2.2 Get方法测试 +""" +memory_easy_storage_path = os.path.join(STORAGE_PATH,"July1_the_ville_isabella_maria_klaus-step-3-4/personas/Isabella Rodriguez/bootstrap_memory/associative_memory") +memroy_chat_storage_path = os.path.join(STORAGE_PATH,"July1_the_ville_isabella_maria_klaus-step-3-11/personas/Isabella Rodriguez/bootstrap_memory/associative_memory") +memory_save_easy_test_path = os.path.join(STORAGE_PATH,"July1_the_ville_isabella_maria_klaus-step-3-4/personas/Isabella Rodriguez/bootstrap_memory/test_memory") +memory_save_chat_test_path = os.path.join(STORAGE_PATH,"July1_the_ville_isabella_maria_klaus-step-3-11/personas/Isabella Rodriguez/bootstrap_memory/test_memory") +class TestAgentMemory: + @pytest.fixture + def agent_memory(self): + # 创建一个AgentMemory实例并返回,可以在所有测试用例中共享 + test_agent_memory = AgentMemory() + test_agent_memory.set_mem_path(memroy_chat_storage_path) + return test_agent_memory + + def test_load(self,agent_memory): + logger.info(f"存储路径为:{agent_memory.memory_saved}") + logger.info(f"存储记忆条数为:{len(agent_memory.storage)}") + logger.info(f"kw_strength为{agent_memory.kw_strength_event},{agent_memory.kw_strength_thought}") + logger.info(f"embeeding.json条数为{len(agent_memory.embeddings)}") + + assert agent_memory.embeddings != None + + def test_save(self,agent_memory): + try: + agent_memory.save(memory_save_chat_test_path) + logger.info("成功存储") + except: + pass + + def test_summary_function(self, agent_memory): + logger.info(f"event长度为{len(agent_memory.event_list)}") + logger.info(f"thought长度为{len(agent_memory.thought_list)}") + result1 = agent_memory.get_summarized_latest_events(4) + logger.info(f"总结最近事件结果为:{result1}") + def test_get_last_chat_function(self,agent_memory): + result2 = agent_memory.get_last_chat("customers") + logger.info(f"上一次对话是{result2}") + + def test_retrieve_function(self,agent_memory): + focus_points = ["who i love?"] + retrieved = dict() + for focal_pt in focus_points: + nodes = [[i.last_accessed, i] + for i in agent_memory.event_list + agent_memory.thought_list + if "idle" not in i.embedding_key] + nodes = sorted(nodes, key=lambda x: x[0]) + nodes = [i for created, i in nodes] + results = agent_retrieve(agent_memory, datetime.now()-timedelta(days=120), 0.99, + focal_pt, nodes, 5) + final_result = [] + for n in results: + for i in agent_memory.storage: + if i.memory_id == n: + i.last_accessed = datetime.now()-timedelta(days=120) + final_result.append(i.content) + + retrieved[focal_pt] = final_result + logger.info(f"检索结果为{retrieved}") \ No newline at end of file diff --git a/examples/st_game/tests/test_basic_memory.py b/examples/st_game/tests/test_basic_memory.py new file mode 100644 index 000000000..79184ceae --- /dev/null +++ b/examples/st_game/tests/test_basic_memory.py @@ -0,0 +1,66 @@ +from datetime import datetime, timedelta +from metagpt.logs import logger +from examples.st_game.memory.agent_memory import BasicMemory +import pytest + +""" +memory测试思路 +1. Basic Memory测试 +2. Agent Memory测试 + 2.1 Load & Save方法测试 + 2.2 Add方法测试 + 2.3 Get方法测试 +""" + +# Create some sample BasicMemory instances +memory1 = BasicMemory( + memory_id="1", + memory_count=1, + type_count=1, + memory_type="event", + depth=1, + created=datetime.now(), + expiration=datetime.now() + timedelta(days=30), + subject="Subject1", + predicate="Predicate1", + object="Object1", + content="This is content 1", + embedding_key="embedding_key_1", + poignancy=1, + keywords=["keyword1", "keyword2"], + filling=["memory_id_2"] +) +memory2 = BasicMemory( + memory_id="2", + memory_count=2, + type_count=2, + memory_type="thought", + depth=2, + created=datetime.now(), + expiration=datetime.now() + timedelta(days=30), + subject="Subject2", + predicate="Predicate2", + object="Object2", + content="This is content 2", + embedding_key="embedding_key_2", + poignancy=2, + keywords=["keyword3", "keyword4"], + filling=[] +) + +@pytest.fixture +def basic_mem_set(): + basic_mem2 = memory2 + yield basic_mem2 + +def test_basic_mem_function(basic_mem_set): + a, b, c = basic_mem_set.summary() + logger.info(f"{a}{b}{c}") + assert a == "Subject2" + +def test_basic_mem_save(basic_mem_set): + result = basic_mem_set.save_to_dict() + logger.info(f"save结果为{result}") + +if __name__ == "__main__": + pytest.main() diff --git a/examples/st_game/tests/test_memory.py b/examples/st_game/tests/test_memory.py deleted file mode 100644 index 451addc8b..000000000 --- a/examples/st_game/tests/test_memory.py +++ /dev/null @@ -1,65 +0,0 @@ -from datetime import datetime -from metagpt.logs import logger -from examples.st_game.memory.agent_memory import AgentMemory, BasicMemory - -# Create some sample BasicMemory instances -memory1 = BasicMemory( - memory_id="1", - memory_count=1, - type_count=1, - memory_type="event", - depth=1, - created=datetime.now(), - expiration=datetime.now(), - subject="Subject1", - predicate="Predicate1", - object="Object1", - content="This is content 1", - embedding_key="embedding_key_1", - poignancy=1, - keywords=["keyword1", "keyword2"], - filling=["memory_id_2"] -) - -memory2 = BasicMemory( - memory_id="2", - memory_count=2, - type_count=2, - memory_type="thought", - depth=2, - created=datetime.now(), - expiration=None, - subject="Subject2", - predicate="Predicate2", - object="Object2", - content="This is content 2", - embedding_key="embedding_key_2", - poignancy=2, - keywords=["keyword3", "keyword4"], - filling=[] -) - -if __name__ == "__main__": - # Create an AgentMemory instance and add the created BasicMemory instances - agent_memory = AgentMemory(memory_saved="sample_memory_folder") - agent_memory.add_event(memory1) - agent_memory.add_thought(memory2) - - # Save the AgentMemory to a JSON file - agent_memory.save("sample_memory_folder") - - # Load the AgentMemory from the JSON file - loaded_agent_memory = AgentMemory(memory_saved="sample_memory_folder") - - # Get the summarized latest events - latest_events = loaded_agent_memory.get_summarized_latest_events(retention=2) - print("Summarized Latest Events:") - for event in latest_events: - print(event) - - # Get the last chat for a specific role - last_chat = loaded_agent_memory.get_last_chat(target_role_name="role1") - if last_chat: - print(f"Last chat for role1: {last_chat.content}") - else: - print("No chat found for role1")