Merge pull request #1207 from kithib/main

add new android operation(open app,exit,stop)
This commit is contained in:
Alexander Wu 2024-05-17 19:08:00 +08:00 committed by GitHub
commit d9ed99e85f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 580 additions and 10 deletions

View file

@ -1,13 +1,19 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : The Android external environment to integrate with Android apps
import subprocess
import clip
import time
from pathlib import Path
from typing import Any, Optional
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from PIL import Image
from pydantic import Field
from metagpt.environment.android.text_icon_localization import *
from metagpt.environment.android.const import ADB_EXEC_FAIL
from metagpt.environment.android.env_space import (
EnvAction,
@ -17,6 +23,20 @@ from metagpt.environment.android.env_space import (
EnvObsValType,
)
from metagpt.environment.base_env import ExtEnv, mark_as_readable, mark_as_writeable
from metagpt.logs import logger
from metagpt.utils.common import download_model
from metagpt.const import DEFAULT_WORKSPACE_ROOT
def load_cv_model(device: str = "cpu") -> any:
ocr_detection = pipeline(Tasks.ocr_detection, model="damo/cv_resnet18_ocr-detection-line-level_damo")
ocr_recognition = pipeline(Tasks.ocr_recognition,
model="damo/cv_convnextTiny_ocr-recognition-document_damo")
file_url = "https://huggingface.co/ShilongLiu/GroundingDINO/blob/main/groundingdino_swint_ogc.pth"
target_folder = Path(f"{DEFAULT_WORKSPACE_ROOT}/weights")
file_path = download_model(file_url, target_folder)
groundingdino_model = load_model(file_path, device=device).eval()
return ocr_detection, ocr_recognition, groundingdino_model
class AndroidExtEnv(ExtEnv):
@ -25,10 +45,14 @@ class AndroidExtEnv(ExtEnv):
xml_dir: Optional[Path] = Field(default=None)
width: int = Field(default=720, description="device screen width")
height: int = Field(default=1080, description="device screen height")
ocr_detection: any = Field(default=None, description="ocr detection model")
ocr_recognition: any = Field(default=None, description="ocr recognition model")
groundingdino_model: any = Field(default=None, description="clip groundingdino model")
def __init__(self, **data: Any):
super().__init__(**data)
device_id = data.get("device_id")
self.ocr_detection, self.ocr_recognition, self.groundingdino_model = load_cv_model()
if device_id:
devices = self.list_devices()
if device_id not in devices:
@ -36,15 +60,14 @@ class AndroidExtEnv(ExtEnv):
(width, height) = self.device_shape
self.width = data.get("width", width)
self.height = data.get("height", height)
self.create_device_path(self.screenshot_dir)
self.create_device_path(self.xml_dir)
def reset(
self,
*,
seed: Optional[int] = None,
options: Optional[dict[str, Any]] = None,
self,
*,
seed: Optional[int] = None,
options: Optional[dict[str, Any]] = None,
) -> tuple[dict[str, Any], dict[str, Any]]:
super().reset(seed=seed, options=options)
@ -149,14 +172,26 @@ class AndroidExtEnv(ExtEnv):
ss_remote_path = Path(self.screenshot_dir).joinpath(f"{ss_name}.png")
ss_cmd = f"{self.adb_prefix_shell} screencap -p {ss_remote_path}"
ss_res = self.execute_adb_with_cmd(ss_cmd)
time.sleep(0.1)
res = ADB_EXEC_FAIL
if ss_res != ADB_EXEC_FAIL:
ss_local_path = Path(local_save_dir).joinpath(f"{ss_name}.png")
pull_cmd = f"{self.adb_prefix} pull {ss_remote_path} {ss_local_path}"
pull_res = self.execute_adb_with_cmd(pull_cmd)
time.sleep(0.1)
if pull_res != ADB_EXEC_FAIL:
res = ss_local_path
else:
ss_cmd = f"{self.adb_prefix_shell} rm /sdcard/{ss_name}.png"
ss_res = self.execute_adb_with_cmd(ss_cmd)
time.sleep(0.1)
ss_cmd = f"{self.adb_prefix_shell} screencap -p /sdcard/{ss_name}.png"
ss_res = self.execute_adb_with_cmd(ss_cmd)
time.sleep(0.1)
ss_cmd = f"{self.adb_prefix} pull /sdcard/{ss_name}.png {self.screenshot_dir}"
ss_res = self.execute_adb_with_cmd(ss_cmd)
image_path = Path(f"{self.screenshot_dir}/{ss_name}.png")
res = image_path
return Path(res)
@mark_as_readable
@ -224,7 +259,94 @@ class AndroidExtEnv(ExtEnv):
return swipe_res
@mark_as_writeable
def user_swipe_to(self, start: tuple[int, int], end: tuple[int, int], duration: int = 400):
def user_swipe_to(self, start: tuple[int, int], end: tuple[int, int], duration: int = 400) -> str:
adb_cmd = f"{self.adb_prefix_si} swipe {start[0]} {start[1]} {end[0]} {end[1]} {duration}"
swipe_res = self.execute_adb_with_cmd(adb_cmd)
return swipe_res
@mark_as_writeable
def user_exit(self) -> str:
adb_cmd = f"{self.adb_prefix_shell} am start -a android.intent.action.MAIN -c android.intent.category.HOME"
exit_res = self.execute_adb_with_cmd(adb_cmd)
return exit_res
def _ocr_text(self, text: str) -> list:
image = self.get_screenshot("screenshot", self.screenshot_dir)
iw, ih = Image.open(image).size
x, y = self.device_shape
if iw > ih:
x, y = y, x
iw, ih = ih, iw
in_coordinate, out_coordinate = ocr(image, text, self.ocr_detection, self.ocr_recognition, iw, ih)
output_list = [in_coordinate, out_coordinate, x, y, iw, ih, image]
return output_list
@mark_as_writeable
def user_open_app(self, app_name: str) -> str:
ocr_result = self._ocr_text(app_name)
in_coordinate, out_coordinate, x, y, iw, ih = (
ocr_result[0], ocr_result[1], ocr_result[2], ocr_result[3], ocr_result[4], ocr_result[5])
if len(in_coordinate) == 0:
logger.info(f"No App named {app_name}.")
return "no app here"
else:
tap_coordinate = [
(in_coordinate[0][0] + in_coordinate[0][2]) / 2,
(in_coordinate[0][1] + in_coordinate[0][3]) / 2,
]
tap_coordinate = [round(tap_coordinate[0] / iw, 2), round(tap_coordinate[1] / ih, 2)]
return self.system_tap(tap_coordinate[0] * x, (tap_coordinate[1] - round(50 / y, 2)) * y)
@mark_as_writeable
def user_click_text(self, text: str) -> str:
ocr_result = self._ocr_text(text)
in_coordinate, out_coordinate, x, y, iw, ih, image = (
ocr_result[0], ocr_result[1], ocr_result[2], ocr_result[3], ocr_result[4], ocr_result[5], ocr_result[6])
if len(out_coordinate) == 0:
logger.info(
f"Failed to execute action click text ({text}). The text \"{text}\" is not detected in the screenshot.")
elif len(out_coordinate) == 1:
tap_coordinate = [(in_coordinate[0][0] + in_coordinate[0][2]) / 2,
(in_coordinate[0][1] + in_coordinate[0][3]) / 2]
tap_coordinate = [round(tap_coordinate[0] / iw, 2), round(tap_coordinate[1] / ih, 2)]
return self.system_tap(tap_coordinate[0] * x, tap_coordinate[1] * y)
else:
logger.info(
f"Failed to execute action click text ({text}). There are too many text \"{text}\" in the screenshot.")
@mark_as_writeable
def user_stop(self):
logger.info("Successful execution of tasks")
@mark_as_writeable
def user_click_icon(self, icon_shape_color: str) -> str:
screenshot_path = self.get_screenshot("screenshot", self.screenshot_dir)
image= screenshot_path
iw, ih = Image.open(image).size
x, y = self.device_shape
if iw > ih:
x, y = y, x
iw, ih = ih, iw
in_coordinate, out_coordinate = det(image, "icon", self.groundingdino_model) # 检测icon
if len(out_coordinate) == 1: # only one icon
tap_coordinate = [(in_coordinate[0][0] + in_coordinate[0][2]) / 2,
(in_coordinate[0][1] + in_coordinate[0][3]) / 2]
tap_coordinate = [round(tap_coordinate[0] / iw, 2), round(tap_coordinate[1] / ih, 2)]
return self.system_tap(tap_coordinate[0] * x, tap_coordinate[1] * y)
else:
temp_file = Path(f"{DEFAULT_WORKSPACE_ROOT}/temp")
temp_file.mkdir(parents=True, exist_ok=True)
hash_table, clip_filter = [], []
for i, (td, box) in enumerate(zip(in_coordinate, out_coordinate)):
if crop_for_clip(image, td, i, temp_file):
hash_table.append(td)
crop_image = f"{i}.png"
clip_filter.append(temp_file.joinpath(crop_image))
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
clip_filter = clip_for_icon(clip_model, clip_preprocess, clip_filter, icon_shape_color)
final_box = hash_table[clip_filter]
tap_coordinate = [(final_box[0] + final_box[2]) / 2, (final_box[1] + final_box[3]) / 2]
tap_coordinate = [round(tap_coordinate[0] / iw, 2), round(tap_coordinate[1] / ih, 2)]
print(tap_coordinate[0] * x, tap_coordinate[1] * y)
return self.system_tap(tap_coordinate[0] * x, tap_coordinate[1] * y)

View file

@ -0,0 +1,43 @@
batch_size = 1
modelname = "groundingdino"
backbone = "swin_T_224_1k"
position_embedding = "sine"
pe_temperatureH = 20
pe_temperatureW = 20
return_interm_indices = [1, 2, 3]
backbone_freeze_keywords = None
enc_layers = 6
dec_layers = 6
pre_norm = False
dim_feedforward = 2048
hidden_dim = 256
dropout = 0.0
nheads = 8
num_queries = 900
query_dim = 4
num_patterns = 0
num_feature_levels = 4
enc_n_points = 4
dec_n_points = 4
two_stage_type = "standard"
two_stage_bbox_embed_share = False
two_stage_class_embed_share = False
transformer_activation = "relu"
dec_pred_bbox_embed_share = True
dn_box_noise_scale = 1.0
dn_label_noise_ratio = 0.5
dn_label_coef = 1.0
dn_bbox_coef = 1.0
embed_init_tgt = True
dn_labelbook_size = 2000
max_text_len = 256
text_encoder_type = "bert-base-uncased"
use_text_enhancer = True
use_fusion_layer = True
use_checkpoint = True
use_transformer_ckpt = True
use_text_cross_attention = True
text_dropout = 0.0
fusion_dropout = 0.0
fusion_droppath = 0.1
sub_sentence_present = True

View file

@ -0,0 +1,363 @@
# The code in this file was modified by MobileAgent
# https://github.com/X-PLUG/MobileAgent.git
import math
import clip
import cv2
import numpy as np
import torch
import subprocess
import time
from pathlib import Path
import groundingdino.datasets.transforms as T
from groundingdino.models import build_model
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
from PIL import Image, ImageDraw
################################## text_localization using ocr #######################
def crop_image(img: any, position: any) -> any:
def distance(x1, y1, x2, y2):
return math.sqrt(pow(x1 - x2, 2) + pow(y1 - y2, 2))
position = position.tolist()
for i in range(4):
for j in range(i + 1, 4):
if position[i][0] > position[j][0]:
tmp = position[j]
position[j] = position[i]
position[i] = tmp
if position[0][1] > position[1][1]:
tmp = position[0]
position[0] = position[1]
position[1] = tmp
if position[2][1] > position[3][1]:
tmp = position[2]
position[2] = position[3]
position[3] = tmp
x1, y1 = position[0][0], position[0][1]
x2, y2 = position[2][0], position[2][1]
x3, y3 = position[3][0], position[3][1]
x4, y4 = position[1][0], position[1][1]
corners = np.zeros((4, 2), np.float32)
corners[0] = [x1, y1]
corners[1] = [x2, y2]
corners[2] = [x4, y4]
corners[3] = [x3, y3]
img_width = distance((x1 + x4) / 2, (y1 + y4) / 2, (x2 + x3) / 2, (y2 + y3) / 2)
img_height = distance((x1 + x2) / 2, (y1 + y2) / 2, (x4 + x3) / 2, (y4 + y3) / 2)
corners_trans = np.zeros((4, 2), np.float32)
corners_trans[0] = [0, 0]
corners_trans[1] = [img_width - 1, 0]
corners_trans[2] = [0, img_height - 1]
corners_trans[3] = [img_width - 1, img_height - 1]
transform = cv2.getPerspectiveTransform(corners, corners_trans)
dst = cv2.warpPerspective(img, transform, (int(img_width), int(img_height)))
return dst
def calculate_size(box: any) -> any:
return (box[2] - box[0]) * (box[3] - box[1])
def order_point(cooperation: any) -> any:
arr = np.array(cooperation).reshape([4, 2])
sum_ = np.sum(arr, 0)
centroid = sum_ / arr.shape[0]
theta = np.arctan2(arr[:, 1] - centroid[1], arr[:, 0] - centroid[0])
sort_points = arr[np.argsort(theta)]
sort_points = sort_points.reshape([4, -1])
if sort_points[0][0] > centroid[0]:
sort_points = np.concatenate([sort_points[3:], sort_points[:3]])
sort_points = sort_points.reshape([4, 2]).astype("float32")
return sort_points
def longest_common_substring_length(str1: str, str2: str) -> int:
m = len(str1)
n = len(str2)
dp = [[0] * (n + 1) for _ in range(m + 1)]
for i in range(1, m + 1):
for j in range(1, n + 1):
if str1[i - 1] == str2[j - 1]:
dp[i][j] = dp[i - 1][j - 1] + 1
else:
dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])
return dp[m][n]
def ocr(image_path: Path, prompt: str, ocr_detection: any, ocr_recognition: any, x: int, y: int) -> any:
text_data = []
coordinate = []
image = Image.open(image_path)
iw, ih = image.size
image_full = cv2.imread(str(image_path))
det_result = ocr_detection(image_full)
det_result = det_result["polygons"]
for i in range(det_result.shape[0]):
pts = order_point(det_result[i])
image_crop = crop_image(image_full, pts)
result = ocr_recognition(image_crop)["text"][0]
if result == prompt:
box = [int(e) for e in list(pts.reshape(-1))]
box = [box[0], box[1], box[4], box[5]]
if calculate_size(box) > 0.05 * iw * ih:
continue
text_data.append(
[
int(max(0, box[0] - 10) * x / iw),
int(max(0, box[1] - 10) * y / ih),
int(min(box[2] + 10, iw) * x / iw),
int(min(box[3] + 10, ih) * y / ih),
]
)
coordinate.append(
[
int(max(0, box[0] - 300) * x / iw),
int(max(0, box[1] - 400) * y / ih),
int(min(box[2] + 300, iw) * x / iw),
int(min(box[3] + 400, ih) * y / ih),
]
)
max_length = 0
if len(text_data) == 0:
for i in range(det_result.shape[0]):
pts = order_point(det_result[i])
image_crop = crop_image(image_full, pts)
result = ocr_recognition(image_crop)["text"][0]
if len(result) < 0.3 * len(prompt):
continue
if result in prompt:
now_length = len(result)
else:
now_length = longest_common_substring_length(result, prompt)
if now_length > max_length:
max_length = now_length
box = [int(e) for e in list(pts.reshape(-1))]
box = [box[0], box[1], box[4], box[5]]
text_data = [
[
int(max(0, box[0] - 10) * x / iw),
int(max(0, box[1] - 10) * y / ih),
int(min(box[2] + 10, iw) * x / iw),
int(min(box[3] + 10, ih) * y / ih),
]
]
coordinate = [
[
int(max(0, box[0] - 300) * x / iw),
int(max(0, box[1] - 400) * y / ih),
int(min(box[2] + 300, iw) * x / iw),
int(min(box[3] + 400, ih) * y / ih),
]
]
if len(prompt) <= 10:
if max_length >= 0.8 * len(prompt):
return text_data, coordinate
else:
return [], []
elif (len(prompt) > 10) and (len(prompt) <= 20):
if max_length >= 0.5 * len(prompt):
return text_data, coordinate
else:
return [], []
else:
if max_length >= 0.4 * len(prompt):
return text_data, coordinate
else:
return [], []
else:
return text_data, coordinate
################################## icon_localization using clip #######################
def calculate_iou(box1: list, box2: list) -> float:
x_a = max(box1[0], box2[0])
y_a = max(box1[1], box2[1])
x_b = min(box1[2], box2[2])
y_b = min(box1[3], box2[3])
inter_area = max(0, x_b - x_a) * max(0, y_b - y_a)
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
union_area = box1_area + box2_area - inter_area
iou = inter_area / union_area
return iou
def in_box(box: list, target: list) -> bool:
if (box[0] > target[0]) and (box[1] > target[1]) and (box[2] < target[2]) and (box[3] < target[3]):
return True
else:
return False
def crop_for_clip(image: any, box: any, i: int, temp_file: Path) -> bool:
image = Image.open(image)
w, h = image.size
bound = [0, 0, w, h]
if in_box(box, bound):
cropped_image = image.crop(box)
cropped_image.save(temp_file.joinpath(f"{i}.png"))
return True
else:
return False
def clip_for_icon(clip_model: any, clip_preprocess: any, images: any, prompt: str) -> any:
image_features = []
for image_file in images:
image = clip_preprocess(Image.open(image_file)).unsqueeze(0).to(next(clip_model.parameters()).device)
image_feature = clip_model.encode_image(image)
image_features.append(image_feature)
image_features = torch.cat(image_features)
text = clip.tokenize([prompt]).to(next(clip_model.parameters()).device)
text_features = clip_model.encode_text(text)
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=0).squeeze(0)
_, max_pos = torch.max(similarity, dim=0)
pos = max_pos.item()
return pos
def transform_image(image_pil: any) -> any:
transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
image, _ = transform(image_pil, None) # 3, h, w
return image
def load_model(model_checkpoint_path: Path, device: str) -> any:
model_config_path = "grounding_dino_config.py"
args = SLConfig.fromfile(model_config_path)
args.device = device
model = build_model(args)
checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
print(load_res)
_ = model.eval()
return model
def get_grounding_output(model: any, image: any, caption: str, box_threshold: any, text_threshold: any, with_logits: bool = True) -> any:
caption = caption.lower()
caption = caption.strip()
if not caption.endswith("."):
caption = caption + "."
with torch.no_grad():
outputs = model(image[None], captions=[caption])
logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
logits.shape[0]
logits_filt = logits.clone()
boxes_filt = boxes.clone()
filt_mask = logits_filt.max(dim=1)[0] > box_threshold
logits_filt = logits_filt[filt_mask] # num_filt, 256
boxes_filt = boxes_filt[filt_mask] # num_filt, 4
logits_filt.shape[0]
tokenlizer = model.tokenizer
tokenized = tokenlizer(caption)
pred_phrases = []
scores = []
for logit, box in zip(logits_filt, boxes_filt):
pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
if with_logits:
pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
else:
pred_phrases.append(pred_phrase)
scores.append(logit.max().item())
return boxes_filt, torch.Tensor(scores), pred_phrases
def remove_boxes(boxes_filt: any, size: any, iou_threshold: float = 0.5) -> any:
boxes_to_remove = set()
for i in range(len(boxes_filt)):
if calculate_size(boxes_filt[i]) > 0.05 * size[0] * size[1]:
boxes_to_remove.add(i)
for j in range(len(boxes_filt)):
if calculate_size(boxes_filt[j]) > 0.05 * size[0] * size[1]:
boxes_to_remove.add(j)
if i == j:
continue
if i in boxes_to_remove or j in boxes_to_remove:
continue
iou = calculate_iou(boxes_filt[i], boxes_filt[j])
if iou >= iou_threshold:
boxes_to_remove.add(j)
boxes_filt = [box for idx, box in enumerate(boxes_filt) if idx not in boxes_to_remove]
return boxes_filt
def det(input_image: any, text_prompt: str, groundingdino_model: any, box_threshold:float = 0.05, text_threshold:float = 0.5) -> any:
image = Image.open(input_image)
size = image.size
image_pil = image.convert("RGB")
image = np.array(image_pil)
transformed_image = transform_image(image_pil)
boxes_filt, scores, pred_phrases = get_grounding_output(
groundingdino_model, transformed_image, text_prompt, box_threshold, text_threshold
)
H, W = size[1], size[0]
for i in range(boxes_filt.size(0)):
boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
boxes_filt[i][2:] += boxes_filt[i][:2]
boxes_filt = boxes_filt.cpu().int().tolist()
filtered_boxes = remove_boxes(boxes_filt, size) # [:9]
coordinate = []
image_data = []
for box in filtered_boxes:
image_data.append(
[max(0, box[0] - 10), max(0, box[1] - 10), min(box[2] + 10, size[0]), min(box[3] + 10, size[1])]
)
coordinate.append(
[max(0, box[0] - 25), max(0, box[1] - 25), min(box[2] + 25, size[0]), min(box[3] + 25, size[1])]
)
return image_data, coordinate

View file

@ -219,7 +219,7 @@ class OutputParser:
if start_index != -1 and end_index != -1:
# Extract the structure part
structure_text = text[start_index : end_index + 1]
structure_text = text[start_index: end_index + 1]
try:
# Attempt to convert the text to a Python data type using ast.literal_eval
@ -841,3 +841,21 @@ def get_markdown_codeblock_type(filename: str) -> str:
"application/sql": "sql",
}
return mappings.get(mime_type, "text")
def download_model(file_url: str, target_folder: Path) -> Path:
file_name = file_url.split('/')[-1]
file_path = target_folder.joinpath(f"{file_name}")
if not file_path.exists():
file_path.mkdir(parents=True, exist_ok=True)
try:
response = requests.get(file_url, stream=True)
response.raise_for_status() # 检查请求是否成功
# 保存文件
with open(file_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
logger.info(f'权重文件已下载并保存至 {file_path}')
except requests.exceptions.HTTPError as err:
logger.info(f'权重文件下载过程中发生错误: {err}')
return file_path

View file

@ -45,7 +45,30 @@ extras_require = {
"llama-index-postprocessor-flag-embedding-reranker==0.1.2",
"docx2txt==0.8",
],
"android_assistant": ["pyshine==0.0.9", "opencv-python==4.6.0.66"],
"android_assistant": [
"pyshine==0.0.9",
"opencv-python==4.6.0.66",
"protobuf<3.20,>=3.9.2",
"modelscope",
"tensorflow==2.9.1; os_name == 'linux'",
"tensorflow==2.9.1; os_name == 'win32'",
"tensorflow-macos==2.9; os_name == 'darwin'",
"keras==2.9.0",
"torch",
"torchvision",
"transformers",
"opencv-python",
"matplotlib",
"pycocotools",
"SentencePiece",
"tf_slim",
"tf_keras",
"pyclipper",
"shapely",
"groundingdino-py",
"datasets==2.18.0",
"clip-openai"
],
}
extras_require["test"] = [
@ -96,4 +119,5 @@ setup(
],
},
include_package_data=True,
)