mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-14 15:25:17 +02:00
Update env & test code
This commit is contained in:
parent
07c360b9c7
commit
cfc0cc1fa5
12 changed files with 248 additions and 140 deletions
|
|
@ -16,3 +16,6 @@ ## Free Your Hands
|
|||
### By Text
|
||||
|
||||
### By Voice
|
||||
|
||||
## Run It
|
||||
python run_assistant.py "Create a contact in Contacts App named zjy with a phone number +86 18831933368" --mode "manual" --app-name "Contacts"
|
||||
|
|
@ -33,7 +33,8 @@ class ManualRecord(Action):
|
|||
screenshot_after_path: Path = ""
|
||||
xml_path: Path = ""
|
||||
|
||||
async def run(self, demo_name: str, task_desc: str,task_dir: Path, env: AndroidEnv):
|
||||
# async def run(self, demo_name: str, task_desc: str,task_dir: Path, env: AndroidEnv):
|
||||
async def run(self, task_desc: str, task_dir: Path, env: AndroidEnv):
|
||||
|
||||
self.record_path = Path(task_dir) / "record.txt"
|
||||
self.task_desc_path = Path(task_dir) / "task_desc.txt"
|
||||
|
|
@ -53,16 +54,18 @@ class ManualRecord(Action):
|
|||
step = 0
|
||||
while True:
|
||||
step += 1
|
||||
screenshot_path: Path = env.observe(
|
||||
screenshot_path: Path = await env.observe(
|
||||
EnvAPIAbstract(
|
||||
api_name="get_screenshot",
|
||||
kwargs={"ss_name": f"{demo_name}_{step}", "local_save_dir": self.screenshot_before_path}
|
||||
# kwargs={"ss_name": f"{demo_name}_{step}", "local_save_dir": self.screenshot_before_path}
|
||||
kwargs={"ss_name": f"{step}", "local_save_dir": self.screenshot_before_path}
|
||||
)
|
||||
)
|
||||
xml_path: Path = env.observe(
|
||||
xml_path: Path = await env.observe(
|
||||
EnvAPIAbstract(
|
||||
api_name="get_xml",
|
||||
kwargs={"xml_name": f"{demo_name}_{step}", "local_save_dir": self.xml_path}
|
||||
# kwargs={"xml_name": f"{demo_name}_{step}", "local_save_dir": self.xml_path}
|
||||
kwargs={"xml_name": f"{step}", "local_save_dir": self.xml_path}
|
||||
)
|
||||
)
|
||||
if not screenshot_path.exists() or not xml_path.exists():
|
||||
|
|
@ -86,14 +89,13 @@ class ManualRecord(Action):
|
|||
bbox = e.bbox
|
||||
center_ = (bbox[0][0] + bbox[1][0]) // 2, (bbox[0][1] + bbox[1][1]) // 2
|
||||
dist = (abs(center[0] - center_[0]) ** 2 + abs(center[1] - center_[1]) ** 2) ** 0.5
|
||||
# TODO Modify config to default 30. It should be modified back config after single action test
|
||||
# if dist <= config.get_other("min_dist"):
|
||||
if dist <= 30:
|
||||
if dist <= config.get_other("min_dist"):
|
||||
close = True
|
||||
break
|
||||
if not close:
|
||||
elem_list.append(elem)
|
||||
screenshot_labeled_path = Path(self.screenshot_after_path).joinpath(f"{demo_name}_{step}_labeled.png")
|
||||
screenshot_labeled_path = Path(self.screenshot_after_path).joinpath(f"{step}_labeled.png")
|
||||
# screenshot_labeled_path = Path(self.screenshot_after_path).joinpath(f"{demo_name}_{step}_labeled.png")
|
||||
labeled_img = draw_bbox_multi(screenshot_path, screenshot_labeled_path, elem_list)
|
||||
|
||||
cv2.imshow("image", labeled_img)
|
||||
|
|
@ -142,7 +144,7 @@ class ManualRecord(Action):
|
|||
user_input = ""
|
||||
while not user_input:
|
||||
user_input = input()
|
||||
env.step(EnvAPIAbstract(api_name="user_input", kwargs={"input_txt": user_input}))
|
||||
await env.step(EnvAPIAbstract(api_name="user_input", kwargs={"input_txt": user_input}))
|
||||
record_file.write(f'text({input_area}:sep:"{user_input}"):::{elem_list[int(input_area) - 1].uid}\n')
|
||||
elif user_input.lower() == ActionOp.LONG_PRESS.value:
|
||||
logger.info(
|
||||
|
|
|
|||
|
|
@ -38,9 +38,9 @@ class ParseRecord(Action):
|
|||
screenshot_before_path: Path = ""
|
||||
screenshot_after_path: Path = ""
|
||||
|
||||
async def run(self, app_name: str, demo_name: str, task_dir: Path, docs_dir: Path, env: AndroidEnv):
|
||||
if not docs_dir.exists():
|
||||
docs_dir.mkdir(parents=True, exist_ok=True)
|
||||
# async def run(self, app_name: str, demo_name: str, task_dir: Path, docs_dir: Path, env: AndroidEnv):
|
||||
async def run(self, app_name: str, task_dir: Path, docs_dir: Path, env: AndroidEnv):
|
||||
docs_dir.mkdir(parents=True, exist_ok=True)
|
||||
doc_count = 0
|
||||
self.record_path = Path(task_dir) / "record.txt"
|
||||
self.task_desc_path = Path(task_dir) / "task_desc.txt"
|
||||
|
|
@ -51,8 +51,10 @@ class ParseRecord(Action):
|
|||
record_step_count = len(record_file.readlines()) - 1
|
||||
record_file.seek(0)
|
||||
for step in range(1, record_step_count + 1):
|
||||
img_before_base64 = encode_image(self.screenshot_after_path.joinpath(f"{demo_name}_{step}_labeled.png"))
|
||||
img_after_base64 = encode_image(self.screenshot_after_path.joinpath(f"{demo_name}_{step + 1}_labeled.png"))
|
||||
# img_before_base64 = encode_image(self.screenshot_after_path.joinpath(f"{demo_name}_{step}_labeled.png"))
|
||||
# img_after_base64 = encode_image(self.screenshot_after_path.joinpath(f"{demo_name}_{step + 1}_labeled.png"))
|
||||
img_before_base64 = encode_image(self.screenshot_after_path.joinpath(f"{step}_labeled.png"))
|
||||
img_after_base64 = encode_image(self.screenshot_after_path.joinpath(f"{step + 1}_labeled.png"))
|
||||
rec = record_file.readline().strip()
|
||||
action, resource_id = rec.split(":::")
|
||||
action_type = action.split("(")[0]
|
||||
|
|
@ -110,8 +112,8 @@ class ParseRecord(Action):
|
|||
)
|
||||
if "error" in node.content:
|
||||
return AndroidActionOutput(action_state=RunState.FAIL)
|
||||
|
||||
log_path = task_dir.joinpath(f"log_{app_name}_{demo_name}.txt")
|
||||
# log_path = task_dir.joinpath(f"log_{app_name}_{demo_name}.txt")
|
||||
log_path = task_dir.joinpath(f"log_{app_name}.txt")
|
||||
prompt = node.compile(context=context, schema="json", mode="auto")
|
||||
msg = node.content
|
||||
doc_content[action_type] = msg
|
||||
|
|
|
|||
|
|
@ -92,13 +92,13 @@ class ScreenshotParse(Action):
|
|||
if not path.exists():
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
screenshot_path: Path = env.observe(
|
||||
screenshot_path: Path = await env.observe(
|
||||
EnvAPIAbstract(
|
||||
api_name="get_screenshot",
|
||||
kwargs={"ss_name": f"{round_count}_before", "local_save_dir": task_dir}
|
||||
)
|
||||
)
|
||||
xml_path: Path = env.observe(
|
||||
xml_path: Path = await env.observe(
|
||||
EnvAPIAbstract(
|
||||
api_name="get_xml",
|
||||
kwargs={"xml_name": f"{round_count}", "local_save_dir": task_dir}
|
||||
|
|
@ -121,9 +121,7 @@ class ScreenshotParse(Action):
|
|||
bbox = e.bbox
|
||||
center_ = (bbox[0][0] + bbox[1][0]) // 2, (bbox[0][1] + bbox[1][1]) // 2
|
||||
dist = (abs(center[0] - center_[0]) ** 2 + abs(center[1] - center_[1]) ** 2) ** 0.5
|
||||
# TODO Modify config to default 30. It should be modified back config after single action test
|
||||
# if dist <= config.get_other("min_dist"):
|
||||
if dist <= 30:
|
||||
if dist <= config.get_other("min_dist"):
|
||||
close = True
|
||||
break
|
||||
if not close:
|
||||
|
|
@ -156,21 +154,21 @@ class ScreenshotParse(Action):
|
|||
|
||||
if isinstance(op_param, TapOp):
|
||||
x, y = elem_bbox_to_xy(elem_list[op_param.area - 1].bbox)
|
||||
res = env.step(EnvAPIAbstract(api_name="system_tap", kwargs={"x": x, "y": y}))
|
||||
res = await env.step(EnvAPIAbstract(api_name="system_tap", kwargs={"x": x, "y": y}))
|
||||
if res == ADB_EXEC_FAIL:
|
||||
return AndroidActionOutput(action_state=RunState.FAIL)
|
||||
elif isinstance(op_param, TextOp):
|
||||
res = env.step(EnvAPIAbstract(api_name="user_input", kwargs={"input_txt": op_param.input_str}))
|
||||
res = await env.step(EnvAPIAbstract(api_name="user_input", kwargs={"input_txt": op_param.input_str}))
|
||||
if res == ADB_EXEC_FAIL:
|
||||
return AndroidActionOutput(action_state=RunState.FAIL)
|
||||
elif isinstance(op_param, LongPressOp):
|
||||
x, y = elem_bbox_to_xy(elem_list[op_param.area - 1].bbox)
|
||||
res = env.step(EnvAPIAbstract(api_name="user_longpress", kwargs={"x": x, "y": y}))
|
||||
res = await env.step(EnvAPIAbstract(api_name="user_longpress", kwargs={"x": x, "y": y}))
|
||||
if res == ADB_EXEC_FAIL:
|
||||
return AndroidActionOutput(action_state=RunState.FAIL)
|
||||
elif isinstance(op_param, SwipeOp_3):
|
||||
x, y = elem_bbox_to_xy(elem_list[op_param.area - 1].bbox)
|
||||
res = env.step(
|
||||
res = await env.step(
|
||||
EnvAPIAbstract(
|
||||
api_name="user_swipe",
|
||||
kwargs={"x": x, "y": y, "orient": op_param.swipe_orient, "dist": op_param.dist}
|
||||
|
|
@ -183,18 +181,18 @@ class ScreenshotParse(Action):
|
|||
elif isinstance(op_param, TapGridOp) or isinstance(op_param, LongPressGridOp):
|
||||
x, y = area_to_xy(op_param.area, op_param.subarea, env.width, env.height, env.rows, env.cols)
|
||||
if isinstance(op_param, TapGridOp):
|
||||
res = env.step(EnvAPIAbstract(api_name="system_tap", kwargs={"x": x, "y": y}))
|
||||
res = await env.step(EnvAPIAbstract(api_name="system_tap", kwargs={"x": x, "y": y}))
|
||||
if res == ADB_EXEC_FAIL:
|
||||
return AndroidActionOutput(action_state=RunState.FAIL)
|
||||
else:
|
||||
# LongPressGridOp
|
||||
res = env.step(EnvAPIAbstract(api_name="user_longpress", kwargs={"x": x, "y": y}))
|
||||
res = await env.step(EnvAPIAbstract(api_name="user_longpress", kwargs={"x": x, "y": y}))
|
||||
if res == ADB_EXEC_FAIL:
|
||||
return AndroidActionOutput(action_state=RunState.FAIL)
|
||||
elif isinstance(op_param, SwipeGridOp):
|
||||
start_x, start_y = area_to_xy(op_param.start_area, op_param.start_subarea, width, height, rows, cols)
|
||||
end_x, end_y = area_to_xy(op_param.end_area, op_param.end_subarea, width, height, rows, cols)
|
||||
res = env.step(
|
||||
res = await env.step(
|
||||
EnvAPIAbstract(api_name="user_swipe_to", kwargs={"start": (start_x, start_y), "end": (end_x, end_y)}))
|
||||
if res == ADB_EXEC_FAIL:
|
||||
return AndroidActionOutput(action_state=RunState.FAIL)
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ from examples.andriod_assistant.utils.schema import (
|
|||
ReflectLogItem,
|
||||
RunState,
|
||||
SwipeOp,
|
||||
SwipeOp_3,
|
||||
TapOp,
|
||||
TextOp,
|
||||
)
|
||||
|
|
@ -70,12 +71,12 @@ class SelfLearnAndReflect(Action):
|
|||
async def run_self_learn(
|
||||
self, round_count: int, task_desc: str, last_act: str, task_dir: Path, env: AndroidEnv
|
||||
) -> AndroidActionOutput:
|
||||
screenshot_path: Path = env.observe(
|
||||
screenshot_path: Path = await env.observe(
|
||||
EnvAPIAbstract(
|
||||
api_name="get_screenshot", kwargs={"ss_name": f"{round_count}_before", "local_save_dir": task_dir}
|
||||
)
|
||||
)
|
||||
xml_path: Path = env.observe(
|
||||
xml_path: Path = await env.observe(
|
||||
EnvAPIAbstract(api_name="get_xml", kwargs={"xml_name": f"{round_count}", "local_save_dir": task_dir})
|
||||
)
|
||||
if not screenshot_path.exists() or not xml_path.exists():
|
||||
|
|
@ -100,9 +101,7 @@ class SelfLearnAndReflect(Action):
|
|||
bbox = e.bbox
|
||||
center_ = (bbox[0][0] + bbox[1][0]) // 2, (bbox[0][1] + bbox[1][1]) // 2
|
||||
dist = (abs(center[0] - center_[0]) ** 2 + abs(center[1] - center_[1]) ** 2) ** 0.5
|
||||
# TODO Modify config to default 30. It should be modified back config after single action test
|
||||
# if dist <= config.get_other("min_dist"):
|
||||
if dist <= 30:
|
||||
if dist <= config.get_other("min_dist"):
|
||||
close = True
|
||||
break
|
||||
if not close:
|
||||
|
|
@ -125,7 +124,6 @@ class SelfLearnAndReflect(Action):
|
|||
OpLogItem(step=round_count, prompt=prompt, image=str(screenshot_before_labeled_path), response=node.content)
|
||||
op_param = screenshot_parse_extract(node.instruct_content.model_dump(), grid_on=False)
|
||||
# TODO Modify Op_param. When op_param.action is FINISH, how to solve this ?
|
||||
logger.info(op_param)
|
||||
if op_param.param_state == RunState.FINISH:
|
||||
return AndroidActionOutput(action_state=RunState.FINISH)
|
||||
if op_param.param_state == RunState.FAIL:
|
||||
|
|
@ -134,26 +132,26 @@ class SelfLearnAndReflect(Action):
|
|||
if isinstance(op_param, TapOp):
|
||||
self.ui_area = op_param.area
|
||||
x, y = elem_bbox_to_xy(elem_list[op_param.area - 1].bbox)
|
||||
res = env.step(EnvAPIAbstract(api_name="system_tap", kwargs={"x": x, "y": y}))
|
||||
res = await env.step(EnvAPIAbstract(api_name="system_tap", kwargs={"x": x, "y": y}))
|
||||
if res == ADB_EXEC_FAIL:
|
||||
return AndroidActionOutput(action_state=RunState.FAIL)
|
||||
elif isinstance(op_param, TextOp):
|
||||
res = env.step(EnvAPIAbstract(api_name="user_input", kwargs={"input_txt": op_param.input_str}))
|
||||
res = await env.step(EnvAPIAbstract(api_name="user_input", kwargs={"input_txt": op_param.input_str}))
|
||||
if res == ADB_EXEC_FAIL:
|
||||
return AndroidActionOutput(action_state=RunState.FAIL)
|
||||
elif isinstance(op_param, LongPressOp):
|
||||
self.ui_area = op_param.area
|
||||
x, y = elem_bbox_to_xy(elem_list[op_param.area - 1].bbox)
|
||||
res = env.step(EnvAPIAbstract(api_name="user_longpress", kwargs={"x": x, "y": y}))
|
||||
res = await env.step(EnvAPIAbstract(api_name="user_longpress", kwargs={"x": x, "y": y}))
|
||||
if res == ADB_EXEC_FAIL:
|
||||
return AndroidActionOutput(action_state=RunState.FAIL)
|
||||
elif isinstance(op_param, SwipeOp):
|
||||
elif isinstance(op_param, SwipeOp_3):
|
||||
self.ui_area = op_param.area
|
||||
self.swipe_orient = op_param.swipe_orient
|
||||
x, y = elem_bbox_to_xy(elem_list[op_param.area - 1].bbox)
|
||||
res = env.step(
|
||||
res = await env.step(
|
||||
EnvAPIAbstract(
|
||||
"user_swipe", kwargs={"x": x, "y": y, "orient": op_param.swipe_orient, "dist": op_param.dist}
|
||||
api_name="user_swipe", kwargs={"x": x, "y": y, "orient": op_param.swipe_orient, "dist": op_param.dist}
|
||||
)
|
||||
)
|
||||
if res == ADB_EXEC_FAIL:
|
||||
|
|
@ -167,8 +165,7 @@ class SelfLearnAndReflect(Action):
|
|||
async def run_reflect(
|
||||
self, round_count: int, task_desc: str, last_act: str, task_dir: Path, docs_dir: Path, env: AndroidEnv
|
||||
) -> AndroidActionOutput:
|
||||
logger.info("run_reflect")
|
||||
screenshot_path: Path = env.observe(
|
||||
screenshot_path: Path = await env.observe(
|
||||
EnvAPIAbstract(
|
||||
api_name="get_screenshot", kwargs={"ss_name": f"{round_count}_after", "local_save_dir": task_dir}
|
||||
)
|
||||
|
|
@ -180,7 +177,6 @@ class SelfLearnAndReflect(Action):
|
|||
draw_bbox_multi(screenshot_path, screenshot_after_labeled_path, elem_list=self.elem_list)
|
||||
img_base64 = encode_image(screenshot_after_labeled_path)
|
||||
|
||||
logger.info(f"act_name: {self.act_name}")
|
||||
if self.act_name == ActionOp.TAP.value:
|
||||
action = "tapping"
|
||||
elif self.act_name == ActionOp.LONG_PRESS.value:
|
||||
|
|
@ -225,7 +221,7 @@ class SelfLearnAndReflect(Action):
|
|||
self.useless_list.append(resource_id)
|
||||
last_act = "NONE"
|
||||
if op_param.decision == Decision.BACK.value:
|
||||
res = env.step(EnvAPIAbstract(api_name="system_back"))
|
||||
res = await env.step(EnvAPIAbstract(api_name="system_back"))
|
||||
if res == ADB_EXEC_FAIL:
|
||||
return AndroidActionOutput(action_state=RunState.FAIL)
|
||||
doc = op_param.documentation
|
||||
|
|
|
|||
|
|
@ -37,7 +37,8 @@ class AndroidAssistant(Role):
|
|||
self._watch([UserRequirement])
|
||||
|
||||
app_name = config.get_other("app_name", "demo")
|
||||
data_dir = Path(__file__).parent.joinpath("..", "output")
|
||||
curr_path = Path(__file__).parent
|
||||
data_dir = curr_path.joinpath("..", "output")
|
||||
cur_datetime = datetime.fromtimestamp(int(time.time())).strftime("%Y-%m-%d_%H-%M-%S")
|
||||
|
||||
"""Firstly, we decide the state with user config, further, we can do it automatically, like if it's new app,
|
||||
|
|
@ -67,39 +68,57 @@ class AndroidAssistant(Role):
|
|||
self._set_react_mode(RoleReactMode.BY_ORDER)
|
||||
|
||||
def _check_dir(self):
|
||||
self.task_dir.mkdir(exist_ok=True)
|
||||
self.docs_dir.mkdir(exist_ok=True)
|
||||
self.task_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.docs_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async def react(self) -> Message:
|
||||
self.round_count += 1
|
||||
super().react()
|
||||
result = await super().react()
|
||||
print(f"react result {result}")
|
||||
return result
|
||||
|
||||
async def _act(self) -> Message:
|
||||
# Question: How to achieve self_learn's loop action ?
|
||||
logger.info(f"{self._setting}: to do {self.rc.todo}({self.rc.todo.name})")
|
||||
todo = self.rc.todo
|
||||
send_to = ""
|
||||
if isinstance(todo, ManualRecord):
|
||||
resp = await todo.run()
|
||||
resp = await todo.run(
|
||||
# demo_name="",
|
||||
task_dir=self.task_dir,
|
||||
task_desc=self.task_desc,
|
||||
env=self.rc.env
|
||||
)
|
||||
elif isinstance(todo, ParseRecord):
|
||||
resp = await todo.run()
|
||||
resp = await todo.run(
|
||||
app_name=config.get_other("app_name", "demo"),
|
||||
task_dir=self.task_dir,
|
||||
docs_dir=self.docs_dir,
|
||||
env=self.rc.env
|
||||
)
|
||||
elif isinstance(todo, SelfLearnAndReflect):
|
||||
resp = await todo.run(round_count=self.round_count,
|
||||
task_desc=self.task_desc,
|
||||
last_act=self.last_act,
|
||||
task_dir=self.task_dir,
|
||||
docs_dir=self.docs_dir,
|
||||
env=self.rc.env)
|
||||
resp = await todo.run(
|
||||
round_count=self.round_count,
|
||||
task_desc=self.task_desc,
|
||||
last_act=self.last_act,
|
||||
task_dir=self.task_dir,
|
||||
docs_dir=self.docs_dir,
|
||||
env=self.rc.env
|
||||
)
|
||||
if resp.action_state == RunState.SUCCESS:
|
||||
self.last_act = resp.data.get("last_act")
|
||||
send_to = self.name
|
||||
|
||||
elif isinstance(todo, ScreenshotParse):
|
||||
resp = await todo.run(round_count=self.round_count,
|
||||
task_desc=self.task_desc,
|
||||
last_act=self.last_act,
|
||||
task_dir=self.task_dir,
|
||||
grid_on=self.grid_on,
|
||||
env=self.rc.env)
|
||||
resp = await todo.run(
|
||||
round_count=self.round_count,
|
||||
task_desc=self.task_desc,
|
||||
last_act=self.last_act,
|
||||
task_dir=self.task_dir,
|
||||
docs_dir=self.docs_dir,
|
||||
grid_on=self.grid_on,
|
||||
env=self.rc.env
|
||||
)
|
||||
if resp.action_state == RunState.SUCCESS:
|
||||
self.grid_on = resp.data.get("grid_on")
|
||||
send_to = self.name
|
||||
|
|
|
|||
|
|
@ -50,12 +50,13 @@ def startup(
|
|||
)
|
||||
|
||||
team = Team(env=AndroidEnv())
|
||||
team.hire([AndroidAssistant])
|
||||
team.hire([AndroidAssistant()])
|
||||
team.invest(investment)
|
||||
company.run_project(idea=task_desc)
|
||||
team.run_project(idea=task_desc)
|
||||
|
||||
asyncio.run(team.run(n_round=n_round))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
# Command python run_assistant.py "Create a contact in Contacts App named zjy with a phone number +86 18831933368"
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : test on android emulator
|
||||
# @Desc : test on android emulator action. After Modify Role Test, this script is discarded.
|
||||
import asyncio
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
|
@ -50,14 +50,14 @@ if __name__ == "__main__":
|
|||
env=test_env_self_learn_android
|
||||
),
|
||||
test_manual_record.run(
|
||||
demo_name=DEMO_NAME,
|
||||
# demo_name=DEMO_NAME,
|
||||
task_dir=TASK_PATH / "demos" / f"manual_record_{DEMO_NAME}",
|
||||
task_desc="Create a contact in Contacts App named zjy with a phone number +86 18831933368 ",
|
||||
env=test_env_manual_learn_android
|
||||
),
|
||||
test_manual_parse.run(
|
||||
app_name="Contacts",
|
||||
demo_name=DEMO_NAME,
|
||||
# demo_name=DEMO_NAME,
|
||||
task_dir=TASK_PATH / "demos" / f"manual_record_{DEMO_NAME}", # 修要修改
|
||||
docs_dir=PARSE_RECORD_DOC_PATH, # 需要修改
|
||||
env=test_env_manual_learn_android
|
||||
|
|
|
|||
|
|
@ -122,8 +122,12 @@ class Config(CLIParams, YamlModel):
|
|||
def set_other(self, other: dict):
|
||||
self.other = other
|
||||
|
||||
def get_other(self, key: str):
|
||||
return self.other.get(key)
|
||||
def get_other(self, key: str, default_value: str = None):
|
||||
if default_value is None:
|
||||
return self.other.get(key)
|
||||
else:
|
||||
return self.other.get(key, default_value)
|
||||
|
||||
|
||||
def get_openai_llm(self) -> Optional[LLMConfig]:
|
||||
"""Get OpenAI LLMConfig by name. If no OpenAI, raise Exception"""
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
||||
# TODO
|
||||
from metagpt.environment.base_env import Environment
|
||||
from metagpt.environment.android_env.android_env import AndroidEnv
|
||||
from metagpt.environment.gym_env.gym_env import GymEnv
|
||||
from metagpt.environment.mincraft_env.mincraft_env import MincraftExtEnv
|
||||
|
|
|
|||
|
|
@ -9,10 +9,10 @@ from typing import Any, Optional
|
|||
from pydantic import Field
|
||||
|
||||
from metagpt.const import ADB_EXEC_FAIL
|
||||
from metagpt.environment.base_env import Env, ExtEnv, mark_as_readable, mark_as_writeable
|
||||
from metagpt.environment.base_env import Environment, ExtEnv, mark_as_readable, mark_as_writeable
|
||||
|
||||
|
||||
class AndroidExtEnv(Env, ExtEnv):
|
||||
class AndroidExtEnv(Environment, ExtEnv):
|
||||
device_id: Optional[str] = Field(default=None)
|
||||
screenshot_dir: Optional[Path] = Field(default=None)
|
||||
xml_dir: Optional[Path] = Field(default=None)
|
||||
|
|
|
|||
|
|
@ -2,25 +2,29 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Desc : base env of executing environment
|
||||
|
||||
import asyncio
|
||||
from enum import Enum
|
||||
from typing import Optional, Union, Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Set, Union
|
||||
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator
|
||||
from metagpt.context import Context
|
||||
from metagpt.environment.api.env_api import (
|
||||
EnvAPIAbstract,
|
||||
ReadAPIRegistry,
|
||||
WriteAPIRegistry,
|
||||
)
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import get_function_schema, is_coroutine_func
|
||||
from metagpt.utils.common import get_function_schema, is_coroutine_func, is_send_to
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from metagpt.roles.role import Role # noqa: F401
|
||||
|
||||
|
||||
class EnvType(Enum):
|
||||
ANDROID = "Android"
|
||||
GYM = "Gym"
|
||||
WEREWOLF = "Werewolf"
|
||||
MINCRAFT = "Minsraft"
|
||||
MINCRAFT = "Mincraft"
|
||||
STANFORDTOWN = "StanfordTown"
|
||||
|
||||
|
||||
|
|
@ -28,49 +32,25 @@ env_write_api_registry = WriteAPIRegistry()
|
|||
env_read_api_registry = ReadAPIRegistry()
|
||||
|
||||
|
||||
# def mark_as_readable(func):
|
||||
# """mark functionn as a readable one in ExtEnv, it observes something from ExtEnv"""
|
||||
#
|
||||
# def wrapper(self: ExtEnv, *args, **kwargs):
|
||||
# api_name = func.__name__
|
||||
# self.read_api_registry[api_name] = func
|
||||
# return func(self, *args, **kwargs)
|
||||
#
|
||||
# return wrapper
|
||||
#
|
||||
# def mark_as_writeable(func):
|
||||
# """mark functionn as a writeable one in ExtEnv, it does something to ExtEnv"""
|
||||
#
|
||||
# def wrapper(self: ExtEnv, *args, **kwargs):
|
||||
# api_name = func.__name__
|
||||
# self.write_api_registry[api_name] = func
|
||||
# return func(self, *args, **kwargs)
|
||||
#
|
||||
# return wrapper
|
||||
|
||||
def mark_as_readable(func):
|
||||
"""mark function as a readable one in ExtEnv, it observes something from ExtEnv"""
|
||||
"""mark functionn as a readable one in ExtEnv, it observes something from ExtEnv"""
|
||||
env_read_api_registry[func.__name__] = get_function_schema(func)
|
||||
return func
|
||||
|
||||
|
||||
def mark_as_writeable(func):
|
||||
"""mark function as a writeable one in ExtEnv, it does something to ExtEnv"""
|
||||
"""mark functionn as a writeable one in ExtEnv, it does something to ExtEnv"""
|
||||
env_write_api_registry[func.__name__] = get_function_schema(func)
|
||||
return func
|
||||
|
||||
|
||||
class ExtEnv(BaseModel):
|
||||
"""External Env to intergate actual game environment"""
|
||||
|
||||
write_api_registry: WriteAPIRegistry = Field(default_factory=WriteAPIRegistry, exclude=True)
|
||||
read_api_registry: ReadAPIRegistry = Field(default_factory=ReadAPIRegistry, exclude=True)
|
||||
|
||||
|
||||
class Env(ExtEnv):
|
||||
"""Env to intergate with MetaGPT"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
class ExtEnv(BaseModel):
|
||||
"""External Env to intergate actual game environment"""
|
||||
|
||||
def _check_api_exist(self, rw_api: Optional[str] = None):
|
||||
if not rw_api:
|
||||
|
|
@ -84,45 +64,25 @@ class Env(ExtEnv):
|
|||
else:
|
||||
return env_write_api_registry.get_apis()
|
||||
|
||||
# TODO adds is_coroutine_func
|
||||
# def observe(self, env_action: Union[str, EnvAPIAbstract]):
|
||||
# if isinstance(env_action, str):
|
||||
# read_api = env_write_api_registry.get(api_name=env_action)
|
||||
# self._check_api_exist(read_api)
|
||||
# res = read_api(self)
|
||||
# elif isinstance(env_action, EnvAPIAbstract):
|
||||
# read_api = env_write_api_registry.get(api_name=env_action.api_name)
|
||||
# self._check_api_exist(read_api)
|
||||
# res = read_api(self, *env_action.args, **env_action.kwargs)
|
||||
# return res
|
||||
#
|
||||
# def step(self, env_action: Union[str, Message, EnvAPIAbstract, list[EnvAPIAbstract]]):
|
||||
# res = None
|
||||
# if isinstance(env_action, Message):
|
||||
# self.publish_message(env_action)
|
||||
# elif isinstance(env_action, EnvAPIAbstract):
|
||||
# print(f"CURRENT API NAME: {env_action.api_name}")
|
||||
# write_api = self.write_api_registry.get(env_action.api_name)
|
||||
# self._check_api_exist(write_api)
|
||||
# res = write_api(self, *env_action.args, **env_action.kwargs)
|
||||
#
|
||||
# return res
|
||||
|
||||
def observe(self, env_action: Union[str, EnvAPIAbstract]):
|
||||
# TODO Adds is_coroutine_func
|
||||
async def observe(self, env_action: Union[str, EnvAPIAbstract]):
|
||||
"""get observation from particular api of ExtEnv"""
|
||||
if isinstance(env_action, str):
|
||||
read_api = env_read_api_registry.get(api_name=env_action)["func"]
|
||||
self._check_api_exist(read_api)
|
||||
res = read_api(self)
|
||||
if is_coroutine_func(read_api):
|
||||
res = await read_api(self)
|
||||
else:
|
||||
res = read_api(self)
|
||||
elif isinstance(env_action, EnvAPIAbstract):
|
||||
read_api = env_read_api_registry.get(api_name=env_action.api_name)["func"]
|
||||
self._check_api_exist(read_api)
|
||||
res = read_api(self, *env_action.args, **env_action.kwargs)
|
||||
|
||||
if is_coroutine_func(read_api):
|
||||
res = await read_api(self, *env_action.args, **env_action.kwargs)
|
||||
else:
|
||||
res = read_api(self, *env_action.args, **env_action.kwargs)
|
||||
return res
|
||||
|
||||
def step(self, env_action: Union[str, Message, EnvAPIAbstract, list[EnvAPIAbstract]]):
|
||||
async def step(self, env_action: Union[str, Message, EnvAPIAbstract, list[EnvAPIAbstract]]):
|
||||
"""execute through particular api of ExtEnv"""
|
||||
res = None
|
||||
if isinstance(env_action, Message):
|
||||
|
|
@ -130,9 +90,131 @@ class Env(ExtEnv):
|
|||
elif isinstance(env_action, EnvAPIAbstract):
|
||||
write_api = env_write_api_registry.get(env_action.api_name)["func"]
|
||||
self._check_api_exist(write_api)
|
||||
res = write_api(self, *env_action.args, **env_action.kwargs)
|
||||
if is_coroutine_func(write_api):
|
||||
res = await write_api(self, *env_action.args, **env_action.kwargs)
|
||||
else:
|
||||
res = write_api(self, *env_action.args, **env_action.kwargs)
|
||||
|
||||
return res
|
||||
|
||||
def publish_message(self, message: "Message"):
|
||||
pass
|
||||
class Environment(ExtEnv):
|
||||
"""环境,承载一批角色,角色可以向环境发布消息,可以被其他角色观察到
|
||||
Environment, hosting a batch of roles, roles can publish messages to the environment, and can be observed by other roles
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
desc: str = Field(default="") # 环境描述
|
||||
roles: dict[str, SerializeAsAny["Role"]] = Field(default_factory=dict, validate_default=True)
|
||||
member_addrs: Dict["Role", Set] = Field(default_factory=dict, exclude=True)
|
||||
history: str = "" # For debug
|
||||
context: Context = Field(default_factory=Context, exclude=True)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def init_roles(self):
|
||||
self.add_roles(self.roles.values())
|
||||
return self
|
||||
|
||||
def add_role(self, role: "Role"):
|
||||
"""增加一个在当前环境的角色
|
||||
Add a role in the current environment
|
||||
"""
|
||||
self.roles[role.profile] = role
|
||||
role.set_env(self)
|
||||
role.context = self.context
|
||||
|
||||
def add_roles(self, roles: Iterable["Role"]):
|
||||
"""增加一批在当前环境的角色
|
||||
Add a batch of characters in the current environment
|
||||
"""
|
||||
for role in roles:
|
||||
self.roles[role.profile] = role
|
||||
|
||||
for role in roles: # setup system message with roles
|
||||
role.set_env(self)
|
||||
role.context = self.context
|
||||
|
||||
def publish_message(self, message: Message, peekable: bool = True) -> bool:
|
||||
"""
|
||||
Distribute the message to the recipients.
|
||||
In accordance with the Message routing structure design in Chapter 2.2.1 of RFC 116, as already planned
|
||||
in RFC 113 for the entire system, the routing information in the Message is only responsible for
|
||||
specifying the message recipient, without concern for where the message recipient is located. How to
|
||||
route the message to the message recipient is a problem addressed by the transport framework designed
|
||||
in RFC 113.
|
||||
"""
|
||||
logger.debug(f"publish_message: {message.dump()}")
|
||||
found = False
|
||||
# According to the routing feature plan in Chapter 2.2.3.2 of RFC 113
|
||||
for role, addrs in self.member_addrs.items():
|
||||
if is_send_to(message, addrs):
|
||||
role.put_message(message)
|
||||
found = True
|
||||
if not found:
|
||||
logger.warning(f"Message no recipients: {message.dump()}")
|
||||
self.history += f"\n{message}" # For debug
|
||||
|
||||
return True
|
||||
|
||||
async def run(self, k=1):
|
||||
"""处理一次所有信息的运行
|
||||
Process all Role runs at once
|
||||
"""
|
||||
for _ in range(k):
|
||||
futures = []
|
||||
for role in self.roles.values():
|
||||
future = role.run()
|
||||
futures.append(future)
|
||||
|
||||
await asyncio.gather(*futures)
|
||||
logger.debug(f"is idle: {self.is_idle}")
|
||||
|
||||
def get_roles(self) -> dict[str, "Role"]:
|
||||
"""获得环境内的所有角色
|
||||
Process all Role runs at once
|
||||
"""
|
||||
return self.roles
|
||||
|
||||
def get_role(self, name: str) -> "Role":
|
||||
"""获得环境内的指定角色
|
||||
get all the environment roles
|
||||
"""
|
||||
return self.roles.get(name, None)
|
||||
|
||||
def role_names(self) -> list[str]:
|
||||
return [i.name for i in self.roles.values()]
|
||||
|
||||
@property
|
||||
def is_idle(self):
|
||||
"""If true, all actions have been executed."""
|
||||
for r in self.roles.values():
|
||||
if not r.is_idle:
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_addresses(self, obj):
|
||||
"""Get the addresses of the object."""
|
||||
return self.member_addrs.get(obj, {})
|
||||
|
||||
def set_addresses(self, obj, addresses):
|
||||
"""Set the addresses of the object"""
|
||||
self.member_addrs[obj] = addresses
|
||||
|
||||
def archive(self, auto_archive=True):
|
||||
if auto_archive and self.context.git_repo:
|
||||
self.context.git_repo.archive()
|
||||
|
||||
@classmethod
|
||||
def model_rebuild(cls, **kwargs):
|
||||
from metagpt.roles.role import Role # noqa: F401
|
||||
|
||||
super().model_rebuild(**kwargs)
|
||||
|
||||
|
||||
Environment.model_rebuild()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue