mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-15 11:02:36 +02:00
add android_env obs/action space
This commit is contained in:
parent
5180980b8a
commit
cbb2e66cd9
2 changed files with 149 additions and 15 deletions
|
|
@ -9,8 +9,14 @@ from typing import Any, Optional
|
|||
from pydantic import Field
|
||||
|
||||
from metagpt.environment.android.const import ADB_EXEC_FAIL
|
||||
from metagpt.environment.android.env_space import (
|
||||
EnvAction,
|
||||
EnvActionType,
|
||||
EnvObsParams,
|
||||
EnvObsType,
|
||||
EnvObsValType,
|
||||
)
|
||||
from metagpt.environment.base_env import ExtEnv, mark_as_readable, mark_as_writeable
|
||||
from metagpt.environment.base_env_space import BaseEnvAction, BaseEnvObsParams
|
||||
|
||||
|
||||
class AndroidExtEnv(ExtEnv):
|
||||
|
|
@ -20,20 +26,6 @@ class AndroidExtEnv(ExtEnv):
|
|||
width: int = Field(default=720, description="device screen width")
|
||||
height: int = Field(default=1080, description="device screen height")
|
||||
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
options: Optional[dict[str, Any]] = None,
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
pass
|
||||
|
||||
def observe(self, obs_params: Optional[BaseEnvObsParams] = None) -> Any:
|
||||
pass
|
||||
|
||||
def step(self, action: BaseEnvAction) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]:
|
||||
pass
|
||||
|
||||
def __init__(self, **data: Any):
|
||||
super().__init__(**data)
|
||||
if data.get("device_id"):
|
||||
|
|
@ -41,6 +33,58 @@ class AndroidExtEnv(ExtEnv):
|
|||
self.width = data.get("width", width)
|
||||
self.height = data.get("height", height)
|
||||
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
options: Optional[dict[str, Any]] = None,
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
super().reset(seed=seed, options=options)
|
||||
|
||||
obs = self._get_obs()
|
||||
|
||||
return obs, {}
|
||||
|
||||
def _get_obs(self) -> dict[str, EnvObsValType]:
|
||||
pass
|
||||
|
||||
def observe(self, obs_params: Optional[EnvObsParams] = None) -> Any:
|
||||
obs_type = obs_params.obs_type if obs_params else EnvObsType.NONE
|
||||
if obs_type == EnvObsType.NONE:
|
||||
pass
|
||||
elif obs_type == EnvObsType.GET_SCREENSHOT:
|
||||
obs = self.get_screenshot(ss_name=obs_params.ss_name, local_save_dir=obs_params.local_save_dir)
|
||||
elif obs_type == EnvObsType.GET_XML:
|
||||
obs = self.get_xml(xml_name=obs_params.xml_name, local_save_dir=obs_params.local_save_dir)
|
||||
return obs
|
||||
|
||||
def step(self, action: EnvAction) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]:
|
||||
res = self._execute_env_action(action)
|
||||
|
||||
obs = {}
|
||||
|
||||
ret = (obs, 1.0, False, False, {"res": res})
|
||||
return ret
|
||||
|
||||
def _execute_env_action(self, action: EnvAction):
|
||||
action_type = action.action_type
|
||||
res = None
|
||||
if action_type == EnvActionType.NONE:
|
||||
pass
|
||||
elif action_type == EnvActionType.SYSTEM_BACK:
|
||||
res = self.system_back()
|
||||
elif action_type == EnvActionType.SYSTEM_TAP:
|
||||
res = self.system_tap(x=action.coord[0], y=action.coord[1])
|
||||
elif action_type == EnvActionType.USER_INPUT:
|
||||
res = self.user_input(input_txt=action.input_txt)
|
||||
elif action_type == EnvActionType.USER_LONGPRESS:
|
||||
res = self.user_longpress(x=action.coord[0], y=action.coord[1])
|
||||
elif action_type == EnvActionType.USER_SWIPE:
|
||||
res = self.user_swipe(x=action.coord[0], y=action.coord[1], orient=action.orient, dist=action.dist)
|
||||
elif action_type == EnvActionType.USER_SWIPE_TO:
|
||||
res = self.user_swipe_to(start=action.coord, end=action.tgt_coord)
|
||||
return res
|
||||
|
||||
@property
|
||||
def adb_prefix_si(self):
|
||||
"""adb cmd prefix with `device_id` and `shell input`"""
|
||||
|
|
|
|||
90
metagpt/environment/android/env_space.py
Normal file
90
metagpt/environment/android/env_space.py
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from gymnasium import spaces
|
||||
from pydantic import ConfigDict, Field, field_validator
|
||||
|
||||
from metagpt.environment.base_env_space import (
|
||||
BaseEnvAction,
|
||||
BaseEnvActionType,
|
||||
BaseEnvObsParams,
|
||||
BaseEnvObsType,
|
||||
)
|
||||
|
||||
|
||||
class EnvActionType(BaseEnvActionType):
|
||||
NONE = 0 # no action to run, just get observation
|
||||
|
||||
SYSTEM_BACK = 1
|
||||
SYSTEM_TAP = 2
|
||||
USER_INPUT = 3
|
||||
USER_LONGPRESS = 4
|
||||
USER_SWIPE = 5
|
||||
USER_SWIPE_TO = 6
|
||||
|
||||
|
||||
class EnvAction(BaseEnvAction):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
action_type: int = Field(default=EnvActionType.NONE, description="action type")
|
||||
coord: npt.NDArray[np.int64] = Field(
|
||||
default_factory=lambda: np.zeros(2, dtype=np.int64), description="operation coordinate"
|
||||
)
|
||||
tgt_coord: npt.NDArray[np.int64] = Field(
|
||||
default_factory=lambda: np.zeros(2, dtype=np.int64), description="target operation coordinate"
|
||||
)
|
||||
input_txt: str = Field(default="", description="user input text")
|
||||
orient: str = Field(default="up", description="swipe orient")
|
||||
dist: str = Field(default="medium", description="swipe dist")
|
||||
|
||||
@field_validator("coord", "tgt_coord", mode="before")
|
||||
@classmethod
|
||||
def check_coord(cls, coord) -> npt.NDArray[np.int64]:
|
||||
if not isinstance(coord, np.ndarray):
|
||||
return np.array(coord)
|
||||
|
||||
|
||||
class EnvObsType(BaseEnvObsType):
|
||||
NONE = 0 # get whole observation from env
|
||||
|
||||
GET_SCREENSHOT = 1
|
||||
GET_XML = 2
|
||||
|
||||
|
||||
class EnvObsParams(BaseEnvObsParams):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
obs_type: int = Field(default=EnvObsType.NONE, description="observation type")
|
||||
ss_name: str = Field(default="", description="screenshot file name")
|
||||
xml_name: str = Field(default="", description="xml file name")
|
||||
local_save_dir: str = Field(default="", description="local dir to save file")
|
||||
|
||||
|
||||
EnvObsValType = str
|
||||
|
||||
|
||||
def get_observation_space() -> spaces.Dict:
|
||||
space = spaces.Dict({"screenshot": spaces.Text(256), "xml": spaces.Text(256)})
|
||||
return space
|
||||
|
||||
|
||||
def get_action_space(device_shape: tuple[int, int]) -> spaces.Dict:
|
||||
space = spaces.Dict(
|
||||
{
|
||||
"action_type": spaces.Discrete(len(EnvActionType)),
|
||||
"coord": spaces.Box(
|
||||
np.array([0, 0], dtype=np.int64), np.array([device_shape[0], device_shape[1]], dtype=np.int64)
|
||||
),
|
||||
"tgt_coord": spaces.Box(
|
||||
np.array([0, 0], dtype=np.int64), np.array([device_shape[0], device_shape[1]], dtype=np.int64)
|
||||
),
|
||||
"input_txt": spaces.Text(256),
|
||||
"orient": spaces.Text(16),
|
||||
"dist": spaces.Text(16),
|
||||
}
|
||||
)
|
||||
return space
|
||||
Loading…
Add table
Add a link
Reference in a new issue