From fb376307661f08ce53317ddfe1226d299ee8e52d Mon Sep 17 00:00:00 2001 From: SereneWalden <22496084+SereneWalden@users.noreply.github.com> Date: Sat, 30 Sep 2023 22:34:43 +0800 Subject: [PATCH] add input/output types --- examples/st_game/maze.py | 26 +++++++++++------------ examples/st_game/memory/spatial_memory.py | 14 ++++++------ examples/st_game/roles/st_role.py | 4 ++-- 3 files changed, 22 insertions(+), 22 deletions(-) diff --git a/examples/st_game/maze.py b/examples/st_game/maze.py index 8add39f90..1e2ef8ccc 100644 --- a/examples/st_game/maze.py +++ b/examples/st_game/maze.py @@ -18,7 +18,7 @@ from .utils.const import MAZE_ASSET_PATH from .utils.utils import read_csv_to_list class Maze: - def __init__(self, maze_asset_path: Path = MAZE_ASSET_PATH): + def __init__(self, maze_asset_path: Path = MAZE_ASSET_PATH) -> None: # READING IN THE BASIC META INFORMATION ABOUT THE MAP self.maze_asset_path = maze_asset_path maze_matrix_path = maze_asset_path.joinpath("matrix") @@ -216,7 +216,7 @@ class Maze: self.nx_graph = grid_graph - def turn_coordinate_to_tile(self, px_coordinate): + def turn_coordinate_to_tile(self, px_coordinate: tuple[int, int]) -> tuple[int, int]: """ Turns a pixel coordinate to a tile coordinate. @@ -234,7 +234,7 @@ class Maze: return (x, y) - def access_tile(self, tile): + def access_tile(self, tile: tuple[int, int]) -> dict: """ Returns the tiles details dictionary that is stored in self.tiles of the designated x, y location. @@ -257,7 +257,7 @@ class Maze: return self.tiles[y][x] - def get_tile_path(self, tile, level): + def get_tile_path(self, tile: tuple[int, int], level: str) -> str: """ Get the tile string address given its coordinate. You designate the level by giving it a string level description. @@ -294,7 +294,7 @@ class Maze: return path - def get_nearby_tiles(self, tile, vision_r): + def get_nearby_tiles(self, tile: tuple[int, int], vision_r: int) -> list[tuple[int, int]]: """ Given the current tile and vision_r, return a list of tiles that are within the radius. Note that this implementation looks at a square @@ -335,7 +335,7 @@ class Maze: return nearby_tiles - def add_event_from_tile(self, curr_event, tile): + def add_event_from_tile(self, curr_event: tuple[str], tile: tuple[int, int]) -> None: """ Add an event triple to a tile. @@ -350,8 +350,8 @@ class Maze: self.tiles[tile[1]][tile[0]]["events"].add(curr_event) - def remove_event_from_tile(self, curr_event, tile): - """ + def remove_event_from_tile(self, curr_event: tuple[str], tile: tuple[int, int]) -> None: + """dswaq Remove an event triple from a tile. INPUT: @@ -368,7 +368,7 @@ class Maze: self.tiles[tile[1]][tile[0]]["events"].remove(event) - def turn_event_from_tile_idle(self, curr_event, tile): + def turn_event_from_tile_idle(self, curr_event: tuple[str], tile: tuple[int, int]) -> None: curr_tile_ev_cp = self.tiles[tile[1]][tile[0]]["events"].copy() for event in curr_tile_ev_cp: if event == curr_event: @@ -377,7 +377,7 @@ class Maze: self.tiles[tile[1]][tile[0]]["events"].add(new_event) - def remove_subject_events_from_tile(self, subject, tile): + def remove_subject_events_from_tile(self, subject: str, tile: tuple[int, int]) -> None: """ Remove an event triple that has the input subject from a tile. @@ -393,7 +393,7 @@ class Maze: self.tiles[tile[1]][tile[0]]["events"].remove(event) - def _find_closest_node(self, coords): + def _find_closest_node(self, coords: tuple[int, int]) -> tuple[int, int]: target_coords = self.nx_graph.nodes min_dist = None closest_coordinate = None @@ -408,9 +408,9 @@ class Maze: closest_coordinate = target return closest_coordinate - def find_path(self, start, end): + def find_path(self, start: tuple[int, int], end: tuple[int, int]) -> list[tuple[int, int]]: if start not in self.nx_graph.nodes: start = self._find_closest_node(start) if end not in self.nx_graph.nodes: end = self._find_closest_node(end) - return self.nx_graph.shortest_path(start, end) + return nx.shortest_path(self.nx_graph, start, end) diff --git a/examples/st_game/memory/spatial_memory.py b/examples/st_game/memory/spatial_memory.py index b3357b962..455d60e05 100644 --- a/examples/st_game/memory/spatial_memory.py +++ b/examples/st_game/memory/spatial_memory.py @@ -9,14 +9,14 @@ import json import os class MemoryTree: - def __init__(self, f_saved): + def __init__(self, f_saved: str) -> None: self.tree = {} if os.path.isfile(f_saved) and os.path.exists(f_saved): with open(f_saved) as f: self.tree = json.load(f) - def print_tree(self): + def print_tree(self) -> None: def _print_tree(tree, depth): dash = " >" * depth if type(tree) == type(list()): @@ -32,12 +32,12 @@ class MemoryTree: _print_tree(self.tree, 0) - def save(self, out_json): + def save(self, out_json: str) -> None: with open(out_json, "w") as outfile: json.dump(self.tree, outfile) - def get_str_accessible_sectors(self, curr_world): + def get_str_accessible_sectors(self, curr_world: str) -> str: """ Returns a summary string of all the arenas that the persona can access within the current sector. @@ -56,7 +56,7 @@ class MemoryTree: return x - def get_str_accessible_sector_arenas(self, sector): + def get_str_accessible_sector_arenas(self, sector: str) -> str: """ Returns a summary string of all the arenas that the persona can access within the current sector. @@ -78,7 +78,7 @@ class MemoryTree: return x - def get_str_accessible_arena_game_objects(self, arena): + def get_str_accessible_arena_game_objects(self, arena: str) -> str: """ Get a str list of all accessible game objects that are in the arena. If temp_address is specified, we return the objects that are available in @@ -104,7 +104,7 @@ class MemoryTree: return x - def add_tile_info(self, tile_info: dict): + def add_tile_info(self, tile_info: dict) -> None: if tile_info["world"]: if (tile_info["world"] not in self.tree): self.tree[tile_info["world"]] = {} diff --git a/examples/st_game/roles/st_role.py b/examples/st_game/roles/st_role.py index c16180a72..ff807b95b 100644 --- a/examples/st_game/roles/st_role.py +++ b/examples/st_game/roles/st_role.py @@ -18,7 +18,7 @@ from operator import itemgetter from metagpt.roles.role import Role, RoleContext from metagpt.schema import Message -from ..memory.agent_memory import AgentMemory +from ..memory.agent_memory import AgentMemory, BasicMemory from ..memory.spatial_memory import MemoryTree from ..actions.dummy_action import DummyAction from ..actions.user_requirement import UserRequirement @@ -68,7 +68,7 @@ class STRole(Role): """ pass - async def observe(self): + async def observe(self) -> list[BasicMemory]: # TODO observe info from maze_env """ Perceive events around the role and saves it to the memory, both events