diff --git a/examples/st_game/actions/run_reflect_action.py b/examples/st_game/actions/run_reflect_action.py new file mode 100644 index 000000000..7d4a897b8 --- /dev/null +++ b/examples/st_game/actions/run_reflect_action.py @@ -0,0 +1,277 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : integration reflect action + +import re +from ..roles.st_role import STRole +from ..actions.st_action import STAction +from ..memory.agent_memory import BasicMemory + +# run_gpt_prompt_focal_pt方法 +class AgentFocusPt(STAction): + + def __init__(self, name="AgentFocusPt", context: list[BasicMemory] = None, llm=None): + super().__init__(name, context, llm) + + def _func_validate(self, llm_resp: str, prompt: str) -> bool: + try: + self._func_cleanup(llm_resp, prompt) + return True + except: + return False + + def _func_cleanup(self, llm_resp: str, prompt: str = "") -> str: + llm_resp = "1) " + llm_resp.strip() + ret = [] + for i in llm_resp.split("\n"): + ret += [i.split(") ")[-1]] + return ret + + def _func_fail_default_resp(self) -> str: + pass + + async def run(self, role: STRole, statements: str, n: int, test_input = None) -> str: + def create_prompt_input(role: STRole, statements, n, test_input=None): + prompt_input = [statements, str(n)] + return prompt_input + + prompt_input = create_prompt_input(role, statements,n) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, + "generate_focal_pt_v1.txt") + + example_output = '["What should Jane do for lunch", "Does Jane like strawberry", "Who is Jane"]' + special_instruction = "Output must be a list of str." + output = await self._run_v2(prompt, + example_output, + special_instruction) + + return output[0] + + +# run_gpt_prompt_insight_and_guidance +class AgentInsightAndGuidance(STAction): + + def __init__(self, name="AgentInsightAndGuidance", context: list[BasicMemory] = None, llm=None): + super().__init__(name, context, llm) + + def _func_validate(self, llm_resp: str, prompt: str) -> bool: + try: + self._func_cleanup(llm_resp, prompt) + return True + except: + return False + + def _func_cleanup(self, llm_resp: str, prompt: str = "") -> str: + llm_resp = "1. " + llm_resp.strip() + ret = dict() + for i in llm_resp.split("\n"): + row = i.split(". ")[-1] + thought = row.split("(because of ")[0].strip() + evi_raw = row.split("(because of ")[1].split(")")[0].strip() + evi_raw = re.findall(r'\d+', evi_raw) + evi_raw = [int(i.strip()) for i in evi_raw] + ret[thought] = evi_raw + return ret + + def _func_fail_default_resp(self) -> str: + pass + + async def run(self, role: STRole, statements: str, n: int, test_input = None) -> str: + def create_prompt_input(role: STRole, statements, n, test_input=None): + prompt_input = [statements, str(n)] + return prompt_input + + prompt_input = create_prompt_input(role, statements,n) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, + "insight_and_evidence_v1.txt") + + output = await self._run_v1(prompt) + + return output[0] + +# run_gpt_prompt_event_triple +class AgentEventTriple(STAction): + def __init__(self, name="AgentEventTriple", context: list[BasicMemory] = None, llm=None): + super().__init__(name, context, llm) + + def _func_validate(self, llm_resp: str, prompt: str) -> bool: + try: + llm_resp = self._func_cleanup(llm_resp, prompt="") + if len(llm_resp) != 2: + return False + except: return False + return True + + def _func_cleanup(self, llm_resp: str, prompt: str = "") -> str: + cr = llm_resp.strip() + cr = [i.strip() for i in cr.split(")")[0].split(",")] + return cr + + def _func_fail_default_resp(self) -> str: + pass + + async def run(self, statements: str, role: STRole, verbose = False) -> str: + def create_prompt_input(statements, role): + if "(" in statements: + statements = statements.split("(")[-1].split(")")[0] + prompt_input = [role._rc.scratch.name, + statements, + role._rc.scratch.name] + return prompt_input + + prompt_input = create_prompt_input(statements, role) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, + "generate_event_triple_v1.txt") + + output = await self._run_v1(prompt) + + return output[0] + +# run_gpt_prompt_event_poignancy +class AgentEventPoignancy(STAction): + def __init__(self, name="AgentEventPoignancy", context: list[BasicMemory] = None, llm=None): + super().__init__(name, context, llm) + + def _func_validate(self, llm_resp: str, prompt: str) -> bool: + try: + self._func_cleanup(llm_resp, prompt) + return True + except: + return False + + def _func_cleanup(self, llm_resp: str, prompt: str = "") -> str: + llm_resp = int(llm_resp.strip()) + return llm_resp + + def _func_fail_default_resp(self) -> str: + pass + + async def run(self, role: STRole, statements: str, test_input = None, verbose = False) -> str: + def create_prompt_input(role: STRole, statements: str, test_input=None): + prompt_input = [role._rc.scratch.name, + role._rc.scratch.get_str_iss(), + role._rc.scratch.name, + statements] + return prompt_input + + prompt_input = create_prompt_input(role, statements) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, + "poignancy_event_v1.txt") + + example_output = "5" ######## + special_instruction = "The output should ONLY contain ONE integer value on the scale of 1 to 10." + output = await self._run_v2(prompt, + example_output, + special_instruction) + + return output[0] + +# run_gpt_prompt_chat_poignancy +class AgentChatPoignancy(STAction): + def __init__(self, name="AgentChatPoignancy", context: list[BasicMemory] = None, llm=None): + super().__init__(name, context, llm) + + def _func_validate(self, llm_resp: str, prompt: str) -> bool: + try: + self._func_cleanup(llm_resp, prompt) + return True + except: + return False + + def _func_cleanup(self, llm_resp: str, prompt: str = "") -> str: + llm_resp = int(llm_resp.strip()) + return llm_resp + + def _func_fail_default_resp(self) -> str: + pass + + async def run(self, role: STRole, statements: str, test_input = None, verbose = False) -> str: + def create_prompt_input(role: STRole, statements, test_input=None): + prompt_input = [role._rc.scratch.name, + role._rc.scratch.get_str_iss(), + role._rc.scratch.name, + statements] + return prompt_input + + prompt_input = create_prompt_input(role, statements) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, + "poignancy_chat_v1.txt") + + example_output = "5" ######## + special_instruction = "The output should ONLY contain ONE integer value on the scale of 1 to 10." + output = await self._run_v2(prompt, + example_output, + special_instruction) + + return output[0] + +# run_gpt_prompt_planning_thought_on_convo +class AgentPlanThoughtOnConvo(STAction): + def __init__(self, name="AgentPlanThoughtOnConvo", context: list[BasicMemory] = None, llm=None): + super().__init__(name, context, llm) + + def _func_validate(self, llm_resp: str, prompt: str) -> bool: + try: + self._func_cleanup(llm_resp, prompt) + return True + except: + return False + + def _func_cleanup(self, llm_resp: str, prompt: str = "") -> str: + return llm_resp.split('"')[0].strip() + + def _func_fail_default_resp(self) -> str: + pass + + async def run(self, role: STRole, statements: str, test_input = None, verbose = False) -> str: + def create_prompt_input(role, statements, test_input=None): + prompt_input = [statements, + role._rc.scratch.name, + role._rc.scratch.name, + role._rc.scratch.name] + return prompt_input + + prompt_input = create_prompt_input(role, statements) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, + "planning_thought_on_convo_v1.txt") + + output = await self._run_v1(prompt) + + return output[0] + +# run_gpt_prompt_memo_on_convo +class AgentMemoryOnConvo(STAction): + def __init__(self, name="AgentMemoryOnConvo", context: list[BasicMemory] = None, llm=None): + super().__init__(name, context, llm) + + def _func_validate(self, llm_resp: str, prompt: str) -> bool: + try: + self._func_cleanup(llm_resp, prompt) + return True + except: + return False + + def _func_cleanup(self, llm_resp: str, prompt: str = "") -> str: + return llm_resp.split('"')[0].strip() + + def _func_fail_default_resp(self) -> str: + pass + + async def run(self, role: STRole, statements: str, test_input = None, verbose = False) -> str: + def create_prompt_input(role, statements, test_input=None): + prompt_input = [statements, + role._rc.scratch.name, + role._rc.scratch.name, + role._rc.scratch.name] + return prompt_input + + prompt_input = create_prompt_input(role, statements) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, + "memo_on_convo_v1.txt") + example_output = 'Jane Doe was interesting to talk to.' + special_instruction = 'The output should ONLY contain a string that summarizes anything interesting that the agent may have noticed' + output = await self._run_v2(prompt, + example_output, + special_instruction) + + return output[0] \ No newline at end of file diff --git a/examples/st_game/compress_sim_storage.py b/examples/st_game/compress_sim_storage.py new file mode 100644 index 000000000..348c231d3 --- /dev/null +++ b/examples/st_game/compress_sim_storage.py @@ -0,0 +1,66 @@ +""" +Author: Joon Sung Park (joonspk@stanford.edu) + +File: compress_sim_storage.py +Description: Compresses a simulation for replay demos. +""" + +import shutil +import json +from utils.tools import find_filenames, create_folder_if_not_there + +def compress(sim_code): + # 构建文件路径 + sim_storage = f"../environment/frontend_server/storage/{sim_code}" + compressed_storage = f"../environment/frontend_server/compressed_storage/{sim_code}" + persona_folder = sim_storage + "/personas" + move_folder = sim_storage + "/movement" + meta_file = sim_storage + "/reverie/meta.json" + + # 收集角色名称 + persona_names = [] + for i in find_filenames(persona_folder, ""): + x = i.split("/")[-1].strip() + if x[0] != ".": + persona_names += [x] + + # 最大移动计算 + max_move_count = max([int(i.split("/")[-1].split(".")[0]) + for i in find_filenames(move_folder, "json")]) + + persona_last_move = dict() + master_move = dict() + for i in range(max_move_count + 1): + master_move[i] = dict() + with open(f"{move_folder}/{str(i)}.json") as json_file: + i_move_dict = json.load(json_file)["persona"] + for p in persona_names: + move = False + if i == 0: + move = True + elif (i_move_dict[p]["movement"] != persona_last_move[p]["movement"] + or i_move_dict[p]["pronunciatio"] != persona_last_move[p]["pronunciatio"] + or i_move_dict[p]["description"] != persona_last_move[p]["description"] + or i_move_dict[p]["chat"] != persona_last_move[p]["chat"]): + move = True + + if move: + persona_last_move[p] = {"movement": i_move_dict[p]["movement"], + "pronunciatio": i_move_dict[p]["pronunciatio"], + "description": i_move_dict[p]["description"], + "chat": i_move_dict[p]["chat"]} + master_move[i][p] = {"movement": i_move_dict[p]["movement"], + "pronunciatio": i_move_dict[p]["pronunciatio"], + "description": i_move_dict[p]["description"], + "chat": i_move_dict[p]["chat"]} + + # 创建存储目录 + create_folder_if_not_there(compressed_storage) + with open(f"{compressed_storage}/master_movement.json", "w") as outfile: + outfile.write(json.dumps(master_move, indent=2)) + + shutil.copyfile(meta_file, f"{compressed_storage}/meta.json") + shutil.copytree(persona_folder, f"{compressed_storage}/personas/") + +if __name__ == '__main__': + compress("July1_the_ville_isabella_maria_klaus-step-3-9") diff --git a/examples/st_game/memory/agent_memory.py b/examples/st_game/memory/agent_memory.py index 46ee3cb0e..8cce0964d 100644 --- a/examples/st_game/memory/agent_memory.py +++ b/examples/st_game/memory/agent_memory.py @@ -36,7 +36,7 @@ class BasicMemory(Message): self.created: datetime = created # 创建时间 self.expiration: datetime = expiration # 记忆失效时间,默认为空() - self.last_accessed: datetime = created # 上一次调用的时间,初始化时候与self.created一致 + self.last_accessed: datetime = self.created # 上一次调用的时间,初始化时候与self.created一致 self.subject: str = subject # 主语 self.predicate: str = predicate # 谓语 diff --git a/examples/st_game/memory/retrieve.py b/examples/st_game/memory/retrieve.py index e41145812..423ff224a 100644 --- a/examples/st_game/memory/retrieve.py +++ b/examples/st_game/memory/retrieve.py @@ -10,10 +10,9 @@ from numpy.linalg import norm from ..memory.agent_memory import AgentMemory, BasicMemory from ..utils.utils import get_embedding +from ..roles.st_role import STRole - -def agent_retrieve(agent_memory: AgentMemory, curr_time: datetime.datetime, memory_forget: float, query: str, - n: int = 30, topk: int = 4) -> list[BasicMemory]: +def agent_retrieve(agent_memory: AgentMemory, curr_time: datetime.datetime, memory_forget: float, query: str, topk: int = 4) -> list[BasicMemory]: """ Retrieve需要集合Role使用,原因在于Role才具有AgentMemory,scratch 逻辑:Role调用该函数,self._rc.AgentMemory,self._rc.scratch.curr_time,self._rc.scratch.memory_forget @@ -28,8 +27,7 @@ def agent_retrieve(agent_memory: AgentMemory, curr_time: datetime.datetime, memo } """ memories = agent_memory.storage - sorted_memories = sorted(memories, key=lambda memory_node: memory_node.last_accessed_time, reverse=True) - memories = sorted_memories[:n] if len(sorted_memories) >= n else sorted_memories + memories = sorted(memories, key=lambda memory_node: memory_node.last_accessed, reverse=True) score_list = [] score_list = extract_importance(memories, score_list) @@ -38,7 +36,7 @@ def agent_retrieve(agent_memory: AgentMemory, curr_time: datetime.datetime, memo score_list = normalize_score_floats(score_list, 0, 1) total_dict = {} - gw = [1, 1, 1] # 三个因素的权重,重要性,近因性,相关性 + gw = [1, 1, 1] # 三个因素的权重,重要性,近因性,相关性, for i in range(len(score_list)): total_score = (score_list[i]['importance'] * gw[0] + score_list[i]['recency'] * gw[1] + @@ -48,7 +46,25 @@ def agent_retrieve(agent_memory: AgentMemory, curr_time: datetime.datetime, memo result = top_highest_x_values(total_dict, topk) - return result + return result # 返回的是一个BasicMemory列表 + +def new_agent_retrieve(strole: STRole, focus_points: list, n_count = 30): + """ + 输入为Strole,关注点列表,返回记忆数量 + 输出为字典,键为focus_point,值为对应的记忆列表 + """ + retrieved = dict() + for focal_pt in focus_points: + nodes = [[i.last_accessed, i] + for i in strole._rc.memory.event_list + strole._rc.memory.thought_list + if "idle" not in i.embedding_key] + nodes = sorted(nodes, key=lambda x: x[0]) + nodes = [i for created, i in nodes] + results = agent_retrieve(strole._rc.memory, strole._rc.scratch.curr_time, strole._rc.scratch.recency_decay, focal_pt, n_count) + for n in results: + n.last_accessed = strole._rc.scratch.curr_time + + retrieved[focal_pt] = results def top_highest_x_values(d, x): @@ -143,117 +159,4 @@ def normalize_score_floats(score_list, target_min, target_max): score_list[i]['relevance'] = relevance_list[i] score_list[i]['recency'] = recency_list[i] - return score_list - - -def normalize_dict_floats(d: dict, target_min: Union[int, float], target_max: Union[int, float]) -> dict: - """ - This function normalizes the float values of a given dictionary 'd' between - a target minimum and maximum value. The normalization is done by scaling the - values to the target range while maintaining the same relative proportions - between the original values. - - INPUT: - d: Dictionary. The input dictionary whose float values need to be - normalized. - target_min: Integer or float. The minimum value to which the original - values should be scaled. - target_max: Integer or float. The maximum value to which the original - values should be scaled. - OUTPUT: - d: A new dictionary with the same keys as the input but with the float - values normalized between the target_min and target_max. - - Example input: - d = {'a':1.2,'b':3.4,'c':5.6,'d':7.8} - target_min = -5 - target_max = 5 - """ - min_val = min(val for val in d.values()) - max_val = max(val for val in d.values()) - range_val = max_val - min_val - - if range_val == 0: - for key, val in d.items(): - d[key] = (target_max - target_min) / 2 - else: - for key, val in d.items(): - d[key] = ((val - min_val) * (target_max - target_min) - / range_val + target_min) - return d - - -def new_retrieve(role, focal_points, n_count=30): - """ - Given the current role and focal points (focal points are events or - thoughts for which we are retrieving), we retrieve a set of nodes for each - of the focal points and return a dictionary. - - INPUT: - role: The current role object whose memory we are retrieving. - focal_points: A list of focal points (string description of the events or - thoughts that is the focus of current retrieval). - OUTPUT: - retrieved: A dictionary whose keys are a string focal point, and whose - values are a list of Node object in the agent's associative - memory. - - Example input: - role = object - focal_points = ["How are you?", "Jane is swimming in the pond"] - """ - # is the main dictionary that we are returning - retrieved = dict() - for focal_pt in focal_points: - scratch = role._rc.scratch - # Getting all nodes from the agent's memory (both thoughts and events) and - # sorting them by the datetime of creation. - # You could also imagine getting the raw conversation, but for now. - nodes = [[i.last_accessed, i] - for i in role._rc.memory.event_list + role._rc.memory.thought_list - if "idle" not in i.embedding_key] - nodes = sorted(nodes, key=lambda x: x[0]) - nodes = [i for created, i in nodes] - - # Calculating the component dictionaries and normalizing them. - recency_out = extract_recency(role, nodes) # TODO - recency_out = normalize_dict_floats(recency_out, 0, 1) - importance_out = extract_importance(role, nodes) - importance_out = normalize_dict_floats(importance_out, 0, 1) - relevance_out = extract_relevance(role, nodes, focal_pt) - relevance_out = normalize_dict_floats(relevance_out, 0, 1) - - # Computing the final scores that combines the component values. - # Note to self: test out different weights. [1, 1, 1] tends to work - # decently, but in the future, these weights should likely be learned, - # perhaps through an RL-like process. - # gw = [1, 1, 1] - # gw = [1, 2, 1] - gw = [0.5, 3, 2] - master_out = dict() - for key in recency_out.keys(): - master_out[key] = (scratch.recency_w * recency_out[key] * gw[0] - + scratch.relevance_w * relevance_out[key] * gw[1] - + scratch.importance_w * importance_out[key] * gw[2]) - - master_out = top_highest_x_values(master_out, len(master_out.keys())) - for key, val in master_out.items(): - print(role._rc.memory.id_to_node[key].embedding_key, val) - print(scratch.recency_w * recency_out[key] * 1, - scratch.relevance_w * relevance_out[key] * 1, - scratch.importance_w * importance_out[key] * 1) - - # Extracting the highest x values. - # has the key of node.id and value of float. Once we get the - # highest x values, we want to translate the node.id into nodes and return - # the list of nodes. - master_out = top_highest_x_values(master_out, n_count) - master_nodes = [role._rc.memory.id_to_node[key] - for key in list(master_out.keys())] - - for n in master_nodes: - n.last_accessed = scratch.curr_time - - retrieved[focal_pt] = master_nodes - - return retrieved + return score_list \ No newline at end of file diff --git a/examples/st_game/memory/scratch.py b/examples/st_game/memory/scratch.py index 19a566fa0..8c98e20ca 100644 --- a/examples/st_game/memory/scratch.py +++ b/examples/st_game/memory/scratch.py @@ -5,7 +5,7 @@ import datetime import json -from ..utils.check import check_if_file_exists +from ..utils.tools import check_if_file_exists class Scratch: @@ -510,3 +510,4 @@ class Scratch: minute = curr_min_sum % 60 ret += f"{hour:02}:{minute:02} || {row[0]}\n" return ret + \ No newline at end of file diff --git a/examples/st_game/prompts/generate_event_triple_v1.txt b/examples/st_game/prompts/generate_event_triple_v1.txt new file mode 100644 index 000000000..699ce154c --- /dev/null +++ b/examples/st_game/prompts/generate_event_triple_v1.txt @@ -0,0 +1,30 @@ +generate_event_triple_v1.txt + +Variables: +!! -- Persona's full name. +!! -- Current action description +!! -- Persona's full name. + +### +Task: Turn the input into (subject, predicate, object). + +Input: Sam Johnson is eating breakfast. +Output: (Dolores Murphy, eat, breakfast) +--- +Input: Joon Park is brewing coffee. +Output: (Joon Park, brew, coffee) +--- +Input: Jane Cook is sleeping. +Output: (Jane Cook, is, sleep) +--- +Input: Michael Bernstein is writing email on a computer. +Output: (Michael Bernstein, write, email) +--- +Input: Percy Liang is teaching students in a classroom. +Output: (Percy Liang, teach, students) +--- +Input: Merrie Morris is running on a treadmill. +Output: (Merrie Morris, run, treadmill) +--- +Input: !! is !!. +Output: (!!, \ No newline at end of file diff --git a/examples/st_game/prompts/generate_focal_pt_v1.txt b/examples/st_game/prompts/generate_focal_pt_v1.txt new file mode 100644 index 000000000..73f76ec61 --- /dev/null +++ b/examples/st_game/prompts/generate_focal_pt_v1.txt @@ -0,0 +1,11 @@ +generate_focal_pt_v1.txt + +Variables: +!! -- Event/thought statements +!! -- Count + +### +!! + +Given only the information above, what are !! most salient high-level questions we can answer about the subjects grounded in the statements? +1) \ No newline at end of file diff --git a/examples/st_game/prompts/insight_and_evidence_v1.txt b/examples/st_game/prompts/insight_and_evidence_v1.txt new file mode 100644 index 000000000..579c81637 --- /dev/null +++ b/examples/st_game/prompts/insight_and_evidence_v1.txt @@ -0,0 +1,12 @@ +insight_and_evidence_v1.txt + +Variables: +!! -- Numbered list of event/thought statements +!! -- target persona name or "the conversation" + +### +Input: +!! + +What !! high-level insights can you infer from the above statements? (example format: insight (because of 1, 5, 3)) +1. \ No newline at end of file diff --git a/examples/st_game/prompts/memo_on_convo_v1.txt b/examples/st_game/prompts/memo_on_convo_v1.txt new file mode 100644 index 000000000..38b34bfbd --- /dev/null +++ b/examples/st_game/prompts/memo_on_convo_v1.txt @@ -0,0 +1,15 @@ +memo_on_convo_v1.txt + +Variables: +!! -- All convo utterances +!! -- persona name +!! -- persona name +!! -- persona name + +### +[Conversation] +!! + +Write down if there is anything from the conversation that !! might have found interesting from !!'s perspective, in a full sentence. + +"!! \ No newline at end of file diff --git a/examples/st_game/prompts/planning_thought_on_convo_v1.txt b/examples/st_game/prompts/planning_thought_on_convo_v1.txt new file mode 100644 index 000000000..0563dcee9 --- /dev/null +++ b/examples/st_game/prompts/planning_thought_on_convo_v1.txt @@ -0,0 +1,15 @@ +planning_thought_on_convo_v1.txt + +Variables: +!! -- All convo utterances +!! -- persona name +!! -- persona name +!! -- persona name + +### +[Conversation] +!! + +Write down if there is anything from the conversation that !! need to remember for her planning, from !!'s perspective, in a full sentence. + +"!! \ No newline at end of file diff --git a/examples/st_game/prompts/poignancy_event_v1.txt b/examples/st_game/prompts/poignancy_event_v1.txt new file mode 100644 index 000000000..34975696b --- /dev/null +++ b/examples/st_game/prompts/poignancy_event_v1.txt @@ -0,0 +1,15 @@ +poignancy_event_v1.txt + +!!: agent name +!!: iss +!!: name +!!: event description + +### +Here is a brief description of !!. +!! + +On the scale of 1 to 10, where 1 is purely mundane (e.g., brushing teeth, making bed) and 10 is extremely poignant (e.g., a break up, college acceptance), rate the likely poignancy of the following event for !!. + +Event: !! +Rate (return a number between 1 to 10): \ No newline at end of file diff --git a/examples/st_game/reflect/reflect.py b/examples/st_game/reflect/reflect.py new file mode 100644 index 000000000..d76b03fd1 --- /dev/null +++ b/examples/st_game/reflect/reflect.py @@ -0,0 +1,217 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : reflect function +import sys + +import datetime +import random + +from numpy import dot +from numpy.linalg import norm +from ..roles.st_role import STRole +from ..utils.utils import get_embedding +from ..actions.run_reflect_action import AgentFocusPt,AgentInsightAndGuidance,AgentEventTriple,AgentEventPoignancy,AgentChatPoignancy,AgentPlanThoughtOnConvo,AgentMemoryOnConvo + +def generate_focal_points(strole:STRole, n=3): + + nodes = [[i.last_accessed, i] + for i in strole._rc.memory.event_list + strole._rc.memory.thought_list + if "idle" not in i.embedding_key] + + nodes = sorted(nodes, key=lambda x: x[0]) + nodes = [i for created, i in nodes] + + statements = "" + for node in nodes[-1 * strole._rc.scratch.importance_ele_n:]: + statements += node.embedding_key + "\n" + run_focal_pt = AgentFocusPt() + # Question 1 + return run_focal_pt.run(strole, statements, n) + + +def generate_insights_and_evidence(strole, nodes, n=5): + + statements = "" + for count, node in enumerate(nodes): + statements += f'{str(count)}. {node.embedding_key}\n' + run_insight_and_guidance = AgentInsightAndGuidance() + ret = run_insight_and_guidance.run(strole, statements, n) + + print(ret) + try: + + for thought, evi_raw in ret.items(): + evidence_node_id = [nodes[i].node_id for i in evi_raw] + ret[thought] = evidence_node_id + return ret + except: + return {"this is blank": "node_1"} + + +def generate_action_event_triple(act_desp, strole): + """TODO + + INPUT: + act_desp: the description of the action (e.g., "sleeping") + strole: The Persona class instance + OUTPUT: + a string of emoji that translates action description. + EXAMPLE OUTPUT: + "🧈🍞" + """ + run_event_triple = AgentEventTriple() + return AgentEventTriple(act_desp, strole) + + +def generate_poig_score(strole:STRole, event_type, description): + + if "is idle" in description: + return 1 + + if event_type == "event" or event_type == "thought": + run_event_poignancy = AgentEventPoignancy() + return run_event_poignancy.run(strole, description)[0] + elif event_type == "chat": + run_chat_poignancy = AgentChatPoignancy() + return run_chat_poignancy.run(strole, + strole._rc.scratch.act_description)[0] + + +def generate_planning_thought_on_convo(strole, all_utt): + run_planning_on_convo = AgentPlanThoughtOnConvo() + return run_planning_on_convo.run(strole, all_utt) + + +def generate_memo_on_convo(strole, all_utt): + run_memo_on_convo = AgentMemoryOnConvo() + return run_memo_on_convo.run(strole, all_utt) + + + +# Done +def run_reflect(strole:STRole): + """ + Run the actual reflection. We generate the focal points, retrieve any + relevant nodes, and generate thoughts and insights. + + INPUT: + strole: Current Persona object + Output: + None + """ + # Reflection requires certain focal points. Generate that first. + focal_points = generate_focal_points(strole, 3) + # Retrieve the relevant Nodes object for each of the focal points. + # has keys of focal points, and values of the associated Nodes. + retrieved = strole.retrieve(focal_points) + + # For each of the focal points, generate thoughts and save it in the + # agent's memory. + for focal_pt, nodes in retrieved.items(): + xx = [i.embedding_key for i in nodes] + for xxx in xx: print(xxx) + + thoughts = generate_insights_and_evidence(strole, nodes, 5) + # 生成的是字典类型 + for thought, evidence in thoughts.items(): + created = strole.scratch.curr_time + expiration = strole.scratch.curr_time + datetime.timedelta(days=30) + s, p, o = generate_action_event_triple(thought, strole) + keywords = set([s, p, o]) + thought_poignancy = generate_poig_score(strole, "thought", thought) + thought_embedding_pair = (thought, get_embedding(thought)) + + strole._rc.memory.add_thought(created, expiration, s, p, o, + thought, keywords, thought_poignancy, + thought_embedding_pair, evidence) + +# Done +def reflection_trigger(strole: STRole): + """ + Given the current strole, determine whether the strole should run a + reflection. + + Our current implementation checks for whether the sum of the new importance + measure has reached the set (hyper-parameter) threshold. + + INPUT: + strole: Current Persona object + Output: + True if we are running a new reflection. + False otherwise. + """ + print(strole._rc.scratch.name, "strole.scratch.importance_trigger_curr::", strole._rc.scratch.importance_trigger_curr) + print(strole._rc.scratch.importance_trigger_max) + + if (strole._rc.scratch.importance_trigger_curr <= 0 and + [] != strole._rc.memory.seq_event + strole._rc.memory.seq_thought): + return True + return False + +# Done +def reset_reflection_counter(strole: STRole): + """ + We reset the counters used for the reflection trigger. + + INPUT: + strole: Current Persona object + Output: + None + """ + strole_imt_max = strole._rc.scratch.importance_trigger_max + strole._rc.scratch.importance_trigger_curr = strole_imt_max + strole._rc.scratch.importance_ele_n = 0 + +# Question 1 chat函数 +def reflect(strole: STRole): + """ + The main reflection module for the strole. We first check if the trigger + conditions are met, and if so, run the reflection and reset any of the + relevant counters. + + INPUT: + strole: Current Persona object + Output: + None + """ + if reflection_trigger(strole): + run_reflect(strole) + reset_reflection_counter(strole) + + if strole._rc.scratch.chatting_end_time: + if strole._rc.scratch.curr_time + datetime.timedelta(0,10) == strole._rc.scratch.chatting_end_time: + all_utt = "" + if strole._rc.scratch.chat: + for row in strole._rc.scratch.chat: + all_utt += f"{row[0]}: {row[1]}\n" + + # Question memory添加对话函数 + evidence = [strole._rc.memory.get_last_chat(strole._rc.scratch.chatting_with).memory_id] + + planning_thought = generate_planning_thought_on_convo(strole, all_utt) + planning_thought = f"For {strole._rc.scratch.name}'s planning: {planning_thought}" + + created = strole._rc.scratch.curr_time + expiration = strole._rc.scratch.curr_time + datetime.timedelta(days=30) + s, p, o = generate_action_event_triple(planning_thought, strole) + keywords = set([s, p, o]) + thought_poignancy = generate_poig_score(strole, "thought", planning_thought) + thought_embedding_pair = (planning_thought, get_embedding(planning_thought)) + + strole._rc.memory.add_thought(created, expiration, s, p, o, + planning_thought, keywords, thought_poignancy, + thought_embedding_pair, evidence) + + memo_thought = generate_memo_on_convo(strole, all_utt) + memo_thought = f"{strole._rc.scratch.name} {memo_thought}" + + created = strole._rc.scratch.curr_time + expiration = strole._rc.scratch.curr_time + datetime.timedelta(days=30) + s, p, o = generate_action_event_triple(memo_thought, strole) + keywords = set([s, p, o]) + thought_poignancy = generate_poig_score(strole, "thought", memo_thought) + thought_embedding_pair = (memo_thought, get_embedding(memo_thought)) + + strole._rc.memory.add_thought(created, expiration, s, p, o, + memo_thought, keywords, thought_poignancy, + thought_embedding_pair, evidence) \ No newline at end of file diff --git a/examples/st_game/reflect/st_reflect.py b/examples/st_game/reflect/st_reflect.py index 1b22ce99f..da31a22eb 100644 --- a/examples/st_game/reflect/st_reflect.py +++ b/examples/st_game/reflect/st_reflect.py @@ -4,16 +4,10 @@ import asyncio import json +import time from metagpt.logs import logger -import time -from ga_prompt_generator import final_response -''' -等待Agent和memory更新,保留相关引用但可以忽略。 -''' -from ..memory.associative_memory import MemoryBasic - -import json -import time +from ..prompts.wrapper_prompt import special_response_generate +from ..memory.agent_memory import BasicMemory async def agent_reflect(memories_list): diff --git a/examples/st_game/roles/st_role.py b/examples/st_game/roles/st_role.py index 643853063..5cfd7018d 100644 --- a/examples/st_game/roles/st_role.py +++ b/examples/st_game/roles/st_role.py @@ -23,7 +23,7 @@ from ..memory.spatial_memory import MemoryTree from ..actions.dummy_action import DummyAction from ..actions.user_requirement import UserRequirement from ..maze_environment import MazeEnvironment -from ..memory.retrieve import agent_retrieve +from ..memory.retrieve import new_agent_retrieve from ..memory.scratch import Scratch from ..utils.utils import get_embedding, generate_poig_score @@ -216,10 +216,9 @@ class STRole(Role): return ret_events - async def retrieve(self, query, n=30, topk=4): + async def retrieve(self, focus_points, n=30): # TODO retrieve memories from agent_memory - retrieve_memories = agent_retrieve(self._rc.memory, self._rc.scratch.curr_time, self._rc.scratch.recency_decay, - query, n, topk) + retrieve_memories = new_agent_retrieve(self,focus_points,n) return retrieve_memories async def plan(self): diff --git a/examples/st_game/utils/check.py b/examples/st_game/utils/check.py deleted file mode 100644 index 0a806fe2d..000000000 --- a/examples/st_game/utils/check.py +++ /dev/null @@ -1,14 +0,0 @@ -def check_if_file_exists(curr_file): - """ - Checks if a file exists - ARGS: - curr_file: path to the current csv file. - RETURNS: - True if the file exists - False if the file does not exist - """ - try: - with open(curr_file) as f_analysis_file: pass - return True - except: - return False \ No newline at end of file diff --git a/examples/st_game/utils/tools.py b/examples/st_game/utils/tools.py new file mode 100644 index 000000000..320c161d1 --- /dev/null +++ b/examples/st_game/utils/tools.py @@ -0,0 +1,60 @@ +import os + +def check_if_file_exists(curr_file): + """ + Checks if a file exists + ARGS: + curr_file: path to the current csv file. + RETURNS: + True if the file exists + False if the file does not exist + """ + try: + with open(curr_file) as f_analysis_file: + pass + return True + except: + return False + +def create_folder_if_not_there(curr_path): + """ + Checks if a folder in the curr_path exists. If it does not exist, creates + the folder. + Note that if the curr_path designates a file location, it will operate on + the folder that contains the file. But the function also works even if the + path designates to just a folder. + Args: + curr_list: list to write. The list comes in the following form: + [['key1', 'val1-1', 'val1-2'...], + ['key2', 'val2-1', 'val2-2'...],] + outfile: name of the csv file to write + RETURNS: + True: if a new folder is created + False: if a new folder is not created + """ + outfolder_name = curr_path.split("/") + if len(outfolder_name) != 1: + # This checks if the curr path is a file or a folder. + if "." in outfolder_name[-1]: + outfolder_name = outfolder_name[:-1] + + outfolder_name = "/".join(outfolder_name) + if not os.path.exists(outfolder_name): + os.makedirs(outfolder_name) + return True + + return False + +def find_filenames(path_to_dir, suffix=".csv"): + """ + Given a directory, find all files that end with the provided suffix and + return their paths. + ARGS: + path_to_dir: Path to the current directory + suffix: The target suffix. + RETURNS: + A list of paths to all files in the directory. + """ + filenames = os.listdir(path_to_dir) + return [path_to_dir + "/" + filename + for filename in filenames if filename.endswith(suffix)]