diff --git a/metagpt/environment/android/android_ext_env.py b/metagpt/environment/android/android_ext_env.py index 78f27923f..63a421fa2 100644 --- a/metagpt/environment/android/android_ext_env.py +++ b/metagpt/environment/android/android_ext_env.py @@ -2,18 +2,17 @@ # -*- coding: utf-8 -*- # @Desc : The Android external environment to integrate with Android apps import subprocess -import clip import time from pathlib import Path from typing import Any, Optional +import clip from modelscope.pipelines import pipeline from modelscope.utils.constant import Tasks - from PIL import Image from pydantic import Field -from metagpt.environment.android.text_icon_localization import * +from metagpt.const import DEFAULT_WORKSPACE_ROOT from metagpt.environment.android.const import ADB_EXEC_FAIL from metagpt.environment.android.env_space import ( EnvAction, @@ -22,16 +21,21 @@ from metagpt.environment.android.env_space import ( EnvObsType, EnvObsValType, ) +from metagpt.environment.android.text_icon_localization import ( + clip_for_icon, + crop_for_clip, + det, + load_model, + ocr, +) from metagpt.environment.base_env import ExtEnv, mark_as_readable, mark_as_writeable from metagpt.logs import logger from metagpt.utils.common import download_model -from metagpt.const import DEFAULT_WORKSPACE_ROOT def load_cv_model(device: str = "cpu") -> any: ocr_detection = pipeline(Tasks.ocr_detection, model="damo/cv_resnet18_ocr-detection-line-level_damo") - ocr_recognition = pipeline(Tasks.ocr_recognition, - model="damo/cv_convnextTiny_ocr-recognition-document_damo") + ocr_recognition = pipeline(Tasks.ocr_recognition, model="damo/cv_convnextTiny_ocr-recognition-document_damo") file_url = "https://huggingface.co/ShilongLiu/GroundingDINO/blob/main/groundingdino_swint_ogc.pth" target_folder = Path(f"{DEFAULT_WORKSPACE_ROOT}/weights") file_path = download_model(file_url, target_folder) @@ -64,10 +68,10 @@ class AndroidExtEnv(ExtEnv): self.create_device_path(self.xml_dir) def reset( - self, - *, - seed: Optional[int] = None, - options: Optional[dict[str, Any]] = None, + self, + *, + seed: Optional[int] = None, + options: Optional[dict[str, Any]] = None, ) -> tuple[dict[str, Any], dict[str, Any]]: super().reset(seed=seed, options=options) @@ -284,8 +288,14 @@ class AndroidExtEnv(ExtEnv): @mark_as_writeable def user_open_app(self, app_name: str) -> str: ocr_result = self._ocr_text(app_name) - in_coordinate, out_coordinate, x, y, iw, ih = ( - ocr_result[0], ocr_result[1], ocr_result[2], ocr_result[3], ocr_result[4], ocr_result[5]) + in_coordinate, _, x, y, iw, ih = ( + ocr_result[0], + ocr_result[1], + ocr_result[2], + ocr_result[3], + ocr_result[4], + ocr_result[5], + ) if len(in_coordinate) == 0: logger.info(f"No App named {app_name}.") return "no app here" @@ -300,19 +310,30 @@ class AndroidExtEnv(ExtEnv): @mark_as_writeable def user_click_text(self, text: str) -> str: ocr_result = self._ocr_text(text) - in_coordinate, out_coordinate, x, y, iw, ih, image = ( - ocr_result[0], ocr_result[1], ocr_result[2], ocr_result[3], ocr_result[4], ocr_result[5], ocr_result[6]) + in_coordinate, out_coordinate, x, y, iw, ih, _ = ( + ocr_result[0], + ocr_result[1], + ocr_result[2], + ocr_result[3], + ocr_result[4], + ocr_result[5], + ocr_result[6], + ) if len(out_coordinate) == 0: logger.info( - f"Failed to execute action click text ({text}). The text \"{text}\" is not detected in the screenshot.") + f'Failed to execute action click text ({text}). The text "{text}" is not detected in the screenshot.' + ) elif len(out_coordinate) == 1: - tap_coordinate = [(in_coordinate[0][0] + in_coordinate[0][2]) / 2, - (in_coordinate[0][1] + in_coordinate[0][3]) / 2] + tap_coordinate = [ + (in_coordinate[0][0] + in_coordinate[0][2]) / 2, + (in_coordinate[0][1] + in_coordinate[0][3]) / 2, + ] tap_coordinate = [round(tap_coordinate[0] / iw, 2), round(tap_coordinate[1] / ih, 2)] return self.system_tap(tap_coordinate[0] * x, tap_coordinate[1] * y) else: logger.info( - f"Failed to execute action click text ({text}). There are too many text \"{text}\" in the screenshot.") + f'Failed to execute action click text ({text}). There are too many text "{text}" in the screenshot.' + ) @mark_as_writeable def user_stop(self): @@ -321,7 +342,7 @@ class AndroidExtEnv(ExtEnv): @mark_as_writeable def user_click_icon(self, icon_shape_color: str) -> str: screenshot_path = self.get_screenshot("screenshot", self.screenshot_dir) - image= screenshot_path + image = screenshot_path iw, ih = Image.open(image).size x, y = self.device_shape if iw > ih: @@ -329,8 +350,10 @@ class AndroidExtEnv(ExtEnv): iw, ih = ih, iw in_coordinate, out_coordinate = det(image, "icon", self.groundingdino_model) # 检测icon if len(out_coordinate) == 1: # only one icon - tap_coordinate = [(in_coordinate[0][0] + in_coordinate[0][2]) / 2, - (in_coordinate[0][1] + in_coordinate[0][3]) / 2] + tap_coordinate = [ + (in_coordinate[0][0] + in_coordinate[0][2]) / 2, + (in_coordinate[0][1] + in_coordinate[0][3]) / 2, + ] tap_coordinate = [round(tap_coordinate[0] / iw, 2), round(tap_coordinate[1] / ih, 2)] return self.system_tap(tap_coordinate[0] * x, tap_coordinate[1] * y) @@ -343,7 +366,7 @@ class AndroidExtEnv(ExtEnv): hash_table.append(td) crop_image = f"{i}.png" clip_filter.append(temp_file.joinpath(crop_image)) - clip_model, clip_preprocess = clip.load("ViT-B/32", device=device) + clip_model, clip_preprocess = clip.load("ViT-B/32") # FIXME: device=device clip_filter = clip_for_icon(clip_model, clip_preprocess, clip_filter, icon_shape_color) final_box = hash_table[clip_filter] tap_coordinate = [(final_box[0] + final_box[2]) / 2, (final_box[1] + final_box[3]) / 2] diff --git a/metagpt/environment/android/text_icon_localization.py b/metagpt/environment/android/text_icon_localization.py index 60d62ed03..e8886b540 100644 --- a/metagpt/environment/android/text_icon_localization.py +++ b/metagpt/environment/android/text_icon_localization.py @@ -2,22 +2,21 @@ # https://github.com/X-PLUG/MobileAgent.git import math +from pathlib import Path + import clip import cv2 +import groundingdino.datasets.transforms as T import numpy as np import torch -import subprocess -import time -from pathlib import Path -import groundingdino.datasets.transforms as T from groundingdino.models import build_model from groundingdino.util.slconfig import SLConfig from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap -from PIL import Image, ImageDraw - +from PIL import Image ################################## text_localization using ocr ####################### + def crop_image(img: any, position: any) -> any: def distance(x1, y1, x2, y2): return math.sqrt(pow(x1 - x2, 2) + pow(y1 - y2, 2)) @@ -271,7 +270,9 @@ def load_model(model_checkpoint_path: Path, device: str) -> any: return model -def get_grounding_output(model: any, image: any, caption: str, box_threshold: any, text_threshold: any, with_logits: bool = True) -> any: +def get_grounding_output( + model: any, image: any, caption: str, box_threshold: any, text_threshold: any, with_logits: bool = True +) -> any: caption = caption.lower() caption = caption.strip() if not caption.endswith("."): @@ -328,7 +329,13 @@ def remove_boxes(boxes_filt: any, size: any, iou_threshold: float = 0.5) -> any: return boxes_filt -def det(input_image: any, text_prompt: str, groundingdino_model: any, box_threshold:float = 0.05, text_threshold:float = 0.5) -> any: +def det( + input_image: any, + text_prompt: str, + groundingdino_model: any, + box_threshold: float = 0.05, + text_threshold: float = 0.5, +) -> any: image = Image.open(input_image) size = image.size @@ -359,5 +366,3 @@ def det(input_image: any, text_prompt: str, groundingdino_model: any, box_thresh ) return image_data, coordinate - - diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index 982e6921b..cf490084d 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -219,7 +219,7 @@ class OutputParser: if start_index != -1 and end_index != -1: # Extract the structure part - structure_text = text[start_index: end_index + 1] + structure_text = text[start_index : end_index + 1] try: # Attempt to convert the text to a Python data type using ast.literal_eval @@ -844,7 +844,7 @@ def get_markdown_codeblock_type(filename: str) -> str: def download_model(file_url: str, target_folder: Path) -> Path: - file_name = file_url.split('/')[-1] + file_name = file_url.split("/")[-1] file_path = target_folder.joinpath(f"{file_name}") if not file_path.exists(): file_path.mkdir(parents=True, exist_ok=True) @@ -852,10 +852,10 @@ def download_model(file_url: str, target_folder: Path) -> Path: response = requests.get(file_url, stream=True) response.raise_for_status() # 检查请求是否成功 # 保存文件 - with open(file_path, 'wb') as f: + with open(file_path, "wb") as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) - logger.info(f'权重文件已下载并保存至 {file_path}') + logger.info(f"权重文件已下载并保存至 {file_path}") except requests.exceptions.HTTPError as err: - logger.info(f'权重文件下载过程中发生错误: {err}') + logger.info(f"权重文件下载过程中发生错误: {err}") return file_path diff --git a/setup.py b/setup.py index 43c043720..6a15d5eda 100644 --- a/setup.py +++ b/setup.py @@ -67,7 +67,7 @@ extras_require = { "shapely", "groundingdino-py", "datasets==2.18.0", - "clip-openai" + "clip-openai", ], } @@ -119,5 +119,4 @@ setup( ], }, include_package_data=True, - )