mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
fix pre-commit
This commit is contained in:
parent
d9ed99e85f
commit
43df876f24
4 changed files with 66 additions and 39 deletions
|
|
@ -2,18 +2,17 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Desc : The Android external environment to integrate with Android apps
|
||||
import subprocess
|
||||
import clip
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import clip
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
from PIL import Image
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.environment.android.text_icon_localization import *
|
||||
from metagpt.const import DEFAULT_WORKSPACE_ROOT
|
||||
from metagpt.environment.android.const import ADB_EXEC_FAIL
|
||||
from metagpt.environment.android.env_space import (
|
||||
EnvAction,
|
||||
|
|
@ -22,16 +21,21 @@ from metagpt.environment.android.env_space import (
|
|||
EnvObsType,
|
||||
EnvObsValType,
|
||||
)
|
||||
from metagpt.environment.android.text_icon_localization import (
|
||||
clip_for_icon,
|
||||
crop_for_clip,
|
||||
det,
|
||||
load_model,
|
||||
ocr,
|
||||
)
|
||||
from metagpt.environment.base_env import ExtEnv, mark_as_readable, mark_as_writeable
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.common import download_model
|
||||
from metagpt.const import DEFAULT_WORKSPACE_ROOT
|
||||
|
||||
|
||||
def load_cv_model(device: str = "cpu") -> any:
|
||||
ocr_detection = pipeline(Tasks.ocr_detection, model="damo/cv_resnet18_ocr-detection-line-level_damo")
|
||||
ocr_recognition = pipeline(Tasks.ocr_recognition,
|
||||
model="damo/cv_convnextTiny_ocr-recognition-document_damo")
|
||||
ocr_recognition = pipeline(Tasks.ocr_recognition, model="damo/cv_convnextTiny_ocr-recognition-document_damo")
|
||||
file_url = "https://huggingface.co/ShilongLiu/GroundingDINO/blob/main/groundingdino_swint_ogc.pth"
|
||||
target_folder = Path(f"{DEFAULT_WORKSPACE_ROOT}/weights")
|
||||
file_path = download_model(file_url, target_folder)
|
||||
|
|
@ -64,10 +68,10 @@ class AndroidExtEnv(ExtEnv):
|
|||
self.create_device_path(self.xml_dir)
|
||||
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
options: Optional[dict[str, Any]] = None,
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
options: Optional[dict[str, Any]] = None,
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
super().reset(seed=seed, options=options)
|
||||
|
||||
|
|
@ -284,8 +288,14 @@ class AndroidExtEnv(ExtEnv):
|
|||
@mark_as_writeable
|
||||
def user_open_app(self, app_name: str) -> str:
|
||||
ocr_result = self._ocr_text(app_name)
|
||||
in_coordinate, out_coordinate, x, y, iw, ih = (
|
||||
ocr_result[0], ocr_result[1], ocr_result[2], ocr_result[3], ocr_result[4], ocr_result[5])
|
||||
in_coordinate, _, x, y, iw, ih = (
|
||||
ocr_result[0],
|
||||
ocr_result[1],
|
||||
ocr_result[2],
|
||||
ocr_result[3],
|
||||
ocr_result[4],
|
||||
ocr_result[5],
|
||||
)
|
||||
if len(in_coordinate) == 0:
|
||||
logger.info(f"No App named {app_name}.")
|
||||
return "no app here"
|
||||
|
|
@ -300,19 +310,30 @@ class AndroidExtEnv(ExtEnv):
|
|||
@mark_as_writeable
|
||||
def user_click_text(self, text: str) -> str:
|
||||
ocr_result = self._ocr_text(text)
|
||||
in_coordinate, out_coordinate, x, y, iw, ih, image = (
|
||||
ocr_result[0], ocr_result[1], ocr_result[2], ocr_result[3], ocr_result[4], ocr_result[5], ocr_result[6])
|
||||
in_coordinate, out_coordinate, x, y, iw, ih, _ = (
|
||||
ocr_result[0],
|
||||
ocr_result[1],
|
||||
ocr_result[2],
|
||||
ocr_result[3],
|
||||
ocr_result[4],
|
||||
ocr_result[5],
|
||||
ocr_result[6],
|
||||
)
|
||||
if len(out_coordinate) == 0:
|
||||
logger.info(
|
||||
f"Failed to execute action click text ({text}). The text \"{text}\" is not detected in the screenshot.")
|
||||
f'Failed to execute action click text ({text}). The text "{text}" is not detected in the screenshot.'
|
||||
)
|
||||
elif len(out_coordinate) == 1:
|
||||
tap_coordinate = [(in_coordinate[0][0] + in_coordinate[0][2]) / 2,
|
||||
(in_coordinate[0][1] + in_coordinate[0][3]) / 2]
|
||||
tap_coordinate = [
|
||||
(in_coordinate[0][0] + in_coordinate[0][2]) / 2,
|
||||
(in_coordinate[0][1] + in_coordinate[0][3]) / 2,
|
||||
]
|
||||
tap_coordinate = [round(tap_coordinate[0] / iw, 2), round(tap_coordinate[1] / ih, 2)]
|
||||
return self.system_tap(tap_coordinate[0] * x, tap_coordinate[1] * y)
|
||||
else:
|
||||
logger.info(
|
||||
f"Failed to execute action click text ({text}). There are too many text \"{text}\" in the screenshot.")
|
||||
f'Failed to execute action click text ({text}). There are too many text "{text}" in the screenshot.'
|
||||
)
|
||||
|
||||
@mark_as_writeable
|
||||
def user_stop(self):
|
||||
|
|
@ -321,7 +342,7 @@ class AndroidExtEnv(ExtEnv):
|
|||
@mark_as_writeable
|
||||
def user_click_icon(self, icon_shape_color: str) -> str:
|
||||
screenshot_path = self.get_screenshot("screenshot", self.screenshot_dir)
|
||||
image= screenshot_path
|
||||
image = screenshot_path
|
||||
iw, ih = Image.open(image).size
|
||||
x, y = self.device_shape
|
||||
if iw > ih:
|
||||
|
|
@ -329,8 +350,10 @@ class AndroidExtEnv(ExtEnv):
|
|||
iw, ih = ih, iw
|
||||
in_coordinate, out_coordinate = det(image, "icon", self.groundingdino_model) # 检测icon
|
||||
if len(out_coordinate) == 1: # only one icon
|
||||
tap_coordinate = [(in_coordinate[0][0] + in_coordinate[0][2]) / 2,
|
||||
(in_coordinate[0][1] + in_coordinate[0][3]) / 2]
|
||||
tap_coordinate = [
|
||||
(in_coordinate[0][0] + in_coordinate[0][2]) / 2,
|
||||
(in_coordinate[0][1] + in_coordinate[0][3]) / 2,
|
||||
]
|
||||
tap_coordinate = [round(tap_coordinate[0] / iw, 2), round(tap_coordinate[1] / ih, 2)]
|
||||
return self.system_tap(tap_coordinate[0] * x, tap_coordinate[1] * y)
|
||||
|
||||
|
|
@ -343,7 +366,7 @@ class AndroidExtEnv(ExtEnv):
|
|||
hash_table.append(td)
|
||||
crop_image = f"{i}.png"
|
||||
clip_filter.append(temp_file.joinpath(crop_image))
|
||||
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
|
||||
clip_model, clip_preprocess = clip.load("ViT-B/32") # FIXME: device=device
|
||||
clip_filter = clip_for_icon(clip_model, clip_preprocess, clip_filter, icon_shape_color)
|
||||
final_box = hash_table[clip_filter]
|
||||
tap_coordinate = [(final_box[0] + final_box[2]) / 2, (final_box[1] + final_box[3]) / 2]
|
||||
|
|
|
|||
|
|
@ -2,22 +2,21 @@
|
|||
# https://github.com/X-PLUG/MobileAgent.git
|
||||
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import clip
|
||||
import cv2
|
||||
import groundingdino.datasets.transforms as T
|
||||
import numpy as np
|
||||
import torch
|
||||
import subprocess
|
||||
import time
|
||||
from pathlib import Path
|
||||
import groundingdino.datasets.transforms as T
|
||||
from groundingdino.models import build_model
|
||||
from groundingdino.util.slconfig import SLConfig
|
||||
from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
from PIL import Image
|
||||
|
||||
################################## text_localization using ocr #######################
|
||||
|
||||
|
||||
def crop_image(img: any, position: any) -> any:
|
||||
def distance(x1, y1, x2, y2):
|
||||
return math.sqrt(pow(x1 - x2, 2) + pow(y1 - y2, 2))
|
||||
|
|
@ -271,7 +270,9 @@ def load_model(model_checkpoint_path: Path, device: str) -> any:
|
|||
return model
|
||||
|
||||
|
||||
def get_grounding_output(model: any, image: any, caption: str, box_threshold: any, text_threshold: any, with_logits: bool = True) -> any:
|
||||
def get_grounding_output(
|
||||
model: any, image: any, caption: str, box_threshold: any, text_threshold: any, with_logits: bool = True
|
||||
) -> any:
|
||||
caption = caption.lower()
|
||||
caption = caption.strip()
|
||||
if not caption.endswith("."):
|
||||
|
|
@ -328,7 +329,13 @@ def remove_boxes(boxes_filt: any, size: any, iou_threshold: float = 0.5) -> any:
|
|||
return boxes_filt
|
||||
|
||||
|
||||
def det(input_image: any, text_prompt: str, groundingdino_model: any, box_threshold:float = 0.05, text_threshold:float = 0.5) -> any:
|
||||
def det(
|
||||
input_image: any,
|
||||
text_prompt: str,
|
||||
groundingdino_model: any,
|
||||
box_threshold: float = 0.05,
|
||||
text_threshold: float = 0.5,
|
||||
) -> any:
|
||||
image = Image.open(input_image)
|
||||
size = image.size
|
||||
|
||||
|
|
@ -359,5 +366,3 @@ def det(input_image: any, text_prompt: str, groundingdino_model: any, box_thresh
|
|||
)
|
||||
|
||||
return image_data, coordinate
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -219,7 +219,7 @@ class OutputParser:
|
|||
|
||||
if start_index != -1 and end_index != -1:
|
||||
# Extract the structure part
|
||||
structure_text = text[start_index: end_index + 1]
|
||||
structure_text = text[start_index : end_index + 1]
|
||||
|
||||
try:
|
||||
# Attempt to convert the text to a Python data type using ast.literal_eval
|
||||
|
|
@ -844,7 +844,7 @@ def get_markdown_codeblock_type(filename: str) -> str:
|
|||
|
||||
|
||||
def download_model(file_url: str, target_folder: Path) -> Path:
|
||||
file_name = file_url.split('/')[-1]
|
||||
file_name = file_url.split("/")[-1]
|
||||
file_path = target_folder.joinpath(f"{file_name}")
|
||||
if not file_path.exists():
|
||||
file_path.mkdir(parents=True, exist_ok=True)
|
||||
|
|
@ -852,10 +852,10 @@ def download_model(file_url: str, target_folder: Path) -> Path:
|
|||
response = requests.get(file_url, stream=True)
|
||||
response.raise_for_status() # 检查请求是否成功
|
||||
# 保存文件
|
||||
with open(file_path, 'wb') as f:
|
||||
with open(file_path, "wb") as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
logger.info(f'权重文件已下载并保存至 {file_path}')
|
||||
logger.info(f"权重文件已下载并保存至 {file_path}")
|
||||
except requests.exceptions.HTTPError as err:
|
||||
logger.info(f'权重文件下载过程中发生错误: {err}')
|
||||
logger.info(f"权重文件下载过程中发生错误: {err}")
|
||||
return file_path
|
||||
|
|
|
|||
3
setup.py
3
setup.py
|
|
@ -67,7 +67,7 @@ extras_require = {
|
|||
"shapely",
|
||||
"groundingdino-py",
|
||||
"datasets==2.18.0",
|
||||
"clip-openai"
|
||||
"clip-openai",
|
||||
],
|
||||
}
|
||||
|
||||
|
|
@ -119,5 +119,4 @@ setup(
|
|||
],
|
||||
},
|
||||
include_package_data=True,
|
||||
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue