Merge remote-tracking branch 'origin/main'

This commit is contained in:
kithib 2024-04-18 15:36:55 +08:00
parent 8bebed3fc2
commit 4f18c4b4b5
5 changed files with 207 additions and 58 deletions

View file

@ -0,0 +1,43 @@
batch_size = 1
modelname = "groundingdino"
backbone = "swin_T_224_1k"
position_embedding = "sine"
pe_temperatureH = 20
pe_temperatureW = 20
return_interm_indices = [1, 2, 3]
backbone_freeze_keywords = None
enc_layers = 6
dec_layers = 6
pre_norm = False
dim_feedforward = 2048
hidden_dim = 256
dropout = 0.0
nheads = 8
num_queries = 900
query_dim = 4
num_patterns = 0
num_feature_levels = 4
enc_n_points = 4
dec_n_points = 4
two_stage_type = "standard"
two_stage_bbox_embed_share = False
two_stage_class_embed_share = False
transformer_activation = "relu"
dec_pred_bbox_embed_share = True
dn_box_noise_scale = 1.0
dn_label_noise_ratio = 0.5
dn_label_coef = 1.0
dn_bbox_coef = 1.0
embed_init_tgt = True
dn_labelbook_size = 2000
max_text_len = 256
text_encoder_type = "bert-base-uncased"
use_text_enhancer = True
use_fusion_layer = True
use_checkpoint = True
use_transformer_ckpt = True
use_text_cross_attention = True
text_dropout = 0.0
fusion_dropout = 0.0
fusion_droppath = 0.1
sub_sentence_present = True

View file

@ -1,17 +1,20 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : The Android external environment to integrate with Android apps
import os
import subprocess
import clip
import time
from pathlib import Path
from typing import Any, Optional
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from PIL import Image
from pydantic import Field
from metagpt.environment.android.text_icon_localization import ocr
from metagpt.environment.android.text_icon_localization import *
from metagpt.environment.android.const import ADB_EXEC_FAIL
from metagpt.environment.android.env_space import (
EnvAction,
@ -22,6 +25,7 @@ from metagpt.environment.android.env_space import (
)
from metagpt.environment.base_env import ExtEnv, mark_as_readable, mark_as_writeable
from metagpt.logs import logger
from metagpt.utils.download_modelweight import download_model
class AndroidExtEnv(ExtEnv):
@ -42,14 +46,14 @@ class AndroidExtEnv(ExtEnv):
self.width = data.get("width", width)
self.height = data.get("height", height)
self.create_device_path(self.screenshot_dir)
self.create_device_path(self.xml_dir)
# self.create_device_path(self.screenshot_dir)
# self.create_device_path(self.xml_dir)
def reset(
self,
*,
seed: Optional[int] = None,
options: Optional[dict[str, Any]] = None,
self,
*,
seed: Optional[int] = None,
options: Optional[dict[str, Any]] = None,
) -> tuple[dict[str, Any], dict[str, Any]]:
super().reset(seed=seed, options=options)
@ -154,14 +158,17 @@ class AndroidExtEnv(ExtEnv):
ss_remote_path = Path(self.screenshot_dir).joinpath(f"{ss_name}.png")
ss_cmd = f"{self.adb_prefix_shell} screencap -p {ss_remote_path}"
ss_res = self.execute_adb_with_cmd(ss_cmd)
time.sleep(0.1)
res = ADB_EXEC_FAIL
if ss_res != ADB_EXEC_FAIL:
ss_local_path = Path(local_save_dir).joinpath(f"{ss_name}.png")
pull_cmd = f"{self.adb_prefix} pull {ss_remote_path} {ss_local_path}"
pull_res = self.execute_adb_with_cmd(pull_cmd)
time.sleep(0.1)
if pull_res != ADB_EXEC_FAIL:
res = ss_local_path
else:
res = get_screenshot_only(local_save_dir)
return Path(res)
@mark_as_readable
@ -229,22 +236,22 @@ class AndroidExtEnv(ExtEnv):
return swipe_res
@mark_as_writeable
def user_swipe_to(self, start: tuple[int, int], end: tuple[int, int], duration: int = 400):
def user_swipe_to(self, start: tuple[int, int], end: tuple[int, int], duration: int = 400) -> str:
adb_cmd = f"{self.adb_prefix_si} swipe {start[0]} {start[1]} {end[0]} {end[1]} {duration}"
swipe_res = self.execute_adb_with_cmd(adb_cmd)
return swipe_res
@mark_as_writeable
def user_exit(self):
adb_cmd = "adb shell am start -a android.intent.action.MAIN -c android.intent.category.HOME"
def user_exit(self) -> str:
adb_cmd = f"{self.adb_prefix_shell} am start -a android.intent.action.MAIN -c android.intent.category.HOME"
exit_res = self.execute_adb_with_cmd(adb_cmd)
return exit_res
@mark_as_writeable
def user_openApp(self, app_name: str):
# openApp without xml
screenshot_path = self.get_screenshot("screenshot", "../../../examples/data/screenshot")
image = screenshot_path
def _ocr_text(self, text: str) -> list:
if not os.path.exists(self.screenshot_dir):
os.makedirs(self.screenshot_dir)
image = self.get_screenshot("screenshot", self.screenshot_dir)
ocr_detection = pipeline(Tasks.ocr_detection, model="damo/cv_resnet18_ocr-detection-line-level_damo")
ocr_recognition = pipeline(Tasks.ocr_recognition, model="damo/cv_convnextTiny_ocr-recognition-document_damo")
iw, ih = Image.open(image).size
@ -252,7 +259,15 @@ class AndroidExtEnv(ExtEnv):
if iw > ih:
x, y = y, x
iw, ih = ih, iw
in_coordinate, out_coordinate = ocr(image, app_name, ocr_detection, ocr_recognition, iw, ih)
in_coordinate, out_coordinate = ocr(image, text, ocr_detection, ocr_recognition, iw, ih)
output_list = [in_coordinate, out_coordinate, x, y, iw, ih, image]
return output_list
@mark_as_writeable
def user_open_app(self, app_name: str) -> str:
ocr_result = self._ocr_text(app_name)
in_coordinate, out_coordinate, x, y, iw, ih = (
ocr_result[0], ocr_result[1], ocr_result[2], ocr_result[3], ocr_result[4], ocr_result[5])
if len(in_coordinate) == 0:
logger.info(f"No App named {app_name}.")
return "no"
@ -262,11 +277,69 @@ class AndroidExtEnv(ExtEnv):
(in_coordinate[0][1] + in_coordinate[0][3]) / 2,
]
tap_coordinate = [round(tap_coordinate[0] / iw, 2), round(tap_coordinate[1] / ih, 2)]
#print(f"{parameter}在屏幕的坐标为为{tap_coordinate[0] * x} ,{(tap_coordinate[1] - round(50 / y, 2)) * y}")
return self.system_tap(tap_coordinate[0] * x, (tap_coordinate[1] - round(50 / y, 2)) * y)
@mark_as_writeable
def user_click_text(self, text: str) -> str:
ocr_result = self._ocr_text(text)
in_coordinate, out_coordinate, x, y, iw, ih, image = (
ocr_result[0], ocr_result[1], ocr_result[2], ocr_result[3], ocr_result[4], ocr_result[5], ocr_result[6])
if len(out_coordinate) == 0:
logger.info(
f"Failed to execute action click text ({text}). The text \"{text}\" is not detected in the screenshot.")
elif len(out_coordinate) == 1:
tap_coordinate = [(in_coordinate[0][0] + in_coordinate[0][2]) / 2,
(in_coordinate[0][1] + in_coordinate[0][3]) / 2]
tap_coordinate = [round(tap_coordinate[0] / iw, 2), round(tap_coordinate[1] / ih, 2)]
return self.system_tap(tap_coordinate[0] * x, tap_coordinate[1] * y)
else:
logger.info(
f"Failed to execute action click text ({text}). There are too many text \"{text}\" in the screenshot.")
@mark_as_writeable
def user_stop(self):
logger.info("Successful execution of tasks")
# todo user_clickIcon
@mark_as_writeable
def user_click_icon(self, icon_shape_color: str) -> str:
if not os.path.exists(self.screenshot_dir):
os.makedirs(self.screenshot_dir)
screenshot_path = self.get_screenshot("screenshot", self.screenshot_dir)
image, device = screenshot_path, 'cpu'
iw, ih = Image.open(image).size
x, y = self.device_shape
if iw > ih:
x, y = y, x
iw, ih = ih, iw
# 下载权重文件
file_url = 'https://huggingface.co/ShilongLiu/GroundingDINO/blob/main/groundingdino_swint_ogc.pth' # 加载远程model
target_folder = '/Users/kit/Desktop/深度赋值/amzingproject/MetaGPT/workspace/weights'
file_path = download_model(file_url, target_folder)
groundingdino_model = load_model(file_path, device=device).eval()
in_coordinate, out_coordinate = det(image, "icon", groundingdino_model) # 检测icon
if len(out_coordinate) == 1: # only one icon
tap_coordinate = [(in_coordinate[0][0] + in_coordinate[0][2]) / 2,
(in_coordinate[0][1] + in_coordinate[0][3]) / 2]
tap_coordinate = [round(tap_coordinate[0] / iw, 2), round(tap_coordinate[1] / ih, 2)]
return self.system_tap(tap_coordinate[0] * x, tap_coordinate[1] * y)
else:
temp_file = "/Users/kit/Desktop/深度赋值/amzingproject/MetaGPT/workspace/temp"
if not os.path.exists(temp_file):
os.mkdir(temp_file)
hash_table, clip_filter= [],[]
for i, (td, box) in enumerate(zip(in_coordinate, out_coordinate)):
if crop_for_clip(image, td, i, temp_file):
hash_table.append(td)
crop_image = f"{i}.jpg"
clip_filter.append(os.path.join(temp_file, crop_image))
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
clip_filter = clip_for_icon(clip_model, clip_preprocess, clip_filter, icon_shape_color)
final_box = hash_table[clip_filter]
tap_coordinate = [(final_box[0] + final_box[2]) / 2, (final_box[1] + final_box[3]) / 2]
tap_coordinate = [round(tap_coordinate[0] / iw, 2), round(tap_coordinate[1] / ih, 2)]
return self.system_tap(tap_coordinate[0] * x, tap_coordinate[1] * y)

View file

@ -3,7 +3,9 @@ import clip
import cv2
import numpy as np
import torch
import subprocess
import time
from pathlib import Path
import groundingdino.datasets.transforms as T
from groundingdino.models import build_model
from groundingdino.util.slconfig import SLConfig
@ -12,6 +14,7 @@ from PIL import Image, ImageDraw
################################## text_localization using ocr #######################
def crop_image(img, position):
def distance(x1, y1, x2, y2):
return math.sqrt(pow(x1 - x2, 2) + pow(y1 - y2, 2))
@ -96,7 +99,7 @@ def ocr(image_path, prompt, ocr_detection, ocr_recognition, x, y):
image = Image.open(image_path)
iw, ih = image.size
image_full = cv2.imread(image_path)
image_full = cv2.imread(str(image_path))
det_result = ocr_detection(image_full)
det_result = det_result["polygons"]
for i in range(det_result.shape[0]):
@ -205,7 +208,6 @@ def calculate_iou(box1, box2):
def crop(image, box, i, text_data=None):
image = Image.open(image)
if text_data:
draw = ImageDraw.Draw(image)
draw.rectangle(((text_data[0], text_data[1]), (text_data[2], text_data[3])), outline="red", width=5)
@ -224,31 +226,13 @@ def in_box(box, target):
return False
def crop_for_clip(image, box, i, position):
def crop_for_clip(image, box, i, temp_file):
image = Image.open(image)
w, h = image.size
if position == "left":
bound = [0, 0, w / 2, h]
elif position == "right":
bound = [w / 2, 0, w, h]
elif position == "top":
bound = [0, 0, w, h / 2]
elif position == "bottom":
bound = [0, h / 2, w, h]
elif position == "top left":
bound = [0, 0, w / 2, h / 2]
elif position == "top right":
bound = [w / 2, 0, w, h / 2]
elif position == "bottom left":
bound = [0, h / 2, w / 2, h]
elif position == "bottom right":
bound = [w / 2, h / 2, w, h]
else:
bound = [0, 0, w, h]
bound = [0, 0, w, h]
if in_box(box, bound):
cropped_image = image.crop(box)
cropped_image.save(f"./temp/{i}.jpg")
cropped_image.save(f"{temp_file}/{i}.jpg")
return True
else:
return False
@ -286,7 +270,8 @@ def transform_image(image_pil):
return image
def load_model(model_config_path, model_checkpoint_path, device):
def load_model(model_checkpoint_path, device):
model_config_path = 'GroundingDINO_SwinT_OGC.py'
args = SLConfig.fromfile(model_config_path)
args.device = device
model = build_model(args)
@ -387,17 +372,23 @@ def det(input_image, text_prompt, groundingdino_model, box_threshold=0.05, text_
return image_data, coordinate
if __name__ == "__main__":
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
def get_screenshot_only(screenshot_dir: Path) -> str:
command = " adb shell rm /sdcard/screenshot.png"
subprocess.run(command, capture_output=True, text=True, shell=True)
time.sleep(0.1)
command = "adb shell screencap -p /sdcard/screenshot.png"
subprocess.run(command, capture_output=True, text=True, shell=True)
time.sleep(0.1)
command = f"adb pull /sdcard/screenshot.png {screenshot_dir}"
subprocess.run(command, capture_output=True, text=True, shell=True)
image_path = f"{screenshot_dir}/screenshot.png"
save_path = f"{screenshot_dir}/screenshot.jpg"
image = Image.open(image_path)
original_width, original_height = image.size
new_width = int(original_width * 0.5)
new_height = int(original_height * 0.5)
resized_image = image.resize((new_width, new_height))
resized_image.convert("RGB").save(save_path, "JPEG")
time.sleep(0.1)
return save_path
image_ori = "./screenshot/screenshot.png"
image = "./screenshot/screenshot.png"
parameter = "抖音"
ocr_detection = pipeline(Tasks.ocr_detection, model="damo/cv_resnet18_ocr-detection-line-level_damo")
ocr_recognition = pipeline(Tasks.ocr_recognition, model="damo/cv_convnextTiny_ocr-recognition-document_damo")
iw, ih = Image.open(image).size
if iw > ih:
iw, ih = ih, iw
in_coordinate, out_coordinate = ocr(image_ori, parameter, ocr_detection, ocr_recognition, iw, ih)
print(f"ocr 计算结果为 {in_coordinate} ,{out_coordinate} ")

View file

@ -0,0 +1,22 @@
import os
import requests
from pathlib import Path
def download_model(file_url: str, target_folder: str) -> str:
file_name = file_url.split('/')[-1] # 文件名从URL中提取
file_path = os.path.join(target_folder, file_name) # 完整的文件路径
if not os.path.exists(target_folder):
os.makedirs(target_folder)
# 发起GET请求下载文件
try:
response = requests.get(file_url, stream=True)
response.raise_for_status() # 检查请求是否成功
# 保存文件
with open(file_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print(f'权重文件已下载并保存至 {file_path}')
except requests.exceptions.HTTPError as err:
print(f'权重文件下载过程中发生错误: {err}')
return file_path

View file

@ -70,4 +70,24 @@ qianfan==0.3.2
dashscope==1.14.1
rank-bm25==0.2.2 # for tool recommendation
jieba==0.42.1 # for tool recommendation
gymnasium==0.29.1
gymnasium==0.29.1
# for clip and ocr
git+https://github.com/openai/CLIP.git
protobuf<3.20,>=3.9.2
modelscope
tensorflow==2.9.1; os_name == 'linux'
tensorflow-macos==2.9; os_name == 'darwin'
keras==2.9.0
torch
torchvision
transformers
opencv-python
matplotlib
pycocotools
timm
SentencePiece
tf_slim
tf_keras
pyclipper
shapely