update android_env to simplify code

This commit is contained in:
better629 2024-03-27 22:25:22 +08:00
parent 9604dec795
commit c1308f98ba
18 changed files with 200 additions and 209 deletions

View file

@ -2,6 +2,8 @@
# -*- coding: utf-8 -*-
# @Desc :
from pathlib import Path
from typing import Union
import numpy as np
import numpy.typing as npt
@ -61,7 +63,7 @@ class EnvObsParams(BaseEnvObsParams):
obs_type: int = Field(default=EnvObsType.NONE, description="observation type")
ss_name: str = Field(default="", description="screenshot file name")
xml_name: str = Field(default="", description="xml file name")
local_save_dir: str = Field(default="", description="local dir to save file")
local_save_dir: Union[str, Path] = Field(default="", description="local dir to save file")
EnvObsValType = str

View file

@ -5,8 +5,11 @@
from pydantic import Field
from metagpt.environment.android_env.android_ext_env import AndroidExtEnv
from metagpt.environment.base_env import Environment
class AndroidEnv(AndroidExtEnv):
class AndroidEnv(AndroidExtEnv, Environment):
"""in order to use actual `reset`&`observe`, inherited order: AndroidExtEnv, Environment"""
rows: int = Field(default=0, description="rows of a grid on the screenshot")
cols: int = Field(default=0, description="cols of a grid on the screenshot")

View file

@ -8,16 +8,18 @@ from typing import Any, Optional
from pydantic import Field
from metagpt.const import ADB_EXEC_FAIL
from metagpt.environment.base_env import (
Environment,
ExtEnv,
mark_as_readable,
mark_as_writeable,
from metagpt.environment.android.const import ADB_EXEC_FAIL
from metagpt.environment.android.env_space import (
EnvAction,
EnvActionType,
EnvObsParams,
EnvObsType,
EnvObsValType,
)
from metagpt.environment.base_env import ExtEnv, mark_as_readable, mark_as_writeable
class AndroidExtEnv(Environment, ExtEnv):
class AndroidExtEnv(ExtEnv):
device_id: Optional[str] = Field(default=None)
screenshot_dir: Optional[Path] = Field(default=None)
xml_dir: Optional[Path] = Field(default=None)
@ -26,11 +28,70 @@ class AndroidExtEnv(Environment, ExtEnv):
def __init__(self, **data: Any):
super().__init__(**data)
if data.get("device_id"):
device_id = data.get("device_id")
if device_id:
devices = self.list_devices()
if device_id not in devices:
raise RuntimeError(f"device-id: {device_id} not found")
(width, height) = self.device_shape
self.width = data.get("width", width)
self.height = data.get("height", height)
self.create_device_path(self.screenshot_dir)
self.create_device_path(self.xml_dir)
def reset(
self,
*,
seed: Optional[int] = None,
options: Optional[dict[str, Any]] = None,
) -> tuple[dict[str, Any], dict[str, Any]]:
super().reset(seed=seed, options=options)
obs = self._get_obs()
return obs, {}
def _get_obs(self) -> dict[str, EnvObsValType]:
pass
def observe(self, obs_params: Optional[EnvObsParams] = None) -> Any:
obs_type = obs_params.obs_type if obs_params else EnvObsType.NONE
if obs_type == EnvObsType.NONE:
pass
elif obs_type == EnvObsType.GET_SCREENSHOT:
obs = self.get_screenshot(ss_name=obs_params.ss_name, local_save_dir=obs_params.local_save_dir)
elif obs_type == EnvObsType.GET_XML:
obs = self.get_xml(xml_name=obs_params.xml_name, local_save_dir=obs_params.local_save_dir)
return obs
def step(self, action: EnvAction) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]:
res = self._execute_env_action(action)
obs = {}
ret = (obs, 1.0, False, False, {"res": res})
return ret
def _execute_env_action(self, action: EnvAction):
action_type = action.action_type
res = None
if action_type == EnvActionType.NONE:
pass
elif action_type == EnvActionType.SYSTEM_BACK:
res = self.system_back()
elif action_type == EnvActionType.SYSTEM_TAP:
res = self.system_tap(x=action.coord[0], y=action.coord[1])
elif action_type == EnvActionType.USER_INPUT:
res = self.user_input(input_txt=action.input_txt)
elif action_type == EnvActionType.USER_LONGPRESS:
res = self.user_longpress(x=action.coord[0], y=action.coord[1])
elif action_type == EnvActionType.USER_SWIPE:
res = self.user_swipe(x=action.coord[0], y=action.coord[1], orient=action.orient, dist=action.dist)
elif action_type == EnvActionType.USER_SWIPE_TO:
res = self.user_swipe_to(start=action.coord, end=action.tgt_coord)
return res
@property
def adb_prefix_si(self):
"""adb cmd prefix with `device_id` and `shell input`"""
@ -54,6 +115,12 @@ class AndroidExtEnv(Environment, ExtEnv):
exec_res = res.stdout.strip()
return exec_res
def create_device_path(self, folder_path: Path):
adb_cmd = f"{self.adb_prefix_shell} mkdir {folder_path} -p"
res = self.execute_adb_with_cmd(adb_cmd)
if res == ADB_EXEC_FAIL:
raise RuntimeError(f"create device path: {folder_path} failed")
@property
def device_shape(self) -> tuple[int, int]:
adb_cmd = f"{self.adb_prefix_shell} wm size"