From 84a8c0d0bd7fc61d76fd573e151010df25cf06b1 Mon Sep 17 00:00:00 2001 From: kit <101046518@qq.com> Date: Fri, 26 Apr 2024 11:52:54 +0800 Subject: [PATCH] Merge remote-tracking branch 'origin/main' --- .../environment/android/android_ext_env.py | 36 +++++++++++++++---- .../android/text_icon_localization.py | 27 +++----------- 2 files changed, 33 insertions(+), 30 deletions(-) diff --git a/metagpt/environment/android/android_ext_env.py b/metagpt/environment/android/android_ext_env.py index 9a3e5a4c0..e15d7fe3f 100644 --- a/metagpt/environment/android/android_ext_env.py +++ b/metagpt/environment/android/android_ext_env.py @@ -34,10 +34,16 @@ class AndroidExtEnv(ExtEnv): xml_dir: Optional[Path] = Field(default=None) width: int = Field(default=720, description="device screen width") height: int = Field(default=1080, description="device screen height") + cv_model_status: dict = Field(default=None, description="Record model loading status") def __init__(self, **data: Any): super().__init__(**data) device_id = data.get("device_id") + self.cv_model_status = { + 'ocr_detection_loaded': False, + 'ocr_recognition_loaded': False, + 'clip_model_loaded': False + } if device_id: devices = self.list_devices() if device_id not in devices: @@ -45,8 +51,8 @@ class AndroidExtEnv(ExtEnv): (width, height) = self.device_shape self.width = data.get("width", width) self.height = data.get("height", height) - self.create_device_path(self.screenshot_dir) - self.create_device_path(self.xml_dir) + #self.create_device_path(self.screenshot_dir) + #self.create_device_path(self.xml_dir) def reset( self, @@ -167,7 +173,16 @@ class AndroidExtEnv(ExtEnv): if pull_res != ADB_EXEC_FAIL: res = ss_local_path else: - res = get_screenshot_only(local_save_dir) + ss_cmd = f"{self.adb_prefix_shell} rm /sdcard/screenshot.png" + ss_res = self.execute_adb_with_cmd(ss_cmd) + time.sleep(0.1) + ss_cmd = f"{self.adb_prefix_shell} screencap -p /sdcard/screenshot.png" + ss_res = self.execute_adb_with_cmd(ss_cmd) + time.sleep(0.1) + ss_cmd = f"{self.adb_prefix} pull /sdcard/screenshot.png {self.screenshot_dir}" + ss_res = self.execute_adb_with_cmd(ss_cmd) + image_path = Path(f"{self.screenshot_dir}/screenshot.png") + res = image_path return Path(res) @mark_as_readable @@ -246,12 +261,17 @@ class AndroidExtEnv(ExtEnv): exit_res = self.execute_adb_with_cmd(adb_cmd) return exit_res + def _ocr_text(self, text: str) -> list: if not self.screenshot_dir.exists(): self.screenshot_dir.mkdir(parents=True, exist_ok=True) image = self.get_screenshot("screenshot", self.screenshot_dir) - ocr_detection = pipeline(Tasks.ocr_detection, model="damo/cv_resnet18_ocr-detection-line-level_damo") - ocr_recognition = pipeline(Tasks.ocr_recognition, model="damo/cv_convnextTiny_ocr-recognition-document_damo") + if self.cv_model_status['ocr_detection_loaded'] == False: + ocr_detection = pipeline(Tasks.ocr_detection, model="damo/cv_resnet18_ocr-detection-line-level_damo") + self.cv_model_status['ocr_detection_loaded'] = True + if self.cv_model_status['ocr_recognition_loaded'] == False: + ocr_recognition = pipeline(Tasks.ocr_recognition, model="damo/cv_convnextTiny_ocr-recognition-document_damo") + self.cv_model_status['ocr_recognition_loaded'] == True iw, ih = Image.open(image).size x, y = self.device_shape if iw > ih: @@ -312,7 +332,9 @@ class AndroidExtEnv(ExtEnv): file_url = 'https://huggingface.co/ShilongLiu/GroundingDINO/blob/main/groundingdino_swint_ogc.pth' # 加载远程model target_folder = Path(f'{DEFAULT_WORKSPACE_ROOT}/weights') file_path = download_model(file_url, target_folder) - groundingdino_model = load_model(file_path, device=device).eval() + if self.cv_model_status['clip_model_loaded'] == False: + groundingdino_model = load_model(file_path, device=device).eval() + self.cv_model_status['clip_model_loaded'] = True in_coordinate, out_coordinate = det(image, "icon", groundingdino_model) # 检测icon if len(out_coordinate) == 1: # only one icon tap_coordinate = [(in_coordinate[0][0] + in_coordinate[0][2]) / 2, @@ -328,7 +350,7 @@ class AndroidExtEnv(ExtEnv): for i, (td, box) in enumerate(zip(in_coordinate, out_coordinate)): if crop_for_clip(image, td, i, temp_file): hash_table.append(td) - crop_image = f"{i}.jpg" + crop_image = f"{i}.png" clip_filter.append(temp_file.joinpath(crop_image)) clip_model, clip_preprocess = clip.load("ViT-B/32", device=device) clip_filter = clip_for_icon(clip_model, clip_preprocess, clip_filter, icon_shape_color) diff --git a/metagpt/environment/android/text_icon_localization.py b/metagpt/environment/android/text_icon_localization.py index 4dd17ca60..2021acec4 100644 --- a/metagpt/environment/android/text_icon_localization.py +++ b/metagpt/environment/android/text_icon_localization.py @@ -221,7 +221,7 @@ def crop_for_clip(image: any, box: any, i: int, temp_file: Path) -> bool: bound = [0, 0, w, h] if in_box(box, bound): cropped_image = image.crop(box) - cropped_image.save(temp_file.joinpath(f"{i}.jpg")) + cropped_image.save(temp_file.joinpath(f"{i}.png")) return True else: return False @@ -271,7 +271,7 @@ def load_model(model_checkpoint_path: Path, device: str) -> any: return model -def get_grounding_output(model: any, image: any, caption: str, box_threshold: any, text_threshold: any, with_logits=True) -> any: +def get_grounding_output(model: any, image: any, caption: str, box_threshold: any, text_threshold: any, with_logits: bool = True) -> any: caption = caption.lower() caption = caption.strip() if not caption.endswith("."): @@ -306,7 +306,7 @@ def get_grounding_output(model: any, image: any, caption: str, box_threshold: an return boxes_filt, torch.Tensor(scores), pred_phrases -def remove_boxes(boxes_filt: any, size: any, iou_threshold=0.5) -> any: +def remove_boxes(boxes_filt: any, size: any, iou_threshold: float = 0.5) -> any: boxes_to_remove = set() for i in range(len(boxes_filt)): @@ -328,7 +328,7 @@ def remove_boxes(boxes_filt: any, size: any, iou_threshold=0.5) -> any: return boxes_filt -def det(input_image: any, text_prompt: str, groundingdino_model: any, box_threshold=0.05, text_threshold=0.5) -> any: +def det(input_image: any, text_prompt: str, groundingdino_model: any, box_threshold:float = 0.05, text_threshold:float = 0.5) -> any: image = Image.open(input_image) size = image.size @@ -361,22 +361,3 @@ def det(input_image: any, text_prompt: str, groundingdino_model: any, box_thresh return image_data, coordinate -def get_screenshot_only(screenshot_dir: Path) -> Path: - command = " adb shell rm /sdcard/screenshot.png" - subprocess.run(command, capture_output=True, text=True, shell=True) - time.sleep(0.1) - command = "adb shell screencap -p /sdcard/screenshot.png" - subprocess.run(command, capture_output=True, text=True, shell=True) - time.sleep(0.1) - command = f"adb pull /sdcard/screenshot.png {screenshot_dir}" - subprocess.run(command, capture_output=True, text=True, shell=True) - image_path = Path(f"{screenshot_dir}/screenshot.png") - save_path = Path(f"{screenshot_dir}/screenshot.jpg") - image = Image.open(image_path) - original_width, original_height = image.size - new_width = int(original_width * 0.5) - new_height = int(original_height * 0.5) - resized_image = image.resize((new_width, new_height)) - resized_image.convert("RGB").save(save_path, "JPEG") - time.sleep(0.1) - return save_path