diff --git a/examples/st_game/maze.py b/examples/st_game/maze.py index 1e2ef8ccc..98edbff7c 100644 --- a/examples/st_game/maze.py +++ b/examples/st_game/maze.py @@ -7,7 +7,7 @@ Author: Joon Sung Park (joonspk@stanford.edu) File: maze.py Description: Defines the Maze class, which represents the map of the simulated -world in a 2-dimensional matrix. +world in a 2-dimensional matrix. """ import json @@ -17,400 +17,396 @@ import networkx as nx from .utils.const import MAZE_ASSET_PATH from .utils.utils import read_csv_to_list -class Maze: - def __init__(self, maze_asset_path: Path = MAZE_ASSET_PATH) -> None: - # READING IN THE BASIC META INFORMATION ABOUT THE MAP - self.maze_asset_path = maze_asset_path - maze_matrix_path = maze_asset_path.joinpath("matrix") - # Reading in the meta information about the world. If you want tp see the - # example variables, check out the maze_meta_info.json file. - meta_info = json.load(open(maze_matrix_path.joinpath("maze_meta_info.json"))) - # and denote the number of tiles make up the - # height and width of the map. - self.maze_width = int(meta_info["maze_width"]) - self.maze_height = int(meta_info["maze_height"]) - # denotes the pixel height/width of a tile. - self.sq_tile_size = int(meta_info["sq_tile_size"]) - # is a string description of any relevant special - # constraints the world might have. - # e.g., "planning to stay at home all day and never go out of her home" - self.special_constraint = meta_info["special_constraint"] - # READING IN SPECIAL BLOCKS - # Special blocks are those that are colored in the Tiled map. +class Maze: + def __init__(self, maze_asset_path: Path = MAZE_ASSET_PATH) -> None: + # READING IN THE BASIC META INFORMATION ABOUT THE MAP + self.maze_asset_path = maze_asset_path + maze_matrix_path = maze_asset_path.joinpath("matrix") + # Reading in the meta information about the world. If you want tp see the + # example variables, check out the maze_meta_info.json file. + meta_info = json.load(open(maze_matrix_path.joinpath("maze_meta_info.json"))) + # and denote the number of tiles make up the + # height and width of the map. + self.maze_width = int(meta_info["maze_width"]) + self.maze_height = int(meta_info["maze_height"]) + # denotes the pixel height/width of a tile. + self.sq_tile_size = int(meta_info["sq_tile_size"]) + # is a string description of any relevant special + # constraints the world might have. + # e.g., "planning to stay at home all day and never go out of her home" + self.special_constraint = meta_info["special_constraint"] - # Here is an example row for the arena block file: - # e.g., "25335, Double Studio, Studio, Common Room" - # And here is another example row for the game object block file: - # e.g, "25331, Double Studio, Studio, Bedroom 2, Painting" + # READING IN SPECIAL BLOCKS + # Special blocks are those that are colored in the Tiled map. - # Notice that the first element here is the color marker digit from the - # Tiled export. Then we basically have the block path: - # World, Sector, Arena, Game Object -- again, these paths need to be - # unique within an instance of Reverie. - blocks_folder = maze_matrix_path.joinpath("special_blocks") + # Here is an example row for the arena block file: + # e.g., "25335, Double Studio, Studio, Common Room" + # And here is another example row for the game object block file: + # e.g, "25331, Double Studio, Studio, Bedroom 2, Painting" - _wb = blocks_folder.joinpath("world_blocks.csv") - wb_rows = read_csv_to_list(_wb, header=False) - wb = wb_rows[0][-1] - - _sb = blocks_folder.joinpath("sector_blocks.csv") - sb_rows = read_csv_to_list(_sb, header=False) - sb_dict = dict() - for i in sb_rows: sb_dict[i[0]] = i[-1] - - _ab = blocks_folder.joinpath("arena_blocks.csv") - ab_rows = read_csv_to_list(_ab, header=False) - ab_dict = dict() - for i in ab_rows: ab_dict[i[0]] = i[-1] - - _gob = blocks_folder.joinpath("game_object_blocks.csv") - gob_rows = read_csv_to_list(_gob, header=False) - gob_dict = dict() - for i in gob_rows: gob_dict[i[0]] = i[-1] - - _slb = blocks_folder.joinpath("spawning_location_blocks.csv") - slb_rows = read_csv_to_list(_slb, header=False) - slb_dict = dict() - for i in slb_rows: slb_dict[i[0]] = i[-1] + # Notice that the first element here is the color marker digit from the + # Tiled export. Then we basically have the block path: + # World, Sector, Arena, Game Object -- again, these paths need to be + # unique within an instance of Reverie. + blocks_folder = maze_matrix_path.joinpath("special_blocks") - # [SECTION 3] Reading in the matrices - # This is your typical two dimensional matrices. It's made up of 0s and - # the number that represents the color block from the blocks folder. - maze_folder = maze_matrix_path.joinpath("maze") + _wb = blocks_folder.joinpath("world_blocks.csv") + wb_rows = read_csv_to_list(_wb, header=False) + wb = wb_rows[0][-1] - _cm = maze_folder.joinpath("collision_maze.csv") - collision_maze_raw = read_csv_to_list(_cm, header=False)[0] - _sm = maze_folder.joinpath("sector_maze.csv") - sector_maze_raw = read_csv_to_list(_sm, header=False)[0] - _am = maze_folder.joinpath("arena_maze.csv") - arena_maze_raw = read_csv_to_list(_am, header=False)[0] - _gom = maze_folder.joinpath("game_object_maze.csv") - game_object_maze_raw = read_csv_to_list(_gom, header=False)[0] - _slm = maze_folder.joinpath("spawning_location_maze.csv") - spawning_location_maze_raw = read_csv_to_list(_slm, header=False)[0] + _sb = blocks_folder.joinpath("sector_blocks.csv") + sb_rows = read_csv_to_list(_sb, header=False) + sb_dict = dict() + for i in sb_rows: + sb_dict[i[0]] = i[-1] - # Loading the maze. The mazes are taken directly from the json exports of - # Tiled maps. They should be in csv format. - # Importantly, they are "not" in a 2-d matrix format -- they are single - # row matrices with the length of width x height of the maze. So we need - # to convert here. - # We can do this all at once since the dimension of all these matrices are - # identical (e.g., 70 x 40). - # example format: [['0', '0', ... '25309', '0',...], ['0',...]...] - # 25309 is the collision bar number right now. - self.collision_maze = [] - sector_maze = [] - arena_maze = [] - game_object_maze = [] - spawning_location_maze = [] - for i in range(0, len(collision_maze_raw), meta_info["maze_width"]): - tw = meta_info["maze_width"] - self.collision_maze += [collision_maze_raw[i:i+tw]] - sector_maze += [sector_maze_raw[i:i+tw]] - arena_maze += [arena_maze_raw[i:i+tw]] - game_object_maze += [game_object_maze_raw[i:i+tw]] - spawning_location_maze += [spawning_location_maze_raw[i:i+tw]] + _ab = blocks_folder.joinpath("arena_blocks.csv") + ab_rows = read_csv_to_list(_ab, header=False) + ab_dict = dict() + for i in ab_rows: + ab_dict[i[0]] = i[-1] - # Once we are done loading in the maze, we now set up self.tiles. This is - # a matrix accessed by row:col where each access point is a dictionary - # that contains all the things that are taking place in that tile. - # More specifically, it contains information about its "world," "sector," - # "arena," "game_object," "spawning_location," as well as whether it is a - # collision block, and a set of all events taking place in it. - # e.g., self.tiles[32][59] = {'world': 'double studio', - # 'sector': '', 'arena': '', 'game_object': '', - # 'spawning_location': '', 'collision': False, 'events': set()} - # e.g., self.tiles[9][58] = {'world': 'double studio', - # 'sector': 'double studio', 'arena': 'bedroom 2', - # 'game_object': 'bed', 'spawning_location': 'bedroom-2-a', - # 'collision': False, - # 'events': {('double studio:double studio:bedroom 2:bed', - # None, None)}} - self.tiles = [] - for i in range(self.maze_height): - row = [] - for j in range(self.maze_width): - tile_details = dict() - tile_details["world"] = wb - - tile_details["sector"] = "" - if sector_maze[i][j] in sb_dict: - tile_details["sector"] = sb_dict[sector_maze[i][j]] - - tile_details["arena"] = "" - if arena_maze[i][j] in ab_dict: - tile_details["arena"] = ab_dict[arena_maze[i][j]] - - tile_details["game_object"] = "" - if game_object_maze[i][j] in gob_dict: - tile_details["game_object"] = gob_dict[game_object_maze[i][j]] - - tile_details["spawning_location"] = "" - if spawning_location_maze[i][j] in slb_dict: - tile_details["spawning_location"] = slb_dict[spawning_location_maze[i][j]] - - tile_details["collision"] = False - if self.collision_maze[i][j] != "0": - tile_details["collision"] = True + _gob = blocks_folder.joinpath("game_object_blocks.csv") + gob_rows = read_csv_to_list(_gob, header=False) + gob_dict = dict() + for i in gob_rows: + gob_dict[i[0]] = i[-1] - tile_details["events"] = set() - - row += [tile_details] - self.tiles += [row] - # Each game object occupies an event in the tile. We are setting up the - # default event value here. - for i in range(self.maze_height): - for j in range(self.maze_width): - if self.tiles[i][j]["game_object"]: - object_name = ":".join([self.tiles[i][j]["world"], - self.tiles[i][j]["sector"], - self.tiles[i][j]["arena"], - self.tiles[i][j]["game_object"]]) - go_event = (object_name, None, None, None) - self.tiles[i][j]["events"].add(go_event) + _slb = blocks_folder.joinpath("spawning_location_blocks.csv") + slb_rows = read_csv_to_list(_slb, header=False) + slb_dict = dict() + for i in slb_rows: + slb_dict[i[0]] = i[-1] - # Reverse tile access. - # -- given a string address, we return a set of all - # tile coordinates belonging to that address (this is opposite of - # self.tiles that give you the string address given a coordinate). This is - # an optimization component for finding paths for the personas' movement. - # self.address_tiles['bedroom-2-a'] == {(58, 9)} - # self.address_tiles['double studio:recreation:pool table'] - # == {(29, 14), (31, 11), (30, 14), (32, 11), ...}, - self.address_tiles = dict() - for i in range(self.maze_height): - for j in range(self.maze_width): - addresses = [] - if self.tiles[i][j]["sector"]: - add = f'{self.tiles[i][j]["world"]}:' - add += f'{self.tiles[i][j]["sector"]}' - addresses += [add] - if self.tiles[i][j]["arena"]: - add = f'{self.tiles[i][j]["world"]}:' - add += f'{self.tiles[i][j]["sector"]}:' - add += f'{self.tiles[i][j]["arena"]}' - addresses += [add] - if self.tiles[i][j]["game_object"]: - add = f'{self.tiles[i][j]["world"]}:' - add += f'{self.tiles[i][j]["sector"]}:' - add += f'{self.tiles[i][j]["arena"]}:' - add += f'{self.tiles[i][j]["game_object"]}' - addresses += [add] - if self.tiles[i][j]["spawning_location"]: - add = f'{self.tiles[i][j]["spawning_location"]}' - addresses += [add] + # [SECTION 3] Reading in the matrices + # This is your typical two dimensional matrices. It's made up of 0s and + # the number that represents the color block from the blocks folder. + maze_folder = maze_matrix_path.joinpath("maze") - for add in addresses: - if add in self.address_tiles: - self.address_tiles[add].add((j, i)) - else: - self.address_tiles[add] = set([(j, i)]) + _cm = maze_folder.joinpath("collision_maze.csv") + collision_maze_raw = read_csv_to_list(_cm, header=False)[0] + _sm = maze_folder.joinpath("sector_maze.csv") + sector_maze_raw = read_csv_to_list(_sm, header=False)[0] + _am = maze_folder.joinpath("arena_maze.csv") + arena_maze_raw = read_csv_to_list(_am, header=False)[0] + _gom = maze_folder.joinpath("game_object_maze.csv") + game_object_maze_raw = read_csv_to_list(_gom, header=False)[0] + _slm = maze_folder.joinpath("spawning_location_maze.csv") + spawning_location_maze_raw = read_csv_to_list(_slm, header=False)[0] - # Build an nx.Graph. - grid_graph = nx.grid_2d_graph(m=self.maze_width, n=self.maze_height) - for i in range(self.maze_height): - for j in range(self.maze_width): - if self.collision_maze[i][j]!=0: - grid_graph.remove_node((i,j)) - self.nx_graph = grid_graph + # Loading the maze. The mazes are taken directly from the json exports of + # Tiled maps. They should be in csv format. + # Importantly, they are "not" in a 2-d matrix format -- they are single + # row matrices with the length of width x height of the maze. So we need + # to convert here. + # We can do this all at once since the dimension of all these matrices are + # identical (e.g., 70 x 40). + # example format: [['0', '0', ... '25309', '0',...], ['0',...]...] + # 25309 is the collision bar number right now. + self.collision_maze = [] + sector_maze = [] + arena_maze = [] + game_object_maze = [] + spawning_location_maze = [] + for i in range(0, len(collision_maze_raw), meta_info["maze_width"]): + tw = meta_info["maze_width"] + self.collision_maze += [collision_maze_raw[i:i + tw]] + sector_maze += [sector_maze_raw[i:i + tw]] + arena_maze += [arena_maze_raw[i:i + tw]] + game_object_maze += [game_object_maze_raw[i:i + tw]] + spawning_location_maze += [spawning_location_maze_raw[i:i + tw]] + # Once we are done loading in the maze, we now set up self.tiles. This is + # a matrix accessed by row:col where each access point is a dictionary + # that contains all the things that are taking place in that tile. + # More specifically, it contains information about its "world," "sector," + # "arena," "game_object," "spawning_location," as well as whether it is a + # collision block, and a set of all events taking place in it. + # e.g., self.tiles[32][59] = {'world': 'double studio', + # 'sector': '', 'arena': '', 'game_object': '', + # 'spawning_location': '', 'collision': False, 'events': set()} + # e.g., self.tiles[9][58] = {'world': 'double studio', + # 'sector': 'double studio', 'arena': 'bedroom 2', + # 'game_object': 'bed', 'spawning_location': 'bedroom-2-a', + # 'collision': False, + # 'events': {('double studio:double studio:bedroom 2:bed', + # None, None)}} + self.tiles = [] + for i in range(self.maze_height): + row = [] + for j in range(self.maze_width): + tile_details = dict() + tile_details["world"] = wb - def turn_coordinate_to_tile(self, px_coordinate: tuple[int, int]) -> tuple[int, int]: - """ - Turns a pixel coordinate to a tile coordinate. + tile_details["sector"] = "" + if sector_maze[i][j] in sb_dict: + tile_details["sector"] = sb_dict[sector_maze[i][j]] - INPUT - px_coordinate: The pixel coordinate of our interest. Comes in the x, y - format. - OUTPUT - tile coordinate (x, y): The tile coordinate that corresponds to the - pixel coordinate. - EXAMPLE OUTPUT - Given (1600, 384), outputs (50, 12) - """ - x = math.ceil(px_coordinate[0]/self.sq_tile_size) - y = math.ceil(px_coordinate[1]/self.sq_tile_size) - return (x, y) + tile_details["arena"] = "" + if arena_maze[i][j] in ab_dict: + tile_details["arena"] = ab_dict[arena_maze[i][j]] + tile_details["game_object"] = "" + if game_object_maze[i][j] in gob_dict: + tile_details["game_object"] = gob_dict[game_object_maze[i][j]] - def access_tile(self, tile: tuple[int, int]) -> dict: - """ - Returns the tiles details dictionary that is stored in self.tiles of the - designated x, y location. + tile_details["spawning_location"] = "" + if spawning_location_maze[i][j] in slb_dict: + tile_details["spawning_location"] = slb_dict[spawning_location_maze[i][j]] - INPUT - tile: The tile coordinate of our interest in (x, y) form. - OUTPUT - The tile detail dictionary for the designated tile. - EXAMPLE OUTPUT - Given (58, 9), - self.tiles[9][58] = {'world': 'double studio', - 'sector': 'double studio', 'arena': 'bedroom 2', - 'game_object': 'bed', 'spawning_location': 'bedroom-2-a', - 'collision': False, - 'events': {('double studio:double studio:bedroom 2:bed', - None, None)}} - """ - x = tile[0] - y = tile[1] - return self.tiles[y][x] + tile_details["collision"] = False + if self.collision_maze[i][j] != "0": + tile_details["collision"] = True + tile_details["events"] = set() - def get_tile_path(self, tile: tuple[int, int], level: str) -> str: - """ - Get the tile string address given its coordinate. You designate the level - by giving it a string level description. + row += [tile_details] + self.tiles += [row] + # Each game object occupies an event in the tile. We are setting up the + # default event value here. + for i in range(self.maze_height): + for j in range(self.maze_width): + if self.tiles[i][j]["game_object"]: + object_name = ":".join([self.tiles[i][j]["world"], + self.tiles[i][j]["sector"], + self.tiles[i][j]["arena"], + self.tiles[i][j]["game_object"]]) + go_event = (object_name, None, None, None) + self.tiles[i][j]["events"].add(go_event) - INPUT: - tile: The tile coordinate of our interest in (x, y) form. - level: world, sector, arena, or game object - OUTPUT - The string address for the tile. - EXAMPLE OUTPUT - Given tile=(58, 9), and level=arena, - "double studio:double studio:bedroom 2" - """ - x = tile[0] - y = tile[1] - tile = self.tiles[y][x] + # Reverse tile access. + # -- given a string address, we return a set of all + # tile coordinates belonging to that address (this is opposite of + # self.tiles that give you the string address given a coordinate). This is + # an optimization component for finding paths for the personas' movement. + # self.address_tiles['bedroom-2-a'] == {(58, 9)} + # self.address_tiles['double studio:recreation:pool table'] + # == {(29, 14), (31, 11), (30, 14), (32, 11), ...}, + self.address_tiles = dict() + for i in range(self.maze_height): + for j in range(self.maze_width): + addresses = [] + if self.tiles[i][j]["sector"]: + add = f'{self.tiles[i][j]["world"]}:' + add += f'{self.tiles[i][j]["sector"]}' + addresses += [add] + if self.tiles[i][j]["arena"]: + add = f'{self.tiles[i][j]["world"]}:' + add += f'{self.tiles[i][j]["sector"]}:' + add += f'{self.tiles[i][j]["arena"]}' + addresses += [add] + if self.tiles[i][j]["game_object"]: + add = f'{self.tiles[i][j]["world"]}:' + add += f'{self.tiles[i][j]["sector"]}:' + add += f'{self.tiles[i][j]["arena"]}:' + add += f'{self.tiles[i][j]["game_object"]}' + addresses += [add] + if self.tiles[i][j]["spawning_location"]: + add = f'{self.tiles[i][j]["spawning_location"]}' + addresses += [add] - path = f"{tile['world']}" - if level == "world": - return path - else: - path += f":{tile['sector']}" - - if level == "sector": - return path - else: - path += f":{tile['arena']}" + for add in addresses: + if add in self.address_tiles: + self.address_tiles[add].add((j, i)) + else: + self.address_tiles[add] = set([(j, i)]) - if level == "arena": - return path - else: - path += f":{tile['game_object']}" + # Build an nx.Graph. + grid_graph = nx.grid_2d_graph(m=self.maze_width, n=self.maze_height) + for i in range(self.maze_height): + for j in range(self.maze_width): + if self.collision_maze[i][j] != 0: + grid_graph.remove_node((i, j)) + self.nx_graph = grid_graph - return path + def turn_coordinate_to_tile(self, px_coordinate: tuple[int, int]) -> tuple[int, int]: + """ + Turns a pixel coordinate to a tile coordinate. + INPUT + px_coordinate: The pixel coordinate of our interest. Comes in the x, y + format. + OUTPUT + tile coordinate (x, y): The tile coordinate that corresponds to the + pixel coordinate. + EXAMPLE OUTPUT + Given (1600, 384), outputs (50, 12) + """ + x = math.ceil(px_coordinate[0] / self.sq_tile_size) + y = math.ceil(px_coordinate[1] / self.sq_tile_size) + return (x, y) - def get_nearby_tiles(self, tile: tuple[int, int], vision_r: int) -> list[tuple[int, int]]: - """ - Given the current tile and vision_r, return a list of tiles that are - within the radius. Note that this implementation looks at a square - boundary when determining what is within the radius. - i.e., for vision_r, returns x's. - x x x x x - x x x x x - x x P x x - x x x x x - x x x x x + def access_tile(self, tile: tuple[int, int]) -> dict: + """ + Returns the tiles details dictionary that is stored in self.tiles of the + designated x, y location. - INPUT: - tile: The tile coordinate of our interest in (x, y) form. - vision_r: The radius of the persona's vision. - OUTPUT: - nearby_tiles: a list of tiles that are within the radius. - """ - left_end = 0 - if tile[0] - vision_r > left_end: - left_end = tile[0] - vision_r + INPUT + tile: The tile coordinate of our interest in (x, y) form. + OUTPUT + The tile detail dictionary for the designated tile. + EXAMPLE OUTPUT + Given (58, 9), + self.tiles[9][58] = {'world': 'double studio', + 'sector': 'double studio', 'arena': 'bedroom 2', + 'game_object': 'bed', 'spawning_location': 'bedroom-2-a', + 'collision': False, + 'events': {('double studio:double studio:bedroom 2:bed', + None, None)}} + """ + x = tile[0] + y = tile[1] + return self.tiles[y][x] - right_end = self.maze_width - 1 - if tile[0] + vision_r + 1 < right_end: - right_end = tile[0] + vision_r + 1 + def get_tile_path(self, tile: tuple[int, int], level: str) -> str: + """ + Get the tile string address given its coordinate. You designate the level + by giving it a string level description. - bottom_end = self.maze_height - 1 - if tile[1] + vision_r + 1 < bottom_end: - bottom_end = tile[1] + vision_r + 1 + INPUT: + tile: The tile coordinate of our interest in (x, y) form. + level: world, sector, arena, or game object + OUTPUT + The string address for the tile. + EXAMPLE OUTPUT + Given tile=(58, 9), and level=arena, + "double studio:double studio:bedroom 2" + """ + x = tile[0] + y = tile[1] + tile = self.tiles[y][x] - top_end = 0 - if tile[1] - vision_r > top_end: - top_end = tile[1] - vision_r + path = f"{tile['world']}" + if level == "world": + return path + else: + path += f":{tile['sector']}" - nearby_tiles = [] - for i in range(left_end, right_end): - for j in range(top_end, bottom_end): - nearby_tiles += [(i, j)] - return nearby_tiles + if level == "sector": + return path + else: + path += f":{tile['arena']}" + if level == "arena": + return path + else: + path += f":{tile['game_object']}" - def add_event_from_tile(self, curr_event: tuple[str], tile: tuple[int, int]) -> None: - """ - Add an event triple to a tile. + return path - INPUT: - curr_event: Current event triple. - e.g., ('double studio:double studio:bedroom 2:bed', None, - None) - tile: The tile coordinate of our interest in (x, y) form. - OUPUT: - None - """ - self.tiles[tile[1]][tile[0]]["events"].add(curr_event) + def get_nearby_tiles(self, tile: tuple[int, int], vision_r: int) -> list[tuple[int, int]]: + """ + Given the current tile and vision_r, return a list of tiles that are + within the radius. Note that this implementation looks at a square + boundary when determining what is within the radius. + i.e., for vision_r, returns x's. + x x x x x + x x x x x + x x P x x + x x x x x + x x x x x + INPUT: + tile: The tile coordinate of our interest in (x, y) form. + vision_r: The radius of the persona's vision. + OUTPUT: + nearby_tiles: a list of tiles that are within the radius. + """ + left_end = 0 + if tile[0] - vision_r > left_end: + left_end = tile[0] - vision_r - def remove_event_from_tile(self, curr_event: tuple[str], tile: tuple[int, int]) -> None: - """dswaq - Remove an event triple from a tile. + right_end = self.maze_width - 1 + if tile[0] + vision_r + 1 < right_end: + right_end = tile[0] + vision_r + 1 - INPUT: - curr_event: Current event triple. - e.g., ('double studio:double studio:bedroom 2:bed', None, - None) - tile: The tile coordinate of our interest in (x, y) form. - OUPUT: - None - """ - curr_tile_ev_cp = self.tiles[tile[1]][tile[0]]["events"].copy() - for event in curr_tile_ev_cp: - if event == curr_event: - self.tiles[tile[1]][tile[0]]["events"].remove(event) + bottom_end = self.maze_height - 1 + if tile[1] + vision_r + 1 < bottom_end: + bottom_end = tile[1] + vision_r + 1 + top_end = 0 + if tile[1] - vision_r > top_end: + top_end = tile[1] - vision_r - def turn_event_from_tile_idle(self, curr_event: tuple[str], tile: tuple[int, int]) -> None: - curr_tile_ev_cp = self.tiles[tile[1]][tile[0]]["events"].copy() - for event in curr_tile_ev_cp: - if event == curr_event: - self.tiles[tile[1]][tile[0]]["events"].remove(event) - new_event = (event[0], None, None, None) - self.tiles[tile[1]][tile[0]]["events"].add(new_event) + nearby_tiles = [] + for i in range(left_end, right_end): + for j in range(top_end, bottom_end): + nearby_tiles += [(i, j)] + return nearby_tiles + def add_event_from_tile(self, curr_event: tuple[str], tile: tuple[int, int]) -> None: + """ + Add an event triple to a tile. - def remove_subject_events_from_tile(self, subject: str, tile: tuple[int, int]) -> None: - """ - Remove an event triple that has the input subject from a tile. + INPUT: + curr_event: Current event triple. + e.g., ('double studio:double studio:bedroom 2:bed', None, + None) + tile: The tile coordinate of our interest in (x, y) form. + OUPUT: + None + """ + self.tiles[tile[1]][tile[0]]["events"].add(curr_event) - INPUT: - subject: "Isabella Rodriguez" - tile: The tile coordinate of our interest in (x, y) form. - OUPUT: - None - """ - curr_tile_ev_cp = self.tiles[tile[1]][tile[0]]["events"].copy() - for event in curr_tile_ev_cp: - if event[0] == subject: - self.tiles[tile[1]][tile[0]]["events"].remove(event) + def remove_event_from_tile(self, curr_event: tuple[str], tile: tuple[int, int]) -> None: + """dswaq + Remove an event triple from a tile. + INPUT: + curr_event: Current event triple. + e.g., ('double studio:double studio:bedroom 2:bed', None, + None) + tile: The tile coordinate of our interest in (x, y) form. + OUPUT: + None + """ + curr_tile_ev_cp = self.tiles[tile[1]][tile[0]]["events"].copy() + for event in curr_tile_ev_cp: + if event == curr_event: + self.tiles[tile[1]][tile[0]]["events"].remove(event) - def _find_closest_node(self, coords: tuple[int, int]) -> tuple[int, int]: - target_coords = self.nx_graph.nodes - min_dist = None - closest_coordinate = None - for target in target_coords: - dist = math.dist(coords, target) - if not closest_coordinate: - min_dist = dist - closest_coordinate = target - else: - if min_dist > dist: - min_dist = dist - closest_coordinate = target - return closest_coordinate + def turn_event_from_tile_idle(self, curr_event: tuple[str], tile: tuple[int, int]) -> None: + curr_tile_ev_cp = self.tiles[tile[1]][tile[0]]["events"].copy() + for event in curr_tile_ev_cp: + if event == curr_event: + self.tiles[tile[1]][tile[0]]["events"].remove(event) + new_event = (event[0], None, None, None) + self.tiles[tile[1]][tile[0]]["events"].add(new_event) - def find_path(self, start: tuple[int, int], end: tuple[int, int]) -> list[tuple[int, int]]: - if start not in self.nx_graph.nodes: - start = self._find_closest_node(start) - if end not in self.nx_graph.nodes: - end = self._find_closest_node(end) - return nx.shortest_path(self.nx_graph, start, end) + def remove_subject_events_from_tile(self, subject: str, tile: tuple[int, int]) -> None: + """ + Remove an event triple that has the input subject from a tile. + + INPUT: + subject: "Isabella Rodriguez" + tile: The tile coordinate of our interest in (x, y) form. + OUPUT: + None + """ + curr_tile_ev_cp = self.tiles[tile[1]][tile[0]]["events"].copy() + for event in curr_tile_ev_cp: + if event[0] == subject: + self.tiles[tile[1]][tile[0]]["events"].remove(event) + + def _find_closest_node(self, coords: tuple[int, int]) -> tuple[int, int]: + target_coords = self.nx_graph.nodes + min_dist = None + closest_coordinate = None + for target in target_coords: + dist = math.dist(coords, target) + if not closest_coordinate: + min_dist = dist + closest_coordinate = target + else: + if min_dist > dist: + min_dist = dist + closest_coordinate = target + return closest_coordinate + + def find_path(self, start: tuple[int, int], end: tuple[int, int]) -> list[tuple[int, int]]: + if start not in self.nx_graph.nodes: + start = self._find_closest_node(start) + if end not in self.nx_graph.nodes: + end = self._find_closest_node(end) + return nx.shortest_path(self.nx_graph, start, end) diff --git a/examples/st_game/memory/agent_memory.py b/examples/st_game/memory/agent_memory.py index 59a61155f..ff93965dc 100644 --- a/examples/st_game/memory/agent_memory.py +++ b/examples/st_game/memory/agent_memory.py @@ -14,39 +14,40 @@ class BasicMemory(Message): created: datetime, expiration: datetime, subject: str, predicate: str, object: str, content: str, embedding_key: str, poignancy: int, keywords: list, filling: list, - cause_by = ""): + cause_by=""): """ BasicMemory继承于MG的Message类,其中content属性替代description属性 Message类中对于Chat类型支持的非常好,对于Agent个体的Perceive,Reflection,Plan支持的并不多 在Type设计上,我们延续GA的三个种类,但是对于Chat种类的对话进行特别设计(具体怎么设计还没想好) """ - super().__init__(content,cause_by=cause_by) + super().__init__(content, cause_by=cause_by) """ 从父类中继承的属性 content: str # 记忆描述 cause_by: Type["Action"] = field(default="") # 触发动作,只在Type为chat时初始化 - cause_by 接受一个Action类,在此项目中,每个Agent需要有一个基础动作[Receive] 用于接受假对话Message;而每个Agent需要有独一无二的动作类,用以接受真对话Message + cause_by 接受一个Action类,在此项目中,每个Agent需要有一个基础动作[Receive] 用于接受假对话Message; + 而每个Agent需要有独一无二的动作类,用以接受真对话Message """ - self.memory_id: str = memory_id # 记忆ID - self.memory_count: int = memory_count # 第几个记忆,实际数值与Memory相等 - self.type_count: int = type_count # 第几种记忆,类型为整数(具体不太理解如何生成的) - self.memory_type: str = memory_type # 记忆类型,包含 event,thought,chat三种类型 - self.depth: str = depth # 记忆深度,类型为整数 + self.memory_id: str = memory_id # 记忆ID + self.memory_count: int = memory_count # 第几个记忆,实际数值与Memory相等 + self.type_count: int = type_count # 第几种记忆,类型为整数(具体不太理解如何生成的) + self.memory_type: str = memory_type # 记忆类型,包含 event,thought,chat三种类型 + self.depth: str = depth # 记忆深度,类型为整数 - self.created: datetime = created # 创建时间 - self.expiration: datetime = expiration # 记忆失效时间,默认为空() - self.last_accessed: datetime = created # 上一次调用的时间,初始化时候与self.created一致 + self.created: datetime = created # 创建时间 + self.expiration: datetime = expiration # 记忆失效时间,默认为空() + self.last_accessed: datetime = created # 上一次调用的时间,初始化时候与self.created一致 - self.subject: str = subject # 主语 - self.predicate: str = predicate # 谓语 - self.object: str = object # 宾语 + self.subject: str = subject # 主语 + self.predicate: str = predicate # 谓语 + self.object: str = object # 宾语 - self.embedding_key: str = embedding_key # 内容与self.content一致 - self.poignancy: int = poignancy # importance值 - self.keywords: list = keywords # keywords - self.filling: list = filling # 装的与之相关联的memory_id的列表 + self.embedding_key: str = embedding_key # 内容与self.content一致 + self.poignancy: int = poignancy # importance值 + self.keywords: list = keywords # keywords + self.filling: list = filling # 装的与之相关联的memory_id的列表 - def summary(self): + def summary(self): return (self.subject, self.predicate, self.object) def save_to_dict(self) -> dict: @@ -65,9 +66,9 @@ class BasicMemory(Message): memory_dict[node_id]["cmemory_dicteated"] = self.created.strftime('%Y-%m-%d %H:%M:%S') memory_dict[node_id]["expiration"] = None - if self.expiration: + if self.expiration: memory_dict[node_id]["expiration"] = (self.expiration - .strftime('%Y-%m-%d %H:%M:%S')) + .strftime('%Y-%m-%d %H:%M:%S')) memory_dict[node_id]["subject"] = self.subject memory_dict[node_id]["predicate"] = self.predicate @@ -83,6 +84,7 @@ class BasicMemory(Message): return memory_dict + class AgentMemory(Memory): """ GA中主要存储三种JSON @@ -90,6 +92,7 @@ class AgentMemory(Memory): 2. Node.json (Dict Node_id:Node) 3. kw_strength.json """ + def __init__(self, memory_saved: str): """ AgentMemory类继承自Memory类,重写storage替代GA中id_to_node,一方面存储所有信息,一方面作为JSON转化 @@ -97,21 +100,20 @@ class AgentMemory(Memory): @李嵩@张凯 这里的storage是List,你们需要写一个JSON转化器,将List修改为node.json一致的格式 """ super.__init__() - self.storage: list[BasicMemory] = [] # 重写Stroage,存储BasicMemory所有节点 - self.event_list = [] # 存储event记忆 - self.thought_list = [] # 存储thought记忆 + self.storage: list[BasicMemory] = [] # 重写Stroage,存储BasicMemory所有节点 + self.event_list = [] # 存储event记忆 + self.thought_list = [] # 存储thought记忆 - self.event_keywords = dict() # 存储keywords - self.thought_keywords = dict() + self.event_keywords = dict() # 存储keywords + self.thought_keywords = dict() self.chat_keywords = dict() - self.kw_strength_event = dict() # 关键词影响存储 - self.kw_strength_thought = dict() + self.kw_strength_event = dict() # 关键词影响存储 + self.kw_strength_thought = dict() self.load(memory_saved) - - def save(self,memory_saved:str): + def save(self, memory_saved: str): """ 将MemormyBasic类存储为Nodes.json形式。复现GA中的Kw Strength.json形式 这里添加一个路径即可 @@ -121,36 +123,35 @@ class AgentMemory(Memory): for i in range(len(self.storage)): memory_node = self.storage[i] memory_json.update(memory_node) - with open(memory_saved+"/nodes.json", "w") as outfile: + with open(memory_saved + "/nodes.json", "w") as outfile: json.dump(memory_json, outfile) - with open(memory_saved+"/embeddings.json", "w") as outfile: + with open(memory_saved + "/embeddings.json", "w") as outfile: json.dump(self.embeddings, outfile) strength_json = dict() strength_json["kw_strength_event"] = self.kw_strength_event strength_json["kw_strength_thought"] = self.kw_strength_thought - with open(memory_saved+"/kw_strength.json", "w") as outfile: + with open(memory_saved + "/kw_strength.json", "w") as outfile: json.dump(strength_json, outfile) - - def load(self,memory_saved:str): + def load(self, memory_saved: str): """ 将GA的JSON解析,填充到AgentMemory类之中 """ self.embeddings = json.load(open(memory_saved + "/embeddings.json")) memory_load = json.load(open(memory_saved + "/nodes.json")) for count in range(len(memory_load.keys())): - node_id = f"node_{str(count+1)}" + node_id = f"node_{str(count + 1)}" node_details = memory_load[node_id] node_type = node_details["type"] - created = datetime.datetime.strptime(node_details["created"], - '%Y-%m-%d %H:%M:%S') + created = datetime.datetime.strptime(node_details["created"], + '%Y-%m-%d %H:%M:%S') expiration = None - if node_details["expiration"]: + if node_details["expiration"]: expiration = datetime.datetime.strptime(node_details["expiration"], '%Y-%m-%d %H:%M:%S') - + if node_details["cause_by"]: cause_by = node_details["cause_by"] @@ -159,29 +160,28 @@ class AgentMemory(Memory): o = node_details["object"] description = node_details["description"] - embedding_pair = (node_details["embedding_key"], - self.embeddings[node_details["embedding_key"]]) - poignancy =node_details["poignancy"] + embedding_pair = (node_details["embedding_key"], + self.embeddings[node_details["embedding_key"]]) + poignancy = node_details["poignancy"] keywords = set(node_details["keywords"]) filling = node_details["filling"] - - if node_type == "event": - self.add_event(created, expiration, s, p, o, - description, keywords, poignancy, embedding_pair, filling) - elif node_type == "chat": - self.add_chat(created, expiration, s, p, o, - description, keywords, poignancy, embedding_pair, filling,cause_by) - elif node_type == "thought": - self.add_thought(created, expiration, s, p, o, - description, keywords, poignancy, embedding_pair, filling) + + if node_type == "event": + self.add_event(created, expiration, s, p, o, + description, keywords, poignancy, embedding_pair, filling) + elif node_type == "chat": + self.add_chat(created, expiration, s, p, o, + description, keywords, poignancy, embedding_pair, filling, cause_by) + elif node_type == "thought": + self.add_thought(created, expiration, s, p, o, + description, keywords, poignancy, embedding_pair, filling) strength_keywords_load = json.load(open(memory_saved + "/kw_strength.json")) - if strength_keywords_load["kw_strength_event"]: + if strength_keywords_load["kw_strength_event"]: self.kw_strength_event = strength_keywords_load["kw_strength_event"] - if strength_keywords_load["kw_strength_thought"]: + if strength_keywords_load["kw_strength_thought"]: self.kw_strength_thought = strength_keywords_load["kw_strength_thought"] - def add(self, memory_basic: BasicMemory): """ Add a new message to storage, while updating the index @@ -192,18 +192,17 @@ class AgentMemory(Memory): self.storage.append(memory_basic) if memory_basic.cause_by: self.index[memory_basic.cause_by][0:0] = [memory_basic] - return + return if memory_basic.type == "thought": self.thought_list[0:0] = [memory_basic] return if memory_basic.type == "event": self.event_list[0:0] = [memory_basic] - - def add_chat(self, created, expiration, s, p, o, - content, keywords, poignancy, - embedding_pair, filling, - cause_by): + def add_chat(self, created, expiration, s, p, o, + content, keywords, poignancy, + embedding_pair, filling, + cause_by): """ 调用add方法,初始化chat,在创建的时候就需要调用embeeding函数 """ @@ -211,31 +210,30 @@ class AgentMemory(Memory): type_count = len(self.thought_list) + 1 memory_type = "chat" memory_id = f"memory_{str(memory_count)}" - depth = 1 + depth = 1 memory_node = BasicMemory(memory_id, memory_count, type_count, memory_type, depth, created, expiration, - s, p ,o, + s, p, o, content, embedding_pair[0], - poignancy, keywords, filling, + poignancy, keywords, filling, cause_by) keywords = [i.lower() for i in keywords] - for kw in keywords: - if kw in self.chat_keywords: + for kw in keywords: + if kw in self.chat_keywords: self.chat_keywords[kw][0:0] = [memory_node] - else: + else: self.chat_keywords[kw] = [memory_node] - + self.add(memory_node) self.embeddings[embedding_pair[0]] = embedding_pair[1] - return memory_node + return memory_node - - def add_thought(self, created, expiration, s, p, o, - content, keywords, poignancy, - embedding_pair, filling): + def add_thought(self, created, expiration, s, p, o, + content, keywords, poignancy, + embedding_pair, filling): """ 调用add方法,初始化thought """ @@ -243,44 +241,43 @@ class AgentMemory(Memory): type_count = len(self.thought_list) + 1 memory_type = "event" memory_id = f"memory_{str(memory_count)}" - depth = 1 - + depth = 1 + try: - if filling: - depth_list = [memory_node.depth for memory_node in self.storage if memory_node.memory_id in filling ] + if filling: + depth_list = [memory_node.depth for memory_node in self.storage if memory_node.memory_id in filling] depth += max(depth_list) - except: - pass + except Exception as exp: + pass memory_node = BasicMemory(memory_id, memory_count, type_count, memory_type, depth, created, expiration, - s, p ,o, + s, p, o, content, embedding_pair[0], poignancy, keywords, filling) keywords = [i.lower() for i in keywords] - for kw in keywords: - if kw in self.thought_keywords: + for kw in keywords: + if kw in self.thought_keywords: self.thought_keywords[kw][0:0] = [memory_node] - else: + else: self.thought_keywords[kw] = [memory_node] - + self.add(memory_node) - if f"{p} {o}" != "is idle": - for kw in keywords: + if f"{p} {o}" != "is idle": + for kw in keywords: if kw in self.kw_strength_thought: self.kw_strength_thought[kw] += 1 - else: + else: self.kw_strength_thought[kw] = 1 self.embeddings[embedding_pair[0]] = embedding_pair[1] return memory_node - - def add_event(self, created, expiration, s, p, o, - content, keywords, poignancy, - embedding_pair, filling): + def add_event(self, created, expiration, s, p, o, + content, keywords, poignancy, + embedding_pair, filling): """ 调用add方法,初始化event """ @@ -289,40 +286,39 @@ class AgentMemory(Memory): memory_type = "event" memory_id = f"memory_{str(memory_count)}" depth = 0 - + if "(" in content: - content = (" ".join(content.split()[:3]) - + " " - + content.split("(")[-1][:-1]) - + content = (" ".join(content.split()[:3]) + + " " + + content.split("(")[-1][:-1]) + memory_node = BasicMemory(memory_id, memory_count, type_count, memory_type, depth, created, expiration, - s, p ,o, + s, p, o, content, embedding_pair[0], poignancy, keywords, filling) keywords = [i.lower() for i in keywords] - for kw in keywords: - if kw in self.event_keywords: + for kw in keywords: + if kw in self.event_keywords: self.event_keywords[kw][0:0] = [memory_node] - else: + else: self.event_keywords[kw] = [memory_node] - + self.add(memory_node) - if f"{p} {o}" != "is idle": - for kw in keywords: + if f"{p} {o}" != "is idle": + for kw in keywords: if kw in self.kw_strength_event: self.kw_strength_event[kw] += 1 - else: + else: self.kw_strength_event[kw] = 1 self.embeddings[embedding_pair[0]] = embedding_pair[1] return memory_node - - def get_summarized_latest_events(self, retention): + def get_summarized_latest_events(self, retention): ret_set = set() - for e_node in self.event_list[:retention]: + for e_node in self.event_list[:retention]: ret_set.add(e_node.summary()) - return ret_set \ No newline at end of file + return ret_set diff --git a/examples/st_game/memory/retrieve.py b/examples/st_game/memory/retrieve.py index 97eb3b6f0..d35418b77 100644 --- a/examples/st_game/memory/retrieve.py +++ b/examples/st_game/memory/retrieve.py @@ -9,7 +9,8 @@ from examples.st_game.memory.agent_memory import AgentMemory, BasicMemory from utils.utils import embedding_tools -def agent_retrieve(agent_memory: AgentMemory, curr_time: datetime.datetime, memory_forget: float, query: str, n: int = 30, topk: int = 4) -> list[BasicMemory]: +def agent_retrieve(agent_memory: AgentMemory, curr_time: datetime.datetime, memory_forget: float, query: str, + n: int = 30, topk: int = 4) -> list[BasicMemory]: """ Retrieve需要集合Role使用,原因在于Role才具有AgentMemory,scratch 逻辑:Role调用该函数,self._rc.AgentMemory,self._rc.scratch.curr_time,self._rc.scratch.memory_forget @@ -87,7 +88,7 @@ def extract_recency(curr_time, memory_forget, score_list): """ for i in range(len(score_list)): day_count = (curr_time - score_list[i]['memory'].created).days - score_list[i]['recency'] = memory_forget**day_count + score_list[i]['recency'] = memory_forget ** day_count return score_list diff --git a/examples/st_game/memory/scratch.py b/examples/st_game/memory/scratch.py index d0d13002e..19a566fa0 100644 --- a/examples/st_game/memory/scratch.py +++ b/examples/st_game/memory/scratch.py @@ -4,534 +4,509 @@ import datetime import json -import sys -sys.path.append('../../') from ..utils.check import check_if_file_exists -class Scratch: - def __init__(self, f_saved): - # 类别1:人物超参 - self.vision_r = 4 - self.att_bandwidth = 3 - self.retention = 5 - # 类别2:世界信息 - self.curr_time = None - self.curr_tile = None - self.daily_plan_req = None - - # 类别3:人物角色的核心身份 - self.name = None - self.first_name = None - self.last_name = None - self.age = None - # L0 permanent core traits. - self.innate = None - # L1 stable traits. - self.learned = None - # L2 external implementation. - self.currently = None - self.lifestyle = None - self.living_area = None +class Scratch: + def __init__(self, f_saved): + # 类别1:人物超参 + self.vision_r = 4 + self.att_bandwidth = 3 + self.retention = 5 - # 类别4:旧反思变量 - self.concept_forget = 100 - self.daily_reflection_time = 60 * 3 - self.daily_reflection_size = 5 - self.overlap_reflect_th = 2 - self.kw_strg_event_reflect_th = 4 - self.kw_strg_thought_reflect_th = 4 - - # 类别5:新反思变量 - self.recency_w = 1 - self.relevance_w = 1 - self.importance_w = 1 - self.recency_decay = 0.99 - self.importance_trigger_max = 150 - self.importance_trigger_curr = self.importance_trigger_max - self.importance_ele_n = 0 - self.thought_count = 5 - - # 类别6:个人计划 - self.daily_req = [] - self.f_daily_schedule = [] - self.f_daily_schedule_hourly_org = [] - - # 类别7:当前动作 - self.act_address = None - self.act_start_time = None - self.act_duration = None - self.act_description = None - self.act_pronunciatio = None - self.act_event = (self.name, None, None) - - self.act_obj_description = None - self.act_obj_pronunciatio = None - self.act_obj_event = (self.name, None, None) - - self.chatting_with = None - self.chat = None - self.chatting_with_buffer = dict() - self.chatting_end_time = None - - self.act_path_set = False - self.planned_path = [] - - if check_if_file_exists(f_saved): - # If we have a bootstrap file, load that here. - scratch_load = json.load(open(f_saved)) - - self.vision_r = scratch_load["vision_r"] - self.att_bandwidth = scratch_load["att_bandwidth"] - self.retention = scratch_load["retention"] - - if scratch_load["curr_time"]: - self.curr_time = datetime.datetime.strptime(scratch_load["curr_time"], - "%B %d, %Y, %H:%M:%S") - else: + # 类别2:世界信息 self.curr_time = None - self.curr_tile = scratch_load["curr_tile"] - self.daily_plan_req = scratch_load["daily_plan_req"] + self.curr_tile = None + self.daily_plan_req = None - self.name = scratch_load["name"] - self.first_name = scratch_load["first_name"] - self.last_name = scratch_load["last_name"] - self.age = scratch_load["age"] - self.innate = scratch_load["innate"] - self.learned = scratch_load["learned"] - self.currently = scratch_load["currently"] - self.lifestyle = scratch_load["lifestyle"] - self.living_area = scratch_load["living_area"] + # 类别3:人物角色的核心身份 + self.name = None + self.first_name = None + self.last_name = None + self.age = None + # L0 permanent core traits. + self.innate = None + # L1 stable traits. + self.learned = None + # L2 external implementation. + self.currently = None + self.lifestyle = None + self.living_area = None - self.concept_forget = scratch_load["concept_forget"] - self.daily_reflection_time = scratch_load["daily_reflection_time"] - self.daily_reflection_size = scratch_load["daily_reflection_size"] - self.overlap_reflect_th = scratch_load["overlap_reflect_th"] - self.kw_strg_event_reflect_th = scratch_load["kw_strg_event_reflect_th"] - self.kw_strg_thought_reflect_th = scratch_load["kw_strg_thought_reflect_th"] + # 类别4:旧反思变量 + self.concept_forget = 100 + self.daily_reflection_time = 60 * 3 + self.daily_reflection_size = 5 + self.overlap_reflect_th = 2 + self.kw_strg_event_reflect_th = 4 + self.kw_strg_thought_reflect_th = 4 - self.recency_w = scratch_load["recency_w"] - self.relevance_w = scratch_load["relevance_w"] - self.importance_w = scratch_load["importance_w"] - self.recency_decay = scratch_load["recency_decay"] - self.importance_trigger_max = scratch_load["importance_trigger_max"] - self.importance_trigger_curr = scratch_load["importance_trigger_curr"] - self.importance_ele_n = scratch_load["importance_ele_n"] - self.thought_count = scratch_load["thought_count"] + # 类别5:新反思变量 + self.recency_w = 1 + self.relevance_w = 1 + self.importance_w = 1 + self.recency_decay = 0.99 + self.importance_trigger_max = 150 + self.importance_trigger_curr = self.importance_trigger_max + self.importance_ele_n = 0 + self.thought_count = 5 - self.daily_req = scratch_load["daily_req"] - self.f_daily_schedule = scratch_load["f_daily_schedule"] - self.f_daily_schedule_hourly_org = scratch_load["f_daily_schedule_hourly_org"] + # 类别6:个人计划 + self.daily_req = [] + self.f_daily_schedule = [] + self.f_daily_schedule_hourly_org = [] - self.act_address = scratch_load["act_address"] - if scratch_load["act_start_time"]: - self.act_start_time = datetime.datetime.strptime( - scratch_load["act_start_time"], - "%B %d, %Y, %H:%M:%S") - else: - self.curr_time = None - self.act_duration = scratch_load["act_duration"] - self.act_description = scratch_load["act_description"] - self.act_pronunciatio = scratch_load["act_pronunciatio"] - self.act_event = tuple(scratch_load["act_event"]) + # 类别7:当前动作 + self.act_address = None + self.act_start_time = None + self.act_duration = None + self.act_description = None + self.act_pronunciatio = None + self.act_event = (self.name, None, None) - self.act_obj_description = scratch_load["act_obj_description"] - self.act_obj_pronunciatio = scratch_load["act_obj_pronunciatio"] - self.act_obj_event = tuple(scratch_load["act_obj_event"]) + self.act_obj_description = None + self.act_obj_pronunciatio = None + self.act_obj_event = (self.name, None, None) - self.chatting_with = scratch_load["chatting_with"] - self.chat = scratch_load["chat"] - self.chatting_with_buffer = scratch_load["chatting_with_buffer"] - if scratch_load["chatting_end_time"]: - self.chatting_end_time = datetime.datetime.strptime( - scratch_load["chatting_end_time"], - "%B %d, %Y, %H:%M:%S") - else: + self.chatting_with = None + self.chat = None + self.chatting_with_buffer = dict() self.chatting_end_time = None - self.act_path_set = scratch_load["act_path_set"] - self.planned_path = scratch_load["planned_path"] + self.act_path_set = False + self.planned_path = [] + if check_if_file_exists(f_saved): + # If we have a bootstrap file, load that here. + scratch_load = json.load(open(f_saved)) - def save(self, out_json): - """ - Save persona's scratch. + self.vision_r = scratch_load["vision_r"] + self.att_bandwidth = scratch_load["att_bandwidth"] + self.retention = scratch_load["retention"] - INPUT: - out_json: The file where we wil be saving our persona's state. - OUTPUT: - None - """ - scratch = dict() - scratch["vision_r"] = self.vision_r - scratch["att_bandwidth"] = self.att_bandwidth - scratch["retention"] = self.retention + if scratch_load["curr_time"]: + self.curr_time = datetime.datetime.strptime(scratch_load["curr_time"], + "%B %d, %Y, %H:%M:%S") + else: + self.curr_time = None + self.curr_tile = scratch_load["curr_tile"] + self.daily_plan_req = scratch_load["daily_plan_req"] - scratch["curr_time"] = self.curr_time.strftime("%B %d, %Y, %H:%M:%S") - scratch["curr_tile"] = self.curr_tile - scratch["daily_plan_req"] = self.daily_plan_req + self.name = scratch_load["name"] + self.first_name = scratch_load["first_name"] + self.last_name = scratch_load["last_name"] + self.age = scratch_load["age"] + self.innate = scratch_load["innate"] + self.learned = scratch_load["learned"] + self.currently = scratch_load["currently"] + self.lifestyle = scratch_load["lifestyle"] + self.living_area = scratch_load["living_area"] - scratch["name"] = self.name - scratch["first_name"] = self.first_name - scratch["last_name"] = self.last_name - scratch["age"] = self.age - scratch["innate"] = self.innate - scratch["learned"] = self.learned - scratch["currently"] = self.currently - scratch["lifestyle"] = self.lifestyle - scratch["living_area"] = self.living_area + self.concept_forget = scratch_load["concept_forget"] + self.daily_reflection_time = scratch_load["daily_reflection_time"] + self.daily_reflection_size = scratch_load["daily_reflection_size"] + self.overlap_reflect_th = scratch_load["overlap_reflect_th"] + self.kw_strg_event_reflect_th = scratch_load["kw_strg_event_reflect_th"] + self.kw_strg_thought_reflect_th = scratch_load["kw_strg_thought_reflect_th"] - scratch["concept_forget"] = self.concept_forget - scratch["daily_reflection_time"] = self.daily_reflection_time - scratch["daily_reflection_size"] = self.daily_reflection_size - scratch["overlap_reflect_th"] = self.overlap_reflect_th - scratch["kw_strg_event_reflect_th"] = self.kw_strg_event_reflect_th - scratch["kw_strg_thought_reflect_th"] = self.kw_strg_thought_reflect_th + self.recency_w = scratch_load["recency_w"] + self.relevance_w = scratch_load["relevance_w"] + self.importance_w = scratch_load["importance_w"] + self.recency_decay = scratch_load["recency_decay"] + self.importance_trigger_max = scratch_load["importance_trigger_max"] + self.importance_trigger_curr = scratch_load["importance_trigger_curr"] + self.importance_ele_n = scratch_load["importance_ele_n"] + self.thought_count = scratch_load["thought_count"] - scratch["recency_w"] = self.recency_w - scratch["relevance_w"] = self.relevance_w - scratch["importance_w"] = self.importance_w - scratch["recency_decay"] = self.recency_decay - scratch["importance_trigger_max"] = self.importance_trigger_max - scratch["importance_trigger_curr"] = self.importance_trigger_curr - scratch["importance_ele_n"] = self.importance_ele_n - scratch["thought_count"] = self.thought_count + self.daily_req = scratch_load["daily_req"] + self.f_daily_schedule = scratch_load["f_daily_schedule"] + self.f_daily_schedule_hourly_org = scratch_load["f_daily_schedule_hourly_org"] - scratch["daily_req"] = self.daily_req - scratch["f_daily_schedule"] = self.f_daily_schedule - scratch["f_daily_schedule_hourly_org"] = self.f_daily_schedule_hourly_org + self.act_address = scratch_load["act_address"] + if scratch_load["act_start_time"]: + self.act_start_time = datetime.datetime.strptime( + scratch_load["act_start_time"], + "%B %d, %Y, %H:%M:%S") + else: + self.curr_time = None + self.act_duration = scratch_load["act_duration"] + self.act_description = scratch_load["act_description"] + self.act_pronunciatio = scratch_load["act_pronunciatio"] + self.act_event = tuple(scratch_load["act_event"]) - scratch["act_address"] = self.act_address - scratch["act_start_time"] = (self.act_start_time + self.act_obj_description = scratch_load["act_obj_description"] + self.act_obj_pronunciatio = scratch_load["act_obj_pronunciatio"] + self.act_obj_event = tuple(scratch_load["act_obj_event"]) + + self.chatting_with = scratch_load["chatting_with"] + self.chat = scratch_load["chat"] + self.chatting_with_buffer = scratch_load["chatting_with_buffer"] + if scratch_load["chatting_end_time"]: + self.chatting_end_time = datetime.datetime.strptime( + scratch_load["chatting_end_time"], + "%B %d, %Y, %H:%M:%S") + else: + self.chatting_end_time = None + + self.act_path_set = scratch_load["act_path_set"] + self.planned_path = scratch_load["planned_path"] + + def save(self, out_json): + """ + Save persona's scratch. + + INPUT: + out_json: The file where we wil be saving our persona's state. + OUTPUT: + None + """ + scratch = dict() + scratch["vision_r"] = self.vision_r + scratch["att_bandwidth"] = self.att_bandwidth + scratch["retention"] = self.retention + + scratch["curr_time"] = self.curr_time.strftime("%B %d, %Y, %H:%M:%S") + scratch["curr_tile"] = self.curr_tile + scratch["daily_plan_req"] = self.daily_plan_req + + scratch["name"] = self.name + scratch["first_name"] = self.first_name + scratch["last_name"] = self.last_name + scratch["age"] = self.age + scratch["innate"] = self.innate + scratch["learned"] = self.learned + scratch["currently"] = self.currently + scratch["lifestyle"] = self.lifestyle + scratch["living_area"] = self.living_area + + scratch["concept_forget"] = self.concept_forget + scratch["daily_reflection_time"] = self.daily_reflection_time + scratch["daily_reflection_size"] = self.daily_reflection_size + scratch["overlap_reflect_th"] = self.overlap_reflect_th + scratch["kw_strg_event_reflect_th"] = self.kw_strg_event_reflect_th + scratch["kw_strg_thought_reflect_th"] = self.kw_strg_thought_reflect_th + + scratch["recency_w"] = self.recency_w + scratch["relevance_w"] = self.relevance_w + scratch["importance_w"] = self.importance_w + scratch["recency_decay"] = self.recency_decay + scratch["importance_trigger_max"] = self.importance_trigger_max + scratch["importance_trigger_curr"] = self.importance_trigger_curr + scratch["importance_ele_n"] = self.importance_ele_n + scratch["thought_count"] = self.thought_count + + scratch["daily_req"] = self.daily_req + scratch["f_daily_schedule"] = self.f_daily_schedule + scratch["f_daily_schedule_hourly_org"] = self.f_daily_schedule_hourly_org + + scratch["act_address"] = self.act_address + scratch["act_start_time"] = (self.act_start_time .strftime("%B %d, %Y, %H:%M:%S")) - scratch["act_duration"] = self.act_duration - scratch["act_description"] = self.act_description - scratch["act_pronunciatio"] = self.act_pronunciatio - scratch["act_event"] = self.act_event + scratch["act_duration"] = self.act_duration + scratch["act_description"] = self.act_description + scratch["act_pronunciatio"] = self.act_pronunciatio + scratch["act_event"] = self.act_event - scratch["act_obj_description"] = self.act_obj_description - scratch["act_obj_pronunciatio"] = self.act_obj_pronunciatio - scratch["act_obj_event"] = self.act_obj_event + scratch["act_obj_description"] = self.act_obj_description + scratch["act_obj_pronunciatio"] = self.act_obj_pronunciatio + scratch["act_obj_event"] = self.act_obj_event - scratch["chatting_with"] = self.chatting_with - scratch["chat"] = self.chat - scratch["chatting_with_buffer"] = self.chatting_with_buffer - if self.chatting_end_time: - scratch["chatting_end_time"] = (self.chatting_end_time - .strftime("%B %d, %Y, %H:%M:%S")) - else: - scratch["chatting_end_time"] = None + scratch["chatting_with"] = self.chatting_with + scratch["chat"] = self.chat + scratch["chatting_with_buffer"] = self.chatting_with_buffer + if self.chatting_end_time: + scratch["chatting_end_time"] = (self.chatting_end_time + .strftime("%B %d, %Y, %H:%M:%S")) + else: + scratch["chatting_end_time"] = None - scratch["act_path_set"] = self.act_path_set - scratch["planned_path"] = self.planned_path + scratch["act_path_set"] = self.act_path_set + scratch["planned_path"] = self.planned_path - with open(out_json, "w") as outfile: - json.dump(scratch, outfile, indent=2) + with open(out_json, "w") as outfile: + json.dump(scratch, outfile, indent=2) + def get_f_daily_schedule_index(self, advance=0): + """ + We get the current index of self.f_daily_schedule. - def get_f_daily_schedule_index(self, advance=0): - """ - We get the current index of self.f_daily_schedule. + Recall that self.f_daily_schedule stores the decomposed action sequences + up until now, and the hourly sequences of the future action for the rest + of today. Given that self.f_daily_schedule is a list of list where the + inner list is composed of [task, duration], we continue to add up the + duration until we reach "if elapsed > today_min_elapsed" condition. The + index where we stop is the index we will return. - Recall that self.f_daily_schedule stores the decomposed action sequences - up until now, and the hourly sequences of the future action for the rest - of today. Given that self.f_daily_schedule is a list of list where the - inner list is composed of [task, duration], we continue to add up the - duration until we reach "if elapsed > today_min_elapsed" condition. The - index where we stop is the index we will return. + INPUT + advance: Integer value of the number minutes we want to look into the + future. This allows us to get the index of a future timeframe. + OUTPUT + an integer value for the current index of f_daily_schedule. + """ + # We first calculate teh number of minutes elapsed today. + today_min_elapsed = 0 + today_min_elapsed += self.curr_time.hour * 60 + today_min_elapsed += self.curr_time.minute + today_min_elapsed += advance - INPUT - advance: Integer value of the number minutes we want to look into the - future. This allows us to get the index of a future timeframe. - OUTPUT - an integer value for the current index of f_daily_schedule. - """ - # We first calculate teh number of minutes elapsed today. - today_min_elapsed = 0 - today_min_elapsed += self.curr_time.hour * 60 - today_min_elapsed += self.curr_time.minute - today_min_elapsed += advance + x = 0 + for task, duration in self.f_daily_schedule: + x += duration + x = 0 + for task, duration in self.f_daily_schedule_hourly_org: + x += duration - x = 0 - for task, duration in self.f_daily_schedule: - x += duration - x = 0 - for task, duration in self.f_daily_schedule_hourly_org: - x += duration + # We then calculate the current index based on that. + curr_index = 0 + elapsed = 0 + for task, duration in self.f_daily_schedule: + elapsed += duration + if elapsed > today_min_elapsed: + return curr_index + curr_index += 1 - # We then calculate the current index based on that. - curr_index = 0 - elapsed = 0 - for task, duration in self.f_daily_schedule: - elapsed += duration - if elapsed > today_min_elapsed: return curr_index - curr_index += 1 - return curr_index + def get_f_daily_schedule_hourly_org_index(self, advance=0): + """ + We get the current index of self.f_daily_schedule_hourly_org. + It is otherwise the same as get_f_daily_schedule_index. - - def get_f_daily_schedule_hourly_org_index(self, advance=0): - """ - We get the current index of self.f_daily_schedule_hourly_org. - It is otherwise the same as get_f_daily_schedule_index. - - INPUT - advance: Integer value of the number minutes we want to look into the - future. This allows us to get the index of a future timeframe. - OUTPUT - an integer value for the current index of f_daily_schedule. - """ - # We first calculate teh number of minutes elapsed today. - today_min_elapsed = 0 - today_min_elapsed += self.curr_time.hour * 60 - today_min_elapsed += self.curr_time.minute - today_min_elapsed += advance - # We then calculate the current index based on that. - curr_index = 0 - elapsed = 0 - for task, duration in self.f_daily_schedule_hourly_org: - elapsed += duration - if elapsed > today_min_elapsed: + INPUT + advance: Integer value of the number minutes we want to look into the + future. This allows us to get the index of a future timeframe. + OUTPUT + an integer value for the current index of f_daily_schedule. + """ + # We first calculate teh number of minutes elapsed today. + today_min_elapsed = 0 + today_min_elapsed += self.curr_time.hour * 60 + today_min_elapsed += self.curr_time.minute + today_min_elapsed += advance + # We then calculate the current index based on that. + curr_index = 0 + elapsed = 0 + for task, duration in self.f_daily_schedule_hourly_org: + elapsed += duration + if elapsed > today_min_elapsed: + return curr_index + curr_index += 1 return curr_index - curr_index += 1 - return curr_index + def get_str_iss(self): + """ + ISS stands for "identity stable set." This describes the commonset summary + of this persona -- basically, the bare minimum description of the persona + that gets used in almost all prompts that need to call on the persona. - def get_str_iss(self): - """ - ISS stands for "identity stable set." This describes the commonset summary - of this persona -- basically, the bare minimum description of the persona - that gets used in almost all prompts that need to call on the persona. + INPUT + None + OUTPUT + the identity stable set summary of the persona in a string form. + EXAMPLE STR OUTPUT + "Name: Dolores Heitmiller + Age: 28 + Innate traits: hard-edged, independent, loyal + Learned traits: Dolores is a painter who wants live quietly and paint + while enjoying her everyday life. + Currently: Dolores is preparing for her first solo show. She mostly + works from home. + Lifestyle: Dolores goes to bed around 11pm, sleeps for 7 hours, eats + dinner around 6pm. + Daily plan requirement: Dolores is planning to stay at home all day and + never go out." + """ + commonset = "" + commonset += f"Name: {self.name}\n" + commonset += f"Age: {self.age}\n" + commonset += f"Innate traits: {self.innate}\n" + commonset += f"Learned traits: {self.learned}\n" + commonset += f"Currently: {self.currently}\n" + commonset += f"Lifestyle: {self.lifestyle}\n" + commonset += f"Daily plan requirement: {self.daily_plan_req}\n" + commonset += f"Current Date: {self.curr_time.strftime('%A %B %d')}\n" + return commonset - INPUT - None - OUTPUT - the identity stable set summary of the persona in a string form. - EXAMPLE STR OUTPUT - "Name: Dolores Heitmiller - Age: 28 - Innate traits: hard-edged, independent, loyal - Learned traits: Dolores is a painter who wants live quietly and paint - while enjoying her everyday life. - Currently: Dolores is preparing for her first solo show. She mostly - works from home. - Lifestyle: Dolores goes to bed around 11pm, sleeps for 7 hours, eats - dinner around 6pm. - Daily plan requirement: Dolores is planning to stay at home all day and - never go out." - """ - commonset = "" - commonset += f"Name: {self.name}\n" - commonset += f"Age: {self.age}\n" - commonset += f"Innate traits: {self.innate}\n" - commonset += f"Learned traits: {self.learned}\n" - commonset += f"Currently: {self.currently}\n" - commonset += f"Lifestyle: {self.lifestyle}\n" - commonset += f"Daily plan requirement: {self.daily_plan_req}\n" - commonset += f"Current Date: {self.curr_time.strftime('%A %B %d')}\n" - return commonset + def get_str_name(self): + return self.name + def get_str_firstname(self): + return self.first_name - def get_str_name(self): - return self.name + def get_str_lastname(self): + return self.last_name + def get_str_age(self): + return str(self.age) - def get_str_firstname(self): - return self.first_name + def get_str_innate(self): + return self.innate + def get_str_learned(self): + return self.learned - def get_str_lastname(self): - return self.last_name + def get_str_currently(self): + return self.currently + def get_str_lifestyle(self): + return self.lifestyle - def get_str_age(self): - return str(self.age) + def get_str_daily_plan_req(self): + return self.daily_plan_req + def get_str_curr_date_str(self): + return self.curr_time.strftime("%A %B %d") - def get_str_innate(self): - return self.innate + def get_curr_event(self): + if not self.act_address: + return (self.name, None, None) + else: + return self.act_event + def get_curr_event_and_desc(self): + if not self.act_address: + return (self.name, None, None, None) + else: + return (self.act_event[0], + self.act_event[1], + self.act_event[2], + self.act_description) - def get_str_learned(self): - return self.learned + def get_curr_obj_event_and_desc(self): + if not self.act_address: + return ("", None, None, None) + else: + return (self.act_address, + self.act_obj_event[1], + self.act_obj_event[2], + self.act_obj_description) + def add_new_action(self, + action_address, + action_duration, + action_description, + action_pronunciatio, + action_event, + chatting_with, + chat, + chatting_with_buffer, + chatting_end_time, + act_obj_description, + act_obj_pronunciatio, + act_obj_event, + act_start_time=None): + self.act_address = action_address + self.act_duration = action_duration + self.act_description = action_description + self.act_pronunciatio = action_pronunciatio + self.act_event = action_event - def get_str_currently(self): - return self.currently + self.chatting_with = chatting_with + self.chat = chat + if chatting_with_buffer: + self.chatting_with_buffer.update(chatting_with_buffer) + self.chatting_end_time = chatting_end_time + self.act_obj_description = act_obj_description + self.act_obj_pronunciatio = act_obj_pronunciatio + self.act_obj_event = act_obj_event - def get_str_lifestyle(self): - return self.lifestyle + self.act_start_time = self.curr_time + self.act_path_set = False - def get_str_daily_plan_req(self): - return self.daily_plan_req + def act_time_str(self): + """ + Returns a string output of the current time. + INPUT + None + OUTPUT + A string output of the current time. + EXAMPLE STR OUTPUT + "14:05 P.M." + """ + return self.act_start_time.strftime("%H:%M %p") - def get_str_curr_date_str(self): - return self.curr_time.strftime("%A %B %d") + def act_check_finished(self): + """ + Checks whether the self.Action instance has finished. + INPUT + curr_datetime: Current time. If current time is later than the action's + start time + its duration, then the action has finished. + OUTPUT + Boolean [True]: Action has finished. + Boolean [False]: Action has not finished and is still ongoing. + """ + if not self.act_address: + return True - def get_curr_event(self): - if not self.act_address: - return (self.name, None, None) - else: - return self.act_event + if self.chatting_with: + end_time = self.chatting_end_time + else: + x = self.act_start_time + if x.second != 0: + x = x.replace(second=0) + x = (x + datetime.timedelta(minutes=1)) + end_time = (x + datetime.timedelta(minutes=self.act_duration)) + if end_time.strftime("%H:%M:%S") == self.curr_time.strftime("%H:%M:%S"): + return True + return False - def get_curr_event_and_desc(self): - if not self.act_address: - return (self.name, None, None, None) - else: - return (self.act_event[0], - self.act_event[1], - self.act_event[2], - self.act_description) + def act_summarize(self): + """ + Summarize the current action as a dictionary. + INPUT + None + OUTPUT + ret: A human readable summary of the action. + """ + exp = dict() + exp["persona"] = self.name + exp["address"] = self.act_address + exp["start_datetime"] = self.act_start_time + exp["duration"] = self.act_duration + exp["description"] = self.act_description + exp["pronunciatio"] = self.act_pronunciatio + return exp - def get_curr_obj_event_and_desc(self): - if not self.act_address: - return ("", None, None, None) - else: - return (self.act_address, - self.act_obj_event[1], - self.act_obj_event[2], - self.act_obj_description) + def act_summary_str(self): + """ + Returns a string summary of the current action. Meant to be + human-readable. + INPUT + None + OUTPUT + ret: A human readable summary of the action. + """ + start_datetime_str = self.act_start_time.strftime("%A %B %d -- %H:%M %p") + ret = f"[{start_datetime_str}]\n" + ret += f"Activity: {self.name} is {self.act_description}\n" + ret += f"Address: {self.act_address}\n" + ret += f"Duration in minutes (e.g., x min): {str(self.act_duration)} min\n" + return ret - def add_new_action(self, - action_address, - action_duration, - action_description, - action_pronunciatio, - action_event, - chatting_with, - chat, - chatting_with_buffer, - chatting_end_time, - act_obj_description, - act_obj_pronunciatio, - act_obj_event, - act_start_time=None): - self.act_address = action_address - self.act_duration = action_duration - self.act_description = action_description - self.act_pronunciatio = action_pronunciatio - self.act_event = action_event + def get_str_daily_schedule_summary(self): + ret = "" + curr_min_sum = 0 + for row in self.f_daily_schedule: + curr_min_sum += row[1] + hour = int(curr_min_sum / 60) + minute = curr_min_sum % 60 + ret += f"{hour:02}:{minute:02} || {row[0]}\n" + return ret - self.chatting_with = chatting_with - self.chat = chat - if chatting_with_buffer: - self.chatting_with_buffer.update(chatting_with_buffer) - self.chatting_end_time = chatting_end_time - - self.act_obj_description = act_obj_description - self.act_obj_pronunciatio = act_obj_pronunciatio - self.act_obj_event = act_obj_event - - self.act_start_time = self.curr_time - - self.act_path_set = False - - - def act_time_str(self): - """ - Returns a string output of the current time. - - INPUT - None - OUTPUT - A string output of the current time. - EXAMPLE STR OUTPUT - "14:05 P.M." - """ - return self.act_start_time.strftime("%H:%M %p") - - - def act_check_finished(self): - """ - Checks whether the self.Action instance has finished. - - INPUT - curr_datetime: Current time. If current time is later than the action's - start time + its duration, then the action has finished. - OUTPUT - Boolean [True]: Action has finished. - Boolean [False]: Action has not finished and is still ongoing. - """ - if not self.act_address: - return True - - if self.chatting_with: - end_time = self.chatting_end_time - else: - x = self.act_start_time - if x.second != 0: - x = x.replace(second=0) - x = (x + datetime.timedelta(minutes=1)) - end_time = (x + datetime.timedelta(minutes=self.act_duration)) - - if end_time.strftime("%H:%M:%S") == self.curr_time.strftime("%H:%M:%S"): - return True - return False - - - def act_summarize(self): - """ - Summarize the current action as a dictionary. - - INPUT - None - OUTPUT - ret: A human readable summary of the action. - """ - exp = dict() - exp["persona"] = self.name - exp["address"] = self.act_address - exp["start_datetime"] = self.act_start_time - exp["duration"] = self.act_duration - exp["description"] = self.act_description - exp["pronunciatio"] = self.act_pronunciatio - return exp - - - def act_summary_str(self): - """ - Returns a string summary of the current action. Meant to be - human-readable. - - INPUT - None - OUTPUT - ret: A human readable summary of the action. - """ - start_datetime_str = self.act_start_time.strftime("%A %B %d -- %H:%M %p") - ret = f"[{start_datetime_str}]\n" - ret += f"Activity: {self.name} is {self.act_description}\n" - ret += f"Address: {self.act_address}\n" - ret += f"Duration in minutes (e.g., x min): {str(self.act_duration)} min\n" - return ret - - - def get_str_daily_schedule_summary(self): - ret = "" - curr_min_sum = 0 - for row in self.f_daily_schedule: - curr_min_sum += row[1] - hour = int(curr_min_sum/60) - minute = curr_min_sum%60 - ret += f"{hour:02}:{minute:02} || {row[0]}\n" - return ret - - - def get_str_daily_schedule_hourly_org_summary(self): - ret = "" - curr_min_sum = 0 - for row in self.f_daily_schedule_hourly_org: - curr_min_sum += row[1] - hour = int(curr_min_sum/60) - minute = curr_min_sum%60 - ret += f"{hour:02}:{minute:02} || {row[0]}\n" - return ret + def get_str_daily_schedule_hourly_org_summary(self): + ret = "" + curr_min_sum = 0 + for row in self.f_daily_schedule_hourly_org: + curr_min_sum += row[1] + hour = int(curr_min_sum / 60) + minute = curr_min_sum % 60 + ret += f"{hour:02}:{minute:02} || {row[0]}\n" + return ret diff --git a/examples/st_game/memory/spatial_memory.py b/examples/st_game/memory/spatial_memory.py index 455d60e05..73bdc552b 100644 --- a/examples/st_game/memory/spatial_memory.py +++ b/examples/st_game/memory/spatial_memory.py @@ -3,130 +3,113 @@ Author: Joon Sung Park (joonspk@stanford.edu) File: spatial_memory.py Description: Defines the MemoryTree class that serves as the agents' spatial -memory that aids in grounding their behavior in the game world. +memory that aids in grounding their behavior in the game world. """ import json import os -class MemoryTree: - def __init__(self, f_saved: str) -> None: - self.tree = {} - if os.path.isfile(f_saved) and os.path.exists(f_saved): - with open(f_saved) as f: - self.tree = json.load(f) +class MemoryTree: + def __init__(self, f_saved: str) -> None: + self.tree = {} + if os.path.isfile(f_saved) and os.path.exists(f_saved): + with open(f_saved) as f: + self.tree = json.load(f) - def print_tree(self) -> None: - def _print_tree(tree, depth): - dash = " >" * depth - if type(tree) == type(list()): - if tree: - print (dash, tree) - return + def print_tree(self) -> None: + def _print_tree(tree, depth): + dash = " >" * depth + if isinstance(tree, list): + if tree: + print(dash, tree) + return - for key, val in tree.items(): - if key: - print (dash, key) - _print_tree(val, depth+1) - - _print_tree(self.tree, 0) - + for key, val in tree.items(): + if key: + print(dash, key) + _print_tree(val, depth + 1) - def save(self, out_json: str) -> None: - with open(out_json, "w") as outfile: - json.dump(self.tree, outfile) + _print_tree(self.tree, 0) + def save(self, out_json: str) -> None: + with open(out_json, "w") as outfile: + json.dump(self.tree, outfile) - def get_str_accessible_sectors(self, curr_world: str) -> str: - """ - Returns a summary string of all the arenas that the persona can access - within the current sector. + def get_str_accessible_sectors(self, curr_world: str) -> str: + """ + Returns a summary string of all the arenas that the persona can access + within the current sector. - Note that there are places a given persona cannot enter. This information - is provided in the persona sheet. We account for this in this function. - - INPUT - None - OUTPUT - A summary string of all the arenas that the persona can access. - EXAMPLE STR OUTPUT - "bedroom, kitchen, dining room, office, bathroom" - """ - x = ", ".join(list(self.tree[curr_world].keys())) - return x - - - def get_str_accessible_sector_arenas(self, sector: str) -> str: - """ - Returns a summary string of all the arenas that the persona can access - within the current sector. - - Note that there are places a given persona cannot enter. This information - is provided in the persona sheet. We account for this in this function. - - INPUT - None - OUTPUT - A summary string of all the arenas that the persona can access. - EXAMPLE STR OUTPUT - "bedroom, kitchen, dining room, office, bathroom" - """ - curr_world, curr_sector = sector.split(":") - if not curr_sector: - return "" - x = ", ".join(list(self.tree[curr_world][curr_sector].keys())) - return x - - - def get_str_accessible_arena_game_objects(self, arena: str) -> str: - """ - Get a str list of all accessible game objects that are in the arena. If - temp_address is specified, we return the objects that are available in - that arena, and if not, we return the objects that are in the arena our - persona is currently in. - - INPUT - temp_address: optional arena address - OUTPUT - str list of all accessible game objects in the gmae arena. - EXAMPLE STR OUTPUT - "phone, charger, bed, nightstand" - """ - curr_world, curr_sector, curr_arena = arena.split(":") - - if not curr_arena: - return "" - - try: - x = ", ".join(list(self.tree[curr_world][curr_sector][curr_arena])) - except: - x = ", ".join(list(self.tree[curr_world][curr_sector][curr_arena.lower()])) - return x - - - def add_tile_info(self, tile_info: dict) -> None: - if tile_info["world"]: - if (tile_info["world"] not in self.tree): - self.tree[tile_info["world"]] = {} - if tile_info["sector"]: - if (tile_info["sector"] not in self.tree[tile_info["world"]]): - self.tree[tile_info["world"]][tile_info["sector"]] = {} - if tile_info["arena"]: - if (tile_info["arena"] not in self.tree[tile_info["world"]] - [tile_info["sector"]]): - self.tree[tile_info["world"]][tile_info["sector"]][tile_info["arena"]] = [] - if tile_info["game_object"]: - if (tile_info["game_object"] not in self.tree[tile_info["world"]] - [tile_info["sector"]] - [tile_info["arena"]]): - self.tree[tile_info["world"]][tile_info["sector"]][tile_info["arena"]] += [ - tile_info["game_object"]] + Note that there are places a given persona cannot enter. This information + is provided in the persona sheet. We account for this in this function. + INPUT + None + OUTPUT + A summary string of all the arenas that the persona can access. + EXAMPLE STR OUTPUT + "bedroom, kitchen, dining room, office, bathroom" + """ + x = ", ".join(list(self.tree[curr_world].keys())) + return x + def get_str_accessible_sector_arenas(self, sector: str) -> str: + """ + Returns a summary string of all the arenas that the persona can access + within the current sector. + Note that there are places a given persona cannot enter. This information + is provided in the persona sheet. We account for this in this function. + INPUT + None + OUTPUT + A summary string of all the arenas that the persona can access. + EXAMPLE STR OUTPUT + "bedroom, kitchen, dining room, office, bathroom" + """ + curr_world, curr_sector = sector.split(":") + if not curr_sector: + return "" + x = ", ".join(list(self.tree[curr_world][curr_sector].keys())) + return x + def get_str_accessible_arena_game_objects(self, arena: str) -> str: + """ + Get a str list of all accessible game objects that are in the arena. If + temp_address is specified, we return the objects that are available in + that arena, and if not, we return the objects that are in the arena our + persona is currently in. + INPUT + temp_address: optional arena address + OUTPUT + str list of all accessible game objects in the gmae arena. + EXAMPLE STR OUTPUT + "phone, charger, bed, nightstand" + """ + curr_world, curr_sector, curr_arena = arena.split(":") + if not curr_arena: + return "" + try: + x = ", ".join(list(self.tree[curr_world][curr_sector][curr_arena])) + except Exception as exp: + x = ", ".join(list(self.tree[curr_world][curr_sector][curr_arena.lower()])) + return x + def add_tile_info(self, tile_info: dict) -> None: + if tile_info["world"]: + if tile_info["world"] not in self.tree: + self.tree[tile_info["world"]] = {} + if tile_info["sector"]: + if tile_info["sector"] not in self.tree[tile_info["world"]]: + self.tree[tile_info["world"]][tile_info["sector"]] = {} + if tile_info["arena"]: + if tile_info["arena"] not in self.tree[tile_info["world"]][tile_info["sector"]]: + self.tree[tile_info["world"]][tile_info["sector"]][tile_info["arena"]] = [] + if tile_info["game_object"]: + if tile_info["game_object"] not in self.tree[tile_info["world"]][tile_info["sector"]][tile_info["arena"]]: + self.tree[tile_info["world"]][tile_info["sector"]][tile_info["arena"]] += [ + tile_info["game_object"]] diff --git a/examples/st_game/requirements.txt b/examples/st_game/requirements.txt new file mode 100644 index 000000000..6ae7ffa24 --- /dev/null +++ b/examples/st_game/requirements.txt @@ -0,0 +1 @@ +networkx \ No newline at end of file diff --git a/examples/st_game/roles/st_role.py b/examples/st_game/roles/st_role.py index 0c70c1d80..e4de84afa 100644 --- a/examples/st_game/roles/st_role.py +++ b/examples/st_game/roles/st_role.py @@ -44,8 +44,9 @@ class STRole(Role): def __init__(self, name: str = "Klaus Mueller", profile: str = "STMember", + sim_path: str = "new_sim", has_inner_voice: bool = False): - + self.sim_path = sim_path self._rc = STRoleContext() super(STRole, self).__init__(name=name, profile=profile) @@ -73,147 +74,148 @@ class STRole(Role): async def observe(self) -> list[BasicMemory]: # TODO observe info from maze_env """ - Perceive events around the role and saves it to the memory, both events - and spaces. + Perceive events around the role and saves it to the memory, both events + and spaces. - We first perceive the events nearby the role, as determined by its - . If there are a lot of events happening within that radius, we + We first perceive the events nearby the role, as determined by its + . If there are a lot of events happening within that radius, we take the of the closest events. Finally, we check whether any of them are new, as determined by . If they are new, then we - save those and return the instances for those events. + save those and return the instances for those events. - OUTPUT: - ret_events: a list of that are perceived and new. + OUTPUT: + ret_events: a list of that are perceived and new. """ maze = self._rc.env.maze # PERCEIVE SPACE # We get the nearby tiles given our current tile and the persona's vision - # radius. - nearby_tiles = maze.get_nearby_tiles(self._rc.scratch.curr_tile, - self._rc.scratch.vision_r) + # radius. + nearby_tiles = maze.get_nearby_tiles(self._rc.scratch.curr_tile, + self._rc.scratch.vision_r) # We then store the perceived space. Note that the s_mem of the persona is - # in the form of a tree constructed using dictionaries. - for tile in nearby_tiles: + # in the form of a tree constructed using dictionaries. + for tile in nearby_tiles: tile_info = maze.access_tile(tile) self._rc.spatial_memory.add_tile_info(tile_info) - # PERCEIVE EVENTS. + # PERCEIVE EVENTS. # We will perceive events that take place in the same arena as the - # persona's current arena. + # persona's current arena. curr_arena_path = maze.get_tile_path(self._rc.scratch.curr_tile, "arena") # We do not perceive the same event twice (this can happen if an object is # extended across multiple tiles). percept_events_set = set() # We will order our percept based on the distance, with the closest ones - # getting priorities. + # getting priorities. percept_events_list = [] # First, we put all events that are occuring in the nearby tiles into the # percept_events_list - for tile in nearby_tiles: + for tile in nearby_tiles: tile_details = maze.access_tile(tile) - if tile_details["events"]: - if maze.get_tile_path(tile, "arena") == curr_arena_path: - # This calculates the distance between the persona's current tile, - # and the target tile. - dist = math.dist([tile[0], tile[1]], - [self._rc.scratch.curr_tile[0], - self._rc.scratch.curr_tile[1]]) - # Add any relevant events to our temp set/list with the distant info. - for event in tile_details["events"]: - if event not in percept_events_set: + if tile_details["events"]: + if maze.get_tile_path(tile, "arena") == curr_arena_path: + # This calculates the distance between the persona's current tile, + # and the target tile. + dist = math.dist([tile[0], tile[1]], + [self._rc.scratch.curr_tile[0], + self._rc.scratch.curr_tile[1]]) + # Add any relevant events to our temp set/list with the distant info. + for event in tile_details["events"]: + if event not in percept_events_set: percept_events_list += [[dist, event]] percept_events_set.add(event) # We sort, and perceive only self._rc.scratch.att_bandwidth of the closest # events. If the bandwidth is larger, then it means the persona can perceive - # more elements within a small area. + # more elements within a small area. percept_events_list = sorted(percept_events_list, key=itemgetter(0)) perceived_events = [] - for dist, event in percept_events_list[:self._rc.scratch.att_bandwidth]: + for dist, event in percept_events_list[:self._rc.scratch.att_bandwidth]: perceived_events += [event] - # Storing events. - # is a list of instances from the persona's - # associative memory. + # Storing events. + # is a list of instances from the persona's + # associative memory. ret_events = [] - for p_event in perceived_events: + for p_event in perceived_events: s, p, o, desc = p_event - if not p: - # If the object is not present, then we default the event to "idle". + if not p: + # If the object is not present, then we default the event to "idle". p = "is" o = "idle" desc = "idle" desc = f"{s.split(':')[-1]} is {desc}" p_event = (s, p, o) - # We retrieve the latest self._rc.scratch.retention events. If there is + # We retrieve the latest self._rc.scratch.retention events. If there is # something new that is happening (that is, p_event not in latest_events), - # then we add that event to the a_mem and return it. + # then we add that event to the a_mem and return it. latest_events = self._rc.memory.get_summarized_latest_events( - self._rc.scratch.retention) + self._rc.scratch.retention) if p_event not in latest_events: - # We start by managing keywords. + # We start by managing keywords. keywords = set() sub = p_event[0] obj = p_event[2] - if ":" in p_event[0]: + if ":" in p_event[0]: sub = p_event[0].split(":")[-1] - if ":" in p_event[2]: + if ":" in p_event[2]: obj = p_event[2].split(":")[-1] keywords.update([sub, obj]) # Get event embedding desc_embedding_in = desc - if "(" in desc: + if "(" in desc: desc_embedding_in = (desc_embedding_in.split("(")[1] - .split(")")[0] - .strip()) - if desc_embedding_in in self._rc.memory.embeddings: + .split(")")[0] + .strip()) + if desc_embedding_in in self._rc.memory.embeddings: event_embedding = self._rc.memory.embeddings[desc_embedding_in] - else: + else: event_embedding = get_embedding(desc_embedding_in) event_embedding_pair = (desc_embedding_in, event_embedding) - - # Get event poignancy. - event_poignancy = generate_poig_score(self, - "action", - desc_embedding_in) + + # Get event poignancy. + event_poignancy = generate_poig_score(self, + "action", + desc_embedding_in) # If we observe the persona's self chat, we include that in the memory - # of the persona here. + # of the persona here. chat_node_ids = [] - if p_event[0] == f"{self.name}" and p_event[1] == "chat with": + if p_event[0] == f"{self.name}" and p_event[1] == "chat with": curr_event = self._rc.scratch.act_event - if self._rc.scratch.act_description in self._rc.memory.embeddings: + if self._rc.scratch.act_description in self._rc.memory.embeddings: chat_embedding = self._rc.memory.embeddings[ - self._rc.scratch.act_description] - else: + self._rc.scratch.act_description] + else: chat_embedding = get_embedding(self._rc.scratch - .act_description) - chat_embedding_pair = (self._rc.scratch.act_description, - chat_embedding) - chat_poignancy = generate_poig_score(self._rc.scratch, "chat", - self._rc.scratch.act_description) + .act_description) + chat_embedding_pair = (self._rc.scratch.act_description, + chat_embedding) + chat_poignancy = generate_poig_score(self._rc.scratch, "chat", + self._rc.scratch.act_description) chat_node = self._rc.memory.add_chat(self._rc.scratch.curr_time, None, - curr_event[0], curr_event[1], curr_event[2], - self._rc.scratch.act_description, keywords, - chat_poignancy, chat_embedding_pair, - self._rc.scratch.chat) + curr_event[0], curr_event[1], curr_event[2], + self._rc.scratch.act_description, keywords, + chat_poignancy, chat_embedding_pair, + self._rc.scratch.chat) chat_node_ids = [chat_node.node_id] - # Finally, we add the current event to the agent's memory. + # Finally, we add the current event to the agent's memory. ret_events += [self._rc.memory.add_event(self._rc.scratch.curr_time, None, - s, p, o, desc, keywords, event_poignancy, - event_embedding_pair, chat_node_ids)] + s, p, o, desc, keywords, event_poignancy, + event_embedding_pair, chat_node_ids)] self._rc.scratch.importance_trigger_curr -= event_poignancy self._rc.scratch.importance_ele_n += 1 return ret_events - async def retrieve(self, query, n = 30 ,topk = 4): + async def retrieve(self, query, n=30, topk=4): # TODO retrieve memories from agent_memory - retrieve_memories = agent_retrieve(self._rc.memory, self._rc.scratch.curr_time, self._rc.scratch.recency_decay, query, n, topk) + retrieve_memories = agent_retrieve(self._rc.memory, self._rc.scratch.curr_time, self._rc.scratch.recency_decay, + query, n, topk) return retrieve_memories async def plan(self): diff --git a/examples/st_game/run_st_game.py b/examples/st_game/run_st_game.py index db0cc190e..7b00a6f71 100644 --- a/examples/st_game/run_st_game.py +++ b/examples/st_game/run_st_game.py @@ -20,11 +20,14 @@ async def startup(idea: str, # get role names from `storage/{simulation_name}/reverie/meta.json` and then init roles reverie_meta = get_reverie_meta(fork_sim_code) roles = [] - # TODO + sim_path = STORAGE_PATH.joinpath(sim_code) for idx, role_name in enumerate(reverie_meta["persona_names"]): role_stg_path = STORAGE_PATH.joinpath(fork_sim_code).joinpath(f"personas/{role_name}") has_inner_voice = True if idx == 0 else False - role = STRole(name=role_name, has_inner_voice=has_inner_voice) + role = STRole(name=role_name, + sim_path=sim_path, + profile=f"STMember_{idx}", + has_inner_voice=has_inner_voice) role.load_from(role_stg_path) roles.append(role) diff --git a/examples/st_game/stanford_town.py b/examples/st_game/stanford_town.py index c29b3484e..0522710f7 100644 --- a/examples/st_game/stanford_town.py +++ b/examples/st_game/stanford_town.py @@ -7,6 +7,7 @@ from pydantic import Field from metagpt.software_company import SoftwareCompany from metagpt.roles.role import Role from metagpt.schema import Message +from metagpt.logs import logger from maze_environment import MazeEnvironment from actions.user_requirement import UserRequirement @@ -17,6 +18,7 @@ class StanfordTown(SoftwareCompany): environment: MazeEnvironment = Field(default_factory=MazeEnvironment) def wakeup_roles(self, roles: list[Role]): + logger.warning(f"The Town add {len(roles)} roles, and start to operate.") self.environment.add_roles(roles) def start_project(self, idea):