mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-24 14:15:17 +02:00
add base environment action_space/observation space and update stanford_town_env
This commit is contained in:
parent
e240c0dc01
commit
2b7d09ede2
12 changed files with 341 additions and 75 deletions
|
|
@ -4,7 +4,7 @@
|
|||
|
||||
import random
|
||||
|
||||
from metagpt.environment.api.env_api import EnvAPIAbstract
|
||||
from metagpt.environment.stanford_town_env.env_space import EnvObsParams, EnvObsType
|
||||
from metagpt.logs import logger
|
||||
|
||||
from .st_action import STAction
|
||||
|
|
@ -367,8 +367,8 @@ class GenActionDetails(STAction):
|
|||
return fs
|
||||
|
||||
async def run(self, role: "STRole", act_desp: str, act_dura):
|
||||
access_tile = await role.rc.env.observe(
|
||||
EnvAPIAbstract(api_name="access_tile", kwargs={"tile": role.scratch.curr_tile})
|
||||
access_tile = role.rc.env.observe(
|
||||
obs_params=EnvObsParams(obs_type=EnvObsType.GET_TITLE, coord=role.scratch.curr_tile)
|
||||
)
|
||||
act_world = access_tile["world"]
|
||||
act_sector = await GenActionSector().run(role, access_tile, act_desp)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
|
||||
from examples.st_game.actions.st_action import STAction
|
||||
from examples.st_game.utils.utils import extract_first_json_dict
|
||||
from metagpt.environment.api.env_api import EnvAPIAbstract
|
||||
from metagpt.environment.stanford_town_env.env_space import EnvObsParams, EnvObsType
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
|
|
@ -113,8 +113,8 @@ class GenIterChatUTT(STAction):
|
|||
]
|
||||
return prompt_input
|
||||
|
||||
access_tile = await init_role.rc.env.observe(
|
||||
EnvAPIAbstract(api_name="access_tile", kwargs={"tile": init_role.scratch.curr_tile})
|
||||
access_tile = init_role.rc.env.observe(
|
||||
obs_params=EnvObsParams(obs_type=EnvObsType.GET_TITLE, coord=init_role.scratch.curr_tile)
|
||||
)
|
||||
prompt_input = create_prompt_input(access_tile, init_role, target_role, retrieved, curr_context, curr_chat)
|
||||
prompt = self.generate_prompt_with_tmpl_filename(prompt_input, "iterative_convo_v1.txt")
|
||||
|
|
|
|||
|
|
@ -36,7 +36,12 @@ from examples.st_game.utils.mg_ga_transform import (
|
|||
)
|
||||
from examples.st_game.utils.utils import get_embedding, path_finder
|
||||
from metagpt.actions.add_requirement import UserRequirement
|
||||
from metagpt.environment.api.env_api import EnvAPIAbstract
|
||||
from metagpt.environment.stanford_town_env.env_space import (
|
||||
EnvAction,
|
||||
EnvActionType,
|
||||
EnvObsParams,
|
||||
EnvObsType,
|
||||
)
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles.role import Role, RoleContext
|
||||
from metagpt.schema import Message
|
||||
|
|
@ -115,10 +120,12 @@ class STRole(Role):
|
|||
pt_x = role_env["x"]
|
||||
pt_y = role_env["y"]
|
||||
self.rc.scratch.curr_tile = (pt_x, pt_y)
|
||||
await self.rc.env.step(
|
||||
EnvAPIAbstract(
|
||||
api_name="add_tiles_event",
|
||||
kwargs={"pt_y": pt_y, "pt_x": pt_x, "event": self.scratch.get_curr_event_and_desc()},
|
||||
|
||||
self.rc.env.step(
|
||||
EnvAction(
|
||||
action_type=EnvActionType.ADD_TILE_EVENT,
|
||||
coord=(pt_x, pt_y),
|
||||
event=self.scratch.get_curr_event_and_desc(),
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -231,24 +238,24 @@ class STRole(Role):
|
|||
# PERCEIVE SPACE
|
||||
# We get the nearby tiles given our current tile and the persona's vision
|
||||
# radius.
|
||||
nearby_tiles = await self.rc.env.observe(
|
||||
EnvAPIAbstract(
|
||||
api_name="get_nearby_tiles",
|
||||
kwargs={"tile": self.rc.scratch.curr_tile, "vision_r": self.rc.scratch.vision_r},
|
||||
nearby_tiles = self.rc.env.observe(
|
||||
EnvObsParams(
|
||||
obs_type=EnvObsType.TILE_NBR, coord=self.rc.scratch.curr_tile, vision_radius=self.rc.scratch.vision_r
|
||||
)
|
||||
)
|
||||
|
||||
# We then store the perceived space. Note that the s_mem of the persona is
|
||||
# in the form of a tree constructed using dictionaries.
|
||||
for tile in nearby_tiles:
|
||||
tile_info = await self.rc.env.observe(EnvAPIAbstract(api_name="access_tile", kwargs={"tile": tile}))
|
||||
tile_info = self.rc.env.observe(EnvObsParams(obs_type=EnvObsType.GET_TITLE, coord=tile))
|
||||
self.rc.spatial_memory.add_tile_info(tile_info)
|
||||
|
||||
# PERCEIVE EVENTS.
|
||||
# We will perceive events that take place in the same arena as the
|
||||
# persona's current arena.
|
||||
curr_arena_path = await self.rc.env.observe(
|
||||
EnvAPIAbstract(api_name="get_tile_path", kwargs={"tile": self.rc.scratch.curr_tile, "level": "arena"})
|
||||
|
||||
curr_arena_path = self.rc.env.observe(
|
||||
EnvObsParams(obs_type=EnvObsType.TILE_PATH, coord=self.rc.scratch.curr_tile, level="arena")
|
||||
)
|
||||
|
||||
# We do not perceive the same event twice (this can happen if an object is
|
||||
|
|
@ -260,10 +267,10 @@ class STRole(Role):
|
|||
# First, we put all events that are occuring in the nearby tiles into the
|
||||
# percept_events_list
|
||||
for tile in nearby_tiles:
|
||||
tile_details = await self.rc.env.observe(EnvAPIAbstract(api_name="access_tile", kwargs={"tile": tile}))
|
||||
tile_details = self.rc.env.observe(EnvObsParams(obs_type=EnvObsType.GET_TITLE, coord=tile))
|
||||
if tile_details["events"]:
|
||||
tmp_arena_path = await self.rc.env.observe(
|
||||
EnvAPIAbstract(api_name="get_tile_path", kwargs={"tile": tile, "level": "arena"})
|
||||
tmp_arena_path = self.rc.env.observe(
|
||||
EnvObsParams(obs_type=EnvObsType.TILE_PATH, coord=tile, level="arena")
|
||||
)
|
||||
|
||||
if tmp_arena_path == curr_arena_path:
|
||||
|
|
@ -418,14 +425,14 @@ class STRole(Role):
|
|||
if "<persona>" in plan:
|
||||
# Executing persona-persona interaction.
|
||||
target_p_tile = roles[plan.split("<persona>")[-1].strip()].scratch.curr_tile
|
||||
collision_maze = await self.rc.env.observe(EnvAPIAbstract(api_name="get_collision_maze"))
|
||||
collision_maze = self.rc.env.observe()["collision_maze"]
|
||||
potential_path = path_finder(
|
||||
collision_maze, self.rc.scratch.curr_tile, target_p_tile, collision_block_id
|
||||
)
|
||||
if len(potential_path) <= 2:
|
||||
target_tiles = [potential_path[0]]
|
||||
else:
|
||||
collision_maze = await self.rc.env.observe(EnvAPIAbstract(api_name="get_collision_maze"))
|
||||
collision_maze = self.rc.env.observe()["collision_maze"]
|
||||
potential_1 = path_finder(
|
||||
collision_maze,
|
||||
self.rc.scratch.curr_tile,
|
||||
|
|
@ -455,7 +462,7 @@ class STRole(Role):
|
|||
# Executing a random location action.
|
||||
plan = ":".join(plan.split(":")[:-1])
|
||||
|
||||
address_tiles = await self.rc.env.observe(EnvAPIAbstract(api_name="get_address_tiles"))
|
||||
address_tiles = self.rc.env.observe()["address_tiles"]
|
||||
target_tiles = address_tiles[plan]
|
||||
target_tiles = random.sample(list(target_tiles), 1)
|
||||
|
||||
|
|
@ -465,7 +472,7 @@ class STRole(Role):
|
|||
# Retrieve the target addresses. Again, plan is an action address in its
|
||||
# string form. <maze.address_tiles> takes this and returns candidate
|
||||
# coordinates.
|
||||
address_tiles = await self.rc.env.observe(EnvAPIAbstract(api_name="get_address_tiles"))
|
||||
address_tiles = self.rc.env.observe()["address_tiles"]
|
||||
if plan not in address_tiles:
|
||||
address_tiles["Johnson Park:park:park garden"] # ERRORRRRRRR
|
||||
else:
|
||||
|
|
@ -485,7 +492,7 @@ class STRole(Role):
|
|||
persona_name_set = set(roles.keys())
|
||||
new_target_tiles = []
|
||||
for i in target_tiles:
|
||||
access_tile = await self.rc.env.observe(EnvAPIAbstract(api_name="access_tile", kwargs={"tile": i}))
|
||||
access_tile = self.rc.env.observe(EnvObsParams(obs_type=EnvObsType.GET_TITLE, coord=i))
|
||||
curr_event_set = access_tile["events"]
|
||||
pass_curr_tile = False
|
||||
for j in curr_event_set:
|
||||
|
|
@ -507,7 +514,7 @@ class STRole(Role):
|
|||
# an input, and returns a list of coordinate tuples that becomes the
|
||||
# path.
|
||||
# e.g., [(0, 1), (1, 1), (1, 2), (1, 3), (1, 4)...]
|
||||
collision_maze = await self.rc.env.observe(EnvAPIAbstract(api_name="get_collision_maze"))
|
||||
collision_maze = self.rc.env.observe()["collision_maze"]
|
||||
curr_path = path_finder(collision_maze, curr_tile, i, collision_block_id)
|
||||
if not closest_target_tile:
|
||||
closest_target_tile = i
|
||||
|
|
@ -539,23 +546,20 @@ class STRole(Role):
|
|||
ret = True
|
||||
if role_env:
|
||||
for key, val in self.game_obj_cleanup.items():
|
||||
await self.rc.env.step(
|
||||
EnvAPIAbstract(api_name="turn_event_from_tile_idle", kwargs={"curr_event": key, "tile": val})
|
||||
)
|
||||
self.rc.env.step(EnvAction(action_type=EnvActionType.TURN_TILE_EVENT_IDLE, coord=val, event=key))
|
||||
|
||||
# reset game_obj_cleanup
|
||||
self.game_obj_cleanup = dict()
|
||||
curr_tile = self.role_tile
|
||||
new_tile = (role_env["x"], role_env["y"])
|
||||
await self.rc.env.step(
|
||||
EnvAPIAbstract(
|
||||
api_name="remove_subject_events_from_tile", kwargs={"subject": self.name, "tile": curr_tile}
|
||||
)
|
||||
self.rc.env.step(
|
||||
EnvAction(action_type=EnvActionType.RM_TITLE_SUB_EVENT, coord=curr_tile, subject=self.name)
|
||||
)
|
||||
await self.rc.env.step(
|
||||
EnvAPIAbstract(
|
||||
api_name="add_event_from_tile",
|
||||
kwargs={"curr_event": self.scratch.get_curr_event_and_desc(), "tile": new_tile},
|
||||
self.rc.env.step(
|
||||
EnvAction(
|
||||
action_type=EnvActionType.ADD_TILE_EVENT,
|
||||
coord=new_tile,
|
||||
event=self.scratch.get_curr_event_and_desc(),
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -563,16 +567,16 @@ class STRole(Role):
|
|||
# the persona gets there, we activate the object action.
|
||||
if not self.scratch.planned_path:
|
||||
self.game_obj_cleanup[self.scratch.get_curr_event_and_desc()] = new_tile
|
||||
await self.rc.env.step(
|
||||
EnvAPIAbstract(
|
||||
api_name="add_event_from_tile",
|
||||
kwargs={"curr_event": self.scratch.get_curr_event_and_desc(), "tile": new_tile},
|
||||
self.rc.env.step(
|
||||
EnvAction(
|
||||
action_type=EnvActionType.ADD_TILE_EVENT,
|
||||
coord=new_tile,
|
||||
event=self.scratch.get_curr_event_and_desc(),
|
||||
)
|
||||
)
|
||||
|
||||
blank = (self.scratch.get_curr_obj_event_and_desc()[0], None, None, None)
|
||||
await self.rc.env.step(
|
||||
EnvAPIAbstract(api_name="remove_event_from_tile", kwargs={"curr_event": blank, "tile": new_tile})
|
||||
)
|
||||
self.rc.env.step(EnvAction(action_type=EnvActionType.RM_TILE_EVENT, coord=new_tile, event=blank))
|
||||
|
||||
# update role's new tile
|
||||
self.rc.scratch.curr_tile = new_tile
|
||||
|
|
|
|||
1
examples/st_game/storage/.gitignore
vendored
1
examples/st_game/storage/.gitignore
vendored
|
|
@ -1,2 +1,3 @@
|
|||
# path to store simulation data
|
||||
test_*
|
||||
July*
|
||||
|
|
@ -31,7 +31,7 @@ async def test_gen_action_details():
|
|||
act_desp = "sleeping"
|
||||
act_dura = "120"
|
||||
|
||||
access_tile = await role.rc.env.observe(
|
||||
access_tile = await role.rc.env.read_from_api(
|
||||
EnvAPIAbstract(api_name="access_tile", kwargs={"tile": role.scratch.curr_tile})
|
||||
)
|
||||
act_world = access_tile["world"]
|
||||
|
|
|
|||
|
|
@ -3,9 +3,12 @@
|
|||
# @Desc : base env of executing environment
|
||||
|
||||
import asyncio
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Set, Union
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator
|
||||
|
||||
from metagpt.context import Context
|
||||
|
|
@ -14,6 +17,7 @@ from metagpt.environment.api.env_api import (
|
|||
ReadAPIRegistry,
|
||||
WriteAPIRegistry,
|
||||
)
|
||||
from metagpt.environment.base_env_space import BaseEnvAction, BaseEnvObsParams
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import get_function_schema, is_coroutine_func, is_send_to
|
||||
|
|
@ -49,6 +53,11 @@ def mark_as_writeable(func):
|
|||
class ExtEnv(BaseModel):
|
||||
"""External Env to integrate actual game environment"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
action_space: spaces.Space[ActType] = Field(default_factory=spaces.Space, exclude=True)
|
||||
observation_space: spaces.Space[ObsType] = Field(default_factory=spaces.Space, exclude=True)
|
||||
|
||||
def _check_api_exist(self, rw_api: Optional[str] = None):
|
||||
if not rw_api:
|
||||
raise ValueError(f"{rw_api} not exists")
|
||||
|
|
@ -61,39 +70,56 @@ class ExtEnv(BaseModel):
|
|||
else:
|
||||
return env_write_api_registry.get_apis()
|
||||
|
||||
async def observe(self, env_action: Union[str, EnvAPIAbstract]):
|
||||
async def read_from_api(self, env_action: Union[str, EnvAPIAbstract]):
|
||||
"""get observation from particular api of ExtEnv"""
|
||||
if isinstance(env_action, str):
|
||||
read_api = env_read_api_registry.get(api_name=env_action)["func"]
|
||||
self._check_api_exist(read_api)
|
||||
if is_coroutine_func(read_api):
|
||||
res = await read_api(self)
|
||||
env_read_api = env_read_api_registry.get(api_name=env_action)["func"]
|
||||
self._check_api_exist(env_read_api)
|
||||
if is_coroutine_func(env_read_api):
|
||||
res = await env_read_api(self)
|
||||
else:
|
||||
res = read_api(self)
|
||||
res = env_read_api(self)
|
||||
elif isinstance(env_action, EnvAPIAbstract):
|
||||
read_api = env_read_api_registry.get(api_name=env_action.api_name)["func"]
|
||||
self._check_api_exist(read_api)
|
||||
if is_coroutine_func(read_api):
|
||||
res = await read_api(self, *env_action.args, **env_action.kwargs)
|
||||
env_read_api = env_read_api_registry.get(api_name=env_action.api_name)["func"]
|
||||
self._check_api_exist(env_read_api)
|
||||
if is_coroutine_func(env_read_api):
|
||||
res = await env_read_api(self, *env_action.args, **env_action.kwargs)
|
||||
else:
|
||||
res = read_api(self, *env_action.args, **env_action.kwargs)
|
||||
res = env_read_api(self, *env_action.args, **env_action.kwargs)
|
||||
return res
|
||||
|
||||
async def step(self, env_action: Union[str, Message, EnvAPIAbstract, list[EnvAPIAbstract]]):
|
||||
async def write_thru_api(self, env_action: Union[str, Message, EnvAPIAbstract, list[EnvAPIAbstract]]):
|
||||
"""execute through particular api of ExtEnv"""
|
||||
res = None
|
||||
if isinstance(env_action, Message):
|
||||
self.publish_message(env_action)
|
||||
elif isinstance(env_action, EnvAPIAbstract):
|
||||
write_api = env_write_api_registry.get(env_action.api_name)["func"]
|
||||
self._check_api_exist(write_api)
|
||||
if is_coroutine_func(write_api):
|
||||
res = await write_api(self, *env_action.args, **env_action.kwargs)
|
||||
env_write_api = env_write_api_registry.get(env_action.api_name)["func"]
|
||||
self._check_api_exist(env_write_api)
|
||||
if is_coroutine_func(env_write_api):
|
||||
res = await env_write_api(self, *env_action.args, **env_action.kwargs)
|
||||
else:
|
||||
res = write_api(self, *env_action.args, **env_action.kwargs)
|
||||
res = env_write_api(self, *env_action.args, **env_action.kwargs)
|
||||
|
||||
return res
|
||||
|
||||
@abstractmethod
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
options: Optional[dict[str, Any]] = None,
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
"""Implement this to get init observation"""
|
||||
|
||||
@abstractmethod
|
||||
def observe(self, obs_params: Optional[BaseEnvObsParams] = None) -> Any:
|
||||
"""Implement this if you want to get partial observation from the env"""
|
||||
|
||||
@abstractmethod
|
||||
def step(self, action: BaseEnvAction) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]:
|
||||
"""Implement this to feed a action and then get new observation from the env"""
|
||||
|
||||
|
||||
class Environment(ExtEnv):
|
||||
"""环境,承载一批角色,角色可以向环境发布消息,可以被其他角色观察到
|
||||
|
|
|
|||
33
metagpt/environment/base_env_space.py
Normal file
33
metagpt/environment/base_env_space.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
||||
from enum import IntEnum
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class BaseEnvActionType(IntEnum):
|
||||
# # NONE = 0 # no action to run, just get observation
|
||||
pass
|
||||
|
||||
|
||||
class BaseEnvAction(BaseModel):
|
||||
"""env action type and its related params of action functions/apis"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
action_type: int = Field(default=0, description="action type")
|
||||
|
||||
|
||||
class BaseEnvObsType(IntEnum):
|
||||
# # NONE = 0 # get whole observation from env
|
||||
pass
|
||||
|
||||
|
||||
class BaseEnvObsParams(BaseModel):
|
||||
"""observation params for different EnvObsType to get its observe result"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
obs_type: int = Field(default=0, description="observation type")
|
||||
105
metagpt/environment/stanford_town_env/env_space.py
Normal file
105
metagpt/environment/stanford_town_env/env_space.py
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from gymnasium import spaces
|
||||
from pydantic import ConfigDict, Field, field_validator
|
||||
|
||||
from metagpt.environment.base_env_space import (
|
||||
BaseEnvAction,
|
||||
BaseEnvActionType,
|
||||
BaseEnvObsParams,
|
||||
BaseEnvObsType,
|
||||
)
|
||||
|
||||
|
||||
class EnvActionType(BaseEnvActionType):
|
||||
NONE = 0 # no action to run, just get observation
|
||||
|
||||
ADD_TILE_EVENT = 1 # Add an event triple to a tile
|
||||
RM_TILE_EVENT = 2 # Remove an event triple from a tile
|
||||
TURN_TILE_EVENT_IDLE = 3 # Turn an event triple from a tile into idle
|
||||
RM_TITLE_SUB_EVENT = 4 # Remove an event triple that has the input subject from a tile
|
||||
|
||||
|
||||
class EnvAction(BaseEnvAction):
|
||||
"""env action type and its related params of action functions/apis"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
action_type: int = Field(default=EnvActionType.NONE, description="action type")
|
||||
coord: npt.NDArray[np.int64] = Field(
|
||||
default_factory=lambda: np.zeros(2, dtype=np.int64), description="tile coordinate"
|
||||
)
|
||||
subject: str = Field(default="", description="subject name of first element in event")
|
||||
event: tuple[str, Optional[str], Optional[str], Optional[str]] = Field(
|
||||
default=["", None, None, None], description="tile event"
|
||||
)
|
||||
|
||||
@field_validator("coord", mode="before")
|
||||
@classmethod
|
||||
def check_coord(cls, coord) -> npt.NDArray[np.int64]:
|
||||
if not isinstance(coord, np.ndarray):
|
||||
return np.array(coord)
|
||||
|
||||
|
||||
class EnvObsType(BaseEnvObsType):
|
||||
"""get part observation with specific params"""
|
||||
|
||||
NONE = 0 # get whole observation from env
|
||||
|
||||
GET_TITLE = 1 # get the tile detail dictionary with given tile coord
|
||||
TILE_PATH = 2 # get the tile address with given tile coord
|
||||
TILE_NBR = 3 # get the neighbors of given tile coord and its vision radius
|
||||
|
||||
|
||||
class EnvObsParams(BaseEnvObsParams):
|
||||
"""observation params for different EnvObsType"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
obs_type: int = Field(default=EnvObsType.NONE, description="observation type")
|
||||
coord: npt.NDArray[np.int64] = Field(
|
||||
default_factory=lambda: np.zeros(2, dtype=np.int64), description="tile coordinate"
|
||||
)
|
||||
level: str = Field(default="", description="different level of title")
|
||||
vision_radius: int = Field(default=0, description="the vision radius of current tile")
|
||||
|
||||
@field_validator("coord", mode="before")
|
||||
@classmethod
|
||||
def check_coord(cls, coord) -> npt.NDArray[np.int64]:
|
||||
if not isinstance(coord, np.ndarray):
|
||||
return np.array(coord)
|
||||
|
||||
|
||||
EnvObsValType = Union[list[list[str]], dict[str, set[tuple[int, int]]], list[list[dict[str, Any]]]]
|
||||
|
||||
|
||||
def get_observation_space() -> spaces.Dict:
|
||||
# it's a
|
||||
space = spaces.Dict(
|
||||
{"collision_maze": spaces.Discrete(2), "tiles": spaces.Discrete(2), "address_tiles": spaces.Discrete(2)}
|
||||
)
|
||||
|
||||
return space
|
||||
|
||||
|
||||
def get_action_space(maze_shape: tuple[int, int]) -> spaces.Dict:
|
||||
"""The fields defined by the space correspond to the input parameters of the action except `action_type`"""
|
||||
space = spaces.Dict(
|
||||
{
|
||||
"action_type": spaces.Discrete(len(EnvActionType)),
|
||||
"coord": spaces.Box(
|
||||
np.array([0, 0], dtype=np.int64), np.array([maze_shape[0], maze_shape[1]], dtype=np.int64)
|
||||
), # coord of the tile
|
||||
"subject": spaces.Text(256), # the first element of an tile event
|
||||
"event": spaces.Tuple(
|
||||
(spaces.Text(256), spaces.Text(256), spaces.Text(256), spaces.Text(256))
|
||||
), # event is a tuple of four str
|
||||
}
|
||||
)
|
||||
return space
|
||||
|
|
@ -5,11 +5,20 @@
|
|||
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import ConfigDict, Field, model_validator
|
||||
|
||||
from metagpt.environment.base_env import ExtEnv, mark_as_readable, mark_as_writeable
|
||||
from metagpt.environment.stanford_town_env.env_space import (
|
||||
EnvAction,
|
||||
EnvActionType,
|
||||
EnvObsParams,
|
||||
EnvObsType,
|
||||
EnvObsValType,
|
||||
get_action_space,
|
||||
get_observation_space,
|
||||
)
|
||||
from metagpt.utils.common import read_csv_to_list, read_json_file
|
||||
|
||||
|
||||
|
|
@ -197,15 +206,82 @@ class StanfordTownExtEnv(ExtEnv):
|
|||
else:
|
||||
address_tiles[add] = set([(j, i)])
|
||||
values["address_tiles"] = address_tiles
|
||||
|
||||
values["action_space"] = get_action_space((maze_width, maze_height))
|
||||
values["observation_space"] = get_observation_space()
|
||||
return values
|
||||
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
options: Optional[dict[str, Any]] = None,
|
||||
) -> tuple[dict[str, EnvObsValType], dict[str, Any]]:
|
||||
"""reset env and get the init observation
|
||||
Return results corresponding to `observation, info`
|
||||
"""
|
||||
super().reset(seed=seed, options=options)
|
||||
|
||||
obs = self._get_obs()
|
||||
|
||||
return obs, {}
|
||||
|
||||
def _get_obs(self) -> dict[str, EnvObsValType]:
|
||||
"""Get observation"""
|
||||
return {
|
||||
"collision_maze": self.get_collision_maze(),
|
||||
"tiles": self.tiles,
|
||||
"address_tiles": self.get_address_tiles(),
|
||||
}
|
||||
|
||||
def observe(self, obs_params: Optional[EnvObsParams] = None) -> Any:
|
||||
"""Get partial or full observation from the env"""
|
||||
obs_type = obs_params.obs_type if obs_params else EnvObsType.NONE
|
||||
if obs_type == EnvObsType.NONE:
|
||||
obs = self._get_obs()
|
||||
elif obs_type == EnvObsType.GET_TITLE:
|
||||
obs = self.access_tile(tile=obs_params.coord)
|
||||
elif obs_type == EnvObsType.TILE_PATH:
|
||||
obs = self.get_tile_path(tile=obs_params.coord, level=obs_params.level)
|
||||
elif obs_type == EnvObsType.TILE_NBR:
|
||||
obs = self.get_nearby_tiles(tile=obs_params.coord, vision_r=obs_params.vision_radius)
|
||||
return obs
|
||||
|
||||
def step(self, action: EnvAction) -> tuple[dict[str, EnvObsValType], float, bool, bool, dict[str, Any]]:
|
||||
"""Execute action and then return observation
|
||||
Return results corresponding to `observation, reward, terminated, truncated, info`
|
||||
"""
|
||||
terminated = False
|
||||
try:
|
||||
self._execute_env_action(action)
|
||||
except Exception:
|
||||
terminated = True
|
||||
|
||||
obs = self._get_obs()
|
||||
|
||||
ret = (obs, 1.0, terminated, False, {})
|
||||
return ret
|
||||
|
||||
def _execute_env_action(self, action: EnvAction):
|
||||
action_type = action.action_type
|
||||
if action_type == EnvActionType.NONE:
|
||||
pass
|
||||
elif action_type == EnvActionType.ADD_TILE_EVENT:
|
||||
self.add_event_from_tile(curr_event=action.event, tile=action.coord)
|
||||
elif action_type == EnvActionType.RM_TILE_EVENT:
|
||||
self.remove_event_from_tile(curr_event=action.event, tile=action.coord)
|
||||
elif action_type == EnvActionType.TURN_TILE_EVENT_IDLE:
|
||||
self.turn_event_from_tile_idle(curr_event=action.event, tile=action.coord)
|
||||
elif action_type == EnvActionType.RM_TITLE_SUB_EVENT:
|
||||
self.remove_subject_events_from_tile(subject=action.subject, tile=action.coord)
|
||||
|
||||
def turn_coordinate_to_tile(self, px_coordinate: tuple[int, int]) -> tuple[int, int]:
|
||||
"""
|
||||
Turns a pixel coordinate to a tile coordinate.
|
||||
"""
|
||||
x = math.ceil(px_coordinate[0] / self.sq_tile_size)
|
||||
y = math.ceil(px_coordinate[1] / self.sq_tile_size)
|
||||
return (x, y)
|
||||
return x, y
|
||||
|
||||
@mark_as_readable
|
||||
def get_collision_maze(self) -> list:
|
||||
|
|
@ -316,10 +392,6 @@ class StanfordTownExtEnv(ExtEnv):
|
|||
nearby_tiles += [(i, j)]
|
||||
return nearby_tiles
|
||||
|
||||
@mark_as_writeable
|
||||
def add_tiles_event(self, pt_y: int, pt_x: int, event: Tuple[str, str, str, str]):
|
||||
self.tiles[pt_y][pt_x]["events"].add(event)
|
||||
|
||||
@mark_as_writeable
|
||||
def add_event_from_tile(self, curr_event: tuple[str], tile: tuple[int, int]) -> None:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -69,4 +69,5 @@ imap_tools==1.5.0 # Used by metagpt/tools/libs/email_login.py
|
|||
qianfan==0.3.2
|
||||
dashscope==1.14.1
|
||||
rank-bm25==0.2.2 # for tool recommendation
|
||||
jieba==0.42.1 # for tool recommendation
|
||||
jieba==0.42.1 # for tool recommendation
|
||||
gymnasium==0.29.1
|
||||
|
|
@ -4,6 +4,12 @@
|
|||
|
||||
from pathlib import Path
|
||||
|
||||
from metagpt.environment.stanford_town_env.env_space import (
|
||||
EnvAction,
|
||||
EnvActionType,
|
||||
EnvObsParams,
|
||||
EnvObsType,
|
||||
)
|
||||
from metagpt.environment.stanford_town_env.stanford_town_ext_env import (
|
||||
StanfordTownExtEnv,
|
||||
)
|
||||
|
|
@ -27,7 +33,6 @@ def test_stanford_town_ext_env():
|
|||
assert len(ext_env.get_nearby_tiles(tile=tile, vision_r=5)) == 121
|
||||
|
||||
event = ("double studio:double studio:bedroom 2:bed", None, None, None)
|
||||
ext_env.add_tiles_event(tile[1], tile[0], event=event)
|
||||
ext_env.add_event_from_tile(event, tile)
|
||||
assert len(ext_env.tiles[tile[1]][tile[0]]["events"]) == 1
|
||||
|
||||
|
|
@ -38,3 +43,22 @@ def test_stanford_town_ext_env():
|
|||
|
||||
ext_env.remove_subject_events_from_tile(subject=event[0], tile=tile)
|
||||
assert len(ext_env.tiles[tile[1]][tile[0]]["events"]) == 0
|
||||
|
||||
|
||||
def test_stanford_town_ext_env_observe_step():
|
||||
ext_env = StanfordTownExtEnv(maze_asset_path=maze_asset_path)
|
||||
obs, info = ext_env.reset()
|
||||
assert len(info) == 0
|
||||
assert len(obs["address_tiles"]) == 306
|
||||
|
||||
tile = (58, 9)
|
||||
obs = ext_env.observe(obs_params=EnvObsParams(obs_type=EnvObsType.TILE_PATH, coord=tile, level="world"))
|
||||
assert obs == "the Ville"
|
||||
|
||||
action = ext_env.action_space.sample()
|
||||
assert len(action) == 4
|
||||
assert len(action["event"]) == 4
|
||||
|
||||
event = ("double studio:double studio:bedroom 2:bed", None, None, None)
|
||||
obs, _, _, _, _ = ext_env.step(action=EnvAction(action_type=EnvActionType.ADD_TILE_EVENT, coord=tile, event=event))
|
||||
assert len(ext_env.tiles[tile[1]][tile[0]]["events"]) == 1
|
||||
|
|
|
|||
|
|
@ -44,11 +44,11 @@ async def test_ext_env():
|
|||
assert len(apis) > 0
|
||||
assert len(apis["read_api"]) == 3
|
||||
|
||||
_ = await env.step(EnvAPIAbstract(api_name="write_api", kwargs={"a": 5, "b": 10}))
|
||||
_ = await env.write_thru_api(EnvAPIAbstract(api_name="write_api", kwargs={"a": 5, "b": 10}))
|
||||
assert env.value == 15
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await env.observe("not_exist_api")
|
||||
await env.read_from_api("not_exist_api")
|
||||
|
||||
assert await env.observe("read_api_no_param") == 15
|
||||
assert await env.observe(EnvAPIAbstract(api_name="read_api", kwargs={"a": 5, "b": 5})) == 10
|
||||
assert await env.read_from_api("read_api_no_param") == 15
|
||||
assert await env.read_from_api(EnvAPIAbstract(api_name="read_api", kwargs={"a": 5, "b": 5})) == 10
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue