mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-01 20:03:28 +02:00
fix conflict
This commit is contained in:
commit
f9d64d4184
3 changed files with 94 additions and 2 deletions
|
|
@ -75,6 +75,7 @@ class Config(CLIParams, YamlModel):
|
|||
iflytek_api_key: str = ""
|
||||
azure_tts_subscription_key: str = ""
|
||||
azure_tts_region: str = ""
|
||||
other: dict = dict() # other dict
|
||||
|
||||
@classmethod
|
||||
def from_home(cls, path):
|
||||
|
|
@ -136,7 +137,6 @@ class Config(CLIParams, YamlModel):
|
|||
else:
|
||||
return self.other.get(key, default_value)
|
||||
|
||||
|
||||
def get_openai_llm(self) -> Optional[LLMConfig]:
|
||||
"""Get OpenAI LLMConfig by name. If no OpenAI, raise Exception"""
|
||||
if self.llm.api_type == LLMType.OPENAI:
|
||||
|
|
|
|||
92
metagpt/environment/android_env/env_space.py
Normal file
92
metagpt/environment/android_env/env_space.py
Normal file
|
|
@ -0,0 +1,92 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from gymnasium import spaces
|
||||
from pydantic import ConfigDict, Field, field_validator
|
||||
|
||||
from metagpt.environment.base_env_space import (
|
||||
BaseEnvAction,
|
||||
BaseEnvActionType,
|
||||
BaseEnvObsParams,
|
||||
BaseEnvObsType,
|
||||
)
|
||||
|
||||
|
||||
class EnvActionType(BaseEnvActionType):
|
||||
NONE = 0 # no action to run, just get observation
|
||||
|
||||
SYSTEM_BACK = 1
|
||||
SYSTEM_TAP = 2
|
||||
USER_INPUT = 3
|
||||
USER_LONGPRESS = 4
|
||||
USER_SWIPE = 5
|
||||
USER_SWIPE_TO = 6
|
||||
|
||||
|
||||
class EnvAction(BaseEnvAction):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
action_type: int = Field(default=EnvActionType.NONE, description="action type")
|
||||
coord: npt.NDArray[np.int64] = Field(
|
||||
default_factory=lambda: np.zeros(2, dtype=np.int64), description="operation coordinate"
|
||||
)
|
||||
tgt_coord: npt.NDArray[np.int64] = Field(
|
||||
default_factory=lambda: np.zeros(2, dtype=np.int64), description="target operation coordinate"
|
||||
)
|
||||
input_txt: str = Field(default="", description="user input text")
|
||||
orient: str = Field(default="up", description="swipe orient")
|
||||
dist: str = Field(default="medium", description="swipe dist")
|
||||
|
||||
@field_validator("coord", "tgt_coord", mode="before")
|
||||
@classmethod
|
||||
def check_coord(cls, coord) -> npt.NDArray[np.int64]:
|
||||
if not isinstance(coord, np.ndarray):
|
||||
return np.array(coord)
|
||||
|
||||
|
||||
class EnvObsType(BaseEnvObsType):
|
||||
NONE = 0 # get whole observation from env
|
||||
|
||||
GET_SCREENSHOT = 1
|
||||
GET_XML = 2
|
||||
|
||||
|
||||
class EnvObsParams(BaseEnvObsParams):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
obs_type: int = Field(default=EnvObsType.NONE, description="observation type")
|
||||
ss_name: str = Field(default="", description="screenshot file name")
|
||||
xml_name: str = Field(default="", description="xml file name")
|
||||
local_save_dir: Union[str, Path] = Field(default="", description="local dir to save file")
|
||||
|
||||
|
||||
EnvObsValType = str
|
||||
|
||||
|
||||
def get_observation_space() -> spaces.Dict:
|
||||
space = spaces.Dict({"screenshot": spaces.Text(256), "xml": spaces.Text(256)})
|
||||
return space
|
||||
|
||||
|
||||
def get_action_space(device_shape: tuple[int, int]) -> spaces.Dict:
|
||||
space = spaces.Dict(
|
||||
{
|
||||
"action_type": spaces.Discrete(len(EnvActionType)),
|
||||
"coord": spaces.Box(
|
||||
np.array([0, 0], dtype=np.int64), np.array([device_shape[0], device_shape[1]], dtype=np.int64)
|
||||
),
|
||||
"tgt_coord": spaces.Box(
|
||||
np.array([0, 0], dtype=np.int64), np.array([device_shape[0], device_shape[1]], dtype=np.int64)
|
||||
),
|
||||
"input_txt": spaces.Text(256),
|
||||
"orient": spaces.Text(16),
|
||||
"dist": spaces.Text(16),
|
||||
}
|
||||
)
|
||||
return space
|
||||
|
|
@ -22,7 +22,7 @@ class EnvAPIRegistry(BaseModel):
|
|||
|
||||
def get(self, api_name: str):
|
||||
if api_name not in self.registry:
|
||||
raise ValueError
|
||||
raise KeyError(f"api_name: {api_name} not found")
|
||||
return self.registry.get(api_name)
|
||||
|
||||
def __getitem__(self, api_name: str) -> Callable:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue