diff --git a/metagpt/environment/android/android_ext_env.py b/metagpt/environment/android/android_ext_env.py index d2344fa1f..1627bf99e 100644 --- a/metagpt/environment/android/android_ext_env.py +++ b/metagpt/environment/android/android_ext_env.py @@ -9,8 +9,14 @@ from typing import Any, Optional from pydantic import Field from metagpt.environment.android.const import ADB_EXEC_FAIL +from metagpt.environment.android.env_space import ( + EnvAction, + EnvActionType, + EnvObsParams, + EnvObsType, + EnvObsValType, +) from metagpt.environment.base_env import ExtEnv, mark_as_readable, mark_as_writeable -from metagpt.environment.base_env_space import BaseEnvAction, BaseEnvObsParams class AndroidExtEnv(ExtEnv): @@ -20,20 +26,6 @@ class AndroidExtEnv(ExtEnv): width: int = Field(default=720, description="device screen width") height: int = Field(default=1080, description="device screen height") - def reset( - self, - *, - seed: Optional[int] = None, - options: Optional[dict[str, Any]] = None, - ) -> tuple[dict[str, Any], dict[str, Any]]: - pass - - def observe(self, obs_params: Optional[BaseEnvObsParams] = None) -> Any: - pass - - def step(self, action: BaseEnvAction) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]: - pass - def __init__(self, **data: Any): super().__init__(**data) if data.get("device_id"): @@ -41,6 +33,58 @@ class AndroidExtEnv(ExtEnv): self.width = data.get("width", width) self.height = data.get("height", height) + def reset( + self, + *, + seed: Optional[int] = None, + options: Optional[dict[str, Any]] = None, + ) -> tuple[dict[str, Any], dict[str, Any]]: + super().reset(seed=seed, options=options) + + obs = self._get_obs() + + return obs, {} + + def _get_obs(self) -> dict[str, EnvObsValType]: + pass + + def observe(self, obs_params: Optional[EnvObsParams] = None) -> Any: + obs_type = obs_params.obs_type if obs_params else EnvObsType.NONE + if obs_type == EnvObsType.NONE: + pass + elif obs_type == EnvObsType.GET_SCREENSHOT: + obs = self.get_screenshot(ss_name=obs_params.ss_name, local_save_dir=obs_params.local_save_dir) + elif obs_type == EnvObsType.GET_XML: + obs = self.get_xml(xml_name=obs_params.xml_name, local_save_dir=obs_params.local_save_dir) + return obs + + def step(self, action: EnvAction) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]: + res = self._execute_env_action(action) + + obs = {} + + ret = (obs, 1.0, False, False, {"res": res}) + return ret + + def _execute_env_action(self, action: EnvAction): + action_type = action.action_type + res = None + if action_type == EnvActionType.NONE: + pass + elif action_type == EnvActionType.SYSTEM_BACK: + res = self.system_back() + elif action_type == EnvActionType.SYSTEM_TAP: + res = self.system_tap(x=action.coord[0], y=action.coord[1]) + elif action_type == EnvActionType.USER_INPUT: + res = self.user_input(input_txt=action.input_txt) + elif action_type == EnvActionType.USER_LONGPRESS: + res = self.user_longpress(x=action.coord[0], y=action.coord[1]) + elif action_type == EnvActionType.USER_SWIPE: + res = self.user_swipe(x=action.coord[0], y=action.coord[1], orient=action.orient, dist=action.dist) + elif action_type == EnvActionType.USER_SWIPE_TO: + res = self.user_swipe_to(start=action.coord, end=action.tgt_coord) + return res + @property def adb_prefix_si(self): """adb cmd prefix with `device_id` and `shell input`""" diff --git a/metagpt/environment/android/env_space.py b/metagpt/environment/android/env_space.py new file mode 100644 index 000000000..55ddbf83e --- /dev/null +++ b/metagpt/environment/android/env_space.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + + +import numpy as np +import numpy.typing as npt +from gymnasium import spaces +from pydantic import ConfigDict, Field, field_validator + +from metagpt.environment.base_env_space import ( + BaseEnvAction, + BaseEnvActionType, + BaseEnvObsParams, + BaseEnvObsType, +) + + +class EnvActionType(BaseEnvActionType): + NONE = 0 # no action to run, just get observation + + SYSTEM_BACK = 1 + SYSTEM_TAP = 2 + USER_INPUT = 3 + USER_LONGPRESS = 4 + USER_SWIPE = 5 + USER_SWIPE_TO = 6 + + +class EnvAction(BaseEnvAction): + model_config = ConfigDict(arbitrary_types_allowed=True) + + action_type: int = Field(default=EnvActionType.NONE, description="action type") + coord: npt.NDArray[np.int64] = Field( + default_factory=lambda: np.zeros(2, dtype=np.int64), description="operation coordinate" + ) + tgt_coord: npt.NDArray[np.int64] = Field( + default_factory=lambda: np.zeros(2, dtype=np.int64), description="target operation coordinate" + ) + input_txt: str = Field(default="", description="user input text") + orient: str = Field(default="up", description="swipe orient") + dist: str = Field(default="medium", description="swipe dist") + + @field_validator("coord", "tgt_coord", mode="before") + @classmethod + def check_coord(cls, coord) -> npt.NDArray[np.int64]: + if not isinstance(coord, np.ndarray): + return np.array(coord) + + +class EnvObsType(BaseEnvObsType): + NONE = 0 # get whole observation from env + + GET_SCREENSHOT = 1 + GET_XML = 2 + + +class EnvObsParams(BaseEnvObsParams): + model_config = ConfigDict(arbitrary_types_allowed=True) + + obs_type: int = Field(default=EnvObsType.NONE, description="observation type") + ss_name: str = Field(default="", description="screenshot file name") + xml_name: str = Field(default="", description="xml file name") + local_save_dir: str = Field(default="", description="local dir to save file") + + +EnvObsValType = str + + +def get_observation_space() -> spaces.Dict: + space = spaces.Dict({"screenshot": spaces.Text(256), "xml": spaces.Text(256)}) + return space + + +def get_action_space(device_shape: tuple[int, int]) -> spaces.Dict: + space = spaces.Dict( + { + "action_type": spaces.Discrete(len(EnvActionType)), + "coord": spaces.Box( + np.array([0, 0], dtype=np.int64), np.array([device_shape[0], device_shape[1]], dtype=np.int64) + ), + "tgt_coord": spaces.Box( + np.array([0, 0], dtype=np.int64), np.array([device_shape[0], device_shape[1]], dtype=np.int64) + ), + "input_txt": spaces.Text(256), + "orient": spaces.Text(16), + "dist": spaces.Text(16), + } + ) + return space