Update env & test code

This commit is contained in:
didi 2024-02-28 17:00:51 +08:00
parent 07c360b9c7
commit cfc0cc1fa5
12 changed files with 248 additions and 140 deletions

View file

@ -122,8 +122,12 @@ class Config(CLIParams, YamlModel):
def set_other(self, other: dict):
self.other = other
def get_other(self, key: str):
return self.other.get(key)
def get_other(self, key: str, default_value: str = None):
if default_value is None:
return self.other.get(key)
else:
return self.other.get(key, default_value)
def get_openai_llm(self) -> Optional[LLMConfig]:
"""Get OpenAI LLMConfig by name. If no OpenAI, raise Exception"""

View file

@ -1,7 +1,8 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
# TODO
from metagpt.environment.base_env import Environment
from metagpt.environment.android_env.android_env import AndroidEnv
from metagpt.environment.gym_env.gym_env import GymEnv
from metagpt.environment.mincraft_env.mincraft_env import MincraftExtEnv

View file

@ -9,10 +9,10 @@ from typing import Any, Optional
from pydantic import Field
from metagpt.const import ADB_EXEC_FAIL
from metagpt.environment.base_env import Env, ExtEnv, mark_as_readable, mark_as_writeable
from metagpt.environment.base_env import Environment, ExtEnv, mark_as_readable, mark_as_writeable
class AndroidExtEnv(Env, ExtEnv):
class AndroidExtEnv(Environment, ExtEnv):
device_id: Optional[str] = Field(default=None)
screenshot_dir: Optional[Path] = Field(default=None)
xml_dir: Optional[Path] = Field(default=None)

View file

@ -2,25 +2,29 @@
# -*- coding: utf-8 -*-
# @Desc : base env of executing environment
import asyncio
from enum import Enum
from typing import Optional, Union, Any
from pydantic import BaseModel, ConfigDict, Field
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Set, Union
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator
from metagpt.context import Context
from metagpt.environment.api.env_api import (
EnvAPIAbstract,
ReadAPIRegistry,
WriteAPIRegistry,
)
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.utils.common import get_function_schema, is_coroutine_func
from metagpt.utils.common import get_function_schema, is_coroutine_func, is_send_to
if TYPE_CHECKING:
from metagpt.roles.role import Role # noqa: F401
class EnvType(Enum):
ANDROID = "Android"
GYM = "Gym"
WEREWOLF = "Werewolf"
MINCRAFT = "Minsraft"
MINCRAFT = "Mincraft"
STANFORDTOWN = "StanfordTown"
@ -28,49 +32,25 @@ env_write_api_registry = WriteAPIRegistry()
env_read_api_registry = ReadAPIRegistry()
# def mark_as_readable(func):
# """mark functionn as a readable one in ExtEnv, it observes something from ExtEnv"""
#
# def wrapper(self: ExtEnv, *args, **kwargs):
# api_name = func.__name__
# self.read_api_registry[api_name] = func
# return func(self, *args, **kwargs)
#
# return wrapper
#
# def mark_as_writeable(func):
# """mark functionn as a writeable one in ExtEnv, it does something to ExtEnv"""
#
# def wrapper(self: ExtEnv, *args, **kwargs):
# api_name = func.__name__
# self.write_api_registry[api_name] = func
# return func(self, *args, **kwargs)
#
# return wrapper
def mark_as_readable(func):
"""mark function as a readable one in ExtEnv, it observes something from ExtEnv"""
"""mark functionn as a readable one in ExtEnv, it observes something from ExtEnv"""
env_read_api_registry[func.__name__] = get_function_schema(func)
return func
def mark_as_writeable(func):
"""mark function as a writeable one in ExtEnv, it does something to ExtEnv"""
"""mark functionn as a writeable one in ExtEnv, it does something to ExtEnv"""
env_write_api_registry[func.__name__] = get_function_schema(func)
return func
class ExtEnv(BaseModel):
"""External Env to intergate actual game environment"""
write_api_registry: WriteAPIRegistry = Field(default_factory=WriteAPIRegistry, exclude=True)
read_api_registry: ReadAPIRegistry = Field(default_factory=ReadAPIRegistry, exclude=True)
class Env(ExtEnv):
"""Env to intergate with MetaGPT"""
model_config = ConfigDict(arbitrary_types_allowed=True)
class ExtEnv(BaseModel):
"""External Env to intergate actual game environment"""
def _check_api_exist(self, rw_api: Optional[str] = None):
if not rw_api:
@ -84,45 +64,25 @@ class Env(ExtEnv):
else:
return env_write_api_registry.get_apis()
# TODO adds is_coroutine_func
# def observe(self, env_action: Union[str, EnvAPIAbstract]):
# if isinstance(env_action, str):
# read_api = env_write_api_registry.get(api_name=env_action)
# self._check_api_exist(read_api)
# res = read_api(self)
# elif isinstance(env_action, EnvAPIAbstract):
# read_api = env_write_api_registry.get(api_name=env_action.api_name)
# self._check_api_exist(read_api)
# res = read_api(self, *env_action.args, **env_action.kwargs)
# return res
#
# def step(self, env_action: Union[str, Message, EnvAPIAbstract, list[EnvAPIAbstract]]):
# res = None
# if isinstance(env_action, Message):
# self.publish_message(env_action)
# elif isinstance(env_action, EnvAPIAbstract):
# print(f"CURRENT API NAME: {env_action.api_name}")
# write_api = self.write_api_registry.get(env_action.api_name)
# self._check_api_exist(write_api)
# res = write_api(self, *env_action.args, **env_action.kwargs)
#
# return res
def observe(self, env_action: Union[str, EnvAPIAbstract]):
# TODO Adds is_coroutine_func
async def observe(self, env_action: Union[str, EnvAPIAbstract]):
"""get observation from particular api of ExtEnv"""
if isinstance(env_action, str):
read_api = env_read_api_registry.get(api_name=env_action)["func"]
self._check_api_exist(read_api)
res = read_api(self)
if is_coroutine_func(read_api):
res = await read_api(self)
else:
res = read_api(self)
elif isinstance(env_action, EnvAPIAbstract):
read_api = env_read_api_registry.get(api_name=env_action.api_name)["func"]
self._check_api_exist(read_api)
res = read_api(self, *env_action.args, **env_action.kwargs)
if is_coroutine_func(read_api):
res = await read_api(self, *env_action.args, **env_action.kwargs)
else:
res = read_api(self, *env_action.args, **env_action.kwargs)
return res
def step(self, env_action: Union[str, Message, EnvAPIAbstract, list[EnvAPIAbstract]]):
async def step(self, env_action: Union[str, Message, EnvAPIAbstract, list[EnvAPIAbstract]]):
"""execute through particular api of ExtEnv"""
res = None
if isinstance(env_action, Message):
@ -130,9 +90,131 @@ class Env(ExtEnv):
elif isinstance(env_action, EnvAPIAbstract):
write_api = env_write_api_registry.get(env_action.api_name)["func"]
self._check_api_exist(write_api)
res = write_api(self, *env_action.args, **env_action.kwargs)
if is_coroutine_func(write_api):
res = await write_api(self, *env_action.args, **env_action.kwargs)
else:
res = write_api(self, *env_action.args, **env_action.kwargs)
return res
def publish_message(self, message: "Message"):
pass
class Environment(ExtEnv):
"""环境,承载一批角色,角色可以向环境发布消息,可以被其他角色观察到
Environment, hosting a batch of roles, roles can publish messages to the environment, and can be observed by other roles
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
desc: str = Field(default="") # 环境描述
roles: dict[str, SerializeAsAny["Role"]] = Field(default_factory=dict, validate_default=True)
member_addrs: Dict["Role", Set] = Field(default_factory=dict, exclude=True)
history: str = "" # For debug
context: Context = Field(default_factory=Context, exclude=True)
@model_validator(mode="after")
def init_roles(self):
self.add_roles(self.roles.values())
return self
def add_role(self, role: "Role"):
"""增加一个在当前环境的角色
Add a role in the current environment
"""
self.roles[role.profile] = role
role.set_env(self)
role.context = self.context
def add_roles(self, roles: Iterable["Role"]):
"""增加一批在当前环境的角色
Add a batch of characters in the current environment
"""
for role in roles:
self.roles[role.profile] = role
for role in roles: # setup system message with roles
role.set_env(self)
role.context = self.context
def publish_message(self, message: Message, peekable: bool = True) -> bool:
"""
Distribute the message to the recipients.
In accordance with the Message routing structure design in Chapter 2.2.1 of RFC 116, as already planned
in RFC 113 for the entire system, the routing information in the Message is only responsible for
specifying the message recipient, without concern for where the message recipient is located. How to
route the message to the message recipient is a problem addressed by the transport framework designed
in RFC 113.
"""
logger.debug(f"publish_message: {message.dump()}")
found = False
# According to the routing feature plan in Chapter 2.2.3.2 of RFC 113
for role, addrs in self.member_addrs.items():
if is_send_to(message, addrs):
role.put_message(message)
found = True
if not found:
logger.warning(f"Message no recipients: {message.dump()}")
self.history += f"\n{message}" # For debug
return True
async def run(self, k=1):
"""处理一次所有信息的运行
Process all Role runs at once
"""
for _ in range(k):
futures = []
for role in self.roles.values():
future = role.run()
futures.append(future)
await asyncio.gather(*futures)
logger.debug(f"is idle: {self.is_idle}")
def get_roles(self) -> dict[str, "Role"]:
"""获得环境内的所有角色
Process all Role runs at once
"""
return self.roles
def get_role(self, name: str) -> "Role":
"""获得环境内的指定角色
get all the environment roles
"""
return self.roles.get(name, None)
def role_names(self) -> list[str]:
return [i.name for i in self.roles.values()]
@property
def is_idle(self):
"""If true, all actions have been executed."""
for r in self.roles.values():
if not r.is_idle:
return False
return True
def get_addresses(self, obj):
"""Get the addresses of the object."""
return self.member_addrs.get(obj, {})
def set_addresses(self, obj, addresses):
"""Set the addresses of the object"""
self.member_addrs[obj] = addresses
def archive(self, auto_archive=True):
if auto_archive and self.context.git_repo:
self.context.git_repo.archive()
@classmethod
def model_rebuild(cls, **kwargs):
from metagpt.roles.role import Role # noqa: F401
super().model_rebuild(**kwargs)
Environment.model_rebuild()