add input/output types

This commit is contained in:
SereneWalden 2023-09-30 22:34:43 +08:00
parent 42c12ab277
commit fb37630766
3 changed files with 22 additions and 22 deletions

View file

@ -18,7 +18,7 @@ from .utils.const import MAZE_ASSET_PATH
from .utils.utils import read_csv_to_list
class Maze:
def __init__(self, maze_asset_path: Path = MAZE_ASSET_PATH):
def __init__(self, maze_asset_path: Path = MAZE_ASSET_PATH) -> None:
# READING IN THE BASIC META INFORMATION ABOUT THE MAP
self.maze_asset_path = maze_asset_path
maze_matrix_path = maze_asset_path.joinpath("matrix")
@ -216,7 +216,7 @@ class Maze:
self.nx_graph = grid_graph
def turn_coordinate_to_tile(self, px_coordinate):
def turn_coordinate_to_tile(self, px_coordinate: tuple[int, int]) -> tuple[int, int]:
"""
Turns a pixel coordinate to a tile coordinate.
@ -234,7 +234,7 @@ class Maze:
return (x, y)
def access_tile(self, tile):
def access_tile(self, tile: tuple[int, int]) -> dict:
"""
Returns the tiles details dictionary that is stored in self.tiles of the
designated x, y location.
@ -257,7 +257,7 @@ class Maze:
return self.tiles[y][x]
def get_tile_path(self, tile, level):
def get_tile_path(self, tile: tuple[int, int], level: str) -> str:
"""
Get the tile string address given its coordinate. You designate the level
by giving it a string level description.
@ -294,7 +294,7 @@ class Maze:
return path
def get_nearby_tiles(self, tile, vision_r):
def get_nearby_tiles(self, tile: tuple[int, int], vision_r: int) -> list[tuple[int, int]]:
"""
Given the current tile and vision_r, return a list of tiles that are
within the radius. Note that this implementation looks at a square
@ -335,7 +335,7 @@ class Maze:
return nearby_tiles
def add_event_from_tile(self, curr_event, tile):
def add_event_from_tile(self, curr_event: tuple[str], tile: tuple[int, int]) -> None:
"""
Add an event triple to a tile.
@ -350,8 +350,8 @@ class Maze:
self.tiles[tile[1]][tile[0]]["events"].add(curr_event)
def remove_event_from_tile(self, curr_event, tile):
"""
def remove_event_from_tile(self, curr_event: tuple[str], tile: tuple[int, int]) -> None:
"""dswaq
Remove an event triple from a tile.
INPUT:
@ -368,7 +368,7 @@ class Maze:
self.tiles[tile[1]][tile[0]]["events"].remove(event)
def turn_event_from_tile_idle(self, curr_event, tile):
def turn_event_from_tile_idle(self, curr_event: tuple[str], tile: tuple[int, int]) -> None:
curr_tile_ev_cp = self.tiles[tile[1]][tile[0]]["events"].copy()
for event in curr_tile_ev_cp:
if event == curr_event:
@ -377,7 +377,7 @@ class Maze:
self.tiles[tile[1]][tile[0]]["events"].add(new_event)
def remove_subject_events_from_tile(self, subject, tile):
def remove_subject_events_from_tile(self, subject: str, tile: tuple[int, int]) -> None:
"""
Remove an event triple that has the input subject from a tile.
@ -393,7 +393,7 @@ class Maze:
self.tiles[tile[1]][tile[0]]["events"].remove(event)
def _find_closest_node(self, coords):
def _find_closest_node(self, coords: tuple[int, int]) -> tuple[int, int]:
target_coords = self.nx_graph.nodes
min_dist = None
closest_coordinate = None
@ -408,9 +408,9 @@ class Maze:
closest_coordinate = target
return closest_coordinate
def find_path(self, start, end):
def find_path(self, start: tuple[int, int], end: tuple[int, int]) -> list[tuple[int, int]]:
if start not in self.nx_graph.nodes:
start = self._find_closest_node(start)
if end not in self.nx_graph.nodes:
end = self._find_closest_node(end)
return self.nx_graph.shortest_path(start, end)
return nx.shortest_path(self.nx_graph, start, end)

View file

@ -9,14 +9,14 @@ import json
import os
class MemoryTree:
def __init__(self, f_saved):
def __init__(self, f_saved: str) -> None:
self.tree = {}
if os.path.isfile(f_saved) and os.path.exists(f_saved):
with open(f_saved) as f:
self.tree = json.load(f)
def print_tree(self):
def print_tree(self) -> None:
def _print_tree(tree, depth):
dash = " >" * depth
if type(tree) == type(list()):
@ -32,12 +32,12 @@ class MemoryTree:
_print_tree(self.tree, 0)
def save(self, out_json):
def save(self, out_json: str) -> None:
with open(out_json, "w") as outfile:
json.dump(self.tree, outfile)
def get_str_accessible_sectors(self, curr_world):
def get_str_accessible_sectors(self, curr_world: str) -> str:
"""
Returns a summary string of all the arenas that the persona can access
within the current sector.
@ -56,7 +56,7 @@ class MemoryTree:
return x
def get_str_accessible_sector_arenas(self, sector):
def get_str_accessible_sector_arenas(self, sector: str) -> str:
"""
Returns a summary string of all the arenas that the persona can access
within the current sector.
@ -78,7 +78,7 @@ class MemoryTree:
return x
def get_str_accessible_arena_game_objects(self, arena):
def get_str_accessible_arena_game_objects(self, arena: str) -> str:
"""
Get a str list of all accessible game objects that are in the arena. If
temp_address is specified, we return the objects that are available in
@ -104,7 +104,7 @@ class MemoryTree:
return x
def add_tile_info(self, tile_info: dict):
def add_tile_info(self, tile_info: dict) -> None:
if tile_info["world"]:
if (tile_info["world"] not in self.tree):
self.tree[tile_info["world"]] = {}

View file

@ -18,7 +18,7 @@ from operator import itemgetter
from metagpt.roles.role import Role, RoleContext
from metagpt.schema import Message
from ..memory.agent_memory import AgentMemory
from ..memory.agent_memory import AgentMemory, BasicMemory
from ..memory.spatial_memory import MemoryTree
from ..actions.dummy_action import DummyAction
from ..actions.user_requirement import UserRequirement
@ -68,7 +68,7 @@ class STRole(Role):
"""
pass
async def observe(self):
async def observe(self) -> list[BasicMemory]:
# TODO observe info from maze_env
"""
Perceive events around the role and saves it to the memory, both events