FIx Format and Some bugs in android_assistant.py

This commit is contained in:
didi 2024-03-04 16:47:27 +08:00
parent f58012611c
commit 138bb6e63d
16 changed files with 510 additions and 177 deletions

View file

@ -9,9 +9,8 @@ import cv2
from examples.andriod_assistant.utils.schema import (
ActionOp,
AndroidActionOutput,
AndroidElement,
RunState,
SwipeOp
SwipeOp,
)
from examples.andriod_assistant.utils.utils import draw_bbox_multi, traverse_xml_tree
from metagpt.actions.action import Action
@ -24,6 +23,7 @@ from metagpt.logs import logger
class ManualRecord(Action):
"""do a human operation on the screen with human input"""
name: str = "ManualRecord"
useless_list: list[str] = [] # store useless elements uid
@ -35,19 +35,18 @@ class ManualRecord(Action):
# async def run(self, demo_name: str, task_desc: str,task_dir: Path, env: AndroidEnv):
async def run(self, task_desc: str, task_dir: Path, env: AndroidEnv):
self.record_path = Path(task_dir) / "record.txt"
self.task_desc_path = Path(task_dir) / "task_desc.txt"
self.screenshot_before_path = Path(task_dir)/"raw_screenshots"
self.screenshot_after_path = Path(task_dir)/"labeled_screenshots"
self.xml_path = Path(task_dir)/"xml"
self.screenshot_before_path = Path(task_dir) / "raw_screenshots"
self.screenshot_after_path = Path(task_dir) / "labeled_screenshots"
self.xml_path = Path(task_dir) / "xml"
for path in [self.screenshot_before_path,self.screenshot_after_path, self.xml_path]:
for path in [self.screenshot_before_path, self.screenshot_after_path, self.xml_path]:
if not path.exists():
path.mkdir(parents=True, exist_ok=True)
with open(self.record_path, 'w') as file:
file.write('')
with open(self.record_path, "w") as file:
file.write("")
record_file = open(self.record_path, "w")
with open(self.task_desc_path, "w") as f:
f.write(task_desc)
@ -58,14 +57,14 @@ class ManualRecord(Action):
EnvAPIAbstract(
api_name="get_screenshot",
# kwargs={"ss_name": f"{demo_name}_{step}", "local_save_dir": self.screenshot_before_path}
kwargs={"ss_name": f"{step}", "local_save_dir": self.screenshot_before_path}
kwargs={"ss_name": f"{step}", "local_save_dir": self.screenshot_before_path},
)
)
xml_path: Path = await env.observe(
EnvAPIAbstract(
api_name="get_xml",
# kwargs={"xml_name": f"{demo_name}_{step}", "local_save_dir": self.xml_path}
kwargs={"xml_name": f"{step}", "local_save_dir": self.xml_path}
kwargs={"xml_name": f"{step}", "local_save_dir": self.xml_path},
)
)
if not screenshot_path.exists() or not xml_path.exists():
@ -110,11 +109,11 @@ class ManualRecord(Action):
)
while (
user_input.lower() != ActionOp.TAP.value
and user_input.lower() != ActionOp.TEXT.value
and user_input.lower() != ActionOp.LONG_PRESS.value
and user_input.lower() != ActionOp.SWIPE.value
and user_input.lower() != ActionOp.STOP.value
user_input.lower() != ActionOp.TAP.value
and user_input.lower() != ActionOp.TEXT.value
and user_input.lower() != ActionOp.LONG_PRESS.value
and user_input.lower() != ActionOp.SWIPE.value
and user_input.lower() != ActionOp.STOP.value
):
user_input = input()
@ -167,10 +166,10 @@ class ManualRecord(Action):
)
user_input = ""
while (
user_input != SwipeOp.UP.value
and user_input != SwipeOp.DOWN.value
and user_input != SwipeOp.LEFT.value
and user_input != SwipeOp.RIGHT.value
user_input != SwipeOp.UP.value
and user_input != SwipeOp.DOWN.value
and user_input != SwipeOp.LEFT.value
and user_input != SwipeOp.RIGHT.value
):
user_input = input()
swipe_dir = user_input
@ -179,7 +178,9 @@ class ManualRecord(Action):
user_input = input()
tl, br = elem_list[int(user_input) - 1].bbox
x, y = (tl[0] + br[0]) // 2, (tl[1] + br[1]) // 2
ret = await env.step(EnvAPIAbstract(api_name="user_swipe", kwargs={"x": x, "y": y, "orient": swipe_dir}))
ret = await env.step(
EnvAPIAbstract(api_name="user_swipe", kwargs={"x": x, "y": y, "orient": swipe_dir})
)
if ret == ADB_EXEC_FAIL:
return AndroidActionOutput(action_state=RunState.FAIL)
record_file.write(f"swipe({int(user_input)}:sep:{swipe_dir}):::{elem_list[int(user_input) - 1].uid}\n")
@ -190,5 +191,3 @@ class ManualRecord(Action):
else:
break
time.sleep(3)

View file

@ -6,7 +6,6 @@
import ast
import json
import re
import time
from pathlib import Path
from examples.andriod_assistant.actions.parse_record_an import RECORD_PARSE_NODE
@ -44,8 +43,8 @@ class ParseRecord(Action):
doc_count = 0
self.record_path = Path(task_dir) / "record.txt"
self.task_desc_path = Path(task_dir) / "task_desc.txt"
self.screenshot_before_path = Path(task_dir)/"raw_screenshots"
self.screenshot_after_path = Path(task_dir)/"labeled_screenshots"
self.screenshot_before_path = Path(task_dir) / "raw_screenshots"
self.screenshot_after_path = Path(task_dir) / "labeled_screenshots"
with open(self.record_path, "r") as record_file:
record_step_count = len(record_file.readlines()) - 1
@ -137,5 +136,6 @@ class ParseRecord(Action):
logger.info(f"Documentation generation phase completed. {doc_count} docs generated.")
# TODO
# 1. LOG中记录方式有问题需要把IMG的部分拿出去丢掉
# 1. LOG中记录方式有问题需要把IMG的部分拿出去丢掉

View file

@ -26,8 +26,8 @@ from examples.andriod_assistant.utils.schema import (
)
from examples.andriod_assistant.utils.utils import (
area_to_xy,
draw_grid,
draw_bbox_multi,
draw_grid,
elem_bbox_to_xy,
screenshot_parse_extract,
traverse_xml_tree,
@ -79,14 +79,14 @@ class ScreenshotParse(Action):
return ui_doc
async def run(
self,
round_count: int,
task_desc: str,
last_act: str,
task_dir: Path,
docs_dir: Path,
grid_on: bool,
env: AndroidEnv,
self,
round_count: int,
task_desc: str,
last_act: str,
task_dir: Path,
docs_dir: Path,
grid_on: bool,
env: AndroidEnv,
):
for path in [task_dir, docs_dir]:
if not path.exists():
@ -94,15 +94,11 @@ class ScreenshotParse(Action):
screenshot_path: Path = await env.observe(
EnvAPIAbstract(
api_name="get_screenshot",
kwargs={"ss_name": f"{round_count}_before", "local_save_dir": task_dir}
api_name="get_screenshot", kwargs={"ss_name": f"{round_count}_before", "local_save_dir": task_dir}
)
)
xml_path: Path = await env.observe(
EnvAPIAbstract(
api_name="get_xml",
kwargs={"xml_name": f"{round_count}", "local_save_dir": task_dir}
)
EnvAPIAbstract(api_name="get_xml", kwargs={"xml_name": f"{round_count}", "local_save_dir": task_dir})
)
width, height = env.device_shape
if not screenshot_path.exists() or not xml_path.exists():
@ -134,7 +130,7 @@ class ScreenshotParse(Action):
parse_template = screenshot_parse_with_grid_template if grid_on else screenshot_parse_template
if grid_on:
rows, cols = draw_grid(screenshot_path, task_dir / f"{round_count}_grid.png")
env.rows, env.cols = draw_grid(screenshot_path, task_dir / f"{round_count}_grid.png")
ui_doc = self._makeup_ui_document(elem_list, docs_dir)
context = parse_template.format(ui_document=ui_doc, task_description=task_desc, last_act=last_act)
@ -171,7 +167,7 @@ class ScreenshotParse(Action):
res = await env.step(
EnvAPIAbstract(
api_name="user_swipe",
kwargs={"x": x, "y": y, "orient": op_param.swipe_orient, "dist": op_param.dist}
kwargs={"x": x, "y": y, "orient": op_param.swipe_orient, "dist": op_param.dist},
)
)
if res == ADB_EXEC_FAIL:
@ -190,10 +186,15 @@ class ScreenshotParse(Action):
if res == ADB_EXEC_FAIL:
return AndroidActionOutput(action_state=RunState.FAIL)
elif isinstance(op_param, SwipeGridOp):
start_x, start_y = area_to_xy(op_param.start_area, op_param.start_subarea, env.width, env.height, env.rows, env.cols)
end_x, end_y = area_to_xy(op_param.end_area, op_param.end_subarea, env.width, env.height, env.rows, env.cols)
start_x, start_y = area_to_xy(
op_param.start_area, op_param.start_subarea, env.width, env.height, env.rows, env.cols
)
end_x, end_y = area_to_xy(
op_param.end_area, op_param.end_subarea, env.width, env.height, env.rows, env.cols
)
res = await env.step(
EnvAPIAbstract(api_name="user_swipe_to", kwargs={"start": (start_x, start_y), "end": (end_x, end_y)}))
EnvAPIAbstract(api_name="user_swipe_to", kwargs={"start": (start_x, start_y), "end": (end_x, end_y)})
)
if res == ADB_EXEC_FAIL:
return AndroidActionOutput(action_state=RunState.FAIL)

View file

@ -59,17 +59,17 @@ class SelfLearnAndReflect(Action):
ui_area: int = -1
async def run(
self, round_count: int, task_desc: str, last_act: str, task_dir: Path, docs_dir: Path, env: AndroidEnv
self, round_count: int, task_desc: str, last_act: str, task_dir: Path, docs_dir: Path, env: AndroidEnv
) -> AndroidActionOutput:
for path in [task_dir,docs_dir]:
for path in [task_dir, docs_dir]:
if not path.exists():
path.mkdir(parents=True,exist_ok=True)
path.mkdir(parents=True, exist_ok=True)
resp = await self.run_self_learn(round_count, task_desc, last_act, task_dir, env)
resp = await self.run_reflect(round_count, task_desc, last_act, task_dir, docs_dir, env)
return resp
async def run_self_learn(
self, round_count: int, task_desc: str, last_act: str, task_dir: Path, env: AndroidEnv
self, round_count: int, task_desc: str, last_act: str, task_dir: Path, env: AndroidEnv
) -> AndroidActionOutput:
screenshot_path: Path = await env.observe(
EnvAPIAbstract(
@ -151,7 +151,8 @@ class SelfLearnAndReflect(Action):
x, y = elem_bbox_to_xy(elem_list[op_param.area - 1].bbox)
res = await env.step(
EnvAPIAbstract(
api_name="user_swipe", kwargs={"x": x, "y": y, "orient": op_param.swipe_orient, "dist": op_param.dist}
api_name="user_swipe",
kwargs={"x": x, "y": y, "orient": op_param.swipe_orient, "dist": op_param.dist},
)
)
if res == ADB_EXEC_FAIL:
@ -159,11 +160,10 @@ class SelfLearnAndReflect(Action):
self.elem_list = elem_list
self.act_name = op_param.act_name
print("探索阶段结束")
return AndroidActionOutput()
async def run_reflect(
self, round_count: int, task_desc: str, last_act: str, task_dir: Path, docs_dir: Path, env: AndroidEnv
self, round_count: int, task_desc: str, last_act: str, task_dir: Path, docs_dir: Path, env: AndroidEnv
) -> AndroidActionOutput:
screenshot_path: Path = await env.observe(
EnvAPIAbstract(
@ -176,7 +176,6 @@ class SelfLearnAndReflect(Action):
screenshot_after_labeled_path = task_dir.joinpath(f"{round_count}_after_labeled.png")
draw_bbox_multi(screenshot_path, screenshot_after_labeled_path, elem_list=self.elem_list)
img_base64 = encode_image(screenshot_after_labeled_path)
if self.act_name == ActionOp.TAP.value:
action = "tapping"
elif self.act_name == ActionOp.LONG_PRESS.value:
@ -187,6 +186,11 @@ class SelfLearnAndReflect(Action):
action = "v_swipe"
elif self.swipe_orient == SwipeOp.LEFT.value or self.swipe_orient == SwipeOp.RIGHT.value:
action = "h_swipe"
else:
# TODO Test for assignment, This error is eupiped with the next.
logger.info(f"Warning: current action name:{self.act_name}")
logger.info("Warning: act_name parse wrong!")
action = None
context = reflect_template.format(
action=action, ui_element=str(self.ui_area), task_desc=task_desc, last_act=last_act
)
@ -211,7 +215,8 @@ class SelfLearnAndReflect(Action):
return AndroidActionOutput(action_state=RunState.FINISH)
if op_param.param_state == RunState.FAIL:
return AndroidActionOutput(action_state=RunState.FAIL)
# TODO 这里经常出现错误
logger.info(f"Error 高发地区, 长度为{len(self.elem_list)}ui_erea为{self.ui_area}")
resource_id = self.elem_list[int(self.ui_area) - 1].uid
if op_param.decision == Decision.INEFFECTIVE.value:
self.useless_list.append(resource_id)
@ -235,8 +240,7 @@ class SelfLearnAndReflect(Action):
doc_content = DocContent()
setattr(doc_content, self.act_name, doc)
doc_path.write_text(str(doc_content))
print("反思阶段结束")
return AndroidActionOutput(data={"last_act": last_act})
# TODO 如何处理 FINISH 状态这一点应该需要与role 联动才能解决
# TODO 如何处理 FINISH 状态这一点应该需要与role 联动才能解决

View file

@ -2,16 +2,19 @@
# -*- coding: utf-8 -*-
# @Desc : android assistant to learn from app operations and operate apps
import time
from typing import Optional
from pathlib import Path
from pydantic import Field
from datetime import datetime
from pathlib import Path
from typing import Optional
from pydantic import Field
from examples.andriod_assistant.actions.manual_record import ManualRecord
from examples.andriod_assistant.actions.parse_record import ParseRecord
from examples.andriod_assistant.actions.screenshot_parse import ScreenshotParse
from examples.andriod_assistant.actions.self_learn_and_reflect import SelfLearnAndReflect
from examples.andriod_assistant.utils.schema import RunState, AndroidActionOutput
from examples.andriod_assistant.actions.self_learn_and_reflect import (
SelfLearnAndReflect,
)
from examples.andriod_assistant.utils.schema import AndroidActionOutput, RunState
from metagpt.actions.add_requirement import UserRequirement
from metagpt.config2 import config
from metagpt.logs import logger
@ -35,7 +38,7 @@ class AndroidAssistant(Role):
super().__init__(**data)
self._watch([UserRequirement, AndroidActionOutput])
self.task_desc = config.get_other("task_desc", "Just explore any app in this phone!")
app_name = config.get_other("app_name", "demo")
curr_path = Path(__file__).parent
data_dir = curr_path.joinpath("..", "output")
@ -49,20 +52,20 @@ class AndroidAssistant(Role):
# Remember, only run each action only one time, no need to run n_round.
self.set_actions([ManualRecord, ParseRecord])
self.task_dir = data_dir.joinpath(app_name, f"manual_learn_{cur_datetime}")
self.docs_dir = data_dir.joinpath(app_name, f"manual_docs")
self.docs_dir = data_dir.joinpath(app_name, "manual_docs")
elif config.get_other("stage") == "learn" and config.get_other("mode") == "auto":
# choose SelfLearnAndReflect to run
self.set_actions([SelfLearnAndReflect])
self.task_dir = data_dir.joinpath(app_name, f"auto_learn_{cur_datetime}")
self.docs_dir = data_dir.joinpath(app_name, f"auto_docs")
self.docs_dir = data_dir.joinpath(app_name, "auto_docs")
elif config.get_other("stage") == "act":
# choose ScreenshotParse to run
self.set_actions([ScreenshotParse])
self.task_dir = data_dir.joinpath(app_name, f"act_{cur_datetime}")
if config.get_other("mode") == "manual":
self.docs_dir = data_dir.joinpath(app_name, f"manual_docs")
self.docs_dir = data_dir.joinpath(app_name, "manual_docs")
else:
self.docs_dir = data_dir.joinpath(app_name, f"auto_docs")
self.docs_dir = data_dir.joinpath(app_name, "auto_docs")
self._check_dir()
self._set_react_mode(RoleReactMode.BY_ORDER)
@ -80,20 +83,14 @@ class AndroidAssistant(Role):
async def _act(self) -> Message:
logger.info(f"{self._setting}: to do {self.rc.todo}({self.rc.todo.name})")
todo = self.rc.todo
# TODO 这里修改 Send to 会有作用吗?
send_to = ""
if isinstance(todo, ManualRecord):
resp = await todo.run(
task_dir=self.task_dir,
task_desc=self.task_desc,
env=self.rc.env
)
resp = await todo.run(task_dir=self.task_dir, task_desc=self.task_desc, env=self.rc.env)
elif isinstance(todo, ParseRecord):
resp = await todo.run(
app_name=config.get_other("app_name", "demo"),
task_dir=self.task_dir,
docs_dir=self.docs_dir,
env=self.rc.env
env=self.rc.env,
)
elif isinstance(todo, SelfLearnAndReflect):
resp = await todo.run(
@ -102,11 +99,10 @@ class AndroidAssistant(Role):
last_act=self.last_act,
task_dir=self.task_dir,
docs_dir=self.docs_dir,
env=self.rc.env
env=self.rc.env,
)
if resp.action_state == RunState.SUCCESS:
self.last_act = resp.data.get("last_act")
send_to = self.name
elif isinstance(todo, ScreenshotParse):
resp = await todo.run(
round_count=self.round_count,
@ -115,19 +111,18 @@ class AndroidAssistant(Role):
task_dir=self.task_dir,
docs_dir=self.docs_dir,
grid_on=self.grid_on,
env=self.rc.env
env=self.rc.env,
)
if resp.action_state == RunState.SUCCESS:
logger.info(f"grid_on: {resp.data.get('grid_on')}")
self.grid_on = resp.data.get("grid_on")
send_to = self.name
msg = Message(
content=f"RoundCount: {self.round_count}",
role=self.profile,
cause_by=type(todo),
cause_by=type(resp),
send_from=self.name,
send_to=self.name
send_to=self.name,
)
self.publish_message(msg)
# self.publish_message(msg)
self.rc.memory.add(msg)
return msg

View file

@ -44,6 +44,7 @@ def startup(
"stage": stage,
"mode": mode,
"app_name": app_name,
"task_desc": task_desc,
"refine_doc": refine_doc,
"min_dist": min_dist,
"android_screenshot_dir": android_screenshot_dir,
@ -68,15 +69,3 @@ def startup(
if __name__ == "__main__":
app()
# Command python run_assistant.py "Create a contact in Contacts App named zjy with a phone number +86 18831933368"
# python run_assistant.py "Create a contact in Contacts App named zjy with a phone number +86 18831933368" --mode "auto" --app-name "Contacts"examples\andriod_assistant>
# TODO
# 0. How to set Round ?
# 1. Manual Record & Parse Record Success
# 2. Self Learn Fail
# local variable 'action' referenced before assignment
# 3. Act
# 3.1 TODO Act with Manual Docs
# 3.2 TDOO Act with Auto Docs

View file

@ -3,7 +3,8 @@
# @Desc :
from enum import Enum
from pydantic import Field, BaseModel, field_validator
from pydantic import BaseModel, Field, field_validator
class ActionOp(Enum):
@ -37,6 +38,7 @@ class Decision(Enum):
class AndroidElement(BaseModel):
"""UI Element"""
uid: str = Field(default="")
bbox: tuple[tuple[int, int], tuple[int, int]] = Field(default={})
attrib: str = Field(default="")
@ -44,6 +46,7 @@ class AndroidElement(BaseModel):
class OpLogItem(BaseModel):
"""log content for self-learn or task act"""
step: int = Field(default=0)
prompt: str = Field(default="")
image: str = Field(default="")
@ -52,6 +55,7 @@ class OpLogItem(BaseModel):
class ReflectLogItem(BaseModel):
"""log content for self-learn-reflect"""
step: int = Field(default=0)
prompt: str = Field(default="")
image_before: str = Field(default="")
@ -61,6 +65,7 @@ class ReflectLogItem(BaseModel):
class RecordLogItem(BaseModel):
"""log content for record parse, same as ReflectLogItem"""
step: int = Field(default=0)
prompt: str = Field(default="")
image_before: str = Field(default="")
@ -79,6 +84,7 @@ class DocContent(BaseModel):
# start =================== define different Action Op and its params =============
class RunState(Enum):
"""run state"""
SUCCESS = "success"
FINISH = "finish"
FAIL = "fail"
@ -101,6 +107,7 @@ class TextOp(BaseOpParam):
class LongPressOp(BaseOpParam):
area: int = Field(default=-1)
# Modify This SwipeOp to SwipeOp_3, Need better name
class SwipeOp_3(BaseOpParam):
area: int = Field(default=-1)
@ -113,7 +120,6 @@ class GridOp(BaseModel):
class BaseGridOpParam(BaseOpParam):
@field_validator("act_name", mode="before")
@classmethod
def check_act_name(cls, act_name: str) -> str:

View file

@ -2,20 +2,33 @@
# -*- coding: utf-8 -*-
# @Desc :
import re
from pathlib import Path
from typing import Union
from xml.etree.ElementTree import Element, iterparse
import cv2
from pathlib import Path
import pyshine as ps
import re
from metagpt.config2 import config
from examples.andriod_assistant.utils.schema import (
ActionOp,
AndroidElement,
BaseGridOpParam,
BaseOpParam,
Decision,
GridOp,
LongPressGridOp,
LongPressOp,
ReflectOp,
RunState,
SwipeGridOp,
SwipeOp_3,
TapGridOp,
TapOp,
TextOp,
)
from metagpt.logs import logger
from examples.andriod_assistant.utils.schema import AndroidElement
from examples.andriod_assistant.utils.schema import BaseOpParam, BaseGridOpParam, GridOp, ActionOp, TapOp, TapGridOp, \
LongPressOp, LongPressGridOp, SwipeOp_3, SwipeGridOp, TextOp, RunState, ReflectOp, Decision
def get_id_from_element(elem: Element) -> str:
bounds = elem.attrib["bounds"][1:-1].split("][")
@ -67,8 +80,13 @@ def traverse_xml_tree(xml_path: Path, elem_list: list[AndroidElement], attrib: s
path.pop()
def draw_bbox_multi(img_path: Path, output_path: Path, elem_list: list[AndroidElement], record_mode: bool = False,
dark_mode: bool = False):
def draw_bbox_multi(
img_path: Path,
output_path: Path,
elem_list: list[AndroidElement],
record_mode: bool = False,
dark_mode: bool = False,
):
imgcv = cv2.imread(str(img_path))
count = 1
for elem in elem_list:
@ -85,17 +103,35 @@ def draw_bbox_multi(img_path: Path, output_path: Path, elem_list: list[AndroidEl
color = (0, 0, 250)
else:
color = (0, 250, 0)
imgcv = ps.putBText(imgcv, label, text_offset_x=(left + right) // 2 + 10,
text_offset_y=(top + bottom) // 2 + 10,
vspace=10, hspace=10, font_scale=1, thickness=2, background_RGB=color,
text_RGB=(255, 250, 250), alpha=0.5)
imgcv = ps.putBText(
imgcv,
label,
text_offset_x=(left + right) // 2 + 10,
text_offset_y=(top + bottom) // 2 + 10,
vspace=10,
hspace=10,
font_scale=1,
thickness=2,
background_RGB=color,
text_RGB=(255, 250, 250),
alpha=0.5,
)
else:
text_color = (10, 10, 10) if dark_mode else (255, 250, 250)
bg_color = (255, 250, 250) if dark_mode else (10, 10, 10)
imgcv = ps.putBText(imgcv, label, text_offset_x=(left + right) // 2 + 10,
text_offset_y=(top + bottom) // 2 + 10,
vspace=10, hspace=10, font_scale=1, thickness=2, background_RGB=bg_color,
text_RGB=text_color, alpha=0.5)
imgcv = ps.putBText(
imgcv,
label,
text_offset_x=(left + right) // 2 + 10,
text_offset_y=(top + bottom) // 2 + 10,
vspace=10,
hspace=10,
font_scale=1,
thickness=2,
background_RGB=bg_color,
text_RGB=text_color,
alpha=0.5,
)
except Exception as e:
logger.error(f"ERROR: An exception occurs while labeling the image\n{e}")
count += 1
@ -110,7 +146,7 @@ def draw_grid(img_path: Path, output_path: Path) -> tuple[int, int]:
return i
return -1
image = cv2.imread(img_path)
image = cv2.imread(str(img_path))
height, width, _ = image.shape
color = (255, 116, 113)
unit_height = get_unit_len(height)
@ -130,16 +166,31 @@ def draw_grid(img_path: Path, output_path: Path) -> tuple[int, int]:
right = int((j + 1) * unit_width)
bottom = int((i + 1) * unit_height)
cv2.rectangle(image, (left, top), (right, bottom), color, thick // 2)
cv2.putText(image, str(label), (left + int(unit_width * 0.05) + 3, top + int(unit_height * 0.3) + 3), 0,
int(0.01 * unit_width), (0, 0, 0), thick)
cv2.putText(image, str(label), (left + int(unit_width * 0.05), top + int(unit_height * 0.3)), 0,
int(0.01 * unit_width), color, thick)
cv2.imwrite(output_path, image)
cv2.putText(
image,
str(label),
(left + int(unit_width * 0.05) + 3, top + int(unit_height * 0.3) + 3),
0,
int(0.01 * unit_width),
(0, 0, 0),
thick,
)
cv2.putText(
image,
str(label),
(left + int(unit_width * 0.05), top + int(unit_height * 0.3)),
0,
int(0.01 * unit_width),
color,
thick,
)
cv2.imwrite(str(output_path), image)
return rows, cols
def area_to_xy(area: int, subarea: str, width: int, height: int, rows: int, cols: int) -> tuple[int, int]:
area -= 1
logger.info(f"{cols}")
row, col = area // cols, area % cols
x_0, y_0 = col * (width // cols), row * (height // rows)
if subarea == "top-left":
@ -174,9 +225,11 @@ def reflect_parse_extarct(parsed_json: dict) -> ReflectOp:
if decision not in Decision.values():
op = ReflectOp(param_state=RunState.FAIL)
else:
op = ReflectOp(decision=parsed_json.get("Decision"),
thought=parsed_json.get("Thought"),
documentation=parsed_json.get("Documentation"))
op = ReflectOp(
decision=parsed_json.get("Decision"),
thought=parsed_json.get("Thought"),
documentation=parsed_json.get("Documentation"),
)
return op
@ -237,11 +290,9 @@ def screenshot_parse_extract_with_grid(act_name: str, act: str, last_act: str) -
elif act_name == ActionOp.SWIPE.value:
params = re.findall(r"swipe\((.*?)\)", act)[0].split(",")
params = op_params_clean(params)
op = SwipeGridOp(act_name=act_name,
start_area=params[0],
start_subarea=params[1],
end_area=params[2],
end_subarea=params[3])
op = SwipeGridOp(
act_name=act_name, start_area=params[0], start_subarea=params[1], end_area=params[2], end_subarea=params[3]
)
elif act_name == ActionOp.GRID.value:
op = GridOp(act_name=act_name)
else:

View file

@ -140,14 +140,14 @@ class ActionNode:
instruct_content: BaseModel
def __init__(
self,
key: str,
expected_type: Type,
instruction: str,
example: Any,
content: str = "",
children: dict[str, "ActionNode"] = None,
schema: str = "",
self,
key: str,
expected_type: Type,
instruction: str,
example: Any,
content: str = "",
children: dict[str, "ActionNode"] = None,
schema: str = "",
):
self.key = key
self.expected_type = expected_type
@ -349,14 +349,14 @@ class ActionNode:
after=general_after_log(logger),
)
async def _aask_v1(
self,
prompt: str,
output_class_name: str,
output_data_mapping: dict,
images: Optional[Union[str, list[str]]] = None,
system_msgs: Optional[list[str]] = None,
schema="markdown", # compatible to original format
timeout=3,
self,
prompt: str,
output_class_name: str,
output_data_mapping: dict,
images: Optional[Union[str, list[str]]] = None,
system_msgs: Optional[list[str]] = None,
schema="markdown", # compatible to original format
timeout=3,
) -> (str, BaseModel):
"""Use ActionOutput to wrap the output of aask"""
content = await self.llm.aask(prompt, system_msgs, images=images, timeout=timeout)
@ -406,15 +406,15 @@ class ActionNode:
return self
async def fill(
self,
context,
llm,
schema="json",
mode="auto",
strgy="simple",
images: Optional[Union[str, list[str]]] = None,
timeout=3,
exclude=[],
self,
context,
llm,
schema="json",
mode="auto",
strgy="simple",
images: Optional[Union[str, list[str]]] = None,
timeout=3,
exclude=[],
):
logger.info("进入fill")
"""Fill the node(s) with mode.
@ -560,7 +560,7 @@ class ActionNode:
return nodes_output
async def auto_revise(
self, revise_mode: ReviseMode = ReviseMode.AUTO, template: str = REVISE_TEMPLATE
self, revise_mode: ReviseMode = ReviseMode.AUTO, template: str = REVISE_TEMPLATE
) -> dict[str, str]:
"""revise the value of incorrect keys"""
# generate review comments

View file

@ -1,13 +1,11 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
# TODO
from metagpt.environment.base_env import Environment
from metagpt.environment.android_env.android_env import AndroidEnv
from metagpt.environment.gym_env.gym_env import GymEnv
from metagpt.environment.mincraft_env.mincraft_env import MincraftExtEnv
from metagpt.environment.werewolf_env.werewolf_env import WerewolfEnv
from metagpt.environment.stanford_town_env.stanford_town_env import StanfordTownEnv
# from metagpt.environment.software_env.software_env import SoftwareEnv

View file

@ -9,7 +9,12 @@ from typing import Any, Optional
from pydantic import Field
from metagpt.const import ADB_EXEC_FAIL
from metagpt.environment.base_env import Environment, ExtEnv, mark_as_readable, mark_as_writeable
from metagpt.environment.base_env import (
Environment,
ExtEnv,
mark_as_readable,
mark_as_writeable,
)
class AndroidExtEnv(Environment, ExtEnv):
@ -42,7 +47,7 @@ class AndroidExtEnv(Environment, ExtEnv):
return f"adb -s {self.device_id} "
def execute_adb_with_cmd(self, adb_cmd: str) -> str:
adb_cmd = adb_cmd.replace('\\', '/')
adb_cmd = adb_cmd.replace("\\", "/")
res = subprocess.run(adb_cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
exec_res = ADB_EXEC_FAIL
if not res.returncode:

View file

@ -5,7 +5,9 @@
import asyncio
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Set, Union
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator
from metagpt.context import Context
from metagpt.environment.api.env_api import (
EnvAPIAbstract,
@ -43,12 +45,14 @@ def mark_as_writeable(func):
env_write_api_registry[func.__name__] = get_function_schema(func)
return func
class ExtEnv(BaseModel):
"""External Env to intergate actual game environment"""
write_api_registry: WriteAPIRegistry = Field(default_factory=WriteAPIRegistry, exclude=True)
read_api_registry: ReadAPIRegistry = Field(default_factory=ReadAPIRegistry, exclude=True)
class ExtEnv(BaseModel):
"""External Env to intergate actual game environment"""
@ -97,6 +101,7 @@ class ExtEnv(BaseModel):
return res
class Environment(ExtEnv):
"""环境,承载一批角色,角色可以向环境发布消息,可以被其他角色观察到
Environment, hosting a batch of roles, roles can publish messages to the environment, and can be observed by other roles
@ -212,9 +217,3 @@ class Environment(ExtEnv):
Environment.model_rebuild()

View file

@ -41,6 +41,7 @@ from metagpt.const import (
MESSAGE_ROUTE_FROM,
MESSAGE_ROUTE_TO,
MESSAGE_ROUTE_TO_ALL,
PRDS_FILE_REPO,
SYSTEM_DESIGN_FILE_REPO,
TASK_FILE_REPO,
)
@ -328,6 +329,200 @@ class AIMessage(Message):
super().__init__(content=content, role="assistant")
class Task(BaseModel):
task_id: str = ""
dependent_task_ids: list[str] = [] # Tasks prerequisite to this Task
instruction: str = ""
task_type: str = ""
code: str = ""
result: str = ""
is_success: bool = False
is_finished: bool = False
def reset(self):
self.code = ""
self.result = ""
self.is_success = False
self.is_finished = False
def update_task_result(self, task_result: TaskResult):
self.code = task_result.code
self.result = task_result.result
self.is_success = task_result.is_success
class TaskResult(BaseModel):
"""Result of taking a task, with result and is_success required to be filled"""
code: str = ""
result: str
is_success: bool
class Plan(BaseModel):
goal: str
context: str = ""
tasks: list[Task] = []
task_map: dict[str, Task] = {}
current_task_id: str = ""
def _topological_sort(self, tasks: list[Task]):
task_map = {task.task_id: task for task in tasks}
dependencies = {task.task_id: set(task.dependent_task_ids) for task in tasks}
sorted_tasks = []
visited = set()
def visit(task_id):
if task_id in visited:
return
visited.add(task_id)
for dependent_id in dependencies.get(task_id, []):
visit(dependent_id)
sorted_tasks.append(task_map[task_id])
for task in tasks:
visit(task.task_id)
return sorted_tasks
def add_tasks(self, tasks: list[Task]):
"""
Integrates new tasks into the existing plan, ensuring dependency order is maintained.
This method performs two primary functions based on the current state of the task list:
1. If there are no existing tasks, it topologically sorts the provided tasks to ensure
correct execution order based on dependencies, and sets these as the current tasks.
2. If there are existing tasks, it merges the new tasks with the existing ones. It maintains
any common prefix of tasks (based on task_id and instruction) and appends the remainder
of the new tasks. The current task is updated to the first unfinished task in this merged list.
Args:
tasks (list[Task]): A list of tasks (may be unordered) to add to the plan.
Returns:
None: The method updates the internal state of the plan but does not return anything.
"""
if not tasks:
return
# Topologically sort the new tasks to ensure correct dependency order
new_tasks = self._topological_sort(tasks)
if not self.tasks:
# If there are no existing tasks, set the new tasks as the current tasks
self.tasks = new_tasks
else:
# Find the length of the common prefix between existing and new tasks
prefix_length = 0
for old_task, new_task in zip(self.tasks, new_tasks):
if old_task.task_id != new_task.task_id or old_task.instruction != new_task.instruction:
break
prefix_length += 1
# Combine the common prefix with the remainder of the new tasks
final_tasks = self.tasks[:prefix_length] + new_tasks[prefix_length:]
self.tasks = final_tasks
# Update current_task_id to the first unfinished task in the merged list
self._update_current_task()
# Update the task map for quick access to tasks by ID
self.task_map = {task.task_id: task for task in self.tasks}
def reset_task(self, task_id: str):
"""
Clear code and result of the task based on task_id, and set the task as unfinished.
Args:
task_id (str): The ID of the task to be reset.
Returns:
None
"""
if task_id in self.task_map:
task = self.task_map[task_id]
task.reset()
def replace_task(self, new_task: Task):
"""
Replace an existing task with the new input task based on task_id, and reset all tasks depending on it.
Args:
new_task (Task): The new task that will replace an existing one.
Returns:
None
"""
assert new_task.task_id in self.task_map
# Replace the task in the task map and the task list
self.task_map[new_task.task_id] = new_task
for i, task in enumerate(self.tasks):
if task.task_id == new_task.task_id:
self.tasks[i] = new_task
break
# Reset dependent tasks
for task in self.tasks:
if new_task.task_id in task.dependent_task_ids:
self.reset_task(task.task_id)
def append_task(self, new_task: Task):
"""
Append a new task to the end of existing task sequences
Args:
new_task (Task): The new task to be appended to the existing task sequence
Returns:
None
"""
assert not self.has_task_id(new_task.task_id), "Task already in current plan, use replace_task instead"
assert all(
[self.has_task_id(dep_id) for dep_id in new_task.dependent_task_ids]
), "New task has unknown dependencies"
# Existing tasks do not depend on the new task, it's fine to put it to the end of the sorted task sequence
self.tasks.append(new_task)
self.task_map[new_task.task_id] = new_task
self._update_current_task()
def has_task_id(self, task_id: str) -> bool:
return task_id in self.task_map
def _update_current_task(self):
current_task_id = ""
for task in self.tasks:
if not task.is_finished:
current_task_id = task.task_id
break
self.current_task_id = current_task_id # all tasks finished
@property
def current_task(self) -> Task:
"""Find current task to execute
Returns:
Task: the current task to be executed
"""
return self.task_map.get(self.current_task_id, None)
def finish_current_task(self):
"""Finish current task, set Task.is_finished=True, set current task to next task"""
if self.current_task_id:
self.current_task.is_finished = True
self._update_current_task() # set to next task
def get_finished_tasks(self) -> list[Task]:
"""return all finished tasks in correct linearized order
Returns:
list[Task]: list of finished tasks
"""
return [task for task in self.tasks if task.is_finished]
class MessageQueue(BaseModel):
"""Message queue which supports asynchronous updates."""
@ -417,6 +612,7 @@ class CodingContext(BaseContext):
design_doc: Optional[Document] = None
task_doc: Optional[Document] = None
code_doc: Optional[Document] = None
code_plan_and_change_doc: Optional[Document] = None
class TestingContext(BaseContext):
@ -470,6 +666,29 @@ class BugFixContext(BaseContext):
filename: str = ""
class CodePlanAndChangeContext(BaseModel):
requirement: str = ""
prd_filename: str = ""
design_filename: str = ""
task_filename: str = ""
@staticmethod
def loads(filenames: List, **kwargs) -> CodePlanAndChangeContext:
ctx = CodePlanAndChangeContext(requirement=kwargs.get("requirement", ""))
for filename in filenames:
filename = Path(filename)
if filename.is_relative_to(PRDS_FILE_REPO):
ctx.prd_filename = filename.name
continue
if filename.is_relative_to(SYSTEM_DESIGN_FILE_REPO):
ctx.design_filename = filename.name
continue
if filename.is_relative_to(TASK_FILE_REPO):
ctx.task_filename = filename.name
continue
return ctx
# mermaid class view
class ClassMeta(BaseModel):
name: str = ""

View file

@ -76,7 +76,7 @@ class Team(BaseModel):
def hire(self, roles: list[Role]):
"""Hire roles to cooperate"""
only_role = roles[0]
roles[0]
self.env.add_roles(roles)
@property
@ -134,4 +134,4 @@ class Team(BaseModel):
await self.env.run()
self.env.archive(auto_archive)
return self.env.history
return self.env.history

View file

@ -24,13 +24,16 @@ import re
import sys
import traceback
import typing
from io import BytesIO
from pathlib import Path
from typing import Any, List, Tuple, Union, Callable
from typing import Any, Callable, List, Tuple, Union
import aiofiles
import loguru
import requests
from PIL import Image
from pydantic_core import to_jsonable_python
from tenacity import RetryCallState, _utils
from tenacity import RetryCallState, RetryError, _utils
from metagpt.const import MESSAGE_ROUTE_TO_ALL
from metagpt.logs import logger
@ -214,7 +217,7 @@ class OutputParser:
if start_index != -1 and end_index != -1:
# Extract the structure part
structure_text = text[start_index: end_index + 1]
structure_text = text[start_index : end_index + 1]
try:
# Attempt to convert the text to a Python data type using ast.literal_eval
@ -358,6 +361,31 @@ def parse_recipient(text):
return ""
def create_func_call_config(func_schema: dict) -> dict:
"""Create new function call config"""
tools = [{"type": "function", "function": func_schema}]
tool_choice = {"type": "function", "function": {"name": func_schema["name"]}}
return {
"tools": tools,
"tool_choice": tool_choice,
}
def remove_comments(code_str: str) -> str:
"""Remove comments from code."""
pattern = r"(\".*?\"|\'.*?\')|(\#.*?$)"
def replace_func(match):
if match.group(2) is not None:
return ""
else:
return match.group(1)
clean_code = re.sub(pattern, replace_func, code_str, flags=re.MULTILINE)
clean_code = os.linesep.join([s.rstrip() for s in clean_code.splitlines() if s.strip()])
return clean_code
def get_class_name(cls) -> str:
"""Return class name"""
return f"{cls.__module__}.{cls.__name__}"
@ -466,13 +494,13 @@ def read_json_file(json_file: str, encoding="utf-8") -> list[Any]:
return data
def write_json_file(json_file: str, data: list, encoding=None):
def write_json_file(json_file: str, data: list, encoding: str = None, indent: int = 4):
folder_path = Path(json_file).parent
if not folder_path.exists():
folder_path.mkdir(parents=True, exist_ok=True)
with open(json_file, "w", encoding=encoding) as fout:
json.dump(data, fout, ensure_ascii=False, indent=4, default=to_jsonable_python)
json.dump(data, fout, ensure_ascii=False, indent=indent, default=to_jsonable_python)
def read_csv_to_list(curr_file: str, header=False, strip_trail=True):
@ -538,7 +566,7 @@ def role_raise_decorator(func):
self.rc.memory.delete(self.latest_observed_msg)
# raise again to make it captured outside
raise Exception(format_trackback_info(limit=None))
except Exception:
except Exception as e:
if self.latest_observed_msg:
logger.warning(
"There is a exception in role's execution, in order to resume, "
@ -547,6 +575,12 @@ def role_raise_decorator(func):
# remove role newest observed msg to make it observed again
self.rc.memory.delete(self.latest_observed_msg)
# raise again to make it captured outside
if isinstance(e, RetryError):
last_error = e.last_attempt._exception
name = any_to_str(last_error)
if re.match(r"^openai\.", name) or re.match(r"^httpx\.", name):
raise last_error
raise Exception(format_trackback_info(limit=None))
return wrapper
@ -606,6 +640,39 @@ def is_coroutine_func(func: Callable) -> bool:
return inspect.iscoroutinefunction(func)
def encode_image(image_path: Path, encoding: str = "utf-8") -> str:
with open(str(image_path), "rb") as image_file:
return base64.b64encode(image_file.read()).decode(encoding)
def load_mc_skills_code(skill_names: list[str] = None, skills_dir: Path = None) -> list[str]:
"""load mincraft skill from js files"""
if not skills_dir:
skills_dir = Path(__file__).parent.absolute()
if skill_names is None:
skill_names = [skill[:-3] for skill in os.listdir(f"{skills_dir}") if skill.endswith(".js")]
skills = [skills_dir.joinpath(f"{skill_name}.js").read_text() for skill_name in skill_names]
return skills
def encode_image(image_path_or_pil: Union[Path, Image], encoding: str = "utf-8") -> str:
"""encode image from file or PIL.Image into base64"""
if isinstance(image_path_or_pil, Image.Image):
buffer = BytesIO()
image_path_or_pil.save(buffer, format="JPEG")
bytes_data = buffer.getvalue()
else:
if not image_path_or_pil.exists():
raise FileNotFoundError(f"{image_path_or_pil} not exists")
with open(str(image_path_or_pil), "rb") as image_file:
bytes_data = image_file.read()
return base64.b64encode(bytes_data).decode(encoding)
def decode_image(img_url_or_b64: str) -> Image:
"""decode image from url or base64 into PIL.Image"""
if img_url_or_b64.startswith("http"):
# image http(s) url
resp = requests.get(img_url_or_b64)
img = Image.open(BytesIO(resp.content))
else:
# image b64_json
b64_data = re.sub("^data:image/.+;base64,", "", img_url_or_b64)
img_data = BytesIO(base64.b64decode(b64_data))
img = Image.open(img_data)
return img

View file

@ -4,8 +4,8 @@
import pytest
from metagpt.environment.base_env import Env, mark_as_writeable, mark_as_readable
from metagpt.environment.api.env_api import EnvAPIAbstract
from metagpt.environment.base_env import Env, mark_as_readable, mark_as_writeable
class ForTestEnv(Env):