mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-05 05:42:37 +02:00
commit
ba9e3fe626
3 changed files with 57 additions and 8 deletions
|
|
@ -7,7 +7,7 @@
|
|||
"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
|
@ -78,12 +78,6 @@ class Context(BaseModel):
|
|||
# env.update({k: v for k, v in i.items() if isinstance(v, str)})
|
||||
return env
|
||||
|
||||
# def use_llm(self, name: Optional[str] = None, provider: LLMType = LLMType.OPENAI) -> BaseLLM:
|
||||
# """Use a LLM instance"""
|
||||
# self._llm_config = self.config.get_llm_config(name, provider)
|
||||
# self._llm = None
|
||||
# return self._llm
|
||||
|
||||
def _select_costmanager(self, llm_config: LLMConfig) -> CostManager:
|
||||
"""Return a CostManager instance"""
|
||||
if llm_config.api_type == LLMType.FIREWORKS:
|
||||
|
|
@ -108,3 +102,38 @@ class Context(BaseModel):
|
|||
if llm.cost_manager is None:
|
||||
llm.cost_manager = self._select_costmanager(llm_config)
|
||||
return llm
|
||||
|
||||
def serialize(self) -> Dict[str, Any]:
|
||||
"""Serialize the object's attributes into a dictionary.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary containing serialized data.
|
||||
"""
|
||||
return {
|
||||
"workdir": str(self.repo.workdir) if self.repo else "",
|
||||
"kwargs": {k: v for k, v in self.kwargs.__dict__.items()},
|
||||
"cost_manager": self.cost_manager.model_dump_json(),
|
||||
}
|
||||
|
||||
def deserialize(self, serialized_data: Dict[str, Any]):
|
||||
"""Deserialize the given serialized data and update the object's attributes accordingly.
|
||||
|
||||
Args:
|
||||
serialized_data (Dict[str, Any]): A dictionary containing serialized data.
|
||||
"""
|
||||
if not serialized_data:
|
||||
return
|
||||
workdir = serialized_data.get("workdir")
|
||||
if workdir:
|
||||
self.git_repo = GitRepository(local_path=workdir, auto_init=True)
|
||||
self.repo = ProjectRepo(self.git_repo)
|
||||
src_workspace = self.git_repo.workdir / self.git_repo.workdir.name
|
||||
if src_workspace.exists():
|
||||
self.src_workspace = src_workspace
|
||||
kwargs = serialized_data.get("kwargs")
|
||||
if kwargs:
|
||||
for k, v in kwargs.items():
|
||||
self.kwargs.set(k, v)
|
||||
cost_manager = serialized_data.get("cost_manager")
|
||||
if cost_manager:
|
||||
self.cost_manager.model_validate_json(cost_manager)
|
||||
|
|
|
|||
|
|
@ -56,8 +56,10 @@ class Team(BaseModel):
|
|||
def serialize(self, stg_path: Path = None):
|
||||
stg_path = SERDESER_PATH.joinpath("team") if stg_path is None else stg_path
|
||||
team_info_path = stg_path.joinpath("team.json")
|
||||
serialized_data = self.model_dump()
|
||||
serialized_data["context"] = self.env.context.serialize()
|
||||
|
||||
write_json_file(team_info_path, self.model_dump())
|
||||
write_json_file(team_info_path, serialized_data)
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, stg_path: Path, context: Context = None) -> "Team":
|
||||
|
|
@ -71,6 +73,7 @@ class Team(BaseModel):
|
|||
|
||||
team_info: dict = read_json_file(team_info_path)
|
||||
ctx = context or Context()
|
||||
ctx.deserialize(team_info.pop("context", None))
|
||||
team = Team(**team_info, context=ctx)
|
||||
return team
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from pathlib import Path
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.context import Context
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles import Architect, ProductManager, ProjectManager
|
||||
from metagpt.team import Team
|
||||
|
|
@ -146,5 +147,21 @@ async def test_team_recover_multi_roles_save(mocker, context):
|
|||
await new_company.run(n_round=4)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context(context):
|
||||
context.kwargs.set("a", "a")
|
||||
context.cost_manager.max_budget = 9
|
||||
company = Team(context=context)
|
||||
|
||||
save_to = context.repo.workdir / "serial"
|
||||
company.serialize(save_to)
|
||||
|
||||
company.deserialize(save_to, Context())
|
||||
assert company.env.context.repo
|
||||
assert company.env.context.repo.workdir == context.repo.workdir
|
||||
assert company.env.context.kwargs.a == "a"
|
||||
assert company.env.context.cost_manager.max_budget == context.cost_manager.max_budget
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue