Update AppAgent's self_learn_and_self_reflect's test

This commit is contained in:
Jiayi Zhang 2024-02-22 17:57:25 +08:00
parent a1b0faacf4
commit 13cf80b46a
5 changed files with 77 additions and 56 deletions

View file

@ -61,12 +61,15 @@ class SelfLearnAndReflect(Action):
self, round_count: int, task_desc: str, last_act: str, task_dir: Path, docs_dir: Path, env: AndroidEnv
) -> AndroidActionOutput:
resp = await self.run_self_learn(round_count, task_desc, last_act, task_dir, env)
print(resp)
resp = await self.run_reflect(round_count, task_desc, last_act, task_dir, docs_dir, env)
print(resp)
return resp
async def run_self_learn(
self, round_count: int, task_desc: str, last_act: str, task_dir: Path, env: AndroidEnv
) -> AndroidActionOutput:
logger.info('run_self_learn')
screenshot_path: Path = env.observe(
EnvAPIAbstract(
api_name="get_screenshot", kwargs={"ss_name": f"{round_count}_before", "local_save_dir": task_dir}
@ -80,6 +83,7 @@ class SelfLearnAndReflect(Action):
clickable_list = []
focusable_list = []
# TODO Tuple Bug 从这里开始 Debug
# TODO Tuple Bug
traverse_xml_tree(xml_path, clickable_list, "clickable", True)
traverse_xml_tree(xml_path, focusable_list, "focusable", True)
@ -98,7 +102,9 @@ class SelfLearnAndReflect(Action):
bbox = e.bbox
center_ = (bbox[0][0] + bbox[1][0]) // 2, (bbox[0][1] + bbox[1][1]) // 2
dist = (abs(center[0] - center_[0]) ** 2 + abs(center[1] - center_[1]) ** 2) ** 0.5
if dist <= config.get_other("min_dist"):
# TODO Modify config to default 30. It should be modified back config after single action test
# if dist <= config.get_other("min_dist"):
if dist <= 30:
close = True
break
if not close:
@ -113,10 +119,12 @@ class SelfLearnAndReflect(Action):
context = self_explore_template.format(task_description=task_desc, last_act=last_act)
node = await SCREENSHOT_PARSE_NODE.fill(context=context, llm=self.llm, images=[img_base64])
print(f"fill result:{node}")
if "error" in node.content:
return AndroidActionOutput(action_state=RunState.FAIL)
prompt = node.compile(context=context, schema="json", mode="auto")
OpLogItem(step=round_count, prompt=prompt, image=screenshot_before_labeled_path, response=node.content)
# Modify WindowsPath to Str
OpLogItem(step=round_count, prompt=prompt, image=str(screenshot_before_labeled_path), response=node.content)
op_param = screenshot_parse_extract(node.instruct_content.model_dump(), grid_on=False)
if op_param.param_state == RunState.FINISH:
return AndroidActionOutput(action_state=RunState.FINISH)
@ -126,17 +134,17 @@ class SelfLearnAndReflect(Action):
if isinstance(op_param, TapOp):
self.ui_area = op_param.area
x, y = elem_bbox_to_xy(elem_list[op_param.area - 1].bbox)
res = env.step(EnvAPIAbstract("system_tap", kwargs={"x": x, "y": y}))
res = env.step(EnvAPIAbstract(api_name="system_tap", kwargs={"x": x, "y": y}))
if res == ADB_EXEC_FAIL:
return AndroidActionOutput(action_state=RunState.FAIL)
elif isinstance(op_param, TextOp):
res = env.step(EnvAPIAbstract("user_input", kwargs={"input_txt": op_param.input_str}))
res = env.step(EnvAPIAbstract(api_name="user_input", kwargs={"input_txt": op_param.input_str}))
if res == ADB_EXEC_FAIL:
return AndroidActionOutput(action_state=RunState.FAIL)
elif isinstance(op_param, LongPressOp):
self.ui_area = op_param.area
x, y = elem_bbox_to_xy(elem_list[op_param.area - 1].bbox)
res = env.step(EnvAPIAbstract("user_longpress", kwargs={"x": x, "y": y}))
res = env.step(EnvAPIAbstract(api_name="user_longpress", kwargs={"x": x, "y": y}))
if res == ADB_EXEC_FAIL:
return AndroidActionOutput(action_state=RunState.FAIL)
elif isinstance(op_param, SwipeOp):
@ -158,6 +166,7 @@ class SelfLearnAndReflect(Action):
async def run_reflect(
self, round_count: int, task_desc: str, last_act: str, task_dir: Path, docs_dir: Path, env: AndroidEnv
) -> AndroidActionOutput:
logger.info("run_reflect")
screenshot_path: Path = env.observe(
EnvAPIAbstract(
api_name="get_screenshot", kwargs={"ss_name": f"{round_count}_after", "local_save_dir": task_dir}
@ -170,6 +179,7 @@ class SelfLearnAndReflect(Action):
draw_bbox_multi(screenshot_path, screenshot_after_labeled_path, elem_list=self.elem_list)
img_base64 = encode_image(screenshot_after_labeled_path)
logger.info(f"act_name: {self.act_name}")
if self.act_name == ActionOp.TAP.value:
action = "tapping"
elif self.act_name == ActionOp.LONG_PRESS.value:
@ -194,8 +204,8 @@ class SelfLearnAndReflect(Action):
ReflectLogItem(
step=round_count,
prompt=prompt,
image_before=self.screenshot_before_path,
image_after=screenshot_after_labeled_path,
image_before=str(self.screenshot_before_path),
image_after=str(screenshot_after_labeled_path),
response=node.content,
)
@ -214,7 +224,7 @@ class SelfLearnAndReflect(Action):
self.useless_list.append(resource_id)
last_act = "NONE"
if op_param.decision == Decision.BACK.value:
res = env.step(EnvAPIAbstract("system_back"))
res = env.step(EnvAPIAbstract(api_name="system_back"))
if res == ADB_EXEC_FAIL:
return AndroidActionOutput(action_state=RunState.FAIL)
doc = op_param.documentation

View file

@ -34,7 +34,7 @@ test_manual_parse = ParseRecord()
if __name__ == "__main__":
loop = asyncio.get_event_loop()
test_action_list = [
loop.run_until_complete(
test_self_learning.run(
round_count=20,
task_desc="Create a contact in Contacts App named zjy with a phone number +86 18831933368 ",
@ -42,20 +42,31 @@ if __name__ == "__main__":
task_dir=TASK_PATH,
docs_dir=DOC_PATH,
env=test_env_self_learn_android
),
# test_manual_record.run(
# demo_name=DEMO_NAME,
# task_dir=TASK_PATH,
# env=test_env_manual_learn_android
# ),
# test_manual_parse.run(
# app_name="Contacts",
# demo_name=DEMO_NAME,
# task_dir=TASK_PATH,
# docs_dir=DOC_PATH,
# env=test_env_manual_learn_android
# )
]
loop.run_until_complete(asyncio.gather(*test_action_list))
)
)
# test_action_list = [
# test_self_learning.run(
# round_count=20,
# task_desc="Create a contact in Contacts App named zjy with a phone number +86 18831933368 ",
# last_act="",
# task_dir=TASK_PATH,
# docs_dir=DOC_PATH,
# env=test_env_self_learn_android
# ),
# test_manual_record.run(
# demo_name=DEMO_NAME,
# task_dir=TASK_PATH,
# env=test_env_manual_learn_android
# ),
# test_manual_parse.run(
# app_name="Contacts",
# demo_name=DEMO_NAME,
# task_dir=TASK_PATH,
# docs_dir=DOC_PATH,
# env=test_env_manual_learn_android
# )
# ]
# loop.run_until_complete(asyncio.gather(*test_action_list))
loop.close()
print("Finish")

View file

@ -38,7 +38,7 @@ class Decision(Enum):
class AndroidElement(BaseModel):
"""UI Element"""
uid: str = Field(default="")
bbox: tuple[tuple[int, int]] = Field(default={})
bbox: tuple[tuple[int, int], tuple[int, int]] = Field(default={})
attrib: str = Field(default="")

View file

@ -55,7 +55,9 @@ def traverse_xml_tree(xml_path: Path, elem_list: list[AndroidElement], attrib: s
bbox = e.bbox
center_ = (bbox[0][0] + bbox[1][0]) // 2, (bbox[0][1] + bbox[1][1]) // 2
dist = (abs(center[0] - center_[0]) ** 2 + abs(center[1] - center_[1]) ** 2) ** 0.5
if dist <= config.get_other("min_dist"):
# TODO Modify config to default 30. It should be modified back config after single action test
# if dist <= config.get_other("min_dist"):
if dist <= 30:
close = True
break
if not close:
@ -67,7 +69,7 @@ def traverse_xml_tree(xml_path: Path, elem_list: list[AndroidElement], attrib: s
def draw_bbox_multi(img_path: Path, output_path: Path, elem_list: list[AndroidElement], record_mode: bool = False,
dark_mode: bool = False):
imgcv = cv2.imread(img_path)
imgcv = cv2.imread(str(img_path))
count = 1
for elem in elem_list:
try:
@ -97,7 +99,7 @@ def draw_bbox_multi(img_path: Path, output_path: Path, elem_list: list[AndroidEl
except Exception as e:
logger.error(f"ERROR: An exception occurs while labeling the image\n{e}")
count += 1
cv2.imwrite(output_path, imgcv)
cv2.imwrite(str(output_path), imgcv)
return imgcv