From cf9d86b83263e699f31b0b383ecc80f8a0cc1988 Mon Sep 17 00:00:00 2001 From: kit <101046518@qq.com> Date: Mon, 29 Apr 2024 15:19:42 +0800 Subject: [PATCH] Merge remote-tracking branch 'origin/main' --- .../environment/android/android_ext_env.py | 57 +++++++++---------- .../android/text_icon_localization.py | 2 +- requirements.txt | 4 +- setup.py | 1 + 4 files changed, 29 insertions(+), 35 deletions(-) diff --git a/metagpt/environment/android/android_ext_env.py b/metagpt/environment/android/android_ext_env.py index bf64c0988..060d956a3 100644 --- a/metagpt/environment/android/android_ext_env.py +++ b/metagpt/environment/android/android_ext_env.py @@ -28,22 +28,31 @@ from metagpt.utils.common import download_model from metagpt.const import DEFAULT_WORKSPACE_ROOT +def load_cv_model(device: str = "cpu") -> any: + ocr_detection = pipeline(Tasks.ocr_detection, model="damo/cv_resnet18_ocr-detection-line-level_damo") + ocr_recognition = pipeline(Tasks.ocr_recognition, + model="damo/cv_convnextTiny_ocr-recognition-document_damo") + file_url = "https://huggingface.co/ShilongLiu/GroundingDINO/blob/main/groundingdino_swint_ogc.pth" + target_folder = Path(f"{DEFAULT_WORKSPACE_ROOT}/weights") + file_path = download_model(file_url, target_folder) + groundingdino_model = load_model(file_path, device=device).eval() + return ocr_detection, ocr_recognition, groundingdino_model + + class AndroidExtEnv(ExtEnv): device_id: Optional[str] = Field(default=None) screenshot_dir: Optional[Path] = Field(default=None) xml_dir: Optional[Path] = Field(default=None) width: int = Field(default=720, description="device screen width") height: int = Field(default=1080, description="device screen height") - cv_model_status: dict = Field(default=None, description="Record model loading status") + ocr_detection: any = Field(default=None, description="ocr detection model") + ocr_recognition: any = Field(default=None, description="ocr recognition model") + groundingdino_model: any = Field(default=None, description="clip groundingdino model") def __init__(self, **data: Any): super().__init__(**data) device_id = data.get("device_id") - self.cv_model_status = { - 'ocr_detection_loaded': False, - 'ocr_recognition_loaded': False, - 'clip_model_loaded': False - } + self.ocr_detection, self.ocr_recognition, self.groundingdino_model = load_cv_model() if device_id: devices = self.list_devices() if device_id not in devices: @@ -51,8 +60,8 @@ class AndroidExtEnv(ExtEnv): (width, height) = self.device_shape self.width = data.get("width", width) self.height = data.get("height", height) - self.create_device_path(self.screenshot_dir) - self.create_device_path(self.xml_dir) + #self.create_device_path(self.screenshot_dir) + #self.create_device_path(self.xml_dir) def reset( self, @@ -173,16 +182,16 @@ class AndroidExtEnv(ExtEnv): if pull_res != ADB_EXEC_FAIL: res = ss_local_path else: - ss_cmd = f"{self.adb_prefix_shell} rm /sdcard/screenshot.png" + ss_cmd = f"{self.adb_prefix_shell} rm /sdcard/{ss_name}.png" ss_res = self.execute_adb_with_cmd(ss_cmd) time.sleep(0.1) - ss_cmd = f"{self.adb_prefix_shell} screencap -p /sdcard/screenshot.png" + ss_cmd = f"{self.adb_prefix_shell} screencap -p /sdcard/{ss_name}.png" ss_res = self.execute_adb_with_cmd(ss_cmd) time.sleep(0.1) - ss_cmd = f"{self.adb_prefix} pull /sdcard/screenshot.png {self.screenshot_dir}" + ss_cmd = f"{self.adb_prefix} pull /sdcard/{ss_name}.png {self.screenshot_dir}" ss_res = self.execute_adb_with_cmd(ss_cmd) - image_path = Path(f"{self.screenshot_dir}/screenshot.png") - res = image_path + image_path = Path(f"{self.screenshot_dir}/{ss_name}.png") + res = image_path return Path(res) @mark_as_readable @@ -261,23 +270,16 @@ class AndroidExtEnv(ExtEnv): exit_res = self.execute_adb_with_cmd(adb_cmd) return exit_res - def _ocr_text(self, text: str) -> list: if not self.screenshot_dir.exists(): self.screenshot_dir.mkdir(parents=True, exist_ok=True) image = self.get_screenshot("screenshot", self.screenshot_dir) - if self.cv_model_status['ocr_detection_loaded'] == False: - ocr_detection = pipeline(Tasks.ocr_detection, model="damo/cv_resnet18_ocr-detection-line-level_damo") - self.cv_model_status['ocr_detection_loaded'] = True - if self.cv_model_status['ocr_recognition_loaded'] == False: - ocr_recognition = pipeline(Tasks.ocr_recognition, model="damo/cv_convnextTiny_ocr-recognition-document_damo") - self.cv_model_status['ocr_recognition_loaded'] == True iw, ih = Image.open(image).size x, y = self.device_shape if iw > ih: x, y = y, x iw, ih = ih, iw - in_coordinate, out_coordinate = ocr(image, text, ocr_detection, ocr_recognition, iw, ih) + in_coordinate, out_coordinate = ocr(image, text, self.ocr_detection, self.ocr_recognition, iw, ih) output_list = [in_coordinate, out_coordinate, x, y, iw, ih, image] return output_list @@ -323,19 +325,13 @@ class AndroidExtEnv(ExtEnv): if not self.screenshot_dir.exists(): self.screenshot_dir.mkdir(parents=True, exist_ok=True) screenshot_path = self.get_screenshot("screenshot", self.screenshot_dir) - image, device = screenshot_path, 'cpu' + image= screenshot_path iw, ih = Image.open(image).size x, y = self.device_shape if iw > ih: x, y = y, x iw, ih = ih, iw - if self.cv_model_status['clip_model_loaded'] == False: - file_url = 'https://huggingface.co/ShilongLiu/GroundingDINO/blob/main/groundingdino_swint_ogc.pth' # 加载远程model - target_folder = Path(f'{DEFAULT_WORKSPACE_ROOT}/weights') - file_path = download_model(file_url, target_folder) - groundingdino_model = load_model(file_path, device=device).eval() - self.cv_model_status['clip_model_loaded'] = True - in_coordinate, out_coordinate = det(image, "icon", groundingdino_model) # 检测icon + in_coordinate, out_coordinate = det(image, "icon", self.groundingdino_model) # 检测icon if len(out_coordinate) == 1: # only one icon tap_coordinate = [(in_coordinate[0][0] + in_coordinate[0][2]) / 2, (in_coordinate[0][1] + in_coordinate[0][3]) / 2] @@ -344,8 +340,7 @@ class AndroidExtEnv(ExtEnv): else: temp_file = Path(f"{DEFAULT_WORKSPACE_ROOT}/temp") - if not temp_file.exists(): - temp_file.mkdir(parents=True, exist_ok=True) + temp_file.mkdir(parents=True, exist_ok=True) hash_table, clip_filter = [], [] for i, (td, box) in enumerate(zip(in_coordinate, out_coordinate)): if crop_for_clip(image, td, i, temp_file): diff --git a/metagpt/environment/android/text_icon_localization.py b/metagpt/environment/android/text_icon_localization.py index 2021acec4..60d62ed03 100644 --- a/metagpt/environment/android/text_icon_localization.py +++ b/metagpt/environment/android/text_icon_localization.py @@ -260,7 +260,7 @@ def transform_image(image_pil: any) -> any: def load_model(model_checkpoint_path: Path, device: str) -> any: - model_config_path = 'grounding_dino_config.py' + model_config_path = "grounding_dino_config.py" args = SLConfig.fromfile(model_config_path) args.device = device model = build_model(args) diff --git a/requirements.txt b/requirements.txt index 93816c8ef..d150d61f3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -70,6 +70,4 @@ qianfan==0.3.2 dashscope==1.14.1 rank-bm25==0.2.2 # for tool recommendation jieba==0.42.1 # for tool recommendation -gymnasium==0.29.1 - - +gymnasium==0.29.1 \ No newline at end of file diff --git a/setup.py b/setup.py index 1368a67fd..daa86f88c 100644 --- a/setup.py +++ b/setup.py @@ -48,6 +48,7 @@ extras_require = { "protobuf<3.20,>=3.9.2", "modelscope", "tensorflow==2.9.1; os_name == 'linux'", + "tensorflow==2.9.1; os_name == 'win32'", "tensorflow-macos==2.9; os_name == 'darwin'", "keras==2.9.0", "torch",