fix circular dependency for role/env

This commit is contained in:
shenchucheng 2024-08-11 01:43:05 +08:00
parent f4a3ff2261
commit ee4a536d55
14 changed files with 120 additions and 89 deletions

View file

@ -23,13 +23,14 @@
from __future__ import annotations
from enum import Enum
from typing import TYPE_CHECKING, Iterable, Optional, Set, Type, Union
from typing import Iterable, Optional, Set, Type, Union
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator
from metagpt.actions import Action, ActionOutput
from metagpt.actions.action_node import ActionNode
from metagpt.actions.add_requirement import UserRequirement
from metagpt.base import BaseEnvironment, BaseRole
from metagpt.const import MESSAGE_ROUTE_TO_SELF
from metagpt.context_mixin import ContextMixin
from metagpt.logs import logger
@ -47,9 +48,6 @@ from metagpt.strategy.planner import Planner
from metagpt.utils.common import any_to_name, any_to_str, role_raise_decorator
from metagpt.utils.repair_llm_raw_output import extract_state_value_from_output
if TYPE_CHECKING:
from metagpt.environment import Environment # noqa: F401
PREFIX_TEMPLATE = """You are a {profile}, named {name}, your goal is {goal}. """
CONSTRAINT_TEMPLATE = "the constraint is {constraints}. "
@ -97,7 +95,7 @@ class RoleContext(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
# # env exclude=True to avoid `RecursionError: maximum recursion depth exceeded in comparison`
env: "Environment" = Field(default=None, exclude=True) # # avoid circular import
env: BaseEnvironment = Field(default=None, exclude=True) # # avoid circular import
# TODO judge if ser&deser
msg_buffer: MessageQueue = Field(
default_factory=MessageQueue, exclude=True
@ -123,14 +121,8 @@ class RoleContext(BaseModel):
def history(self) -> list[Message]:
return self.memory.get()
@classmethod
def model_rebuild(cls, **kwargs):
from metagpt.environment.base_env import Environment # noqa: F401
super().model_rebuild(**kwargs)
class Role(SerializationMixin, ContextMixin, BaseModel):
class Role(BaseRole, SerializationMixin, ContextMixin, BaseModel):
"""Role/Agent"""
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
@ -310,7 +302,7 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
logger.debug(f"actions={self.actions}, state={state}")
self.set_todo(self.actions[self.rc.state] if state >= 0 else None)
def set_env(self, env: "Environment"):
def set_env(self, env: BaseEnvironment):
"""Set the environment in which the role works. The role can talk to the environment and can also receive
messages by observing."""
self.rc.env = env
@ -590,6 +582,3 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
if self.actions:
return any_to_name(self.actions[0])
return ""
RoleContext.model_rebuild()