mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-11 15:15:18 +02:00
update env
This commit is contained in:
parent
095ce5caf4
commit
20daa8e93a
10 changed files with 86 additions and 20 deletions
|
|
@ -2,7 +2,7 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Desc : MG Android Env
|
||||
|
||||
from metagpt.env.android_env.android_ext_env import AndroidExtEnv
|
||||
from metagpt.environment.android_env.android_ext_env import AndroidExtEnv
|
||||
|
||||
|
||||
class AndroidEnv(AndroidExtEnv):
|
||||
|
|
|
|||
|
|
@ -18,19 +18,28 @@ class EnvAPIAbstract(BaseModel):
|
|||
class EnvAPIRegistry(BaseModel):
|
||||
"""the registry to store environment w&r api/interface"""
|
||||
|
||||
registry: dict[str, Callable] = Field(default=dict(), include=False)
|
||||
registry: dict[str, Callable] = Field(default=dict(), exclude=True)
|
||||
|
||||
def get(self, api_name: str):
|
||||
return self.registry.get(api_name)
|
||||
|
||||
def __getitem__(self, api_name: str) -> Callable:
|
||||
return self.get(api_name)
|
||||
|
||||
def __setitem__(self, api_name: str, func: Callable):
|
||||
self.registry[api_name] = func
|
||||
|
||||
def __len__(self):
|
||||
return len(self.registry)
|
||||
|
||||
|
||||
class WriteAPIRegistry(EnvAPIRegistry):
|
||||
"""just as a new class name"""
|
||||
"""just as a explicit class name"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ReadAPIRegistry(EnvAPIRegistry):
|
||||
"""just as a new class name"""
|
||||
"""just as a explicit class name"""
|
||||
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -3,11 +3,15 @@
|
|||
# @Desc : base env of executing environment
|
||||
|
||||
from enum import Enum
|
||||
from typing import Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from metagpt.env.api.env_api import EnvAPIAbstract, ReadAPIRegistry, WriteAPIRegistry
|
||||
from metagpt.environment.api.env_api import (
|
||||
EnvAPIAbstract,
|
||||
ReadAPIRegistry,
|
||||
WriteAPIRegistry,
|
||||
)
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
|
|
@ -23,7 +27,7 @@ def mark_as_readable(func):
|
|||
"""mark functionn as a readable one in ExtEnv, it observes something from ExtEnv"""
|
||||
|
||||
def wrapper(self: ExtEnv, *args, **kwargs):
|
||||
api_name = str(func) # TODO
|
||||
api_name = func.__name__
|
||||
self.read_api_registry[api_name] = func
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
|
|
@ -31,10 +35,10 @@ def mark_as_readable(func):
|
|||
|
||||
|
||||
def mark_as_writeable(func):
|
||||
"""mark functionn as a writeable one in ExtEnv, it do something to ExtEnv"""
|
||||
"""mark functionn as a writeable one in ExtEnv, it does something to ExtEnv"""
|
||||
|
||||
def wrapper(self: ExtEnv, *args, **kwargs):
|
||||
api_name = str(func) # TODO
|
||||
api_name = func.__name__
|
||||
self.write_api_registry[api_name] = func
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
|
|
@ -44,8 +48,8 @@ def mark_as_writeable(func):
|
|||
class ExtEnv(BaseModel):
|
||||
"""External Env to intergate actual game environment"""
|
||||
|
||||
write_api_registry: WriteAPIRegistry = Field(default_factory=WriteAPIRegistry, include=False)
|
||||
read_api_registry: ReadAPIRegistry = Field(default_factory=ReadAPIRegistry, include=False)
|
||||
write_api_registry: WriteAPIRegistry = Field(default_factory=WriteAPIRegistry, exclude=True)
|
||||
read_api_registry: ReadAPIRegistry = Field(default_factory=ReadAPIRegistry, exclude=True)
|
||||
|
||||
|
||||
class Env(ExtEnv):
|
||||
|
|
@ -53,10 +57,19 @@ class Env(ExtEnv):
|
|||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def _check_api_exist(self, rw_api: Optional[str] = None):
|
||||
if not rw_api:
|
||||
raise ValueError(f"{rw_api} not exists")
|
||||
|
||||
def observe(self, env_action: Union[str, EnvAPIAbstract]):
|
||||
api_name = env_action.api_name if isinstance(env_action, EnvAPIAbstract) else env_action
|
||||
read_api = self.read_api_registry.get(api_name)
|
||||
res = read_api(*env_action.args, **env_action.kwargs)
|
||||
if isinstance(env_action, str):
|
||||
read_api = self.read_api_registry.get(api_name=env_action)
|
||||
self._check_api_exist(read_api)
|
||||
res = read_api(self)
|
||||
elif isinstance(env_action, EnvAPIAbstract):
|
||||
read_api = self.read_api_registry.get(api_name=env_action.api_name)
|
||||
self._check_api_exist(read_api)
|
||||
res = read_api(self, *env_action.args, **env_action.kwargs)
|
||||
return res
|
||||
|
||||
def step(self, env_action: Union[str, Message, EnvAPIAbstract, list[EnvAPIAbstract]]):
|
||||
|
|
@ -65,7 +78,8 @@ class Env(ExtEnv):
|
|||
self.publish_message(env_action)
|
||||
elif isinstance(env_action, EnvAPIAbstract):
|
||||
write_api = self.write_api_registry.get(env_action.api_name)
|
||||
res = write_api(*env_action.args, **env_action.kwargs)
|
||||
self._check_api_exist(write_api)
|
||||
res = write_api(self, *env_action.args, **env_action.kwargs)
|
||||
|
||||
return res
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,7 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
# @Desc : MG Gym Env
|
||||
|
||||
|
||||
class GymEnv:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -1,3 +1,9 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
# @Desc : MG Mincraft Env
|
||||
|
||||
from metagpt.environment.mincraft_env.mincraft_ext_env import MincraftExtEnv
|
||||
|
||||
|
||||
class MincraftEnv(MincraftExtEnv):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -1,3 +1,9 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
# @Desc : The Mincraft external environment to integrate with Mincraft game
|
||||
|
||||
from metagpt.environment.base_env import ExtEnv
|
||||
|
||||
|
||||
class MincraftExtEnv(ExtEnv):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Desc : MG Werewolf Env
|
||||
|
||||
from metagpt.env.werewolf_env.werewolf_ext_env import WerewolfExtEnv
|
||||
from metagpt.environment.werewolf_env.werewolf_ext_env import WerewolfExtEnv
|
||||
|
||||
|
||||
class WerewolfEnv(WerewolfExtEnv):
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from enum import Enum
|
|||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.env.base_env import ExtEnv, mark_as_readable, mark_as_writeable
|
||||
from metagpt.environment.base_env import ExtEnv, mark_as_readable, mark_as_writeable
|
||||
|
||||
|
||||
class RoleState(Enum):
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from __future__ import annotations
|
|||
|
||||
import ast
|
||||
import contextlib
|
||||
import csv
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
|
|
@ -465,6 +466,29 @@ def write_json_file(json_file: str, data: list, encoding=None):
|
|||
json.dump(data, fout, ensure_ascii=False, indent=4, default=to_jsonable_python)
|
||||
|
||||
|
||||
def read_csv_to_list(curr_file: str, header=False, strip_trail=True):
|
||||
"""
|
||||
Reads in a csv file to a list of list. If header is True, it returns a
|
||||
tuple with (header row, all rows)
|
||||
ARGS:
|
||||
curr_file: path to the current csv file.
|
||||
RETURNS:
|
||||
List of list where the component lists are the rows of the file.
|
||||
"""
|
||||
logger.debug(f"start read csv: {curr_file}")
|
||||
analysis_list = []
|
||||
with open(curr_file) as f_analysis_file:
|
||||
data_reader = csv.reader(f_analysis_file, delimiter=",")
|
||||
for count, row in enumerate(data_reader):
|
||||
if strip_trail:
|
||||
row = [i.strip() for i in row]
|
||||
analysis_list += [row]
|
||||
if not header:
|
||||
return analysis_list
|
||||
else:
|
||||
return analysis_list[0], analysis_list[1:]
|
||||
|
||||
|
||||
def import_class(class_name: str, module_name: str) -> type:
|
||||
module = importlib.import_module(module_name)
|
||||
a_class = getattr(module, class_name)
|
||||
|
|
|
|||
3
tests/metagpt/environment/android_env/__init__.py
Normal file
3
tests/metagpt/environment/android_env/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
Loading…
Add table
Add a link
Reference in a new issue