mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-24 14:15:17 +02:00
Merge pull request #1207 from kithib/main
add new android operation(open app,exit,stop)
This commit is contained in:
commit
d9ed99e85f
5 changed files with 580 additions and 10 deletions
|
|
@ -1,13 +1,19 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : The Android external environment to integrate with Android apps
|
||||
|
||||
import subprocess
|
||||
import clip
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
from PIL import Image
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.environment.android.text_icon_localization import *
|
||||
from metagpt.environment.android.const import ADB_EXEC_FAIL
|
||||
from metagpt.environment.android.env_space import (
|
||||
EnvAction,
|
||||
|
|
@ -17,6 +23,20 @@ from metagpt.environment.android.env_space import (
|
|||
EnvObsValType,
|
||||
)
|
||||
from metagpt.environment.base_env import ExtEnv, mark_as_readable, mark_as_writeable
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.common import download_model
|
||||
from metagpt.const import DEFAULT_WORKSPACE_ROOT
|
||||
|
||||
|
||||
def load_cv_model(device: str = "cpu") -> any:
|
||||
ocr_detection = pipeline(Tasks.ocr_detection, model="damo/cv_resnet18_ocr-detection-line-level_damo")
|
||||
ocr_recognition = pipeline(Tasks.ocr_recognition,
|
||||
model="damo/cv_convnextTiny_ocr-recognition-document_damo")
|
||||
file_url = "https://huggingface.co/ShilongLiu/GroundingDINO/blob/main/groundingdino_swint_ogc.pth"
|
||||
target_folder = Path(f"{DEFAULT_WORKSPACE_ROOT}/weights")
|
||||
file_path = download_model(file_url, target_folder)
|
||||
groundingdino_model = load_model(file_path, device=device).eval()
|
||||
return ocr_detection, ocr_recognition, groundingdino_model
|
||||
|
||||
|
||||
class AndroidExtEnv(ExtEnv):
|
||||
|
|
@ -25,10 +45,14 @@ class AndroidExtEnv(ExtEnv):
|
|||
xml_dir: Optional[Path] = Field(default=None)
|
||||
width: int = Field(default=720, description="device screen width")
|
||||
height: int = Field(default=1080, description="device screen height")
|
||||
ocr_detection: any = Field(default=None, description="ocr detection model")
|
||||
ocr_recognition: any = Field(default=None, description="ocr recognition model")
|
||||
groundingdino_model: any = Field(default=None, description="clip groundingdino model")
|
||||
|
||||
def __init__(self, **data: Any):
|
||||
super().__init__(**data)
|
||||
device_id = data.get("device_id")
|
||||
self.ocr_detection, self.ocr_recognition, self.groundingdino_model = load_cv_model()
|
||||
if device_id:
|
||||
devices = self.list_devices()
|
||||
if device_id not in devices:
|
||||
|
|
@ -36,15 +60,14 @@ class AndroidExtEnv(ExtEnv):
|
|||
(width, height) = self.device_shape
|
||||
self.width = data.get("width", width)
|
||||
self.height = data.get("height", height)
|
||||
|
||||
self.create_device_path(self.screenshot_dir)
|
||||
self.create_device_path(self.xml_dir)
|
||||
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
options: Optional[dict[str, Any]] = None,
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
options: Optional[dict[str, Any]] = None,
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
super().reset(seed=seed, options=options)
|
||||
|
||||
|
|
@ -149,14 +172,26 @@ class AndroidExtEnv(ExtEnv):
|
|||
ss_remote_path = Path(self.screenshot_dir).joinpath(f"{ss_name}.png")
|
||||
ss_cmd = f"{self.adb_prefix_shell} screencap -p {ss_remote_path}"
|
||||
ss_res = self.execute_adb_with_cmd(ss_cmd)
|
||||
|
||||
time.sleep(0.1)
|
||||
res = ADB_EXEC_FAIL
|
||||
if ss_res != ADB_EXEC_FAIL:
|
||||
ss_local_path = Path(local_save_dir).joinpath(f"{ss_name}.png")
|
||||
pull_cmd = f"{self.adb_prefix} pull {ss_remote_path} {ss_local_path}"
|
||||
pull_res = self.execute_adb_with_cmd(pull_cmd)
|
||||
time.sleep(0.1)
|
||||
if pull_res != ADB_EXEC_FAIL:
|
||||
res = ss_local_path
|
||||
else:
|
||||
ss_cmd = f"{self.adb_prefix_shell} rm /sdcard/{ss_name}.png"
|
||||
ss_res = self.execute_adb_with_cmd(ss_cmd)
|
||||
time.sleep(0.1)
|
||||
ss_cmd = f"{self.adb_prefix_shell} screencap -p /sdcard/{ss_name}.png"
|
||||
ss_res = self.execute_adb_with_cmd(ss_cmd)
|
||||
time.sleep(0.1)
|
||||
ss_cmd = f"{self.adb_prefix} pull /sdcard/{ss_name}.png {self.screenshot_dir}"
|
||||
ss_res = self.execute_adb_with_cmd(ss_cmd)
|
||||
image_path = Path(f"{self.screenshot_dir}/{ss_name}.png")
|
||||
res = image_path
|
||||
return Path(res)
|
||||
|
||||
@mark_as_readable
|
||||
|
|
@ -224,7 +259,94 @@ class AndroidExtEnv(ExtEnv):
|
|||
return swipe_res
|
||||
|
||||
@mark_as_writeable
|
||||
def user_swipe_to(self, start: tuple[int, int], end: tuple[int, int], duration: int = 400):
|
||||
def user_swipe_to(self, start: tuple[int, int], end: tuple[int, int], duration: int = 400) -> str:
|
||||
adb_cmd = f"{self.adb_prefix_si} swipe {start[0]} {start[1]} {end[0]} {end[1]} {duration}"
|
||||
swipe_res = self.execute_adb_with_cmd(adb_cmd)
|
||||
return swipe_res
|
||||
|
||||
@mark_as_writeable
|
||||
def user_exit(self) -> str:
|
||||
adb_cmd = f"{self.adb_prefix_shell} am start -a android.intent.action.MAIN -c android.intent.category.HOME"
|
||||
exit_res = self.execute_adb_with_cmd(adb_cmd)
|
||||
return exit_res
|
||||
|
||||
def _ocr_text(self, text: str) -> list:
|
||||
image = self.get_screenshot("screenshot", self.screenshot_dir)
|
||||
iw, ih = Image.open(image).size
|
||||
x, y = self.device_shape
|
||||
if iw > ih:
|
||||
x, y = y, x
|
||||
iw, ih = ih, iw
|
||||
in_coordinate, out_coordinate = ocr(image, text, self.ocr_detection, self.ocr_recognition, iw, ih)
|
||||
output_list = [in_coordinate, out_coordinate, x, y, iw, ih, image]
|
||||
return output_list
|
||||
|
||||
@mark_as_writeable
|
||||
def user_open_app(self, app_name: str) -> str:
|
||||
ocr_result = self._ocr_text(app_name)
|
||||
in_coordinate, out_coordinate, x, y, iw, ih = (
|
||||
ocr_result[0], ocr_result[1], ocr_result[2], ocr_result[3], ocr_result[4], ocr_result[5])
|
||||
if len(in_coordinate) == 0:
|
||||
logger.info(f"No App named {app_name}.")
|
||||
return "no app here"
|
||||
else:
|
||||
tap_coordinate = [
|
||||
(in_coordinate[0][0] + in_coordinate[0][2]) / 2,
|
||||
(in_coordinate[0][1] + in_coordinate[0][3]) / 2,
|
||||
]
|
||||
tap_coordinate = [round(tap_coordinate[0] / iw, 2), round(tap_coordinate[1] / ih, 2)]
|
||||
return self.system_tap(tap_coordinate[0] * x, (tap_coordinate[1] - round(50 / y, 2)) * y)
|
||||
|
||||
@mark_as_writeable
|
||||
def user_click_text(self, text: str) -> str:
|
||||
ocr_result = self._ocr_text(text)
|
||||
in_coordinate, out_coordinate, x, y, iw, ih, image = (
|
||||
ocr_result[0], ocr_result[1], ocr_result[2], ocr_result[3], ocr_result[4], ocr_result[5], ocr_result[6])
|
||||
if len(out_coordinate) == 0:
|
||||
logger.info(
|
||||
f"Failed to execute action click text ({text}). The text \"{text}\" is not detected in the screenshot.")
|
||||
elif len(out_coordinate) == 1:
|
||||
tap_coordinate = [(in_coordinate[0][0] + in_coordinate[0][2]) / 2,
|
||||
(in_coordinate[0][1] + in_coordinate[0][3]) / 2]
|
||||
tap_coordinate = [round(tap_coordinate[0] / iw, 2), round(tap_coordinate[1] / ih, 2)]
|
||||
return self.system_tap(tap_coordinate[0] * x, tap_coordinate[1] * y)
|
||||
else:
|
||||
logger.info(
|
||||
f"Failed to execute action click text ({text}). There are too many text \"{text}\" in the screenshot.")
|
||||
|
||||
@mark_as_writeable
|
||||
def user_stop(self):
|
||||
logger.info("Successful execution of tasks")
|
||||
|
||||
@mark_as_writeable
|
||||
def user_click_icon(self, icon_shape_color: str) -> str:
|
||||
screenshot_path = self.get_screenshot("screenshot", self.screenshot_dir)
|
||||
image= screenshot_path
|
||||
iw, ih = Image.open(image).size
|
||||
x, y = self.device_shape
|
||||
if iw > ih:
|
||||
x, y = y, x
|
||||
iw, ih = ih, iw
|
||||
in_coordinate, out_coordinate = det(image, "icon", self.groundingdino_model) # 检测icon
|
||||
if len(out_coordinate) == 1: # only one icon
|
||||
tap_coordinate = [(in_coordinate[0][0] + in_coordinate[0][2]) / 2,
|
||||
(in_coordinate[0][1] + in_coordinate[0][3]) / 2]
|
||||
tap_coordinate = [round(tap_coordinate[0] / iw, 2), round(tap_coordinate[1] / ih, 2)]
|
||||
return self.system_tap(tap_coordinate[0] * x, tap_coordinate[1] * y)
|
||||
|
||||
else:
|
||||
temp_file = Path(f"{DEFAULT_WORKSPACE_ROOT}/temp")
|
||||
temp_file.mkdir(parents=True, exist_ok=True)
|
||||
hash_table, clip_filter = [], []
|
||||
for i, (td, box) in enumerate(zip(in_coordinate, out_coordinate)):
|
||||
if crop_for_clip(image, td, i, temp_file):
|
||||
hash_table.append(td)
|
||||
crop_image = f"{i}.png"
|
||||
clip_filter.append(temp_file.joinpath(crop_image))
|
||||
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
|
||||
clip_filter = clip_for_icon(clip_model, clip_preprocess, clip_filter, icon_shape_color)
|
||||
final_box = hash_table[clip_filter]
|
||||
tap_coordinate = [(final_box[0] + final_box[2]) / 2, (final_box[1] + final_box[3]) / 2]
|
||||
tap_coordinate = [round(tap_coordinate[0] / iw, 2), round(tap_coordinate[1] / ih, 2)]
|
||||
print(tap_coordinate[0] * x, tap_coordinate[1] * y)
|
||||
return self.system_tap(tap_coordinate[0] * x, tap_coordinate[1] * y)
|
||||
|
|
|
|||
43
metagpt/environment/android/grounding_dino_config.py
Normal file
43
metagpt/environment/android/grounding_dino_config.py
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
batch_size = 1
|
||||
modelname = "groundingdino"
|
||||
backbone = "swin_T_224_1k"
|
||||
position_embedding = "sine"
|
||||
pe_temperatureH = 20
|
||||
pe_temperatureW = 20
|
||||
return_interm_indices = [1, 2, 3]
|
||||
backbone_freeze_keywords = None
|
||||
enc_layers = 6
|
||||
dec_layers = 6
|
||||
pre_norm = False
|
||||
dim_feedforward = 2048
|
||||
hidden_dim = 256
|
||||
dropout = 0.0
|
||||
nheads = 8
|
||||
num_queries = 900
|
||||
query_dim = 4
|
||||
num_patterns = 0
|
||||
num_feature_levels = 4
|
||||
enc_n_points = 4
|
||||
dec_n_points = 4
|
||||
two_stage_type = "standard"
|
||||
two_stage_bbox_embed_share = False
|
||||
two_stage_class_embed_share = False
|
||||
transformer_activation = "relu"
|
||||
dec_pred_bbox_embed_share = True
|
||||
dn_box_noise_scale = 1.0
|
||||
dn_label_noise_ratio = 0.5
|
||||
dn_label_coef = 1.0
|
||||
dn_bbox_coef = 1.0
|
||||
embed_init_tgt = True
|
||||
dn_labelbook_size = 2000
|
||||
max_text_len = 256
|
||||
text_encoder_type = "bert-base-uncased"
|
||||
use_text_enhancer = True
|
||||
use_fusion_layer = True
|
||||
use_checkpoint = True
|
||||
use_transformer_ckpt = True
|
||||
use_text_cross_attention = True
|
||||
text_dropout = 0.0
|
||||
fusion_dropout = 0.0
|
||||
fusion_droppath = 0.1
|
||||
sub_sentence_present = True
|
||||
363
metagpt/environment/android/text_icon_localization.py
Normal file
363
metagpt/environment/android/text_icon_localization.py
Normal file
|
|
@ -0,0 +1,363 @@
|
|||
# The code in this file was modified by MobileAgent
|
||||
# https://github.com/X-PLUG/MobileAgent.git
|
||||
|
||||
import math
|
||||
import clip
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import subprocess
|
||||
import time
|
||||
from pathlib import Path
|
||||
import groundingdino.datasets.transforms as T
|
||||
from groundingdino.models import build_model
|
||||
from groundingdino.util.slconfig import SLConfig
|
||||
from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
|
||||
################################## text_localization using ocr #######################
|
||||
|
||||
def crop_image(img: any, position: any) -> any:
|
||||
def distance(x1, y1, x2, y2):
|
||||
return math.sqrt(pow(x1 - x2, 2) + pow(y1 - y2, 2))
|
||||
|
||||
position = position.tolist()
|
||||
for i in range(4):
|
||||
for j in range(i + 1, 4):
|
||||
if position[i][0] > position[j][0]:
|
||||
tmp = position[j]
|
||||
position[j] = position[i]
|
||||
position[i] = tmp
|
||||
if position[0][1] > position[1][1]:
|
||||
tmp = position[0]
|
||||
position[0] = position[1]
|
||||
position[1] = tmp
|
||||
|
||||
if position[2][1] > position[3][1]:
|
||||
tmp = position[2]
|
||||
position[2] = position[3]
|
||||
position[3] = tmp
|
||||
|
||||
x1, y1 = position[0][0], position[0][1]
|
||||
x2, y2 = position[2][0], position[2][1]
|
||||
x3, y3 = position[3][0], position[3][1]
|
||||
x4, y4 = position[1][0], position[1][1]
|
||||
|
||||
corners = np.zeros((4, 2), np.float32)
|
||||
corners[0] = [x1, y1]
|
||||
corners[1] = [x2, y2]
|
||||
corners[2] = [x4, y4]
|
||||
corners[3] = [x3, y3]
|
||||
|
||||
img_width = distance((x1 + x4) / 2, (y1 + y4) / 2, (x2 + x3) / 2, (y2 + y3) / 2)
|
||||
img_height = distance((x1 + x2) / 2, (y1 + y2) / 2, (x4 + x3) / 2, (y4 + y3) / 2)
|
||||
|
||||
corners_trans = np.zeros((4, 2), np.float32)
|
||||
corners_trans[0] = [0, 0]
|
||||
corners_trans[1] = [img_width - 1, 0]
|
||||
corners_trans[2] = [0, img_height - 1]
|
||||
corners_trans[3] = [img_width - 1, img_height - 1]
|
||||
|
||||
transform = cv2.getPerspectiveTransform(corners, corners_trans)
|
||||
dst = cv2.warpPerspective(img, transform, (int(img_width), int(img_height)))
|
||||
return dst
|
||||
|
||||
|
||||
def calculate_size(box: any) -> any:
|
||||
return (box[2] - box[0]) * (box[3] - box[1])
|
||||
|
||||
|
||||
def order_point(cooperation: any) -> any:
|
||||
arr = np.array(cooperation).reshape([4, 2])
|
||||
sum_ = np.sum(arr, 0)
|
||||
centroid = sum_ / arr.shape[0]
|
||||
theta = np.arctan2(arr[:, 1] - centroid[1], arr[:, 0] - centroid[0])
|
||||
sort_points = arr[np.argsort(theta)]
|
||||
sort_points = sort_points.reshape([4, -1])
|
||||
if sort_points[0][0] > centroid[0]:
|
||||
sort_points = np.concatenate([sort_points[3:], sort_points[:3]])
|
||||
sort_points = sort_points.reshape([4, 2]).astype("float32")
|
||||
return sort_points
|
||||
|
||||
|
||||
def longest_common_substring_length(str1: str, str2: str) -> int:
|
||||
m = len(str1)
|
||||
n = len(str2)
|
||||
dp = [[0] * (n + 1) for _ in range(m + 1)]
|
||||
for i in range(1, m + 1):
|
||||
for j in range(1, n + 1):
|
||||
if str1[i - 1] == str2[j - 1]:
|
||||
dp[i][j] = dp[i - 1][j - 1] + 1
|
||||
else:
|
||||
dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])
|
||||
|
||||
return dp[m][n]
|
||||
|
||||
|
||||
def ocr(image_path: Path, prompt: str, ocr_detection: any, ocr_recognition: any, x: int, y: int) -> any:
|
||||
text_data = []
|
||||
coordinate = []
|
||||
image = Image.open(image_path)
|
||||
iw, ih = image.size
|
||||
|
||||
image_full = cv2.imread(str(image_path))
|
||||
det_result = ocr_detection(image_full)
|
||||
det_result = det_result["polygons"]
|
||||
for i in range(det_result.shape[0]):
|
||||
pts = order_point(det_result[i])
|
||||
image_crop = crop_image(image_full, pts)
|
||||
result = ocr_recognition(image_crop)["text"][0]
|
||||
|
||||
if result == prompt:
|
||||
box = [int(e) for e in list(pts.reshape(-1))]
|
||||
box = [box[0], box[1], box[4], box[5]]
|
||||
|
||||
if calculate_size(box) > 0.05 * iw * ih:
|
||||
continue
|
||||
|
||||
text_data.append(
|
||||
[
|
||||
int(max(0, box[0] - 10) * x / iw),
|
||||
int(max(0, box[1] - 10) * y / ih),
|
||||
int(min(box[2] + 10, iw) * x / iw),
|
||||
int(min(box[3] + 10, ih) * y / ih),
|
||||
]
|
||||
)
|
||||
coordinate.append(
|
||||
[
|
||||
int(max(0, box[0] - 300) * x / iw),
|
||||
int(max(0, box[1] - 400) * y / ih),
|
||||
int(min(box[2] + 300, iw) * x / iw),
|
||||
int(min(box[3] + 400, ih) * y / ih),
|
||||
]
|
||||
)
|
||||
|
||||
max_length = 0
|
||||
if len(text_data) == 0:
|
||||
for i in range(det_result.shape[0]):
|
||||
pts = order_point(det_result[i])
|
||||
image_crop = crop_image(image_full, pts)
|
||||
result = ocr_recognition(image_crop)["text"][0]
|
||||
|
||||
if len(result) < 0.3 * len(prompt):
|
||||
continue
|
||||
|
||||
if result in prompt:
|
||||
now_length = len(result)
|
||||
else:
|
||||
now_length = longest_common_substring_length(result, prompt)
|
||||
|
||||
if now_length > max_length:
|
||||
max_length = now_length
|
||||
box = [int(e) for e in list(pts.reshape(-1))]
|
||||
box = [box[0], box[1], box[4], box[5]]
|
||||
|
||||
text_data = [
|
||||
[
|
||||
int(max(0, box[0] - 10) * x / iw),
|
||||
int(max(0, box[1] - 10) * y / ih),
|
||||
int(min(box[2] + 10, iw) * x / iw),
|
||||
int(min(box[3] + 10, ih) * y / ih),
|
||||
]
|
||||
]
|
||||
coordinate = [
|
||||
[
|
||||
int(max(0, box[0] - 300) * x / iw),
|
||||
int(max(0, box[1] - 400) * y / ih),
|
||||
int(min(box[2] + 300, iw) * x / iw),
|
||||
int(min(box[3] + 400, ih) * y / ih),
|
||||
]
|
||||
]
|
||||
|
||||
if len(prompt) <= 10:
|
||||
if max_length >= 0.8 * len(prompt):
|
||||
return text_data, coordinate
|
||||
else:
|
||||
return [], []
|
||||
elif (len(prompt) > 10) and (len(prompt) <= 20):
|
||||
if max_length >= 0.5 * len(prompt):
|
||||
return text_data, coordinate
|
||||
else:
|
||||
return [], []
|
||||
else:
|
||||
if max_length >= 0.4 * len(prompt):
|
||||
return text_data, coordinate
|
||||
else:
|
||||
return [], []
|
||||
|
||||
else:
|
||||
return text_data, coordinate
|
||||
|
||||
|
||||
################################## icon_localization using clip #######################
|
||||
|
||||
|
||||
def calculate_iou(box1: list, box2: list) -> float:
|
||||
x_a = max(box1[0], box2[0])
|
||||
y_a = max(box1[1], box2[1])
|
||||
x_b = min(box1[2], box2[2])
|
||||
y_b = min(box1[3], box2[3])
|
||||
|
||||
inter_area = max(0, x_b - x_a) * max(0, y_b - y_a)
|
||||
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
||||
box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
||||
union_area = box1_area + box2_area - inter_area
|
||||
iou = inter_area / union_area
|
||||
|
||||
return iou
|
||||
|
||||
|
||||
def in_box(box: list, target: list) -> bool:
|
||||
if (box[0] > target[0]) and (box[1] > target[1]) and (box[2] < target[2]) and (box[3] < target[3]):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def crop_for_clip(image: any, box: any, i: int, temp_file: Path) -> bool:
|
||||
image = Image.open(image)
|
||||
w, h = image.size
|
||||
bound = [0, 0, w, h]
|
||||
if in_box(box, bound):
|
||||
cropped_image = image.crop(box)
|
||||
cropped_image.save(temp_file.joinpath(f"{i}.png"))
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def clip_for_icon(clip_model: any, clip_preprocess: any, images: any, prompt: str) -> any:
|
||||
image_features = []
|
||||
for image_file in images:
|
||||
image = clip_preprocess(Image.open(image_file)).unsqueeze(0).to(next(clip_model.parameters()).device)
|
||||
image_feature = clip_model.encode_image(image)
|
||||
image_features.append(image_feature)
|
||||
image_features = torch.cat(image_features)
|
||||
|
||||
text = clip.tokenize([prompt]).to(next(clip_model.parameters()).device)
|
||||
text_features = clip_model.encode_text(text)
|
||||
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||
similarity = (100.0 * image_features @ text_features.T).softmax(dim=0).squeeze(0)
|
||||
_, max_pos = torch.max(similarity, dim=0)
|
||||
pos = max_pos.item()
|
||||
|
||||
return pos
|
||||
|
||||
|
||||
def transform_image(image_pil: any) -> any:
|
||||
transform = T.Compose(
|
||||
[
|
||||
T.RandomResize([800], max_size=1333),
|
||||
T.ToTensor(),
|
||||
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||
]
|
||||
)
|
||||
image, _ = transform(image_pil, None) # 3, h, w
|
||||
return image
|
||||
|
||||
|
||||
def load_model(model_checkpoint_path: Path, device: str) -> any:
|
||||
model_config_path = "grounding_dino_config.py"
|
||||
args = SLConfig.fromfile(model_config_path)
|
||||
args.device = device
|
||||
model = build_model(args)
|
||||
checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
|
||||
load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
|
||||
print(load_res)
|
||||
_ = model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def get_grounding_output(model: any, image: any, caption: str, box_threshold: any, text_threshold: any, with_logits: bool = True) -> any:
|
||||
caption = caption.lower()
|
||||
caption = caption.strip()
|
||||
if not caption.endswith("."):
|
||||
caption = caption + "."
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(image[None], captions=[caption])
|
||||
logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
|
||||
boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
|
||||
logits.shape[0]
|
||||
|
||||
logits_filt = logits.clone()
|
||||
boxes_filt = boxes.clone()
|
||||
filt_mask = logits_filt.max(dim=1)[0] > box_threshold
|
||||
logits_filt = logits_filt[filt_mask] # num_filt, 256
|
||||
boxes_filt = boxes_filt[filt_mask] # num_filt, 4
|
||||
logits_filt.shape[0]
|
||||
|
||||
tokenlizer = model.tokenizer
|
||||
tokenized = tokenlizer(caption)
|
||||
|
||||
pred_phrases = []
|
||||
scores = []
|
||||
for logit, box in zip(logits_filt, boxes_filt):
|
||||
pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
|
||||
if with_logits:
|
||||
pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
|
||||
else:
|
||||
pred_phrases.append(pred_phrase)
|
||||
scores.append(logit.max().item())
|
||||
|
||||
return boxes_filt, torch.Tensor(scores), pred_phrases
|
||||
|
||||
|
||||
def remove_boxes(boxes_filt: any, size: any, iou_threshold: float = 0.5) -> any:
|
||||
boxes_to_remove = set()
|
||||
|
||||
for i in range(len(boxes_filt)):
|
||||
if calculate_size(boxes_filt[i]) > 0.05 * size[0] * size[1]:
|
||||
boxes_to_remove.add(i)
|
||||
for j in range(len(boxes_filt)):
|
||||
if calculate_size(boxes_filt[j]) > 0.05 * size[0] * size[1]:
|
||||
boxes_to_remove.add(j)
|
||||
if i == j:
|
||||
continue
|
||||
if i in boxes_to_remove or j in boxes_to_remove:
|
||||
continue
|
||||
iou = calculate_iou(boxes_filt[i], boxes_filt[j])
|
||||
if iou >= iou_threshold:
|
||||
boxes_to_remove.add(j)
|
||||
|
||||
boxes_filt = [box for idx, box in enumerate(boxes_filt) if idx not in boxes_to_remove]
|
||||
|
||||
return boxes_filt
|
||||
|
||||
|
||||
def det(input_image: any, text_prompt: str, groundingdino_model: any, box_threshold:float = 0.05, text_threshold:float = 0.5) -> any:
|
||||
image = Image.open(input_image)
|
||||
size = image.size
|
||||
|
||||
image_pil = image.convert("RGB")
|
||||
image = np.array(image_pil)
|
||||
|
||||
transformed_image = transform_image(image_pil)
|
||||
boxes_filt, scores, pred_phrases = get_grounding_output(
|
||||
groundingdino_model, transformed_image, text_prompt, box_threshold, text_threshold
|
||||
)
|
||||
|
||||
H, W = size[1], size[0]
|
||||
for i in range(boxes_filt.size(0)):
|
||||
boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
|
||||
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
|
||||
boxes_filt[i][2:] += boxes_filt[i][:2]
|
||||
|
||||
boxes_filt = boxes_filt.cpu().int().tolist()
|
||||
filtered_boxes = remove_boxes(boxes_filt, size) # [:9]
|
||||
coordinate = []
|
||||
image_data = []
|
||||
for box in filtered_boxes:
|
||||
image_data.append(
|
||||
[max(0, box[0] - 10), max(0, box[1] - 10), min(box[2] + 10, size[0]), min(box[3] + 10, size[1])]
|
||||
)
|
||||
coordinate.append(
|
||||
[max(0, box[0] - 25), max(0, box[1] - 25), min(box[2] + 25, size[0]), min(box[3] + 25, size[1])]
|
||||
)
|
||||
|
||||
return image_data, coordinate
|
||||
|
||||
|
||||
|
|
@ -219,7 +219,7 @@ class OutputParser:
|
|||
|
||||
if start_index != -1 and end_index != -1:
|
||||
# Extract the structure part
|
||||
structure_text = text[start_index : end_index + 1]
|
||||
structure_text = text[start_index: end_index + 1]
|
||||
|
||||
try:
|
||||
# Attempt to convert the text to a Python data type using ast.literal_eval
|
||||
|
|
@ -841,3 +841,21 @@ def get_markdown_codeblock_type(filename: str) -> str:
|
|||
"application/sql": "sql",
|
||||
}
|
||||
return mappings.get(mime_type, "text")
|
||||
|
||||
|
||||
def download_model(file_url: str, target_folder: Path) -> Path:
|
||||
file_name = file_url.split('/')[-1]
|
||||
file_path = target_folder.joinpath(f"{file_name}")
|
||||
if not file_path.exists():
|
||||
file_path.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
response = requests.get(file_url, stream=True)
|
||||
response.raise_for_status() # 检查请求是否成功
|
||||
# 保存文件
|
||||
with open(file_path, 'wb') as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
logger.info(f'权重文件已下载并保存至 {file_path}')
|
||||
except requests.exceptions.HTTPError as err:
|
||||
logger.info(f'权重文件下载过程中发生错误: {err}')
|
||||
return file_path
|
||||
|
|
|
|||
26
setup.py
26
setup.py
|
|
@ -45,7 +45,30 @@ extras_require = {
|
|||
"llama-index-postprocessor-flag-embedding-reranker==0.1.2",
|
||||
"docx2txt==0.8",
|
||||
],
|
||||
"android_assistant": ["pyshine==0.0.9", "opencv-python==4.6.0.66"],
|
||||
"android_assistant": [
|
||||
"pyshine==0.0.9",
|
||||
"opencv-python==4.6.0.66",
|
||||
"protobuf<3.20,>=3.9.2",
|
||||
"modelscope",
|
||||
"tensorflow==2.9.1; os_name == 'linux'",
|
||||
"tensorflow==2.9.1; os_name == 'win32'",
|
||||
"tensorflow-macos==2.9; os_name == 'darwin'",
|
||||
"keras==2.9.0",
|
||||
"torch",
|
||||
"torchvision",
|
||||
"transformers",
|
||||
"opencv-python",
|
||||
"matplotlib",
|
||||
"pycocotools",
|
||||
"SentencePiece",
|
||||
"tf_slim",
|
||||
"tf_keras",
|
||||
"pyclipper",
|
||||
"shapely",
|
||||
"groundingdino-py",
|
||||
"datasets==2.18.0",
|
||||
"clip-openai"
|
||||
],
|
||||
}
|
||||
|
||||
extras_require["test"] = [
|
||||
|
|
@ -96,4 +119,5 @@ setup(
|
|||
],
|
||||
},
|
||||
include_package_data=True,
|
||||
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue