diff --git a/examples/andriod_assistant/actions/self_learn_and_reflect.py b/examples/andriod_assistant/actions/self_learn_and_reflect.py index a943cd846..caba53150 100644 --- a/examples/andriod_assistant/actions/self_learn_and_reflect.py +++ b/examples/andriod_assistant/actions/self_learn_and_reflect.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +# !/usr/bin/env python # -*- coding: utf-8 -*- # @Desc : LIKE scripts/self_explorer.py in stage=learn & mode=auto self_explore_task stage @@ -58,21 +58,21 @@ class SelfLearnAndReflect(Action): ui_area: int = -1 async def run( - self, round_count: int, task_desc: str, last_act: str, task_dir: Path, docs_dir: Path, env: AndroidEnv + self, round_count: int, task_desc: str, last_act: str, task_dir: Path, docs_dir: Path, env: AndroidEnv ) -> AndroidActionOutput: - resp = self.run_self_learn(round_count, task_desc, last_act, task_dir, env) - resp = self.run_reflect(round_count, task_desc, last_act, task_dir, docs_dir, env) + resp = await self.run_self_learn(round_count, task_desc, last_act, task_dir, env) + resp = await self.run_reflect(round_count, task_desc, last_act, task_dir, docs_dir, env) return resp async def run_self_learn( - self, round_count: int, task_desc: str, last_act: str, task_dir: Path, env: AndroidEnv + self, round_count: int, task_desc: str, last_act: str, task_dir: Path, env: AndroidEnv ) -> AndroidActionOutput: - screenshot_path: Path = env.step( + screenshot_path: Path = env.observe( EnvAPIAbstract( api_name="get_screenshot", kwargs={"ss_name": f"{round_count}_before", "local_save_dir": task_dir} ) ) - xml_path: Path = env.step( + xml_path: Path = env.observe( EnvAPIAbstract(api_name="get_xml", kwargs={"xml_name": f"{round_count}", "local_save_dir": task_dir}) ) if not screenshot_path.exists() or not xml_path.exists(): @@ -80,6 +80,7 @@ class SelfLearnAndReflect(Action): clickable_list = [] focusable_list = [] + # TODO Tuple Bug traverse_xml_tree(xml_path, clickable_list, "clickable", True) traverse_xml_tree(xml_path, focusable_list, "focusable", True) elem_list = [] @@ -155,9 +156,9 @@ class SelfLearnAndReflect(Action): return AndroidActionOutput() async def run_reflect( - self, round_count: int, task_desc: str, last_act: str, task_dir: Path, docs_dir: Path, env: AndroidEnv + self, round_count: int, task_desc: str, last_act: str, task_dir: Path, docs_dir: Path, env: AndroidEnv ) -> AndroidActionOutput: - screenshot_path: Path = env.step( + screenshot_path: Path = env.observe( EnvAPIAbstract( api_name="get_screenshot", kwargs={"ss_name": f"{round_count}_after", "local_save_dir": task_dir} ) diff --git a/examples/andriod_assistant/test_for_an.py b/examples/andriod_assistant/test_for_an.py new file mode 100644 index 000000000..dd3d90b6a --- /dev/null +++ b/examples/andriod_assistant/test_for_an.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : test on android emulator +import asyncio +import time +from pathlib import Path +from actions.manual_record import ManualRecord +from actions.parse_record import ParseRecord +from actions.self_learn_and_reflect import SelfLearnAndReflect +from metagpt.environment.android_env.android_env import AndroidEnv + +TASK_PATH = Path("apps/Contacts") +DOC_PATH = TASK_PATH.joinpath("docs") +DEMO_NAME = str(time.time()) +# TODO Test for Self Learning、 +test_env_self_learn_android = AndroidEnv( + device_id="emulator-5554", + xml_dir=Path("/sdcard"), + screenshot_dir=Path("/sdcard/Pictures/Screenshots"), +) +test_self_learning = SelfLearnAndReflect() + +# TODO Test for Manual Learning +test_env_manual_learn_android = AndroidEnv( + device_id="emulator-5554", + xml_dir=Path("/sdcard"), + screenshot_dir=Path("/sdcard/Pictures/Screenshots"), +) +test_manual_record = ManualRecord() +test_manual_parse = ParseRecord() + +# 虚拟机效果实现 +# 不同 Action Node 结果符合预期(Action Node) + +if __name__ == "__main__": + loop = asyncio.get_event_loop() + test_action_list = [ + test_self_learning.run( + round_count=20, + task_desc="Create a contact in Contacts App named zjy with a phone number +86 18831933368 ", + last_act="", + task_dir=TASK_PATH, + docs_dir=DOC_PATH, + env=test_env_self_learn_android + ), + # test_manual_record.run( + # demo_name=DEMO_NAME, + # task_dir=TASK_PATH, + # env=test_env_manual_learn_android + # ), + # test_manual_parse.run( + # app_name="Contacts", + # demo_name=DEMO_NAME, + # task_dir=TASK_PATH, + # docs_dir=DOC_PATH, + # env=test_env_manual_learn_android + # ) + ] + loop.run_until_complete(asyncio.gather(*test_action_list)) + loop.close() + print("Finish") diff --git a/metagpt/environment/android_env/android_ext_env.py b/metagpt/environment/android_env/android_ext_env.py index 7467d394c..4219d9cd8 100644 --- a/metagpt/environment/android_env/android_ext_env.py +++ b/metagpt/environment/android_env/android_ext_env.py @@ -9,10 +9,10 @@ from typing import Any, Optional from pydantic import Field from metagpt.const import ADB_EXEC_FAIL -from metagpt.environment.base_env import ExtEnv, mark_as_readable, mark_as_writeable +from metagpt.environment.base_env import Env, ExtEnv, mark_as_readable, mark_as_writeable -class AndroidExtEnv(ExtEnv): +class AndroidExtEnv(Env, ExtEnv): device_id: Optional[str] = Field(default=None) screenshot_dir: Optional[Path] = Field(default=None) xml_dir: Optional[Path] = Field(default=None) @@ -42,6 +42,7 @@ class AndroidExtEnv(ExtEnv): return f"adb -s {self.device_id} " def execute_adb_with_cmd(self, adb_cmd: str) -> str: + adb_cmd = adb_cmd.replace('\\', '/') res = subprocess.run(adb_cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) exec_res = ADB_EXEC_FAIL if not res.returncode: diff --git a/metagpt/environment/base_env.py b/metagpt/environment/base_env.py index 48917549e..911f33db9 100644 --- a/metagpt/environment/base_env.py +++ b/metagpt/environment/base_env.py @@ -3,7 +3,7 @@ # @Desc : base env of executing environment from enum import Enum -from typing import Optional, Union +from typing import Optional, Union, Any from pydantic import BaseModel, ConfigDict, Field @@ -13,6 +13,7 @@ from metagpt.environment.api.env_api import ( WriteAPIRegistry, ) from metagpt.schema import Message +from metagpt.utils.common import get_function_schema, is_coroutine_func class EnvType(Enum): @@ -23,26 +24,40 @@ class EnvType(Enum): STANFORDTOWN = "StanfordTown" +env_write_api_registry = WriteAPIRegistry() +env_read_api_registry = ReadAPIRegistry() + + +# def mark_as_readable(func): +# """mark functionn as a readable one in ExtEnv, it observes something from ExtEnv""" +# +# def wrapper(self: ExtEnv, *args, **kwargs): +# api_name = func.__name__ +# self.read_api_registry[api_name] = func +# return func(self, *args, **kwargs) +# +# return wrapper +# +# def mark_as_writeable(func): +# """mark functionn as a writeable one in ExtEnv, it does something to ExtEnv""" +# +# def wrapper(self: ExtEnv, *args, **kwargs): +# api_name = func.__name__ +# self.write_api_registry[api_name] = func +# return func(self, *args, **kwargs) +# +# return wrapper + def mark_as_readable(func): - """mark functionn as a readable one in ExtEnv, it observes something from ExtEnv""" - - def wrapper(self: ExtEnv, *args, **kwargs): - api_name = func.__name__ - self.read_api_registry[api_name] = func - return func(self, *args, **kwargs) - - return wrapper + """mark function as a readable one in ExtEnv, it observes something from ExtEnv""" + env_read_api_registry[func.__name__] = get_function_schema(func) + return func def mark_as_writeable(func): - """mark functionn as a writeable one in ExtEnv, it does something to ExtEnv""" - - def wrapper(self: ExtEnv, *args, **kwargs): - api_name = func.__name__ - self.write_api_registry[api_name] = func - return func(self, *args, **kwargs) - - return wrapper + """mark function as a writeable one in ExtEnv, it does something to ExtEnv""" + env_write_api_registry[func.__name__] = get_function_schema(func) + return func class ExtEnv(BaseModel): @@ -61,23 +76,59 @@ class Env(ExtEnv): if not rw_api: raise ValueError(f"{rw_api} not exists") + def get_all_available_apis(self, mode: str = "read") -> list[Any]: + """get available read/write apis definition""" + assert mode in ["read", "write"] + if mode == "read": + return env_read_api_registry.get_apis() + else: + return env_write_api_registry.get_apis() + + # TODO adds is_coroutine_func + # def observe(self, env_action: Union[str, EnvAPIAbstract]): + # if isinstance(env_action, str): + # read_api = env_write_api_registry.get(api_name=env_action) + # self._check_api_exist(read_api) + # res = read_api(self) + # elif isinstance(env_action, EnvAPIAbstract): + # read_api = env_write_api_registry.get(api_name=env_action.api_name) + # self._check_api_exist(read_api) + # res = read_api(self, *env_action.args, **env_action.kwargs) + # return res + # + # def step(self, env_action: Union[str, Message, EnvAPIAbstract, list[EnvAPIAbstract]]): + # res = None + # if isinstance(env_action, Message): + # self.publish_message(env_action) + # elif isinstance(env_action, EnvAPIAbstract): + # print(f"CURRENT API NAME: {env_action.api_name}") + # write_api = self.write_api_registry.get(env_action.api_name) + # self._check_api_exist(write_api) + # res = write_api(self, *env_action.args, **env_action.kwargs) + # + # return res + def observe(self, env_action: Union[str, EnvAPIAbstract]): + # TODO Adds is_coroutine_func + """get observation from particular api of ExtEnv""" if isinstance(env_action, str): - read_api = self.read_api_registry.get(api_name=env_action) + read_api = env_read_api_registry.get(api_name=env_action)["func"] self._check_api_exist(read_api) res = read_api(self) elif isinstance(env_action, EnvAPIAbstract): - read_api = self.read_api_registry.get(api_name=env_action.api_name) + read_api = env_read_api_registry.get(api_name=env_action.api_name)["func"] self._check_api_exist(read_api) res = read_api(self, *env_action.args, **env_action.kwargs) + return res def step(self, env_action: Union[str, Message, EnvAPIAbstract, list[EnvAPIAbstract]]): + """execute through particular api of ExtEnv""" res = None if isinstance(env_action, Message): self.publish_message(env_action) elif isinstance(env_action, EnvAPIAbstract): - write_api = self.write_api_registry.get(env_action.api_name) + write_api = env_write_api_registry.get(env_action.api_name)["func"] self._check_api_exist(write_api) res = write_api(self, *env_action.args, **env_action.kwargs) diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index 142b88620..25aeb54e8 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -25,7 +25,7 @@ import sys import traceback import typing from pathlib import Path -from typing import Any, List, Tuple, Union +from typing import Any, List, Tuple, Union, Callable import aiofiles import loguru @@ -214,7 +214,7 @@ class OutputParser: if start_index != -1 and end_index != -1: # Extract the structure part - structure_text = text[start_index : end_index + 1] + structure_text = text[start_index: end_index + 1] try: # Attempt to convert the text to a Python data type using ast.literal_eval @@ -337,6 +337,14 @@ def print_members(module, indent=0): print(f"{prefix}Method: {name}") +def get_function_schema(func: Callable) -> dict[str, Union[dict, Any, str]]: + sig = inspect.signature(func) + parameters = sig.parameters + return_type = sig.return_annotation + param_schema = {name: parameter.annotation for name, parameter in parameters.items()} + return {"input_params": param_schema, "return_type": return_type, "func_desc": func.__doc__, "func": func} + + def parse_recipient(text): # FIXME: use ActionNode instead. pattern = r"## Send To:\s*([A-Za-z]+)\s*?" # hard code for now @@ -594,6 +602,10 @@ def list_files(root: str | Path) -> List[Path]: return files +def is_coroutine_func(func: Callable) -> bool: + return inspect.iscoroutinefunction(func) + + def encode_image(image_path: Path, encoding: str = "utf-8") -> str: with open(str(image_path), "rb") as image_file: return base64.b64encode(image_file.read()).decode(encoding)