add andriod_assistant action self-learn / self-learn-reflect / screenshot-parse

This commit is contained in:
better629 2024-01-26 15:18:29 +08:00
parent 295571fafa
commit 8a49292045
12 changed files with 388 additions and 46 deletions

View file

@ -7,6 +7,7 @@ from metagpt.actions.action import Action
class ManualRecord(Action):
"""do a human operation on the screen with human input"""
name: str = "ManualRecord"
async def run(self):

View file

@ -1,10 +1,11 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : parse record to generate learned standard operations in stage=learn & mode=manual, LIKE scripts/document_generation.py
from metagpt.actions.action import Action
# @Desc : parse record to generate learned standard operations in stage=learn & mode=manual,
# LIKE scripts/document_generation.py
from examples.andriod_assistant.prompts.operation_prompt import *
from metagpt.actions.action import Action
class ParseRecord(Action):
name: str = "ParseRecord"

View file

@ -2,11 +2,63 @@
# -*- coding: utf-8 -*-
# @Desc : LIKE scripts/task_executor.py in stage=act
from pathlib import Path
from examples.andriod_assistant.prompts.assistant_prompt import (
screenshot_parse_template,
screenshot_parse_with_grid_template,
)
from examples.andriod_assistant.utils.utils import draw_bbox_multi, traverse_xml_tree
from metagpt.actions.action import Action
from metagpt.config2 import config
from metagpt.environment.android_env.android_env import AndroidEnv
from metagpt.environment.api.env_api import EnvAPIAbstract
from metagpt.utils.common import encode_image
class ScreenshotParse(Action):
name: str = "ScreenshotParse"
async def run(self):
pass
async def run(
self, round_count: int, task_desc: str, last_act: str, task_dir: Path, env: AndroidEnv, grid_on: bool = False
):
screenshot_path: Path = env.step(
EnvAPIAbstract(
api_name="get_screenshot", kwargs={"ss_name": f"{round_count}_before", "local_save_dir": task_dir}
)
)
xml_path: Path = env.step(
EnvAPIAbstract(api_name="get_xml", kwargs={"xml_name": f"{round_count}", "local_save_dir": task_dir})
)
if not screenshot_path.exists() or not xml_path.exists():
# TODO exit
return
clickable_list = []
focusable_list = []
traverse_xml_tree(xml_path, clickable_list, "clickable", True)
traverse_xml_tree(xml_path, focusable_list, "focusable", True)
elem_list = clickable_list.copy()
for elem in focusable_list:
bbox = elem.bbox
center = (bbox[0][0] + bbox[1][0]) // 2, (bbox[0][1] + bbox[1][1]) // 2
close = False
for e in clickable_list:
bbox = e.bbox
center_ = (bbox[0][0] + bbox[1][0]) // 2, (bbox[0][1] + bbox[1][1]) // 2
dist = (abs(center[0] - center_[0]) ** 2 + abs(center[1] - center_[1]) ** 2) ** 0.5
if dist <= config.get_other("min_dist"):
close = True
break
if not close:
elem_list.append(elem)
draw_bbox_multi(screenshot_path, task_dir.joinpath(f"{task_dir}_{round_count}_labeled.png"), elem_list)
encode_image(task_dir.joinpath(f"{task_dir}_{round_count}_labeled.png"))
parse_template = screenshot_parse_with_grid_template if grid_on else screenshot_parse_template
# makeup `ui_doc`
# TODO
ui_doc = ""
parse_template.format(ui_document=ui_doc, task_description=task_desc, last_act=last_act)

View file

@ -4,59 +4,45 @@
from metagpt.actions.action_node import ActionNode
OBSERVATION = ActionNode(
key="Observation",
expected_type=str,
instruction="Describe what you observe in the image",
example=""
key="Observation", expected_type=str, instruction="Describe what you observe in the image", example=""
)
THOUGHT = ActionNode(
key="Thought",
expected_type=str,
instruction="To complete the given task, what is the next step I should do",
example=""
example="",
)
ACTION = ActionNode(
key="Action",
expected_type=str,
instruction="The function call with the correct parameters to proceed with the task. If you believe the task is "
"completed or there is nothing to be done, you should output FINISH. You cannot output anything else "
"except a function call or FINISH in this field.",
example=""
"completed or there is nothing to be done, you should output FINISH. You cannot output anything else "
"except a function call or FINISH in this field.",
example="",
)
SUMMARY = ActionNode(
key="Summary",
expected_type=str,
instruction="Summarize your past actions along with your latest action in one or two sentences. Do not include "
"the numeric tag in your summary",
example=""
"the numeric tag in your summary",
example="",
)
SUMMARY_GRID = ActionNode(
key="Summary",
expected_type=str,
instruction="Summarize your past actions along with your latest action in one or two sentences. Do not include "
"the grid area number in your summary",
example=""
"the grid area number in your summary",
example="",
)
NODES = [
OBSERVATION,
THOUGHT,
ACTION,
SUMMARY
]
NODES = [OBSERVATION, THOUGHT, ACTION, SUMMARY]
NODES_GRID = [
OBSERVATION,
THOUGHT,
ACTION,
SUMMARY_GRID
]
NODES_GRID = [OBSERVATION, THOUGHT, ACTION, SUMMARY_GRID]
SCREENSHOT_PARSE_NODE = ActionNode.from_children("ScreenshotParse", NODES)
SCREENSHOT_PARSE_GRID_NODE = ActionNode.from_children("ScreenshotParseGrid", NODES_GRID)

View file

@ -2,14 +2,66 @@
# -*- coding: utf-8 -*-
# @Desc : LIKE scripts/self_explorer.py in stage=learn & mode=auto self_explore_task stage
from metagpt.actions.action import Action
from pathlib import Path
from examples.andriod_assistant.actions.screenshot_parse_an import SCREENSHOT_PARSE_NODE
from examples.andriod_assistant.prompts.assistant_prompt import screenshot_parse_self_explore_template
from examples.andriod_assistant.prompts.assistant_prompt import (
screenshot_parse_self_explore_template,
)
from examples.andriod_assistant.utils.utils import draw_bbox_multi, traverse_xml_tree
from metagpt.actions.action import Action
from metagpt.config2 import config
from metagpt.environment.android_env.android_env import AndroidEnv
from metagpt.environment.api.env_api import EnvAPIAbstract
from metagpt.utils.common import encode_image
class SelfLearn(Action):
name: str = "SelfLearn"
async def run(self):
pass
useless_list: list[str] = [] # store useless elements uid
async def run(self, round_count: int, task_desc: str, last_act: str, task_dir: Path, env: AndroidEnv):
screenshot_path: Path = env.step(
EnvAPIAbstract(
api_name="get_screenshot", kwargs={"ss_name": f"{round_count}_before", "local_save_dir": task_dir}
)
)
xml_path: Path = env.step(
EnvAPIAbstract(api_name="get_xml", kwargs={"xml_name": f"{round_count}", "local_save_dir": task_dir})
)
if not screenshot_path.exists() or not xml_path.exists():
# TODO exit
return
clickable_list = []
focusable_list = []
traverse_xml_tree(xml_path, clickable_list, "clickable", True)
traverse_xml_tree(xml_path, focusable_list, "focusable", True)
elem_list = []
for elem in clickable_list:
if elem.uid in self.useless_list:
continue
elem_list.append(elem)
for elem in focusable_list:
if elem.uid in self.useless_list:
continue
bbox = elem.bbox
center = (bbox[0][0] + bbox[1][0]) // 2, (bbox[0][1] + bbox[1][1]) // 2
close = False
for e in clickable_list:
bbox = e.bbox
center_ = (bbox[0][0] + bbox[1][0]) // 2, (bbox[0][1] + bbox[1][1]) // 2
dist = (abs(center[0] - center_[0]) ** 2 + abs(center[1] - center_[1]) ** 2) ** 0.5
if dist <= config.get_other("min_dist"):
close = True
break
if not close:
elem_list.append(elem)
draw_bbox_multi(screenshot_path, task_dir.joinpath(f"{round_count}_before_labeled.png"), elem_list)
encode_image(task_dir.joinpath(f"{round_count}_before_labeled.png"))
self_explore_template = screenshot_parse_self_explore_template
context = self_explore_template.format(task_description=task_desc, last_act=last_act)
await SCREENSHOT_PARSE_NODE.fill(context=context, llm=self.llm)

View file

@ -2,13 +2,60 @@
# -*- coding: utf-8 -*-
# @Desc : LIKE scripts/self_explorer.py self_explore_reflect stage
from metagpt.actions.action import Action
from pathlib import Path
from examples.andriod_assistant.prompts.assistant_prompt import screenshot_parse_self_explore_reflect_template
from examples.andriod_assistant.prompts.assistant_prompt import (
screenshot_parse_self_explore_reflect_template,
)
from examples.andriod_assistant.utils.schema import AndroidElement
from examples.andriod_assistant.utils.utils import draw_bbox_multi
from metagpt.actions.action import Action
from metagpt.environment.android_env.android_env import AndroidEnv
from metagpt.environment.api.env_api import EnvAPIAbstract
from metagpt.utils.common import encode_image
class SelfLearnReflect(Action):
name: str = "SelfLearnReflect"
async def run(self):
pass
async def run(
self,
round_count: int,
task_desc: str,
last_act: str,
task_dir: Path,
env: AndroidEnv,
elem_list: list[AndroidElement],
act_name: str,
swipe_dir: str,
ui_area: int,
):
if act_name == "text":
# TODO ignore current reflect
return
screenshot_path: Path = env.step(
EnvAPIAbstract(
api_name="get_screenshot", kwargs={"ss_name": f"{round_count}_before", "local_save_dir": task_dir}
)
)
if not screenshot_path.exists():
# TODO exit
return
draw_bbox_multi(screenshot_path, task_dir.joinpath(f"{round_count}_after_labeled.png"), elem_list)
encode_image(task_dir.joinpath(f"{round_count}_after_labeled.png"))
reflect_template = screenshot_parse_self_explore_reflect_template
if act_name == "tap":
action = "tapping"
elif act_name == "long_press":
action = "long pressing"
elif act_name == "swipe":
action = "swiping"
if swipe_dir == "up" or swipe_dir == "down":
action = "v_swipe"
elif swipe_dir == "left" or swipe_dir == "right":
action = "h_swipe"
reflect_template.format(action=action, ui_element=str(ui_area), task_desc=task_desc, last_act=last_act)

View file

@ -165,4 +165,3 @@ Decision: SUCCESS
Thought: <explain why you think the action successfully moved the task forward>
Documentation: <describe the function of the UI element>
"""

View file

@ -0,0 +1 @@
pyshine==0.0.9

View file

@ -2,15 +2,16 @@
# -*- coding: utf-8 -*-
# @Desc : android assistant to learn from app operations and operate apps
from metagpt.roles.role import Role
from metagpt.config2 import config
from metagpt.actions.add_requirement import UserRequirement
from examples.andriod_assistant.actions.manual_record import ManualRecord
from examples.andriod_assistant.actions.parse_record import ParseRecord
from examples.andriod_assistant.actions.screenshot_parse import ScreenshotParse
from examples.andriod_assistant.actions.self_learn import SelfLearn
from examples.andriod_assistant.actions.self_learn_reflect import SelfLearnReflect
from examples.andriod_assistant.actions.screenshot_parse import ScreenshotParse
from metagpt.actions.add_requirement import UserRequirement
from metagpt.config2 import config
from metagpt.logs import logger
from metagpt.roles.role import Role
from metagpt.schema import Message
class AndroidAssistant(Role):
@ -25,6 +26,9 @@ class AndroidAssistant(Role):
self.set_actions([ManualRecord, ParseRecord, SelfLearn, SelfLearnReflect, ScreenshotParse])
async def _think(self) -> bool:
"""Firstly, we decide the state with user config, further, we can do it automatically, like if it's new app,
run the learn first and then do the act stage or learn it during the action.
"""
if config.get_other("stage") == "learn" and config.get_other("mode") == "manual":
# choose ManualRecord and then run ParseRecord
# Remember, only run each action only one time, no need to run n_round.
@ -37,4 +41,4 @@ class AndroidAssistant(Role):
pass
async def _act(self) -> Message:
pass
logger.info(f"{self._setting}: to do {self.rc.todo}({self.rc.todo.name})")

View file

@ -16,6 +16,7 @@ app = typer.Typer(add_completion=False, pretty_exceptions_show_locals=False)
@app.command("", help="Run a Android Assistant")
def startup(
task_desc: str = typer.Argument(help="the task description you want the android assistant to learn or act"),
n_round: int = typer.Option(default=20, help="The max round to do an app operation task."),
stage: str = typer.Option(default="learn", help="stage: learn / act"),
mode: str = typer.Option(default="auto", help="mode: auto / manual , when state=learn"),
@ -49,7 +50,7 @@ def startup(
team = Team(env=AndroidEnv())
team.hire([AndroidAssistant])
team.invest(investment)
company.run_project(idea="") # no need idea, just a mock
company.run_project(idea=task_desc)
asyncio.run(team.run(n_round=n_round))

View file

@ -0,0 +1,39 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
from pydantic import Field, BaseModel
class AndroidElement(BaseModel):
"""UI Element"""
uid: str = Field(default="")
bbox: tuple[tuple[int, int]] = Field(default={})
attrib: str = Field(default="")
class OpLogItem(BaseModel):
"""log content for self-learn or task act"""
step: int = Field(default=0)
prompt: str = Field(default="")
image: str = Field(default="")
response: str = Field(default="")
class ReflectLogItem(BaseModel):
"""log content for self-learn-reflect"""
step: int = Field(default=0)
prompt: str = Field(default="")
image_before: str = Field(default="")
image_after: str = Field(default="")
response: str = Field(default="")
class DocContent(BaseModel):
tap: str = Field(default="")
text: str = Field(default="")
v_swipe: str = Field(default="")
h_swipe: str = Field(default="")
long_press: str = Field(default="")

View file

@ -0,0 +1,159 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
from pydantic import Field, BaseModel
from xml.etree.ElementTree import Element, iterparse
import cv2
from pathlib import Path
import pyshine as ps
import base64
from metagpt.config2 import config
from metagpt.logs import logger
from examples.andriod_assistant.utils.schema import AndroidElement
def get_id_from_element(elem: Element) -> str:
bounds = elem.attrib["bounds"][1:-1].split("][")
x1, y1 = map(int, bounds[0].split(","))
x2, y2 = map(int, bounds[1].split(","))
elem_w, elem_h = x2 - x1, y2 - y1
if "resource-id" in elem.attrib and elem.attrib["resource-id"]:
elem_id = elem.attrib["resource-id"].replace(":", ".").replace("/", "_")
else:
elem_id = f"{elem.attrib['class']}_{elem_w}_{elem_h}"
if "content-desc" in elem.attrib and elem.attrib["content-desc"] and len(elem.attrib["content-desc"]) < 20:
content_desc = elem.attrib["content-desc"].replace("/", "_").replace(" ", "").replace(":", "_")
elem_id += f"_{content_desc}"
return elem_id
def traverse_xml_tree(xml_path: Path, elem_list: list[AndroidElement], attrib: str, add_index=False):
path = []
for event, elem in iterparse(str(xml_path), ["start", "end"]):
if event == "start":
path.append(elem)
if attrib in elem.attrib and elem.attrib[attrib] == "true":
parent_prefix = ""
if len(path) > 1:
parent_prefix = get_id_from_element(path[-2])
bounds = elem.attrib["bounds"][1:-1].split("][")
x1, y1 = map(int, bounds[0].split(","))
x2, y2 = map(int, bounds[1].split(","))
center = (x1 + x2) // 2, (y1 + y2) // 2
elem_id = get_id_from_element(elem)
if parent_prefix:
elem_id = parent_prefix + "_" + elem_id
if add_index:
elem_id += f"_{elem.attrib['index']}"
close = False
for e in elem_list:
bbox = e.bbox
center_ = (bbox[0][0] + bbox[1][0]) // 2, (bbox[0][1] + bbox[1][1]) // 2
dist = (abs(center[0] - center_[0]) ** 2 + abs(center[1] - center_[1]) ** 2) ** 0.5
if dist <= config.get_other("min_dist"):
close = True
break
if not close:
elem_list.append(AndroidElement(uid=elem_id, bbox=((x1, y1), (x2, y2)), attrib=attrib))
if event == "end":
path.pop()
def draw_bbox_multi(img_path: Path, output_path: Path, elem_list: list[AndroidElement], record_mode: bool = False,
dark_mode: bool = False):
imgcv = cv2.imread(img_path)
count = 1
for elem in elem_list:
try:
top_left = elem.bbox[0]
bottom_right = elem.bbox[1]
left, top = top_left[0], top_left[1]
right, bottom = bottom_right[0], bottom_right[1]
label = str(count)
if record_mode:
if elem.attrib == "clickable":
color = (250, 0, 0)
elif elem.attrib == "focusable":
color = (0, 0, 250)
else:
color = (0, 250, 0)
imgcv = ps.putBText(imgcv, label, text_offset_x=(left + right) // 2 + 10,
text_offset_y=(top + bottom) // 2 + 10,
vspace=10, hspace=10, font_scale=1, thickness=2, background_RGB=color,
text_RGB=(255, 250, 250), alpha=0.5)
else:
text_color = (10, 10, 10) if dark_mode else (255, 250, 250)
bg_color = (255, 250, 250) if dark_mode else (10, 10, 10)
imgcv = ps.putBText(imgcv, label, text_offset_x=(left + right) // 2 + 10,
text_offset_y=(top + bottom) // 2 + 10,
vspace=10, hspace=10, font_scale=1, thickness=2, background_RGB=bg_color,
text_RGB=text_color, alpha=0.5)
except Exception as e:
logger.error(f"ERROR: An exception occurs while labeling the image\n{e}")
count += 1
cv2.imwrite(output_path, imgcv)
return imgcv
def draw_grid(img_path: Path, output_path: Path) -> tuple[int, int]:
def get_unit_len(n):
for i in range(1, n + 1):
if n % i == 0 and 120 <= i <= 180:
return i
return -1
image = cv2.imread(img_path)
height, width, _ = image.shape
color = (255, 116, 113)
unit_height = get_unit_len(height)
if unit_height < 0:
unit_height = 120
unit_width = get_unit_len(width)
if unit_width < 0:
unit_width = 120
thick = int(unit_width // 50)
rows = height // unit_height
cols = width // unit_width
for i in range(rows):
for j in range(cols):
label = i * cols + j + 1
left = int(j * unit_width)
top = int(i * unit_height)
right = int((j + 1) * unit_width)
bottom = int((i + 1) * unit_height)
cv2.rectangle(image, (left, top), (right, bottom), color, thick // 2)
cv2.putText(image, str(label), (left + int(unit_width * 0.05) + 3, top + int(unit_height * 0.3) + 3), 0,
int(0.01 * unit_width), (0, 0, 0), thick)
cv2.putText(image, str(label), (left + int(unit_width * 0.05), top + int(unit_height * 0.3)), 0,
int(0.01 * unit_width), color, thick)
cv2.imwrite(output_path, image)
return rows, cols
def area_to_xy(width: int, height: int, cols: int, rows: int, area: int, subarea: str) -> tuple[int, int]:
area -= 1
row, col = area // cols, area % cols
x_0, y_0 = col * (width // cols), row * (height // rows)
if subarea == "top-left":
x, y = x_0 + (width // cols) // 4, y_0 + (height // rows) // 4
elif subarea == "top":
x, y = x_0 + (width // cols) // 2, y_0 + (height // rows) // 4
elif subarea == "top-right":
x, y = x_0 + (width // cols) * 3 // 4, y_0 + (height // rows) // 4
elif subarea == "left":
x, y = x_0 + (width // cols) // 4, y_0 + (height // rows) // 2
elif subarea == "right":
x, y = x_0 + (width // cols) * 3 // 4, y_0 + (height // rows) // 2
elif subarea == "bottom-left":
x, y = x_0 + (width // cols) // 4, y_0 + (height // rows) * 3 // 4
elif subarea == "bottom":
x, y = x_0 + (width // cols) // 2, y_0 + (height // rows) * 3 // 4
elif subarea == "bottom-right":
x, y = x_0 + (width // cols) * 3 // 4, y_0 + (height // rows) * 3 // 4
else:
x, y = x_0 + (width // cols) // 2, y_0 + (height // rows) // 2
return x, y