update andriod_assistant

This commit is contained in:
better629 2024-01-26 21:16:26 +08:00
parent 8a49292045
commit 7f64fbee5a
5 changed files with 243 additions and 18 deletions

View file

@ -8,12 +8,16 @@ from examples.andriod_assistant.prompts.assistant_prompt import (
screenshot_parse_template,
screenshot_parse_with_grid_template,
)
from examples.andriod_assistant.utils.utils import draw_bbox_multi, traverse_xml_tree
from examples.andriod_assistant.utils.schema import OpLogItem, ActionOp, ParamExtState, GridOp, ActionOp, TapOp, TapGridOp, \
LongPressOp, LongPressGridOp, SwipeOp, SwipeGridOp, TextOp, AndroidElement
from examples.andriod_assistant.actions.screenshot_parse_an import SCREENSHOT_PARSE_NODE
from examples.andriod_assistant.utils.utils import draw_bbox_multi, traverse_xml_tree, area_to_xy, screenshot_parse_extract, elem_bbox_to_xy
from metagpt.actions.action import Action
from metagpt.config2 import config
from metagpt.environment.android_env.android_env import AndroidEnv
from metagpt.environment.api.env_api import EnvAPIAbstract
from metagpt.utils.common import encode_image
from metagpt.const import ADB_EXEC_FAIL
class ScreenshotParse(Action):
@ -38,7 +42,7 @@ class ScreenshotParse(Action):
focusable_list = []
traverse_xml_tree(xml_path, clickable_list, "clickable", True)
traverse_xml_tree(xml_path, focusable_list, "focusable", True)
elem_list = clickable_list.copy()
elem_list: list[AndroidElement] = clickable_list.copy()
for elem in focusable_list:
bbox = elem.bbox
center = (bbox[0][0] + bbox[1][0]) // 2, (bbox[0][1] + bbox[1][1]) // 2
@ -52,8 +56,10 @@ class ScreenshotParse(Action):
break
if not close:
elem_list.append(elem)
draw_bbox_multi(screenshot_path, task_dir.joinpath(f"{task_dir}_{round_count}_labeled.png"), elem_list)
encode_image(task_dir.joinpath(f"{task_dir}_{round_count}_labeled.png"))
screenshot_labeled_path = task_dir.joinpath(f"{task_dir}_{round_count}_labeled.png")
draw_bbox_multi(screenshot_path, screenshot_labeled_path, elem_list)
img_base64 = encode_image(screenshot_labeled_path)
parse_template = screenshot_parse_with_grid_template if grid_on else screenshot_parse_template
@ -61,4 +67,69 @@ class ScreenshotParse(Action):
# TODO
ui_doc = ""
parse_template.format(ui_document=ui_doc, task_description=task_desc, last_act=last_act)
context = parse_template.format(ui_document=ui_doc, task_description=task_desc, last_act=last_act)
node = await SCREENSHOT_PARSE_NODE.fill(context=context, llm=self.llm, images=[img_base64])
if "error" in node.content:
# TODO
return
prompt = node.compile(context=context, schema="json", mode="auto")
log_item = OpLogItem(step=round_count, prompt=prompt, image=screenshot_labeled_path, response=node.content)
op_param = screenshot_parse_extract(node.instruct_content.model_dump(), grid_on)
if op_param.param_state == ParamExtState.FINISH:
# TODO
return
if op_param.param_state == ParamExtState.FAIL:
# TODO
return
if isinstance(op_param, TapOp):
x, y = elem_bbox_to_xy(elem_list[op_param.area - 1].bbox)
res = env.step(EnvAPIAbstract("system_tap", kwargs={"x": x, "y": y}))
if res == ADB_EXEC_FAIL:
# TODO
return
elif isinstance(op_param, TextOp):
res = env.step(EnvAPIAbstract("user_input", kwargs={"input_txt": op_param.input_str}))
if res == ADB_EXEC_FAIL:
# TODO
return
elif isinstance(op_param, LongPressOp):
x, y = elem_bbox_to_xy(elem_list[op_param.area - 1].bbox)
res = env.step(EnvAPIAbstract("user_longpress", kwargs={"x": x, "y": y}))
if res == ADB_EXEC_FAIL:
# TODO
return
elif isinstance(op_param, SwipeOp):
x, y = elem_bbox_to_xy(elem_list[op_param.area - 1].bbox)
res = env.step(EnvAPIAbstract("user_swipe", kwargs={"x": x, "y": y, "dir": op_param.swipe_orient, "dist": op_param.dist}))
if res == ADB_EXEC_FAIL:
# TODO
return
elif isinstance(op_param, GridOp):
grid_on = True
elif isinstance(op_param, TapGridOp) or isinstance(op_param, LongPressGridOp):
x, y = area_to_xy(op_param.area, op_param.subarea, env.width, env.height, env.rows, env.cols)
if isinstance(op_param, TapGridOp):
res = env.step(EnvAPIAbstract("system_tap", kwargs={"x": x, "y": y}))
if res == ADB_EXEC_FAIL:
# TODO
return
else:
# LongPressGridOp
res = env.step(EnvAPIAbstract("user_longpress", kwargs={"x": x, "y": y}))
if res == ADB_EXEC_FAIL:
# TODO
return
elif isinstance(op_param, SwipeGridOp):
start_x, start_y = area_to_xy(op_param.start_area, op_param.start_subarea)
end_x, end_y = area_to_xy(op_param.end_area, op_param.end_subarea)
res = env.step(EnvAPIAbstract("user_swipe_to", kwargs={"start": (start_x, start_y), "end": (end_x, end_y)}))
if res == ADB_EXEC_FAIL:
# TODO
return
if op_param.act_name != "grid":
grid_on = True # TODO overwrite it

View file

@ -59,9 +59,9 @@ class SelfLearn(Action):
if not close:
elem_list.append(elem)
draw_bbox_multi(screenshot_path, task_dir.joinpath(f"{round_count}_before_labeled.png"), elem_list)
encode_image(task_dir.joinpath(f"{round_count}_before_labeled.png"))
img_base64 = encode_image(task_dir.joinpath(f"{round_count}_before_labeled.png"))
self_explore_template = screenshot_parse_self_explore_template
context = self_explore_template.format(task_description=task_desc, last_act=last_act)
await SCREENSHOT_PARSE_NODE.fill(context=context, llm=self.llm)
node = await SCREENSHOT_PARSE_NODE.fill(context=context, llm=self.llm, images=[img_base64])

View file

@ -7,7 +7,7 @@ from pathlib import Path
from examples.andriod_assistant.prompts.assistant_prompt import (
screenshot_parse_self_explore_reflect_template,
)
from examples.andriod_assistant.utils.schema import AndroidElement
from examples.andriod_assistant.utils.schema import AndroidElement, ActionOp, SwipeOp
from examples.andriod_assistant.utils.utils import draw_bbox_multi
from metagpt.actions.action import Action
from metagpt.environment.android_env.android_env import AndroidEnv
@ -27,7 +27,7 @@ class SelfLearnReflect(Action):
env: AndroidEnv,
elem_list: list[AndroidElement],
act_name: str,
swipe_dir: str,
swipe_orient: str,
ui_area: int,
):
if act_name == "text":
@ -47,15 +47,15 @@ class SelfLearnReflect(Action):
encode_image(task_dir.joinpath(f"{round_count}_after_labeled.png"))
reflect_template = screenshot_parse_self_explore_reflect_template
if act_name == "tap":
if act_name == ActionOp.TAP.value:
action = "tapping"
elif act_name == "long_press":
elif act_name == ActionOp.LONG_PRESS.value:
action = "long pressing"
elif act_name == "swipe":
elif act_name == ActionOp.SWIPE.value:
action = "swiping"
if swipe_dir == "up" or swipe_dir == "down":
if swipe_orient == SwipeOp.UP.value or swipe_orient == SwipeOp.DOWN.value:
action = "v_swipe"
elif swipe_dir == "left" or swipe_dir == "right":
elif swipe_orient == SwipeOp.LEFT.value or swipe_orient == SwipeOp.RIGHT.value:
action = "h_swipe"
reflect_template.format(action=action, ui_element=str(ui_area), task_desc=task_desc, last_act=last_act)

View file

@ -2,7 +2,23 @@
# -*- coding: utf-8 -*-
# @Desc :
from pydantic import Field, BaseModel
from enum import Enum
from pydantic import Field, BaseModel, field_validator
class ActionOp(Enum):
TAP = "tap"
LONG_PRESS = "long_press"
TEXT = "text"
SWIPE = "swipe"
GRID = "grid"
class SwipeOp(Enum):
UP = "up"
DOWN = "down"
LEFT = "left"
RIGHT = "right"
class AndroidElement(BaseModel):
@ -37,3 +53,64 @@ class DocContent(BaseModel):
long_press: str = Field(default="")
# start =================== define different Action Op and its params =============
class ParamExtState(Enum):
"""Op params extract state"""
SUCCESS = "success"
FINISH = "finish"
FAIL = "fail"
class BaseOpParam(BaseModel):
act_name: str = Field(default="", validate_default=True)
last_act: str = Field(default="")
param_state: ParamExtState = Field(default=ParamExtState.SUCCESS, description="return state when extract params")
class TapOp(BaseOpParam):
area: int = Field(default=-1)
class TextOp(BaseOpParam):
input_str: str = Field(default="")
class LongPressOp(BaseOpParam):
area: int = Field(default=-1)
class SwipeOp(BaseOpParam):
area: int = Field(default=-1)
swipe_orient: str = Field(default="up")
dist: str = Field(default="")
class GridOp(BaseModel):
act_name: str = Field(default="")
class BaseGridOpParam(BaseOpParam):
@field_validator("act_name", mode="before")
@classmethod
def check_act_name(cls, act_name: str) -> str:
return f"{act_name}_grid"
class TapGridOp(BaseGridOpParam):
area: int = Field(default=-1)
subarea: str = Field(default="")
class LongPressGridOp(BaseGridOpParam):
area: int = Field(default=-1)
subarea: str = Field(default="")
class SwipeGridOp(BaseGridOpParam):
start_area: int = Field(default=-1)
start_subarea: str = Field(default="")
end_area: int = Field(default=-1)
end_subarea: str = Field(default="")
# end =================== define different Action Op and its params =============

View file

@ -2,17 +2,19 @@
# -*- coding: utf-8 -*-
# @Desc :
from pydantic import Field, BaseModel
from typing import Union
from xml.etree.ElementTree import Element, iterparse
import cv2
from pathlib import Path
import pyshine as ps
import base64
import re
from metagpt.config2 import config
from metagpt.logs import logger
from examples.andriod_assistant.utils.schema import AndroidElement
from examples.andriod_assistant.utils.schema import BaseOpParam, BaseGridOpParam, GridOp, ActionOp, TapOp, TapGridOp, \
LongPressOp, LongPressGridOp, SwipeOp, SwipeGridOp, TextOp, ParamExtState
def get_id_from_element(elem: Element) -> str:
@ -134,7 +136,7 @@ def draw_grid(img_path: Path, output_path: Path) -> tuple[int, int]:
return rows, cols
def area_to_xy(width: int, height: int, cols: int, rows: int, area: int, subarea: str) -> tuple[int, int]:
def area_to_xy(area: int, subarea: str, width: int, height: int, rows: int, cols: int) -> tuple[int, int]:
area -= 1
row, col = area // cols, area % cols
x_0, y_0 = col * (width // cols), row * (height // rows)
@ -157,3 +159,78 @@ def area_to_xy(width: int, height: int, cols: int, rows: int, area: int, subarea
else:
x, y = x_0 + (width // cols) // 2, y_0 + (height // rows) // 2
return x, y
def elem_bbox_to_xy(bbox: tuple[tuple[int, int]]) -> tuple[int, int]:
tl, br = bbox
x, y = (tl[0] + br[0]) // 2, (tl[1] + br[1]) // 2
return x, y
def screenshot_parse_extract(parsed_json: dict, grid_on: bool = False) -> Union[BaseOpParam, BaseGridOpParam, GridOp]:
act = parsed_json.get("Action")
last_act = parsed_json.get("Summary")
act_name = act.split("(")[0]
if ParamExtState.FINISH.value.upper() in act:
return BaseOpParam(param_state=ParamExtState.FINISH)
if grid_on:
return screenshot_parse_extract_with_grid(act_name, act, last_act)
else:
return screenshot_parse_extract_without_grid(act_name, act, last_act)
def op_params_clean(params: list[str]) -> list[Union[int, str]]:
param_values = []
for param_value in params:
if '"' in param_value or "'" in param_value: # remove `"`
param_values.append(param_value.strip()[1:-1])
else:
param_values.append(int(param_value))
return param_values
def screenshot_parse_extract_without_grid(act_name: str, act: str, last_act: str) -> Union[BaseOpParam, GridOp]:
if act_name == ActionOp.TAP.value:
area = int(re.findall(r"tap\((.*?)\)", act)[0])
op = TapOp(act_name=act_name, area=area, last_act=last_act)
elif act_name == ActionOp.TEXT.value:
input_str = re.findall(r"text\((.*?)\)", act)[0][1:-1]
op = TextOp(act_name=act_name, input_str=input_str, last_act=last_act)
elif act_name == ActionOp.LONG_PRESS.value:
area = int(re.findall(r"long_press\((.*?)\)", act)[0])
op = LongPressOp(act_name=act_name, area=area, last_act=last_act)
elif act_name == ActionOp.SWIPE.value:
params = re.findall(r"swipe\((.*?)\)", act)[0].split(",")
params = op_params_clean(params) # area, swipe_orient, dist
op = SwipeOp(act_name=act_name, area=params[0], swipe_orient=params[1], dist=params[2], last_act=last_act)
elif act_name == ActionOp.GRID.value:
op = GridOp(act_name=act_name)
else:
op = BaseOpParam(param_state=ParamExtState.FAIL)
return op
def screenshot_parse_extract_with_grid(act_name: str, act: str, last_act: str) -> Union[BaseGridOpParam, GridOp]:
if act_name == ActionOp.TAP.value:
params = re.findall(r"tap\((.*?)\)", act)[0].split(",")
params = op_params_clean(params)
op = TapGridOp(act_name=act_name, area=params[0], subarea=params[1], last_act=last_act)
elif act_name == ActionOp.LONG_PRESS.value:
params = re.findall(r"long_press\((.*?)\)", act)[0].split(",")
params = op_params_clean(params)
op = LongPressGridOp(act_name=act_name, area=params[0], subarea=params[1], last_act=last_act)
elif act_name == ActionOp.SWIPE.value:
params = re.findall(r"swipe\((.*?)\)", act)[0].split(",")
params = op_params_clean(params)
op = SwipeGridOp(act_name=act_name,
start_area=params[0],
start_subarea=params[1],
end_area=params[2],
end_subarea=params[3])
elif act_name == ActionOp.GRID.value:
op = GridOp(act_name=act_name)
else:
op = BaseGridOpParam(param_state=ParamExtState.FAIL)
return op