From 6208400f71ee926ed422aed9ed2cc160d7a0de4e Mon Sep 17 00:00:00 2001 From: better629 Date: Thu, 30 Nov 2023 21:42:09 +0800 Subject: [PATCH] fix role._rc init --- metagpt/environment.py | 4 ++++ metagpt/roles/role.py | 11 ++++++----- .../serialize_deserialize/test_team.py | 19 ++++++++++++++++--- 3 files changed, 26 insertions(+), 8 deletions(-) diff --git a/metagpt/environment.py b/metagpt/environment.py index bade53f50..bff12210d 100644 --- a/metagpt/environment.py +++ b/metagpt/environment.py @@ -31,6 +31,7 @@ class Environment(BaseModel): arbitrary_types_allowed = True def __init__(self, **kwargs): + roles = [] for role_key, role in kwargs.get("roles", {}).items(): current_role = kwargs["roles"][role_key] if isinstance(current_role, dict): @@ -41,8 +42,11 @@ class Environment(BaseModel): current_role = subclass(**current_role) break kwargs["roles"][role_key] = current_role + roles.append(current_role) super().__init__(**kwargs) + self.add_roles(roles) # add_roles again to init the Role.set_env + def serialize(self, stg_path: Path): roles_path = stg_path.joinpath("roles.json") roles_info = [] diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 38f564caa..b78597d01 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -88,13 +88,14 @@ class RoleSetting(BaseModel): class RoleContext(BaseModel): """Role Runtime Context""" + # # env exclude=True to avoid `RecursionError: maximum recursion depth exceeded in comparison` env: "Environment" = Field(default=None, exclude=True) memory: Memory = Field(default_factory=Memory) - long_term_memory: LongTermMemory = Field(default_factory=LongTermMemory) + long_term_memory: LongTermMemory = Field(default_factory=LongTermMemory, exclude=True) # TODO not used now state: int = Field(default=-1) # -1 indicates initial or termination state where todo is None todo: Action = Field(default=None) watch: set[Type[Action]] = Field(default_factory=set) - news: list[Type[Message]] = Field(default=[]) + news: list[Type[Message]] = Field(default=[], exclude=True) # TODO not used react_mode: RoleReactMode = RoleReactMode.REACT # see `Role._set_react_mode` for definitions of the following two attributes max_react_loop: int = 1 @@ -128,12 +129,12 @@ class Role(BaseModel): desc: str = "" is_human: bool = False - _llm: BaseGPTAPI = Field(default_factory=LLM, exclude=True) + _llm: BaseGPTAPI = Field(default_factory=LLM) _setting: RoleSetting = Field(default_factory=RoleSetting, alias=True) _role_id: str = "" _states: list[str] = Field(default=[]) _actions: list[Action] = Field(default=[]) - _rc: RoleContext = Field(default=RoleContext, exclude=True) + _rc: RoleContext = Field(default=RoleContext) # builtin variables recovered: bool = False # to tag if a recovered role @@ -179,7 +180,7 @@ class Role(BaseModel): setting = RoleSetting(**kwargs[key]) object.__setattr__(self, "_setting", setting) elif key == "_rc": - _rc = RoleContext() + _rc = RoleContext(**kwargs["_rc"]) object.__setattr__(self, "_rc", _rc) else: if key == "_rc": diff --git a/tests/metagpt/serialize_deserialize/test_team.py b/tests/metagpt/serialize_deserialize/test_team.py index b8972135b..e5ec20f2e 100644 --- a/tests/metagpt/serialize_deserialize/test_team.py +++ b/tests/metagpt/serialize_deserialize/test_team.py @@ -39,7 +39,7 @@ def test_team_deserialize(): assert new_company.environment.get_role(arch.profile) is not None -def test_team_serdeser(): +def test_team_serdeser_save(): company = Team() company.hire([RoleC()]) @@ -60,12 +60,19 @@ async def test_team_recover(): shutil.rmtree(stg_path, ignore_errors=True) company = Team() - company.hire([RoleC()]) + role_c = RoleC() + company.hire([role_c]) company.start_project(idea) await company.run(n_round=4) ser_data = company.dict() new_company = Team(**ser_data) + + new_role_c = new_company.environment.get_role(role_c.profile) + assert new_role_c._rc.memory == role_c._rc.memory + assert new_role_c._rc.env != role_c._rc.env # due to Action raise, role's memory has been changed. + assert new_role_c._rc.env.memory == role_c._rc.env.memory + assert new_company.environment.memory.count() == 1 assert type(list(new_company.environment.roles.values())[0]._actions[0]) == ActionOK @@ -80,11 +87,17 @@ async def test_team_recover_save(): shutil.rmtree(stg_path, ignore_errors=True) company = Team() - company.hire([RoleC()]) + role_c = RoleC() + company.hire([role_c]) company.start_project(idea) await company.run(n_round=4) new_company = Team.recover(stg_path) + new_role_c = new_company.environment.get_role(role_c.profile) + assert new_role_c._rc.memory == role_c._rc.memory + assert new_role_c._rc.env != role_c._rc.env # due to Action raise, role's memory has been changed. + assert new_role_c._rc.env.memory == role_c._rc.env.memory + new_company.start_project(idea) await new_company.run(n_round=4)