Merge remote-tracking branch 'origin/main'

This commit is contained in:
kithib 2024-04-19 17:56:36 +08:00
parent d2e461a1e8
commit d0e898dcfa
7 changed files with 91 additions and 111 deletions

View file

@ -1,7 +1,6 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : The Android external environment to integrate with Android apps
import os
import subprocess
import clip
import time
@ -25,7 +24,8 @@ from metagpt.environment.android.env_space import (
)
from metagpt.environment.base_env import ExtEnv, mark_as_readable, mark_as_writeable
from metagpt.logs import logger
from metagpt.utils.download_modelweight import download_model
from metagpt.utils.common import download_model
from metagpt.const import DEFAULT_WORKSPACE_ROOT
class AndroidExtEnv(ExtEnv):
@ -46,14 +46,14 @@ class AndroidExtEnv(ExtEnv):
self.width = data.get("width", width)
self.height = data.get("height", height)
self.create_device_path(self.screenshot_dir)
self.create_device_path(self.xml_dir)
#self.create_device_path(self.screenshot_dir)
#self.create_device_path(self.xml_dir)
def reset(
self,
*,
seed: Optional[int] = None,
options: Optional[dict[str, Any]] = None,
self,
*,
seed: Optional[int] = None,
options: Optional[dict[str, Any]] = None,
) -> tuple[dict[str, Any], dict[str, Any]]:
super().reset(seed=seed, options=options)
@ -247,10 +247,9 @@ class AndroidExtEnv(ExtEnv):
exit_res = self.execute_adb_with_cmd(adb_cmd)
return exit_res
@mark_as_writeable
def _ocr_text(self, text: str) -> list:
if not os.path.exists(self.screenshot_dir):
os.makedirs(self.screenshot_dir)
if not self.screenshot_dir.exists():
self.screenshot_dir.mkdir(parents=True, exist_ok=True)
image = self.get_screenshot("screenshot", self.screenshot_dir)
ocr_detection = pipeline(Tasks.ocr_detection, model="damo/cv_resnet18_ocr-detection-line-level_damo")
ocr_recognition = pipeline(Tasks.ocr_recognition, model="damo/cv_convnextTiny_ocr-recognition-document_damo")
@ -302,8 +301,8 @@ class AndroidExtEnv(ExtEnv):
@mark_as_writeable
def user_click_icon(self, icon_shape_color: str) -> str:
if not os.path.exists(self.screenshot_dir):
os.makedirs(self.screenshot_dir)
if not self.screenshot_dir.exists():
self.screenshot_dir.mkdir(parents=True, exist_ok=True)
screenshot_path = self.get_screenshot("screenshot", self.screenshot_dir)
image, device = screenshot_path, 'cpu'
iw, ih = Image.open(image).size
@ -311,9 +310,8 @@ class AndroidExtEnv(ExtEnv):
if iw > ih:
x, y = y, x
iw, ih = ih, iw
# 下载权重文件
file_url = 'https://huggingface.co/ShilongLiu/GroundingDINO/blob/main/groundingdino_swint_ogc.pth' # 加载远程model
target_folder = 'workspace/weights'
target_folder = Path(f'{DEFAULT_WORKSPACE_ROOT}/weights')
file_path = download_model(file_url, target_folder)
groundingdino_model = load_model(file_path, device=device).eval()
in_coordinate, out_coordinate = det(image, "icon", groundingdino_model) # 检测icon
@ -324,22 +322,18 @@ class AndroidExtEnv(ExtEnv):
return self.system_tap(tap_coordinate[0] * x, tap_coordinate[1] * y)
else:
temp_file = "workspace/temp"
if not os.path.exists(temp_file):
os.mkdir(temp_file)
hash_table, clip_filter= [],[]
temp_file = Path(f"{DEFAULT_WORKSPACE_ROOT}/temp")
if not temp_file.exists():
temp_file.mkdir(parents=True, exist_ok=True)
hash_table, clip_filter = [], []
for i, (td, box) in enumerate(zip(in_coordinate, out_coordinate)):
if crop_for_clip(image, td, i, temp_file):
hash_table.append(td)
crop_image = f"{i}.jpg"
clip_filter.append(os.path.join(temp_file, crop_image))
clip_filter.append(temp_file.joinpath(crop_image))
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
clip_filter = clip_for_icon(clip_model, clip_preprocess, clip_filter, icon_shape_color)
final_box = hash_table[clip_filter]
tap_coordinate = [(final_box[0] + final_box[2]) / 2, (final_box[1] + final_box[3]) / 2]
tap_coordinate = [round(tap_coordinate[0] / iw, 2), round(tap_coordinate[1] / ih, 2)]
return self.system_tap(tap_coordinate[0] * x, tap_coordinate[1] * y)

View file

@ -1,3 +1,6 @@
# The code in this file was modified by MobileAgent
# https://github.com/X-PLUG/MobileAgent.git
import math
import clip
import cv2
@ -15,7 +18,7 @@ from PIL import Image, ImageDraw
################################## text_localization using ocr #######################
def crop_image(img, position):
def crop_image(img: any, position: any) -> any:
def distance(x1, y1, x2, y2):
return math.sqrt(pow(x1 - x2, 2) + pow(y1 - y2, 2))
@ -61,12 +64,12 @@ def crop_image(img, position):
return dst
def calculate_size(box):
def calculate_size(box: any) -> any:
return (box[2] - box[0]) * (box[3] - box[1])
def order_point(coor):
arr = np.array(coor).reshape([4, 2])
def order_point(cooperation: any) -> any:
arr = np.array(cooperation).reshape([4, 2])
sum_ = np.sum(arr, 0)
centroid = sum_ / arr.shape[0]
theta = np.arctan2(arr[:, 1] - centroid[1], arr[:, 0] - centroid[0])
@ -78,11 +81,10 @@ def order_point(coor):
return sort_points
def longest_common_substring_length(str1, str2):
def longest_common_substring_length(str1: str, str2: str) -> int:
m = len(str1)
n = len(str2)
dp = [[0] * (n + 1) for _ in range(m + 1)]
for i in range(1, m + 1):
for j in range(1, n + 1):
if str1[i - 1] == str2[j - 1]:
@ -93,7 +95,7 @@ def longest_common_substring_length(str1, str2):
return dp[m][n]
def ocr(image_path, prompt, ocr_detection, ocr_recognition, x, y):
def ocr(image_path: Path, prompt: str, ocr_detection: any, ocr_recognition: any, x: int, y: int) -> any:
text_data = []
coordinate = []
image = Image.open(image_path)
@ -191,54 +193,41 @@ def ocr(image_path, prompt, ocr_detection, ocr_recognition, x, y):
################################## icon_localization using clip #######################
def calculate_iou(box1, box2):
xA = max(box1[0], box2[0])
yA = max(box1[1], box2[1])
xB = min(box1[2], box2[2])
yB = min(box1[3], box2[3])
def calculate_iou(box1: list, box2: list) -> float:
x_a = max(box1[0], box2[0])
y_a = max(box1[1], box2[1])
x_b = min(box1[2], box2[2])
y_b = min(box1[3], box2[3])
interArea = max(0, xB - xA) * max(0, yB - yA)
box1Area = (box1[2] - box1[0]) * (box1[3] - box1[1])
box2Area = (box2[2] - box2[0]) * (box2[3] - box2[1])
unionArea = box1Area + box2Area - interArea
iou = interArea / unionArea
inter_area = max(0, x_b - x_a) * max(0, y_b - y_a)
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
union_area = box1_area + box2_area - inter_area
iou = inter_area / union_area
return iou
def crop(image, box, i, text_data=None):
image = Image.open(image)
if text_data:
draw = ImageDraw.Draw(image)
draw.rectangle(((text_data[0], text_data[1]), (text_data[2], text_data[3])), outline="red", width=5)
# font_size = int((text_data[3] - text_data[1])*0.75)
# font = ImageFont.truetype("arial.ttf", font_size)
# draw.text((text_data[0]+5, text_data[1]+5), str(i), font=font, fill="red")
cropped_image = image.crop(box)
cropped_image.save(f"./temp/{i}.jpg")
def in_box(box, target):
def in_box(box: list, target: list) -> bool:
if (box[0] > target[0]) and (box[1] > target[1]) and (box[2] < target[2]) and (box[3] < target[3]):
return True
else:
return False
def crop_for_clip(image, box, i, temp_file):
def crop_for_clip(image: any, box: any, i: int, temp_file: Path) -> bool:
image = Image.open(image)
w, h = image.size
bound = [0, 0, w, h]
if in_box(box, bound):
cropped_image = image.crop(box)
cropped_image.save(f"{temp_file}/{i}.jpg")
cropped_image.save(temp_file.joinpath(f"{i}.jpg"))
return True
else:
return False
def clip_for_icon(clip_model, clip_preprocess, images, prompt):
def clip_for_icon(clip_model: any, clip_preprocess: any, images: any, prompt: str) -> any:
image_features = []
for image_file in images:
image = clip_preprocess(Image.open(image_file)).unsqueeze(0).to(next(clip_model.parameters()).device)
@ -258,7 +247,7 @@ def clip_for_icon(clip_model, clip_preprocess, images, prompt):
return pos
def transform_image(image_pil):
def transform_image(image_pil: any) -> any:
transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
@ -270,8 +259,8 @@ def transform_image(image_pil):
return image
def load_model(model_checkpoint_path, device):
model_config_path = 'GroundingDINO_SwinT_OGC.py'
def load_model(model_checkpoint_path: Path, device: str) -> any:
model_config_path = 'grounding_dino_config.py'
args = SLConfig.fromfile(model_config_path)
args.device = device
model = build_model(args)
@ -282,7 +271,7 @@ def load_model(model_checkpoint_path, device):
return model
def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True):
def get_grounding_output(model: any, image: any, caption: str, box_threshold: any, text_threshold: any, with_logits=True) -> any:
caption = caption.lower()
caption = caption.strip()
if not caption.endswith("."):
@ -317,7 +306,7 @@ def get_grounding_output(model, image, caption, box_threshold, text_threshold, w
return boxes_filt, torch.Tensor(scores), pred_phrases
def remove_boxes(boxes_filt, size, iou_threshold=0.5):
def remove_boxes(boxes_filt: any, size: any, iou_threshold=0.5) -> any:
boxes_to_remove = set()
for i in range(len(boxes_filt)):
@ -339,7 +328,7 @@ def remove_boxes(boxes_filt, size, iou_threshold=0.5):
return boxes_filt
def det(input_image, text_prompt, groundingdino_model, box_threshold=0.05, text_threshold=0.5):
def det(input_image: any, text_prompt: str, groundingdino_model: any, box_threshold=0.05, text_threshold=0.5) -> any:
image = Image.open(input_image)
size = image.size
@ -372,7 +361,7 @@ def det(input_image, text_prompt, groundingdino_model, box_threshold=0.05, text_
return image_data, coordinate
def get_screenshot_only(screenshot_dir: Path) -> str:
def get_screenshot_only(screenshot_dir: Path) -> Path:
command = " adb shell rm /sdcard/screenshot.png"
subprocess.run(command, capture_output=True, text=True, shell=True)
time.sleep(0.1)
@ -381,8 +370,8 @@ def get_screenshot_only(screenshot_dir: Path) -> str:
time.sleep(0.1)
command = f"adb pull /sdcard/screenshot.png {screenshot_dir}"
subprocess.run(command, capture_output=True, text=True, shell=True)
image_path = f"{screenshot_dir}/screenshot.png"
save_path = f"{screenshot_dir}/screenshot.jpg"
image_path = Path(f"{screenshot_dir}/screenshot.png")
save_path = Path(f"{screenshot_dir}/screenshot.jpg")
image = Image.open(image_path)
original_width, original_height = image.size
new_width = int(original_width * 0.5)
@ -391,4 +380,3 @@ def get_screenshot_only(screenshot_dir: Path) -> str:
resized_image.convert("RGB").save(save_path, "JPEG")
time.sleep(0.1)
return save_path

View file

@ -219,7 +219,7 @@ class OutputParser:
if start_index != -1 and end_index != -1:
# Extract the structure part
structure_text = text[start_index : end_index + 1]
structure_text = text[start_index: end_index + 1]
try:
# Attempt to convert the text to a Python data type using ast.literal_eval
@ -841,3 +841,21 @@ def get_markdown_codeblock_type(filename: str) -> str:
"application/sql": "sql",
}
return mappings.get(mime_type, "text")
def download_model(file_url: str, target_folder: Path) -> Path:
file_name = file_url.split('/')[-1]
file_path = target_folder.joinpath(f"{file_name}")
if not file_path.exists():
file_path.mkdir(parents=True, exist_ok=True)
try:
response = requests.get(file_url, stream=True)
response.raise_for_status() # 检查请求是否成功
# 保存文件
with open(file_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
logger.info(f'权重文件已下载并保存至 {file_path}')
except requests.exceptions.HTTPError as err:
logger.info(f'权重文件下载过程中发生错误: {err}')
return file_path

View file

@ -1,22 +0,0 @@
import os
import requests
from pathlib import Path
def download_model(file_url: str, target_folder: str) -> str:
file_name = file_url.split('/')[-1] # 文件名从URL中提取
file_path = os.path.join(target_folder, file_name) # 完整的文件路径
if not os.path.exists(target_folder):
os.makedirs(target_folder)
# 发起GET请求下载文件
try:
response = requests.get(file_url, stream=True)
response.raise_for_status() # 检查请求是否成功
# 保存文件
with open(file_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print(f'权重文件已下载并保存至 {file_path}')
except requests.exceptions.HTTPError as err:
print(f'权重文件下载过程中发生错误: {err}')
return file_path

View file

@ -71,23 +71,4 @@ dashscope==1.14.1
rank-bm25==0.2.2 # for tool recommendation
jieba==0.42.1 # for tool recommendation
gymnasium==0.29.1
# for clip and ocr
git+https://github.com/openai/CLIP.git
protobuf<3.20,>=3.9.2
modelscope
tensorflow==2.9.1; os_name == 'linux'
tensorflow-macos==2.9; os_name == 'darwin'
keras==2.9.0
torch
torchvision
transformers
opencv-python
matplotlib
pycocotools
timm
SentencePiece
tf_slim
tf_keras
pyclipper
shapely

View file

@ -42,7 +42,28 @@ extras_require = {
"llama-index-vector-stores-chroma==0.1.6",
"docx2txt==0.8",
],
"android_assistant": ["pyshine==0.0.9", "opencv-python==4.6.0.66"],
"android_assistant": [
"pyshine==0.0.9",
"opencv-python==4.6.0.66",
"git+https://github.com/openai/CLIP.git",
"protobuf<3.20,>=3.9.2",
"modelscope",
"tensorflow==2.9.1; os_name == 'linux'",
"tensorflow-macos==2.9; os_name == 'darwin'",
"keras==2.9.0",
"torch",
"torchvision",
"transformers",
"opencv-python",
"matplotlib",
"pycocotools",
"SentencePiece",
"tf_slim",
"tf_keras",
"pyclipper",
"shapely",
"groundingdino-py",
],
}
extras_require["test"] = [