From 770dcdc755cd3f34d55bc7f54da90a1f81a628a0 Mon Sep 17 00:00:00 2001
From: didi <2020201387@ruc.edu.cn>
Date: Thu, 28 Sep 2023 00:17:32 +0800
Subject: [PATCH] ga_game Memory & Retrive
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
1. 完善了Memory模块,明天添加不同类型记忆的add方法
2. 添加了Retrive方法
3. 添加了Prompt,Scracth等模块
---
examples/st_game/memory/associative_memory.py | 125 +++++++++++++++-
examples/st_game/memory/retrive.py | 138 ++++++++++++++++++
examples/st_game/memory/scratch.py | 6 +
.../prompts_templates/poignancy_chat_v1.txt | 17 +++
examples/st_game/prompts/run_gpt_prompts.py | 33 +++++
examples/st_game/prompts/wrapper_prompt.py | 42 ++++++
examples/st_game/roles/st_role.py | 4 +-
examples/st_game/utils/utils.py | 9 ++
8 files changed, 369 insertions(+), 5 deletions(-)
create mode 100644 examples/st_game/memory/retrive.py
create mode 100644 examples/st_game/memory/scratch.py
create mode 100644 examples/st_game/prompts/prompts_templates/poignancy_chat_v1.txt
create mode 100644 examples/st_game/prompts/run_gpt_prompts.py
create mode 100644 examples/st_game/prompts/wrapper_prompt.py
diff --git a/examples/st_game/memory/associative_memory.py b/examples/st_game/memory/associative_memory.py
index ef1514398..5b78926b9 100644
--- a/examples/st_game/memory/associative_memory.py
+++ b/examples/st_game/memory/associative_memory.py
@@ -1,9 +1,128 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
-# @Desc : associative_memory to store conversation、plan detail、reflection result and so on.
+# @Desc : MemoryBasic,AgentMemory实现
from metagpt.memory.memory import Memory
+from metagpt.schema import Message
+import json
+from datetime import datetime
+
+class MemoryBasic(Message):
+
+ def __init__(self,memory_id:str,memory_count:int,type_count:int,memory_type:str,depth:int,content:int,
+ creaetd:datetime,expiration:datetime,
+ subject:str,predicate:str,object:str,
+ embedding_key:str,poignancy:int,keywords:list,filling:list):
+ """
+ MemoryBasic继承于MG的Message类,其中content属性替代description属性
+ Message类中对于Chat类型支持的非常好,对于Agent个体的Perceive,Reflection,Plan支持的并不多
+ 在Type设计上,我们延续GA的三个种类,但是对于Chat种类的对话进行特别设计(具体怎么设计还没想好)
+ """
+ super.__init__(content)
+ """
+ 从父类中继承的属性
+ content: str # 记忆描述
+ cause_by: Type["Action"] = field(default="") # 触发动作,只在Type为chat时初始化
+ cause_by 接受一个Action类,在此项目中,每个Agent需要有一个基础动作[Receive] 用于接受假对话Message;而每个Agent需要有独一无二的动作类,用以接受真对话Message
+ """
+ self.memory_id: str = memory_id # 记忆ID
+ self.memory_count: int = memory_count # 第几个记忆,实际数值与Memory相等,但是类型为整数
+ self.type_count: int = type_count # 第几种记忆,类型为整数(具体不太理解如何生成的)
+ self.memory_type: str = memory_type # 记忆类型,使用Field,包含 event,thought,chat三种类型
+ self.depth:str = depth # 记忆深度,类型为整数
+
+ self.created: datetime = creaetd # 创建时间
+ self.expiration: datetime = expiration # 记忆失效时间,默认为空()
+ self.last_accessed: datetime = creaetd # 上一次调用的时间,初始化时候与self.created一致
+
+ self.subject: str = subject # 主语,str类型
+ self.predicate:str = predicate # 谓语,str类型
+ self.object:str = object # 宾语,str类型
+
+ self.embedding_key: str = embedding_key # 内容与self.content一致
+ self.poignancy:int = poignancy # importance值,整数类型
+ self.keywords:list = keywords # keywords,列表
+ self.filling:list = filling # None或者列表
+
+class AgentMemory(Memory):
+ """
+ GA中主要存储三种JSON
+ 1. embedding.json (Dict embedding_key:embedding)
+ 2. Node.json (Dict Node_id:Node)
+ 3. kw_strength.json
+ """
+ def __init__(self,memory_saved:str):
+ """
+ AgentMemory类继承自Memory类,重写storage替代GA中id_to_node,一方面存储所有信息,一方面作为JSON转化
+ index存储与不同Agent的chat信息
+ @李嵩@张凯 这里的storage是List,你们需要写一个JSON转化器,将List修改为node.json一致的格式
+ """
+ super.__init__()
+ self.storage: list[MemoryBasic] = [] # 重写Stroage,存储MemoryBasic所有节点
+ self.event_list = [] # 存储event记忆
+ self.thought_list = [] # 存储thought记忆
+
+ self.event_keywords = dict() # 存储keywords
+ self.thought_keywords = dict()
+ self.chat_keywords = dict()
+
+ self.strength_event_keywords = dict() # 不知道具体作用,所以没有删除
+ self.strength_thought_keywords = dict()
+
+ self.embeddings = json.load(open(memory_saved + "/embeddings.json")) # dict类型,load embedding.json
+ self.memory_load()
-class AssociativeMemory(Memory):
- pass
+ def memory_save(self):
+ """
+ 将MemormyBasic类存储为Nodes.json形式。复现GA中的Kw Strength.json形式
+ @张凯补充一个可调用的函数
+ """
+ pass
+
+ def memory_load(self):
+ """
+ 将GA的JSON解析,填充到AgentMemory类之中
+ """
+ pass
+
+ def add(self, memory_basic: MemoryBasic):
+ """
+ Add a new message to storage, while updating the index
+ 重写add方法,修改原有的Message类为MemoryBasic类,并添加不同的记忆类型添加方式
+ """
+ if memory_basic in self.storage:
+ return
+ self.storage.append(memory_basic)
+ if memory_basic.cause_by:
+ self.index[memory_basic.cause_by].append(memory_basic)
+ return
+ if memory_basic.type == "thought":
+ self.thought_list.append(memory_basic)
+ return
+ if memory_basic.type == "event":
+ self.event_list.append(memory_basic)
+
+ def add_chat(self):
+ """
+ 调用add方法,初始化chat,在创建的时候就需要调用embeeding函数
+ """
+ pass
+
+ def add_thought(self):
+ """
+ 调用add方法,初始化thought
+ """
+ pass
+
+ def add_event(self):
+ """
+ 调用add方法,初始化event
+ """
+ pass
+
+ def retrive(self,):
+ """
+ 调用
+ """
+ pass
diff --git a/examples/st_game/memory/retrive.py b/examples/st_game/memory/retrive.py
new file mode 100644
index 000000000..0bca9ea8b
--- /dev/null
+++ b/examples/st_game/memory/retrive.py
@@ -0,0 +1,138 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Desc : Retrive函数实现
+
+from numpy import dot
+from numpy.linalg import norm
+from datetime import datetime
+from associative_memory import AgentMemory,MemoryBasic
+from utils.utils import embedding_tools
+
+def agent_retrive(agent:AgentMemory,currtime:datetime,memory_forget:float,query:str,n:int= 30,topk:int=4) -> list[MemoryBasic]:
+ """
+ retrive需要集合Role使用,原因在于Role才具有AgentMemory,scratch
+ 逻辑:Role调用该函数,self._rc.AgentMemory,self._rc.scratch.currtime,self._rc.scratch.memory_forget
+ 输入希望查询的内容与希望回顾的条数,返回TopK条高分记忆,即List[MemoryBasic]
+
+ Score_lists示例
+ {
+ "memory":memories[i], MemoryBasic类
+ "importance":memories[i].poignancy
+ "recency":衰减因子计算结果
+ "relevance":搜索结果
+ }
+ """
+ memories = AgentMemory.storage
+ sorted_memories = sorted(memories, key=lambda memory_node: memory_node.last_accessed_time,reverse=True)
+ memories = sorted_memories[:n] if len(sorted_memories) >= n else sorted_memories
+
+ Score_list = []
+ Score_list = extract_importance(memories,Score_list)
+ Score_list = extract_recency(currtime,memory_forget,Score_list)
+ Score_list = extract_relevance(query,Score_list)
+ Score_list = normalize_Socre_floats(Score_list,0,1)
+
+ total_dict = {}
+ gw = [1,1,1] # 三个因素的权重,重要性,近因性,相关性
+ for i in range(len(Score_list)):
+ total_score = (Score_list[i]['importance']*gw[0] +
+ Score_list[i]['recency']*gw[1] +
+ Score_list[i]['relevance']*gw[2]
+ )
+ total_dict[Score_list[i]['memory']] = total_score
+
+ result = top_highest_x_values(total_dict,topk)
+
+ return result
+
+def top_highest_x_values(d, x):
+ """
+ 输入字典,Topx
+ 返回以字典值排序,字典键组成的List[MemoryBasic]
+ """
+ top_v = [item[0] for item in sorted(d.items(),key=lambda item: item[1],reverse= True)[:x]]
+ return top_v
+
+
+def extract_importance(memories,Score_list):
+ """
+ 抽取重要性
+ """
+ for i in range(len(memories)):
+ Score = {"memory":memories[i],
+ "importance":memories[i].poignancy
+ }
+ Score_list.append(Score)
+ return Score_list
+
+# 抽取相关性
+def extract_relevance(query,Score_list):
+ """
+ 抽取相关性
+ """
+ query_embedding = embedding_tools(query)
+ # 进行
+ for i in range(len(Score_list)):
+ result = cos_sim(Score_list[i]["memory"].embedding_key,query_embedding)
+ Score_list[i]['relevance'] = result
+
+ return Score_list
+
+# 抽取近因性
+def extract_recency(currtime,memory_forget,Score_list):
+ """
+ 抽取近因性,目前使用的现实世界过一天走一个衰减因子
+ """
+ for i in range(len(Score_list)):
+ day_count = (currtime-Score_list[i]['memory'].created).days
+ Score_list[i]['recency'] = memory_forget**day_count
+ return Score_list
+
+def cos_sim(a, b):
+ """
+ 计算余弦相似度
+ """
+ return dot(a, b)/(norm(a)*norm(b))
+
+def normalize_List_floats(Single_list,target_min, target_max):
+ """
+ 单个列表归一化
+ """
+ min_val = min(Single_list)
+ max_val = max(Single_list)
+ range_val = max_val - min_val
+
+ if range_val == 0:
+ for i in range(len(Single_list)):
+ Single_list[i] = (target_max - target_min)/2
+ else:
+ for i in range(len(Single_list)):
+ Single_list[i] = ((Single_list[i] - min_val) * (target_max - target_min)
+ / range_val + target_min)
+ return Single_list
+
+
+def normalize_Socre_floats(Score_list, target_min, target_max):
+ """
+ 整体归一化
+ """
+ importance_list = []
+ relevance_list = []
+ recency_list = []
+
+ for i in range(len(Score_list)):
+ importance_list.append(Score_list[i]['importance'])
+ relevance_list.append(Score_list[i]['relevance'])
+ recency_list.append(Score_list[i]['recency'])
+
+ # 进行归一化操作
+ importance_list = normalize_List_floats(importance_list,target_min, target_max)
+ relevance_list = normalize_List_floats(relevance_list,target_min, target_max)
+ recency_list =normalize_List_floats(recency_list,target_min, target_max)
+
+ for i in range(len(Score_list)):
+ Score_list[i]['importance'] = importance_list[i]
+ Score_list[i]['relevance'] = relevance_list[i]
+ Score_list[i]['recency'] = recency_list[i]
+
+ return Score_list
\ No newline at end of file
diff --git a/examples/st_game/memory/scratch.py b/examples/st_game/memory/scratch.py
new file mode 100644
index 000000000..00da03dd6
--- /dev/null
+++ b/examples/st_game/memory/scratch.py
@@ -0,0 +1,6 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Desc : Scratch类实现(角色信息类)
+
+class Scratch():
+ pass
\ No newline at end of file
diff --git a/examples/st_game/prompts/prompts_templates/poignancy_chat_v1.txt b/examples/st_game/prompts/prompts_templates/poignancy_chat_v1.txt
new file mode 100644
index 000000000..572dd8a05
--- /dev/null
+++ b/examples/st_game/prompts/prompts_templates/poignancy_chat_v1.txt
@@ -0,0 +1,17 @@
+poignancy_chat_v1.txt
+
+!!: agent name
+!!: iss
+!!: name
+!!: event description
+
+###
+Here is a brief description of !!.
+!!
+
+On the scale of 1 to 10, where 1 is purely mundane (e.g., routine morning greetings) and 10 is extremely poignant (e.g., a conversation about breaking up, a fight), rate the likely poignancy of the following conversation for !!.
+
+Conversation:
+!!
+
+Rate (return a number between 1 to 10):
\ No newline at end of file
diff --git a/examples/st_game/prompts/run_gpt_prompts.py b/examples/st_game/prompts/run_gpt_prompts.py
new file mode 100644
index 000000000..4c94f3bea
--- /dev/null
+++ b/examples/st_game/prompts/run_gpt_prompts.py
@@ -0,0 +1,33 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Desc : 调用PromptTemplates中模板,实现
+
+from wrapper_prompt import special_response_generate,prompt_generate
+from memory.scratch import Scratch
+from memory.associative_memory import MemoryBasic
+import json
+
+def run_gpt_prompt_chat_poignancy(scratch:Scratch,content:MemoryBasic.content)->str:
+ """
+ 衡量事件心酸度
+ """
+ def create_prompt_input(scratch,content):
+ prompt_input = [scratch.name,
+ scratch.iss,
+ scratch.name,
+ content]
+ return prompt_input
+
+ # 1. Prompt构建
+ # 2. Instruction给出
+ prompt_template = "prompt_templates/poignancy_chat_v1.txt" ########
+ prompt_input = create_prompt_input(scratch, content) ########
+ prompt = prompt_generate(prompt_input, prompt_template)
+ special_instruction = "The output should ONLY contain ONE integer value on the scale of 1 to 10."
+ poignancy = special_response_generate(prompt,special_instruction)
+ try:
+ poi_dict = json.loads(poignancy)
+ return (poi_dict['poignancy'])
+ except:
+ return poignancy
+
diff --git a/examples/st_game/prompts/wrapper_prompt.py b/examples/st_game/prompts/wrapper_prompt.py
new file mode 100644
index 000000000..b61e13520
--- /dev/null
+++ b/examples/st_game/prompts/wrapper_prompt.py
@@ -0,0 +1,42 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# @Desc : 基于Prmopt Templates 填充Prompt; 为Prompt包装与调用
+
+from metagpt import llm
+
+def prompt_generate(curr_input:list, prompt_path:str):
+ """
+ curr_input:输入一个按照PromptTemplate的要求的列表
+ prompt_path:输入一个Promptpath
+ """
+ if type(curr_input) == type("string"):
+ curr_input = [curr_input]
+ curr_input = [str(i) for i in curr_input]
+
+ f = open(prompt_path, "r")
+ prompt = f.read()
+ f.close()
+ for count, i in enumerate(curr_input):
+ prompt = prompt.replace(f"!!", i)
+ if "###" in prompt:
+ prompt = prompt.split("###")[1]
+ return prompt.strip()
+
+def response_generate(prompt:str):
+ """
+ 待完善,我没有找到MG中可以设置Temprature以及Maxtoken的位置
+ """
+ return llm.ai_func(prompt)
+
+def special_response_generate(prompt:str,special_instruction:str,example_output:str = None):
+ """
+ 当对于Prompt生成有特殊要求时,调用该函数增加special_instruction或example_output
+ """
+ prompt = '"""\n' + prompt + '\n"""\n'
+ prompt += f"Output the response to the prompt above in json. {special_instruction}\n"
+ if example_output:
+ prompt += "Example output json:\n"
+ prompt += '{"output": "' + str(example_output) + '"}'
+ return response_generate(prompt)
+
+
diff --git a/examples/st_game/roles/st_role.py b/examples/st_game/roles/st_role.py
index fa071f069..bc6988e28 100644
--- a/examples/st_game/roles/st_role.py
+++ b/examples/st_game/roles/st_role.py
@@ -17,7 +17,7 @@ from pathlib import Path
from metagpt.roles.role import Role, RoleContext
from metagpt.schema import Message
-from ..memory.associative_memory import AssociativeMemory
+from ..memory.associative_memory import AgentMemory
from ..actions.dummy_action import DummyAction
from ..actions.user_requirement import UserRequirement
from ..maze_environment import MazeEnvironment
@@ -25,7 +25,7 @@ from ..maze_environment import MazeEnvironment
class STRoleContext(RoleContext):
env: 'MazeEnvironment' = Field(default=None)
- memory: AssociativeMemory = Field(default=AssociativeMemory)
+ memory: AgentMemory = Field(default=AgentMemory)
class STRole(Role):
diff --git a/examples/st_game/utils/utils.py b/examples/st_game/utils/utils.py
index a70f7606d..11cbabd8e 100644
--- a/examples/st_game/utils/utils.py
+++ b/examples/st_game/utils/utils.py
@@ -4,6 +4,7 @@
from typing import Any
import json
+import openai
from pathlib import Path
@@ -22,3 +23,11 @@ def read_json_file(json_file: str, encoding=None) -> list[Any]:
def write_json_file(json_file: str, data: list, encoding=None):
with open(json_file, "w", encoding=encoding) as fout:
json.dump(data, fout, ensure_ascii=False, indent=4)
+
+def embedding_tools(query):
+ embedding_result = openai.Embedding.create(
+ model="text-embedding-ada-002",
+ input=query
+ )
+ embedding_key = embedding_result['data'][0]["embedding"]
+ return embedding_key
\ No newline at end of file