From 770dcdc755cd3f34d55bc7f54da90a1f81a628a0 Mon Sep 17 00:00:00 2001 From: didi <2020201387@ruc.edu.cn> Date: Thu, 28 Sep 2023 00:17:32 +0800 Subject: [PATCH] ga_game Memory & Retrive MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. 完善了Memory模块,明天添加不同类型记忆的add方法 2. 添加了Retrive方法 3. 添加了Prompt,Scracth等模块 --- examples/st_game/memory/associative_memory.py | 125 +++++++++++++++- examples/st_game/memory/retrive.py | 138 ++++++++++++++++++ examples/st_game/memory/scratch.py | 6 + .../prompts_templates/poignancy_chat_v1.txt | 17 +++ examples/st_game/prompts/run_gpt_prompts.py | 33 +++++ examples/st_game/prompts/wrapper_prompt.py | 42 ++++++ examples/st_game/roles/st_role.py | 4 +- examples/st_game/utils/utils.py | 9 ++ 8 files changed, 369 insertions(+), 5 deletions(-) create mode 100644 examples/st_game/memory/retrive.py create mode 100644 examples/st_game/memory/scratch.py create mode 100644 examples/st_game/prompts/prompts_templates/poignancy_chat_v1.txt create mode 100644 examples/st_game/prompts/run_gpt_prompts.py create mode 100644 examples/st_game/prompts/wrapper_prompt.py diff --git a/examples/st_game/memory/associative_memory.py b/examples/st_game/memory/associative_memory.py index ef1514398..5b78926b9 100644 --- a/examples/st_game/memory/associative_memory.py +++ b/examples/st_game/memory/associative_memory.py @@ -1,9 +1,128 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -# @Desc : associative_memory to store conversation、plan detail、reflection result and so on. +# @Desc : MemoryBasic,AgentMemory实现 from metagpt.memory.memory import Memory +from metagpt.schema import Message +import json +from datetime import datetime + +class MemoryBasic(Message): + + def __init__(self,memory_id:str,memory_count:int,type_count:int,memory_type:str,depth:int,content:int, + creaetd:datetime,expiration:datetime, + subject:str,predicate:str,object:str, + embedding_key:str,poignancy:int,keywords:list,filling:list): + """ + MemoryBasic继承于MG的Message类,其中content属性替代description属性 + Message类中对于Chat类型支持的非常好,对于Agent个体的Perceive,Reflection,Plan支持的并不多 + 在Type设计上,我们延续GA的三个种类,但是对于Chat种类的对话进行特别设计(具体怎么设计还没想好) + """ + super.__init__(content) + """ + 从父类中继承的属性 + content: str # 记忆描述 + cause_by: Type["Action"] = field(default="") # 触发动作,只在Type为chat时初始化 + cause_by 接受一个Action类,在此项目中,每个Agent需要有一个基础动作[Receive] 用于接受假对话Message;而每个Agent需要有独一无二的动作类,用以接受真对话Message + """ + self.memory_id: str = memory_id # 记忆ID + self.memory_count: int = memory_count # 第几个记忆,实际数值与Memory相等,但是类型为整数 + self.type_count: int = type_count # 第几种记忆,类型为整数(具体不太理解如何生成的) + self.memory_type: str = memory_type # 记忆类型,使用Field,包含 event,thought,chat三种类型 + self.depth:str = depth # 记忆深度,类型为整数 + + self.created: datetime = creaetd # 创建时间 + self.expiration: datetime = expiration # 记忆失效时间,默认为空() + self.last_accessed: datetime = creaetd # 上一次调用的时间,初始化时候与self.created一致 + + self.subject: str = subject # 主语,str类型 + self.predicate:str = predicate # 谓语,str类型 + self.object:str = object # 宾语,str类型 + + self.embedding_key: str = embedding_key # 内容与self.content一致 + self.poignancy:int = poignancy # importance值,整数类型 + self.keywords:list = keywords # keywords,列表 + self.filling:list = filling # None或者列表 + +class AgentMemory(Memory): + """ + GA中主要存储三种JSON + 1. embedding.json (Dict embedding_key:embedding) + 2. Node.json (Dict Node_id:Node) + 3. kw_strength.json + """ + def __init__(self,memory_saved:str): + """ + AgentMemory类继承自Memory类,重写storage替代GA中id_to_node,一方面存储所有信息,一方面作为JSON转化 + index存储与不同Agent的chat信息 + @李嵩@张凯 这里的storage是List,你们需要写一个JSON转化器,将List修改为node.json一致的格式 + """ + super.__init__() + self.storage: list[MemoryBasic] = [] # 重写Stroage,存储MemoryBasic所有节点 + self.event_list = [] # 存储event记忆 + self.thought_list = [] # 存储thought记忆 + + self.event_keywords = dict() # 存储keywords + self.thought_keywords = dict() + self.chat_keywords = dict() + + self.strength_event_keywords = dict() # 不知道具体作用,所以没有删除 + self.strength_thought_keywords = dict() + + self.embeddings = json.load(open(memory_saved + "/embeddings.json")) # dict类型,load embedding.json + self.memory_load() -class AssociativeMemory(Memory): - pass + def memory_save(self): + """ + 将MemormyBasic类存储为Nodes.json形式。复现GA中的Kw Strength.json形式 + @张凯补充一个可调用的函数 + """ + pass + + def memory_load(self): + """ + 将GA的JSON解析,填充到AgentMemory类之中 + """ + pass + + def add(self, memory_basic: MemoryBasic): + """ + Add a new message to storage, while updating the index + 重写add方法,修改原有的Message类为MemoryBasic类,并添加不同的记忆类型添加方式 + """ + if memory_basic in self.storage: + return + self.storage.append(memory_basic) + if memory_basic.cause_by: + self.index[memory_basic.cause_by].append(memory_basic) + return + if memory_basic.type == "thought": + self.thought_list.append(memory_basic) + return + if memory_basic.type == "event": + self.event_list.append(memory_basic) + + def add_chat(self): + """ + 调用add方法,初始化chat,在创建的时候就需要调用embeeding函数 + """ + pass + + def add_thought(self): + """ + 调用add方法,初始化thought + """ + pass + + def add_event(self): + """ + 调用add方法,初始化event + """ + pass + + def retrive(self,): + """ + 调用 + """ + pass diff --git a/examples/st_game/memory/retrive.py b/examples/st_game/memory/retrive.py new file mode 100644 index 000000000..0bca9ea8b --- /dev/null +++ b/examples/st_game/memory/retrive.py @@ -0,0 +1,138 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : Retrive函数实现 + +from numpy import dot +from numpy.linalg import norm +from datetime import datetime +from associative_memory import AgentMemory,MemoryBasic +from utils.utils import embedding_tools + +def agent_retrive(agent:AgentMemory,currtime:datetime,memory_forget:float,query:str,n:int= 30,topk:int=4) -> list[MemoryBasic]: + """ + retrive需要集合Role使用,原因在于Role才具有AgentMemory,scratch + 逻辑:Role调用该函数,self._rc.AgentMemory,self._rc.scratch.currtime,self._rc.scratch.memory_forget + 输入希望查询的内容与希望回顾的条数,返回TopK条高分记忆,即List[MemoryBasic] + + Score_lists示例 + { + "memory":memories[i], MemoryBasic类 + "importance":memories[i].poignancy + "recency":衰减因子计算结果 + "relevance":搜索结果 + } + """ + memories = AgentMemory.storage + sorted_memories = sorted(memories, key=lambda memory_node: memory_node.last_accessed_time,reverse=True) + memories = sorted_memories[:n] if len(sorted_memories) >= n else sorted_memories + + Score_list = [] + Score_list = extract_importance(memories,Score_list) + Score_list = extract_recency(currtime,memory_forget,Score_list) + Score_list = extract_relevance(query,Score_list) + Score_list = normalize_Socre_floats(Score_list,0,1) + + total_dict = {} + gw = [1,1,1] # 三个因素的权重,重要性,近因性,相关性 + for i in range(len(Score_list)): + total_score = (Score_list[i]['importance']*gw[0] + + Score_list[i]['recency']*gw[1] + + Score_list[i]['relevance']*gw[2] + ) + total_dict[Score_list[i]['memory']] = total_score + + result = top_highest_x_values(total_dict,topk) + + return result + +def top_highest_x_values(d, x): + """ + 输入字典,Topx + 返回以字典值排序,字典键组成的List[MemoryBasic] + """ + top_v = [item[0] for item in sorted(d.items(),key=lambda item: item[1],reverse= True)[:x]] + return top_v + + +def extract_importance(memories,Score_list): + """ + 抽取重要性 + """ + for i in range(len(memories)): + Score = {"memory":memories[i], + "importance":memories[i].poignancy + } + Score_list.append(Score) + return Score_list + +# 抽取相关性 +def extract_relevance(query,Score_list): + """ + 抽取相关性 + """ + query_embedding = embedding_tools(query) + # 进行 + for i in range(len(Score_list)): + result = cos_sim(Score_list[i]["memory"].embedding_key,query_embedding) + Score_list[i]['relevance'] = result + + return Score_list + +# 抽取近因性 +def extract_recency(currtime,memory_forget,Score_list): + """ + 抽取近因性,目前使用的现实世界过一天走一个衰减因子 + """ + for i in range(len(Score_list)): + day_count = (currtime-Score_list[i]['memory'].created).days + Score_list[i]['recency'] = memory_forget**day_count + return Score_list + +def cos_sim(a, b): + """ + 计算余弦相似度 + """ + return dot(a, b)/(norm(a)*norm(b)) + +def normalize_List_floats(Single_list,target_min, target_max): + """ + 单个列表归一化 + """ + min_val = min(Single_list) + max_val = max(Single_list) + range_val = max_val - min_val + + if range_val == 0: + for i in range(len(Single_list)): + Single_list[i] = (target_max - target_min)/2 + else: + for i in range(len(Single_list)): + Single_list[i] = ((Single_list[i] - min_val) * (target_max - target_min) + / range_val + target_min) + return Single_list + + +def normalize_Socre_floats(Score_list, target_min, target_max): + """ + 整体归一化 + """ + importance_list = [] + relevance_list = [] + recency_list = [] + + for i in range(len(Score_list)): + importance_list.append(Score_list[i]['importance']) + relevance_list.append(Score_list[i]['relevance']) + recency_list.append(Score_list[i]['recency']) + + # 进行归一化操作 + importance_list = normalize_List_floats(importance_list,target_min, target_max) + relevance_list = normalize_List_floats(relevance_list,target_min, target_max) + recency_list =normalize_List_floats(recency_list,target_min, target_max) + + for i in range(len(Score_list)): + Score_list[i]['importance'] = importance_list[i] + Score_list[i]['relevance'] = relevance_list[i] + Score_list[i]['recency'] = recency_list[i] + + return Score_list \ No newline at end of file diff --git a/examples/st_game/memory/scratch.py b/examples/st_game/memory/scratch.py new file mode 100644 index 000000000..00da03dd6 --- /dev/null +++ b/examples/st_game/memory/scratch.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : Scratch类实现(角色信息类) + +class Scratch(): + pass \ No newline at end of file diff --git a/examples/st_game/prompts/prompts_templates/poignancy_chat_v1.txt b/examples/st_game/prompts/prompts_templates/poignancy_chat_v1.txt new file mode 100644 index 000000000..572dd8a05 --- /dev/null +++ b/examples/st_game/prompts/prompts_templates/poignancy_chat_v1.txt @@ -0,0 +1,17 @@ +poignancy_chat_v1.txt + +!!: agent name +!!: iss +!!: name +!!: event description + +### +Here is a brief description of !!. +!! + +On the scale of 1 to 10, where 1 is purely mundane (e.g., routine morning greetings) and 10 is extremely poignant (e.g., a conversation about breaking up, a fight), rate the likely poignancy of the following conversation for !!. + +Conversation: +!! + +Rate (return a number between 1 to 10): \ No newline at end of file diff --git a/examples/st_game/prompts/run_gpt_prompts.py b/examples/st_game/prompts/run_gpt_prompts.py new file mode 100644 index 000000000..4c94f3bea --- /dev/null +++ b/examples/st_game/prompts/run_gpt_prompts.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : 调用PromptTemplates中模板,实现 + +from wrapper_prompt import special_response_generate,prompt_generate +from memory.scratch import Scratch +from memory.associative_memory import MemoryBasic +import json + +def run_gpt_prompt_chat_poignancy(scratch:Scratch,content:MemoryBasic.content)->str: + """ + 衡量事件心酸度 + """ + def create_prompt_input(scratch,content): + prompt_input = [scratch.name, + scratch.iss, + scratch.name, + content] + return prompt_input + + # 1. Prompt构建 + # 2. Instruction给出 + prompt_template = "prompt_templates/poignancy_chat_v1.txt" ######## + prompt_input = create_prompt_input(scratch, content) ######## + prompt = prompt_generate(prompt_input, prompt_template) + special_instruction = "The output should ONLY contain ONE integer value on the scale of 1 to 10." + poignancy = special_response_generate(prompt,special_instruction) + try: + poi_dict = json.loads(poignancy) + return (poi_dict['poignancy']) + except: + return poignancy + diff --git a/examples/st_game/prompts/wrapper_prompt.py b/examples/st_game/prompts/wrapper_prompt.py new file mode 100644 index 000000000..b61e13520 --- /dev/null +++ b/examples/st_game/prompts/wrapper_prompt.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : 基于Prmopt Templates 填充Prompt; 为Prompt包装与调用 + +from metagpt import llm + +def prompt_generate(curr_input:list, prompt_path:str): + """ + curr_input:输入一个按照PromptTemplate的要求的列表 + prompt_path:输入一个Promptpath + """ + if type(curr_input) == type("string"): + curr_input = [curr_input] + curr_input = [str(i) for i in curr_input] + + f = open(prompt_path, "r") + prompt = f.read() + f.close() + for count, i in enumerate(curr_input): + prompt = prompt.replace(f"!!", i) + if "###" in prompt: + prompt = prompt.split("###")[1] + return prompt.strip() + +def response_generate(prompt:str): + """ + 待完善,我没有找到MG中可以设置Temprature以及Maxtoken的位置 + """ + return llm.ai_func(prompt) + +def special_response_generate(prompt:str,special_instruction:str,example_output:str = None): + """ + 当对于Prompt生成有特殊要求时,调用该函数增加special_instruction或example_output + """ + prompt = '"""\n' + prompt + '\n"""\n' + prompt += f"Output the response to the prompt above in json. {special_instruction}\n" + if example_output: + prompt += "Example output json:\n" + prompt += '{"output": "' + str(example_output) + '"}' + return response_generate(prompt) + + diff --git a/examples/st_game/roles/st_role.py b/examples/st_game/roles/st_role.py index fa071f069..bc6988e28 100644 --- a/examples/st_game/roles/st_role.py +++ b/examples/st_game/roles/st_role.py @@ -17,7 +17,7 @@ from pathlib import Path from metagpt.roles.role import Role, RoleContext from metagpt.schema import Message -from ..memory.associative_memory import AssociativeMemory +from ..memory.associative_memory import AgentMemory from ..actions.dummy_action import DummyAction from ..actions.user_requirement import UserRequirement from ..maze_environment import MazeEnvironment @@ -25,7 +25,7 @@ from ..maze_environment import MazeEnvironment class STRoleContext(RoleContext): env: 'MazeEnvironment' = Field(default=None) - memory: AssociativeMemory = Field(default=AssociativeMemory) + memory: AgentMemory = Field(default=AgentMemory) class STRole(Role): diff --git a/examples/st_game/utils/utils.py b/examples/st_game/utils/utils.py index a70f7606d..11cbabd8e 100644 --- a/examples/st_game/utils/utils.py +++ b/examples/st_game/utils/utils.py @@ -4,6 +4,7 @@ from typing import Any import json +import openai from pathlib import Path @@ -22,3 +23,11 @@ def read_json_file(json_file: str, encoding=None) -> list[Any]: def write_json_file(json_file: str, data: list, encoding=None): with open(json_file, "w", encoding=encoding) as fout: json.dump(data, fout, ensure_ascii=False, indent=4) + +def embedding_tools(query): + embedding_result = openai.Embedding.create( + model="text-embedding-ada-002", + input=query + ) + embedding_key = embedding_result['data'][0]["embedding"] + return embedding_key \ No newline at end of file