fix circular dependency for role/env

This commit is contained in:
shenchucheng 2024-08-11 01:43:05 +08:00
parent f4a3ff2261
commit ee4a536d55
14 changed files with 120 additions and 89 deletions

View file

@ -8,9 +8,9 @@ from typing import Any, Optional
from pydantic import Field
from metagpt.base.base_env_space import BaseEnvAction, BaseEnvObsParams
from metagpt.environment.android.const import ADB_EXEC_FAIL
from metagpt.environment.base_env import ExtEnv, mark_as_readable, mark_as_writeable
from metagpt.environment.base_env_space import BaseEnvAction, BaseEnvObsParams
class AndroidExtEnv(ExtEnv):

View file

@ -5,28 +5,26 @@
import asyncio
from abc import abstractmethod
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Set, Union
from typing import Any, Dict, Iterable, Optional, Set, Union
from gymnasium import spaces
from gymnasium.core import ActType, ObsType
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator
from metagpt.base import BaseEnvironment, BaseRole
from metagpt.base.base_env_space import BaseEnvAction, BaseEnvObsParams
from metagpt.context import Context
from metagpt.environment.api.env_api import (
EnvAPIAbstract,
ReadAPIRegistry,
WriteAPIRegistry,
)
from metagpt.environment.base_env_space import BaseEnvAction, BaseEnvObsParams
from metagpt.logs import logger
from metagpt.memory import Memory
from metagpt.schema import Message
from metagpt.utils.common import get_function_schema, is_coroutine_func, is_send_to
from metagpt.utils.git_repository import GitRepository
if TYPE_CHECKING:
from metagpt.roles.role import Role # noqa: F401
class EnvType(Enum):
ANDROID = "Android"
@ -52,7 +50,7 @@ def mark_as_writeable(func):
return func
class ExtEnv(BaseModel):
class ExtEnv(BaseEnvironment, BaseModel):
"""External Env to integrate actual game environment"""
model_config = ConfigDict(arbitrary_types_allowed=True)
@ -131,8 +129,8 @@ class Environment(ExtEnv):
model_config = ConfigDict(arbitrary_types_allowed=True)
desc: str = Field(default="") # 环境描述
roles: dict[str, SerializeAsAny["Role"]] = Field(default_factory=dict, validate_default=True)
member_addrs: Dict["Role", Set] = Field(default_factory=dict, exclude=True)
roles: dict[str, SerializeAsAny[BaseRole]] = Field(default_factory=dict, validate_default=True)
member_addrs: Dict[BaseRole, Set] = Field(default_factory=dict, exclude=True)
history: Memory = Field(default_factory=Memory) # For debug
context: Context = Field(default_factory=Context, exclude=True)
@ -155,7 +153,7 @@ class Environment(ExtEnv):
self.add_roles(self.roles.values())
return self
def add_role(self, role: "Role"):
def add_role(self, role: BaseRole):
"""增加一个在当前环境的角色
Add a role in the current environment
"""
@ -163,7 +161,7 @@ class Environment(ExtEnv):
role.set_env(self)
role.context = self.context
def add_roles(self, roles: Iterable["Role"]):
def add_roles(self, roles: Iterable[BaseRole]):
"""增加一批在当前环境的角色
Add a batch of characters in the current environment
"""
@ -212,13 +210,13 @@ class Environment(ExtEnv):
await asyncio.gather(*futures)
logger.debug(f"is idle: {self.is_idle}")
def get_roles(self) -> dict[str, "Role"]:
def get_roles(self) -> dict[str, BaseRole]:
"""获得环境内的所有角色
Process all Role runs at once
"""
return self.roles
def get_role(self, name: str) -> "Role":
def get_role(self, name: str) -> BaseRole:
"""获得环境内的指定角色
get all the environment roles
"""
@ -247,12 +245,3 @@ class Environment(ExtEnv):
if auto_archive and self.context.kwargs.get("project_path"):
git_repo = GitRepository(self.context.kwargs.project_path)
git_repo.archive()
@classmethod
def model_rebuild(cls, **kwargs):
from metagpt.roles.role import Role # noqa: F401
super().model_rebuild(**kwargs)
Environment.model_rebuild()

View file

@ -1,33 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
from enum import IntEnum
from pydantic import BaseModel, ConfigDict, Field
class BaseEnvActionType(IntEnum):
# # NONE = 0 # no action to run, just get observation
pass
class BaseEnvAction(BaseModel):
"""env action type and its related params of action functions/apis"""
model_config = ConfigDict(arbitrary_types_allowed=True)
action_type: int = Field(default=0, description="action type")
class BaseEnvObsType(IntEnum):
# # NONE = 0 # get whole observation from env
pass
class BaseEnvObsParams(BaseModel):
"""observation params for different EnvObsType to get its observe result"""
model_config = ConfigDict(arbitrary_types_allowed=True)
obs_type: int = Field(default=0, description="observation type")

View file

@ -10,8 +10,8 @@ from typing import Any, Optional
import requests
from pydantic import ConfigDict, Field, model_validator
from metagpt.base.base_env_space import BaseEnvAction, BaseEnvObsParams
from metagpt.environment.base_env import ExtEnv, mark_as_writeable
from metagpt.environment.base_env_space import BaseEnvAction, BaseEnvObsParams
from metagpt.environment.minecraft.const import (
MC_CKPT_DIR,
MC_CORE_INVENTORY_ITEMS,

View file

@ -9,7 +9,7 @@ import numpy.typing as npt
from gymnasium import spaces
from pydantic import ConfigDict, Field, field_validator
from metagpt.environment.base_env_space import (
from metagpt.base.base_env_space import (
BaseEnvAction,
BaseEnvActionType,
BaseEnvObsParams,

View file

@ -9,8 +9,8 @@ from typing import Any, Callable, Optional
from pydantic import ConfigDict, Field
from metagpt.base.base_env_space import BaseEnvAction, BaseEnvObsParams
from metagpt.environment.base_env import ExtEnv, mark_as_readable, mark_as_writeable
from metagpt.environment.base_env_space import BaseEnvAction, BaseEnvObsParams
from metagpt.logs import logger