From 68c8ef107347f713ee6f3433735374d175b98017 Mon Sep 17 00:00:00 2001 From: better629 Date: Wed, 20 Dec 2023 10:44:30 +0800 Subject: [PATCH] update ser&deser code --- metagpt/actions/action.py | 1 - metagpt/roles/role.py | 26 ++++-- metagpt/schema.py | 8 +- metagpt/startup.py | 37 +++++--- metagpt/utils/utils.py | 17 ++-- startup.py | 86 ------------------- .../serialize_deserialize/test_role.py | 2 +- .../serialize_deserialize/test_team.py | 14 ++- 8 files changed, 70 insertions(+), 121 deletions(-) delete mode 100644 startup.py diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index 570863388..8cba18945 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -20,7 +20,6 @@ from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.provider.postprecess.llm_output_postprecess import llm_output_postprecess from metagpt.utils.common import OutputParser from metagpt.utils.utils import general_after_log -from metagpt.utils.utils import import_class action_subclass_registry = {} diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 9b1e0bf94..09371ae08 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -39,7 +39,7 @@ from metagpt.provider.human_provider import HumanProvider from metagpt.schema import Message, MessageQueue from metagpt.utils.common import any_to_str from metagpt.utils.repair_llm_raw_output import extract_state_value_from_output -from metagpt.utils.utils import read_json_file, write_json_file, import_class +from metagpt.utils.utils import read_json_file, write_json_file, import_class, role_raise_decorator PREFIX_TEMPLATE = """You are a {profile}, named {name}, your goal is {goal}, and the constraint is {constraints}. """ @@ -137,6 +137,7 @@ class Role(BaseModel): # builtin variables recovered: bool = False # to tag if a recovered role + latest_observed_msg: Message = None # record the latest observed message when interrupted builtin_class_name: str = "" _private_attributes = { @@ -200,7 +201,6 @@ class Role(BaseModel): def _reset(self): object.__setattr__(self, "_states", []) object.__setattr__(self, "_actions", []) - # object.__setattr__(self, "_rc", RoleContext()) @property def _setting(self): @@ -210,7 +210,7 @@ class Role(BaseModel): stg_path = SERDESER_PATH.joinpath(f"team/environment/roles/{self.__class__.__name__}_{self.name}") \ if stg_path is None else stg_path - role_info = self.dict(exclude={"_rc": {"memory": True}, "_llm": True}) + role_info = self.dict(exclude={"_rc": {"memory": True, "msg_buffer": True}, "_llm": True}) role_info.update({ "role_class": self.__class__.__name__, "module_name": self.__module__ @@ -311,7 +311,7 @@ class Role(BaseModel): def _set_state(self, state: int): """Update the current state.""" self._rc.state = state - logger.debug(self._actions) + logger.debug(f"actions={self._actions}, state={state}") self._rc.todo = self._actions[self._rc.state] if state >= 0 else None def set_env(self, env: "Environment"): @@ -388,15 +388,30 @@ class Role(BaseModel): return msg + def _find_news(self, observed: list[Message], existed: list[Message]) -> list[Message]: + news = [] + # Warning, remove `id` here to make it work for recover + observed_pure = [msg.dict(exclude={"id": True}) for msg in observed] + existed_pure = [msg.dict(exclude={"id": True}) for msg in existed] + for idx, new in enumerate(observed_pure): + if new["cause_by"] in self._rc.watch and new not in existed_pure: + news.append(observed[idx]) + return news + async def _observe(self, ignore_memory=False) -> int: """Prepare new messages for processing from the message buffer and other sources.""" # Read unprocessed messages from the msg buffer. news = self._rc.msg_buffer.pop_all() + if self.recovered: + news = [self.latest_observed_msg] if self.latest_observed_msg else [] + else: + self.latest_observed_msg = news[-1] if len(news) > 0 else None # record the latest observed msg + # Store the read messages in your own memory to prevent duplicate processing. old_messages = [] if ignore_memory else self._rc.memory.get() self._rc.memory.add_batch(news) # Filter out messages of interest. - self._rc.news = [n for n in news if n.cause_by in self._rc.watch and n not in old_messages] + self._rc.news = self._find_news(news, old_messages) # Design Rules: # If you need to further categorize Message objects, you can do so using the Message.set_meta function. @@ -484,6 +499,7 @@ class Role(BaseModel): """A wrapper to return the most recent k memories of this role, return all when k=0""" return self._rc.memory.get(k=k) + @role_raise_decorator async def run(self, with_message=None): """Observe, and think and act based on the results of the observation""" if with_message: diff --git a/metagpt/schema.py b/metagpt/schema.py index 0ec9b5c60..0fdc24e02 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -26,7 +26,6 @@ from typing import Dict, List, Set, TypedDict, Optional, Any from pydantic import BaseModel, Field -from metagpt.actions import UserRequirement from metagpt.config import CONFIG from metagpt.const import ( MESSAGE_ROUTE_CAUSE_BY, @@ -118,8 +117,9 @@ class Message(BaseModel): ic_new = ic_obj(**ic["value"]) kwargs["instruct_content"] = ic_new - kwargs["id"] = uuid.uuid4().hex - kwargs["cause_by"] = any_to_str(kwargs.get("cause_by", UserRequirement)) + kwargs["id"] = kwargs.get("id", uuid.uuid4().hex) + kwargs["cause_by"] = any_to_str(kwargs.get("cause_by", + import_class("UserRequirement", "metagpt.actions.add_requirement"))) kwargs["sent_from"] = any_to_str(kwargs.get("sent_from", "")) kwargs["send_to"] = any_to_str_set(kwargs.get("send_to", {MESSAGE_ROUTE_TO_ALL})) super(Message, self).__init__(**kwargs) @@ -218,7 +218,7 @@ class MessageQueue(BaseModel): if key in kwargs: object.__setattr__(self, key, kwargs[key]) else: - object.__setattr__(self, key, self._private_attributes[key]) + object.__setattr__(self, key, Queue()) def pop(self) -> Message | None: """Pop one message from the queue.""" diff --git a/metagpt/startup.py b/metagpt/startup.py index f930c386b..17eb26665 100644 --- a/metagpt/startup.py +++ b/metagpt/startup.py @@ -27,8 +27,10 @@ def startup( reqa_file: str = typer.Option(default="", help="Specify the source file name for rewriting the quality test code."), max_auto_summarize_code: int = typer.Option( default=-1, - help="The maximum number of times the 'SummarizeCode' action is automatically invoked, with -1 indicating unlimited. This parameter is used for debugging the workflow.", + help="The maximum number of times the 'SummarizeCode' action is automatically invoked, " + "with -1 indicating unlimited. This parameter is used for debugging the workflow.", ), + recover_path: str = typer.Option(default=None, help="recover the project from existing serialized storage") ): """Run a startup. Be a boss.""" from metagpt.roles import ( @@ -50,20 +52,29 @@ def startup( CONFIG.reqa_file = reqa_file CONFIG.max_auto_summarize_code = max_auto_summarize_code - company = Team() - company.hire( - [ - ProductManager(), - Architect(), - ProjectManager(), - ] - ) + if not recover_path: + company = Team() + company.hire( + [ + ProductManager(), + Architect(), + ProjectManager(), + ] + ) - if implement or code_review: - company.hire([Engineer(n_borg=5, use_code_review=code_review)]) + if implement or code_review: + company.hire([Engineer(n_borg=5, use_code_review=code_review)]) - if run_tests: - company.hire([QaEngineer()]) + if run_tests: + company.hire([QaEngineer()]) + else: + # # stg_path = SERDESER_PATH.joinpath("team") + stg_path = Path(recover_path) + if not stg_path.exists() or not str(stg_path).endswith("team"): + raise FileNotFoundError(f"{recover_path} not exists or not endswith `team`") + + company = Team.recover(stg_path=stg_path) + idea = company.idea # use original idea company.invest(investment) company.run_project(idea) diff --git a/metagpt/utils/utils.py b/metagpt/utils/utils.py index 57da57b00..aa7c039c4 100644 --- a/metagpt/utils/utils.py +++ b/metagpt/utils/utils.py @@ -88,18 +88,15 @@ def role_raise_decorator(func): return await func(self, *args, **kwargs) except KeyboardInterrupt as kbi: logger.error(f"KeyboardInterrupt: {kbi} occurs, start to serialize the project") - if self._rc.env: - newest_msgs = self._rc.env.memory.get(1) - if len(newest_msgs) > 0: - self._rc.memory.delete(newest_msgs[0]) + if self.latest_observed_msg: + self._rc.memory.delete(self.latest_observed_msg) raise Exception(format_trackback_info(limit=None)) # raise again to make it captured outside except Exception as exp: - if self._rc.env: - newest_msgs = self._rc.env.memory.get(1) - if len(newest_msgs) > 0: - logger.warning("There is a exception in role's execution, in order to resume, " - "we delete the newest role communication message in the role's memory.") - self._rc.memory.delete(newest_msgs[0]) # remove newest msg of the role to make it observed again + if self.latest_observed_msg: + logger.warning("There is a exception in role's execution, in order to resume, " + "we delete the newest role communication message in the role's memory.") + # remove role newest observed msg to make it observed again + self._rc.memory.delete(self.latest_observed_msg) raise Exception(format_trackback_info(limit=None)) # raise again to make it captured outside return wrapper diff --git a/startup.py b/startup.py deleted file mode 100644 index c4928a1b5..000000000 --- a/startup.py +++ /dev/null @@ -1,86 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -from typing import Optional -import asyncio -import fire -from pathlib import Path - -from metagpt.roles import ( - Architect, - Engineer, - ProductManager, - ProjectManager, - QaEngineer, -) -from metagpt.team import Team - - -async def startup( - idea: str, - investment: float = 3.0, - n_round: int = 5, - code_review: bool = False, - run_tests: bool = False, - implement: bool = True, - recover_path: Optional[str] = None, -): - """Run a startup. Be a boss.""" - if not recover_path: - company = Team() - company.hire( - [ - ProductManager(), - Architect(), - ProjectManager(), - ] - ) - - # if implement or code_review - if implement or code_review: - # developing features: implement the idea - company.hire([Engineer(n_borg=5, use_code_review=code_review)]) - - if run_tests: - # developing features: run tests on the spot and identify bugs - # (bug fixing capability comes soon!) - company.hire([QaEngineer()]) - else: - # # stg_path = SERDESER_PATH.joinpath("team") - stg_path = Path(recover_path) - if not stg_path.exists() or not str(stg_path).endswith("team"): - raise FileNotFoundError(f"{recover_path} not exists or not endswith `team`") - - company = Team.recover(stg_path=stg_path) - idea = company.idea # use original idea - - company.invest(investment) - company.start_project(idea) - await company.run(n_round=n_round) - - -def main( - idea: str, - investment: float = 3.0, - n_round: int = 5, - code_review: bool = True, - run_tests: bool = False, - implement: bool = True, - recover_path: str = None, -): - """ - We are a software startup comprised of AI. By investing in us, - you are empowering a future filled with limitless possibilities. - :param idea: Your innovative idea, such as "Creating a snake game." - :param investment: As an investor, you have the opportunity to contribute - a certain dollar amount to this AI company. - :param n_round: - :param code_review: Whether to use code review. - :param recover_path: recover the project from existing serialized storage - :return: - """ - asyncio.run(startup(idea, investment, n_round, code_review, run_tests, implement, recover_path)) - - -if __name__ == "__main__": - fire.Fire(main) diff --git a/tests/metagpt/serialize_deserialize/test_role.py b/tests/metagpt/serialize_deserialize/test_role.py index f25403dc0..87cf75caa 100644 --- a/tests/metagpt/serialize_deserialize/test_role.py +++ b/tests/metagpt/serialize_deserialize/test_role.py @@ -84,7 +84,7 @@ async def test_role_serdeser_interrupt(): logger.error(f"Exception in `role_a.run`, detail: {format_trackback_info()}") role_c.serialize(stg_path) - assert role_c._rc.memory.count() == 2 + assert role_c._rc.memory.count() == 1 new_role_a: Role = Role.deserialize(stg_path) assert new_role_a._rc.state == 1 diff --git a/tests/metagpt/serialize_deserialize/test_team.py b/tests/metagpt/serialize_deserialize/test_team.py index 01e0a6c70..e87df9b52 100644 --- a/tests/metagpt/serialize_deserialize/test_team.py +++ b/tests/metagpt/serialize_deserialize/test_team.py @@ -106,11 +106,23 @@ async def test_team_recover_multi_roles_save(): stg_path = SERDESER_PATH.joinpath("team") shutil.rmtree(stg_path, ignore_errors=True) + role_a = RoleA() + role_b = RoleB() + + assert role_a.subscription == {"tests.metagpt.serialize_deserialize.test_serdeser_base.RoleA", + "RoleA"} + assert role_b.subscription == {"tests.metagpt.serialize_deserialize.test_serdeser_base.RoleB", + "RoleB"} + assert role_b._rc.watch == {"tests.metagpt.serialize_deserialize.test_serdeser_base.ActionPass"} + company = Team() - company.hire([RoleA(), RoleB()]) + company.hire([role_a, role_b]) company.run_project(idea) await company.run(n_round=4) new_company = Team.recover(stg_path) new_company.run_project(idea) + + assert new_company.env.get_role(role_b.profile)._rc.state == 1 + await new_company.run(n_round=4)