Update env & test code

This commit is contained in:
didi 2024-02-28 17:00:51 +08:00
parent 07c360b9c7
commit cfc0cc1fa5
12 changed files with 248 additions and 140 deletions

View file

@ -16,3 +16,6 @@ ## Free Your Hands
### By Text
### By Voice
## Run It
python run_assistant.py "Create a contact in Contacts App named zjy with a phone number +86 18831933368" --mode "manual" --app-name "Contacts"

View file

@ -33,7 +33,8 @@ class ManualRecord(Action):
screenshot_after_path: Path = ""
xml_path: Path = ""
async def run(self, demo_name: str, task_desc: str,task_dir: Path, env: AndroidEnv):
# async def run(self, demo_name: str, task_desc: str,task_dir: Path, env: AndroidEnv):
async def run(self, task_desc: str, task_dir: Path, env: AndroidEnv):
self.record_path = Path(task_dir) / "record.txt"
self.task_desc_path = Path(task_dir) / "task_desc.txt"
@ -53,16 +54,18 @@ class ManualRecord(Action):
step = 0
while True:
step += 1
screenshot_path: Path = env.observe(
screenshot_path: Path = await env.observe(
EnvAPIAbstract(
api_name="get_screenshot",
kwargs={"ss_name": f"{demo_name}_{step}", "local_save_dir": self.screenshot_before_path}
# kwargs={"ss_name": f"{demo_name}_{step}", "local_save_dir": self.screenshot_before_path}
kwargs={"ss_name": f"{step}", "local_save_dir": self.screenshot_before_path}
)
)
xml_path: Path = env.observe(
xml_path: Path = await env.observe(
EnvAPIAbstract(
api_name="get_xml",
kwargs={"xml_name": f"{demo_name}_{step}", "local_save_dir": self.xml_path}
# kwargs={"xml_name": f"{demo_name}_{step}", "local_save_dir": self.xml_path}
kwargs={"xml_name": f"{step}", "local_save_dir": self.xml_path}
)
)
if not screenshot_path.exists() or not xml_path.exists():
@ -86,14 +89,13 @@ class ManualRecord(Action):
bbox = e.bbox
center_ = (bbox[0][0] + bbox[1][0]) // 2, (bbox[0][1] + bbox[1][1]) // 2
dist = (abs(center[0] - center_[0]) ** 2 + abs(center[1] - center_[1]) ** 2) ** 0.5
# TODO Modify config to default 30. It should be modified back config after single action test
# if dist <= config.get_other("min_dist"):
if dist <= 30:
if dist <= config.get_other("min_dist"):
close = True
break
if not close:
elem_list.append(elem)
screenshot_labeled_path = Path(self.screenshot_after_path).joinpath(f"{demo_name}_{step}_labeled.png")
screenshot_labeled_path = Path(self.screenshot_after_path).joinpath(f"{step}_labeled.png")
# screenshot_labeled_path = Path(self.screenshot_after_path).joinpath(f"{demo_name}_{step}_labeled.png")
labeled_img = draw_bbox_multi(screenshot_path, screenshot_labeled_path, elem_list)
cv2.imshow("image", labeled_img)
@ -142,7 +144,7 @@ class ManualRecord(Action):
user_input = ""
while not user_input:
user_input = input()
env.step(EnvAPIAbstract(api_name="user_input", kwargs={"input_txt": user_input}))
await env.step(EnvAPIAbstract(api_name="user_input", kwargs={"input_txt": user_input}))
record_file.write(f'text({input_area}:sep:"{user_input}"):::{elem_list[int(input_area) - 1].uid}\n')
elif user_input.lower() == ActionOp.LONG_PRESS.value:
logger.info(

View file

@ -38,9 +38,9 @@ class ParseRecord(Action):
screenshot_before_path: Path = ""
screenshot_after_path: Path = ""
async def run(self, app_name: str, demo_name: str, task_dir: Path, docs_dir: Path, env: AndroidEnv):
if not docs_dir.exists():
docs_dir.mkdir(parents=True, exist_ok=True)
# async def run(self, app_name: str, demo_name: str, task_dir: Path, docs_dir: Path, env: AndroidEnv):
async def run(self, app_name: str, task_dir: Path, docs_dir: Path, env: AndroidEnv):
docs_dir.mkdir(parents=True, exist_ok=True)
doc_count = 0
self.record_path = Path(task_dir) / "record.txt"
self.task_desc_path = Path(task_dir) / "task_desc.txt"
@ -51,8 +51,10 @@ class ParseRecord(Action):
record_step_count = len(record_file.readlines()) - 1
record_file.seek(0)
for step in range(1, record_step_count + 1):
img_before_base64 = encode_image(self.screenshot_after_path.joinpath(f"{demo_name}_{step}_labeled.png"))
img_after_base64 = encode_image(self.screenshot_after_path.joinpath(f"{demo_name}_{step + 1}_labeled.png"))
# img_before_base64 = encode_image(self.screenshot_after_path.joinpath(f"{demo_name}_{step}_labeled.png"))
# img_after_base64 = encode_image(self.screenshot_after_path.joinpath(f"{demo_name}_{step + 1}_labeled.png"))
img_before_base64 = encode_image(self.screenshot_after_path.joinpath(f"{step}_labeled.png"))
img_after_base64 = encode_image(self.screenshot_after_path.joinpath(f"{step + 1}_labeled.png"))
rec = record_file.readline().strip()
action, resource_id = rec.split(":::")
action_type = action.split("(")[0]
@ -110,8 +112,8 @@ class ParseRecord(Action):
)
if "error" in node.content:
return AndroidActionOutput(action_state=RunState.FAIL)
log_path = task_dir.joinpath(f"log_{app_name}_{demo_name}.txt")
# log_path = task_dir.joinpath(f"log_{app_name}_{demo_name}.txt")
log_path = task_dir.joinpath(f"log_{app_name}.txt")
prompt = node.compile(context=context, schema="json", mode="auto")
msg = node.content
doc_content[action_type] = msg

View file

@ -92,13 +92,13 @@ class ScreenshotParse(Action):
if not path.exists():
path.mkdir(parents=True, exist_ok=True)
screenshot_path: Path = env.observe(
screenshot_path: Path = await env.observe(
EnvAPIAbstract(
api_name="get_screenshot",
kwargs={"ss_name": f"{round_count}_before", "local_save_dir": task_dir}
)
)
xml_path: Path = env.observe(
xml_path: Path = await env.observe(
EnvAPIAbstract(
api_name="get_xml",
kwargs={"xml_name": f"{round_count}", "local_save_dir": task_dir}
@ -121,9 +121,7 @@ class ScreenshotParse(Action):
bbox = e.bbox
center_ = (bbox[0][0] + bbox[1][0]) // 2, (bbox[0][1] + bbox[1][1]) // 2
dist = (abs(center[0] - center_[0]) ** 2 + abs(center[1] - center_[1]) ** 2) ** 0.5
# TODO Modify config to default 30. It should be modified back config after single action test
# if dist <= config.get_other("min_dist"):
if dist <= 30:
if dist <= config.get_other("min_dist"):
close = True
break
if not close:
@ -156,21 +154,21 @@ class ScreenshotParse(Action):
if isinstance(op_param, TapOp):
x, y = elem_bbox_to_xy(elem_list[op_param.area - 1].bbox)
res = env.step(EnvAPIAbstract(api_name="system_tap", kwargs={"x": x, "y": y}))
res = await env.step(EnvAPIAbstract(api_name="system_tap", kwargs={"x": x, "y": y}))
if res == ADB_EXEC_FAIL:
return AndroidActionOutput(action_state=RunState.FAIL)
elif isinstance(op_param, TextOp):
res = env.step(EnvAPIAbstract(api_name="user_input", kwargs={"input_txt": op_param.input_str}))
res = await env.step(EnvAPIAbstract(api_name="user_input", kwargs={"input_txt": op_param.input_str}))
if res == ADB_EXEC_FAIL:
return AndroidActionOutput(action_state=RunState.FAIL)
elif isinstance(op_param, LongPressOp):
x, y = elem_bbox_to_xy(elem_list[op_param.area - 1].bbox)
res = env.step(EnvAPIAbstract(api_name="user_longpress", kwargs={"x": x, "y": y}))
res = await env.step(EnvAPIAbstract(api_name="user_longpress", kwargs={"x": x, "y": y}))
if res == ADB_EXEC_FAIL:
return AndroidActionOutput(action_state=RunState.FAIL)
elif isinstance(op_param, SwipeOp_3):
x, y = elem_bbox_to_xy(elem_list[op_param.area - 1].bbox)
res = env.step(
res = await env.step(
EnvAPIAbstract(
api_name="user_swipe",
kwargs={"x": x, "y": y, "orient": op_param.swipe_orient, "dist": op_param.dist}
@ -183,18 +181,18 @@ class ScreenshotParse(Action):
elif isinstance(op_param, TapGridOp) or isinstance(op_param, LongPressGridOp):
x, y = area_to_xy(op_param.area, op_param.subarea, env.width, env.height, env.rows, env.cols)
if isinstance(op_param, TapGridOp):
res = env.step(EnvAPIAbstract(api_name="system_tap", kwargs={"x": x, "y": y}))
res = await env.step(EnvAPIAbstract(api_name="system_tap", kwargs={"x": x, "y": y}))
if res == ADB_EXEC_FAIL:
return AndroidActionOutput(action_state=RunState.FAIL)
else:
# LongPressGridOp
res = env.step(EnvAPIAbstract(api_name="user_longpress", kwargs={"x": x, "y": y}))
res = await env.step(EnvAPIAbstract(api_name="user_longpress", kwargs={"x": x, "y": y}))
if res == ADB_EXEC_FAIL:
return AndroidActionOutput(action_state=RunState.FAIL)
elif isinstance(op_param, SwipeGridOp):
start_x, start_y = area_to_xy(op_param.start_area, op_param.start_subarea, width, height, rows, cols)
end_x, end_y = area_to_xy(op_param.end_area, op_param.end_subarea, width, height, rows, cols)
res = env.step(
res = await env.step(
EnvAPIAbstract(api_name="user_swipe_to", kwargs={"start": (start_x, start_y), "end": (end_x, end_y)}))
if res == ADB_EXEC_FAIL:
return AndroidActionOutput(action_state=RunState.FAIL)

View file

@ -26,6 +26,7 @@ from examples.andriod_assistant.utils.schema import (
ReflectLogItem,
RunState,
SwipeOp,
SwipeOp_3,
TapOp,
TextOp,
)
@ -70,12 +71,12 @@ class SelfLearnAndReflect(Action):
async def run_self_learn(
self, round_count: int, task_desc: str, last_act: str, task_dir: Path, env: AndroidEnv
) -> AndroidActionOutput:
screenshot_path: Path = env.observe(
screenshot_path: Path = await env.observe(
EnvAPIAbstract(
api_name="get_screenshot", kwargs={"ss_name": f"{round_count}_before", "local_save_dir": task_dir}
)
)
xml_path: Path = env.observe(
xml_path: Path = await env.observe(
EnvAPIAbstract(api_name="get_xml", kwargs={"xml_name": f"{round_count}", "local_save_dir": task_dir})
)
if not screenshot_path.exists() or not xml_path.exists():
@ -100,9 +101,7 @@ class SelfLearnAndReflect(Action):
bbox = e.bbox
center_ = (bbox[0][0] + bbox[1][0]) // 2, (bbox[0][1] + bbox[1][1]) // 2
dist = (abs(center[0] - center_[0]) ** 2 + abs(center[1] - center_[1]) ** 2) ** 0.5
# TODO Modify config to default 30. It should be modified back config after single action test
# if dist <= config.get_other("min_dist"):
if dist <= 30:
if dist <= config.get_other("min_dist"):
close = True
break
if not close:
@ -125,7 +124,6 @@ class SelfLearnAndReflect(Action):
OpLogItem(step=round_count, prompt=prompt, image=str(screenshot_before_labeled_path), response=node.content)
op_param = screenshot_parse_extract(node.instruct_content.model_dump(), grid_on=False)
# TODO Modify Op_param. When op_param.action is FINISH, how to solve this ?
logger.info(op_param)
if op_param.param_state == RunState.FINISH:
return AndroidActionOutput(action_state=RunState.FINISH)
if op_param.param_state == RunState.FAIL:
@ -134,26 +132,26 @@ class SelfLearnAndReflect(Action):
if isinstance(op_param, TapOp):
self.ui_area = op_param.area
x, y = elem_bbox_to_xy(elem_list[op_param.area - 1].bbox)
res = env.step(EnvAPIAbstract(api_name="system_tap", kwargs={"x": x, "y": y}))
res = await env.step(EnvAPIAbstract(api_name="system_tap", kwargs={"x": x, "y": y}))
if res == ADB_EXEC_FAIL:
return AndroidActionOutput(action_state=RunState.FAIL)
elif isinstance(op_param, TextOp):
res = env.step(EnvAPIAbstract(api_name="user_input", kwargs={"input_txt": op_param.input_str}))
res = await env.step(EnvAPIAbstract(api_name="user_input", kwargs={"input_txt": op_param.input_str}))
if res == ADB_EXEC_FAIL:
return AndroidActionOutput(action_state=RunState.FAIL)
elif isinstance(op_param, LongPressOp):
self.ui_area = op_param.area
x, y = elem_bbox_to_xy(elem_list[op_param.area - 1].bbox)
res = env.step(EnvAPIAbstract(api_name="user_longpress", kwargs={"x": x, "y": y}))
res = await env.step(EnvAPIAbstract(api_name="user_longpress", kwargs={"x": x, "y": y}))
if res == ADB_EXEC_FAIL:
return AndroidActionOutput(action_state=RunState.FAIL)
elif isinstance(op_param, SwipeOp):
elif isinstance(op_param, SwipeOp_3):
self.ui_area = op_param.area
self.swipe_orient = op_param.swipe_orient
x, y = elem_bbox_to_xy(elem_list[op_param.area - 1].bbox)
res = env.step(
res = await env.step(
EnvAPIAbstract(
"user_swipe", kwargs={"x": x, "y": y, "orient": op_param.swipe_orient, "dist": op_param.dist}
api_name="user_swipe", kwargs={"x": x, "y": y, "orient": op_param.swipe_orient, "dist": op_param.dist}
)
)
if res == ADB_EXEC_FAIL:
@ -167,8 +165,7 @@ class SelfLearnAndReflect(Action):
async def run_reflect(
self, round_count: int, task_desc: str, last_act: str, task_dir: Path, docs_dir: Path, env: AndroidEnv
) -> AndroidActionOutput:
logger.info("run_reflect")
screenshot_path: Path = env.observe(
screenshot_path: Path = await env.observe(
EnvAPIAbstract(
api_name="get_screenshot", kwargs={"ss_name": f"{round_count}_after", "local_save_dir": task_dir}
)
@ -180,7 +177,6 @@ class SelfLearnAndReflect(Action):
draw_bbox_multi(screenshot_path, screenshot_after_labeled_path, elem_list=self.elem_list)
img_base64 = encode_image(screenshot_after_labeled_path)
logger.info(f"act_name: {self.act_name}")
if self.act_name == ActionOp.TAP.value:
action = "tapping"
elif self.act_name == ActionOp.LONG_PRESS.value:
@ -225,7 +221,7 @@ class SelfLearnAndReflect(Action):
self.useless_list.append(resource_id)
last_act = "NONE"
if op_param.decision == Decision.BACK.value:
res = env.step(EnvAPIAbstract(api_name="system_back"))
res = await env.step(EnvAPIAbstract(api_name="system_back"))
if res == ADB_EXEC_FAIL:
return AndroidActionOutput(action_state=RunState.FAIL)
doc = op_param.documentation

View file

@ -37,7 +37,8 @@ class AndroidAssistant(Role):
self._watch([UserRequirement])
app_name = config.get_other("app_name", "demo")
data_dir = Path(__file__).parent.joinpath("..", "output")
curr_path = Path(__file__).parent
data_dir = curr_path.joinpath("..", "output")
cur_datetime = datetime.fromtimestamp(int(time.time())).strftime("%Y-%m-%d_%H-%M-%S")
"""Firstly, we decide the state with user config, further, we can do it automatically, like if it's new app,
@ -67,39 +68,57 @@ class AndroidAssistant(Role):
self._set_react_mode(RoleReactMode.BY_ORDER)
def _check_dir(self):
self.task_dir.mkdir(exist_ok=True)
self.docs_dir.mkdir(exist_ok=True)
self.task_dir.mkdir(parents=True, exist_ok=True)
self.docs_dir.mkdir(parents=True, exist_ok=True)
async def react(self) -> Message:
self.round_count += 1
super().react()
result = await super().react()
print(f"react result {result}")
return result
async def _act(self) -> Message:
# Question: How to achieve self_learn's loop action ?
logger.info(f"{self._setting}: to do {self.rc.todo}({self.rc.todo.name})")
todo = self.rc.todo
send_to = ""
if isinstance(todo, ManualRecord):
resp = await todo.run()
resp = await todo.run(
# demo_name="",
task_dir=self.task_dir,
task_desc=self.task_desc,
env=self.rc.env
)
elif isinstance(todo, ParseRecord):
resp = await todo.run()
resp = await todo.run(
app_name=config.get_other("app_name", "demo"),
task_dir=self.task_dir,
docs_dir=self.docs_dir,
env=self.rc.env
)
elif isinstance(todo, SelfLearnAndReflect):
resp = await todo.run(round_count=self.round_count,
task_desc=self.task_desc,
last_act=self.last_act,
task_dir=self.task_dir,
docs_dir=self.docs_dir,
env=self.rc.env)
resp = await todo.run(
round_count=self.round_count,
task_desc=self.task_desc,
last_act=self.last_act,
task_dir=self.task_dir,
docs_dir=self.docs_dir,
env=self.rc.env
)
if resp.action_state == RunState.SUCCESS:
self.last_act = resp.data.get("last_act")
send_to = self.name
elif isinstance(todo, ScreenshotParse):
resp = await todo.run(round_count=self.round_count,
task_desc=self.task_desc,
last_act=self.last_act,
task_dir=self.task_dir,
grid_on=self.grid_on,
env=self.rc.env)
resp = await todo.run(
round_count=self.round_count,
task_desc=self.task_desc,
last_act=self.last_act,
task_dir=self.task_dir,
docs_dir=self.docs_dir,
grid_on=self.grid_on,
env=self.rc.env
)
if resp.action_state == RunState.SUCCESS:
self.grid_on = resp.data.get("grid_on")
send_to = self.name

View file

@ -50,12 +50,13 @@ def startup(
)
team = Team(env=AndroidEnv())
team.hire([AndroidAssistant])
team.hire([AndroidAssistant()])
team.invest(investment)
company.run_project(idea=task_desc)
team.run_project(idea=task_desc)
asyncio.run(team.run(n_round=n_round))
if __name__ == "__main__":
app()
# Command python run_assistant.py "Create a contact in Contacts App named zjy with a phone number +86 18831933368"

View file

@ -1,6 +1,6 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : test on android emulator
# @Desc : test on android emulator action. After Modify Role Test, this script is discarded.
import asyncio
import time
from pathlib import Path
@ -50,14 +50,14 @@ if __name__ == "__main__":
env=test_env_self_learn_android
),
test_manual_record.run(
demo_name=DEMO_NAME,
# demo_name=DEMO_NAME,
task_dir=TASK_PATH / "demos" / f"manual_record_{DEMO_NAME}",
task_desc="Create a contact in Contacts App named zjy with a phone number +86 18831933368 ",
env=test_env_manual_learn_android
),
test_manual_parse.run(
app_name="Contacts",
demo_name=DEMO_NAME,
# demo_name=DEMO_NAME,
task_dir=TASK_PATH / "demos" / f"manual_record_{DEMO_NAME}", # 修要修改
docs_dir=PARSE_RECORD_DOC_PATH, # 需要修改
env=test_env_manual_learn_android

View file

@ -122,8 +122,12 @@ class Config(CLIParams, YamlModel):
def set_other(self, other: dict):
self.other = other
def get_other(self, key: str):
return self.other.get(key)
def get_other(self, key: str, default_value: str = None):
if default_value is None:
return self.other.get(key)
else:
return self.other.get(key, default_value)
def get_openai_llm(self) -> Optional[LLMConfig]:
"""Get OpenAI LLMConfig by name. If no OpenAI, raise Exception"""

View file

@ -1,7 +1,8 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
# TODO
from metagpt.environment.base_env import Environment
from metagpt.environment.android_env.android_env import AndroidEnv
from metagpt.environment.gym_env.gym_env import GymEnv
from metagpt.environment.mincraft_env.mincraft_env import MincraftExtEnv

View file

@ -9,10 +9,10 @@ from typing import Any, Optional
from pydantic import Field
from metagpt.const import ADB_EXEC_FAIL
from metagpt.environment.base_env import Env, ExtEnv, mark_as_readable, mark_as_writeable
from metagpt.environment.base_env import Environment, ExtEnv, mark_as_readable, mark_as_writeable
class AndroidExtEnv(Env, ExtEnv):
class AndroidExtEnv(Environment, ExtEnv):
device_id: Optional[str] = Field(default=None)
screenshot_dir: Optional[Path] = Field(default=None)
xml_dir: Optional[Path] = Field(default=None)

View file

@ -2,25 +2,29 @@
# -*- coding: utf-8 -*-
# @Desc : base env of executing environment
import asyncio
from enum import Enum
from typing import Optional, Union, Any
from pydantic import BaseModel, ConfigDict, Field
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Set, Union
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator
from metagpt.context import Context
from metagpt.environment.api.env_api import (
EnvAPIAbstract,
ReadAPIRegistry,
WriteAPIRegistry,
)
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.utils.common import get_function_schema, is_coroutine_func
from metagpt.utils.common import get_function_schema, is_coroutine_func, is_send_to
if TYPE_CHECKING:
from metagpt.roles.role import Role # noqa: F401
class EnvType(Enum):
ANDROID = "Android"
GYM = "Gym"
WEREWOLF = "Werewolf"
MINCRAFT = "Minsraft"
MINCRAFT = "Mincraft"
STANFORDTOWN = "StanfordTown"
@ -28,49 +32,25 @@ env_write_api_registry = WriteAPIRegistry()
env_read_api_registry = ReadAPIRegistry()
# def mark_as_readable(func):
# """mark functionn as a readable one in ExtEnv, it observes something from ExtEnv"""
#
# def wrapper(self: ExtEnv, *args, **kwargs):
# api_name = func.__name__
# self.read_api_registry[api_name] = func
# return func(self, *args, **kwargs)
#
# return wrapper
#
# def mark_as_writeable(func):
# """mark functionn as a writeable one in ExtEnv, it does something to ExtEnv"""
#
# def wrapper(self: ExtEnv, *args, **kwargs):
# api_name = func.__name__
# self.write_api_registry[api_name] = func
# return func(self, *args, **kwargs)
#
# return wrapper
def mark_as_readable(func):
"""mark function as a readable one in ExtEnv, it observes something from ExtEnv"""
"""mark functionn as a readable one in ExtEnv, it observes something from ExtEnv"""
env_read_api_registry[func.__name__] = get_function_schema(func)
return func
def mark_as_writeable(func):
"""mark function as a writeable one in ExtEnv, it does something to ExtEnv"""
"""mark functionn as a writeable one in ExtEnv, it does something to ExtEnv"""
env_write_api_registry[func.__name__] = get_function_schema(func)
return func
class ExtEnv(BaseModel):
"""External Env to intergate actual game environment"""
write_api_registry: WriteAPIRegistry = Field(default_factory=WriteAPIRegistry, exclude=True)
read_api_registry: ReadAPIRegistry = Field(default_factory=ReadAPIRegistry, exclude=True)
class Env(ExtEnv):
"""Env to intergate with MetaGPT"""
model_config = ConfigDict(arbitrary_types_allowed=True)
class ExtEnv(BaseModel):
"""External Env to intergate actual game environment"""
def _check_api_exist(self, rw_api: Optional[str] = None):
if not rw_api:
@ -84,45 +64,25 @@ class Env(ExtEnv):
else:
return env_write_api_registry.get_apis()
# TODO adds is_coroutine_func
# def observe(self, env_action: Union[str, EnvAPIAbstract]):
# if isinstance(env_action, str):
# read_api = env_write_api_registry.get(api_name=env_action)
# self._check_api_exist(read_api)
# res = read_api(self)
# elif isinstance(env_action, EnvAPIAbstract):
# read_api = env_write_api_registry.get(api_name=env_action.api_name)
# self._check_api_exist(read_api)
# res = read_api(self, *env_action.args, **env_action.kwargs)
# return res
#
# def step(self, env_action: Union[str, Message, EnvAPIAbstract, list[EnvAPIAbstract]]):
# res = None
# if isinstance(env_action, Message):
# self.publish_message(env_action)
# elif isinstance(env_action, EnvAPIAbstract):
# print(f"CURRENT API NAME: {env_action.api_name}")
# write_api = self.write_api_registry.get(env_action.api_name)
# self._check_api_exist(write_api)
# res = write_api(self, *env_action.args, **env_action.kwargs)
#
# return res
def observe(self, env_action: Union[str, EnvAPIAbstract]):
# TODO Adds is_coroutine_func
async def observe(self, env_action: Union[str, EnvAPIAbstract]):
"""get observation from particular api of ExtEnv"""
if isinstance(env_action, str):
read_api = env_read_api_registry.get(api_name=env_action)["func"]
self._check_api_exist(read_api)
res = read_api(self)
if is_coroutine_func(read_api):
res = await read_api(self)
else:
res = read_api(self)
elif isinstance(env_action, EnvAPIAbstract):
read_api = env_read_api_registry.get(api_name=env_action.api_name)["func"]
self._check_api_exist(read_api)
res = read_api(self, *env_action.args, **env_action.kwargs)
if is_coroutine_func(read_api):
res = await read_api(self, *env_action.args, **env_action.kwargs)
else:
res = read_api(self, *env_action.args, **env_action.kwargs)
return res
def step(self, env_action: Union[str, Message, EnvAPIAbstract, list[EnvAPIAbstract]]):
async def step(self, env_action: Union[str, Message, EnvAPIAbstract, list[EnvAPIAbstract]]):
"""execute through particular api of ExtEnv"""
res = None
if isinstance(env_action, Message):
@ -130,9 +90,131 @@ class Env(ExtEnv):
elif isinstance(env_action, EnvAPIAbstract):
write_api = env_write_api_registry.get(env_action.api_name)["func"]
self._check_api_exist(write_api)
res = write_api(self, *env_action.args, **env_action.kwargs)
if is_coroutine_func(write_api):
res = await write_api(self, *env_action.args, **env_action.kwargs)
else:
res = write_api(self, *env_action.args, **env_action.kwargs)
return res
def publish_message(self, message: "Message"):
pass
class Environment(ExtEnv):
"""环境,承载一批角色,角色可以向环境发布消息,可以被其他角色观察到
Environment, hosting a batch of roles, roles can publish messages to the environment, and can be observed by other roles
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
desc: str = Field(default="") # 环境描述
roles: dict[str, SerializeAsAny["Role"]] = Field(default_factory=dict, validate_default=True)
member_addrs: Dict["Role", Set] = Field(default_factory=dict, exclude=True)
history: str = "" # For debug
context: Context = Field(default_factory=Context, exclude=True)
@model_validator(mode="after")
def init_roles(self):
self.add_roles(self.roles.values())
return self
def add_role(self, role: "Role"):
"""增加一个在当前环境的角色
Add a role in the current environment
"""
self.roles[role.profile] = role
role.set_env(self)
role.context = self.context
def add_roles(self, roles: Iterable["Role"]):
"""增加一批在当前环境的角色
Add a batch of characters in the current environment
"""
for role in roles:
self.roles[role.profile] = role
for role in roles: # setup system message with roles
role.set_env(self)
role.context = self.context
def publish_message(self, message: Message, peekable: bool = True) -> bool:
"""
Distribute the message to the recipients.
In accordance with the Message routing structure design in Chapter 2.2.1 of RFC 116, as already planned
in RFC 113 for the entire system, the routing information in the Message is only responsible for
specifying the message recipient, without concern for where the message recipient is located. How to
route the message to the message recipient is a problem addressed by the transport framework designed
in RFC 113.
"""
logger.debug(f"publish_message: {message.dump()}")
found = False
# According to the routing feature plan in Chapter 2.2.3.2 of RFC 113
for role, addrs in self.member_addrs.items():
if is_send_to(message, addrs):
role.put_message(message)
found = True
if not found:
logger.warning(f"Message no recipients: {message.dump()}")
self.history += f"\n{message}" # For debug
return True
async def run(self, k=1):
"""处理一次所有信息的运行
Process all Role runs at once
"""
for _ in range(k):
futures = []
for role in self.roles.values():
future = role.run()
futures.append(future)
await asyncio.gather(*futures)
logger.debug(f"is idle: {self.is_idle}")
def get_roles(self) -> dict[str, "Role"]:
"""获得环境内的所有角色
Process all Role runs at once
"""
return self.roles
def get_role(self, name: str) -> "Role":
"""获得环境内的指定角色
get all the environment roles
"""
return self.roles.get(name, None)
def role_names(self) -> list[str]:
return [i.name for i in self.roles.values()]
@property
def is_idle(self):
"""If true, all actions have been executed."""
for r in self.roles.values():
if not r.is_idle:
return False
return True
def get_addresses(self, obj):
"""Get the addresses of the object."""
return self.member_addrs.get(obj, {})
def set_addresses(self, obj, addresses):
"""Set the addresses of the object"""
self.member_addrs[obj] = addresses
def archive(self, auto_archive=True):
if auto_archive and self.context.git_repo:
self.context.git_repo.archive()
@classmethod
def model_rebuild(cls, **kwargs):
from metagpt.roles.role import Role # noqa: F401
super().model_rebuild(**kwargs)
Environment.model_rebuild()