diff --git a/metagpt/config2.py b/metagpt/config2.py index 757e52506..eb4c38368 100644 --- a/metagpt/config2.py +++ b/metagpt/config2.py @@ -75,6 +75,7 @@ class Config(CLIParams, YamlModel): iflytek_api_key: str = "" azure_tts_subscription_key: str = "" azure_tts_region: str = "" + other: dict = dict() # other dict @classmethod def from_home(cls, path): @@ -136,7 +137,6 @@ class Config(CLIParams, YamlModel): else: return self.other.get(key, default_value) - def get_openai_llm(self) -> Optional[LLMConfig]: """Get OpenAI LLMConfig by name. If no OpenAI, raise Exception""" if self.llm.api_type == LLMType.OPENAI: diff --git a/metagpt/environment/android_env/env_space.py b/metagpt/environment/android_env/env_space.py new file mode 100644 index 000000000..9580e3a7d --- /dev/null +++ b/metagpt/environment/android_env/env_space.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +from pathlib import Path +from typing import Union + +import numpy as np +import numpy.typing as npt +from gymnasium import spaces +from pydantic import ConfigDict, Field, field_validator + +from metagpt.environment.base_env_space import ( + BaseEnvAction, + BaseEnvActionType, + BaseEnvObsParams, + BaseEnvObsType, +) + + +class EnvActionType(BaseEnvActionType): + NONE = 0 # no action to run, just get observation + + SYSTEM_BACK = 1 + SYSTEM_TAP = 2 + USER_INPUT = 3 + USER_LONGPRESS = 4 + USER_SWIPE = 5 + USER_SWIPE_TO = 6 + + +class EnvAction(BaseEnvAction): + model_config = ConfigDict(arbitrary_types_allowed=True) + + action_type: int = Field(default=EnvActionType.NONE, description="action type") + coord: npt.NDArray[np.int64] = Field( + default_factory=lambda: np.zeros(2, dtype=np.int64), description="operation coordinate" + ) + tgt_coord: npt.NDArray[np.int64] = Field( + default_factory=lambda: np.zeros(2, dtype=np.int64), description="target operation coordinate" + ) + input_txt: str = Field(default="", description="user input text") + orient: str = Field(default="up", description="swipe orient") + dist: str = Field(default="medium", description="swipe dist") + + @field_validator("coord", "tgt_coord", mode="before") + @classmethod + def check_coord(cls, coord) -> npt.NDArray[np.int64]: + if not isinstance(coord, np.ndarray): + return np.array(coord) + + +class EnvObsType(BaseEnvObsType): + NONE = 0 # get whole observation from env + + GET_SCREENSHOT = 1 + GET_XML = 2 + + +class EnvObsParams(BaseEnvObsParams): + model_config = ConfigDict(arbitrary_types_allowed=True) + + obs_type: int = Field(default=EnvObsType.NONE, description="observation type") + ss_name: str = Field(default="", description="screenshot file name") + xml_name: str = Field(default="", description="xml file name") + local_save_dir: Union[str, Path] = Field(default="", description="local dir to save file") + + +EnvObsValType = str + + +def get_observation_space() -> spaces.Dict: + space = spaces.Dict({"screenshot": spaces.Text(256), "xml": spaces.Text(256)}) + return space + + +def get_action_space(device_shape: tuple[int, int]) -> spaces.Dict: + space = spaces.Dict( + { + "action_type": spaces.Discrete(len(EnvActionType)), + "coord": spaces.Box( + np.array([0, 0], dtype=np.int64), np.array([device_shape[0], device_shape[1]], dtype=np.int64) + ), + "tgt_coord": spaces.Box( + np.array([0, 0], dtype=np.int64), np.array([device_shape[0], device_shape[1]], dtype=np.int64) + ), + "input_txt": spaces.Text(256), + "orient": spaces.Text(16), + "dist": spaces.Text(16), + } + ) + return space diff --git a/metagpt/environment/api/env_api.py b/metagpt/environment/api/env_api.py index 60c3215c8..924f6b104 100644 --- a/metagpt/environment/api/env_api.py +++ b/metagpt/environment/api/env_api.py @@ -22,7 +22,7 @@ class EnvAPIRegistry(BaseModel): def get(self, api_name: str): if api_name not in self.registry: - raise ValueError + raise KeyError(f"api_name: {api_name} not found") return self.registry.get(api_name) def __getitem__(self, api_name: str) -> Callable: