mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-01 03:46:23 +02:00
Merge remote-tracking branch 'origin/main'
This commit is contained in:
parent
17580333b8
commit
cf9d86b832
4 changed files with 29 additions and 35 deletions
|
|
@ -28,22 +28,31 @@ from metagpt.utils.common import download_model
|
|||
from metagpt.const import DEFAULT_WORKSPACE_ROOT
|
||||
|
||||
|
||||
def load_cv_model(device: str = "cpu") -> any:
|
||||
ocr_detection = pipeline(Tasks.ocr_detection, model="damo/cv_resnet18_ocr-detection-line-level_damo")
|
||||
ocr_recognition = pipeline(Tasks.ocr_recognition,
|
||||
model="damo/cv_convnextTiny_ocr-recognition-document_damo")
|
||||
file_url = "https://huggingface.co/ShilongLiu/GroundingDINO/blob/main/groundingdino_swint_ogc.pth"
|
||||
target_folder = Path(f"{DEFAULT_WORKSPACE_ROOT}/weights")
|
||||
file_path = download_model(file_url, target_folder)
|
||||
groundingdino_model = load_model(file_path, device=device).eval()
|
||||
return ocr_detection, ocr_recognition, groundingdino_model
|
||||
|
||||
|
||||
class AndroidExtEnv(ExtEnv):
|
||||
device_id: Optional[str] = Field(default=None)
|
||||
screenshot_dir: Optional[Path] = Field(default=None)
|
||||
xml_dir: Optional[Path] = Field(default=None)
|
||||
width: int = Field(default=720, description="device screen width")
|
||||
height: int = Field(default=1080, description="device screen height")
|
||||
cv_model_status: dict = Field(default=None, description="Record model loading status")
|
||||
ocr_detection: any = Field(default=None, description="ocr detection model")
|
||||
ocr_recognition: any = Field(default=None, description="ocr recognition model")
|
||||
groundingdino_model: any = Field(default=None, description="clip groundingdino model")
|
||||
|
||||
def __init__(self, **data: Any):
|
||||
super().__init__(**data)
|
||||
device_id = data.get("device_id")
|
||||
self.cv_model_status = {
|
||||
'ocr_detection_loaded': False,
|
||||
'ocr_recognition_loaded': False,
|
||||
'clip_model_loaded': False
|
||||
}
|
||||
self.ocr_detection, self.ocr_recognition, self.groundingdino_model = load_cv_model()
|
||||
if device_id:
|
||||
devices = self.list_devices()
|
||||
if device_id not in devices:
|
||||
|
|
@ -51,8 +60,8 @@ class AndroidExtEnv(ExtEnv):
|
|||
(width, height) = self.device_shape
|
||||
self.width = data.get("width", width)
|
||||
self.height = data.get("height", height)
|
||||
self.create_device_path(self.screenshot_dir)
|
||||
self.create_device_path(self.xml_dir)
|
||||
#self.create_device_path(self.screenshot_dir)
|
||||
#self.create_device_path(self.xml_dir)
|
||||
|
||||
def reset(
|
||||
self,
|
||||
|
|
@ -173,16 +182,16 @@ class AndroidExtEnv(ExtEnv):
|
|||
if pull_res != ADB_EXEC_FAIL:
|
||||
res = ss_local_path
|
||||
else:
|
||||
ss_cmd = f"{self.adb_prefix_shell} rm /sdcard/screenshot.png"
|
||||
ss_cmd = f"{self.adb_prefix_shell} rm /sdcard/{ss_name}.png"
|
||||
ss_res = self.execute_adb_with_cmd(ss_cmd)
|
||||
time.sleep(0.1)
|
||||
ss_cmd = f"{self.adb_prefix_shell} screencap -p /sdcard/screenshot.png"
|
||||
ss_cmd = f"{self.adb_prefix_shell} screencap -p /sdcard/{ss_name}.png"
|
||||
ss_res = self.execute_adb_with_cmd(ss_cmd)
|
||||
time.sleep(0.1)
|
||||
ss_cmd = f"{self.adb_prefix} pull /sdcard/screenshot.png {self.screenshot_dir}"
|
||||
ss_cmd = f"{self.adb_prefix} pull /sdcard/{ss_name}.png {self.screenshot_dir}"
|
||||
ss_res = self.execute_adb_with_cmd(ss_cmd)
|
||||
image_path = Path(f"{self.screenshot_dir}/screenshot.png")
|
||||
res = image_path
|
||||
image_path = Path(f"{self.screenshot_dir}/{ss_name}.png")
|
||||
res = image_path
|
||||
return Path(res)
|
||||
|
||||
@mark_as_readable
|
||||
|
|
@ -261,23 +270,16 @@ class AndroidExtEnv(ExtEnv):
|
|||
exit_res = self.execute_adb_with_cmd(adb_cmd)
|
||||
return exit_res
|
||||
|
||||
|
||||
def _ocr_text(self, text: str) -> list:
|
||||
if not self.screenshot_dir.exists():
|
||||
self.screenshot_dir.mkdir(parents=True, exist_ok=True)
|
||||
image = self.get_screenshot("screenshot", self.screenshot_dir)
|
||||
if self.cv_model_status['ocr_detection_loaded'] == False:
|
||||
ocr_detection = pipeline(Tasks.ocr_detection, model="damo/cv_resnet18_ocr-detection-line-level_damo")
|
||||
self.cv_model_status['ocr_detection_loaded'] = True
|
||||
if self.cv_model_status['ocr_recognition_loaded'] == False:
|
||||
ocr_recognition = pipeline(Tasks.ocr_recognition, model="damo/cv_convnextTiny_ocr-recognition-document_damo")
|
||||
self.cv_model_status['ocr_recognition_loaded'] == True
|
||||
iw, ih = Image.open(image).size
|
||||
x, y = self.device_shape
|
||||
if iw > ih:
|
||||
x, y = y, x
|
||||
iw, ih = ih, iw
|
||||
in_coordinate, out_coordinate = ocr(image, text, ocr_detection, ocr_recognition, iw, ih)
|
||||
in_coordinate, out_coordinate = ocr(image, text, self.ocr_detection, self.ocr_recognition, iw, ih)
|
||||
output_list = [in_coordinate, out_coordinate, x, y, iw, ih, image]
|
||||
return output_list
|
||||
|
||||
|
|
@ -323,19 +325,13 @@ class AndroidExtEnv(ExtEnv):
|
|||
if not self.screenshot_dir.exists():
|
||||
self.screenshot_dir.mkdir(parents=True, exist_ok=True)
|
||||
screenshot_path = self.get_screenshot("screenshot", self.screenshot_dir)
|
||||
image, device = screenshot_path, 'cpu'
|
||||
image= screenshot_path
|
||||
iw, ih = Image.open(image).size
|
||||
x, y = self.device_shape
|
||||
if iw > ih:
|
||||
x, y = y, x
|
||||
iw, ih = ih, iw
|
||||
if self.cv_model_status['clip_model_loaded'] == False:
|
||||
file_url = 'https://huggingface.co/ShilongLiu/GroundingDINO/blob/main/groundingdino_swint_ogc.pth' # 加载远程model
|
||||
target_folder = Path(f'{DEFAULT_WORKSPACE_ROOT}/weights')
|
||||
file_path = download_model(file_url, target_folder)
|
||||
groundingdino_model = load_model(file_path, device=device).eval()
|
||||
self.cv_model_status['clip_model_loaded'] = True
|
||||
in_coordinate, out_coordinate = det(image, "icon", groundingdino_model) # 检测icon
|
||||
in_coordinate, out_coordinate = det(image, "icon", self.groundingdino_model) # 检测icon
|
||||
if len(out_coordinate) == 1: # only one icon
|
||||
tap_coordinate = [(in_coordinate[0][0] + in_coordinate[0][2]) / 2,
|
||||
(in_coordinate[0][1] + in_coordinate[0][3]) / 2]
|
||||
|
|
@ -344,8 +340,7 @@ class AndroidExtEnv(ExtEnv):
|
|||
|
||||
else:
|
||||
temp_file = Path(f"{DEFAULT_WORKSPACE_ROOT}/temp")
|
||||
if not temp_file.exists():
|
||||
temp_file.mkdir(parents=True, exist_ok=True)
|
||||
temp_file.mkdir(parents=True, exist_ok=True)
|
||||
hash_table, clip_filter = [], []
|
||||
for i, (td, box) in enumerate(zip(in_coordinate, out_coordinate)):
|
||||
if crop_for_clip(image, td, i, temp_file):
|
||||
|
|
|
|||
|
|
@ -260,7 +260,7 @@ def transform_image(image_pil: any) -> any:
|
|||
|
||||
|
||||
def load_model(model_checkpoint_path: Path, device: str) -> any:
|
||||
model_config_path = 'grounding_dino_config.py'
|
||||
model_config_path = "grounding_dino_config.py"
|
||||
args = SLConfig.fromfile(model_config_path)
|
||||
args.device = device
|
||||
model = build_model(args)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue