feat: add cost_manager

This commit is contained in:
莘权 马 2024-03-25 23:23:43 +08:00
parent 630eecd18b
commit 6c2d413664
2 changed files with 7 additions and 1 deletions

View file

@ -112,6 +112,7 @@ class Context(BaseModel):
return {
"workdir": str(self.repo.workdir) if self.repo else "",
"kwargs": {k: v for k, v in self.kwargs.__dict__.items()},
"cost_manager": self.cost_manager.model_dump_json(),
}
def deserialize(self, serialized_data: Dict[str, Any]):
@ -133,3 +134,6 @@ class Context(BaseModel):
if kwargs:
for k, v in kwargs.items():
self.kwargs.set(k, v)
cost_manager = serialized_data.get("cost_manager")
if cost_manager:
self.cost_manager.model_validate_json(cost_manager)

View file

@ -150,6 +150,7 @@ async def test_team_recover_multi_roles_save(mocker, context):
@pytest.mark.asyncio
async def test_context(context):
context.kwargs.set("a", "a")
context.cost_manager.max_budget = 9
company = Team(context=context)
save_to = context.repo.workdir / "serial"
@ -158,7 +159,8 @@ async def test_context(context):
company.deserialize(save_to, Context())
assert company.env.context.repo
assert company.env.context.repo.workdir == context.repo.workdir
assert context.kwargs.a == "a"
assert company.env.context.kwargs.a == "a"
assert company.env.context.cost_manager.max_budget == context.cost_manager.max_budget
if __name__ == "__main__":