From 6c2d4136649f476ebc58d1cea4f1faffe1faa109 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Mon, 25 Mar 2024 23:23:43 +0800 Subject: [PATCH] feat: add cost_manager --- metagpt/context.py | 4 ++++ tests/metagpt/serialize_deserialize/test_team.py | 4 +++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/metagpt/context.py b/metagpt/context.py index f1199c492..2bd541202 100644 --- a/metagpt/context.py +++ b/metagpt/context.py @@ -112,6 +112,7 @@ class Context(BaseModel): return { "workdir": str(self.repo.workdir) if self.repo else "", "kwargs": {k: v for k, v in self.kwargs.__dict__.items()}, + "cost_manager": self.cost_manager.model_dump_json(), } def deserialize(self, serialized_data: Dict[str, Any]): @@ -133,3 +134,6 @@ class Context(BaseModel): if kwargs: for k, v in kwargs.items(): self.kwargs.set(k, v) + cost_manager = serialized_data.get("cost_manager") + if cost_manager: + self.cost_manager.model_validate_json(cost_manager) diff --git a/tests/metagpt/serialize_deserialize/test_team.py b/tests/metagpt/serialize_deserialize/test_team.py index a300c78c1..6312e1fde 100644 --- a/tests/metagpt/serialize_deserialize/test_team.py +++ b/tests/metagpt/serialize_deserialize/test_team.py @@ -150,6 +150,7 @@ async def test_team_recover_multi_roles_save(mocker, context): @pytest.mark.asyncio async def test_context(context): context.kwargs.set("a", "a") + context.cost_manager.max_budget = 9 company = Team(context=context) save_to = context.repo.workdir / "serial" @@ -158,7 +159,8 @@ async def test_context(context): company.deserialize(save_to, Context()) assert company.env.context.repo assert company.env.context.repo.workdir == context.repo.workdir - assert context.kwargs.a == "a" + assert company.env.context.kwargs.a == "a" + assert company.env.context.cost_manager.max_budget == context.cost_manager.max_budget if __name__ == "__main__":