fix circular dependency for role/env

This commit is contained in:
shenchucheng 2024-08-11 01:43:05 +08:00
parent f4a3ff2261
commit ee4a536d55
14 changed files with 120 additions and 89 deletions

0
metagpt/base/__init__.py Normal file
View file

38
metagpt/base/base_env.py Normal file
View file

@ -0,0 +1,38 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : base environment
from abc import abstractmethod
from typing import Any, Optional
from metagpt.base.base_env_space import BaseEnvAction, BaseEnvObsParams
from metagpt.schema import Message
class BaseEnvironment:
"""Base environment"""
@abstractmethod
def reset(
self,
*,
seed: Optional[int] = None,
options: Optional[dict[str, Any]] = None,
) -> tuple[dict[str, Any], dict[str, Any]]:
"""Implement this to get init observation"""
@abstractmethod
def observe(self, obs_params: Optional[BaseEnvObsParams] = None) -> Any:
"""Implement this if you want to get partial observation from the env"""
@abstractmethod
def step(self, action: BaseEnvAction) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]:
"""Implement this to feed a action and then get new observation from the env"""
@abstractmethod
def publish_message(self, message: Message, peekable: bool = True) -> bool:
"""Distribute the message to the recipients."""
@abstractmethod
async def run(self, k=1):
"""Process all task at once"""

33
metagpt/base/base_role.py Normal file
View file

@ -0,0 +1,33 @@
from abc import abstractmethod
from typing import Optional, Union
from metagpt.schema import Message
class BaseRole:
"""Abstract base class for all roles."""
name: str
is_idle: bool
@abstractmethod
def think(self):
"""Consider what to do and decide on the next course of action."""
raise NotImplementedError
@abstractmethod
def act(self):
"""Perform the current action."""
raise NotImplementedError
@abstractmethod
async def react(self) -> Message:
"""Entry to one of three strategies by which Role reacts to the observed Message."""
@abstractmethod
async def run(self, with_message: Optional[Union[str, Message, list[str]]] = None) -> Optional[Message]:
"""Observe, and think and act based on the results of the observation."""
@abstractmethod
def get_memories(self, k: int = 0) -> list[Message]:
"""Return the most recent k memories of this role."""

View file

@ -8,9 +8,9 @@ from typing import Any, Optional
from pydantic import Field
from metagpt.base.base_env_space import BaseEnvAction, BaseEnvObsParams
from metagpt.environment.android.const import ADB_EXEC_FAIL
from metagpt.environment.base_env import ExtEnv, mark_as_readable, mark_as_writeable
from metagpt.environment.base_env_space import BaseEnvAction, BaseEnvObsParams
class AndroidExtEnv(ExtEnv):

View file

@ -5,28 +5,26 @@
import asyncio
from abc import abstractmethod
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Set, Union
from typing import Any, Dict, Iterable, Optional, Set, Union
from gymnasium import spaces
from gymnasium.core import ActType, ObsType
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator
from metagpt.base import BaseEnvironment, BaseRole
from metagpt.base.base_env_space import BaseEnvAction, BaseEnvObsParams
from metagpt.context import Context
from metagpt.environment.api.env_api import (
EnvAPIAbstract,
ReadAPIRegistry,
WriteAPIRegistry,
)
from metagpt.environment.base_env_space import BaseEnvAction, BaseEnvObsParams
from metagpt.logs import logger
from metagpt.memory import Memory
from metagpt.schema import Message
from metagpt.utils.common import get_function_schema, is_coroutine_func, is_send_to
from metagpt.utils.git_repository import GitRepository
if TYPE_CHECKING:
from metagpt.roles.role import Role # noqa: F401
class EnvType(Enum):
ANDROID = "Android"
@ -52,7 +50,7 @@ def mark_as_writeable(func):
return func
class ExtEnv(BaseModel):
class ExtEnv(BaseEnvironment, BaseModel):
"""External Env to integrate actual game environment"""
model_config = ConfigDict(arbitrary_types_allowed=True)
@ -131,8 +129,8 @@ class Environment(ExtEnv):
model_config = ConfigDict(arbitrary_types_allowed=True)
desc: str = Field(default="") # 环境描述
roles: dict[str, SerializeAsAny["Role"]] = Field(default_factory=dict, validate_default=True)
member_addrs: Dict["Role", Set] = Field(default_factory=dict, exclude=True)
roles: dict[str, SerializeAsAny[BaseRole]] = Field(default_factory=dict, validate_default=True)
member_addrs: Dict[BaseRole, Set] = Field(default_factory=dict, exclude=True)
history: Memory = Field(default_factory=Memory) # For debug
context: Context = Field(default_factory=Context, exclude=True)
@ -155,7 +153,7 @@ class Environment(ExtEnv):
self.add_roles(self.roles.values())
return self
def add_role(self, role: "Role"):
def add_role(self, role: BaseRole):
"""增加一个在当前环境的角色
Add a role in the current environment
"""
@ -163,7 +161,7 @@ class Environment(ExtEnv):
role.set_env(self)
role.context = self.context
def add_roles(self, roles: Iterable["Role"]):
def add_roles(self, roles: Iterable[BaseRole]):
"""增加一批在当前环境的角色
Add a batch of characters in the current environment
"""
@ -212,13 +210,13 @@ class Environment(ExtEnv):
await asyncio.gather(*futures)
logger.debug(f"is idle: {self.is_idle}")
def get_roles(self) -> dict[str, "Role"]:
def get_roles(self) -> dict[str, BaseRole]:
"""获得环境内的所有角色
Process all Role runs at once
"""
return self.roles
def get_role(self, name: str) -> "Role":
def get_role(self, name: str) -> BaseRole:
"""获得环境内的指定角色
get all the environment roles
"""
@ -247,12 +245,3 @@ class Environment(ExtEnv):
if auto_archive and self.context.kwargs.get("project_path"):
git_repo = GitRepository(self.context.kwargs.project_path)
git_repo.archive()
@classmethod
def model_rebuild(cls, **kwargs):
from metagpt.roles.role import Role # noqa: F401
super().model_rebuild(**kwargs)
Environment.model_rebuild()

View file

@ -10,8 +10,8 @@ from typing import Any, Optional
import requests
from pydantic import ConfigDict, Field, model_validator
from metagpt.base.base_env_space import BaseEnvAction, BaseEnvObsParams
from metagpt.environment.base_env import ExtEnv, mark_as_writeable
from metagpt.environment.base_env_space import BaseEnvAction, BaseEnvObsParams
from metagpt.environment.minecraft.const import (
MC_CKPT_DIR,
MC_CORE_INVENTORY_ITEMS,

View file

@ -9,7 +9,7 @@ import numpy.typing as npt
from gymnasium import spaces
from pydantic import ConfigDict, Field, field_validator
from metagpt.environment.base_env_space import (
from metagpt.base.base_env_space import (
BaseEnvAction,
BaseEnvActionType,
BaseEnvObsParams,

View file

@ -9,8 +9,8 @@ from typing import Any, Callable, Optional
from pydantic import ConfigDict, Field
from metagpt.base.base_env_space import BaseEnvAction, BaseEnvObsParams
from metagpt.environment.base_env import ExtEnv, mark_as_readable, mark_as_writeable
from metagpt.environment.base_env_space import BaseEnvAction, BaseEnvObsParams
from metagpt.logs import logger

View file

@ -16,7 +16,7 @@ import time
from datetime import datetime, timedelta
from operator import itemgetter
from pathlib import Path
from typing import TYPE_CHECKING, Optional
from typing import Optional
from pydantic import ConfigDict, Field, field_validator, model_validator
@ -27,6 +27,7 @@ from metagpt.environment.stanford_town.env_space import (
EnvObsParams,
EnvObsType,
)
from metagpt.environment.stanford_town.stanford_town_env import StanfordTownEnv
from metagpt.ext.stanford_town.actions.dummy_action import DummyAction, DummyMessage
from metagpt.ext.stanford_town.actions.inner_voice_action import (
AgentWhisperThoughtAction,
@ -49,28 +50,15 @@ from metagpt.roles.role import Role, RoleContext
from metagpt.schema import Message
from metagpt.utils.common import any_to_str
if TYPE_CHECKING:
from metagpt.environment.stanford_town.stanford_town_env import ( # noqa: F401
StanfordTownEnv,
)
class STRoleContext(RoleContext):
model_config = ConfigDict(arbitrary_types_allowed=True)
env: "StanfordTownEnv" = Field(default=None, exclude=True)
env: StanfordTownEnv = Field(default=None, exclude=True)
memory: AgentMemory = Field(default_factory=AgentMemory)
scratch: Scratch = Field(default_factory=Scratch)
spatial_memory: MemoryTree = Field(default_factory=MemoryTree)
@classmethod
def model_rebuild(cls, **kwargs):
from metagpt.environment.stanford_town.stanford_town_env import ( # noqa: F401
StanfordTownEnv,
)
super(RoleContext, cls).model_rebuild(**kwargs)
class STRole(Role):
# add a role's property structure to store role's age and so on like GA's Scratch.
@ -635,6 +623,3 @@ class STRole(Role):
time.sleep(0.5)
return DummyMessage()
STRoleContext.model_rebuild()

View file

@ -23,13 +23,14 @@
from __future__ import annotations
from enum import Enum
from typing import TYPE_CHECKING, Iterable, Optional, Set, Type, Union
from typing import Iterable, Optional, Set, Type, Union
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator
from metagpt.actions import Action, ActionOutput
from metagpt.actions.action_node import ActionNode
from metagpt.actions.add_requirement import UserRequirement
from metagpt.base import BaseEnvironment, BaseRole
from metagpt.const import MESSAGE_ROUTE_TO_SELF
from metagpt.context_mixin import ContextMixin
from metagpt.logs import logger
@ -47,9 +48,6 @@ from metagpt.strategy.planner import Planner
from metagpt.utils.common import any_to_name, any_to_str, role_raise_decorator
from metagpt.utils.repair_llm_raw_output import extract_state_value_from_output
if TYPE_CHECKING:
from metagpt.environment import Environment # noqa: F401
PREFIX_TEMPLATE = """You are a {profile}, named {name}, your goal is {goal}. """
CONSTRAINT_TEMPLATE = "the constraint is {constraints}. "
@ -97,7 +95,7 @@ class RoleContext(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
# # env exclude=True to avoid `RecursionError: maximum recursion depth exceeded in comparison`
env: "Environment" = Field(default=None, exclude=True) # # avoid circular import
env: BaseEnvironment = Field(default=None, exclude=True) # # avoid circular import
# TODO judge if ser&deser
msg_buffer: MessageQueue = Field(
default_factory=MessageQueue, exclude=True
@ -123,14 +121,8 @@ class RoleContext(BaseModel):
def history(self) -> list[Message]:
return self.memory.get()
@classmethod
def model_rebuild(cls, **kwargs):
from metagpt.environment.base_env import Environment # noqa: F401
super().model_rebuild(**kwargs)
class Role(SerializationMixin, ContextMixin, BaseModel):
class Role(BaseRole, SerializationMixin, ContextMixin, BaseModel):
"""Role/Agent"""
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
@ -310,7 +302,7 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
logger.debug(f"actions={self.actions}, state={state}")
self.set_todo(self.actions[self.rc.state] if state >= 0 else None)
def set_env(self, env: "Environment"):
def set_env(self, env: BaseEnvironment):
"""Set the environment in which the role works. The role can talk to the environment and can also receive
messages by observing."""
self.rc.env = env
@ -590,6 +582,3 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
if self.actions:
return any_to_name(self.actions[0])
return ""
RoleContext.model_rebuild()