mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-26 15:49:42 +02:00
Update env & test code
This commit is contained in:
parent
07c360b9c7
commit
cfc0cc1fa5
12 changed files with 248 additions and 140 deletions
|
|
@ -122,8 +122,12 @@ class Config(CLIParams, YamlModel):
|
|||
def set_other(self, other: dict):
|
||||
self.other = other
|
||||
|
||||
def get_other(self, key: str):
|
||||
return self.other.get(key)
|
||||
def get_other(self, key: str, default_value: str = None):
|
||||
if default_value is None:
|
||||
return self.other.get(key)
|
||||
else:
|
||||
return self.other.get(key, default_value)
|
||||
|
||||
|
||||
def get_openai_llm(self) -> Optional[LLMConfig]:
|
||||
"""Get OpenAI LLMConfig by name. If no OpenAI, raise Exception"""
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
||||
# TODO
|
||||
from metagpt.environment.base_env import Environment
|
||||
from metagpt.environment.android_env.android_env import AndroidEnv
|
||||
from metagpt.environment.gym_env.gym_env import GymEnv
|
||||
from metagpt.environment.mincraft_env.mincraft_env import MincraftExtEnv
|
||||
|
|
|
|||
|
|
@ -9,10 +9,10 @@ from typing import Any, Optional
|
|||
from pydantic import Field
|
||||
|
||||
from metagpt.const import ADB_EXEC_FAIL
|
||||
from metagpt.environment.base_env import Env, ExtEnv, mark_as_readable, mark_as_writeable
|
||||
from metagpt.environment.base_env import Environment, ExtEnv, mark_as_readable, mark_as_writeable
|
||||
|
||||
|
||||
class AndroidExtEnv(Env, ExtEnv):
|
||||
class AndroidExtEnv(Environment, ExtEnv):
|
||||
device_id: Optional[str] = Field(default=None)
|
||||
screenshot_dir: Optional[Path] = Field(default=None)
|
||||
xml_dir: Optional[Path] = Field(default=None)
|
||||
|
|
|
|||
|
|
@ -2,25 +2,29 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Desc : base env of executing environment
|
||||
|
||||
import asyncio
|
||||
from enum import Enum
|
||||
from typing import Optional, Union, Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Set, Union
|
||||
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator
|
||||
from metagpt.context import Context
|
||||
from metagpt.environment.api.env_api import (
|
||||
EnvAPIAbstract,
|
||||
ReadAPIRegistry,
|
||||
WriteAPIRegistry,
|
||||
)
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import get_function_schema, is_coroutine_func
|
||||
from metagpt.utils.common import get_function_schema, is_coroutine_func, is_send_to
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from metagpt.roles.role import Role # noqa: F401
|
||||
|
||||
|
||||
class EnvType(Enum):
|
||||
ANDROID = "Android"
|
||||
GYM = "Gym"
|
||||
WEREWOLF = "Werewolf"
|
||||
MINCRAFT = "Minsraft"
|
||||
MINCRAFT = "Mincraft"
|
||||
STANFORDTOWN = "StanfordTown"
|
||||
|
||||
|
||||
|
|
@ -28,49 +32,25 @@ env_write_api_registry = WriteAPIRegistry()
|
|||
env_read_api_registry = ReadAPIRegistry()
|
||||
|
||||
|
||||
# def mark_as_readable(func):
|
||||
# """mark functionn as a readable one in ExtEnv, it observes something from ExtEnv"""
|
||||
#
|
||||
# def wrapper(self: ExtEnv, *args, **kwargs):
|
||||
# api_name = func.__name__
|
||||
# self.read_api_registry[api_name] = func
|
||||
# return func(self, *args, **kwargs)
|
||||
#
|
||||
# return wrapper
|
||||
#
|
||||
# def mark_as_writeable(func):
|
||||
# """mark functionn as a writeable one in ExtEnv, it does something to ExtEnv"""
|
||||
#
|
||||
# def wrapper(self: ExtEnv, *args, **kwargs):
|
||||
# api_name = func.__name__
|
||||
# self.write_api_registry[api_name] = func
|
||||
# return func(self, *args, **kwargs)
|
||||
#
|
||||
# return wrapper
|
||||
|
||||
def mark_as_readable(func):
|
||||
"""mark function as a readable one in ExtEnv, it observes something from ExtEnv"""
|
||||
"""mark functionn as a readable one in ExtEnv, it observes something from ExtEnv"""
|
||||
env_read_api_registry[func.__name__] = get_function_schema(func)
|
||||
return func
|
||||
|
||||
|
||||
def mark_as_writeable(func):
|
||||
"""mark function as a writeable one in ExtEnv, it does something to ExtEnv"""
|
||||
"""mark functionn as a writeable one in ExtEnv, it does something to ExtEnv"""
|
||||
env_write_api_registry[func.__name__] = get_function_schema(func)
|
||||
return func
|
||||
|
||||
|
||||
class ExtEnv(BaseModel):
|
||||
"""External Env to intergate actual game environment"""
|
||||
|
||||
write_api_registry: WriteAPIRegistry = Field(default_factory=WriteAPIRegistry, exclude=True)
|
||||
read_api_registry: ReadAPIRegistry = Field(default_factory=ReadAPIRegistry, exclude=True)
|
||||
|
||||
|
||||
class Env(ExtEnv):
|
||||
"""Env to intergate with MetaGPT"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
class ExtEnv(BaseModel):
|
||||
"""External Env to intergate actual game environment"""
|
||||
|
||||
def _check_api_exist(self, rw_api: Optional[str] = None):
|
||||
if not rw_api:
|
||||
|
|
@ -84,45 +64,25 @@ class Env(ExtEnv):
|
|||
else:
|
||||
return env_write_api_registry.get_apis()
|
||||
|
||||
# TODO adds is_coroutine_func
|
||||
# def observe(self, env_action: Union[str, EnvAPIAbstract]):
|
||||
# if isinstance(env_action, str):
|
||||
# read_api = env_write_api_registry.get(api_name=env_action)
|
||||
# self._check_api_exist(read_api)
|
||||
# res = read_api(self)
|
||||
# elif isinstance(env_action, EnvAPIAbstract):
|
||||
# read_api = env_write_api_registry.get(api_name=env_action.api_name)
|
||||
# self._check_api_exist(read_api)
|
||||
# res = read_api(self, *env_action.args, **env_action.kwargs)
|
||||
# return res
|
||||
#
|
||||
# def step(self, env_action: Union[str, Message, EnvAPIAbstract, list[EnvAPIAbstract]]):
|
||||
# res = None
|
||||
# if isinstance(env_action, Message):
|
||||
# self.publish_message(env_action)
|
||||
# elif isinstance(env_action, EnvAPIAbstract):
|
||||
# print(f"CURRENT API NAME: {env_action.api_name}")
|
||||
# write_api = self.write_api_registry.get(env_action.api_name)
|
||||
# self._check_api_exist(write_api)
|
||||
# res = write_api(self, *env_action.args, **env_action.kwargs)
|
||||
#
|
||||
# return res
|
||||
|
||||
def observe(self, env_action: Union[str, EnvAPIAbstract]):
|
||||
# TODO Adds is_coroutine_func
|
||||
async def observe(self, env_action: Union[str, EnvAPIAbstract]):
|
||||
"""get observation from particular api of ExtEnv"""
|
||||
if isinstance(env_action, str):
|
||||
read_api = env_read_api_registry.get(api_name=env_action)["func"]
|
||||
self._check_api_exist(read_api)
|
||||
res = read_api(self)
|
||||
if is_coroutine_func(read_api):
|
||||
res = await read_api(self)
|
||||
else:
|
||||
res = read_api(self)
|
||||
elif isinstance(env_action, EnvAPIAbstract):
|
||||
read_api = env_read_api_registry.get(api_name=env_action.api_name)["func"]
|
||||
self._check_api_exist(read_api)
|
||||
res = read_api(self, *env_action.args, **env_action.kwargs)
|
||||
|
||||
if is_coroutine_func(read_api):
|
||||
res = await read_api(self, *env_action.args, **env_action.kwargs)
|
||||
else:
|
||||
res = read_api(self, *env_action.args, **env_action.kwargs)
|
||||
return res
|
||||
|
||||
def step(self, env_action: Union[str, Message, EnvAPIAbstract, list[EnvAPIAbstract]]):
|
||||
async def step(self, env_action: Union[str, Message, EnvAPIAbstract, list[EnvAPIAbstract]]):
|
||||
"""execute through particular api of ExtEnv"""
|
||||
res = None
|
||||
if isinstance(env_action, Message):
|
||||
|
|
@ -130,9 +90,131 @@ class Env(ExtEnv):
|
|||
elif isinstance(env_action, EnvAPIAbstract):
|
||||
write_api = env_write_api_registry.get(env_action.api_name)["func"]
|
||||
self._check_api_exist(write_api)
|
||||
res = write_api(self, *env_action.args, **env_action.kwargs)
|
||||
if is_coroutine_func(write_api):
|
||||
res = await write_api(self, *env_action.args, **env_action.kwargs)
|
||||
else:
|
||||
res = write_api(self, *env_action.args, **env_action.kwargs)
|
||||
|
||||
return res
|
||||
|
||||
def publish_message(self, message: "Message"):
|
||||
pass
|
||||
class Environment(ExtEnv):
|
||||
"""环境,承载一批角色,角色可以向环境发布消息,可以被其他角色观察到
|
||||
Environment, hosting a batch of roles, roles can publish messages to the environment, and can be observed by other roles
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
desc: str = Field(default="") # 环境描述
|
||||
roles: dict[str, SerializeAsAny["Role"]] = Field(default_factory=dict, validate_default=True)
|
||||
member_addrs: Dict["Role", Set] = Field(default_factory=dict, exclude=True)
|
||||
history: str = "" # For debug
|
||||
context: Context = Field(default_factory=Context, exclude=True)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def init_roles(self):
|
||||
self.add_roles(self.roles.values())
|
||||
return self
|
||||
|
||||
def add_role(self, role: "Role"):
|
||||
"""增加一个在当前环境的角色
|
||||
Add a role in the current environment
|
||||
"""
|
||||
self.roles[role.profile] = role
|
||||
role.set_env(self)
|
||||
role.context = self.context
|
||||
|
||||
def add_roles(self, roles: Iterable["Role"]):
|
||||
"""增加一批在当前环境的角色
|
||||
Add a batch of characters in the current environment
|
||||
"""
|
||||
for role in roles:
|
||||
self.roles[role.profile] = role
|
||||
|
||||
for role in roles: # setup system message with roles
|
||||
role.set_env(self)
|
||||
role.context = self.context
|
||||
|
||||
def publish_message(self, message: Message, peekable: bool = True) -> bool:
|
||||
"""
|
||||
Distribute the message to the recipients.
|
||||
In accordance with the Message routing structure design in Chapter 2.2.1 of RFC 116, as already planned
|
||||
in RFC 113 for the entire system, the routing information in the Message is only responsible for
|
||||
specifying the message recipient, without concern for where the message recipient is located. How to
|
||||
route the message to the message recipient is a problem addressed by the transport framework designed
|
||||
in RFC 113.
|
||||
"""
|
||||
logger.debug(f"publish_message: {message.dump()}")
|
||||
found = False
|
||||
# According to the routing feature plan in Chapter 2.2.3.2 of RFC 113
|
||||
for role, addrs in self.member_addrs.items():
|
||||
if is_send_to(message, addrs):
|
||||
role.put_message(message)
|
||||
found = True
|
||||
if not found:
|
||||
logger.warning(f"Message no recipients: {message.dump()}")
|
||||
self.history += f"\n{message}" # For debug
|
||||
|
||||
return True
|
||||
|
||||
async def run(self, k=1):
|
||||
"""处理一次所有信息的运行
|
||||
Process all Role runs at once
|
||||
"""
|
||||
for _ in range(k):
|
||||
futures = []
|
||||
for role in self.roles.values():
|
||||
future = role.run()
|
||||
futures.append(future)
|
||||
|
||||
await asyncio.gather(*futures)
|
||||
logger.debug(f"is idle: {self.is_idle}")
|
||||
|
||||
def get_roles(self) -> dict[str, "Role"]:
|
||||
"""获得环境内的所有角色
|
||||
Process all Role runs at once
|
||||
"""
|
||||
return self.roles
|
||||
|
||||
def get_role(self, name: str) -> "Role":
|
||||
"""获得环境内的指定角色
|
||||
get all the environment roles
|
||||
"""
|
||||
return self.roles.get(name, None)
|
||||
|
||||
def role_names(self) -> list[str]:
|
||||
return [i.name for i in self.roles.values()]
|
||||
|
||||
@property
|
||||
def is_idle(self):
|
||||
"""If true, all actions have been executed."""
|
||||
for r in self.roles.values():
|
||||
if not r.is_idle:
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_addresses(self, obj):
|
||||
"""Get the addresses of the object."""
|
||||
return self.member_addrs.get(obj, {})
|
||||
|
||||
def set_addresses(self, obj, addresses):
|
||||
"""Set the addresses of the object"""
|
||||
self.member_addrs[obj] = addresses
|
||||
|
||||
def archive(self, auto_archive=True):
|
||||
if auto_archive and self.context.git_repo:
|
||||
self.context.git_repo.archive()
|
||||
|
||||
@classmethod
|
||||
def model_rebuild(cls, **kwargs):
|
||||
from metagpt.roles.role import Role # noqa: F401
|
||||
|
||||
super().model_rebuild(**kwargs)
|
||||
|
||||
|
||||
Environment.model_rebuild()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue