fix pre-commit

This commit is contained in:
geekan 2024-05-17 19:13:26 +08:00
parent d9ed99e85f
commit 43df876f24
4 changed files with 66 additions and 39 deletions

View file

@ -2,18 +2,17 @@
# -*- coding: utf-8 -*-
# @Desc : The Android external environment to integrate with Android apps
import subprocess
import clip
import time
from pathlib import Path
from typing import Any, Optional
import clip
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from PIL import Image
from pydantic import Field
from metagpt.environment.android.text_icon_localization import *
from metagpt.const import DEFAULT_WORKSPACE_ROOT
from metagpt.environment.android.const import ADB_EXEC_FAIL
from metagpt.environment.android.env_space import (
EnvAction,
@ -22,16 +21,21 @@ from metagpt.environment.android.env_space import (
EnvObsType,
EnvObsValType,
)
from metagpt.environment.android.text_icon_localization import (
clip_for_icon,
crop_for_clip,
det,
load_model,
ocr,
)
from metagpt.environment.base_env import ExtEnv, mark_as_readable, mark_as_writeable
from metagpt.logs import logger
from metagpt.utils.common import download_model
from metagpt.const import DEFAULT_WORKSPACE_ROOT
def load_cv_model(device: str = "cpu") -> any:
ocr_detection = pipeline(Tasks.ocr_detection, model="damo/cv_resnet18_ocr-detection-line-level_damo")
ocr_recognition = pipeline(Tasks.ocr_recognition,
model="damo/cv_convnextTiny_ocr-recognition-document_damo")
ocr_recognition = pipeline(Tasks.ocr_recognition, model="damo/cv_convnextTiny_ocr-recognition-document_damo")
file_url = "https://huggingface.co/ShilongLiu/GroundingDINO/blob/main/groundingdino_swint_ogc.pth"
target_folder = Path(f"{DEFAULT_WORKSPACE_ROOT}/weights")
file_path = download_model(file_url, target_folder)
@ -64,10 +68,10 @@ class AndroidExtEnv(ExtEnv):
self.create_device_path(self.xml_dir)
def reset(
self,
*,
seed: Optional[int] = None,
options: Optional[dict[str, Any]] = None,
self,
*,
seed: Optional[int] = None,
options: Optional[dict[str, Any]] = None,
) -> tuple[dict[str, Any], dict[str, Any]]:
super().reset(seed=seed, options=options)
@ -284,8 +288,14 @@ class AndroidExtEnv(ExtEnv):
@mark_as_writeable
def user_open_app(self, app_name: str) -> str:
ocr_result = self._ocr_text(app_name)
in_coordinate, out_coordinate, x, y, iw, ih = (
ocr_result[0], ocr_result[1], ocr_result[2], ocr_result[3], ocr_result[4], ocr_result[5])
in_coordinate, _, x, y, iw, ih = (
ocr_result[0],
ocr_result[1],
ocr_result[2],
ocr_result[3],
ocr_result[4],
ocr_result[5],
)
if len(in_coordinate) == 0:
logger.info(f"No App named {app_name}.")
return "no app here"
@ -300,19 +310,30 @@ class AndroidExtEnv(ExtEnv):
@mark_as_writeable
def user_click_text(self, text: str) -> str:
ocr_result = self._ocr_text(text)
in_coordinate, out_coordinate, x, y, iw, ih, image = (
ocr_result[0], ocr_result[1], ocr_result[2], ocr_result[3], ocr_result[4], ocr_result[5], ocr_result[6])
in_coordinate, out_coordinate, x, y, iw, ih, _ = (
ocr_result[0],
ocr_result[1],
ocr_result[2],
ocr_result[3],
ocr_result[4],
ocr_result[5],
ocr_result[6],
)
if len(out_coordinate) == 0:
logger.info(
f"Failed to execute action click text ({text}). The text \"{text}\" is not detected in the screenshot.")
f'Failed to execute action click text ({text}). The text "{text}" is not detected in the screenshot.'
)
elif len(out_coordinate) == 1:
tap_coordinate = [(in_coordinate[0][0] + in_coordinate[0][2]) / 2,
(in_coordinate[0][1] + in_coordinate[0][3]) / 2]
tap_coordinate = [
(in_coordinate[0][0] + in_coordinate[0][2]) / 2,
(in_coordinate[0][1] + in_coordinate[0][3]) / 2,
]
tap_coordinate = [round(tap_coordinate[0] / iw, 2), round(tap_coordinate[1] / ih, 2)]
return self.system_tap(tap_coordinate[0] * x, tap_coordinate[1] * y)
else:
logger.info(
f"Failed to execute action click text ({text}). There are too many text \"{text}\" in the screenshot.")
f'Failed to execute action click text ({text}). There are too many text "{text}" in the screenshot.'
)
@mark_as_writeable
def user_stop(self):
@ -321,7 +342,7 @@ class AndroidExtEnv(ExtEnv):
@mark_as_writeable
def user_click_icon(self, icon_shape_color: str) -> str:
screenshot_path = self.get_screenshot("screenshot", self.screenshot_dir)
image= screenshot_path
image = screenshot_path
iw, ih = Image.open(image).size
x, y = self.device_shape
if iw > ih:
@ -329,8 +350,10 @@ class AndroidExtEnv(ExtEnv):
iw, ih = ih, iw
in_coordinate, out_coordinate = det(image, "icon", self.groundingdino_model) # 检测icon
if len(out_coordinate) == 1: # only one icon
tap_coordinate = [(in_coordinate[0][0] + in_coordinate[0][2]) / 2,
(in_coordinate[0][1] + in_coordinate[0][3]) / 2]
tap_coordinate = [
(in_coordinate[0][0] + in_coordinate[0][2]) / 2,
(in_coordinate[0][1] + in_coordinate[0][3]) / 2,
]
tap_coordinate = [round(tap_coordinate[0] / iw, 2), round(tap_coordinate[1] / ih, 2)]
return self.system_tap(tap_coordinate[0] * x, tap_coordinate[1] * y)
@ -343,7 +366,7 @@ class AndroidExtEnv(ExtEnv):
hash_table.append(td)
crop_image = f"{i}.png"
clip_filter.append(temp_file.joinpath(crop_image))
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
clip_model, clip_preprocess = clip.load("ViT-B/32") # FIXME: device=device
clip_filter = clip_for_icon(clip_model, clip_preprocess, clip_filter, icon_shape_color)
final_box = hash_table[clip_filter]
tap_coordinate = [(final_box[0] + final_box[2]) / 2, (final_box[1] + final_box[3]) / 2]

View file

@ -2,22 +2,21 @@
# https://github.com/X-PLUG/MobileAgent.git
import math
from pathlib import Path
import clip
import cv2
import groundingdino.datasets.transforms as T
import numpy as np
import torch
import subprocess
import time
from pathlib import Path
import groundingdino.datasets.transforms as T
from groundingdino.models import build_model
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
from PIL import Image, ImageDraw
from PIL import Image
################################## text_localization using ocr #######################
def crop_image(img: any, position: any) -> any:
def distance(x1, y1, x2, y2):
return math.sqrt(pow(x1 - x2, 2) + pow(y1 - y2, 2))
@ -271,7 +270,9 @@ def load_model(model_checkpoint_path: Path, device: str) -> any:
return model
def get_grounding_output(model: any, image: any, caption: str, box_threshold: any, text_threshold: any, with_logits: bool = True) -> any:
def get_grounding_output(
model: any, image: any, caption: str, box_threshold: any, text_threshold: any, with_logits: bool = True
) -> any:
caption = caption.lower()
caption = caption.strip()
if not caption.endswith("."):
@ -328,7 +329,13 @@ def remove_boxes(boxes_filt: any, size: any, iou_threshold: float = 0.5) -> any:
return boxes_filt
def det(input_image: any, text_prompt: str, groundingdino_model: any, box_threshold:float = 0.05, text_threshold:float = 0.5) -> any:
def det(
input_image: any,
text_prompt: str,
groundingdino_model: any,
box_threshold: float = 0.05,
text_threshold: float = 0.5,
) -> any:
image = Image.open(input_image)
size = image.size
@ -359,5 +366,3 @@ def det(input_image: any, text_prompt: str, groundingdino_model: any, box_thresh
)
return image_data, coordinate

View file

@ -219,7 +219,7 @@ class OutputParser:
if start_index != -1 and end_index != -1:
# Extract the structure part
structure_text = text[start_index: end_index + 1]
structure_text = text[start_index : end_index + 1]
try:
# Attempt to convert the text to a Python data type using ast.literal_eval
@ -844,7 +844,7 @@ def get_markdown_codeblock_type(filename: str) -> str:
def download_model(file_url: str, target_folder: Path) -> Path:
file_name = file_url.split('/')[-1]
file_name = file_url.split("/")[-1]
file_path = target_folder.joinpath(f"{file_name}")
if not file_path.exists():
file_path.mkdir(parents=True, exist_ok=True)
@ -852,10 +852,10 @@ def download_model(file_url: str, target_folder: Path) -> Path:
response = requests.get(file_url, stream=True)
response.raise_for_status() # 检查请求是否成功
# 保存文件
with open(file_path, 'wb') as f:
with open(file_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
logger.info(f'权重文件已下载并保存至 {file_path}')
logger.info(f"权重文件已下载并保存至 {file_path}")
except requests.exceptions.HTTPError as err:
logger.info(f'权重文件下载过程中发生错误: {err}')
logger.info(f"权重文件下载过程中发生错误: {err}")
return file_path

View file

@ -67,7 +67,7 @@ extras_require = {
"shapely",
"groundingdino-py",
"datasets==2.18.0",
"clip-openai"
"clip-openai",
],
}
@ -119,5 +119,4 @@ setup(
],
},
include_package_data=True,
)