update env

This commit is contained in:
better629 2024-01-23 16:38:42 +08:00
parent 095ce5caf4
commit 20daa8e93a
10 changed files with 86 additions and 20 deletions

View file

@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
# @Desc : MG Android Env
from metagpt.env.android_env.android_ext_env import AndroidExtEnv
from metagpt.environment.android_env.android_ext_env import AndroidExtEnv
class AndroidEnv(AndroidExtEnv):

View file

@ -18,19 +18,28 @@ class EnvAPIAbstract(BaseModel):
class EnvAPIRegistry(BaseModel):
"""the registry to store environment w&r api/interface"""
registry: dict[str, Callable] = Field(default=dict(), include=False)
registry: dict[str, Callable] = Field(default=dict(), exclude=True)
def get(self, api_name: str):
return self.registry.get(api_name)
def __getitem__(self, api_name: str) -> Callable:
return self.get(api_name)
def __setitem__(self, api_name: str, func: Callable):
self.registry[api_name] = func
def __len__(self):
return len(self.registry)
class WriteAPIRegistry(EnvAPIRegistry):
"""just as a new class name"""
"""just as a explicit class name"""
pass
class ReadAPIRegistry(EnvAPIRegistry):
"""just as a new class name"""
"""just as a explicit class name"""
pass

View file

@ -3,11 +3,15 @@
# @Desc : base env of executing environment
from enum import Enum
from typing import Union
from typing import Optional, Union
from pydantic import BaseModel, ConfigDict, Field
from metagpt.env.api.env_api import EnvAPIAbstract, ReadAPIRegistry, WriteAPIRegistry
from metagpt.environment.api.env_api import (
EnvAPIAbstract,
ReadAPIRegistry,
WriteAPIRegistry,
)
from metagpt.schema import Message
@ -23,7 +27,7 @@ def mark_as_readable(func):
"""mark functionn as a readable one in ExtEnv, it observes something from ExtEnv"""
def wrapper(self: ExtEnv, *args, **kwargs):
api_name = str(func) # TODO
api_name = func.__name__
self.read_api_registry[api_name] = func
return func(self, *args, **kwargs)
@ -31,10 +35,10 @@ def mark_as_readable(func):
def mark_as_writeable(func):
"""mark functionn as a writeable one in ExtEnv, it do something to ExtEnv"""
"""mark functionn as a writeable one in ExtEnv, it does something to ExtEnv"""
def wrapper(self: ExtEnv, *args, **kwargs):
api_name = str(func) # TODO
api_name = func.__name__
self.write_api_registry[api_name] = func
return func(self, *args, **kwargs)
@ -44,8 +48,8 @@ def mark_as_writeable(func):
class ExtEnv(BaseModel):
"""External Env to intergate actual game environment"""
write_api_registry: WriteAPIRegistry = Field(default_factory=WriteAPIRegistry, include=False)
read_api_registry: ReadAPIRegistry = Field(default_factory=ReadAPIRegistry, include=False)
write_api_registry: WriteAPIRegistry = Field(default_factory=WriteAPIRegistry, exclude=True)
read_api_registry: ReadAPIRegistry = Field(default_factory=ReadAPIRegistry, exclude=True)
class Env(ExtEnv):
@ -53,10 +57,19 @@ class Env(ExtEnv):
model_config = ConfigDict(arbitrary_types_allowed=True)
def _check_api_exist(self, rw_api: Optional[str] = None):
if not rw_api:
raise ValueError(f"{rw_api} not exists")
def observe(self, env_action: Union[str, EnvAPIAbstract]):
api_name = env_action.api_name if isinstance(env_action, EnvAPIAbstract) else env_action
read_api = self.read_api_registry.get(api_name)
res = read_api(*env_action.args, **env_action.kwargs)
if isinstance(env_action, str):
read_api = self.read_api_registry.get(api_name=env_action)
self._check_api_exist(read_api)
res = read_api(self)
elif isinstance(env_action, EnvAPIAbstract):
read_api = self.read_api_registry.get(api_name=env_action.api_name)
self._check_api_exist(read_api)
res = read_api(self, *env_action.args, **env_action.kwargs)
return res
def step(self, env_action: Union[str, Message, EnvAPIAbstract, list[EnvAPIAbstract]]):
@ -65,7 +78,8 @@ class Env(ExtEnv):
self.publish_message(env_action)
elif isinstance(env_action, EnvAPIAbstract):
write_api = self.write_api_registry.get(env_action.api_name)
res = write_api(*env_action.args, **env_action.kwargs)
self._check_api_exist(write_api)
res = write_api(self, *env_action.args, **env_action.kwargs)
return res

View file

@ -1,3 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
# @Desc : MG Gym Env
class GymEnv:
pass

View file

@ -1,3 +1,9 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
# @Desc : MG Mincraft Env
from metagpt.environment.mincraft_env.mincraft_ext_env import MincraftExtEnv
class MincraftEnv(MincraftExtEnv):
pass

View file

@ -1,3 +1,9 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
# @Desc : The Mincraft external environment to integrate with Mincraft game
from metagpt.environment.base_env import ExtEnv
class MincraftExtEnv(ExtEnv):
pass

View file

@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
# @Desc : MG Werewolf Env
from metagpt.env.werewolf_env.werewolf_ext_env import WerewolfExtEnv
from metagpt.environment.werewolf_env.werewolf_ext_env import WerewolfExtEnv
class WerewolfEnv(WerewolfExtEnv):

View file

@ -6,7 +6,7 @@ from enum import Enum
from pydantic import Field
from metagpt.env.base_env import ExtEnv, mark_as_readable, mark_as_writeable
from metagpt.environment.base_env import ExtEnv, mark_as_readable, mark_as_writeable
class RoleState(Enum):

View file

@ -13,6 +13,7 @@ from __future__ import annotations
import ast
import contextlib
import csv
import importlib
import inspect
import json
@ -465,6 +466,29 @@ def write_json_file(json_file: str, data: list, encoding=None):
json.dump(data, fout, ensure_ascii=False, indent=4, default=to_jsonable_python)
def read_csv_to_list(curr_file: str, header=False, strip_trail=True):
"""
Reads in a csv file to a list of list. If header is True, it returns a
tuple with (header row, all rows)
ARGS:
curr_file: path to the current csv file.
RETURNS:
List of list where the component lists are the rows of the file.
"""
logger.debug(f"start read csv: {curr_file}")
analysis_list = []
with open(curr_file) as f_analysis_file:
data_reader = csv.reader(f_analysis_file, delimiter=",")
for count, row in enumerate(data_reader):
if strip_trail:
row = [i.strip() for i in row]
analysis_list += [row]
if not header:
return analysis_list
else:
return analysis_list[0], analysis_list[1:]
def import_class(class_name: str, module_name: str) -> type:
module = importlib.import_module(module_name)
a_class = getattr(module, class_name)

View file

@ -0,0 +1,3 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :