Merge remote-tracking branch 'origin/main'

This commit is contained in:
kit 2024-04-29 15:19:42 +08:00
parent 17580333b8
commit cf9d86b832
4 changed files with 29 additions and 35 deletions

View file

@ -28,22 +28,31 @@ from metagpt.utils.common import download_model
from metagpt.const import DEFAULT_WORKSPACE_ROOT
def load_cv_model(device: str = "cpu") -> any:
ocr_detection = pipeline(Tasks.ocr_detection, model="damo/cv_resnet18_ocr-detection-line-level_damo")
ocr_recognition = pipeline(Tasks.ocr_recognition,
model="damo/cv_convnextTiny_ocr-recognition-document_damo")
file_url = "https://huggingface.co/ShilongLiu/GroundingDINO/blob/main/groundingdino_swint_ogc.pth"
target_folder = Path(f"{DEFAULT_WORKSPACE_ROOT}/weights")
file_path = download_model(file_url, target_folder)
groundingdino_model = load_model(file_path, device=device).eval()
return ocr_detection, ocr_recognition, groundingdino_model
class AndroidExtEnv(ExtEnv):
device_id: Optional[str] = Field(default=None)
screenshot_dir: Optional[Path] = Field(default=None)
xml_dir: Optional[Path] = Field(default=None)
width: int = Field(default=720, description="device screen width")
height: int = Field(default=1080, description="device screen height")
cv_model_status: dict = Field(default=None, description="Record model loading status")
ocr_detection: any = Field(default=None, description="ocr detection model")
ocr_recognition: any = Field(default=None, description="ocr recognition model")
groundingdino_model: any = Field(default=None, description="clip groundingdino model")
def __init__(self, **data: Any):
super().__init__(**data)
device_id = data.get("device_id")
self.cv_model_status = {
'ocr_detection_loaded': False,
'ocr_recognition_loaded': False,
'clip_model_loaded': False
}
self.ocr_detection, self.ocr_recognition, self.groundingdino_model = load_cv_model()
if device_id:
devices = self.list_devices()
if device_id not in devices:
@ -51,8 +60,8 @@ class AndroidExtEnv(ExtEnv):
(width, height) = self.device_shape
self.width = data.get("width", width)
self.height = data.get("height", height)
self.create_device_path(self.screenshot_dir)
self.create_device_path(self.xml_dir)
#self.create_device_path(self.screenshot_dir)
#self.create_device_path(self.xml_dir)
def reset(
self,
@ -173,16 +182,16 @@ class AndroidExtEnv(ExtEnv):
if pull_res != ADB_EXEC_FAIL:
res = ss_local_path
else:
ss_cmd = f"{self.adb_prefix_shell} rm /sdcard/screenshot.png"
ss_cmd = f"{self.adb_prefix_shell} rm /sdcard/{ss_name}.png"
ss_res = self.execute_adb_with_cmd(ss_cmd)
time.sleep(0.1)
ss_cmd = f"{self.adb_prefix_shell} screencap -p /sdcard/screenshot.png"
ss_cmd = f"{self.adb_prefix_shell} screencap -p /sdcard/{ss_name}.png"
ss_res = self.execute_adb_with_cmd(ss_cmd)
time.sleep(0.1)
ss_cmd = f"{self.adb_prefix} pull /sdcard/screenshot.png {self.screenshot_dir}"
ss_cmd = f"{self.adb_prefix} pull /sdcard/{ss_name}.png {self.screenshot_dir}"
ss_res = self.execute_adb_with_cmd(ss_cmd)
image_path = Path(f"{self.screenshot_dir}/screenshot.png")
res = image_path
image_path = Path(f"{self.screenshot_dir}/{ss_name}.png")
res = image_path
return Path(res)
@mark_as_readable
@ -261,23 +270,16 @@ class AndroidExtEnv(ExtEnv):
exit_res = self.execute_adb_with_cmd(adb_cmd)
return exit_res
def _ocr_text(self, text: str) -> list:
if not self.screenshot_dir.exists():
self.screenshot_dir.mkdir(parents=True, exist_ok=True)
image = self.get_screenshot("screenshot", self.screenshot_dir)
if self.cv_model_status['ocr_detection_loaded'] == False:
ocr_detection = pipeline(Tasks.ocr_detection, model="damo/cv_resnet18_ocr-detection-line-level_damo")
self.cv_model_status['ocr_detection_loaded'] = True
if self.cv_model_status['ocr_recognition_loaded'] == False:
ocr_recognition = pipeline(Tasks.ocr_recognition, model="damo/cv_convnextTiny_ocr-recognition-document_damo")
self.cv_model_status['ocr_recognition_loaded'] == True
iw, ih = Image.open(image).size
x, y = self.device_shape
if iw > ih:
x, y = y, x
iw, ih = ih, iw
in_coordinate, out_coordinate = ocr(image, text, ocr_detection, ocr_recognition, iw, ih)
in_coordinate, out_coordinate = ocr(image, text, self.ocr_detection, self.ocr_recognition, iw, ih)
output_list = [in_coordinate, out_coordinate, x, y, iw, ih, image]
return output_list
@ -323,19 +325,13 @@ class AndroidExtEnv(ExtEnv):
if not self.screenshot_dir.exists():
self.screenshot_dir.mkdir(parents=True, exist_ok=True)
screenshot_path = self.get_screenshot("screenshot", self.screenshot_dir)
image, device = screenshot_path, 'cpu'
image= screenshot_path
iw, ih = Image.open(image).size
x, y = self.device_shape
if iw > ih:
x, y = y, x
iw, ih = ih, iw
if self.cv_model_status['clip_model_loaded'] == False:
file_url = 'https://huggingface.co/ShilongLiu/GroundingDINO/blob/main/groundingdino_swint_ogc.pth' # 加载远程model
target_folder = Path(f'{DEFAULT_WORKSPACE_ROOT}/weights')
file_path = download_model(file_url, target_folder)
groundingdino_model = load_model(file_path, device=device).eval()
self.cv_model_status['clip_model_loaded'] = True
in_coordinate, out_coordinate = det(image, "icon", groundingdino_model) # 检测icon
in_coordinate, out_coordinate = det(image, "icon", self.groundingdino_model) # 检测icon
if len(out_coordinate) == 1: # only one icon
tap_coordinate = [(in_coordinate[0][0] + in_coordinate[0][2]) / 2,
(in_coordinate[0][1] + in_coordinate[0][3]) / 2]
@ -344,8 +340,7 @@ class AndroidExtEnv(ExtEnv):
else:
temp_file = Path(f"{DEFAULT_WORKSPACE_ROOT}/temp")
if not temp_file.exists():
temp_file.mkdir(parents=True, exist_ok=True)
temp_file.mkdir(parents=True, exist_ok=True)
hash_table, clip_filter = [], []
for i, (td, box) in enumerate(zip(in_coordinate, out_coordinate)):
if crop_for_clip(image, td, i, temp_file):

View file

@ -260,7 +260,7 @@ def transform_image(image_pil: any) -> any:
def load_model(model_checkpoint_path: Path, device: str) -> any:
model_config_path = 'grounding_dino_config.py'
model_config_path = "grounding_dino_config.py"
args = SLConfig.fromfile(model_config_path)
args.device = device
model = build_model(args)