From 630eecd18bb79bbe32f440dbf98141536dbe264a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Mon, 25 Mar 2024 21:04:16 +0800 Subject: [PATCH 1/2] fixbug: #1094 --- metagpt/context.py | 39 +++++++++++++++---- metagpt/team.py | 5 ++- .../serialize_deserialize/test_team.py | 15 +++++++ 3 files changed, 51 insertions(+), 8 deletions(-) diff --git a/metagpt/context.py b/metagpt/context.py index 0add4c71a..f1199c492 100644 --- a/metagpt/context.py +++ b/metagpt/context.py @@ -7,7 +7,7 @@ """ import os from pathlib import Path -from typing import Any, Optional +from typing import Any, Dict, Optional from pydantic import BaseModel, ConfigDict @@ -78,12 +78,6 @@ class Context(BaseModel): # env.update({k: v for k, v in i.items() if isinstance(v, str)}) return env - # def use_llm(self, name: Optional[str] = None, provider: LLMType = LLMType.OPENAI) -> BaseLLM: - # """Use a LLM instance""" - # self._llm_config = self.config.get_llm_config(name, provider) - # self._llm = None - # return self._llm - def _select_costmanager(self, llm_config: LLMConfig) -> CostManager: """Return a CostManager instance""" if llm_config.api_type == LLMType.FIREWORKS: @@ -108,3 +102,34 @@ class Context(BaseModel): if llm.cost_manager is None: llm.cost_manager = self._select_costmanager(llm_config) return llm + + def serialize(self) -> Dict[str, Any]: + """Serialize the object's attributes into a dictionary. + + Returns: + Dict[str, Any]: A dictionary containing serialized data. + """ + return { + "workdir": str(self.repo.workdir) if self.repo else "", + "kwargs": {k: v for k, v in self.kwargs.__dict__.items()}, + } + + def deserialize(self, serialized_data: Dict[str, Any]): + """Deserialize the given serialized data and update the object's attributes accordingly. + + Args: + serialized_data (Dict[str, Any]): A dictionary containing serialized data. + """ + if not serialized_data: + return + workdir = serialized_data.get("workdir") + if workdir: + self.git_repo = GitRepository(local_path=workdir, auto_init=True) + self.repo = ProjectRepo(self.git_repo) + src_workspace = self.git_repo.workdir / self.git_repo.workdir.name + if src_workspace.exists(): + self.src_workspace = src_workspace + kwargs = serialized_data.get("kwargs") + if kwargs: + for k, v in kwargs.items(): + self.kwargs.set(k, v) diff --git a/metagpt/team.py b/metagpt/team.py index 35f987b57..79c4c36aa 100644 --- a/metagpt/team.py +++ b/metagpt/team.py @@ -56,8 +56,10 @@ class Team(BaseModel): def serialize(self, stg_path: Path = None): stg_path = SERDESER_PATH.joinpath("team") if stg_path is None else stg_path team_info_path = stg_path.joinpath("team.json") + serialized_data = self.model_dump() + serialized_data["context"] = self.env.context.serialize() - write_json_file(team_info_path, self.model_dump()) + write_json_file(team_info_path, serialized_data) @classmethod def deserialize(cls, stg_path: Path, context: Context = None) -> "Team": @@ -71,6 +73,7 @@ class Team(BaseModel): team_info: dict = read_json_file(team_info_path) ctx = context or Context() + ctx.deserialize(team_info.pop("context", None)) team = Team(**team_info, context=ctx) return team diff --git a/tests/metagpt/serialize_deserialize/test_team.py b/tests/metagpt/serialize_deserialize/test_team.py index dbd38422d..a300c78c1 100644 --- a/tests/metagpt/serialize_deserialize/test_team.py +++ b/tests/metagpt/serialize_deserialize/test_team.py @@ -8,6 +8,7 @@ from pathlib import Path import pytest +from metagpt.context import Context from metagpt.logs import logger from metagpt.roles import Architect, ProductManager, ProjectManager from metagpt.team import Team @@ -146,5 +147,19 @@ async def test_team_recover_multi_roles_save(mocker, context): await new_company.run(n_round=4) +@pytest.mark.asyncio +async def test_context(context): + context.kwargs.set("a", "a") + company = Team(context=context) + + save_to = context.repo.workdir / "serial" + company.serialize(save_to) + + company.deserialize(save_to, Context()) + assert company.env.context.repo + assert company.env.context.repo.workdir == context.repo.workdir + assert context.kwargs.a == "a" + + if __name__ == "__main__": pytest.main([__file__, "-s"]) From 6c2d4136649f476ebc58d1cea4f1faffe1faa109 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Mon, 25 Mar 2024 23:23:43 +0800 Subject: [PATCH 2/2] feat: add cost_manager --- metagpt/context.py | 4 ++++ tests/metagpt/serialize_deserialize/test_team.py | 4 +++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/metagpt/context.py b/metagpt/context.py index f1199c492..2bd541202 100644 --- a/metagpt/context.py +++ b/metagpt/context.py @@ -112,6 +112,7 @@ class Context(BaseModel): return { "workdir": str(self.repo.workdir) if self.repo else "", "kwargs": {k: v for k, v in self.kwargs.__dict__.items()}, + "cost_manager": self.cost_manager.model_dump_json(), } def deserialize(self, serialized_data: Dict[str, Any]): @@ -133,3 +134,6 @@ class Context(BaseModel): if kwargs: for k, v in kwargs.items(): self.kwargs.set(k, v) + cost_manager = serialized_data.get("cost_manager") + if cost_manager: + self.cost_manager.model_validate_json(cost_manager) diff --git a/tests/metagpt/serialize_deserialize/test_team.py b/tests/metagpt/serialize_deserialize/test_team.py index a300c78c1..6312e1fde 100644 --- a/tests/metagpt/serialize_deserialize/test_team.py +++ b/tests/metagpt/serialize_deserialize/test_team.py @@ -150,6 +150,7 @@ async def test_team_recover_multi_roles_save(mocker, context): @pytest.mark.asyncio async def test_context(context): context.kwargs.set("a", "a") + context.cost_manager.max_budget = 9 company = Team(context=context) save_to = context.repo.workdir / "serial" @@ -158,7 +159,8 @@ async def test_context(context): company.deserialize(save_to, Context()) assert company.env.context.repo assert company.env.context.repo.workdir == context.repo.workdir - assert context.kwargs.a == "a" + assert company.env.context.kwargs.a == "a" + assert company.env.context.cost_manager.max_budget == context.cost_manager.max_budget if __name__ == "__main__":