From 2abe99cf45ec07bf69c44ec4c374704a798fd4c6 Mon Sep 17 00:00:00 2001 From: better629 Date: Thu, 30 Nov 2023 15:18:24 +0800 Subject: [PATCH] update environment/message to BaseModel, update the ser&deser of roles/actions --- metagpt/actions/action.py | 28 ++++- metagpt/actions/design_api.py | 3 +- metagpt/actions/project_management.py | 1 + metagpt/actions/search_and_summarize.py | 7 +- metagpt/actions/write_code.py | 9 +- metagpt/actions/write_code_review.py | 3 +- metagpt/actions/write_prd.py | 3 +- metagpt/actions/write_test.py | 11 +- metagpt/environment.py | 20 +++- metagpt/memory/longterm_memory.py | 14 ++- metagpt/memory/memory.py | 64 +++++++---- metagpt/roles/customer_service.py | 16 ++- metagpt/roles/product_manager.py | 1 + metagpt/roles/project_manager.py | 2 +- metagpt/roles/qa_engineer.py | 24 +++-- metagpt/roles/role.py | 52 ++++++--- metagpt/roles/sales.py | 33 +++--- metagpt/roles/searcher.py | 23 ++-- metagpt/schema.py | 134 ++++++++++-------------- metagpt/team.py | 38 ++++--- metagpt/utils/serialize.py | 26 +++-- metagpt/utils/utils.py | 40 +++++++ startup.py | 17 +-- 23 files changed, 361 insertions(+), 208 deletions(-) diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index e890ef76a..499b5e794 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -6,12 +6,17 @@ @File : action.py """ +from __future__ import annotations +import re +from typing import Optional, Any + from typing import Optional, Any from tenacity import retry, stop_after_attempt, wait_random_exponential from pydantic import BaseModel, Field from metagpt.actions.action_output import ActionOutput from metagpt.llm import LLM +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.logs import logger from metagpt.provider.postprecess.llm_output_postprecess import llm_output_postprecess from metagpt.utils.common import OutputParser @@ -24,18 +29,31 @@ action_subclass_registry = {} class Action(BaseModel): name: str = "" - llm: LLM = Field(default_factory=LLM) - context = None + llm: BaseGPTAPI = Field(default_factory=LLM, exclude=True) + context = "" prefix = "" # aask*时会加上prefix,作为system_message profile = "" # FIXME: USELESS desc = "" # for skill manager - nodes = None # content: Optional[str] = None # instruct_content: Optional[str] = None + + # builtin variables + builtin_class_name: str = "" + + class Config: + arbitrary_types_allowed = True def __init__(self, **kwargs: Any): super().__init__(**kwargs) + # deserialize child classes dynamically for inherited `action` + object.__setattr__(self, "builtin_class_name", self.__class__.__name__) + self.__fields__["builtin_class_name"].default = self.__class__.__name__ + + def __init_subclass__(cls, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + action_subclass_registry[cls.__name__] = cls + def set_prefix(self, prefix, profile): """Set prefix for later usage""" self.prefix = prefix @@ -56,14 +74,14 @@ class Action(BaseModel): } @classmethod - def deserialize(cls, action_dict: dict): + def deserialize(cls, action_dict: dict) -> "Action": action_class_str = action_dict.pop("action_class") module_name = action_dict.pop("module_name") action_class = import_class(action_class_str, module_name) return action_class(**action_dict) @classmethod - def ser_class(cls): + def ser_class(cls) -> dict: """ serialize class type""" return { "action_class": cls.__name__, diff --git a/metagpt/actions/design_api.py b/metagpt/actions/design_api.py index a10ff1c9a..504328582 100644 --- a/metagpt/actions/design_api.py +++ b/metagpt/actions/design_api.py @@ -17,6 +17,7 @@ from pydantic import Field from metagpt.actions import Action, ActionOutput from metagpt.actions.design_api_an import DESIGN_API_NODE from metagpt.llm import LLM +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.config import CONFIG from metagpt.const import ( DATA_API_DESIGN_FILE_REPO, @@ -43,7 +44,7 @@ NEW_REQ_TEMPLATE = """ class WriteDesign(Action): name: str = "" context: Optional[str] = None - llm: LLM = Field(default_factory=LLM) + llm: BaseGPTAPI = Field(default_factory=LLM) desc: str = "Based on the PRD, think about the system design, and design the corresponding APIs, " "data structures, library tables, processes, and paths. Please provide your design, feedback " "clearly and in detail." diff --git a/metagpt/actions/project_management.py b/metagpt/actions/project_management.py index d830a4c15..98a948b64 100644 --- a/metagpt/actions/project_management.py +++ b/metagpt/actions/project_management.py @@ -18,6 +18,7 @@ from metagpt.actions import ActionOutput from metagpt.actions.action import Action from metagpt.actions.project_management_an import PM_NODE from metagpt.llm import LLM +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.config import CONFIG from metagpt.const import ( PACKAGE_REQUIREMENTS_FILENAME, diff --git a/metagpt/actions/search_and_summarize.py b/metagpt/actions/search_and_summarize.py index 7b549518e..7bff1c113 100644 --- a/metagpt/actions/search_and_summarize.py +++ b/metagpt/actions/search_and_summarize.py @@ -11,7 +11,8 @@ from pydantic import BaseModel, Field from metagpt.actions import Action from metagpt.llm import LLM -from metagpt.config import Config +from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.config import Config, CONFIG from metagpt.logs import logger from metagpt.schema import Message from metagpt.tools.search_engine import SearchEngine @@ -106,9 +107,9 @@ You are a member of a professional butler team and will provide helpful suggesti class SearchAndSummarize(Action): name: str = "" content: Optional[str] = None - llm: None = Field(default_factory=LLM) + llm: BaseGPTAPI = Field(default_factory=LLM) config: None = Field(default_factory=Config) - engine: Optional[str] = None + engine: Optional[str] = CONFIG.search_engine search_func: Optional[str] = None search_engine: SearchEngine = None diff --git a/metagpt/actions/write_code.py b/metagpt/actions/write_code.py index 2d155e6bf..bad9a0890 100644 --- a/metagpt/actions/write_code.py +++ b/metagpt/actions/write_code.py @@ -16,14 +16,9 @@ """ import json - from tenacity import retry, stop_after_attempt, wait_random_exponential - - - from typing import List, Optional, Any from pydantic import Field -from tenacity import retry, stop_after_attempt, wait_fixed from metagpt.actions.action import Action from metagpt.config import CONFIG @@ -34,8 +29,8 @@ from metagpt.const import ( TASK_FILE_REPO, TEST_OUTPUTS_FILE_REPO, ) -from metagpt.actions import WriteDesign from metagpt.llm import LLM +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.logs import logger from metagpt.schema import CodingContext, Document, RunCodeResult from metagpt.utils.common import CodeParser @@ -95,7 +90,7 @@ ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenc class WriteCode(Action): name: str = "WriteCode" context: Optional[str] = None - llm: LLM = Field(default_factory=LLM) + llm: BaseGPTAPI = Field(default_factory=LLM) @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) async def write_code(self, prompt) -> str: diff --git a/metagpt/actions/write_code_review.py b/metagpt/actions/write_code_review.py index bf07d0a93..83225060a 100644 --- a/metagpt/actions/write_code_review.py +++ b/metagpt/actions/write_code_review.py @@ -18,6 +18,7 @@ from metagpt.actions.action import Action from metagpt.config import CONFIG from metagpt.logs import logger from metagpt.schema import CodingContext +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.utils.common import CodeParser PROMPT_TEMPLATE = """ @@ -124,7 +125,7 @@ REWRITE_CODE_TEMPLATE = """ class WriteCodeReview(Action): name: str = "WriteCodeReview" context: Optional[str] = None - llm: LLM = Field(default_factory=LLM) + llm: BaseGPTAPI = Field(default_factory=LLM) @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) async def write_code_review_and_rewrite(self, context_prompt, cr_prompt, filename): diff --git a/metagpt/actions/write_prd.py b/metagpt/actions/write_prd.py index 7f9089763..8510733ac 100644 --- a/metagpt/actions/write_prd.py +++ b/metagpt/actions/write_prd.py @@ -27,6 +27,7 @@ from metagpt.actions.write_prd_an import ( WRITE_PRD_NODE, ) from metagpt.llm import LLM +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.actions.search_and_summarize import SearchAndSummarize from metagpt.config import CONFIG from metagpt.const import ( @@ -67,7 +68,7 @@ NEW_REQ_TEMPLATE = """ class WritePRD(Action): name: str = "" content: Optional[str] = None - llm: LLM = Field(default_factory=LLM) + llm: BaseGPTAPI = Field(default_factory=LLM) async def run(self, with_messages, format=CONFIG.prompt_format, *args, **kwargs) -> ActionOutput | Message: # Determine which requirement documents need to be rewritten: Use LLM to assess whether new requirements are diff --git a/metagpt/actions/write_test.py b/metagpt/actions/write_test.py index 9dd967788..fa3931ba6 100644 --- a/metagpt/actions/write_test.py +++ b/metagpt/actions/write_test.py @@ -7,6 +7,12 @@ @Modified By: mashenquan, 2023-11-27. Following the think-act principle, solidify the task parameters when creating the WriteTest object, rather than passing them in when calling the run function. """ + +from typing import Optional +from pydantic import Field + +from metagpt.llm import LLM +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.actions.action import Action from metagpt.config import CONFIG from metagpt.const import TEST_CODES_FILE_REPO @@ -36,8 +42,9 @@ you should correctly import the necessary classes based on these file locations! class WriteTest(Action): - def __init__(self, name="WriteTest", context=None, llm=None): - super().__init__(name, context, llm) + name: str = "WriteTest" + context: Optional[str] = None + llm: BaseGPTAPI = Field(default_factory=LLM) async def write_code(self, prompt): code_rsp = await self._aask(prompt) diff --git a/metagpt/environment.py b/metagpt/environment.py index 19197bd10..242581e17 100644 --- a/metagpt/environment.py +++ b/metagpt/environment.py @@ -19,6 +19,8 @@ from pydantic import BaseModel, Field from metagpt.logs import logger from metagpt.roles import Role +from metagpt.memory import Memory +from metagpt.roles.role import Role, role_subclass_registry from metagpt.schema import Message from metagpt.utils.common import is_subscribed from metagpt.utils.utils import read_json_file, write_json_file @@ -37,6 +39,19 @@ class Environment(BaseModel): class Config: arbitrary_types_allowed = True + def __init__(self, **kwargs): + for role_key, role in kwargs.get("roles", {}).items(): + current_role = kwargs["roles"][role_key] + if isinstance(current_role, dict): + item_class_name = current_role.get("builtin_class_name", None) + for name, subclass in role_subclass_registry.items(): + registery_class_name = subclass.__fields__["builtin_class_name"].default + if item_class_name == registery_class_name: + current_role = subclass(**current_role) + break + kwargs["roles"][role_key] = current_role + super().__init__(**kwargs) + def serialize(self, stg_path: Path): roles_path = stg_path.joinpath("roles.json") roles_info = [] @@ -53,7 +68,8 @@ class Environment(BaseModel): history_path = stg_path.joinpath("history.json") write_json_file(history_path, {"content": self.history}) - def deserialize(self, stg_path: Path): + @classmethod + def deserialize(cls, stg_path: Path) -> "Environment": """ stg_path: ./storage/team/environment/ """ """ stg_path: ./storage/team/environment/ """ roles_path = stg_path.joinpath("roles.json") @@ -80,7 +96,7 @@ class Environment(BaseModel): """ role.set_env(self) # use alias - self.roles[role.role_profile] = role + self.roles[role.profile] = role def add_roles(self, roles: Iterable[Role]): """增加一批在当前环境的角色 diff --git a/metagpt/memory/longterm_memory.py b/metagpt/memory/longterm_memory.py index 22032a86e..e8a5be395 100644 --- a/metagpt/memory/longterm_memory.py +++ b/metagpt/memory/longterm_memory.py @@ -4,6 +4,9 @@ @Desc : the implement of Long-term memory """ +from typing import Optional +from pydantic import Field + from metagpt.logs import logger from metagpt.memory import Memory from metagpt.memory.memory_storage import MemoryStorage @@ -17,11 +20,12 @@ class LongTermMemory(Memory): - update memory when it changed """ - def __init__(self): - self.memory_storage: MemoryStorage = MemoryStorage() - super(LongTermMemory, self).__init__() - self.rc = None # RoleContext - self.msg_from_recover = False + memory_storage: MemoryStorage = Field(default_factory=MemoryStorage) + rc: Optional["RoleContext"] = None + msg_from_recover: bool = False + + class Config: + arbitrary_types_allowed = True def recover_memory(self, role_id: str, rc: "RoleContext"): messages = self.memory_storage.recover_memory(role_id) diff --git a/metagpt/memory/memory.py b/metagpt/memory/memory.py index 43bd33e59..adef0d283 100644 --- a/metagpt/memory/memory.py +++ b/metagpt/memory/memory.py @@ -6,34 +6,51 @@ @File : memory.py @Modified By: mashenquan, 2023-11-1. According to RFC 116: Updated the type of index key. """ +import copy from collections import defaultdict -from typing import Iterable, Set +from typing import Iterable, Type, Union, Optional, Set from pathlib import Path +from pydantic import BaseModel, Field +import json from metagpt.schema import Message from metagpt.utils.common import any_to_str, any_to_str_set from metagpt.utils.utils import read_json_file, write_json_file -from metagpt.utils.serialize import serialize_general_message, deserialize_general_message +from metagpt.utils.utils import import_class -class Memory: +class Memory(BaseModel): """The most basic memory: super-memory""" - def __init__(self): - """Initialize an empty storage list and an empty index dictionary""" - self.storage: list[Message] = [] - self.index: dict[str, list[Message]] = defaultdict(list) + storage: list[Message] = Field(default=[]) + index: dict[str, list[Message]] = Field(default_factory=defaultdict(list)) + + def __init__(self, **kwargs): + index = kwargs.get("index", {}) + new_index = defaultdict(list) + for action_str, value in index.items(): + action_dict = json.loads(action_str) + action_class = import_class("Action", "metagpt.actions.action") + action_obj = action_class.deser_class(action_dict) + new_index[action_obj] = [Message(**item_dict) for item_dict in value] + kwargs["index"] = new_index + super(Memory, self).__init__(**kwargs) + self.index = new_index + + def dict(self, *args, **kwargs) -> "DictStrAny": + """ overwrite the `dict` to dump dynamic pydantic model""" + obj_dict = super(Memory, self).dict(*args, **kwargs) + new_obj_dict = copy.deepcopy(obj_dict) + new_obj_dict["index"] = {} + for action, value in obj_dict["index"].items(): + action_ser = json.dumps(action.ser_class()) + new_obj_dict["index"][action_ser] = value + return new_obj_dict def serialize(self, stg_path: Path): """ stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/ """ memory_path = stg_path.joinpath("memory.json") - - storage = [] - for message in self.storage: - # msg_dict = message.serialize() - msg_dict = serialize_general_message(message) - storage.append(msg_dict) - + storage = self.dict() write_json_file(memory_path, storage) @classmethod @@ -41,13 +58,8 @@ class Memory: """ stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/""" memory_path = stg_path.joinpath("memory.json") - memory = Memory() - memory_list = read_json_file(memory_path) - for message in memory_list: - # distinguish instruct_content type in message - # msg = Message.deserialize(message) - msg = deserialize_general_message(message) - memory.add(msg) + memory_dict = read_json_file(memory_path) + memory = Memory(**memory_dict) return memory @@ -71,6 +83,16 @@ class Memory: """Return all messages containing a specified content""" return [message for message in self.storage if content in message.content] + def delete_newest(self) -> "Message": + """ delete the newest message from the storage""" + if len(self.storage) > 0: + newest_msg = self.storage.pop() + if newest_msg.cause_by and newest_msg in self.index[newest_msg.cause_by]: + self.index[newest_msg.cause_by].remove(newest_msg) + else: + newest_msg = None + return newest_msg + def delete(self, message: Message): """Delete the specified message from storage, while updating the index""" self.storage.remove(message) diff --git a/metagpt/roles/customer_service.py b/metagpt/roles/customer_service.py index 188182d47..62792696f 100644 --- a/metagpt/roles/customer_service.py +++ b/metagpt/roles/customer_service.py @@ -5,6 +5,9 @@ @Author : alexanderwu @File : sales.py """ +from typing import Optional +from pydantic import Field + from metagpt.roles import Sales # from metagpt.actions import SearchAndSummarize @@ -24,5 +27,14 @@ DESC = """ class CustomerService(Sales): - def __init__(self, name="Xiaomei", profile="Human customer service", desc=DESC, store=None): - super().__init__(name, profile, desc=desc, store=store) + + name: str = Field(default="Xiaomei") + profile: str = Field(default="Human customer service") + desc: str = DESC, + + store: Optional[str] = None + + def __init__( + self, + **kwargs): + super().__init__(**kwargs) diff --git a/metagpt/roles/product_manager.py b/metagpt/roles/product_manager.py index a49459fca..30017b60d 100644 --- a/metagpt/roles/product_manager.py +++ b/metagpt/roles/product_manager.py @@ -29,6 +29,7 @@ class ProductManager(Role): role_profile: str = Field(default="Product Manager", alias='profile') goal: str = "efficiently create a successful product" constraints: str = "use same language as user requiremen" + """ Represents a Product Manager role responsible for product development and management. """ diff --git a/metagpt/roles/project_manager.py b/metagpt/roles/project_manager.py index 211e41d3b..b7ee1ed53 100644 --- a/metagpt/roles/project_manager.py +++ b/metagpt/roles/project_manager.py @@ -22,7 +22,7 @@ class ProjectManager(Role): goal (str): Goal of the project manager. constraints (str): Constraints or limitations for the project manager. """ - name: str = "Eve" + name: str = Field(default="Eve") profile: str = Field(default="Project Manager") goal: str = "reak down tasks according to PRD/technical design, generate a task list, and analyze task " \ diff --git a/metagpt/roles/qa_engineer.py b/metagpt/roles/qa_engineer.py index 4439b9b19..ec404570c 100644 --- a/metagpt/roles/qa_engineer.py +++ b/metagpt/roles/qa_engineer.py @@ -14,7 +14,9 @@ @Modified By: mashenquan, 2023-12-5. Enhance the workflow to navigate to WriteCode or QaEngineer based on the results of SummarizeCode. """ -from metagpt.actions import DebugError, RunCode, WriteTest + +from pydantic import Field + from metagpt.actions.summarize_code import SummarizeCode from metagpt.config import CONFIG from metagpt.const import ( @@ -22,6 +24,11 @@ from metagpt.const import ( TEST_CODES_FILE_REPO, TEST_OUTPUTS_FILE_REPO, ) +from metagpt.actions import ( + DebugError, + RunCode, + WriteTest, +) from metagpt.logs import logger from metagpt.roles import Role from metagpt.schema import Document, Message, RunCodeContext, TestingContext @@ -30,21 +37,22 @@ from metagpt.utils.file_repository import FileRepository class QaEngineer(Role): + name: str = Field(default="Edward") + profile: str = Field(default="QaEngineer") + goal: str = "Write comprehensive and robust tests to ensure codes will work as expected without bugs" + constraints: str = "The test code you write should conform to code standard like PEP8, be modular, easy to read and maintain" + test_round_allowed: int = 5 + def __init__( self, - name="Edward", - profile="QaEngineer", - goal="Write comprehensive and robust tests to ensure codes will work as expected without bugs", - constraints="The test code you write should conform to code standard like PEP8, be modular, easy to read and maintain", - test_round_allowed=5, + **kwargs ): - super().__init__(name, profile, goal, constraints) + super().__init__(**kwargs) self._init_actions( [WriteTest] ) # FIXME: a bit hack here, only init one action to circumvent _think() logic, will overwrite _think() in future updates self._watch([SummarizeCode, WriteTest, RunCode, DebugError]) self.test_round = 0 - self.test_round_allowed = test_round_allowed async def _write_test(self, message: Message) -> None: src_file_repo = CONFIG.git_repo.new_file_repository(CONFIG.src_workspace) diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index f1d7df5e7..114e9e599 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -28,15 +28,32 @@ from pydantic import BaseModel, Field from metagpt.actions.action import Action, ActionOutput, action_subclass_registry from metagpt.actions.action_node import ActionNode from metagpt.actions.add_requirement import UserRequirement + +from pathlib import Path + +from typing import ( + Iterable, + Type, + Any +) +from pydantic import BaseModel, Field, validator + +# from metagpt.environment import Environment +from metagpt.config import CONFIG +from metagpt.actions.action import Action, ActionOutput, action_subclass_registry from metagpt.llm import LLM +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.logs import logger from metagpt.schema import Message, MessageQueue from metagpt.utils.common import any_to_str from metagpt.utils.repair_llm_raw_output import extract_state_value_from_output from metagpt.memory import Memory from metagpt.provider.human_provider import HumanProvider + from metagpt.utils.utils import read_json_file, write_json_file, import_class from metagpt.provider.base_gpt_api import BaseGPTAPI + +from metagpt.utils.utils import read_json_file, write_json_file, import_class, role_raise_decorator from metagpt.const import SERDESER_PATH @@ -80,13 +97,12 @@ class RoleReactMode(str, Enum): class RoleSetting(BaseModel): """Role Settings""" - - name: str - profile: str - goal: str - constraints: str - desc: str - is_human: bool + name: str = "" + profile: str = "" + goal: str = "" + constraints: str = "" + desc: str = "" + is_human: bool = False def __str__(self): return f"{self.name}({self.profile})" @@ -174,8 +190,8 @@ class Role(BaseModel): class Config: arbitrary_types_allowed = True exclude = ["_llm"] - - def __init__(self, **kwargs): + + def __init__(self, **kwargs: Any): for index in range(len(kwargs.get("_actions", []))): current_action = kwargs["_actions"][index] if isinstance(current_action, dict): @@ -212,15 +228,19 @@ class Role(BaseModel): object.__setattr__(self, "builtin_class_name", self.__class__.__name__) self.__fields__["builtin_class_name"].default = self.__class__.__name__ + def __init_subclass__(cls, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + role_subclass_registry[cls.__name__] = cls + def _reset(self): - object.__setattr__(self, '_states', []) - object.__setattr__(self, '_actions', []) + object.__setattr__(self, "_states", []) + object.__setattr__(self, "_actions", []) @property def _setting(self): return f"{self.name}({self.profile})" - def serialize(self, stg_path: Path): + def serialize(self, stg_path: Path = None): stg_path = SERDESER_PATH.joinpath(f"team/environment/roles/{self.__class__.__name__}_{self.name}") \ if stg_path is None else stg_path @@ -256,7 +276,7 @@ class Role(BaseModel): action.set_prefix(self._get_prefix(), self.profile) def set_recovered(self, recovered: bool = False): - self._recovered = recovered + self.recovered = recovered def set_memory(self, memory: Memory): self._rc.memory = memory @@ -269,7 +289,7 @@ class Role(BaseModel): for idx, action in enumerate(actions): if not isinstance(action, Action): ## 默认初始化 - i = action() + i = action(name="", llm=self._llm) else: if self._setting.is_human and not isinstance(action.llm, HumanProvider): logger.warning( @@ -358,6 +378,10 @@ class Role(BaseModel): def subscription(self) -> Set: """The labels for messages to be consumed by the Role object.""" return self._subscription + + def set_env(self, env: "Environment"): + """Set the environment in which the role works. The role can talk to the environment and can also receive messages by observing.""" + self._rc.env = env def _get_prefix(self): """Get the role prefix""" diff --git a/metagpt/roles/sales.py b/metagpt/roles/sales.py index d5aac1824..826413dc8 100644 --- a/metagpt/roles/sales.py +++ b/metagpt/roles/sales.py @@ -5,26 +5,31 @@ @Author : alexanderwu @File : sales.py """ + +from typing import Optional +from pydantic import Field + from metagpt.actions import SearchAndSummarize from metagpt.roles import Role from metagpt.tools import SearchEngineType class Sales(Role): - def __init__( - self, - name="Xiaomei", - profile="Retail sales guide", - desc="I am a sales guide in retail. My name is Xiaomei. I will answer some customer questions next, and I " - "will answer questions only based on the information in the knowledge base." - "If I feel that you can't get the answer from the reference material, then I will directly reply that" - " I don't know, and I won't tell you that this is from the knowledge base," - "but pretend to be what I know. Note that each of my replies will be replied in the tone of a " - "professional guide", - store=None, - ): - super().__init__(name, profile, desc=desc) - self._set_store(store) + + name: str = Field(default="Xiaomei") + profile: str = Field(default="Retail sales guide") + desc: str = "I am a sales guide in retail. My name is Xiaomei. I will answer some customer questions next, and I " + "will answer questions only based on the information in the knowledge base." + "If I feel that you can't get the answer from the reference material, then I will directly reply that" + " I don't know, and I won't tell you that this is from the knowledge base," + "but pretend to be what I know. Note that each of my replies will be replied in the tone of a " + "professional guide", + + store: Optional[str] = None + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._set_store(self.store) def _set_store(self, store): if store: diff --git a/metagpt/roles/searcher.py b/metagpt/roles/searcher.py index 5760202ff..7d58ad922 100644 --- a/metagpt/roles/searcher.py +++ b/metagpt/roles/searcher.py @@ -7,6 +7,9 @@ @Modified By: mashenquan, 2023-11-1. According to Chapter 2.2.1 and 2.2.2 of RFC 116, change the data type of the `cause_by` value in the `Message` to a string to support the new message distribution feature. """ + +from pydantic import Field + from metagpt.actions import ActionOutput, SearchAndSummarize from metagpt.actions.action_node import ActionNode from metagpt.logs import logger @@ -27,15 +30,13 @@ class Searcher(Role): engine (SearchEngineType): The type of search engine to use. """ - def __init__( - self, - name: str = "Alice", - profile: str = "Smart Assistant", - goal: str = "Provide search services for users", - constraints: str = "Answer is rich and complete", - engine=SearchEngineType.SERPAPI_GOOGLE, - **kwargs, - ) -> None: + name: str = Field(default="Alice") + profile: str = Field(default="Smart Assistant") + goal: str = "Provide search services for users" + constraints: str = "Answer is rich and complete" + engine: SearchEngineType = SearchEngineType.SERPAPI_GOOGLE + + def __init__(self, **kwargs) -> None: """ Initializes the Searcher role with given attributes. @@ -46,8 +47,8 @@ class Searcher(Role): constraints (str): Constraints or limitations for the searcher. engine (SearchEngineType): The type of search engine to use. """ - super().__init__(name, profile, goal, constraints, **kwargs) - self._init_actions([SearchAndSummarize(engine=engine)]) + super().__init__(**kwargs) + self._init_actions([SearchAndSummarize(engine=self.engine)]) def set_search_func(self, search_func): """Sets a custom search function for the searcher.""" diff --git a/metagpt/schema.py b/metagpt/schema.py index 78e4a6031..a872481bb 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -12,7 +12,6 @@ between actions. 3. Add `id` to `Message` according to Section 2.2.3.1.1 of RFC 135. """ -from __future__ import annotations import asyncio import json @@ -24,6 +23,12 @@ from pathlib import Path from typing import Dict, List, Optional, Set, TypedDict from pydantic import BaseModel, Field +from dataclasses import dataclass, field +from typing import Type, TypedDict, Union, Optional + +from pydantic import BaseModel, Field +from pydantic.main import ModelMetaclass + from metagpt.config import CONFIG from metagpt.const import ( MESSAGE_ROUTE_CAUSE_BY, @@ -34,11 +39,16 @@ from metagpt.const import ( TASK_FILE_REPO, ) from metagpt.logs import logger + from metagpt.utils.common import any_to_str, any_to_str_set # from metagpt.utils.serialize import actionoutout_schema_to_mapping # from metagpt.actions.action_output import ActionOutput # from metagpt.actions.action import Action +from metagpt.utils.serialize import actionoutout_schema_to_mapping, actionoutput_mapping_to_str, \ + actionoutput_str_to_mapping +from metagpt.utils.utils import import_class + class RawMessage(TypedDict): content: str @@ -54,7 +64,7 @@ class Document(BaseModel): filename: str = "" content: str = "" - def get_meta(self) -> Document: + def get_meta(self) -> "Document"": """Get metadata of the document. :return: A new Document instance with the same root path and filename. @@ -104,39 +114,21 @@ class Message(BaseModel): sent_from: str = "" send_to: Set = Field(default_factory={MESSAGE_ROUTE_TO_ALL}) - def __init__( - self, - content, - instruct_content=None, - role="user", - cause_by="", - sent_from="", - send_to=MESSAGE_ROUTE_TO_ALL, - **kwargs, - ): - """ - Parameters not listed below will be stored as meta info, including custom parameters. - :param content: Message content. - :param instruct_content: Message content struct. - :param cause_by: Message producer - :param sent_from: Message route info tells who sent this message. - :param send_to: Specifies the target recipient or consumer for message delivery in the environment. - :param role: Message meta info tells who sent this message. - """ - if not cause_by: - from metagpt.actions import UserRequirement - cause_by = UserRequirement + def __init__(self, **kwargs): + instruct_content = kwargs.get("instruct_content", None) + cause_by = kwargs.get("cause_by", None) + if instruct_content and not isinstance(instruct_content, BaseModel): + ic = instruct_content + mapping = actionoutput_str_to_mapping(ic["mapping"]) - super().__init__( - id=uuid.uuid4().hex, - content=content, - instruct_content=instruct_content, - role=role, - cause_by=any_to_str(cause_by), - sent_from=any_to_str(sent_from), - send_to=any_to_str_set(send_to), - **kwargs, - ) + actionoutput_class = import_class("ActionOutput", "metagpt.actions.action_output") + ic_obj = actionoutput_class.create_model_class(class_name=ic["class"], mapping=mapping) + ic_new = ic_obj(**ic["value"]) + kwargs["instruct_content"] = ic_new + if cause_by and not isinstance(cause_by, ModelMetaclass): + action_class = import_class("Action", "metagpt.actions.action") + kwargs["cause_by"] = action_class.deser_class(cause_by) + super(Message, self).__init__(**kwargs) def __setattr__(self, key, val): """Override `@property.setter`, convert non-string parameters into string parameters.""" @@ -150,6 +142,21 @@ class Message(BaseModel): new_val = val super().__setattr__(key, new_val) + def dict(self, *args, **kwargs) -> "DictStrAny": + """ overwrite the `dict` to dump dynamic pydantic model""" + obj_dict = super(Message, self).dict(*args, **kwargs) + ic = self.instruct_content # deal custom-defined action + if ic: + schema = ic.schema() + mapping = actionoutout_schema_to_mapping(schema) + mapping = actionoutput_mapping_to_str(mapping) + + obj_dict["instruct_content"] = {"class": schema["title"], "mapping": mapping, "value": ic.dict()} + cb = self.cause_by + if cb: + obj_dict["cause_by"] = cb.ser_class() + return obj_dict + def __str__(self): # prefix = '-'.join([self.role, str(self.cause_by)]) return f"{self.role}: {self.content}" @@ -157,45 +164,16 @@ class Message(BaseModel): def __repr__(self): return self.__str__() - # def serialize(self): - # message_cp: Message = copy.deepcopy(self) - # ic = message_cp.instruct_content - # if ic: - # # model create by pydantic create_model like `pydantic.main.prd`, can't pickle.dump directly - # schema = ic.schema() - # mapping = actionoutout_schema_to_mapping(schema) - # - # message_cp.instruct_content = {"class": schema["title"], "mapping": mapping, "value": ic.dict()} - # cb = message_cp.cause_by - # if cb: - # message_cp.cause_by = cb.serialize() - # - # return message_cp.dict() - # - # @classmethod - # def deserialize(cls, message_dict: dict): - # instruct_content = message_dict.get("instruct_content") - # if instruct_content: - # ic = instruct_content - # ic_obj = ActionOutput.create_model_class(class_name=ic["class"], mapping=ic["mapping"]) - # ic_new = ic_obj(**ic["value"]) - # message_dict.instruct_content = ic_new - # cause_by = message_dict.get("cause_by") - # if cause_by: - # message_dict.cause_by = Action.deserialize(cause_by) - # - # return Message(**message_dict) - - def dict(self): - return { - "content": self.content, - "instruct_content": self.instruct_content, - "role": self.role, - "cause_by": self.cause_by, - "sent_from": self.sent_from, - "send_to": self.send_to, - "restricted_to": self.restricted_to - } + # def dict(self): + # return { + # "content": self.content, + # "instruct_content": self.instruct_content, + # "role": self.role, + # "cause_by": self.cause_by, + # "sent_from": self.sent_from, + # "send_to": self.send_to, + # "restricted_to": self.restricted_to + # } def to_dict(self) -> dict: """Return a dict containing `role` and `content` for the LLM call.l""" @@ -316,7 +294,7 @@ class CodingContext(BaseModel): code_doc: Optional[Document] @staticmethod - def loads(val: str) -> CodingContext | None: + def loads(val: str) -> "CodingContext" | None: try: m = json.loads(val) return CodingContext(**m) @@ -330,7 +308,7 @@ class TestingContext(BaseModel): test_doc: Optional[Document] @staticmethod - def loads(val: str) -> TestingContext | None: + def loads(val: str) -> "TestingContext" | None: try: m = json.loads(val) return TestingContext(**m) @@ -351,7 +329,7 @@ class RunCodeContext(BaseModel): output: Optional[str] @staticmethod - def loads(val: str) -> RunCodeContext | None: + def loads(val: str) -> "RunCodeContext" | None: try: m = json.loads(val) return RunCodeContext(**m) @@ -365,7 +343,7 @@ class RunCodeResult(BaseModel): stderr: str @staticmethod - def loads(val: str) -> RunCodeResult | None: + def loads(val: str) -> "RunCodeResult" | None: try: m = json.loads(val) return RunCodeResult(**m) @@ -380,7 +358,7 @@ class CodeSummarizeContext(BaseModel): reason: str = "" @staticmethod - def loads(filenames: List) -> CodeSummarizeContext: + def loads(filenames: List) -> "CodeSummarizeContext": ctx = CodeSummarizeContext() for filename in filenames: if Path(filename).is_relative_to(SYSTEM_DESIGN_FILE_REPO): diff --git a/metagpt/team.py b/metagpt/team.py index 02c48a138..87a6766f6 100644 --- a/metagpt/team.py +++ b/metagpt/team.py @@ -18,7 +18,8 @@ from metagpt.logs import logger from metagpt.roles import Role from metagpt.schema import Message from metagpt.utils.common import NoMoneyException -from metagpt.utils.utils import read_json_file, write_json_file +from metagpt.utils.utils import read_json_file, write_json_file, serialize_decorator +from metagpt.const import SERDESER_PATH class Team(BaseModel): @@ -34,29 +35,35 @@ class Team(BaseModel): class Config: arbitrary_types_allowed = True - def serialize(self, stg_path: Path): + def serialize(self, stg_path: Path = None): + stg_path = SERDESER_PATH.joinpath("team") if stg_path is None else stg_path + team_info_path = stg_path.joinpath("team_info.json") - write_json_file(team_info_path, { - "idea": self.idea, - "investment": self.investment - }) + write_json_file(team_info_path, self.dict(exclude={"environment": True})) - self.environment.serialize(stg_path.joinpath("environment")) + self.environment.serialize(stg_path.joinpath("environment")) # save environment alone - def deserialize(self, stg_path: Path): + @classmethod + def recover(cls, stg_path: Path) -> "Team": + return cls.deserialize(stg_path) + + @classmethod + def deserialize(cls, stg_path: Path) -> "Team": """ stg_path = ./storage/team """ # recover team_info team_info_path = stg_path.joinpath("team_info.json") if not team_info_path.exists(): - logger.error("recover storage not exist, not to recover and continue run the old project.") - team_info = read_json_file(team_info_path) - self.investment = team_info.get("investment", 10.0) - self.idea = team_info.get("idea", "") + raise FileNotFoundError("recover storage meta file `team_info.json` not exist, " + "not to recover and please start a new project.") + + team_info: dict = read_json_file(team_info_path) # recover environment - environment_path = stg_path.joinpath("environment") - self.environment = Environment() - self.environment.deserialize(stg_path=environment_path) + environment = Environment.deserialize(stg_path=stg_path.joinpath("environment")) + team_info.update({"environment": environment}) + + team = Team(**team_info) + return team def hire(self, roles: list[Role]): """Hire roles to cooperate""" @@ -84,6 +91,7 @@ class Team(BaseModel): def _save(self): logger.info(self.json(ensure_ascii=False)) + @serialize_decorator async def run(self, n_round=3): """Run company until target round or no money""" while n_round > 0: diff --git a/metagpt/utils/serialize.py b/metagpt/utils/serialize.py index 56a866f2e..9a7049214 100644 --- a/metagpt/utils/serialize.py +++ b/metagpt/utils/serialize.py @@ -5,9 +5,7 @@ import copy import pickle -from metagpt.actions.action_output import ActionOutput -from metagpt.schema import Message -from metagpt.actions.action import Action +from metagpt.utils.utils import import_class def actionoutout_schema_to_mapping(schema: dict) -> dict: @@ -59,7 +57,7 @@ def actionoutput_str_to_mapping(mapping: dict) -> dict: return new_mapping -def serialize_general_message(message: Message) -> dict: +def serialize_general_message(message: "Message") -> dict: """ serialize Message, not to save""" message_cp = copy.deepcopy(message) ic = message_cp.instruct_content @@ -76,7 +74,7 @@ def serialize_general_message(message: Message) -> dict: return message_cp.dict() -def serialize_message(message: Message): +def serialize_message(message: "Message"): message_cp = copy.deepcopy(message) # avoid `instruct_content` value update by reference ic = message_cp.instruct_content if ic: @@ -90,29 +88,35 @@ def serialize_message(message: Message): return msg_ser -def deserialize_general_message(message_dict: dict) -> Message: +def deserialize_general_message(message_dict: dict) -> "Message": """ deserialize Message, not to load""" instruct_content = message_dict.pop("instruct_content") cause_by = message_dict.pop("cause_by") - message = Message(**message_dict) + message_cls = import_class("Message", "metagpt.schema") + message = message_cls(**message_dict) if instruct_content: ic = instruct_content mapping = actionoutput_str_to_mapping(ic["mapping"]) - ic_obj = ActionOutput.create_model_class(class_name=ic["class"], mapping=mapping) + + actionoutput_class = import_class("ActionOutput", "metagpt.actions.action_output") + ic_obj = actionoutput_class.create_model_class(class_name=ic["class"], mapping=mapping) ic_new = ic_obj(**ic["value"]) message.instruct_content = ic_new if cause_by: - message.cause_by = Action.deser_class(cause_by) + action_class = import_class("Action", "metagpt.actions.action") + message.cause_by = action_class.deser_class(cause_by) return message -def deserialize_message(message_ser: str) -> Message: +def deserialize_message(message_ser: str) -> "Message": message = pickle.loads(message_ser) if message.instruct_content: ic = message.instruct_content - ic_obj = ActionOutput.create_model_class(class_name=ic["class"], mapping=ic["mapping"]) + + actionoutput_class = import_class("ActionOutput", "metagpt.actions.action_output") + ic_obj = actionoutput_class.create_model_class(class_name=ic["class"], mapping=ic["mapping"]) ic_new = ic_obj(**ic["value"]) message.instruct_content = ic_new diff --git a/metagpt/utils/utils.py b/metagpt/utils/utils.py index 220e228c3..ad5c7626a 100644 --- a/metagpt/utils/utils.py +++ b/metagpt/utils/utils.py @@ -56,3 +56,43 @@ def import_class_inst(class_name: str, module_name: str, *args, **kwargs) -> obj a_class = import_class(class_name, module_name) class_inst = a_class(*args, **kwargs) return class_inst + + +def format_trackback_info(limit: int = 2): + return traceback.format_exc(limit=limit) + + +def serialize_decorator(func): + async def wrapper(self, *args, **kwargs): + try: + return await func(self, *args, **kwargs) + except KeyboardInterrupt as kbi: + logger.error(f"KeyboardInterrupt occurs, start to serialize the project, exp:\n{format_trackback_info()}") + self.serialize() # Team.serialize + except Exception as exp: + logger.error(f"Exception occurs, start to serialize the project, exp:\n{format_trackback_info()}") + self.serialize() # Team.serialize + + return wrapper + + +def role_raise_decorator(func): + async def wrapper(self, *args, **kwargs): + try: + return await func(self, *args, **kwargs) + except KeyboardInterrupt as kbi: + logger.error(f"KeyboardInterrupt: {kbi} occurs, start to serialize the project") + if self._rc.env: + newest_msgs = self._rc.env.memory.get(1) + if len(newest_msgs) > 0: + self._rc.memory.delete(newest_msgs[0]) + except Exception as exp: + if self._rc.env: + newest_msgs = self._rc.env.memory.get(1) + if len(newest_msgs) > 0: + logger.warning("There is a exception in role's execution, in order to resume, " + "we delete the newest role communication message in the role's memory.") + self._rc.memory.delete(newest_msgs[0]) # remove newest msg of the role to make it observed again + raise Exception(format_trackback_info(limit=None)) # raise again to make it captured outside + + return wrapper diff --git a/startup.py b/startup.py index 9f753d553..c4928a1b5 100644 --- a/startup.py +++ b/startup.py @@ -1,10 +1,11 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- + +from typing import Optional import asyncio - import fire +from pathlib import Path -from metagpt.const import SERDES_PATH from metagpt.roles import ( Architect, Engineer, @@ -22,11 +23,11 @@ async def startup( code_review: bool = False, run_tests: bool = False, implement: bool = True, - recover_path: bool = False, + recover_path: Optional[str] = None, ): """Run a startup. Be a boss.""" - company = Team() if not recover_path: + company = Team() company.hire( [ ProductManager(), @@ -45,8 +46,12 @@ async def startup( # (bug fixing capability comes soon!) company.hire([QaEngineer()]) else: - stg_path = SERDES_PATH.joinpath("team") - company.deserialize(stg_path=stg_path) + # # stg_path = SERDESER_PATH.joinpath("team") + stg_path = Path(recover_path) + if not stg_path.exists() or not str(stg_path).endswith("team"): + raise FileNotFoundError(f"{recover_path} not exists or not endswith `team`") + + company = Team.recover(stg_path=stg_path) idea = company.idea # use original idea company.invest(investment)