add base environment action_space/observation space and update stanford_town_env

This commit is contained in:
better629 2024-03-26 20:22:45 +08:00
parent 6637312e2d
commit 9a8627f23c
7 changed files with 160 additions and 46 deletions

View file

@ -4,7 +4,7 @@
import random
from metagpt.environment.api.env_api import EnvAPIAbstract
from metagpt.environment.stanford_town_env.env_space import EnvObsParams, EnvObsType
from metagpt.logs import logger
from .st_action import STAction
@ -367,8 +367,8 @@ class GenActionDetails(STAction):
return fs
async def run(self, role: "STRole", act_desp: str, act_dura):
access_tile = await role.rc.env.observe(
EnvAPIAbstract(api_name="access_tile", kwargs={"tile": role.scratch.curr_tile})
access_tile = role.rc.env.observe(
obs_params=EnvObsParams(obs_type=EnvObsType.GET_TITLE, coord=role.scratch.curr_tile)
)
act_world = access_tile["world"]
act_sector = await GenActionSector().run(role, access_tile, act_desp)

View file

@ -4,7 +4,7 @@
from examples.st_game.actions.st_action import STAction
from examples.st_game.utils.utils import extract_first_json_dict
from metagpt.environment.api.env_api import EnvAPIAbstract
from metagpt.environment.stanford_town_env.env_space import EnvObsParams, EnvObsType
from metagpt.logs import logger
@ -113,8 +113,8 @@ class GenIterChatUTT(STAction):
]
return prompt_input
access_tile = await init_role.rc.env.observe(
EnvAPIAbstract(api_name="access_tile", kwargs={"tile": init_role.scratch.curr_tile})
access_tile = init_role.rc.env.observe(
obs_params=EnvObsParams(obs_type=EnvObsType.GET_TITLE, coord=init_role.scratch.curr_tile)
)
prompt_input = create_prompt_input(access_tile, init_role, target_role, retrieved, curr_context, curr_chat)
prompt = self.generate_prompt_with_tmpl_filename(prompt_input, "iterative_convo_v1.txt")

View file

@ -36,7 +36,12 @@ from examples.st_game.utils.mg_ga_transform import (
)
from examples.st_game.utils.utils import get_embedding, path_finder
from metagpt.actions.add_requirement import UserRequirement
from metagpt.environment.api.env_api import EnvAPIAbstract
from metagpt.environment.stanford_town_env.env_space import (
EnvAction,
EnvActionType,
EnvObsParams,
EnvObsType,
)
from metagpt.logs import logger
from metagpt.roles.role import Role, RoleContext
from metagpt.schema import Message
@ -115,10 +120,12 @@ class STRole(Role):
pt_x = role_env["x"]
pt_y = role_env["y"]
self.rc.scratch.curr_tile = (pt_x, pt_y)
await self.rc.env.step(
EnvAPIAbstract(
api_name="add_tiles_event",
kwargs={"pt_y": pt_y, "pt_x": pt_x, "event": self.scratch.get_curr_event_and_desc()},
self.rc.env.step(
EnvAction(
action_type=EnvActionType.ADD_TILE_EVENT,
coord=(pt_x, pt_y),
event=self.scratch.get_curr_event_and_desc(),
)
)
@ -231,24 +238,24 @@ class STRole(Role):
# PERCEIVE SPACE
# We get the nearby tiles given our current tile and the persona's vision
# radius.
nearby_tiles = await self.rc.env.observe(
EnvAPIAbstract(
api_name="get_nearby_tiles",
kwargs={"tile": self.rc.scratch.curr_tile, "vision_r": self.rc.scratch.vision_r},
nearby_tiles = self.rc.env.observe(
EnvObsParams(
obs_type=EnvObsType.TILE_NBR, coord=self.rc.scratch.curr_tile, vision_radius=self.rc.scratch.vision_r
)
)
# We then store the perceived space. Note that the s_mem of the persona is
# in the form of a tree constructed using dictionaries.
for tile in nearby_tiles:
tile_info = await self.rc.env.observe(EnvAPIAbstract(api_name="access_tile", kwargs={"tile": tile}))
tile_info = self.rc.env.observe(EnvObsParams(obs_type=EnvObsType.GET_TITLE, coord=tile))
self.rc.spatial_memory.add_tile_info(tile_info)
# PERCEIVE EVENTS.
# We will perceive events that take place in the same arena as the
# persona's current arena.
curr_arena_path = await self.rc.env.observe(
EnvAPIAbstract(api_name="get_tile_path", kwargs={"tile": self.rc.scratch.curr_tile, "level": "arena"})
curr_arena_path = self.rc.env.observe(
EnvObsParams(obs_type=EnvObsType.TILE_PATH, coord=self.rc.scratch.curr_tile, level="arena")
)
# We do not perceive the same event twice (this can happen if an object is
@ -260,10 +267,10 @@ class STRole(Role):
# First, we put all events that are occuring in the nearby tiles into the
# percept_events_list
for tile in nearby_tiles:
tile_details = await self.rc.env.observe(EnvAPIAbstract(api_name="access_tile", kwargs={"tile": tile}))
tile_details = self.rc.env.observe(EnvObsParams(obs_type=EnvObsType.GET_TITLE, coord=tile))
if tile_details["events"]:
tmp_arena_path = await self.rc.env.observe(
EnvAPIAbstract(api_name="get_tile_path", kwargs={"tile": tile, "level": "arena"})
tmp_arena_path = self.rc.env.observe(
EnvObsParams(obs_type=EnvObsType.TILE_PATH, coord=tile, level="arena")
)
if tmp_arena_path == curr_arena_path:
@ -418,14 +425,14 @@ class STRole(Role):
if "<persona>" in plan:
# Executing persona-persona interaction.
target_p_tile = roles[plan.split("<persona>")[-1].strip()].scratch.curr_tile
collision_maze = await self.rc.env.observe(EnvAPIAbstract(api_name="get_collision_maze"))
collision_maze = self.rc.env.observe()["collision_maze"]
potential_path = path_finder(
collision_maze, self.rc.scratch.curr_tile, target_p_tile, collision_block_id
)
if len(potential_path) <= 2:
target_tiles = [potential_path[0]]
else:
collision_maze = await self.rc.env.observe(EnvAPIAbstract(api_name="get_collision_maze"))
collision_maze = self.rc.env.observe()["collision_maze"]
potential_1 = path_finder(
collision_maze,
self.rc.scratch.curr_tile,
@ -455,7 +462,7 @@ class STRole(Role):
# Executing a random location action.
plan = ":".join(plan.split(":")[:-1])
address_tiles = await self.rc.env.observe(EnvAPIAbstract(api_name="get_address_tiles"))
address_tiles = self.rc.env.observe()["address_tiles"]
target_tiles = address_tiles[plan]
target_tiles = random.sample(list(target_tiles), 1)
@ -465,7 +472,7 @@ class STRole(Role):
# Retrieve the target addresses. Again, plan is an action address in its
# string form. <maze.address_tiles> takes this and returns candidate
# coordinates.
address_tiles = await self.rc.env.observe(EnvAPIAbstract(api_name="get_address_tiles"))
address_tiles = self.rc.env.observe()["address_tiles"]
if plan not in address_tiles:
address_tiles["Johnson Park:park:park garden"] # ERRORRRRRRR
else:
@ -485,7 +492,7 @@ class STRole(Role):
persona_name_set = set(roles.keys())
new_target_tiles = []
for i in target_tiles:
access_tile = await self.rc.env.observe(EnvAPIAbstract(api_name="access_tile", kwargs={"tile": i}))
access_tile = self.rc.env.observe(EnvObsParams(obs_type=EnvObsType.GET_TITLE, coord=i))
curr_event_set = access_tile["events"]
pass_curr_tile = False
for j in curr_event_set:
@ -507,7 +514,7 @@ class STRole(Role):
# an input, and returns a list of coordinate tuples that becomes the
# path.
# e.g., [(0, 1), (1, 1), (1, 2), (1, 3), (1, 4)...]
collision_maze = await self.rc.env.observe(EnvAPIAbstract(api_name="get_collision_maze"))
collision_maze = self.rc.env.observe()["collision_maze"]
curr_path = path_finder(collision_maze, curr_tile, i, collision_block_id)
if not closest_target_tile:
closest_target_tile = i
@ -539,23 +546,20 @@ class STRole(Role):
ret = True
if role_env:
for key, val in self.game_obj_cleanup.items():
await self.rc.env.step(
EnvAPIAbstract(api_name="turn_event_from_tile_idle", kwargs={"curr_event": key, "tile": val})
)
self.rc.env.step(EnvAction(action_type=EnvActionType.TURN_TILE_EVENT_IDLE, coord=val, event=key))
# reset game_obj_cleanup
self.game_obj_cleanup = dict()
curr_tile = self.role_tile
new_tile = (role_env["x"], role_env["y"])
await self.rc.env.step(
EnvAPIAbstract(
api_name="remove_subject_events_from_tile", kwargs={"subject": self.name, "tile": curr_tile}
)
self.rc.env.step(
EnvAction(action_type=EnvActionType.RM_TITLE_SUB_EVENT, coord=curr_tile, subject=self.name)
)
await self.rc.env.step(
EnvAPIAbstract(
api_name="add_event_from_tile",
kwargs={"curr_event": self.scratch.get_curr_event_and_desc(), "tile": new_tile},
self.rc.env.step(
EnvAction(
action_type=EnvActionType.ADD_TILE_EVENT,
coord=new_tile,
event=self.scratch.get_curr_event_and_desc(),
)
)
@ -563,16 +567,16 @@ class STRole(Role):
# the persona gets there, we activate the object action.
if not self.scratch.planned_path:
self.game_obj_cleanup[self.scratch.get_curr_event_and_desc()] = new_tile
await self.rc.env.step(
EnvAPIAbstract(
api_name="add_event_from_tile",
kwargs={"curr_event": self.scratch.get_curr_event_and_desc(), "tile": new_tile},
self.rc.env.step(
EnvAction(
action_type=EnvActionType.ADD_TILE_EVENT,
coord=new_tile,
event=self.scratch.get_curr_event_and_desc(),
)
)
blank = (self.scratch.get_curr_obj_event_and_desc()[0], None, None, None)
await self.rc.env.step(
EnvAPIAbstract(api_name="remove_event_from_tile", kwargs={"curr_event": blank, "tile": new_tile})
)
self.rc.env.step(EnvAction(action_type=EnvActionType.RM_TILE_EVENT, coord=new_tile, event=blank))
# update role's new tile
self.rc.scratch.curr_tile = new_tile

View file

@ -1,2 +1,3 @@
# path to store simulation data
test_*
July*

View file

@ -31,7 +31,7 @@ async def test_gen_action_details():
act_desp = "sleeping"
act_dura = "120"
access_tile = await role.rc.env.observe(
access_tile = await role.rc.env.read_from_api(
EnvAPIAbstract(api_name="access_tile", kwargs={"tile": role.scratch.curr_tile})
)
act_world = access_tile["world"]

View file

@ -10,7 +10,11 @@ from typing import Any, Optional
from pydantic import ConfigDict, Field, model_validator
from metagpt.environment.base_env import ExtEnv, mark_as_readable, mark_as_writeable
<<<<<<< HEAD:metagpt/environment/stanford_town/stanford_town_ext_env.py
from metagpt.environment.stanford_town.env_space import (
=======
from metagpt.environment.stanford_town_env.env_space import (
>>>>>>> 5e6f2757 (add base environment action_space/observation space and update stanford_town_env):metagpt/environment/stanford_town_env/stanford_town_ext_env.py
EnvAction,
EnvActionType,
EnvObsParams,

View file

@ -0,0 +1,105 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
from typing import Any, Optional, Union
import numpy as np
import numpy.typing as npt
from gymnasium import spaces
from pydantic import ConfigDict, Field, field_validator
from metagpt.environment.base_env_space import (
BaseEnvAction,
BaseEnvActionType,
BaseEnvObsParams,
BaseEnvObsType,
)
class EnvActionType(BaseEnvActionType):
NONE = 0 # no action to run, just get observation
ADD_TILE_EVENT = 1 # Add an event triple to a tile
RM_TILE_EVENT = 2 # Remove an event triple from a tile
TURN_TILE_EVENT_IDLE = 3 # Turn an event triple from a tile into idle
RM_TITLE_SUB_EVENT = 4 # Remove an event triple that has the input subject from a tile
class EnvAction(BaseEnvAction):
"""env action type and its related params of action functions/apis"""
model_config = ConfigDict(arbitrary_types_allowed=True)
action_type: int = Field(default=EnvActionType.NONE, description="action type")
coord: npt.NDArray[np.int64] = Field(
default_factory=lambda: np.zeros(2, dtype=np.int64), description="tile coordinate"
)
subject: str = Field(default="", description="subject name of first element in event")
event: tuple[str, Optional[str], Optional[str], Optional[str]] = Field(
default=["", None, None, None], description="tile event"
)
@field_validator("coord", mode="before")
@classmethod
def check_coord(cls, coord) -> npt.NDArray[np.int64]:
if not isinstance(coord, np.ndarray):
return np.array(coord)
class EnvObsType(BaseEnvObsType):
"""get part observation with specific params"""
NONE = 0 # get whole observation from env
GET_TITLE = 1 # get the tile detail dictionary with given tile coord
TILE_PATH = 2 # get the tile address with given tile coord
TILE_NBR = 3 # get the neighbors of given tile coord and its vision radius
class EnvObsParams(BaseEnvObsParams):
"""observation params for different EnvObsType"""
model_config = ConfigDict(arbitrary_types_allowed=True)
obs_type: int = Field(default=EnvObsType.NONE, description="observation type")
coord: npt.NDArray[np.int64] = Field(
default_factory=lambda: np.zeros(2, dtype=np.int64), description="tile coordinate"
)
level: str = Field(default="", description="different level of title")
vision_radius: int = Field(default=0, description="the vision radius of current tile")
@field_validator("coord", mode="before")
@classmethod
def check_coord(cls, coord) -> npt.NDArray[np.int64]:
if not isinstance(coord, np.ndarray):
return np.array(coord)
EnvObsValType = Union[list[list[str]], dict[str, set[tuple[int, int]]], list[list[dict[str, Any]]]]
def get_observation_space() -> spaces.Dict:
# it's a
space = spaces.Dict(
{"collision_maze": spaces.Discrete(2), "tiles": spaces.Discrete(2), "address_tiles": spaces.Discrete(2)}
)
return space
def get_action_space(maze_shape: tuple[int, int]) -> spaces.Dict:
"""The fields defined by the space correspond to the input parameters of the action except `action_type`"""
space = spaces.Dict(
{
"action_type": spaces.Discrete(len(EnvActionType)),
"coord": spaces.Box(
np.array([0, 0], dtype=np.int64), np.array([maze_shape[0], maze_shape[1]], dtype=np.int64)
), # coord of the tile
"subject": spaces.Text(256), # the first element of an tile event
"event": spaces.Tuple(
(spaces.Text(256), spaces.Text(256), spaces.Text(256), spaces.Text(256))
), # event is a tuple of four str
}
)
return space