From 770dcdc755cd3f34d55bc7f54da90a1f81a628a0 Mon Sep 17 00:00:00 2001 From: didi <2020201387@ruc.edu.cn> Date: Thu, 28 Sep 2023 00:17:32 +0800 Subject: [PATCH 1/9] ga_game Memory & Retrive MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. 完善了Memory模块,明天添加不同类型记忆的add方法 2. 添加了Retrive方法 3. 添加了Prompt,Scracth等模块 --- examples/st_game/memory/associative_memory.py | 125 +++++++++++++++- examples/st_game/memory/retrive.py | 138 ++++++++++++++++++ examples/st_game/memory/scratch.py | 6 + .../prompts_templates/poignancy_chat_v1.txt | 17 +++ examples/st_game/prompts/run_gpt_prompts.py | 33 +++++ examples/st_game/prompts/wrapper_prompt.py | 42 ++++++ examples/st_game/roles/st_role.py | 4 +- examples/st_game/utils/utils.py | 9 ++ 8 files changed, 369 insertions(+), 5 deletions(-) create mode 100644 examples/st_game/memory/retrive.py create mode 100644 examples/st_game/memory/scratch.py create mode 100644 examples/st_game/prompts/prompts_templates/poignancy_chat_v1.txt create mode 100644 examples/st_game/prompts/run_gpt_prompts.py create mode 100644 examples/st_game/prompts/wrapper_prompt.py diff --git a/examples/st_game/memory/associative_memory.py b/examples/st_game/memory/associative_memory.py index ef1514398..5b78926b9 100644 --- a/examples/st_game/memory/associative_memory.py +++ b/examples/st_game/memory/associative_memory.py @@ -1,9 +1,128 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -# @Desc : associative_memory to store conversation、plan detail、reflection result and so on. +# @Desc : MemoryBasic,AgentMemory实现 from metagpt.memory.memory import Memory +from metagpt.schema import Message +import json +from datetime import datetime + +class MemoryBasic(Message): + + def __init__(self,memory_id:str,memory_count:int,type_count:int,memory_type:str,depth:int,content:int, + creaetd:datetime,expiration:datetime, + subject:str,predicate:str,object:str, + embedding_key:str,poignancy:int,keywords:list,filling:list): + """ + MemoryBasic继承于MG的Message类,其中content属性替代description属性 + Message类中对于Chat类型支持的非常好,对于Agent个体的Perceive,Reflection,Plan支持的并不多 + 在Type设计上,我们延续GA的三个种类,但是对于Chat种类的对话进行特别设计(具体怎么设计还没想好) + """ + super.__init__(content) + """ + 从父类中继承的属性 + content: str # 记忆描述 + cause_by: Type["Action"] = field(default="") # 触发动作,只在Type为chat时初始化 + cause_by 接受一个Action类,在此项目中,每个Agent需要有一个基础动作[Receive] 用于接受假对话Message;而每个Agent需要有独一无二的动作类,用以接受真对话Message + """ + self.memory_id: str = memory_id # 记忆ID + self.memory_count: int = memory_count # 第几个记忆,实际数值与Memory相等,但是类型为整数 + self.type_count: int = type_count # 第几种记忆,类型为整数(具体不太理解如何生成的) + self.memory_type: str = memory_type # 记忆类型,使用Field,包含 event,thought,chat三种类型 + self.depth:str = depth # 记忆深度,类型为整数 + + self.created: datetime = creaetd # 创建时间 + self.expiration: datetime = expiration # 记忆失效时间,默认为空() + self.last_accessed: datetime = creaetd # 上一次调用的时间,初始化时候与self.created一致 + + self.subject: str = subject # 主语,str类型 + self.predicate:str = predicate # 谓语,str类型 + self.object:str = object # 宾语,str类型 + + self.embedding_key: str = embedding_key # 内容与self.content一致 + self.poignancy:int = poignancy # importance值,整数类型 + self.keywords:list = keywords # keywords,列表 + self.filling:list = filling # None或者列表 + +class AgentMemory(Memory): + """ + GA中主要存储三种JSON + 1. embedding.json (Dict embedding_key:embedding) + 2. Node.json (Dict Node_id:Node) + 3. kw_strength.json + """ + def __init__(self,memory_saved:str): + """ + AgentMemory类继承自Memory类,重写storage替代GA中id_to_node,一方面存储所有信息,一方面作为JSON转化 + index存储与不同Agent的chat信息 + @李嵩@张凯 这里的storage是List,你们需要写一个JSON转化器,将List修改为node.json一致的格式 + """ + super.__init__() + self.storage: list[MemoryBasic] = [] # 重写Stroage,存储MemoryBasic所有节点 + self.event_list = [] # 存储event记忆 + self.thought_list = [] # 存储thought记忆 + + self.event_keywords = dict() # 存储keywords + self.thought_keywords = dict() + self.chat_keywords = dict() + + self.strength_event_keywords = dict() # 不知道具体作用,所以没有删除 + self.strength_thought_keywords = dict() + + self.embeddings = json.load(open(memory_saved + "/embeddings.json")) # dict类型,load embedding.json + self.memory_load() -class AssociativeMemory(Memory): - pass + def memory_save(self): + """ + 将MemormyBasic类存储为Nodes.json形式。复现GA中的Kw Strength.json形式 + @张凯补充一个可调用的函数 + """ + pass + + def memory_load(self): + """ + 将GA的JSON解析,填充到AgentMemory类之中 + """ + pass + + def add(self, memory_basic: MemoryBasic): + """ + Add a new message to storage, while updating the index + 重写add方法,修改原有的Message类为MemoryBasic类,并添加不同的记忆类型添加方式 + """ + if memory_basic in self.storage: + return + self.storage.append(memory_basic) + if memory_basic.cause_by: + self.index[memory_basic.cause_by].append(memory_basic) + return + if memory_basic.type == "thought": + self.thought_list.append(memory_basic) + return + if memory_basic.type == "event": + self.event_list.append(memory_basic) + + def add_chat(self): + """ + 调用add方法,初始化chat,在创建的时候就需要调用embeeding函数 + """ + pass + + def add_thought(self): + """ + 调用add方法,初始化thought + """ + pass + + def add_event(self): + """ + 调用add方法,初始化event + """ + pass + + def retrive(self,): + """ + 调用 + """ + pass diff --git a/examples/st_game/memory/retrive.py b/examples/st_game/memory/retrive.py new file mode 100644 index 000000000..0bca9ea8b --- /dev/null +++ b/examples/st_game/memory/retrive.py @@ -0,0 +1,138 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : Retrive函数实现 + +from numpy import dot +from numpy.linalg import norm +from datetime import datetime +from associative_memory import AgentMemory,MemoryBasic +from utils.utils import embedding_tools + +def agent_retrive(agent:AgentMemory,currtime:datetime,memory_forget:float,query:str,n:int= 30,topk:int=4) -> list[MemoryBasic]: + """ + retrive需要集合Role使用,原因在于Role才具有AgentMemory,scratch + 逻辑:Role调用该函数,self._rc.AgentMemory,self._rc.scratch.currtime,self._rc.scratch.memory_forget + 输入希望查询的内容与希望回顾的条数,返回TopK条高分记忆,即List[MemoryBasic] + + Score_lists示例 + { + "memory":memories[i], MemoryBasic类 + "importance":memories[i].poignancy + "recency":衰减因子计算结果 + "relevance":搜索结果 + } + """ + memories = AgentMemory.storage + sorted_memories = sorted(memories, key=lambda memory_node: memory_node.last_accessed_time,reverse=True) + memories = sorted_memories[:n] if len(sorted_memories) >= n else sorted_memories + + Score_list = [] + Score_list = extract_importance(memories,Score_list) + Score_list = extract_recency(currtime,memory_forget,Score_list) + Score_list = extract_relevance(query,Score_list) + Score_list = normalize_Socre_floats(Score_list,0,1) + + total_dict = {} + gw = [1,1,1] # 三个因素的权重,重要性,近因性,相关性 + for i in range(len(Score_list)): + total_score = (Score_list[i]['importance']*gw[0] + + Score_list[i]['recency']*gw[1] + + Score_list[i]['relevance']*gw[2] + ) + total_dict[Score_list[i]['memory']] = total_score + + result = top_highest_x_values(total_dict,topk) + + return result + +def top_highest_x_values(d, x): + """ + 输入字典,Topx + 返回以字典值排序,字典键组成的List[MemoryBasic] + """ + top_v = [item[0] for item in sorted(d.items(),key=lambda item: item[1],reverse= True)[:x]] + return top_v + + +def extract_importance(memories,Score_list): + """ + 抽取重要性 + """ + for i in range(len(memories)): + Score = {"memory":memories[i], + "importance":memories[i].poignancy + } + Score_list.append(Score) + return Score_list + +# 抽取相关性 +def extract_relevance(query,Score_list): + """ + 抽取相关性 + """ + query_embedding = embedding_tools(query) + # 进行 + for i in range(len(Score_list)): + result = cos_sim(Score_list[i]["memory"].embedding_key,query_embedding) + Score_list[i]['relevance'] = result + + return Score_list + +# 抽取近因性 +def extract_recency(currtime,memory_forget,Score_list): + """ + 抽取近因性,目前使用的现实世界过一天走一个衰减因子 + """ + for i in range(len(Score_list)): + day_count = (currtime-Score_list[i]['memory'].created).days + Score_list[i]['recency'] = memory_forget**day_count + return Score_list + +def cos_sim(a, b): + """ + 计算余弦相似度 + """ + return dot(a, b)/(norm(a)*norm(b)) + +def normalize_List_floats(Single_list,target_min, target_max): + """ + 单个列表归一化 + """ + min_val = min(Single_list) + max_val = max(Single_list) + range_val = max_val - min_val + + if range_val == 0: + for i in range(len(Single_list)): + Single_list[i] = (target_max - target_min)/2 + else: + for i in range(len(Single_list)): + Single_list[i] = ((Single_list[i] - min_val) * (target_max - target_min) + / range_val + target_min) + return Single_list + + +def normalize_Socre_floats(Score_list, target_min, target_max): + """ + 整体归一化 + """ + importance_list = [] + relevance_list = [] + recency_list = [] + + for i in range(len(Score_list)): + importance_list.append(Score_list[i]['importance']) + relevance_list.append(Score_list[i]['relevance']) + recency_list.append(Score_list[i]['recency']) + + # 进行归一化操作 + importance_list = normalize_List_floats(importance_list,target_min, target_max) + relevance_list = normalize_List_floats(relevance_list,target_min, target_max) + recency_list =normalize_List_floats(recency_list,target_min, target_max) + + for i in range(len(Score_list)): + Score_list[i]['importance'] = importance_list[i] + Score_list[i]['relevance'] = relevance_list[i] + Score_list[i]['recency'] = recency_list[i] + + return Score_list \ No newline at end of file diff --git a/examples/st_game/memory/scratch.py b/examples/st_game/memory/scratch.py new file mode 100644 index 000000000..00da03dd6 --- /dev/null +++ b/examples/st_game/memory/scratch.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : Scratch类实现(角色信息类) + +class Scratch(): + pass \ No newline at end of file diff --git a/examples/st_game/prompts/prompts_templates/poignancy_chat_v1.txt b/examples/st_game/prompts/prompts_templates/poignancy_chat_v1.txt new file mode 100644 index 000000000..572dd8a05 --- /dev/null +++ b/examples/st_game/prompts/prompts_templates/poignancy_chat_v1.txt @@ -0,0 +1,17 @@ +poignancy_chat_v1.txt + +!!: agent name +!!: iss +!!: name +!!: event description + +### +Here is a brief description of !!. +!! + +On the scale of 1 to 10, where 1 is purely mundane (e.g., routine morning greetings) and 10 is extremely poignant (e.g., a conversation about breaking up, a fight), rate the likely poignancy of the following conversation for !!. + +Conversation: +!! + +Rate (return a number between 1 to 10): \ No newline at end of file diff --git a/examples/st_game/prompts/run_gpt_prompts.py b/examples/st_game/prompts/run_gpt_prompts.py new file mode 100644 index 000000000..4c94f3bea --- /dev/null +++ b/examples/st_game/prompts/run_gpt_prompts.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : 调用PromptTemplates中模板,实现 + +from wrapper_prompt import special_response_generate,prompt_generate +from memory.scratch import Scratch +from memory.associative_memory import MemoryBasic +import json + +def run_gpt_prompt_chat_poignancy(scratch:Scratch,content:MemoryBasic.content)->str: + """ + 衡量事件心酸度 + """ + def create_prompt_input(scratch,content): + prompt_input = [scratch.name, + scratch.iss, + scratch.name, + content] + return prompt_input + + # 1. Prompt构建 + # 2. Instruction给出 + prompt_template = "prompt_templates/poignancy_chat_v1.txt" ######## + prompt_input = create_prompt_input(scratch, content) ######## + prompt = prompt_generate(prompt_input, prompt_template) + special_instruction = "The output should ONLY contain ONE integer value on the scale of 1 to 10." + poignancy = special_response_generate(prompt,special_instruction) + try: + poi_dict = json.loads(poignancy) + return (poi_dict['poignancy']) + except: + return poignancy + diff --git a/examples/st_game/prompts/wrapper_prompt.py b/examples/st_game/prompts/wrapper_prompt.py new file mode 100644 index 000000000..b61e13520 --- /dev/null +++ b/examples/st_game/prompts/wrapper_prompt.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : 基于Prmopt Templates 填充Prompt; 为Prompt包装与调用 + +from metagpt import llm + +def prompt_generate(curr_input:list, prompt_path:str): + """ + curr_input:输入一个按照PromptTemplate的要求的列表 + prompt_path:输入一个Promptpath + """ + if type(curr_input) == type("string"): + curr_input = [curr_input] + curr_input = [str(i) for i in curr_input] + + f = open(prompt_path, "r") + prompt = f.read() + f.close() + for count, i in enumerate(curr_input): + prompt = prompt.replace(f"!!", i) + if "###" in prompt: + prompt = prompt.split("###")[1] + return prompt.strip() + +def response_generate(prompt:str): + """ + 待完善,我没有找到MG中可以设置Temprature以及Maxtoken的位置 + """ + return llm.ai_func(prompt) + +def special_response_generate(prompt:str,special_instruction:str,example_output:str = None): + """ + 当对于Prompt生成有特殊要求时,调用该函数增加special_instruction或example_output + """ + prompt = '"""\n' + prompt + '\n"""\n' + prompt += f"Output the response to the prompt above in json. {special_instruction}\n" + if example_output: + prompt += "Example output json:\n" + prompt += '{"output": "' + str(example_output) + '"}' + return response_generate(prompt) + + diff --git a/examples/st_game/roles/st_role.py b/examples/st_game/roles/st_role.py index fa071f069..bc6988e28 100644 --- a/examples/st_game/roles/st_role.py +++ b/examples/st_game/roles/st_role.py @@ -17,7 +17,7 @@ from pathlib import Path from metagpt.roles.role import Role, RoleContext from metagpt.schema import Message -from ..memory.associative_memory import AssociativeMemory +from ..memory.associative_memory import AgentMemory from ..actions.dummy_action import DummyAction from ..actions.user_requirement import UserRequirement from ..maze_environment import MazeEnvironment @@ -25,7 +25,7 @@ from ..maze_environment import MazeEnvironment class STRoleContext(RoleContext): env: 'MazeEnvironment' = Field(default=None) - memory: AssociativeMemory = Field(default=AssociativeMemory) + memory: AgentMemory = Field(default=AgentMemory) class STRole(Role): diff --git a/examples/st_game/utils/utils.py b/examples/st_game/utils/utils.py index a70f7606d..11cbabd8e 100644 --- a/examples/st_game/utils/utils.py +++ b/examples/st_game/utils/utils.py @@ -4,6 +4,7 @@ from typing import Any import json +import openai from pathlib import Path @@ -22,3 +23,11 @@ def read_json_file(json_file: str, encoding=None) -> list[Any]: def write_json_file(json_file: str, data: list, encoding=None): with open(json_file, "w", encoding=encoding) as fout: json.dump(data, fout, ensure_ascii=False, indent=4) + +def embedding_tools(query): + embedding_result = openai.Embedding.create( + model="text-embedding-ada-002", + input=query + ) + embedding_key = embedding_result['data'][0]["embedding"] + return embedding_key \ No newline at end of file From 9672d8296981a34ca14b29addc963b5700145857 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=88=B1=E5=90=83=E5=B1=81=E7=9A=84=E5=B0=8F=E5=BC=A0?= <84363704+fucking-dog@users.noreply.github.com> Date: Thu, 28 Sep 2023 11:02:35 +0800 Subject: [PATCH 2/9] Update associative_memory.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修改规范错误 --- examples/st_game/memory/associative_memory.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/st_game/memory/associative_memory.py b/examples/st_game/memory/associative_memory.py index 5b78926b9..6a40b3dda 100644 --- a/examples/st_game/memory/associative_memory.py +++ b/examples/st_game/memory/associative_memory.py @@ -9,10 +9,10 @@ from datetime import datetime class MemoryBasic(Message): - def __init__(self,memory_id:str,memory_count:int,type_count:int,memory_type:str,depth:int,content:int, - creaetd:datetime,expiration:datetime, - subject:str,predicate:str,object:str, - embedding_key:str,poignancy:int,keywords:list,filling:list): + def __init__(self, memory_id:str, memory_count:int, type_count:int, memory_type:str, depth:int, content:int, + creaetd:datetime, expiration:datetime, + subject:str, predicate:str, object:str, + embedding_key:str, poignancy:int, keywords:list, filling:list): """ MemoryBasic继承于MG的Message类,其中content属性替代description属性 Message类中对于Chat类型支持的非常好,对于Agent个体的Perceive,Reflection,Plan支持的并不多 @@ -51,7 +51,7 @@ class AgentMemory(Memory): 2. Node.json (Dict Node_id:Node) 3. kw_strength.json """ - def __init__(self,memory_saved:str): + def __init__(self, memory_saved:str): """ AgentMemory类继承自Memory类,重写storage替代GA中id_to_node,一方面存储所有信息,一方面作为JSON转化 index存储与不同Agent的chat信息 From 6911ca87ab38f09c001819e75544b610fee774ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=88=B1=E5=90=83=E5=B1=81=E7=9A=84=E5=B0=8F=E5=BC=A0?= <84363704+fucking-dog@users.noreply.github.com> Date: Thu, 28 Sep 2023 11:07:20 +0800 Subject: [PATCH 3/9] Update and rename retrive.py to retrieve.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修改了命名与格式问题 retrieve暂时无法放在AgentMemory类中,这个方法计划是交给Role调用的,因为需要使用到AgentMemory类与Scratch类 --- .../st_game/memory/{retrive.py => retrieve.py} | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) rename examples/st_game/memory/{retrive.py => retrieve.py} (89%) diff --git a/examples/st_game/memory/retrive.py b/examples/st_game/memory/retrieve.py similarity index 89% rename from examples/st_game/memory/retrive.py rename to examples/st_game/memory/retrieve.py index 0bca9ea8b..4119524aa 100644 --- a/examples/st_game/memory/retrive.py +++ b/examples/st_game/memory/retrieve.py @@ -8,7 +8,7 @@ from datetime import datetime from associative_memory import AgentMemory,MemoryBasic from utils.utils import embedding_tools -def agent_retrive(agent:AgentMemory,currtime:datetime,memory_forget:float,query:str,n:int= 30,topk:int=4) -> list[MemoryBasic]: +def agent_retrive(agentmemory:AgentMemory, currtime:datetime, memory_forget:float, query:str, n:int= 30, topk:int=4) -> list[MemoryBasic]: """ retrive需要集合Role使用,原因在于Role才具有AgentMemory,scratch 逻辑:Role调用该函数,self._rc.AgentMemory,self._rc.scratch.currtime,self._rc.scratch.memory_forget @@ -22,7 +22,7 @@ def agent_retrive(agent:AgentMemory,currtime:datetime,memory_forget:float,query: "relevance":搜索结果 } """ - memories = AgentMemory.storage + memories = agentmemory.storage sorted_memories = sorted(memories, key=lambda memory_node: memory_node.last_accessed_time,reverse=True) memories = sorted_memories[:n] if len(sorted_memories) >= n else sorted_memories @@ -54,7 +54,7 @@ def top_highest_x_values(d, x): return top_v -def extract_importance(memories,Score_list): +def extract_importance(memories, Score_list): """ 抽取重要性 """ @@ -66,7 +66,7 @@ def extract_importance(memories,Score_list): return Score_list # 抽取相关性 -def extract_relevance(query,Score_list): +def extract_relevance(query, Score_list): """ 抽取相关性 """ @@ -79,7 +79,7 @@ def extract_relevance(query,Score_list): return Score_list # 抽取近因性 -def extract_recency(currtime,memory_forget,Score_list): +def extract_recency(currtime, memory_forget, Score_list): """ 抽取近因性,目前使用的现实世界过一天走一个衰减因子 """ @@ -94,7 +94,7 @@ def cos_sim(a, b): """ return dot(a, b)/(norm(a)*norm(b)) -def normalize_List_floats(Single_list,target_min, target_max): +def normalize_List_floats(Single_list, target_min, target_max): """ 单个列表归一化 """ @@ -112,7 +112,7 @@ def normalize_List_floats(Single_list,target_min, target_max): return Single_list -def normalize_Socre_floats(Score_list, target_min, target_max): +def normalize_socre_floats(Score_list, target_min, target_max): """ 整体归一化 """ @@ -135,4 +135,4 @@ def normalize_Socre_floats(Score_list, target_min, target_max): Score_list[i]['relevance'] = relevance_list[i] Score_list[i]['recency'] = recency_list[i] - return Score_list \ No newline at end of file + return Score_list From c10f1306f511ff80f7d0c21f49439b0b39727a4d Mon Sep 17 00:00:00 2001 From: didi <2020201387@ruc.edu.cn> Date: Thu, 28 Sep 2023 11:18:22 +0800 Subject: [PATCH 4/9] =?UTF-8?q?=E5=AE=8C=E6=88=90=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 完成PR374修改 --- examples/st_game/memory/retrieve.py | 12 ++++++------ .../{prompts_templates => }/poignancy_chat_v1.txt | 0 examples/st_game/prompts/run_gpt_prompts.py | 8 ++++---- 3 files changed, 10 insertions(+), 10 deletions(-) rename examples/st_game/prompts/{prompts_templates => }/poignancy_chat_v1.txt (100%) diff --git a/examples/st_game/memory/retrieve.py b/examples/st_game/memory/retrieve.py index 4119524aa..79042df06 100644 --- a/examples/st_game/memory/retrieve.py +++ b/examples/st_game/memory/retrieve.py @@ -27,10 +27,10 @@ def agent_retrive(agentmemory:AgentMemory, currtime:datetime, memory_forget:floa memories = sorted_memories[:n] if len(sorted_memories) >= n else sorted_memories Score_list = [] - Score_list = extract_importance(memories,Score_list) - Score_list = extract_recency(currtime,memory_forget,Score_list) - Score_list = extract_relevance(query,Score_list) - Score_list = normalize_Socre_floats(Score_list,0,1) + Score_list = extract_importance(memories, Score_list) + Score_list = extract_recency(currtime, memory_forget, Score_list) + Score_list = extract_relevance(query, Score_list) + Score_list = normalize_Socre_floats(Score_list, 0, 1) total_dict = {} gw = [1,1,1] # 三个因素的权重,重要性,近因性,相关性 @@ -41,7 +41,7 @@ def agent_retrive(agentmemory:AgentMemory, currtime:datetime, memory_forget:floa ) total_dict[Score_list[i]['memory']] = total_score - result = top_highest_x_values(total_dict,topk) + result = top_highest_x_values(total_dict, topk) return result @@ -73,7 +73,7 @@ def extract_relevance(query, Score_list): query_embedding = embedding_tools(query) # 进行 for i in range(len(Score_list)): - result = cos_sim(Score_list[i]["memory"].embedding_key,query_embedding) + result = cos_sim(Score_list[i]["memory"].embedding_key, query_embedding) Score_list[i]['relevance'] = result return Score_list diff --git a/examples/st_game/prompts/prompts_templates/poignancy_chat_v1.txt b/examples/st_game/prompts/poignancy_chat_v1.txt similarity index 100% rename from examples/st_game/prompts/prompts_templates/poignancy_chat_v1.txt rename to examples/st_game/prompts/poignancy_chat_v1.txt diff --git a/examples/st_game/prompts/run_gpt_prompts.py b/examples/st_game/prompts/run_gpt_prompts.py index 4c94f3bea..86db1c7c2 100644 --- a/examples/st_game/prompts/run_gpt_prompts.py +++ b/examples/st_game/prompts/run_gpt_prompts.py @@ -7,11 +7,11 @@ from memory.scratch import Scratch from memory.associative_memory import MemoryBasic import json -def run_gpt_prompt_chat_poignancy(scratch:Scratch,content:MemoryBasic.content)->str: +def get_poignancy_action(scratch:Scratch, content:MemoryBasic.content)->str: """ 衡量事件心酸度 """ - def create_prompt_input(scratch,content): + def create_prompt_input(scratch, content): prompt_input = [scratch.name, scratch.iss, scratch.name, @@ -20,11 +20,11 @@ def run_gpt_prompt_chat_poignancy(scratch:Scratch,content:MemoryBasic.content)-> # 1. Prompt构建 # 2. Instruction给出 - prompt_template = "prompt_templates/poignancy_chat_v1.txt" ######## + prompt_template = "poignancy_chat_v1.txt" ######## prompt_input = create_prompt_input(scratch, content) ######## prompt = prompt_generate(prompt_input, prompt_template) special_instruction = "The output should ONLY contain ONE integer value on the scale of 1 to 10." - poignancy = special_response_generate(prompt,special_instruction) + poignancy = special_response_generate(prompt, special_instruction) try: poi_dict = json.loads(poignancy) return (poi_dict['poignancy']) From e7c966653e85d599a4ea21f18b49dc48e6b4ace9 Mon Sep 17 00:00:00 2001 From: didi <2020201387@ruc.edu.cn> Date: Thu, 28 Sep 2023 11:54:11 +0800 Subject: [PATCH 5/9] =?UTF-8?q?format=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 完成 --- examples/st_game/memory/associative_memory.py | 30 ++-- examples/st_game/memory/retrieve.py | 152 +++++++++--------- examples/st_game/prompts/run_gpt_prompts.py | 20 +-- examples/st_game/prompts/wrapper_prompt.py | 52 +++--- 4 files changed, 134 insertions(+), 120 deletions(-) diff --git a/examples/st_game/memory/associative_memory.py b/examples/st_game/memory/associative_memory.py index 6a40b3dda..c771906ec 100644 --- a/examples/st_game/memory/associative_memory.py +++ b/examples/st_game/memory/associative_memory.py @@ -7,12 +7,13 @@ from metagpt.schema import Message import json from datetime import datetime + class MemoryBasic(Message): - def __init__(self, memory_id:str, memory_count:int, type_count:int, memory_type:str, depth:int, content:int, - creaetd:datetime, expiration:datetime, - subject:str, predicate:str, object:str, - embedding_key:str, poignancy:int, keywords:list, filling:list): + def __init__(self, memory_id: str, memory_count: int, type_count: int, memory_type: str, depth: int, content: int, + creaetd: datetime, expiration: datetime, + subject: str, predicate: str, object: str, + embedding_key: str, poignancy: int, keywords: list, filling: list): """ MemoryBasic继承于MG的Message类,其中content属性替代description属性 Message类中对于Chat类型支持的非常好,对于Agent个体的Perceive,Reflection,Plan支持的并不多 @@ -29,29 +30,30 @@ class MemoryBasic(Message): self.memory_count: int = memory_count # 第几个记忆,实际数值与Memory相等,但是类型为整数 self.type_count: int = type_count # 第几种记忆,类型为整数(具体不太理解如何生成的) self.memory_type: str = memory_type # 记忆类型,使用Field,包含 event,thought,chat三种类型 - self.depth:str = depth # 记忆深度,类型为整数 + self.depth: str = depth # 记忆深度,类型为整数 self.created: datetime = creaetd # 创建时间 self.expiration: datetime = expiration # 记忆失效时间,默认为空() self.last_accessed: datetime = creaetd # 上一次调用的时间,初始化时候与self.created一致 self.subject: str = subject # 主语,str类型 - self.predicate:str = predicate # 谓语,str类型 - self.object:str = object # 宾语,str类型 + self.predicate: str = predicate # 谓语,str类型 + self.object: str = object # 宾语,str类型 self.embedding_key: str = embedding_key # 内容与self.content一致 - self.poignancy:int = poignancy # importance值,整数类型 - self.keywords:list = keywords # keywords,列表 - self.filling:list = filling # None或者列表 + self.poignancy: int = poignancy # importance值,整数类型 + self.keywords: list = keywords # keywords,列表 + self.filling: list = filling # None或者列表 + class AgentMemory(Memory): """ GA中主要存储三种JSON 1. embedding.json (Dict embedding_key:embedding) - 2. Node.json (Dict Node_id:Node) - 3. kw_strength.json + 2. Node.json (Dict Node_id:Node) + 3. kw_strength.json """ - def __init__(self, memory_saved:str): + def __init__(self, memory_saved: str): """ AgentMemory类继承自Memory类,重写storage替代GA中id_to_node,一方面存储所有信息,一方面作为JSON转化 index存储与不同Agent的chat信息 @@ -61,7 +63,7 @@ class AgentMemory(Memory): self.storage: list[MemoryBasic] = [] # 重写Stroage,存储MemoryBasic所有节点 self.event_list = [] # 存储event记忆 self.thought_list = [] # 存储thought记忆 - + self.event_keywords = dict() # 存储keywords self.thought_keywords = dict() self.chat_keywords = dict() diff --git a/examples/st_game/memory/retrieve.py b/examples/st_game/memory/retrieve.py index 79042df06..5ac4a9b29 100644 --- a/examples/st_game/memory/retrieve.py +++ b/examples/st_game/memory/retrieve.py @@ -1,118 +1,122 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -# @Desc : Retrive函数实现 +# @Desc : Retrieve函数实现 +import datetime from numpy import dot from numpy.linalg import norm -from datetime import datetime -from associative_memory import AgentMemory,MemoryBasic +from associative_memory import AgentMemory, MemoryBasic from utils.utils import embedding_tools -def agent_retrive(agentmemory:AgentMemory, currtime:datetime, memory_forget:float, query:str, n:int= 30, topk:int=4) -> list[MemoryBasic]: + +def agent_retrieve(agent_memory: AgentMemory, curr_time: datetime.datetime, memory_forget: float, query: str, n: int = 30, topk: int = 4) -> list[MemoryBasic]: """ - retrive需要集合Role使用,原因在于Role才具有AgentMemory,scratch - 逻辑:Role调用该函数,self._rc.AgentMemory,self._rc.scratch.currtime,self._rc.scratch.memory_forget + Retrieve需要集合Role使用,原因在于Role才具有AgentMemory,scratch + 逻辑:Role调用该函数,self._rc.AgentMemory,self._rc.scratch.curr_time,self._rc.scratch.memory_forget 输入希望查询的内容与希望回顾的条数,返回TopK条高分记忆,即List[MemoryBasic] Score_lists示例 { - "memory":memories[i], MemoryBasic类 - "importance":memories[i].poignancy - "recency":衰减因子计算结果 - "relevance":搜索结果 + "memory": memories[i], MemoryBasic类 + "importance": memories[i].poignancy + "recency": 衰减因子计算结果 + "relevance": 搜索结果 } """ - memories = agentmemory.storage - sorted_memories = sorted(memories, key=lambda memory_node: memory_node.last_accessed_time,reverse=True) + memories = agent_memory.storage + sorted_memories = sorted(memories, key=lambda memory_node: memory_node.last_accessed_time, reverse=True) memories = sorted_memories[:n] if len(sorted_memories) >= n else sorted_memories - Score_list = [] - Score_list = extract_importance(memories, Score_list) - Score_list = extract_recency(currtime, memory_forget, Score_list) - Score_list = extract_relevance(query, Score_list) - Score_list = normalize_Socre_floats(Score_list, 0, 1) + score_list = [] + score_list = extract_importance(memories, score_list) + score_list = extract_recency(curr_time, memory_forget, score_list) + score_list = extract_relevance(query, score_list) + score_list = normalize_score_floats(score_list, 0, 1) + + total_dict = {} + gw = [1, 1, 1] # 三个因素的权重,重要性,近因性,相关性 + for i in range(len(score_list)): + total_score = (score_list[i]['importance'] * gw[0] + + score_list[i]['recency'] * gw[1] + + score_list[i]['relevance'] * gw[2] + ) + total_dict[score_list[i]['memory']] = total_score - total_dict = {} - gw = [1,1,1] # 三个因素的权重,重要性,近因性,相关性 - for i in range(len(Score_list)): - total_score = (Score_list[i]['importance']*gw[0] + - Score_list[i]['recency']*gw[1] + - Score_list[i]['relevance']*gw[2] - ) - total_dict[Score_list[i]['memory']] = total_score - result = top_highest_x_values(total_dict, topk) return result + def top_highest_x_values(d, x): """ 输入字典,Topx 返回以字典值排序,字典键组成的List[MemoryBasic] """ - top_v = [item[0] for item in sorted(d.items(),key=lambda item: item[1],reverse= True)[:x]] + top_v = [item[0] for item in sorted(d.items(), key=lambda item: item[1], reverse=True)[:x]] return top_v -def extract_importance(memories, Score_list): +def extract_importance(memories, score_list): """ 抽取重要性 """ for i in range(len(memories)): - Score = {"memory":memories[i], - "importance":memories[i].poignancy + score = {"memory": memories[i], + "importance": memories[i].poignancy } - Score_list.append(Score) - return Score_list + score_list.append(score) + return score_list -# 抽取相关性 -def extract_relevance(query, Score_list): + +def extract_relevance(query, score_list): """ 抽取相关性 """ query_embedding = embedding_tools(query) # 进行 - for i in range(len(Score_list)): - result = cos_sim(Score_list[i]["memory"].embedding_key, query_embedding) - Score_list[i]['relevance'] = result + for i in range(len(score_list)): + result = cos_sim(score_list[i]["memory"].embedding_key, query_embedding) + score_list[i]['relevance'] = result - return Score_list + return score_list -# 抽取近因性 -def extract_recency(currtime, memory_forget, Score_list): + +def extract_recency(curr_time, memory_forget, score_list): """ 抽取近因性,目前使用的现实世界过一天走一个衰减因子 """ - for i in range(len(Score_list)): - day_count = (currtime-Score_list[i]['memory'].created).days - Score_list[i]['recency'] = memory_forget**day_count - return Score_list + for i in range(len(score_list)): + day_count = (curr_time - score_list[i]['memory'].created).days + score_list[i]['recency'] = memory_forget**day_count + return score_list -def cos_sim(a, b): - """ - 计算余弦相似度 - """ - return dot(a, b)/(norm(a)*norm(b)) -def normalize_List_floats(Single_list, target_min, target_max): +def cos_sim(a, b): + """ + 计算余弦相似度 + """ + return dot(a, b) / (norm(a) * norm(b)) + + +def normalize_list_floats(single_list, target_min, target_max): """ 单个列表归一化 """ - min_val = min(Single_list) - max_val = max(Single_list) + min_val = min(single_list) + max_val = max(single_list) range_val = max_val - min_val - if range_val == 0: - for i in range(len(Single_list)): - Single_list[i] = (target_max - target_min)/2 - else: - for i in range(len(Single_list)): - Single_list[i] = ((Single_list[i] - min_val) * (target_max - target_min) - / range_val + target_min) - return Single_list + if range_val == 0: + for i in range(len(single_list)): + single_list[i] = (target_max - target_min) / 2 + else: + for i in range(len(single_list)): + single_list[i] = ((single_list[i] - min_val) * (target_max - target_min) + / range_val + target_min) + return single_list -def normalize_socre_floats(Score_list, target_min, target_max): +def normalize_score_floats(score_list, target_min, target_max): """ 整体归一化 """ @@ -120,19 +124,19 @@ def normalize_socre_floats(Score_list, target_min, target_max): relevance_list = [] recency_list = [] - for i in range(len(Score_list)): - importance_list.append(Score_list[i]['importance']) - relevance_list.append(Score_list[i]['relevance']) - recency_list.append(Score_list[i]['recency']) + for i in range(len(score_list)): + importance_list.append(score_list[i]['importance']) + relevance_list.append(score_list[i]['relevance']) + recency_list.append(score_list[i]['recency']) # 进行归一化操作 - importance_list = normalize_List_floats(importance_list,target_min, target_max) - relevance_list = normalize_List_floats(relevance_list,target_min, target_max) - recency_list =normalize_List_floats(recency_list,target_min, target_max) + importance_list = normalize_list_floats(importance_list, target_min, target_max) + relevance_list = normalize_list_floats(relevance_list, target_min, target_max) + recency_list = normalize_list_floats(recency_list, target_min, target_max) - for i in range(len(Score_list)): - Score_list[i]['importance'] = importance_list[i] - Score_list[i]['relevance'] = relevance_list[i] - Score_list[i]['recency'] = recency_list[i] - - return Score_list + for i in range(len(score_list)): + score_list[i]['importance'] = importance_list[i] + score_list[i]['relevance'] = relevance_list[i] + score_list[i]['recency'] = recency_list[i] + + return score_list diff --git a/examples/st_game/prompts/run_gpt_prompts.py b/examples/st_game/prompts/run_gpt_prompts.py index 86db1c7c2..16ccbc29c 100644 --- a/examples/st_game/prompts/run_gpt_prompts.py +++ b/examples/st_game/prompts/run_gpt_prompts.py @@ -1,18 +1,19 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -# @Desc : 调用PromptTemplates中模板,实现 +# @Desc : 调用Prompts中模板,实现相关Action -from wrapper_prompt import special_response_generate,prompt_generate +from wrapper_prompt import special_response_generate, prompt_generate from memory.scratch import Scratch from memory.associative_memory import MemoryBasic import json -def get_poignancy_action(scratch:Scratch, content:MemoryBasic.content)->str: + +def get_poignancy_action(scratch: Scratch, content: MemoryBasic.content) -> str: """ 衡量事件心酸度 """ - def create_prompt_input(scratch, content): - prompt_input = [scratch.name, + def create_prompt_input(scratch, content): + prompt_input = [scratch.name, scratch.iss, scratch.name, content] @@ -20,14 +21,13 @@ def get_poignancy_action(scratch:Scratch, content:MemoryBasic.content)->str: # 1. Prompt构建 # 2. Instruction给出 - prompt_template = "poignancy_chat_v1.txt" ######## - prompt_input = create_prompt_input(scratch, content) ######## + prompt_template = "poignancy_chat_v1.txt" # 保留原来的注释 + prompt_input = create_prompt_input(scratch, content) # 保留原来的注释 prompt = prompt_generate(prompt_input, prompt_template) special_instruction = "The output should ONLY contain ONE integer value on the scale of 1 to 10." poignancy = special_response_generate(prompt, special_instruction) try: poi_dict = json.loads(poignancy) - return (poi_dict['poignancy']) - except: + return str(poi_dict['poignancy']) # 将返回值强制转换为字符串 + except json.JSONDecodeError as e: return poignancy - diff --git a/examples/st_game/prompts/wrapper_prompt.py b/examples/st_game/prompts/wrapper_prompt.py index b61e13520..0950f99d1 100644 --- a/examples/st_game/prompts/wrapper_prompt.py +++ b/examples/st_game/prompts/wrapper_prompt.py @@ -1,42 +1,50 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -# @Desc : 基于Prmopt Templates 填充Prompt; 为Prompt包装与调用 +# @Desc : 基于Prompt Templates 填充Prompt; 为Prompt包装与调用 from metagpt import llm -def prompt_generate(curr_input:list, prompt_path:str): - """ - curr_input:输入一个按照PromptTemplate的要求的列表 - prompt_path:输入一个Promptpath - """ - if type(curr_input) == type("string"): - curr_input = [curr_input] - curr_input = [str(i) for i in curr_input] - f = open(prompt_path, "r") - prompt = f.read() - f.close() - for count, i in enumerate(curr_input): +def prompt_generate(curr_input: list, prompt_path: str): + """ + curr_input: 输入一个按照Prompt Template的要求的列表 + prompt_path: 输入一个Prompt path + """ + # 如果输入是字符串,将其转换为列表 + if isinstance(curr_input, str): + curr_input = [curr_input] + + # 将输入列表中的每个元素转换为字符串 + curr_input = [str(i) for i in curr_input] + + with open(prompt_path, "r") as f: + prompt = f.read() + + for count, i in enumerate(curr_input): prompt = prompt.replace(f"!!", i) - if "###" in prompt: + + if "###" in prompt: prompt = prompt.split("###")[1] + return prompt.strip() -def response_generate(prompt:str): + +def response_generate(prompt: str): """ - 待完善,我没有找到MG中可以设置Temprature以及Maxtoken的位置 + 待完善,我没有找到MG中可以设置Temperature以及Maxtoken的位置 """ return llm.ai_func(prompt) -def special_response_generate(prompt:str,special_instruction:str,example_output:str = None): + +def special_response_generate(prompt: str, special_instruction: str, example_output: str = None): """ 当对于Prompt生成有特殊要求时,调用该函数增加special_instruction或example_output """ prompt = '"""\n' + prompt + '\n"""\n' - prompt += f"Output the response to the prompt above in json. {special_instruction}\n" - if example_output: - prompt += "Example output json:\n" + prompt += f"Output the response to the prompt above in JSON. {special_instruction}\n" + + if example_output: + prompt += "Example output JSON:\n" prompt += '{"output": "' + str(example_output) + '"}' + return response_generate(prompt) - - From 3b42ab42a0fd48d53581d59846330eea18f05c9e Mon Sep 17 00:00:00 2001 From: didi <2020201387@ruc.edu.cn> Date: Thu, 28 Sep 2023 21:54:54 +0800 Subject: [PATCH 6/9] =?UTF-8?q?=E5=B7=B2=E5=AE=8C=E6=88=90=E6=B3=A8?= =?UTF-8?q?=E9=87=8A=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ...{associative_memory.py => agent_memory.py} | 36 +++++++++---------- examples/st_game/memory/retrieve.py | 12 +++---- examples/st_game/prompts/run_gpt_prompts.py | 4 +-- examples/st_game/roles/st_role.py | 2 +- 4 files changed, 27 insertions(+), 27 deletions(-) rename examples/st_game/memory/{associative_memory.py => agent_memory.py} (78%) diff --git a/examples/st_game/memory/associative_memory.py b/examples/st_game/memory/agent_memory.py similarity index 78% rename from examples/st_game/memory/associative_memory.py rename to examples/st_game/memory/agent_memory.py index c771906ec..adb8a5f1f 100644 --- a/examples/st_game/memory/associative_memory.py +++ b/examples/st_game/memory/agent_memory.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -# @Desc : MemoryBasic,AgentMemory实现 +# @Desc : BasicMemory,AgentMemory实现 from metagpt.memory.memory import Memory from metagpt.schema import Message @@ -8,14 +8,14 @@ import json from datetime import datetime -class MemoryBasic(Message): +class BasicMemory(Message): def __init__(self, memory_id: str, memory_count: int, type_count: int, memory_type: str, depth: int, content: int, creaetd: datetime, expiration: datetime, subject: str, predicate: str, object: str, embedding_key: str, poignancy: int, keywords: list, filling: list): """ - MemoryBasic继承于MG的Message类,其中content属性替代description属性 + BasicMemory继承于MG的Message类,其中content属性替代description属性 Message类中对于Chat类型支持的非常好,对于Agent个体的Perceive,Reflection,Plan支持的并不多 在Type设计上,我们延续GA的三个种类,但是对于Chat种类的对话进行特别设计(具体怎么设计还没想好) """ @@ -27,23 +27,23 @@ class MemoryBasic(Message): cause_by 接受一个Action类,在此项目中,每个Agent需要有一个基础动作[Receive] 用于接受假对话Message;而每个Agent需要有独一无二的动作类,用以接受真对话Message """ self.memory_id: str = memory_id # 记忆ID - self.memory_count: int = memory_count # 第几个记忆,实际数值与Memory相等,但是类型为整数 + self.memory_count: int = memory_count # 第几个记忆,实际数值与Memory相等 self.type_count: int = type_count # 第几种记忆,类型为整数(具体不太理解如何生成的) - self.memory_type: str = memory_type # 记忆类型,使用Field,包含 event,thought,chat三种类型 + self.memory_type: str = memory_type # 记忆类型,包含 event,thought,chat三种类型 self.depth: str = depth # 记忆深度,类型为整数 self.created: datetime = creaetd # 创建时间 self.expiration: datetime = expiration # 记忆失效时间,默认为空() self.last_accessed: datetime = creaetd # 上一次调用的时间,初始化时候与self.created一致 - self.subject: str = subject # 主语,str类型 - self.predicate: str = predicate # 谓语,str类型 - self.object: str = object # 宾语,str类型 + self.subject: str = subject # 主语 + self.predicate: str = predicate # 谓语 + self.object: str = object # 宾语 self.embedding_key: str = embedding_key # 内容与self.content一致 - self.poignancy: int = poignancy # importance值,整数类型 - self.keywords: list = keywords # keywords,列表 - self.filling: list = filling # None或者列表 + self.poignancy: int = poignancy # importance值 + self.keywords: list = keywords # keywords + self.filling: list = filling # None或者列表 class AgentMemory(Memory): @@ -60,7 +60,7 @@ class AgentMemory(Memory): @李嵩@张凯 这里的storage是List,你们需要写一个JSON转化器,将List修改为node.json一致的格式 """ super.__init__() - self.storage: list[MemoryBasic] = [] # 重写Stroage,存储MemoryBasic所有节点 + self.storage: list[BasicMemory] = [] # 重写Stroage,存储BasicMemory所有节点 self.event_list = [] # 存储event记忆 self.thought_list = [] # 存储thought记忆 @@ -71,27 +71,27 @@ class AgentMemory(Memory): self.strength_event_keywords = dict() # 不知道具体作用,所以没有删除 self.strength_thought_keywords = dict() - self.embeddings = json.load(open(memory_saved + "/embeddings.json")) # dict类型,load embedding.json - self.memory_load() + self.embeddings = json.load(open(memory_saved + "/embeddings.json")) + self.load() - def memory_save(self): + def save(self): """ 将MemormyBasic类存储为Nodes.json形式。复现GA中的Kw Strength.json形式 @张凯补充一个可调用的函数 """ pass - def memory_load(self): + def load(self): """ 将GA的JSON解析,填充到AgentMemory类之中 """ pass - def add(self, memory_basic: MemoryBasic): + def add(self, memory_basic: BasicMemory): """ Add a new message to storage, while updating the index - 重写add方法,修改原有的Message类为MemoryBasic类,并添加不同的记忆类型添加方式 + 重写add方法,修改原有的Message类为BasicMemory类,并添加不同的记忆类型添加方式 """ if memory_basic in self.storage: return diff --git a/examples/st_game/memory/retrieve.py b/examples/st_game/memory/retrieve.py index 5ac4a9b29..97eb3b6f0 100644 --- a/examples/st_game/memory/retrieve.py +++ b/examples/st_game/memory/retrieve.py @@ -5,19 +5,19 @@ import datetime from numpy import dot from numpy.linalg import norm -from associative_memory import AgentMemory, MemoryBasic +from examples.st_game.memory.agent_memory import AgentMemory, BasicMemory from utils.utils import embedding_tools -def agent_retrieve(agent_memory: AgentMemory, curr_time: datetime.datetime, memory_forget: float, query: str, n: int = 30, topk: int = 4) -> list[MemoryBasic]: +def agent_retrieve(agent_memory: AgentMemory, curr_time: datetime.datetime, memory_forget: float, query: str, n: int = 30, topk: int = 4) -> list[BasicMemory]: """ Retrieve需要集合Role使用,原因在于Role才具有AgentMemory,scratch 逻辑:Role调用该函数,self._rc.AgentMemory,self._rc.scratch.curr_time,self._rc.scratch.memory_forget - 输入希望查询的内容与希望回顾的条数,返回TopK条高分记忆,即List[MemoryBasic] + 输入希望查询的内容与希望回顾的条数,返回TopK条高分记忆,即List[BasicMemory] Score_lists示例 { - "memory": memories[i], MemoryBasic类 + "memory": memories[i], BasicMemory类 "importance": memories[i].poignancy "recency": 衰减因子计算结果 "relevance": 搜索结果 @@ -34,7 +34,7 @@ def agent_retrieve(agent_memory: AgentMemory, curr_time: datetime.datetime, memo score_list = normalize_score_floats(score_list, 0, 1) total_dict = {} - gw = [1, 1, 1] # 三个因素的权重,重要性,近因性,相关性 + gw = [1, 1, 1] # 三个因素的权重,重要性,近因性,相关性 for i in range(len(score_list)): total_score = (score_list[i]['importance'] * gw[0] + score_list[i]['recency'] * gw[1] + @@ -50,7 +50,7 @@ def agent_retrieve(agent_memory: AgentMemory, curr_time: datetime.datetime, memo def top_highest_x_values(d, x): """ 输入字典,Topx - 返回以字典值排序,字典键组成的List[MemoryBasic] + 返回以字典值排序,字典键组成的List[BasicMemory] """ top_v = [item[0] for item in sorted(d.items(), key=lambda item: item[1], reverse=True)[:x]] return top_v diff --git a/examples/st_game/prompts/run_gpt_prompts.py b/examples/st_game/prompts/run_gpt_prompts.py index 16ccbc29c..14b699c15 100644 --- a/examples/st_game/prompts/run_gpt_prompts.py +++ b/examples/st_game/prompts/run_gpt_prompts.py @@ -4,11 +4,11 @@ from wrapper_prompt import special_response_generate, prompt_generate from memory.scratch import Scratch -from memory.associative_memory import MemoryBasic +from examples.st_game.memory.agent_memory import BasicMemory import json -def get_poignancy_action(scratch: Scratch, content: MemoryBasic.content) -> str: +def get_poignancy_action(scratch: Scratch, content: BasicMemory.content) -> str: """ 衡量事件心酸度 """ diff --git a/examples/st_game/roles/st_role.py b/examples/st_game/roles/st_role.py index bc6988e28..047f9545b 100644 --- a/examples/st_game/roles/st_role.py +++ b/examples/st_game/roles/st_role.py @@ -17,7 +17,7 @@ from pathlib import Path from metagpt.roles.role import Role, RoleContext from metagpt.schema import Message -from ..memory.associative_memory import AgentMemory +from ..memory.agent_memory import AgentMemory from ..actions.dummy_action import DummyAction from ..actions.user_requirement import UserRequirement from ..maze_environment import MazeEnvironment From 0f8f4fba5bdc2f959b23aa1e5e2195fa35318d48 Mon Sep 17 00:00:00 2001 From: didi <2020201387@ruc.edu.cn> Date: Thu, 28 Sep 2023 22:53:45 +0800 Subject: [PATCH 7/9] =?UTF-8?q?=E5=B0=8F=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/st_game/memory/agent_memory.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/st_game/memory/agent_memory.py b/examples/st_game/memory/agent_memory.py index adb8a5f1f..db7ff80b8 100644 --- a/examples/st_game/memory/agent_memory.py +++ b/examples/st_game/memory/agent_memory.py @@ -11,7 +11,7 @@ from datetime import datetime class BasicMemory(Message): def __init__(self, memory_id: str, memory_count: int, type_count: int, memory_type: str, depth: int, content: int, - creaetd: datetime, expiration: datetime, + created: datetime, expiration: datetime, subject: str, predicate: str, object: str, embedding_key: str, poignancy: int, keywords: list, filling: list): """ @@ -32,9 +32,9 @@ class BasicMemory(Message): self.memory_type: str = memory_type # 记忆类型,包含 event,thought,chat三种类型 self.depth: str = depth # 记忆深度,类型为整数 - self.created: datetime = creaetd # 创建时间 + self.created: datetime = created # 创建时间 self.expiration: datetime = expiration # 记忆失效时间,默认为空() - self.last_accessed: datetime = creaetd # 上一次调用的时间,初始化时候与self.created一致 + self.last_accessed: datetime = created # 上一次调用的时间,初始化时候与self.created一致 self.subject: str = subject # 主语 self.predicate: str = predicate # 谓语 From 13e75bab8e84046c5281d9f50bfd060f1eeba18a Mon Sep 17 00:00:00 2001 From: didi <2020201387@ruc.edu.cn> Date: Thu, 28 Sep 2023 23:10:44 +0800 Subject: [PATCH 8/9] Update agent_memory.py --- examples/st_game/memory/agent_memory.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/st_game/memory/agent_memory.py b/examples/st_game/memory/agent_memory.py index db7ff80b8..ecf6190d4 100644 --- a/examples/st_game/memory/agent_memory.py +++ b/examples/st_game/memory/agent_memory.py @@ -10,7 +10,7 @@ from datetime import datetime class BasicMemory(Message): - def __init__(self, memory_id: str, memory_count: int, type_count: int, memory_type: str, depth: int, content: int, + def __init__(self, memory_id: str, memory_count: int, type_count: int, memory_type: str, depth: int, content: str, created: datetime, expiration: datetime, subject: str, predicate: str, object: str, embedding_key: str, poignancy: int, keywords: list, filling: list): @@ -19,7 +19,7 @@ class BasicMemory(Message): Message类中对于Chat类型支持的非常好,对于Agent个体的Perceive,Reflection,Plan支持的并不多 在Type设计上,我们延续GA的三个种类,但是对于Chat种类的对话进行特别设计(具体怎么设计还没想好) """ - super.__init__(content) + super().__init__(content) """ 从父类中继承的属性 content: str # 记忆描述 @@ -128,3 +128,6 @@ class AgentMemory(Memory): 调用 """ pass + +if __name__ == "__main__": + \ No newline at end of file From e035706091ebef6626ca38265c47200e3a42cc5e Mon Sep 17 00:00:00 2001 From: didi <2020201387@ruc.edu.cn> Date: Fri, 29 Sep 2023 17:49:31 +0800 Subject: [PATCH 9/9] =?UTF-8?q?9.29=20=E6=9B=B4=E6=96=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 完善了一下AgentMemory的Add方法(反思与Plan中都要用),save load方法(需要跟组员对修改新的一版);添加了Scratch类属性与方法(给了一版文档介绍);修改了STrole中STrolecontext属性,添加了Strole中Retrieve方法 --- examples/st_game/memory/agent_memory.py | 243 +++++++++-- examples/st_game/memory/scratch.py | 535 +++++++++++++++++++++++- examples/st_game/roles/st_role.py | 11 +- examples/st_game/utils/check.py | 14 + 4 files changed, 771 insertions(+), 32 deletions(-) create mode 100644 examples/st_game/utils/check.py diff --git a/examples/st_game/memory/agent_memory.py b/examples/st_game/memory/agent_memory.py index ecf6190d4..a56100ee7 100644 --- a/examples/st_game/memory/agent_memory.py +++ b/examples/st_game/memory/agent_memory.py @@ -10,16 +10,17 @@ from datetime import datetime class BasicMemory(Message): - def __init__(self, memory_id: str, memory_count: int, type_count: int, memory_type: str, depth: int, content: str, + def __init__(self, memory_id: str, memory_count: int, type_count: int, memory_type: str, depth: int, created: datetime, expiration: datetime, subject: str, predicate: str, object: str, - embedding_key: str, poignancy: int, keywords: list, filling: list): + content: str, embedding_key: str, poignancy: int, keywords: list, filling: list, + cause_by = ""): """ BasicMemory继承于MG的Message类,其中content属性替代description属性 Message类中对于Chat类型支持的非常好,对于Agent个体的Perceive,Reflection,Plan支持的并不多 在Type设计上,我们延续GA的三个种类,但是对于Chat种类的对话进行特别设计(具体怎么设计还没想好) """ - super().__init__(content) + super().__init__(content,cause_by=cause_by) """ 从父类中继承的属性 content: str # 记忆描述 @@ -43,8 +44,41 @@ class BasicMemory(Message): self.embedding_key: str = embedding_key # 内容与self.content一致 self.poignancy: int = poignancy # importance值 self.keywords: list = keywords # keywords - self.filling: list = filling # None或者列表 + self.filling: list = filling # 装的与之相关联的memory_id的列表 + def save_to_dict(self) -> dict: + """ + 将MemoryBasic类转化为字典,用于存储json文件 + 这里需要注意,cause_by跟GA不兼容,所以需要做一个格式转换 + """ + memory_dict = dict() + node_id = self.memory_id + + memory_dict[node_id] = dict() + memory_dict[node_id]["node_count"] = self.memory_count + memory_dict[node_id]["type_count"] = self.type_count + memory_dict[node_id]["type"] = self.type + memory_dict[node_id]["depth"] = self.depth + + memory_dict[node_id]["cmemory_dicteated"] = self.created.strftime('%Y-%m-%d %H:%M:%S') + memory_dict[node_id]["expiration"] = None + if self.expiration: + memory_dict[node_id]["expiration"] = (self.expiration + .strftime('%Y-%m-%d %H:%M:%S')) + + memory_dict[node_id]["subject"] = self.subject + memory_dict[node_id]["predicate"] = self.predicate + memory_dict[node_id]["object"] = self.object + + memory_dict[node_id]["description"] = self.description + memory_dict[node_id]["embedding_key"] = self.embedding_key + memory_dict[node_id]["poignancy"] = self.poignancy + memory_dict[node_id]["keywords"] = list(self.keywords) + memory_dict[node_id]["filling"] = self.filling + if self.cause_by: + memory_dict[node_id]["cause_by"] = self.cause_by + + return memory_dict class AgentMemory(Memory): """ @@ -68,25 +102,82 @@ class AgentMemory(Memory): self.thought_keywords = dict() self.chat_keywords = dict() - self.strength_event_keywords = dict() # 不知道具体作用,所以没有删除 - self.strength_thought_keywords = dict() + self.kw_strength_event = dict() # 关键词影响存储 + self.kw_strength_thought = dict() - self.embeddings = json.load(open(memory_saved + "/embeddings.json")) - self.load() + self.load(memory_saved) - def save(self): + def save(self,memory_saved:str): """ 将MemormyBasic类存储为Nodes.json形式。复现GA中的Kw Strength.json形式 - @张凯补充一个可调用的函数 + 这里添加一个路径即可 """ - pass - def load(self): + memory_json = dict() + for i in range(len(self.storage)): + memory_node = self.storage[i] + memory_json.update(memory_node) + with open(memory_saved+"/nodes.json", "w") as outfile: + json.dump(memory_json, outfile) + + with open(memory_saved+"/embeddings.json", "w") as outfile: + json.dump(self.embeddings, outfile) + + strength_json = dict() + strength_json["kw_strength_event"] = self.kw_strength_event + strength_json["kw_strength_thought"] = self.kw_strength_thought + with open(memory_saved+"/kw_strength.json", "w") as outfile: + json.dump(strength_json, outfile) + + + def load(self,memory_saved:str): """ 将GA的JSON解析,填充到AgentMemory类之中 """ - pass + self.embeddings = json.load(open(memory_saved + "/embeddings.json")) + memory_load = json.load(open(memory_saved + "/nodes.json")) + for count in range(len(memory_load.keys())): + node_id = f"node_{str(count+1)}" + node_details = memory_load[node_id] + node_type = node_details["type"] + created = datetime.datetime.strptime(node_details["created"], + '%Y-%m-%d %H:%M:%S') + expiration = None + if node_details["expiration"]: + expiration = datetime.datetime.strptime(node_details["expiration"], + '%Y-%m-%d %H:%M:%S') + + if node_details["cause_by"]: + cause_by = node_details["cause_by"] + + s = node_details["subject"] + p = node_details["predicate"] + o = node_details["object"] + + description = node_details["description"] + embedding_pair = (node_details["embedding_key"], + self.embeddings[node_details["embedding_key"]]) + poignancy =node_details["poignancy"] + keywords = set(node_details["keywords"]) + filling = node_details["filling"] + + if node_type == "event": + self.add_event(created, expiration, s, p, o, + description, keywords, poignancy, embedding_pair, filling) + elif node_type == "chat": + self.add_chat(created, expiration, s, p, o, + description, keywords, poignancy, embedding_pair, filling,cause_by) + elif node_type == "thought": + self.add_thought(created, expiration, s, p, o, + description, keywords, poignancy, embedding_pair, filling) + + strength_keywords_load = json.load(open(memory_saved + "/kw_strength.json")) + if strength_keywords_load["kw_strength_event"]: + self.kw_strength_event = strength_keywords_load["kw_strength_event"] + if strength_keywords_load["kw_strength_thought"]: + self.kw_strength_thought = strength_keywords_load["kw_strength_thought"] + def add(self, memory_basic: BasicMemory): """ @@ -97,37 +188,131 @@ class AgentMemory(Memory): return self.storage.append(memory_basic) if memory_basic.cause_by: - self.index[memory_basic.cause_by].append(memory_basic) + self.index[memory_basic.cause_by][0:0] = [memory_basic] return if memory_basic.type == "thought": - self.thought_list.append(memory_basic) + self.thought_list[0:0] = [memory_basic] return if memory_basic.type == "event": - self.event_list.append(memory_basic) + self.event_list[0:0] = [memory_basic] - def add_chat(self): + + def add_chat(self, created, expiration, s, p, o, + content, keywords, poignancy, + embedding_pair, filling, + cause_by): """ 调用add方法,初始化chat,在创建的时候就需要调用embeeding函数 """ - pass + memory_count = len(self.storage) + 1 + type_count = len(self.thought_list) + 1 + memory_type = "chat" + memory_id = f"memory_{str(memory_count)}" + depth = 1 - def add_thought(self): + memory_node = BasicMemory(memory_id, memory_count, type_count, memory_type, depth, + created, expiration, + s, p ,o, + content, embedding_pair[0], + poignancy, keywords, filling, + cause_by) + + keywords = [i.lower() for i in keywords] + for kw in keywords: + if kw in self.chat_keywords: + self.chat_keywords[kw][0:0] = [memory_node] + else: + self.chat_keywords[kw] = [memory_node] + + self.add(memory_node) + + self.embeddings[embedding_pair[0]] = embedding_pair[1] + return memory_node + + + def add_thought(self, created, expiration, s, p, o, + content, keywords, poignancy, + embedding_pair, filling): """ 调用add方法,初始化thought """ - pass + memory_count = len(self.storage) + 1 + type_count = len(self.thought_list) + 1 + memory_type = "event" + memory_id = f"memory_{str(memory_count)}" + depth = 1 + + try: + if filling: + depth_list = [memory_node.depth for memory_node in self.storage if memory_node.memory_id in filling ] + depth += max(depth_list) + except: + pass - def add_event(self): + memory_node = BasicMemory(memory_id, memory_count, type_count, memory_type, depth, + created, expiration, + s, p ,o, + content, embedding_pair[0], + poignancy, keywords, filling) + + keywords = [i.lower() for i in keywords] + for kw in keywords: + if kw in self.thought_keywords: + self.thought_keywords[kw][0:0] = [memory_node] + else: + self.thought_keywords[kw] = [memory_node] + + self.add(memory_node) + + if f"{p} {o}" != "is idle": + for kw in keywords: + if kw in self.kw_strength_thought: + self.kw_strength_thought[kw] += 1 + else: + self.kw_strength_thought[kw] = 1 + + self.embeddings[embedding_pair[0]] = embedding_pair[1] + return memory_node + + + def add_event(self, created, expiration, s, p, o, + content, keywords, poignancy, + embedding_pair, filling): """ 调用add方法,初始化event """ - pass + memory_count = len(self.storage) + 1 + type_count = len(self.event_list) + 1 + memory_type = "event" + memory_id = f"memory_{str(memory_count)}" + depth = 0 + + if "(" in content: + content = (" ".join(content.split()[:3]) + + " " + + content.split("(")[-1][:-1]) + + memory_node = BasicMemory(memory_id, memory_count, type_count, memory_type, depth, + created, expiration, + s, p ,o, + content, embedding_pair[0], + poignancy, keywords, filling) - def retrive(self,): - """ - 调用 - """ - pass + keywords = [i.lower() for i in keywords] + for kw in keywords: + if kw in self.event_keywords: + self.event_keywords[kw][0:0] = [memory_node] + else: + self.event_keywords[kw] = [memory_node] + + self.add(memory_node) -if __name__ == "__main__": - \ No newline at end of file + if f"{p} {o}" != "is idle": + for kw in keywords: + if kw in self.kw_strength_event: + self.kw_strength_event[kw] += 1 + else: + self.kw_strength_event[kw] = 1 + + self.embeddings[embedding_pair[0]] = embedding_pair[1] + return memory_node diff --git a/examples/st_game/memory/scratch.py b/examples/st_game/memory/scratch.py index 00da03dd6..d0d13002e 100644 --- a/examples/st_game/memory/scratch.py +++ b/examples/st_game/memory/scratch.py @@ -2,5 +2,536 @@ # -*- coding: utf-8 -*- # @Desc : Scratch类实现(角色信息类) -class Scratch(): - pass \ No newline at end of file +import datetime +import json +import sys +sys.path.append('../../') + +from ..utils.check import check_if_file_exists + +class Scratch: + def __init__(self, f_saved): + # 类别1:人物超参 + self.vision_r = 4 + self.att_bandwidth = 3 + self.retention = 5 + + # 类别2:世界信息 + self.curr_time = None + self.curr_tile = None + self.daily_plan_req = None + + # 类别3:人物角色的核心身份 + self.name = None + self.first_name = None + self.last_name = None + self.age = None + # L0 permanent core traits. + self.innate = None + # L1 stable traits. + self.learned = None + # L2 external implementation. + self.currently = None + self.lifestyle = None + self.living_area = None + + # 类别4:旧反思变量 + self.concept_forget = 100 + self.daily_reflection_time = 60 * 3 + self.daily_reflection_size = 5 + self.overlap_reflect_th = 2 + self.kw_strg_event_reflect_th = 4 + self.kw_strg_thought_reflect_th = 4 + + # 类别5:新反思变量 + self.recency_w = 1 + self.relevance_w = 1 + self.importance_w = 1 + self.recency_decay = 0.99 + self.importance_trigger_max = 150 + self.importance_trigger_curr = self.importance_trigger_max + self.importance_ele_n = 0 + self.thought_count = 5 + + # 类别6:个人计划 + self.daily_req = [] + self.f_daily_schedule = [] + self.f_daily_schedule_hourly_org = [] + + # 类别7:当前动作 + self.act_address = None + self.act_start_time = None + self.act_duration = None + self.act_description = None + self.act_pronunciatio = None + self.act_event = (self.name, None, None) + + self.act_obj_description = None + self.act_obj_pronunciatio = None + self.act_obj_event = (self.name, None, None) + + self.chatting_with = None + self.chat = None + self.chatting_with_buffer = dict() + self.chatting_end_time = None + + self.act_path_set = False + self.planned_path = [] + + if check_if_file_exists(f_saved): + # If we have a bootstrap file, load that here. + scratch_load = json.load(open(f_saved)) + + self.vision_r = scratch_load["vision_r"] + self.att_bandwidth = scratch_load["att_bandwidth"] + self.retention = scratch_load["retention"] + + if scratch_load["curr_time"]: + self.curr_time = datetime.datetime.strptime(scratch_load["curr_time"], + "%B %d, %Y, %H:%M:%S") + else: + self.curr_time = None + self.curr_tile = scratch_load["curr_tile"] + self.daily_plan_req = scratch_load["daily_plan_req"] + + self.name = scratch_load["name"] + self.first_name = scratch_load["first_name"] + self.last_name = scratch_load["last_name"] + self.age = scratch_load["age"] + self.innate = scratch_load["innate"] + self.learned = scratch_load["learned"] + self.currently = scratch_load["currently"] + self.lifestyle = scratch_load["lifestyle"] + self.living_area = scratch_load["living_area"] + + self.concept_forget = scratch_load["concept_forget"] + self.daily_reflection_time = scratch_load["daily_reflection_time"] + self.daily_reflection_size = scratch_load["daily_reflection_size"] + self.overlap_reflect_th = scratch_load["overlap_reflect_th"] + self.kw_strg_event_reflect_th = scratch_load["kw_strg_event_reflect_th"] + self.kw_strg_thought_reflect_th = scratch_load["kw_strg_thought_reflect_th"] + + self.recency_w = scratch_load["recency_w"] + self.relevance_w = scratch_load["relevance_w"] + self.importance_w = scratch_load["importance_w"] + self.recency_decay = scratch_load["recency_decay"] + self.importance_trigger_max = scratch_load["importance_trigger_max"] + self.importance_trigger_curr = scratch_load["importance_trigger_curr"] + self.importance_ele_n = scratch_load["importance_ele_n"] + self.thought_count = scratch_load["thought_count"] + + self.daily_req = scratch_load["daily_req"] + self.f_daily_schedule = scratch_load["f_daily_schedule"] + self.f_daily_schedule_hourly_org = scratch_load["f_daily_schedule_hourly_org"] + + self.act_address = scratch_load["act_address"] + if scratch_load["act_start_time"]: + self.act_start_time = datetime.datetime.strptime( + scratch_load["act_start_time"], + "%B %d, %Y, %H:%M:%S") + else: + self.curr_time = None + self.act_duration = scratch_load["act_duration"] + self.act_description = scratch_load["act_description"] + self.act_pronunciatio = scratch_load["act_pronunciatio"] + self.act_event = tuple(scratch_load["act_event"]) + + self.act_obj_description = scratch_load["act_obj_description"] + self.act_obj_pronunciatio = scratch_load["act_obj_pronunciatio"] + self.act_obj_event = tuple(scratch_load["act_obj_event"]) + + self.chatting_with = scratch_load["chatting_with"] + self.chat = scratch_load["chat"] + self.chatting_with_buffer = scratch_load["chatting_with_buffer"] + if scratch_load["chatting_end_time"]: + self.chatting_end_time = datetime.datetime.strptime( + scratch_load["chatting_end_time"], + "%B %d, %Y, %H:%M:%S") + else: + self.chatting_end_time = None + + self.act_path_set = scratch_load["act_path_set"] + self.planned_path = scratch_load["planned_path"] + + + def save(self, out_json): + """ + Save persona's scratch. + + INPUT: + out_json: The file where we wil be saving our persona's state. + OUTPUT: + None + """ + scratch = dict() + scratch["vision_r"] = self.vision_r + scratch["att_bandwidth"] = self.att_bandwidth + scratch["retention"] = self.retention + + scratch["curr_time"] = self.curr_time.strftime("%B %d, %Y, %H:%M:%S") + scratch["curr_tile"] = self.curr_tile + scratch["daily_plan_req"] = self.daily_plan_req + + scratch["name"] = self.name + scratch["first_name"] = self.first_name + scratch["last_name"] = self.last_name + scratch["age"] = self.age + scratch["innate"] = self.innate + scratch["learned"] = self.learned + scratch["currently"] = self.currently + scratch["lifestyle"] = self.lifestyle + scratch["living_area"] = self.living_area + + scratch["concept_forget"] = self.concept_forget + scratch["daily_reflection_time"] = self.daily_reflection_time + scratch["daily_reflection_size"] = self.daily_reflection_size + scratch["overlap_reflect_th"] = self.overlap_reflect_th + scratch["kw_strg_event_reflect_th"] = self.kw_strg_event_reflect_th + scratch["kw_strg_thought_reflect_th"] = self.kw_strg_thought_reflect_th + + scratch["recency_w"] = self.recency_w + scratch["relevance_w"] = self.relevance_w + scratch["importance_w"] = self.importance_w + scratch["recency_decay"] = self.recency_decay + scratch["importance_trigger_max"] = self.importance_trigger_max + scratch["importance_trigger_curr"] = self.importance_trigger_curr + scratch["importance_ele_n"] = self.importance_ele_n + scratch["thought_count"] = self.thought_count + + scratch["daily_req"] = self.daily_req + scratch["f_daily_schedule"] = self.f_daily_schedule + scratch["f_daily_schedule_hourly_org"] = self.f_daily_schedule_hourly_org + + scratch["act_address"] = self.act_address + scratch["act_start_time"] = (self.act_start_time + .strftime("%B %d, %Y, %H:%M:%S")) + scratch["act_duration"] = self.act_duration + scratch["act_description"] = self.act_description + scratch["act_pronunciatio"] = self.act_pronunciatio + scratch["act_event"] = self.act_event + + scratch["act_obj_description"] = self.act_obj_description + scratch["act_obj_pronunciatio"] = self.act_obj_pronunciatio + scratch["act_obj_event"] = self.act_obj_event + + scratch["chatting_with"] = self.chatting_with + scratch["chat"] = self.chat + scratch["chatting_with_buffer"] = self.chatting_with_buffer + if self.chatting_end_time: + scratch["chatting_end_time"] = (self.chatting_end_time + .strftime("%B %d, %Y, %H:%M:%S")) + else: + scratch["chatting_end_time"] = None + + scratch["act_path_set"] = self.act_path_set + scratch["planned_path"] = self.planned_path + + with open(out_json, "w") as outfile: + json.dump(scratch, outfile, indent=2) + + + def get_f_daily_schedule_index(self, advance=0): + """ + We get the current index of self.f_daily_schedule. + + Recall that self.f_daily_schedule stores the decomposed action sequences + up until now, and the hourly sequences of the future action for the rest + of today. Given that self.f_daily_schedule is a list of list where the + inner list is composed of [task, duration], we continue to add up the + duration until we reach "if elapsed > today_min_elapsed" condition. The + index where we stop is the index we will return. + + INPUT + advance: Integer value of the number minutes we want to look into the + future. This allows us to get the index of a future timeframe. + OUTPUT + an integer value for the current index of f_daily_schedule. + """ + # We first calculate teh number of minutes elapsed today. + today_min_elapsed = 0 + today_min_elapsed += self.curr_time.hour * 60 + today_min_elapsed += self.curr_time.minute + today_min_elapsed += advance + + x = 0 + for task, duration in self.f_daily_schedule: + x += duration + x = 0 + for task, duration in self.f_daily_schedule_hourly_org: + x += duration + + # We then calculate the current index based on that. + curr_index = 0 + elapsed = 0 + for task, duration in self.f_daily_schedule: + elapsed += duration + if elapsed > today_min_elapsed: + return curr_index + curr_index += 1 + + return curr_index + + + def get_f_daily_schedule_hourly_org_index(self, advance=0): + """ + We get the current index of self.f_daily_schedule_hourly_org. + It is otherwise the same as get_f_daily_schedule_index. + + INPUT + advance: Integer value of the number minutes we want to look into the + future. This allows us to get the index of a future timeframe. + OUTPUT + an integer value for the current index of f_daily_schedule. + """ + # We first calculate teh number of minutes elapsed today. + today_min_elapsed = 0 + today_min_elapsed += self.curr_time.hour * 60 + today_min_elapsed += self.curr_time.minute + today_min_elapsed += advance + # We then calculate the current index based on that. + curr_index = 0 + elapsed = 0 + for task, duration in self.f_daily_schedule_hourly_org: + elapsed += duration + if elapsed > today_min_elapsed: + return curr_index + curr_index += 1 + return curr_index + + + def get_str_iss(self): + """ + ISS stands for "identity stable set." This describes the commonset summary + of this persona -- basically, the bare minimum description of the persona + that gets used in almost all prompts that need to call on the persona. + + INPUT + None + OUTPUT + the identity stable set summary of the persona in a string form. + EXAMPLE STR OUTPUT + "Name: Dolores Heitmiller + Age: 28 + Innate traits: hard-edged, independent, loyal + Learned traits: Dolores is a painter who wants live quietly and paint + while enjoying her everyday life. + Currently: Dolores is preparing for her first solo show. She mostly + works from home. + Lifestyle: Dolores goes to bed around 11pm, sleeps for 7 hours, eats + dinner around 6pm. + Daily plan requirement: Dolores is planning to stay at home all day and + never go out." + """ + commonset = "" + commonset += f"Name: {self.name}\n" + commonset += f"Age: {self.age}\n" + commonset += f"Innate traits: {self.innate}\n" + commonset += f"Learned traits: {self.learned}\n" + commonset += f"Currently: {self.currently}\n" + commonset += f"Lifestyle: {self.lifestyle}\n" + commonset += f"Daily plan requirement: {self.daily_plan_req}\n" + commonset += f"Current Date: {self.curr_time.strftime('%A %B %d')}\n" + return commonset + + + def get_str_name(self): + return self.name + + + def get_str_firstname(self): + return self.first_name + + + def get_str_lastname(self): + return self.last_name + + + def get_str_age(self): + return str(self.age) + + + def get_str_innate(self): + return self.innate + + + def get_str_learned(self): + return self.learned + + + def get_str_currently(self): + return self.currently + + + def get_str_lifestyle(self): + return self.lifestyle + + + def get_str_daily_plan_req(self): + return self.daily_plan_req + + + def get_str_curr_date_str(self): + return self.curr_time.strftime("%A %B %d") + + + def get_curr_event(self): + if not self.act_address: + return (self.name, None, None) + else: + return self.act_event + + + def get_curr_event_and_desc(self): + if not self.act_address: + return (self.name, None, None, None) + else: + return (self.act_event[0], + self.act_event[1], + self.act_event[2], + self.act_description) + + + def get_curr_obj_event_and_desc(self): + if not self.act_address: + return ("", None, None, None) + else: + return (self.act_address, + self.act_obj_event[1], + self.act_obj_event[2], + self.act_obj_description) + + + def add_new_action(self, + action_address, + action_duration, + action_description, + action_pronunciatio, + action_event, + chatting_with, + chat, + chatting_with_buffer, + chatting_end_time, + act_obj_description, + act_obj_pronunciatio, + act_obj_event, + act_start_time=None): + self.act_address = action_address + self.act_duration = action_duration + self.act_description = action_description + self.act_pronunciatio = action_pronunciatio + self.act_event = action_event + + self.chatting_with = chatting_with + self.chat = chat + if chatting_with_buffer: + self.chatting_with_buffer.update(chatting_with_buffer) + self.chatting_end_time = chatting_end_time + + self.act_obj_description = act_obj_description + self.act_obj_pronunciatio = act_obj_pronunciatio + self.act_obj_event = act_obj_event + + self.act_start_time = self.curr_time + + self.act_path_set = False + + + def act_time_str(self): + """ + Returns a string output of the current time. + + INPUT + None + OUTPUT + A string output of the current time. + EXAMPLE STR OUTPUT + "14:05 P.M." + """ + return self.act_start_time.strftime("%H:%M %p") + + + def act_check_finished(self): + """ + Checks whether the self.Action instance has finished. + + INPUT + curr_datetime: Current time. If current time is later than the action's + start time + its duration, then the action has finished. + OUTPUT + Boolean [True]: Action has finished. + Boolean [False]: Action has not finished and is still ongoing. + """ + if not self.act_address: + return True + + if self.chatting_with: + end_time = self.chatting_end_time + else: + x = self.act_start_time + if x.second != 0: + x = x.replace(second=0) + x = (x + datetime.timedelta(minutes=1)) + end_time = (x + datetime.timedelta(minutes=self.act_duration)) + + if end_time.strftime("%H:%M:%S") == self.curr_time.strftime("%H:%M:%S"): + return True + return False + + + def act_summarize(self): + """ + Summarize the current action as a dictionary. + + INPUT + None + OUTPUT + ret: A human readable summary of the action. + """ + exp = dict() + exp["persona"] = self.name + exp["address"] = self.act_address + exp["start_datetime"] = self.act_start_time + exp["duration"] = self.act_duration + exp["description"] = self.act_description + exp["pronunciatio"] = self.act_pronunciatio + return exp + + + def act_summary_str(self): + """ + Returns a string summary of the current action. Meant to be + human-readable. + + INPUT + None + OUTPUT + ret: A human readable summary of the action. + """ + start_datetime_str = self.act_start_time.strftime("%A %B %d -- %H:%M %p") + ret = f"[{start_datetime_str}]\n" + ret += f"Activity: {self.name} is {self.act_description}\n" + ret += f"Address: {self.act_address}\n" + ret += f"Duration in minutes (e.g., x min): {str(self.act_duration)} min\n" + return ret + + + def get_str_daily_schedule_summary(self): + ret = "" + curr_min_sum = 0 + for row in self.f_daily_schedule: + curr_min_sum += row[1] + hour = int(curr_min_sum/60) + minute = curr_min_sum%60 + ret += f"{hour:02}:{minute:02} || {row[0]}\n" + return ret + + + def get_str_daily_schedule_hourly_org_summary(self): + ret = "" + curr_min_sum = 0 + for row in self.f_daily_schedule_hourly_org: + curr_min_sum += row[1] + hour = int(curr_min_sum/60) + minute = curr_min_sum%60 + ret += f"{hour:02}:{minute:02} || {row[0]}\n" + return ret diff --git a/examples/st_game/roles/st_role.py b/examples/st_game/roles/st_role.py index 047f9545b..8f0c47f7f 100644 --- a/examples/st_game/roles/st_role.py +++ b/examples/st_game/roles/st_role.py @@ -21,15 +21,18 @@ from ..memory.agent_memory import AgentMemory from ..actions.dummy_action import DummyAction from ..actions.user_requirement import UserRequirement from ..maze_environment import MazeEnvironment +from ..memory.retrieve import agent_retrieve +from ..memory.scratch import Scratch class STRoleContext(RoleContext): env: 'MazeEnvironment' = Field(default=None) memory: AgentMemory = Field(default=AgentMemory) + scratch: Scratch = Field(default=Scratch) class STRole(Role): - + # 继承Role类,Role类继承RoleContext,这里的逻辑需要认真考虑 # add a role's property structure to store role's age and so on like GA's Scratch. def __init__(self, @@ -65,6 +68,12 @@ class STRole(Role): # TODO observe info from maze_env pass + async def retrieve(self, query, n = 30 ,topk = 4): + # TODO retrieve memories from agent_memory + retrieve_memories = agent_retrieve(self._rc.memory, self._rc.scratch.curr_time, self._rc.scratch.recency_decay, query, n, topk) + return retrieve_memories + + async def plan(self): # TODO make a plan diff --git a/examples/st_game/utils/check.py b/examples/st_game/utils/check.py new file mode 100644 index 000000000..0a806fe2d --- /dev/null +++ b/examples/st_game/utils/check.py @@ -0,0 +1,14 @@ +def check_if_file_exists(curr_file): + """ + Checks if a file exists + ARGS: + curr_file: path to the current csv file. + RETURNS: + True if the file exists + False if the file does not exist + """ + try: + with open(curr_file) as f_analysis_file: pass + return True + except: + return False \ No newline at end of file