From 4f18c4b4b5765b63d0f5fe765577c0f1d411ce34 Mon Sep 17 00:00:00 2001 From: kithib <1010465183@qq.com> Date: Thu, 18 Apr 2024 15:36:55 +0800 Subject: [PATCH] Merge remote-tracking branch 'origin/main' --- .../android/GroundingDINO_SwinT_OGC.py | 43 +++++++ .../environment/android/android_ext_env.py | 111 +++++++++++++++--- .../android/text_icon_localization.py | 67 +++++------ metagpt/utils/download_modelweight.py | 22 ++++ requirements.txt | 22 +++- 5 files changed, 207 insertions(+), 58 deletions(-) create mode 100644 metagpt/environment/android/GroundingDINO_SwinT_OGC.py create mode 100644 metagpt/utils/download_modelweight.py diff --git a/metagpt/environment/android/GroundingDINO_SwinT_OGC.py b/metagpt/environment/android/GroundingDINO_SwinT_OGC.py new file mode 100644 index 000000000..9158d5f62 --- /dev/null +++ b/metagpt/environment/android/GroundingDINO_SwinT_OGC.py @@ -0,0 +1,43 @@ +batch_size = 1 +modelname = "groundingdino" +backbone = "swin_T_224_1k" +position_embedding = "sine" +pe_temperatureH = 20 +pe_temperatureW = 20 +return_interm_indices = [1, 2, 3] +backbone_freeze_keywords = None +enc_layers = 6 +dec_layers = 6 +pre_norm = False +dim_feedforward = 2048 +hidden_dim = 256 +dropout = 0.0 +nheads = 8 +num_queries = 900 +query_dim = 4 +num_patterns = 0 +num_feature_levels = 4 +enc_n_points = 4 +dec_n_points = 4 +two_stage_type = "standard" +two_stage_bbox_embed_share = False +two_stage_class_embed_share = False +transformer_activation = "relu" +dec_pred_bbox_embed_share = True +dn_box_noise_scale = 1.0 +dn_label_noise_ratio = 0.5 +dn_label_coef = 1.0 +dn_bbox_coef = 1.0 +embed_init_tgt = True +dn_labelbook_size = 2000 +max_text_len = 256 +text_encoder_type = "bert-base-uncased" +use_text_enhancer = True +use_fusion_layer = True +use_checkpoint = True +use_transformer_ckpt = True +use_text_cross_attention = True +text_dropout = 0.0 +fusion_dropout = 0.0 +fusion_droppath = 0.1 +sub_sentence_present = True diff --git a/metagpt/environment/android/android_ext_env.py b/metagpt/environment/android/android_ext_env.py index 152c71d04..230a351ad 100644 --- a/metagpt/environment/android/android_ext_env.py +++ b/metagpt/environment/android/android_ext_env.py @@ -1,17 +1,20 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # @Desc : The Android external environment to integrate with Android apps - +import os import subprocess +import clip +import time from pathlib import Path from typing import Any, Optional from modelscope.pipelines import pipeline from modelscope.utils.constant import Tasks + from PIL import Image from pydantic import Field -from metagpt.environment.android.text_icon_localization import ocr +from metagpt.environment.android.text_icon_localization import * from metagpt.environment.android.const import ADB_EXEC_FAIL from metagpt.environment.android.env_space import ( EnvAction, @@ -22,6 +25,7 @@ from metagpt.environment.android.env_space import ( ) from metagpt.environment.base_env import ExtEnv, mark_as_readable, mark_as_writeable from metagpt.logs import logger +from metagpt.utils.download_modelweight import download_model class AndroidExtEnv(ExtEnv): @@ -42,14 +46,14 @@ class AndroidExtEnv(ExtEnv): self.width = data.get("width", width) self.height = data.get("height", height) - self.create_device_path(self.screenshot_dir) - self.create_device_path(self.xml_dir) + # self.create_device_path(self.screenshot_dir) + # self.create_device_path(self.xml_dir) def reset( - self, - *, - seed: Optional[int] = None, - options: Optional[dict[str, Any]] = None, + self, + *, + seed: Optional[int] = None, + options: Optional[dict[str, Any]] = None, ) -> tuple[dict[str, Any], dict[str, Any]]: super().reset(seed=seed, options=options) @@ -154,14 +158,17 @@ class AndroidExtEnv(ExtEnv): ss_remote_path = Path(self.screenshot_dir).joinpath(f"{ss_name}.png") ss_cmd = f"{self.adb_prefix_shell} screencap -p {ss_remote_path}" ss_res = self.execute_adb_with_cmd(ss_cmd) - + time.sleep(0.1) res = ADB_EXEC_FAIL if ss_res != ADB_EXEC_FAIL: ss_local_path = Path(local_save_dir).joinpath(f"{ss_name}.png") pull_cmd = f"{self.adb_prefix} pull {ss_remote_path} {ss_local_path}" pull_res = self.execute_adb_with_cmd(pull_cmd) + time.sleep(0.1) if pull_res != ADB_EXEC_FAIL: res = ss_local_path + else: + res = get_screenshot_only(local_save_dir) return Path(res) @mark_as_readable @@ -229,22 +236,22 @@ class AndroidExtEnv(ExtEnv): return swipe_res @mark_as_writeable - def user_swipe_to(self, start: tuple[int, int], end: tuple[int, int], duration: int = 400): + def user_swipe_to(self, start: tuple[int, int], end: tuple[int, int], duration: int = 400) -> str: adb_cmd = f"{self.adb_prefix_si} swipe {start[0]} {start[1]} {end[0]} {end[1]} {duration}" swipe_res = self.execute_adb_with_cmd(adb_cmd) return swipe_res @mark_as_writeable - def user_exit(self): - adb_cmd = "adb shell am start -a android.intent.action.MAIN -c android.intent.category.HOME" + def user_exit(self) -> str: + adb_cmd = f"{self.adb_prefix_shell} am start -a android.intent.action.MAIN -c android.intent.category.HOME" exit_res = self.execute_adb_with_cmd(adb_cmd) return exit_res @mark_as_writeable - def user_openApp(self, app_name: str): - # openApp without xml - screenshot_path = self.get_screenshot("screenshot", "../../../examples/data/screenshot") - image = screenshot_path + def _ocr_text(self, text: str) -> list: + if not os.path.exists(self.screenshot_dir): + os.makedirs(self.screenshot_dir) + image = self.get_screenshot("screenshot", self.screenshot_dir) ocr_detection = pipeline(Tasks.ocr_detection, model="damo/cv_resnet18_ocr-detection-line-level_damo") ocr_recognition = pipeline(Tasks.ocr_recognition, model="damo/cv_convnextTiny_ocr-recognition-document_damo") iw, ih = Image.open(image).size @@ -252,7 +259,15 @@ class AndroidExtEnv(ExtEnv): if iw > ih: x, y = y, x iw, ih = ih, iw - in_coordinate, out_coordinate = ocr(image, app_name, ocr_detection, ocr_recognition, iw, ih) + in_coordinate, out_coordinate = ocr(image, text, ocr_detection, ocr_recognition, iw, ih) + output_list = [in_coordinate, out_coordinate, x, y, iw, ih, image] + return output_list + + @mark_as_writeable + def user_open_app(self, app_name: str) -> str: + ocr_result = self._ocr_text(app_name) + in_coordinate, out_coordinate, x, y, iw, ih = ( + ocr_result[0], ocr_result[1], ocr_result[2], ocr_result[3], ocr_result[4], ocr_result[5]) if len(in_coordinate) == 0: logger.info(f"No App named {app_name}.") return "no" @@ -262,11 +277,69 @@ class AndroidExtEnv(ExtEnv): (in_coordinate[0][1] + in_coordinate[0][3]) / 2, ] tap_coordinate = [round(tap_coordinate[0] / iw, 2), round(tap_coordinate[1] / ih, 2)] - #print(f"{parameter}在屏幕的坐标为为{tap_coordinate[0] * x} ,{(tap_coordinate[1] - round(50 / y, 2)) * y}") return self.system_tap(tap_coordinate[0] * x, (tap_coordinate[1] - round(50 / y, 2)) * y) + @mark_as_writeable + def user_click_text(self, text: str) -> str: + ocr_result = self._ocr_text(text) + in_coordinate, out_coordinate, x, y, iw, ih, image = ( + ocr_result[0], ocr_result[1], ocr_result[2], ocr_result[3], ocr_result[4], ocr_result[5], ocr_result[6]) + if len(out_coordinate) == 0: + logger.info( + f"Failed to execute action click text ({text}). The text \"{text}\" is not detected in the screenshot.") + elif len(out_coordinate) == 1: + tap_coordinate = [(in_coordinate[0][0] + in_coordinate[0][2]) / 2, + (in_coordinate[0][1] + in_coordinate[0][3]) / 2] + tap_coordinate = [round(tap_coordinate[0] / iw, 2), round(tap_coordinate[1] / ih, 2)] + return self.system_tap(tap_coordinate[0] * x, tap_coordinate[1] * y) + else: + logger.info( + f"Failed to execute action click text ({text}). There are too many text \"{text}\" in the screenshot.") + @mark_as_writeable def user_stop(self): logger.info("Successful execution of tasks") - # todo : user_clickIcon + @mark_as_writeable + def user_click_icon(self, icon_shape_color: str) -> str: + if not os.path.exists(self.screenshot_dir): + os.makedirs(self.screenshot_dir) + screenshot_path = self.get_screenshot("screenshot", self.screenshot_dir) + image, device = screenshot_path, 'cpu' + iw, ih = Image.open(image).size + x, y = self.device_shape + if iw > ih: + x, y = y, x + iw, ih = ih, iw + # 下载权重文件 + file_url = 'https://huggingface.co/ShilongLiu/GroundingDINO/blob/main/groundingdino_swint_ogc.pth' # 加载远程model + target_folder = '/Users/kit/Desktop/深度赋值/amzingproject/MetaGPT/workspace/weights' + file_path = download_model(file_url, target_folder) + groundingdino_model = load_model(file_path, device=device).eval() + in_coordinate, out_coordinate = det(image, "icon", groundingdino_model) # 检测icon + if len(out_coordinate) == 1: # only one icon + tap_coordinate = [(in_coordinate[0][0] + in_coordinate[0][2]) / 2, + (in_coordinate[0][1] + in_coordinate[0][3]) / 2] + tap_coordinate = [round(tap_coordinate[0] / iw, 2), round(tap_coordinate[1] / ih, 2)] + return self.system_tap(tap_coordinate[0] * x, tap_coordinate[1] * y) + + else: + temp_file = "/Users/kit/Desktop/深度赋值/amzingproject/MetaGPT/workspace/temp" + if not os.path.exists(temp_file): + os.mkdir(temp_file) + hash_table, clip_filter= [],[] + for i, (td, box) in enumerate(zip(in_coordinate, out_coordinate)): + if crop_for_clip(image, td, i, temp_file): + hash_table.append(td) + crop_image = f"{i}.jpg" + clip_filter.append(os.path.join(temp_file, crop_image)) + clip_model, clip_preprocess = clip.load("ViT-B/32", device=device) + clip_filter = clip_for_icon(clip_model, clip_preprocess, clip_filter, icon_shape_color) + final_box = hash_table[clip_filter] + tap_coordinate = [(final_box[0] + final_box[2]) / 2, (final_box[1] + final_box[3]) / 2] + tap_coordinate = [round(tap_coordinate[0] / iw, 2), round(tap_coordinate[1] / ih, 2)] + return self.system_tap(tap_coordinate[0] * x, tap_coordinate[1] * y) + + + + diff --git a/metagpt/environment/android/text_icon_localization.py b/metagpt/environment/android/text_icon_localization.py index d1b5ba2f9..8c3d22c7c 100644 --- a/metagpt/environment/android/text_icon_localization.py +++ b/metagpt/environment/android/text_icon_localization.py @@ -3,7 +3,9 @@ import clip import cv2 import numpy as np import torch - +import subprocess +import time +from pathlib import Path import groundingdino.datasets.transforms as T from groundingdino.models import build_model from groundingdino.util.slconfig import SLConfig @@ -12,6 +14,7 @@ from PIL import Image, ImageDraw ################################## text_localization using ocr ####################### + def crop_image(img, position): def distance(x1, y1, x2, y2): return math.sqrt(pow(x1 - x2, 2) + pow(y1 - y2, 2)) @@ -96,7 +99,7 @@ def ocr(image_path, prompt, ocr_detection, ocr_recognition, x, y): image = Image.open(image_path) iw, ih = image.size - image_full = cv2.imread(image_path) + image_full = cv2.imread(str(image_path)) det_result = ocr_detection(image_full) det_result = det_result["polygons"] for i in range(det_result.shape[0]): @@ -205,7 +208,6 @@ def calculate_iou(box1, box2): def crop(image, box, i, text_data=None): image = Image.open(image) - if text_data: draw = ImageDraw.Draw(image) draw.rectangle(((text_data[0], text_data[1]), (text_data[2], text_data[3])), outline="red", width=5) @@ -224,31 +226,13 @@ def in_box(box, target): return False -def crop_for_clip(image, box, i, position): +def crop_for_clip(image, box, i, temp_file): image = Image.open(image) w, h = image.size - if position == "left": - bound = [0, 0, w / 2, h] - elif position == "right": - bound = [w / 2, 0, w, h] - elif position == "top": - bound = [0, 0, w, h / 2] - elif position == "bottom": - bound = [0, h / 2, w, h] - elif position == "top left": - bound = [0, 0, w / 2, h / 2] - elif position == "top right": - bound = [w / 2, 0, w, h / 2] - elif position == "bottom left": - bound = [0, h / 2, w / 2, h] - elif position == "bottom right": - bound = [w / 2, h / 2, w, h] - else: - bound = [0, 0, w, h] - + bound = [0, 0, w, h] if in_box(box, bound): cropped_image = image.crop(box) - cropped_image.save(f"./temp/{i}.jpg") + cropped_image.save(f"{temp_file}/{i}.jpg") return True else: return False @@ -286,7 +270,8 @@ def transform_image(image_pil): return image -def load_model(model_config_path, model_checkpoint_path, device): +def load_model(model_checkpoint_path, device): + model_config_path = 'GroundingDINO_SwinT_OGC.py' args = SLConfig.fromfile(model_config_path) args.device = device model = build_model(args) @@ -387,17 +372,23 @@ def det(input_image, text_prompt, groundingdino_model, box_threshold=0.05, text_ return image_data, coordinate -if __name__ == "__main__": - from modelscope.pipelines import pipeline - from modelscope.utils.constant import Tasks +def get_screenshot_only(screenshot_dir: Path) -> str: + command = " adb shell rm /sdcard/screenshot.png" + subprocess.run(command, capture_output=True, text=True, shell=True) + time.sleep(0.1) + command = "adb shell screencap -p /sdcard/screenshot.png" + subprocess.run(command, capture_output=True, text=True, shell=True) + time.sleep(0.1) + command = f"adb pull /sdcard/screenshot.png {screenshot_dir}" + subprocess.run(command, capture_output=True, text=True, shell=True) + image_path = f"{screenshot_dir}/screenshot.png" + save_path = f"{screenshot_dir}/screenshot.jpg" + image = Image.open(image_path) + original_width, original_height = image.size + new_width = int(original_width * 0.5) + new_height = int(original_height * 0.5) + resized_image = image.resize((new_width, new_height)) + resized_image.convert("RGB").save(save_path, "JPEG") + time.sleep(0.1) + return save_path - image_ori = "./screenshot/screenshot.png" - image = "./screenshot/screenshot.png" - parameter = "抖音" - ocr_detection = pipeline(Tasks.ocr_detection, model="damo/cv_resnet18_ocr-detection-line-level_damo") - ocr_recognition = pipeline(Tasks.ocr_recognition, model="damo/cv_convnextTiny_ocr-recognition-document_damo") - iw, ih = Image.open(image).size - if iw > ih: - iw, ih = ih, iw - in_coordinate, out_coordinate = ocr(image_ori, parameter, ocr_detection, ocr_recognition, iw, ih) - print(f"ocr 计算结果为 {in_coordinate} ,{out_coordinate} ") diff --git a/metagpt/utils/download_modelweight.py b/metagpt/utils/download_modelweight.py new file mode 100644 index 000000000..2b8bcf41b --- /dev/null +++ b/metagpt/utils/download_modelweight.py @@ -0,0 +1,22 @@ +import os +import requests +from pathlib import Path + + +def download_model(file_url: str, target_folder: str) -> str: + file_name = file_url.split('/')[-1] # 文件名(从URL中提取) + file_path = os.path.join(target_folder, file_name) # 完整的文件路径 + if not os.path.exists(target_folder): + os.makedirs(target_folder) + # 发起GET请求下载文件 + try: + response = requests.get(file_url, stream=True) + response.raise_for_status() # 检查请求是否成功 + # 保存文件 + with open(file_path, 'wb') as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + print(f'权重文件已下载并保存至 {file_path}') + except requests.exceptions.HTTPError as err: + print(f'权重文件下载过程中发生错误: {err}') + return file_path diff --git a/requirements.txt b/requirements.txt index d150d61f3..46832e943 100644 --- a/requirements.txt +++ b/requirements.txt @@ -70,4 +70,24 @@ qianfan==0.3.2 dashscope==1.14.1 rank-bm25==0.2.2 # for tool recommendation jieba==0.42.1 # for tool recommendation -gymnasium==0.29.1 \ No newline at end of file +gymnasium==0.29.1 + +# for clip and ocr +git+https://github.com/openai/CLIP.git +protobuf<3.20,>=3.9.2 +modelscope +tensorflow==2.9.1; os_name == 'linux' +tensorflow-macos==2.9; os_name == 'darwin' +keras==2.9.0 +torch +torchvision +transformers +opencv-python +matplotlib +pycocotools +timm +SentencePiece +tf_slim +tf_keras +pyclipper +shapely \ No newline at end of file