Merge pull request #833 from shenchucheng/fix-role-not-fully-defined-error-debugger

Fix `Role` not fully defined error debugger
This commit is contained in:
geekan 2024-02-02 20:20:27 +08:00 committed by GitHub
commit a7cb21a3ae
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 33 additions and 18 deletions

View file

@ -4,7 +4,7 @@
import asyncio
from enum import Enum
from typing import Any, Iterable, Optional, Set, Union
from typing import TYPE_CHECKING, Any, Iterable, Optional, Set, Union
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator
@ -15,10 +15,12 @@ from metagpt.environment.api.env_api import (
WriteAPIRegistry,
)
from metagpt.logs import logger
from metagpt.roles.role import Role
from metagpt.schema import Message
from metagpt.utils.common import get_function_schema, is_coroutine_func, is_send_to
if TYPE_CHECKING:
from metagpt.roles.role import Role # noqa: F401
class EnvType(Enum):
ANDROID = "Android"
@ -101,8 +103,8 @@ class Environment(ExtEnv):
model_config = ConfigDict(arbitrary_types_allowed=True)
desc: str = Field(default="") # 环境描述
roles: dict[str, SerializeAsAny[Role]] = Field(default_factory=dict, validate_default=True)
member_addrs: dict[Role, Set] = Field(default_factory=dict, exclude=True)
roles: dict[str, SerializeAsAny["Role"]] = Field(default_factory=dict, validate_default=True)
member_addrs: dict["Role", Set] = Field(default_factory=dict, exclude=True)
history: str = "" # For debug
context: Context = Field(default_factory=Context, exclude=True)
@ -111,7 +113,7 @@ class Environment(ExtEnv):
self.add_roles(self.roles.values())
return self
def add_role(self, role: Role):
def add_role(self, role: "Role"):
"""增加一个在当前环境的角色
Add a role in the current environment
"""
@ -119,7 +121,7 @@ class Environment(ExtEnv):
role.set_env(self)
role.context = self.context
def add_roles(self, roles: Iterable[Role]):
def add_roles(self, roles: Iterable["Role"]):
"""增加一批在当前环境的角色
Add a batch of characters in the current environment
"""
@ -165,13 +167,13 @@ class Environment(ExtEnv):
await asyncio.gather(*futures)
logger.debug(f"is idle: {self.is_idle}")
def get_roles(self) -> dict[str, Role]:
def get_roles(self) -> dict[str, "Role"]:
"""获得环境内的所有角色
Process all Role runs at once
"""
return self.roles
def get_role(self, name: str) -> Role:
def get_role(self, name: str) -> "Role":
"""获得环境内的指定角色
get all the environment roles
"""
@ -199,3 +201,12 @@ class Environment(ExtEnv):
def archive(self, auto_archive=True):
if auto_archive and self.context.git_repo:
self.context.git_repo.archive()
@classmethod
def model_rebuild(cls, **kwargs):
from metagpt.roles.role import Role # noqa: F401
super().model_rebuild(**kwargs)
Environment.model_rebuild()

View file

@ -23,7 +23,7 @@
from __future__ import annotations
from enum import Enum
from typing import Iterable, Optional, Set, Type, Union
from typing import TYPE_CHECKING, Iterable, Optional, Set, Type, Union
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator
@ -39,6 +39,10 @@ from metagpt.utils.common import any_to_name, any_to_str, role_raise_decorator
from metagpt.utils.project_repo import ProjectRepo
from metagpt.utils.repair_llm_raw_output import extract_state_value_from_output
if TYPE_CHECKING:
from metagpt.environment import Environment # noqa: F401
PREFIX_TEMPLATE = """You are a {profile}, named {name}, your goal is {goal}. """
CONSTRAINT_TEMPLATE = "the constraint is {constraints}. "
@ -117,6 +121,12 @@ class RoleContext(BaseModel):
def history(self) -> list[Message]:
return self.memory.get()
@classmethod
def model_rebuild(cls, **kwargs):
from metagpt.environment.base_env import Environment # noqa: F401
super().model_rebuild(**kwargs)
class Role(SerializationMixin, ContextMixin, BaseModel):
"""Role/Agent"""
@ -155,7 +165,6 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
return self
def _process_role_extra(self):
self.pydantic_rebuild_model()
kwargs = self.model_extra or {}
if self.is_human:
@ -168,14 +177,6 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
if self.latest_observed_msg:
self.recovered = True
@staticmethod
def pydantic_rebuild_model():
"""Rebuild model to avoid `RecursionError: maximum recursion depth exceeded in comparison`"""
from metagpt.environment import Environment
Environment
Role.model_rebuild()
@property
def todo(self) -> Action:
"""Get action to do"""
@ -559,3 +560,6 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
if self.actions:
return any_to_name(self.actions[0])
return ""
RoleContext.model_rebuild()