From 770dcdc755cd3f34d55bc7f54da90a1f81a628a0 Mon Sep 17 00:00:00 2001
From: didi <2020201387@ruc.edu.cn>
Date: Thu, 28 Sep 2023 00:17:32 +0800
Subject: [PATCH 1/9] ga_game Memory & Retrive
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
1. 完善了Memory模块,明天添加不同类型记忆的add方法
2. 添加了Retrive方法
3. 添加了Prompt,Scracth等模块
---
examples/st_game/memory/associative_memory.py | 125 +++++++++++++++-
examples/st_game/memory/retrive.py | 138 ++++++++++++++++++
examples/st_game/memory/scratch.py | 6 +
.../prompts_templates/poignancy_chat_v1.txt | 17 +++
examples/st_game/prompts/run_gpt_prompts.py | 33 +++++
examples/st_game/prompts/wrapper_prompt.py | 42 ++++++
examples/st_game/roles/st_role.py | 4 +-
examples/st_game/utils/utils.py | 9 ++
8 files changed, 369 insertions(+), 5 deletions(-)
create mode 100644 examples/st_game/memory/retrive.py
create mode 100644 examples/st_game/memory/scratch.py
create mode 100644 examples/st_game/prompts/prompts_templates/poignancy_chat_v1.txt
create mode 100644 examples/st_game/prompts/run_gpt_prompts.py
create mode 100644 examples/st_game/prompts/wrapper_prompt.py
diff --git a/examples/st_game/memory/associative_memory.py b/examples/st_game/memory/associative_memory.py
index ef1514398..5b78926b9 100644
--- a/examples/st_game/memory/associative_memory.py
+++ b/examples/st_game/memory/associative_memory.py
@@ -1,9 +1,128 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
-# @Desc : associative_memory to store conversation、plan detail、reflection result and so on.
+# @Desc : MemoryBasic,AgentMemory实现
from metagpt.memory.memory import Memory
+from metagpt.schema import Message
+import json
+from datetime import datetime
+
+class MemoryBasic(Message):
+
+ def __init__(self,memory_id:str,memory_count:int,type_count:int,memory_type:str,depth:int,content:int,
+ creaetd:datetime,expiration:datetime,
+ subject:str,predicate:str,object:str,
+ embedding_key:str,poignancy:int,keywords:list,filling:list):
+ """
+ MemoryBasic继承于MG的Message类,其中content属性替代description属性
+ Message类中对于Chat类型支持的非常好,对于Agent个体的Perceive,Reflection,Plan支持的并不多
+ 在Type设计上,我们延续GA的三个种类,但是对于Chat种类的对话进行特别设计(具体怎么设计还没想好)
+ """
+ super.__init__(content)
+ """
+ 从父类中继承的属性
+ content: str # 记忆描述
+ cause_by: Type["Action"] = field(default="") # 触发动作,只在Type为chat时初始化
+ cause_by 接受一个Action类,在此项目中,每个Agent需要有一个基础动作[Receive] 用于接受假对话Message;而每个Agent需要有独一无二的动作类,用以接受真对话Message
+ """
+ self.memory_id: str = memory_id # 记忆ID
+ self.memory_count: int = memory_count # 第几个记忆,实际数值与Memory相等,但是类型为整数
+ self.type_count: int = type_count # 第几种记忆,类型为整数(具体不太理解如何生成的)
+ self.memory_type: str = memory_type # 记忆类型,使用Field,包含 event,thought,chat三种类型
+ self.depth:str = depth # 记忆深度,类型为整数
+
+ self.created: datetime = creaetd # 创建时间
+ self.expiration: datetime = expiration # 记忆失效时间,默认为空()
+ self.last_accessed: datetime = creaetd # 上一次调用的时间,初始化时候与self.created一致
+
+ self.subject: str = subject # 主语,str类型
+ self.predicate:str = predicate # 谓语,str类型
+ self.object:str = object # 宾语,str类型
+
+ self.embedding_key: str = embedding_key # 内容与self.content一致
+ self.poignancy:int = poignancy # importance值,整数类型
+ self.keywords:list = keywords # keywords,列表
+ self.filling:list = filling # None或者列表
+
+class AgentMemory(Memory):
+ """
+ GA中主要存储三种JSON
+ 1. embedding.json (Dict embedding_key:embedding)
+ 2. Node.json (Dict Node_id:Node)
+ 3. kw_strength.json
+ """
+ def __init__(self,memory_saved:str):
+ """
+ AgentMemory类继承自Memory类,重写storage替代GA中id_to_node,一方面存储所有信息,一方面作为JSON转化
+ index存储与不同Agent的chat信息
+ @李嵩@张凯 这里的storage是List,你们需要写一个JSON转化器,将List修改为node.json一致的格式
+ """
+ super.__init__()
+ self.storage: list[MemoryBasic] = [] # 重写Stroage,存储MemoryBasic所有节点
+ self.event_list = [] # 存储event记忆
+ self.thought_list = [] # 存储thought记忆
+
+ self.event_keywords = dict() # 存储keywords
+ self.thought_keywords = dict()
+ self.chat_keywords = dict()
+
+ self.strength_event_keywords = dict() # 不知道具体作用,所以没有删除
+ self.strength_thought_keywords = dict()
+
+ self.embeddings = json.load(open(memory_saved + "/embeddings.json")) # dict类型,load embedding.json
+ self.memory_load()
-class AssociativeMemory(Memory):
- pass
+ def memory_save(self):
+ """
+ 将MemormyBasic类存储为Nodes.json形式。复现GA中的Kw Strength.json形式
+ @张凯补充一个可调用的函数
+ """
+ pass
+
+ def memory_load(self):
+ """
+ 将GA的JSON解析,填充到AgentMemory类之中
+ """
+ pass
+
+ def add(self, memory_basic: MemoryBasic):
+ """
+ Add a new message to storage, while updating the index
+ 重写add方法,修改原有的Message类为MemoryBasic类,并添加不同的记忆类型添加方式
+ """
+ if memory_basic in self.storage:
+ return
+ self.storage.append(memory_basic)
+ if memory_basic.cause_by:
+ self.index[memory_basic.cause_by].append(memory_basic)
+ return
+ if memory_basic.type == "thought":
+ self.thought_list.append(memory_basic)
+ return
+ if memory_basic.type == "event":
+ self.event_list.append(memory_basic)
+
+ def add_chat(self):
+ """
+ 调用add方法,初始化chat,在创建的时候就需要调用embeeding函数
+ """
+ pass
+
+ def add_thought(self):
+ """
+ 调用add方法,初始化thought
+ """
+ pass
+
+ def add_event(self):
+ """
+ 调用add方法,初始化event
+ """
+ pass
+
+ def retrive(self,):
+ """
+ 调用
+ """
+ pass
diff --git a/examples/st_game/memory/retrive.py b/examples/st_game/memory/retrive.py
new file mode 100644
index 000000000..0bca9ea8b
--- /dev/null
+++ b/examples/st_game/memory/retrive.py
@@ -0,0 +1,138 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Desc : Retrive函数实现
+
+from numpy import dot
+from numpy.linalg import norm
+from datetime import datetime
+from associative_memory import AgentMemory,MemoryBasic
+from utils.utils import embedding_tools
+
+def agent_retrive(agent:AgentMemory,currtime:datetime,memory_forget:float,query:str,n:int= 30,topk:int=4) -> list[MemoryBasic]:
+ """
+ retrive需要集合Role使用,原因在于Role才具有AgentMemory,scratch
+ 逻辑:Role调用该函数,self._rc.AgentMemory,self._rc.scratch.currtime,self._rc.scratch.memory_forget
+ 输入希望查询的内容与希望回顾的条数,返回TopK条高分记忆,即List[MemoryBasic]
+
+ Score_lists示例
+ {
+ "memory":memories[i], MemoryBasic类
+ "importance":memories[i].poignancy
+ "recency":衰减因子计算结果
+ "relevance":搜索结果
+ }
+ """
+ memories = AgentMemory.storage
+ sorted_memories = sorted(memories, key=lambda memory_node: memory_node.last_accessed_time,reverse=True)
+ memories = sorted_memories[:n] if len(sorted_memories) >= n else sorted_memories
+
+ Score_list = []
+ Score_list = extract_importance(memories,Score_list)
+ Score_list = extract_recency(currtime,memory_forget,Score_list)
+ Score_list = extract_relevance(query,Score_list)
+ Score_list = normalize_Socre_floats(Score_list,0,1)
+
+ total_dict = {}
+ gw = [1,1,1] # 三个因素的权重,重要性,近因性,相关性
+ for i in range(len(Score_list)):
+ total_score = (Score_list[i]['importance']*gw[0] +
+ Score_list[i]['recency']*gw[1] +
+ Score_list[i]['relevance']*gw[2]
+ )
+ total_dict[Score_list[i]['memory']] = total_score
+
+ result = top_highest_x_values(total_dict,topk)
+
+ return result
+
+def top_highest_x_values(d, x):
+ """
+ 输入字典,Topx
+ 返回以字典值排序,字典键组成的List[MemoryBasic]
+ """
+ top_v = [item[0] for item in sorted(d.items(),key=lambda item: item[1],reverse= True)[:x]]
+ return top_v
+
+
+def extract_importance(memories,Score_list):
+ """
+ 抽取重要性
+ """
+ for i in range(len(memories)):
+ Score = {"memory":memories[i],
+ "importance":memories[i].poignancy
+ }
+ Score_list.append(Score)
+ return Score_list
+
+# 抽取相关性
+def extract_relevance(query,Score_list):
+ """
+ 抽取相关性
+ """
+ query_embedding = embedding_tools(query)
+ # 进行
+ for i in range(len(Score_list)):
+ result = cos_sim(Score_list[i]["memory"].embedding_key,query_embedding)
+ Score_list[i]['relevance'] = result
+
+ return Score_list
+
+# 抽取近因性
+def extract_recency(currtime,memory_forget,Score_list):
+ """
+ 抽取近因性,目前使用的现实世界过一天走一个衰减因子
+ """
+ for i in range(len(Score_list)):
+ day_count = (currtime-Score_list[i]['memory'].created).days
+ Score_list[i]['recency'] = memory_forget**day_count
+ return Score_list
+
+def cos_sim(a, b):
+ """
+ 计算余弦相似度
+ """
+ return dot(a, b)/(norm(a)*norm(b))
+
+def normalize_List_floats(Single_list,target_min, target_max):
+ """
+ 单个列表归一化
+ """
+ min_val = min(Single_list)
+ max_val = max(Single_list)
+ range_val = max_val - min_val
+
+ if range_val == 0:
+ for i in range(len(Single_list)):
+ Single_list[i] = (target_max - target_min)/2
+ else:
+ for i in range(len(Single_list)):
+ Single_list[i] = ((Single_list[i] - min_val) * (target_max - target_min)
+ / range_val + target_min)
+ return Single_list
+
+
+def normalize_Socre_floats(Score_list, target_min, target_max):
+ """
+ 整体归一化
+ """
+ importance_list = []
+ relevance_list = []
+ recency_list = []
+
+ for i in range(len(Score_list)):
+ importance_list.append(Score_list[i]['importance'])
+ relevance_list.append(Score_list[i]['relevance'])
+ recency_list.append(Score_list[i]['recency'])
+
+ # 进行归一化操作
+ importance_list = normalize_List_floats(importance_list,target_min, target_max)
+ relevance_list = normalize_List_floats(relevance_list,target_min, target_max)
+ recency_list =normalize_List_floats(recency_list,target_min, target_max)
+
+ for i in range(len(Score_list)):
+ Score_list[i]['importance'] = importance_list[i]
+ Score_list[i]['relevance'] = relevance_list[i]
+ Score_list[i]['recency'] = recency_list[i]
+
+ return Score_list
\ No newline at end of file
diff --git a/examples/st_game/memory/scratch.py b/examples/st_game/memory/scratch.py
new file mode 100644
index 000000000..00da03dd6
--- /dev/null
+++ b/examples/st_game/memory/scratch.py
@@ -0,0 +1,6 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Desc : Scratch类实现(角色信息类)
+
+class Scratch():
+ pass
\ No newline at end of file
diff --git a/examples/st_game/prompts/prompts_templates/poignancy_chat_v1.txt b/examples/st_game/prompts/prompts_templates/poignancy_chat_v1.txt
new file mode 100644
index 000000000..572dd8a05
--- /dev/null
+++ b/examples/st_game/prompts/prompts_templates/poignancy_chat_v1.txt
@@ -0,0 +1,17 @@
+poignancy_chat_v1.txt
+
+!!: agent name
+!!: iss
+!!: name
+!!: event description
+
+###
+Here is a brief description of !!.
+!!
+
+On the scale of 1 to 10, where 1 is purely mundane (e.g., routine morning greetings) and 10 is extremely poignant (e.g., a conversation about breaking up, a fight), rate the likely poignancy of the following conversation for !!.
+
+Conversation:
+!!
+
+Rate (return a number between 1 to 10):
\ No newline at end of file
diff --git a/examples/st_game/prompts/run_gpt_prompts.py b/examples/st_game/prompts/run_gpt_prompts.py
new file mode 100644
index 000000000..4c94f3bea
--- /dev/null
+++ b/examples/st_game/prompts/run_gpt_prompts.py
@@ -0,0 +1,33 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Desc : 调用PromptTemplates中模板,实现
+
+from wrapper_prompt import special_response_generate,prompt_generate
+from memory.scratch import Scratch
+from memory.associative_memory import MemoryBasic
+import json
+
+def run_gpt_prompt_chat_poignancy(scratch:Scratch,content:MemoryBasic.content)->str:
+ """
+ 衡量事件心酸度
+ """
+ def create_prompt_input(scratch,content):
+ prompt_input = [scratch.name,
+ scratch.iss,
+ scratch.name,
+ content]
+ return prompt_input
+
+ # 1. Prompt构建
+ # 2. Instruction给出
+ prompt_template = "prompt_templates/poignancy_chat_v1.txt" ########
+ prompt_input = create_prompt_input(scratch, content) ########
+ prompt = prompt_generate(prompt_input, prompt_template)
+ special_instruction = "The output should ONLY contain ONE integer value on the scale of 1 to 10."
+ poignancy = special_response_generate(prompt,special_instruction)
+ try:
+ poi_dict = json.loads(poignancy)
+ return (poi_dict['poignancy'])
+ except:
+ return poignancy
+
diff --git a/examples/st_game/prompts/wrapper_prompt.py b/examples/st_game/prompts/wrapper_prompt.py
new file mode 100644
index 000000000..b61e13520
--- /dev/null
+++ b/examples/st_game/prompts/wrapper_prompt.py
@@ -0,0 +1,42 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Desc : 基于Prmopt Templates 填充Prompt; 为Prompt包装与调用
+
+from metagpt import llm
+
+def prompt_generate(curr_input:list, prompt_path:str):
+ """
+ curr_input:输入一个按照PromptTemplate的要求的列表
+ prompt_path:输入一个Promptpath
+ """
+ if type(curr_input) == type("string"):
+ curr_input = [curr_input]
+ curr_input = [str(i) for i in curr_input]
+
+ f = open(prompt_path, "r")
+ prompt = f.read()
+ f.close()
+ for count, i in enumerate(curr_input):
+ prompt = prompt.replace(f"!!", i)
+ if "###" in prompt:
+ prompt = prompt.split("###")[1]
+ return prompt.strip()
+
+def response_generate(prompt:str):
+ """
+ 待完善,我没有找到MG中可以设置Temprature以及Maxtoken的位置
+ """
+ return llm.ai_func(prompt)
+
+def special_response_generate(prompt:str,special_instruction:str,example_output:str = None):
+ """
+ 当对于Prompt生成有特殊要求时,调用该函数增加special_instruction或example_output
+ """
+ prompt = '"""\n' + prompt + '\n"""\n'
+ prompt += f"Output the response to the prompt above in json. {special_instruction}\n"
+ if example_output:
+ prompt += "Example output json:\n"
+ prompt += '{"output": "' + str(example_output) + '"}'
+ return response_generate(prompt)
+
+
diff --git a/examples/st_game/roles/st_role.py b/examples/st_game/roles/st_role.py
index fa071f069..bc6988e28 100644
--- a/examples/st_game/roles/st_role.py
+++ b/examples/st_game/roles/st_role.py
@@ -17,7 +17,7 @@ from pathlib import Path
from metagpt.roles.role import Role, RoleContext
from metagpt.schema import Message
-from ..memory.associative_memory import AssociativeMemory
+from ..memory.associative_memory import AgentMemory
from ..actions.dummy_action import DummyAction
from ..actions.user_requirement import UserRequirement
from ..maze_environment import MazeEnvironment
@@ -25,7 +25,7 @@ from ..maze_environment import MazeEnvironment
class STRoleContext(RoleContext):
env: 'MazeEnvironment' = Field(default=None)
- memory: AssociativeMemory = Field(default=AssociativeMemory)
+ memory: AgentMemory = Field(default=AgentMemory)
class STRole(Role):
diff --git a/examples/st_game/utils/utils.py b/examples/st_game/utils/utils.py
index a70f7606d..11cbabd8e 100644
--- a/examples/st_game/utils/utils.py
+++ b/examples/st_game/utils/utils.py
@@ -4,6 +4,7 @@
from typing import Any
import json
+import openai
from pathlib import Path
@@ -22,3 +23,11 @@ def read_json_file(json_file: str, encoding=None) -> list[Any]:
def write_json_file(json_file: str, data: list, encoding=None):
with open(json_file, "w", encoding=encoding) as fout:
json.dump(data, fout, ensure_ascii=False, indent=4)
+
+def embedding_tools(query):
+ embedding_result = openai.Embedding.create(
+ model="text-embedding-ada-002",
+ input=query
+ )
+ embedding_key = embedding_result['data'][0]["embedding"]
+ return embedding_key
\ No newline at end of file
From 9672d8296981a34ca14b29addc963b5700145857 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E7=88=B1=E5=90=83=E5=B1=81=E7=9A=84=E5=B0=8F=E5=BC=A0?=
<84363704+fucking-dog@users.noreply.github.com>
Date: Thu, 28 Sep 2023 11:02:35 +0800
Subject: [PATCH 2/9] Update associative_memory.py
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
修改规范错误
---
examples/st_game/memory/associative_memory.py | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
diff --git a/examples/st_game/memory/associative_memory.py b/examples/st_game/memory/associative_memory.py
index 5b78926b9..6a40b3dda 100644
--- a/examples/st_game/memory/associative_memory.py
+++ b/examples/st_game/memory/associative_memory.py
@@ -9,10 +9,10 @@ from datetime import datetime
class MemoryBasic(Message):
- def __init__(self,memory_id:str,memory_count:int,type_count:int,memory_type:str,depth:int,content:int,
- creaetd:datetime,expiration:datetime,
- subject:str,predicate:str,object:str,
- embedding_key:str,poignancy:int,keywords:list,filling:list):
+ def __init__(self, memory_id:str, memory_count:int, type_count:int, memory_type:str, depth:int, content:int,
+ creaetd:datetime, expiration:datetime,
+ subject:str, predicate:str, object:str,
+ embedding_key:str, poignancy:int, keywords:list, filling:list):
"""
MemoryBasic继承于MG的Message类,其中content属性替代description属性
Message类中对于Chat类型支持的非常好,对于Agent个体的Perceive,Reflection,Plan支持的并不多
@@ -51,7 +51,7 @@ class AgentMemory(Memory):
2. Node.json (Dict Node_id:Node)
3. kw_strength.json
"""
- def __init__(self,memory_saved:str):
+ def __init__(self, memory_saved:str):
"""
AgentMemory类继承自Memory类,重写storage替代GA中id_to_node,一方面存储所有信息,一方面作为JSON转化
index存储与不同Agent的chat信息
From 6911ca87ab38f09c001819e75544b610fee774ed Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E7=88=B1=E5=90=83=E5=B1=81=E7=9A=84=E5=B0=8F=E5=BC=A0?=
<84363704+fucking-dog@users.noreply.github.com>
Date: Thu, 28 Sep 2023 11:07:20 +0800
Subject: [PATCH 3/9] Update and rename retrive.py to retrieve.py
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
修改了命名与格式问题
retrieve暂时无法放在AgentMemory类中,这个方法计划是交给Role调用的,因为需要使用到AgentMemory类与Scratch类
---
.../st_game/memory/{retrive.py => retrieve.py} | 16 ++++++++--------
1 file changed, 8 insertions(+), 8 deletions(-)
rename examples/st_game/memory/{retrive.py => retrieve.py} (89%)
diff --git a/examples/st_game/memory/retrive.py b/examples/st_game/memory/retrieve.py
similarity index 89%
rename from examples/st_game/memory/retrive.py
rename to examples/st_game/memory/retrieve.py
index 0bca9ea8b..4119524aa 100644
--- a/examples/st_game/memory/retrive.py
+++ b/examples/st_game/memory/retrieve.py
@@ -8,7 +8,7 @@ from datetime import datetime
from associative_memory import AgentMemory,MemoryBasic
from utils.utils import embedding_tools
-def agent_retrive(agent:AgentMemory,currtime:datetime,memory_forget:float,query:str,n:int= 30,topk:int=4) -> list[MemoryBasic]:
+def agent_retrive(agentmemory:AgentMemory, currtime:datetime, memory_forget:float, query:str, n:int= 30, topk:int=4) -> list[MemoryBasic]:
"""
retrive需要集合Role使用,原因在于Role才具有AgentMemory,scratch
逻辑:Role调用该函数,self._rc.AgentMemory,self._rc.scratch.currtime,self._rc.scratch.memory_forget
@@ -22,7 +22,7 @@ def agent_retrive(agent:AgentMemory,currtime:datetime,memory_forget:float,query:
"relevance":搜索结果
}
"""
- memories = AgentMemory.storage
+ memories = agentmemory.storage
sorted_memories = sorted(memories, key=lambda memory_node: memory_node.last_accessed_time,reverse=True)
memories = sorted_memories[:n] if len(sorted_memories) >= n else sorted_memories
@@ -54,7 +54,7 @@ def top_highest_x_values(d, x):
return top_v
-def extract_importance(memories,Score_list):
+def extract_importance(memories, Score_list):
"""
抽取重要性
"""
@@ -66,7 +66,7 @@ def extract_importance(memories,Score_list):
return Score_list
# 抽取相关性
-def extract_relevance(query,Score_list):
+def extract_relevance(query, Score_list):
"""
抽取相关性
"""
@@ -79,7 +79,7 @@ def extract_relevance(query,Score_list):
return Score_list
# 抽取近因性
-def extract_recency(currtime,memory_forget,Score_list):
+def extract_recency(currtime, memory_forget, Score_list):
"""
抽取近因性,目前使用的现实世界过一天走一个衰减因子
"""
@@ -94,7 +94,7 @@ def cos_sim(a, b):
"""
return dot(a, b)/(norm(a)*norm(b))
-def normalize_List_floats(Single_list,target_min, target_max):
+def normalize_List_floats(Single_list, target_min, target_max):
"""
单个列表归一化
"""
@@ -112,7 +112,7 @@ def normalize_List_floats(Single_list,target_min, target_max):
return Single_list
-def normalize_Socre_floats(Score_list, target_min, target_max):
+def normalize_socre_floats(Score_list, target_min, target_max):
"""
整体归一化
"""
@@ -135,4 +135,4 @@ def normalize_Socre_floats(Score_list, target_min, target_max):
Score_list[i]['relevance'] = relevance_list[i]
Score_list[i]['recency'] = recency_list[i]
- return Score_list
\ No newline at end of file
+ return Score_list
From c10f1306f511ff80f7d0c21f49439b0b39727a4d Mon Sep 17 00:00:00 2001
From: didi <2020201387@ruc.edu.cn>
Date: Thu, 28 Sep 2023 11:18:22 +0800
Subject: [PATCH 4/9] =?UTF-8?q?=E5=AE=8C=E6=88=90=E4=BF=AE=E6=94=B9?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
完成PR374修改
---
examples/st_game/memory/retrieve.py | 12 ++++++------
.../{prompts_templates => }/poignancy_chat_v1.txt | 0
examples/st_game/prompts/run_gpt_prompts.py | 8 ++++----
3 files changed, 10 insertions(+), 10 deletions(-)
rename examples/st_game/prompts/{prompts_templates => }/poignancy_chat_v1.txt (100%)
diff --git a/examples/st_game/memory/retrieve.py b/examples/st_game/memory/retrieve.py
index 4119524aa..79042df06 100644
--- a/examples/st_game/memory/retrieve.py
+++ b/examples/st_game/memory/retrieve.py
@@ -27,10 +27,10 @@ def agent_retrive(agentmemory:AgentMemory, currtime:datetime, memory_forget:floa
memories = sorted_memories[:n] if len(sorted_memories) >= n else sorted_memories
Score_list = []
- Score_list = extract_importance(memories,Score_list)
- Score_list = extract_recency(currtime,memory_forget,Score_list)
- Score_list = extract_relevance(query,Score_list)
- Score_list = normalize_Socre_floats(Score_list,0,1)
+ Score_list = extract_importance(memories, Score_list)
+ Score_list = extract_recency(currtime, memory_forget, Score_list)
+ Score_list = extract_relevance(query, Score_list)
+ Score_list = normalize_Socre_floats(Score_list, 0, 1)
total_dict = {}
gw = [1,1,1] # 三个因素的权重,重要性,近因性,相关性
@@ -41,7 +41,7 @@ def agent_retrive(agentmemory:AgentMemory, currtime:datetime, memory_forget:floa
)
total_dict[Score_list[i]['memory']] = total_score
- result = top_highest_x_values(total_dict,topk)
+ result = top_highest_x_values(total_dict, topk)
return result
@@ -73,7 +73,7 @@ def extract_relevance(query, Score_list):
query_embedding = embedding_tools(query)
# 进行
for i in range(len(Score_list)):
- result = cos_sim(Score_list[i]["memory"].embedding_key,query_embedding)
+ result = cos_sim(Score_list[i]["memory"].embedding_key, query_embedding)
Score_list[i]['relevance'] = result
return Score_list
diff --git a/examples/st_game/prompts/prompts_templates/poignancy_chat_v1.txt b/examples/st_game/prompts/poignancy_chat_v1.txt
similarity index 100%
rename from examples/st_game/prompts/prompts_templates/poignancy_chat_v1.txt
rename to examples/st_game/prompts/poignancy_chat_v1.txt
diff --git a/examples/st_game/prompts/run_gpt_prompts.py b/examples/st_game/prompts/run_gpt_prompts.py
index 4c94f3bea..86db1c7c2 100644
--- a/examples/st_game/prompts/run_gpt_prompts.py
+++ b/examples/st_game/prompts/run_gpt_prompts.py
@@ -7,11 +7,11 @@ from memory.scratch import Scratch
from memory.associative_memory import MemoryBasic
import json
-def run_gpt_prompt_chat_poignancy(scratch:Scratch,content:MemoryBasic.content)->str:
+def get_poignancy_action(scratch:Scratch, content:MemoryBasic.content)->str:
"""
衡量事件心酸度
"""
- def create_prompt_input(scratch,content):
+ def create_prompt_input(scratch, content):
prompt_input = [scratch.name,
scratch.iss,
scratch.name,
@@ -20,11 +20,11 @@ def run_gpt_prompt_chat_poignancy(scratch:Scratch,content:MemoryBasic.content)->
# 1. Prompt构建
# 2. Instruction给出
- prompt_template = "prompt_templates/poignancy_chat_v1.txt" ########
+ prompt_template = "poignancy_chat_v1.txt" ########
prompt_input = create_prompt_input(scratch, content) ########
prompt = prompt_generate(prompt_input, prompt_template)
special_instruction = "The output should ONLY contain ONE integer value on the scale of 1 to 10."
- poignancy = special_response_generate(prompt,special_instruction)
+ poignancy = special_response_generate(prompt, special_instruction)
try:
poi_dict = json.loads(poignancy)
return (poi_dict['poignancy'])
From e7c966653e85d599a4ea21f18b49dc48e6b4ace9 Mon Sep 17 00:00:00 2001
From: didi <2020201387@ruc.edu.cn>
Date: Thu, 28 Sep 2023 11:54:11 +0800
Subject: [PATCH 5/9] =?UTF-8?q?format=E4=BF=AE=E6=94=B9?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
完成
---
examples/st_game/memory/associative_memory.py | 30 ++--
examples/st_game/memory/retrieve.py | 152 +++++++++---------
examples/st_game/prompts/run_gpt_prompts.py | 20 +--
examples/st_game/prompts/wrapper_prompt.py | 52 +++---
4 files changed, 134 insertions(+), 120 deletions(-)
diff --git a/examples/st_game/memory/associative_memory.py b/examples/st_game/memory/associative_memory.py
index 6a40b3dda..c771906ec 100644
--- a/examples/st_game/memory/associative_memory.py
+++ b/examples/st_game/memory/associative_memory.py
@@ -7,12 +7,13 @@ from metagpt.schema import Message
import json
from datetime import datetime
+
class MemoryBasic(Message):
- def __init__(self, memory_id:str, memory_count:int, type_count:int, memory_type:str, depth:int, content:int,
- creaetd:datetime, expiration:datetime,
- subject:str, predicate:str, object:str,
- embedding_key:str, poignancy:int, keywords:list, filling:list):
+ def __init__(self, memory_id: str, memory_count: int, type_count: int, memory_type: str, depth: int, content: int,
+ creaetd: datetime, expiration: datetime,
+ subject: str, predicate: str, object: str,
+ embedding_key: str, poignancy: int, keywords: list, filling: list):
"""
MemoryBasic继承于MG的Message类,其中content属性替代description属性
Message类中对于Chat类型支持的非常好,对于Agent个体的Perceive,Reflection,Plan支持的并不多
@@ -29,29 +30,30 @@ class MemoryBasic(Message):
self.memory_count: int = memory_count # 第几个记忆,实际数值与Memory相等,但是类型为整数
self.type_count: int = type_count # 第几种记忆,类型为整数(具体不太理解如何生成的)
self.memory_type: str = memory_type # 记忆类型,使用Field,包含 event,thought,chat三种类型
- self.depth:str = depth # 记忆深度,类型为整数
+ self.depth: str = depth # 记忆深度,类型为整数
self.created: datetime = creaetd # 创建时间
self.expiration: datetime = expiration # 记忆失效时间,默认为空()
self.last_accessed: datetime = creaetd # 上一次调用的时间,初始化时候与self.created一致
self.subject: str = subject # 主语,str类型
- self.predicate:str = predicate # 谓语,str类型
- self.object:str = object # 宾语,str类型
+ self.predicate: str = predicate # 谓语,str类型
+ self.object: str = object # 宾语,str类型
self.embedding_key: str = embedding_key # 内容与self.content一致
- self.poignancy:int = poignancy # importance值,整数类型
- self.keywords:list = keywords # keywords,列表
- self.filling:list = filling # None或者列表
+ self.poignancy: int = poignancy # importance值,整数类型
+ self.keywords: list = keywords # keywords,列表
+ self.filling: list = filling # None或者列表
+
class AgentMemory(Memory):
"""
GA中主要存储三种JSON
1. embedding.json (Dict embedding_key:embedding)
- 2. Node.json (Dict Node_id:Node)
- 3. kw_strength.json
+ 2. Node.json (Dict Node_id:Node)
+ 3. kw_strength.json
"""
- def __init__(self, memory_saved:str):
+ def __init__(self, memory_saved: str):
"""
AgentMemory类继承自Memory类,重写storage替代GA中id_to_node,一方面存储所有信息,一方面作为JSON转化
index存储与不同Agent的chat信息
@@ -61,7 +63,7 @@ class AgentMemory(Memory):
self.storage: list[MemoryBasic] = [] # 重写Stroage,存储MemoryBasic所有节点
self.event_list = [] # 存储event记忆
self.thought_list = [] # 存储thought记忆
-
+
self.event_keywords = dict() # 存储keywords
self.thought_keywords = dict()
self.chat_keywords = dict()
diff --git a/examples/st_game/memory/retrieve.py b/examples/st_game/memory/retrieve.py
index 79042df06..5ac4a9b29 100644
--- a/examples/st_game/memory/retrieve.py
+++ b/examples/st_game/memory/retrieve.py
@@ -1,118 +1,122 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
-# @Desc : Retrive函数实现
+# @Desc : Retrieve函数实现
+import datetime
from numpy import dot
from numpy.linalg import norm
-from datetime import datetime
-from associative_memory import AgentMemory,MemoryBasic
+from associative_memory import AgentMemory, MemoryBasic
from utils.utils import embedding_tools
-def agent_retrive(agentmemory:AgentMemory, currtime:datetime, memory_forget:float, query:str, n:int= 30, topk:int=4) -> list[MemoryBasic]:
+
+def agent_retrieve(agent_memory: AgentMemory, curr_time: datetime.datetime, memory_forget: float, query: str, n: int = 30, topk: int = 4) -> list[MemoryBasic]:
"""
- retrive需要集合Role使用,原因在于Role才具有AgentMemory,scratch
- 逻辑:Role调用该函数,self._rc.AgentMemory,self._rc.scratch.currtime,self._rc.scratch.memory_forget
+ Retrieve需要集合Role使用,原因在于Role才具有AgentMemory,scratch
+ 逻辑:Role调用该函数,self._rc.AgentMemory,self._rc.scratch.curr_time,self._rc.scratch.memory_forget
输入希望查询的内容与希望回顾的条数,返回TopK条高分记忆,即List[MemoryBasic]
Score_lists示例
{
- "memory":memories[i], MemoryBasic类
- "importance":memories[i].poignancy
- "recency":衰减因子计算结果
- "relevance":搜索结果
+ "memory": memories[i], MemoryBasic类
+ "importance": memories[i].poignancy
+ "recency": 衰减因子计算结果
+ "relevance": 搜索结果
}
"""
- memories = agentmemory.storage
- sorted_memories = sorted(memories, key=lambda memory_node: memory_node.last_accessed_time,reverse=True)
+ memories = agent_memory.storage
+ sorted_memories = sorted(memories, key=lambda memory_node: memory_node.last_accessed_time, reverse=True)
memories = sorted_memories[:n] if len(sorted_memories) >= n else sorted_memories
- Score_list = []
- Score_list = extract_importance(memories, Score_list)
- Score_list = extract_recency(currtime, memory_forget, Score_list)
- Score_list = extract_relevance(query, Score_list)
- Score_list = normalize_Socre_floats(Score_list, 0, 1)
+ score_list = []
+ score_list = extract_importance(memories, score_list)
+ score_list = extract_recency(curr_time, memory_forget, score_list)
+ score_list = extract_relevance(query, score_list)
+ score_list = normalize_score_floats(score_list, 0, 1)
+
+ total_dict = {}
+ gw = [1, 1, 1] # 三个因素的权重,重要性,近因性,相关性
+ for i in range(len(score_list)):
+ total_score = (score_list[i]['importance'] * gw[0] +
+ score_list[i]['recency'] * gw[1] +
+ score_list[i]['relevance'] * gw[2]
+ )
+ total_dict[score_list[i]['memory']] = total_score
- total_dict = {}
- gw = [1,1,1] # 三个因素的权重,重要性,近因性,相关性
- for i in range(len(Score_list)):
- total_score = (Score_list[i]['importance']*gw[0] +
- Score_list[i]['recency']*gw[1] +
- Score_list[i]['relevance']*gw[2]
- )
- total_dict[Score_list[i]['memory']] = total_score
-
result = top_highest_x_values(total_dict, topk)
return result
+
def top_highest_x_values(d, x):
"""
输入字典,Topx
返回以字典值排序,字典键组成的List[MemoryBasic]
"""
- top_v = [item[0] for item in sorted(d.items(),key=lambda item: item[1],reverse= True)[:x]]
+ top_v = [item[0] for item in sorted(d.items(), key=lambda item: item[1], reverse=True)[:x]]
return top_v
-def extract_importance(memories, Score_list):
+def extract_importance(memories, score_list):
"""
抽取重要性
"""
for i in range(len(memories)):
- Score = {"memory":memories[i],
- "importance":memories[i].poignancy
+ score = {"memory": memories[i],
+ "importance": memories[i].poignancy
}
- Score_list.append(Score)
- return Score_list
+ score_list.append(score)
+ return score_list
-# 抽取相关性
-def extract_relevance(query, Score_list):
+
+def extract_relevance(query, score_list):
"""
抽取相关性
"""
query_embedding = embedding_tools(query)
# 进行
- for i in range(len(Score_list)):
- result = cos_sim(Score_list[i]["memory"].embedding_key, query_embedding)
- Score_list[i]['relevance'] = result
+ for i in range(len(score_list)):
+ result = cos_sim(score_list[i]["memory"].embedding_key, query_embedding)
+ score_list[i]['relevance'] = result
- return Score_list
+ return score_list
-# 抽取近因性
-def extract_recency(currtime, memory_forget, Score_list):
+
+def extract_recency(curr_time, memory_forget, score_list):
"""
抽取近因性,目前使用的现实世界过一天走一个衰减因子
"""
- for i in range(len(Score_list)):
- day_count = (currtime-Score_list[i]['memory'].created).days
- Score_list[i]['recency'] = memory_forget**day_count
- return Score_list
+ for i in range(len(score_list)):
+ day_count = (curr_time - score_list[i]['memory'].created).days
+ score_list[i]['recency'] = memory_forget**day_count
+ return score_list
-def cos_sim(a, b):
- """
- 计算余弦相似度
- """
- return dot(a, b)/(norm(a)*norm(b))
-def normalize_List_floats(Single_list, target_min, target_max):
+def cos_sim(a, b):
+ """
+ 计算余弦相似度
+ """
+ return dot(a, b) / (norm(a) * norm(b))
+
+
+def normalize_list_floats(single_list, target_min, target_max):
"""
单个列表归一化
"""
- min_val = min(Single_list)
- max_val = max(Single_list)
+ min_val = min(single_list)
+ max_val = max(single_list)
range_val = max_val - min_val
- if range_val == 0:
- for i in range(len(Single_list)):
- Single_list[i] = (target_max - target_min)/2
- else:
- for i in range(len(Single_list)):
- Single_list[i] = ((Single_list[i] - min_val) * (target_max - target_min)
- / range_val + target_min)
- return Single_list
+ if range_val == 0:
+ for i in range(len(single_list)):
+ single_list[i] = (target_max - target_min) / 2
+ else:
+ for i in range(len(single_list)):
+ single_list[i] = ((single_list[i] - min_val) * (target_max - target_min)
+ / range_val + target_min)
+ return single_list
-def normalize_socre_floats(Score_list, target_min, target_max):
+def normalize_score_floats(score_list, target_min, target_max):
"""
整体归一化
"""
@@ -120,19 +124,19 @@ def normalize_socre_floats(Score_list, target_min, target_max):
relevance_list = []
recency_list = []
- for i in range(len(Score_list)):
- importance_list.append(Score_list[i]['importance'])
- relevance_list.append(Score_list[i]['relevance'])
- recency_list.append(Score_list[i]['recency'])
+ for i in range(len(score_list)):
+ importance_list.append(score_list[i]['importance'])
+ relevance_list.append(score_list[i]['relevance'])
+ recency_list.append(score_list[i]['recency'])
# 进行归一化操作
- importance_list = normalize_List_floats(importance_list,target_min, target_max)
- relevance_list = normalize_List_floats(relevance_list,target_min, target_max)
- recency_list =normalize_List_floats(recency_list,target_min, target_max)
+ importance_list = normalize_list_floats(importance_list, target_min, target_max)
+ relevance_list = normalize_list_floats(relevance_list, target_min, target_max)
+ recency_list = normalize_list_floats(recency_list, target_min, target_max)
- for i in range(len(Score_list)):
- Score_list[i]['importance'] = importance_list[i]
- Score_list[i]['relevance'] = relevance_list[i]
- Score_list[i]['recency'] = recency_list[i]
-
- return Score_list
+ for i in range(len(score_list)):
+ score_list[i]['importance'] = importance_list[i]
+ score_list[i]['relevance'] = relevance_list[i]
+ score_list[i]['recency'] = recency_list[i]
+
+ return score_list
diff --git a/examples/st_game/prompts/run_gpt_prompts.py b/examples/st_game/prompts/run_gpt_prompts.py
index 86db1c7c2..16ccbc29c 100644
--- a/examples/st_game/prompts/run_gpt_prompts.py
+++ b/examples/st_game/prompts/run_gpt_prompts.py
@@ -1,18 +1,19 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
-# @Desc : 调用PromptTemplates中模板,实现
+# @Desc : 调用Prompts中模板,实现相关Action
-from wrapper_prompt import special_response_generate,prompt_generate
+from wrapper_prompt import special_response_generate, prompt_generate
from memory.scratch import Scratch
from memory.associative_memory import MemoryBasic
import json
-def get_poignancy_action(scratch:Scratch, content:MemoryBasic.content)->str:
+
+def get_poignancy_action(scratch: Scratch, content: MemoryBasic.content) -> str:
"""
衡量事件心酸度
"""
- def create_prompt_input(scratch, content):
- prompt_input = [scratch.name,
+ def create_prompt_input(scratch, content):
+ prompt_input = [scratch.name,
scratch.iss,
scratch.name,
content]
@@ -20,14 +21,13 @@ def get_poignancy_action(scratch:Scratch, content:MemoryBasic.content)->str:
# 1. Prompt构建
# 2. Instruction给出
- prompt_template = "poignancy_chat_v1.txt" ########
- prompt_input = create_prompt_input(scratch, content) ########
+ prompt_template = "poignancy_chat_v1.txt" # 保留原来的注释
+ prompt_input = create_prompt_input(scratch, content) # 保留原来的注释
prompt = prompt_generate(prompt_input, prompt_template)
special_instruction = "The output should ONLY contain ONE integer value on the scale of 1 to 10."
poignancy = special_response_generate(prompt, special_instruction)
try:
poi_dict = json.loads(poignancy)
- return (poi_dict['poignancy'])
- except:
+ return str(poi_dict['poignancy']) # 将返回值强制转换为字符串
+ except json.JSONDecodeError as e:
return poignancy
-
diff --git a/examples/st_game/prompts/wrapper_prompt.py b/examples/st_game/prompts/wrapper_prompt.py
index b61e13520..0950f99d1 100644
--- a/examples/st_game/prompts/wrapper_prompt.py
+++ b/examples/st_game/prompts/wrapper_prompt.py
@@ -1,42 +1,50 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
-# @Desc : 基于Prmopt Templates 填充Prompt; 为Prompt包装与调用
+# @Desc : 基于Prompt Templates 填充Prompt; 为Prompt包装与调用
from metagpt import llm
-def prompt_generate(curr_input:list, prompt_path:str):
- """
- curr_input:输入一个按照PromptTemplate的要求的列表
- prompt_path:输入一个Promptpath
- """
- if type(curr_input) == type("string"):
- curr_input = [curr_input]
- curr_input = [str(i) for i in curr_input]
- f = open(prompt_path, "r")
- prompt = f.read()
- f.close()
- for count, i in enumerate(curr_input):
+def prompt_generate(curr_input: list, prompt_path: str):
+ """
+ curr_input: 输入一个按照Prompt Template的要求的列表
+ prompt_path: 输入一个Prompt path
+ """
+ # 如果输入是字符串,将其转换为列表
+ if isinstance(curr_input, str):
+ curr_input = [curr_input]
+
+ # 将输入列表中的每个元素转换为字符串
+ curr_input = [str(i) for i in curr_input]
+
+ with open(prompt_path, "r") as f:
+ prompt = f.read()
+
+ for count, i in enumerate(curr_input):
prompt = prompt.replace(f"!!", i)
- if "###" in prompt:
+
+ if "###" in prompt:
prompt = prompt.split("###")[1]
+
return prompt.strip()
-def response_generate(prompt:str):
+
+def response_generate(prompt: str):
"""
- 待完善,我没有找到MG中可以设置Temprature以及Maxtoken的位置
+ 待完善,我没有找到MG中可以设置Temperature以及Maxtoken的位置
"""
return llm.ai_func(prompt)
-def special_response_generate(prompt:str,special_instruction:str,example_output:str = None):
+
+def special_response_generate(prompt: str, special_instruction: str, example_output: str = None):
"""
当对于Prompt生成有特殊要求时,调用该函数增加special_instruction或example_output
"""
prompt = '"""\n' + prompt + '\n"""\n'
- prompt += f"Output the response to the prompt above in json. {special_instruction}\n"
- if example_output:
- prompt += "Example output json:\n"
+ prompt += f"Output the response to the prompt above in JSON. {special_instruction}\n"
+
+ if example_output:
+ prompt += "Example output JSON:\n"
prompt += '{"output": "' + str(example_output) + '"}'
+
return response_generate(prompt)
-
-
From 3b42ab42a0fd48d53581d59846330eea18f05c9e Mon Sep 17 00:00:00 2001
From: didi <2020201387@ruc.edu.cn>
Date: Thu, 28 Sep 2023 21:54:54 +0800
Subject: [PATCH 6/9] =?UTF-8?q?=E5=B7=B2=E5=AE=8C=E6=88=90=E6=B3=A8?=
=?UTF-8?q?=E9=87=8A=E4=BF=AE=E6=94=B9?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
...{associative_memory.py => agent_memory.py} | 36 +++++++++----------
examples/st_game/memory/retrieve.py | 12 +++----
examples/st_game/prompts/run_gpt_prompts.py | 4 +--
examples/st_game/roles/st_role.py | 2 +-
4 files changed, 27 insertions(+), 27 deletions(-)
rename examples/st_game/memory/{associative_memory.py => agent_memory.py} (78%)
diff --git a/examples/st_game/memory/associative_memory.py b/examples/st_game/memory/agent_memory.py
similarity index 78%
rename from examples/st_game/memory/associative_memory.py
rename to examples/st_game/memory/agent_memory.py
index c771906ec..adb8a5f1f 100644
--- a/examples/st_game/memory/associative_memory.py
+++ b/examples/st_game/memory/agent_memory.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
-# @Desc : MemoryBasic,AgentMemory实现
+# @Desc : BasicMemory,AgentMemory实现
from metagpt.memory.memory import Memory
from metagpt.schema import Message
@@ -8,14 +8,14 @@ import json
from datetime import datetime
-class MemoryBasic(Message):
+class BasicMemory(Message):
def __init__(self, memory_id: str, memory_count: int, type_count: int, memory_type: str, depth: int, content: int,
creaetd: datetime, expiration: datetime,
subject: str, predicate: str, object: str,
embedding_key: str, poignancy: int, keywords: list, filling: list):
"""
- MemoryBasic继承于MG的Message类,其中content属性替代description属性
+ BasicMemory继承于MG的Message类,其中content属性替代description属性
Message类中对于Chat类型支持的非常好,对于Agent个体的Perceive,Reflection,Plan支持的并不多
在Type设计上,我们延续GA的三个种类,但是对于Chat种类的对话进行特别设计(具体怎么设计还没想好)
"""
@@ -27,23 +27,23 @@ class MemoryBasic(Message):
cause_by 接受一个Action类,在此项目中,每个Agent需要有一个基础动作[Receive] 用于接受假对话Message;而每个Agent需要有独一无二的动作类,用以接受真对话Message
"""
self.memory_id: str = memory_id # 记忆ID
- self.memory_count: int = memory_count # 第几个记忆,实际数值与Memory相等,但是类型为整数
+ self.memory_count: int = memory_count # 第几个记忆,实际数值与Memory相等
self.type_count: int = type_count # 第几种记忆,类型为整数(具体不太理解如何生成的)
- self.memory_type: str = memory_type # 记忆类型,使用Field,包含 event,thought,chat三种类型
+ self.memory_type: str = memory_type # 记忆类型,包含 event,thought,chat三种类型
self.depth: str = depth # 记忆深度,类型为整数
self.created: datetime = creaetd # 创建时间
self.expiration: datetime = expiration # 记忆失效时间,默认为空()
self.last_accessed: datetime = creaetd # 上一次调用的时间,初始化时候与self.created一致
- self.subject: str = subject # 主语,str类型
- self.predicate: str = predicate # 谓语,str类型
- self.object: str = object # 宾语,str类型
+ self.subject: str = subject # 主语
+ self.predicate: str = predicate # 谓语
+ self.object: str = object # 宾语
self.embedding_key: str = embedding_key # 内容与self.content一致
- self.poignancy: int = poignancy # importance值,整数类型
- self.keywords: list = keywords # keywords,列表
- self.filling: list = filling # None或者列表
+ self.poignancy: int = poignancy # importance值
+ self.keywords: list = keywords # keywords
+ self.filling: list = filling # None或者列表
class AgentMemory(Memory):
@@ -60,7 +60,7 @@ class AgentMemory(Memory):
@李嵩@张凯 这里的storage是List,你们需要写一个JSON转化器,将List修改为node.json一致的格式
"""
super.__init__()
- self.storage: list[MemoryBasic] = [] # 重写Stroage,存储MemoryBasic所有节点
+ self.storage: list[BasicMemory] = [] # 重写Stroage,存储BasicMemory所有节点
self.event_list = [] # 存储event记忆
self.thought_list = [] # 存储thought记忆
@@ -71,27 +71,27 @@ class AgentMemory(Memory):
self.strength_event_keywords = dict() # 不知道具体作用,所以没有删除
self.strength_thought_keywords = dict()
- self.embeddings = json.load(open(memory_saved + "/embeddings.json")) # dict类型,load embedding.json
- self.memory_load()
+ self.embeddings = json.load(open(memory_saved + "/embeddings.json"))
+ self.load()
- def memory_save(self):
+ def save(self):
"""
将MemormyBasic类存储为Nodes.json形式。复现GA中的Kw Strength.json形式
@张凯补充一个可调用的函数
"""
pass
- def memory_load(self):
+ def load(self):
"""
将GA的JSON解析,填充到AgentMemory类之中
"""
pass
- def add(self, memory_basic: MemoryBasic):
+ def add(self, memory_basic: BasicMemory):
"""
Add a new message to storage, while updating the index
- 重写add方法,修改原有的Message类为MemoryBasic类,并添加不同的记忆类型添加方式
+ 重写add方法,修改原有的Message类为BasicMemory类,并添加不同的记忆类型添加方式
"""
if memory_basic in self.storage:
return
diff --git a/examples/st_game/memory/retrieve.py b/examples/st_game/memory/retrieve.py
index 5ac4a9b29..97eb3b6f0 100644
--- a/examples/st_game/memory/retrieve.py
+++ b/examples/st_game/memory/retrieve.py
@@ -5,19 +5,19 @@
import datetime
from numpy import dot
from numpy.linalg import norm
-from associative_memory import AgentMemory, MemoryBasic
+from examples.st_game.memory.agent_memory import AgentMemory, BasicMemory
from utils.utils import embedding_tools
-def agent_retrieve(agent_memory: AgentMemory, curr_time: datetime.datetime, memory_forget: float, query: str, n: int = 30, topk: int = 4) -> list[MemoryBasic]:
+def agent_retrieve(agent_memory: AgentMemory, curr_time: datetime.datetime, memory_forget: float, query: str, n: int = 30, topk: int = 4) -> list[BasicMemory]:
"""
Retrieve需要集合Role使用,原因在于Role才具有AgentMemory,scratch
逻辑:Role调用该函数,self._rc.AgentMemory,self._rc.scratch.curr_time,self._rc.scratch.memory_forget
- 输入希望查询的内容与希望回顾的条数,返回TopK条高分记忆,即List[MemoryBasic]
+ 输入希望查询的内容与希望回顾的条数,返回TopK条高分记忆,即List[BasicMemory]
Score_lists示例
{
- "memory": memories[i], MemoryBasic类
+ "memory": memories[i], BasicMemory类
"importance": memories[i].poignancy
"recency": 衰减因子计算结果
"relevance": 搜索结果
@@ -34,7 +34,7 @@ def agent_retrieve(agent_memory: AgentMemory, curr_time: datetime.datetime, memo
score_list = normalize_score_floats(score_list, 0, 1)
total_dict = {}
- gw = [1, 1, 1] # 三个因素的权重,重要性,近因性,相关性
+ gw = [1, 1, 1] # 三个因素的权重,重要性,近因性,相关性
for i in range(len(score_list)):
total_score = (score_list[i]['importance'] * gw[0] +
score_list[i]['recency'] * gw[1] +
@@ -50,7 +50,7 @@ def agent_retrieve(agent_memory: AgentMemory, curr_time: datetime.datetime, memo
def top_highest_x_values(d, x):
"""
输入字典,Topx
- 返回以字典值排序,字典键组成的List[MemoryBasic]
+ 返回以字典值排序,字典键组成的List[BasicMemory]
"""
top_v = [item[0] for item in sorted(d.items(), key=lambda item: item[1], reverse=True)[:x]]
return top_v
diff --git a/examples/st_game/prompts/run_gpt_prompts.py b/examples/st_game/prompts/run_gpt_prompts.py
index 16ccbc29c..14b699c15 100644
--- a/examples/st_game/prompts/run_gpt_prompts.py
+++ b/examples/st_game/prompts/run_gpt_prompts.py
@@ -4,11 +4,11 @@
from wrapper_prompt import special_response_generate, prompt_generate
from memory.scratch import Scratch
-from memory.associative_memory import MemoryBasic
+from examples.st_game.memory.agent_memory import BasicMemory
import json
-def get_poignancy_action(scratch: Scratch, content: MemoryBasic.content) -> str:
+def get_poignancy_action(scratch: Scratch, content: BasicMemory.content) -> str:
"""
衡量事件心酸度
"""
diff --git a/examples/st_game/roles/st_role.py b/examples/st_game/roles/st_role.py
index bc6988e28..047f9545b 100644
--- a/examples/st_game/roles/st_role.py
+++ b/examples/st_game/roles/st_role.py
@@ -17,7 +17,7 @@ from pathlib import Path
from metagpt.roles.role import Role, RoleContext
from metagpt.schema import Message
-from ..memory.associative_memory import AgentMemory
+from ..memory.agent_memory import AgentMemory
from ..actions.dummy_action import DummyAction
from ..actions.user_requirement import UserRequirement
from ..maze_environment import MazeEnvironment
From 0f8f4fba5bdc2f959b23aa1e5e2195fa35318d48 Mon Sep 17 00:00:00 2001
From: didi <2020201387@ruc.edu.cn>
Date: Thu, 28 Sep 2023 22:53:45 +0800
Subject: [PATCH 7/9] =?UTF-8?q?=E5=B0=8F=E9=94=99=E8=AF=AF?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
examples/st_game/memory/agent_memory.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/examples/st_game/memory/agent_memory.py b/examples/st_game/memory/agent_memory.py
index adb8a5f1f..db7ff80b8 100644
--- a/examples/st_game/memory/agent_memory.py
+++ b/examples/st_game/memory/agent_memory.py
@@ -11,7 +11,7 @@ from datetime import datetime
class BasicMemory(Message):
def __init__(self, memory_id: str, memory_count: int, type_count: int, memory_type: str, depth: int, content: int,
- creaetd: datetime, expiration: datetime,
+ created: datetime, expiration: datetime,
subject: str, predicate: str, object: str,
embedding_key: str, poignancy: int, keywords: list, filling: list):
"""
@@ -32,9 +32,9 @@ class BasicMemory(Message):
self.memory_type: str = memory_type # 记忆类型,包含 event,thought,chat三种类型
self.depth: str = depth # 记忆深度,类型为整数
- self.created: datetime = creaetd # 创建时间
+ self.created: datetime = created # 创建时间
self.expiration: datetime = expiration # 记忆失效时间,默认为空()
- self.last_accessed: datetime = creaetd # 上一次调用的时间,初始化时候与self.created一致
+ self.last_accessed: datetime = created # 上一次调用的时间,初始化时候与self.created一致
self.subject: str = subject # 主语
self.predicate: str = predicate # 谓语
From 13e75bab8e84046c5281d9f50bfd060f1eeba18a Mon Sep 17 00:00:00 2001
From: didi <2020201387@ruc.edu.cn>
Date: Thu, 28 Sep 2023 23:10:44 +0800
Subject: [PATCH 8/9] Update agent_memory.py
---
examples/st_game/memory/agent_memory.py | 7 +++++--
1 file changed, 5 insertions(+), 2 deletions(-)
diff --git a/examples/st_game/memory/agent_memory.py b/examples/st_game/memory/agent_memory.py
index db7ff80b8..ecf6190d4 100644
--- a/examples/st_game/memory/agent_memory.py
+++ b/examples/st_game/memory/agent_memory.py
@@ -10,7 +10,7 @@ from datetime import datetime
class BasicMemory(Message):
- def __init__(self, memory_id: str, memory_count: int, type_count: int, memory_type: str, depth: int, content: int,
+ def __init__(self, memory_id: str, memory_count: int, type_count: int, memory_type: str, depth: int, content: str,
created: datetime, expiration: datetime,
subject: str, predicate: str, object: str,
embedding_key: str, poignancy: int, keywords: list, filling: list):
@@ -19,7 +19,7 @@ class BasicMemory(Message):
Message类中对于Chat类型支持的非常好,对于Agent个体的Perceive,Reflection,Plan支持的并不多
在Type设计上,我们延续GA的三个种类,但是对于Chat种类的对话进行特别设计(具体怎么设计还没想好)
"""
- super.__init__(content)
+ super().__init__(content)
"""
从父类中继承的属性
content: str # 记忆描述
@@ -128,3 +128,6 @@ class AgentMemory(Memory):
调用
"""
pass
+
+if __name__ == "__main__":
+
\ No newline at end of file
From e035706091ebef6626ca38265c47200e3a42cc5e Mon Sep 17 00:00:00 2001
From: didi <2020201387@ruc.edu.cn>
Date: Fri, 29 Sep 2023 17:49:31 +0800
Subject: [PATCH 9/9] =?UTF-8?q?9.29=20=E6=9B=B4=E6=96=B0?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
完善了一下AgentMemory的Add方法(反思与Plan中都要用),save load方法(需要跟组员对修改新的一版);添加了Scratch类属性与方法(给了一版文档介绍);修改了STrole中STrolecontext属性,添加了Strole中Retrieve方法
---
examples/st_game/memory/agent_memory.py | 243 +++++++++--
examples/st_game/memory/scratch.py | 535 +++++++++++++++++++++++-
examples/st_game/roles/st_role.py | 11 +-
examples/st_game/utils/check.py | 14 +
4 files changed, 771 insertions(+), 32 deletions(-)
create mode 100644 examples/st_game/utils/check.py
diff --git a/examples/st_game/memory/agent_memory.py b/examples/st_game/memory/agent_memory.py
index ecf6190d4..a56100ee7 100644
--- a/examples/st_game/memory/agent_memory.py
+++ b/examples/st_game/memory/agent_memory.py
@@ -10,16 +10,17 @@ from datetime import datetime
class BasicMemory(Message):
- def __init__(self, memory_id: str, memory_count: int, type_count: int, memory_type: str, depth: int, content: str,
+ def __init__(self, memory_id: str, memory_count: int, type_count: int, memory_type: str, depth: int,
created: datetime, expiration: datetime,
subject: str, predicate: str, object: str,
- embedding_key: str, poignancy: int, keywords: list, filling: list):
+ content: str, embedding_key: str, poignancy: int, keywords: list, filling: list,
+ cause_by = ""):
"""
BasicMemory继承于MG的Message类,其中content属性替代description属性
Message类中对于Chat类型支持的非常好,对于Agent个体的Perceive,Reflection,Plan支持的并不多
在Type设计上,我们延续GA的三个种类,但是对于Chat种类的对话进行特别设计(具体怎么设计还没想好)
"""
- super().__init__(content)
+ super().__init__(content,cause_by=cause_by)
"""
从父类中继承的属性
content: str # 记忆描述
@@ -43,8 +44,41 @@ class BasicMemory(Message):
self.embedding_key: str = embedding_key # 内容与self.content一致
self.poignancy: int = poignancy # importance值
self.keywords: list = keywords # keywords
- self.filling: list = filling # None或者列表
+ self.filling: list = filling # 装的与之相关联的memory_id的列表
+ def save_to_dict(self) -> dict:
+ """
+ 将MemoryBasic类转化为字典,用于存储json文件
+ 这里需要注意,cause_by跟GA不兼容,所以需要做一个格式转换
+ """
+ memory_dict = dict()
+ node_id = self.memory_id
+
+ memory_dict[node_id] = dict()
+ memory_dict[node_id]["node_count"] = self.memory_count
+ memory_dict[node_id]["type_count"] = self.type_count
+ memory_dict[node_id]["type"] = self.type
+ memory_dict[node_id]["depth"] = self.depth
+
+ memory_dict[node_id]["cmemory_dicteated"] = self.created.strftime('%Y-%m-%d %H:%M:%S')
+ memory_dict[node_id]["expiration"] = None
+ if self.expiration:
+ memory_dict[node_id]["expiration"] = (self.expiration
+ .strftime('%Y-%m-%d %H:%M:%S'))
+
+ memory_dict[node_id]["subject"] = self.subject
+ memory_dict[node_id]["predicate"] = self.predicate
+ memory_dict[node_id]["object"] = self.object
+
+ memory_dict[node_id]["description"] = self.description
+ memory_dict[node_id]["embedding_key"] = self.embedding_key
+ memory_dict[node_id]["poignancy"] = self.poignancy
+ memory_dict[node_id]["keywords"] = list(self.keywords)
+ memory_dict[node_id]["filling"] = self.filling
+ if self.cause_by:
+ memory_dict[node_id]["cause_by"] = self.cause_by
+
+ return memory_dict
class AgentMemory(Memory):
"""
@@ -68,25 +102,82 @@ class AgentMemory(Memory):
self.thought_keywords = dict()
self.chat_keywords = dict()
- self.strength_event_keywords = dict() # 不知道具体作用,所以没有删除
- self.strength_thought_keywords = dict()
+ self.kw_strength_event = dict() # 关键词影响存储
+ self.kw_strength_thought = dict()
- self.embeddings = json.load(open(memory_saved + "/embeddings.json"))
- self.load()
+ self.load(memory_saved)
- def save(self):
+ def save(self,memory_saved:str):
"""
将MemormyBasic类存储为Nodes.json形式。复现GA中的Kw Strength.json形式
- @张凯补充一个可调用的函数
+ 这里添加一个路径即可
"""
- pass
- def load(self):
+ memory_json = dict()
+ for i in range(len(self.storage)):
+ memory_node = self.storage[i]
+ memory_json.update(memory_node)
+ with open(memory_saved+"/nodes.json", "w") as outfile:
+ json.dump(memory_json, outfile)
+
+ with open(memory_saved+"/embeddings.json", "w") as outfile:
+ json.dump(self.embeddings, outfile)
+
+ strength_json = dict()
+ strength_json["kw_strength_event"] = self.kw_strength_event
+ strength_json["kw_strength_thought"] = self.kw_strength_thought
+ with open(memory_saved+"/kw_strength.json", "w") as outfile:
+ json.dump(strength_json, outfile)
+
+
+ def load(self,memory_saved:str):
"""
将GA的JSON解析,填充到AgentMemory类之中
"""
- pass
+ self.embeddings = json.load(open(memory_saved + "/embeddings.json"))
+ memory_load = json.load(open(memory_saved + "/nodes.json"))
+ for count in range(len(memory_load.keys())):
+ node_id = f"node_{str(count+1)}"
+ node_details = memory_load[node_id]
+ node_type = node_details["type"]
+ created = datetime.datetime.strptime(node_details["created"],
+ '%Y-%m-%d %H:%M:%S')
+ expiration = None
+ if node_details["expiration"]:
+ expiration = datetime.datetime.strptime(node_details["expiration"],
+ '%Y-%m-%d %H:%M:%S')
+
+ if node_details["cause_by"]:
+ cause_by = node_details["cause_by"]
+
+ s = node_details["subject"]
+ p = node_details["predicate"]
+ o = node_details["object"]
+
+ description = node_details["description"]
+ embedding_pair = (node_details["embedding_key"],
+ self.embeddings[node_details["embedding_key"]])
+ poignancy =node_details["poignancy"]
+ keywords = set(node_details["keywords"])
+ filling = node_details["filling"]
+
+ if node_type == "event":
+ self.add_event(created, expiration, s, p, o,
+ description, keywords, poignancy, embedding_pair, filling)
+ elif node_type == "chat":
+ self.add_chat(created, expiration, s, p, o,
+ description, keywords, poignancy, embedding_pair, filling,cause_by)
+ elif node_type == "thought":
+ self.add_thought(created, expiration, s, p, o,
+ description, keywords, poignancy, embedding_pair, filling)
+
+ strength_keywords_load = json.load(open(memory_saved + "/kw_strength.json"))
+ if strength_keywords_load["kw_strength_event"]:
+ self.kw_strength_event = strength_keywords_load["kw_strength_event"]
+ if strength_keywords_load["kw_strength_thought"]:
+ self.kw_strength_thought = strength_keywords_load["kw_strength_thought"]
+
def add(self, memory_basic: BasicMemory):
"""
@@ -97,37 +188,131 @@ class AgentMemory(Memory):
return
self.storage.append(memory_basic)
if memory_basic.cause_by:
- self.index[memory_basic.cause_by].append(memory_basic)
+ self.index[memory_basic.cause_by][0:0] = [memory_basic]
return
if memory_basic.type == "thought":
- self.thought_list.append(memory_basic)
+ self.thought_list[0:0] = [memory_basic]
return
if memory_basic.type == "event":
- self.event_list.append(memory_basic)
+ self.event_list[0:0] = [memory_basic]
- def add_chat(self):
+
+ def add_chat(self, created, expiration, s, p, o,
+ content, keywords, poignancy,
+ embedding_pair, filling,
+ cause_by):
"""
调用add方法,初始化chat,在创建的时候就需要调用embeeding函数
"""
- pass
+ memory_count = len(self.storage) + 1
+ type_count = len(self.thought_list) + 1
+ memory_type = "chat"
+ memory_id = f"memory_{str(memory_count)}"
+ depth = 1
- def add_thought(self):
+ memory_node = BasicMemory(memory_id, memory_count, type_count, memory_type, depth,
+ created, expiration,
+ s, p ,o,
+ content, embedding_pair[0],
+ poignancy, keywords, filling,
+ cause_by)
+
+ keywords = [i.lower() for i in keywords]
+ for kw in keywords:
+ if kw in self.chat_keywords:
+ self.chat_keywords[kw][0:0] = [memory_node]
+ else:
+ self.chat_keywords[kw] = [memory_node]
+
+ self.add(memory_node)
+
+ self.embeddings[embedding_pair[0]] = embedding_pair[1]
+ return memory_node
+
+
+ def add_thought(self, created, expiration, s, p, o,
+ content, keywords, poignancy,
+ embedding_pair, filling):
"""
调用add方法,初始化thought
"""
- pass
+ memory_count = len(self.storage) + 1
+ type_count = len(self.thought_list) + 1
+ memory_type = "event"
+ memory_id = f"memory_{str(memory_count)}"
+ depth = 1
+
+ try:
+ if filling:
+ depth_list = [memory_node.depth for memory_node in self.storage if memory_node.memory_id in filling ]
+ depth += max(depth_list)
+ except:
+ pass
- def add_event(self):
+ memory_node = BasicMemory(memory_id, memory_count, type_count, memory_type, depth,
+ created, expiration,
+ s, p ,o,
+ content, embedding_pair[0],
+ poignancy, keywords, filling)
+
+ keywords = [i.lower() for i in keywords]
+ for kw in keywords:
+ if kw in self.thought_keywords:
+ self.thought_keywords[kw][0:0] = [memory_node]
+ else:
+ self.thought_keywords[kw] = [memory_node]
+
+ self.add(memory_node)
+
+ if f"{p} {o}" != "is idle":
+ for kw in keywords:
+ if kw in self.kw_strength_thought:
+ self.kw_strength_thought[kw] += 1
+ else:
+ self.kw_strength_thought[kw] = 1
+
+ self.embeddings[embedding_pair[0]] = embedding_pair[1]
+ return memory_node
+
+
+ def add_event(self, created, expiration, s, p, o,
+ content, keywords, poignancy,
+ embedding_pair, filling):
"""
调用add方法,初始化event
"""
- pass
+ memory_count = len(self.storage) + 1
+ type_count = len(self.event_list) + 1
+ memory_type = "event"
+ memory_id = f"memory_{str(memory_count)}"
+ depth = 0
+
+ if "(" in content:
+ content = (" ".join(content.split()[:3])
+ + " "
+ + content.split("(")[-1][:-1])
+
+ memory_node = BasicMemory(memory_id, memory_count, type_count, memory_type, depth,
+ created, expiration,
+ s, p ,o,
+ content, embedding_pair[0],
+ poignancy, keywords, filling)
- def retrive(self,):
- """
- 调用
- """
- pass
+ keywords = [i.lower() for i in keywords]
+ for kw in keywords:
+ if kw in self.event_keywords:
+ self.event_keywords[kw][0:0] = [memory_node]
+ else:
+ self.event_keywords[kw] = [memory_node]
+
+ self.add(memory_node)
-if __name__ == "__main__":
-
\ No newline at end of file
+ if f"{p} {o}" != "is idle":
+ for kw in keywords:
+ if kw in self.kw_strength_event:
+ self.kw_strength_event[kw] += 1
+ else:
+ self.kw_strength_event[kw] = 1
+
+ self.embeddings[embedding_pair[0]] = embedding_pair[1]
+ return memory_node
diff --git a/examples/st_game/memory/scratch.py b/examples/st_game/memory/scratch.py
index 00da03dd6..d0d13002e 100644
--- a/examples/st_game/memory/scratch.py
+++ b/examples/st_game/memory/scratch.py
@@ -2,5 +2,536 @@
# -*- coding: utf-8 -*-
# @Desc : Scratch类实现(角色信息类)
-class Scratch():
- pass
\ No newline at end of file
+import datetime
+import json
+import sys
+sys.path.append('../../')
+
+from ..utils.check import check_if_file_exists
+
+class Scratch:
+ def __init__(self, f_saved):
+ # 类别1:人物超参
+ self.vision_r = 4
+ self.att_bandwidth = 3
+ self.retention = 5
+
+ # 类别2:世界信息
+ self.curr_time = None
+ self.curr_tile = None
+ self.daily_plan_req = None
+
+ # 类别3:人物角色的核心身份
+ self.name = None
+ self.first_name = None
+ self.last_name = None
+ self.age = None
+ # L0 permanent core traits.
+ self.innate = None
+ # L1 stable traits.
+ self.learned = None
+ # L2 external implementation.
+ self.currently = None
+ self.lifestyle = None
+ self.living_area = None
+
+ # 类别4:旧反思变量
+ self.concept_forget = 100
+ self.daily_reflection_time = 60 * 3
+ self.daily_reflection_size = 5
+ self.overlap_reflect_th = 2
+ self.kw_strg_event_reflect_th = 4
+ self.kw_strg_thought_reflect_th = 4
+
+ # 类别5:新反思变量
+ self.recency_w = 1
+ self.relevance_w = 1
+ self.importance_w = 1
+ self.recency_decay = 0.99
+ self.importance_trigger_max = 150
+ self.importance_trigger_curr = self.importance_trigger_max
+ self.importance_ele_n = 0
+ self.thought_count = 5
+
+ # 类别6:个人计划
+ self.daily_req = []
+ self.f_daily_schedule = []
+ self.f_daily_schedule_hourly_org = []
+
+ # 类别7:当前动作
+ self.act_address = None
+ self.act_start_time = None
+ self.act_duration = None
+ self.act_description = None
+ self.act_pronunciatio = None
+ self.act_event = (self.name, None, None)
+
+ self.act_obj_description = None
+ self.act_obj_pronunciatio = None
+ self.act_obj_event = (self.name, None, None)
+
+ self.chatting_with = None
+ self.chat = None
+ self.chatting_with_buffer = dict()
+ self.chatting_end_time = None
+
+ self.act_path_set = False
+ self.planned_path = []
+
+ if check_if_file_exists(f_saved):
+ # If we have a bootstrap file, load that here.
+ scratch_load = json.load(open(f_saved))
+
+ self.vision_r = scratch_load["vision_r"]
+ self.att_bandwidth = scratch_load["att_bandwidth"]
+ self.retention = scratch_load["retention"]
+
+ if scratch_load["curr_time"]:
+ self.curr_time = datetime.datetime.strptime(scratch_load["curr_time"],
+ "%B %d, %Y, %H:%M:%S")
+ else:
+ self.curr_time = None
+ self.curr_tile = scratch_load["curr_tile"]
+ self.daily_plan_req = scratch_load["daily_plan_req"]
+
+ self.name = scratch_load["name"]
+ self.first_name = scratch_load["first_name"]
+ self.last_name = scratch_load["last_name"]
+ self.age = scratch_load["age"]
+ self.innate = scratch_load["innate"]
+ self.learned = scratch_load["learned"]
+ self.currently = scratch_load["currently"]
+ self.lifestyle = scratch_load["lifestyle"]
+ self.living_area = scratch_load["living_area"]
+
+ self.concept_forget = scratch_load["concept_forget"]
+ self.daily_reflection_time = scratch_load["daily_reflection_time"]
+ self.daily_reflection_size = scratch_load["daily_reflection_size"]
+ self.overlap_reflect_th = scratch_load["overlap_reflect_th"]
+ self.kw_strg_event_reflect_th = scratch_load["kw_strg_event_reflect_th"]
+ self.kw_strg_thought_reflect_th = scratch_load["kw_strg_thought_reflect_th"]
+
+ self.recency_w = scratch_load["recency_w"]
+ self.relevance_w = scratch_load["relevance_w"]
+ self.importance_w = scratch_load["importance_w"]
+ self.recency_decay = scratch_load["recency_decay"]
+ self.importance_trigger_max = scratch_load["importance_trigger_max"]
+ self.importance_trigger_curr = scratch_load["importance_trigger_curr"]
+ self.importance_ele_n = scratch_load["importance_ele_n"]
+ self.thought_count = scratch_load["thought_count"]
+
+ self.daily_req = scratch_load["daily_req"]
+ self.f_daily_schedule = scratch_load["f_daily_schedule"]
+ self.f_daily_schedule_hourly_org = scratch_load["f_daily_schedule_hourly_org"]
+
+ self.act_address = scratch_load["act_address"]
+ if scratch_load["act_start_time"]:
+ self.act_start_time = datetime.datetime.strptime(
+ scratch_load["act_start_time"],
+ "%B %d, %Y, %H:%M:%S")
+ else:
+ self.curr_time = None
+ self.act_duration = scratch_load["act_duration"]
+ self.act_description = scratch_load["act_description"]
+ self.act_pronunciatio = scratch_load["act_pronunciatio"]
+ self.act_event = tuple(scratch_load["act_event"])
+
+ self.act_obj_description = scratch_load["act_obj_description"]
+ self.act_obj_pronunciatio = scratch_load["act_obj_pronunciatio"]
+ self.act_obj_event = tuple(scratch_load["act_obj_event"])
+
+ self.chatting_with = scratch_load["chatting_with"]
+ self.chat = scratch_load["chat"]
+ self.chatting_with_buffer = scratch_load["chatting_with_buffer"]
+ if scratch_load["chatting_end_time"]:
+ self.chatting_end_time = datetime.datetime.strptime(
+ scratch_load["chatting_end_time"],
+ "%B %d, %Y, %H:%M:%S")
+ else:
+ self.chatting_end_time = None
+
+ self.act_path_set = scratch_load["act_path_set"]
+ self.planned_path = scratch_load["planned_path"]
+
+
+ def save(self, out_json):
+ """
+ Save persona's scratch.
+
+ INPUT:
+ out_json: The file where we wil be saving our persona's state.
+ OUTPUT:
+ None
+ """
+ scratch = dict()
+ scratch["vision_r"] = self.vision_r
+ scratch["att_bandwidth"] = self.att_bandwidth
+ scratch["retention"] = self.retention
+
+ scratch["curr_time"] = self.curr_time.strftime("%B %d, %Y, %H:%M:%S")
+ scratch["curr_tile"] = self.curr_tile
+ scratch["daily_plan_req"] = self.daily_plan_req
+
+ scratch["name"] = self.name
+ scratch["first_name"] = self.first_name
+ scratch["last_name"] = self.last_name
+ scratch["age"] = self.age
+ scratch["innate"] = self.innate
+ scratch["learned"] = self.learned
+ scratch["currently"] = self.currently
+ scratch["lifestyle"] = self.lifestyle
+ scratch["living_area"] = self.living_area
+
+ scratch["concept_forget"] = self.concept_forget
+ scratch["daily_reflection_time"] = self.daily_reflection_time
+ scratch["daily_reflection_size"] = self.daily_reflection_size
+ scratch["overlap_reflect_th"] = self.overlap_reflect_th
+ scratch["kw_strg_event_reflect_th"] = self.kw_strg_event_reflect_th
+ scratch["kw_strg_thought_reflect_th"] = self.kw_strg_thought_reflect_th
+
+ scratch["recency_w"] = self.recency_w
+ scratch["relevance_w"] = self.relevance_w
+ scratch["importance_w"] = self.importance_w
+ scratch["recency_decay"] = self.recency_decay
+ scratch["importance_trigger_max"] = self.importance_trigger_max
+ scratch["importance_trigger_curr"] = self.importance_trigger_curr
+ scratch["importance_ele_n"] = self.importance_ele_n
+ scratch["thought_count"] = self.thought_count
+
+ scratch["daily_req"] = self.daily_req
+ scratch["f_daily_schedule"] = self.f_daily_schedule
+ scratch["f_daily_schedule_hourly_org"] = self.f_daily_schedule_hourly_org
+
+ scratch["act_address"] = self.act_address
+ scratch["act_start_time"] = (self.act_start_time
+ .strftime("%B %d, %Y, %H:%M:%S"))
+ scratch["act_duration"] = self.act_duration
+ scratch["act_description"] = self.act_description
+ scratch["act_pronunciatio"] = self.act_pronunciatio
+ scratch["act_event"] = self.act_event
+
+ scratch["act_obj_description"] = self.act_obj_description
+ scratch["act_obj_pronunciatio"] = self.act_obj_pronunciatio
+ scratch["act_obj_event"] = self.act_obj_event
+
+ scratch["chatting_with"] = self.chatting_with
+ scratch["chat"] = self.chat
+ scratch["chatting_with_buffer"] = self.chatting_with_buffer
+ if self.chatting_end_time:
+ scratch["chatting_end_time"] = (self.chatting_end_time
+ .strftime("%B %d, %Y, %H:%M:%S"))
+ else:
+ scratch["chatting_end_time"] = None
+
+ scratch["act_path_set"] = self.act_path_set
+ scratch["planned_path"] = self.planned_path
+
+ with open(out_json, "w") as outfile:
+ json.dump(scratch, outfile, indent=2)
+
+
+ def get_f_daily_schedule_index(self, advance=0):
+ """
+ We get the current index of self.f_daily_schedule.
+
+ Recall that self.f_daily_schedule stores the decomposed action sequences
+ up until now, and the hourly sequences of the future action for the rest
+ of today. Given that self.f_daily_schedule is a list of list where the
+ inner list is composed of [task, duration], we continue to add up the
+ duration until we reach "if elapsed > today_min_elapsed" condition. The
+ index where we stop is the index we will return.
+
+ INPUT
+ advance: Integer value of the number minutes we want to look into the
+ future. This allows us to get the index of a future timeframe.
+ OUTPUT
+ an integer value for the current index of f_daily_schedule.
+ """
+ # We first calculate teh number of minutes elapsed today.
+ today_min_elapsed = 0
+ today_min_elapsed += self.curr_time.hour * 60
+ today_min_elapsed += self.curr_time.minute
+ today_min_elapsed += advance
+
+ x = 0
+ for task, duration in self.f_daily_schedule:
+ x += duration
+ x = 0
+ for task, duration in self.f_daily_schedule_hourly_org:
+ x += duration
+
+ # We then calculate the current index based on that.
+ curr_index = 0
+ elapsed = 0
+ for task, duration in self.f_daily_schedule:
+ elapsed += duration
+ if elapsed > today_min_elapsed:
+ return curr_index
+ curr_index += 1
+
+ return curr_index
+
+
+ def get_f_daily_schedule_hourly_org_index(self, advance=0):
+ """
+ We get the current index of self.f_daily_schedule_hourly_org.
+ It is otherwise the same as get_f_daily_schedule_index.
+
+ INPUT
+ advance: Integer value of the number minutes we want to look into the
+ future. This allows us to get the index of a future timeframe.
+ OUTPUT
+ an integer value for the current index of f_daily_schedule.
+ """
+ # We first calculate teh number of minutes elapsed today.
+ today_min_elapsed = 0
+ today_min_elapsed += self.curr_time.hour * 60
+ today_min_elapsed += self.curr_time.minute
+ today_min_elapsed += advance
+ # We then calculate the current index based on that.
+ curr_index = 0
+ elapsed = 0
+ for task, duration in self.f_daily_schedule_hourly_org:
+ elapsed += duration
+ if elapsed > today_min_elapsed:
+ return curr_index
+ curr_index += 1
+ return curr_index
+
+
+ def get_str_iss(self):
+ """
+ ISS stands for "identity stable set." This describes the commonset summary
+ of this persona -- basically, the bare minimum description of the persona
+ that gets used in almost all prompts that need to call on the persona.
+
+ INPUT
+ None
+ OUTPUT
+ the identity stable set summary of the persona in a string form.
+ EXAMPLE STR OUTPUT
+ "Name: Dolores Heitmiller
+ Age: 28
+ Innate traits: hard-edged, independent, loyal
+ Learned traits: Dolores is a painter who wants live quietly and paint
+ while enjoying her everyday life.
+ Currently: Dolores is preparing for her first solo show. She mostly
+ works from home.
+ Lifestyle: Dolores goes to bed around 11pm, sleeps for 7 hours, eats
+ dinner around 6pm.
+ Daily plan requirement: Dolores is planning to stay at home all day and
+ never go out."
+ """
+ commonset = ""
+ commonset += f"Name: {self.name}\n"
+ commonset += f"Age: {self.age}\n"
+ commonset += f"Innate traits: {self.innate}\n"
+ commonset += f"Learned traits: {self.learned}\n"
+ commonset += f"Currently: {self.currently}\n"
+ commonset += f"Lifestyle: {self.lifestyle}\n"
+ commonset += f"Daily plan requirement: {self.daily_plan_req}\n"
+ commonset += f"Current Date: {self.curr_time.strftime('%A %B %d')}\n"
+ return commonset
+
+
+ def get_str_name(self):
+ return self.name
+
+
+ def get_str_firstname(self):
+ return self.first_name
+
+
+ def get_str_lastname(self):
+ return self.last_name
+
+
+ def get_str_age(self):
+ return str(self.age)
+
+
+ def get_str_innate(self):
+ return self.innate
+
+
+ def get_str_learned(self):
+ return self.learned
+
+
+ def get_str_currently(self):
+ return self.currently
+
+
+ def get_str_lifestyle(self):
+ return self.lifestyle
+
+
+ def get_str_daily_plan_req(self):
+ return self.daily_plan_req
+
+
+ def get_str_curr_date_str(self):
+ return self.curr_time.strftime("%A %B %d")
+
+
+ def get_curr_event(self):
+ if not self.act_address:
+ return (self.name, None, None)
+ else:
+ return self.act_event
+
+
+ def get_curr_event_and_desc(self):
+ if not self.act_address:
+ return (self.name, None, None, None)
+ else:
+ return (self.act_event[0],
+ self.act_event[1],
+ self.act_event[2],
+ self.act_description)
+
+
+ def get_curr_obj_event_and_desc(self):
+ if not self.act_address:
+ return ("", None, None, None)
+ else:
+ return (self.act_address,
+ self.act_obj_event[1],
+ self.act_obj_event[2],
+ self.act_obj_description)
+
+
+ def add_new_action(self,
+ action_address,
+ action_duration,
+ action_description,
+ action_pronunciatio,
+ action_event,
+ chatting_with,
+ chat,
+ chatting_with_buffer,
+ chatting_end_time,
+ act_obj_description,
+ act_obj_pronunciatio,
+ act_obj_event,
+ act_start_time=None):
+ self.act_address = action_address
+ self.act_duration = action_duration
+ self.act_description = action_description
+ self.act_pronunciatio = action_pronunciatio
+ self.act_event = action_event
+
+ self.chatting_with = chatting_with
+ self.chat = chat
+ if chatting_with_buffer:
+ self.chatting_with_buffer.update(chatting_with_buffer)
+ self.chatting_end_time = chatting_end_time
+
+ self.act_obj_description = act_obj_description
+ self.act_obj_pronunciatio = act_obj_pronunciatio
+ self.act_obj_event = act_obj_event
+
+ self.act_start_time = self.curr_time
+
+ self.act_path_set = False
+
+
+ def act_time_str(self):
+ """
+ Returns a string output of the current time.
+
+ INPUT
+ None
+ OUTPUT
+ A string output of the current time.
+ EXAMPLE STR OUTPUT
+ "14:05 P.M."
+ """
+ return self.act_start_time.strftime("%H:%M %p")
+
+
+ def act_check_finished(self):
+ """
+ Checks whether the self.Action instance has finished.
+
+ INPUT
+ curr_datetime: Current time. If current time is later than the action's
+ start time + its duration, then the action has finished.
+ OUTPUT
+ Boolean [True]: Action has finished.
+ Boolean [False]: Action has not finished and is still ongoing.
+ """
+ if not self.act_address:
+ return True
+
+ if self.chatting_with:
+ end_time = self.chatting_end_time
+ else:
+ x = self.act_start_time
+ if x.second != 0:
+ x = x.replace(second=0)
+ x = (x + datetime.timedelta(minutes=1))
+ end_time = (x + datetime.timedelta(minutes=self.act_duration))
+
+ if end_time.strftime("%H:%M:%S") == self.curr_time.strftime("%H:%M:%S"):
+ return True
+ return False
+
+
+ def act_summarize(self):
+ """
+ Summarize the current action as a dictionary.
+
+ INPUT
+ None
+ OUTPUT
+ ret: A human readable summary of the action.
+ """
+ exp = dict()
+ exp["persona"] = self.name
+ exp["address"] = self.act_address
+ exp["start_datetime"] = self.act_start_time
+ exp["duration"] = self.act_duration
+ exp["description"] = self.act_description
+ exp["pronunciatio"] = self.act_pronunciatio
+ return exp
+
+
+ def act_summary_str(self):
+ """
+ Returns a string summary of the current action. Meant to be
+ human-readable.
+
+ INPUT
+ None
+ OUTPUT
+ ret: A human readable summary of the action.
+ """
+ start_datetime_str = self.act_start_time.strftime("%A %B %d -- %H:%M %p")
+ ret = f"[{start_datetime_str}]\n"
+ ret += f"Activity: {self.name} is {self.act_description}\n"
+ ret += f"Address: {self.act_address}\n"
+ ret += f"Duration in minutes (e.g., x min): {str(self.act_duration)} min\n"
+ return ret
+
+
+ def get_str_daily_schedule_summary(self):
+ ret = ""
+ curr_min_sum = 0
+ for row in self.f_daily_schedule:
+ curr_min_sum += row[1]
+ hour = int(curr_min_sum/60)
+ minute = curr_min_sum%60
+ ret += f"{hour:02}:{minute:02} || {row[0]}\n"
+ return ret
+
+
+ def get_str_daily_schedule_hourly_org_summary(self):
+ ret = ""
+ curr_min_sum = 0
+ for row in self.f_daily_schedule_hourly_org:
+ curr_min_sum += row[1]
+ hour = int(curr_min_sum/60)
+ minute = curr_min_sum%60
+ ret += f"{hour:02}:{minute:02} || {row[0]}\n"
+ return ret
diff --git a/examples/st_game/roles/st_role.py b/examples/st_game/roles/st_role.py
index 047f9545b..8f0c47f7f 100644
--- a/examples/st_game/roles/st_role.py
+++ b/examples/st_game/roles/st_role.py
@@ -21,15 +21,18 @@ from ..memory.agent_memory import AgentMemory
from ..actions.dummy_action import DummyAction
from ..actions.user_requirement import UserRequirement
from ..maze_environment import MazeEnvironment
+from ..memory.retrieve import agent_retrieve
+from ..memory.scratch import Scratch
class STRoleContext(RoleContext):
env: 'MazeEnvironment' = Field(default=None)
memory: AgentMemory = Field(default=AgentMemory)
+ scratch: Scratch = Field(default=Scratch)
class STRole(Role):
-
+ # 继承Role类,Role类继承RoleContext,这里的逻辑需要认真考虑
# add a role's property structure to store role's age and so on like GA's Scratch.
def __init__(self,
@@ -65,6 +68,12 @@ class STRole(Role):
# TODO observe info from maze_env
pass
+ async def retrieve(self, query, n = 30 ,topk = 4):
+ # TODO retrieve memories from agent_memory
+ retrieve_memories = agent_retrieve(self._rc.memory, self._rc.scratch.curr_time, self._rc.scratch.recency_decay, query, n, topk)
+ return retrieve_memories
+
+
async def plan(self):
# TODO make a plan
diff --git a/examples/st_game/utils/check.py b/examples/st_game/utils/check.py
new file mode 100644
index 000000000..0a806fe2d
--- /dev/null
+++ b/examples/st_game/utils/check.py
@@ -0,0 +1,14 @@
+def check_if_file_exists(curr_file):
+ """
+ Checks if a file exists
+ ARGS:
+ curr_file: path to the current csv file.
+ RETURNS:
+ True if the file exists
+ False if the file does not exist
+ """
+ try:
+ with open(curr_file) as f_analysis_file: pass
+ return True
+ except:
+ return False
\ No newline at end of file