diff --git a/examples/andriod_assistant/actions/manual_record.py b/examples/andriod_assistant/actions/manual_record.py index ef9796b55..5deafa680 100644 --- a/examples/andriod_assistant/actions/manual_record.py +++ b/examples/andriod_assistant/actions/manual_record.py @@ -9,9 +9,8 @@ import cv2 from examples.andriod_assistant.utils.schema import ( ActionOp, AndroidActionOutput, - AndroidElement, RunState, - SwipeOp + SwipeOp, ) from examples.andriod_assistant.utils.utils import draw_bbox_multi, traverse_xml_tree from metagpt.actions.action import Action @@ -24,6 +23,7 @@ from metagpt.logs import logger class ManualRecord(Action): """do a human operation on the screen with human input""" + name: str = "ManualRecord" useless_list: list[str] = [] # store useless elements uid @@ -35,19 +35,18 @@ class ManualRecord(Action): # async def run(self, demo_name: str, task_desc: str,task_dir: Path, env: AndroidEnv): async def run(self, task_desc: str, task_dir: Path, env: AndroidEnv): - self.record_path = Path(task_dir) / "record.txt" self.task_desc_path = Path(task_dir) / "task_desc.txt" - self.screenshot_before_path = Path(task_dir)/"raw_screenshots" - self.screenshot_after_path = Path(task_dir)/"labeled_screenshots" - self.xml_path = Path(task_dir)/"xml" + self.screenshot_before_path = Path(task_dir) / "raw_screenshots" + self.screenshot_after_path = Path(task_dir) / "labeled_screenshots" + self.xml_path = Path(task_dir) / "xml" - for path in [self.screenshot_before_path,self.screenshot_after_path, self.xml_path]: + for path in [self.screenshot_before_path, self.screenshot_after_path, self.xml_path]: if not path.exists(): path.mkdir(parents=True, exist_ok=True) - with open(self.record_path, 'w') as file: - file.write('') + with open(self.record_path, "w") as file: + file.write("") record_file = open(self.record_path, "w") with open(self.task_desc_path, "w") as f: f.write(task_desc) @@ -58,14 +57,14 @@ class ManualRecord(Action): EnvAPIAbstract( api_name="get_screenshot", # kwargs={"ss_name": f"{demo_name}_{step}", "local_save_dir": self.screenshot_before_path} - kwargs={"ss_name": f"{step}", "local_save_dir": self.screenshot_before_path} + kwargs={"ss_name": f"{step}", "local_save_dir": self.screenshot_before_path}, ) ) xml_path: Path = await env.observe( EnvAPIAbstract( api_name="get_xml", # kwargs={"xml_name": f"{demo_name}_{step}", "local_save_dir": self.xml_path} - kwargs={"xml_name": f"{step}", "local_save_dir": self.xml_path} + kwargs={"xml_name": f"{step}", "local_save_dir": self.xml_path}, ) ) if not screenshot_path.exists() or not xml_path.exists(): @@ -110,11 +109,11 @@ class ManualRecord(Action): ) while ( - user_input.lower() != ActionOp.TAP.value - and user_input.lower() != ActionOp.TEXT.value - and user_input.lower() != ActionOp.LONG_PRESS.value - and user_input.lower() != ActionOp.SWIPE.value - and user_input.lower() != ActionOp.STOP.value + user_input.lower() != ActionOp.TAP.value + and user_input.lower() != ActionOp.TEXT.value + and user_input.lower() != ActionOp.LONG_PRESS.value + and user_input.lower() != ActionOp.SWIPE.value + and user_input.lower() != ActionOp.STOP.value ): user_input = input() @@ -167,10 +166,10 @@ class ManualRecord(Action): ) user_input = "" while ( - user_input != SwipeOp.UP.value - and user_input != SwipeOp.DOWN.value - and user_input != SwipeOp.LEFT.value - and user_input != SwipeOp.RIGHT.value + user_input != SwipeOp.UP.value + and user_input != SwipeOp.DOWN.value + and user_input != SwipeOp.LEFT.value + and user_input != SwipeOp.RIGHT.value ): user_input = input() swipe_dir = user_input @@ -179,7 +178,9 @@ class ManualRecord(Action): user_input = input() tl, br = elem_list[int(user_input) - 1].bbox x, y = (tl[0] + br[0]) // 2, (tl[1] + br[1]) // 2 - ret = await env.step(EnvAPIAbstract(api_name="user_swipe", kwargs={"x": x, "y": y, "orient": swipe_dir})) + ret = await env.step( + EnvAPIAbstract(api_name="user_swipe", kwargs={"x": x, "y": y, "orient": swipe_dir}) + ) if ret == ADB_EXEC_FAIL: return AndroidActionOutput(action_state=RunState.FAIL) record_file.write(f"swipe({int(user_input)}:sep:{swipe_dir}):::{elem_list[int(user_input) - 1].uid}\n") @@ -190,5 +191,3 @@ class ManualRecord(Action): else: break time.sleep(3) - - diff --git a/examples/andriod_assistant/actions/parse_record.py b/examples/andriod_assistant/actions/parse_record.py index 774ae0701..51759d9cd 100644 --- a/examples/andriod_assistant/actions/parse_record.py +++ b/examples/andriod_assistant/actions/parse_record.py @@ -6,7 +6,6 @@ import ast import json import re -import time from pathlib import Path from examples.andriod_assistant.actions.parse_record_an import RECORD_PARSE_NODE @@ -44,8 +43,8 @@ class ParseRecord(Action): doc_count = 0 self.record_path = Path(task_dir) / "record.txt" self.task_desc_path = Path(task_dir) / "task_desc.txt" - self.screenshot_before_path = Path(task_dir)/"raw_screenshots" - self.screenshot_after_path = Path(task_dir)/"labeled_screenshots" + self.screenshot_before_path = Path(task_dir) / "raw_screenshots" + self.screenshot_after_path = Path(task_dir) / "labeled_screenshots" with open(self.record_path, "r") as record_file: record_step_count = len(record_file.readlines()) - 1 @@ -137,5 +136,6 @@ class ParseRecord(Action): logger.info(f"Documentation generation phase completed. {doc_count} docs generated.") + # TODO -# 1. LOG中记录方式有问题,需要把IMG的部分拿出去丢掉 \ No newline at end of file +# 1. LOG中记录方式有问题,需要把IMG的部分拿出去丢掉 diff --git a/examples/andriod_assistant/actions/screenshot_parse.py b/examples/andriod_assistant/actions/screenshot_parse.py index 38db933ea..f3dd7da6c 100644 --- a/examples/andriod_assistant/actions/screenshot_parse.py +++ b/examples/andriod_assistant/actions/screenshot_parse.py @@ -26,8 +26,8 @@ from examples.andriod_assistant.utils.schema import ( ) from examples.andriod_assistant.utils.utils import ( area_to_xy, - draw_grid, draw_bbox_multi, + draw_grid, elem_bbox_to_xy, screenshot_parse_extract, traverse_xml_tree, @@ -79,14 +79,14 @@ class ScreenshotParse(Action): return ui_doc async def run( - self, - round_count: int, - task_desc: str, - last_act: str, - task_dir: Path, - docs_dir: Path, - grid_on: bool, - env: AndroidEnv, + self, + round_count: int, + task_desc: str, + last_act: str, + task_dir: Path, + docs_dir: Path, + grid_on: bool, + env: AndroidEnv, ): for path in [task_dir, docs_dir]: if not path.exists(): @@ -94,15 +94,11 @@ class ScreenshotParse(Action): screenshot_path: Path = await env.observe( EnvAPIAbstract( - api_name="get_screenshot", - kwargs={"ss_name": f"{round_count}_before", "local_save_dir": task_dir} + api_name="get_screenshot", kwargs={"ss_name": f"{round_count}_before", "local_save_dir": task_dir} ) ) xml_path: Path = await env.observe( - EnvAPIAbstract( - api_name="get_xml", - kwargs={"xml_name": f"{round_count}", "local_save_dir": task_dir} - ) + EnvAPIAbstract(api_name="get_xml", kwargs={"xml_name": f"{round_count}", "local_save_dir": task_dir}) ) width, height = env.device_shape if not screenshot_path.exists() or not xml_path.exists(): @@ -134,7 +130,7 @@ class ScreenshotParse(Action): parse_template = screenshot_parse_with_grid_template if grid_on else screenshot_parse_template if grid_on: - rows, cols = draw_grid(screenshot_path, task_dir / f"{round_count}_grid.png") + env.rows, env.cols = draw_grid(screenshot_path, task_dir / f"{round_count}_grid.png") ui_doc = self._makeup_ui_document(elem_list, docs_dir) context = parse_template.format(ui_document=ui_doc, task_description=task_desc, last_act=last_act) @@ -171,7 +167,7 @@ class ScreenshotParse(Action): res = await env.step( EnvAPIAbstract( api_name="user_swipe", - kwargs={"x": x, "y": y, "orient": op_param.swipe_orient, "dist": op_param.dist} + kwargs={"x": x, "y": y, "orient": op_param.swipe_orient, "dist": op_param.dist}, ) ) if res == ADB_EXEC_FAIL: @@ -190,10 +186,15 @@ class ScreenshotParse(Action): if res == ADB_EXEC_FAIL: return AndroidActionOutput(action_state=RunState.FAIL) elif isinstance(op_param, SwipeGridOp): - start_x, start_y = area_to_xy(op_param.start_area, op_param.start_subarea, env.width, env.height, env.rows, env.cols) - end_x, end_y = area_to_xy(op_param.end_area, op_param.end_subarea, env.width, env.height, env.rows, env.cols) + start_x, start_y = area_to_xy( + op_param.start_area, op_param.start_subarea, env.width, env.height, env.rows, env.cols + ) + end_x, end_y = area_to_xy( + op_param.end_area, op_param.end_subarea, env.width, env.height, env.rows, env.cols + ) res = await env.step( - EnvAPIAbstract(api_name="user_swipe_to", kwargs={"start": (start_x, start_y), "end": (end_x, end_y)})) + EnvAPIAbstract(api_name="user_swipe_to", kwargs={"start": (start_x, start_y), "end": (end_x, end_y)}) + ) if res == ADB_EXEC_FAIL: return AndroidActionOutput(action_state=RunState.FAIL) diff --git a/examples/andriod_assistant/actions/self_learn_and_reflect.py b/examples/andriod_assistant/actions/self_learn_and_reflect.py index 57dea0e79..780985947 100644 --- a/examples/andriod_assistant/actions/self_learn_and_reflect.py +++ b/examples/andriod_assistant/actions/self_learn_and_reflect.py @@ -59,17 +59,17 @@ class SelfLearnAndReflect(Action): ui_area: int = -1 async def run( - self, round_count: int, task_desc: str, last_act: str, task_dir: Path, docs_dir: Path, env: AndroidEnv + self, round_count: int, task_desc: str, last_act: str, task_dir: Path, docs_dir: Path, env: AndroidEnv ) -> AndroidActionOutput: - for path in [task_dir,docs_dir]: + for path in [task_dir, docs_dir]: if not path.exists(): - path.mkdir(parents=True,exist_ok=True) + path.mkdir(parents=True, exist_ok=True) resp = await self.run_self_learn(round_count, task_desc, last_act, task_dir, env) resp = await self.run_reflect(round_count, task_desc, last_act, task_dir, docs_dir, env) return resp async def run_self_learn( - self, round_count: int, task_desc: str, last_act: str, task_dir: Path, env: AndroidEnv + self, round_count: int, task_desc: str, last_act: str, task_dir: Path, env: AndroidEnv ) -> AndroidActionOutput: screenshot_path: Path = await env.observe( EnvAPIAbstract( @@ -151,7 +151,8 @@ class SelfLearnAndReflect(Action): x, y = elem_bbox_to_xy(elem_list[op_param.area - 1].bbox) res = await env.step( EnvAPIAbstract( - api_name="user_swipe", kwargs={"x": x, "y": y, "orient": op_param.swipe_orient, "dist": op_param.dist} + api_name="user_swipe", + kwargs={"x": x, "y": y, "orient": op_param.swipe_orient, "dist": op_param.dist}, ) ) if res == ADB_EXEC_FAIL: @@ -159,11 +160,10 @@ class SelfLearnAndReflect(Action): self.elem_list = elem_list self.act_name = op_param.act_name - print("探索阶段结束") return AndroidActionOutput() async def run_reflect( - self, round_count: int, task_desc: str, last_act: str, task_dir: Path, docs_dir: Path, env: AndroidEnv + self, round_count: int, task_desc: str, last_act: str, task_dir: Path, docs_dir: Path, env: AndroidEnv ) -> AndroidActionOutput: screenshot_path: Path = await env.observe( EnvAPIAbstract( @@ -176,7 +176,6 @@ class SelfLearnAndReflect(Action): screenshot_after_labeled_path = task_dir.joinpath(f"{round_count}_after_labeled.png") draw_bbox_multi(screenshot_path, screenshot_after_labeled_path, elem_list=self.elem_list) img_base64 = encode_image(screenshot_after_labeled_path) - if self.act_name == ActionOp.TAP.value: action = "tapping" elif self.act_name == ActionOp.LONG_PRESS.value: @@ -187,6 +186,11 @@ class SelfLearnAndReflect(Action): action = "v_swipe" elif self.swipe_orient == SwipeOp.LEFT.value or self.swipe_orient == SwipeOp.RIGHT.value: action = "h_swipe" + else: + # TODO Test for assignment, This error is eupiped with the next. + logger.info(f"Warning: current action name:{self.act_name}") + logger.info("Warning: act_name parse wrong!") + action = None context = reflect_template.format( action=action, ui_element=str(self.ui_area), task_desc=task_desc, last_act=last_act ) @@ -211,7 +215,8 @@ class SelfLearnAndReflect(Action): return AndroidActionOutput(action_state=RunState.FINISH) if op_param.param_state == RunState.FAIL: return AndroidActionOutput(action_state=RunState.FAIL) - + # TODO 这里经常出现错误 + logger.info(f"Error 高发地区, 长度为{len(self.elem_list)},ui_erea为{self.ui_area}") resource_id = self.elem_list[int(self.ui_area) - 1].uid if op_param.decision == Decision.INEFFECTIVE.value: self.useless_list.append(resource_id) @@ -235,8 +240,7 @@ class SelfLearnAndReflect(Action): doc_content = DocContent() setattr(doc_content, self.act_name, doc) doc_path.write_text(str(doc_content)) - print("反思阶段结束") return AndroidActionOutput(data={"last_act": last_act}) -# TODO 如何处理 FINISH 状态,这一点应该需要与role 联动才能解决 +# TODO 如何处理 FINISH 状态,这一点应该需要与role 联动才能解决 diff --git a/examples/andriod_assistant/roles/android_assistant.py b/examples/andriod_assistant/roles/android_assistant.py index 606d582f7..cf97b5fcd 100644 --- a/examples/andriod_assistant/roles/android_assistant.py +++ b/examples/andriod_assistant/roles/android_assistant.py @@ -2,16 +2,19 @@ # -*- coding: utf-8 -*- # @Desc : android assistant to learn from app operations and operate apps import time -from typing import Optional -from pathlib import Path -from pydantic import Field from datetime import datetime +from pathlib import Path +from typing import Optional + +from pydantic import Field from examples.andriod_assistant.actions.manual_record import ManualRecord from examples.andriod_assistant.actions.parse_record import ParseRecord from examples.andriod_assistant.actions.screenshot_parse import ScreenshotParse -from examples.andriod_assistant.actions.self_learn_and_reflect import SelfLearnAndReflect -from examples.andriod_assistant.utils.schema import RunState, AndroidActionOutput +from examples.andriod_assistant.actions.self_learn_and_reflect import ( + SelfLearnAndReflect, +) +from examples.andriod_assistant.utils.schema import AndroidActionOutput, RunState from metagpt.actions.add_requirement import UserRequirement from metagpt.config2 import config from metagpt.logs import logger @@ -35,7 +38,7 @@ class AndroidAssistant(Role): super().__init__(**data) self._watch([UserRequirement, AndroidActionOutput]) - + self.task_desc = config.get_other("task_desc", "Just explore any app in this phone!") app_name = config.get_other("app_name", "demo") curr_path = Path(__file__).parent data_dir = curr_path.joinpath("..", "output") @@ -49,20 +52,20 @@ class AndroidAssistant(Role): # Remember, only run each action only one time, no need to run n_round. self.set_actions([ManualRecord, ParseRecord]) self.task_dir = data_dir.joinpath(app_name, f"manual_learn_{cur_datetime}") - self.docs_dir = data_dir.joinpath(app_name, f"manual_docs") + self.docs_dir = data_dir.joinpath(app_name, "manual_docs") elif config.get_other("stage") == "learn" and config.get_other("mode") == "auto": # choose SelfLearnAndReflect to run self.set_actions([SelfLearnAndReflect]) self.task_dir = data_dir.joinpath(app_name, f"auto_learn_{cur_datetime}") - self.docs_dir = data_dir.joinpath(app_name, f"auto_docs") + self.docs_dir = data_dir.joinpath(app_name, "auto_docs") elif config.get_other("stage") == "act": # choose ScreenshotParse to run self.set_actions([ScreenshotParse]) self.task_dir = data_dir.joinpath(app_name, f"act_{cur_datetime}") if config.get_other("mode") == "manual": - self.docs_dir = data_dir.joinpath(app_name, f"manual_docs") + self.docs_dir = data_dir.joinpath(app_name, "manual_docs") else: - self.docs_dir = data_dir.joinpath(app_name, f"auto_docs") + self.docs_dir = data_dir.joinpath(app_name, "auto_docs") self._check_dir() self._set_react_mode(RoleReactMode.BY_ORDER) @@ -80,20 +83,14 @@ class AndroidAssistant(Role): async def _act(self) -> Message: logger.info(f"{self._setting}: to do {self.rc.todo}({self.rc.todo.name})") todo = self.rc.todo - # TODO 这里修改 Send to 会有作用吗? - send_to = "" if isinstance(todo, ManualRecord): - resp = await todo.run( - task_dir=self.task_dir, - task_desc=self.task_desc, - env=self.rc.env - ) + resp = await todo.run(task_dir=self.task_dir, task_desc=self.task_desc, env=self.rc.env) elif isinstance(todo, ParseRecord): resp = await todo.run( app_name=config.get_other("app_name", "demo"), task_dir=self.task_dir, docs_dir=self.docs_dir, - env=self.rc.env + env=self.rc.env, ) elif isinstance(todo, SelfLearnAndReflect): resp = await todo.run( @@ -102,11 +99,10 @@ class AndroidAssistant(Role): last_act=self.last_act, task_dir=self.task_dir, docs_dir=self.docs_dir, - env=self.rc.env + env=self.rc.env, ) if resp.action_state == RunState.SUCCESS: self.last_act = resp.data.get("last_act") - send_to = self.name elif isinstance(todo, ScreenshotParse): resp = await todo.run( round_count=self.round_count, @@ -115,19 +111,18 @@ class AndroidAssistant(Role): task_dir=self.task_dir, docs_dir=self.docs_dir, grid_on=self.grid_on, - env=self.rc.env + env=self.rc.env, ) if resp.action_state == RunState.SUCCESS: + logger.info(f"grid_on: {resp.data.get('grid_on')}") self.grid_on = resp.data.get("grid_on") - send_to = self.name - msg = Message( content=f"RoundCount: {self.round_count}", role=self.profile, - cause_by=type(todo), + cause_by=type(resp), send_from=self.name, - send_to=self.name + send_to=self.name, ) - self.publish_message(msg) + # self.publish_message(msg) self.rc.memory.add(msg) return msg diff --git a/examples/andriod_assistant/run_assistant.py b/examples/andriod_assistant/run_assistant.py index eb80c2111..3d9ed5cfa 100644 --- a/examples/andriod_assistant/run_assistant.py +++ b/examples/andriod_assistant/run_assistant.py @@ -44,6 +44,7 @@ def startup( "stage": stage, "mode": mode, "app_name": app_name, + "task_desc": task_desc, "refine_doc": refine_doc, "min_dist": min_dist, "android_screenshot_dir": android_screenshot_dir, @@ -68,15 +69,3 @@ def startup( if __name__ == "__main__": app() -# Command python run_assistant.py "Create a contact in Contacts App named zjy with a phone number +86 18831933368" - -# python run_assistant.py "Create a contact in Contacts App named zjy with a phone number +86 18831933368" --mode "auto" --app-name "Contacts"examples\andriod_assistant> - -# TODO -# 0. How to set Round ? -# 1. Manual Record & Parse Record Success -# 2. Self Learn Fail -# local variable 'action' referenced before assignment -# 3. Act -# 3.1 TODO Act with Manual Docs -# 3.2 TDOO Act with Auto Docs diff --git a/examples/andriod_assistant/utils/schema.py b/examples/andriod_assistant/utils/schema.py index 18e637a0d..d7990de40 100644 --- a/examples/andriod_assistant/utils/schema.py +++ b/examples/andriod_assistant/utils/schema.py @@ -3,7 +3,8 @@ # @Desc : from enum import Enum -from pydantic import Field, BaseModel, field_validator + +from pydantic import BaseModel, Field, field_validator class ActionOp(Enum): @@ -37,6 +38,7 @@ class Decision(Enum): class AndroidElement(BaseModel): """UI Element""" + uid: str = Field(default="") bbox: tuple[tuple[int, int], tuple[int, int]] = Field(default={}) attrib: str = Field(default="") @@ -44,6 +46,7 @@ class AndroidElement(BaseModel): class OpLogItem(BaseModel): """log content for self-learn or task act""" + step: int = Field(default=0) prompt: str = Field(default="") image: str = Field(default="") @@ -52,6 +55,7 @@ class OpLogItem(BaseModel): class ReflectLogItem(BaseModel): """log content for self-learn-reflect""" + step: int = Field(default=0) prompt: str = Field(default="") image_before: str = Field(default="") @@ -61,6 +65,7 @@ class ReflectLogItem(BaseModel): class RecordLogItem(BaseModel): """log content for record parse, same as ReflectLogItem""" + step: int = Field(default=0) prompt: str = Field(default="") image_before: str = Field(default="") @@ -79,6 +84,7 @@ class DocContent(BaseModel): # start =================== define different Action Op and its params ============= class RunState(Enum): """run state""" + SUCCESS = "success" FINISH = "finish" FAIL = "fail" @@ -101,6 +107,7 @@ class TextOp(BaseOpParam): class LongPressOp(BaseOpParam): area: int = Field(default=-1) + # Modify This SwipeOp to SwipeOp_3, Need better name class SwipeOp_3(BaseOpParam): area: int = Field(default=-1) @@ -113,7 +120,6 @@ class GridOp(BaseModel): class BaseGridOpParam(BaseOpParam): - @field_validator("act_name", mode="before") @classmethod def check_act_name(cls, act_name: str) -> str: diff --git a/examples/andriod_assistant/utils/utils.py b/examples/andriod_assistant/utils/utils.py index b82c656a4..b53df55be 100644 --- a/examples/andriod_assistant/utils/utils.py +++ b/examples/andriod_assistant/utils/utils.py @@ -2,20 +2,33 @@ # -*- coding: utf-8 -*- # @Desc : +import re +from pathlib import Path from typing import Union from xml.etree.ElementTree import Element, iterparse + import cv2 -from pathlib import Path import pyshine as ps -import re -from metagpt.config2 import config +from examples.andriod_assistant.utils.schema import ( + ActionOp, + AndroidElement, + BaseGridOpParam, + BaseOpParam, + Decision, + GridOp, + LongPressGridOp, + LongPressOp, + ReflectOp, + RunState, + SwipeGridOp, + SwipeOp_3, + TapGridOp, + TapOp, + TextOp, +) from metagpt.logs import logger -from examples.andriod_assistant.utils.schema import AndroidElement -from examples.andriod_assistant.utils.schema import BaseOpParam, BaseGridOpParam, GridOp, ActionOp, TapOp, TapGridOp, \ - LongPressOp, LongPressGridOp, SwipeOp_3, SwipeGridOp, TextOp, RunState, ReflectOp, Decision - def get_id_from_element(elem: Element) -> str: bounds = elem.attrib["bounds"][1:-1].split("][") @@ -67,8 +80,13 @@ def traverse_xml_tree(xml_path: Path, elem_list: list[AndroidElement], attrib: s path.pop() -def draw_bbox_multi(img_path: Path, output_path: Path, elem_list: list[AndroidElement], record_mode: bool = False, - dark_mode: bool = False): +def draw_bbox_multi( + img_path: Path, + output_path: Path, + elem_list: list[AndroidElement], + record_mode: bool = False, + dark_mode: bool = False, +): imgcv = cv2.imread(str(img_path)) count = 1 for elem in elem_list: @@ -85,17 +103,35 @@ def draw_bbox_multi(img_path: Path, output_path: Path, elem_list: list[AndroidEl color = (0, 0, 250) else: color = (0, 250, 0) - imgcv = ps.putBText(imgcv, label, text_offset_x=(left + right) // 2 + 10, - text_offset_y=(top + bottom) // 2 + 10, - vspace=10, hspace=10, font_scale=1, thickness=2, background_RGB=color, - text_RGB=(255, 250, 250), alpha=0.5) + imgcv = ps.putBText( + imgcv, + label, + text_offset_x=(left + right) // 2 + 10, + text_offset_y=(top + bottom) // 2 + 10, + vspace=10, + hspace=10, + font_scale=1, + thickness=2, + background_RGB=color, + text_RGB=(255, 250, 250), + alpha=0.5, + ) else: text_color = (10, 10, 10) if dark_mode else (255, 250, 250) bg_color = (255, 250, 250) if dark_mode else (10, 10, 10) - imgcv = ps.putBText(imgcv, label, text_offset_x=(left + right) // 2 + 10, - text_offset_y=(top + bottom) // 2 + 10, - vspace=10, hspace=10, font_scale=1, thickness=2, background_RGB=bg_color, - text_RGB=text_color, alpha=0.5) + imgcv = ps.putBText( + imgcv, + label, + text_offset_x=(left + right) // 2 + 10, + text_offset_y=(top + bottom) // 2 + 10, + vspace=10, + hspace=10, + font_scale=1, + thickness=2, + background_RGB=bg_color, + text_RGB=text_color, + alpha=0.5, + ) except Exception as e: logger.error(f"ERROR: An exception occurs while labeling the image\n{e}") count += 1 @@ -110,7 +146,7 @@ def draw_grid(img_path: Path, output_path: Path) -> tuple[int, int]: return i return -1 - image = cv2.imread(img_path) + image = cv2.imread(str(img_path)) height, width, _ = image.shape color = (255, 116, 113) unit_height = get_unit_len(height) @@ -130,16 +166,31 @@ def draw_grid(img_path: Path, output_path: Path) -> tuple[int, int]: right = int((j + 1) * unit_width) bottom = int((i + 1) * unit_height) cv2.rectangle(image, (left, top), (right, bottom), color, thick // 2) - cv2.putText(image, str(label), (left + int(unit_width * 0.05) + 3, top + int(unit_height * 0.3) + 3), 0, - int(0.01 * unit_width), (0, 0, 0), thick) - cv2.putText(image, str(label), (left + int(unit_width * 0.05), top + int(unit_height * 0.3)), 0, - int(0.01 * unit_width), color, thick) - cv2.imwrite(output_path, image) + cv2.putText( + image, + str(label), + (left + int(unit_width * 0.05) + 3, top + int(unit_height * 0.3) + 3), + 0, + int(0.01 * unit_width), + (0, 0, 0), + thick, + ) + cv2.putText( + image, + str(label), + (left + int(unit_width * 0.05), top + int(unit_height * 0.3)), + 0, + int(0.01 * unit_width), + color, + thick, + ) + cv2.imwrite(str(output_path), image) return rows, cols def area_to_xy(area: int, subarea: str, width: int, height: int, rows: int, cols: int) -> tuple[int, int]: area -= 1 + logger.info(f"{cols}") row, col = area // cols, area % cols x_0, y_0 = col * (width // cols), row * (height // rows) if subarea == "top-left": @@ -174,9 +225,11 @@ def reflect_parse_extarct(parsed_json: dict) -> ReflectOp: if decision not in Decision.values(): op = ReflectOp(param_state=RunState.FAIL) else: - op = ReflectOp(decision=parsed_json.get("Decision"), - thought=parsed_json.get("Thought"), - documentation=parsed_json.get("Documentation")) + op = ReflectOp( + decision=parsed_json.get("Decision"), + thought=parsed_json.get("Thought"), + documentation=parsed_json.get("Documentation"), + ) return op @@ -237,11 +290,9 @@ def screenshot_parse_extract_with_grid(act_name: str, act: str, last_act: str) - elif act_name == ActionOp.SWIPE.value: params = re.findall(r"swipe\((.*?)\)", act)[0].split(",") params = op_params_clean(params) - op = SwipeGridOp(act_name=act_name, - start_area=params[0], - start_subarea=params[1], - end_area=params[2], - end_subarea=params[3]) + op = SwipeGridOp( + act_name=act_name, start_area=params[0], start_subarea=params[1], end_area=params[2], end_subarea=params[3] + ) elif act_name == ActionOp.GRID.value: op = GridOp(act_name=act_name) else: diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index 6c23c4c70..bbf4aabcd 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -140,14 +140,14 @@ class ActionNode: instruct_content: BaseModel def __init__( - self, - key: str, - expected_type: Type, - instruction: str, - example: Any, - content: str = "", - children: dict[str, "ActionNode"] = None, - schema: str = "", + self, + key: str, + expected_type: Type, + instruction: str, + example: Any, + content: str = "", + children: dict[str, "ActionNode"] = None, + schema: str = "", ): self.key = key self.expected_type = expected_type @@ -349,14 +349,14 @@ class ActionNode: after=general_after_log(logger), ) async def _aask_v1( - self, - prompt: str, - output_class_name: str, - output_data_mapping: dict, - images: Optional[Union[str, list[str]]] = None, - system_msgs: Optional[list[str]] = None, - schema="markdown", # compatible to original format - timeout=3, + self, + prompt: str, + output_class_name: str, + output_data_mapping: dict, + images: Optional[Union[str, list[str]]] = None, + system_msgs: Optional[list[str]] = None, + schema="markdown", # compatible to original format + timeout=3, ) -> (str, BaseModel): """Use ActionOutput to wrap the output of aask""" content = await self.llm.aask(prompt, system_msgs, images=images, timeout=timeout) @@ -406,15 +406,15 @@ class ActionNode: return self async def fill( - self, - context, - llm, - schema="json", - mode="auto", - strgy="simple", - images: Optional[Union[str, list[str]]] = None, - timeout=3, - exclude=[], + self, + context, + llm, + schema="json", + mode="auto", + strgy="simple", + images: Optional[Union[str, list[str]]] = None, + timeout=3, + exclude=[], ): logger.info("进入fill") """Fill the node(s) with mode. @@ -560,7 +560,7 @@ class ActionNode: return nodes_output async def auto_revise( - self, revise_mode: ReviseMode = ReviseMode.AUTO, template: str = REVISE_TEMPLATE + self, revise_mode: ReviseMode = ReviseMode.AUTO, template: str = REVISE_TEMPLATE ) -> dict[str, str]: """revise the value of incorrect keys""" # generate review comments diff --git a/metagpt/environment/__init__.py b/metagpt/environment/__init__.py index d2df8fd02..04f8658f9 100644 --- a/metagpt/environment/__init__.py +++ b/metagpt/environment/__init__.py @@ -1,13 +1,11 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # @Desc : -# TODO from metagpt.environment.base_env import Environment from metagpt.environment.android_env.android_env import AndroidEnv from metagpt.environment.gym_env.gym_env import GymEnv from metagpt.environment.mincraft_env.mincraft_env import MincraftExtEnv from metagpt.environment.werewolf_env.werewolf_env import WerewolfEnv - from metagpt.environment.stanford_town_env.stanford_town_env import StanfordTownEnv # from metagpt.environment.software_env.software_env import SoftwareEnv diff --git a/metagpt/environment/android_env/android_ext_env.py b/metagpt/environment/android_env/android_ext_env.py index 72eae7182..298d79ffe 100644 --- a/metagpt/environment/android_env/android_ext_env.py +++ b/metagpt/environment/android_env/android_ext_env.py @@ -9,7 +9,12 @@ from typing import Any, Optional from pydantic import Field from metagpt.const import ADB_EXEC_FAIL -from metagpt.environment.base_env import Environment, ExtEnv, mark_as_readable, mark_as_writeable +from metagpt.environment.base_env import ( + Environment, + ExtEnv, + mark_as_readable, + mark_as_writeable, +) class AndroidExtEnv(Environment, ExtEnv): @@ -42,7 +47,7 @@ class AndroidExtEnv(Environment, ExtEnv): return f"adb -s {self.device_id} " def execute_adb_with_cmd(self, adb_cmd: str) -> str: - adb_cmd = adb_cmd.replace('\\', '/') + adb_cmd = adb_cmd.replace("\\", "/") res = subprocess.run(adb_cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) exec_res = ADB_EXEC_FAIL if not res.returncode: diff --git a/metagpt/environment/base_env.py b/metagpt/environment/base_env.py index b39010aa1..86ccf99eb 100644 --- a/metagpt/environment/base_env.py +++ b/metagpt/environment/base_env.py @@ -5,7 +5,9 @@ import asyncio from enum import Enum from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Set, Union + from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator + from metagpt.context import Context from metagpt.environment.api.env_api import ( EnvAPIAbstract, @@ -43,12 +45,14 @@ def mark_as_writeable(func): env_write_api_registry[func.__name__] = get_function_schema(func) return func + class ExtEnv(BaseModel): """External Env to intergate actual game environment""" write_api_registry: WriteAPIRegistry = Field(default_factory=WriteAPIRegistry, exclude=True) read_api_registry: ReadAPIRegistry = Field(default_factory=ReadAPIRegistry, exclude=True) + class ExtEnv(BaseModel): """External Env to intergate actual game environment""" @@ -97,6 +101,7 @@ class ExtEnv(BaseModel): return res + class Environment(ExtEnv): """环境,承载一批角色,角色可以向环境发布消息,可以被其他角色观察到 Environment, hosting a batch of roles, roles can publish messages to the environment, and can be observed by other roles @@ -212,9 +217,3 @@ class Environment(ExtEnv): Environment.model_rebuild() - - - - - - diff --git a/metagpt/schema.py b/metagpt/schema.py index 22bb359b6..7bbb567b9 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -41,6 +41,7 @@ from metagpt.const import ( MESSAGE_ROUTE_FROM, MESSAGE_ROUTE_TO, MESSAGE_ROUTE_TO_ALL, + PRDS_FILE_REPO, SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO, ) @@ -328,6 +329,200 @@ class AIMessage(Message): super().__init__(content=content, role="assistant") +class Task(BaseModel): + task_id: str = "" + dependent_task_ids: list[str] = [] # Tasks prerequisite to this Task + instruction: str = "" + task_type: str = "" + code: str = "" + result: str = "" + is_success: bool = False + is_finished: bool = False + + def reset(self): + self.code = "" + self.result = "" + self.is_success = False + self.is_finished = False + + def update_task_result(self, task_result: TaskResult): + self.code = task_result.code + self.result = task_result.result + self.is_success = task_result.is_success + + +class TaskResult(BaseModel): + """Result of taking a task, with result and is_success required to be filled""" + + code: str = "" + result: str + is_success: bool + + +class Plan(BaseModel): + goal: str + context: str = "" + tasks: list[Task] = [] + task_map: dict[str, Task] = {} + current_task_id: str = "" + + def _topological_sort(self, tasks: list[Task]): + task_map = {task.task_id: task for task in tasks} + dependencies = {task.task_id: set(task.dependent_task_ids) for task in tasks} + sorted_tasks = [] + visited = set() + + def visit(task_id): + if task_id in visited: + return + visited.add(task_id) + for dependent_id in dependencies.get(task_id, []): + visit(dependent_id) + sorted_tasks.append(task_map[task_id]) + + for task in tasks: + visit(task.task_id) + + return sorted_tasks + + def add_tasks(self, tasks: list[Task]): + """ + Integrates new tasks into the existing plan, ensuring dependency order is maintained. + + This method performs two primary functions based on the current state of the task list: + 1. If there are no existing tasks, it topologically sorts the provided tasks to ensure + correct execution order based on dependencies, and sets these as the current tasks. + 2. If there are existing tasks, it merges the new tasks with the existing ones. It maintains + any common prefix of tasks (based on task_id and instruction) and appends the remainder + of the new tasks. The current task is updated to the first unfinished task in this merged list. + + Args: + tasks (list[Task]): A list of tasks (may be unordered) to add to the plan. + + Returns: + None: The method updates the internal state of the plan but does not return anything. + """ + if not tasks: + return + + # Topologically sort the new tasks to ensure correct dependency order + new_tasks = self._topological_sort(tasks) + + if not self.tasks: + # If there are no existing tasks, set the new tasks as the current tasks + self.tasks = new_tasks + + else: + # Find the length of the common prefix between existing and new tasks + prefix_length = 0 + for old_task, new_task in zip(self.tasks, new_tasks): + if old_task.task_id != new_task.task_id or old_task.instruction != new_task.instruction: + break + prefix_length += 1 + + # Combine the common prefix with the remainder of the new tasks + final_tasks = self.tasks[:prefix_length] + new_tasks[prefix_length:] + self.tasks = final_tasks + + # Update current_task_id to the first unfinished task in the merged list + self._update_current_task() + + # Update the task map for quick access to tasks by ID + self.task_map = {task.task_id: task for task in self.tasks} + + def reset_task(self, task_id: str): + """ + Clear code and result of the task based on task_id, and set the task as unfinished. + + Args: + task_id (str): The ID of the task to be reset. + + Returns: + None + """ + if task_id in self.task_map: + task = self.task_map[task_id] + task.reset() + + def replace_task(self, new_task: Task): + """ + Replace an existing task with the new input task based on task_id, and reset all tasks depending on it. + + Args: + new_task (Task): The new task that will replace an existing one. + + Returns: + None + """ + assert new_task.task_id in self.task_map + # Replace the task in the task map and the task list + self.task_map[new_task.task_id] = new_task + for i, task in enumerate(self.tasks): + if task.task_id == new_task.task_id: + self.tasks[i] = new_task + break + + # Reset dependent tasks + for task in self.tasks: + if new_task.task_id in task.dependent_task_ids: + self.reset_task(task.task_id) + + def append_task(self, new_task: Task): + """ + Append a new task to the end of existing task sequences + + Args: + new_task (Task): The new task to be appended to the existing task sequence + + Returns: + None + """ + assert not self.has_task_id(new_task.task_id), "Task already in current plan, use replace_task instead" + + assert all( + [self.has_task_id(dep_id) for dep_id in new_task.dependent_task_ids] + ), "New task has unknown dependencies" + + # Existing tasks do not depend on the new task, it's fine to put it to the end of the sorted task sequence + self.tasks.append(new_task) + self.task_map[new_task.task_id] = new_task + self._update_current_task() + + def has_task_id(self, task_id: str) -> bool: + return task_id in self.task_map + + def _update_current_task(self): + current_task_id = "" + for task in self.tasks: + if not task.is_finished: + current_task_id = task.task_id + break + self.current_task_id = current_task_id # all tasks finished + + @property + def current_task(self) -> Task: + """Find current task to execute + + Returns: + Task: the current task to be executed + """ + return self.task_map.get(self.current_task_id, None) + + def finish_current_task(self): + """Finish current task, set Task.is_finished=True, set current task to next task""" + if self.current_task_id: + self.current_task.is_finished = True + self._update_current_task() # set to next task + + def get_finished_tasks(self) -> list[Task]: + """return all finished tasks in correct linearized order + + Returns: + list[Task]: list of finished tasks + """ + return [task for task in self.tasks if task.is_finished] + + class MessageQueue(BaseModel): """Message queue which supports asynchronous updates.""" @@ -417,6 +612,7 @@ class CodingContext(BaseContext): design_doc: Optional[Document] = None task_doc: Optional[Document] = None code_doc: Optional[Document] = None + code_plan_and_change_doc: Optional[Document] = None class TestingContext(BaseContext): @@ -470,6 +666,29 @@ class BugFixContext(BaseContext): filename: str = "" +class CodePlanAndChangeContext(BaseModel): + requirement: str = "" + prd_filename: str = "" + design_filename: str = "" + task_filename: str = "" + + @staticmethod + def loads(filenames: List, **kwargs) -> CodePlanAndChangeContext: + ctx = CodePlanAndChangeContext(requirement=kwargs.get("requirement", "")) + for filename in filenames: + filename = Path(filename) + if filename.is_relative_to(PRDS_FILE_REPO): + ctx.prd_filename = filename.name + continue + if filename.is_relative_to(SYSTEM_DESIGN_FILE_REPO): + ctx.design_filename = filename.name + continue + if filename.is_relative_to(TASK_FILE_REPO): + ctx.task_filename = filename.name + continue + return ctx + + # mermaid class view class ClassMeta(BaseModel): name: str = "" diff --git a/metagpt/team.py b/metagpt/team.py index 2cc5d659c..beb1d6186 100644 --- a/metagpt/team.py +++ b/metagpt/team.py @@ -76,7 +76,7 @@ class Team(BaseModel): def hire(self, roles: list[Role]): """Hire roles to cooperate""" - only_role = roles[0] + roles[0] self.env.add_roles(roles) @property @@ -134,4 +134,4 @@ class Team(BaseModel): await self.env.run() self.env.archive(auto_archive) - return self.env.history \ No newline at end of file + return self.env.history diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index 25aeb54e8..015902c3d 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -24,13 +24,16 @@ import re import sys import traceback import typing +from io import BytesIO from pathlib import Path -from typing import Any, List, Tuple, Union, Callable +from typing import Any, Callable, List, Tuple, Union import aiofiles import loguru +import requests +from PIL import Image from pydantic_core import to_jsonable_python -from tenacity import RetryCallState, _utils +from tenacity import RetryCallState, RetryError, _utils from metagpt.const import MESSAGE_ROUTE_TO_ALL from metagpt.logs import logger @@ -214,7 +217,7 @@ class OutputParser: if start_index != -1 and end_index != -1: # Extract the structure part - structure_text = text[start_index: end_index + 1] + structure_text = text[start_index : end_index + 1] try: # Attempt to convert the text to a Python data type using ast.literal_eval @@ -358,6 +361,31 @@ def parse_recipient(text): return "" +def create_func_call_config(func_schema: dict) -> dict: + """Create new function call config""" + tools = [{"type": "function", "function": func_schema}] + tool_choice = {"type": "function", "function": {"name": func_schema["name"]}} + return { + "tools": tools, + "tool_choice": tool_choice, + } + + +def remove_comments(code_str: str) -> str: + """Remove comments from code.""" + pattern = r"(\".*?\"|\'.*?\')|(\#.*?$)" + + def replace_func(match): + if match.group(2) is not None: + return "" + else: + return match.group(1) + + clean_code = re.sub(pattern, replace_func, code_str, flags=re.MULTILINE) + clean_code = os.linesep.join([s.rstrip() for s in clean_code.splitlines() if s.strip()]) + return clean_code + + def get_class_name(cls) -> str: """Return class name""" return f"{cls.__module__}.{cls.__name__}" @@ -466,13 +494,13 @@ def read_json_file(json_file: str, encoding="utf-8") -> list[Any]: return data -def write_json_file(json_file: str, data: list, encoding=None): +def write_json_file(json_file: str, data: list, encoding: str = None, indent: int = 4): folder_path = Path(json_file).parent if not folder_path.exists(): folder_path.mkdir(parents=True, exist_ok=True) with open(json_file, "w", encoding=encoding) as fout: - json.dump(data, fout, ensure_ascii=False, indent=4, default=to_jsonable_python) + json.dump(data, fout, ensure_ascii=False, indent=indent, default=to_jsonable_python) def read_csv_to_list(curr_file: str, header=False, strip_trail=True): @@ -538,7 +566,7 @@ def role_raise_decorator(func): self.rc.memory.delete(self.latest_observed_msg) # raise again to make it captured outside raise Exception(format_trackback_info(limit=None)) - except Exception: + except Exception as e: if self.latest_observed_msg: logger.warning( "There is a exception in role's execution, in order to resume, " @@ -547,6 +575,12 @@ def role_raise_decorator(func): # remove role newest observed msg to make it observed again self.rc.memory.delete(self.latest_observed_msg) # raise again to make it captured outside + if isinstance(e, RetryError): + last_error = e.last_attempt._exception + name = any_to_str(last_error) + if re.match(r"^openai\.", name) or re.match(r"^httpx\.", name): + raise last_error + raise Exception(format_trackback_info(limit=None)) return wrapper @@ -606,6 +640,39 @@ def is_coroutine_func(func: Callable) -> bool: return inspect.iscoroutinefunction(func) -def encode_image(image_path: Path, encoding: str = "utf-8") -> str: - with open(str(image_path), "rb") as image_file: - return base64.b64encode(image_file.read()).decode(encoding) +def load_mc_skills_code(skill_names: list[str] = None, skills_dir: Path = None) -> list[str]: + """load mincraft skill from js files""" + if not skills_dir: + skills_dir = Path(__file__).parent.absolute() + if skill_names is None: + skill_names = [skill[:-3] for skill in os.listdir(f"{skills_dir}") if skill.endswith(".js")] + skills = [skills_dir.joinpath(f"{skill_name}.js").read_text() for skill_name in skill_names] + return skills + + +def encode_image(image_path_or_pil: Union[Path, Image], encoding: str = "utf-8") -> str: + """encode image from file or PIL.Image into base64""" + if isinstance(image_path_or_pil, Image.Image): + buffer = BytesIO() + image_path_or_pil.save(buffer, format="JPEG") + bytes_data = buffer.getvalue() + else: + if not image_path_or_pil.exists(): + raise FileNotFoundError(f"{image_path_or_pil} not exists") + with open(str(image_path_or_pil), "rb") as image_file: + bytes_data = image_file.read() + return base64.b64encode(bytes_data).decode(encoding) + + +def decode_image(img_url_or_b64: str) -> Image: + """decode image from url or base64 into PIL.Image""" + if img_url_or_b64.startswith("http"): + # image http(s) url + resp = requests.get(img_url_or_b64) + img = Image.open(BytesIO(resp.content)) + else: + # image b64_json + b64_data = re.sub("^data:image/.+;base64,", "", img_url_or_b64) + img_data = BytesIO(base64.b64decode(b64_data)) + img = Image.open(img_data) + return img diff --git a/tests/metagpt/environment/test_base_env.py b/tests/metagpt/environment/test_base_env.py index 85df6d023..0c4fb9ef1 100644 --- a/tests/metagpt/environment/test_base_env.py +++ b/tests/metagpt/environment/test_base_env.py @@ -4,8 +4,8 @@ import pytest -from metagpt.environment.base_env import Env, mark_as_writeable, mark_as_readable from metagpt.environment.api.env_api import EnvAPIAbstract +from metagpt.environment.base_env import Env, mark_as_readable, mark_as_writeable class ForTestEnv(Env):