From cfc0cc1fa56ac5afb095cbdec41e464562179274 Mon Sep 17 00:00:00 2001 From: didi Date: Wed, 28 Feb 2024 17:00:51 +0800 Subject: [PATCH] Update env & test code --- examples/andriod_assistant/README.md | 3 + .../actions/manual_record.py | 22 +- .../andriod_assistant/actions/parse_record.py | 16 +- .../actions/screenshot_parse.py | 22 +- .../actions/self_learn_and_reflect.py | 28 +-- .../roles/android_assistant.py | 55 +++-- examples/andriod_assistant/run_assistant.py | 5 +- examples/andriod_assistant/test_for_an.py | 6 +- metagpt/config2.py | 8 +- metagpt/environment/__init__.py | 3 +- .../android_env/android_ext_env.py | 4 +- metagpt/environment/base_env.py | 216 ++++++++++++------ 12 files changed, 248 insertions(+), 140 deletions(-) diff --git a/examples/andriod_assistant/README.md b/examples/andriod_assistant/README.md index 6239bfcc1..48c15be5e 100644 --- a/examples/andriod_assistant/README.md +++ b/examples/andriod_assistant/README.md @@ -16,3 +16,6 @@ ## Free Your Hands ### By Text ### By Voice + +## Run It +python run_assistant.py "Create a contact in Contacts App named zjy with a phone number +86 18831933368" --mode "manual" --app-name "Contacts" \ No newline at end of file diff --git a/examples/andriod_assistant/actions/manual_record.py b/examples/andriod_assistant/actions/manual_record.py index affae143a..abcda3c8e 100644 --- a/examples/andriod_assistant/actions/manual_record.py +++ b/examples/andriod_assistant/actions/manual_record.py @@ -33,7 +33,8 @@ class ManualRecord(Action): screenshot_after_path: Path = "" xml_path: Path = "" - async def run(self, demo_name: str, task_desc: str,task_dir: Path, env: AndroidEnv): + # async def run(self, demo_name: str, task_desc: str,task_dir: Path, env: AndroidEnv): + async def run(self, task_desc: str, task_dir: Path, env: AndroidEnv): self.record_path = Path(task_dir) / "record.txt" self.task_desc_path = Path(task_dir) / "task_desc.txt" @@ -53,16 +54,18 @@ class ManualRecord(Action): step = 0 while True: step += 1 - screenshot_path: Path = env.observe( + screenshot_path: Path = await env.observe( EnvAPIAbstract( api_name="get_screenshot", - kwargs={"ss_name": f"{demo_name}_{step}", "local_save_dir": self.screenshot_before_path} + # kwargs={"ss_name": f"{demo_name}_{step}", "local_save_dir": self.screenshot_before_path} + kwargs={"ss_name": f"{step}", "local_save_dir": self.screenshot_before_path} ) ) - xml_path: Path = env.observe( + xml_path: Path = await env.observe( EnvAPIAbstract( api_name="get_xml", - kwargs={"xml_name": f"{demo_name}_{step}", "local_save_dir": self.xml_path} + # kwargs={"xml_name": f"{demo_name}_{step}", "local_save_dir": self.xml_path} + kwargs={"xml_name": f"{step}", "local_save_dir": self.xml_path} ) ) if not screenshot_path.exists() or not xml_path.exists(): @@ -86,14 +89,13 @@ class ManualRecord(Action): bbox = e.bbox center_ = (bbox[0][0] + bbox[1][0]) // 2, (bbox[0][1] + bbox[1][1]) // 2 dist = (abs(center[0] - center_[0]) ** 2 + abs(center[1] - center_[1]) ** 2) ** 0.5 - # TODO Modify config to default 30. It should be modified back config after single action test - # if dist <= config.get_other("min_dist"): - if dist <= 30: + if dist <= config.get_other("min_dist"): close = True break if not close: elem_list.append(elem) - screenshot_labeled_path = Path(self.screenshot_after_path).joinpath(f"{demo_name}_{step}_labeled.png") + screenshot_labeled_path = Path(self.screenshot_after_path).joinpath(f"{step}_labeled.png") + # screenshot_labeled_path = Path(self.screenshot_after_path).joinpath(f"{demo_name}_{step}_labeled.png") labeled_img = draw_bbox_multi(screenshot_path, screenshot_labeled_path, elem_list) cv2.imshow("image", labeled_img) @@ -142,7 +144,7 @@ class ManualRecord(Action): user_input = "" while not user_input: user_input = input() - env.step(EnvAPIAbstract(api_name="user_input", kwargs={"input_txt": user_input})) + await env.step(EnvAPIAbstract(api_name="user_input", kwargs={"input_txt": user_input})) record_file.write(f'text({input_area}:sep:"{user_input}"):::{elem_list[int(input_area) - 1].uid}\n') elif user_input.lower() == ActionOp.LONG_PRESS.value: logger.info( diff --git a/examples/andriod_assistant/actions/parse_record.py b/examples/andriod_assistant/actions/parse_record.py index 4688f796b..774ae0701 100644 --- a/examples/andriod_assistant/actions/parse_record.py +++ b/examples/andriod_assistant/actions/parse_record.py @@ -38,9 +38,9 @@ class ParseRecord(Action): screenshot_before_path: Path = "" screenshot_after_path: Path = "" - async def run(self, app_name: str, demo_name: str, task_dir: Path, docs_dir: Path, env: AndroidEnv): - if not docs_dir.exists(): - docs_dir.mkdir(parents=True, exist_ok=True) + # async def run(self, app_name: str, demo_name: str, task_dir: Path, docs_dir: Path, env: AndroidEnv): + async def run(self, app_name: str, task_dir: Path, docs_dir: Path, env: AndroidEnv): + docs_dir.mkdir(parents=True, exist_ok=True) doc_count = 0 self.record_path = Path(task_dir) / "record.txt" self.task_desc_path = Path(task_dir) / "task_desc.txt" @@ -51,8 +51,10 @@ class ParseRecord(Action): record_step_count = len(record_file.readlines()) - 1 record_file.seek(0) for step in range(1, record_step_count + 1): - img_before_base64 = encode_image(self.screenshot_after_path.joinpath(f"{demo_name}_{step}_labeled.png")) - img_after_base64 = encode_image(self.screenshot_after_path.joinpath(f"{demo_name}_{step + 1}_labeled.png")) + # img_before_base64 = encode_image(self.screenshot_after_path.joinpath(f"{demo_name}_{step}_labeled.png")) + # img_after_base64 = encode_image(self.screenshot_after_path.joinpath(f"{demo_name}_{step + 1}_labeled.png")) + img_before_base64 = encode_image(self.screenshot_after_path.joinpath(f"{step}_labeled.png")) + img_after_base64 = encode_image(self.screenshot_after_path.joinpath(f"{step + 1}_labeled.png")) rec = record_file.readline().strip() action, resource_id = rec.split(":::") action_type = action.split("(")[0] @@ -110,8 +112,8 @@ class ParseRecord(Action): ) if "error" in node.content: return AndroidActionOutput(action_state=RunState.FAIL) - - log_path = task_dir.joinpath(f"log_{app_name}_{demo_name}.txt") + # log_path = task_dir.joinpath(f"log_{app_name}_{demo_name}.txt") + log_path = task_dir.joinpath(f"log_{app_name}.txt") prompt = node.compile(context=context, schema="json", mode="auto") msg = node.content doc_content[action_type] = msg diff --git a/examples/andriod_assistant/actions/screenshot_parse.py b/examples/andriod_assistant/actions/screenshot_parse.py index 40082bc04..c2bd16863 100644 --- a/examples/andriod_assistant/actions/screenshot_parse.py +++ b/examples/andriod_assistant/actions/screenshot_parse.py @@ -92,13 +92,13 @@ class ScreenshotParse(Action): if not path.exists(): path.mkdir(parents=True, exist_ok=True) - screenshot_path: Path = env.observe( + screenshot_path: Path = await env.observe( EnvAPIAbstract( api_name="get_screenshot", kwargs={"ss_name": f"{round_count}_before", "local_save_dir": task_dir} ) ) - xml_path: Path = env.observe( + xml_path: Path = await env.observe( EnvAPIAbstract( api_name="get_xml", kwargs={"xml_name": f"{round_count}", "local_save_dir": task_dir} @@ -121,9 +121,7 @@ class ScreenshotParse(Action): bbox = e.bbox center_ = (bbox[0][0] + bbox[1][0]) // 2, (bbox[0][1] + bbox[1][1]) // 2 dist = (abs(center[0] - center_[0]) ** 2 + abs(center[1] - center_[1]) ** 2) ** 0.5 - # TODO Modify config to default 30. It should be modified back config after single action test - # if dist <= config.get_other("min_dist"): - if dist <= 30: + if dist <= config.get_other("min_dist"): close = True break if not close: @@ -156,21 +154,21 @@ class ScreenshotParse(Action): if isinstance(op_param, TapOp): x, y = elem_bbox_to_xy(elem_list[op_param.area - 1].bbox) - res = env.step(EnvAPIAbstract(api_name="system_tap", kwargs={"x": x, "y": y})) + res = await env.step(EnvAPIAbstract(api_name="system_tap", kwargs={"x": x, "y": y})) if res == ADB_EXEC_FAIL: return AndroidActionOutput(action_state=RunState.FAIL) elif isinstance(op_param, TextOp): - res = env.step(EnvAPIAbstract(api_name="user_input", kwargs={"input_txt": op_param.input_str})) + res = await env.step(EnvAPIAbstract(api_name="user_input", kwargs={"input_txt": op_param.input_str})) if res == ADB_EXEC_FAIL: return AndroidActionOutput(action_state=RunState.FAIL) elif isinstance(op_param, LongPressOp): x, y = elem_bbox_to_xy(elem_list[op_param.area - 1].bbox) - res = env.step(EnvAPIAbstract(api_name="user_longpress", kwargs={"x": x, "y": y})) + res = await env.step(EnvAPIAbstract(api_name="user_longpress", kwargs={"x": x, "y": y})) if res == ADB_EXEC_FAIL: return AndroidActionOutput(action_state=RunState.FAIL) elif isinstance(op_param, SwipeOp_3): x, y = elem_bbox_to_xy(elem_list[op_param.area - 1].bbox) - res = env.step( + res = await env.step( EnvAPIAbstract( api_name="user_swipe", kwargs={"x": x, "y": y, "orient": op_param.swipe_orient, "dist": op_param.dist} @@ -183,18 +181,18 @@ class ScreenshotParse(Action): elif isinstance(op_param, TapGridOp) or isinstance(op_param, LongPressGridOp): x, y = area_to_xy(op_param.area, op_param.subarea, env.width, env.height, env.rows, env.cols) if isinstance(op_param, TapGridOp): - res = env.step(EnvAPIAbstract(api_name="system_tap", kwargs={"x": x, "y": y})) + res = await env.step(EnvAPIAbstract(api_name="system_tap", kwargs={"x": x, "y": y})) if res == ADB_EXEC_FAIL: return AndroidActionOutput(action_state=RunState.FAIL) else: # LongPressGridOp - res = env.step(EnvAPIAbstract(api_name="user_longpress", kwargs={"x": x, "y": y})) + res = await env.step(EnvAPIAbstract(api_name="user_longpress", kwargs={"x": x, "y": y})) if res == ADB_EXEC_FAIL: return AndroidActionOutput(action_state=RunState.FAIL) elif isinstance(op_param, SwipeGridOp): start_x, start_y = area_to_xy(op_param.start_area, op_param.start_subarea, width, height, rows, cols) end_x, end_y = area_to_xy(op_param.end_area, op_param.end_subarea, width, height, rows, cols) - res = env.step( + res = await env.step( EnvAPIAbstract(api_name="user_swipe_to", kwargs={"start": (start_x, start_y), "end": (end_x, end_y)})) if res == ADB_EXEC_FAIL: return AndroidActionOutput(action_state=RunState.FAIL) diff --git a/examples/andriod_assistant/actions/self_learn_and_reflect.py b/examples/andriod_assistant/actions/self_learn_and_reflect.py index 02193b860..57dea0e79 100644 --- a/examples/andriod_assistant/actions/self_learn_and_reflect.py +++ b/examples/andriod_assistant/actions/self_learn_and_reflect.py @@ -26,6 +26,7 @@ from examples.andriod_assistant.utils.schema import ( ReflectLogItem, RunState, SwipeOp, + SwipeOp_3, TapOp, TextOp, ) @@ -70,12 +71,12 @@ class SelfLearnAndReflect(Action): async def run_self_learn( self, round_count: int, task_desc: str, last_act: str, task_dir: Path, env: AndroidEnv ) -> AndroidActionOutput: - screenshot_path: Path = env.observe( + screenshot_path: Path = await env.observe( EnvAPIAbstract( api_name="get_screenshot", kwargs={"ss_name": f"{round_count}_before", "local_save_dir": task_dir} ) ) - xml_path: Path = env.observe( + xml_path: Path = await env.observe( EnvAPIAbstract(api_name="get_xml", kwargs={"xml_name": f"{round_count}", "local_save_dir": task_dir}) ) if not screenshot_path.exists() or not xml_path.exists(): @@ -100,9 +101,7 @@ class SelfLearnAndReflect(Action): bbox = e.bbox center_ = (bbox[0][0] + bbox[1][0]) // 2, (bbox[0][1] + bbox[1][1]) // 2 dist = (abs(center[0] - center_[0]) ** 2 + abs(center[1] - center_[1]) ** 2) ** 0.5 - # TODO Modify config to default 30. It should be modified back config after single action test - # if dist <= config.get_other("min_dist"): - if dist <= 30: + if dist <= config.get_other("min_dist"): close = True break if not close: @@ -125,7 +124,6 @@ class SelfLearnAndReflect(Action): OpLogItem(step=round_count, prompt=prompt, image=str(screenshot_before_labeled_path), response=node.content) op_param = screenshot_parse_extract(node.instruct_content.model_dump(), grid_on=False) # TODO Modify Op_param. When op_param.action is FINISH, how to solve this ? - logger.info(op_param) if op_param.param_state == RunState.FINISH: return AndroidActionOutput(action_state=RunState.FINISH) if op_param.param_state == RunState.FAIL: @@ -134,26 +132,26 @@ class SelfLearnAndReflect(Action): if isinstance(op_param, TapOp): self.ui_area = op_param.area x, y = elem_bbox_to_xy(elem_list[op_param.area - 1].bbox) - res = env.step(EnvAPIAbstract(api_name="system_tap", kwargs={"x": x, "y": y})) + res = await env.step(EnvAPIAbstract(api_name="system_tap", kwargs={"x": x, "y": y})) if res == ADB_EXEC_FAIL: return AndroidActionOutput(action_state=RunState.FAIL) elif isinstance(op_param, TextOp): - res = env.step(EnvAPIAbstract(api_name="user_input", kwargs={"input_txt": op_param.input_str})) + res = await env.step(EnvAPIAbstract(api_name="user_input", kwargs={"input_txt": op_param.input_str})) if res == ADB_EXEC_FAIL: return AndroidActionOutput(action_state=RunState.FAIL) elif isinstance(op_param, LongPressOp): self.ui_area = op_param.area x, y = elem_bbox_to_xy(elem_list[op_param.area - 1].bbox) - res = env.step(EnvAPIAbstract(api_name="user_longpress", kwargs={"x": x, "y": y})) + res = await env.step(EnvAPIAbstract(api_name="user_longpress", kwargs={"x": x, "y": y})) if res == ADB_EXEC_FAIL: return AndroidActionOutput(action_state=RunState.FAIL) - elif isinstance(op_param, SwipeOp): + elif isinstance(op_param, SwipeOp_3): self.ui_area = op_param.area self.swipe_orient = op_param.swipe_orient x, y = elem_bbox_to_xy(elem_list[op_param.area - 1].bbox) - res = env.step( + res = await env.step( EnvAPIAbstract( - "user_swipe", kwargs={"x": x, "y": y, "orient": op_param.swipe_orient, "dist": op_param.dist} + api_name="user_swipe", kwargs={"x": x, "y": y, "orient": op_param.swipe_orient, "dist": op_param.dist} ) ) if res == ADB_EXEC_FAIL: @@ -167,8 +165,7 @@ class SelfLearnAndReflect(Action): async def run_reflect( self, round_count: int, task_desc: str, last_act: str, task_dir: Path, docs_dir: Path, env: AndroidEnv ) -> AndroidActionOutput: - logger.info("run_reflect") - screenshot_path: Path = env.observe( + screenshot_path: Path = await env.observe( EnvAPIAbstract( api_name="get_screenshot", kwargs={"ss_name": f"{round_count}_after", "local_save_dir": task_dir} ) @@ -180,7 +177,6 @@ class SelfLearnAndReflect(Action): draw_bbox_multi(screenshot_path, screenshot_after_labeled_path, elem_list=self.elem_list) img_base64 = encode_image(screenshot_after_labeled_path) - logger.info(f"act_name: {self.act_name}") if self.act_name == ActionOp.TAP.value: action = "tapping" elif self.act_name == ActionOp.LONG_PRESS.value: @@ -225,7 +221,7 @@ class SelfLearnAndReflect(Action): self.useless_list.append(resource_id) last_act = "NONE" if op_param.decision == Decision.BACK.value: - res = env.step(EnvAPIAbstract(api_name="system_back")) + res = await env.step(EnvAPIAbstract(api_name="system_back")) if res == ADB_EXEC_FAIL: return AndroidActionOutput(action_state=RunState.FAIL) doc = op_param.documentation diff --git a/examples/andriod_assistant/roles/android_assistant.py b/examples/andriod_assistant/roles/android_assistant.py index 38c850f32..cd2d0d807 100644 --- a/examples/andriod_assistant/roles/android_assistant.py +++ b/examples/andriod_assistant/roles/android_assistant.py @@ -37,7 +37,8 @@ class AndroidAssistant(Role): self._watch([UserRequirement]) app_name = config.get_other("app_name", "demo") - data_dir = Path(__file__).parent.joinpath("..", "output") + curr_path = Path(__file__).parent + data_dir = curr_path.joinpath("..", "output") cur_datetime = datetime.fromtimestamp(int(time.time())).strftime("%Y-%m-%d_%H-%M-%S") """Firstly, we decide the state with user config, further, we can do it automatically, like if it's new app, @@ -67,39 +68,57 @@ class AndroidAssistant(Role): self._set_react_mode(RoleReactMode.BY_ORDER) def _check_dir(self): - self.task_dir.mkdir(exist_ok=True) - self.docs_dir.mkdir(exist_ok=True) + self.task_dir.mkdir(parents=True, exist_ok=True) + self.docs_dir.mkdir(parents=True, exist_ok=True) async def react(self) -> Message: self.round_count += 1 - super().react() + result = await super().react() + print(f"react result {result}") + return result async def _act(self) -> Message: + # Question: How to achieve self_learn's loop action ? logger.info(f"{self._setting}: to do {self.rc.todo}({self.rc.todo.name})") todo = self.rc.todo send_to = "" if isinstance(todo, ManualRecord): - resp = await todo.run() + resp = await todo.run( + # demo_name="", + task_dir=self.task_dir, + task_desc=self.task_desc, + env=self.rc.env + ) elif isinstance(todo, ParseRecord): - resp = await todo.run() + resp = await todo.run( + app_name=config.get_other("app_name", "demo"), + task_dir=self.task_dir, + docs_dir=self.docs_dir, + env=self.rc.env + ) elif isinstance(todo, SelfLearnAndReflect): - resp = await todo.run(round_count=self.round_count, - task_desc=self.task_desc, - last_act=self.last_act, - task_dir=self.task_dir, - docs_dir=self.docs_dir, - env=self.rc.env) + resp = await todo.run( + round_count=self.round_count, + task_desc=self.task_desc, + last_act=self.last_act, + task_dir=self.task_dir, + docs_dir=self.docs_dir, + env=self.rc.env + ) if resp.action_state == RunState.SUCCESS: self.last_act = resp.data.get("last_act") send_to = self.name elif isinstance(todo, ScreenshotParse): - resp = await todo.run(round_count=self.round_count, - task_desc=self.task_desc, - last_act=self.last_act, - task_dir=self.task_dir, - grid_on=self.grid_on, - env=self.rc.env) + resp = await todo.run( + round_count=self.round_count, + task_desc=self.task_desc, + last_act=self.last_act, + task_dir=self.task_dir, + docs_dir=self.docs_dir, + grid_on=self.grid_on, + env=self.rc.env + ) if resp.action_state == RunState.SUCCESS: self.grid_on = resp.data.get("grid_on") send_to = self.name diff --git a/examples/andriod_assistant/run_assistant.py b/examples/andriod_assistant/run_assistant.py index ce15d9511..187a8032b 100644 --- a/examples/andriod_assistant/run_assistant.py +++ b/examples/andriod_assistant/run_assistant.py @@ -50,12 +50,13 @@ def startup( ) team = Team(env=AndroidEnv()) - team.hire([AndroidAssistant]) + team.hire([AndroidAssistant()]) team.invest(investment) - company.run_project(idea=task_desc) + team.run_project(idea=task_desc) asyncio.run(team.run(n_round=n_round)) if __name__ == "__main__": app() +# Command python run_assistant.py "Create a contact in Contacts App named zjy with a phone number +86 18831933368" \ No newline at end of file diff --git a/examples/andriod_assistant/test_for_an.py b/examples/andriod_assistant/test_for_an.py index 8f6fb9b91..bccb5f3b3 100644 --- a/examples/andriod_assistant/test_for_an.py +++ b/examples/andriod_assistant/test_for_an.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -# @Desc : test on android emulator +# @Desc : test on android emulator action. After Modify Role Test, this script is discarded. import asyncio import time from pathlib import Path @@ -50,14 +50,14 @@ if __name__ == "__main__": env=test_env_self_learn_android ), test_manual_record.run( - demo_name=DEMO_NAME, + # demo_name=DEMO_NAME, task_dir=TASK_PATH / "demos" / f"manual_record_{DEMO_NAME}", task_desc="Create a contact in Contacts App named zjy with a phone number +86 18831933368 ", env=test_env_manual_learn_android ), test_manual_parse.run( app_name="Contacts", - demo_name=DEMO_NAME, + # demo_name=DEMO_NAME, task_dir=TASK_PATH / "demos" / f"manual_record_{DEMO_NAME}", # 修要修改 docs_dir=PARSE_RECORD_DOC_PATH, # 需要修改 env=test_env_manual_learn_android diff --git a/metagpt/config2.py b/metagpt/config2.py index 9fc94b330..2b9cdc78e 100644 --- a/metagpt/config2.py +++ b/metagpt/config2.py @@ -122,8 +122,12 @@ class Config(CLIParams, YamlModel): def set_other(self, other: dict): self.other = other - def get_other(self, key: str): - return self.other.get(key) + def get_other(self, key: str, default_value: str = None): + if default_value is None: + return self.other.get(key) + else: + return self.other.get(key, default_value) + def get_openai_llm(self) -> Optional[LLMConfig]: """Get OpenAI LLMConfig by name. If no OpenAI, raise Exception""" diff --git a/metagpt/environment/__init__.py b/metagpt/environment/__init__.py index 592164d63..d2df8fd02 100644 --- a/metagpt/environment/__init__.py +++ b/metagpt/environment/__init__.py @@ -1,7 +1,8 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # @Desc : - +# TODO +from metagpt.environment.base_env import Environment from metagpt.environment.android_env.android_env import AndroidEnv from metagpt.environment.gym_env.gym_env import GymEnv from metagpt.environment.mincraft_env.mincraft_env import MincraftExtEnv diff --git a/metagpt/environment/android_env/android_ext_env.py b/metagpt/environment/android_env/android_ext_env.py index 4219d9cd8..72eae7182 100644 --- a/metagpt/environment/android_env/android_ext_env.py +++ b/metagpt/environment/android_env/android_ext_env.py @@ -9,10 +9,10 @@ from typing import Any, Optional from pydantic import Field from metagpt.const import ADB_EXEC_FAIL -from metagpt.environment.base_env import Env, ExtEnv, mark_as_readable, mark_as_writeable +from metagpt.environment.base_env import Environment, ExtEnv, mark_as_readable, mark_as_writeable -class AndroidExtEnv(Env, ExtEnv): +class AndroidExtEnv(Environment, ExtEnv): device_id: Optional[str] = Field(default=None) screenshot_dir: Optional[Path] = Field(default=None) xml_dir: Optional[Path] = Field(default=None) diff --git a/metagpt/environment/base_env.py b/metagpt/environment/base_env.py index 911f33db9..b39010aa1 100644 --- a/metagpt/environment/base_env.py +++ b/metagpt/environment/base_env.py @@ -2,25 +2,29 @@ # -*- coding: utf-8 -*- # @Desc : base env of executing environment +import asyncio from enum import Enum -from typing import Optional, Union, Any - -from pydantic import BaseModel, ConfigDict, Field - +from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Set, Union +from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator +from metagpt.context import Context from metagpt.environment.api.env_api import ( EnvAPIAbstract, ReadAPIRegistry, WriteAPIRegistry, ) +from metagpt.logs import logger from metagpt.schema import Message -from metagpt.utils.common import get_function_schema, is_coroutine_func +from metagpt.utils.common import get_function_schema, is_coroutine_func, is_send_to + +if TYPE_CHECKING: + from metagpt.roles.role import Role # noqa: F401 class EnvType(Enum): ANDROID = "Android" GYM = "Gym" WEREWOLF = "Werewolf" - MINCRAFT = "Minsraft" + MINCRAFT = "Mincraft" STANFORDTOWN = "StanfordTown" @@ -28,49 +32,25 @@ env_write_api_registry = WriteAPIRegistry() env_read_api_registry = ReadAPIRegistry() -# def mark_as_readable(func): -# """mark functionn as a readable one in ExtEnv, it observes something from ExtEnv""" -# -# def wrapper(self: ExtEnv, *args, **kwargs): -# api_name = func.__name__ -# self.read_api_registry[api_name] = func -# return func(self, *args, **kwargs) -# -# return wrapper -# -# def mark_as_writeable(func): -# """mark functionn as a writeable one in ExtEnv, it does something to ExtEnv""" -# -# def wrapper(self: ExtEnv, *args, **kwargs): -# api_name = func.__name__ -# self.write_api_registry[api_name] = func -# return func(self, *args, **kwargs) -# -# return wrapper - def mark_as_readable(func): - """mark function as a readable one in ExtEnv, it observes something from ExtEnv""" + """mark functionn as a readable one in ExtEnv, it observes something from ExtEnv""" env_read_api_registry[func.__name__] = get_function_schema(func) return func def mark_as_writeable(func): - """mark function as a writeable one in ExtEnv, it does something to ExtEnv""" + """mark functionn as a writeable one in ExtEnv, it does something to ExtEnv""" env_write_api_registry[func.__name__] = get_function_schema(func) return func - class ExtEnv(BaseModel): """External Env to intergate actual game environment""" write_api_registry: WriteAPIRegistry = Field(default_factory=WriteAPIRegistry, exclude=True) read_api_registry: ReadAPIRegistry = Field(default_factory=ReadAPIRegistry, exclude=True) - -class Env(ExtEnv): - """Env to intergate with MetaGPT""" - - model_config = ConfigDict(arbitrary_types_allowed=True) +class ExtEnv(BaseModel): + """External Env to intergate actual game environment""" def _check_api_exist(self, rw_api: Optional[str] = None): if not rw_api: @@ -84,45 +64,25 @@ class Env(ExtEnv): else: return env_write_api_registry.get_apis() - # TODO adds is_coroutine_func - # def observe(self, env_action: Union[str, EnvAPIAbstract]): - # if isinstance(env_action, str): - # read_api = env_write_api_registry.get(api_name=env_action) - # self._check_api_exist(read_api) - # res = read_api(self) - # elif isinstance(env_action, EnvAPIAbstract): - # read_api = env_write_api_registry.get(api_name=env_action.api_name) - # self._check_api_exist(read_api) - # res = read_api(self, *env_action.args, **env_action.kwargs) - # return res - # - # def step(self, env_action: Union[str, Message, EnvAPIAbstract, list[EnvAPIAbstract]]): - # res = None - # if isinstance(env_action, Message): - # self.publish_message(env_action) - # elif isinstance(env_action, EnvAPIAbstract): - # print(f"CURRENT API NAME: {env_action.api_name}") - # write_api = self.write_api_registry.get(env_action.api_name) - # self._check_api_exist(write_api) - # res = write_api(self, *env_action.args, **env_action.kwargs) - # - # return res - - def observe(self, env_action: Union[str, EnvAPIAbstract]): - # TODO Adds is_coroutine_func + async def observe(self, env_action: Union[str, EnvAPIAbstract]): """get observation from particular api of ExtEnv""" if isinstance(env_action, str): read_api = env_read_api_registry.get(api_name=env_action)["func"] self._check_api_exist(read_api) - res = read_api(self) + if is_coroutine_func(read_api): + res = await read_api(self) + else: + res = read_api(self) elif isinstance(env_action, EnvAPIAbstract): read_api = env_read_api_registry.get(api_name=env_action.api_name)["func"] self._check_api_exist(read_api) - res = read_api(self, *env_action.args, **env_action.kwargs) - + if is_coroutine_func(read_api): + res = await read_api(self, *env_action.args, **env_action.kwargs) + else: + res = read_api(self, *env_action.args, **env_action.kwargs) return res - def step(self, env_action: Union[str, Message, EnvAPIAbstract, list[EnvAPIAbstract]]): + async def step(self, env_action: Union[str, Message, EnvAPIAbstract, list[EnvAPIAbstract]]): """execute through particular api of ExtEnv""" res = None if isinstance(env_action, Message): @@ -130,9 +90,131 @@ class Env(ExtEnv): elif isinstance(env_action, EnvAPIAbstract): write_api = env_write_api_registry.get(env_action.api_name)["func"] self._check_api_exist(write_api) - res = write_api(self, *env_action.args, **env_action.kwargs) + if is_coroutine_func(write_api): + res = await write_api(self, *env_action.args, **env_action.kwargs) + else: + res = write_api(self, *env_action.args, **env_action.kwargs) return res - def publish_message(self, message: "Message"): - pass +class Environment(ExtEnv): + """环境,承载一批角色,角色可以向环境发布消息,可以被其他角色观察到 + Environment, hosting a batch of roles, roles can publish messages to the environment, and can be observed by other roles + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + + desc: str = Field(default="") # 环境描述 + roles: dict[str, SerializeAsAny["Role"]] = Field(default_factory=dict, validate_default=True) + member_addrs: Dict["Role", Set] = Field(default_factory=dict, exclude=True) + history: str = "" # For debug + context: Context = Field(default_factory=Context, exclude=True) + + @model_validator(mode="after") + def init_roles(self): + self.add_roles(self.roles.values()) + return self + + def add_role(self, role: "Role"): + """增加一个在当前环境的角色 + Add a role in the current environment + """ + self.roles[role.profile] = role + role.set_env(self) + role.context = self.context + + def add_roles(self, roles: Iterable["Role"]): + """增加一批在当前环境的角色 + Add a batch of characters in the current environment + """ + for role in roles: + self.roles[role.profile] = role + + for role in roles: # setup system message with roles + role.set_env(self) + role.context = self.context + + def publish_message(self, message: Message, peekable: bool = True) -> bool: + """ + Distribute the message to the recipients. + In accordance with the Message routing structure design in Chapter 2.2.1 of RFC 116, as already planned + in RFC 113 for the entire system, the routing information in the Message is only responsible for + specifying the message recipient, without concern for where the message recipient is located. How to + route the message to the message recipient is a problem addressed by the transport framework designed + in RFC 113. + """ + logger.debug(f"publish_message: {message.dump()}") + found = False + # According to the routing feature plan in Chapter 2.2.3.2 of RFC 113 + for role, addrs in self.member_addrs.items(): + if is_send_to(message, addrs): + role.put_message(message) + found = True + if not found: + logger.warning(f"Message no recipients: {message.dump()}") + self.history += f"\n{message}" # For debug + + return True + + async def run(self, k=1): + """处理一次所有信息的运行 + Process all Role runs at once + """ + for _ in range(k): + futures = [] + for role in self.roles.values(): + future = role.run() + futures.append(future) + + await asyncio.gather(*futures) + logger.debug(f"is idle: {self.is_idle}") + + def get_roles(self) -> dict[str, "Role"]: + """获得环境内的所有角色 + Process all Role runs at once + """ + return self.roles + + def get_role(self, name: str) -> "Role": + """获得环境内的指定角色 + get all the environment roles + """ + return self.roles.get(name, None) + + def role_names(self) -> list[str]: + return [i.name for i in self.roles.values()] + + @property + def is_idle(self): + """If true, all actions have been executed.""" + for r in self.roles.values(): + if not r.is_idle: + return False + return True + + def get_addresses(self, obj): + """Get the addresses of the object.""" + return self.member_addrs.get(obj, {}) + + def set_addresses(self, obj, addresses): + """Set the addresses of the object""" + self.member_addrs[obj] = addresses + + def archive(self, auto_archive=True): + if auto_archive and self.context.git_repo: + self.context.git_repo.archive() + + @classmethod + def model_rebuild(cls, **kwargs): + from metagpt.roles.role import Role # noqa: F401 + + super().model_rebuild(**kwargs) + + +Environment.model_rebuild() + + + + + +