mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-29 15:59:42 +02:00
Update test for action node & Modify extenv (self reflection)
This commit is contained in:
parent
32211ff5f2
commit
a1b0faacf4
5 changed files with 159 additions and 33 deletions
|
|
@ -9,10 +9,10 @@ from typing import Any, Optional
|
|||
from pydantic import Field
|
||||
|
||||
from metagpt.const import ADB_EXEC_FAIL
|
||||
from metagpt.environment.base_env import ExtEnv, mark_as_readable, mark_as_writeable
|
||||
from metagpt.environment.base_env import Env, ExtEnv, mark_as_readable, mark_as_writeable
|
||||
|
||||
|
||||
class AndroidExtEnv(ExtEnv):
|
||||
class AndroidExtEnv(Env, ExtEnv):
|
||||
device_id: Optional[str] = Field(default=None)
|
||||
screenshot_dir: Optional[Path] = Field(default=None)
|
||||
xml_dir: Optional[Path] = Field(default=None)
|
||||
|
|
@ -42,6 +42,7 @@ class AndroidExtEnv(ExtEnv):
|
|||
return f"adb -s {self.device_id} "
|
||||
|
||||
def execute_adb_with_cmd(self, adb_cmd: str) -> str:
|
||||
adb_cmd = adb_cmd.replace('\\', '/')
|
||||
res = subprocess.run(adb_cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
exec_res = ADB_EXEC_FAIL
|
||||
if not res.returncode:
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
# @Desc : base env of executing environment
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, Union, Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
|
@ -13,6 +13,7 @@ from metagpt.environment.api.env_api import (
|
|||
WriteAPIRegistry,
|
||||
)
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import get_function_schema, is_coroutine_func
|
||||
|
||||
|
||||
class EnvType(Enum):
|
||||
|
|
@ -23,26 +24,40 @@ class EnvType(Enum):
|
|||
STANFORDTOWN = "StanfordTown"
|
||||
|
||||
|
||||
env_write_api_registry = WriteAPIRegistry()
|
||||
env_read_api_registry = ReadAPIRegistry()
|
||||
|
||||
|
||||
# def mark_as_readable(func):
|
||||
# """mark functionn as a readable one in ExtEnv, it observes something from ExtEnv"""
|
||||
#
|
||||
# def wrapper(self: ExtEnv, *args, **kwargs):
|
||||
# api_name = func.__name__
|
||||
# self.read_api_registry[api_name] = func
|
||||
# return func(self, *args, **kwargs)
|
||||
#
|
||||
# return wrapper
|
||||
#
|
||||
# def mark_as_writeable(func):
|
||||
# """mark functionn as a writeable one in ExtEnv, it does something to ExtEnv"""
|
||||
#
|
||||
# def wrapper(self: ExtEnv, *args, **kwargs):
|
||||
# api_name = func.__name__
|
||||
# self.write_api_registry[api_name] = func
|
||||
# return func(self, *args, **kwargs)
|
||||
#
|
||||
# return wrapper
|
||||
|
||||
def mark_as_readable(func):
|
||||
"""mark functionn as a readable one in ExtEnv, it observes something from ExtEnv"""
|
||||
|
||||
def wrapper(self: ExtEnv, *args, **kwargs):
|
||||
api_name = func.__name__
|
||||
self.read_api_registry[api_name] = func
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
"""mark function as a readable one in ExtEnv, it observes something from ExtEnv"""
|
||||
env_read_api_registry[func.__name__] = get_function_schema(func)
|
||||
return func
|
||||
|
||||
|
||||
def mark_as_writeable(func):
|
||||
"""mark functionn as a writeable one in ExtEnv, it does something to ExtEnv"""
|
||||
|
||||
def wrapper(self: ExtEnv, *args, **kwargs):
|
||||
api_name = func.__name__
|
||||
self.write_api_registry[api_name] = func
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
"""mark function as a writeable one in ExtEnv, it does something to ExtEnv"""
|
||||
env_write_api_registry[func.__name__] = get_function_schema(func)
|
||||
return func
|
||||
|
||||
|
||||
class ExtEnv(BaseModel):
|
||||
|
|
@ -61,23 +76,59 @@ class Env(ExtEnv):
|
|||
if not rw_api:
|
||||
raise ValueError(f"{rw_api} not exists")
|
||||
|
||||
def get_all_available_apis(self, mode: str = "read") -> list[Any]:
|
||||
"""get available read/write apis definition"""
|
||||
assert mode in ["read", "write"]
|
||||
if mode == "read":
|
||||
return env_read_api_registry.get_apis()
|
||||
else:
|
||||
return env_write_api_registry.get_apis()
|
||||
|
||||
# TODO adds is_coroutine_func
|
||||
# def observe(self, env_action: Union[str, EnvAPIAbstract]):
|
||||
# if isinstance(env_action, str):
|
||||
# read_api = env_write_api_registry.get(api_name=env_action)
|
||||
# self._check_api_exist(read_api)
|
||||
# res = read_api(self)
|
||||
# elif isinstance(env_action, EnvAPIAbstract):
|
||||
# read_api = env_write_api_registry.get(api_name=env_action.api_name)
|
||||
# self._check_api_exist(read_api)
|
||||
# res = read_api(self, *env_action.args, **env_action.kwargs)
|
||||
# return res
|
||||
#
|
||||
# def step(self, env_action: Union[str, Message, EnvAPIAbstract, list[EnvAPIAbstract]]):
|
||||
# res = None
|
||||
# if isinstance(env_action, Message):
|
||||
# self.publish_message(env_action)
|
||||
# elif isinstance(env_action, EnvAPIAbstract):
|
||||
# print(f"CURRENT API NAME: {env_action.api_name}")
|
||||
# write_api = self.write_api_registry.get(env_action.api_name)
|
||||
# self._check_api_exist(write_api)
|
||||
# res = write_api(self, *env_action.args, **env_action.kwargs)
|
||||
#
|
||||
# return res
|
||||
|
||||
def observe(self, env_action: Union[str, EnvAPIAbstract]):
|
||||
# TODO Adds is_coroutine_func
|
||||
"""get observation from particular api of ExtEnv"""
|
||||
if isinstance(env_action, str):
|
||||
read_api = self.read_api_registry.get(api_name=env_action)
|
||||
read_api = env_read_api_registry.get(api_name=env_action)["func"]
|
||||
self._check_api_exist(read_api)
|
||||
res = read_api(self)
|
||||
elif isinstance(env_action, EnvAPIAbstract):
|
||||
read_api = self.read_api_registry.get(api_name=env_action.api_name)
|
||||
read_api = env_read_api_registry.get(api_name=env_action.api_name)["func"]
|
||||
self._check_api_exist(read_api)
|
||||
res = read_api(self, *env_action.args, **env_action.kwargs)
|
||||
|
||||
return res
|
||||
|
||||
def step(self, env_action: Union[str, Message, EnvAPIAbstract, list[EnvAPIAbstract]]):
|
||||
"""execute through particular api of ExtEnv"""
|
||||
res = None
|
||||
if isinstance(env_action, Message):
|
||||
self.publish_message(env_action)
|
||||
elif isinstance(env_action, EnvAPIAbstract):
|
||||
write_api = self.write_api_registry.get(env_action.api_name)
|
||||
write_api = env_write_api_registry.get(env_action.api_name)["func"]
|
||||
self._check_api_exist(write_api)
|
||||
res = write_api(self, *env_action.args, **env_action.kwargs)
|
||||
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ import sys
|
|||
import traceback
|
||||
import typing
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Tuple, Union
|
||||
from typing import Any, List, Tuple, Union, Callable
|
||||
|
||||
import aiofiles
|
||||
import loguru
|
||||
|
|
@ -214,7 +214,7 @@ class OutputParser:
|
|||
|
||||
if start_index != -1 and end_index != -1:
|
||||
# Extract the structure part
|
||||
structure_text = text[start_index : end_index + 1]
|
||||
structure_text = text[start_index: end_index + 1]
|
||||
|
||||
try:
|
||||
# Attempt to convert the text to a Python data type using ast.literal_eval
|
||||
|
|
@ -337,6 +337,14 @@ def print_members(module, indent=0):
|
|||
print(f"{prefix}Method: {name}")
|
||||
|
||||
|
||||
def get_function_schema(func: Callable) -> dict[str, Union[dict, Any, str]]:
|
||||
sig = inspect.signature(func)
|
||||
parameters = sig.parameters
|
||||
return_type = sig.return_annotation
|
||||
param_schema = {name: parameter.annotation for name, parameter in parameters.items()}
|
||||
return {"input_params": param_schema, "return_type": return_type, "func_desc": func.__doc__, "func": func}
|
||||
|
||||
|
||||
def parse_recipient(text):
|
||||
# FIXME: use ActionNode instead.
|
||||
pattern = r"## Send To:\s*([A-Za-z]+)\s*?" # hard code for now
|
||||
|
|
@ -594,6 +602,10 @@ def list_files(root: str | Path) -> List[Path]:
|
|||
return files
|
||||
|
||||
|
||||
def is_coroutine_func(func: Callable) -> bool:
|
||||
return inspect.iscoroutinefunction(func)
|
||||
|
||||
|
||||
def encode_image(image_path: Path, encoding: str = "utf-8") -> str:
|
||||
with open(str(image_path), "rb") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode(encoding)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue