From 8a49292045ddea415ae93be9382bc0294fab191c Mon Sep 17 00:00:00 2001 From: better629 Date: Fri, 26 Jan 2024 15:18:29 +0800 Subject: [PATCH] add andriod_assistant action self-learn / self-learn-reflect / screenshot-parse --- .../actions/manual_record.py | 1 + .../andriod_assistant/actions/parse_record.py | 7 +- .../actions/screenshot_parse.py | 56 +++++- .../actions/screenshot_parse_an.py | 36 ++-- .../andriod_assistant/actions/self_learn.py | 60 ++++++- .../actions/self_learn_reflect.py | 55 +++++- .../prompts/assistant_prompt.py | 1 - examples/andriod_assistant/requirements.txt | 1 + .../roles/android_assistant.py | 16 +- examples/andriod_assistant/run_assistant.py | 3 +- examples/andriod_assistant/utils/schema.py | 39 +++++ examples/andriod_assistant/utils/utils.py | 159 ++++++++++++++++++ 12 files changed, 388 insertions(+), 46 deletions(-) create mode 100644 examples/andriod_assistant/requirements.txt create mode 100644 examples/andriod_assistant/utils/schema.py create mode 100644 examples/andriod_assistant/utils/utils.py diff --git a/examples/andriod_assistant/actions/manual_record.py b/examples/andriod_assistant/actions/manual_record.py index 23012416d..463bce1fd 100644 --- a/examples/andriod_assistant/actions/manual_record.py +++ b/examples/andriod_assistant/actions/manual_record.py @@ -7,6 +7,7 @@ from metagpt.actions.action import Action class ManualRecord(Action): """do a human operation on the screen with human input""" + name: str = "ManualRecord" async def run(self): diff --git a/examples/andriod_assistant/actions/parse_record.py b/examples/andriod_assistant/actions/parse_record.py index 3ffa4d4e8..77f49fbd3 100644 --- a/examples/andriod_assistant/actions/parse_record.py +++ b/examples/andriod_assistant/actions/parse_record.py @@ -1,10 +1,11 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -# @Desc : parse record to generate learned standard operations in stage=learn & mode=manual, LIKE scripts/document_generation.py - -from metagpt.actions.action import Action +# @Desc : parse record to generate learned standard operations in stage=learn & mode=manual, +# LIKE scripts/document_generation.py from examples.andriod_assistant.prompts.operation_prompt import * +from metagpt.actions.action import Action + class ParseRecord(Action): name: str = "ParseRecord" diff --git a/examples/andriod_assistant/actions/screenshot_parse.py b/examples/andriod_assistant/actions/screenshot_parse.py index 1fa5a26b9..c956f4d53 100644 --- a/examples/andriod_assistant/actions/screenshot_parse.py +++ b/examples/andriod_assistant/actions/screenshot_parse.py @@ -2,11 +2,63 @@ # -*- coding: utf-8 -*- # @Desc : LIKE scripts/task_executor.py in stage=act +from pathlib import Path + +from examples.andriod_assistant.prompts.assistant_prompt import ( + screenshot_parse_template, + screenshot_parse_with_grid_template, +) +from examples.andriod_assistant.utils.utils import draw_bbox_multi, traverse_xml_tree from metagpt.actions.action import Action +from metagpt.config2 import config +from metagpt.environment.android_env.android_env import AndroidEnv +from metagpt.environment.api.env_api import EnvAPIAbstract +from metagpt.utils.common import encode_image class ScreenshotParse(Action): name: str = "ScreenshotParse" - async def run(self): - pass + async def run( + self, round_count: int, task_desc: str, last_act: str, task_dir: Path, env: AndroidEnv, grid_on: bool = False + ): + screenshot_path: Path = env.step( + EnvAPIAbstract( + api_name="get_screenshot", kwargs={"ss_name": f"{round_count}_before", "local_save_dir": task_dir} + ) + ) + xml_path: Path = env.step( + EnvAPIAbstract(api_name="get_xml", kwargs={"xml_name": f"{round_count}", "local_save_dir": task_dir}) + ) + if not screenshot_path.exists() or not xml_path.exists(): + # TODO exit + return + + clickable_list = [] + focusable_list = [] + traverse_xml_tree(xml_path, clickable_list, "clickable", True) + traverse_xml_tree(xml_path, focusable_list, "focusable", True) + elem_list = clickable_list.copy() + for elem in focusable_list: + bbox = elem.bbox + center = (bbox[0][0] + bbox[1][0]) // 2, (bbox[0][1] + bbox[1][1]) // 2 + close = False + for e in clickable_list: + bbox = e.bbox + center_ = (bbox[0][0] + bbox[1][0]) // 2, (bbox[0][1] + bbox[1][1]) // 2 + dist = (abs(center[0] - center_[0]) ** 2 + abs(center[1] - center_[1]) ** 2) ** 0.5 + if dist <= config.get_other("min_dist"): + close = True + break + if not close: + elem_list.append(elem) + draw_bbox_multi(screenshot_path, task_dir.joinpath(f"{task_dir}_{round_count}_labeled.png"), elem_list) + encode_image(task_dir.joinpath(f"{task_dir}_{round_count}_labeled.png")) + + parse_template = screenshot_parse_with_grid_template if grid_on else screenshot_parse_template + + # makeup `ui_doc` + # TODO + ui_doc = "" + + parse_template.format(ui_document=ui_doc, task_description=task_desc, last_act=last_act) diff --git a/examples/andriod_assistant/actions/screenshot_parse_an.py b/examples/andriod_assistant/actions/screenshot_parse_an.py index d9879bdb6..eb23ba934 100644 --- a/examples/andriod_assistant/actions/screenshot_parse_an.py +++ b/examples/andriod_assistant/actions/screenshot_parse_an.py @@ -4,59 +4,45 @@ from metagpt.actions.action_node import ActionNode - OBSERVATION = ActionNode( - key="Observation", - expected_type=str, - instruction="Describe what you observe in the image", - example="" + key="Observation", expected_type=str, instruction="Describe what you observe in the image", example="" ) THOUGHT = ActionNode( key="Thought", expected_type=str, instruction="To complete the given task, what is the next step I should do", - example="" + example="", ) ACTION = ActionNode( key="Action", expected_type=str, instruction="The function call with the correct parameters to proceed with the task. If you believe the task is " - "completed or there is nothing to be done, you should output FINISH. You cannot output anything else " - "except a function call or FINISH in this field.", - example="" + "completed or there is nothing to be done, you should output FINISH. You cannot output anything else " + "except a function call or FINISH in this field.", + example="", ) SUMMARY = ActionNode( key="Summary", expected_type=str, instruction="Summarize your past actions along with your latest action in one or two sentences. Do not include " - "the numeric tag in your summary", - example="" + "the numeric tag in your summary", + example="", ) SUMMARY_GRID = ActionNode( key="Summary", expected_type=str, instruction="Summarize your past actions along with your latest action in one or two sentences. Do not include " - "the grid area number in your summary", - example="" + "the grid area number in your summary", + example="", ) -NODES = [ - OBSERVATION, - THOUGHT, - ACTION, - SUMMARY -] +NODES = [OBSERVATION, THOUGHT, ACTION, SUMMARY] -NODES_GRID = [ - OBSERVATION, - THOUGHT, - ACTION, - SUMMARY_GRID -] +NODES_GRID = [OBSERVATION, THOUGHT, ACTION, SUMMARY_GRID] SCREENSHOT_PARSE_NODE = ActionNode.from_children("ScreenshotParse", NODES) SCREENSHOT_PARSE_GRID_NODE = ActionNode.from_children("ScreenshotParseGrid", NODES_GRID) diff --git a/examples/andriod_assistant/actions/self_learn.py b/examples/andriod_assistant/actions/self_learn.py index ffc52f535..cbb78c2a2 100644 --- a/examples/andriod_assistant/actions/self_learn.py +++ b/examples/andriod_assistant/actions/self_learn.py @@ -2,14 +2,66 @@ # -*- coding: utf-8 -*- # @Desc : LIKE scripts/self_explorer.py in stage=learn & mode=auto self_explore_task stage -from metagpt.actions.action import Action +from pathlib import Path from examples.andriod_assistant.actions.screenshot_parse_an import SCREENSHOT_PARSE_NODE -from examples.andriod_assistant.prompts.assistant_prompt import screenshot_parse_self_explore_template +from examples.andriod_assistant.prompts.assistant_prompt import ( + screenshot_parse_self_explore_template, +) +from examples.andriod_assistant.utils.utils import draw_bbox_multi, traverse_xml_tree +from metagpt.actions.action import Action +from metagpt.config2 import config +from metagpt.environment.android_env.android_env import AndroidEnv +from metagpt.environment.api.env_api import EnvAPIAbstract +from metagpt.utils.common import encode_image class SelfLearn(Action): name: str = "SelfLearn" - async def run(self): - pass + useless_list: list[str] = [] # store useless elements uid + + async def run(self, round_count: int, task_desc: str, last_act: str, task_dir: Path, env: AndroidEnv): + screenshot_path: Path = env.step( + EnvAPIAbstract( + api_name="get_screenshot", kwargs={"ss_name": f"{round_count}_before", "local_save_dir": task_dir} + ) + ) + xml_path: Path = env.step( + EnvAPIAbstract(api_name="get_xml", kwargs={"xml_name": f"{round_count}", "local_save_dir": task_dir}) + ) + if not screenshot_path.exists() or not xml_path.exists(): + # TODO exit + return + + clickable_list = [] + focusable_list = [] + traverse_xml_tree(xml_path, clickable_list, "clickable", True) + traverse_xml_tree(xml_path, focusable_list, "focusable", True) + elem_list = [] + for elem in clickable_list: + if elem.uid in self.useless_list: + continue + elem_list.append(elem) + for elem in focusable_list: + if elem.uid in self.useless_list: + continue + bbox = elem.bbox + center = (bbox[0][0] + bbox[1][0]) // 2, (bbox[0][1] + bbox[1][1]) // 2 + close = False + for e in clickable_list: + bbox = e.bbox + center_ = (bbox[0][0] + bbox[1][0]) // 2, (bbox[0][1] + bbox[1][1]) // 2 + dist = (abs(center[0] - center_[0]) ** 2 + abs(center[1] - center_[1]) ** 2) ** 0.5 + if dist <= config.get_other("min_dist"): + close = True + break + if not close: + elem_list.append(elem) + draw_bbox_multi(screenshot_path, task_dir.joinpath(f"{round_count}_before_labeled.png"), elem_list) + encode_image(task_dir.joinpath(f"{round_count}_before_labeled.png")) + + self_explore_template = screenshot_parse_self_explore_template + context = self_explore_template.format(task_description=task_desc, last_act=last_act) + + await SCREENSHOT_PARSE_NODE.fill(context=context, llm=self.llm) diff --git a/examples/andriod_assistant/actions/self_learn_reflect.py b/examples/andriod_assistant/actions/self_learn_reflect.py index 57f87a524..fa76b7b4b 100644 --- a/examples/andriod_assistant/actions/self_learn_reflect.py +++ b/examples/andriod_assistant/actions/self_learn_reflect.py @@ -2,13 +2,60 @@ # -*- coding: utf-8 -*- # @Desc : LIKE scripts/self_explorer.py self_explore_reflect stage -from metagpt.actions.action import Action +from pathlib import Path -from examples.andriod_assistant.prompts.assistant_prompt import screenshot_parse_self_explore_reflect_template +from examples.andriod_assistant.prompts.assistant_prompt import ( + screenshot_parse_self_explore_reflect_template, +) +from examples.andriod_assistant.utils.schema import AndroidElement +from examples.andriod_assistant.utils.utils import draw_bbox_multi +from metagpt.actions.action import Action +from metagpt.environment.android_env.android_env import AndroidEnv +from metagpt.environment.api.env_api import EnvAPIAbstract +from metagpt.utils.common import encode_image class SelfLearnReflect(Action): name: str = "SelfLearnReflect" - async def run(self): - pass + async def run( + self, + round_count: int, + task_desc: str, + last_act: str, + task_dir: Path, + env: AndroidEnv, + elem_list: list[AndroidElement], + act_name: str, + swipe_dir: str, + ui_area: int, + ): + if act_name == "text": + # TODO ignore current reflect + return + + screenshot_path: Path = env.step( + EnvAPIAbstract( + api_name="get_screenshot", kwargs={"ss_name": f"{round_count}_before", "local_save_dir": task_dir} + ) + ) + if not screenshot_path.exists(): + # TODO exit + return + + draw_bbox_multi(screenshot_path, task_dir.joinpath(f"{round_count}_after_labeled.png"), elem_list) + encode_image(task_dir.joinpath(f"{round_count}_after_labeled.png")) + + reflect_template = screenshot_parse_self_explore_reflect_template + if act_name == "tap": + action = "tapping" + elif act_name == "long_press": + action = "long pressing" + elif act_name == "swipe": + action = "swiping" + if swipe_dir == "up" or swipe_dir == "down": + action = "v_swipe" + elif swipe_dir == "left" or swipe_dir == "right": + action = "h_swipe" + + reflect_template.format(action=action, ui_element=str(ui_area), task_desc=task_desc, last_act=last_act) diff --git a/examples/andriod_assistant/prompts/assistant_prompt.py b/examples/andriod_assistant/prompts/assistant_prompt.py index a2c7900c6..068f78f3f 100644 --- a/examples/andriod_assistant/prompts/assistant_prompt.py +++ b/examples/andriod_assistant/prompts/assistant_prompt.py @@ -165,4 +165,3 @@ Decision: SUCCESS Thought: Documentation: """ - diff --git a/examples/andriod_assistant/requirements.txt b/examples/andriod_assistant/requirements.txt new file mode 100644 index 000000000..e879bece5 --- /dev/null +++ b/examples/andriod_assistant/requirements.txt @@ -0,0 +1 @@ +pyshine==0.0.9 \ No newline at end of file diff --git a/examples/andriod_assistant/roles/android_assistant.py b/examples/andriod_assistant/roles/android_assistant.py index 7e5e3d595..9e9a22b0d 100644 --- a/examples/andriod_assistant/roles/android_assistant.py +++ b/examples/andriod_assistant/roles/android_assistant.py @@ -2,15 +2,16 @@ # -*- coding: utf-8 -*- # @Desc : android assistant to learn from app operations and operate apps -from metagpt.roles.role import Role -from metagpt.config2 import config -from metagpt.actions.add_requirement import UserRequirement - from examples.andriod_assistant.actions.manual_record import ManualRecord from examples.andriod_assistant.actions.parse_record import ParseRecord +from examples.andriod_assistant.actions.screenshot_parse import ScreenshotParse from examples.andriod_assistant.actions.self_learn import SelfLearn from examples.andriod_assistant.actions.self_learn_reflect import SelfLearnReflect -from examples.andriod_assistant.actions.screenshot_parse import ScreenshotParse +from metagpt.actions.add_requirement import UserRequirement +from metagpt.config2 import config +from metagpt.logs import logger +from metagpt.roles.role import Role +from metagpt.schema import Message class AndroidAssistant(Role): @@ -25,6 +26,9 @@ class AndroidAssistant(Role): self.set_actions([ManualRecord, ParseRecord, SelfLearn, SelfLearnReflect, ScreenshotParse]) async def _think(self) -> bool: + """Firstly, we decide the state with user config, further, we can do it automatically, like if it's new app, + run the learn first and then do the act stage or learn it during the action. + """ if config.get_other("stage") == "learn" and config.get_other("mode") == "manual": # choose ManualRecord and then run ParseRecord # Remember, only run each action only one time, no need to run n_round. @@ -37,4 +41,4 @@ class AndroidAssistant(Role): pass async def _act(self) -> Message: - pass + logger.info(f"{self._setting}: to do {self.rc.todo}({self.rc.todo.name})") diff --git a/examples/andriod_assistant/run_assistant.py b/examples/andriod_assistant/run_assistant.py index d01e1f1da..4d599e80b 100644 --- a/examples/andriod_assistant/run_assistant.py +++ b/examples/andriod_assistant/run_assistant.py @@ -16,6 +16,7 @@ app = typer.Typer(add_completion=False, pretty_exceptions_show_locals=False) @app.command("", help="Run a Android Assistant") def startup( + task_desc: str = typer.Argument(help="the task description you want the android assistant to learn or act"), n_round: int = typer.Option(default=20, help="The max round to do an app operation task."), stage: str = typer.Option(default="learn", help="stage: learn / act"), mode: str = typer.Option(default="auto", help="mode: auto / manual , when state=learn"), @@ -49,7 +50,7 @@ def startup( team = Team(env=AndroidEnv()) team.hire([AndroidAssistant]) team.invest(investment) - company.run_project(idea="") # no need idea, just a mock + company.run_project(idea=task_desc) asyncio.run(team.run(n_round=n_round)) diff --git a/examples/andriod_assistant/utils/schema.py b/examples/andriod_assistant/utils/schema.py new file mode 100644 index 000000000..d48b401d2 --- /dev/null +++ b/examples/andriod_assistant/utils/schema.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +from pydantic import Field, BaseModel + + +class AndroidElement(BaseModel): + """UI Element""" + uid: str = Field(default="") + bbox: tuple[tuple[int, int]] = Field(default={}) + attrib: str = Field(default="") + + +class OpLogItem(BaseModel): + """log content for self-learn or task act""" + step: int = Field(default=0) + prompt: str = Field(default="") + image: str = Field(default="") + response: str = Field(default="") + + +class ReflectLogItem(BaseModel): + """log content for self-learn-reflect""" + step: int = Field(default=0) + prompt: str = Field(default="") + image_before: str = Field(default="") + image_after: str = Field(default="") + response: str = Field(default="") + + +class DocContent(BaseModel): + tap: str = Field(default="") + text: str = Field(default="") + v_swipe: str = Field(default="") + h_swipe: str = Field(default="") + long_press: str = Field(default="") + + diff --git a/examples/andriod_assistant/utils/utils.py b/examples/andriod_assistant/utils/utils.py new file mode 100644 index 000000000..7254e49c8 --- /dev/null +++ b/examples/andriod_assistant/utils/utils.py @@ -0,0 +1,159 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +from pydantic import Field, BaseModel +from xml.etree.ElementTree import Element, iterparse +import cv2 +from pathlib import Path +import pyshine as ps +import base64 + +from metagpt.config2 import config +from metagpt.logs import logger + +from examples.andriod_assistant.utils.schema import AndroidElement + + +def get_id_from_element(elem: Element) -> str: + bounds = elem.attrib["bounds"][1:-1].split("][") + x1, y1 = map(int, bounds[0].split(",")) + x2, y2 = map(int, bounds[1].split(",")) + elem_w, elem_h = x2 - x1, y2 - y1 + if "resource-id" in elem.attrib and elem.attrib["resource-id"]: + elem_id = elem.attrib["resource-id"].replace(":", ".").replace("/", "_") + else: + elem_id = f"{elem.attrib['class']}_{elem_w}_{elem_h}" + if "content-desc" in elem.attrib and elem.attrib["content-desc"] and len(elem.attrib["content-desc"]) < 20: + content_desc = elem.attrib["content-desc"].replace("/", "_").replace(" ", "").replace(":", "_") + elem_id += f"_{content_desc}" + return elem_id + + +def traverse_xml_tree(xml_path: Path, elem_list: list[AndroidElement], attrib: str, add_index=False): + path = [] + for event, elem in iterparse(str(xml_path), ["start", "end"]): + if event == "start": + path.append(elem) + if attrib in elem.attrib and elem.attrib[attrib] == "true": + parent_prefix = "" + if len(path) > 1: + parent_prefix = get_id_from_element(path[-2]) + bounds = elem.attrib["bounds"][1:-1].split("][") + x1, y1 = map(int, bounds[0].split(",")) + x2, y2 = map(int, bounds[1].split(",")) + center = (x1 + x2) // 2, (y1 + y2) // 2 + elem_id = get_id_from_element(elem) + if parent_prefix: + elem_id = parent_prefix + "_" + elem_id + if add_index: + elem_id += f"_{elem.attrib['index']}" + close = False + for e in elem_list: + bbox = e.bbox + center_ = (bbox[0][0] + bbox[1][0]) // 2, (bbox[0][1] + bbox[1][1]) // 2 + dist = (abs(center[0] - center_[0]) ** 2 + abs(center[1] - center_[1]) ** 2) ** 0.5 + if dist <= config.get_other("min_dist"): + close = True + break + if not close: + elem_list.append(AndroidElement(uid=elem_id, bbox=((x1, y1), (x2, y2)), attrib=attrib)) + + if event == "end": + path.pop() + + +def draw_bbox_multi(img_path: Path, output_path: Path, elem_list: list[AndroidElement], record_mode: bool = False, + dark_mode: bool = False): + imgcv = cv2.imread(img_path) + count = 1 + for elem in elem_list: + try: + top_left = elem.bbox[0] + bottom_right = elem.bbox[1] + left, top = top_left[0], top_left[1] + right, bottom = bottom_right[0], bottom_right[1] + label = str(count) + if record_mode: + if elem.attrib == "clickable": + color = (250, 0, 0) + elif elem.attrib == "focusable": + color = (0, 0, 250) + else: + color = (0, 250, 0) + imgcv = ps.putBText(imgcv, label, text_offset_x=(left + right) // 2 + 10, + text_offset_y=(top + bottom) // 2 + 10, + vspace=10, hspace=10, font_scale=1, thickness=2, background_RGB=color, + text_RGB=(255, 250, 250), alpha=0.5) + else: + text_color = (10, 10, 10) if dark_mode else (255, 250, 250) + bg_color = (255, 250, 250) if dark_mode else (10, 10, 10) + imgcv = ps.putBText(imgcv, label, text_offset_x=(left + right) // 2 + 10, + text_offset_y=(top + bottom) // 2 + 10, + vspace=10, hspace=10, font_scale=1, thickness=2, background_RGB=bg_color, + text_RGB=text_color, alpha=0.5) + except Exception as e: + logger.error(f"ERROR: An exception occurs while labeling the image\n{e}") + count += 1 + cv2.imwrite(output_path, imgcv) + return imgcv + + +def draw_grid(img_path: Path, output_path: Path) -> tuple[int, int]: + def get_unit_len(n): + for i in range(1, n + 1): + if n % i == 0 and 120 <= i <= 180: + return i + return -1 + + image = cv2.imread(img_path) + height, width, _ = image.shape + color = (255, 116, 113) + unit_height = get_unit_len(height) + if unit_height < 0: + unit_height = 120 + unit_width = get_unit_len(width) + if unit_width < 0: + unit_width = 120 + thick = int(unit_width // 50) + rows = height // unit_height + cols = width // unit_width + for i in range(rows): + for j in range(cols): + label = i * cols + j + 1 + left = int(j * unit_width) + top = int(i * unit_height) + right = int((j + 1) * unit_width) + bottom = int((i + 1) * unit_height) + cv2.rectangle(image, (left, top), (right, bottom), color, thick // 2) + cv2.putText(image, str(label), (left + int(unit_width * 0.05) + 3, top + int(unit_height * 0.3) + 3), 0, + int(0.01 * unit_width), (0, 0, 0), thick) + cv2.putText(image, str(label), (left + int(unit_width * 0.05), top + int(unit_height * 0.3)), 0, + int(0.01 * unit_width), color, thick) + cv2.imwrite(output_path, image) + return rows, cols + + +def area_to_xy(width: int, height: int, cols: int, rows: int, area: int, subarea: str) -> tuple[int, int]: + area -= 1 + row, col = area // cols, area % cols + x_0, y_0 = col * (width // cols), row * (height // rows) + if subarea == "top-left": + x, y = x_0 + (width // cols) // 4, y_0 + (height // rows) // 4 + elif subarea == "top": + x, y = x_0 + (width // cols) // 2, y_0 + (height // rows) // 4 + elif subarea == "top-right": + x, y = x_0 + (width // cols) * 3 // 4, y_0 + (height // rows) // 4 + elif subarea == "left": + x, y = x_0 + (width // cols) // 4, y_0 + (height // rows) // 2 + elif subarea == "right": + x, y = x_0 + (width // cols) * 3 // 4, y_0 + (height // rows) // 2 + elif subarea == "bottom-left": + x, y = x_0 + (width // cols) // 4, y_0 + (height // rows) * 3 // 4 + elif subarea == "bottom": + x, y = x_0 + (width // cols) // 2, y_0 + (height // rows) * 3 // 4 + elif subarea == "bottom-right": + x, y = x_0 + (width // cols) * 3 // 4, y_0 + (height // rows) * 3 // 4 + else: + x, y = x_0 + (width // cols) // 2, y_0 + (height // rows) // 2 + return x, y