From b5bfa4b8a71b3d345cc92d9329cd7ff7fc0b31ae Mon Sep 17 00:00:00 2001 From: better629 Date: Tue, 26 Mar 2024 20:35:47 +0800 Subject: [PATCH] add inherited funcs and then implement --- .../android_env/android_ext_env.py | 15 +++++++++++++ .../minecraft_env/minecraft_env.py | 12 +++++------ .../minecraft_env/minecraft_ext_env.py | 21 ++++++++++++++++--- .../werewolf_env/werewolf_ext_env.py | 17 ++++++++++++++- tests/metagpt/environment/test_base_env.py | 17 +++++++++++++++ 5 files changed, 72 insertions(+), 10 deletions(-) diff --git a/metagpt/environment/android_env/android_ext_env.py b/metagpt/environment/android_env/android_ext_env.py index b81b2cd26..01a24c5b9 100644 --- a/metagpt/environment/android_env/android_ext_env.py +++ b/metagpt/environment/android_env/android_ext_env.py @@ -10,6 +10,7 @@ from pydantic import Field from metagpt.environment.android_env.const import ADB_EXEC_FAIL from metagpt.environment.base_env import ExtEnv, mark_as_readable, mark_as_writeable +from metagpt.environment.base_env_space import BaseEnvAction, BaseEnvObsParams class AndroidExtEnv(ExtEnv): @@ -19,6 +20,20 @@ class AndroidExtEnv(ExtEnv): width: int = Field(default=720, description="device screen width") height: int = Field(default=1080, description="device screen height") + def reset( + self, + *, + seed: Optional[int] = None, + options: Optional[dict[str, Any]] = None, + ) -> tuple[dict[str, Any], dict[str, Any]]: + pass + + def observe(self, obs_params: Optional[BaseEnvObsParams] = None) -> Any: + pass + + def step(self, action: BaseEnvAction) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]: + pass + def __init__(self, **data: Any): super().__init__(**data) if data.get("device_id"): diff --git a/metagpt/environment/minecraft_env/minecraft_env.py b/metagpt/environment/minecraft_env/minecraft_env.py index 26d4d03a8..bba35ce21 100644 --- a/metagpt/environment/minecraft_env/minecraft_env.py +++ b/metagpt/environment/minecraft_env/minecraft_env.py @@ -282,7 +282,7 @@ class MinecraftEnv(Environment, MinecraftExtEnv): position = event["status"]["position"] blocks.append(block) positions.append(position) - new_events = self.step( + new_events = self._step( f"await givePlacedItemBack(bot, {json.dumps(blocks)}, {json.dumps(positions)})", programs=self.programs, ) @@ -323,7 +323,7 @@ class MinecraftEnv(Environment, MinecraftExtEnv): Exception: If there is an issue retrieving events. """ try: - self.reset( + self._reset( options={ "mode": "soft", "wait_ticks": 20, @@ -332,13 +332,13 @@ class MinecraftEnv(Environment, MinecraftExtEnv): # difficulty = "easy" if len(self.completed_tasks) > 15 else "peaceful" difficulty = "peaceful" - events = self.step("bot.chat(`/time set ${getNextTime()}`);\n" + f"bot.chat('/difficulty {difficulty}');") + events = self._step("bot.chat(`/time set ${getNextTime()}`);\n" + f"bot.chat('/difficulty {difficulty}');") self.update_event(events) return events except Exception as e: time.sleep(3) # wait for mineflayer to exit # reset bot status here - events = self.reset( + events = self._reset( options={ "mode": "hard", "wait_ticks": 20, @@ -365,7 +365,7 @@ class MinecraftEnv(Environment, MinecraftExtEnv): Exception: If there is an issue retrieving events. """ try: - events = self.step( + events = self._step( code=self.code, programs=self.programs, ) @@ -374,7 +374,7 @@ class MinecraftEnv(Environment, MinecraftExtEnv): except Exception as e: time.sleep(3) # wait for mineflayer to exit # reset bot status here - events = self.reset( + events = self._reset( options={ "mode": "hard", "wait_ticks": 20, diff --git a/metagpt/environment/minecraft_env/minecraft_ext_env.py b/metagpt/environment/minecraft_env/minecraft_ext_env.py index 3b793079f..74f417eb0 100644 --- a/metagpt/environment/minecraft_env/minecraft_ext_env.py +++ b/metagpt/environment/minecraft_env/minecraft_ext_env.py @@ -5,12 +5,13 @@ import json import time -from typing import Optional +from typing import Any, Optional import requests from pydantic import ConfigDict, Field, model_validator from metagpt.environment.base_env import ExtEnv, mark_as_writeable +from metagpt.environment.base_env_space import BaseEnvAction, BaseEnvObsParams from metagpt.environment.minecraft_env.const import ( MC_CKPT_DIR, MC_CORE_INVENTORY_ITEMS, @@ -38,6 +39,20 @@ class MinecraftExtEnv(ExtEnv): server_paused: bool = Field(default=False) warm_up: dict = Field(default=dict()) + def reset( + self, + *, + seed: Optional[int] = None, + options: Optional[dict[str, Any]] = None, + ) -> tuple[dict[str, Any], dict[str, Any]]: + pass + + def observe(self, obs_params: Optional[BaseEnvObsParams] = None) -> Any: + pass + + def step(self, action: BaseEnvAction) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]: + pass + @property def server(self) -> str: return f"{self.server_host}:{self.server_port}" @@ -115,7 +130,7 @@ class MinecraftExtEnv(ExtEnv): return res.json() @mark_as_writeable - def reset(self, *, seed=None, options=None) -> dict: + def _reset(self, *, seed=None, options=None) -> dict: if options is None: options = {} if options.get("inventory", {}) and options.get("mode", "hard") != "hard": @@ -145,7 +160,7 @@ class MinecraftExtEnv(ExtEnv): return json.loads(returned_data) @mark_as_writeable - def step(self, code: str, programs: str = "") -> dict: + def _step(self, code: str, programs: str = "") -> dict: if not self.has_reset: raise RuntimeError("Environment has not been reset yet") self.check_process() diff --git a/metagpt/environment/werewolf_env/werewolf_ext_env.py b/metagpt/environment/werewolf_env/werewolf_ext_env.py index 7c4b4c475..3f2508b06 100644 --- a/metagpt/environment/werewolf_env/werewolf_ext_env.py +++ b/metagpt/environment/werewolf_env/werewolf_ext_env.py @@ -5,11 +5,12 @@ import random from collections import Counter from enum import Enum -from typing import Callable, Optional +from typing import Any, Callable, Optional from pydantic import ConfigDict, Field from metagpt.environment.base_env import ExtEnv, mark_as_readable, mark_as_writeable +from metagpt.environment.base_env_space import BaseEnvAction, BaseEnvObsParams from metagpt.logs import logger @@ -128,6 +129,20 @@ class WerewolfExtEnv(ExtEnv): player_poisoned: Optional[str] = Field(default=None) player_current_dead: list[str] = Field(default=[]) + def reset( + self, + *, + seed: Optional[int] = None, + options: Optional[dict[str, Any]] = None, + ) -> tuple[dict[str, Any], dict[str, Any]]: + pass + + def observe(self, obs_params: Optional[BaseEnvObsParams] = None) -> Any: + pass + + def step(self, action: BaseEnvAction) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]: + pass + @property def living_players(self) -> list[str]: player_names = [] diff --git a/tests/metagpt/environment/test_base_env.py b/tests/metagpt/environment/test_base_env.py index 28815a874..404f1c206 100644 --- a/tests/metagpt/environment/test_base_env.py +++ b/tests/metagpt/environment/test_base_env.py @@ -2,6 +2,8 @@ # -*- coding: utf-8 -*- # @Desc : the unittest of ExtEnv&Env +from typing import Any, Optional + import pytest from metagpt.environment.api.env_api import EnvAPIAbstract @@ -12,11 +14,26 @@ from metagpt.environment.base_env import ( mark_as_readable, mark_as_writeable, ) +from metagpt.environment.base_env_space import BaseEnvAction, BaseEnvObsParams class ForTestEnv(Environment): value: int = 0 + def reset( + self, + *, + seed: Optional[int] = None, + options: Optional[dict[str, Any]] = None, + ) -> tuple[dict[str, Any], dict[str, Any]]: + pass + + def observe(self, obs_params: Optional[BaseEnvObsParams] = None) -> Any: + pass + + def step(self, action: BaseEnvAction) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]: + pass + @mark_as_readable def read_api_no_param(self): return self.value