format and add little update

This commit is contained in:
better629 2023-10-01 09:03:51 +08:00
parent bd11e48d7a
commit fd4ee7256c
9 changed files with 1062 additions and 1103 deletions

View file

@ -7,7 +7,7 @@ Author: Joon Sung Park (joonspk@stanford.edu)
File: maze.py
Description: Defines the Maze class, which represents the map of the simulated
world in a 2-dimensional matrix.
world in a 2-dimensional matrix.
"""
import json
@ -17,400 +17,396 @@ import networkx as nx
from .utils.const import MAZE_ASSET_PATH
from .utils.utils import read_csv_to_list
class Maze:
def __init__(self, maze_asset_path: Path = MAZE_ASSET_PATH) -> None:
# READING IN THE BASIC META INFORMATION ABOUT THE MAP
self.maze_asset_path = maze_asset_path
maze_matrix_path = maze_asset_path.joinpath("matrix")
# Reading in the meta information about the world. If you want tp see the
# example variables, check out the maze_meta_info.json file.
meta_info = json.load(open(maze_matrix_path.joinpath("maze_meta_info.json")))
# <maze_width> and <maze_height> denote the number of tiles make up the
# height and width of the map.
self.maze_width = int(meta_info["maze_width"])
self.maze_height = int(meta_info["maze_height"])
# <sq_tile_size> denotes the pixel height/width of a tile.
self.sq_tile_size = int(meta_info["sq_tile_size"])
# <special_constraint> is a string description of any relevant special
# constraints the world might have.
# e.g., "planning to stay at home all day and never go out of her home"
self.special_constraint = meta_info["special_constraint"]
# READING IN SPECIAL BLOCKS
# Special blocks are those that are colored in the Tiled map.
class Maze:
def __init__(self, maze_asset_path: Path = MAZE_ASSET_PATH) -> None:
# READING IN THE BASIC META INFORMATION ABOUT THE MAP
self.maze_asset_path = maze_asset_path
maze_matrix_path = maze_asset_path.joinpath("matrix")
# Reading in the meta information about the world. If you want tp see the
# example variables, check out the maze_meta_info.json file.
meta_info = json.load(open(maze_matrix_path.joinpath("maze_meta_info.json")))
# <maze_width> and <maze_height> denote the number of tiles make up the
# height and width of the map.
self.maze_width = int(meta_info["maze_width"])
self.maze_height = int(meta_info["maze_height"])
# <sq_tile_size> denotes the pixel height/width of a tile.
self.sq_tile_size = int(meta_info["sq_tile_size"])
# <special_constraint> is a string description of any relevant special
# constraints the world might have.
# e.g., "planning to stay at home all day and never go out of her home"
self.special_constraint = meta_info["special_constraint"]
# Here is an example row for the arena block file:
# e.g., "25335, Double Studio, Studio, Common Room"
# And here is another example row for the game object block file:
# e.g, "25331, Double Studio, Studio, Bedroom 2, Painting"
# READING IN SPECIAL BLOCKS
# Special blocks are those that are colored in the Tiled map.
# Notice that the first element here is the color marker digit from the
# Tiled export. Then we basically have the block path:
# World, Sector, Arena, Game Object -- again, these paths need to be
# unique within an instance of Reverie.
blocks_folder = maze_matrix_path.joinpath("special_blocks")
# Here is an example row for the arena block file:
# e.g., "25335, Double Studio, Studio, Common Room"
# And here is another example row for the game object block file:
# e.g, "25331, Double Studio, Studio, Bedroom 2, Painting"
_wb = blocks_folder.joinpath("world_blocks.csv")
wb_rows = read_csv_to_list(_wb, header=False)
wb = wb_rows[0][-1]
_sb = blocks_folder.joinpath("sector_blocks.csv")
sb_rows = read_csv_to_list(_sb, header=False)
sb_dict = dict()
for i in sb_rows: sb_dict[i[0]] = i[-1]
_ab = blocks_folder.joinpath("arena_blocks.csv")
ab_rows = read_csv_to_list(_ab, header=False)
ab_dict = dict()
for i in ab_rows: ab_dict[i[0]] = i[-1]
_gob = blocks_folder.joinpath("game_object_blocks.csv")
gob_rows = read_csv_to_list(_gob, header=False)
gob_dict = dict()
for i in gob_rows: gob_dict[i[0]] = i[-1]
_slb = blocks_folder.joinpath("spawning_location_blocks.csv")
slb_rows = read_csv_to_list(_slb, header=False)
slb_dict = dict()
for i in slb_rows: slb_dict[i[0]] = i[-1]
# Notice that the first element here is the color marker digit from the
# Tiled export. Then we basically have the block path:
# World, Sector, Arena, Game Object -- again, these paths need to be
# unique within an instance of Reverie.
blocks_folder = maze_matrix_path.joinpath("special_blocks")
# [SECTION 3] Reading in the matrices
# This is your typical two dimensional matrices. It's made up of 0s and
# the number that represents the color block from the blocks folder.
maze_folder = maze_matrix_path.joinpath("maze")
_wb = blocks_folder.joinpath("world_blocks.csv")
wb_rows = read_csv_to_list(_wb, header=False)
wb = wb_rows[0][-1]
_cm = maze_folder.joinpath("collision_maze.csv")
collision_maze_raw = read_csv_to_list(_cm, header=False)[0]
_sm = maze_folder.joinpath("sector_maze.csv")
sector_maze_raw = read_csv_to_list(_sm, header=False)[0]
_am = maze_folder.joinpath("arena_maze.csv")
arena_maze_raw = read_csv_to_list(_am, header=False)[0]
_gom = maze_folder.joinpath("game_object_maze.csv")
game_object_maze_raw = read_csv_to_list(_gom, header=False)[0]
_slm = maze_folder.joinpath("spawning_location_maze.csv")
spawning_location_maze_raw = read_csv_to_list(_slm, header=False)[0]
_sb = blocks_folder.joinpath("sector_blocks.csv")
sb_rows = read_csv_to_list(_sb, header=False)
sb_dict = dict()
for i in sb_rows:
sb_dict[i[0]] = i[-1]
# Loading the maze. The mazes are taken directly from the json exports of
# Tiled maps. They should be in csv format.
# Importantly, they are "not" in a 2-d matrix format -- they are single
# row matrices with the length of width x height of the maze. So we need
# to convert here.
# We can do this all at once since the dimension of all these matrices are
# identical (e.g., 70 x 40).
# example format: [['0', '0', ... '25309', '0',...], ['0',...]...]
# 25309 is the collision bar number right now.
self.collision_maze = []
sector_maze = []
arena_maze = []
game_object_maze = []
spawning_location_maze = []
for i in range(0, len(collision_maze_raw), meta_info["maze_width"]):
tw = meta_info["maze_width"]
self.collision_maze += [collision_maze_raw[i:i+tw]]
sector_maze += [sector_maze_raw[i:i+tw]]
arena_maze += [arena_maze_raw[i:i+tw]]
game_object_maze += [game_object_maze_raw[i:i+tw]]
spawning_location_maze += [spawning_location_maze_raw[i:i+tw]]
_ab = blocks_folder.joinpath("arena_blocks.csv")
ab_rows = read_csv_to_list(_ab, header=False)
ab_dict = dict()
for i in ab_rows:
ab_dict[i[0]] = i[-1]
# Once we are done loading in the maze, we now set up self.tiles. This is
# a matrix accessed by row:col where each access point is a dictionary
# that contains all the things that are taking place in that tile.
# More specifically, it contains information about its "world," "sector,"
# "arena," "game_object," "spawning_location," as well as whether it is a
# collision block, and a set of all events taking place in it.
# e.g., self.tiles[32][59] = {'world': 'double studio',
# 'sector': '', 'arena': '', 'game_object': '',
# 'spawning_location': '', 'collision': False, 'events': set()}
# e.g., self.tiles[9][58] = {'world': 'double studio',
# 'sector': 'double studio', 'arena': 'bedroom 2',
# 'game_object': 'bed', 'spawning_location': 'bedroom-2-a',
# 'collision': False,
# 'events': {('double studio:double studio:bedroom 2:bed',
# None, None)}}
self.tiles = []
for i in range(self.maze_height):
row = []
for j in range(self.maze_width):
tile_details = dict()
tile_details["world"] = wb
tile_details["sector"] = ""
if sector_maze[i][j] in sb_dict:
tile_details["sector"] = sb_dict[sector_maze[i][j]]
tile_details["arena"] = ""
if arena_maze[i][j] in ab_dict:
tile_details["arena"] = ab_dict[arena_maze[i][j]]
tile_details["game_object"] = ""
if game_object_maze[i][j] in gob_dict:
tile_details["game_object"] = gob_dict[game_object_maze[i][j]]
tile_details["spawning_location"] = ""
if spawning_location_maze[i][j] in slb_dict:
tile_details["spawning_location"] = slb_dict[spawning_location_maze[i][j]]
tile_details["collision"] = False
if self.collision_maze[i][j] != "0":
tile_details["collision"] = True
_gob = blocks_folder.joinpath("game_object_blocks.csv")
gob_rows = read_csv_to_list(_gob, header=False)
gob_dict = dict()
for i in gob_rows:
gob_dict[i[0]] = i[-1]
tile_details["events"] = set()
row += [tile_details]
self.tiles += [row]
# Each game object occupies an event in the tile. We are setting up the
# default event value here.
for i in range(self.maze_height):
for j in range(self.maze_width):
if self.tiles[i][j]["game_object"]:
object_name = ":".join([self.tiles[i][j]["world"],
self.tiles[i][j]["sector"],
self.tiles[i][j]["arena"],
self.tiles[i][j]["game_object"]])
go_event = (object_name, None, None, None)
self.tiles[i][j]["events"].add(go_event)
_slb = blocks_folder.joinpath("spawning_location_blocks.csv")
slb_rows = read_csv_to_list(_slb, header=False)
slb_dict = dict()
for i in slb_rows:
slb_dict[i[0]] = i[-1]
# Reverse tile access.
# <self.address_tiles> -- given a string address, we return a set of all
# tile coordinates belonging to that address (this is opposite of
# self.tiles that give you the string address given a coordinate). This is
# an optimization component for finding paths for the personas' movement.
# self.address_tiles['<spawn_loc>bedroom-2-a'] == {(58, 9)}
# self.address_tiles['double studio:recreation:pool table']
# == {(29, 14), (31, 11), (30, 14), (32, 11), ...},
self.address_tiles = dict()
for i in range(self.maze_height):
for j in range(self.maze_width):
addresses = []
if self.tiles[i][j]["sector"]:
add = f'{self.tiles[i][j]["world"]}:'
add += f'{self.tiles[i][j]["sector"]}'
addresses += [add]
if self.tiles[i][j]["arena"]:
add = f'{self.tiles[i][j]["world"]}:'
add += f'{self.tiles[i][j]["sector"]}:'
add += f'{self.tiles[i][j]["arena"]}'
addresses += [add]
if self.tiles[i][j]["game_object"]:
add = f'{self.tiles[i][j]["world"]}:'
add += f'{self.tiles[i][j]["sector"]}:'
add += f'{self.tiles[i][j]["arena"]}:'
add += f'{self.tiles[i][j]["game_object"]}'
addresses += [add]
if self.tiles[i][j]["spawning_location"]:
add = f'<spawn_loc>{self.tiles[i][j]["spawning_location"]}'
addresses += [add]
# [SECTION 3] Reading in the matrices
# This is your typical two dimensional matrices. It's made up of 0s and
# the number that represents the color block from the blocks folder.
maze_folder = maze_matrix_path.joinpath("maze")
for add in addresses:
if add in self.address_tiles:
self.address_tiles[add].add((j, i))
else:
self.address_tiles[add] = set([(j, i)])
_cm = maze_folder.joinpath("collision_maze.csv")
collision_maze_raw = read_csv_to_list(_cm, header=False)[0]
_sm = maze_folder.joinpath("sector_maze.csv")
sector_maze_raw = read_csv_to_list(_sm, header=False)[0]
_am = maze_folder.joinpath("arena_maze.csv")
arena_maze_raw = read_csv_to_list(_am, header=False)[0]
_gom = maze_folder.joinpath("game_object_maze.csv")
game_object_maze_raw = read_csv_to_list(_gom, header=False)[0]
_slm = maze_folder.joinpath("spawning_location_maze.csv")
spawning_location_maze_raw = read_csv_to_list(_slm, header=False)[0]
# Build an nx.Graph.
grid_graph = nx.grid_2d_graph(m=self.maze_width, n=self.maze_height)
for i in range(self.maze_height):
for j in range(self.maze_width):
if self.collision_maze[i][j]!=0:
grid_graph.remove_node((i,j))
self.nx_graph = grid_graph
# Loading the maze. The mazes are taken directly from the json exports of
# Tiled maps. They should be in csv format.
# Importantly, they are "not" in a 2-d matrix format -- they are single
# row matrices with the length of width x height of the maze. So we need
# to convert here.
# We can do this all at once since the dimension of all these matrices are
# identical (e.g., 70 x 40).
# example format: [['0', '0', ... '25309', '0',...], ['0',...]...]
# 25309 is the collision bar number right now.
self.collision_maze = []
sector_maze = []
arena_maze = []
game_object_maze = []
spawning_location_maze = []
for i in range(0, len(collision_maze_raw), meta_info["maze_width"]):
tw = meta_info["maze_width"]
self.collision_maze += [collision_maze_raw[i:i + tw]]
sector_maze += [sector_maze_raw[i:i + tw]]
arena_maze += [arena_maze_raw[i:i + tw]]
game_object_maze += [game_object_maze_raw[i:i + tw]]
spawning_location_maze += [spawning_location_maze_raw[i:i + tw]]
# Once we are done loading in the maze, we now set up self.tiles. This is
# a matrix accessed by row:col where each access point is a dictionary
# that contains all the things that are taking place in that tile.
# More specifically, it contains information about its "world," "sector,"
# "arena," "game_object," "spawning_location," as well as whether it is a
# collision block, and a set of all events taking place in it.
# e.g., self.tiles[32][59] = {'world': 'double studio',
# 'sector': '', 'arena': '', 'game_object': '',
# 'spawning_location': '', 'collision': False, 'events': set()}
# e.g., self.tiles[9][58] = {'world': 'double studio',
# 'sector': 'double studio', 'arena': 'bedroom 2',
# 'game_object': 'bed', 'spawning_location': 'bedroom-2-a',
# 'collision': False,
# 'events': {('double studio:double studio:bedroom 2:bed',
# None, None)}}
self.tiles = []
for i in range(self.maze_height):
row = []
for j in range(self.maze_width):
tile_details = dict()
tile_details["world"] = wb
def turn_coordinate_to_tile(self, px_coordinate: tuple[int, int]) -> tuple[int, int]:
"""
Turns a pixel coordinate to a tile coordinate.
tile_details["sector"] = ""
if sector_maze[i][j] in sb_dict:
tile_details["sector"] = sb_dict[sector_maze[i][j]]
INPUT
px_coordinate: The pixel coordinate of our interest. Comes in the x, y
format.
OUTPUT
tile coordinate (x, y): The tile coordinate that corresponds to the
pixel coordinate.
EXAMPLE OUTPUT
Given (1600, 384), outputs (50, 12)
"""
x = math.ceil(px_coordinate[0]/self.sq_tile_size)
y = math.ceil(px_coordinate[1]/self.sq_tile_size)
return (x, y)
tile_details["arena"] = ""
if arena_maze[i][j] in ab_dict:
tile_details["arena"] = ab_dict[arena_maze[i][j]]
tile_details["game_object"] = ""
if game_object_maze[i][j] in gob_dict:
tile_details["game_object"] = gob_dict[game_object_maze[i][j]]
def access_tile(self, tile: tuple[int, int]) -> dict:
"""
Returns the tiles details dictionary that is stored in self.tiles of the
designated x, y location.
tile_details["spawning_location"] = ""
if spawning_location_maze[i][j] in slb_dict:
tile_details["spawning_location"] = slb_dict[spawning_location_maze[i][j]]
INPUT
tile: The tile coordinate of our interest in (x, y) form.
OUTPUT
The tile detail dictionary for the designated tile.
EXAMPLE OUTPUT
Given (58, 9),
self.tiles[9][58] = {'world': 'double studio',
'sector': 'double studio', 'arena': 'bedroom 2',
'game_object': 'bed', 'spawning_location': 'bedroom-2-a',
'collision': False,
'events': {('double studio:double studio:bedroom 2:bed',
None, None)}}
"""
x = tile[0]
y = tile[1]
return self.tiles[y][x]
tile_details["collision"] = False
if self.collision_maze[i][j] != "0":
tile_details["collision"] = True
tile_details["events"] = set()
def get_tile_path(self, tile: tuple[int, int], level: str) -> str:
"""
Get the tile string address given its coordinate. You designate the level
by giving it a string level description.
row += [tile_details]
self.tiles += [row]
# Each game object occupies an event in the tile. We are setting up the
# default event value here.
for i in range(self.maze_height):
for j in range(self.maze_width):
if self.tiles[i][j]["game_object"]:
object_name = ":".join([self.tiles[i][j]["world"],
self.tiles[i][j]["sector"],
self.tiles[i][j]["arena"],
self.tiles[i][j]["game_object"]])
go_event = (object_name, None, None, None)
self.tiles[i][j]["events"].add(go_event)
INPUT:
tile: The tile coordinate of our interest in (x, y) form.
level: world, sector, arena, or game object
OUTPUT
The string address for the tile.
EXAMPLE OUTPUT
Given tile=(58, 9), and level=arena,
"double studio:double studio:bedroom 2"
"""
x = tile[0]
y = tile[1]
tile = self.tiles[y][x]
# Reverse tile access.
# <self.address_tiles> -- given a string address, we return a set of all
# tile coordinates belonging to that address (this is opposite of
# self.tiles that give you the string address given a coordinate). This is
# an optimization component for finding paths for the personas' movement.
# self.address_tiles['<spawn_loc>bedroom-2-a'] == {(58, 9)}
# self.address_tiles['double studio:recreation:pool table']
# == {(29, 14), (31, 11), (30, 14), (32, 11), ...},
self.address_tiles = dict()
for i in range(self.maze_height):
for j in range(self.maze_width):
addresses = []
if self.tiles[i][j]["sector"]:
add = f'{self.tiles[i][j]["world"]}:'
add += f'{self.tiles[i][j]["sector"]}'
addresses += [add]
if self.tiles[i][j]["arena"]:
add = f'{self.tiles[i][j]["world"]}:'
add += f'{self.tiles[i][j]["sector"]}:'
add += f'{self.tiles[i][j]["arena"]}'
addresses += [add]
if self.tiles[i][j]["game_object"]:
add = f'{self.tiles[i][j]["world"]}:'
add += f'{self.tiles[i][j]["sector"]}:'
add += f'{self.tiles[i][j]["arena"]}:'
add += f'{self.tiles[i][j]["game_object"]}'
addresses += [add]
if self.tiles[i][j]["spawning_location"]:
add = f'<spawn_loc>{self.tiles[i][j]["spawning_location"]}'
addresses += [add]
path = f"{tile['world']}"
if level == "world":
return path
else:
path += f":{tile['sector']}"
if level == "sector":
return path
else:
path += f":{tile['arena']}"
for add in addresses:
if add in self.address_tiles:
self.address_tiles[add].add((j, i))
else:
self.address_tiles[add] = set([(j, i)])
if level == "arena":
return path
else:
path += f":{tile['game_object']}"
# Build an nx.Graph.
grid_graph = nx.grid_2d_graph(m=self.maze_width, n=self.maze_height)
for i in range(self.maze_height):
for j in range(self.maze_width):
if self.collision_maze[i][j] != 0:
grid_graph.remove_node((i, j))
self.nx_graph = grid_graph
return path
def turn_coordinate_to_tile(self, px_coordinate: tuple[int, int]) -> tuple[int, int]:
"""
Turns a pixel coordinate to a tile coordinate.
INPUT
px_coordinate: The pixel coordinate of our interest. Comes in the x, y
format.
OUTPUT
tile coordinate (x, y): The tile coordinate that corresponds to the
pixel coordinate.
EXAMPLE OUTPUT
Given (1600, 384), outputs (50, 12)
"""
x = math.ceil(px_coordinate[0] / self.sq_tile_size)
y = math.ceil(px_coordinate[1] / self.sq_tile_size)
return (x, y)
def get_nearby_tiles(self, tile: tuple[int, int], vision_r: int) -> list[tuple[int, int]]:
"""
Given the current tile and vision_r, return a list of tiles that are
within the radius. Note that this implementation looks at a square
boundary when determining what is within the radius.
i.e., for vision_r, returns x's.
x x x x x
x x x x x
x x P x x
x x x x x
x x x x x
def access_tile(self, tile: tuple[int, int]) -> dict:
"""
Returns the tiles details dictionary that is stored in self.tiles of the
designated x, y location.
INPUT:
tile: The tile coordinate of our interest in (x, y) form.
vision_r: The radius of the persona's vision.
OUTPUT:
nearby_tiles: a list of tiles that are within the radius.
"""
left_end = 0
if tile[0] - vision_r > left_end:
left_end = tile[0] - vision_r
INPUT
tile: The tile coordinate of our interest in (x, y) form.
OUTPUT
The tile detail dictionary for the designated tile.
EXAMPLE OUTPUT
Given (58, 9),
self.tiles[9][58] = {'world': 'double studio',
'sector': 'double studio', 'arena': 'bedroom 2',
'game_object': 'bed', 'spawning_location': 'bedroom-2-a',
'collision': False,
'events': {('double studio:double studio:bedroom 2:bed',
None, None)}}
"""
x = tile[0]
y = tile[1]
return self.tiles[y][x]
right_end = self.maze_width - 1
if tile[0] + vision_r + 1 < right_end:
right_end = tile[0] + vision_r + 1
def get_tile_path(self, tile: tuple[int, int], level: str) -> str:
"""
Get the tile string address given its coordinate. You designate the level
by giving it a string level description.
bottom_end = self.maze_height - 1
if tile[1] + vision_r + 1 < bottom_end:
bottom_end = tile[1] + vision_r + 1
INPUT:
tile: The tile coordinate of our interest in (x, y) form.
level: world, sector, arena, or game object
OUTPUT
The string address for the tile.
EXAMPLE OUTPUT
Given tile=(58, 9), and level=arena,
"double studio:double studio:bedroom 2"
"""
x = tile[0]
y = tile[1]
tile = self.tiles[y][x]
top_end = 0
if tile[1] - vision_r > top_end:
top_end = tile[1] - vision_r
path = f"{tile['world']}"
if level == "world":
return path
else:
path += f":{tile['sector']}"
nearby_tiles = []
for i in range(left_end, right_end):
for j in range(top_end, bottom_end):
nearby_tiles += [(i, j)]
return nearby_tiles
if level == "sector":
return path
else:
path += f":{tile['arena']}"
if level == "arena":
return path
else:
path += f":{tile['game_object']}"
def add_event_from_tile(self, curr_event: tuple[str], tile: tuple[int, int]) -> None:
"""
Add an event triple to a tile.
return path
INPUT:
curr_event: Current event triple.
e.g., ('double studio:double studio:bedroom 2:bed', None,
None)
tile: The tile coordinate of our interest in (x, y) form.
OUPUT:
None
"""
self.tiles[tile[1]][tile[0]]["events"].add(curr_event)
def get_nearby_tiles(self, tile: tuple[int, int], vision_r: int) -> list[tuple[int, int]]:
"""
Given the current tile and vision_r, return a list of tiles that are
within the radius. Note that this implementation looks at a square
boundary when determining what is within the radius.
i.e., for vision_r, returns x's.
x x x x x
x x x x x
x x P x x
x x x x x
x x x x x
INPUT:
tile: The tile coordinate of our interest in (x, y) form.
vision_r: The radius of the persona's vision.
OUTPUT:
nearby_tiles: a list of tiles that are within the radius.
"""
left_end = 0
if tile[0] - vision_r > left_end:
left_end = tile[0] - vision_r
def remove_event_from_tile(self, curr_event: tuple[str], tile: tuple[int, int]) -> None:
"""dswaq
Remove an event triple from a tile.
right_end = self.maze_width - 1
if tile[0] + vision_r + 1 < right_end:
right_end = tile[0] + vision_r + 1
INPUT:
curr_event: Current event triple.
e.g., ('double studio:double studio:bedroom 2:bed', None,
None)
tile: The tile coordinate of our interest in (x, y) form.
OUPUT:
None
"""
curr_tile_ev_cp = self.tiles[tile[1]][tile[0]]["events"].copy()
for event in curr_tile_ev_cp:
if event == curr_event:
self.tiles[tile[1]][tile[0]]["events"].remove(event)
bottom_end = self.maze_height - 1
if tile[1] + vision_r + 1 < bottom_end:
bottom_end = tile[1] + vision_r + 1
top_end = 0
if tile[1] - vision_r > top_end:
top_end = tile[1] - vision_r
def turn_event_from_tile_idle(self, curr_event: tuple[str], tile: tuple[int, int]) -> None:
curr_tile_ev_cp = self.tiles[tile[1]][tile[0]]["events"].copy()
for event in curr_tile_ev_cp:
if event == curr_event:
self.tiles[tile[1]][tile[0]]["events"].remove(event)
new_event = (event[0], None, None, None)
self.tiles[tile[1]][tile[0]]["events"].add(new_event)
nearby_tiles = []
for i in range(left_end, right_end):
for j in range(top_end, bottom_end):
nearby_tiles += [(i, j)]
return nearby_tiles
def add_event_from_tile(self, curr_event: tuple[str], tile: tuple[int, int]) -> None:
"""
Add an event triple to a tile.
def remove_subject_events_from_tile(self, subject: str, tile: tuple[int, int]) -> None:
"""
Remove an event triple that has the input subject from a tile.
INPUT:
curr_event: Current event triple.
e.g., ('double studio:double studio:bedroom 2:bed', None,
None)
tile: The tile coordinate of our interest in (x, y) form.
OUPUT:
None
"""
self.tiles[tile[1]][tile[0]]["events"].add(curr_event)
INPUT:
subject: "Isabella Rodriguez"
tile: The tile coordinate of our interest in (x, y) form.
OUPUT:
None
"""
curr_tile_ev_cp = self.tiles[tile[1]][tile[0]]["events"].copy()
for event in curr_tile_ev_cp:
if event[0] == subject:
self.tiles[tile[1]][tile[0]]["events"].remove(event)
def remove_event_from_tile(self, curr_event: tuple[str], tile: tuple[int, int]) -> None:
"""dswaq
Remove an event triple from a tile.
INPUT:
curr_event: Current event triple.
e.g., ('double studio:double studio:bedroom 2:bed', None,
None)
tile: The tile coordinate of our interest in (x, y) form.
OUPUT:
None
"""
curr_tile_ev_cp = self.tiles[tile[1]][tile[0]]["events"].copy()
for event in curr_tile_ev_cp:
if event == curr_event:
self.tiles[tile[1]][tile[0]]["events"].remove(event)
def _find_closest_node(self, coords: tuple[int, int]) -> tuple[int, int]:
target_coords = self.nx_graph.nodes
min_dist = None
closest_coordinate = None
for target in target_coords:
dist = math.dist(coords, target)
if not closest_coordinate:
min_dist = dist
closest_coordinate = target
else:
if min_dist > dist:
min_dist = dist
closest_coordinate = target
return closest_coordinate
def turn_event_from_tile_idle(self, curr_event: tuple[str], tile: tuple[int, int]) -> None:
curr_tile_ev_cp = self.tiles[tile[1]][tile[0]]["events"].copy()
for event in curr_tile_ev_cp:
if event == curr_event:
self.tiles[tile[1]][tile[0]]["events"].remove(event)
new_event = (event[0], None, None, None)
self.tiles[tile[1]][tile[0]]["events"].add(new_event)
def find_path(self, start: tuple[int, int], end: tuple[int, int]) -> list[tuple[int, int]]:
if start not in self.nx_graph.nodes:
start = self._find_closest_node(start)
if end not in self.nx_graph.nodes:
end = self._find_closest_node(end)
return nx.shortest_path(self.nx_graph, start, end)
def remove_subject_events_from_tile(self, subject: str, tile: tuple[int, int]) -> None:
"""
Remove an event triple that has the input subject from a tile.
INPUT:
subject: "Isabella Rodriguez"
tile: The tile coordinate of our interest in (x, y) form.
OUPUT:
None
"""
curr_tile_ev_cp = self.tiles[tile[1]][tile[0]]["events"].copy()
for event in curr_tile_ev_cp:
if event[0] == subject:
self.tiles[tile[1]][tile[0]]["events"].remove(event)
def _find_closest_node(self, coords: tuple[int, int]) -> tuple[int, int]:
target_coords = self.nx_graph.nodes
min_dist = None
closest_coordinate = None
for target in target_coords:
dist = math.dist(coords, target)
if not closest_coordinate:
min_dist = dist
closest_coordinate = target
else:
if min_dist > dist:
min_dist = dist
closest_coordinate = target
return closest_coordinate
def find_path(self, start: tuple[int, int], end: tuple[int, int]) -> list[tuple[int, int]]:
if start not in self.nx_graph.nodes:
start = self._find_closest_node(start)
if end not in self.nx_graph.nodes:
end = self._find_closest_node(end)
return nx.shortest_path(self.nx_graph, start, end)

View file

@ -14,39 +14,40 @@ class BasicMemory(Message):
created: datetime, expiration: datetime,
subject: str, predicate: str, object: str,
content: str, embedding_key: str, poignancy: int, keywords: list, filling: list,
cause_by = ""):
cause_by=""):
"""
BasicMemory继承于MG的Message类其中content属性替代description属性
Message类中对于Chat类型支持的非常好对于Agent个体的Perceive,Reflection,Plan支持的并不多
在Type设计上我们延续GA的三个种类但是对于Chat种类的对话进行特别设计具体怎么设计还没想好
"""
super().__init__(content,cause_by=cause_by)
super().__init__(content, cause_by=cause_by)
"""
从父类中继承的属性
content: str # 记忆描述
cause_by: Type["Action"] = field(default="") # 触发动作只在Type为chat时初始化
cause_by 接受一个Action类在此项目中每个Agent需要有一个基础动作[Receive] 用于接受假对话Message而每个Agent需要有独一无二的动作类用以接受真对话Message
cause_by 接受一个Action类在此项目中每个Agent需要有一个基础动作[Receive] 用于接受假对话Message
而每个Agent需要有独一无二的动作类用以接受真对话Message
"""
self.memory_id: str = memory_id # 记忆ID
self.memory_count: int = memory_count # 第几个记忆实际数值与Memory相等
self.type_count: int = type_count # 第几种记忆,类型为整数(具体不太理解如何生成的)
self.memory_type: str = memory_type # 记忆类型,包含 event,thought,chat三种类型
self.depth: str = depth # 记忆深度,类型为整数
self.memory_id: str = memory_id # 记忆ID
self.memory_count: int = memory_count # 第几个记忆实际数值与Memory相等
self.type_count: int = type_count # 第几种记忆,类型为整数(具体不太理解如何生成的)
self.memory_type: str = memory_type # 记忆类型,包含 event,thought,chat三种类型
self.depth: str = depth # 记忆深度,类型为整数
self.created: datetime = created # 创建时间
self.expiration: datetime = expiration # 记忆失效时间,默认为空()
self.last_accessed: datetime = created # 上一次调用的时间初始化时候与self.created一致
self.created: datetime = created # 创建时间
self.expiration: datetime = expiration # 记忆失效时间,默认为空()
self.last_accessed: datetime = created # 上一次调用的时间初始化时候与self.created一致
self.subject: str = subject # 主语
self.predicate: str = predicate # 谓语
self.object: str = object # 宾语
self.subject: str = subject # 主语
self.predicate: str = predicate # 谓语
self.object: str = object # 宾语
self.embedding_key: str = embedding_key # 内容与self.content一致
self.poignancy: int = poignancy # importance值
self.keywords: list = keywords # keywords
self.filling: list = filling # 装的与之相关联的memory_id的列表
self.embedding_key: str = embedding_key # 内容与self.content一致
self.poignancy: int = poignancy # importance值
self.keywords: list = keywords # keywords
self.filling: list = filling # 装的与之相关联的memory_id的列表
def summary(self):
def summary(self):
return (self.subject, self.predicate, self.object)
def save_to_dict(self) -> dict:
@ -65,9 +66,9 @@ class BasicMemory(Message):
memory_dict[node_id]["cmemory_dicteated"] = self.created.strftime('%Y-%m-%d %H:%M:%S')
memory_dict[node_id]["expiration"] = None
if self.expiration:
if self.expiration:
memory_dict[node_id]["expiration"] = (self.expiration
.strftime('%Y-%m-%d %H:%M:%S'))
.strftime('%Y-%m-%d %H:%M:%S'))
memory_dict[node_id]["subject"] = self.subject
memory_dict[node_id]["predicate"] = self.predicate
@ -83,6 +84,7 @@ class BasicMemory(Message):
return memory_dict
class AgentMemory(Memory):
"""
GA中主要存储三种JSON
@ -90,6 +92,7 @@ class AgentMemory(Memory):
2. Node.json (Dict Node_id:Node)
3. kw_strength.json
"""
def __init__(self, memory_saved: str):
"""
AgentMemory类继承自Memory类重写storage替代GA中id_to_node一方面存储所有信息一方面作为JSON转化
@ -97,21 +100,20 @@ class AgentMemory(Memory):
@李嵩@张凯 这里的storage是List你们需要写一个JSON转化器将List修改为node.json一致的格式
"""
super.__init__()
self.storage: list[BasicMemory] = [] # 重写Stroage存储BasicMemory所有节点
self.event_list = [] # 存储event记忆
self.thought_list = [] # 存储thought记忆
self.storage: list[BasicMemory] = [] # 重写Stroage存储BasicMemory所有节点
self.event_list = [] # 存储event记忆
self.thought_list = [] # 存储thought记忆
self.event_keywords = dict() # 存储keywords
self.thought_keywords = dict()
self.event_keywords = dict() # 存储keywords
self.thought_keywords = dict()
self.chat_keywords = dict()
self.kw_strength_event = dict() # 关键词影响存储
self.kw_strength_thought = dict()
self.kw_strength_event = dict() # 关键词影响存储
self.kw_strength_thought = dict()
self.load(memory_saved)
def save(self,memory_saved:str):
def save(self, memory_saved: str):
"""
将MemormyBasic类存储为Nodes.json形式复现GA中的Kw Strength.json形式
这里添加一个路径即可
@ -121,36 +123,35 @@ class AgentMemory(Memory):
for i in range(len(self.storage)):
memory_node = self.storage[i]
memory_json.update(memory_node)
with open(memory_saved+"/nodes.json", "w") as outfile:
with open(memory_saved + "/nodes.json", "w") as outfile:
json.dump(memory_json, outfile)
with open(memory_saved+"/embeddings.json", "w") as outfile:
with open(memory_saved + "/embeddings.json", "w") as outfile:
json.dump(self.embeddings, outfile)
strength_json = dict()
strength_json["kw_strength_event"] = self.kw_strength_event
strength_json["kw_strength_thought"] = self.kw_strength_thought
with open(memory_saved+"/kw_strength.json", "w") as outfile:
with open(memory_saved + "/kw_strength.json", "w") as outfile:
json.dump(strength_json, outfile)
def load(self,memory_saved:str):
def load(self, memory_saved: str):
"""
将GA的JSON解析填充到AgentMemory类之中
"""
self.embeddings = json.load(open(memory_saved + "/embeddings.json"))
memory_load = json.load(open(memory_saved + "/nodes.json"))
for count in range(len(memory_load.keys())):
node_id = f"node_{str(count+1)}"
node_id = f"node_{str(count + 1)}"
node_details = memory_load[node_id]
node_type = node_details["type"]
created = datetime.datetime.strptime(node_details["created"],
'%Y-%m-%d %H:%M:%S')
created = datetime.datetime.strptime(node_details["created"],
'%Y-%m-%d %H:%M:%S')
expiration = None
if node_details["expiration"]:
if node_details["expiration"]:
expiration = datetime.datetime.strptime(node_details["expiration"],
'%Y-%m-%d %H:%M:%S')
if node_details["cause_by"]:
cause_by = node_details["cause_by"]
@ -159,29 +160,28 @@ class AgentMemory(Memory):
o = node_details["object"]
description = node_details["description"]
embedding_pair = (node_details["embedding_key"],
self.embeddings[node_details["embedding_key"]])
poignancy =node_details["poignancy"]
embedding_pair = (node_details["embedding_key"],
self.embeddings[node_details["embedding_key"]])
poignancy = node_details["poignancy"]
keywords = set(node_details["keywords"])
filling = node_details["filling"]
if node_type == "event":
self.add_event(created, expiration, s, p, o,
description, keywords, poignancy, embedding_pair, filling)
elif node_type == "chat":
self.add_chat(created, expiration, s, p, o,
description, keywords, poignancy, embedding_pair, filling,cause_by)
elif node_type == "thought":
self.add_thought(created, expiration, s, p, o,
description, keywords, poignancy, embedding_pair, filling)
if node_type == "event":
self.add_event(created, expiration, s, p, o,
description, keywords, poignancy, embedding_pair, filling)
elif node_type == "chat":
self.add_chat(created, expiration, s, p, o,
description, keywords, poignancy, embedding_pair, filling, cause_by)
elif node_type == "thought":
self.add_thought(created, expiration, s, p, o,
description, keywords, poignancy, embedding_pair, filling)
strength_keywords_load = json.load(open(memory_saved + "/kw_strength.json"))
if strength_keywords_load["kw_strength_event"]:
if strength_keywords_load["kw_strength_event"]:
self.kw_strength_event = strength_keywords_load["kw_strength_event"]
if strength_keywords_load["kw_strength_thought"]:
if strength_keywords_load["kw_strength_thought"]:
self.kw_strength_thought = strength_keywords_load["kw_strength_thought"]
def add(self, memory_basic: BasicMemory):
"""
Add a new message to storage, while updating the index
@ -192,18 +192,17 @@ class AgentMemory(Memory):
self.storage.append(memory_basic)
if memory_basic.cause_by:
self.index[memory_basic.cause_by][0:0] = [memory_basic]
return
return
if memory_basic.type == "thought":
self.thought_list[0:0] = [memory_basic]
return
if memory_basic.type == "event":
self.event_list[0:0] = [memory_basic]
def add_chat(self, created, expiration, s, p, o,
content, keywords, poignancy,
embedding_pair, filling,
cause_by):
def add_chat(self, created, expiration, s, p, o,
content, keywords, poignancy,
embedding_pair, filling,
cause_by):
"""
调用add方法初始化chat在创建的时候就需要调用embeeding函数
"""
@ -211,31 +210,30 @@ class AgentMemory(Memory):
type_count = len(self.thought_list) + 1
memory_type = "chat"
memory_id = f"memory_{str(memory_count)}"
depth = 1
depth = 1
memory_node = BasicMemory(memory_id, memory_count, type_count, memory_type, depth,
created, expiration,
s, p ,o,
s, p, o,
content, embedding_pair[0],
poignancy, keywords, filling,
poignancy, keywords, filling,
cause_by)
keywords = [i.lower() for i in keywords]
for kw in keywords:
if kw in self.chat_keywords:
for kw in keywords:
if kw in self.chat_keywords:
self.chat_keywords[kw][0:0] = [memory_node]
else:
else:
self.chat_keywords[kw] = [memory_node]
self.add(memory_node)
self.embeddings[embedding_pair[0]] = embedding_pair[1]
return memory_node
return memory_node
def add_thought(self, created, expiration, s, p, o,
content, keywords, poignancy,
embedding_pair, filling):
def add_thought(self, created, expiration, s, p, o,
content, keywords, poignancy,
embedding_pair, filling):
"""
调用add方法初始化thought
"""
@ -243,44 +241,43 @@ class AgentMemory(Memory):
type_count = len(self.thought_list) + 1
memory_type = "event"
memory_id = f"memory_{str(memory_count)}"
depth = 1
depth = 1
try:
if filling:
depth_list = [memory_node.depth for memory_node in self.storage if memory_node.memory_id in filling ]
if filling:
depth_list = [memory_node.depth for memory_node in self.storage if memory_node.memory_id in filling]
depth += max(depth_list)
except:
pass
except Exception as exp:
pass
memory_node = BasicMemory(memory_id, memory_count, type_count, memory_type, depth,
created, expiration,
s, p ,o,
s, p, o,
content, embedding_pair[0],
poignancy, keywords, filling)
keywords = [i.lower() for i in keywords]
for kw in keywords:
if kw in self.thought_keywords:
for kw in keywords:
if kw in self.thought_keywords:
self.thought_keywords[kw][0:0] = [memory_node]
else:
else:
self.thought_keywords[kw] = [memory_node]
self.add(memory_node)
if f"{p} {o}" != "is idle":
for kw in keywords:
if f"{p} {o}" != "is idle":
for kw in keywords:
if kw in self.kw_strength_thought:
self.kw_strength_thought[kw] += 1
else:
else:
self.kw_strength_thought[kw] = 1
self.embeddings[embedding_pair[0]] = embedding_pair[1]
return memory_node
def add_event(self, created, expiration, s, p, o,
content, keywords, poignancy,
embedding_pair, filling):
def add_event(self, created, expiration, s, p, o,
content, keywords, poignancy,
embedding_pair, filling):
"""
调用add方法初始化event
"""
@ -289,40 +286,39 @@ class AgentMemory(Memory):
memory_type = "event"
memory_id = f"memory_{str(memory_count)}"
depth = 0
if "(" in content:
content = (" ".join(content.split()[:3])
+ " "
+ content.split("(")[-1][:-1])
content = (" ".join(content.split()[:3])
+ " "
+ content.split("(")[-1][:-1])
memory_node = BasicMemory(memory_id, memory_count, type_count, memory_type, depth,
created, expiration,
s, p ,o,
s, p, o,
content, embedding_pair[0],
poignancy, keywords, filling)
keywords = [i.lower() for i in keywords]
for kw in keywords:
if kw in self.event_keywords:
for kw in keywords:
if kw in self.event_keywords:
self.event_keywords[kw][0:0] = [memory_node]
else:
else:
self.event_keywords[kw] = [memory_node]
self.add(memory_node)
if f"{p} {o}" != "is idle":
for kw in keywords:
if f"{p} {o}" != "is idle":
for kw in keywords:
if kw in self.kw_strength_event:
self.kw_strength_event[kw] += 1
else:
else:
self.kw_strength_event[kw] = 1
self.embeddings[embedding_pair[0]] = embedding_pair[1]
return memory_node
def get_summarized_latest_events(self, retention):
def get_summarized_latest_events(self, retention):
ret_set = set()
for e_node in self.event_list[:retention]:
for e_node in self.event_list[:retention]:
ret_set.add(e_node.summary())
return ret_set
return ret_set

View file

@ -9,7 +9,8 @@ from examples.st_game.memory.agent_memory import AgentMemory, BasicMemory
from utils.utils import embedding_tools
def agent_retrieve(agent_memory: AgentMemory, curr_time: datetime.datetime, memory_forget: float, query: str, n: int = 30, topk: int = 4) -> list[BasicMemory]:
def agent_retrieve(agent_memory: AgentMemory, curr_time: datetime.datetime, memory_forget: float, query: str,
n: int = 30, topk: int = 4) -> list[BasicMemory]:
"""
Retrieve需要集合Role使用,原因在于Role才具有AgentMemory,scratch
逻辑:Role调用该函数,self._rc.AgentMemory,self._rc.scratch.curr_time,self._rc.scratch.memory_forget
@ -87,7 +88,7 @@ def extract_recency(curr_time, memory_forget, score_list):
"""
for i in range(len(score_list)):
day_count = (curr_time - score_list[i]['memory'].created).days
score_list[i]['recency'] = memory_forget**day_count
score_list[i]['recency'] = memory_forget ** day_count
return score_list

View file

@ -4,534 +4,509 @@
import datetime
import json
import sys
sys.path.append('../../')
from ..utils.check import check_if_file_exists
class Scratch:
def __init__(self, f_saved):
# 类别1:人物超参
self.vision_r = 4
self.att_bandwidth = 3
self.retention = 5
# 类别2:世界信息
self.curr_time = None
self.curr_tile = None
self.daily_plan_req = None
# 类别3:人物角色的核心身份
self.name = None
self.first_name = None
self.last_name = None
self.age = None
# L0 permanent core traits.
self.innate = None
# L1 stable traits.
self.learned = None
# L2 external implementation.
self.currently = None
self.lifestyle = None
self.living_area = None
class Scratch:
def __init__(self, f_saved):
# 类别1:人物超参
self.vision_r = 4
self.att_bandwidth = 3
self.retention = 5
# 类别4:旧反思变量
self.concept_forget = 100
self.daily_reflection_time = 60 * 3
self.daily_reflection_size = 5
self.overlap_reflect_th = 2
self.kw_strg_event_reflect_th = 4
self.kw_strg_thought_reflect_th = 4
# 类别5:新反思变量
self.recency_w = 1
self.relevance_w = 1
self.importance_w = 1
self.recency_decay = 0.99
self.importance_trigger_max = 150
self.importance_trigger_curr = self.importance_trigger_max
self.importance_ele_n = 0
self.thought_count = 5
# 类别6:个人计划
self.daily_req = []
self.f_daily_schedule = []
self.f_daily_schedule_hourly_org = []
# 类别7:当前动作
self.act_address = None
self.act_start_time = None
self.act_duration = None
self.act_description = None
self.act_pronunciatio = None
self.act_event = (self.name, None, None)
self.act_obj_description = None
self.act_obj_pronunciatio = None
self.act_obj_event = (self.name, None, None)
self.chatting_with = None
self.chat = None
self.chatting_with_buffer = dict()
self.chatting_end_time = None
self.act_path_set = False
self.planned_path = []
if check_if_file_exists(f_saved):
# If we have a bootstrap file, load that here.
scratch_load = json.load(open(f_saved))
self.vision_r = scratch_load["vision_r"]
self.att_bandwidth = scratch_load["att_bandwidth"]
self.retention = scratch_load["retention"]
if scratch_load["curr_time"]:
self.curr_time = datetime.datetime.strptime(scratch_load["curr_time"],
"%B %d, %Y, %H:%M:%S")
else:
# 类别2:世界信息
self.curr_time = None
self.curr_tile = scratch_load["curr_tile"]
self.daily_plan_req = scratch_load["daily_plan_req"]
self.curr_tile = None
self.daily_plan_req = None
self.name = scratch_load["name"]
self.first_name = scratch_load["first_name"]
self.last_name = scratch_load["last_name"]
self.age = scratch_load["age"]
self.innate = scratch_load["innate"]
self.learned = scratch_load["learned"]
self.currently = scratch_load["currently"]
self.lifestyle = scratch_load["lifestyle"]
self.living_area = scratch_load["living_area"]
# 类别3:人物角色的核心身份
self.name = None
self.first_name = None
self.last_name = None
self.age = None
# L0 permanent core traits.
self.innate = None
# L1 stable traits.
self.learned = None
# L2 external implementation.
self.currently = None
self.lifestyle = None
self.living_area = None
self.concept_forget = scratch_load["concept_forget"]
self.daily_reflection_time = scratch_load["daily_reflection_time"]
self.daily_reflection_size = scratch_load["daily_reflection_size"]
self.overlap_reflect_th = scratch_load["overlap_reflect_th"]
self.kw_strg_event_reflect_th = scratch_load["kw_strg_event_reflect_th"]
self.kw_strg_thought_reflect_th = scratch_load["kw_strg_thought_reflect_th"]
# 类别4:旧反思变量
self.concept_forget = 100
self.daily_reflection_time = 60 * 3
self.daily_reflection_size = 5
self.overlap_reflect_th = 2
self.kw_strg_event_reflect_th = 4
self.kw_strg_thought_reflect_th = 4
self.recency_w = scratch_load["recency_w"]
self.relevance_w = scratch_load["relevance_w"]
self.importance_w = scratch_load["importance_w"]
self.recency_decay = scratch_load["recency_decay"]
self.importance_trigger_max = scratch_load["importance_trigger_max"]
self.importance_trigger_curr = scratch_load["importance_trigger_curr"]
self.importance_ele_n = scratch_load["importance_ele_n"]
self.thought_count = scratch_load["thought_count"]
# 类别5:新反思变量
self.recency_w = 1
self.relevance_w = 1
self.importance_w = 1
self.recency_decay = 0.99
self.importance_trigger_max = 150
self.importance_trigger_curr = self.importance_trigger_max
self.importance_ele_n = 0
self.thought_count = 5
self.daily_req = scratch_load["daily_req"]
self.f_daily_schedule = scratch_load["f_daily_schedule"]
self.f_daily_schedule_hourly_org = scratch_load["f_daily_schedule_hourly_org"]
# 类别6:个人计划
self.daily_req = []
self.f_daily_schedule = []
self.f_daily_schedule_hourly_org = []
self.act_address = scratch_load["act_address"]
if scratch_load["act_start_time"]:
self.act_start_time = datetime.datetime.strptime(
scratch_load["act_start_time"],
"%B %d, %Y, %H:%M:%S")
else:
self.curr_time = None
self.act_duration = scratch_load["act_duration"]
self.act_description = scratch_load["act_description"]
self.act_pronunciatio = scratch_load["act_pronunciatio"]
self.act_event = tuple(scratch_load["act_event"])
# 类别7:当前动作
self.act_address = None
self.act_start_time = None
self.act_duration = None
self.act_description = None
self.act_pronunciatio = None
self.act_event = (self.name, None, None)
self.act_obj_description = scratch_load["act_obj_description"]
self.act_obj_pronunciatio = scratch_load["act_obj_pronunciatio"]
self.act_obj_event = tuple(scratch_load["act_obj_event"])
self.act_obj_description = None
self.act_obj_pronunciatio = None
self.act_obj_event = (self.name, None, None)
self.chatting_with = scratch_load["chatting_with"]
self.chat = scratch_load["chat"]
self.chatting_with_buffer = scratch_load["chatting_with_buffer"]
if scratch_load["chatting_end_time"]:
self.chatting_end_time = datetime.datetime.strptime(
scratch_load["chatting_end_time"],
"%B %d, %Y, %H:%M:%S")
else:
self.chatting_with = None
self.chat = None
self.chatting_with_buffer = dict()
self.chatting_end_time = None
self.act_path_set = scratch_load["act_path_set"]
self.planned_path = scratch_load["planned_path"]
self.act_path_set = False
self.planned_path = []
if check_if_file_exists(f_saved):
# If we have a bootstrap file, load that here.
scratch_load = json.load(open(f_saved))
def save(self, out_json):
"""
Save persona's scratch.
self.vision_r = scratch_load["vision_r"]
self.att_bandwidth = scratch_load["att_bandwidth"]
self.retention = scratch_load["retention"]
INPUT:
out_json: The file where we wil be saving our persona's state.
OUTPUT:
None
"""
scratch = dict()
scratch["vision_r"] = self.vision_r
scratch["att_bandwidth"] = self.att_bandwidth
scratch["retention"] = self.retention
if scratch_load["curr_time"]:
self.curr_time = datetime.datetime.strptime(scratch_load["curr_time"],
"%B %d, %Y, %H:%M:%S")
else:
self.curr_time = None
self.curr_tile = scratch_load["curr_tile"]
self.daily_plan_req = scratch_load["daily_plan_req"]
scratch["curr_time"] = self.curr_time.strftime("%B %d, %Y, %H:%M:%S")
scratch["curr_tile"] = self.curr_tile
scratch["daily_plan_req"] = self.daily_plan_req
self.name = scratch_load["name"]
self.first_name = scratch_load["first_name"]
self.last_name = scratch_load["last_name"]
self.age = scratch_load["age"]
self.innate = scratch_load["innate"]
self.learned = scratch_load["learned"]
self.currently = scratch_load["currently"]
self.lifestyle = scratch_load["lifestyle"]
self.living_area = scratch_load["living_area"]
scratch["name"] = self.name
scratch["first_name"] = self.first_name
scratch["last_name"] = self.last_name
scratch["age"] = self.age
scratch["innate"] = self.innate
scratch["learned"] = self.learned
scratch["currently"] = self.currently
scratch["lifestyle"] = self.lifestyle
scratch["living_area"] = self.living_area
self.concept_forget = scratch_load["concept_forget"]
self.daily_reflection_time = scratch_load["daily_reflection_time"]
self.daily_reflection_size = scratch_load["daily_reflection_size"]
self.overlap_reflect_th = scratch_load["overlap_reflect_th"]
self.kw_strg_event_reflect_th = scratch_load["kw_strg_event_reflect_th"]
self.kw_strg_thought_reflect_th = scratch_load["kw_strg_thought_reflect_th"]
scratch["concept_forget"] = self.concept_forget
scratch["daily_reflection_time"] = self.daily_reflection_time
scratch["daily_reflection_size"] = self.daily_reflection_size
scratch["overlap_reflect_th"] = self.overlap_reflect_th
scratch["kw_strg_event_reflect_th"] = self.kw_strg_event_reflect_th
scratch["kw_strg_thought_reflect_th"] = self.kw_strg_thought_reflect_th
self.recency_w = scratch_load["recency_w"]
self.relevance_w = scratch_load["relevance_w"]
self.importance_w = scratch_load["importance_w"]
self.recency_decay = scratch_load["recency_decay"]
self.importance_trigger_max = scratch_load["importance_trigger_max"]
self.importance_trigger_curr = scratch_load["importance_trigger_curr"]
self.importance_ele_n = scratch_load["importance_ele_n"]
self.thought_count = scratch_load["thought_count"]
scratch["recency_w"] = self.recency_w
scratch["relevance_w"] = self.relevance_w
scratch["importance_w"] = self.importance_w
scratch["recency_decay"] = self.recency_decay
scratch["importance_trigger_max"] = self.importance_trigger_max
scratch["importance_trigger_curr"] = self.importance_trigger_curr
scratch["importance_ele_n"] = self.importance_ele_n
scratch["thought_count"] = self.thought_count
self.daily_req = scratch_load["daily_req"]
self.f_daily_schedule = scratch_load["f_daily_schedule"]
self.f_daily_schedule_hourly_org = scratch_load["f_daily_schedule_hourly_org"]
scratch["daily_req"] = self.daily_req
scratch["f_daily_schedule"] = self.f_daily_schedule
scratch["f_daily_schedule_hourly_org"] = self.f_daily_schedule_hourly_org
self.act_address = scratch_load["act_address"]
if scratch_load["act_start_time"]:
self.act_start_time = datetime.datetime.strptime(
scratch_load["act_start_time"],
"%B %d, %Y, %H:%M:%S")
else:
self.curr_time = None
self.act_duration = scratch_load["act_duration"]
self.act_description = scratch_load["act_description"]
self.act_pronunciatio = scratch_load["act_pronunciatio"]
self.act_event = tuple(scratch_load["act_event"])
scratch["act_address"] = self.act_address
scratch["act_start_time"] = (self.act_start_time
self.act_obj_description = scratch_load["act_obj_description"]
self.act_obj_pronunciatio = scratch_load["act_obj_pronunciatio"]
self.act_obj_event = tuple(scratch_load["act_obj_event"])
self.chatting_with = scratch_load["chatting_with"]
self.chat = scratch_load["chat"]
self.chatting_with_buffer = scratch_load["chatting_with_buffer"]
if scratch_load["chatting_end_time"]:
self.chatting_end_time = datetime.datetime.strptime(
scratch_load["chatting_end_time"],
"%B %d, %Y, %H:%M:%S")
else:
self.chatting_end_time = None
self.act_path_set = scratch_load["act_path_set"]
self.planned_path = scratch_load["planned_path"]
def save(self, out_json):
"""
Save persona's scratch.
INPUT:
out_json: The file where we wil be saving our persona's state.
OUTPUT:
None
"""
scratch = dict()
scratch["vision_r"] = self.vision_r
scratch["att_bandwidth"] = self.att_bandwidth
scratch["retention"] = self.retention
scratch["curr_time"] = self.curr_time.strftime("%B %d, %Y, %H:%M:%S")
scratch["curr_tile"] = self.curr_tile
scratch["daily_plan_req"] = self.daily_plan_req
scratch["name"] = self.name
scratch["first_name"] = self.first_name
scratch["last_name"] = self.last_name
scratch["age"] = self.age
scratch["innate"] = self.innate
scratch["learned"] = self.learned
scratch["currently"] = self.currently
scratch["lifestyle"] = self.lifestyle
scratch["living_area"] = self.living_area
scratch["concept_forget"] = self.concept_forget
scratch["daily_reflection_time"] = self.daily_reflection_time
scratch["daily_reflection_size"] = self.daily_reflection_size
scratch["overlap_reflect_th"] = self.overlap_reflect_th
scratch["kw_strg_event_reflect_th"] = self.kw_strg_event_reflect_th
scratch["kw_strg_thought_reflect_th"] = self.kw_strg_thought_reflect_th
scratch["recency_w"] = self.recency_w
scratch["relevance_w"] = self.relevance_w
scratch["importance_w"] = self.importance_w
scratch["recency_decay"] = self.recency_decay
scratch["importance_trigger_max"] = self.importance_trigger_max
scratch["importance_trigger_curr"] = self.importance_trigger_curr
scratch["importance_ele_n"] = self.importance_ele_n
scratch["thought_count"] = self.thought_count
scratch["daily_req"] = self.daily_req
scratch["f_daily_schedule"] = self.f_daily_schedule
scratch["f_daily_schedule_hourly_org"] = self.f_daily_schedule_hourly_org
scratch["act_address"] = self.act_address
scratch["act_start_time"] = (self.act_start_time
.strftime("%B %d, %Y, %H:%M:%S"))
scratch["act_duration"] = self.act_duration
scratch["act_description"] = self.act_description
scratch["act_pronunciatio"] = self.act_pronunciatio
scratch["act_event"] = self.act_event
scratch["act_duration"] = self.act_duration
scratch["act_description"] = self.act_description
scratch["act_pronunciatio"] = self.act_pronunciatio
scratch["act_event"] = self.act_event
scratch["act_obj_description"] = self.act_obj_description
scratch["act_obj_pronunciatio"] = self.act_obj_pronunciatio
scratch["act_obj_event"] = self.act_obj_event
scratch["act_obj_description"] = self.act_obj_description
scratch["act_obj_pronunciatio"] = self.act_obj_pronunciatio
scratch["act_obj_event"] = self.act_obj_event
scratch["chatting_with"] = self.chatting_with
scratch["chat"] = self.chat
scratch["chatting_with_buffer"] = self.chatting_with_buffer
if self.chatting_end_time:
scratch["chatting_end_time"] = (self.chatting_end_time
.strftime("%B %d, %Y, %H:%M:%S"))
else:
scratch["chatting_end_time"] = None
scratch["chatting_with"] = self.chatting_with
scratch["chat"] = self.chat
scratch["chatting_with_buffer"] = self.chatting_with_buffer
if self.chatting_end_time:
scratch["chatting_end_time"] = (self.chatting_end_time
.strftime("%B %d, %Y, %H:%M:%S"))
else:
scratch["chatting_end_time"] = None
scratch["act_path_set"] = self.act_path_set
scratch["planned_path"] = self.planned_path
scratch["act_path_set"] = self.act_path_set
scratch["planned_path"] = self.planned_path
with open(out_json, "w") as outfile:
json.dump(scratch, outfile, indent=2)
with open(out_json, "w") as outfile:
json.dump(scratch, outfile, indent=2)
def get_f_daily_schedule_index(self, advance=0):
"""
We get the current index of self.f_daily_schedule.
def get_f_daily_schedule_index(self, advance=0):
"""
We get the current index of self.f_daily_schedule.
Recall that self.f_daily_schedule stores the decomposed action sequences
up until now, and the hourly sequences of the future action for the rest
of today. Given that self.f_daily_schedule is a list of list where the
inner list is composed of [task, duration], we continue to add up the
duration until we reach "if elapsed > today_min_elapsed" condition. The
index where we stop is the index we will return.
Recall that self.f_daily_schedule stores the decomposed action sequences
up until now, and the hourly sequences of the future action for the rest
of today. Given that self.f_daily_schedule is a list of list where the
inner list is composed of [task, duration], we continue to add up the
duration until we reach "if elapsed > today_min_elapsed" condition. The
index where we stop is the index we will return.
INPUT
advance: Integer value of the number minutes we want to look into the
future. This allows us to get the index of a future timeframe.
OUTPUT
an integer value for the current index of f_daily_schedule.
"""
# We first calculate teh number of minutes elapsed today.
today_min_elapsed = 0
today_min_elapsed += self.curr_time.hour * 60
today_min_elapsed += self.curr_time.minute
today_min_elapsed += advance
INPUT
advance: Integer value of the number minutes we want to look into the
future. This allows us to get the index of a future timeframe.
OUTPUT
an integer value for the current index of f_daily_schedule.
"""
# We first calculate teh number of minutes elapsed today.
today_min_elapsed = 0
today_min_elapsed += self.curr_time.hour * 60
today_min_elapsed += self.curr_time.minute
today_min_elapsed += advance
x = 0
for task, duration in self.f_daily_schedule:
x += duration
x = 0
for task, duration in self.f_daily_schedule_hourly_org:
x += duration
x = 0
for task, duration in self.f_daily_schedule:
x += duration
x = 0
for task, duration in self.f_daily_schedule_hourly_org:
x += duration
# We then calculate the current index based on that.
curr_index = 0
elapsed = 0
for task, duration in self.f_daily_schedule:
elapsed += duration
if elapsed > today_min_elapsed:
return curr_index
curr_index += 1
# We then calculate the current index based on that.
curr_index = 0
elapsed = 0
for task, duration in self.f_daily_schedule:
elapsed += duration
if elapsed > today_min_elapsed:
return curr_index
curr_index += 1
return curr_index
def get_f_daily_schedule_hourly_org_index(self, advance=0):
"""
We get the current index of self.f_daily_schedule_hourly_org.
It is otherwise the same as get_f_daily_schedule_index.
def get_f_daily_schedule_hourly_org_index(self, advance=0):
"""
We get the current index of self.f_daily_schedule_hourly_org.
It is otherwise the same as get_f_daily_schedule_index.
INPUT
advance: Integer value of the number minutes we want to look into the
future. This allows us to get the index of a future timeframe.
OUTPUT
an integer value for the current index of f_daily_schedule.
"""
# We first calculate teh number of minutes elapsed today.
today_min_elapsed = 0
today_min_elapsed += self.curr_time.hour * 60
today_min_elapsed += self.curr_time.minute
today_min_elapsed += advance
# We then calculate the current index based on that.
curr_index = 0
elapsed = 0
for task, duration in self.f_daily_schedule_hourly_org:
elapsed += duration
if elapsed > today_min_elapsed:
INPUT
advance: Integer value of the number minutes we want to look into the
future. This allows us to get the index of a future timeframe.
OUTPUT
an integer value for the current index of f_daily_schedule.
"""
# We first calculate teh number of minutes elapsed today.
today_min_elapsed = 0
today_min_elapsed += self.curr_time.hour * 60
today_min_elapsed += self.curr_time.minute
today_min_elapsed += advance
# We then calculate the current index based on that.
curr_index = 0
elapsed = 0
for task, duration in self.f_daily_schedule_hourly_org:
elapsed += duration
if elapsed > today_min_elapsed:
return curr_index
curr_index += 1
return curr_index
curr_index += 1
return curr_index
def get_str_iss(self):
"""
ISS stands for "identity stable set." This describes the commonset summary
of this persona -- basically, the bare minimum description of the persona
that gets used in almost all prompts that need to call on the persona.
def get_str_iss(self):
"""
ISS stands for "identity stable set." This describes the commonset summary
of this persona -- basically, the bare minimum description of the persona
that gets used in almost all prompts that need to call on the persona.
INPUT
None
OUTPUT
the identity stable set summary of the persona in a string form.
EXAMPLE STR OUTPUT
"Name: Dolores Heitmiller
Age: 28
Innate traits: hard-edged, independent, loyal
Learned traits: Dolores is a painter who wants live quietly and paint
while enjoying her everyday life.
Currently: Dolores is preparing for her first solo show. She mostly
works from home.
Lifestyle: Dolores goes to bed around 11pm, sleeps for 7 hours, eats
dinner around 6pm.
Daily plan requirement: Dolores is planning to stay at home all day and
never go out."
"""
commonset = ""
commonset += f"Name: {self.name}\n"
commonset += f"Age: {self.age}\n"
commonset += f"Innate traits: {self.innate}\n"
commonset += f"Learned traits: {self.learned}\n"
commonset += f"Currently: {self.currently}\n"
commonset += f"Lifestyle: {self.lifestyle}\n"
commonset += f"Daily plan requirement: {self.daily_plan_req}\n"
commonset += f"Current Date: {self.curr_time.strftime('%A %B %d')}\n"
return commonset
INPUT
None
OUTPUT
the identity stable set summary of the persona in a string form.
EXAMPLE STR OUTPUT
"Name: Dolores Heitmiller
Age: 28
Innate traits: hard-edged, independent, loyal
Learned traits: Dolores is a painter who wants live quietly and paint
while enjoying her everyday life.
Currently: Dolores is preparing for her first solo show. She mostly
works from home.
Lifestyle: Dolores goes to bed around 11pm, sleeps for 7 hours, eats
dinner around 6pm.
Daily plan requirement: Dolores is planning to stay at home all day and
never go out."
"""
commonset = ""
commonset += f"Name: {self.name}\n"
commonset += f"Age: {self.age}\n"
commonset += f"Innate traits: {self.innate}\n"
commonset += f"Learned traits: {self.learned}\n"
commonset += f"Currently: {self.currently}\n"
commonset += f"Lifestyle: {self.lifestyle}\n"
commonset += f"Daily plan requirement: {self.daily_plan_req}\n"
commonset += f"Current Date: {self.curr_time.strftime('%A %B %d')}\n"
return commonset
def get_str_name(self):
return self.name
def get_str_firstname(self):
return self.first_name
def get_str_name(self):
return self.name
def get_str_lastname(self):
return self.last_name
def get_str_age(self):
return str(self.age)
def get_str_firstname(self):
return self.first_name
def get_str_innate(self):
return self.innate
def get_str_learned(self):
return self.learned
def get_str_lastname(self):
return self.last_name
def get_str_currently(self):
return self.currently
def get_str_lifestyle(self):
return self.lifestyle
def get_str_age(self):
return str(self.age)
def get_str_daily_plan_req(self):
return self.daily_plan_req
def get_str_curr_date_str(self):
return self.curr_time.strftime("%A %B %d")
def get_str_innate(self):
return self.innate
def get_curr_event(self):
if not self.act_address:
return (self.name, None, None)
else:
return self.act_event
def get_curr_event_and_desc(self):
if not self.act_address:
return (self.name, None, None, None)
else:
return (self.act_event[0],
self.act_event[1],
self.act_event[2],
self.act_description)
def get_str_learned(self):
return self.learned
def get_curr_obj_event_and_desc(self):
if not self.act_address:
return ("", None, None, None)
else:
return (self.act_address,
self.act_obj_event[1],
self.act_obj_event[2],
self.act_obj_description)
def add_new_action(self,
action_address,
action_duration,
action_description,
action_pronunciatio,
action_event,
chatting_with,
chat,
chatting_with_buffer,
chatting_end_time,
act_obj_description,
act_obj_pronunciatio,
act_obj_event,
act_start_time=None):
self.act_address = action_address
self.act_duration = action_duration
self.act_description = action_description
self.act_pronunciatio = action_pronunciatio
self.act_event = action_event
def get_str_currently(self):
return self.currently
self.chatting_with = chatting_with
self.chat = chat
if chatting_with_buffer:
self.chatting_with_buffer.update(chatting_with_buffer)
self.chatting_end_time = chatting_end_time
self.act_obj_description = act_obj_description
self.act_obj_pronunciatio = act_obj_pronunciatio
self.act_obj_event = act_obj_event
def get_str_lifestyle(self):
return self.lifestyle
self.act_start_time = self.curr_time
self.act_path_set = False
def get_str_daily_plan_req(self):
return self.daily_plan_req
def act_time_str(self):
"""
Returns a string output of the current time.
INPUT
None
OUTPUT
A string output of the current time.
EXAMPLE STR OUTPUT
"14:05 P.M."
"""
return self.act_start_time.strftime("%H:%M %p")
def get_str_curr_date_str(self):
return self.curr_time.strftime("%A %B %d")
def act_check_finished(self):
"""
Checks whether the self.Action instance has finished.
INPUT
curr_datetime: Current time. If current time is later than the action's
start time + its duration, then the action has finished.
OUTPUT
Boolean [True]: Action has finished.
Boolean [False]: Action has not finished and is still ongoing.
"""
if not self.act_address:
return True
def get_curr_event(self):
if not self.act_address:
return (self.name, None, None)
else:
return self.act_event
if self.chatting_with:
end_time = self.chatting_end_time
else:
x = self.act_start_time
if x.second != 0:
x = x.replace(second=0)
x = (x + datetime.timedelta(minutes=1))
end_time = (x + datetime.timedelta(minutes=self.act_duration))
if end_time.strftime("%H:%M:%S") == self.curr_time.strftime("%H:%M:%S"):
return True
return False
def get_curr_event_and_desc(self):
if not self.act_address:
return (self.name, None, None, None)
else:
return (self.act_event[0],
self.act_event[1],
self.act_event[2],
self.act_description)
def act_summarize(self):
"""
Summarize the current action as a dictionary.
INPUT
None
OUTPUT
ret: A human readable summary of the action.
"""
exp = dict()
exp["persona"] = self.name
exp["address"] = self.act_address
exp["start_datetime"] = self.act_start_time
exp["duration"] = self.act_duration
exp["description"] = self.act_description
exp["pronunciatio"] = self.act_pronunciatio
return exp
def get_curr_obj_event_and_desc(self):
if not self.act_address:
return ("", None, None, None)
else:
return (self.act_address,
self.act_obj_event[1],
self.act_obj_event[2],
self.act_obj_description)
def act_summary_str(self):
"""
Returns a string summary of the current action. Meant to be
human-readable.
INPUT
None
OUTPUT
ret: A human readable summary of the action.
"""
start_datetime_str = self.act_start_time.strftime("%A %B %d -- %H:%M %p")
ret = f"[{start_datetime_str}]\n"
ret += f"Activity: {self.name} is {self.act_description}\n"
ret += f"Address: {self.act_address}\n"
ret += f"Duration in minutes (e.g., x min): {str(self.act_duration)} min\n"
return ret
def add_new_action(self,
action_address,
action_duration,
action_description,
action_pronunciatio,
action_event,
chatting_with,
chat,
chatting_with_buffer,
chatting_end_time,
act_obj_description,
act_obj_pronunciatio,
act_obj_event,
act_start_time=None):
self.act_address = action_address
self.act_duration = action_duration
self.act_description = action_description
self.act_pronunciatio = action_pronunciatio
self.act_event = action_event
def get_str_daily_schedule_summary(self):
ret = ""
curr_min_sum = 0
for row in self.f_daily_schedule:
curr_min_sum += row[1]
hour = int(curr_min_sum / 60)
minute = curr_min_sum % 60
ret += f"{hour:02}:{minute:02} || {row[0]}\n"
return ret
self.chatting_with = chatting_with
self.chat = chat
if chatting_with_buffer:
self.chatting_with_buffer.update(chatting_with_buffer)
self.chatting_end_time = chatting_end_time
self.act_obj_description = act_obj_description
self.act_obj_pronunciatio = act_obj_pronunciatio
self.act_obj_event = act_obj_event
self.act_start_time = self.curr_time
self.act_path_set = False
def act_time_str(self):
"""
Returns a string output of the current time.
INPUT
None
OUTPUT
A string output of the current time.
EXAMPLE STR OUTPUT
"14:05 P.M."
"""
return self.act_start_time.strftime("%H:%M %p")
def act_check_finished(self):
"""
Checks whether the self.Action instance has finished.
INPUT
curr_datetime: Current time. If current time is later than the action's
start time + its duration, then the action has finished.
OUTPUT
Boolean [True]: Action has finished.
Boolean [False]: Action has not finished and is still ongoing.
"""
if not self.act_address:
return True
if self.chatting_with:
end_time = self.chatting_end_time
else:
x = self.act_start_time
if x.second != 0:
x = x.replace(second=0)
x = (x + datetime.timedelta(minutes=1))
end_time = (x + datetime.timedelta(minutes=self.act_duration))
if end_time.strftime("%H:%M:%S") == self.curr_time.strftime("%H:%M:%S"):
return True
return False
def act_summarize(self):
"""
Summarize the current action as a dictionary.
INPUT
None
OUTPUT
ret: A human readable summary of the action.
"""
exp = dict()
exp["persona"] = self.name
exp["address"] = self.act_address
exp["start_datetime"] = self.act_start_time
exp["duration"] = self.act_duration
exp["description"] = self.act_description
exp["pronunciatio"] = self.act_pronunciatio
return exp
def act_summary_str(self):
"""
Returns a string summary of the current action. Meant to be
human-readable.
INPUT
None
OUTPUT
ret: A human readable summary of the action.
"""
start_datetime_str = self.act_start_time.strftime("%A %B %d -- %H:%M %p")
ret = f"[{start_datetime_str}]\n"
ret += f"Activity: {self.name} is {self.act_description}\n"
ret += f"Address: {self.act_address}\n"
ret += f"Duration in minutes (e.g., x min): {str(self.act_duration)} min\n"
return ret
def get_str_daily_schedule_summary(self):
ret = ""
curr_min_sum = 0
for row in self.f_daily_schedule:
curr_min_sum += row[1]
hour = int(curr_min_sum/60)
minute = curr_min_sum%60
ret += f"{hour:02}:{minute:02} || {row[0]}\n"
return ret
def get_str_daily_schedule_hourly_org_summary(self):
ret = ""
curr_min_sum = 0
for row in self.f_daily_schedule_hourly_org:
curr_min_sum += row[1]
hour = int(curr_min_sum/60)
minute = curr_min_sum%60
ret += f"{hour:02}:{minute:02} || {row[0]}\n"
return ret
def get_str_daily_schedule_hourly_org_summary(self):
ret = ""
curr_min_sum = 0
for row in self.f_daily_schedule_hourly_org:
curr_min_sum += row[1]
hour = int(curr_min_sum / 60)
minute = curr_min_sum % 60
ret += f"{hour:02}:{minute:02} || {row[0]}\n"
return ret

View file

@ -3,130 +3,113 @@ Author: Joon Sung Park (joonspk@stanford.edu)
File: spatial_memory.py
Description: Defines the MemoryTree class that serves as the agents' spatial
memory that aids in grounding their behavior in the game world.
memory that aids in grounding their behavior in the game world.
"""
import json
import os
class MemoryTree:
def __init__(self, f_saved: str) -> None:
self.tree = {}
if os.path.isfile(f_saved) and os.path.exists(f_saved):
with open(f_saved) as f:
self.tree = json.load(f)
class MemoryTree:
def __init__(self, f_saved: str) -> None:
self.tree = {}
if os.path.isfile(f_saved) and os.path.exists(f_saved):
with open(f_saved) as f:
self.tree = json.load(f)
def print_tree(self) -> None:
def _print_tree(tree, depth):
dash = " >" * depth
if type(tree) == type(list()):
if tree:
print (dash, tree)
return
def print_tree(self) -> None:
def _print_tree(tree, depth):
dash = " >" * depth
if isinstance(tree, list):
if tree:
print(dash, tree)
return
for key, val in tree.items():
if key:
print (dash, key)
_print_tree(val, depth+1)
_print_tree(self.tree, 0)
for key, val in tree.items():
if key:
print(dash, key)
_print_tree(val, depth + 1)
def save(self, out_json: str) -> None:
with open(out_json, "w") as outfile:
json.dump(self.tree, outfile)
_print_tree(self.tree, 0)
def save(self, out_json: str) -> None:
with open(out_json, "w") as outfile:
json.dump(self.tree, outfile)
def get_str_accessible_sectors(self, curr_world: str) -> str:
"""
Returns a summary string of all the arenas that the persona can access
within the current sector.
def get_str_accessible_sectors(self, curr_world: str) -> str:
"""
Returns a summary string of all the arenas that the persona can access
within the current sector.
Note that there are places a given persona cannot enter. This information
is provided in the persona sheet. We account for this in this function.
INPUT
None
OUTPUT
A summary string of all the arenas that the persona can access.
EXAMPLE STR OUTPUT
"bedroom, kitchen, dining room, office, bathroom"
"""
x = ", ".join(list(self.tree[curr_world].keys()))
return x
def get_str_accessible_sector_arenas(self, sector: str) -> str:
"""
Returns a summary string of all the arenas that the persona can access
within the current sector.
Note that there are places a given persona cannot enter. This information
is provided in the persona sheet. We account for this in this function.
INPUT
None
OUTPUT
A summary string of all the arenas that the persona can access.
EXAMPLE STR OUTPUT
"bedroom, kitchen, dining room, office, bathroom"
"""
curr_world, curr_sector = sector.split(":")
if not curr_sector:
return ""
x = ", ".join(list(self.tree[curr_world][curr_sector].keys()))
return x
def get_str_accessible_arena_game_objects(self, arena: str) -> str:
"""
Get a str list of all accessible game objects that are in the arena. If
temp_address is specified, we return the objects that are available in
that arena, and if not, we return the objects that are in the arena our
persona is currently in.
INPUT
temp_address: optional arena address
OUTPUT
str list of all accessible game objects in the gmae arena.
EXAMPLE STR OUTPUT
"phone, charger, bed, nightstand"
"""
curr_world, curr_sector, curr_arena = arena.split(":")
if not curr_arena:
return ""
try:
x = ", ".join(list(self.tree[curr_world][curr_sector][curr_arena]))
except:
x = ", ".join(list(self.tree[curr_world][curr_sector][curr_arena.lower()]))
return x
def add_tile_info(self, tile_info: dict) -> None:
if tile_info["world"]:
if (tile_info["world"] not in self.tree):
self.tree[tile_info["world"]] = {}
if tile_info["sector"]:
if (tile_info["sector"] not in self.tree[tile_info["world"]]):
self.tree[tile_info["world"]][tile_info["sector"]] = {}
if tile_info["arena"]:
if (tile_info["arena"] not in self.tree[tile_info["world"]]
[tile_info["sector"]]):
self.tree[tile_info["world"]][tile_info["sector"]][tile_info["arena"]] = []
if tile_info["game_object"]:
if (tile_info["game_object"] not in self.tree[tile_info["world"]]
[tile_info["sector"]]
[tile_info["arena"]]):
self.tree[tile_info["world"]][tile_info["sector"]][tile_info["arena"]] += [
tile_info["game_object"]]
Note that there are places a given persona cannot enter. This information
is provided in the persona sheet. We account for this in this function.
INPUT
None
OUTPUT
A summary string of all the arenas that the persona can access.
EXAMPLE STR OUTPUT
"bedroom, kitchen, dining room, office, bathroom"
"""
x = ", ".join(list(self.tree[curr_world].keys()))
return x
def get_str_accessible_sector_arenas(self, sector: str) -> str:
"""
Returns a summary string of all the arenas that the persona can access
within the current sector.
Note that there are places a given persona cannot enter. This information
is provided in the persona sheet. We account for this in this function.
INPUT
None
OUTPUT
A summary string of all the arenas that the persona can access.
EXAMPLE STR OUTPUT
"bedroom, kitchen, dining room, office, bathroom"
"""
curr_world, curr_sector = sector.split(":")
if not curr_sector:
return ""
x = ", ".join(list(self.tree[curr_world][curr_sector].keys()))
return x
def get_str_accessible_arena_game_objects(self, arena: str) -> str:
"""
Get a str list of all accessible game objects that are in the arena. If
temp_address is specified, we return the objects that are available in
that arena, and if not, we return the objects that are in the arena our
persona is currently in.
INPUT
temp_address: optional arena address
OUTPUT
str list of all accessible game objects in the gmae arena.
EXAMPLE STR OUTPUT
"phone, charger, bed, nightstand"
"""
curr_world, curr_sector, curr_arena = arena.split(":")
if not curr_arena:
return ""
try:
x = ", ".join(list(self.tree[curr_world][curr_sector][curr_arena]))
except Exception as exp:
x = ", ".join(list(self.tree[curr_world][curr_sector][curr_arena.lower()]))
return x
def add_tile_info(self, tile_info: dict) -> None:
if tile_info["world"]:
if tile_info["world"] not in self.tree:
self.tree[tile_info["world"]] = {}
if tile_info["sector"]:
if tile_info["sector"] not in self.tree[tile_info["world"]]:
self.tree[tile_info["world"]][tile_info["sector"]] = {}
if tile_info["arena"]:
if tile_info["arena"] not in self.tree[tile_info["world"]][tile_info["sector"]]:
self.tree[tile_info["world"]][tile_info["sector"]][tile_info["arena"]] = []
if tile_info["game_object"]:
if tile_info["game_object"] not in self.tree[tile_info["world"]][tile_info["sector"]][tile_info["arena"]]:
self.tree[tile_info["world"]][tile_info["sector"]][tile_info["arena"]] += [
tile_info["game_object"]]

View file

@ -0,0 +1 @@
networkx

View file

@ -44,8 +44,9 @@ class STRole(Role):
def __init__(self,
name: str = "Klaus Mueller",
profile: str = "STMember",
sim_path: str = "new_sim",
has_inner_voice: bool = False):
self.sim_path = sim_path
self._rc = STRoleContext()
super(STRole, self).__init__(name=name,
profile=profile)
@ -73,147 +74,148 @@ class STRole(Role):
async def observe(self) -> list[BasicMemory]:
# TODO observe info from maze_env
"""
Perceive events around the role and saves it to the memory, both events
and spaces.
Perceive events around the role and saves it to the memory, both events
and spaces.
We first perceive the events nearby the role, as determined by its
<vision_r>. If there are a lot of events happening within that radius, we
We first perceive the events nearby the role, as determined by its
<vision_r>. If there are a lot of events happening within that radius, we
take the <att_bandwidth> of the closest events. Finally, we check whether
any of them are new, as determined by <retention>. If they are new, then we
save those and return the <BasicMemory> instances for those events.
save those and return the <BasicMemory> instances for those events.
OUTPUT:
ret_events: a list of <BasicMemory> that are perceived and new.
OUTPUT:
ret_events: a list of <BasicMemory> that are perceived and new.
"""
maze = self._rc.env.maze
# PERCEIVE SPACE
# We get the nearby tiles given our current tile and the persona's vision
# radius.
nearby_tiles = maze.get_nearby_tiles(self._rc.scratch.curr_tile,
self._rc.scratch.vision_r)
# radius.
nearby_tiles = maze.get_nearby_tiles(self._rc.scratch.curr_tile,
self._rc.scratch.vision_r)
# We then store the perceived space. Note that the s_mem of the persona is
# in the form of a tree constructed using dictionaries.
for tile in nearby_tiles:
# in the form of a tree constructed using dictionaries.
for tile in nearby_tiles:
tile_info = maze.access_tile(tile)
self._rc.spatial_memory.add_tile_info(tile_info)
# PERCEIVE EVENTS.
# PERCEIVE EVENTS.
# We will perceive events that take place in the same arena as the
# persona's current arena.
# persona's current arena.
curr_arena_path = maze.get_tile_path(self._rc.scratch.curr_tile, "arena")
# We do not perceive the same event twice (this can happen if an object is
# extended across multiple tiles).
percept_events_set = set()
# We will order our percept based on the distance, with the closest ones
# getting priorities.
# getting priorities.
percept_events_list = []
# First, we put all events that are occuring in the nearby tiles into the
# percept_events_list
for tile in nearby_tiles:
for tile in nearby_tiles:
tile_details = maze.access_tile(tile)
if tile_details["events"]:
if maze.get_tile_path(tile, "arena") == curr_arena_path:
# This calculates the distance between the persona's current tile,
# and the target tile.
dist = math.dist([tile[0], tile[1]],
[self._rc.scratch.curr_tile[0],
self._rc.scratch.curr_tile[1]])
# Add any relevant events to our temp set/list with the distant info.
for event in tile_details["events"]:
if event not in percept_events_set:
if tile_details["events"]:
if maze.get_tile_path(tile, "arena") == curr_arena_path:
# This calculates the distance between the persona's current tile,
# and the target tile.
dist = math.dist([tile[0], tile[1]],
[self._rc.scratch.curr_tile[0],
self._rc.scratch.curr_tile[1]])
# Add any relevant events to our temp set/list with the distant info.
for event in tile_details["events"]:
if event not in percept_events_set:
percept_events_list += [[dist, event]]
percept_events_set.add(event)
# We sort, and perceive only self._rc.scratch.att_bandwidth of the closest
# events. If the bandwidth is larger, then it means the persona can perceive
# more elements within a small area.
# more elements within a small area.
percept_events_list = sorted(percept_events_list, key=itemgetter(0))
perceived_events = []
for dist, event in percept_events_list[:self._rc.scratch.att_bandwidth]:
for dist, event in percept_events_list[:self._rc.scratch.att_bandwidth]:
perceived_events += [event]
# Storing events.
# <ret_events> is a list of <BasicMemory> instances from the persona's
# associative memory.
# Storing events.
# <ret_events> is a list of <BasicMemory> instances from the persona's
# associative memory.
ret_events = []
for p_event in perceived_events:
for p_event in perceived_events:
s, p, o, desc = p_event
if not p:
# If the object is not present, then we default the event to "idle".
if not p:
# If the object is not present, then we default the event to "idle".
p = "is"
o = "idle"
desc = "idle"
desc = f"{s.split(':')[-1]} is {desc}"
p_event = (s, p, o)
# We retrieve the latest self._rc.scratch.retention events. If there is
# We retrieve the latest self._rc.scratch.retention events. If there is
# something new that is happening (that is, p_event not in latest_events),
# then we add that event to the a_mem and return it.
# then we add that event to the a_mem and return it.
latest_events = self._rc.memory.get_summarized_latest_events(
self._rc.scratch.retention)
self._rc.scratch.retention)
if p_event not in latest_events:
# We start by managing keywords.
# We start by managing keywords.
keywords = set()
sub = p_event[0]
obj = p_event[2]
if ":" in p_event[0]:
if ":" in p_event[0]:
sub = p_event[0].split(":")[-1]
if ":" in p_event[2]:
if ":" in p_event[2]:
obj = p_event[2].split(":")[-1]
keywords.update([sub, obj])
# Get event embedding
desc_embedding_in = desc
if "(" in desc:
if "(" in desc:
desc_embedding_in = (desc_embedding_in.split("(")[1]
.split(")")[0]
.strip())
if desc_embedding_in in self._rc.memory.embeddings:
.split(")")[0]
.strip())
if desc_embedding_in in self._rc.memory.embeddings:
event_embedding = self._rc.memory.embeddings[desc_embedding_in]
else:
else:
event_embedding = get_embedding(desc_embedding_in)
event_embedding_pair = (desc_embedding_in, event_embedding)
# Get event poignancy.
event_poignancy = generate_poig_score(self,
"action",
desc_embedding_in)
# Get event poignancy.
event_poignancy = generate_poig_score(self,
"action",
desc_embedding_in)
# If we observe the persona's self chat, we include that in the memory
# of the persona here.
# of the persona here.
chat_node_ids = []
if p_event[0] == f"{self.name}" and p_event[1] == "chat with":
if p_event[0] == f"{self.name}" and p_event[1] == "chat with":
curr_event = self._rc.scratch.act_event
if self._rc.scratch.act_description in self._rc.memory.embeddings:
if self._rc.scratch.act_description in self._rc.memory.embeddings:
chat_embedding = self._rc.memory.embeddings[
self._rc.scratch.act_description]
else:
self._rc.scratch.act_description]
else:
chat_embedding = get_embedding(self._rc.scratch
.act_description)
chat_embedding_pair = (self._rc.scratch.act_description,
chat_embedding)
chat_poignancy = generate_poig_score(self._rc.scratch, "chat",
self._rc.scratch.act_description)
.act_description)
chat_embedding_pair = (self._rc.scratch.act_description,
chat_embedding)
chat_poignancy = generate_poig_score(self._rc.scratch, "chat",
self._rc.scratch.act_description)
chat_node = self._rc.memory.add_chat(self._rc.scratch.curr_time, None,
curr_event[0], curr_event[1], curr_event[2],
self._rc.scratch.act_description, keywords,
chat_poignancy, chat_embedding_pair,
self._rc.scratch.chat)
curr_event[0], curr_event[1], curr_event[2],
self._rc.scratch.act_description, keywords,
chat_poignancy, chat_embedding_pair,
self._rc.scratch.chat)
chat_node_ids = [chat_node.node_id]
# Finally, we add the current event to the agent's memory.
# Finally, we add the current event to the agent's memory.
ret_events += [self._rc.memory.add_event(self._rc.scratch.curr_time, None,
s, p, o, desc, keywords, event_poignancy,
event_embedding_pair, chat_node_ids)]
s, p, o, desc, keywords, event_poignancy,
event_embedding_pair, chat_node_ids)]
self._rc.scratch.importance_trigger_curr -= event_poignancy
self._rc.scratch.importance_ele_n += 1
return ret_events
async def retrieve(self, query, n = 30 ,topk = 4):
async def retrieve(self, query, n=30, topk=4):
# TODO retrieve memories from agent_memory
retrieve_memories = agent_retrieve(self._rc.memory, self._rc.scratch.curr_time, self._rc.scratch.recency_decay, query, n, topk)
retrieve_memories = agent_retrieve(self._rc.memory, self._rc.scratch.curr_time, self._rc.scratch.recency_decay,
query, n, topk)
return retrieve_memories
async def plan(self):

View file

@ -20,11 +20,14 @@ async def startup(idea: str,
# get role names from `storage/{simulation_name}/reverie/meta.json` and then init roles
reverie_meta = get_reverie_meta(fork_sim_code)
roles = []
# TODO
sim_path = STORAGE_PATH.joinpath(sim_code)
for idx, role_name in enumerate(reverie_meta["persona_names"]):
role_stg_path = STORAGE_PATH.joinpath(fork_sim_code).joinpath(f"personas/{role_name}")
has_inner_voice = True if idx == 0 else False
role = STRole(name=role_name, has_inner_voice=has_inner_voice)
role = STRole(name=role_name,
sim_path=sim_path,
profile=f"STMember_{idx}",
has_inner_voice=has_inner_voice)
role.load_from(role_stg_path)
roles.append(role)

View file

@ -7,6 +7,7 @@ from pydantic import Field
from metagpt.software_company import SoftwareCompany
from metagpt.roles.role import Role
from metagpt.schema import Message
from metagpt.logs import logger
from maze_environment import MazeEnvironment
from actions.user_requirement import UserRequirement
@ -17,6 +18,7 @@ class StanfordTown(SoftwareCompany):
environment: MazeEnvironment = Field(default_factory=MazeEnvironment)
def wakeup_roles(self, roles: list[Role]):
logger.warning(f"The Town add {len(roles)} roles, and start to operate.")
self.environment.add_roles(roles)
def start_project(self, idea):