From 4008ce3b1503d524a51a9cbf392ae476b8f1178f Mon Sep 17 00:00:00 2001
From: didi <2020201387@ruc.edu.cn>
Date: Tue, 3 Oct 2023 18:30:00 +0800
Subject: [PATCH] add_inner_voice & reflect_role & agent_memory bug fixs
---
examples/__init__.py | 0
.../st_game/actions/inner_voice_action.py | 37 +++++++++++
.../st_game/actions/run_reflect_action.py | 6 +-
examples/st_game/memory/agent_memory.py | 28 ++++----
examples/st_game/memory/retrieve.py | 1 -
.../prompts/whisper_inner_thought_v1.txt | 11 ++++
examples/st_game/reflect/reflect.py | 6 +-
examples/st_game/roles/st_role.py | 30 +++++++--
.../st_game/tests/actions/test_reflect.py | 0
.../st_game/tests/actions/test_retrieve.py | 0
examples/st_game/tests/test_memory.py | 65 +++++++++++++++++++
11 files changed, 158 insertions(+), 26 deletions(-)
create mode 100644 examples/__init__.py
create mode 100644 examples/st_game/actions/inner_voice_action.py
create mode 100644 examples/st_game/prompts/whisper_inner_thought_v1.txt
create mode 100644 examples/st_game/tests/actions/test_reflect.py
create mode 100644 examples/st_game/tests/actions/test_retrieve.py
create mode 100644 examples/st_game/tests/test_memory.py
diff --git a/examples/__init__.py b/examples/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/examples/st_game/actions/inner_voice_action.py b/examples/st_game/actions/inner_voice_action.py
new file mode 100644
index 000000000..2a9bb0afc
--- /dev/null
+++ b/examples/st_game/actions/inner_voice_action.py
@@ -0,0 +1,37 @@
+import re
+from examples.st_game.roles.st_role import STRole
+from examples.st_game.actions.st_action import STAction
+from examples.st_game.memory.agent_memory import BasicMemory
+from metagpt.logs import logger
+
+class AgentWhisperThoughtAction(STAction):
+
+ def __init__(self, name="AgentWhisperThoughtAction", context: list[BasicMemory] = None, llm=None):
+ super().__init__(name, context, llm)
+
+ def _func_validate(self, llm_resp: str, prompt: str) -> bool:
+ try:
+ self._func_cleanup(llm_resp, prompt)
+ return True
+ except:
+ return False
+
+ def _func_cleanup(self, llm_resp: str, prompt: str = "") -> list:
+ return llm_resp.split('"')[0].strip()
+
+ def _func_fail_default_resp(self) -> str:
+ pass
+
+ async def run(self, role: STRole, statements: str, test_input=None, verbose=False) -> str:
+ def create_prompt_input(role: STRole, statements, test_input=None):
+ prompt_input = [role.scratch.name, statements]
+ return prompt_input
+
+ prompt_input = create_prompt_input(role, statements)
+ prompt = self.generate_prompt_with_tmpl_filename(prompt_input,
+ "whisper_inner_thought_v1.txt")
+
+ output = await self._run_v1(prompt)
+ logger.info(f"Run action: {self.__class__.__name__} with result: {output}")
+ return output
+
diff --git a/examples/st_game/actions/run_reflect_action.py b/examples/st_game/actions/run_reflect_action.py
index ab83a22ee..923cf68eb 100644
--- a/examples/st_game/actions/run_reflect_action.py
+++ b/examples/st_game/actions/run_reflect_action.py
@@ -22,7 +22,7 @@ class AgentFocusPt(STAction):
except:
return False
- def _func_cleanup(self, llm_resp: str, prompt: str = "") -> str:
+ def _func_cleanup(self, llm_resp: str, prompt: str = "") -> list:
llm_resp = "1) " + llm_resp.strip()
ret = []
for i in llm_resp.split("\n"):
@@ -145,7 +145,7 @@ class AgentEventPoignancy(STAction):
except:
return False
- def _func_cleanup(self, llm_resp: str, prompt: str = "") -> str:
+ def _func_cleanup(self, llm_resp: str, prompt: str = "") -> int:
llm_resp = int(llm_resp.strip())
return llm_resp
@@ -186,7 +186,7 @@ class AgentChatPoignancy(STAction):
except:
return False
- def _func_cleanup(self, llm_resp: str, prompt: str = "") -> str:
+ def _func_cleanup(self, llm_resp: str, prompt: str = "") -> int:
llm_resp = int(llm_resp.strip())
return llm_resp
diff --git a/examples/st_game/memory/agent_memory.py b/examples/st_game/memory/agent_memory.py
index 60aa4ae81..617603195 100644
--- a/examples/st_game/memory/agent_memory.py
+++ b/examples/st_game/memory/agent_memory.py
@@ -33,7 +33,7 @@ class BasicMemory(Message):
self.memory_count: int = memory_count # 第几个记忆,实际数值与Memory相等
self.type_count: int = type_count # 第几种记忆,类型为整数(具体不太理解如何生成的)
self.memory_type: str = memory_type # 记忆类型,包含 event,thought,chat三种类型
- self.depth: str = depth # 记忆深度,类型为整数
+ self.depth: int = depth # 记忆深度,类型为整数
self.created: datetime = created # 创建时间
self.expiration: datetime = expiration # 记忆失效时间,默认为空()
@@ -62,10 +62,10 @@ class BasicMemory(Message):
memory_dict[node_id] = dict()
memory_dict[node_id]["node_count"] = self.memory_count
memory_dict[node_id]["type_count"] = self.type_count
- memory_dict[node_id]["type"] = self.type
+ memory_dict[node_id]["type"] = self.memory_type
memory_dict[node_id]["depth"] = self.depth
- memory_dict[node_id]["cmemory_dicteated"] = self.created.strftime('%Y-%m-%d %H:%M:%S')
+ memory_dict[node_id]["created"] = self.created.strftime('%Y-%m-%d %H:%M:%S')
memory_dict[node_id]["expiration"] = None
if self.expiration:
memory_dict[node_id]["expiration"] = (self.expiration
@@ -75,7 +75,7 @@ class BasicMemory(Message):
memory_dict[node_id]["predicate"] = self.predicate
memory_dict[node_id]["object"] = self.object
- memory_dict[node_id]["description"] = self.description
+ memory_dict[node_id]["description"] = self.content
memory_dict[node_id]["embedding_key"] = self.embedding_key
memory_dict[node_id]["poignancy"] = self.poignancy
memory_dict[node_id]["keywords"] = list(self.keywords)
@@ -102,7 +102,7 @@ class AgentMemory(Memory):
"""
super(AgentMemory, self).__init__()
self.id_to_node = dict() # TODO jiayi add
- self.storage: list[BasicMemory] = [] # 重写Stroage,存储BasicMemory所有节点
+ self.storage: list[BasicMemory] = [] # 重写Storage,存储BasicMemory所有节点
self.event_list = [] # 存储event记忆
self.thought_list = [] # 存储thought记忆
self.chat_list = [] # chat-related memory
@@ -122,7 +122,7 @@ class AgentMemory(Memory):
def save(self, memory_saved: str):
"""
- 将MemormyBasic类存储为Nodes.json形式。复现GA中的Kw Strength.json形式
+ 将MemoryBasic类存储为Nodes.json形式。复现GA中的Kw Strength.json形式
这里添加一个路径即可
"""
@@ -152,16 +152,13 @@ class AgentMemory(Memory):
node_id = f"node_{str(count + 1)}"
node_details = memory_load[node_id]
node_type = node_details["type"]
- created = datetime.datetime.strptime(node_details["created"],
+ created = datetime.strptime(node_details["created"],
'%Y-%m-%d %H:%M:%S')
expiration = None
if node_details["expiration"]:
- expiration = datetime.datetime.strptime(node_details["expiration"],
+ expiration = datetime.strptime(node_details["expiration"],
'%Y-%m-%d %H:%M:%S')
- if node_details["cause_by"]:
- cause_by = node_details["cause_by"]
-
s = node_details["subject"]
p = node_details["predicate"]
o = node_details["object"]
@@ -177,6 +174,7 @@ class AgentMemory(Memory):
self.add_event(created, expiration, s, p, o,
description, keywords, poignancy, embedding_pair, filling)
elif node_type == "chat":
+ cause_by = node_details["cause_by"]
self.add_chat(created, expiration, s, p, o,
description, keywords, poignancy, embedding_pair, filling, cause_by)
elif node_type == "thought":
@@ -200,10 +198,10 @@ class AgentMemory(Memory):
if memory_basic.cause_by:
self.index[memory_basic.cause_by][0:0] = [memory_basic]
return
- if memory_basic.type == "thought":
+ if memory_basic.memory_type == "thought":
self.thought_list[0:0] = [memory_basic]
return
- if memory_basic.type == "event":
+ if memory_basic.memory_type == "event":
self.event_list[0:0] = [memory_basic]
def add_chat(self, created, expiration, s, p, o,
@@ -211,7 +209,7 @@ class AgentMemory(Memory):
embedding_pair, filling,
cause_by):
"""
- 调用add方法,初始化chat,在创建的时候就需要调用embeeding函数
+ 调用add方法,初始化chat,在创建的时候就需要调用embedding函数
"""
memory_count = len(self.storage) + 1
type_count = len(self.thought_list) + 1
@@ -330,7 +328,7 @@ class AgentMemory(Memory):
ret_set.add(e_node.summary())
return ret_set
- def get_last_chat(self, target_role_name: str) -> str:
+ def get_last_chat(self, target_role_name: str):
if target_role_name.lower() in self.chat_keywords:
return self.chat_keywords[target_role_name.lower()][0]
else:
diff --git a/examples/st_game/memory/retrieve.py b/examples/st_game/memory/retrieve.py
index 9f19a41d4..0656d5c05 100644
--- a/examples/st_game/memory/retrieve.py
+++ b/examples/st_game/memory/retrieve.py
@@ -3,7 +3,6 @@
# @Desc : Retrieve函数实现
import datetime
-from typing import Union
from numpy import dot
from numpy.linalg import norm
diff --git a/examples/st_game/prompts/whisper_inner_thought_v1.txt b/examples/st_game/prompts/whisper_inner_thought_v1.txt
new file mode 100644
index 000000000..b1ed50aaa
--- /dev/null
+++ b/examples/st_game/prompts/whisper_inner_thought_v1.txt
@@ -0,0 +1,11 @@
+whisper_inner_thought_v1.txt
+
+Variables:
+!! -- init persona name
+!! -- whisper
+
+###
+Translate the following thought into a statement about !!.
+
+Thought: "!!"
+Statement: "
\ No newline at end of file
diff --git a/examples/st_game/reflect/reflect.py b/examples/st_game/reflect/reflect.py
index 6c19cf3fc..6cb7d86f2 100644
--- a/examples/st_game/reflect/reflect.py
+++ b/examples/st_game/reflect/reflect.py
@@ -5,7 +5,7 @@
import datetime
from metagpt.logs import logger
-
+from examples.st_game.roles.st_role import STRole
from examples.st_game.utils.utils import get_embedding
from examples.st_game.actions.run_reflect_action import (
AgentFocusPt, AgentInsightAndGuidance, AgentEventTriple,
@@ -62,7 +62,7 @@ def generate_action_event_triple(act_desp, role):
return AgentEventTriple(act_desp, role)
-def generate_poig_score(role: "STRole", event_type, description):
+def generate_poig_score(role: STRole, event_type, description):
if "is idle" in description:
return 1
@@ -167,7 +167,7 @@ def reset_reflection_counter(role: "STRole"):
# Question 1 chat函数
-def reflect(role: "STRole"):
+def role_reflect(role: "STRole"):
"""
The main reflection module for the role. We first check if the trigger
conditions are met, and if so, run the reflection and reset any of the
diff --git a/examples/st_game/roles/st_role.py b/examples/st_game/roles/st_role.py
index bd96b70c6..43e2e0472 100644
--- a/examples/st_game/roles/st_role.py
+++ b/examples/st_game/roles/st_role.py
@@ -34,6 +34,9 @@ from examples.st_game.utils.utils import get_embedding, path_finder
from examples.st_game.utils.const import collision_block_id, STORAGE_PATH
from examples.st_game.reflect.reflect import generate_poig_score
from examples.st_game.utils.mg_ga_transform import save_movement, get_role_environment
+from examples.st_game.actions.inner_voice_action import AgentWhisperThoughtAction
+from examples.st_game.actions.run_reflect_action import AgentEventTriple
+from examples.st_game.reflect.reflect import role_reflect
class STRoleContext(RoleContext):
@@ -122,13 +125,32 @@ class STRole(Role):
if len(self._rc.news) == 1 and isinstance(self._rc.news[0], UserRequirement):
# add inner voice
# TODO
+ self.add_inner_voice(self._rc.news[0].content)
logger.warning(f"Role: {self.name} add inner voice: {self._rc.news[0].content}")
return 1 # always return 1 to execute role's `_react`
- def add_inner_voice(self):
+ def add_inner_voice(self, whisper):
# TODO
- pass
+ def generate_inner_thought(strole: STRole, whisper):
+ run_whisper_thought = AgentWhisperThoughtAction()
+ inner_thought = run_whisper_thought.run(self, whisper)
+ return inner_thought
+
+ whisper = input("Enter Input: ")
+ thought = generate_inner_thought(whisper)
+
+ created = self._rc.scratch.curr_time
+ expiration = self._rc.scratch.curr_time + datetime.timedelta(days=30)
+ run_event_triple = AgentEventTriple()
+ s, p, o = run_event_triple(thought, self)
+ keywords = set([s, p, o])
+ thought_poignancy = generate_poig_score(self, "event", whisper)
+ thought_embedding_pair = (thought, get_embedding(thought))
+ self._rc.memory.add_thought(created, expiration, s, p, o,
+ thought, keywords, thought_poignancy,
+ thought_embedding_pair, None)
+
async def observe(self) -> list[BasicMemory]:
# TODO observe info from maze_env
@@ -288,9 +310,9 @@ class STRole(Role):
async def reflect(self):
# TODO reflection if meet reflect condition
-
+ role_reflect(self)
# TODO re-add result to memory
- pass
+ # 已封装到Reflect函数之中
def execute(self, plan: str):
"""
diff --git a/examples/st_game/tests/actions/test_reflect.py b/examples/st_game/tests/actions/test_reflect.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/examples/st_game/tests/actions/test_retrieve.py b/examples/st_game/tests/actions/test_retrieve.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/examples/st_game/tests/test_memory.py b/examples/st_game/tests/test_memory.py
new file mode 100644
index 000000000..9c0354b36
--- /dev/null
+++ b/examples/st_game/tests/test_memory.py
@@ -0,0 +1,65 @@
+from datetime import datetime
+from metagpt.logs import logger
+from ..memory.agent_memory import AgentMemory, BasicMemory
+
+# Create some sample BasicMemory instances
+memory1 = BasicMemory(
+ memory_id="1",
+ memory_count=1,
+ type_count=1,
+ memory_type="event",
+ depth=1,
+ created=datetime.now(),
+ expiration=datetime.now(),
+ subject="Subject1",
+ predicate="Predicate1",
+ object="Object1",
+ content="This is content 1",
+ embedding_key="embedding_key_1",
+ poignancy=1,
+ keywords=["keyword1", "keyword2"],
+ filling=["memory_id_2"]
+)
+
+memory2 = BasicMemory(
+ memory_id="2",
+ memory_count=2,
+ type_count=2,
+ memory_type="thought",
+ depth=2,
+ created=datetime.now(),
+ expiration=None,
+ subject="Subject2",
+ predicate="Predicate2",
+ object="Object2",
+ content="This is content 2",
+ embedding_key="embedding_key_2",
+ poignancy=2,
+ keywords=["keyword3", "keyword4"],
+ filling=[]
+)
+
+if __name__ == "__main__":
+ # Create an AgentMemory instance and add the created BasicMemory instances
+ agent_memory = AgentMemory(memory_saved="sample_memory_folder")
+ agent_memory.add_event(memory1)
+ agent_memory.add_thought(memory2)
+
+ # Save the AgentMemory to a JSON file
+ agent_memory.save("sample_memory_folder")
+
+ # Load the AgentMemory from the JSON file
+ loaded_agent_memory = AgentMemory(memory_saved="sample_memory_folder")
+
+ # Get the summarized latest events
+ latest_events = loaded_agent_memory.get_summarized_latest_events(retention=2)
+ print("Summarized Latest Events:")
+ for event in latest_events:
+ print(event)
+
+ # Get the last chat for a specific role
+ last_chat = loaded_agent_memory.get_last_chat(target_role_name="role1")
+ if last_chat:
+ print(f"Last chat for role1: {last_chat.content}")
+ else:
+ print("No chat found for role1")