add android_env obs/action space

This commit is contained in:
better629 2024-03-27 11:02:22 +08:00
parent 5180980b8a
commit cbb2e66cd9
2 changed files with 149 additions and 15 deletions

View file

@ -9,8 +9,14 @@ from typing import Any, Optional
from pydantic import Field
from metagpt.environment.android.const import ADB_EXEC_FAIL
from metagpt.environment.android.env_space import (
EnvAction,
EnvActionType,
EnvObsParams,
EnvObsType,
EnvObsValType,
)
from metagpt.environment.base_env import ExtEnv, mark_as_readable, mark_as_writeable
from metagpt.environment.base_env_space import BaseEnvAction, BaseEnvObsParams
class AndroidExtEnv(ExtEnv):
@ -20,20 +26,6 @@ class AndroidExtEnv(ExtEnv):
width: int = Field(default=720, description="device screen width")
height: int = Field(default=1080, description="device screen height")
def reset(
self,
*,
seed: Optional[int] = None,
options: Optional[dict[str, Any]] = None,
) -> tuple[dict[str, Any], dict[str, Any]]:
pass
def observe(self, obs_params: Optional[BaseEnvObsParams] = None) -> Any:
pass
def step(self, action: BaseEnvAction) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]:
pass
def __init__(self, **data: Any):
super().__init__(**data)
if data.get("device_id"):
@ -41,6 +33,58 @@ class AndroidExtEnv(ExtEnv):
self.width = data.get("width", width)
self.height = data.get("height", height)
def reset(
self,
*,
seed: Optional[int] = None,
options: Optional[dict[str, Any]] = None,
) -> tuple[dict[str, Any], dict[str, Any]]:
super().reset(seed=seed, options=options)
obs = self._get_obs()
return obs, {}
def _get_obs(self) -> dict[str, EnvObsValType]:
pass
def observe(self, obs_params: Optional[EnvObsParams] = None) -> Any:
obs_type = obs_params.obs_type if obs_params else EnvObsType.NONE
if obs_type == EnvObsType.NONE:
pass
elif obs_type == EnvObsType.GET_SCREENSHOT:
obs = self.get_screenshot(ss_name=obs_params.ss_name, local_save_dir=obs_params.local_save_dir)
elif obs_type == EnvObsType.GET_XML:
obs = self.get_xml(xml_name=obs_params.xml_name, local_save_dir=obs_params.local_save_dir)
return obs
def step(self, action: EnvAction) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]:
res = self._execute_env_action(action)
obs = {}
ret = (obs, 1.0, False, False, {"res": res})
return ret
def _execute_env_action(self, action: EnvAction):
action_type = action.action_type
res = None
if action_type == EnvActionType.NONE:
pass
elif action_type == EnvActionType.SYSTEM_BACK:
res = self.system_back()
elif action_type == EnvActionType.SYSTEM_TAP:
res = self.system_tap(x=action.coord[0], y=action.coord[1])
elif action_type == EnvActionType.USER_INPUT:
res = self.user_input(input_txt=action.input_txt)
elif action_type == EnvActionType.USER_LONGPRESS:
res = self.user_longpress(x=action.coord[0], y=action.coord[1])
elif action_type == EnvActionType.USER_SWIPE:
res = self.user_swipe(x=action.coord[0], y=action.coord[1], orient=action.orient, dist=action.dist)
elif action_type == EnvActionType.USER_SWIPE_TO:
res = self.user_swipe_to(start=action.coord, end=action.tgt_coord)
return res
@property
def adb_prefix_si(self):
"""adb cmd prefix with `device_id` and `shell input`"""

View file

@ -0,0 +1,90 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
import numpy as np
import numpy.typing as npt
from gymnasium import spaces
from pydantic import ConfigDict, Field, field_validator
from metagpt.environment.base_env_space import (
BaseEnvAction,
BaseEnvActionType,
BaseEnvObsParams,
BaseEnvObsType,
)
class EnvActionType(BaseEnvActionType):
NONE = 0 # no action to run, just get observation
SYSTEM_BACK = 1
SYSTEM_TAP = 2
USER_INPUT = 3
USER_LONGPRESS = 4
USER_SWIPE = 5
USER_SWIPE_TO = 6
class EnvAction(BaseEnvAction):
model_config = ConfigDict(arbitrary_types_allowed=True)
action_type: int = Field(default=EnvActionType.NONE, description="action type")
coord: npt.NDArray[np.int64] = Field(
default_factory=lambda: np.zeros(2, dtype=np.int64), description="operation coordinate"
)
tgt_coord: npt.NDArray[np.int64] = Field(
default_factory=lambda: np.zeros(2, dtype=np.int64), description="target operation coordinate"
)
input_txt: str = Field(default="", description="user input text")
orient: str = Field(default="up", description="swipe orient")
dist: str = Field(default="medium", description="swipe dist")
@field_validator("coord", "tgt_coord", mode="before")
@classmethod
def check_coord(cls, coord) -> npt.NDArray[np.int64]:
if not isinstance(coord, np.ndarray):
return np.array(coord)
class EnvObsType(BaseEnvObsType):
NONE = 0 # get whole observation from env
GET_SCREENSHOT = 1
GET_XML = 2
class EnvObsParams(BaseEnvObsParams):
model_config = ConfigDict(arbitrary_types_allowed=True)
obs_type: int = Field(default=EnvObsType.NONE, description="observation type")
ss_name: str = Field(default="", description="screenshot file name")
xml_name: str = Field(default="", description="xml file name")
local_save_dir: str = Field(default="", description="local dir to save file")
EnvObsValType = str
def get_observation_space() -> spaces.Dict:
space = spaces.Dict({"screenshot": spaces.Text(256), "xml": spaces.Text(256)})
return space
def get_action_space(device_shape: tuple[int, int]) -> spaces.Dict:
space = spaces.Dict(
{
"action_type": spaces.Discrete(len(EnvActionType)),
"coord": spaces.Box(
np.array([0, 0], dtype=np.int64), np.array([device_shape[0], device_shape[1]], dtype=np.int64)
),
"tgt_coord": spaces.Box(
np.array([0, 0], dtype=np.int64), np.array([device_shape[0], device_shape[1]], dtype=np.int64)
),
"input_txt": spaces.Text(256),
"orient": spaces.Text(16),
"dist": spaces.Text(16),
}
)
return space