diff --git a/examples/st_game/maze_environment.py b/examples/st_game/maze_environment.py index 434516ec6..9d393d765 100644 --- a/examples/st_game/maze_environment.py +++ b/examples/st_game/maze_environment.py @@ -10,3 +10,5 @@ class MazeEnvironment(Environment): def __init__(self, name: str, maze: Maze) -> None: self.name = name self.maze = maze + + \ No newline at end of file diff --git a/examples/st_game/memory/agent_memory.py b/examples/st_game/memory/agent_memory.py index a56100ee7..59a61155f 100644 --- a/examples/st_game/memory/agent_memory.py +++ b/examples/st_game/memory/agent_memory.py @@ -46,6 +46,9 @@ class BasicMemory(Message): self.keywords: list = keywords # keywords self.filling: list = filling # 装的与之相关联的memory_id的列表 + def summary(self): + return (self.subject, self.predicate, self.object) + def save_to_dict(self) -> dict: """ 将MemoryBasic类转化为字典,用于存储json文件 @@ -316,3 +319,10 @@ class AgentMemory(Memory): self.embeddings[embedding_pair[0]] = embedding_pair[1] return memory_node + + + def get_summarized_latest_events(self, retention): + ret_set = set() + for e_node in self.event_list[:retention]: + ret_set.add(e_node.summary()) + return ret_set \ No newline at end of file diff --git a/examples/st_game/memory/spatial_memory.py b/examples/st_game/memory/spatial_memory.py index edbe9641d..b3357b962 100644 --- a/examples/st_game/memory/spatial_memory.py +++ b/examples/st_game/memory/spatial_memory.py @@ -104,6 +104,24 @@ class MemoryTree: return x + def add_tile_info(self, tile_info: dict): + if tile_info["world"]: + if (tile_info["world"] not in self.tree): + self.tree[tile_info["world"]] = {} + if tile_info["sector"]: + if (tile_info["sector"] not in self.tree[tile_info["world"]]): + self.tree[tile_info["world"]][tile_info["sector"]] = {} + if tile_info["arena"]: + if (tile_info["arena"] not in self.tree[tile_info["world"]] + [tile_info["sector"]]): + self.tree[tile_info["world"]][tile_info["sector"]][tile_info["arena"]] = [] + if tile_info["game_object"]: + if (tile_info["game_object"] not in self.tree[tile_info["world"]] + [tile_info["sector"]] + [tile_info["arena"]]): + self.tree[tile_info["world"]][tile_info["sector"]][tile_info["arena"]] += [ + tile_info["game_object"]] + diff --git a/examples/st_game/roles/st_role.py b/examples/st_game/roles/st_role.py index 1fcab2344..c16180a72 100644 --- a/examples/st_game/roles/st_role.py +++ b/examples/st_game/roles/st_role.py @@ -10,9 +10,10 @@ Do the steps following: - reflect, do the High-level thinking based on memories and re-add into the memory - execute, move or else in the Maze """ - +import math from pydantic import Field from pathlib import Path +from operator import itemgetter from metagpt.roles.role import Role, RoleContext from metagpt.schema import Message @@ -24,6 +25,7 @@ from ..actions.user_requirement import UserRequirement from ..maze_environment import MazeEnvironment from ..memory.retrieve import agent_retrieve from ..memory.scratch import Scratch +from ..utils.utils import get_embedding, generate_poig_score class STRoleContext(RoleContext): @@ -68,14 +70,150 @@ class STRole(Role): async def observe(self): # TODO observe info from maze_env - pass + """ + Perceive events around the role and saves it to the memory, both events + and spaces. + + We first perceive the events nearby the role, as determined by its + . If there are a lot of events happening within that radius, we + take the of the closest events. Finally, we check whether + any of them are new, as determined by . If they are new, then we + save those and return the instances for those events. + + OUTPUT: + ret_events: a list of that are perceived and new. + """ + maze = self._rc.env.maze + # PERCEIVE SPACE + # We get the nearby tiles given our current tile and the persona's vision + # radius. + nearby_tiles = maze.get_nearby_tiles(self._rc.scratch.curr_tile, + self._rc.scratch.vision_r) + + # We then store the perceived space. Note that the s_mem of the persona is + # in the form of a tree constructed using dictionaries. + for tile in nearby_tiles: + tile_info = maze.access_tile(tile) + self._rc.spatial_memory.add_tile_info(tile_info) + + # PERCEIVE EVENTS. + # We will perceive events that take place in the same arena as the + # persona's current arena. + curr_arena_path = maze.get_tile_path(self._rc.scratch.curr_tile, "arena") + # We do not perceive the same event twice (this can happen if an object is + # extended across multiple tiles). + percept_events_set = set() + # We will order our percept based on the distance, with the closest ones + # getting priorities. + percept_events_list = [] + # First, we put all events that are occuring in the nearby tiles into the + # percept_events_list + for tile in nearby_tiles: + tile_details = maze.access_tile(tile) + if tile_details["events"]: + if maze.get_tile_path(tile, "arena") == curr_arena_path: + # This calculates the distance between the persona's current tile, + # and the target tile. + dist = math.dist([tile[0], tile[1]], + [self._rc.scratch.curr_tile[0], + self._rc.scratch.curr_tile[1]]) + # Add any relevant events to our temp set/list with the distant info. + for event in tile_details["events"]: + if event not in percept_events_set: + percept_events_list += [[dist, event]] + percept_events_set.add(event) + + # We sort, and perceive only self._rc.scratch.att_bandwidth of the closest + # events. If the bandwidth is larger, then it means the persona can perceive + # more elements within a small area. + percept_events_list = sorted(percept_events_list, key=itemgetter(0)) + perceived_events = [] + for dist, event in percept_events_list[:self._rc.scratch.att_bandwidth]: + perceived_events += [event] + + # Storing events. + # is a list of instances from the persona's + # associative memory. + ret_events = [] + for p_event in perceived_events: + s, p, o, desc = p_event + if not p: + # If the object is not present, then we default the event to "idle". + p = "is" + o = "idle" + desc = "idle" + desc = f"{s.split(':')[-1]} is {desc}" + p_event = (s, p, o) + + # We retrieve the latest self._rc.scratch.retention events. If there is + # something new that is happening (that is, p_event not in latest_events), + # then we add that event to the a_mem and return it. + latest_events = self._rc.memory.get_summarized_latest_events( + self._rc.scratch.retention) + if p_event not in latest_events: + # We start by managing keywords. + keywords = set() + sub = p_event[0] + obj = p_event[2] + if ":" in p_event[0]: + sub = p_event[0].split(":")[-1] + if ":" in p_event[2]: + obj = p_event[2].split(":")[-1] + keywords.update([sub, obj]) + + # Get event embedding + desc_embedding_in = desc + if "(" in desc: + desc_embedding_in = (desc_embedding_in.split("(")[1] + .split(")")[0] + .strip()) + if desc_embedding_in in self._rc.memory.embeddings: + event_embedding = self._rc.memory.embeddings[desc_embedding_in] + else: + event_embedding = get_embedding(desc_embedding_in) + event_embedding_pair = (desc_embedding_in, event_embedding) + + # Get event poignancy. + event_poignancy = generate_poig_score(self, + "action", + desc_embedding_in) + + # If we observe the persona's self chat, we include that in the memory + # of the persona here. + chat_node_ids = [] + if p_event[0] == f"{self.name}" and p_event[1] == "chat with": + curr_event = self._rc.scratch.act_event + if self._rc.scratch.act_description in self._rc.memory.embeddings: + chat_embedding = self._rc.memory.embeddings[ + self._rc.scratch.act_description] + else: + chat_embedding = get_embedding(self._rc.scratch + .act_description) + chat_embedding_pair = (self._rc.scratch.act_description, + chat_embedding) + chat_poignancy = generate_poig_score(self._rc.scratch, "chat", + self._rc.scratch.act_description) + chat_node = self._rc.memory.add_chat(self._rc.scratch.curr_time, None, + curr_event[0], curr_event[1], curr_event[2], + self._rc.scratch.act_description, keywords, + chat_poignancy, chat_embedding_pair, + self._rc.scratch.chat) + chat_node_ids = [chat_node.node_id] + + # Finally, we add the current event to the agent's memory. + ret_events += [self._rc.memory.add_event(self._rc.scratch.curr_time, None, + s, p, o, desc, keywords, event_poignancy, + event_embedding_pair, chat_node_ids)] + self._rc.scratch.importance_trigger_curr -= event_poignancy + self._rc.scratch.importance_ele_n += 1 + + return ret_events async def retrieve(self, query, n = 30 ,topk = 4): # TODO retrieve memories from agent_memory retrieve_memories = agent_retrieve(self._rc.memory, self._rc.scratch.curr_time, self._rc.scratch.recency_decay, query, n, topk) return retrieve_memories - async def plan(self): # TODO make a plan diff --git a/examples/st_game/utils/utils.py b/examples/st_game/utils/utils.py index fb94d8c5c..e6b29a667 100644 --- a/examples/st_game/utils/utils.py +++ b/examples/st_game/utils/utils.py @@ -7,6 +7,7 @@ import json import openai from pathlib import Path import csv +from ..prompts.run_gpt_prompts import get_poignancy_action, get_poignancy_chat def read_json_file(json_file: str, encoding=None) -> list[Any]: @@ -25,6 +26,7 @@ def write_json_file(json_file: str, data: list, encoding=None): with open(json_file, "w", encoding=encoding) as fout: json.dump(data, fout, ensure_ascii=False, indent=4) + def embedding_tools(query): embedding_result = openai.Embedding.create( model="text-embedding-ada-002", @@ -33,6 +35,7 @@ def embedding_tools(query): embedding_key = embedding_result['data'][0]["embedding"] return embedding_key + def read_csv_to_list(curr_file: str, header=False, strip_trail=True): """ Reads in a csv file to a list of list. If header is True, it returns a @@ -61,3 +64,20 @@ def read_csv_to_list(curr_file: str, header=False, strip_trail=True): analysis_list += [row] return analysis_list[0], analysis_list[1:] + +def get_embedding(text, model: str="text-embedding-ada-002"): + text = text.replace("\n", " ") + if not text: + text = "this is blank" + return openai.Embedding.create( + input=[text], model=model)['data'][0]['embedding'] + + +def generate_poig_score(scratch, event_type, description): + if "is idle" in description: + return 1 + if event_type == "action": + return get_poignancy_action(scratch, description)[0] + elif event_type == "chat": + return get_poignancy_chat(scratch, description)[0] +