diff --git a/metagpt/environment/android_env/android_env.py b/metagpt/environment/android_env/android_env.py index c6058aa4a..87b49750d 100644 --- a/metagpt/environment/android_env/android_env.py +++ b/metagpt/environment/android_env/android_env.py @@ -2,7 +2,7 @@ # -*- coding: utf-8 -*- # @Desc : MG Android Env -from metagpt.env.android_env.android_ext_env import AndroidExtEnv +from metagpt.environment.android_env.android_ext_env import AndroidExtEnv class AndroidEnv(AndroidExtEnv): diff --git a/metagpt/environment/api/env_api.py b/metagpt/environment/api/env_api.py index bb7a75243..6469e5b4c 100644 --- a/metagpt/environment/api/env_api.py +++ b/metagpt/environment/api/env_api.py @@ -18,19 +18,28 @@ class EnvAPIAbstract(BaseModel): class EnvAPIRegistry(BaseModel): """the registry to store environment w&r api/interface""" - registry: dict[str, Callable] = Field(default=dict(), include=False) + registry: dict[str, Callable] = Field(default=dict(), exclude=True) def get(self, api_name: str): return self.registry.get(api_name) + def __getitem__(self, api_name: str) -> Callable: + return self.get(api_name) + + def __setitem__(self, api_name: str, func: Callable): + self.registry[api_name] = func + + def __len__(self): + return len(self.registry) + class WriteAPIRegistry(EnvAPIRegistry): - """just as a new class name""" + """just as a explicit class name""" pass class ReadAPIRegistry(EnvAPIRegistry): - """just as a new class name""" + """just as a explicit class name""" pass diff --git a/metagpt/environment/base_env.py b/metagpt/environment/base_env.py index 4c25ae044..48917549e 100644 --- a/metagpt/environment/base_env.py +++ b/metagpt/environment/base_env.py @@ -3,11 +3,15 @@ # @Desc : base env of executing environment from enum import Enum -from typing import Union +from typing import Optional, Union from pydantic import BaseModel, ConfigDict, Field -from metagpt.env.api.env_api import EnvAPIAbstract, ReadAPIRegistry, WriteAPIRegistry +from metagpt.environment.api.env_api import ( + EnvAPIAbstract, + ReadAPIRegistry, + WriteAPIRegistry, +) from metagpt.schema import Message @@ -23,7 +27,7 @@ def mark_as_readable(func): """mark functionn as a readable one in ExtEnv, it observes something from ExtEnv""" def wrapper(self: ExtEnv, *args, **kwargs): - api_name = str(func) # TODO + api_name = func.__name__ self.read_api_registry[api_name] = func return func(self, *args, **kwargs) @@ -31,10 +35,10 @@ def mark_as_readable(func): def mark_as_writeable(func): - """mark functionn as a writeable one in ExtEnv, it do something to ExtEnv""" + """mark functionn as a writeable one in ExtEnv, it does something to ExtEnv""" def wrapper(self: ExtEnv, *args, **kwargs): - api_name = str(func) # TODO + api_name = func.__name__ self.write_api_registry[api_name] = func return func(self, *args, **kwargs) @@ -44,8 +48,8 @@ def mark_as_writeable(func): class ExtEnv(BaseModel): """External Env to intergate actual game environment""" - write_api_registry: WriteAPIRegistry = Field(default_factory=WriteAPIRegistry, include=False) - read_api_registry: ReadAPIRegistry = Field(default_factory=ReadAPIRegistry, include=False) + write_api_registry: WriteAPIRegistry = Field(default_factory=WriteAPIRegistry, exclude=True) + read_api_registry: ReadAPIRegistry = Field(default_factory=ReadAPIRegistry, exclude=True) class Env(ExtEnv): @@ -53,10 +57,19 @@ class Env(ExtEnv): model_config = ConfigDict(arbitrary_types_allowed=True) + def _check_api_exist(self, rw_api: Optional[str] = None): + if not rw_api: + raise ValueError(f"{rw_api} not exists") + def observe(self, env_action: Union[str, EnvAPIAbstract]): - api_name = env_action.api_name if isinstance(env_action, EnvAPIAbstract) else env_action - read_api = self.read_api_registry.get(api_name) - res = read_api(*env_action.args, **env_action.kwargs) + if isinstance(env_action, str): + read_api = self.read_api_registry.get(api_name=env_action) + self._check_api_exist(read_api) + res = read_api(self) + elif isinstance(env_action, EnvAPIAbstract): + read_api = self.read_api_registry.get(api_name=env_action.api_name) + self._check_api_exist(read_api) + res = read_api(self, *env_action.args, **env_action.kwargs) return res def step(self, env_action: Union[str, Message, EnvAPIAbstract, list[EnvAPIAbstract]]): @@ -65,7 +78,8 @@ class Env(ExtEnv): self.publish_message(env_action) elif isinstance(env_action, EnvAPIAbstract): write_api = self.write_api_registry.get(env_action.api_name) - res = write_api(*env_action.args, **env_action.kwargs) + self._check_api_exist(write_api) + res = write_api(self, *env_action.args, **env_action.kwargs) return res diff --git a/metagpt/environment/gym_env/gym_env.py b/metagpt/environment/gym_env/gym_env.py index 2bcf8efd0..b83d988d6 100644 --- a/metagpt/environment/gym_env/gym_env.py +++ b/metagpt/environment/gym_env/gym_env.py @@ -1,3 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -# @Desc : +# @Desc : MG Gym Env + + +class GymEnv: + pass diff --git a/metagpt/environment/mincraft_env/mincraft_env.py b/metagpt/environment/mincraft_env/mincraft_env.py index 2bcf8efd0..e79b87cf0 100644 --- a/metagpt/environment/mincraft_env/mincraft_env.py +++ b/metagpt/environment/mincraft_env/mincraft_env.py @@ -1,3 +1,9 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -# @Desc : +# @Desc : MG Mincraft Env + +from metagpt.environment.mincraft_env.mincraft_ext_env import MincraftExtEnv + + +class MincraftEnv(MincraftExtEnv): + pass diff --git a/metagpt/environment/mincraft_env/mincraft_ext_env.py b/metagpt/environment/mincraft_env/mincraft_ext_env.py index 2bcf8efd0..6012a80d9 100644 --- a/metagpt/environment/mincraft_env/mincraft_ext_env.py +++ b/metagpt/environment/mincraft_env/mincraft_ext_env.py @@ -1,3 +1,9 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -# @Desc : +# @Desc : The Mincraft external environment to integrate with Mincraft game + +from metagpt.environment.base_env import ExtEnv + + +class MincraftExtEnv(ExtEnv): + pass diff --git a/metagpt/environment/werewolf_env/werewolf_env.py b/metagpt/environment/werewolf_env/werewolf_env.py index 29e9f9b81..831f8e020 100644 --- a/metagpt/environment/werewolf_env/werewolf_env.py +++ b/metagpt/environment/werewolf_env/werewolf_env.py @@ -2,7 +2,7 @@ # -*- coding: utf-8 -*- # @Desc : MG Werewolf Env -from metagpt.env.werewolf_env.werewolf_ext_env import WerewolfExtEnv +from metagpt.environment.werewolf_env.werewolf_ext_env import WerewolfExtEnv class WerewolfEnv(WerewolfExtEnv): diff --git a/metagpt/environment/werewolf_env/werewolf_ext_env.py b/metagpt/environment/werewolf_env/werewolf_ext_env.py index 8543ce246..014417009 100644 --- a/metagpt/environment/werewolf_env/werewolf_ext_env.py +++ b/metagpt/environment/werewolf_env/werewolf_ext_env.py @@ -6,7 +6,7 @@ from enum import Enum from pydantic import Field -from metagpt.env.base_env import ExtEnv, mark_as_readable, mark_as_writeable +from metagpt.environment.base_env import ExtEnv, mark_as_readable, mark_as_writeable class RoleState(Enum): diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index e09d49d84..74024fdd6 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -13,6 +13,7 @@ from __future__ import annotations import ast import contextlib +import csv import importlib import inspect import json @@ -465,6 +466,29 @@ def write_json_file(json_file: str, data: list, encoding=None): json.dump(data, fout, ensure_ascii=False, indent=4, default=to_jsonable_python) +def read_csv_to_list(curr_file: str, header=False, strip_trail=True): + """ + Reads in a csv file to a list of list. If header is True, it returns a + tuple with (header row, all rows) + ARGS: + curr_file: path to the current csv file. + RETURNS: + List of list where the component lists are the rows of the file. + """ + logger.debug(f"start read csv: {curr_file}") + analysis_list = [] + with open(curr_file) as f_analysis_file: + data_reader = csv.reader(f_analysis_file, delimiter=",") + for count, row in enumerate(data_reader): + if strip_trail: + row = [i.strip() for i in row] + analysis_list += [row] + if not header: + return analysis_list + else: + return analysis_list[0], analysis_list[1:] + + def import_class(class_name: str, module_name: str) -> type: module = importlib.import_module(module_name) a_class = getattr(module, class_name) diff --git a/tests/metagpt/environment/android_env/__init__.py b/tests/metagpt/environment/android_env/__init__.py new file mode 100644 index 000000000..2bcf8efd0 --- /dev/null +++ b/tests/metagpt/environment/android_env/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc :