add base environment action_space/observation space and update stanford_town_env

This commit is contained in:
better629 2024-03-26 20:22:45 +08:00
parent e240c0dc01
commit 2b7d09ede2
12 changed files with 341 additions and 75 deletions

View file

@ -3,9 +3,12 @@
# @Desc : base env of executing environment
import asyncio
from abc import abstractmethod
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Set, Union
from gymnasium import spaces
from gymnasium.core import ActType, ObsType
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator
from metagpt.context import Context
@ -14,6 +17,7 @@ from metagpt.environment.api.env_api import (
ReadAPIRegistry,
WriteAPIRegistry,
)
from metagpt.environment.base_env_space import BaseEnvAction, BaseEnvObsParams
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.utils.common import get_function_schema, is_coroutine_func, is_send_to
@ -49,6 +53,11 @@ def mark_as_writeable(func):
class ExtEnv(BaseModel):
"""External Env to integrate actual game environment"""
model_config = ConfigDict(arbitrary_types_allowed=True)
action_space: spaces.Space[ActType] = Field(default_factory=spaces.Space, exclude=True)
observation_space: spaces.Space[ObsType] = Field(default_factory=spaces.Space, exclude=True)
def _check_api_exist(self, rw_api: Optional[str] = None):
if not rw_api:
raise ValueError(f"{rw_api} not exists")
@ -61,39 +70,56 @@ class ExtEnv(BaseModel):
else:
return env_write_api_registry.get_apis()
async def observe(self, env_action: Union[str, EnvAPIAbstract]):
async def read_from_api(self, env_action: Union[str, EnvAPIAbstract]):
"""get observation from particular api of ExtEnv"""
if isinstance(env_action, str):
read_api = env_read_api_registry.get(api_name=env_action)["func"]
self._check_api_exist(read_api)
if is_coroutine_func(read_api):
res = await read_api(self)
env_read_api = env_read_api_registry.get(api_name=env_action)["func"]
self._check_api_exist(env_read_api)
if is_coroutine_func(env_read_api):
res = await env_read_api(self)
else:
res = read_api(self)
res = env_read_api(self)
elif isinstance(env_action, EnvAPIAbstract):
read_api = env_read_api_registry.get(api_name=env_action.api_name)["func"]
self._check_api_exist(read_api)
if is_coroutine_func(read_api):
res = await read_api(self, *env_action.args, **env_action.kwargs)
env_read_api = env_read_api_registry.get(api_name=env_action.api_name)["func"]
self._check_api_exist(env_read_api)
if is_coroutine_func(env_read_api):
res = await env_read_api(self, *env_action.args, **env_action.kwargs)
else:
res = read_api(self, *env_action.args, **env_action.kwargs)
res = env_read_api(self, *env_action.args, **env_action.kwargs)
return res
async def step(self, env_action: Union[str, Message, EnvAPIAbstract, list[EnvAPIAbstract]]):
async def write_thru_api(self, env_action: Union[str, Message, EnvAPIAbstract, list[EnvAPIAbstract]]):
"""execute through particular api of ExtEnv"""
res = None
if isinstance(env_action, Message):
self.publish_message(env_action)
elif isinstance(env_action, EnvAPIAbstract):
write_api = env_write_api_registry.get(env_action.api_name)["func"]
self._check_api_exist(write_api)
if is_coroutine_func(write_api):
res = await write_api(self, *env_action.args, **env_action.kwargs)
env_write_api = env_write_api_registry.get(env_action.api_name)["func"]
self._check_api_exist(env_write_api)
if is_coroutine_func(env_write_api):
res = await env_write_api(self, *env_action.args, **env_action.kwargs)
else:
res = write_api(self, *env_action.args, **env_action.kwargs)
res = env_write_api(self, *env_action.args, **env_action.kwargs)
return res
@abstractmethod
def reset(
self,
*,
seed: Optional[int] = None,
options: Optional[dict[str, Any]] = None,
) -> tuple[dict[str, Any], dict[str, Any]]:
"""Implement this to get init observation"""
@abstractmethod
def observe(self, obs_params: Optional[BaseEnvObsParams] = None) -> Any:
"""Implement this if you want to get partial observation from the env"""
@abstractmethod
def step(self, action: BaseEnvAction) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]:
"""Implement this to feed a action and then get new observation from the env"""
class Environment(ExtEnv):
"""环境,承载一批角色,角色可以向环境发布消息,可以被其他角色观察到

View file

@ -0,0 +1,33 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
from enum import IntEnum
from pydantic import BaseModel, ConfigDict, Field
class BaseEnvActionType(IntEnum):
# # NONE = 0 # no action to run, just get observation
pass
class BaseEnvAction(BaseModel):
"""env action type and its related params of action functions/apis"""
model_config = ConfigDict(arbitrary_types_allowed=True)
action_type: int = Field(default=0, description="action type")
class BaseEnvObsType(IntEnum):
# # NONE = 0 # get whole observation from env
pass
class BaseEnvObsParams(BaseModel):
"""observation params for different EnvObsType to get its observe result"""
model_config = ConfigDict(arbitrary_types_allowed=True)
obs_type: int = Field(default=0, description="observation type")

View file

@ -0,0 +1,105 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
from typing import Any, Optional, Union
import numpy as np
import numpy.typing as npt
from gymnasium import spaces
from pydantic import ConfigDict, Field, field_validator
from metagpt.environment.base_env_space import (
BaseEnvAction,
BaseEnvActionType,
BaseEnvObsParams,
BaseEnvObsType,
)
class EnvActionType(BaseEnvActionType):
NONE = 0 # no action to run, just get observation
ADD_TILE_EVENT = 1 # Add an event triple to a tile
RM_TILE_EVENT = 2 # Remove an event triple from a tile
TURN_TILE_EVENT_IDLE = 3 # Turn an event triple from a tile into idle
RM_TITLE_SUB_EVENT = 4 # Remove an event triple that has the input subject from a tile
class EnvAction(BaseEnvAction):
"""env action type and its related params of action functions/apis"""
model_config = ConfigDict(arbitrary_types_allowed=True)
action_type: int = Field(default=EnvActionType.NONE, description="action type")
coord: npt.NDArray[np.int64] = Field(
default_factory=lambda: np.zeros(2, dtype=np.int64), description="tile coordinate"
)
subject: str = Field(default="", description="subject name of first element in event")
event: tuple[str, Optional[str], Optional[str], Optional[str]] = Field(
default=["", None, None, None], description="tile event"
)
@field_validator("coord", mode="before")
@classmethod
def check_coord(cls, coord) -> npt.NDArray[np.int64]:
if not isinstance(coord, np.ndarray):
return np.array(coord)
class EnvObsType(BaseEnvObsType):
"""get part observation with specific params"""
NONE = 0 # get whole observation from env
GET_TITLE = 1 # get the tile detail dictionary with given tile coord
TILE_PATH = 2 # get the tile address with given tile coord
TILE_NBR = 3 # get the neighbors of given tile coord and its vision radius
class EnvObsParams(BaseEnvObsParams):
"""observation params for different EnvObsType"""
model_config = ConfigDict(arbitrary_types_allowed=True)
obs_type: int = Field(default=EnvObsType.NONE, description="observation type")
coord: npt.NDArray[np.int64] = Field(
default_factory=lambda: np.zeros(2, dtype=np.int64), description="tile coordinate"
)
level: str = Field(default="", description="different level of title")
vision_radius: int = Field(default=0, description="the vision radius of current tile")
@field_validator("coord", mode="before")
@classmethod
def check_coord(cls, coord) -> npt.NDArray[np.int64]:
if not isinstance(coord, np.ndarray):
return np.array(coord)
EnvObsValType = Union[list[list[str]], dict[str, set[tuple[int, int]]], list[list[dict[str, Any]]]]
def get_observation_space() -> spaces.Dict:
# it's a
space = spaces.Dict(
{"collision_maze": spaces.Discrete(2), "tiles": spaces.Discrete(2), "address_tiles": spaces.Discrete(2)}
)
return space
def get_action_space(maze_shape: tuple[int, int]) -> spaces.Dict:
"""The fields defined by the space correspond to the input parameters of the action except `action_type`"""
space = spaces.Dict(
{
"action_type": spaces.Discrete(len(EnvActionType)),
"coord": spaces.Box(
np.array([0, 0], dtype=np.int64), np.array([maze_shape[0], maze_shape[1]], dtype=np.int64)
), # coord of the tile
"subject": spaces.Text(256), # the first element of an tile event
"event": spaces.Tuple(
(spaces.Text(256), spaces.Text(256), spaces.Text(256), spaces.Text(256))
), # event is a tuple of four str
}
)
return space

View file

@ -5,11 +5,20 @@
import math
from pathlib import Path
from typing import Optional, Tuple
from typing import Any, Optional
from pydantic import ConfigDict, Field, model_validator
from metagpt.environment.base_env import ExtEnv, mark_as_readable, mark_as_writeable
from metagpt.environment.stanford_town_env.env_space import (
EnvAction,
EnvActionType,
EnvObsParams,
EnvObsType,
EnvObsValType,
get_action_space,
get_observation_space,
)
from metagpt.utils.common import read_csv_to_list, read_json_file
@ -197,15 +206,82 @@ class StanfordTownExtEnv(ExtEnv):
else:
address_tiles[add] = set([(j, i)])
values["address_tiles"] = address_tiles
values["action_space"] = get_action_space((maze_width, maze_height))
values["observation_space"] = get_observation_space()
return values
def reset(
self,
*,
seed: Optional[int] = None,
options: Optional[dict[str, Any]] = None,
) -> tuple[dict[str, EnvObsValType], dict[str, Any]]:
"""reset env and get the init observation
Return results corresponding to `observation, info`
"""
super().reset(seed=seed, options=options)
obs = self._get_obs()
return obs, {}
def _get_obs(self) -> dict[str, EnvObsValType]:
"""Get observation"""
return {
"collision_maze": self.get_collision_maze(),
"tiles": self.tiles,
"address_tiles": self.get_address_tiles(),
}
def observe(self, obs_params: Optional[EnvObsParams] = None) -> Any:
"""Get partial or full observation from the env"""
obs_type = obs_params.obs_type if obs_params else EnvObsType.NONE
if obs_type == EnvObsType.NONE:
obs = self._get_obs()
elif obs_type == EnvObsType.GET_TITLE:
obs = self.access_tile(tile=obs_params.coord)
elif obs_type == EnvObsType.TILE_PATH:
obs = self.get_tile_path(tile=obs_params.coord, level=obs_params.level)
elif obs_type == EnvObsType.TILE_NBR:
obs = self.get_nearby_tiles(tile=obs_params.coord, vision_r=obs_params.vision_radius)
return obs
def step(self, action: EnvAction) -> tuple[dict[str, EnvObsValType], float, bool, bool, dict[str, Any]]:
"""Execute action and then return observation
Return results corresponding to `observation, reward, terminated, truncated, info`
"""
terminated = False
try:
self._execute_env_action(action)
except Exception:
terminated = True
obs = self._get_obs()
ret = (obs, 1.0, terminated, False, {})
return ret
def _execute_env_action(self, action: EnvAction):
action_type = action.action_type
if action_type == EnvActionType.NONE:
pass
elif action_type == EnvActionType.ADD_TILE_EVENT:
self.add_event_from_tile(curr_event=action.event, tile=action.coord)
elif action_type == EnvActionType.RM_TILE_EVENT:
self.remove_event_from_tile(curr_event=action.event, tile=action.coord)
elif action_type == EnvActionType.TURN_TILE_EVENT_IDLE:
self.turn_event_from_tile_idle(curr_event=action.event, tile=action.coord)
elif action_type == EnvActionType.RM_TITLE_SUB_EVENT:
self.remove_subject_events_from_tile(subject=action.subject, tile=action.coord)
def turn_coordinate_to_tile(self, px_coordinate: tuple[int, int]) -> tuple[int, int]:
"""
Turns a pixel coordinate to a tile coordinate.
"""
x = math.ceil(px_coordinate[0] / self.sq_tile_size)
y = math.ceil(px_coordinate[1] / self.sq_tile_size)
return (x, y)
return x, y
@mark_as_readable
def get_collision_maze(self) -> list:
@ -316,10 +392,6 @@ class StanfordTownExtEnv(ExtEnv):
nearby_tiles += [(i, j)]
return nearby_tiles
@mark_as_writeable
def add_tiles_event(self, pt_y: int, pt_x: int, event: Tuple[str, str, str, str]):
self.tiles[pt_y][pt_x]["events"].add(event)
@mark_as_writeable
def add_event_from_tile(self, curr_event: tuple[str], tile: tuple[int, int]) -> None:
"""