add base environment action_space/observation space and update stanford_town_env

This commit is contained in:
better629 2024-03-26 20:22:45 +08:00
parent e240c0dc01
commit 2b7d09ede2
12 changed files with 341 additions and 75 deletions

View file

@ -4,6 +4,12 @@
from pathlib import Path
from metagpt.environment.stanford_town_env.env_space import (
EnvAction,
EnvActionType,
EnvObsParams,
EnvObsType,
)
from metagpt.environment.stanford_town_env.stanford_town_ext_env import (
StanfordTownExtEnv,
)
@ -27,7 +33,6 @@ def test_stanford_town_ext_env():
assert len(ext_env.get_nearby_tiles(tile=tile, vision_r=5)) == 121
event = ("double studio:double studio:bedroom 2:bed", None, None, None)
ext_env.add_tiles_event(tile[1], tile[0], event=event)
ext_env.add_event_from_tile(event, tile)
assert len(ext_env.tiles[tile[1]][tile[0]]["events"]) == 1
@ -38,3 +43,22 @@ def test_stanford_town_ext_env():
ext_env.remove_subject_events_from_tile(subject=event[0], tile=tile)
assert len(ext_env.tiles[tile[1]][tile[0]]["events"]) == 0
def test_stanford_town_ext_env_observe_step():
ext_env = StanfordTownExtEnv(maze_asset_path=maze_asset_path)
obs, info = ext_env.reset()
assert len(info) == 0
assert len(obs["address_tiles"]) == 306
tile = (58, 9)
obs = ext_env.observe(obs_params=EnvObsParams(obs_type=EnvObsType.TILE_PATH, coord=tile, level="world"))
assert obs == "the Ville"
action = ext_env.action_space.sample()
assert len(action) == 4
assert len(action["event"]) == 4
event = ("double studio:double studio:bedroom 2:bed", None, None, None)
obs, _, _, _, _ = ext_env.step(action=EnvAction(action_type=EnvActionType.ADD_TILE_EVENT, coord=tile, event=event))
assert len(ext_env.tiles[tile[1]][tile[0]]["events"]) == 1

View file

@ -44,11 +44,11 @@ async def test_ext_env():
assert len(apis) > 0
assert len(apis["read_api"]) == 3
_ = await env.step(EnvAPIAbstract(api_name="write_api", kwargs={"a": 5, "b": 10}))
_ = await env.write_thru_api(EnvAPIAbstract(api_name="write_api", kwargs={"a": 5, "b": 10}))
assert env.value == 15
with pytest.raises(ValueError):
await env.observe("not_exist_api")
await env.read_from_api("not_exist_api")
assert await env.observe("read_api_no_param") == 15
assert await env.observe(EnvAPIAbstract(api_name="read_api", kwargs={"a": 5, "b": 5})) == 10
assert await env.read_from_api("read_api_no_param") == 15
assert await env.read_from_api(EnvAPIAbstract(api_name="read_api", kwargs={"a": 5, "b": 5})) == 10