add mg ser&deser

This commit is contained in:
better629 2023-11-28 09:29:00 +08:00
parent 4db99b825a
commit 949bc747f9
16 changed files with 693 additions and 16 deletions

View file

@ -17,6 +17,7 @@ from metagpt.logs import logger
from metagpt.provider.postprecess.llm_output_postprecess import llm_output_postprecess
from metagpt.utils.common import OutputParser
from metagpt.utils.utils import general_after_log
from metagpt.utils.utils import import_class
class Action(ABC):
@ -51,6 +52,36 @@ class Action(ABC):
def __repr__(self):
return self.__str__()
def serialize(self):
return {
"action_class": self.__class__.__name__,
"module_name": self.__module__,
"name": self.name
}
@classmethod
def deserialize(cls, action_dict: dict):
action_class_str = action_dict.pop("action_class")
module_name = action_dict.pop("module_name")
action_class = import_class(action_class_str, module_name)
return action_class(**action_dict)
@classmethod
def ser_class(cls):
""" serialize class type"""
return {
"action_class": cls.__name__,
"module_name": cls.__module__
}
@classmethod
def deser_class(cls, action_dict: dict):
""" deserialize class type """
action_class_str = action_dict.pop("action_class")
module_name = action_dict.pop("module_name")
action_class = import_class(action_class_str, module_name)
return action_class
async def _aask(self, prompt: str, system_msgs: Optional[list[str]] = None) -> str:
"""Append default prefix"""
if not system_msgs:

View file

@ -60,6 +60,8 @@ SWAGGER_PATH = UT_PATH / "files/api/"
UT_PY_PATH = UT_PATH / "files/ut/"
API_QUESTIONS_PATH = UT_PATH / "files/question/"
SERDES_PATH = DEFAULT_WORKSPACE_ROOT / "storage" # TODO to store `storage` under the individual generated project
TMP = METAGPT_ROOT / "tmp"
SOURCE_ROOT = METAGPT_ROOT / "metagpt"

View file

@ -13,6 +13,7 @@
"""
import asyncio
from typing import Iterable, Set
from pathlib import Path
from pydantic import BaseModel, Field
@ -20,6 +21,7 @@ from metagpt.logs import logger
from metagpt.roles import Role
from metagpt.schema import Message
from metagpt.utils.common import is_subscribed
from metagpt.utils.utils import read_json_file, write_json_file
class Environment(BaseModel):
@ -35,6 +37,42 @@ class Environment(BaseModel):
class Config:
arbitrary_types_allowed = True
def serialize(self, stg_path: Path):
roles_path = stg_path.joinpath("roles.json")
roles_info = []
for role_key, role in self.roles.items():
roles_info.append({
"role_class": role.__class__.__name__,
"module_name": role.__module__,
"role_name": role.name
})
role.serialize(stg_path=stg_path.joinpath(f"roles/{role.__class__.__name__}_{role.name}"))
write_json_file(roles_path, roles_info)
self.memory.serialize(stg_path)
history_path = stg_path.joinpath("history.json")
write_json_file(history_path, {"content": self.history})
def deserialize(self, stg_path: Path):
""" stg_path: ./storage/team/environment/ """
roles_path = stg_path.joinpath("roles.json")
roles_info = read_json_file(roles_path)
for role_info in roles_info:
role_class = role_info.get("role_class")
role_name = role_info.get("role_name")
role_path = stg_path.joinpath(f"roles/{role_class}_{role_name}")
role = Role.deserialize(role_path)
self.add_role(role)
memory = Memory.deserialize(stg_path)
self.memory = memory
history_path = stg_path.joinpath("history.json")
history = read_json_file(history_path)
self.history = history.get("content")
def add_role(self, role: Role):
"""增加一个在当前环境的角色
Add a role in the current environment

View file

@ -8,9 +8,12 @@
"""
from collections import defaultdict
from typing import Iterable, Set
from pathlib import Path
from metagpt.schema import Message
from metagpt.utils.common import any_to_str, any_to_str_set
from metagpt.utils.utils import read_json_file, write_json_file
from metagpt.utils.serialize import serialize_general_message, deserialize_general_message
class Memory:
@ -21,6 +24,33 @@ class Memory:
self.storage: list[Message] = []
self.index: dict[str, list[Message]] = defaultdict(list)
def serialize(self, stg_path: Path):
""" stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/ """
memory_path = stg_path.joinpath("memory.json")
storage = []
for message in self.storage:
# msg_dict = message.serialize()
msg_dict = serialize_general_message(message)
storage.append(msg_dict)
write_json_file(memory_path, storage)
@classmethod
def deserialize(cls, stg_path: Path) -> "Memory":
""" stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/"""
memory_path = stg_path.joinpath("memory.json")
memory = Memory()
memory_list = read_json_file(memory_path)
for message in memory_list:
# distinguish instruct_content type in message
# msg = Message.deserialize(message)
msg = deserialize_general_message(message)
memory.add(msg)
return memory
def add(self, message: Message):
"""Add a new message to storage, while updating the index"""
if message in self.storage:

View file

@ -22,7 +22,7 @@ from __future__ import annotations
from enum import Enum
from typing import Iterable, Set, Type
from pathlib import Path
from pydantic import BaseModel, Field
from metagpt.actions import Action, ActionOutput
@ -30,10 +30,12 @@ from metagpt.actions.action_node import ActionNode
from metagpt.actions.add_requirement import UserRequirement
from metagpt.llm import LLM, HumanProvider
from metagpt.logs import logger
from metagpt.memory import Memory
from metagpt.schema import Message, MessageQueue
from metagpt.utils.common import any_to_str
from metagpt.utils.repair_llm_raw_output import extract_state_value_from_output
from metagpt.memory import Memory
from metagpt.utils.utils import read_json_file, write_json_file, import_class
PREFIX_TEMPLATE = """You are a {profile}, named {name}, your goal is {goal}, and the constraint is {constraints}. """
@ -152,6 +154,87 @@ class Role(metaclass=_RoleInjector):
self._rc = RoleContext()
self._subscription = {any_to_str(self), name} if name else {any_to_str(self)}
self._recovered = False
def serialize(self, stg_path: Path):
role_info_path = stg_path.joinpath("role_info.json")
role_info = {
"role_class": self.__class__.__name__,
"module_name": self.__module__
}
setting = self._setting.dict()
setting.pop("desc")
setting.pop("is_human") # not all inherited roles have this atrr
role_info.update(setting)
write_json_file(role_info_path, role_info)
actions_info_path = stg_path.joinpath("actions/actions_info.json")
actions_info = []
for action in self._actions:
actions_info.append(action.serialize())
write_json_file(actions_info_path, actions_info)
watches_info_path = stg_path.joinpath("watches/watches_info.json")
watches_info = []
for watch in self._rc.watch:
watches_info.append(watch.ser_class())
write_json_file(watches_info_path, watches_info)
actions_todo_path = stg_path.joinpath("actions/todo.json")
actions_todo = {
"cur_state": self._rc.state,
"react_mode": self._rc.react_mode.value,
"max_react_loop": self._rc.max_react_loop
}
write_json_file(actions_todo_path, actions_todo)
self._rc.memory.serialize(stg_path)
@classmethod
def deserialize(cls, stg_path: Path) -> "Role":
""" stg_path = ./storage/team/environment/roles/{role_class}_{role_name}"""
role_info_path = stg_path.joinpath("role_info.json")
role_info = read_json_file(role_info_path)
role_class_str = role_info.pop("role_class")
module_name = role_info.pop("module_name")
role_class = import_class(class_name=role_class_str, module_name=module_name)
role = role_class(**role_info) # initiate particular Role
actions_info_path = stg_path.joinpath("actions/actions_info.json")
actions = []
actions_info = read_json_file(actions_info_path)
for action_info in actions_info:
action = Action.deserialize(action_info)
actions.append(action)
watches_info_path = stg_path.joinpath("watches/watches_info.json")
watches = []
watches_info = read_json_file(watches_info_path)
for watch_info in watches_info:
action = Action.deser_class(watch_info)
watches.append(action)
role.init_actions(actions)
role.watch(watches)
actions_todo_path = stg_path.joinpath("actions/todo.json")
# recover self._rc.state
actions_todo = read_json_file(actions_todo_path)
max_react_loop = actions_todo.get("max_react_loop", 1)
cur_state = actions_todo.get("cur_state", -1)
role.set_state(cur_state)
role.set_recovered(True)
react_mode_str = actions_todo.get("react_mode", RoleReactMode.REACT.value)
if react_mode_str not in RoleReactMode.values():
logger.warning(f"ReactMode: {react_mode_str} not in {RoleReactMode.values()}, use react as default")
react_mode_str = RoleReactMode.REACT.value
role.set_react_mode(RoleReactMode(react_mode_str), max_react_loop)
role_memory = Memory.deserialize(stg_path)
role.set_memory(role_memory)
return role
def _reset(self):
self._states = []
@ -160,6 +243,15 @@ class Role(metaclass=_RoleInjector):
def _init_action_system_message(self, action: Action):
action.set_prefix(self._get_prefix(), self.profile)
def set_recovered(self, recovered: bool = False):
self._recovered = recovered
def set_memory(self, memory: Memory):
self._rc.memory = memory
def init_actions(self, actions):
self._init_actions(actions)
def _init_actions(self, actions):
self._reset()
for idx, action in enumerate(actions):
@ -178,6 +270,9 @@ class Role(metaclass=_RoleInjector):
self._actions.append(i)
self._states.append(f"{idx}. {action}")
def set_react_mode(self, react_mode: RoleReactMode, max_react_loop: int = 1):
self._set_react_mode(react_mode, max_react_loop)
def _set_react_mode(self, react_mode: str, max_react_loop: int = 1):
"""Set strategy of the Role reacting to observed Message. Variation lies in how
this Role elects action to perform during the _think stage, especially if it is capable of multiple Actions.
@ -199,6 +294,9 @@ class Role(metaclass=_RoleInjector):
if react_mode == RoleReactMode.REACT:
self._rc.max_react_loop = max_react_loop
def watch(self, actions: Iterable[Type[Action]]):
self._watch(actions)
def _watch(self, actions: Iterable[Type[Action]]):
"""Watch Actions of interest. Role will select Messages caused by these Actions from its personal message
buffer during _observe.
@ -217,6 +315,9 @@ class Role(metaclass=_RoleInjector):
if self._rc.env: # According to the routing feature plan in Chapter 2.2.3.2 of RFC 113
self._rc.env.set_subscription(self, self._subscription)
def set_state(self, state: int):
self._set_state(state)
def _set_state(self, state: int):
"""Update the current state."""
self._rc.state = state
@ -230,6 +331,10 @@ class Role(metaclass=_RoleInjector):
if env:
env.set_subscription(self, self._subscription)
@property
def name(self):
return self._setting.name
@property
def profile(self):
"""Get the role description (position)"""
@ -257,6 +362,11 @@ class Role(metaclass=_RoleInjector):
# If there is only one action, then only this one can be performed
self._set_state(0)
return
if self._recovered and self._rc.state >= 0:
self._set_state(self._rc.state) # action to run from recovered state
self._recovered = False # avoid max_react_loop out of work
return
prompt = self._get_prefix()
prompt += STATE_TEMPLATE.format(
history=self._rc.history,
@ -349,7 +459,8 @@ class Role(metaclass=_RoleInjector):
async def _act_by_order(self) -> Message:
"""switch action each time by order defined in _init_actions, i.e. _act (Action1) -> _act (Action2) -> ..."""
for i in range(len(self._states)):
start_idx = self._rc.state if self._rc.state >= 0 else 0 # action to run from recovered state
for i in range(start_idx, len(self._states)):
self._set_state(i)
rsp = await self._act()
return rsp # return output from the last action

View file

@ -22,7 +22,6 @@ from asyncio import Queue, QueueEmpty, wait_for
from json import JSONDecodeError
from pathlib import Path
from typing import Dict, List, Optional, Set, TypedDict
from pydantic import BaseModel, Field
from metagpt.config import CONFIG
@ -36,6 +35,9 @@ from metagpt.const import (
)
from metagpt.logs import logger
from metagpt.utils.common import any_to_str, any_to_str_set
# from metagpt.utils.serialize import actionoutout_schema_to_mapping
# from metagpt.actions.action_output import ActionOutput
# from metagpt.actions.action import Action
class RawMessage(TypedDict):
@ -155,6 +157,46 @@ class Message(BaseModel):
def __repr__(self):
return self.__str__()
# def serialize(self):
# message_cp: Message = copy.deepcopy(self)
# ic = message_cp.instruct_content
# if ic:
# # model create by pydantic create_model like `pydantic.main.prd`, can't pickle.dump directly
# schema = ic.schema()
# mapping = actionoutout_schema_to_mapping(schema)
#
# message_cp.instruct_content = {"class": schema["title"], "mapping": mapping, "value": ic.dict()}
# cb = message_cp.cause_by
# if cb:
# message_cp.cause_by = cb.serialize()
#
# return message_cp.dict()
#
# @classmethod
# def deserialize(cls, message_dict: dict):
# instruct_content = message_dict.get("instruct_content")
# if instruct_content:
# ic = instruct_content
# ic_obj = ActionOutput.create_model_class(class_name=ic["class"], mapping=ic["mapping"])
# ic_new = ic_obj(**ic["value"])
# message_dict.instruct_content = ic_new
# cause_by = message_dict.get("cause_by")
# if cause_by:
# message_dict.cause_by = Action.deserialize(cause_by)
#
# return Message(**message_dict)
def dict(self):
return {
"content": self.content,
"instruct_content": self.instruct_content,
"role": self.role,
"cause_by": self.cause_by,
"sent_from": self.sent_from,
"send_to": self.send_to,
"restricted_to": self.restricted_to
}
def to_dict(self) -> dict:
"""Return a dict containing `role` and `content` for the LLM call.l"""
return {"role": self.role, "content": self.content}

View file

@ -7,6 +7,7 @@
@Modified By: mashenquan, 2023/11/27. Add an archiving operation after completing the project, as specified in
Section 2.2.3.3 of RFC 135.
"""
from pathlib import Path
from pydantic import BaseModel, Field
from metagpt.actions import UserRequirement
@ -17,6 +18,7 @@ from metagpt.logs import logger
from metagpt.roles import Role
from metagpt.schema import Message
from metagpt.utils.common import NoMoneyException
from metagpt.utils.utils import read_json_file, write_json_file
class Team(BaseModel):
@ -32,6 +34,30 @@ class Team(BaseModel):
class Config:
arbitrary_types_allowed = True
def serialize(self, stg_path: Path):
team_info_path = stg_path.joinpath("team_info.json")
write_json_file(team_info_path, {
"idea": self.idea,
"investment": self.investment
})
self.environment.serialize(stg_path.joinpath("environment"))
def deserialize(self, stg_path: Path):
""" stg_path = ./storage/team """
# recover team_info
team_info_path = stg_path.joinpath("team_info.json")
if not team_info_path.exists():
logger.error("recover storage not exist, not to recover and continue run the old project.")
team_info = read_json_file(team_info_path)
self.investment = team_info.get("investment", 10.0)
self.idea = team_info.get("idea", "")
# recover environment
environment_path = stg_path.joinpath("environment")
self.environment = Environment()
self.environment.deserialize(stg_path=environment_path)
def hire(self, roles: list[Role]):
"""Hire roles to cooperate"""
self.env.add_roles(roles)

View file

@ -4,13 +4,13 @@
import copy
import pickle
from typing import Dict, List
from metagpt.actions.action_output import ActionOutput
from metagpt.schema import Message
from metagpt.actions.action import Action
def actionoutout_schema_to_mapping(schema: Dict) -> Dict:
def actionoutout_schema_to_mapping(schema: dict) -> dict:
"""
directly traverse the `properties` in the first level.
schema structure likes
@ -35,13 +35,47 @@ def actionoutout_schema_to_mapping(schema: Dict) -> Dict:
if property["type"] == "string":
mapping[field] = (str, ...)
elif property["type"] == "array" and property["items"]["type"] == "string":
mapping[field] = (List[str], ...)
mapping[field] = (list[str], ...)
elif property["type"] == "array" and property["items"]["type"] == "array":
# here only consider the `List[List[str]]` situation
mapping[field] = (List[List[str]], ...)
# here only consider the `list[list[str]]` situation
mapping[field] = (list[list[str]], ...)
return mapping
def actionoutput_mapping_to_str(mapping: dict) -> dict:
new_mapping = {}
for key, value in mapping.items():
new_mapping[key] = str(value)
return new_mapping
def actionoutput_str_to_mapping(mapping: dict) -> dict:
new_mapping = {}
for key, value in mapping.items():
if value == "(<class 'str'>, Ellipsis)":
new_mapping[key] = (str, ...)
else:
new_mapping[key] = eval(value) # `"'(list[str], Ellipsis)"` to `(list[str], ...)`
return new_mapping
def serialize_general_message(message: Message) -> dict:
""" serialize Message, not to save"""
message_cp = copy.deepcopy(message)
ic = message_cp.instruct_content
if ic:
# model create by pydantic create_model like `pydantic.main.prd`, can't pickle.dump directly
schema = ic.schema()
mapping = actionoutout_schema_to_mapping(schema)
mapping = actionoutput_mapping_to_str(mapping)
message_cp.instruct_content = {"class": schema["title"], "mapping": mapping, "value": ic.dict()}
cb = message_cp.cause_by
if cb:
message_cp.cause_by = cb.ser_class()
return message_cp.dict()
def serialize_message(message: Message):
message_cp = copy.deepcopy(message) # avoid `instruct_content` value update by reference
ic = message_cp.instruct_content
@ -56,6 +90,24 @@ def serialize_message(message: Message):
return msg_ser
def deserialize_general_message(message_dict: dict) -> Message:
""" deserialize Message, not to load"""
instruct_content = message_dict.pop("instruct_content")
cause_by = message_dict.pop("cause_by")
message = Message(**message_dict)
if instruct_content:
ic = instruct_content
mapping = actionoutput_str_to_mapping(ic["mapping"])
ic_obj = ActionOutput.create_model_class(class_name=ic["class"], mapping=mapping)
ic_new = ic_obj(**ic["value"])
message.instruct_content = ic_new
if cause_by:
message.cause_by = Action.deser_class(cause_by)
return message
def deserialize_message(message_ser: str) -> Message:
message = pickle.loads(message_ser)
if message.instruct_content:

View file

@ -3,7 +3,10 @@
# @Desc :
import typing
from typing import Any
import json
from pathlib import Path
import importlib
from tenacity import _utils
@ -20,3 +23,36 @@ def general_after_log(logger: "loguru.Logger", sec_format: str = "%0.3f") -> typ
)
return log_it
def read_json_file(json_file: str, encoding=None) -> list[Any]:
if not Path(json_file).exists():
raise FileNotFoundError(f"json_file: {json_file} not exist, return []")
with open(json_file, "r", encoding=encoding) as fin:
try:
data = json.load(fin)
except Exception as exp:
raise ValueError(f"read json file: {json_file} failed")
return data
def write_json_file(json_file: str, data: list, encoding=None):
folder_path = Path(json_file).parent
if not folder_path.exists():
folder_path.mkdir(parents=True, exist_ok=True)
with open(json_file, "w", encoding=encoding) as fout:
json.dump(data, fout, ensure_ascii=False, indent=4)
def import_class(class_name: str, module_name: str) -> type:
module = importlib.import_module(module_name)
a_class = getattr(module, class_name)
return a_class
def import_class_inst(class_name: str, module_name: str, *args, **kwargs) -> object:
a_class = import_class(class_name, module_name)
class_inst = a_class(*args, **kwargs)
return class_inst

81
startup.py Normal file
View file

@ -0,0 +1,81 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import asyncio
import fire
from metagpt.const import SERDES_PATH
from metagpt.roles import (
Architect,
Engineer,
ProductManager,
ProjectManager,
QaEngineer,
)
from metagpt.team import Team
async def startup(
idea: str,
investment: float = 3.0,
n_round: int = 5,
code_review: bool = False,
run_tests: bool = False,
implement: bool = True,
recover_path: bool = False,
):
"""Run a startup. Be a boss."""
company = Team()
if not recover_path:
company.hire(
[
ProductManager(),
Architect(),
ProjectManager(),
]
)
# if implement or code_review
if implement or code_review:
# developing features: implement the idea
company.hire([Engineer(n_borg=5, use_code_review=code_review)])
if run_tests:
# developing features: run tests on the spot and identify bugs
# (bug fixing capability comes soon!)
company.hire([QaEngineer()])
else:
stg_path = SERDES_PATH.joinpath("team")
company.deserialize(stg_path=stg_path)
idea = company.idea # use original idea
company.invest(investment)
company.start_project(idea)
await company.run(n_round=n_round)
def main(
idea: str,
investment: float = 3.0,
n_round: int = 5,
code_review: bool = True,
run_tests: bool = False,
implement: bool = True,
recover_path: str = None,
):
"""
We are a software startup comprised of AI. By investing in us,
you are empowering a future filled with limitless possibilities.
:param idea: Your innovative idea, such as "Creating a snake game."
:param investment: As an investor, you have the opportunity to contribute
a certain dollar amount to this AI company.
:param n_round:
:param code_review: Whether to use code review.
:param recover_path: recover the project from existing serialized storage
:return:
"""
asyncio.run(startup(idea, investment, n_round, code_review, run_tests, implement, recover_path))
if __name__ == "__main__":
fire.Fire(main)

View file

@ -11,3 +11,20 @@ from metagpt.actions import Action, WritePRD, WriteTest
def test_action_repr():
actions = [Action(), WriteTest(), WritePRD()]
assert "WriteTest" in str(actions)
def test_action_serdes():
action_info = WriteTest.ser_class()
assert action_info["action_class"] == "WriteTest"
action_class = Action.deser_class(action_info)
assert action_class == WriteTest
def test_action_class_serdes():
name = "write test"
action_info = WriteTest(name=name).serialize()
assert action_info["name"] == name
action = Action.deserialize(action_info)
assert action.name == name

View file

@ -0,0 +1,42 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : unittest of memory
from pathlib import Path
from metagpt.schema import Message
from metagpt.memory.memory import Memory
from metagpt.actions.action_output import ActionOutput
from metagpt.actions.design_api import WriteDesign
from metagpt.actions.add_requirement import BossRequirement
serdes_path = Path(__file__).absolute().parent.joinpath("../../data/serdes_storage")
def test_memory_serdes():
msg1 = Message(role="User",
content="write a 2048 game",
cause_by=BossRequirement)
out_mapping = {"field1": (list[str], ...)}
out_data = {"field1": ["field1 value1", "field1 value2"]}
ic_obj = ActionOutput.create_model_class("system_design", out_mapping)
msg2 = Message(role="Architect",
instruct_content=ic_obj(**out_data),
content="system design content",
cause_by=WriteDesign)
memory = Memory()
memory.add_batch([msg1, msg2])
stg_path = serdes_path.joinpath("team/environment")
memory.serialize(stg_path)
assert stg_path.joinpath("memory.json").exists()
new_memory = Memory.deserialize(stg_path)
assert new_memory.count() == 2
new_msg2 = new_memory.get(1)[0]
assert new_msg2.instruct_content.field1 == ["field1 value1", "field1 value2"]
assert new_msg2.cause_by == WriteDesign
stg_path.joinpath("memory.json").unlink()

View file

@ -0,0 +1,85 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : unittest of Role
from pathlib import Path
import shutil
import pytest
from metagpt.roles.role import Role, RoleReactMode
from metagpt.actions.action import Action
from metagpt.schema import Message
from metagpt.actions.add_requirement import BossRequirement
from metagpt.roles.product_manager import ProductManager
serdes_path = Path(__file__).absolute().parent.joinpath("../../data/serdes_storage")
def test_role_serdes():
stg_path_prefix = serdes_path.joinpath("team/environment/roles/")
shutil.rmtree(serdes_path.joinpath("team"), ignore_errors=True)
pm = ProductManager()
role_tag = f"{pm.__class__.__name__}_{pm.name}"
stg_path = stg_path_prefix.joinpath(role_tag)
pm.serialize(stg_path)
assert stg_path.joinpath("actions/actions_info.json").exists()
new_pm = Role.deserialize(stg_path)
assert new_pm.name == pm.name
assert len(new_pm.get_memories(1)) == 0
class ActionOK(Action):
async def run(self, messages: list["Message"]):
return "ok"
class ActionRaise(Action):
async def run(self, messages: list["Message"]):
raise RuntimeError("parse error")
class RoleA(Role):
def __init__(self,
name: str = "RoleA",
profile: str = "Role A",
goal: str = "",
constraints: str = ""):
super(RoleA, self).__init__(name=name, profile=profile, goal=goal, constraints=constraints)
self._init_actions([ActionOK, ActionRaise])
self._watch([BossRequirement])
self._rc.react_mode = RoleReactMode.BY_ORDER
async def run(self, message: "Message" = None, stg_path: str = None):
try:
await super(RoleA, self).run(message)
except Exception as exp:
print("exp ", exp)
self.serialize(stg_path)
@pytest.mark.asyncio
async def test_role_serdes_interrupt():
role_a = RoleA()
shutil.rmtree(serdes_path.joinpath("team"), ignore_errors=True)
stg_path = serdes_path.joinpath(f"team/environment/roles/{role_a.__class__.__name__}_{role_a.name}")
await role_a.run(
message=Message(content="demo", cause_by=BossRequirement),
stg_path=stg_path
)
assert role_a._rc.memory.count() == 2
assert stg_path.joinpath("actions/todo.json").exists()
new_role_a: Role = Role.deserialize(stg_path)
assert new_role_a._rc.state == 1
await role_a.run(
message=Message(content="demo", cause_by=BossRequirement),
stg_path=stg_path
)

View file

@ -7,6 +7,8 @@
"""
import pytest
from pathlib import Path
import shutil
from metagpt.actions import UserRequirement
from metagpt.environment import Environment
@ -14,6 +16,10 @@ from metagpt.logs import logger
from metagpt.manager import Manager
from metagpt.roles import Architect, ProductManager, Role
from metagpt.schema import Message
from tests.metagpt.roles.test_role import RoleA
serdes_path = Path(__file__).absolute().parent.joinpath("../data/serdes_storage")
@pytest.fixture
@ -36,12 +42,6 @@ def test_get_roles(env: Environment):
assert roles == {role1.profile: role1, role2.profile: role2}
def test_set_manager(env: Environment):
manager = Manager()
env.set_manager(manager)
assert env.manager == manager
@pytest.mark.asyncio
async def test_publish_and_process_message(env: Environment):
product_manager = ProductManager("Alice", "Product Manager", "做AI Native产品", "资源有限")
@ -54,3 +54,18 @@ async def test_publish_and_process_message(env: Environment):
await env.run(k=2)
logger.info(f"{env.history=}")
assert len(env.history) > 10
def test_environment_serdes():
environment = Environment()
role_a = RoleA()
shutil.rmtree(serdes_path.joinpath("team"), ignore_errors=True)
stg_path = serdes_path.joinpath("team/environment")
environment.add_role(role_a)
environment.serialize(stg_path)
new_env: Environment = Environment()
new_env.deserialize(stg_path)
assert len(new_env.roles) == 1

View file

@ -7,12 +7,16 @@
@Modified By: mashenquan, 2023-11-1. In line with Chapter 2.2.1 and 2.2.2 of RFC 116, introduce unit tests for
the utilization of the new feature of `Message` class.
"""
import json
import pytest
from metagpt.actions import Action
from metagpt.schema import AIMessage, Message, SystemMessage, UserMessage
from metagpt.actions.action_output import ActionOutput
from metagpt.actions.write_code import WriteCode
from metagpt.utils.serialize import serialize_general_message, deserialize_general_message
from metagpt.utils.common import get_class_name
@ -70,5 +74,43 @@ def test_routes():
assert m.send_to == {"e", get_class_name(Action)}
def test_message_serdes():
out_mapping = {"field3": (str, ...), "field4": (list[str], ...)}
out_data = {"field3": "field3 value3", "field4": ["field4 value1", "field4 value2"]}
ic_obj = ActionOutput.create_model_class("code", out_mapping)
message = Message(
content="code",
instruct_content=ic_obj(**out_data),
role="engineer",
cause_by=WriteCode
)
message_dict = serialize_general_message(message)
assert message_dict["cause_by"] == {"action_class": "WriteCode"}
assert message_dict["instruct_content"] == {
"class": "code",
"mapping": {
"field3": "(<class 'str'>, Ellipsis)",
"field4": "(list[str], Ellipsis)"
},
"value": {
"field3": "field3 value3",
"field4": ["field4 value1", "field4 value2"]
}
}
new_message = deserialize_general_message(message_dict)
assert new_message.content == message.content
assert new_message.instruct_content == message.instruct_content
assert new_message.cause_by == message.cause_by
assert new_message.instruct_content.field3 == out_data["field3"]
message = Message(content="code")
message_dict = serialize_general_message(message)
new_message = deserialize_general_message(message_dict)
assert new_message.instruct_content is None
assert new_message.cause_by == ""
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -0,0 +1,27 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : unittest of team
from pathlib import Path
import shutil
from metagpt.team import Team
from tests.metagpt.roles.test_role import RoleA
serdes_path = Path(__file__).absolute().parent.joinpath("../data/serdes_storage")
def test_team_serdes():
company = Team()
company.hire([RoleA()])
stg_path = serdes_path.joinpath("team")
shutil.rmtree(stg_path, ignore_errors=True)
company.serialize(stg_path=stg_path)
new_company = Team()
new_company.deserialize(stg_path)
assert len(new_company.environment.roles) == 1