From 2b7d09ede241efcc07c476bf8016a7d3d85e4734 Mon Sep 17 00:00:00 2001 From: better629 Date: Tue, 26 Mar 2024 20:22:45 +0800 Subject: [PATCH] add base environment action_space/observation space and update stanford_town_env --- .../st_game/actions/gen_action_details.py | 6 +- examples/st_game/actions/gen_iter_chat_utt.py | 6 +- examples/st_game/roles/st_role.py | 82 +++++++------- examples/st_game/storage/.gitignore | 1 + .../tests/actions/test_gen_action_details.py | 2 +- metagpt/environment/base_env.py | 60 +++++++--- metagpt/environment/base_env_space.py | 33 ++++++ .../stanford_town_env/env_space.py | 105 ++++++++++++++++++ .../stanford_town_ext_env.py | 84 +++++++++++++- requirements.txt | 3 +- .../test_stanford_town_ext_env.py | 26 ++++- tests/metagpt/environment/test_base_env.py | 8 +- 12 files changed, 341 insertions(+), 75 deletions(-) create mode 100644 metagpt/environment/base_env_space.py create mode 100644 metagpt/environment/stanford_town_env/env_space.py diff --git a/examples/st_game/actions/gen_action_details.py b/examples/st_game/actions/gen_action_details.py index 92a53087a..6af2cb338 100644 --- a/examples/st_game/actions/gen_action_details.py +++ b/examples/st_game/actions/gen_action_details.py @@ -4,7 +4,7 @@ import random -from metagpt.environment.api.env_api import EnvAPIAbstract +from metagpt.environment.stanford_town_env.env_space import EnvObsParams, EnvObsType from metagpt.logs import logger from .st_action import STAction @@ -367,8 +367,8 @@ class GenActionDetails(STAction): return fs async def run(self, role: "STRole", act_desp: str, act_dura): - access_tile = await role.rc.env.observe( - EnvAPIAbstract(api_name="access_tile", kwargs={"tile": role.scratch.curr_tile}) + access_tile = role.rc.env.observe( + obs_params=EnvObsParams(obs_type=EnvObsType.GET_TITLE, coord=role.scratch.curr_tile) ) act_world = access_tile["world"] act_sector = await GenActionSector().run(role, access_tile, act_desp) diff --git a/examples/st_game/actions/gen_iter_chat_utt.py b/examples/st_game/actions/gen_iter_chat_utt.py index 2b0d46f4e..eb5f569c7 100644 --- a/examples/st_game/actions/gen_iter_chat_utt.py +++ b/examples/st_game/actions/gen_iter_chat_utt.py @@ -4,7 +4,7 @@ from examples.st_game.actions.st_action import STAction from examples.st_game.utils.utils import extract_first_json_dict -from metagpt.environment.api.env_api import EnvAPIAbstract +from metagpt.environment.stanford_town_env.env_space import EnvObsParams, EnvObsType from metagpt.logs import logger @@ -113,8 +113,8 @@ class GenIterChatUTT(STAction): ] return prompt_input - access_tile = await init_role.rc.env.observe( - EnvAPIAbstract(api_name="access_tile", kwargs={"tile": init_role.scratch.curr_tile}) + access_tile = init_role.rc.env.observe( + obs_params=EnvObsParams(obs_type=EnvObsType.GET_TITLE, coord=init_role.scratch.curr_tile) ) prompt_input = create_prompt_input(access_tile, init_role, target_role, retrieved, curr_context, curr_chat) prompt = self.generate_prompt_with_tmpl_filename(prompt_input, "iterative_convo_v1.txt") diff --git a/examples/st_game/roles/st_role.py b/examples/st_game/roles/st_role.py index d5dd994f9..48de34f15 100644 --- a/examples/st_game/roles/st_role.py +++ b/examples/st_game/roles/st_role.py @@ -36,7 +36,12 @@ from examples.st_game.utils.mg_ga_transform import ( ) from examples.st_game.utils.utils import get_embedding, path_finder from metagpt.actions.add_requirement import UserRequirement -from metagpt.environment.api.env_api import EnvAPIAbstract +from metagpt.environment.stanford_town_env.env_space import ( + EnvAction, + EnvActionType, + EnvObsParams, + EnvObsType, +) from metagpt.logs import logger from metagpt.roles.role import Role, RoleContext from metagpt.schema import Message @@ -115,10 +120,12 @@ class STRole(Role): pt_x = role_env["x"] pt_y = role_env["y"] self.rc.scratch.curr_tile = (pt_x, pt_y) - await self.rc.env.step( - EnvAPIAbstract( - api_name="add_tiles_event", - kwargs={"pt_y": pt_y, "pt_x": pt_x, "event": self.scratch.get_curr_event_and_desc()}, + + self.rc.env.step( + EnvAction( + action_type=EnvActionType.ADD_TILE_EVENT, + coord=(pt_x, pt_y), + event=self.scratch.get_curr_event_and_desc(), ) ) @@ -231,24 +238,24 @@ class STRole(Role): # PERCEIVE SPACE # We get the nearby tiles given our current tile and the persona's vision # radius. - nearby_tiles = await self.rc.env.observe( - EnvAPIAbstract( - api_name="get_nearby_tiles", - kwargs={"tile": self.rc.scratch.curr_tile, "vision_r": self.rc.scratch.vision_r}, + nearby_tiles = self.rc.env.observe( + EnvObsParams( + obs_type=EnvObsType.TILE_NBR, coord=self.rc.scratch.curr_tile, vision_radius=self.rc.scratch.vision_r ) ) # We then store the perceived space. Note that the s_mem of the persona is # in the form of a tree constructed using dictionaries. for tile in nearby_tiles: - tile_info = await self.rc.env.observe(EnvAPIAbstract(api_name="access_tile", kwargs={"tile": tile})) + tile_info = self.rc.env.observe(EnvObsParams(obs_type=EnvObsType.GET_TITLE, coord=tile)) self.rc.spatial_memory.add_tile_info(tile_info) # PERCEIVE EVENTS. # We will perceive events that take place in the same arena as the # persona's current arena. - curr_arena_path = await self.rc.env.observe( - EnvAPIAbstract(api_name="get_tile_path", kwargs={"tile": self.rc.scratch.curr_tile, "level": "arena"}) + + curr_arena_path = self.rc.env.observe( + EnvObsParams(obs_type=EnvObsType.TILE_PATH, coord=self.rc.scratch.curr_tile, level="arena") ) # We do not perceive the same event twice (this can happen if an object is @@ -260,10 +267,10 @@ class STRole(Role): # First, we put all events that are occuring in the nearby tiles into the # percept_events_list for tile in nearby_tiles: - tile_details = await self.rc.env.observe(EnvAPIAbstract(api_name="access_tile", kwargs={"tile": tile})) + tile_details = self.rc.env.observe(EnvObsParams(obs_type=EnvObsType.GET_TITLE, coord=tile)) if tile_details["events"]: - tmp_arena_path = await self.rc.env.observe( - EnvAPIAbstract(api_name="get_tile_path", kwargs={"tile": tile, "level": "arena"}) + tmp_arena_path = self.rc.env.observe( + EnvObsParams(obs_type=EnvObsType.TILE_PATH, coord=tile, level="arena") ) if tmp_arena_path == curr_arena_path: @@ -418,14 +425,14 @@ class STRole(Role): if "" in plan: # Executing persona-persona interaction. target_p_tile = roles[plan.split("")[-1].strip()].scratch.curr_tile - collision_maze = await self.rc.env.observe(EnvAPIAbstract(api_name="get_collision_maze")) + collision_maze = self.rc.env.observe()["collision_maze"] potential_path = path_finder( collision_maze, self.rc.scratch.curr_tile, target_p_tile, collision_block_id ) if len(potential_path) <= 2: target_tiles = [potential_path[0]] else: - collision_maze = await self.rc.env.observe(EnvAPIAbstract(api_name="get_collision_maze")) + collision_maze = self.rc.env.observe()["collision_maze"] potential_1 = path_finder( collision_maze, self.rc.scratch.curr_tile, @@ -455,7 +462,7 @@ class STRole(Role): # Executing a random location action. plan = ":".join(plan.split(":")[:-1]) - address_tiles = await self.rc.env.observe(EnvAPIAbstract(api_name="get_address_tiles")) + address_tiles = self.rc.env.observe()["address_tiles"] target_tiles = address_tiles[plan] target_tiles = random.sample(list(target_tiles), 1) @@ -465,7 +472,7 @@ class STRole(Role): # Retrieve the target addresses. Again, plan is an action address in its # string form. takes this and returns candidate # coordinates. - address_tiles = await self.rc.env.observe(EnvAPIAbstract(api_name="get_address_tiles")) + address_tiles = self.rc.env.observe()["address_tiles"] if plan not in address_tiles: address_tiles["Johnson Park:park:park garden"] # ERRORRRRRRR else: @@ -485,7 +492,7 @@ class STRole(Role): persona_name_set = set(roles.keys()) new_target_tiles = [] for i in target_tiles: - access_tile = await self.rc.env.observe(EnvAPIAbstract(api_name="access_tile", kwargs={"tile": i})) + access_tile = self.rc.env.observe(EnvObsParams(obs_type=EnvObsType.GET_TITLE, coord=i)) curr_event_set = access_tile["events"] pass_curr_tile = False for j in curr_event_set: @@ -507,7 +514,7 @@ class STRole(Role): # an input, and returns a list of coordinate tuples that becomes the # path. # e.g., [(0, 1), (1, 1), (1, 2), (1, 3), (1, 4)...] - collision_maze = await self.rc.env.observe(EnvAPIAbstract(api_name="get_collision_maze")) + collision_maze = self.rc.env.observe()["collision_maze"] curr_path = path_finder(collision_maze, curr_tile, i, collision_block_id) if not closest_target_tile: closest_target_tile = i @@ -539,23 +546,20 @@ class STRole(Role): ret = True if role_env: for key, val in self.game_obj_cleanup.items(): - await self.rc.env.step( - EnvAPIAbstract(api_name="turn_event_from_tile_idle", kwargs={"curr_event": key, "tile": val}) - ) + self.rc.env.step(EnvAction(action_type=EnvActionType.TURN_TILE_EVENT_IDLE, coord=val, event=key)) # reset game_obj_cleanup self.game_obj_cleanup = dict() curr_tile = self.role_tile new_tile = (role_env["x"], role_env["y"]) - await self.rc.env.step( - EnvAPIAbstract( - api_name="remove_subject_events_from_tile", kwargs={"subject": self.name, "tile": curr_tile} - ) + self.rc.env.step( + EnvAction(action_type=EnvActionType.RM_TITLE_SUB_EVENT, coord=curr_tile, subject=self.name) ) - await self.rc.env.step( - EnvAPIAbstract( - api_name="add_event_from_tile", - kwargs={"curr_event": self.scratch.get_curr_event_and_desc(), "tile": new_tile}, + self.rc.env.step( + EnvAction( + action_type=EnvActionType.ADD_TILE_EVENT, + coord=new_tile, + event=self.scratch.get_curr_event_and_desc(), ) ) @@ -563,16 +567,16 @@ class STRole(Role): # the persona gets there, we activate the object action. if not self.scratch.planned_path: self.game_obj_cleanup[self.scratch.get_curr_event_and_desc()] = new_tile - await self.rc.env.step( - EnvAPIAbstract( - api_name="add_event_from_tile", - kwargs={"curr_event": self.scratch.get_curr_event_and_desc(), "tile": new_tile}, + self.rc.env.step( + EnvAction( + action_type=EnvActionType.ADD_TILE_EVENT, + coord=new_tile, + event=self.scratch.get_curr_event_and_desc(), ) ) + blank = (self.scratch.get_curr_obj_event_and_desc()[0], None, None, None) - await self.rc.env.step( - EnvAPIAbstract(api_name="remove_event_from_tile", kwargs={"curr_event": blank, "tile": new_tile}) - ) + self.rc.env.step(EnvAction(action_type=EnvActionType.RM_TILE_EVENT, coord=new_tile, event=blank)) # update role's new tile self.rc.scratch.curr_tile = new_tile diff --git a/examples/st_game/storage/.gitignore b/examples/st_game/storage/.gitignore index 6c37f8efd..72b125e04 100644 --- a/examples/st_game/storage/.gitignore +++ b/examples/st_game/storage/.gitignore @@ -1,2 +1,3 @@ # path to store simulation data test_* +July* \ No newline at end of file diff --git a/examples/st_game/tests/actions/test_gen_action_details.py b/examples/st_game/tests/actions/test_gen_action_details.py index 3edf9b116..49e24481d 100644 --- a/examples/st_game/tests/actions/test_gen_action_details.py +++ b/examples/st_game/tests/actions/test_gen_action_details.py @@ -31,7 +31,7 @@ async def test_gen_action_details(): act_desp = "sleeping" act_dura = "120" - access_tile = await role.rc.env.observe( + access_tile = await role.rc.env.read_from_api( EnvAPIAbstract(api_name="access_tile", kwargs={"tile": role.scratch.curr_tile}) ) act_world = access_tile["world"] diff --git a/metagpt/environment/base_env.py b/metagpt/environment/base_env.py index 942bf2409..c6bfcbc12 100644 --- a/metagpt/environment/base_env.py +++ b/metagpt/environment/base_env.py @@ -3,9 +3,12 @@ # @Desc : base env of executing environment import asyncio +from abc import abstractmethod from enum import Enum from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Set, Union +from gymnasium import spaces +from gymnasium.core import ActType, ObsType from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator from metagpt.context import Context @@ -14,6 +17,7 @@ from metagpt.environment.api.env_api import ( ReadAPIRegistry, WriteAPIRegistry, ) +from metagpt.environment.base_env_space import BaseEnvAction, BaseEnvObsParams from metagpt.logs import logger from metagpt.schema import Message from metagpt.utils.common import get_function_schema, is_coroutine_func, is_send_to @@ -49,6 +53,11 @@ def mark_as_writeable(func): class ExtEnv(BaseModel): """External Env to integrate actual game environment""" + model_config = ConfigDict(arbitrary_types_allowed=True) + + action_space: spaces.Space[ActType] = Field(default_factory=spaces.Space, exclude=True) + observation_space: spaces.Space[ObsType] = Field(default_factory=spaces.Space, exclude=True) + def _check_api_exist(self, rw_api: Optional[str] = None): if not rw_api: raise ValueError(f"{rw_api} not exists") @@ -61,39 +70,56 @@ class ExtEnv(BaseModel): else: return env_write_api_registry.get_apis() - async def observe(self, env_action: Union[str, EnvAPIAbstract]): + async def read_from_api(self, env_action: Union[str, EnvAPIAbstract]): """get observation from particular api of ExtEnv""" if isinstance(env_action, str): - read_api = env_read_api_registry.get(api_name=env_action)["func"] - self._check_api_exist(read_api) - if is_coroutine_func(read_api): - res = await read_api(self) + env_read_api = env_read_api_registry.get(api_name=env_action)["func"] + self._check_api_exist(env_read_api) + if is_coroutine_func(env_read_api): + res = await env_read_api(self) else: - res = read_api(self) + res = env_read_api(self) elif isinstance(env_action, EnvAPIAbstract): - read_api = env_read_api_registry.get(api_name=env_action.api_name)["func"] - self._check_api_exist(read_api) - if is_coroutine_func(read_api): - res = await read_api(self, *env_action.args, **env_action.kwargs) + env_read_api = env_read_api_registry.get(api_name=env_action.api_name)["func"] + self._check_api_exist(env_read_api) + if is_coroutine_func(env_read_api): + res = await env_read_api(self, *env_action.args, **env_action.kwargs) else: - res = read_api(self, *env_action.args, **env_action.kwargs) + res = env_read_api(self, *env_action.args, **env_action.kwargs) return res - async def step(self, env_action: Union[str, Message, EnvAPIAbstract, list[EnvAPIAbstract]]): + async def write_thru_api(self, env_action: Union[str, Message, EnvAPIAbstract, list[EnvAPIAbstract]]): """execute through particular api of ExtEnv""" res = None if isinstance(env_action, Message): self.publish_message(env_action) elif isinstance(env_action, EnvAPIAbstract): - write_api = env_write_api_registry.get(env_action.api_name)["func"] - self._check_api_exist(write_api) - if is_coroutine_func(write_api): - res = await write_api(self, *env_action.args, **env_action.kwargs) + env_write_api = env_write_api_registry.get(env_action.api_name)["func"] + self._check_api_exist(env_write_api) + if is_coroutine_func(env_write_api): + res = await env_write_api(self, *env_action.args, **env_action.kwargs) else: - res = write_api(self, *env_action.args, **env_action.kwargs) + res = env_write_api(self, *env_action.args, **env_action.kwargs) return res + @abstractmethod + def reset( + self, + *, + seed: Optional[int] = None, + options: Optional[dict[str, Any]] = None, + ) -> tuple[dict[str, Any], dict[str, Any]]: + """Implement this to get init observation""" + + @abstractmethod + def observe(self, obs_params: Optional[BaseEnvObsParams] = None) -> Any: + """Implement this if you want to get partial observation from the env""" + + @abstractmethod + def step(self, action: BaseEnvAction) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]: + """Implement this to feed a action and then get new observation from the env""" + class Environment(ExtEnv): """环境,承载一批角色,角色可以向环境发布消息,可以被其他角色观察到 diff --git a/metagpt/environment/base_env_space.py b/metagpt/environment/base_env_space.py new file mode 100644 index 000000000..fd0cfa399 --- /dev/null +++ b/metagpt/environment/base_env_space.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +from enum import IntEnum + +from pydantic import BaseModel, ConfigDict, Field + + +class BaseEnvActionType(IntEnum): + # # NONE = 0 # no action to run, just get observation + pass + + +class BaseEnvAction(BaseModel): + """env action type and its related params of action functions/apis""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + action_type: int = Field(default=0, description="action type") + + +class BaseEnvObsType(IntEnum): + # # NONE = 0 # get whole observation from env + pass + + +class BaseEnvObsParams(BaseModel): + """observation params for different EnvObsType to get its observe result""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + obs_type: int = Field(default=0, description="observation type") diff --git a/metagpt/environment/stanford_town_env/env_space.py b/metagpt/environment/stanford_town_env/env_space.py new file mode 100644 index 000000000..e100a2952 --- /dev/null +++ b/metagpt/environment/stanford_town_env/env_space.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +from typing import Any, Optional, Union + +import numpy as np +import numpy.typing as npt +from gymnasium import spaces +from pydantic import ConfigDict, Field, field_validator + +from metagpt.environment.base_env_space import ( + BaseEnvAction, + BaseEnvActionType, + BaseEnvObsParams, + BaseEnvObsType, +) + + +class EnvActionType(BaseEnvActionType): + NONE = 0 # no action to run, just get observation + + ADD_TILE_EVENT = 1 # Add an event triple to a tile + RM_TILE_EVENT = 2 # Remove an event triple from a tile + TURN_TILE_EVENT_IDLE = 3 # Turn an event triple from a tile into idle + RM_TITLE_SUB_EVENT = 4 # Remove an event triple that has the input subject from a tile + + +class EnvAction(BaseEnvAction): + """env action type and its related params of action functions/apis""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + action_type: int = Field(default=EnvActionType.NONE, description="action type") + coord: npt.NDArray[np.int64] = Field( + default_factory=lambda: np.zeros(2, dtype=np.int64), description="tile coordinate" + ) + subject: str = Field(default="", description="subject name of first element in event") + event: tuple[str, Optional[str], Optional[str], Optional[str]] = Field( + default=["", None, None, None], description="tile event" + ) + + @field_validator("coord", mode="before") + @classmethod + def check_coord(cls, coord) -> npt.NDArray[np.int64]: + if not isinstance(coord, np.ndarray): + return np.array(coord) + + +class EnvObsType(BaseEnvObsType): + """get part observation with specific params""" + + NONE = 0 # get whole observation from env + + GET_TITLE = 1 # get the tile detail dictionary with given tile coord + TILE_PATH = 2 # get the tile address with given tile coord + TILE_NBR = 3 # get the neighbors of given tile coord and its vision radius + + +class EnvObsParams(BaseEnvObsParams): + """observation params for different EnvObsType""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + obs_type: int = Field(default=EnvObsType.NONE, description="observation type") + coord: npt.NDArray[np.int64] = Field( + default_factory=lambda: np.zeros(2, dtype=np.int64), description="tile coordinate" + ) + level: str = Field(default="", description="different level of title") + vision_radius: int = Field(default=0, description="the vision radius of current tile") + + @field_validator("coord", mode="before") + @classmethod + def check_coord(cls, coord) -> npt.NDArray[np.int64]: + if not isinstance(coord, np.ndarray): + return np.array(coord) + + +EnvObsValType = Union[list[list[str]], dict[str, set[tuple[int, int]]], list[list[dict[str, Any]]]] + + +def get_observation_space() -> spaces.Dict: + # it's a + space = spaces.Dict( + {"collision_maze": spaces.Discrete(2), "tiles": spaces.Discrete(2), "address_tiles": spaces.Discrete(2)} + ) + + return space + + +def get_action_space(maze_shape: tuple[int, int]) -> spaces.Dict: + """The fields defined by the space correspond to the input parameters of the action except `action_type`""" + space = spaces.Dict( + { + "action_type": spaces.Discrete(len(EnvActionType)), + "coord": spaces.Box( + np.array([0, 0], dtype=np.int64), np.array([maze_shape[0], maze_shape[1]], dtype=np.int64) + ), # coord of the tile + "subject": spaces.Text(256), # the first element of an tile event + "event": spaces.Tuple( + (spaces.Text(256), spaces.Text(256), spaces.Text(256), spaces.Text(256)) + ), # event is a tuple of four str + } + ) + return space diff --git a/metagpt/environment/stanford_town_env/stanford_town_ext_env.py b/metagpt/environment/stanford_town_env/stanford_town_ext_env.py index 8a9a65965..b41ae375c 100644 --- a/metagpt/environment/stanford_town_env/stanford_town_ext_env.py +++ b/metagpt/environment/stanford_town_env/stanford_town_ext_env.py @@ -5,11 +5,20 @@ import math from pathlib import Path -from typing import Optional, Tuple +from typing import Any, Optional from pydantic import ConfigDict, Field, model_validator from metagpt.environment.base_env import ExtEnv, mark_as_readable, mark_as_writeable +from metagpt.environment.stanford_town_env.env_space import ( + EnvAction, + EnvActionType, + EnvObsParams, + EnvObsType, + EnvObsValType, + get_action_space, + get_observation_space, +) from metagpt.utils.common import read_csv_to_list, read_json_file @@ -197,15 +206,82 @@ class StanfordTownExtEnv(ExtEnv): else: address_tiles[add] = set([(j, i)]) values["address_tiles"] = address_tiles + + values["action_space"] = get_action_space((maze_width, maze_height)) + values["observation_space"] = get_observation_space() return values + def reset( + self, + *, + seed: Optional[int] = None, + options: Optional[dict[str, Any]] = None, + ) -> tuple[dict[str, EnvObsValType], dict[str, Any]]: + """reset env and get the init observation + Return results corresponding to `observation, info` + """ + super().reset(seed=seed, options=options) + + obs = self._get_obs() + + return obs, {} + + def _get_obs(self) -> dict[str, EnvObsValType]: + """Get observation""" + return { + "collision_maze": self.get_collision_maze(), + "tiles": self.tiles, + "address_tiles": self.get_address_tiles(), + } + + def observe(self, obs_params: Optional[EnvObsParams] = None) -> Any: + """Get partial or full observation from the env""" + obs_type = obs_params.obs_type if obs_params else EnvObsType.NONE + if obs_type == EnvObsType.NONE: + obs = self._get_obs() + elif obs_type == EnvObsType.GET_TITLE: + obs = self.access_tile(tile=obs_params.coord) + elif obs_type == EnvObsType.TILE_PATH: + obs = self.get_tile_path(tile=obs_params.coord, level=obs_params.level) + elif obs_type == EnvObsType.TILE_NBR: + obs = self.get_nearby_tiles(tile=obs_params.coord, vision_r=obs_params.vision_radius) + return obs + + def step(self, action: EnvAction) -> tuple[dict[str, EnvObsValType], float, bool, bool, dict[str, Any]]: + """Execute action and then return observation + Return results corresponding to `observation, reward, terminated, truncated, info` + """ + terminated = False + try: + self._execute_env_action(action) + except Exception: + terminated = True + + obs = self._get_obs() + + ret = (obs, 1.0, terminated, False, {}) + return ret + + def _execute_env_action(self, action: EnvAction): + action_type = action.action_type + if action_type == EnvActionType.NONE: + pass + elif action_type == EnvActionType.ADD_TILE_EVENT: + self.add_event_from_tile(curr_event=action.event, tile=action.coord) + elif action_type == EnvActionType.RM_TILE_EVENT: + self.remove_event_from_tile(curr_event=action.event, tile=action.coord) + elif action_type == EnvActionType.TURN_TILE_EVENT_IDLE: + self.turn_event_from_tile_idle(curr_event=action.event, tile=action.coord) + elif action_type == EnvActionType.RM_TITLE_SUB_EVENT: + self.remove_subject_events_from_tile(subject=action.subject, tile=action.coord) + def turn_coordinate_to_tile(self, px_coordinate: tuple[int, int]) -> tuple[int, int]: """ Turns a pixel coordinate to a tile coordinate. """ x = math.ceil(px_coordinate[0] / self.sq_tile_size) y = math.ceil(px_coordinate[1] / self.sq_tile_size) - return (x, y) + return x, y @mark_as_readable def get_collision_maze(self) -> list: @@ -316,10 +392,6 @@ class StanfordTownExtEnv(ExtEnv): nearby_tiles += [(i, j)] return nearby_tiles - @mark_as_writeable - def add_tiles_event(self, pt_y: int, pt_x: int, event: Tuple[str, str, str, str]): - self.tiles[pt_y][pt_x]["events"].add(event) - @mark_as_writeable def add_event_from_tile(self, curr_event: tuple[str], tile: tuple[int, int]) -> None: """ diff --git a/requirements.txt b/requirements.txt index da8aa26b2..d150d61f3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -69,4 +69,5 @@ imap_tools==1.5.0 # Used by metagpt/tools/libs/email_login.py qianfan==0.3.2 dashscope==1.14.1 rank-bm25==0.2.2 # for tool recommendation -jieba==0.42.1 # for tool recommendation \ No newline at end of file +jieba==0.42.1 # for tool recommendation +gymnasium==0.29.1 \ No newline at end of file diff --git a/tests/metagpt/environment/stanford_town_env/test_stanford_town_ext_env.py b/tests/metagpt/environment/stanford_town_env/test_stanford_town_ext_env.py index b167f83bb..63e88cf32 100644 --- a/tests/metagpt/environment/stanford_town_env/test_stanford_town_ext_env.py +++ b/tests/metagpt/environment/stanford_town_env/test_stanford_town_ext_env.py @@ -4,6 +4,12 @@ from pathlib import Path +from metagpt.environment.stanford_town_env.env_space import ( + EnvAction, + EnvActionType, + EnvObsParams, + EnvObsType, +) from metagpt.environment.stanford_town_env.stanford_town_ext_env import ( StanfordTownExtEnv, ) @@ -27,7 +33,6 @@ def test_stanford_town_ext_env(): assert len(ext_env.get_nearby_tiles(tile=tile, vision_r=5)) == 121 event = ("double studio:double studio:bedroom 2:bed", None, None, None) - ext_env.add_tiles_event(tile[1], tile[0], event=event) ext_env.add_event_from_tile(event, tile) assert len(ext_env.tiles[tile[1]][tile[0]]["events"]) == 1 @@ -38,3 +43,22 @@ def test_stanford_town_ext_env(): ext_env.remove_subject_events_from_tile(subject=event[0], tile=tile) assert len(ext_env.tiles[tile[1]][tile[0]]["events"]) == 0 + + +def test_stanford_town_ext_env_observe_step(): + ext_env = StanfordTownExtEnv(maze_asset_path=maze_asset_path) + obs, info = ext_env.reset() + assert len(info) == 0 + assert len(obs["address_tiles"]) == 306 + + tile = (58, 9) + obs = ext_env.observe(obs_params=EnvObsParams(obs_type=EnvObsType.TILE_PATH, coord=tile, level="world")) + assert obs == "the Ville" + + action = ext_env.action_space.sample() + assert len(action) == 4 + assert len(action["event"]) == 4 + + event = ("double studio:double studio:bedroom 2:bed", None, None, None) + obs, _, _, _, _ = ext_env.step(action=EnvAction(action_type=EnvActionType.ADD_TILE_EVENT, coord=tile, event=event)) + assert len(ext_env.tiles[tile[1]][tile[0]]["events"]) == 1 diff --git a/tests/metagpt/environment/test_base_env.py b/tests/metagpt/environment/test_base_env.py index fd73679d8..28815a874 100644 --- a/tests/metagpt/environment/test_base_env.py +++ b/tests/metagpt/environment/test_base_env.py @@ -44,11 +44,11 @@ async def test_ext_env(): assert len(apis) > 0 assert len(apis["read_api"]) == 3 - _ = await env.step(EnvAPIAbstract(api_name="write_api", kwargs={"a": 5, "b": 10})) + _ = await env.write_thru_api(EnvAPIAbstract(api_name="write_api", kwargs={"a": 5, "b": 10})) assert env.value == 15 with pytest.raises(ValueError): - await env.observe("not_exist_api") + await env.read_from_api("not_exist_api") - assert await env.observe("read_api_no_param") == 15 - assert await env.observe(EnvAPIAbstract(api_name="read_api", kwargs={"a": 5, "b": 5})) == 10 + assert await env.read_from_api("read_api_no_param") == 15 + assert await env.read_from_api(EnvAPIAbstract(api_name="read_api", kwargs={"a": 5, "b": 5})) == 10