From 5e3607f85bc4fec0ff97c57ff7d866f108e3c9c3 Mon Sep 17 00:00:00 2001 From: better629 Date: Thu, 30 Nov 2023 15:18:24 +0800 Subject: [PATCH] update environment/message to BaseModel, update the ser&deser of roles/actions --- metagpt/actions/action.py | 29 ++++- metagpt/actions/design_api.py | 8 +- metagpt/actions/project_management.py | 3 +- metagpt/actions/search_and_summarize.py | 15 ++- metagpt/actions/write_code.py | 3 +- metagpt/actions/write_code_review.py | 4 +- metagpt/actions/write_prd.py | 6 +- metagpt/actions/write_test.py | 11 +- metagpt/const.py | 2 +- metagpt/environment.py | 39 +++++-- metagpt/memory/longterm_memory.py | 14 ++- metagpt/memory/memory.py | 79 ++++++++++---- metagpt/roles/architect.py | 4 +- metagpt/roles/customer_service.py | 19 ++-- metagpt/roles/engineer.py | 4 +- metagpt/roles/product_manager.py | 5 +- metagpt/roles/project_manager.py | 4 +- metagpt/roles/qa_engineer.py | 16 +-- metagpt/roles/role.py | 130 +++++++++++++--------- metagpt/roles/sales.py | 31 +++--- metagpt/roles/seacher.py | 21 ++-- metagpt/schema.py | 138 ++++++++++++++---------- metagpt/team.py | 39 ++++--- metagpt/utils/serialize.py | 26 +++-- metagpt/utils/utils.py | 43 ++++++++ startup.py | 17 +-- 26 files changed, 458 insertions(+), 252 deletions(-) diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index aefe6d39d..7a7f194f4 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -5,8 +5,9 @@ @Author : alexanderwu @File : action.py """ + +from __future__ import annotations import re -from abc import ABC from typing import Optional, Any from pydantic import BaseModel, Field @@ -14,25 +15,43 @@ from tenacity import retry, stop_after_attempt, wait_fixed from metagpt.actions.action_output import ActionOutput from metagpt.llm import LLM +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.logs import logger from metagpt.utils.common import OutputParser from metagpt.utils.custom_decoder import CustomDecoder from metagpt.utils.utils import import_class +action_subclass_registry = {} + + class Action(BaseModel): name: str = "" - llm: LLM = Field(default_factory=LLM) + llm: BaseGPTAPI = Field(default_factory=LLM, exclude=True) context = "" prefix = "" profile = "" desc = "" content: Optional[str] = None instruct_content: Optional[str] = None + + # builtin variables + builtin_class_name: str = "" + + class Config: + arbitrary_types_allowed = True def __init__(self, **kwargs: Any): super().__init__(**kwargs) - + + # deserialize child classes dynamically for inherited `action` + object.__setattr__(self, "builtin_class_name", self.__class__.__name__) + self.__fields__["builtin_class_name"].default = self.__class__.__name__ + + def __init_subclass__(cls, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + action_subclass_registry[cls.__name__] = cls + def set_prefix(self, prefix, profile): """Set prefix for later usage""" self.prefix = prefix @@ -52,14 +71,14 @@ class Action(BaseModel): } @classmethod - def deserialize(cls, action_dict: dict): + def deserialize(cls, action_dict: dict) -> "Action": action_class_str = action_dict.pop("action_class") module_name = action_dict.pop("module_name") action_class = import_class(action_class_str, module_name) return action_class(**action_dict) @classmethod - def ser_class(cls): + def ser_class(cls) -> dict: """ serialize class type""" return { "action_class": cls.__name__, diff --git a/metagpt/actions/design_api.py b/metagpt/actions/design_api.py index 30df70ce7..015678baa 100644 --- a/metagpt/actions/design_api.py +++ b/metagpt/actions/design_api.py @@ -13,6 +13,7 @@ from pydantic import Field from metagpt.actions import Action, ActionOutput from metagpt.llm import LLM +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.config import CONFIG from metagpt.const import WORKSPACE_ROOT from metagpt.logs import logger @@ -155,12 +156,11 @@ OUTPUT_MAPPING = { class WriteDesign(Action): name: str = "" context: Optional[str] = None - llm: LLM = Field(default_factory=LLM) + llm: BaseGPTAPI = Field(default_factory=LLM) desc: str = "Based on the PRD, think about the system design, and design the corresponding APIs, " "data structures, library tables, processes, and paths. Please provide your design, feedback " "clearly and in detail." - def recreate_workspace(self, workspace: Path): try: shutil.rmtree(workspace) @@ -168,7 +168,6 @@ class WriteDesign(Action): pass # Folder does not exist, but we don't care workspace.mkdir(parents=True, exist_ok=True) - async def _save_prd(self, docs_path, resources_path, context): prd_file = docs_path / "prd.md" if context[-1].instruct_content and context[-1].instruct_content.dict()["Competitive Quadrant Chart"]: @@ -179,7 +178,6 @@ class WriteDesign(Action): logger.info(f"Saving PRD to {prd_file}") prd_file.write_text(json_to_markdown(context[-1].instruct_content.dict())) - async def _save_system_design(self, docs_path, resources_path, system_design): data_api_design = system_design.instruct_content.dict()[ "Data structures and interface definitions" @@ -193,7 +191,6 @@ class WriteDesign(Action): logger.info(f"Saving System Designs to {system_design_file}") system_design_file.write_text((json_to_markdown(system_design.instruct_content.dict()))) - async def _save(self, context, system_design): if isinstance(system_design, ActionOutput): ws_name = system_design.instruct_content.dict()["Python package name"] @@ -211,7 +208,6 @@ class WriteDesign(Action): logger.error(f"Failed to save PRD {e}") await self._save_system_design(docs_path, resources_path, system_design) - async def run(self, context, format=CONFIG.prompt_format): prompt_template, format_example = get_template(templates, format) prompt = prompt_template.format(context=context, format_example=format_example) diff --git a/metagpt/actions/project_management.py b/metagpt/actions/project_management.py index b72507ee3..cf44906cd 100644 --- a/metagpt/actions/project_management.py +++ b/metagpt/actions/project_management.py @@ -11,6 +11,7 @@ from pydantic import Field from metagpt.actions.action import Action from metagpt.llm import LLM +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.config import CONFIG from metagpt.const import WORKSPACE_ROOT from metagpt.utils.common import CodeParser @@ -168,7 +169,7 @@ OUTPUT_MAPPING = { class WriteTasks(Action): name: str = "CreateTasks" context: Optional[str] = None - llm: LLM = Field(default_factory=LLM) + llm: BaseGPTAPI = Field(default_factory=LLM) def _save(self, context, rsp): try: diff --git a/metagpt/actions/search_and_summarize.py b/metagpt/actions/search_and_summarize.py index 0580303e6..6b0c1f717 100644 --- a/metagpt/actions/search_and_summarize.py +++ b/metagpt/actions/search_and_summarize.py @@ -8,14 +8,15 @@ import pydantic from typing import Optional, Any from pydantic import BaseModel, Field +from pydantic import root_validator from metagpt.actions import Action from metagpt.llm import LLM -from metagpt.config import Config +from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.config import Config, CONFIG from metagpt.logs import logger from metagpt.schema import Message from metagpt.tools.search_engine import SearchEngine -from pydantic import root_validator SEARCH_AND_SUMMARIZE_SYSTEM = """### Requirements 1. Please summarize the latest dialogue based on the reference information (secondary) and dialogue history (primary). Do not include text that is irrelevant to the conversation. @@ -106,13 +107,13 @@ You are a member of a professional butler team and will provide helpful suggesti class SearchAndSummarize(Action): name: str = "" content: Optional[str] = None - llm: None = Field(default_factory=LLM) + llm: BaseGPTAPI = Field(default_factory=LLM) config: None = Field(default_factory=Config) - engine: Optional[str] = None + engine: Optional[str] = CONFIG.search_engine search_func: Optional[str] = None + search_engine: SearchEngine = None result = "" - @root_validator def validate_engine_and_run_func(cls, values): @@ -130,9 +131,7 @@ class SearchAndSummarize(Action): values['search_engine'] = search_engine return values - - - + async def run(self, context: list[Message], system_text=SEARCH_AND_SUMMARIZE_SYSTEM) -> str: print(context) if self.search_engine is None: diff --git a/metagpt/actions/write_code.py b/metagpt/actions/write_code.py index 2dc240591..10487e53a 100644 --- a/metagpt/actions/write_code.py +++ b/metagpt/actions/write_code.py @@ -13,6 +13,7 @@ from tenacity import retry, stop_after_attempt, wait_fixed from metagpt.actions import WriteDesign from metagpt.actions.action import Action from metagpt.llm import LLM +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.const import WORKSPACE_ROOT from metagpt.logs import logger from metagpt.schema import Message @@ -50,7 +51,7 @@ ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenc class WriteCode(Action): name: str = "WriteCode" context: Optional[str] = None - llm: LLM = Field(default_factory=LLM) + llm: BaseGPTAPI = Field(default_factory=LLM) def _is_invalid(self, filename): return any(i in filename for i in ["mp3", "wav"]) diff --git a/metagpt/actions/write_code_review.py b/metagpt/actions/write_code_review.py index 3d86d7c63..79e462f76 100644 --- a/metagpt/actions/write_code_review.py +++ b/metagpt/actions/write_code_review.py @@ -12,7 +12,7 @@ from tenacity import retry, stop_after_attempt, wait_fixed from metagpt.llm import LLM from metagpt.actions.action import Action from metagpt.logs import logger -from metagpt.schema import Message +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.utils.common import CodeParser PROMPT_TEMPLATE = """ @@ -67,7 +67,7 @@ FORMAT_EXAMPLE = """ class WriteCodeReview(Action): name: str = "WriteCodeReview" context: Optional[str] = None - llm: LLM = Field(default_factory=LLM) + llm: BaseGPTAPI = Field(default_factory=LLM) @retry(stop=stop_after_attempt(2), wait=wait_fixed(1)) async def write_code(self, prompt): diff --git a/metagpt/actions/write_prd.py b/metagpt/actions/write_prd.py index 660d7fb95..450bed7e7 100644 --- a/metagpt/actions/write_prd.py +++ b/metagpt/actions/write_prd.py @@ -11,6 +11,7 @@ from pydantic import BaseModel, Field from metagpt.actions import Action, ActionOutput from metagpt.llm import LLM +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.actions.search_and_summarize import SearchAndSummarize from metagpt.config import CONFIG from metagpt.logs import logger @@ -224,12 +225,9 @@ OUTPUT_MAPPING = { class WritePRD(Action): name: str = "" content: Optional[str] = None - llm: LLM = Field(default_factory=LLM) + llm: BaseGPTAPI = Field(default_factory=LLM) assistant_search_action: Action = None - def __init__(self, **kwargs): - super().__init__(**kwargs) - async def run(self, requirements, format=CONFIG.prompt_format, *args, **kwargs) -> ActionOutput: # self.assistant_search_action = SearchAndSummarize() if self.assistant_search_action is None: diff --git a/metagpt/actions/write_test.py b/metagpt/actions/write_test.py index 35ff36dc2..6c902444a 100644 --- a/metagpt/actions/write_test.py +++ b/metagpt/actions/write_test.py @@ -5,6 +5,12 @@ @Author : alexanderwu @File : environment.py """ + +from typing import Optional +from pydantic import Field + +from metagpt.llm import LLM +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.actions.action import Action from metagpt.logs import logger from metagpt.utils.common import CodeParser @@ -31,8 +37,9 @@ you should correctly import the necessary classes based on these file locations! class WriteTest(Action): - def __init__(self, name="WriteTest", context=None, llm=None): - super().__init__(name, context, llm) + name: str = "WriteTest" + context: Optional[str] = None + llm: BaseGPTAPI = Field(default_factory=LLM) async def write_code(self, prompt): code_rsp = await self._aask(prompt) diff --git a/metagpt/const.py b/metagpt/const.py index 711546d03..4b063a3dd 100644 --- a/metagpt/const.py +++ b/metagpt/const.py @@ -42,7 +42,7 @@ TMP = PROJECT_ROOT / "tmp" RESEARCH_PATH = DATA_PATH / "research" TUTORIAL_PATH = DATA_PATH / "tutorial_docx" INVOICE_OCR_TABLE_PATH = DATA_PATH / "invoice_table" -SERDES_PATH = WORKSPACE_ROOT / "storage" # TODO to store `storage` under the individual generated project +SERDESER_PATH = WORKSPACE_ROOT / "storage" # TODO to store `storage` under the individual generated project SKILL_DIRECTORY = PROJECT_ROOT / "metagpt/skills" diff --git a/metagpt/environment.py b/metagpt/environment.py index e867ad6fc..bade53f50 100644 --- a/metagpt/environment.py +++ b/metagpt/environment.py @@ -12,7 +12,7 @@ from pathlib import Path from pydantic import BaseModel, Field from metagpt.memory import Memory -from metagpt.roles import Role +from metagpt.roles.role import Role, role_subclass_registry from metagpt.schema import Message from metagpt.utils.utils import read_json_file, write_json_file @@ -30,6 +30,19 @@ class Environment(BaseModel): class Config: arbitrary_types_allowed = True + def __init__(self, **kwargs): + for role_key, role in kwargs.get("roles", {}).items(): + current_role = kwargs["roles"][role_key] + if isinstance(current_role, dict): + item_class_name = current_role.get("builtin_class_name", None) + for name, subclass in role_subclass_registry.items(): + registery_class_name = subclass.__fields__["builtin_class_name"].default + if item_class_name == registery_class_name: + current_role = subclass(**current_role) + break + kwargs["roles"][role_key] = current_role + super().__init__(**kwargs) + def serialize(self, stg_path: Path): roles_path = stg_path.joinpath("roles.json") roles_info = [] @@ -46,33 +59,39 @@ class Environment(BaseModel): history_path = stg_path.joinpath("history.json") write_json_file(history_path, {"content": self.history}) - def deserialize(self, stg_path: Path): + @classmethod + def deserialize(cls, stg_path: Path) -> "Environment": """ stg_path: ./storage/team/environment/ """ roles_path = stg_path.joinpath("roles.json") roles_info = read_json_file(roles_path) + roles = [] for role_info in roles_info: role_class = role_info.get("role_class") role_name = role_info.get("role_name") role_path = stg_path.joinpath(f"roles/{role_class}_{role_name}") role = Role.deserialize(role_path) - - self.add_role(role) + roles.append(role) memory = Memory.deserialize(stg_path) - self.memory = memory - history_path = stg_path.joinpath("history.json") - history = read_json_file(history_path) - self.history = history.get("content") + history = read_json_file(stg_path.joinpath("history.json")) + history = history.get("content") + + environment = Environment(**{ + "memory": memory, + "history": history + }) + environment.add_roles(roles) + return environment def add_role(self, role: Role): - """增加一个在当前环境的角色, 默认为profile/role_profile + """增加一个在当前环境的角色, 默认为profile Add a role in the current environment """ role.set_env(self) # use alias - self.roles[role.role_profile] = role + self.roles[role.profile] = role def add_roles(self, roles: Iterable[Role]): """增加一批在当前环境的角色 diff --git a/metagpt/memory/longterm_memory.py b/metagpt/memory/longterm_memory.py index f8abea5f3..5d149ee7a 100644 --- a/metagpt/memory/longterm_memory.py +++ b/metagpt/memory/longterm_memory.py @@ -2,6 +2,9 @@ # -*- coding: utf-8 -*- # @Desc : the implement of Long-term memory +from typing import Optional +from pydantic import Field + from metagpt.logs import logger from metagpt.memory import Memory from metagpt.memory.memory_storage import MemoryStorage @@ -15,11 +18,12 @@ class LongTermMemory(Memory): - update memory when it changed """ - def __init__(self): - self.memory_storage: MemoryStorage = MemoryStorage() - super(LongTermMemory, self).__init__() - self.rc = None # RoleContext - self.msg_from_recover = False + memory_storage: MemoryStorage = Field(default_factory=MemoryStorage) + rc: Optional["RoleContext"] = None + msg_from_recover: bool = False + + class Config: + arbitrary_types_allowed = True def recover_memory(self, role_id: str, rc: "RoleContext"): messages = self.memory_storage.recover_memory(role_id) diff --git a/metagpt/memory/memory.py b/metagpt/memory/memory.py index a839bb038..c88cc750e 100644 --- a/metagpt/memory/memory.py +++ b/metagpt/memory/memory.py @@ -5,34 +5,65 @@ @Author : alexanderwu @File : memory.py """ +import copy from collections import defaultdict -from typing import Iterable, Type +from typing import Iterable, Type, Union, Optional from pathlib import Path +from pydantic import BaseModel, Field +import json from metagpt.actions import Action from metagpt.schema import Message from metagpt.utils.utils import read_json_file, write_json_file -from metagpt.utils.serialize import serialize_general_message, deserialize_general_message +from metagpt.utils.utils import import_class -class Memory: +class Memory(BaseModel): """The most basic memory: super-memory""" - def __init__(self): - """Initialize an empty storage list and an empty index dictionary""" - self.storage: list[Message] = [] - self.index: dict[Type[Action], list[Message]] = defaultdict(list) + storage: list[Message] = Field(default=[]) + index: dict[Type[Action], list[Message]] = Field(default_factory=defaultdict(list)) + + def __init__(self, **kwargs): + index = kwargs.get("index", {}) + new_index = defaultdict(list) + for action_str, value in index.items(): + action_dict = json.loads(action_str) + action_class = import_class("Action", "metagpt.actions.action") + action_obj = action_class.deser_class(action_dict) + new_index[action_obj] = [Message(**item_dict) for item_dict in value] + kwargs["index"] = new_index + super(Memory, self).__init__(**kwargs) + self.index = new_index + + def dict(self, + *, + include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + by_alias: bool = False, + skip_defaults: Optional[bool] = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False) -> "DictStrAny": + """ overwrite the `dict` to dump dynamic pydantic model""" + obj_dict = super(Memory, self).dict(include=include, + exclude=exclude, + by_alias=by_alias, + skip_defaults=skip_defaults, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none) + new_obj_dict = copy.deepcopy(obj_dict) + new_obj_dict["index"] = {} + for action, value in obj_dict["index"].items(): + action_ser = json.dumps(action.ser_class()) + new_obj_dict["index"][action_ser] = value + return new_obj_dict def serialize(self, stg_path: Path): """ stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/ """ memory_path = stg_path.joinpath("memory.json") - - storage = [] - for message in self.storage: - # msg_dict = message.serialize() - msg_dict = serialize_general_message(message) - storage.append(msg_dict) - + storage = self.dict() write_json_file(memory_path, storage) @classmethod @@ -40,13 +71,8 @@ class Memory: """ stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/""" memory_path = stg_path.joinpath("memory.json") - memory = Memory() - memory_list = read_json_file(memory_path) - for message in memory_list: - # distinguish instruct_content type in message - # msg = Message.deserialize(message) - msg = deserialize_general_message(message) - memory.add(msg) + memory_dict = read_json_file(memory_path) + memory = Memory(**memory_dict) return memory @@ -70,6 +96,16 @@ class Memory: """Return all messages containing a specified content""" return [message for message in self.storage if content in message.content] + def delete_newest(self) -> "Message": + """ delete the newest message from the storage""" + if len(self.storage) > 0: + newest_msg = self.storage.pop() + if newest_msg.cause_by and newest_msg in self.index[newest_msg.cause_by]: + self.index[newest_msg.cause_by].remove(newest_msg) + else: + newest_msg = None + return newest_msg + def delete(self, message: Message): """Delete the specified message from storage, while updating the index""" self.storage.remove(message) @@ -115,4 +151,3 @@ class Memory: continue rsp += self.index[action] return rsp - \ No newline at end of file diff --git a/metagpt/roles/architect.py b/metagpt/roles/architect.py index face22a68..09d52edbe 100644 --- a/metagpt/roles/architect.py +++ b/metagpt/roles/architect.py @@ -22,8 +22,8 @@ class Architect(Role): goal (str): Primary goal or responsibility of the architect. constraints (str): Constraints or guidelines for the architect. """ - name: str = "Bob" - role_profile: str = Field(default="Architect" , alias='profile') + name: str = Field(default="Bob") + profile: str = Field(default="Architect") goal: str = "Design a concise, usable, complete python system" constraints: str = "Try to specify good open source tools as much as possible" diff --git a/metagpt/roles/customer_service.py b/metagpt/roles/customer_service.py index 4547f8190..62792696f 100644 --- a/metagpt/roles/customer_service.py +++ b/metagpt/roles/customer_service.py @@ -5,6 +5,9 @@ @Author : alexanderwu @File : sales.py """ +from typing import Optional +from pydantic import Field + from metagpt.roles import Sales # from metagpt.actions import SearchAndSummarize @@ -24,12 +27,14 @@ DESC = """ class CustomerService(Sales): + + name: str = Field(default="Xiaomei") + profile: str = Field(default="Human customer service") + desc: str = DESC, + + store: Optional[str] = None + def __init__( self, - name="Xiaomei", - profile="Human customer service", - desc=DESC, - store=None - ): - super().__init__(name, profile, desc=desc, store=store) - \ No newline at end of file + **kwargs): + super().__init__(**kwargs) diff --git a/metagpt/roles/engineer.py b/metagpt/roles/engineer.py index 129bedeb8..e90f586f0 100644 --- a/metagpt/roles/engineer.py +++ b/metagpt/roles/engineer.py @@ -60,8 +60,8 @@ class Engineer(Role): use_code_review (bool): Whether to use code review. todos (list): List of tasks. """ - name: str = "Alex" - role_profile: str = Field(default="Engineer", alias='profile') + name: str = Field(default="Alex") + profile: str = Field(default="Engineer") goal: str = "Write elegant, readable, extensible, efficient code" constraints: str = "The code should conform to standards like PEP8 and be modular and maintainable" n_borg: int = 1 diff --git a/metagpt/roles/product_manager.py b/metagpt/roles/product_manager.py index b099fb4d9..6f68fe5ba 100644 --- a/metagpt/roles/product_manager.py +++ b/metagpt/roles/product_manager.py @@ -21,10 +21,11 @@ class ProductManager(Role): goal (str): Goal of the product manager. constraints (str): Constraints or limitations for the product manager. """ - name: str = "Alice" - role_profile: str = Field(default="Product Manager", alias='profile') + name: str = Field(default="Alice") + profile: str = Field(default="Product Manager") goal: str = "Efficiently create a successful product" constraints: str = "" + """ Represents a Product Manager role responsible for product development and management. """ diff --git a/metagpt/roles/project_manager.py b/metagpt/roles/project_manager.py index a2b227f22..c8e785d85 100644 --- a/metagpt/roles/project_manager.py +++ b/metagpt/roles/project_manager.py @@ -22,8 +22,8 @@ class ProjectManager(Role): goal (str): Goal of the project manager. constraints (str): Constraints or limitations for the project manager. """ - name: str = "Eve" - role_profile: str = Field(default="Project Manager", alias='profile') + name: str = Field(default="Eve") + profile: str = Field(default="Project Manager") goal: str = "Improve team efficiency and deliver with quality and quantity" constraints: str = "" diff --git a/metagpt/roles/qa_engineer.py b/metagpt/roles/qa_engineer.py index a763c2ce8..bad3f2409 100644 --- a/metagpt/roles/qa_engineer.py +++ b/metagpt/roles/qa_engineer.py @@ -7,6 +7,7 @@ """ import os from pathlib import Path +from pydantic import Field from metagpt.actions import ( DebugError, @@ -25,21 +26,22 @@ from metagpt.utils.special_tokens import FILENAME_CODE_SEP, MSG_SEP class QaEngineer(Role): + name: str = Field(default="Edward") + profile: str = Field(default="QaEngineer") + goal: str = "Write comprehensive and robust tests to ensure codes will work as expected without bugs" + constraints: str = "The test code you write should conform to code standard like PEP8, be modular, easy to read and maintain" + test_round_allowed: int = 5 + def __init__( self, - name="Edward", - profile="QaEngineer", - goal="Write comprehensive and robust tests to ensure codes will work as expected without bugs", - constraints="The test code you write should conform to code standard like PEP8, be modular, easy to read and maintain", - test_round_allowed=5, + **kwargs ): - super().__init__(name, profile, goal, constraints) + super().__init__(**kwargs) self._init_actions( [WriteTest] ) # FIXME: a bit hack here, only init one action to circumvent _think() logic, will overwrite _think() in future updates self._watch([WriteCode, WriteCodeReview, WriteTest, RunCode, DebugError]) self.test_round = 0 - self.test_round_allowed = test_round_allowed @classmethod def parse_workspace(cls, system_design_msg: Message) -> str: diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index e9371c2c0..b6332aa4c 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -6,27 +6,29 @@ @File : role.py """ +from __future__ import annotations from enum import Enum from pathlib import Path -from __future__ import annotations from typing import ( Iterable, - Type + Type, + Any ) -import re -from pydantic import BaseModel, Field -from importlib import import_module +from pydantic import BaseModel, Field, validator # from metagpt.environment import Environment from metagpt.config import CONFIG -from metagpt.actions import Action, ActionOutput +from metagpt.actions.action import Action, ActionOutput, action_subclass_registry from metagpt.llm import LLM +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.logs import logger from metagpt.memory import Memory, LongTermMemory from metagpt.schema import Message from metagpt.provider.human_provider import HumanProvider -from metagpt.utils.utils import read_json_file, write_json_file, import_class +from metagpt.utils.utils import read_json_file, write_json_file, import_class, role_raise_decorator +from metagpt.const import SERDESER_PATH + PREFIX_TEMPLATE = """You are a {profile}, named {name}, your goal is {goal}, and the constraint is {constraints}. """ @@ -57,6 +59,7 @@ ROLE_TEMPLATE = """Your response should be based on the previous conversation hi {name}: {result} """ + class RoleReactMode(str, Enum): REACT = "react" BY_ORDER = "by_order" @@ -74,6 +77,7 @@ class RoleSetting(BaseModel): goal: str = "" constraints: str = "" desc: str = "" + is_human: bool = False def __str__(self): return f"{self.name}({self.profile})" @@ -84,10 +88,10 @@ class RoleSetting(BaseModel): class RoleContext(BaseModel): """Role Runtime Context""" - env: 'Environment' = Field(default=None) + env: "Environment" = Field(default=None) memory: Memory = Field(default_factory=Memory) long_term_memory: LongTermMemory = Field(default_factory=LongTermMemory) - state: int = Field(default=0) + state: int = Field(default=-1) # -1 indicates initial or termination state where todo is None todo: Action = Field(default=None) watch: set[Type[Action]] = Field(default_factory=set) news: list[Type[Message]] = Field(default=[]) @@ -112,53 +116,86 @@ class RoleContext(BaseModel): return self.memory.get() +role_subclass_registry = {} + + class Role(BaseModel): """Role/Agent""" - name: str = "" profile: str = "" goal: str = "" constraints: str = "" desc: str = "" - _setting: RoleSetting = Field(default_factory=RoleSetting, alias="_setting") - _setting = RoleSetting(name=name, profile=profile, goal=goal, constraints=constraints) + is_human: bool = False + + _llm: BaseGPTAPI = Field(default_factory=LLM, exclude=True) + _setting: RoleSetting = Field(default_factory=RoleSetting, alias=True) _role_id: str = "" - _states: list = Field(default=[]) - _actions: list = Field(default=[]) - _actions_type: list = Field(default=[]) + _states: list[str] = Field(default=[]) + _actions: list[Action] = Field(default=[]) _rc: RoleContext = RoleContext() - + + # builtin variables + recovered: bool = False # to tag if a recovered role + builtin_class_name: str = "" + _private_attributes = { - "_setting": _setting, + "_llm": LLM() if not is_human else HumanProvider(), "_role_id": _role_id, "_states": [], - "_actions": [], - "_actions_type": [] # 用于记录和序列化 + "_actions": [] } - + class Config: arbitrary_types_allowed = True - - def __init__(self, **kwargs): + exclude = ["_llm"] + + def __init__(self, **kwargs: Any): + for index in range(len(kwargs.get("_actions", []))): + current_action = kwargs["_actions"][index] + if isinstance(current_action, dict): + item_class_name = current_action.get("builtin_class_name", None) + for name, subclass in action_subclass_registry.items(): + registery_class_name = subclass.__fields__["builtin_class_name"].default + if item_class_name == registery_class_name: + current_action = subclass(**current_action) + break + kwargs["_actions"][index] = current_action + super().__init__(**kwargs) + # 关于私有变量的初始化 https://github.com/pydantic/pydantic/issues/655 + self._private_attributes["_llm"] = LLM() if not self.is_human else HumanProvider() + self._private_attributes["_setting"] = RoleSetting(name=self.name, profile=self.profile, goal=self.goal, + desc=self.desc, constraints=self.constraints, + is_human=self.is_human) for key in self._private_attributes.keys(): if key in kwargs: object.__setattr__(self, key, kwargs[key]) - if key =="_setting": - _setting = RoleSetting(**kwargs[key]) - object.__setattr__(self, '_setting', _setting) + if key == "_setting": + setting = RoleSetting(**kwargs[key]) + object.__setattr__(self, "_setting", setting) elif key == "_rc": _rc = RoleContext - object.__setattr__(self, '_rc', _rc) + object.__setattr__(self, "_rc", _rc) else: object.__setattr__(self, key, self._private_attributes[key]) + + # deserialize child classes dynamically for inherited `role` + object.__setattr__(self, "builtin_class_name", self.__class__.__name__) + self.__fields__["builtin_class_name"].default = self.__class__.__name__ + + def __init_subclass__(cls, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + role_subclass_registry[cls.__name__] = cls def _reset(self): - object.__setattr__(self, '_states', []) - object.__setattr__(self, '_actions', []) + object.__setattr__(self, "_states", []) + object.__setattr__(self, "_actions", []) - def serialize(self, stg_path: Path): + def serialize(self, stg_path: Path = None): + stg_path = SERDESER_PATH.joinpath(f"team/environment/roles/{self.__class__.__name__}_{self.name}") \ + if stg_path is None else stg_path role_info_path = stg_path.joinpath("role_info.json") role_info = { "role_class": self.__class__.__name__, @@ -207,7 +244,7 @@ class Role(BaseModel): actions = [] actions_info = read_json_file(actions_info_path) for action_info in actions_info: - action = Action.deserialize(action_info) + action = Action.deser_class(action_info) actions.append(action) watches_info_path = stg_path.joinpath("watches/watches_info.json") @@ -238,12 +275,8 @@ class Role(BaseModel): return role - def _reset(self): - self._states = [] - self._actions = [] - def set_recovered(self, recovered: bool = False): - self._recovered = recovered + self.recovered = recovered def set_memory(self, memory: Memory): self._rc.memory = memory @@ -256,7 +289,8 @@ class Role(BaseModel): for idx, action in enumerate(actions): if not isinstance(action, Action): ## 默认初始化 - i = action("", llm=self._llm) + # import pdb; pdb.set_trace() + i = action(name="", llm=self._llm) else: if self._setting.is_human and not isinstance(action.llm, HumanProvider): logger.warning(f"is_human attribute does not take effect," @@ -265,8 +299,6 @@ class Role(BaseModel): i.set_prefix(self._get_prefix(), self.profile) self._actions.append(i) self._states.append(f"{idx}. {action}") - action_title = action.schema()["title"] - self._actions_type.append(action_title) def set_react_mode(self, react_mode: RoleReactMode, max_react_loop: int = 1): self._set_react_mode(react_mode, max_react_loop) @@ -310,19 +342,10 @@ class Role(BaseModel): logger.debug(self._actions) self._rc.todo = self._actions[self._rc.state] if state >= 0 else None - def set_env(self, env: 'Environment'): + def set_env(self, env: "Environment"): """Set the environment in which the role works. The role can talk to the environment and can also receive messages by observing.""" self._rc.env = env - @property - def name(self): - return self._setting.name - - @property - def profile(self): - """Get the role description (position)""" - return self._setting.profile - def _get_prefix(self): """Get the role prefix""" if self._setting.desc: @@ -347,7 +370,7 @@ class Role(BaseModel): logger.debug(f"{prompt=}") if (not next_state.isdigit() and next_state != "-1") \ or int(next_state) not in range(-1, len(self._states)): - logger.warning(f'Invalid answer of state, {next_state=}, will be set to -1') + logger.warning(f"Invalid answer of state, {next_state=}, will be set to -1") next_state = -1 else: next_state = int(next_state) @@ -384,7 +407,7 @@ class Role(BaseModel): news_text = [f"{i.role}: {i.content[:20]}..." for i in self._rc.news] if news_text: - logger.debug(f'{self._setting} observed: {news_text}') + logger.debug(f"{self._setting} observed: {news_text}") return len(self._rc.news) def _publish_message(self, msg): @@ -400,7 +423,7 @@ class Role(BaseModel): Use llm to select actions in _think dynamically """ actions_taken = 0 - rsp = Message("No actions taken yet") # will be overwritten after Role _act + rsp = Message(content="No actions taken yet") # will be overwritten after Role _act while actions_taken < self._rc.max_react_loop: # think await self._think() @@ -410,7 +433,7 @@ class Role(BaseModel): logger.debug(f"{self._setting}: {self._rc.state=}, will do {self._rc.todo}") rsp = await self._act() actions_taken += 1 - return rsp # return output from the last action + return rsp # return output from the last action async def _act_by_order(self) -> Message: """switch action each time by order defined in _init_actions, i.e. _act (Action1) -> _act (Action2) -> ...""" @@ -454,7 +477,8 @@ class Role(BaseModel): def get_memories(self, k=0) -> list[Message]: """A wrapper to return the most recent k memories of this role, return all when k=0""" return self._rc.memory.get(k=k) - + + @role_raise_decorator async def run(self, message=None): """Observe, and think and act based on the results of the observation""" if message: diff --git a/metagpt/roles/sales.py b/metagpt/roles/sales.py index a45ad6f1b..dd360d82a 100644 --- a/metagpt/roles/sales.py +++ b/metagpt/roles/sales.py @@ -5,26 +5,34 @@ @Author : alexanderwu @File : sales.py """ + +from typing import Optional +from pydantic import Field + from metagpt.actions import SearchAndSummarize from metagpt.roles import Role from metagpt.tools import SearchEngineType class Sales(Role): + + name: str = Field(default="Xiaomei") + profile: str = Field(default="Retail sales guide") + desc: str = "I am a sales guide in retail. My name is Xiaomei. I will answer some customer questions next, and I " + "will answer questions only based on the information in the knowledge base." + "If I feel that you can't get the answer from the reference material, then I will directly reply that" + " I don't know, and I won't tell you that this is from the knowledge base," + "but pretend to be what I know. Note that each of my replies will be replied in the tone of a " + "professional guide", + + store: Optional[str] = None + def __init__( self, - name="Xiaomei", - profile="Retail sales guide", - desc="I am a sales guide in retail. My name is Xiaomei. I will answer some customer questions next, and I " - "will answer questions only based on the information in the knowledge base." - "If I feel that you can't get the answer from the reference material, then I will directly reply that" - " I don't know, and I won't tell you that this is from the knowledge base," - "but pretend to be what I know. Note that each of my replies will be replied in the tone of a " - "professional guide", - store=None + **kwargs ): - super().__init__(name, profile, desc=desc) - self._set_store(store) + super().__init__(**kwargs) + self._set_store(self.store) def _set_store(self, store): if store: @@ -32,4 +40,3 @@ class Sales(Role): else: action = SearchAndSummarize() self._init_actions([action]) - \ No newline at end of file diff --git a/metagpt/roles/seacher.py b/metagpt/roles/seacher.py index 0b6e089da..e8f291d0d 100644 --- a/metagpt/roles/seacher.py +++ b/metagpt/roles/seacher.py @@ -5,6 +5,9 @@ @Author : alexanderwu @File : seacher.py """ + +from pydantic import Field + from metagpt.actions import ActionOutput, SearchAndSummarize from metagpt.logs import logger from metagpt.roles import Role @@ -23,14 +26,14 @@ class Searcher(Role): constraints (str): Constraints or limitations for the searcher. engine (SearchEngineType): The type of search engine to use. """ + + name: str = Field(default="Alice") + profile: str = Field(default="Smart Assistant") + goal: str = "Provide search services for users" + constraints: str = "Answer is rich and complete" + engine: SearchEngineType = SearchEngineType.SERPAPI_GOOGLE - def __init__(self, - name: str = 'Alice', - profile: str = 'Smart Assistant', - goal: str = 'Provide search services for users', - constraints: str = 'Answer is rich and complete', - engine=SearchEngineType.SERPAPI_GOOGLE, - **kwargs) -> None: + def __init__(self, **kwargs) -> None: """ Initializes the Searcher role with given attributes. @@ -41,8 +44,8 @@ class Searcher(Role): constraints (str): Constraints or limitations for the searcher. engine (SearchEngineType): The type of search engine to use. """ - super().__init__(name, profile, goal, constraints, **kwargs) - self._init_actions([SearchAndSummarize(engine=engine)]) + super().__init__(**kwargs) + self._init_actions([SearchAndSummarize(engine=self.engine)]) def set_search_func(self, search_func): """Sets a custom search function for the searcher.""" diff --git a/metagpt/schema.py b/metagpt/schema.py index 3374a7241..60aa819b0 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -5,18 +5,17 @@ @Author : alexanderwu @File : schema.py """ -from __future__ import annotations from dataclasses import dataclass, field -from typing import Type, TypedDict -import copy +from typing import Type, TypedDict, Union, Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field +from pydantic.main import ModelMetaclass from metagpt.logs import logger -# from metagpt.utils.serialize import actionoutout_schema_to_mapping -# from metagpt.actions.action_output import ActionOutput -# from metagpt.actions.action import Action +from metagpt.utils.serialize import actionoutout_schema_to_mapping, actionoutput_mapping_to_str, \ + actionoutput_str_to_mapping +from metagpt.utils.utils import import_class class RawMessage(TypedDict): @@ -24,16 +23,72 @@ class RawMessage(TypedDict): role: str -@dataclass -class Message: - """list[: ]""" - content: str - instruct_content: BaseModel = field(default=None) - role: str = field(default='user') # system / user / assistant - cause_by: Type["Action"] = field(default="") - sent_from: str = field(default="") - send_to: str = field(default="") - restricted_to: str = field(default="") +class Message(BaseModel): + content: str = "" + instruct_content: BaseModel = Field(default=None) + role: str = "user" # system / user / assistant + cause_by: Type["Action"] = Field(default=None) + sent_from: str = "" + send_to: str = "" + restricted_to: str = "" + + def __init__(self, **kwargs): + instruct_content = kwargs.get("instruct_content", None) + cause_by = kwargs.get("cause_by", None) + if instruct_content and not isinstance(instruct_content, BaseModel): + ic = instruct_content + mapping = actionoutput_str_to_mapping(ic["mapping"]) + + actionoutput_class = import_class("ActionOutput", "metagpt.actions.action_output") + ic_obj = actionoutput_class.create_model_class(class_name=ic["class"], mapping=mapping) + ic_new = ic_obj(**ic["value"]) + kwargs["instruct_content"] = ic_new + if cause_by and not isinstance(cause_by, ModelMetaclass): + action_class = import_class("Action", "metagpt.actions.action") + kwargs["cause_by"] = action_class.deser_class(cause_by) + super(Message, self).__init__(**kwargs) + + def dict(self, + *, + include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + by_alias: bool = False, + skip_defaults: Optional[bool] = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False) -> "DictStrAny": + """ overwrite the `dict` to dump dynamic pydantic model""" + obj_dict = super(Message, self).dict(include=include, + exclude=exclude, + by_alias=by_alias, + skip_defaults=skip_defaults, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none) + ic = self.instruct_content # deal custom-defined action + if ic: + schema = ic.schema() + mapping = actionoutout_schema_to_mapping(schema) + mapping = actionoutput_mapping_to_str(mapping) + + obj_dict["instruct_content"] = {"class": schema["title"], "mapping": mapping, "value": ic.dict()} + cb = self.cause_by + if cb: + obj_dict["cause_by"] = cb.ser_class() + return obj_dict + +# +# +# @dataclass +# class Message: +# """list[: ]""" +# content: str +# instruct_content: BaseModel = field(default=None) +# role: str = field(default='user') # system / user / assistant +# cause_by: Type["Action"] = field(default="") +# sent_from: str = field(default="") +# send_to: str = field(default="") +# restricted_to: str = field(default="") def __str__(self): # prefix = '-'.join([self.role, str(self.cause_by)]) @@ -42,45 +97,16 @@ class Message: def __repr__(self): return self.__str__() - # def serialize(self): - # message_cp: Message = copy.deepcopy(self) - # ic = message_cp.instruct_content - # if ic: - # # model create by pydantic create_model like `pydantic.main.prd`, can't pickle.dump directly - # schema = ic.schema() - # mapping = actionoutout_schema_to_mapping(schema) - # - # message_cp.instruct_content = {"class": schema["title"], "mapping": mapping, "value": ic.dict()} - # cb = message_cp.cause_by - # if cb: - # message_cp.cause_by = cb.serialize() - # - # return message_cp.dict() - # - # @classmethod - # def deserialize(cls, message_dict: dict): - # instruct_content = message_dict.get("instruct_content") - # if instruct_content: - # ic = instruct_content - # ic_obj = ActionOutput.create_model_class(class_name=ic["class"], mapping=ic["mapping"]) - # ic_new = ic_obj(**ic["value"]) - # message_dict.instruct_content = ic_new - # cause_by = message_dict.get("cause_by") - # if cause_by: - # message_dict.cause_by = Action.deserialize(cause_by) - # - # return Message(**message_dict) - - def dict(self): - return { - "content": self.content, - "instruct_content": self.instruct_content, - "role": self.role, - "cause_by": self.cause_by, - "sent_from": self.sent_from, - "send_to": self.send_to, - "restricted_to": self.restricted_to - } + # def dict(self): + # return { + # "content": self.content, + # "instruct_content": self.instruct_content, + # "role": self.role, + # "cause_by": self.cause_by, + # "sent_from": self.sent_from, + # "send_to": self.send_to, + # "restricted_to": self.restricted_to + # } def to_dict(self) -> dict: return { diff --git a/metagpt/team.py b/metagpt/team.py index 3b76e5ff4..795019b92 100644 --- a/metagpt/team.py +++ b/metagpt/team.py @@ -15,7 +15,8 @@ from metagpt.logs import logger from metagpt.roles import Role from metagpt.schema import Message from metagpt.utils.common import NoMoneyException -from metagpt.utils.utils import read_json_file, write_json_file +from metagpt.utils.utils import read_json_file, write_json_file, serialize_decorator +from metagpt.const import SERDESER_PATH class Team(BaseModel): @@ -30,29 +31,35 @@ class Team(BaseModel): class Config: arbitrary_types_allowed = True - def serialize(self, stg_path: Path): + def serialize(self, stg_path: Path = None): + stg_path = SERDESER_PATH.joinpath("team") if stg_path is None else stg_path + team_info_path = stg_path.joinpath("team_info.json") - write_json_file(team_info_path, { - "idea": self.idea, - "investment": self.investment - }) + write_json_file(team_info_path, self.dict(exclude={"environment": True})) - self.environment.serialize(stg_path.joinpath("environment")) + self.environment.serialize(stg_path.joinpath("environment")) # save environment alone - def deserialize(self, stg_path: Path): + @classmethod + def recover(cls, stg_path: Path) -> "Team": + return cls.deserialize(stg_path) + + @classmethod + def deserialize(cls, stg_path: Path) -> "Team": """ stg_path = ./storage/team """ # recover team_info team_info_path = stg_path.joinpath("team_info.json") if not team_info_path.exists(): - logger.error("recover storage not exist, not to recover and continue run the old project.") - team_info = read_json_file(team_info_path) - self.investment = team_info.get("investment", 10.0) - self.idea = team_info.get("idea", "") + raise FileNotFoundError("recover storage meta file `team_info.json` not exist, " + "not to recover and please start a new project.") + + team_info: dict = read_json_file(team_info_path) # recover environment - environment_path = stg_path.joinpath("environment") - self.environment = Environment() - self.environment.deserialize(stg_path=environment_path) + environment = Environment.deserialize(stg_path=stg_path.joinpath("environment")) + team_info.update({"environment": environment}) + + team = Team(**team_info) + return team def hire(self, roles: list[Role]): """Hire roles to cooperate""" @@ -76,6 +83,7 @@ class Team(BaseModel): def _save(self): logger.info(self.json()) + @serialize_decorator async def run(self, n_round=3): """Run company until target round or no money""" while n_round > 0: @@ -85,4 +93,3 @@ class Team(BaseModel): self._check_balance() await self.environment.run() return self.environment.history - \ No newline at end of file diff --git a/metagpt/utils/serialize.py b/metagpt/utils/serialize.py index 56a866f2e..9a7049214 100644 --- a/metagpt/utils/serialize.py +++ b/metagpt/utils/serialize.py @@ -5,9 +5,7 @@ import copy import pickle -from metagpt.actions.action_output import ActionOutput -from metagpt.schema import Message -from metagpt.actions.action import Action +from metagpt.utils.utils import import_class def actionoutout_schema_to_mapping(schema: dict) -> dict: @@ -59,7 +57,7 @@ def actionoutput_str_to_mapping(mapping: dict) -> dict: return new_mapping -def serialize_general_message(message: Message) -> dict: +def serialize_general_message(message: "Message") -> dict: """ serialize Message, not to save""" message_cp = copy.deepcopy(message) ic = message_cp.instruct_content @@ -76,7 +74,7 @@ def serialize_general_message(message: Message) -> dict: return message_cp.dict() -def serialize_message(message: Message): +def serialize_message(message: "Message"): message_cp = copy.deepcopy(message) # avoid `instruct_content` value update by reference ic = message_cp.instruct_content if ic: @@ -90,29 +88,35 @@ def serialize_message(message: Message): return msg_ser -def deserialize_general_message(message_dict: dict) -> Message: +def deserialize_general_message(message_dict: dict) -> "Message": """ deserialize Message, not to load""" instruct_content = message_dict.pop("instruct_content") cause_by = message_dict.pop("cause_by") - message = Message(**message_dict) + message_cls = import_class("Message", "metagpt.schema") + message = message_cls(**message_dict) if instruct_content: ic = instruct_content mapping = actionoutput_str_to_mapping(ic["mapping"]) - ic_obj = ActionOutput.create_model_class(class_name=ic["class"], mapping=mapping) + + actionoutput_class = import_class("ActionOutput", "metagpt.actions.action_output") + ic_obj = actionoutput_class.create_model_class(class_name=ic["class"], mapping=mapping) ic_new = ic_obj(**ic["value"]) message.instruct_content = ic_new if cause_by: - message.cause_by = Action.deser_class(cause_by) + action_class = import_class("Action", "metagpt.actions.action") + message.cause_by = action_class.deser_class(cause_by) return message -def deserialize_message(message_ser: str) -> Message: +def deserialize_message(message_ser: str) -> "Message": message = pickle.loads(message_ser) if message.instruct_content: ic = message.instruct_content - ic_obj = ActionOutput.create_model_class(class_name=ic["class"], mapping=ic["mapping"]) + + actionoutput_class = import_class("ActionOutput", "metagpt.actions.action_output") + ic_obj = actionoutput_class.create_model_class(class_name=ic["class"], mapping=ic["mapping"]) ic_new = ic_obj(**ic["value"]) message.instruct_content = ic_new diff --git a/metagpt/utils/utils.py b/metagpt/utils/utils.py index 81ceea884..1cf618ba0 100644 --- a/metagpt/utils/utils.py +++ b/metagpt/utils/utils.py @@ -6,6 +6,9 @@ from typing import Any import json from pathlib import Path import importlib +import traceback + +from metagpt.logs import logger def read_json_file(json_file: str, encoding=None) -> list[Any]: @@ -39,3 +42,43 @@ def import_class_inst(class_name: str, module_name: str, *args, **kwargs) -> obj a_class = import_class(class_name, module_name) class_inst = a_class(*args, **kwargs) return class_inst + + +def format_trackback_info(limit: int = 2): + return traceback.format_exc(limit=limit) + + +def serialize_decorator(func): + async def wrapper(self, *args, **kwargs): + try: + return await func(self, *args, **kwargs) + except KeyboardInterrupt as kbi: + logger.error(f"KeyboardInterrupt occurs, start to serialize the project, exp:\n{format_trackback_info()}") + self.serialize() # Team.serialize + except Exception as exp: + logger.error(f"Exception occurs, start to serialize the project, exp:\n{format_trackback_info()}") + self.serialize() # Team.serialize + + return wrapper + + +def role_raise_decorator(func): + async def wrapper(self, *args, **kwargs): + try: + return await func(self, *args, **kwargs) + except KeyboardInterrupt as kbi: + logger.error(f"KeyboardInterrupt: {kbi} occurs, start to serialize the project") + if self._rc.env: + newest_msgs = self._rc.env.memory.get(1) + if len(newest_msgs) > 0: + self._rc.memory.delete(newest_msgs[0]) + except Exception as exp: + if self._rc.env: + newest_msgs = self._rc.env.memory.get(1) + if len(newest_msgs) > 0: + logger.warning("There is a exception in role's execution, in order to resume, " + "we delete the newest role communication message in the role's memory.") + self._rc.memory.delete(newest_msgs[0]) # remove newest msg of the role to make it observed again + raise Exception(format_trackback_info(limit=None)) # raise again to make it captured outside + + return wrapper diff --git a/startup.py b/startup.py index 9f753d553..c4928a1b5 100644 --- a/startup.py +++ b/startup.py @@ -1,10 +1,11 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- + +from typing import Optional import asyncio - import fire +from pathlib import Path -from metagpt.const import SERDES_PATH from metagpt.roles import ( Architect, Engineer, @@ -22,11 +23,11 @@ async def startup( code_review: bool = False, run_tests: bool = False, implement: bool = True, - recover_path: bool = False, + recover_path: Optional[str] = None, ): """Run a startup. Be a boss.""" - company = Team() if not recover_path: + company = Team() company.hire( [ ProductManager(), @@ -45,8 +46,12 @@ async def startup( # (bug fixing capability comes soon!) company.hire([QaEngineer()]) else: - stg_path = SERDES_PATH.joinpath("team") - company.deserialize(stg_path=stg_path) + # # stg_path = SERDESER_PATH.joinpath("team") + stg_path = Path(recover_path) + if not stg_path.exists() or not str(stg_path).endswith("team"): + raise FileNotFoundError(f"{recover_path} not exists or not endswith `team`") + + company = Team.recover(stg_path=stg_path) idea = company.idea # use original idea company.invest(investment)