From d0e898dcfada0c3b64286cd9661efb2d88e8355b Mon Sep 17 00:00:00 2001 From: kithib <1010465183@qq.com> Date: Fri, 19 Apr 2024 17:56:36 +0800 Subject: [PATCH] Merge remote-tracking branch 'origin/main' --- .../environment/android/android_ext_env.py | 42 +++++----- ..._SwinT_OGC.py => grounding_dino_config.py} | 0 .../android/text_icon_localization.py | 76 ++++++++----------- metagpt/utils/common.py | 20 ++++- metagpt/utils/download_modelweight.py | 22 ------ requirements.txt | 19 ----- setup.py | 23 +++++- 7 files changed, 91 insertions(+), 111 deletions(-) rename metagpt/environment/android/{GroundingDINO_SwinT_OGC.py => grounding_dino_config.py} (100%) delete mode 100644 metagpt/utils/download_modelweight.py diff --git a/metagpt/environment/android/android_ext_env.py b/metagpt/environment/android/android_ext_env.py index eb8a69330..cba0636c7 100644 --- a/metagpt/environment/android/android_ext_env.py +++ b/metagpt/environment/android/android_ext_env.py @@ -1,7 +1,6 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # @Desc : The Android external environment to integrate with Android apps -import os import subprocess import clip import time @@ -25,7 +24,8 @@ from metagpt.environment.android.env_space import ( ) from metagpt.environment.base_env import ExtEnv, mark_as_readable, mark_as_writeable from metagpt.logs import logger -from metagpt.utils.download_modelweight import download_model +from metagpt.utils.common import download_model +from metagpt.const import DEFAULT_WORKSPACE_ROOT class AndroidExtEnv(ExtEnv): @@ -46,14 +46,14 @@ class AndroidExtEnv(ExtEnv): self.width = data.get("width", width) self.height = data.get("height", height) - self.create_device_path(self.screenshot_dir) - self.create_device_path(self.xml_dir) + #self.create_device_path(self.screenshot_dir) + #self.create_device_path(self.xml_dir) def reset( - self, - *, - seed: Optional[int] = None, - options: Optional[dict[str, Any]] = None, + self, + *, + seed: Optional[int] = None, + options: Optional[dict[str, Any]] = None, ) -> tuple[dict[str, Any], dict[str, Any]]: super().reset(seed=seed, options=options) @@ -247,10 +247,9 @@ class AndroidExtEnv(ExtEnv): exit_res = self.execute_adb_with_cmd(adb_cmd) return exit_res - @mark_as_writeable def _ocr_text(self, text: str) -> list: - if not os.path.exists(self.screenshot_dir): - os.makedirs(self.screenshot_dir) + if not self.screenshot_dir.exists(): + self.screenshot_dir.mkdir(parents=True, exist_ok=True) image = self.get_screenshot("screenshot", self.screenshot_dir) ocr_detection = pipeline(Tasks.ocr_detection, model="damo/cv_resnet18_ocr-detection-line-level_damo") ocr_recognition = pipeline(Tasks.ocr_recognition, model="damo/cv_convnextTiny_ocr-recognition-document_damo") @@ -302,8 +301,8 @@ class AndroidExtEnv(ExtEnv): @mark_as_writeable def user_click_icon(self, icon_shape_color: str) -> str: - if not os.path.exists(self.screenshot_dir): - os.makedirs(self.screenshot_dir) + if not self.screenshot_dir.exists(): + self.screenshot_dir.mkdir(parents=True, exist_ok=True) screenshot_path = self.get_screenshot("screenshot", self.screenshot_dir) image, device = screenshot_path, 'cpu' iw, ih = Image.open(image).size @@ -311,9 +310,8 @@ class AndroidExtEnv(ExtEnv): if iw > ih: x, y = y, x iw, ih = ih, iw - # 下载权重文件 file_url = 'https://huggingface.co/ShilongLiu/GroundingDINO/blob/main/groundingdino_swint_ogc.pth' # 加载远程model - target_folder = 'workspace/weights' + target_folder = Path(f'{DEFAULT_WORKSPACE_ROOT}/weights') file_path = download_model(file_url, target_folder) groundingdino_model = load_model(file_path, device=device).eval() in_coordinate, out_coordinate = det(image, "icon", groundingdino_model) # 检测icon @@ -324,22 +322,18 @@ class AndroidExtEnv(ExtEnv): return self.system_tap(tap_coordinate[0] * x, tap_coordinate[1] * y) else: - temp_file = "workspace/temp" - if not os.path.exists(temp_file): - os.mkdir(temp_file) - hash_table, clip_filter= [],[] + temp_file = Path(f"{DEFAULT_WORKSPACE_ROOT}/temp") + if not temp_file.exists(): + temp_file.mkdir(parents=True, exist_ok=True) + hash_table, clip_filter = [], [] for i, (td, box) in enumerate(zip(in_coordinate, out_coordinate)): if crop_for_clip(image, td, i, temp_file): hash_table.append(td) crop_image = f"{i}.jpg" - clip_filter.append(os.path.join(temp_file, crop_image)) + clip_filter.append(temp_file.joinpath(crop_image)) clip_model, clip_preprocess = clip.load("ViT-B/32", device=device) clip_filter = clip_for_icon(clip_model, clip_preprocess, clip_filter, icon_shape_color) final_box = hash_table[clip_filter] tap_coordinate = [(final_box[0] + final_box[2]) / 2, (final_box[1] + final_box[3]) / 2] tap_coordinate = [round(tap_coordinate[0] / iw, 2), round(tap_coordinate[1] / ih, 2)] return self.system_tap(tap_coordinate[0] * x, tap_coordinate[1] * y) - - - - diff --git a/metagpt/environment/android/GroundingDINO_SwinT_OGC.py b/metagpt/environment/android/grounding_dino_config.py similarity index 100% rename from metagpt/environment/android/GroundingDINO_SwinT_OGC.py rename to metagpt/environment/android/grounding_dino_config.py diff --git a/metagpt/environment/android/text_icon_localization.py b/metagpt/environment/android/text_icon_localization.py index 8c3d22c7c..4dd17ca60 100644 --- a/metagpt/environment/android/text_icon_localization.py +++ b/metagpt/environment/android/text_icon_localization.py @@ -1,3 +1,6 @@ +# The code in this file was modified by MobileAgent +# https://github.com/X-PLUG/MobileAgent.git + import math import clip import cv2 @@ -15,7 +18,7 @@ from PIL import Image, ImageDraw ################################## text_localization using ocr ####################### -def crop_image(img, position): +def crop_image(img: any, position: any) -> any: def distance(x1, y1, x2, y2): return math.sqrt(pow(x1 - x2, 2) + pow(y1 - y2, 2)) @@ -61,12 +64,12 @@ def crop_image(img, position): return dst -def calculate_size(box): +def calculate_size(box: any) -> any: return (box[2] - box[0]) * (box[3] - box[1]) -def order_point(coor): - arr = np.array(coor).reshape([4, 2]) +def order_point(cooperation: any) -> any: + arr = np.array(cooperation).reshape([4, 2]) sum_ = np.sum(arr, 0) centroid = sum_ / arr.shape[0] theta = np.arctan2(arr[:, 1] - centroid[1], arr[:, 0] - centroid[0]) @@ -78,11 +81,10 @@ def order_point(coor): return sort_points -def longest_common_substring_length(str1, str2): +def longest_common_substring_length(str1: str, str2: str) -> int: m = len(str1) n = len(str2) dp = [[0] * (n + 1) for _ in range(m + 1)] - for i in range(1, m + 1): for j in range(1, n + 1): if str1[i - 1] == str2[j - 1]: @@ -93,7 +95,7 @@ def longest_common_substring_length(str1, str2): return dp[m][n] -def ocr(image_path, prompt, ocr_detection, ocr_recognition, x, y): +def ocr(image_path: Path, prompt: str, ocr_detection: any, ocr_recognition: any, x: int, y: int) -> any: text_data = [] coordinate = [] image = Image.open(image_path) @@ -191,54 +193,41 @@ def ocr(image_path, prompt, ocr_detection, ocr_recognition, x, y): ################################## icon_localization using clip ####################### -def calculate_iou(box1, box2): - xA = max(box1[0], box2[0]) - yA = max(box1[1], box2[1]) - xB = min(box1[2], box2[2]) - yB = min(box1[3], box2[3]) +def calculate_iou(box1: list, box2: list) -> float: + x_a = max(box1[0], box2[0]) + y_a = max(box1[1], box2[1]) + x_b = min(box1[2], box2[2]) + y_b = min(box1[3], box2[3]) - interArea = max(0, xB - xA) * max(0, yB - yA) - box1Area = (box1[2] - box1[0]) * (box1[3] - box1[1]) - box2Area = (box2[2] - box2[0]) * (box2[3] - box2[1]) - unionArea = box1Area + box2Area - interArea - iou = interArea / unionArea + inter_area = max(0, x_b - x_a) * max(0, y_b - y_a) + box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1]) + box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1]) + union_area = box1_area + box2_area - inter_area + iou = inter_area / union_area return iou -def crop(image, box, i, text_data=None): - image = Image.open(image) - if text_data: - draw = ImageDraw.Draw(image) - draw.rectangle(((text_data[0], text_data[1]), (text_data[2], text_data[3])), outline="red", width=5) - # font_size = int((text_data[3] - text_data[1])*0.75) - # font = ImageFont.truetype("arial.ttf", font_size) - # draw.text((text_data[0]+5, text_data[1]+5), str(i), font=font, fill="red") - - cropped_image = image.crop(box) - cropped_image.save(f"./temp/{i}.jpg") - - -def in_box(box, target): +def in_box(box: list, target: list) -> bool: if (box[0] > target[0]) and (box[1] > target[1]) and (box[2] < target[2]) and (box[3] < target[3]): return True else: return False -def crop_for_clip(image, box, i, temp_file): +def crop_for_clip(image: any, box: any, i: int, temp_file: Path) -> bool: image = Image.open(image) w, h = image.size bound = [0, 0, w, h] if in_box(box, bound): cropped_image = image.crop(box) - cropped_image.save(f"{temp_file}/{i}.jpg") + cropped_image.save(temp_file.joinpath(f"{i}.jpg")) return True else: return False -def clip_for_icon(clip_model, clip_preprocess, images, prompt): +def clip_for_icon(clip_model: any, clip_preprocess: any, images: any, prompt: str) -> any: image_features = [] for image_file in images: image = clip_preprocess(Image.open(image_file)).unsqueeze(0).to(next(clip_model.parameters()).device) @@ -258,7 +247,7 @@ def clip_for_icon(clip_model, clip_preprocess, images, prompt): return pos -def transform_image(image_pil): +def transform_image(image_pil: any) -> any: transform = T.Compose( [ T.RandomResize([800], max_size=1333), @@ -270,8 +259,8 @@ def transform_image(image_pil): return image -def load_model(model_checkpoint_path, device): - model_config_path = 'GroundingDINO_SwinT_OGC.py' +def load_model(model_checkpoint_path: Path, device: str) -> any: + model_config_path = 'grounding_dino_config.py' args = SLConfig.fromfile(model_config_path) args.device = device model = build_model(args) @@ -282,7 +271,7 @@ def load_model(model_checkpoint_path, device): return model -def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True): +def get_grounding_output(model: any, image: any, caption: str, box_threshold: any, text_threshold: any, with_logits=True) -> any: caption = caption.lower() caption = caption.strip() if not caption.endswith("."): @@ -317,7 +306,7 @@ def get_grounding_output(model, image, caption, box_threshold, text_threshold, w return boxes_filt, torch.Tensor(scores), pred_phrases -def remove_boxes(boxes_filt, size, iou_threshold=0.5): +def remove_boxes(boxes_filt: any, size: any, iou_threshold=0.5) -> any: boxes_to_remove = set() for i in range(len(boxes_filt)): @@ -339,7 +328,7 @@ def remove_boxes(boxes_filt, size, iou_threshold=0.5): return boxes_filt -def det(input_image, text_prompt, groundingdino_model, box_threshold=0.05, text_threshold=0.5): +def det(input_image: any, text_prompt: str, groundingdino_model: any, box_threshold=0.05, text_threshold=0.5) -> any: image = Image.open(input_image) size = image.size @@ -372,7 +361,7 @@ def det(input_image, text_prompt, groundingdino_model, box_threshold=0.05, text_ return image_data, coordinate -def get_screenshot_only(screenshot_dir: Path) -> str: +def get_screenshot_only(screenshot_dir: Path) -> Path: command = " adb shell rm /sdcard/screenshot.png" subprocess.run(command, capture_output=True, text=True, shell=True) time.sleep(0.1) @@ -381,8 +370,8 @@ def get_screenshot_only(screenshot_dir: Path) -> str: time.sleep(0.1) command = f"adb pull /sdcard/screenshot.png {screenshot_dir}" subprocess.run(command, capture_output=True, text=True, shell=True) - image_path = f"{screenshot_dir}/screenshot.png" - save_path = f"{screenshot_dir}/screenshot.jpg" + image_path = Path(f"{screenshot_dir}/screenshot.png") + save_path = Path(f"{screenshot_dir}/screenshot.jpg") image = Image.open(image_path) original_width, original_height = image.size new_width = int(original_width * 0.5) @@ -391,4 +380,3 @@ def get_screenshot_only(screenshot_dir: Path) -> str: resized_image.convert("RGB").save(save_path, "JPEG") time.sleep(0.1) return save_path - diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index 0876b85ad..982e6921b 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -219,7 +219,7 @@ class OutputParser: if start_index != -1 and end_index != -1: # Extract the structure part - structure_text = text[start_index : end_index + 1] + structure_text = text[start_index: end_index + 1] try: # Attempt to convert the text to a Python data type using ast.literal_eval @@ -841,3 +841,21 @@ def get_markdown_codeblock_type(filename: str) -> str: "application/sql": "sql", } return mappings.get(mime_type, "text") + + +def download_model(file_url: str, target_folder: Path) -> Path: + file_name = file_url.split('/')[-1] + file_path = target_folder.joinpath(f"{file_name}") + if not file_path.exists(): + file_path.mkdir(parents=True, exist_ok=True) + try: + response = requests.get(file_url, stream=True) + response.raise_for_status() # 检查请求是否成功 + # 保存文件 + with open(file_path, 'wb') as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + logger.info(f'权重文件已下载并保存至 {file_path}') + except requests.exceptions.HTTPError as err: + logger.info(f'权重文件下载过程中发生错误: {err}') + return file_path diff --git a/metagpt/utils/download_modelweight.py b/metagpt/utils/download_modelweight.py deleted file mode 100644 index 2b8bcf41b..000000000 --- a/metagpt/utils/download_modelweight.py +++ /dev/null @@ -1,22 +0,0 @@ -import os -import requests -from pathlib import Path - - -def download_model(file_url: str, target_folder: str) -> str: - file_name = file_url.split('/')[-1] # 文件名(从URL中提取) - file_path = os.path.join(target_folder, file_name) # 完整的文件路径 - if not os.path.exists(target_folder): - os.makedirs(target_folder) - # 发起GET请求下载文件 - try: - response = requests.get(file_url, stream=True) - response.raise_for_status() # 检查请求是否成功 - # 保存文件 - with open(file_path, 'wb') as f: - for chunk in response.iter_content(chunk_size=8192): - f.write(chunk) - print(f'权重文件已下载并保存至 {file_path}') - except requests.exceptions.HTTPError as err: - print(f'权重文件下载过程中发生错误: {err}') - return file_path diff --git a/requirements.txt b/requirements.txt index 46832e943..c6d46fa25 100644 --- a/requirements.txt +++ b/requirements.txt @@ -71,23 +71,4 @@ dashscope==1.14.1 rank-bm25==0.2.2 # for tool recommendation jieba==0.42.1 # for tool recommendation gymnasium==0.29.1 - -# for clip and ocr -git+https://github.com/openai/CLIP.git -protobuf<3.20,>=3.9.2 -modelscope -tensorflow==2.9.1; os_name == 'linux' -tensorflow-macos==2.9; os_name == 'darwin' -keras==2.9.0 -torch -torchvision -transformers -opencv-python -matplotlib -pycocotools timm -SentencePiece -tf_slim -tf_keras -pyclipper -shapely \ No newline at end of file diff --git a/setup.py b/setup.py index e43bf3ed0..d33ac8e0f 100644 --- a/setup.py +++ b/setup.py @@ -42,7 +42,28 @@ extras_require = { "llama-index-vector-stores-chroma==0.1.6", "docx2txt==0.8", ], - "android_assistant": ["pyshine==0.0.9", "opencv-python==4.6.0.66"], + "android_assistant": [ + "pyshine==0.0.9", + "opencv-python==4.6.0.66", + "git+https://github.com/openai/CLIP.git", + "protobuf<3.20,>=3.9.2", + "modelscope", + "tensorflow==2.9.1; os_name == 'linux'", + "tensorflow-macos==2.9; os_name == 'darwin'", + "keras==2.9.0", + "torch", + "torchvision", + "transformers", + "opencv-python", + "matplotlib", + "pycocotools", + "SentencePiece", + "tf_slim", + "tf_keras", + "pyclipper", + "shapely", + "groundingdino-py", + ], } extras_require["test"] = [