Merge remote-tracking branch 'origin/main'

This commit is contained in:
kit 2024-04-26 11:52:54 +08:00
parent 3da74ec00d
commit 84a8c0d0bd
2 changed files with 33 additions and 30 deletions

View file

@ -34,10 +34,16 @@ class AndroidExtEnv(ExtEnv):
xml_dir: Optional[Path] = Field(default=None)
width: int = Field(default=720, description="device screen width")
height: int = Field(default=1080, description="device screen height")
cv_model_status: dict = Field(default=None, description="Record model loading status")
def __init__(self, **data: Any):
super().__init__(**data)
device_id = data.get("device_id")
self.cv_model_status = {
'ocr_detection_loaded': False,
'ocr_recognition_loaded': False,
'clip_model_loaded': False
}
if device_id:
devices = self.list_devices()
if device_id not in devices:
@ -45,8 +51,8 @@ class AndroidExtEnv(ExtEnv):
(width, height) = self.device_shape
self.width = data.get("width", width)
self.height = data.get("height", height)
self.create_device_path(self.screenshot_dir)
self.create_device_path(self.xml_dir)
#self.create_device_path(self.screenshot_dir)
#self.create_device_path(self.xml_dir)
def reset(
self,
@ -167,7 +173,16 @@ class AndroidExtEnv(ExtEnv):
if pull_res != ADB_EXEC_FAIL:
res = ss_local_path
else:
res = get_screenshot_only(local_save_dir)
ss_cmd = f"{self.adb_prefix_shell} rm /sdcard/screenshot.png"
ss_res = self.execute_adb_with_cmd(ss_cmd)
time.sleep(0.1)
ss_cmd = f"{self.adb_prefix_shell} screencap -p /sdcard/screenshot.png"
ss_res = self.execute_adb_with_cmd(ss_cmd)
time.sleep(0.1)
ss_cmd = f"{self.adb_prefix} pull /sdcard/screenshot.png {self.screenshot_dir}"
ss_res = self.execute_adb_with_cmd(ss_cmd)
image_path = Path(f"{self.screenshot_dir}/screenshot.png")
res = image_path
return Path(res)
@mark_as_readable
@ -246,12 +261,17 @@ class AndroidExtEnv(ExtEnv):
exit_res = self.execute_adb_with_cmd(adb_cmd)
return exit_res
def _ocr_text(self, text: str) -> list:
if not self.screenshot_dir.exists():
self.screenshot_dir.mkdir(parents=True, exist_ok=True)
image = self.get_screenshot("screenshot", self.screenshot_dir)
ocr_detection = pipeline(Tasks.ocr_detection, model="damo/cv_resnet18_ocr-detection-line-level_damo")
ocr_recognition = pipeline(Tasks.ocr_recognition, model="damo/cv_convnextTiny_ocr-recognition-document_damo")
if self.cv_model_status['ocr_detection_loaded'] == False:
ocr_detection = pipeline(Tasks.ocr_detection, model="damo/cv_resnet18_ocr-detection-line-level_damo")
self.cv_model_status['ocr_detection_loaded'] = True
if self.cv_model_status['ocr_recognition_loaded'] == False:
ocr_recognition = pipeline(Tasks.ocr_recognition, model="damo/cv_convnextTiny_ocr-recognition-document_damo")
self.cv_model_status['ocr_recognition_loaded'] == True
iw, ih = Image.open(image).size
x, y = self.device_shape
if iw > ih:
@ -312,7 +332,9 @@ class AndroidExtEnv(ExtEnv):
file_url = 'https://huggingface.co/ShilongLiu/GroundingDINO/blob/main/groundingdino_swint_ogc.pth' # 加载远程model
target_folder = Path(f'{DEFAULT_WORKSPACE_ROOT}/weights')
file_path = download_model(file_url, target_folder)
groundingdino_model = load_model(file_path, device=device).eval()
if self.cv_model_status['clip_model_loaded'] == False:
groundingdino_model = load_model(file_path, device=device).eval()
self.cv_model_status['clip_model_loaded'] = True
in_coordinate, out_coordinate = det(image, "icon", groundingdino_model) # 检测icon
if len(out_coordinate) == 1: # only one icon
tap_coordinate = [(in_coordinate[0][0] + in_coordinate[0][2]) / 2,
@ -328,7 +350,7 @@ class AndroidExtEnv(ExtEnv):
for i, (td, box) in enumerate(zip(in_coordinate, out_coordinate)):
if crop_for_clip(image, td, i, temp_file):
hash_table.append(td)
crop_image = f"{i}.jpg"
crop_image = f"{i}.png"
clip_filter.append(temp_file.joinpath(crop_image))
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
clip_filter = clip_for_icon(clip_model, clip_preprocess, clip_filter, icon_shape_color)

View file

@ -221,7 +221,7 @@ def crop_for_clip(image: any, box: any, i: int, temp_file: Path) -> bool:
bound = [0, 0, w, h]
if in_box(box, bound):
cropped_image = image.crop(box)
cropped_image.save(temp_file.joinpath(f"{i}.jpg"))
cropped_image.save(temp_file.joinpath(f"{i}.png"))
return True
else:
return False
@ -271,7 +271,7 @@ def load_model(model_checkpoint_path: Path, device: str) -> any:
return model
def get_grounding_output(model: any, image: any, caption: str, box_threshold: any, text_threshold: any, with_logits=True) -> any:
def get_grounding_output(model: any, image: any, caption: str, box_threshold: any, text_threshold: any, with_logits: bool = True) -> any:
caption = caption.lower()
caption = caption.strip()
if not caption.endswith("."):
@ -306,7 +306,7 @@ def get_grounding_output(model: any, image: any, caption: str, box_threshold: an
return boxes_filt, torch.Tensor(scores), pred_phrases
def remove_boxes(boxes_filt: any, size: any, iou_threshold=0.5) -> any:
def remove_boxes(boxes_filt: any, size: any, iou_threshold: float = 0.5) -> any:
boxes_to_remove = set()
for i in range(len(boxes_filt)):
@ -328,7 +328,7 @@ def remove_boxes(boxes_filt: any, size: any, iou_threshold=0.5) -> any:
return boxes_filt
def det(input_image: any, text_prompt: str, groundingdino_model: any, box_threshold=0.05, text_threshold=0.5) -> any:
def det(input_image: any, text_prompt: str, groundingdino_model: any, box_threshold:float = 0.05, text_threshold:float = 0.5) -> any:
image = Image.open(input_image)
size = image.size
@ -361,22 +361,3 @@ def det(input_image: any, text_prompt: str, groundingdino_model: any, box_thresh
return image_data, coordinate
def get_screenshot_only(screenshot_dir: Path) -> Path:
command = " adb shell rm /sdcard/screenshot.png"
subprocess.run(command, capture_output=True, text=True, shell=True)
time.sleep(0.1)
command = "adb shell screencap -p /sdcard/screenshot.png"
subprocess.run(command, capture_output=True, text=True, shell=True)
time.sleep(0.1)
command = f"adb pull /sdcard/screenshot.png {screenshot_dir}"
subprocess.run(command, capture_output=True, text=True, shell=True)
image_path = Path(f"{screenshot_dir}/screenshot.png")
save_path = Path(f"{screenshot_dir}/screenshot.jpg")
image = Image.open(image_path)
original_width, original_height = image.size
new_width = int(original_width * 0.5)
new_height = int(original_height * 0.5)
resized_image = image.resize((new_width, new_height))
resized_image.convert("RGB").save(save_path, "JPEG")
time.sleep(0.1)
return save_path