fix env=None when init Team with env=xxx

This commit is contained in:
better629 2023-12-27 16:34:43 +08:00
parent 7d523b3922
commit 2dbaee0ff2
4 changed files with 53 additions and 60 deletions

View file

@ -57,6 +57,7 @@ class Environment(BaseModel):
@model_validator(mode="after")
def init_roles(self):
self.add_roles(self.roles.values())
return self
def serialize(self, stg_path: Path):
roles_path = stg_path.joinpath("roles.json")

View file

@ -195,7 +195,7 @@ class Message(BaseModel):
def dump(self) -> str:
"""Convert the object to json string"""
return self.model_dump_json(exclude_none=True)
return self.model_dump_json(exclude_none=True, warnings=False)
@staticmethod
@handle_exception(exception_type=JSONDecodeError, default_return=None)
@ -250,15 +250,6 @@ class MessageQueue(BaseModel):
_queue: Queue = PrivateAttr(default_factory=Queue)
# _private_attributes = {"_queue": Queue()}
# def __init__(self, **kwargs: Any):
# for key in self._private_attributes.keys():
# if key in kwargs:
# object.__setattr__(self, key, kwargs[key])
# else:
# object.__setattr__(self, key, Queue())
def pop(self) -> Message | None:
"""Pop one message from the queue."""
try:

View file

@ -71,9 +71,8 @@ class Team(BaseModel):
# recover environment
environment = Environment.deserialize(stg_path=stg_path.joinpath("environment"))
# team_info.update({"env": environment})
team_info.update({"env": environment})
team = Team(**team_info)
team.env = environment
return team
def hire(self, roles: list[Role]):

View file

@ -9,38 +9,40 @@ import pytest
from metagpt.const import SERDESER_PATH
from metagpt.logs import logger
from metagpt.roles import Architect, ProductManager, ProjectManager
from metagpt.team import Team
from tests.metagpt.serialize_deserialize.test_serdeser_base import (
ActionOK,
RoleA,
RoleB,
RoleC,
serdeser_path,
)
# def test_team_deserialize():
# company = Team()
#
# pm = ProductManager()
# arch = Architect()
# company.hire(
# [
# pm,
# arch,
# ProjectManager(),
# ]
# )
# assert len(company.env.get_roles()) == 3
# ser_company = company.model_dump()
# print("ser_company ", ser_company)
# new_company = Team.model_validate(ser_company)
#
# assert len(new_company.env.get_roles()) == 3
# assert new_company.env.get_role(pm.profile) is not None
#
# new_pm = new_company.env.get_role(pm.profile)
# assert type(new_pm) == ProductManager
# assert new_company.env.get_role(pm.profile) is not None
# assert new_company.env.get_role(arch.profile) is not None
def test_team_deserialize():
company = Team()
pm = ProductManager()
arch = Architect()
company.hire(
[
pm,
arch,
ProjectManager(),
]
)
assert len(company.env.get_roles()) == 3
ser_company = company.model_dump()
new_company = Team.model_validate(ser_company)
assert len(new_company.env.get_roles()) == 3
assert new_company.env.get_role(pm.profile) is not None
new_pm = new_company.env.get_role(pm.profile)
assert type(new_pm) == ProductManager
assert new_company.env.get_role(pm.profile) is not None
assert new_company.env.get_role(arch.profile) is not None
def test_team_serdeser_save():
@ -58,30 +60,30 @@ def test_team_serdeser_save():
assert len(new_company.env.roles) == 1
# @pytest.mark.asyncio
# async def test_team_recover():
# idea = "write a snake game"
# stg_path = SERDESER_PATH.joinpath("team")
# shutil.rmtree(stg_path, ignore_errors=True)
#
# company = Team()
# role_c = RoleC()
# company.hire([role_c])
# company.run_project(idea)
# await company.run(n_round=4)
#
# ser_data = company.model_dump()
# new_company = Team(**ser_data)
#
# new_role_c = new_company.env.get_role(role_c.profile)
# # assert new_role_c.rc.memory == role_c.rc.memory # TODO
# assert new_role_c.rc.env != role_c.rc.env # TODO
# assert type(list(new_company.env.roles.values())[0].actions[0]) == ActionOK
#
# new_company.run_project(idea)
# await new_company.run(n_round=4)
#
#
@pytest.mark.asyncio
async def test_team_recover():
idea = "write a snake game"
stg_path = SERDESER_PATH.joinpath("team")
shutil.rmtree(stg_path, ignore_errors=True)
company = Team()
role_c = RoleC()
company.hire([role_c])
company.run_project(idea)
await company.run(n_round=4)
ser_data = company.model_dump()
new_company = Team(**ser_data)
new_company.env.get_role(role_c.profile)
# assert new_role_c.rc.memory == role_c.rc.memory # TODO
# assert new_role_c.rc.env != role_c.rc.env # TODO
assert type(list(new_company.env.roles.values())[0].actions[0]) == ActionOK
new_company.run_project(idea)
await new_company.run(n_round=4)
@pytest.mark.asyncio
async def test_team_recover_save():
idea = "write a 2048 web game"