Update test for action node & Modify extenv (self reflection)

This commit is contained in:
Jiayi Zhang 2024-02-05 22:17:43 +08:00
parent 32211ff5f2
commit a1b0faacf4
5 changed files with 159 additions and 33 deletions

View file

@ -1,4 +1,4 @@
#!/usr/bin/env python
# !/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : LIKE scripts/self_explorer.py in stage=learn & mode=auto self_explore_task stage
@ -58,21 +58,21 @@ class SelfLearnAndReflect(Action):
ui_area: int = -1
async def run(
self, round_count: int, task_desc: str, last_act: str, task_dir: Path, docs_dir: Path, env: AndroidEnv
self, round_count: int, task_desc: str, last_act: str, task_dir: Path, docs_dir: Path, env: AndroidEnv
) -> AndroidActionOutput:
resp = self.run_self_learn(round_count, task_desc, last_act, task_dir, env)
resp = self.run_reflect(round_count, task_desc, last_act, task_dir, docs_dir, env)
resp = await self.run_self_learn(round_count, task_desc, last_act, task_dir, env)
resp = await self.run_reflect(round_count, task_desc, last_act, task_dir, docs_dir, env)
return resp
async def run_self_learn(
self, round_count: int, task_desc: str, last_act: str, task_dir: Path, env: AndroidEnv
self, round_count: int, task_desc: str, last_act: str, task_dir: Path, env: AndroidEnv
) -> AndroidActionOutput:
screenshot_path: Path = env.step(
screenshot_path: Path = env.observe(
EnvAPIAbstract(
api_name="get_screenshot", kwargs={"ss_name": f"{round_count}_before", "local_save_dir": task_dir}
)
)
xml_path: Path = env.step(
xml_path: Path = env.observe(
EnvAPIAbstract(api_name="get_xml", kwargs={"xml_name": f"{round_count}", "local_save_dir": task_dir})
)
if not screenshot_path.exists() or not xml_path.exists():
@ -80,6 +80,7 @@ class SelfLearnAndReflect(Action):
clickable_list = []
focusable_list = []
# TODO Tuple Bug
traverse_xml_tree(xml_path, clickable_list, "clickable", True)
traverse_xml_tree(xml_path, focusable_list, "focusable", True)
elem_list = []
@ -155,9 +156,9 @@ class SelfLearnAndReflect(Action):
return AndroidActionOutput()
async def run_reflect(
self, round_count: int, task_desc: str, last_act: str, task_dir: Path, docs_dir: Path, env: AndroidEnv
self, round_count: int, task_desc: str, last_act: str, task_dir: Path, docs_dir: Path, env: AndroidEnv
) -> AndroidActionOutput:
screenshot_path: Path = env.step(
screenshot_path: Path = env.observe(
EnvAPIAbstract(
api_name="get_screenshot", kwargs={"ss_name": f"{round_count}_after", "local_save_dir": task_dir}
)

View file

@ -0,0 +1,61 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : test on android emulator
import asyncio
import time
from pathlib import Path
from actions.manual_record import ManualRecord
from actions.parse_record import ParseRecord
from actions.self_learn_and_reflect import SelfLearnAndReflect
from metagpt.environment.android_env.android_env import AndroidEnv
TASK_PATH = Path("apps/Contacts")
DOC_PATH = TASK_PATH.joinpath("docs")
DEMO_NAME = str(time.time())
# TODO Test for Self Learning、
test_env_self_learn_android = AndroidEnv(
device_id="emulator-5554",
xml_dir=Path("/sdcard"),
screenshot_dir=Path("/sdcard/Pictures/Screenshots"),
)
test_self_learning = SelfLearnAndReflect()
# TODO Test for Manual Learning
test_env_manual_learn_android = AndroidEnv(
device_id="emulator-5554",
xml_dir=Path("/sdcard"),
screenshot_dir=Path("/sdcard/Pictures/Screenshots"),
)
test_manual_record = ManualRecord()
test_manual_parse = ParseRecord()
# 虚拟机效果实现
# 不同 Action Node 结果符合预期(Action Node)
if __name__ == "__main__":
loop = asyncio.get_event_loop()
test_action_list = [
test_self_learning.run(
round_count=20,
task_desc="Create a contact in Contacts App named zjy with a phone number +86 18831933368 ",
last_act="",
task_dir=TASK_PATH,
docs_dir=DOC_PATH,
env=test_env_self_learn_android
),
# test_manual_record.run(
# demo_name=DEMO_NAME,
# task_dir=TASK_PATH,
# env=test_env_manual_learn_android
# ),
# test_manual_parse.run(
# app_name="Contacts",
# demo_name=DEMO_NAME,
# task_dir=TASK_PATH,
# docs_dir=DOC_PATH,
# env=test_env_manual_learn_android
# )
]
loop.run_until_complete(asyncio.gather(*test_action_list))
loop.close()
print("Finish")

View file

@ -9,10 +9,10 @@ from typing import Any, Optional
from pydantic import Field
from metagpt.const import ADB_EXEC_FAIL
from metagpt.environment.base_env import ExtEnv, mark_as_readable, mark_as_writeable
from metagpt.environment.base_env import Env, ExtEnv, mark_as_readable, mark_as_writeable
class AndroidExtEnv(ExtEnv):
class AndroidExtEnv(Env, ExtEnv):
device_id: Optional[str] = Field(default=None)
screenshot_dir: Optional[Path] = Field(default=None)
xml_dir: Optional[Path] = Field(default=None)
@ -42,6 +42,7 @@ class AndroidExtEnv(ExtEnv):
return f"adb -s {self.device_id} "
def execute_adb_with_cmd(self, adb_cmd: str) -> str:
adb_cmd = adb_cmd.replace('\\', '/')
res = subprocess.run(adb_cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
exec_res = ADB_EXEC_FAIL
if not res.returncode:

View file

@ -3,7 +3,7 @@
# @Desc : base env of executing environment
from enum import Enum
from typing import Optional, Union
from typing import Optional, Union, Any
from pydantic import BaseModel, ConfigDict, Field
@ -13,6 +13,7 @@ from metagpt.environment.api.env_api import (
WriteAPIRegistry,
)
from metagpt.schema import Message
from metagpt.utils.common import get_function_schema, is_coroutine_func
class EnvType(Enum):
@ -23,26 +24,40 @@ class EnvType(Enum):
STANFORDTOWN = "StanfordTown"
env_write_api_registry = WriteAPIRegistry()
env_read_api_registry = ReadAPIRegistry()
# def mark_as_readable(func):
# """mark functionn as a readable one in ExtEnv, it observes something from ExtEnv"""
#
# def wrapper(self: ExtEnv, *args, **kwargs):
# api_name = func.__name__
# self.read_api_registry[api_name] = func
# return func(self, *args, **kwargs)
#
# return wrapper
#
# def mark_as_writeable(func):
# """mark functionn as a writeable one in ExtEnv, it does something to ExtEnv"""
#
# def wrapper(self: ExtEnv, *args, **kwargs):
# api_name = func.__name__
# self.write_api_registry[api_name] = func
# return func(self, *args, **kwargs)
#
# return wrapper
def mark_as_readable(func):
"""mark functionn as a readable one in ExtEnv, it observes something from ExtEnv"""
def wrapper(self: ExtEnv, *args, **kwargs):
api_name = func.__name__
self.read_api_registry[api_name] = func
return func(self, *args, **kwargs)
return wrapper
"""mark function as a readable one in ExtEnv, it observes something from ExtEnv"""
env_read_api_registry[func.__name__] = get_function_schema(func)
return func
def mark_as_writeable(func):
"""mark functionn as a writeable one in ExtEnv, it does something to ExtEnv"""
def wrapper(self: ExtEnv, *args, **kwargs):
api_name = func.__name__
self.write_api_registry[api_name] = func
return func(self, *args, **kwargs)
return wrapper
"""mark function as a writeable one in ExtEnv, it does something to ExtEnv"""
env_write_api_registry[func.__name__] = get_function_schema(func)
return func
class ExtEnv(BaseModel):
@ -61,23 +76,59 @@ class Env(ExtEnv):
if not rw_api:
raise ValueError(f"{rw_api} not exists")
def get_all_available_apis(self, mode: str = "read") -> list[Any]:
"""get available read/write apis definition"""
assert mode in ["read", "write"]
if mode == "read":
return env_read_api_registry.get_apis()
else:
return env_write_api_registry.get_apis()
# TODO adds is_coroutine_func
# def observe(self, env_action: Union[str, EnvAPIAbstract]):
# if isinstance(env_action, str):
# read_api = env_write_api_registry.get(api_name=env_action)
# self._check_api_exist(read_api)
# res = read_api(self)
# elif isinstance(env_action, EnvAPIAbstract):
# read_api = env_write_api_registry.get(api_name=env_action.api_name)
# self._check_api_exist(read_api)
# res = read_api(self, *env_action.args, **env_action.kwargs)
# return res
#
# def step(self, env_action: Union[str, Message, EnvAPIAbstract, list[EnvAPIAbstract]]):
# res = None
# if isinstance(env_action, Message):
# self.publish_message(env_action)
# elif isinstance(env_action, EnvAPIAbstract):
# print(f"CURRENT API NAME: {env_action.api_name}")
# write_api = self.write_api_registry.get(env_action.api_name)
# self._check_api_exist(write_api)
# res = write_api(self, *env_action.args, **env_action.kwargs)
#
# return res
def observe(self, env_action: Union[str, EnvAPIAbstract]):
# TODO Adds is_coroutine_func
"""get observation from particular api of ExtEnv"""
if isinstance(env_action, str):
read_api = self.read_api_registry.get(api_name=env_action)
read_api = env_read_api_registry.get(api_name=env_action)["func"]
self._check_api_exist(read_api)
res = read_api(self)
elif isinstance(env_action, EnvAPIAbstract):
read_api = self.read_api_registry.get(api_name=env_action.api_name)
read_api = env_read_api_registry.get(api_name=env_action.api_name)["func"]
self._check_api_exist(read_api)
res = read_api(self, *env_action.args, **env_action.kwargs)
return res
def step(self, env_action: Union[str, Message, EnvAPIAbstract, list[EnvAPIAbstract]]):
"""execute through particular api of ExtEnv"""
res = None
if isinstance(env_action, Message):
self.publish_message(env_action)
elif isinstance(env_action, EnvAPIAbstract):
write_api = self.write_api_registry.get(env_action.api_name)
write_api = env_write_api_registry.get(env_action.api_name)["func"]
self._check_api_exist(write_api)
res = write_api(self, *env_action.args, **env_action.kwargs)

View file

@ -25,7 +25,7 @@ import sys
import traceback
import typing
from pathlib import Path
from typing import Any, List, Tuple, Union
from typing import Any, List, Tuple, Union, Callable
import aiofiles
import loguru
@ -214,7 +214,7 @@ class OutputParser:
if start_index != -1 and end_index != -1:
# Extract the structure part
structure_text = text[start_index : end_index + 1]
structure_text = text[start_index: end_index + 1]
try:
# Attempt to convert the text to a Python data type using ast.literal_eval
@ -337,6 +337,14 @@ def print_members(module, indent=0):
print(f"{prefix}Method: {name}")
def get_function_schema(func: Callable) -> dict[str, Union[dict, Any, str]]:
sig = inspect.signature(func)
parameters = sig.parameters
return_type = sig.return_annotation
param_schema = {name: parameter.annotation for name, parameter in parameters.items()}
return {"input_params": param_schema, "return_type": return_type, "func_desc": func.__doc__, "func": func}
def parse_recipient(text):
# FIXME: use ActionNode instead.
pattern = r"## Send To:\s*([A-Za-z]+)\s*?" # hard code for now
@ -594,6 +602,10 @@ def list_files(root: str | Path) -> List[Path]:
return files
def is_coroutine_func(func: Callable) -> bool:
return inspect.iscoroutinefunction(func)
def encode_image(image_path: Path, encoding: str = "utf-8") -> str:
with open(str(image_path), "rb") as image_file:
return base64.b64encode(image_file.read()).decode(encoding)