From 949bc747f92c368f47bd73966e0eba205d4f7a40 Mon Sep 17 00:00:00 2001 From: better629 Date: Tue, 28 Nov 2023 09:29:00 +0800 Subject: [PATCH] add mg ser&deser --- metagpt/actions/action.py | 31 +++++++ metagpt/const.py | 2 + metagpt/environment.py | 38 +++++++++ metagpt/memory/memory.py | 30 +++++++ metagpt/roles/role.py | 117 ++++++++++++++++++++++++++- metagpt/schema.py | 44 +++++++++- metagpt/team.py | 26 ++++++ metagpt/utils/serialize.py | 62 ++++++++++++-- metagpt/utils/utils.py | 38 ++++++++- startup.py | 81 +++++++++++++++++++ tests/metagpt/actions/test_action.py | 17 ++++ tests/metagpt/memory/test_memory.py | 42 ++++++++++ tests/metagpt/roles/test_role.py | 85 +++++++++++++++++++ tests/metagpt/test_environment.py | 27 +++++-- tests/metagpt/test_schema.py | 42 ++++++++++ tests/metagpt/test_team.py | 27 +++++++ 16 files changed, 693 insertions(+), 16 deletions(-) create mode 100644 startup.py create mode 100644 tests/metagpt/memory/test_memory.py create mode 100644 tests/metagpt/roles/test_role.py create mode 100644 tests/metagpt/test_team.py diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index 1534b1f4d..3bfb69de4 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -17,6 +17,7 @@ from metagpt.logs import logger from metagpt.provider.postprecess.llm_output_postprecess import llm_output_postprecess from metagpt.utils.common import OutputParser from metagpt.utils.utils import general_after_log +from metagpt.utils.utils import import_class class Action(ABC): @@ -51,6 +52,36 @@ class Action(ABC): def __repr__(self): return self.__str__() + def serialize(self): + return { + "action_class": self.__class__.__name__, + "module_name": self.__module__, + "name": self.name + } + + @classmethod + def deserialize(cls, action_dict: dict): + action_class_str = action_dict.pop("action_class") + module_name = action_dict.pop("module_name") + action_class = import_class(action_class_str, module_name) + return action_class(**action_dict) + + @classmethod + def ser_class(cls): + """ serialize class type""" + return { + "action_class": cls.__name__, + "module_name": cls.__module__ + } + + @classmethod + def deser_class(cls, action_dict: dict): + """ deserialize class type """ + action_class_str = action_dict.pop("action_class") + module_name = action_dict.pop("module_name") + action_class = import_class(action_class_str, module_name) + return action_class + async def _aask(self, prompt: str, system_msgs: Optional[list[str]] = None) -> str: """Append default prefix""" if not system_msgs: diff --git a/metagpt/const.py b/metagpt/const.py index 10de0ff66..b46bc15a4 100644 --- a/metagpt/const.py +++ b/metagpt/const.py @@ -60,6 +60,8 @@ SWAGGER_PATH = UT_PATH / "files/api/" UT_PY_PATH = UT_PATH / "files/ut/" API_QUESTIONS_PATH = UT_PATH / "files/question/" +SERDES_PATH = DEFAULT_WORKSPACE_ROOT / "storage" # TODO to store `storage` under the individual generated project + TMP = METAGPT_ROOT / "tmp" SOURCE_ROOT = METAGPT_ROOT / "metagpt" diff --git a/metagpt/environment.py b/metagpt/environment.py index 89b6f9d46..14da6cd95 100644 --- a/metagpt/environment.py +++ b/metagpt/environment.py @@ -13,6 +13,7 @@ """ import asyncio from typing import Iterable, Set +from pathlib import Path from pydantic import BaseModel, Field @@ -20,6 +21,7 @@ from metagpt.logs import logger from metagpt.roles import Role from metagpt.schema import Message from metagpt.utils.common import is_subscribed +from metagpt.utils.utils import read_json_file, write_json_file class Environment(BaseModel): @@ -35,6 +37,42 @@ class Environment(BaseModel): class Config: arbitrary_types_allowed = True + def serialize(self, stg_path: Path): + roles_path = stg_path.joinpath("roles.json") + roles_info = [] + for role_key, role in self.roles.items(): + roles_info.append({ + "role_class": role.__class__.__name__, + "module_name": role.__module__, + "role_name": role.name + }) + role.serialize(stg_path=stg_path.joinpath(f"roles/{role.__class__.__name__}_{role.name}")) + write_json_file(roles_path, roles_info) + + self.memory.serialize(stg_path) + history_path = stg_path.joinpath("history.json") + write_json_file(history_path, {"content": self.history}) + + def deserialize(self, stg_path: Path): + """ stg_path: ./storage/team/environment/ """ + roles_path = stg_path.joinpath("roles.json") + roles_info = read_json_file(roles_path) + for role_info in roles_info: + role_class = role_info.get("role_class") + role_name = role_info.get("role_name") + + role_path = stg_path.joinpath(f"roles/{role_class}_{role_name}") + role = Role.deserialize(role_path) + + self.add_role(role) + + memory = Memory.deserialize(stg_path) + self.memory = memory + + history_path = stg_path.joinpath("history.json") + history = read_json_file(history_path) + self.history = history.get("content") + def add_role(self, role: Role): """增加一个在当前环境的角色 Add a role in the current environment diff --git a/metagpt/memory/memory.py b/metagpt/memory/memory.py index 53b65fcf7..43bd33e59 100644 --- a/metagpt/memory/memory.py +++ b/metagpt/memory/memory.py @@ -8,9 +8,12 @@ """ from collections import defaultdict from typing import Iterable, Set +from pathlib import Path from metagpt.schema import Message from metagpt.utils.common import any_to_str, any_to_str_set +from metagpt.utils.utils import read_json_file, write_json_file +from metagpt.utils.serialize import serialize_general_message, deserialize_general_message class Memory: @@ -21,6 +24,33 @@ class Memory: self.storage: list[Message] = [] self.index: dict[str, list[Message]] = defaultdict(list) + def serialize(self, stg_path: Path): + """ stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/ """ + memory_path = stg_path.joinpath("memory.json") + + storage = [] + for message in self.storage: + # msg_dict = message.serialize() + msg_dict = serialize_general_message(message) + storage.append(msg_dict) + + write_json_file(memory_path, storage) + + @classmethod + def deserialize(cls, stg_path: Path) -> "Memory": + """ stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/""" + memory_path = stg_path.joinpath("memory.json") + + memory = Memory() + memory_list = read_json_file(memory_path) + for message in memory_list: + # distinguish instruct_content type in message + # msg = Message.deserialize(message) + msg = deserialize_general_message(message) + memory.add(msg) + + return memory + def add(self, message: Message): """Add a new message to storage, while updating the index""" if message in self.storage: diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 1e7ebf711..bb3b2acfe 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -22,7 +22,7 @@ from __future__ import annotations from enum import Enum from typing import Iterable, Set, Type - +from pathlib import Path from pydantic import BaseModel, Field from metagpt.actions import Action, ActionOutput @@ -30,10 +30,12 @@ from metagpt.actions.action_node import ActionNode from metagpt.actions.add_requirement import UserRequirement from metagpt.llm import LLM, HumanProvider from metagpt.logs import logger -from metagpt.memory import Memory from metagpt.schema import Message, MessageQueue from metagpt.utils.common import any_to_str from metagpt.utils.repair_llm_raw_output import extract_state_value_from_output +from metagpt.memory import Memory +from metagpt.utils.utils import read_json_file, write_json_file, import_class + PREFIX_TEMPLATE = """You are a {profile}, named {name}, your goal is {goal}, and the constraint is {constraints}. """ @@ -152,6 +154,87 @@ class Role(metaclass=_RoleInjector): self._rc = RoleContext() self._subscription = {any_to_str(self), name} if name else {any_to_str(self)} + self._recovered = False + + def serialize(self, stg_path: Path): + role_info_path = stg_path.joinpath("role_info.json") + role_info = { + "role_class": self.__class__.__name__, + "module_name": self.__module__ + } + setting = self._setting.dict() + setting.pop("desc") + setting.pop("is_human") # not all inherited roles have this atrr + role_info.update(setting) + write_json_file(role_info_path, role_info) + + actions_info_path = stg_path.joinpath("actions/actions_info.json") + actions_info = [] + for action in self._actions: + actions_info.append(action.serialize()) + write_json_file(actions_info_path, actions_info) + + watches_info_path = stg_path.joinpath("watches/watches_info.json") + watches_info = [] + for watch in self._rc.watch: + watches_info.append(watch.ser_class()) + write_json_file(watches_info_path, watches_info) + + actions_todo_path = stg_path.joinpath("actions/todo.json") + actions_todo = { + "cur_state": self._rc.state, + "react_mode": self._rc.react_mode.value, + "max_react_loop": self._rc.max_react_loop + } + write_json_file(actions_todo_path, actions_todo) + + self._rc.memory.serialize(stg_path) + + @classmethod + def deserialize(cls, stg_path: Path) -> "Role": + """ stg_path = ./storage/team/environment/roles/{role_class}_{role_name}""" + role_info_path = stg_path.joinpath("role_info.json") + role_info = read_json_file(role_info_path) + + role_class_str = role_info.pop("role_class") + module_name = role_info.pop("module_name") + role_class = import_class(class_name=role_class_str, module_name=module_name) + + role = role_class(**role_info) # initiate particular Role + actions_info_path = stg_path.joinpath("actions/actions_info.json") + actions = [] + actions_info = read_json_file(actions_info_path) + for action_info in actions_info: + action = Action.deserialize(action_info) + actions.append(action) + + watches_info_path = stg_path.joinpath("watches/watches_info.json") + watches = [] + watches_info = read_json_file(watches_info_path) + for watch_info in watches_info: + action = Action.deser_class(watch_info) + watches.append(action) + + role.init_actions(actions) + role.watch(watches) + + actions_todo_path = stg_path.joinpath("actions/todo.json") + # recover self._rc.state + actions_todo = read_json_file(actions_todo_path) + max_react_loop = actions_todo.get("max_react_loop", 1) + cur_state = actions_todo.get("cur_state", -1) + role.set_state(cur_state) + role.set_recovered(True) + react_mode_str = actions_todo.get("react_mode", RoleReactMode.REACT.value) + if react_mode_str not in RoleReactMode.values(): + logger.warning(f"ReactMode: {react_mode_str} not in {RoleReactMode.values()}, use react as default") + react_mode_str = RoleReactMode.REACT.value + role.set_react_mode(RoleReactMode(react_mode_str), max_react_loop) + + role_memory = Memory.deserialize(stg_path) + role.set_memory(role_memory) + + return role def _reset(self): self._states = [] @@ -160,6 +243,15 @@ class Role(metaclass=_RoleInjector): def _init_action_system_message(self, action: Action): action.set_prefix(self._get_prefix(), self.profile) + def set_recovered(self, recovered: bool = False): + self._recovered = recovered + + def set_memory(self, memory: Memory): + self._rc.memory = memory + + def init_actions(self, actions): + self._init_actions(actions) + def _init_actions(self, actions): self._reset() for idx, action in enumerate(actions): @@ -178,6 +270,9 @@ class Role(metaclass=_RoleInjector): self._actions.append(i) self._states.append(f"{idx}. {action}") + def set_react_mode(self, react_mode: RoleReactMode, max_react_loop: int = 1): + self._set_react_mode(react_mode, max_react_loop) + def _set_react_mode(self, react_mode: str, max_react_loop: int = 1): """Set strategy of the Role reacting to observed Message. Variation lies in how this Role elects action to perform during the _think stage, especially if it is capable of multiple Actions. @@ -199,6 +294,9 @@ class Role(metaclass=_RoleInjector): if react_mode == RoleReactMode.REACT: self._rc.max_react_loop = max_react_loop + def watch(self, actions: Iterable[Type[Action]]): + self._watch(actions) + def _watch(self, actions: Iterable[Type[Action]]): """Watch Actions of interest. Role will select Messages caused by these Actions from its personal message buffer during _observe. @@ -217,6 +315,9 @@ class Role(metaclass=_RoleInjector): if self._rc.env: # According to the routing feature plan in Chapter 2.2.3.2 of RFC 113 self._rc.env.set_subscription(self, self._subscription) + def set_state(self, state: int): + self._set_state(state) + def _set_state(self, state: int): """Update the current state.""" self._rc.state = state @@ -230,6 +331,10 @@ class Role(metaclass=_RoleInjector): if env: env.set_subscription(self, self._subscription) + @property + def name(self): + return self._setting.name + @property def profile(self): """Get the role description (position)""" @@ -257,6 +362,11 @@ class Role(metaclass=_RoleInjector): # If there is only one action, then only this one can be performed self._set_state(0) return + if self._recovered and self._rc.state >= 0: + self._set_state(self._rc.state) # action to run from recovered state + self._recovered = False # avoid max_react_loop out of work + return + prompt = self._get_prefix() prompt += STATE_TEMPLATE.format( history=self._rc.history, @@ -349,7 +459,8 @@ class Role(metaclass=_RoleInjector): async def _act_by_order(self) -> Message: """switch action each time by order defined in _init_actions, i.e. _act (Action1) -> _act (Action2) -> ...""" - for i in range(len(self._states)): + start_idx = self._rc.state if self._rc.state >= 0 else 0 # action to run from recovered state + for i in range(start_idx, len(self._states)): self._set_state(i) rsp = await self._act() return rsp # return output from the last action diff --git a/metagpt/schema.py b/metagpt/schema.py index 5aec378e4..78e4a6031 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -22,7 +22,6 @@ from asyncio import Queue, QueueEmpty, wait_for from json import JSONDecodeError from pathlib import Path from typing import Dict, List, Optional, Set, TypedDict - from pydantic import BaseModel, Field from metagpt.config import CONFIG @@ -36,6 +35,9 @@ from metagpt.const import ( ) from metagpt.logs import logger from metagpt.utils.common import any_to_str, any_to_str_set +# from metagpt.utils.serialize import actionoutout_schema_to_mapping +# from metagpt.actions.action_output import ActionOutput +# from metagpt.actions.action import Action class RawMessage(TypedDict): @@ -155,6 +157,46 @@ class Message(BaseModel): def __repr__(self): return self.__str__() + # def serialize(self): + # message_cp: Message = copy.deepcopy(self) + # ic = message_cp.instruct_content + # if ic: + # # model create by pydantic create_model like `pydantic.main.prd`, can't pickle.dump directly + # schema = ic.schema() + # mapping = actionoutout_schema_to_mapping(schema) + # + # message_cp.instruct_content = {"class": schema["title"], "mapping": mapping, "value": ic.dict()} + # cb = message_cp.cause_by + # if cb: + # message_cp.cause_by = cb.serialize() + # + # return message_cp.dict() + # + # @classmethod + # def deserialize(cls, message_dict: dict): + # instruct_content = message_dict.get("instruct_content") + # if instruct_content: + # ic = instruct_content + # ic_obj = ActionOutput.create_model_class(class_name=ic["class"], mapping=ic["mapping"]) + # ic_new = ic_obj(**ic["value"]) + # message_dict.instruct_content = ic_new + # cause_by = message_dict.get("cause_by") + # if cause_by: + # message_dict.cause_by = Action.deserialize(cause_by) + # + # return Message(**message_dict) + + def dict(self): + return { + "content": self.content, + "instruct_content": self.instruct_content, + "role": self.role, + "cause_by": self.cause_by, + "sent_from": self.sent_from, + "send_to": self.send_to, + "restricted_to": self.restricted_to + } + def to_dict(self) -> dict: """Return a dict containing `role` and `content` for the LLM call.l""" return {"role": self.role, "content": self.content} diff --git a/metagpt/team.py b/metagpt/team.py index a5c405f80..02c48a138 100644 --- a/metagpt/team.py +++ b/metagpt/team.py @@ -7,6 +7,7 @@ @Modified By: mashenquan, 2023/11/27. Add an archiving operation after completing the project, as specified in Section 2.2.3.3 of RFC 135. """ +from pathlib import Path from pydantic import BaseModel, Field from metagpt.actions import UserRequirement @@ -17,6 +18,7 @@ from metagpt.logs import logger from metagpt.roles import Role from metagpt.schema import Message from metagpt.utils.common import NoMoneyException +from metagpt.utils.utils import read_json_file, write_json_file class Team(BaseModel): @@ -32,6 +34,30 @@ class Team(BaseModel): class Config: arbitrary_types_allowed = True + def serialize(self, stg_path: Path): + team_info_path = stg_path.joinpath("team_info.json") + write_json_file(team_info_path, { + "idea": self.idea, + "investment": self.investment + }) + + self.environment.serialize(stg_path.joinpath("environment")) + + def deserialize(self, stg_path: Path): + """ stg_path = ./storage/team """ + # recover team_info + team_info_path = stg_path.joinpath("team_info.json") + if not team_info_path.exists(): + logger.error("recover storage not exist, not to recover and continue run the old project.") + team_info = read_json_file(team_info_path) + self.investment = team_info.get("investment", 10.0) + self.idea = team_info.get("idea", "") + + # recover environment + environment_path = stg_path.joinpath("environment") + self.environment = Environment() + self.environment.deserialize(stg_path=environment_path) + def hire(self, roles: list[Role]): """Hire roles to cooperate""" self.env.add_roles(roles) diff --git a/metagpt/utils/serialize.py b/metagpt/utils/serialize.py index 124176fcb..56a866f2e 100644 --- a/metagpt/utils/serialize.py +++ b/metagpt/utils/serialize.py @@ -4,13 +4,13 @@ import copy import pickle -from typing import Dict, List from metagpt.actions.action_output import ActionOutput from metagpt.schema import Message +from metagpt.actions.action import Action -def actionoutout_schema_to_mapping(schema: Dict) -> Dict: +def actionoutout_schema_to_mapping(schema: dict) -> dict: """ directly traverse the `properties` in the first level. schema structure likes @@ -35,13 +35,47 @@ def actionoutout_schema_to_mapping(schema: Dict) -> Dict: if property["type"] == "string": mapping[field] = (str, ...) elif property["type"] == "array" and property["items"]["type"] == "string": - mapping[field] = (List[str], ...) + mapping[field] = (list[str], ...) elif property["type"] == "array" and property["items"]["type"] == "array": - # here only consider the `List[List[str]]` situation - mapping[field] = (List[List[str]], ...) + # here only consider the `list[list[str]]` situation + mapping[field] = (list[list[str]], ...) return mapping +def actionoutput_mapping_to_str(mapping: dict) -> dict: + new_mapping = {} + for key, value in mapping.items(): + new_mapping[key] = str(value) + return new_mapping + + +def actionoutput_str_to_mapping(mapping: dict) -> dict: + new_mapping = {} + for key, value in mapping.items(): + if value == "(, Ellipsis)": + new_mapping[key] = (str, ...) + else: + new_mapping[key] = eval(value) # `"'(list[str], Ellipsis)"` to `(list[str], ...)` + return new_mapping + + +def serialize_general_message(message: Message) -> dict: + """ serialize Message, not to save""" + message_cp = copy.deepcopy(message) + ic = message_cp.instruct_content + if ic: + # model create by pydantic create_model like `pydantic.main.prd`, can't pickle.dump directly + schema = ic.schema() + mapping = actionoutout_schema_to_mapping(schema) + mapping = actionoutput_mapping_to_str(mapping) + + message_cp.instruct_content = {"class": schema["title"], "mapping": mapping, "value": ic.dict()} + cb = message_cp.cause_by + if cb: + message_cp.cause_by = cb.ser_class() + return message_cp.dict() + + def serialize_message(message: Message): message_cp = copy.deepcopy(message) # avoid `instruct_content` value update by reference ic = message_cp.instruct_content @@ -56,6 +90,24 @@ def serialize_message(message: Message): return msg_ser +def deserialize_general_message(message_dict: dict) -> Message: + """ deserialize Message, not to load""" + instruct_content = message_dict.pop("instruct_content") + cause_by = message_dict.pop("cause_by") + + message = Message(**message_dict) + if instruct_content: + ic = instruct_content + mapping = actionoutput_str_to_mapping(ic["mapping"]) + ic_obj = ActionOutput.create_model_class(class_name=ic["class"], mapping=mapping) + ic_new = ic_obj(**ic["value"]) + message.instruct_content = ic_new + if cause_by: + message.cause_by = Action.deser_class(cause_by) + + return message + + def deserialize_message(message_ser: str) -> Message: message = pickle.loads(message_ser) if message.instruct_content: diff --git a/metagpt/utils/utils.py b/metagpt/utils/utils.py index 5ceed65d9..220e228c3 100644 --- a/metagpt/utils/utils.py +++ b/metagpt/utils/utils.py @@ -3,7 +3,10 @@ # @Desc : import typing - +from typing import Any +import json +from pathlib import Path +import importlib from tenacity import _utils @@ -20,3 +23,36 @@ def general_after_log(logger: "loguru.Logger", sec_format: str = "%0.3f") -> typ ) return log_it + + +def read_json_file(json_file: str, encoding=None) -> list[Any]: + if not Path(json_file).exists(): + raise FileNotFoundError(f"json_file: {json_file} not exist, return []") + + with open(json_file, "r", encoding=encoding) as fin: + try: + data = json.load(fin) + except Exception as exp: + raise ValueError(f"read json file: {json_file} failed") + return data + + +def write_json_file(json_file: str, data: list, encoding=None): + folder_path = Path(json_file).parent + if not folder_path.exists(): + folder_path.mkdir(parents=True, exist_ok=True) + + with open(json_file, "w", encoding=encoding) as fout: + json.dump(data, fout, ensure_ascii=False, indent=4) + + +def import_class(class_name: str, module_name: str) -> type: + module = importlib.import_module(module_name) + a_class = getattr(module, class_name) + return a_class + + +def import_class_inst(class_name: str, module_name: str, *args, **kwargs) -> object: + a_class = import_class(class_name, module_name) + class_inst = a_class(*args, **kwargs) + return class_inst diff --git a/startup.py b/startup.py new file mode 100644 index 000000000..9f753d553 --- /dev/null +++ b/startup.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +import asyncio + +import fire + +from metagpt.const import SERDES_PATH +from metagpt.roles import ( + Architect, + Engineer, + ProductManager, + ProjectManager, + QaEngineer, +) +from metagpt.team import Team + + +async def startup( + idea: str, + investment: float = 3.0, + n_round: int = 5, + code_review: bool = False, + run_tests: bool = False, + implement: bool = True, + recover_path: bool = False, +): + """Run a startup. Be a boss.""" + company = Team() + if not recover_path: + company.hire( + [ + ProductManager(), + Architect(), + ProjectManager(), + ] + ) + + # if implement or code_review + if implement or code_review: + # developing features: implement the idea + company.hire([Engineer(n_borg=5, use_code_review=code_review)]) + + if run_tests: + # developing features: run tests on the spot and identify bugs + # (bug fixing capability comes soon!) + company.hire([QaEngineer()]) + else: + stg_path = SERDES_PATH.joinpath("team") + company.deserialize(stg_path=stg_path) + idea = company.idea # use original idea + + company.invest(investment) + company.start_project(idea) + await company.run(n_round=n_round) + + +def main( + idea: str, + investment: float = 3.0, + n_round: int = 5, + code_review: bool = True, + run_tests: bool = False, + implement: bool = True, + recover_path: str = None, +): + """ + We are a software startup comprised of AI. By investing in us, + you are empowering a future filled with limitless possibilities. + :param idea: Your innovative idea, such as "Creating a snake game." + :param investment: As an investor, you have the opportunity to contribute + a certain dollar amount to this AI company. + :param n_round: + :param code_review: Whether to use code review. + :param recover_path: recover the project from existing serialized storage + :return: + """ + asyncio.run(startup(idea, investment, n_round, code_review, run_tests, implement, recover_path)) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/tests/metagpt/actions/test_action.py b/tests/metagpt/actions/test_action.py index 9775630cc..4468a6f6f 100644 --- a/tests/metagpt/actions/test_action.py +++ b/tests/metagpt/actions/test_action.py @@ -11,3 +11,20 @@ from metagpt.actions import Action, WritePRD, WriteTest def test_action_repr(): actions = [Action(), WriteTest(), WritePRD()] assert "WriteTest" in str(actions) + + +def test_action_serdes(): + action_info = WriteTest.ser_class() + assert action_info["action_class"] == "WriteTest" + + action_class = Action.deser_class(action_info) + assert action_class == WriteTest + + +def test_action_class_serdes(): + name = "write test" + action_info = WriteTest(name=name).serialize() + assert action_info["name"] == name + + action = Action.deserialize(action_info) + assert action.name == name diff --git a/tests/metagpt/memory/test_memory.py b/tests/metagpt/memory/test_memory.py new file mode 100644 index 000000000..bda79ded1 --- /dev/null +++ b/tests/metagpt/memory/test_memory.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : unittest of memory + +from pathlib import Path + +from metagpt.schema import Message +from metagpt.memory.memory import Memory +from metagpt.actions.action_output import ActionOutput +from metagpt.actions.design_api import WriteDesign +from metagpt.actions.add_requirement import BossRequirement + +serdes_path = Path(__file__).absolute().parent.joinpath("../../data/serdes_storage") + + +def test_memory_serdes(): + msg1 = Message(role="User", + content="write a 2048 game", + cause_by=BossRequirement) + + out_mapping = {"field1": (list[str], ...)} + out_data = {"field1": ["field1 value1", "field1 value2"]} + ic_obj = ActionOutput.create_model_class("system_design", out_mapping) + msg2 = Message(role="Architect", + instruct_content=ic_obj(**out_data), + content="system design content", + cause_by=WriteDesign) + + memory = Memory() + memory.add_batch([msg1, msg2]) + + stg_path = serdes_path.joinpath("team/environment") + memory.serialize(stg_path) + assert stg_path.joinpath("memory.json").exists() + + new_memory = Memory.deserialize(stg_path) + assert new_memory.count() == 2 + new_msg2 = new_memory.get(1)[0] + assert new_msg2.instruct_content.field1 == ["field1 value1", "field1 value2"] + assert new_msg2.cause_by == WriteDesign + + stg_path.joinpath("memory.json").unlink() diff --git a/tests/metagpt/roles/test_role.py b/tests/metagpt/roles/test_role.py new file mode 100644 index 000000000..a19ad9cb5 --- /dev/null +++ b/tests/metagpt/roles/test_role.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : unittest of Role + +from pathlib import Path +import shutil +import pytest + +from metagpt.roles.role import Role, RoleReactMode +from metagpt.actions.action import Action +from metagpt.schema import Message +from metagpt.actions.add_requirement import BossRequirement +from metagpt.roles.product_manager import ProductManager + +serdes_path = Path(__file__).absolute().parent.joinpath("../../data/serdes_storage") + + +def test_role_serdes(): + stg_path_prefix = serdes_path.joinpath("team/environment/roles/") + shutil.rmtree(serdes_path.joinpath("team"), ignore_errors=True) + + pm = ProductManager() + role_tag = f"{pm.__class__.__name__}_{pm.name}" + stg_path = stg_path_prefix.joinpath(role_tag) + pm.serialize(stg_path) + assert stg_path.joinpath("actions/actions_info.json").exists() + + new_pm = Role.deserialize(stg_path) + assert new_pm.name == pm.name + assert len(new_pm.get_memories(1)) == 0 + + +class ActionOK(Action): + + async def run(self, messages: list["Message"]): + return "ok" + + +class ActionRaise(Action): + + async def run(self, messages: list["Message"]): + raise RuntimeError("parse error") + + +class RoleA(Role): + + def __init__(self, + name: str = "RoleA", + profile: str = "Role A", + goal: str = "", + constraints: str = ""): + super(RoleA, self).__init__(name=name, profile=profile, goal=goal, constraints=constraints) + self._init_actions([ActionOK, ActionRaise]) + self._watch([BossRequirement]) + self._rc.react_mode = RoleReactMode.BY_ORDER + + async def run(self, message: "Message" = None, stg_path: str = None): + try: + await super(RoleA, self).run(message) + except Exception as exp: + print("exp ", exp) + self.serialize(stg_path) + + +@pytest.mark.asyncio +async def test_role_serdes_interrupt(): + role_a = RoleA() + shutil.rmtree(serdes_path.joinpath("team"), ignore_errors=True) + + stg_path = serdes_path.joinpath(f"team/environment/roles/{role_a.__class__.__name__}_{role_a.name}") + await role_a.run( + message=Message(content="demo", cause_by=BossRequirement), + stg_path=stg_path + ) + assert role_a._rc.memory.count() == 2 + + assert stg_path.joinpath("actions/todo.json").exists() + + new_role_a: Role = Role.deserialize(stg_path) + assert new_role_a._rc.state == 1 + await role_a.run( + message=Message(content="demo", cause_by=BossRequirement), + stg_path=stg_path + ) + diff --git a/tests/metagpt/test_environment.py b/tests/metagpt/test_environment.py index b27bc3da7..03236a08b 100644 --- a/tests/metagpt/test_environment.py +++ b/tests/metagpt/test_environment.py @@ -7,6 +7,8 @@ """ import pytest +from pathlib import Path +import shutil from metagpt.actions import UserRequirement from metagpt.environment import Environment @@ -14,6 +16,10 @@ from metagpt.logs import logger from metagpt.manager import Manager from metagpt.roles import Architect, ProductManager, Role from metagpt.schema import Message +from tests.metagpt.roles.test_role import RoleA + + +serdes_path = Path(__file__).absolute().parent.joinpath("../data/serdes_storage") @pytest.fixture @@ -36,12 +42,6 @@ def test_get_roles(env: Environment): assert roles == {role1.profile: role1, role2.profile: role2} -def test_set_manager(env: Environment): - manager = Manager() - env.set_manager(manager) - assert env.manager == manager - - @pytest.mark.asyncio async def test_publish_and_process_message(env: Environment): product_manager = ProductManager("Alice", "Product Manager", "做AI Native产品", "资源有限") @@ -54,3 +54,18 @@ async def test_publish_and_process_message(env: Environment): await env.run(k=2) logger.info(f"{env.history=}") assert len(env.history) > 10 + + +def test_environment_serdes(): + environment = Environment() + role_a = RoleA() + + shutil.rmtree(serdes_path.joinpath("team"), ignore_errors=True) + + stg_path = serdes_path.joinpath("team/environment") + environment.add_role(role_a) + environment.serialize(stg_path) + + new_env: Environment = Environment() + new_env.deserialize(stg_path) + assert len(new_env.roles) == 1 diff --git a/tests/metagpt/test_schema.py b/tests/metagpt/test_schema.py index 51ebd5baa..4a6f518b1 100644 --- a/tests/metagpt/test_schema.py +++ b/tests/metagpt/test_schema.py @@ -7,12 +7,16 @@ @Modified By: mashenquan, 2023-11-1. In line with Chapter 2.2.1 and 2.2.2 of RFC 116, introduce unit tests for the utilization of the new feature of `Message` class. """ + import json import pytest from metagpt.actions import Action from metagpt.schema import AIMessage, Message, SystemMessage, UserMessage +from metagpt.actions.action_output import ActionOutput +from metagpt.actions.write_code import WriteCode +from metagpt.utils.serialize import serialize_general_message, deserialize_general_message from metagpt.utils.common import get_class_name @@ -70,5 +74,43 @@ def test_routes(): assert m.send_to == {"e", get_class_name(Action)} +def test_message_serdes(): + out_mapping = {"field3": (str, ...), "field4": (list[str], ...)} + out_data = {"field3": "field3 value3", "field4": ["field4 value1", "field4 value2"]} + ic_obj = ActionOutput.create_model_class("code", out_mapping) + + message = Message( + content="code", + instruct_content=ic_obj(**out_data), + role="engineer", + cause_by=WriteCode + ) + message_dict = serialize_general_message(message) + assert message_dict["cause_by"] == {"action_class": "WriteCode"} + assert message_dict["instruct_content"] == { + "class": "code", + "mapping": { + "field3": "(, Ellipsis)", + "field4": "(list[str], Ellipsis)" + }, + "value": { + "field3": "field3 value3", + "field4": ["field4 value1", "field4 value2"] + } + } + + new_message = deserialize_general_message(message_dict) + assert new_message.content == message.content + assert new_message.instruct_content == message.instruct_content + assert new_message.cause_by == message.cause_by + assert new_message.instruct_content.field3 == out_data["field3"] + + message = Message(content="code") + message_dict = serialize_general_message(message) + new_message = deserialize_general_message(message_dict) + assert new_message.instruct_content is None + assert new_message.cause_by == "" + + if __name__ == "__main__": pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/test_team.py b/tests/metagpt/test_team.py new file mode 100644 index 000000000..ab201152c --- /dev/null +++ b/tests/metagpt/test_team.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : unittest of team + +from pathlib import Path +import shutil + +from metagpt.team import Team + +from tests.metagpt.roles.test_role import RoleA + +serdes_path = Path(__file__).absolute().parent.joinpath("../data/serdes_storage") + + +def test_team_serdes(): + company = Team() + company.hire([RoleA()]) + + stg_path = serdes_path.joinpath("team") + shutil.rmtree(stg_path, ignore_errors=True) + + company.serialize(stg_path=stg_path) + + new_company = Team() + new_company.deserialize(stg_path) + + assert len(new_company.environment.roles) == 1