From 4702059caf3c76b05d2a6c7c119a56fbd03a8db9 Mon Sep 17 00:00:00 2001 From: stellahsr Date: Mon, 27 Nov 2023 21:12:50 +0800 Subject: [PATCH 01/14] update basic code for serialize --- metagpt/actions/action.py | 61 +++--- metagpt/actions/design_api.py | 30 ++- metagpt/actions/project_management.py | 27 ++- metagpt/actions/search_and_summarize.py | 52 +++-- metagpt/actions/write_code.py | 15 +- metagpt/actions/write_code_review.py | 13 +- metagpt/actions/write_prd.py | 32 ++- metagpt/environment.py | 5 +- metagpt/roles/architect.py | 18 +- metagpt/roles/engineer.py | 76 +++---- metagpt/roles/product_manager.py | 36 ++-- metagpt/roles/project_manager.py | 27 +-- metagpt/roles/role.py | 271 +++++++++++------------- 13 files changed, 342 insertions(+), 321 deletions(-) diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index 790295d55..7bb5a151b 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -7,8 +7,9 @@ """ import re from abc import ABC -from typing import Optional +from typing import Optional, Any +from pydantic import BaseModel, Field from tenacity import retry, stop_after_attempt, wait_fixed from metagpt.actions.action_output import ActionOutput @@ -18,45 +19,45 @@ from metagpt.utils.common import OutputParser from metagpt.utils.custom_decoder import CustomDecoder -class Action(ABC): - def __init__(self, name: str = "", context=None, llm: LLM = None): - self.name: str = name - if llm is None: - llm = LLM() - self.llm = llm - self.context = context - self.prefix = "" - self.profile = "" - self.desc = "" - self.content = "" - self.instruct_content = None - +class Action(BaseModel): + name: str = "" + llm: LLM = Field(default_factory=LLM) + context = "" + prefix = "" + profile = "" + desc = "" + content: Optional[str] = None + instruct_content: Optional[str] = None + + def __init__(self, **kwargs: Any): + super().__init__(**kwargs) + def set_prefix(self, prefix, profile): """Set prefix for later usage""" self.prefix = prefix self.profile = profile - + def __str__(self): return self.__class__.__name__ - + def __repr__(self): return self.__str__() - + async def _aask(self, prompt: str, system_msgs: Optional[list[str]] = None) -> str: """Append default prefix""" if not system_msgs: system_msgs = [] system_msgs.append(self.prefix) return await self.llm.aask(prompt, system_msgs) - + @retry(stop=stop_after_attempt(3), wait=wait_fixed(1)) async def _aask_v1( - self, - prompt: str, - output_class_name: str, - output_data_mapping: dict, - system_msgs: Optional[list[str]] = None, - format="markdown", # compatible to original format + self, + prompt: str, + output_class_name: str, + output_data_mapping: dict, + system_msgs: Optional[list[str]] = None, + format="markdown", # compatible to original format ) -> ActionOutput: """Append default prefix""" if not system_msgs: @@ -65,25 +66,25 @@ class Action(ABC): content = await self.llm.aask(prompt, system_msgs) logger.debug(content) output_class = ActionOutput.create_model_class(output_class_name, output_data_mapping) - + if format == "json": pattern = r"\[CONTENT\](\s*\{.*?\}\s*)\[/CONTENT\]" matches = re.findall(pattern, content, re.DOTALL) - + for match in matches: if match: content = match break - + parsed_data = CustomDecoder(strict=False).decode(content) - + else: # using markdown parser parsed_data = OutputParser.parse_data_with_mapping(content, output_data_mapping) - + logger.debug(parsed_data) instruct_content = output_class(**parsed_data) return ActionOutput(content, instruct_content) - + async def run(self, *args, **kwargs): """Run action""" raise NotImplementedError("The run method should be implemented in a subclass.") diff --git a/metagpt/actions/design_api.py b/metagpt/actions/design_api.py index 75df8b909..30df70ce7 100644 --- a/metagpt/actions/design_api.py +++ b/metagpt/actions/design_api.py @@ -7,9 +7,12 @@ """ import shutil from pathlib import Path -from typing import List +from typing import List, Optional, Any + +from pydantic import Field from metagpt.actions import Action, ActionOutput +from metagpt.llm import LLM from metagpt.config import CONFIG from metagpt.const import WORKSPACE_ROOT from metagpt.logs import logger @@ -150,13 +153,13 @@ OUTPUT_MAPPING = { class WriteDesign(Action): - def __init__(self, name, context=None, llm=None): - super().__init__(name, context, llm) - self.desc = ( - "Based on the PRD, think about the system design, and design the corresponding APIs, " - "data structures, library tables, processes, and paths. Please provide your design, feedback " - "clearly and in detail." - ) + name: str = "" + context: Optional[str] = None + llm: LLM = Field(default_factory=LLM) + desc: str = "Based on the PRD, think about the system design, and design the corresponding APIs, " + "data structures, library tables, processes, and paths. Please provide your design, feedback " + "clearly and in detail." + def recreate_workspace(self, workspace: Path): try: @@ -165,16 +168,18 @@ class WriteDesign(Action): pass # Folder does not exist, but we don't care workspace.mkdir(parents=True, exist_ok=True) + async def _save_prd(self, docs_path, resources_path, context): prd_file = docs_path / "prd.md" if context[-1].instruct_content and context[-1].instruct_content.dict()["Competitive Quadrant Chart"]: quadrant_chart = context[-1].instruct_content.dict()["Competitive Quadrant Chart"] await mermaid_to_file(quadrant_chart, resources_path / "competitive_analysis") - + if context[-1].instruct_content: logger.info(f"Saving PRD to {prd_file}") prd_file.write_text(json_to_markdown(context[-1].instruct_content.dict())) + async def _save_system_design(self, docs_path, resources_path, system_design): data_api_design = system_design.instruct_content.dict()[ "Data structures and interface definitions" @@ -188,6 +193,7 @@ class WriteDesign(Action): logger.info(f"Saving System Designs to {system_design_file}") system_design_file.write_text((json_to_markdown(system_design.instruct_content.dict()))) + async def _save(self, context, system_design): if isinstance(system_design, ActionOutput): ws_name = system_design.instruct_content.dict()["Python package name"] @@ -199,9 +205,13 @@ class WriteDesign(Action): resources_path = workspace / "resources" docs_path.mkdir(parents=True, exist_ok=True) resources_path.mkdir(parents=True, exist_ok=True) - await self._save_prd(docs_path, resources_path, context) + try: + await self._save_prd(docs_path, resources_path, context) + except Exception as e: + logger.error(f"Failed to save PRD {e}") await self._save_system_design(docs_path, resources_path, system_design) + async def run(self, context, format=CONFIG.prompt_format): prompt_template, format_example = get_template(templates, format) prompt = prompt_template.format(context=context, format_example=format_example) diff --git a/metagpt/actions/project_management.py b/metagpt/actions/project_management.py index b395fa64e..b72507ee3 100644 --- a/metagpt/actions/project_management.py +++ b/metagpt/actions/project_management.py @@ -5,9 +5,12 @@ @Author : alexanderwu @File : project_management.py """ -from typing import List +from typing import List, Optional, Any + +from pydantic import Field from metagpt.actions.action import Action +from metagpt.llm import LLM from metagpt.config import CONFIG from metagpt.const import WORKSPACE_ROOT from metagpt.utils.common import CodeParser @@ -163,21 +166,25 @@ OUTPUT_MAPPING = { class WriteTasks(Action): - def __init__(self, name="CreateTasks", context=None, llm=None): - super().__init__(name, context, llm) - + name: str = "CreateTasks" + context: Optional[str] = None + llm: LLM = Field(default_factory=LLM) + def _save(self, context, rsp): - if context[-1].instruct_content: - ws_name = context[-1].instruct_content.dict()["Python package name"] - else: - ws_name = CodeParser.parse_str(block="Python package name", text=context[-1].content) + try: + if context[-1].instruct_content: + ws_name = context[-1].instruct_content.dict()["Python package name"] + else: + ws_name = CodeParser.parse_str(block="Python package name", text=context[-1].content) + except: + ws_name = "cli_snake_game" # fixme: 应该透传 file_path = WORKSPACE_ROOT / ws_name / "docs/api_spec_and_tasks.md" file_path.write_text(json_to_markdown(rsp.instruct_content.dict())) - + # Write requirements.txt requirements_path = WORKSPACE_ROOT / ws_name / "requirements.txt" requirements_path.write_text("\n".join(rsp.instruct_content.dict().get("Required Python third-party packages"))) - + async def run(self, context, format=CONFIG.prompt_format): prompt_template, format_example = get_template(templates, format) prompt = prompt_template.format(context=context, format_example=format_example) diff --git a/metagpt/actions/search_and_summarize.py b/metagpt/actions/search_and_summarize.py index 069f2a977..0580303e6 100644 --- a/metagpt/actions/search_and_summarize.py +++ b/metagpt/actions/search_and_summarize.py @@ -6,12 +6,16 @@ @File : search_google.py """ import pydantic +from typing import Optional, Any +from pydantic import BaseModel, Field from metagpt.actions import Action +from metagpt.llm import LLM from metagpt.config import Config from metagpt.logs import logger from metagpt.schema import Message from metagpt.tools.search_engine import SearchEngine +from pydantic import root_validator SEARCH_AND_SUMMARIZE_SYSTEM = """### Requirements 1. Please summarize the latest dialogue based on the reference information (secondary) and dialogue history (primary). Do not include text that is irrelevant to the conversation. @@ -54,7 +58,6 @@ SEARCH_AND_SUMMARIZE_PROMPT = """ """ - SEARCH_AND_SUMMARIZE_SALES_SYSTEM = """## Requirements 1. Please summarize the latest dialogue based on the reference information (secondary) and dialogue history (primary). Do not include text that is irrelevant to the conversation. - The context is for reference only. If it is irrelevant to the user's search request history, please reduce its reference and usage. @@ -101,23 +104,41 @@ You are a member of a professional butler team and will provide helpful suggesti class SearchAndSummarize(Action): - def __init__(self, name="", context=None, llm=None, engine=None, search_func=None): - self.config = Config() - self.engine = engine or self.config.search_engine + name: str = "" + content: Optional[str] = None + llm: None = Field(default_factory=LLM) + config: None = Field(default_factory=Config) + engine: Optional[str] = None + search_func: Optional[str] = None - try: - self.search_engine = SearchEngine(self.engine, run_func=search_func) - except pydantic.ValidationError: - self.search_engine = None + result = "" + - self.result = "" - super().__init__(name, context, llm) + @root_validator + def validate_engine_and_run_func(cls, values): + engine = values.get('engine') + search_func = values.get('search_func') + config = Config() + + if engine is None: + engine = config.search_engine + config_data = { + 'engine': engine, + 'run_func': search_func + } + search_engine = SearchEngine(**config_data) + values['search_engine'] = search_engine + return values + + + async def run(self, context: list[Message], system_text=SEARCH_AND_SUMMARIZE_SYSTEM) -> str: + print(context) if self.search_engine is None: logger.warning("Configure one of SERPAPI_API_KEY, SERPER_API_KEY, GOOGLE_API_KEY to unlock full feature") return "" - + query = context[-1].content # logger.debug(query) rsp = await self.search_engine.run(query) @@ -126,9 +147,9 @@ class SearchAndSummarize(Action): logger.error("empty rsp...") return "" # logger.info(rsp) - + system_prompt = [system_text] - + prompt = SEARCH_AND_SUMMARIZE_PROMPT.format( # PREFIX = self.prefix, ROLE=self.profile, @@ -140,4 +161,7 @@ class SearchAndSummarize(Action): logger.debug(prompt) logger.debug(result) return result - \ No newline at end of file + + +if __name__ == "__main__": + action = SearchAndSummarize() diff --git a/metagpt/actions/write_code.py b/metagpt/actions/write_code.py index c000805c5..2dc240591 100644 --- a/metagpt/actions/write_code.py +++ b/metagpt/actions/write_code.py @@ -5,13 +5,18 @@ @Author : alexanderwu @File : write_code.py """ +from typing import List, Optional, Any + +from pydantic import Field +from tenacity import retry, stop_after_attempt, wait_fixed + from metagpt.actions import WriteDesign from metagpt.actions.action import Action +from metagpt.llm import LLM from metagpt.const import WORKSPACE_ROOT from metagpt.logs import logger from metagpt.schema import Message from metagpt.utils.common import CodeParser -from tenacity import retry, stop_after_attempt, wait_fixed PROMPT_TEMPLATE = """ NOTICE @@ -43,9 +48,10 @@ ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenc class WriteCode(Action): - def __init__(self, name="WriteCode", context: list[Message] = None, llm=None): - super().__init__(name, context, llm) - + name: str = "WriteCode" + context: Optional[str] = None + llm: LLM = Field(default_factory=LLM) + def _is_invalid(self, filename): return any(i in filename for i in ["mp3", "wav"]) @@ -79,4 +85,3 @@ class WriteCode(Action): # code_rsp = await self._aask_v1(prompt, "code_rsp", OUTPUT_MAPPING) # self._save(context, filename, code) return code - \ No newline at end of file diff --git a/metagpt/actions/write_code_review.py b/metagpt/actions/write_code_review.py index 4ff4d6cf6..3d86d7c63 100644 --- a/metagpt/actions/write_code_review.py +++ b/metagpt/actions/write_code_review.py @@ -5,12 +5,15 @@ @Author : alexanderwu @File : write_code_review.py """ +from typing import List, Optional, Any +from pydantic import Field +from tenacity import retry, stop_after_attempt, wait_fixed +from metagpt.llm import LLM from metagpt.actions.action import Action from metagpt.logs import logger from metagpt.schema import Message from metagpt.utils.common import CodeParser -from tenacity import retry, stop_after_attempt, wait_fixed PROMPT_TEMPLATE = """ NOTICE @@ -62,9 +65,10 @@ FORMAT_EXAMPLE = """ class WriteCodeReview(Action): - def __init__(self, name="WriteCodeReview", context: list[Message] = None, llm=None): - super().__init__(name, context, llm) - + name: str = "WriteCodeReview" + context: Optional[str] = None + llm: LLM = Field(default_factory=LLM) + @retry(stop=stop_after_attempt(2), wait=wait_fixed(1)) async def write_code(self, prompt): code_rsp = await self._aask(prompt) @@ -79,4 +83,3 @@ class WriteCodeReview(Action): # code_rsp = await self._aask_v1(prompt, "code_rsp", OUTPUT_MAPPING) # self._save(context, filename, code) return code - \ No newline at end of file diff --git a/metagpt/actions/write_prd.py b/metagpt/actions/write_prd.py index bd04ca79e..660d7fb95 100644 --- a/metagpt/actions/write_prd.py +++ b/metagpt/actions/write_prd.py @@ -5,9 +5,12 @@ @Author : alexanderwu @File : write_prd.py """ -from typing import List +from typing import List, Optional, Any + +from pydantic import BaseModel, Field from metagpt.actions import Action, ActionOutput +from metagpt.llm import LLM from metagpt.actions.search_and_summarize import SearchAndSummarize from metagpt.config import CONFIG from metagpt.logs import logger @@ -219,18 +222,25 @@ OUTPUT_MAPPING = { class WritePRD(Action): - def __init__(self, name="", context=None, llm=None): - super().__init__(name, context, llm) - + name: str = "" + content: Optional[str] = None + llm: LLM = Field(default_factory=LLM) + assistant_search_action: Action = None + + def __init__(self, **kwargs): + super().__init__(**kwargs) + async def run(self, requirements, format=CONFIG.prompt_format, *args, **kwargs) -> ActionOutput: - sas = SearchAndSummarize() - # rsp = await sas.run(context=requirements, system_text=SEARCH_AND_SUMMARIZE_SYSTEM_EN_US) - rsp = "" - info = f"### Search Results\n{sas.result}\n\n### Search Summary\n{rsp}" - if sas.result: - logger.info(sas.result) + # self.assistant_search_action = SearchAndSummarize() + if self.assistant_search_action is None: + self.assistant_search_action = SearchAndSummarize() + # self.assistant_search_action = SearchAndSummarize() + rsp = await self.assistant_search_action.run(context=requirements) + info = f"### Search Results\n{self.assistant_search_action.result}\n\n### Search Summary\n{rsp}" + if self.assistant_search_action.result: + logger.info(self.assistant_search_action.result) logger.info(rsp) - + prompt_template, format_example = get_template(templates, format) prompt = prompt_template.format( requirements=requirements, search_information=info, format_example=format_example diff --git a/metagpt/environment.py b/metagpt/environment.py index 24e6ada2f..88ff145e0 100644 --- a/metagpt/environment.py +++ b/metagpt/environment.py @@ -29,11 +29,12 @@ class Environment(BaseModel): arbitrary_types_allowed = True def add_role(self, role: Role): - """增加一个在当前环境的角色 + """增加一个在当前环境的角色, 默认为profile/role_profile Add a role in the current environment """ role.set_env(self) - self.roles[role.profile] = role + # use alias + self.roles[role.role_profile] = role def add_roles(self, roles: Iterable[Role]): """增加一批在当前环境的角色 diff --git a/metagpt/roles/architect.py b/metagpt/roles/architect.py index 15d5fe5b1..face22a68 100644 --- a/metagpt/roles/architect.py +++ b/metagpt/roles/architect.py @@ -5,10 +5,11 @@ @Author : alexanderwu @File : architect.py """ +from pydantic import Field from metagpt.actions import WritePRD from metagpt.actions.design_api import WriteDesign -from metagpt.roles import Role +from metagpt.roles.role import Role class Architect(Role): @@ -21,17 +22,16 @@ class Architect(Role): goal (str): Primary goal or responsibility of the architect. constraints (str): Constraints or guidelines for the architect. """ + name: str = "Bob" + role_profile: str = Field(default="Architect" , alias='profile') + goal: str = "Design a concise, usable, complete python system" + constraints: str = "Try to specify good open source tools as much as possible" def __init__( - self, - name: str = "Bob", - profile: str = "Architect", - goal: str = "Design a concise, usable, complete python system", - constraints: str = "Try to specify good open source tools as much as possible", + self, + **kwargs ) -> None: - """Initializes the Architect with given attributes.""" - super().__init__(name, profile, goal, constraints) - + super().__init__(**kwargs) # Initialize actions specific to the Architect role self._init_actions([WriteDesign]) diff --git a/metagpt/roles/engineer.py b/metagpt/roles/engineer.py index 1f6685b38..129bedeb8 100644 --- a/metagpt/roles/engineer.py +++ b/metagpt/roles/engineer.py @@ -9,11 +9,12 @@ import asyncio import shutil from collections import OrderedDict from pathlib import Path +from pydantic import Field from metagpt.actions import WriteCode, WriteCodeReview, WriteDesign, WriteTasks from metagpt.const import WORKSPACE_ROOT from metagpt.logs import logger -from metagpt.roles import Role +from metagpt.roles.role import Role from metagpt.schema import Message from metagpt.utils.common import CodeParser from metagpt.utils.special_tokens import FILENAME_CODE_SEP, MSG_SEP @@ -23,7 +24,7 @@ async def gather_ordered_k(coros, k) -> list: tasks = OrderedDict() results = [None] * len(coros) done_queue = asyncio.Queue() - + for i, coro in enumerate(coros): if len(tasks) >= k: done, _ = await asyncio.wait(tasks.keys(), return_when=asyncio.FIRST_COMPLETED) @@ -32,17 +33,17 @@ async def gather_ordered_k(coros, k) -> list: await done_queue.put((index, task.result())) task = asyncio.create_task(coro) tasks[task] = i - + if tasks: done, _ = await asyncio.wait(tasks.keys()) for task in done: index = tasks[task] await done_queue.put((index, task.result())) - + while not done_queue.empty(): index, result = await done_queue.get() results[index] = result - + return results @@ -59,42 +60,42 @@ class Engineer(Role): use_code_review (bool): Whether to use code review. todos (list): List of tasks. """ - + name: str = "Alex" + role_profile: str = Field(default="Engineer", alias='profile') + goal: str = "Write elegant, readable, extensible, efficient code" + constraints: str = "The code should conform to standards like PEP8 and be modular and maintainable" + n_borg: int = 1 + use_code_review: bool = False + todos: list = [] + def __init__( - self, - name: str = "Alex", - profile: str = "Engineer", - goal: str = "Write elegant, readable, extensible, efficient code", - constraints: str = "The code should conform to standards like PEP8 and be modular and maintainable", - n_borg: int = 1, - use_code_review: bool = False, + self, + **kwargs ) -> None: - """Initializes the Engineer role with given attributes.""" - super().__init__(name, profile, goal, constraints) - self._init_actions([WriteCode]) - self.use_code_review = use_code_review + super().__init__(**kwargs) + + actions = [WriteCode] if self.use_code_review: - self._init_actions([WriteCode, WriteCodeReview]) + actions = [WriteCode, WriteCodeReview] + self._init_actions(actions) self._watch([WriteTasks]) - self.todos = [] - self.n_borg = n_borg - + @classmethod def parse_tasks(self, task_msg: Message) -> list[str]: if task_msg.instruct_content: return task_msg.instruct_content.dict().get("Task list") return CodeParser.parse_file_list(block="Task list", text=task_msg.content) - + @classmethod def parse_code(self, code_text: str) -> str: return CodeParser.parse_code(block="", text=code_text) - + @classmethod def parse_workspace(cls, system_design_msg: Message) -> str: if system_design_msg.instruct_content: return system_design_msg.instruct_content.dict().get("Python package name").strip().strip("'").strip('"') return CodeParser.parse_str(block="Python package name", text=system_design_msg.content) - + def get_workspace(self) -> Path: msg = self._rc.memory.get_by_action(WriteDesign)[-1] if not msg: @@ -102,7 +103,7 @@ class Engineer(Role): workspace = self.parse_workspace(msg) # Codes are written in workspace/{package_name}/{package_name} return WORKSPACE_ROOT / workspace / workspace - + def recreate_workspace(self): workspace = self.get_workspace() try: @@ -110,7 +111,7 @@ class Engineer(Role): except FileNotFoundError: pass # The folder does not exist, but we don't care workspace.mkdir(parents=True, exist_ok=True) - + def write_file(self, filename: str, code: str): workspace = self.get_workspace() filename = filename.replace('"', "").replace("\n", "") @@ -118,12 +119,12 @@ class Engineer(Role): file.parent.mkdir(parents=True, exist_ok=True) file.write_text(code) return file - + def recv(self, message: Message) -> None: self._rc.memory.add(message) if message in self._rc.important_memory: self.todos = self.parse_tasks(message) - + async def _act_mp(self) -> Message: # self.recreate_workspace() todo_coros = [] @@ -132,7 +133,7 @@ class Engineer(Role): context=self._rc.memory.get_by_actions([WriteTasks, WriteDesign]), filename=todo ) todo_coros.append(todo_coro) - + rsps = await gather_ordered_k(todo_coros, self.n_borg) for todo, code_rsp in zip(self.todos, rsps): _ = self.parse_code(code_rsp) @@ -142,11 +143,11 @@ class Engineer(Role): msg = Message(content=code_rsp, role=self.profile, cause_by=type(self._rc.todo)) self._rc.memory.add(msg) del self.todos[0] - + logger.info(f"Done {self.get_workspace()} generating.") msg = Message(content="all done.", role=self.profile, cause_by=type(self._rc.todo)) return msg - + async def _act_sp(self) -> Message: code_msg_all = [] # gather all code info, will pass to qa_engineer for tests later for todo in self.todos: @@ -157,16 +158,16 @@ class Engineer(Role): file_path = self.write_file(todo, code) msg = Message(content=code, role=self.profile, cause_by=type(self._rc.todo)) self._rc.memory.add(msg) - + code_msg = todo + FILENAME_CODE_SEP + str(file_path) code_msg_all.append(code_msg) - + logger.info(f"Done {self.get_workspace()} generating.") msg = Message( content=MSG_SEP.join(code_msg_all), role=self.profile, cause_by=type(self._rc.todo), send_to="QaEngineer" ) return msg - + async def _act_sp_precision(self) -> Message: code_msg_all = [] # gather all code info, will pass to qa_engineer for tests later for todo in self.todos: @@ -195,19 +196,18 @@ class Engineer(Role): file_path = self.write_file(todo, code) msg = Message(content=code, role=self.profile, cause_by=WriteCode) self._rc.memory.add(msg) - + code_msg = todo + FILENAME_CODE_SEP + str(file_path) code_msg_all.append(code_msg) - + logger.info(f"Done {self.get_workspace()} generating.") msg = Message( content=MSG_SEP.join(code_msg_all), role=self.profile, cause_by=type(self._rc.todo), send_to="QaEngineer" ) return msg - + async def _act(self) -> Message: """Determines the mode of action based on whether code review is used.""" - logger.info(f"{self._setting}: ready to WriteCode") if self.use_code_review: return await self._act_sp_precision() return await self._act_sp() diff --git a/metagpt/roles/product_manager.py b/metagpt/roles/product_manager.py index a58ea5385..b099fb4d9 100644 --- a/metagpt/roles/product_manager.py +++ b/metagpt/roles/product_manager.py @@ -5,37 +5,33 @@ @Author : alexanderwu @File : product_manager.py """ +from pydantic import Field + from metagpt.actions import BossRequirement, WritePRD -from metagpt.roles import Role +from metagpt.roles.role import Role class ProductManager(Role): """ - Represents a Product Manager role responsible for product development and management. + Initializes the ProductManager role with given attributes. - Attributes: + Args: name (str): Name of the product manager. - profile (str): Role profile, default is 'Product Manager'. + profile (str): Role profile. goal (str): Goal of the product manager. constraints (str): Constraints or limitations for the product manager. """ - + name: str = "Alice" + role_profile: str = Field(default="Product Manager", alias='profile') + goal: str = "Efficiently create a successful product" + constraints: str = "" + """ + Represents a Product Manager role responsible for product development and management. + """ def __init__( - self, - name: str = "Alice", - profile: str = "Product Manager", - goal: str = "Efficiently create a successful product", - constraints: str = "", + self, + **kwargs ) -> None: - """ - Initializes the ProductManager role with given attributes. - - Args: - name (str): Name of the product manager. - profile (str): Role profile. - goal (str): Goal of the product manager. - constraints (str): Constraints or limitations for the product manager. - """ - super().__init__(name, profile, goal, constraints) + super().__init__(**kwargs) self._init_actions([WritePRD]) self._watch([BossRequirement]) diff --git a/metagpt/roles/project_manager.py b/metagpt/roles/project_manager.py index 7e7c5699d..a2b227f22 100644 --- a/metagpt/roles/project_manager.py +++ b/metagpt/roles/project_manager.py @@ -5,9 +5,11 @@ @Author : alexanderwu @File : project_manager.py """ +from pydantic import Field + from metagpt.actions import WriteTasks from metagpt.actions.design_api import WriteDesign -from metagpt.roles import Role +from metagpt.roles.role import Role class ProjectManager(Role): @@ -20,23 +22,16 @@ class ProjectManager(Role): goal (str): Goal of the project manager. constraints (str): Constraints or limitations for the project manager. """ + name: str = "Eve" + role_profile: str = Field(default="Project Manager", alias='profile') + + goal: str = "Improve team efficiency and deliver with quality and quantity" + constraints: str = "" def __init__( - self, - name: str = "Eve", - profile: str = "Project Manager", - goal: str = "Improve team efficiency and deliver with quality and quantity", - constraints: str = "", + self, + **kwargs ) -> None: - """ - Initializes the ProjectManager role with given attributes. - - Args: - name (str): Name of the project manager. - profile (str): Role profile. - goal (str): Goal of the project manager. - constraints (str): Constraints or limitations for the project manager. - """ - super().__init__(name, profile, goal, constraints) + super().__init__(**kwargs) self._init_actions([WriteTasks]) self._watch([WriteDesign]) diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index b96c361c0..9aae64188 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -5,17 +5,26 @@ @Author : alexanderwu @File : role.py """ + from __future__ import annotations -from typing import Iterable, Type, Union -from enum import Enum - +import sys +from types import SimpleNamespace +from typing import ( + Dict, + Optional, + Union, + Iterable, + Type +) +import re from pydantic import BaseModel, Field +from importlib import import_module # from metagpt.environment import Environment from metagpt.config import CONFIG from metagpt.actions import Action, ActionOutput -from metagpt.llm import LLM, HumanProvider +from metagpt.llm import LLM from metagpt.logs import logger from metagpt.memory import Memory, LongTermMemory from metagpt.schema import Message @@ -28,14 +37,12 @@ Please note that only the text between the first and second "===" is information {history} === -Your previous stage: {previous_state} - -Now choose one of the following stages you need to go to in the next step: +You can now choose one of the following stages to decide the stage you need to go in the next step: {states} Just answer a number between 0-{n_states}, choose the most suitable stage according to the understanding of the conversation. Please note that the answer only needs a number, no need to add any other text. -If you think you have completed your goal and don't need to go to any of the stages, return -1. +If there is no conversation record, choose 0. Do not answer anything else, and do not add any other information in your answer. """ @@ -49,27 +56,18 @@ ROLE_TEMPLATE = """Your response should be based on the previous conversation hi {name}: {result} """ -class RoleReactMode(str, Enum): - REACT = "react" - BY_ORDER = "by_order" - PLAN_AND_ACT = "plan_and_act" - - @classmethod - def values(cls): - return [item.value for item in cls] class RoleSetting(BaseModel): """Role Settings""" - name: str - profile: str - goal: str - constraints: str - desc: str - is_human: bool - + name: str = "" + profile: str = "" + goal: str = "" + constraints: str = "" + desc: str = "" + def __str__(self): return f"{self.name}({self.profile})" - + def __repr__(self): return self.__str__() @@ -79,109 +77,128 @@ class RoleContext(BaseModel): env: 'Environment' = Field(default=None) memory: Memory = Field(default_factory=Memory) long_term_memory: LongTermMemory = Field(default_factory=LongTermMemory) - state: int = Field(default=-1) # -1 indicates initial or termination state where todo is None + state: int = Field(default=0) todo: Action = Field(default=None) watch: set[Type[Action]] = Field(default_factory=set) news: list[Type[Message]] = Field(default=[]) - react_mode: RoleReactMode = RoleReactMode.REACT # see `Role._set_react_mode` for definitions of the following two attributes - max_react_loop: int = 1 - + class Config: arbitrary_types_allowed = True - + def check(self, role_id: str): if hasattr(CONFIG, "long_term_memory") and CONFIG.long_term_memory: self.long_term_memory.recover_memory(role_id, self) self.memory = self.long_term_memory # use memory to act as long_term_memory for unify operation - + @property def important_memory(self) -> list[Message]: """Get the information corresponding to the watched actions""" return self.memory.get_by_actions(self.watch) - + @property def history(self) -> list[Message]: return self.memory.get() -class Role: +class Role(BaseModel): """Role/Agent""" - - def __init__(self, name="", profile="", goal="", constraints="", desc="", is_human=False): - self._llm = LLM() if not is_human else HumanProvider() - self._setting = RoleSetting(name=name, profile=profile, goal=goal, - constraints=constraints, desc=desc, is_human=is_human) - self._states = [] - self._actions = [] - self._role_id = str(self._setting) - self._rc = RoleContext() - + name: str = "" + profile: str = "" + goal: str = "" + constraints: str = "" + desc: str = "" + _setting: RoleSetting = Field(default_factory=RoleSetting, alias="_setting") + _setting = RoleSetting(name=name, profile=profile, goal=goal, constraints=constraints) + _role_id: str = "" + _states: list = Field(default=[]) + _actions: list = Field(default=[]) + _actions_type: list = Field(default=[]) + _rc: RoleContext = RoleContext() + + _private_attributes = { + '_setting': _setting, + '_role_id': _role_id, + '_states': [], + '_actions': [], + '_actions_type': [] # 用于记录和序列化 + } + + class Config: + arbitrary_types_allowed = True + + def __init__(self, **kwargs): + super().__init__(**kwargs) + # 关于私有变量的初始化 https://github.com/pydantic/pydantic/issues/655 + for key in self._private_attributes.keys(): + if key in kwargs: + object.__setattr__(self, key, kwargs[key]) + if key =="_setting": + _setting = RoleSetting(**kwargs[key]) + object.__setattr__(self, '_setting', _setting) + elif key == "_rc": + _rc = RoleContext + object.__setattr__(self, '_rc', _rc) + else: + object.__setattr__(self, key, self._private_attributes[key]) + def _reset(self): - self._states = [] - self._actions = [] + object.__setattr__(self, '_states', []) + object.__setattr__(self, '_actions', []) + + + @staticmethod + def _process_class(class_str, module_name): + cleaned_string = re.sub(r"[<>']", "", class_str).replace("class ", "") + package_name = "metagpt" + file_name = cleaned_string.replace(package_name, "").replace("." + module_name, "") + print(file_name) + # print("\n", sys.modules) + module_file = import_module(file_name, package=package_name) + module = getattr(module_file, module_name) + return module + def _init_actions(self, actions): self._reset() for idx, action in enumerate(actions): if not isinstance(action, Action): - i = action("", llm=self._llm) + ## 默认初始化 + i = action() else: - if self._setting.is_human and not isinstance(action.llm, HumanProvider): - logger.warning(f"is_human attribute does not take effect," - f"as Role's {str(action)} was initialized using LLM, try passing in Action classes instead of initialized instances") i = action i.set_prefix(self._get_prefix(), self.profile) self._actions.append(i) self._states.append(f"{idx}. {action}") - - def _set_react_mode(self, react_mode: str, max_react_loop: int = 1): - """Set strategy of the Role reacting to observed Message. Variation lies in how - this Role elects action to perform during the _think stage, especially if it is capable of multiple Actions. - - Args: - react_mode (str): Mode for choosing action during the _think stage, can be one of: - "react": standard think-act loop in the ReAct paper, alternating thinking and acting to solve the task, i.e. _think -> _act -> _think -> _act -> ... - Use llm to select actions in _think dynamically; - "by_order": switch action each time by order defined in _init_actions, i.e. _act (Action1) -> _act (Action2) -> ...; - "plan_and_act": first plan, then execute an action sequence, i.e. _think (of a plan) -> _act -> _act -> ... - Use llm to come up with the plan dynamically. - Defaults to "react". - max_react_loop (int): Maximum react cycles to execute, used to prevent the agent from reacting forever. - Take effect only when react_mode is react, in which we use llm to choose actions, including termination. - Defaults to 1, i.e. _think -> _act (-> return result and end) - """ - assert react_mode in RoleReactMode.values(), f"react_mode must be one of {RoleReactMode.values()}" - self._rc.react_mode = react_mode - if react_mode == RoleReactMode.REACT: - self._rc.max_react_loop = max_react_loop - + action_title = action.schema()["title"] + self._actions_type.append(action_title) + def _watch(self, actions: Iterable[Type[Action]]): """Listen to the corresponding behaviors""" self._rc.watch.update(actions) # check RoleContext after adding watch actions self._rc.check(self._role_id) - - def _set_state(self, state: int): + + def _set_state(self, state): """Update the current state.""" self._rc.state = state logger.debug(self._actions) - self._rc.todo = self._actions[self._rc.state] if state >= 0 else None - + self._rc.todo = self._actions[self._rc.state] + def set_env(self, env: 'Environment'): """Set the environment in which the role works. The role can talk to the environment and can also receive messages by observing.""" self._rc.env = env - + @property def profile(self): """Get the role description (position)""" return self._setting.profile - + def _get_prefix(self): """Get the role prefix""" if self._setting.desc: return self._setting.desc return PREFIX_TEMPLATE.format(**self._setting.dict()) - + async def _think(self) -> None: """Think about what to do and decide on the next action""" if len(self._actions) == 1: @@ -190,104 +207,60 @@ class Role: return prompt = self._get_prefix() prompt += STATE_TEMPLATE.format(history=self._rc.history, states="\n".join(self._states), - n_states=len(self._states) - 1, previous_state=self._rc.state) - # print(prompt) + n_states=len(self._states) - 1) next_state = await self._llm.aask(prompt) logger.debug(f"{prompt=}") - if (not next_state.isdigit() and next_state != "-1") \ - or int(next_state) not in range(-1, len(self._states)): - logger.warning(f'Invalid answer of state, {next_state=}, will be set to -1') - next_state = -1 - else: - next_state = int(next_state) - if next_state == -1: - logger.info(f"End actions with {next_state=}") - self._set_state(next_state) - + if not next_state.isdigit() or int(next_state) not in range(len(self._states)): + logger.warning(f'Invalid answer of state, {next_state=}') + next_state = "0" + self._set_state(int(next_state)) + async def _act(self) -> Message: - # prompt = self.get_prefix() - # prompt += ROLE_TEMPLATE.format(name=self.profile, state=self.states[self.state], result=response, - # history=self.history) - logger.info(f"{self._setting}: ready to {self._rc.todo}") response = await self._rc.todo.run(self._rc.important_memory) # logger.info(response) if isinstance(response, ActionOutput): msg = Message(content=response.content, instruct_content=response.instruct_content, - role=self.profile, cause_by=type(self._rc.todo)) + role=self.profile, cause_by=type(self._rc.todo)) else: msg = Message(content=response, role=self.profile, cause_by=type(self._rc.todo)) self._rc.memory.add(msg) # logger.debug(f"{response}") - + return msg - + async def _observe(self) -> int: """Observe from the environment, obtain important information, and add it to memory""" if not self._rc.env: return 0 env_msgs = self._rc.env.memory.get() - + observed = self._rc.env.memory.get_by_actions(self._rc.watch) - self._rc.news = self._rc.memory.find_news(observed) # find news (previously unseen messages) from observed messages - + self._rc.news = self._rc.memory.find_news( + observed) # find news (previously unseen messages) from observed messages + for i in env_msgs: self.recv(i) - + news_text = [f"{i.role}: {i.content[:20]}..." for i in self._rc.news] if news_text: logger.debug(f'{self._setting} observed: {news_text}') return len(self._rc.news) - + def _publish_message(self, msg): """If the role belongs to env, then the role's messages will be broadcast to env""" if not self._rc.env: # If env does not exist, do not publish the message return self._rc.env.publish_message(msg) - + async def _react(self) -> Message: - """Think first, then act, until the Role _think it is time to stop and requires no more todo. - This is the standard think-act loop in the ReAct paper, which alternates thinking and acting in task solving, i.e. _think -> _act -> _think -> _act -> ... - Use llm to select actions in _think dynamically - """ - actions_taken = 0 - rsp = Message("No actions taken yet") # will be overwritten after Role _act - while actions_taken < self._rc.max_react_loop: - # think - await self._think() - if self._rc.todo is None: - break - # act - logger.debug(f"{self._setting}: {self._rc.state=}, will do {self._rc.todo}") - rsp = await self._act() - actions_taken += 1 - return rsp # return output from the last action - - async def _act_by_order(self) -> Message: - """switch action each time by order defined in _init_actions, i.e. _act (Action1) -> _act (Action2) -> ...""" - for i in range(len(self._states)): - self._set_state(i) - rsp = await self._act() - return rsp # return output from the last action - - async def _plan_and_act(self) -> Message: - """first plan, then execute an action sequence, i.e. _think (of a plan) -> _act -> _act -> ... Use llm to come up with the plan dynamically.""" - # TODO: to be implemented - return Message("") - - async def react(self) -> Message: - """Entry to one of three strategies by which Role reacts to the observed Message""" - if self._rc.react_mode == RoleReactMode.REACT: - rsp = await self._react() - elif self._rc.react_mode == RoleReactMode.BY_ORDER: - rsp = await self._act_by_order() - elif self._rc.react_mode == RoleReactMode.PLAN_AND_ACT: - rsp = await self._plan_and_act() - self._set_state(state=-1) # current reaction is complete, reset state to -1 and todo back to None - return rsp - + """Think first, then act""" + await self._think() + logger.debug(f"{self._setting}: {self._rc.state=}, will do {self._rc.todo}") + return await self._act() + def recv(self, message: Message) -> None: """add message to history.""" # self._history += f"\n{message}" @@ -295,18 +268,14 @@ class Role: if message in self._rc.memory.get(): return self._rc.memory.add(message) - + async def handle(self, message: Message) -> Message: """Receive information and reply with actions""" # logger.debug(f"{self.name=}, {self.profile=}, {message.role=}") self.recv(message) - + return await self._react() - - def get_memories(self, k=0) -> list[Message]: - """A wrapper to return the most recent k memories of this role, return all when k=0""" - return self._rc.memory.get(k=k) - + async def run(self, message=None): """Observe, and think and act based on the results of the observation""" if message: @@ -320,8 +289,8 @@ class Role: # If there is no new information, suspend and wait logger.debug(f"{self._setting}: no news. waiting.") return - - rsp = await self.react() + + rsp = await self._react() # Publish the reply to the environment, waiting for the next subscriber to process self._publish_message(rsp) return rsp From 0dd63e4b2363d30d6c7e5db1705e749f00c9f82f Mon Sep 17 00:00:00 2001 From: stellahsr Date: Mon, 27 Nov 2023 21:13:19 +0800 Subject: [PATCH 02/14] update test cases for serialize_deserialize --- .../metagpt/serialize_deserialize/__init__.py | 4 ++ .../serialize_deserialize/test_actions.py | 24 ++++++++++ .../test_architect_deserialize.py | 26 ++++++++++ .../test_product_manager.py | 21 +++++++++ .../test_project_manager.py | 26 ++++++++++ .../serialize_deserialize/test_role.py | 41 ++++++++++++++++ .../serialize_deserialize/test_team.py | 47 +++++++++++++++++++ .../serialize_deserialize/test_wrire_prd.py | 28 +++++++++++ .../serialize_deserialize/test_write_code.py | 42 +++++++++++++++++ .../test_write_design.py | 39 +++++++++++++++ 10 files changed, 298 insertions(+) create mode 100644 tests/metagpt/serialize_deserialize/__init__.py create mode 100644 tests/metagpt/serialize_deserialize/test_actions.py create mode 100644 tests/metagpt/serialize_deserialize/test_architect_deserialize.py create mode 100644 tests/metagpt/serialize_deserialize/test_product_manager.py create mode 100644 tests/metagpt/serialize_deserialize/test_project_manager.py create mode 100644 tests/metagpt/serialize_deserialize/test_role.py create mode 100644 tests/metagpt/serialize_deserialize/test_team.py create mode 100644 tests/metagpt/serialize_deserialize/test_wrire_prd.py create mode 100644 tests/metagpt/serialize_deserialize/test_write_code.py create mode 100644 tests/metagpt/serialize_deserialize/test_write_design.py diff --git a/tests/metagpt/serialize_deserialize/__init__.py b/tests/metagpt/serialize_deserialize/__init__.py new file mode 100644 index 000000000..78f454fb5 --- /dev/null +++ b/tests/metagpt/serialize_deserialize/__init__.py @@ -0,0 +1,4 @@ +# -*- coding: utf-8 -*- +# @Date : 11/22/2023 11:48 AM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : diff --git a/tests/metagpt/serialize_deserialize/test_actions.py b/tests/metagpt/serialize_deserialize/test_actions.py new file mode 100644 index 000000000..e2efa982b --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_actions.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -*- +# @Date : 11/22/2023 11:48 AM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +import pytest + +from metagpt.actions import Action +from metagpt.llm import LLM + +def test_action_serialize(): + action = Action() + ser_action_dict = action.dict() + assert "name" in ser_action_dict + assert "llm" in ser_action_dict + +@pytest.mark.asyncio +async def test_action_deserialize(): + action = Action() + serialized_data = action.dict() + + new_action = Action(**serialized_data) + assert new_action.name == "" + assert new_action.llm == LLM() + assert len(await new_action._aask("who are you")) > 0 diff --git a/tests/metagpt/serialize_deserialize/test_architect_deserialize.py b/tests/metagpt/serialize_deserialize/test_architect_deserialize.py new file mode 100644 index 000000000..cff1bbadd --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_architect_deserialize.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- +# @Date : 11/26/2023 2:04 PM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +import pytest + +from metagpt.roles.architect import Architect +from metagpt.actions.action import Action + +def test_architect_serialize(): + role = Architect() + ser_role_dict = role.dict(by_alias=True) + assert "name" in ser_role_dict + assert "_states" in ser_role_dict + assert "_actions" in ser_role_dict + +@pytest.mark.asyncio +async def test_architect_deserialize(): + role = Architect() + ser_role_dict = role.dict(by_alias=True) + new_role = Architect(**ser_role_dict) + # new_role = Architect.deserialize(ser_role_dict) + assert new_role.name == "Bob" + assert len(new_role._actions) == 1 + assert isinstance(new_role._actions[0], Action) + await new_role._actions[0].run(context="write a cli snake game") \ No newline at end of file diff --git a/tests/metagpt/serialize_deserialize/test_product_manager.py b/tests/metagpt/serialize_deserialize/test_product_manager.py new file mode 100644 index 000000000..978c50e5e --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_product_manager.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- +# @Date : 11/26/2023 2:07 PM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +import pytest + +from metagpt.roles.product_manager import ProductManager +from metagpt.actions.action import Action +from metagpt.schema import Message + +@pytest.mark.asyncio +async def test_product_manager_deserialize(): + role = ProductManager() + ser_role_dict = role.dict(by_alias=True) + new_role = ProductManager(**ser_role_dict) + # new_role = ProductManager().deserialize(ser_role_dict) + + assert new_role.name == "Alice" + assert len(new_role._actions) == 1 + assert isinstance(new_role._actions[0], Action) + await new_role._actions[0].run([Message(content="write a cli snake game")]) \ No newline at end of file diff --git a/tests/metagpt/serialize_deserialize/test_project_manager.py b/tests/metagpt/serialize_deserialize/test_project_manager.py new file mode 100644 index 000000000..590bd8109 --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_project_manager.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- +# @Date : 11/26/2023 2:06 PM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +import pytest + +from metagpt.roles.project_manager import ProjectManager +from metagpt.actions.action import Action + +def test_project_manager_serialize(): + role = ProjectManager() + ser_role_dict = role.dict(by_alias=True) + assert "name" in ser_role_dict + assert "_states" in ser_role_dict + assert "_actions" in ser_role_dict + +@pytest.mark.asyncio +async def test_project_manager_deserialize(): + role = ProjectManager() + ser_role_dict = role.dict(by_alias=True) + new_role = ProjectManager(**ser_role_dict) + # new_role = ProjectManager().deserialize(ser_role_dict) + assert new_role.name == "Eve" + assert len(new_role._actions) == 1 + assert isinstance(new_role._actions[0], Action) + await new_role._actions[0].run(context="write a cli snake game") \ No newline at end of file diff --git a/tests/metagpt/serialize_deserialize/test_role.py b/tests/metagpt/serialize_deserialize/test_role.py new file mode 100644 index 000000000..432c9acb7 --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_role.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +# @Date : 11/23/2023 4:49 PM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +import pytest + +from metagpt.roles.role import Role +from metagpt.roles.engineer import Engineer + +from metagpt.actions.action import Action + + +def test_role_serialize(): + role = Role() + ser_role_dict = role.dict(by_alias=True) + assert "name" in ser_role_dict + assert "_states" in ser_role_dict + assert "_actions" in ser_role_dict + + +def test_engineer_serialize(): + role = Engineer() + ser_role_dict = role.dict(by_alias=True) + assert "name" in ser_role_dict + assert "_states" in ser_role_dict + assert "_actions" in ser_role_dict + + +@pytest.mark.asyncio +async def test_engineer_deserialize(): + role = Engineer(use_code_review=True) + ser_role_dict = role.dict(by_alias=True) + # new_role = Engineer().deserialize(ser_role_dict) + # also can be deserialized in this way: + new_role = Engineer(**ser_role_dict) + assert new_role.name == "Alex" + assert new_role.use_code_review == True + assert len(new_role._actions) == 2 + assert isinstance(new_role._actions[0], Action) + assert isinstance(new_role._actions[1], Action) + await new_role._actions[0].run(context="write a cli snake game", filename="test_code") diff --git a/tests/metagpt/serialize_deserialize/test_team.py b/tests/metagpt/serialize_deserialize/test_team.py new file mode 100644 index 000000000..44a75d262 --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_team.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +# @Date : 11/27/2023 10:07 AM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +import pytest + +from metagpt.environment import Environment +from metagpt.schema import Message +from metagpt.software_company import SoftwareCompany +from metagpt.roles import ProjectManager, ProductManager, Architect + + +def test_env_serialize(): + env = Environment() + ser_env_dict = env.dict() + assert "roles" in ser_env_dict + assert "memory" in ser_env_dict + assert "memory" in ser_env_dict + + +def test_env_deserialize(): + env = Environment() + env.publish_message(message=Message(content="test env serialize")) + ser_env_dict = env.dict() + new_env = Environment(**ser_env_dict) + assert len(new_env.roles) == 0 + assert new_env.memory.storage[0].content == "test env serialize" + assert len(new_env.history) == 25 + + +def test_softwarecompany_deserialize(): + team = SoftwareCompany() + team.hire( + [ + ProductManager(), + Architect(), + ProjectManager(), + ] + ) + assert len(team.environment.get_roles()) == 3 + ser_team_dict = team.dict() + new_team = SoftwareCompany(**ser_team_dict) + + assert len(new_team.environment.get_roles()) == 3 + assert new_team.environment.get_role('Product Manager') is not None + assert new_team.environment.get_role('Product Manager') is not None + assert new_team.environment.get_role('Architect') is not None diff --git a/tests/metagpt/serialize_deserialize/test_wrire_prd.py b/tests/metagpt/serialize_deserialize/test_wrire_prd.py new file mode 100644 index 000000000..9b2653820 --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_wrire_prd.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- +# @Date : 11/22/2023 1:47 PM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +import pytest + +from metagpt.actions import WritePRD +from metagpt.llm import LLM +from metagpt.schema import Message + + +def test_action_serialize(): + action = WritePRD() + ser_action_dict = action.dict() + assert "name" in ser_action_dict + assert "llm" in ser_action_dict + + +@pytest.mark.asyncio +async def test_action_deserialize(): + action = WritePRD() + serialized_data = action.dict() + new_action = WritePRD(**serialized_data) + # new_action = WritePRD().deserialize(serialized_data) + assert new_action.name == "" + assert new_action.llm == LLM() + assert len(await new_action.run([Message(content="write a cli snake game")]))>0 + diff --git a/tests/metagpt/serialize_deserialize/test_write_code.py b/tests/metagpt/serialize_deserialize/test_write_code.py new file mode 100644 index 000000000..0b1f1dc7c --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_write_code.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +# @Date : 11/23/2023 10:56 AM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +import pytest + +from metagpt.actions import WriteCode, WriteCodeReview +from metagpt.llm import LLM + +def test_write_design_serialize(): + action = WriteCode() + ser_action_dict = action.dict() + assert ser_action_dict["name"] == "WriteCode" + assert "llm" in ser_action_dict + +def test_write_task_serialize(): + action = WriteCodeReview() + ser_action_dict = action.dict() + assert ser_action_dict["name"] == "WriteCodeReview" + assert "llm" in ser_action_dict + +@pytest.mark.asyncio +async def test_write_code_deserialize(): + action = WriteCode() + serialized_data = action.dict() + new_action = WriteCode(**serialized_data) + # new_action = WriteCode().deserialize(serialized_data) + assert new_action.name == "WriteCode" + assert new_action.llm == LLM() + await new_action.run(context="write a cli snake game", filename="test_code") + +@pytest.mark.asyncio +async def test_write_code_review_deserialize(): + action = WriteCodeReview() + serialized_data = action.dict() + new_action = WriteCodeReview(**serialized_data) + # new_action = WriteCodeReview().deserialize(serialized_data) + code = await WriteCode().run(context="write a cli snake game", filename="test_code") + + assert new_action.name == "WriteCodeReview" + assert new_action.llm == LLM() + await new_action.run(context="write a cli snake game", code =code, filename="test_rewrite_code") \ No newline at end of file diff --git a/tests/metagpt/serialize_deserialize/test_write_design.py b/tests/metagpt/serialize_deserialize/test_write_design.py new file mode 100644 index 000000000..56bf78a63 --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_write_design.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +# @Date : 11/22/2023 8:19 PM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +import pytest + +from metagpt.actions import WriteDesign, WriteTasks +from metagpt.llm import LLM + +def test_write_design_serialize(): + action = WriteDesign() + ser_action_dict = action.dict() + assert "name" in ser_action_dict + assert "llm" in ser_action_dict + +def test_write_task_serialize(): + action = WriteTasks() + ser_action_dict = action.dict() + assert "name" in ser_action_dict + assert "llm" in ser_action_dict + +@pytest.mark.asyncio +async def test_write_design_deserialize(): + action = WriteDesign() + serialized_data = action.dict() + new_action = WriteDesign().deserialize(serialized_data) + assert new_action.name == "" + assert new_action.llm == LLM() + await new_action.run(context="write a cli snake game") + +@pytest.mark.asyncio +async def test_write_task_deserialize(): + action = WriteTasks() + serialized_data = action.dict() + new_action = WriteTasks(**serialized_data) + # new_action = WriteTasks().deserialize(serialized_data) + assert new_action.name == "CreateTasks" + assert new_action.llm == LLM() + await new_action.run(context="write a cli snake game") \ No newline at end of file From d99b4c62e33d1c37cb832c04030697d37a90be66 Mon Sep 17 00:00:00 2001 From: better629 Date: Tue, 28 Nov 2023 09:29:00 +0800 Subject: [PATCH 03/14] add mg ser&deser --- metagpt/actions/action.py | 32 ++++++++ metagpt/const.py | 1 + metagpt/environment.py | 38 +++++++++ metagpt/memory/memory.py | 30 +++++++ metagpt/roles/role.py | 115 ++++++++++++++++++++++++++- metagpt/schema.py | 44 ++++++++++ metagpt/team.py | 26 ++++++ metagpt/utils/serialize.py | 62 +++++++++++++-- metagpt/utils/utils.py | 41 ++++++++++ startup.py | 41 ++++++---- tests/metagpt/actions/test_action.py | 17 ++++ tests/metagpt/memory/test_memory.py | 42 ++++++++++ tests/metagpt/roles/test_role.py | 85 ++++++++++++++++++++ tests/metagpt/test_environment.py | 29 +++++-- tests/metagpt/test_schema.py | 42 ++++++++++ tests/metagpt/test_team.py | 27 +++++++ 16 files changed, 641 insertions(+), 31 deletions(-) create mode 100644 metagpt/utils/utils.py create mode 100644 tests/metagpt/memory/test_memory.py create mode 100644 tests/metagpt/roles/test_role.py create mode 100644 tests/metagpt/test_team.py diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index 790295d55..a538baa77 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -9,6 +9,7 @@ import re from abc import ABC from typing import Optional +import importlib from tenacity import retry, stop_after_attempt, wait_fixed from metagpt.actions.action_output import ActionOutput @@ -16,6 +17,7 @@ from metagpt.llm import LLM from metagpt.logs import logger from metagpt.utils.common import OutputParser from metagpt.utils.custom_decoder import CustomDecoder +from metagpt.utils.utils import import_class class Action(ABC): @@ -42,6 +44,36 @@ class Action(ABC): def __repr__(self): return self.__str__() + def serialize(self): + return { + "action_class": self.__class__.__name__, + "module_name": self.__module__, + "name": self.name + } + + @classmethod + def deserialize(cls, action_dict: dict): + action_class_str = action_dict.pop("action_class") + module_name = action_dict.pop("module_name") + action_class = import_class(action_class_str, module_name) + return action_class(**action_dict) + + @classmethod + def ser_class(cls): + """ serialize class type""" + return { + "action_class": cls.__name__, + "module_name": cls.__module__ + } + + @classmethod + def deser_class(cls, action_dict: dict): + """ deserialize class type """ + action_class_str = action_dict.pop("action_class") + module_name = action_dict.pop("module_name") + action_class = import_class(action_class_str, module_name) + return action_class + async def _aask(self, prompt: str, system_msgs: Optional[list[str]] = None) -> str: """Append default prefix""" if not system_msgs: diff --git a/metagpt/const.py b/metagpt/const.py index 407ce803a..711546d03 100644 --- a/metagpt/const.py +++ b/metagpt/const.py @@ -42,6 +42,7 @@ TMP = PROJECT_ROOT / "tmp" RESEARCH_PATH = DATA_PATH / "research" TUTORIAL_PATH = DATA_PATH / "tutorial_docx" INVOICE_OCR_TABLE_PATH = DATA_PATH / "invoice_table" +SERDES_PATH = WORKSPACE_ROOT / "storage" # TODO to store `storage` under the individual generated project SKILL_DIRECTORY = PROJECT_ROOT / "metagpt/skills" diff --git a/metagpt/environment.py b/metagpt/environment.py index 24e6ada2f..d1fa561f0 100644 --- a/metagpt/environment.py +++ b/metagpt/environment.py @@ -7,12 +7,14 @@ """ import asyncio from typing import Iterable +from pathlib import Path from pydantic import BaseModel, Field from metagpt.memory import Memory from metagpt.roles import Role from metagpt.schema import Message +from metagpt.utils.utils import read_json_file, write_json_file class Environment(BaseModel): @@ -28,6 +30,42 @@ class Environment(BaseModel): class Config: arbitrary_types_allowed = True + def serialize(self, stg_path: Path): + roles_path = stg_path.joinpath("roles.json") + roles_info = [] + for role_key, role in self.roles.items(): + roles_info.append({ + "role_class": role.__class__.__name__, + "module_name": role.__module__, + "role_name": role.name + }) + role.serialize(stg_path=stg_path.joinpath(f"roles/{role.__class__.__name__}_{role.name}")) + write_json_file(roles_path, roles_info) + + self.memory.serialize(stg_path) + history_path = stg_path.joinpath("history.json") + write_json_file(history_path, {"content": self.history}) + + def deserialize(self, stg_path: Path): + """ stg_path: ./storage/team/environment/ """ + roles_path = stg_path.joinpath("roles.json") + roles_info = read_json_file(roles_path) + for role_info in roles_info: + role_class = role_info.get("role_class") + role_name = role_info.get("role_name") + + role_path = stg_path.joinpath(f"roles/{role_class}_{role_name}") + role = Role.deserialize(role_path) + + self.add_role(role) + + memory = Memory.deserialize(stg_path) + self.memory = memory + + history_path = stg_path.joinpath("history.json") + history = read_json_file(history_path) + self.history = history.get("content") + def add_role(self, role: Role): """增加一个在当前环境的角色 Add a role in the current environment diff --git a/metagpt/memory/memory.py b/metagpt/memory/memory.py index c818fa707..a839bb038 100644 --- a/metagpt/memory/memory.py +++ b/metagpt/memory/memory.py @@ -7,9 +7,12 @@ """ from collections import defaultdict from typing import Iterable, Type +from pathlib import Path from metagpt.actions import Action from metagpt.schema import Message +from metagpt.utils.utils import read_json_file, write_json_file +from metagpt.utils.serialize import serialize_general_message, deserialize_general_message class Memory: @@ -20,6 +23,33 @@ class Memory: self.storage: list[Message] = [] self.index: dict[Type[Action], list[Message]] = defaultdict(list) + def serialize(self, stg_path: Path): + """ stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/ """ + memory_path = stg_path.joinpath("memory.json") + + storage = [] + for message in self.storage: + # msg_dict = message.serialize() + msg_dict = serialize_general_message(message) + storage.append(msg_dict) + + write_json_file(memory_path, storage) + + @classmethod + def deserialize(cls, stg_path: Path) -> "Memory": + """ stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/""" + memory_path = stg_path.joinpath("memory.json") + + memory = Memory() + memory_list = read_json_file(memory_path) + for message in memory_list: + # distinguish instruct_content type in message + # msg = Message.deserialize(message) + msg = deserialize_general_message(message) + memory.add(msg) + + return memory + def add(self, message: Message): """Add a new message to storage, while updating the index""" if message in self.storage: diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index b96c361c0..9b0613fd5 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -9,8 +9,9 @@ from __future__ import annotations from typing import Iterable, Type, Union from enum import Enum - +from pathlib import Path from pydantic import BaseModel, Field +import importlib # from metagpt.environment import Environment from metagpt.config import CONFIG @@ -19,6 +20,7 @@ from metagpt.llm import LLM, HumanProvider from metagpt.logs import logger from metagpt.memory import Memory, LongTermMemory from metagpt.schema import Message +from metagpt.utils.utils import read_json_file, write_json_file, import_class PREFIX_TEMPLATE = """You are a {profile}, named {name}, your goal is {goal}, and the constraint is {constraints}. """ @@ -115,11 +117,101 @@ class Role: self._actions = [] self._role_id = str(self._setting) self._rc = RoleContext() + self._recovered = False + + def serialize(self, stg_path: Path): + role_info_path = stg_path.joinpath("role_info.json") + role_info = { + "role_class": self.__class__.__name__, + "module_name": self.__module__ + } + setting = self._setting.dict() + setting.pop("desc") + setting.pop("is_human") # not all inherited roles have this atrr + role_info.update(setting) + write_json_file(role_info_path, role_info) + + actions_info_path = stg_path.joinpath("actions/actions_info.json") + actions_info = [] + for action in self._actions: + actions_info.append(action.serialize()) + write_json_file(actions_info_path, actions_info) + + watches_info_path = stg_path.joinpath("watches/watches_info.json") + watches_info = [] + for watch in self._rc.watch: + watches_info.append(watch.ser_class()) + write_json_file(watches_info_path, watches_info) + + actions_todo_path = stg_path.joinpath("actions/todo.json") + actions_todo = { + "cur_state": self._rc.state, + "react_mode": self._rc.react_mode.value, + "max_react_loop": self._rc.max_react_loop + } + write_json_file(actions_todo_path, actions_todo) + + self._rc.memory.serialize(stg_path) + + @classmethod + def deserialize(cls, stg_path: Path) -> "Role": + """ stg_path = ./storage/team/environment/roles/{role_class}_{role_name}""" + role_info_path = stg_path.joinpath("role_info.json") + role_info = read_json_file(role_info_path) + + role_class_str = role_info.pop("role_class") + module_name = role_info.pop("module_name") + role_class = import_class(class_name=role_class_str, module_name=module_name) + + role = role_class(**role_info) # initiate particular Role + actions_info_path = stg_path.joinpath("actions/actions_info.json") + actions = [] + actions_info = read_json_file(actions_info_path) + for action_info in actions_info: + action = Action.deserialize(action_info) + actions.append(action) + + watches_info_path = stg_path.joinpath("watches/watches_info.json") + watches = [] + watches_info = read_json_file(watches_info_path) + for watch_info in watches_info: + action = Action.deser_class(watch_info) + watches.append(action) + + role.init_actions(actions) + role.watch(watches) + + actions_todo_path = stg_path.joinpath("actions/todo.json") + # recover self._rc.state + actions_todo = read_json_file(actions_todo_path) + max_react_loop = actions_todo.get("max_react_loop", 1) + cur_state = actions_todo.get("cur_state", -1) + role.set_state(cur_state) + role.set_recovered(True) + react_mode_str = actions_todo.get("react_mode", RoleReactMode.REACT.value) + if react_mode_str not in RoleReactMode.values(): + logger.warning(f"ReactMode: {react_mode_str} not in {RoleReactMode.values()}, use react as default") + react_mode_str = RoleReactMode.REACT.value + role.set_react_mode(RoleReactMode(react_mode_str), max_react_loop) + + role_memory = Memory.deserialize(stg_path) + role.set_memory(role_memory) + + return role def _reset(self): self._states = [] self._actions = [] + def set_recovered(self, recovered: bool = False): + self._recovered = recovered + + def set_memory(self, memory: Memory): + self._rc.memory = memory + + def init_actions(self, actions): + self._init_actions(actions) + def _init_actions(self, actions): self._reset() for idx, action in enumerate(actions): @@ -134,6 +226,9 @@ class Role: self._actions.append(i) self._states.append(f"{idx}. {action}") + def set_react_mode(self, react_mode: RoleReactMode, max_react_loop: int = 1): + self._set_react_mode(react_mode, max_react_loop) + def _set_react_mode(self, react_mode: str, max_react_loop: int = 1): """Set strategy of the Role reacting to observed Message. Variation lies in how this Role elects action to perform during the _think stage, especially if it is capable of multiple Actions. @@ -155,12 +250,18 @@ class Role: if react_mode == RoleReactMode.REACT: self._rc.max_react_loop = max_react_loop + def watch(self, actions: Iterable[Type[Action]]): + self._watch(actions) + def _watch(self, actions: Iterable[Type[Action]]): """Listen to the corresponding behaviors""" self._rc.watch.update(actions) # check RoleContext after adding watch actions self._rc.check(self._role_id) + def set_state(self, state: int): + self._set_state(state) + def _set_state(self, state: int): """Update the current state.""" self._rc.state = state @@ -171,6 +272,10 @@ class Role: """Set the environment in which the role works. The role can talk to the environment and can also receive messages by observing.""" self._rc.env = env + @property + def name(self): + return self._setting.name + @property def profile(self): """Get the role description (position)""" @@ -188,6 +293,11 @@ class Role: # If there is only one action, then only this one can be performed self._set_state(0) return + if self._recovered and self._rc.state >= 0: + self._set_state(self._rc.state) # action to run from recovered state + self._recovered = False # avoid max_react_loop out of work + return + prompt = self._get_prefix() prompt += STATE_TEMPLATE.format(history=self._rc.history, states="\n".join(self._states), n_states=len(self._states) - 1, previous_state=self._rc.state) @@ -267,7 +377,8 @@ class Role: async def _act_by_order(self) -> Message: """switch action each time by order defined in _init_actions, i.e. _act (Action1) -> _act (Action2) -> ...""" - for i in range(len(self._states)): + start_idx = self._rc.state if self._rc.state >= 0 else 0 # action to run from recovered state + for i in range(start_idx, len(self._states)): self._set_state(i) rsp = await self._act() return rsp # return output from the last action diff --git a/metagpt/schema.py b/metagpt/schema.py index bdca093c2..3374a7241 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -9,10 +9,14 @@ from __future__ import annotations from dataclasses import dataclass, field from typing import Type, TypedDict +import copy from pydantic import BaseModel from metagpt.logs import logger +# from metagpt.utils.serialize import actionoutout_schema_to_mapping +# from metagpt.actions.action_output import ActionOutput +# from metagpt.actions.action import Action class RawMessage(TypedDict): @@ -38,6 +42,46 @@ class Message: def __repr__(self): return self.__str__() + # def serialize(self): + # message_cp: Message = copy.deepcopy(self) + # ic = message_cp.instruct_content + # if ic: + # # model create by pydantic create_model like `pydantic.main.prd`, can't pickle.dump directly + # schema = ic.schema() + # mapping = actionoutout_schema_to_mapping(schema) + # + # message_cp.instruct_content = {"class": schema["title"], "mapping": mapping, "value": ic.dict()} + # cb = message_cp.cause_by + # if cb: + # message_cp.cause_by = cb.serialize() + # + # return message_cp.dict() + # + # @classmethod + # def deserialize(cls, message_dict: dict): + # instruct_content = message_dict.get("instruct_content") + # if instruct_content: + # ic = instruct_content + # ic_obj = ActionOutput.create_model_class(class_name=ic["class"], mapping=ic["mapping"]) + # ic_new = ic_obj(**ic["value"]) + # message_dict.instruct_content = ic_new + # cause_by = message_dict.get("cause_by") + # if cause_by: + # message_dict.cause_by = Action.deserialize(cause_by) + # + # return Message(**message_dict) + + def dict(self): + return { + "content": self.content, + "instruct_content": self.instruct_content, + "role": self.role, + "cause_by": self.cause_by, + "sent_from": self.sent_from, + "send_to": self.send_to, + "restricted_to": self.restricted_to + } + def to_dict(self) -> dict: return { "role": self.role, diff --git a/metagpt/team.py b/metagpt/team.py index 67d3ecec8..3b76e5ff4 100644 --- a/metagpt/team.py +++ b/metagpt/team.py @@ -5,6 +5,7 @@ @Author : alexanderwu @File : software_company.py """ +from pathlib import Path from pydantic import BaseModel, Field from metagpt.actions import BossRequirement @@ -14,6 +15,7 @@ from metagpt.logs import logger from metagpt.roles import Role from metagpt.schema import Message from metagpt.utils.common import NoMoneyException +from metagpt.utils.utils import read_json_file, write_json_file class Team(BaseModel): @@ -28,6 +30,30 @@ class Team(BaseModel): class Config: arbitrary_types_allowed = True + def serialize(self, stg_path: Path): + team_info_path = stg_path.joinpath("team_info.json") + write_json_file(team_info_path, { + "idea": self.idea, + "investment": self.investment + }) + + self.environment.serialize(stg_path.joinpath("environment")) + + def deserialize(self, stg_path: Path): + """ stg_path = ./storage/team """ + # recover team_info + team_info_path = stg_path.joinpath("team_info.json") + if not team_info_path.exists(): + logger.error("recover storage not exist, not to recover and continue run the old project.") + team_info = read_json_file(team_info_path) + self.investment = team_info.get("investment", 10.0) + self.idea = team_info.get("idea", "") + + # recover environment + environment_path = stg_path.joinpath("environment") + self.environment = Environment() + self.environment.deserialize(stg_path=environment_path) + def hire(self, roles: list[Role]): """Hire roles to cooperate""" self.environment.add_roles(roles) diff --git a/metagpt/utils/serialize.py b/metagpt/utils/serialize.py index 124176fcb..56a866f2e 100644 --- a/metagpt/utils/serialize.py +++ b/metagpt/utils/serialize.py @@ -4,13 +4,13 @@ import copy import pickle -from typing import Dict, List from metagpt.actions.action_output import ActionOutput from metagpt.schema import Message +from metagpt.actions.action import Action -def actionoutout_schema_to_mapping(schema: Dict) -> Dict: +def actionoutout_schema_to_mapping(schema: dict) -> dict: """ directly traverse the `properties` in the first level. schema structure likes @@ -35,13 +35,47 @@ def actionoutout_schema_to_mapping(schema: Dict) -> Dict: if property["type"] == "string": mapping[field] = (str, ...) elif property["type"] == "array" and property["items"]["type"] == "string": - mapping[field] = (List[str], ...) + mapping[field] = (list[str], ...) elif property["type"] == "array" and property["items"]["type"] == "array": - # here only consider the `List[List[str]]` situation - mapping[field] = (List[List[str]], ...) + # here only consider the `list[list[str]]` situation + mapping[field] = (list[list[str]], ...) return mapping +def actionoutput_mapping_to_str(mapping: dict) -> dict: + new_mapping = {} + for key, value in mapping.items(): + new_mapping[key] = str(value) + return new_mapping + + +def actionoutput_str_to_mapping(mapping: dict) -> dict: + new_mapping = {} + for key, value in mapping.items(): + if value == "(, Ellipsis)": + new_mapping[key] = (str, ...) + else: + new_mapping[key] = eval(value) # `"'(list[str], Ellipsis)"` to `(list[str], ...)` + return new_mapping + + +def serialize_general_message(message: Message) -> dict: + """ serialize Message, not to save""" + message_cp = copy.deepcopy(message) + ic = message_cp.instruct_content + if ic: + # model create by pydantic create_model like `pydantic.main.prd`, can't pickle.dump directly + schema = ic.schema() + mapping = actionoutout_schema_to_mapping(schema) + mapping = actionoutput_mapping_to_str(mapping) + + message_cp.instruct_content = {"class": schema["title"], "mapping": mapping, "value": ic.dict()} + cb = message_cp.cause_by + if cb: + message_cp.cause_by = cb.ser_class() + return message_cp.dict() + + def serialize_message(message: Message): message_cp = copy.deepcopy(message) # avoid `instruct_content` value update by reference ic = message_cp.instruct_content @@ -56,6 +90,24 @@ def serialize_message(message: Message): return msg_ser +def deserialize_general_message(message_dict: dict) -> Message: + """ deserialize Message, not to load""" + instruct_content = message_dict.pop("instruct_content") + cause_by = message_dict.pop("cause_by") + + message = Message(**message_dict) + if instruct_content: + ic = instruct_content + mapping = actionoutput_str_to_mapping(ic["mapping"]) + ic_obj = ActionOutput.create_model_class(class_name=ic["class"], mapping=mapping) + ic_new = ic_obj(**ic["value"]) + message.instruct_content = ic_new + if cause_by: + message.cause_by = Action.deser_class(cause_by) + + return message + + def deserialize_message(message_ser: str) -> Message: message = pickle.loads(message_ser) if message.instruct_content: diff --git a/metagpt/utils/utils.py b/metagpt/utils/utils.py new file mode 100644 index 000000000..81ceea884 --- /dev/null +++ b/metagpt/utils/utils.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +from typing import Any +import json +from pathlib import Path +import importlib + + +def read_json_file(json_file: str, encoding=None) -> list[Any]: + if not Path(json_file).exists(): + raise FileNotFoundError(f"json_file: {json_file} not exist, return []") + + with open(json_file, "r", encoding=encoding) as fin: + try: + data = json.load(fin) + except Exception as exp: + raise ValueError(f"read json file: {json_file} failed") + return data + + +def write_json_file(json_file: str, data: list, encoding=None): + folder_path = Path(json_file).parent + if not folder_path.exists(): + folder_path.mkdir(parents=True, exist_ok=True) + + with open(json_file, "w", encoding=encoding) as fout: + json.dump(data, fout, ensure_ascii=False, indent=4) + + +def import_class(class_name: str, module_name: str) -> type: + module = importlib.import_module(module_name) + a_class = getattr(module, class_name) + return a_class + + +def import_class_inst(class_name: str, module_name: str, *args, **kwargs) -> object: + a_class = import_class(class_name, module_name) + class_inst = a_class(*args, **kwargs) + return class_inst diff --git a/startup.py b/startup.py index e9fbf94d3..9f753d553 100644 --- a/startup.py +++ b/startup.py @@ -4,6 +4,7 @@ import asyncio import fire +from metagpt.const import SERDES_PATH from metagpt.roles import ( Architect, Engineer, @@ -21,26 +22,32 @@ async def startup( code_review: bool = False, run_tests: bool = False, implement: bool = True, + recover_path: bool = False, ): """Run a startup. Be a boss.""" company = Team() - company.hire( - [ - ProductManager(), - Architect(), - ProjectManager(), - ] - ) + if not recover_path: + company.hire( + [ + ProductManager(), + Architect(), + ProjectManager(), + ] + ) - # if implement or code_review - if implement or code_review: - # developing features: implement the idea - company.hire([Engineer(n_borg=5, use_code_review=code_review)]) + # if implement or code_review + if implement or code_review: + # developing features: implement the idea + company.hire([Engineer(n_borg=5, use_code_review=code_review)]) - if run_tests: - # developing features: run tests on the spot and identify bugs - # (bug fixing capability comes soon!) - company.hire([QaEngineer()]) + if run_tests: + # developing features: run tests on the spot and identify bugs + # (bug fixing capability comes soon!) + company.hire([QaEngineer()]) + else: + stg_path = SERDES_PATH.joinpath("team") + company.deserialize(stg_path=stg_path) + idea = company.idea # use original idea company.invest(investment) company.start_project(idea) @@ -54,6 +61,7 @@ def main( code_review: bool = True, run_tests: bool = False, implement: bool = True, + recover_path: str = None, ): """ We are a software startup comprised of AI. By investing in us, @@ -63,9 +71,10 @@ def main( a certain dollar amount to this AI company. :param n_round: :param code_review: Whether to use code review. + :param recover_path: recover the project from existing serialized storage :return: """ - asyncio.run(startup(idea, investment, n_round, code_review, run_tests, implement)) + asyncio.run(startup(idea, investment, n_round, code_review, run_tests, implement, recover_path)) if __name__ == "__main__": diff --git a/tests/metagpt/actions/test_action.py b/tests/metagpt/actions/test_action.py index 9775630cc..4468a6f6f 100644 --- a/tests/metagpt/actions/test_action.py +++ b/tests/metagpt/actions/test_action.py @@ -11,3 +11,20 @@ from metagpt.actions import Action, WritePRD, WriteTest def test_action_repr(): actions = [Action(), WriteTest(), WritePRD()] assert "WriteTest" in str(actions) + + +def test_action_serdes(): + action_info = WriteTest.ser_class() + assert action_info["action_class"] == "WriteTest" + + action_class = Action.deser_class(action_info) + assert action_class == WriteTest + + +def test_action_class_serdes(): + name = "write test" + action_info = WriteTest(name=name).serialize() + assert action_info["name"] == name + + action = Action.deserialize(action_info) + assert action.name == name diff --git a/tests/metagpt/memory/test_memory.py b/tests/metagpt/memory/test_memory.py new file mode 100644 index 000000000..bda79ded1 --- /dev/null +++ b/tests/metagpt/memory/test_memory.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : unittest of memory + +from pathlib import Path + +from metagpt.schema import Message +from metagpt.memory.memory import Memory +from metagpt.actions.action_output import ActionOutput +from metagpt.actions.design_api import WriteDesign +from metagpt.actions.add_requirement import BossRequirement + +serdes_path = Path(__file__).absolute().parent.joinpath("../../data/serdes_storage") + + +def test_memory_serdes(): + msg1 = Message(role="User", + content="write a 2048 game", + cause_by=BossRequirement) + + out_mapping = {"field1": (list[str], ...)} + out_data = {"field1": ["field1 value1", "field1 value2"]} + ic_obj = ActionOutput.create_model_class("system_design", out_mapping) + msg2 = Message(role="Architect", + instruct_content=ic_obj(**out_data), + content="system design content", + cause_by=WriteDesign) + + memory = Memory() + memory.add_batch([msg1, msg2]) + + stg_path = serdes_path.joinpath("team/environment") + memory.serialize(stg_path) + assert stg_path.joinpath("memory.json").exists() + + new_memory = Memory.deserialize(stg_path) + assert new_memory.count() == 2 + new_msg2 = new_memory.get(1)[0] + assert new_msg2.instruct_content.field1 == ["field1 value1", "field1 value2"] + assert new_msg2.cause_by == WriteDesign + + stg_path.joinpath("memory.json").unlink() diff --git a/tests/metagpt/roles/test_role.py b/tests/metagpt/roles/test_role.py new file mode 100644 index 000000000..a19ad9cb5 --- /dev/null +++ b/tests/metagpt/roles/test_role.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : unittest of Role + +from pathlib import Path +import shutil +import pytest + +from metagpt.roles.role import Role, RoleReactMode +from metagpt.actions.action import Action +from metagpt.schema import Message +from metagpt.actions.add_requirement import BossRequirement +from metagpt.roles.product_manager import ProductManager + +serdes_path = Path(__file__).absolute().parent.joinpath("../../data/serdes_storage") + + +def test_role_serdes(): + stg_path_prefix = serdes_path.joinpath("team/environment/roles/") + shutil.rmtree(serdes_path.joinpath("team"), ignore_errors=True) + + pm = ProductManager() + role_tag = f"{pm.__class__.__name__}_{pm.name}" + stg_path = stg_path_prefix.joinpath(role_tag) + pm.serialize(stg_path) + assert stg_path.joinpath("actions/actions_info.json").exists() + + new_pm = Role.deserialize(stg_path) + assert new_pm.name == pm.name + assert len(new_pm.get_memories(1)) == 0 + + +class ActionOK(Action): + + async def run(self, messages: list["Message"]): + return "ok" + + +class ActionRaise(Action): + + async def run(self, messages: list["Message"]): + raise RuntimeError("parse error") + + +class RoleA(Role): + + def __init__(self, + name: str = "RoleA", + profile: str = "Role A", + goal: str = "", + constraints: str = ""): + super(RoleA, self).__init__(name=name, profile=profile, goal=goal, constraints=constraints) + self._init_actions([ActionOK, ActionRaise]) + self._watch([BossRequirement]) + self._rc.react_mode = RoleReactMode.BY_ORDER + + async def run(self, message: "Message" = None, stg_path: str = None): + try: + await super(RoleA, self).run(message) + except Exception as exp: + print("exp ", exp) + self.serialize(stg_path) + + +@pytest.mark.asyncio +async def test_role_serdes_interrupt(): + role_a = RoleA() + shutil.rmtree(serdes_path.joinpath("team"), ignore_errors=True) + + stg_path = serdes_path.joinpath(f"team/environment/roles/{role_a.__class__.__name__}_{role_a.name}") + await role_a.run( + message=Message(content="demo", cause_by=BossRequirement), + stg_path=stg_path + ) + assert role_a._rc.memory.count() == 2 + + assert stg_path.joinpath("actions/todo.json").exists() + + new_role_a: Role = Role.deserialize(stg_path) + assert new_role_a._rc.state == 1 + await role_a.run( + message=Message(content="demo", cause_by=BossRequirement), + stg_path=stg_path + ) + diff --git a/tests/metagpt/test_environment.py b/tests/metagpt/test_environment.py index a0f1f6257..3cc2d8a7a 100644 --- a/tests/metagpt/test_environment.py +++ b/tests/metagpt/test_environment.py @@ -7,13 +7,18 @@ """ import pytest +from pathlib import Path +import shutil from metagpt.actions import BossRequirement from metagpt.environment import Environment from metagpt.logs import logger -from metagpt.manager import Manager from metagpt.roles import Architect, ProductManager, Role from metagpt.schema import Message +from tests.metagpt.roles.test_role import RoleA + + +serdes_path = Path(__file__).absolute().parent.joinpath("../data/serdes_storage") @pytest.fixture @@ -36,21 +41,29 @@ def test_get_roles(env: Environment): assert roles == {role1.profile: role1, role2.profile: role2} -def test_set_manager(env: Environment): - manager = Manager() - env.set_manager(manager) - assert env.manager == manager - - @pytest.mark.asyncio async def test_publish_and_process_message(env: Environment): product_manager = ProductManager("Alice", "Product Manager", "做AI Native产品", "资源有限") architect = Architect("Bob", "Architect", "设计一个可用、高效、较低成本的系统,包括数据结构与接口", "资源有限,需要节省成本") env.add_roles([product_manager, architect]) - env.set_manager(Manager()) env.publish_message(Message(role="BOSS", content="需要一个基于LLM做总结的搜索引擎", cause_by=BossRequirement)) await env.run(k=2) logger.info(f"{env.history=}") assert len(env.history) > 10 + + +def test_environment_serdes(): + environment = Environment() + role_a = RoleA() + + shutil.rmtree(serdes_path.joinpath("team"), ignore_errors=True) + + stg_path = serdes_path.joinpath("team/environment") + environment.add_role(role_a) + environment.serialize(stg_path) + + new_env: Environment = Environment() + new_env.deserialize(stg_path) + assert len(new_env.roles) == 1 diff --git a/tests/metagpt/test_schema.py b/tests/metagpt/test_schema.py index 12666e0d3..f515326e8 100644 --- a/tests/metagpt/test_schema.py +++ b/tests/metagpt/test_schema.py @@ -5,7 +5,11 @@ @Author : alexanderwu @File : test_schema.py """ + from metagpt.schema import AIMessage, Message, SystemMessage, UserMessage +from metagpt.actions.action_output import ActionOutput +from metagpt.actions.write_code import WriteCode +from metagpt.utils.serialize import serialize_general_message, deserialize_general_message def test_messages(): @@ -19,3 +23,41 @@ def test_messages(): text = str(msgs) roles = ['user', 'system', 'assistant', 'QA'] assert all([i in text for i in roles]) + + +def test_message_serdes(): + out_mapping = {"field3": (str, ...), "field4": (list[str], ...)} + out_data = {"field3": "field3 value3", "field4": ["field4 value1", "field4 value2"]} + ic_obj = ActionOutput.create_model_class("code", out_mapping) + + message = Message( + content="code", + instruct_content=ic_obj(**out_data), + role="engineer", + cause_by=WriteCode + ) + message_dict = serialize_general_message(message) + assert message_dict["cause_by"] == {"action_class": "WriteCode"} + assert message_dict["instruct_content"] == { + "class": "code", + "mapping": { + "field3": "(, Ellipsis)", + "field4": "(list[str], Ellipsis)" + }, + "value": { + "field3": "field3 value3", + "field4": ["field4 value1", "field4 value2"] + } + } + + new_message = deserialize_general_message(message_dict) + assert new_message.content == message.content + assert new_message.instruct_content == message.instruct_content + assert new_message.cause_by == message.cause_by + assert new_message.instruct_content.field3 == out_data["field3"] + + message = Message(content="code") + message_dict = serialize_general_message(message) + new_message = deserialize_general_message(message_dict) + assert new_message.instruct_content is None + assert new_message.cause_by == "" diff --git a/tests/metagpt/test_team.py b/tests/metagpt/test_team.py new file mode 100644 index 000000000..ab201152c --- /dev/null +++ b/tests/metagpt/test_team.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : unittest of team + +from pathlib import Path +import shutil + +from metagpt.team import Team + +from tests.metagpt.roles.test_role import RoleA + +serdes_path = Path(__file__).absolute().parent.joinpath("../data/serdes_storage") + + +def test_team_serdes(): + company = Team() + company.hire([RoleA()]) + + stg_path = serdes_path.joinpath("team") + shutil.rmtree(stg_path, ignore_errors=True) + + company.serialize(stg_path=stg_path) + + new_company = Team() + new_company.deserialize(stg_path) + + assert len(new_company.environment.roles) == 1 From 39e4aa98ab6101ee1016cc8584c4f36977498077 Mon Sep 17 00:00:00 2001 From: better629 Date: Tue, 28 Nov 2023 10:47:19 +0800 Subject: [PATCH 04/14] fix role and format ut of serialize_deserialize --- metagpt/roles/role.py | 29 +++++-------------- .../serialize_deserialize/test_actions.py | 2 ++ .../test_architect_deserialize.py | 2 ++ .../test_product_manager.py | 1 + .../test_project_manager.py | 2 ++ .../serialize_deserialize/test_role.py | 2 +- .../serialize_deserialize/test_wrire_prd.py | 4 +-- .../serialize_deserialize/test_write_code.py | 6 +++- .../test_write_design.py | 6 +++- 9 files changed, 27 insertions(+), 27 deletions(-) diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index eb5539f43..e9371c2c0 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -6,16 +6,11 @@ @File : role.py """ -import sys from enum import Enum -import importlib +from pathlib import Path from __future__ import annotations -from types import SimpleNamespace from typing import ( - Dict, - Optional, - Union, Iterable, Type ) @@ -30,6 +25,7 @@ from metagpt.llm import LLM from metagpt.logs import logger from metagpt.memory import Memory, LongTermMemory from metagpt.schema import Message +from metagpt.provider.human_provider import HumanProvider from metagpt.utils.utils import read_json_file, write_json_file, import_class PREFIX_TEMPLATE = """You are a {profile}, named {name}, your goal is {goal}, and the constraint is {constraints}. """ @@ -133,11 +129,11 @@ class Role(BaseModel): _rc: RoleContext = RoleContext() _private_attributes = { - "_setting': _setting, - "_role_id': _role_id, - '_states': [], - '_actions': [], - '_actions_type': [] # 用于记录和序列化 + "_setting": _setting, + "_role_id": _role_id, + "_states": [], + "_actions": [], + "_actions_type": [] # 用于记录和序列化 } class Config: @@ -162,17 +158,6 @@ class Role(BaseModel): object.__setattr__(self, '_states', []) object.__setattr__(self, '_actions', []) - @staticmethod - def _process_class(class_str, module_name): - cleaned_string = re.sub(r"[<>']", "", class_str).replace("class ", "") - package_name = "metagpt" - file_name = cleaned_string.replace(package_name, "").replace("." + module_name, "") - print(file_name) - # print("\n", sys.modules) - module_file = import_module(file_name, package=package_name) - module = getattr(module_file, module_name) - return module - def serialize(self, stg_path: Path): role_info_path = stg_path.joinpath("role_info.json") role_info = { diff --git a/tests/metagpt/serialize_deserialize/test_actions.py b/tests/metagpt/serialize_deserialize/test_actions.py index e2efa982b..2fec2121a 100644 --- a/tests/metagpt/serialize_deserialize/test_actions.py +++ b/tests/metagpt/serialize_deserialize/test_actions.py @@ -7,12 +7,14 @@ import pytest from metagpt.actions import Action from metagpt.llm import LLM + def test_action_serialize(): action = Action() ser_action_dict = action.dict() assert "name" in ser_action_dict assert "llm" in ser_action_dict + @pytest.mark.asyncio async def test_action_deserialize(): action = Action() diff --git a/tests/metagpt/serialize_deserialize/test_architect_deserialize.py b/tests/metagpt/serialize_deserialize/test_architect_deserialize.py index cff1bbadd..d0ee3bc99 100644 --- a/tests/metagpt/serialize_deserialize/test_architect_deserialize.py +++ b/tests/metagpt/serialize_deserialize/test_architect_deserialize.py @@ -7,6 +7,7 @@ import pytest from metagpt.roles.architect import Architect from metagpt.actions.action import Action + def test_architect_serialize(): role = Architect() ser_role_dict = role.dict(by_alias=True) @@ -14,6 +15,7 @@ def test_architect_serialize(): assert "_states" in ser_role_dict assert "_actions" in ser_role_dict + @pytest.mark.asyncio async def test_architect_deserialize(): role = Architect() diff --git a/tests/metagpt/serialize_deserialize/test_product_manager.py b/tests/metagpt/serialize_deserialize/test_product_manager.py index 978c50e5e..2aed87a28 100644 --- a/tests/metagpt/serialize_deserialize/test_product_manager.py +++ b/tests/metagpt/serialize_deserialize/test_product_manager.py @@ -8,6 +8,7 @@ from metagpt.roles.product_manager import ProductManager from metagpt.actions.action import Action from metagpt.schema import Message + @pytest.mark.asyncio async def test_product_manager_deserialize(): role = ProductManager() diff --git a/tests/metagpt/serialize_deserialize/test_project_manager.py b/tests/metagpt/serialize_deserialize/test_project_manager.py index 590bd8109..fbc0dcc08 100644 --- a/tests/metagpt/serialize_deserialize/test_project_manager.py +++ b/tests/metagpt/serialize_deserialize/test_project_manager.py @@ -7,6 +7,7 @@ import pytest from metagpt.roles.project_manager import ProjectManager from metagpt.actions.action import Action + def test_project_manager_serialize(): role = ProjectManager() ser_role_dict = role.dict(by_alias=True) @@ -14,6 +15,7 @@ def test_project_manager_serialize(): assert "_states" in ser_role_dict assert "_actions" in ser_role_dict + @pytest.mark.asyncio async def test_project_manager_deserialize(): role = ProjectManager() diff --git a/tests/metagpt/serialize_deserialize/test_role.py b/tests/metagpt/serialize_deserialize/test_role.py index 432c9acb7..0e438d1a2 100644 --- a/tests/metagpt/serialize_deserialize/test_role.py +++ b/tests/metagpt/serialize_deserialize/test_role.py @@ -34,7 +34,7 @@ async def test_engineer_deserialize(): # also can be deserialized in this way: new_role = Engineer(**ser_role_dict) assert new_role.name == "Alex" - assert new_role.use_code_review == True + assert new_role.use_code_review is True assert len(new_role._actions) == 2 assert isinstance(new_role._actions[0], Action) assert isinstance(new_role._actions[1], Action) diff --git a/tests/metagpt/serialize_deserialize/test_wrire_prd.py b/tests/metagpt/serialize_deserialize/test_wrire_prd.py index 9b2653820..baa08ed76 100644 --- a/tests/metagpt/serialize_deserialize/test_wrire_prd.py +++ b/tests/metagpt/serialize_deserialize/test_wrire_prd.py @@ -24,5 +24,5 @@ async def test_action_deserialize(): # new_action = WritePRD().deserialize(serialized_data) assert new_action.name == "" assert new_action.llm == LLM() - assert len(await new_action.run([Message(content="write a cli snake game")]))>0 - + assert len(await new_action.run([Message(content="write a cli snake game")])) > 0 + diff --git a/tests/metagpt/serialize_deserialize/test_write_code.py b/tests/metagpt/serialize_deserialize/test_write_code.py index 0b1f1dc7c..9d659caaf 100644 --- a/tests/metagpt/serialize_deserialize/test_write_code.py +++ b/tests/metagpt/serialize_deserialize/test_write_code.py @@ -7,18 +7,21 @@ import pytest from metagpt.actions import WriteCode, WriteCodeReview from metagpt.llm import LLM + def test_write_design_serialize(): action = WriteCode() ser_action_dict = action.dict() assert ser_action_dict["name"] == "WriteCode" assert "llm" in ser_action_dict + def test_write_task_serialize(): action = WriteCodeReview() ser_action_dict = action.dict() assert ser_action_dict["name"] == "WriteCodeReview" assert "llm" in ser_action_dict - + + @pytest.mark.asyncio async def test_write_code_deserialize(): action = WriteCode() @@ -29,6 +32,7 @@ async def test_write_code_deserialize(): assert new_action.llm == LLM() await new_action.run(context="write a cli snake game", filename="test_code") + @pytest.mark.asyncio async def test_write_code_review_deserialize(): action = WriteCodeReview() diff --git a/tests/metagpt/serialize_deserialize/test_write_design.py b/tests/metagpt/serialize_deserialize/test_write_design.py index 56bf78a63..e6e236676 100644 --- a/tests/metagpt/serialize_deserialize/test_write_design.py +++ b/tests/metagpt/serialize_deserialize/test_write_design.py @@ -7,18 +7,21 @@ import pytest from metagpt.actions import WriteDesign, WriteTasks from metagpt.llm import LLM + def test_write_design_serialize(): action = WriteDesign() ser_action_dict = action.dict() assert "name" in ser_action_dict assert "llm" in ser_action_dict + def test_write_task_serialize(): action = WriteTasks() ser_action_dict = action.dict() assert "name" in ser_action_dict assert "llm" in ser_action_dict + @pytest.mark.asyncio async def test_write_design_deserialize(): action = WriteDesign() @@ -28,6 +31,7 @@ async def test_write_design_deserialize(): assert new_action.llm == LLM() await new_action.run(context="write a cli snake game") + @pytest.mark.asyncio async def test_write_task_deserialize(): action = WriteTasks() @@ -36,4 +40,4 @@ async def test_write_task_deserialize(): # new_action = WriteTasks().deserialize(serialized_data) assert new_action.name == "CreateTasks" assert new_action.llm == LLM() - await new_action.run(context="write a cli snake game") \ No newline at end of file + await new_action.run(context="write a cli snake game") From 9e5c873d77754f24a7b36be0e697975d30efed04 Mon Sep 17 00:00:00 2001 From: better629 Date: Thu, 30 Nov 2023 15:10:38 +0800 Subject: [PATCH 05/14] update unittest of ser&deser --- tests/metagpt/actions/test_action.py | 17 --- tests/metagpt/roles/test_role.py | 84 +----------- .../serialize_deserialize/test_action.py | 49 +++++++ .../serialize_deserialize/test_actions.py | 26 ---- .../test_architect_deserialize.py | 2 +- .../serialize_deserialize/test_environment.py | 91 +++++++++++++ .../test_memory.py | 34 ++++- .../test_product_manager.py | 4 +- .../test_project_manager.py | 6 +- .../serialize_deserialize/test_role.py | 63 ++++++++- .../serialize_deserialize/test_schema.py | 49 +++++++ .../test_serdeser_base.py | 88 +++++++++++++ .../serialize_deserialize/test_team.py | 124 +++++++++++++----- .../serialize_deserialize/test_wrire_prd.py | 1 - .../serialize_deserialize/test_write_code.py | 2 +- tests/metagpt/test_environment.py | 44 +++---- tests/metagpt/test_role.py | 14 -- tests/metagpt/test_schema.py | 4 +- tests/metagpt/test_team.py | 22 +--- 19 files changed, 496 insertions(+), 228 deletions(-) create mode 100644 tests/metagpt/serialize_deserialize/test_action.py delete mode 100644 tests/metagpt/serialize_deserialize/test_actions.py create mode 100644 tests/metagpt/serialize_deserialize/test_environment.py rename tests/metagpt/{memory => serialize_deserialize}/test_memory.py (52%) create mode 100644 tests/metagpt/serialize_deserialize/test_schema.py create mode 100644 tests/metagpt/serialize_deserialize/test_serdeser_base.py delete mode 100644 tests/metagpt/test_role.py diff --git a/tests/metagpt/actions/test_action.py b/tests/metagpt/actions/test_action.py index 4468a6f6f..9775630cc 100644 --- a/tests/metagpt/actions/test_action.py +++ b/tests/metagpt/actions/test_action.py @@ -11,20 +11,3 @@ from metagpt.actions import Action, WritePRD, WriteTest def test_action_repr(): actions = [Action(), WriteTest(), WritePRD()] assert "WriteTest" in str(actions) - - -def test_action_serdes(): - action_info = WriteTest.ser_class() - assert action_info["action_class"] == "WriteTest" - - action_class = Action.deser_class(action_info) - assert action_class == WriteTest - - -def test_action_class_serdes(): - name = "write test" - action_info = WriteTest(name=name).serialize() - assert action_info["name"] == name - - action = Action.deserialize(action_info) - assert action.name == name diff --git a/tests/metagpt/roles/test_role.py b/tests/metagpt/roles/test_role.py index a19ad9cb5..72cd84a9a 100644 --- a/tests/metagpt/roles/test_role.py +++ b/tests/metagpt/roles/test_role.py @@ -2,84 +2,10 @@ # -*- coding: utf-8 -*- # @Desc : unittest of Role -from pathlib import Path -import shutil -import pytest - -from metagpt.roles.role import Role, RoleReactMode -from metagpt.actions.action import Action -from metagpt.schema import Message -from metagpt.actions.add_requirement import BossRequirement -from metagpt.roles.product_manager import ProductManager - -serdes_path = Path(__file__).absolute().parent.joinpath("../../data/serdes_storage") +from metagpt.roles.role import Role -def test_role_serdes(): - stg_path_prefix = serdes_path.joinpath("team/environment/roles/") - shutil.rmtree(serdes_path.joinpath("team"), ignore_errors=True) - - pm = ProductManager() - role_tag = f"{pm.__class__.__name__}_{pm.name}" - stg_path = stg_path_prefix.joinpath(role_tag) - pm.serialize(stg_path) - assert stg_path.joinpath("actions/actions_info.json").exists() - - new_pm = Role.deserialize(stg_path) - assert new_pm.name == pm.name - assert len(new_pm.get_memories(1)) == 0 - - -class ActionOK(Action): - - async def run(self, messages: list["Message"]): - return "ok" - - -class ActionRaise(Action): - - async def run(self, messages: list["Message"]): - raise RuntimeError("parse error") - - -class RoleA(Role): - - def __init__(self, - name: str = "RoleA", - profile: str = "Role A", - goal: str = "", - constraints: str = ""): - super(RoleA, self).__init__(name=name, profile=profile, goal=goal, constraints=constraints) - self._init_actions([ActionOK, ActionRaise]) - self._watch([BossRequirement]) - self._rc.react_mode = RoleReactMode.BY_ORDER - - async def run(self, message: "Message" = None, stg_path: str = None): - try: - await super(RoleA, self).run(message) - except Exception as exp: - print("exp ", exp) - self.serialize(stg_path) - - -@pytest.mark.asyncio -async def test_role_serdes_interrupt(): - role_a = RoleA() - shutil.rmtree(serdes_path.joinpath("team"), ignore_errors=True) - - stg_path = serdes_path.joinpath(f"team/environment/roles/{role_a.__class__.__name__}_{role_a.name}") - await role_a.run( - message=Message(content="demo", cause_by=BossRequirement), - stg_path=stg_path - ) - assert role_a._rc.memory.count() == 2 - - assert stg_path.joinpath("actions/todo.json").exists() - - new_role_a: Role = Role.deserialize(stg_path) - assert new_role_a._rc.state == 1 - await role_a.run( - message=Message(content="demo", cause_by=BossRequirement), - stg_path=stg_path - ) - +def test_role_desc(): + role = Role(profile="Sales", desc="Best Seller") + assert role.profile == "Sales" + assert role._setting.desc == "Best Seller" diff --git a/tests/metagpt/serialize_deserialize/test_action.py b/tests/metagpt/serialize_deserialize/test_action.py new file mode 100644 index 000000000..b624dff5a --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_action.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +# @Date : 11/22/2023 11:48 AM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +import pytest + +from metagpt.actions import Action, WritePRD, WriteTest +from metagpt.llm import LLM +from metagpt.provider.openai_api import OpenAIGPTAPI + + +def test_action_serialize(): + action = Action() + ser_action_dict = action.dict() + assert "name" in ser_action_dict + assert "llm" in ser_action_dict + + +@pytest.mark.asyncio +async def test_action_deserialize(): + action = Action() + serialized_data = action.dict() + assert isinstance(serialized_data["llm"], OpenAIGPTAPI) + + new_action = Action(**serialized_data) + + assert new_action.name == "" + assert new_action.llm == LLM() + assert len(await new_action._aask("who are you")) > 0 + + +def test_action_serdeser(): + action_info = WriteTest.ser_class() + assert action_info["action_class"] == "WriteTest" + + action_class = Action.deser_class(action_info) + assert action_class == WriteTest + + +def test_action_class_serdeser(): + name = "write test" + action_info = WriteTest(name=name).serialize() + assert action_info["name"] == name + + action_info = WriteTest(name=name, llm=LLM()).serialize() + assert action_info["name"] == name + + action = Action.deserialize(action_info) + assert action.name == name diff --git a/tests/metagpt/serialize_deserialize/test_actions.py b/tests/metagpt/serialize_deserialize/test_actions.py deleted file mode 100644 index 2fec2121a..000000000 --- a/tests/metagpt/serialize_deserialize/test_actions.py +++ /dev/null @@ -1,26 +0,0 @@ -# -*- coding: utf-8 -*- -# @Date : 11/22/2023 11:48 AM -# @Author : stellahong (stellahong@fuzhi.ai) -# @Desc : -import pytest - -from metagpt.actions import Action -from metagpt.llm import LLM - - -def test_action_serialize(): - action = Action() - ser_action_dict = action.dict() - assert "name" in ser_action_dict - assert "llm" in ser_action_dict - - -@pytest.mark.asyncio -async def test_action_deserialize(): - action = Action() - serialized_data = action.dict() - - new_action = Action(**serialized_data) - assert new_action.name == "" - assert new_action.llm == LLM() - assert len(await new_action._aask("who are you")) > 0 diff --git a/tests/metagpt/serialize_deserialize/test_architect_deserialize.py b/tests/metagpt/serialize_deserialize/test_architect_deserialize.py index d0ee3bc99..fb58f0a3a 100644 --- a/tests/metagpt/serialize_deserialize/test_architect_deserialize.py +++ b/tests/metagpt/serialize_deserialize/test_architect_deserialize.py @@ -25,4 +25,4 @@ async def test_architect_deserialize(): assert new_role.name == "Bob" assert len(new_role._actions) == 1 assert isinstance(new_role._actions[0], Action) - await new_role._actions[0].run(context="write a cli snake game") \ No newline at end of file + await new_role._actions[0].run(context="write a cli snake game") diff --git a/tests/metagpt/serialize_deserialize/test_environment.py b/tests/metagpt/serialize_deserialize/test_environment.py new file mode 100644 index 000000000..15336eb6a --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_environment.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +from pathlib import Path +import shutil + +from metagpt.schema import Message +from metagpt.actions.action_output import ActionOutput +from metagpt.roles.project_manager import ProjectManager +from metagpt.actions.add_requirement import BossRequirement +from metagpt.actions.project_management import WriteTasks +from metagpt.environment import Environment +from tests.metagpt.serialize_deserialize.test_serdeser_base import RoleC, ActionOK, serdeser_path + + +def test_env_serialize(): + env = Environment() + ser_env_dict = env.dict() + assert "roles" in ser_env_dict + assert "memory" in ser_env_dict + + +def test_env_deserialize(): + env = Environment() + env.publish_message(message=Message(content="test env serialize")) + ser_env_dict = env.dict() + new_env = Environment(**ser_env_dict) + assert len(new_env.roles) == 0 + assert new_env.memory.storage[0].content == "test env serialize" + assert len(new_env.history) == 25 + + +def test_environment_serdeser(): + out_mapping = {"field1": (list[str], ...)} + out_data = {"field1": ["field1 value1", "field1 value2"]} + ic_obj = ActionOutput.create_model_class("prd", out_mapping) + + message = Message( + content="prd", + instruct_content=ic_obj(**out_data), + role="product manager", + cause_by=BossRequirement + ) + + environment = Environment() + role_c = RoleC() + environment.add_role(role_c) + environment.publish_message(message) + + ser_data = environment.dict() + assert ser_data["roles"]["Role C"]["name"] == "RoleC" + + new_env: Environment = Environment(**ser_data) + assert len(new_env.roles) == 1 + + assert new_env.memory.count() == 1 + assert new_env.memory.storage[0].instruct_content == ic_obj(**out_data) + assert list(new_env.roles.values())[0]._states == list(environment.roles.values())[0]._states + assert list(new_env.roles.values())[0]._actions == list(environment.roles.values())[0]._actions + assert isinstance(list(environment.roles.values())[0]._actions[0], ActionOK) + assert type(list(new_env.roles.values())[0]._actions[0]) == ActionOK + + +def test_environment_serdeser_v2(): + environment = Environment() + pm = ProjectManager() + environment.add_role(pm) + + ser_data = environment.dict() + + new_env: Environment = Environment(**ser_data) + role = new_env.get_role(pm.profile) + assert isinstance(role, ProjectManager) + assert isinstance(role._actions[0], WriteTasks) + assert isinstance(list(new_env.roles.values())[0]._actions[0], WriteTasks) + + +def test_environment_serdeser_save(): + environment = Environment() + role_c = RoleC() + + shutil.rmtree(serdeser_path.joinpath("team"), ignore_errors=True) + + stg_path = serdeser_path.joinpath("team/environment") + environment.add_role(role_c) + environment.serialize(stg_path) + + new_env: Environment = Environment.deserialize(stg_path) + assert len(new_env.roles) == 1 + assert type(list(new_env.roles.values())[0]._actions[0]) == ActionOK diff --git a/tests/metagpt/memory/test_memory.py b/tests/metagpt/serialize_deserialize/test_memory.py similarity index 52% rename from tests/metagpt/memory/test_memory.py rename to tests/metagpt/serialize_deserialize/test_memory.py index bda79ded1..e24f31af3 100644 --- a/tests/metagpt/memory/test_memory.py +++ b/tests/metagpt/serialize_deserialize/test_memory.py @@ -3,6 +3,7 @@ # @Desc : unittest of memory from pathlib import Path +from pydantic import BaseModel from metagpt.schema import Message from metagpt.memory.memory import Memory @@ -10,10 +11,36 @@ from metagpt.actions.action_output import ActionOutput from metagpt.actions.design_api import WriteDesign from metagpt.actions.add_requirement import BossRequirement -serdes_path = Path(__file__).absolute().parent.joinpath("../../data/serdes_storage") +from tests.metagpt.serialize_deserialize.test_serdeser_base import serdeser_path -def test_memory_serdes(): +def test_memory_serdeser(): + msg1 = Message(role="Boss", + content="write a snake game", + cause_by=BossRequirement) + + out_mapping = {"field2": (list[str], ...)} + out_data = {"field2": ["field2 value1", "field2 value2"]} + ic_obj = ActionOutput.create_model_class("system_design", out_mapping) + msg2 = Message(role="Architect", + instruct_content=ic_obj(**out_data), + content="system design content", + cause_by=WriteDesign) + + memory = Memory() + memory.add_batch([msg1, msg2]) + ser_data = memory.dict() + + new_memory = Memory(**ser_data) + assert new_memory.count() == 2 + new_msg2 = new_memory.get(2)[0] + assert isinstance(new_msg2, BaseModel) + assert isinstance(new_memory.storage[-1], BaseModel) + assert new_memory.storage[-1].cause_by == WriteDesign + assert new_msg2.role == "Boss" + + +def test_memory_serdeser_save(): msg1 = Message(role="User", content="write a 2048 game", cause_by=BossRequirement) @@ -29,7 +56,7 @@ def test_memory_serdes(): memory = Memory() memory.add_batch([msg1, msg2]) - stg_path = serdes_path.joinpath("team/environment") + stg_path = serdeser_path.joinpath("team/environment") memory.serialize(stg_path) assert stg_path.joinpath("memory.json").exists() @@ -38,5 +65,6 @@ def test_memory_serdes(): new_msg2 = new_memory.get(1)[0] assert new_msg2.instruct_content.field1 == ["field1 value1", "field1 value2"] assert new_msg2.cause_by == WriteDesign + assert len(new_memory.index) == 2 stg_path.joinpath("memory.json").unlink() diff --git a/tests/metagpt/serialize_deserialize/test_product_manager.py b/tests/metagpt/serialize_deserialize/test_product_manager.py index 2aed87a28..54584cf96 100644 --- a/tests/metagpt/serialize_deserialize/test_product_manager.py +++ b/tests/metagpt/serialize_deserialize/test_product_manager.py @@ -15,8 +15,8 @@ async def test_product_manager_deserialize(): ser_role_dict = role.dict(by_alias=True) new_role = ProductManager(**ser_role_dict) # new_role = ProductManager().deserialize(ser_role_dict) - + assert new_role.name == "Alice" assert len(new_role._actions) == 1 assert isinstance(new_role._actions[0], Action) - await new_role._actions[0].run([Message(content="write a cli snake game")]) \ No newline at end of file + await new_role._actions[0].run([Message(content="write a cli snake game")]) diff --git a/tests/metagpt/serialize_deserialize/test_project_manager.py b/tests/metagpt/serialize_deserialize/test_project_manager.py index fbc0dcc08..21fafa72e 100644 --- a/tests/metagpt/serialize_deserialize/test_project_manager.py +++ b/tests/metagpt/serialize_deserialize/test_project_manager.py @@ -6,6 +6,7 @@ import pytest from metagpt.roles.project_manager import ProjectManager from metagpt.actions.action import Action +from metagpt.actions.project_management import WriteTasks def test_project_manager_serialize(): @@ -20,9 +21,10 @@ def test_project_manager_serialize(): async def test_project_manager_deserialize(): role = ProjectManager() ser_role_dict = role.dict(by_alias=True) + new_role = ProjectManager(**ser_role_dict) - # new_role = ProjectManager().deserialize(ser_role_dict) assert new_role.name == "Eve" assert len(new_role._actions) == 1 assert isinstance(new_role._actions[0], Action) - await new_role._actions[0].run(context="write a cli snake game") \ No newline at end of file + assert isinstance(new_role._actions[0], WriteTasks) + # await new_role._actions[0].run(context="write a cli snake game") diff --git a/tests/metagpt/serialize_deserialize/test_role.py b/tests/metagpt/serialize_deserialize/test_role.py index 0e438d1a2..f260dea3a 100644 --- a/tests/metagpt/serialize_deserialize/test_role.py +++ b/tests/metagpt/serialize_deserialize/test_role.py @@ -2,12 +2,22 @@ # @Date : 11/23/2023 4:49 PM # @Author : stellahong (stellahong@fuzhi.ai) # @Desc : + +from pathlib import Path +import shutil import pytest +from metagpt.logs import logger from metagpt.roles.role import Role +from metagpt.actions import WriteCode, WriteCodeReview +from metagpt.schema import Message +from metagpt.actions.add_requirement import BossRequirement +from metagpt.roles.product_manager import ProductManager +from metagpt.const import SERDESER_PATH from metagpt.roles.engineer import Engineer +from metagpt.utils.utils import format_trackback_info -from metagpt.actions.action import Action +from tests.metagpt.serialize_deserialize.test_serdeser_base import RoleC, serdeser_path def test_role_serialize(): @@ -30,12 +40,53 @@ def test_engineer_serialize(): async def test_engineer_deserialize(): role = Engineer(use_code_review=True) ser_role_dict = role.dict(by_alias=True) - # new_role = Engineer().deserialize(ser_role_dict) - # also can be deserialized in this way: + new_role = Engineer(**ser_role_dict) assert new_role.name == "Alex" assert new_role.use_code_review is True assert len(new_role._actions) == 2 - assert isinstance(new_role._actions[0], Action) - assert isinstance(new_role._actions[1], Action) - await new_role._actions[0].run(context="write a cli snake game", filename="test_code") + assert isinstance(new_role._actions[0], WriteCode) + assert isinstance(new_role._actions[1], WriteCodeReview) + # await new_role._actions[0].run(context="write a cli snake game", filename="test_code") + + +def test_role_serdeser_save(): + stg_path_prefix = serdeser_path.joinpath("team/environment/roles/") + shutil.rmtree(serdeser_path.joinpath("team"), ignore_errors=True) + + pm = ProductManager() + role_tag = f"{pm.__class__.__name__}_{pm.name}" + stg_path = stg_path_prefix.joinpath(role_tag) + pm.serialize(stg_path) + assert stg_path.joinpath("actions/actions_info.json").exists() + + new_pm = Role.deserialize(stg_path) + assert new_pm.name == pm.name + assert len(new_pm.get_memories(1)) == 0 + + +@pytest.mark.asyncio +async def test_role_serdeser_interrupt(): + role_c = RoleC() + shutil.rmtree(SERDESER_PATH.joinpath("team"), ignore_errors=True) + + stg_path = SERDESER_PATH.joinpath(f"team/environment/roles/{role_c.__class__.__name__}_{role_c.name}") + try: + await role_c.run( + message=Message(content="demo", cause_by=BossRequirement) + ) + except Exception as exp: + logger.error(f"Exception in `role_a.run`, detail: {format_trackback_info()}") + role_c.serialize(stg_path) + + assert role_c._rc.memory.count() == 2 + + assert stg_path.joinpath("actions/todo.json").exists() + + new_role_a: Role = Role.deserialize(stg_path) + assert new_role_a._rc.state == 1 + + with pytest.raises(Exception): + await role_c.run( + message=Message(content="demo", cause_by=BossRequirement) + ) diff --git a/tests/metagpt/serialize_deserialize/test_schema.py b/tests/metagpt/serialize_deserialize/test_schema.py new file mode 100644 index 000000000..74b134cad --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_schema.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : unittest of schema ser&deser + +from metagpt.schema import Message +from metagpt.actions.action_output import ActionOutput +from metagpt.actions.write_code import WriteCode + +from tests.metagpt.serialize_deserialize.test_serdeser_base import MockMessage + + +def test_message_serdeser(): + out_mapping = {"field3": (str, ...), "field4": (list[str], ...)} + out_data = {"field3": "field3 value3", "field4": ["field4 value1", "field4 value2"]} + ic_obj = ActionOutput.create_model_class("code", out_mapping) + + message = Message( + content="code", + instruct_content=ic_obj(**out_data), + role="engineer", + cause_by=WriteCode + ) + ser_data = message.dict() + assert ser_data["cause_by"] == { + "action_class": "WriteCode", + "module_name": "metagpt.actions.write_code" + } + assert ser_data["instruct_content"]["class"] == "code" + + new_message = Message(**ser_data) + assert new_message.cause_by == WriteCode + assert new_message.cause_by in [WriteCode] + assert new_message.instruct_content == ic_obj(**out_data) + + +def test_message_without_postprocess(): + """ to explain `instruct_content` should be postprocessed """ + out_mapping = {"field1": (list[str], ...)} + out_data = {"field1": ["field1 value1", "field1 value2"]} + ic_obj = ActionOutput.create_model_class("code", out_mapping) + message = MockMessage( + content="code", + instruct_content=ic_obj(**out_data) + ) + ser_data = message.dict() + assert ser_data["instruct_content"] == {"field1": ["field1 value1", "field1 value2"]} + + new_message = MockMessage(**ser_data) + assert new_message.instruct_content != ic_obj(**out_data) diff --git a/tests/metagpt/serialize_deserialize/test_serdeser_base.py b/tests/metagpt/serialize_deserialize/test_serdeser_base.py new file mode 100644 index 000000000..35bad6cd9 --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_serdeser_base.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : base test actions / roles used in unittest + +from pydantic import BaseModel, Field +from pathlib import Path + +from metagpt.actions.action import Action +from metagpt.roles.role import Role, RoleReactMode +from metagpt.actions.add_requirement import BossRequirement + + +serdeser_path = Path(__file__).absolute().parent.joinpath("../../data/serdeser_storage") + + +class MockMessage(BaseModel): + """ to test normal dict without postprocess """ + content: str = "" + instruct_content: BaseModel = Field(default=None) + + +class ActionPass(Action): + name: str = "ActionPass" + + async def run(self, messages: list["Message"]): + return "pass" + + +class ActionOK(Action): + name: str = "ActionOK" + + async def run(self, messages: list["Message"]): + return "ok" + + +class ActionRaise(Action): + name: str = "ActionRaise" + + async def run(self, messages: list["Message"]): + raise RuntimeError("parse error in ActionRaise") + + +class RoleA(Role): + + name: str = Field(default="RoleA") + profile: str = Field(default="Role A") + goal: str = "RoleA's goal" + constraints: str = "RoleA's constraints" + + def __init__(self, **kwargs): + super(RoleA, self).__init__(**kwargs) + self._init_actions([ActionPass]) + self._watch([BossRequirement]) + + async def run(self, message: "Message" = None): + await super(RoleA, self).run(message) + + +class RoleB(Role): + name: str = Field(default="RoleB") + profile: str = Field(default="Role B") + goal: str = "RoleB's goal" + constraints: str = "RoleB's constraints" + + def __init__(self, **kwargs): + super(RoleB, self).__init__(**kwargs) + self._init_actions([ActionOK, ActionRaise]) + self._watch([ActionPass]) + self._rc.react_mode = RoleReactMode.BY_ORDER + + async def run(self, message: "Message" = None): + await super(RoleB, self).run(message) + + +class RoleC(Role): + name: str = Field(default="RoleC") + profile: str = Field(default="Role C") + goal: str = "RoleC's goal" + constraints: str = "RoleC's constraints" + + def __init__(self, **kwargs): + super(RoleC, self).__init__(**kwargs) + self._init_actions([ActionOK, ActionRaise]) + self._watch([BossRequirement]) + self._rc.react_mode = RoleReactMode.BY_ORDER + + async def run(self, message: "Message" = None): + await super(RoleC, self).run(message) diff --git a/tests/metagpt/serialize_deserialize/test_team.py b/tests/metagpt/serialize_deserialize/test_team.py index 44a75d262..e9122ebc0 100644 --- a/tests/metagpt/serialize_deserialize/test_team.py +++ b/tests/metagpt/serialize_deserialize/test_team.py @@ -2,46 +2,104 @@ # @Date : 11/27/2023 10:07 AM # @Author : stellahong (stellahong@fuzhi.ai) # @Desc : + +from pathlib import Path +import shutil import pytest -from metagpt.environment import Environment -from metagpt.schema import Message -from metagpt.software_company import SoftwareCompany from metagpt.roles import ProjectManager, ProductManager, Architect +from metagpt.team import Team +from metagpt.const import SERDESER_PATH + +from tests.metagpt.serialize_deserialize.test_serdeser_base import RoleA, RoleB, RoleC, serdeser_path -def test_env_serialize(): - env = Environment() - ser_env_dict = env.dict() - assert "roles" in ser_env_dict - assert "memory" in ser_env_dict - assert "memory" in ser_env_dict +def test_team_deserialize(): + company = Team() - -def test_env_deserialize(): - env = Environment() - env.publish_message(message=Message(content="test env serialize")) - ser_env_dict = env.dict() - new_env = Environment(**ser_env_dict) - assert len(new_env.roles) == 0 - assert new_env.memory.storage[0].content == "test env serialize" - assert len(new_env.history) == 25 - - -def test_softwarecompany_deserialize(): - team = SoftwareCompany() - team.hire( + pm = ProductManager() + arch = Architect() + company.hire( [ - ProductManager(), - Architect(), + pm, + arch, ProjectManager(), ] ) - assert len(team.environment.get_roles()) == 3 - ser_team_dict = team.dict() - new_team = SoftwareCompany(**ser_team_dict) - - assert len(new_team.environment.get_roles()) == 3 - assert new_team.environment.get_role('Product Manager') is not None - assert new_team.environment.get_role('Product Manager') is not None - assert new_team.environment.get_role('Architect') is not None + assert len(company.environment.get_roles()) == 3 + ser_company = company.dict() + new_company = Team(**ser_company) + + assert len(new_company.environment.get_roles()) == 3 + assert new_company.environment.get_role(pm.profile) is not None + + new_pm = new_company.environment.get_role(pm.profile) + assert type(new_pm) == ProductManager + assert new_company.environment.get_role(pm.profile) is not None + assert new_company.environment.get_role(arch.profile) is not None + + +def test_team_serdeser(): + company = Team() + company.hire([RoleC()]) + + stg_path = serdeser_path.joinpath("team") + shutil.rmtree(stg_path, ignore_errors=True) + + company.serialize(stg_path=stg_path) + + new_company = Team.deserialize(stg_path) + + assert len(new_company.environment.roles) == 1 + + +@pytest.mark.asyncio +async def test_team_recover(): + idea = "write a snake game" + stg_path = SERDESER_PATH.joinpath("team") + shutil.rmtree(stg_path, ignore_errors=True) + + company = Team() + company.hire([RoleC()]) + company.start_project(idea) + await company.run(n_round=4) + + ser_data = company.dict() + new_company = Team(**ser_data) + assert new_company.environment.memory.count() == 1 + assert type(list(new_company.environment.roles.values())[0]._actions[0]) == ActionOK + + new_company.start_project(idea) + await new_company.run(n_round=4) + + +@pytest.mark.asyncio +async def test_team_recover_save(): + idea = "write a 2048 web game" + stg_path = SERDESER_PATH.joinpath("team") + shutil.rmtree(stg_path, ignore_errors=True) + + company = Team() + company.hire([RoleC()]) + company.start_project(idea) + await company.run(n_round=4) + + new_company = Team.recover(stg_path) + new_company.start_project(idea) + await new_company.run(n_round=4) + + +@pytest.mark.asyncio +async def test_team_recover_multi_roles_save(): + idea = "write a snake game" + stg_path = SERDESER_PATH.joinpath("team") + shutil.rmtree(stg_path, ignore_errors=True) + + company = Team() + company.hire([RoleA(), RoleB()]) + company.start_project(idea) + await company.run(n_round=4) + + new_company = Team.recover(stg_path) + new_company.start_project(idea) + await new_company.run(n_round=4) diff --git a/tests/metagpt/serialize_deserialize/test_wrire_prd.py b/tests/metagpt/serialize_deserialize/test_wrire_prd.py index baa08ed76..96b4d19ad 100644 --- a/tests/metagpt/serialize_deserialize/test_wrire_prd.py +++ b/tests/metagpt/serialize_deserialize/test_wrire_prd.py @@ -25,4 +25,3 @@ async def test_action_deserialize(): assert new_action.name == "" assert new_action.llm == LLM() assert len(await new_action.run([Message(content="write a cli snake game")])) > 0 - diff --git a/tests/metagpt/serialize_deserialize/test_write_code.py b/tests/metagpt/serialize_deserialize/test_write_code.py index 9d659caaf..7f4799014 100644 --- a/tests/metagpt/serialize_deserialize/test_write_code.py +++ b/tests/metagpt/serialize_deserialize/test_write_code.py @@ -43,4 +43,4 @@ async def test_write_code_review_deserialize(): assert new_action.name == "WriteCodeReview" assert new_action.llm == LLM() - await new_action.run(context="write a cli snake game", code =code, filename="test_rewrite_code") \ No newline at end of file + await new_action.run(context="write a cli snake game", code=code, filename="test_rewrite_code") diff --git a/tests/metagpt/test_environment.py b/tests/metagpt/test_environment.py index 3cc2d8a7a..9f69e6189 100644 --- a/tests/metagpt/test_environment.py +++ b/tests/metagpt/test_environment.py @@ -8,17 +8,15 @@ import pytest from pathlib import Path -import shutil from metagpt.actions import BossRequirement from metagpt.environment import Environment from metagpt.logs import logger from metagpt.roles import Architect, ProductManager, Role from metagpt.schema import Message -from tests.metagpt.roles.test_role import RoleA -serdes_path = Path(__file__).absolute().parent.joinpath("../data/serdes_storage") +serdeser_path = Path(__file__).absolute().parent.joinpath("../data/serdeser_storage") @pytest.fixture @@ -27,14 +25,23 @@ def env(): def test_add_role(env: Environment): - role = ProductManager("Alice", "product manager", "create a new product", "limited resources") + role = ProductManager(name="Alice", + profile="product manager", + goal="create a new product", + constraints="limited resources") env.add_role(role) assert env.get_role(role.profile) == role def test_get_roles(env: Environment): - role1 = Role("Alice", "product manager", "create a new product", "limited resources") - role2 = Role("Bob", "engineer", "develop the new product", "short deadline") + role1 = Role(name="Alice", + profile="product manager", + goal="create a new product", + constraints="limited resources") + role2 = Role(name="Bob", + profile="engineer", + goal="develop the new product", + constraints="short deadline") env.add_role(role1) env.add_role(role2) roles = env.get_roles() @@ -43,8 +50,14 @@ def test_get_roles(env: Environment): @pytest.mark.asyncio async def test_publish_and_process_message(env: Environment): - product_manager = ProductManager("Alice", "Product Manager", "做AI Native产品", "资源有限") - architect = Architect("Bob", "Architect", "设计一个可用、高效、较低成本的系统,包括数据结构与接口", "资源有限,需要节省成本") + product_manager = ProductManager(name="Alice", + profile="Product Manager", + goal="做AI Native产品", + constraints="资源有限") + architect = Architect(name="Bob", + profile="Architect", + goal="设计一个可用、高效、较低成本的系统,包括数据结构与接口", + constraints="资源有限,需要节省成本") env.add_roles([product_manager, architect]) env.publish_message(Message(role="BOSS", content="需要一个基于LLM做总结的搜索引擎", cause_by=BossRequirement)) @@ -52,18 +65,3 @@ async def test_publish_and_process_message(env: Environment): await env.run(k=2) logger.info(f"{env.history=}") assert len(env.history) > 10 - - -def test_environment_serdes(): - environment = Environment() - role_a = RoleA() - - shutil.rmtree(serdes_path.joinpath("team"), ignore_errors=True) - - stg_path = serdes_path.joinpath("team/environment") - environment.add_role(role_a) - environment.serialize(stg_path) - - new_env: Environment = Environment() - new_env.deserialize(stg_path) - assert len(new_env.roles) == 1 diff --git a/tests/metagpt/test_role.py b/tests/metagpt/test_role.py deleted file mode 100644 index 11fd804ec..000000000 --- a/tests/metagpt/test_role.py +++ /dev/null @@ -1,14 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -@Time : 2023/5/11 14:44 -@Author : alexanderwu -@File : test_role.py -""" -from metagpt.roles import Role - - -def test_role_desc(): - i = Role(profile='Sales', desc='Best Seller') - assert i.profile == 'Sales' - assert i._setting.desc == 'Best Seller' diff --git a/tests/metagpt/test_schema.py b/tests/metagpt/test_schema.py index f515326e8..c70c93cfc 100644 --- a/tests/metagpt/test_schema.py +++ b/tests/metagpt/test_schema.py @@ -25,7 +25,7 @@ def test_messages(): assert all([i in text for i in roles]) -def test_message_serdes(): +def test_message_serdeser(): out_mapping = {"field3": (str, ...), "field4": (list[str], ...)} out_data = {"field3": "field3 value3", "field4": ["field4 value1", "field4 value2"]} ic_obj = ActionOutput.create_model_class("code", out_mapping) @@ -37,7 +37,7 @@ def test_message_serdes(): cause_by=WriteCode ) message_dict = serialize_general_message(message) - assert message_dict["cause_by"] == {"action_class": "WriteCode"} + assert message_dict["cause_by"] == {"action_class": "WriteCode", "module_name": "metagpt.actions.write_code"} assert message_dict["instruct_content"] == { "class": "code", "mapping": { diff --git a/tests/metagpt/test_team.py b/tests/metagpt/test_team.py index ab201152c..efd035bb2 100644 --- a/tests/metagpt/test_team.py +++ b/tests/metagpt/test_team.py @@ -2,26 +2,12 @@ # -*- coding: utf-8 -*- # @Desc : unittest of team -from pathlib import Path -import shutil - from metagpt.team import Team - -from tests.metagpt.roles.test_role import RoleA - -serdes_path = Path(__file__).absolute().parent.joinpath("../data/serdes_storage") +from metagpt.roles.project_manager import ProjectManager -def test_team_serdes(): +def test_team(): company = Team() - company.hire([RoleA()]) + company.hire([ProjectManager()]) - stg_path = serdes_path.joinpath("team") - shutil.rmtree(stg_path, ignore_errors=True) - - company.serialize(stg_path=stg_path) - - new_company = Team() - new_company.deserialize(stg_path) - - assert len(new_company.environment.roles) == 1 + assert len(company.environment.roles) == 1 From 5e3607f85bc4fec0ff97c57ff7d866f108e3c9c3 Mon Sep 17 00:00:00 2001 From: better629 Date: Thu, 30 Nov 2023 15:18:24 +0800 Subject: [PATCH 06/14] update environment/message to BaseModel, update the ser&deser of roles/actions --- metagpt/actions/action.py | 29 ++++- metagpt/actions/design_api.py | 8 +- metagpt/actions/project_management.py | 3 +- metagpt/actions/search_and_summarize.py | 15 ++- metagpt/actions/write_code.py | 3 +- metagpt/actions/write_code_review.py | 4 +- metagpt/actions/write_prd.py | 6 +- metagpt/actions/write_test.py | 11 +- metagpt/const.py | 2 +- metagpt/environment.py | 39 +++++-- metagpt/memory/longterm_memory.py | 14 ++- metagpt/memory/memory.py | 79 ++++++++++---- metagpt/roles/architect.py | 4 +- metagpt/roles/customer_service.py | 19 ++-- metagpt/roles/engineer.py | 4 +- metagpt/roles/product_manager.py | 5 +- metagpt/roles/project_manager.py | 4 +- metagpt/roles/qa_engineer.py | 16 +-- metagpt/roles/role.py | 130 +++++++++++++--------- metagpt/roles/sales.py | 31 +++--- metagpt/roles/seacher.py | 21 ++-- metagpt/schema.py | 138 ++++++++++++++---------- metagpt/team.py | 39 ++++--- metagpt/utils/serialize.py | 26 +++-- metagpt/utils/utils.py | 43 ++++++++ startup.py | 17 +-- 26 files changed, 458 insertions(+), 252 deletions(-) diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index aefe6d39d..7a7f194f4 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -5,8 +5,9 @@ @Author : alexanderwu @File : action.py """ + +from __future__ import annotations import re -from abc import ABC from typing import Optional, Any from pydantic import BaseModel, Field @@ -14,25 +15,43 @@ from tenacity import retry, stop_after_attempt, wait_fixed from metagpt.actions.action_output import ActionOutput from metagpt.llm import LLM +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.logs import logger from metagpt.utils.common import OutputParser from metagpt.utils.custom_decoder import CustomDecoder from metagpt.utils.utils import import_class +action_subclass_registry = {} + + class Action(BaseModel): name: str = "" - llm: LLM = Field(default_factory=LLM) + llm: BaseGPTAPI = Field(default_factory=LLM, exclude=True) context = "" prefix = "" profile = "" desc = "" content: Optional[str] = None instruct_content: Optional[str] = None + + # builtin variables + builtin_class_name: str = "" + + class Config: + arbitrary_types_allowed = True def __init__(self, **kwargs: Any): super().__init__(**kwargs) - + + # deserialize child classes dynamically for inherited `action` + object.__setattr__(self, "builtin_class_name", self.__class__.__name__) + self.__fields__["builtin_class_name"].default = self.__class__.__name__ + + def __init_subclass__(cls, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + action_subclass_registry[cls.__name__] = cls + def set_prefix(self, prefix, profile): """Set prefix for later usage""" self.prefix = prefix @@ -52,14 +71,14 @@ class Action(BaseModel): } @classmethod - def deserialize(cls, action_dict: dict): + def deserialize(cls, action_dict: dict) -> "Action": action_class_str = action_dict.pop("action_class") module_name = action_dict.pop("module_name") action_class = import_class(action_class_str, module_name) return action_class(**action_dict) @classmethod - def ser_class(cls): + def ser_class(cls) -> dict: """ serialize class type""" return { "action_class": cls.__name__, diff --git a/metagpt/actions/design_api.py b/metagpt/actions/design_api.py index 30df70ce7..015678baa 100644 --- a/metagpt/actions/design_api.py +++ b/metagpt/actions/design_api.py @@ -13,6 +13,7 @@ from pydantic import Field from metagpt.actions import Action, ActionOutput from metagpt.llm import LLM +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.config import CONFIG from metagpt.const import WORKSPACE_ROOT from metagpt.logs import logger @@ -155,12 +156,11 @@ OUTPUT_MAPPING = { class WriteDesign(Action): name: str = "" context: Optional[str] = None - llm: LLM = Field(default_factory=LLM) + llm: BaseGPTAPI = Field(default_factory=LLM) desc: str = "Based on the PRD, think about the system design, and design the corresponding APIs, " "data structures, library tables, processes, and paths. Please provide your design, feedback " "clearly and in detail." - def recreate_workspace(self, workspace: Path): try: shutil.rmtree(workspace) @@ -168,7 +168,6 @@ class WriteDesign(Action): pass # Folder does not exist, but we don't care workspace.mkdir(parents=True, exist_ok=True) - async def _save_prd(self, docs_path, resources_path, context): prd_file = docs_path / "prd.md" if context[-1].instruct_content and context[-1].instruct_content.dict()["Competitive Quadrant Chart"]: @@ -179,7 +178,6 @@ class WriteDesign(Action): logger.info(f"Saving PRD to {prd_file}") prd_file.write_text(json_to_markdown(context[-1].instruct_content.dict())) - async def _save_system_design(self, docs_path, resources_path, system_design): data_api_design = system_design.instruct_content.dict()[ "Data structures and interface definitions" @@ -193,7 +191,6 @@ class WriteDesign(Action): logger.info(f"Saving System Designs to {system_design_file}") system_design_file.write_text((json_to_markdown(system_design.instruct_content.dict()))) - async def _save(self, context, system_design): if isinstance(system_design, ActionOutput): ws_name = system_design.instruct_content.dict()["Python package name"] @@ -211,7 +208,6 @@ class WriteDesign(Action): logger.error(f"Failed to save PRD {e}") await self._save_system_design(docs_path, resources_path, system_design) - async def run(self, context, format=CONFIG.prompt_format): prompt_template, format_example = get_template(templates, format) prompt = prompt_template.format(context=context, format_example=format_example) diff --git a/metagpt/actions/project_management.py b/metagpt/actions/project_management.py index b72507ee3..cf44906cd 100644 --- a/metagpt/actions/project_management.py +++ b/metagpt/actions/project_management.py @@ -11,6 +11,7 @@ from pydantic import Field from metagpt.actions.action import Action from metagpt.llm import LLM +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.config import CONFIG from metagpt.const import WORKSPACE_ROOT from metagpt.utils.common import CodeParser @@ -168,7 +169,7 @@ OUTPUT_MAPPING = { class WriteTasks(Action): name: str = "CreateTasks" context: Optional[str] = None - llm: LLM = Field(default_factory=LLM) + llm: BaseGPTAPI = Field(default_factory=LLM) def _save(self, context, rsp): try: diff --git a/metagpt/actions/search_and_summarize.py b/metagpt/actions/search_and_summarize.py index 0580303e6..6b0c1f717 100644 --- a/metagpt/actions/search_and_summarize.py +++ b/metagpt/actions/search_and_summarize.py @@ -8,14 +8,15 @@ import pydantic from typing import Optional, Any from pydantic import BaseModel, Field +from pydantic import root_validator from metagpt.actions import Action from metagpt.llm import LLM -from metagpt.config import Config +from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.config import Config, CONFIG from metagpt.logs import logger from metagpt.schema import Message from metagpt.tools.search_engine import SearchEngine -from pydantic import root_validator SEARCH_AND_SUMMARIZE_SYSTEM = """### Requirements 1. Please summarize the latest dialogue based on the reference information (secondary) and dialogue history (primary). Do not include text that is irrelevant to the conversation. @@ -106,13 +107,13 @@ You are a member of a professional butler team and will provide helpful suggesti class SearchAndSummarize(Action): name: str = "" content: Optional[str] = None - llm: None = Field(default_factory=LLM) + llm: BaseGPTAPI = Field(default_factory=LLM) config: None = Field(default_factory=Config) - engine: Optional[str] = None + engine: Optional[str] = CONFIG.search_engine search_func: Optional[str] = None + search_engine: SearchEngine = None result = "" - @root_validator def validate_engine_and_run_func(cls, values): @@ -130,9 +131,7 @@ class SearchAndSummarize(Action): values['search_engine'] = search_engine return values - - - + async def run(self, context: list[Message], system_text=SEARCH_AND_SUMMARIZE_SYSTEM) -> str: print(context) if self.search_engine is None: diff --git a/metagpt/actions/write_code.py b/metagpt/actions/write_code.py index 2dc240591..10487e53a 100644 --- a/metagpt/actions/write_code.py +++ b/metagpt/actions/write_code.py @@ -13,6 +13,7 @@ from tenacity import retry, stop_after_attempt, wait_fixed from metagpt.actions import WriteDesign from metagpt.actions.action import Action from metagpt.llm import LLM +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.const import WORKSPACE_ROOT from metagpt.logs import logger from metagpt.schema import Message @@ -50,7 +51,7 @@ ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenc class WriteCode(Action): name: str = "WriteCode" context: Optional[str] = None - llm: LLM = Field(default_factory=LLM) + llm: BaseGPTAPI = Field(default_factory=LLM) def _is_invalid(self, filename): return any(i in filename for i in ["mp3", "wav"]) diff --git a/metagpt/actions/write_code_review.py b/metagpt/actions/write_code_review.py index 3d86d7c63..79e462f76 100644 --- a/metagpt/actions/write_code_review.py +++ b/metagpt/actions/write_code_review.py @@ -12,7 +12,7 @@ from tenacity import retry, stop_after_attempt, wait_fixed from metagpt.llm import LLM from metagpt.actions.action import Action from metagpt.logs import logger -from metagpt.schema import Message +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.utils.common import CodeParser PROMPT_TEMPLATE = """ @@ -67,7 +67,7 @@ FORMAT_EXAMPLE = """ class WriteCodeReview(Action): name: str = "WriteCodeReview" context: Optional[str] = None - llm: LLM = Field(default_factory=LLM) + llm: BaseGPTAPI = Field(default_factory=LLM) @retry(stop=stop_after_attempt(2), wait=wait_fixed(1)) async def write_code(self, prompt): diff --git a/metagpt/actions/write_prd.py b/metagpt/actions/write_prd.py index 660d7fb95..450bed7e7 100644 --- a/metagpt/actions/write_prd.py +++ b/metagpt/actions/write_prd.py @@ -11,6 +11,7 @@ from pydantic import BaseModel, Field from metagpt.actions import Action, ActionOutput from metagpt.llm import LLM +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.actions.search_and_summarize import SearchAndSummarize from metagpt.config import CONFIG from metagpt.logs import logger @@ -224,12 +225,9 @@ OUTPUT_MAPPING = { class WritePRD(Action): name: str = "" content: Optional[str] = None - llm: LLM = Field(default_factory=LLM) + llm: BaseGPTAPI = Field(default_factory=LLM) assistant_search_action: Action = None - def __init__(self, **kwargs): - super().__init__(**kwargs) - async def run(self, requirements, format=CONFIG.prompt_format, *args, **kwargs) -> ActionOutput: # self.assistant_search_action = SearchAndSummarize() if self.assistant_search_action is None: diff --git a/metagpt/actions/write_test.py b/metagpt/actions/write_test.py index 35ff36dc2..6c902444a 100644 --- a/metagpt/actions/write_test.py +++ b/metagpt/actions/write_test.py @@ -5,6 +5,12 @@ @Author : alexanderwu @File : environment.py """ + +from typing import Optional +from pydantic import Field + +from metagpt.llm import LLM +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.actions.action import Action from metagpt.logs import logger from metagpt.utils.common import CodeParser @@ -31,8 +37,9 @@ you should correctly import the necessary classes based on these file locations! class WriteTest(Action): - def __init__(self, name="WriteTest", context=None, llm=None): - super().__init__(name, context, llm) + name: str = "WriteTest" + context: Optional[str] = None + llm: BaseGPTAPI = Field(default_factory=LLM) async def write_code(self, prompt): code_rsp = await self._aask(prompt) diff --git a/metagpt/const.py b/metagpt/const.py index 711546d03..4b063a3dd 100644 --- a/metagpt/const.py +++ b/metagpt/const.py @@ -42,7 +42,7 @@ TMP = PROJECT_ROOT / "tmp" RESEARCH_PATH = DATA_PATH / "research" TUTORIAL_PATH = DATA_PATH / "tutorial_docx" INVOICE_OCR_TABLE_PATH = DATA_PATH / "invoice_table" -SERDES_PATH = WORKSPACE_ROOT / "storage" # TODO to store `storage` under the individual generated project +SERDESER_PATH = WORKSPACE_ROOT / "storage" # TODO to store `storage` under the individual generated project SKILL_DIRECTORY = PROJECT_ROOT / "metagpt/skills" diff --git a/metagpt/environment.py b/metagpt/environment.py index e867ad6fc..bade53f50 100644 --- a/metagpt/environment.py +++ b/metagpt/environment.py @@ -12,7 +12,7 @@ from pathlib import Path from pydantic import BaseModel, Field from metagpt.memory import Memory -from metagpt.roles import Role +from metagpt.roles.role import Role, role_subclass_registry from metagpt.schema import Message from metagpt.utils.utils import read_json_file, write_json_file @@ -30,6 +30,19 @@ class Environment(BaseModel): class Config: arbitrary_types_allowed = True + def __init__(self, **kwargs): + for role_key, role in kwargs.get("roles", {}).items(): + current_role = kwargs["roles"][role_key] + if isinstance(current_role, dict): + item_class_name = current_role.get("builtin_class_name", None) + for name, subclass in role_subclass_registry.items(): + registery_class_name = subclass.__fields__["builtin_class_name"].default + if item_class_name == registery_class_name: + current_role = subclass(**current_role) + break + kwargs["roles"][role_key] = current_role + super().__init__(**kwargs) + def serialize(self, stg_path: Path): roles_path = stg_path.joinpath("roles.json") roles_info = [] @@ -46,33 +59,39 @@ class Environment(BaseModel): history_path = stg_path.joinpath("history.json") write_json_file(history_path, {"content": self.history}) - def deserialize(self, stg_path: Path): + @classmethod + def deserialize(cls, stg_path: Path) -> "Environment": """ stg_path: ./storage/team/environment/ """ roles_path = stg_path.joinpath("roles.json") roles_info = read_json_file(roles_path) + roles = [] for role_info in roles_info: role_class = role_info.get("role_class") role_name = role_info.get("role_name") role_path = stg_path.joinpath(f"roles/{role_class}_{role_name}") role = Role.deserialize(role_path) - - self.add_role(role) + roles.append(role) memory = Memory.deserialize(stg_path) - self.memory = memory - history_path = stg_path.joinpath("history.json") - history = read_json_file(history_path) - self.history = history.get("content") + history = read_json_file(stg_path.joinpath("history.json")) + history = history.get("content") + + environment = Environment(**{ + "memory": memory, + "history": history + }) + environment.add_roles(roles) + return environment def add_role(self, role: Role): - """增加一个在当前环境的角色, 默认为profile/role_profile + """增加一个在当前环境的角色, 默认为profile Add a role in the current environment """ role.set_env(self) # use alias - self.roles[role.role_profile] = role + self.roles[role.profile] = role def add_roles(self, roles: Iterable[Role]): """增加一批在当前环境的角色 diff --git a/metagpt/memory/longterm_memory.py b/metagpt/memory/longterm_memory.py index f8abea5f3..5d149ee7a 100644 --- a/metagpt/memory/longterm_memory.py +++ b/metagpt/memory/longterm_memory.py @@ -2,6 +2,9 @@ # -*- coding: utf-8 -*- # @Desc : the implement of Long-term memory +from typing import Optional +from pydantic import Field + from metagpt.logs import logger from metagpt.memory import Memory from metagpt.memory.memory_storage import MemoryStorage @@ -15,11 +18,12 @@ class LongTermMemory(Memory): - update memory when it changed """ - def __init__(self): - self.memory_storage: MemoryStorage = MemoryStorage() - super(LongTermMemory, self).__init__() - self.rc = None # RoleContext - self.msg_from_recover = False + memory_storage: MemoryStorage = Field(default_factory=MemoryStorage) + rc: Optional["RoleContext"] = None + msg_from_recover: bool = False + + class Config: + arbitrary_types_allowed = True def recover_memory(self, role_id: str, rc: "RoleContext"): messages = self.memory_storage.recover_memory(role_id) diff --git a/metagpt/memory/memory.py b/metagpt/memory/memory.py index a839bb038..c88cc750e 100644 --- a/metagpt/memory/memory.py +++ b/metagpt/memory/memory.py @@ -5,34 +5,65 @@ @Author : alexanderwu @File : memory.py """ +import copy from collections import defaultdict -from typing import Iterable, Type +from typing import Iterable, Type, Union, Optional from pathlib import Path +from pydantic import BaseModel, Field +import json from metagpt.actions import Action from metagpt.schema import Message from metagpt.utils.utils import read_json_file, write_json_file -from metagpt.utils.serialize import serialize_general_message, deserialize_general_message +from metagpt.utils.utils import import_class -class Memory: +class Memory(BaseModel): """The most basic memory: super-memory""" - def __init__(self): - """Initialize an empty storage list and an empty index dictionary""" - self.storage: list[Message] = [] - self.index: dict[Type[Action], list[Message]] = defaultdict(list) + storage: list[Message] = Field(default=[]) + index: dict[Type[Action], list[Message]] = Field(default_factory=defaultdict(list)) + + def __init__(self, **kwargs): + index = kwargs.get("index", {}) + new_index = defaultdict(list) + for action_str, value in index.items(): + action_dict = json.loads(action_str) + action_class = import_class("Action", "metagpt.actions.action") + action_obj = action_class.deser_class(action_dict) + new_index[action_obj] = [Message(**item_dict) for item_dict in value] + kwargs["index"] = new_index + super(Memory, self).__init__(**kwargs) + self.index = new_index + + def dict(self, + *, + include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + by_alias: bool = False, + skip_defaults: Optional[bool] = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False) -> "DictStrAny": + """ overwrite the `dict` to dump dynamic pydantic model""" + obj_dict = super(Memory, self).dict(include=include, + exclude=exclude, + by_alias=by_alias, + skip_defaults=skip_defaults, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none) + new_obj_dict = copy.deepcopy(obj_dict) + new_obj_dict["index"] = {} + for action, value in obj_dict["index"].items(): + action_ser = json.dumps(action.ser_class()) + new_obj_dict["index"][action_ser] = value + return new_obj_dict def serialize(self, stg_path: Path): """ stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/ """ memory_path = stg_path.joinpath("memory.json") - - storage = [] - for message in self.storage: - # msg_dict = message.serialize() - msg_dict = serialize_general_message(message) - storage.append(msg_dict) - + storage = self.dict() write_json_file(memory_path, storage) @classmethod @@ -40,13 +71,8 @@ class Memory: """ stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/""" memory_path = stg_path.joinpath("memory.json") - memory = Memory() - memory_list = read_json_file(memory_path) - for message in memory_list: - # distinguish instruct_content type in message - # msg = Message.deserialize(message) - msg = deserialize_general_message(message) - memory.add(msg) + memory_dict = read_json_file(memory_path) + memory = Memory(**memory_dict) return memory @@ -70,6 +96,16 @@ class Memory: """Return all messages containing a specified content""" return [message for message in self.storage if content in message.content] + def delete_newest(self) -> "Message": + """ delete the newest message from the storage""" + if len(self.storage) > 0: + newest_msg = self.storage.pop() + if newest_msg.cause_by and newest_msg in self.index[newest_msg.cause_by]: + self.index[newest_msg.cause_by].remove(newest_msg) + else: + newest_msg = None + return newest_msg + def delete(self, message: Message): """Delete the specified message from storage, while updating the index""" self.storage.remove(message) @@ -115,4 +151,3 @@ class Memory: continue rsp += self.index[action] return rsp - \ No newline at end of file diff --git a/metagpt/roles/architect.py b/metagpt/roles/architect.py index face22a68..09d52edbe 100644 --- a/metagpt/roles/architect.py +++ b/metagpt/roles/architect.py @@ -22,8 +22,8 @@ class Architect(Role): goal (str): Primary goal or responsibility of the architect. constraints (str): Constraints or guidelines for the architect. """ - name: str = "Bob" - role_profile: str = Field(default="Architect" , alias='profile') + name: str = Field(default="Bob") + profile: str = Field(default="Architect") goal: str = "Design a concise, usable, complete python system" constraints: str = "Try to specify good open source tools as much as possible" diff --git a/metagpt/roles/customer_service.py b/metagpt/roles/customer_service.py index 4547f8190..62792696f 100644 --- a/metagpt/roles/customer_service.py +++ b/metagpt/roles/customer_service.py @@ -5,6 +5,9 @@ @Author : alexanderwu @File : sales.py """ +from typing import Optional +from pydantic import Field + from metagpt.roles import Sales # from metagpt.actions import SearchAndSummarize @@ -24,12 +27,14 @@ DESC = """ class CustomerService(Sales): + + name: str = Field(default="Xiaomei") + profile: str = Field(default="Human customer service") + desc: str = DESC, + + store: Optional[str] = None + def __init__( self, - name="Xiaomei", - profile="Human customer service", - desc=DESC, - store=None - ): - super().__init__(name, profile, desc=desc, store=store) - \ No newline at end of file + **kwargs): + super().__init__(**kwargs) diff --git a/metagpt/roles/engineer.py b/metagpt/roles/engineer.py index 129bedeb8..e90f586f0 100644 --- a/metagpt/roles/engineer.py +++ b/metagpt/roles/engineer.py @@ -60,8 +60,8 @@ class Engineer(Role): use_code_review (bool): Whether to use code review. todos (list): List of tasks. """ - name: str = "Alex" - role_profile: str = Field(default="Engineer", alias='profile') + name: str = Field(default="Alex") + profile: str = Field(default="Engineer") goal: str = "Write elegant, readable, extensible, efficient code" constraints: str = "The code should conform to standards like PEP8 and be modular and maintainable" n_borg: int = 1 diff --git a/metagpt/roles/product_manager.py b/metagpt/roles/product_manager.py index b099fb4d9..6f68fe5ba 100644 --- a/metagpt/roles/product_manager.py +++ b/metagpt/roles/product_manager.py @@ -21,10 +21,11 @@ class ProductManager(Role): goal (str): Goal of the product manager. constraints (str): Constraints or limitations for the product manager. """ - name: str = "Alice" - role_profile: str = Field(default="Product Manager", alias='profile') + name: str = Field(default="Alice") + profile: str = Field(default="Product Manager") goal: str = "Efficiently create a successful product" constraints: str = "" + """ Represents a Product Manager role responsible for product development and management. """ diff --git a/metagpt/roles/project_manager.py b/metagpt/roles/project_manager.py index a2b227f22..c8e785d85 100644 --- a/metagpt/roles/project_manager.py +++ b/metagpt/roles/project_manager.py @@ -22,8 +22,8 @@ class ProjectManager(Role): goal (str): Goal of the project manager. constraints (str): Constraints or limitations for the project manager. """ - name: str = "Eve" - role_profile: str = Field(default="Project Manager", alias='profile') + name: str = Field(default="Eve") + profile: str = Field(default="Project Manager") goal: str = "Improve team efficiency and deliver with quality and quantity" constraints: str = "" diff --git a/metagpt/roles/qa_engineer.py b/metagpt/roles/qa_engineer.py index a763c2ce8..bad3f2409 100644 --- a/metagpt/roles/qa_engineer.py +++ b/metagpt/roles/qa_engineer.py @@ -7,6 +7,7 @@ """ import os from pathlib import Path +from pydantic import Field from metagpt.actions import ( DebugError, @@ -25,21 +26,22 @@ from metagpt.utils.special_tokens import FILENAME_CODE_SEP, MSG_SEP class QaEngineer(Role): + name: str = Field(default="Edward") + profile: str = Field(default="QaEngineer") + goal: str = "Write comprehensive and robust tests to ensure codes will work as expected without bugs" + constraints: str = "The test code you write should conform to code standard like PEP8, be modular, easy to read and maintain" + test_round_allowed: int = 5 + def __init__( self, - name="Edward", - profile="QaEngineer", - goal="Write comprehensive and robust tests to ensure codes will work as expected without bugs", - constraints="The test code you write should conform to code standard like PEP8, be modular, easy to read and maintain", - test_round_allowed=5, + **kwargs ): - super().__init__(name, profile, goal, constraints) + super().__init__(**kwargs) self._init_actions( [WriteTest] ) # FIXME: a bit hack here, only init one action to circumvent _think() logic, will overwrite _think() in future updates self._watch([WriteCode, WriteCodeReview, WriteTest, RunCode, DebugError]) self.test_round = 0 - self.test_round_allowed = test_round_allowed @classmethod def parse_workspace(cls, system_design_msg: Message) -> str: diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index e9371c2c0..b6332aa4c 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -6,27 +6,29 @@ @File : role.py """ +from __future__ import annotations from enum import Enum from pathlib import Path -from __future__ import annotations from typing import ( Iterable, - Type + Type, + Any ) -import re -from pydantic import BaseModel, Field -from importlib import import_module +from pydantic import BaseModel, Field, validator # from metagpt.environment import Environment from metagpt.config import CONFIG -from metagpt.actions import Action, ActionOutput +from metagpt.actions.action import Action, ActionOutput, action_subclass_registry from metagpt.llm import LLM +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.logs import logger from metagpt.memory import Memory, LongTermMemory from metagpt.schema import Message from metagpt.provider.human_provider import HumanProvider -from metagpt.utils.utils import read_json_file, write_json_file, import_class +from metagpt.utils.utils import read_json_file, write_json_file, import_class, role_raise_decorator +from metagpt.const import SERDESER_PATH + PREFIX_TEMPLATE = """You are a {profile}, named {name}, your goal is {goal}, and the constraint is {constraints}. """ @@ -57,6 +59,7 @@ ROLE_TEMPLATE = """Your response should be based on the previous conversation hi {name}: {result} """ + class RoleReactMode(str, Enum): REACT = "react" BY_ORDER = "by_order" @@ -74,6 +77,7 @@ class RoleSetting(BaseModel): goal: str = "" constraints: str = "" desc: str = "" + is_human: bool = False def __str__(self): return f"{self.name}({self.profile})" @@ -84,10 +88,10 @@ class RoleSetting(BaseModel): class RoleContext(BaseModel): """Role Runtime Context""" - env: 'Environment' = Field(default=None) + env: "Environment" = Field(default=None) memory: Memory = Field(default_factory=Memory) long_term_memory: LongTermMemory = Field(default_factory=LongTermMemory) - state: int = Field(default=0) + state: int = Field(default=-1) # -1 indicates initial or termination state where todo is None todo: Action = Field(default=None) watch: set[Type[Action]] = Field(default_factory=set) news: list[Type[Message]] = Field(default=[]) @@ -112,53 +116,86 @@ class RoleContext(BaseModel): return self.memory.get() +role_subclass_registry = {} + + class Role(BaseModel): """Role/Agent""" - name: str = "" profile: str = "" goal: str = "" constraints: str = "" desc: str = "" - _setting: RoleSetting = Field(default_factory=RoleSetting, alias="_setting") - _setting = RoleSetting(name=name, profile=profile, goal=goal, constraints=constraints) + is_human: bool = False + + _llm: BaseGPTAPI = Field(default_factory=LLM, exclude=True) + _setting: RoleSetting = Field(default_factory=RoleSetting, alias=True) _role_id: str = "" - _states: list = Field(default=[]) - _actions: list = Field(default=[]) - _actions_type: list = Field(default=[]) + _states: list[str] = Field(default=[]) + _actions: list[Action] = Field(default=[]) _rc: RoleContext = RoleContext() - + + # builtin variables + recovered: bool = False # to tag if a recovered role + builtin_class_name: str = "" + _private_attributes = { - "_setting": _setting, + "_llm": LLM() if not is_human else HumanProvider(), "_role_id": _role_id, "_states": [], - "_actions": [], - "_actions_type": [] # 用于记录和序列化 + "_actions": [] } - + class Config: arbitrary_types_allowed = True - - def __init__(self, **kwargs): + exclude = ["_llm"] + + def __init__(self, **kwargs: Any): + for index in range(len(kwargs.get("_actions", []))): + current_action = kwargs["_actions"][index] + if isinstance(current_action, dict): + item_class_name = current_action.get("builtin_class_name", None) + for name, subclass in action_subclass_registry.items(): + registery_class_name = subclass.__fields__["builtin_class_name"].default + if item_class_name == registery_class_name: + current_action = subclass(**current_action) + break + kwargs["_actions"][index] = current_action + super().__init__(**kwargs) + # 关于私有变量的初始化 https://github.com/pydantic/pydantic/issues/655 + self._private_attributes["_llm"] = LLM() if not self.is_human else HumanProvider() + self._private_attributes["_setting"] = RoleSetting(name=self.name, profile=self.profile, goal=self.goal, + desc=self.desc, constraints=self.constraints, + is_human=self.is_human) for key in self._private_attributes.keys(): if key in kwargs: object.__setattr__(self, key, kwargs[key]) - if key =="_setting": - _setting = RoleSetting(**kwargs[key]) - object.__setattr__(self, '_setting', _setting) + if key == "_setting": + setting = RoleSetting(**kwargs[key]) + object.__setattr__(self, "_setting", setting) elif key == "_rc": _rc = RoleContext - object.__setattr__(self, '_rc', _rc) + object.__setattr__(self, "_rc", _rc) else: object.__setattr__(self, key, self._private_attributes[key]) + + # deserialize child classes dynamically for inherited `role` + object.__setattr__(self, "builtin_class_name", self.__class__.__name__) + self.__fields__["builtin_class_name"].default = self.__class__.__name__ + + def __init_subclass__(cls, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + role_subclass_registry[cls.__name__] = cls def _reset(self): - object.__setattr__(self, '_states', []) - object.__setattr__(self, '_actions', []) + object.__setattr__(self, "_states", []) + object.__setattr__(self, "_actions", []) - def serialize(self, stg_path: Path): + def serialize(self, stg_path: Path = None): + stg_path = SERDESER_PATH.joinpath(f"team/environment/roles/{self.__class__.__name__}_{self.name}") \ + if stg_path is None else stg_path role_info_path = stg_path.joinpath("role_info.json") role_info = { "role_class": self.__class__.__name__, @@ -207,7 +244,7 @@ class Role(BaseModel): actions = [] actions_info = read_json_file(actions_info_path) for action_info in actions_info: - action = Action.deserialize(action_info) + action = Action.deser_class(action_info) actions.append(action) watches_info_path = stg_path.joinpath("watches/watches_info.json") @@ -238,12 +275,8 @@ class Role(BaseModel): return role - def _reset(self): - self._states = [] - self._actions = [] - def set_recovered(self, recovered: bool = False): - self._recovered = recovered + self.recovered = recovered def set_memory(self, memory: Memory): self._rc.memory = memory @@ -256,7 +289,8 @@ class Role(BaseModel): for idx, action in enumerate(actions): if not isinstance(action, Action): ## 默认初始化 - i = action("", llm=self._llm) + # import pdb; pdb.set_trace() + i = action(name="", llm=self._llm) else: if self._setting.is_human and not isinstance(action.llm, HumanProvider): logger.warning(f"is_human attribute does not take effect," @@ -265,8 +299,6 @@ class Role(BaseModel): i.set_prefix(self._get_prefix(), self.profile) self._actions.append(i) self._states.append(f"{idx}. {action}") - action_title = action.schema()["title"] - self._actions_type.append(action_title) def set_react_mode(self, react_mode: RoleReactMode, max_react_loop: int = 1): self._set_react_mode(react_mode, max_react_loop) @@ -310,19 +342,10 @@ class Role(BaseModel): logger.debug(self._actions) self._rc.todo = self._actions[self._rc.state] if state >= 0 else None - def set_env(self, env: 'Environment'): + def set_env(self, env: "Environment"): """Set the environment in which the role works. The role can talk to the environment and can also receive messages by observing.""" self._rc.env = env - @property - def name(self): - return self._setting.name - - @property - def profile(self): - """Get the role description (position)""" - return self._setting.profile - def _get_prefix(self): """Get the role prefix""" if self._setting.desc: @@ -347,7 +370,7 @@ class Role(BaseModel): logger.debug(f"{prompt=}") if (not next_state.isdigit() and next_state != "-1") \ or int(next_state) not in range(-1, len(self._states)): - logger.warning(f'Invalid answer of state, {next_state=}, will be set to -1') + logger.warning(f"Invalid answer of state, {next_state=}, will be set to -1") next_state = -1 else: next_state = int(next_state) @@ -384,7 +407,7 @@ class Role(BaseModel): news_text = [f"{i.role}: {i.content[:20]}..." for i in self._rc.news] if news_text: - logger.debug(f'{self._setting} observed: {news_text}') + logger.debug(f"{self._setting} observed: {news_text}") return len(self._rc.news) def _publish_message(self, msg): @@ -400,7 +423,7 @@ class Role(BaseModel): Use llm to select actions in _think dynamically """ actions_taken = 0 - rsp = Message("No actions taken yet") # will be overwritten after Role _act + rsp = Message(content="No actions taken yet") # will be overwritten after Role _act while actions_taken < self._rc.max_react_loop: # think await self._think() @@ -410,7 +433,7 @@ class Role(BaseModel): logger.debug(f"{self._setting}: {self._rc.state=}, will do {self._rc.todo}") rsp = await self._act() actions_taken += 1 - return rsp # return output from the last action + return rsp # return output from the last action async def _act_by_order(self) -> Message: """switch action each time by order defined in _init_actions, i.e. _act (Action1) -> _act (Action2) -> ...""" @@ -454,7 +477,8 @@ class Role(BaseModel): def get_memories(self, k=0) -> list[Message]: """A wrapper to return the most recent k memories of this role, return all when k=0""" return self._rc.memory.get(k=k) - + + @role_raise_decorator async def run(self, message=None): """Observe, and think and act based on the results of the observation""" if message: diff --git a/metagpt/roles/sales.py b/metagpt/roles/sales.py index a45ad6f1b..dd360d82a 100644 --- a/metagpt/roles/sales.py +++ b/metagpt/roles/sales.py @@ -5,26 +5,34 @@ @Author : alexanderwu @File : sales.py """ + +from typing import Optional +from pydantic import Field + from metagpt.actions import SearchAndSummarize from metagpt.roles import Role from metagpt.tools import SearchEngineType class Sales(Role): + + name: str = Field(default="Xiaomei") + profile: str = Field(default="Retail sales guide") + desc: str = "I am a sales guide in retail. My name is Xiaomei. I will answer some customer questions next, and I " + "will answer questions only based on the information in the knowledge base." + "If I feel that you can't get the answer from the reference material, then I will directly reply that" + " I don't know, and I won't tell you that this is from the knowledge base," + "but pretend to be what I know. Note that each of my replies will be replied in the tone of a " + "professional guide", + + store: Optional[str] = None + def __init__( self, - name="Xiaomei", - profile="Retail sales guide", - desc="I am a sales guide in retail. My name is Xiaomei. I will answer some customer questions next, and I " - "will answer questions only based on the information in the knowledge base." - "If I feel that you can't get the answer from the reference material, then I will directly reply that" - " I don't know, and I won't tell you that this is from the knowledge base," - "but pretend to be what I know. Note that each of my replies will be replied in the tone of a " - "professional guide", - store=None + **kwargs ): - super().__init__(name, profile, desc=desc) - self._set_store(store) + super().__init__(**kwargs) + self._set_store(self.store) def _set_store(self, store): if store: @@ -32,4 +40,3 @@ class Sales(Role): else: action = SearchAndSummarize() self._init_actions([action]) - \ No newline at end of file diff --git a/metagpt/roles/seacher.py b/metagpt/roles/seacher.py index 0b6e089da..e8f291d0d 100644 --- a/metagpt/roles/seacher.py +++ b/metagpt/roles/seacher.py @@ -5,6 +5,9 @@ @Author : alexanderwu @File : seacher.py """ + +from pydantic import Field + from metagpt.actions import ActionOutput, SearchAndSummarize from metagpt.logs import logger from metagpt.roles import Role @@ -23,14 +26,14 @@ class Searcher(Role): constraints (str): Constraints or limitations for the searcher. engine (SearchEngineType): The type of search engine to use. """ + + name: str = Field(default="Alice") + profile: str = Field(default="Smart Assistant") + goal: str = "Provide search services for users" + constraints: str = "Answer is rich and complete" + engine: SearchEngineType = SearchEngineType.SERPAPI_GOOGLE - def __init__(self, - name: str = 'Alice', - profile: str = 'Smart Assistant', - goal: str = 'Provide search services for users', - constraints: str = 'Answer is rich and complete', - engine=SearchEngineType.SERPAPI_GOOGLE, - **kwargs) -> None: + def __init__(self, **kwargs) -> None: """ Initializes the Searcher role with given attributes. @@ -41,8 +44,8 @@ class Searcher(Role): constraints (str): Constraints or limitations for the searcher. engine (SearchEngineType): The type of search engine to use. """ - super().__init__(name, profile, goal, constraints, **kwargs) - self._init_actions([SearchAndSummarize(engine=engine)]) + super().__init__(**kwargs) + self._init_actions([SearchAndSummarize(engine=self.engine)]) def set_search_func(self, search_func): """Sets a custom search function for the searcher.""" diff --git a/metagpt/schema.py b/metagpt/schema.py index 3374a7241..60aa819b0 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -5,18 +5,17 @@ @Author : alexanderwu @File : schema.py """ -from __future__ import annotations from dataclasses import dataclass, field -from typing import Type, TypedDict -import copy +from typing import Type, TypedDict, Union, Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field +from pydantic.main import ModelMetaclass from metagpt.logs import logger -# from metagpt.utils.serialize import actionoutout_schema_to_mapping -# from metagpt.actions.action_output import ActionOutput -# from metagpt.actions.action import Action +from metagpt.utils.serialize import actionoutout_schema_to_mapping, actionoutput_mapping_to_str, \ + actionoutput_str_to_mapping +from metagpt.utils.utils import import_class class RawMessage(TypedDict): @@ -24,16 +23,72 @@ class RawMessage(TypedDict): role: str -@dataclass -class Message: - """list[: ]""" - content: str - instruct_content: BaseModel = field(default=None) - role: str = field(default='user') # system / user / assistant - cause_by: Type["Action"] = field(default="") - sent_from: str = field(default="") - send_to: str = field(default="") - restricted_to: str = field(default="") +class Message(BaseModel): + content: str = "" + instruct_content: BaseModel = Field(default=None) + role: str = "user" # system / user / assistant + cause_by: Type["Action"] = Field(default=None) + sent_from: str = "" + send_to: str = "" + restricted_to: str = "" + + def __init__(self, **kwargs): + instruct_content = kwargs.get("instruct_content", None) + cause_by = kwargs.get("cause_by", None) + if instruct_content and not isinstance(instruct_content, BaseModel): + ic = instruct_content + mapping = actionoutput_str_to_mapping(ic["mapping"]) + + actionoutput_class = import_class("ActionOutput", "metagpt.actions.action_output") + ic_obj = actionoutput_class.create_model_class(class_name=ic["class"], mapping=mapping) + ic_new = ic_obj(**ic["value"]) + kwargs["instruct_content"] = ic_new + if cause_by and not isinstance(cause_by, ModelMetaclass): + action_class = import_class("Action", "metagpt.actions.action") + kwargs["cause_by"] = action_class.deser_class(cause_by) + super(Message, self).__init__(**kwargs) + + def dict(self, + *, + include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + by_alias: bool = False, + skip_defaults: Optional[bool] = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False) -> "DictStrAny": + """ overwrite the `dict` to dump dynamic pydantic model""" + obj_dict = super(Message, self).dict(include=include, + exclude=exclude, + by_alias=by_alias, + skip_defaults=skip_defaults, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none) + ic = self.instruct_content # deal custom-defined action + if ic: + schema = ic.schema() + mapping = actionoutout_schema_to_mapping(schema) + mapping = actionoutput_mapping_to_str(mapping) + + obj_dict["instruct_content"] = {"class": schema["title"], "mapping": mapping, "value": ic.dict()} + cb = self.cause_by + if cb: + obj_dict["cause_by"] = cb.ser_class() + return obj_dict + +# +# +# @dataclass +# class Message: +# """list[: ]""" +# content: str +# instruct_content: BaseModel = field(default=None) +# role: str = field(default='user') # system / user / assistant +# cause_by: Type["Action"] = field(default="") +# sent_from: str = field(default="") +# send_to: str = field(default="") +# restricted_to: str = field(default="") def __str__(self): # prefix = '-'.join([self.role, str(self.cause_by)]) @@ -42,45 +97,16 @@ class Message: def __repr__(self): return self.__str__() - # def serialize(self): - # message_cp: Message = copy.deepcopy(self) - # ic = message_cp.instruct_content - # if ic: - # # model create by pydantic create_model like `pydantic.main.prd`, can't pickle.dump directly - # schema = ic.schema() - # mapping = actionoutout_schema_to_mapping(schema) - # - # message_cp.instruct_content = {"class": schema["title"], "mapping": mapping, "value": ic.dict()} - # cb = message_cp.cause_by - # if cb: - # message_cp.cause_by = cb.serialize() - # - # return message_cp.dict() - # - # @classmethod - # def deserialize(cls, message_dict: dict): - # instruct_content = message_dict.get("instruct_content") - # if instruct_content: - # ic = instruct_content - # ic_obj = ActionOutput.create_model_class(class_name=ic["class"], mapping=ic["mapping"]) - # ic_new = ic_obj(**ic["value"]) - # message_dict.instruct_content = ic_new - # cause_by = message_dict.get("cause_by") - # if cause_by: - # message_dict.cause_by = Action.deserialize(cause_by) - # - # return Message(**message_dict) - - def dict(self): - return { - "content": self.content, - "instruct_content": self.instruct_content, - "role": self.role, - "cause_by": self.cause_by, - "sent_from": self.sent_from, - "send_to": self.send_to, - "restricted_to": self.restricted_to - } + # def dict(self): + # return { + # "content": self.content, + # "instruct_content": self.instruct_content, + # "role": self.role, + # "cause_by": self.cause_by, + # "sent_from": self.sent_from, + # "send_to": self.send_to, + # "restricted_to": self.restricted_to + # } def to_dict(self) -> dict: return { diff --git a/metagpt/team.py b/metagpt/team.py index 3b76e5ff4..795019b92 100644 --- a/metagpt/team.py +++ b/metagpt/team.py @@ -15,7 +15,8 @@ from metagpt.logs import logger from metagpt.roles import Role from metagpt.schema import Message from metagpt.utils.common import NoMoneyException -from metagpt.utils.utils import read_json_file, write_json_file +from metagpt.utils.utils import read_json_file, write_json_file, serialize_decorator +from metagpt.const import SERDESER_PATH class Team(BaseModel): @@ -30,29 +31,35 @@ class Team(BaseModel): class Config: arbitrary_types_allowed = True - def serialize(self, stg_path: Path): + def serialize(self, stg_path: Path = None): + stg_path = SERDESER_PATH.joinpath("team") if stg_path is None else stg_path + team_info_path = stg_path.joinpath("team_info.json") - write_json_file(team_info_path, { - "idea": self.idea, - "investment": self.investment - }) + write_json_file(team_info_path, self.dict(exclude={"environment": True})) - self.environment.serialize(stg_path.joinpath("environment")) + self.environment.serialize(stg_path.joinpath("environment")) # save environment alone - def deserialize(self, stg_path: Path): + @classmethod + def recover(cls, stg_path: Path) -> "Team": + return cls.deserialize(stg_path) + + @classmethod + def deserialize(cls, stg_path: Path) -> "Team": """ stg_path = ./storage/team """ # recover team_info team_info_path = stg_path.joinpath("team_info.json") if not team_info_path.exists(): - logger.error("recover storage not exist, not to recover and continue run the old project.") - team_info = read_json_file(team_info_path) - self.investment = team_info.get("investment", 10.0) - self.idea = team_info.get("idea", "") + raise FileNotFoundError("recover storage meta file `team_info.json` not exist, " + "not to recover and please start a new project.") + + team_info: dict = read_json_file(team_info_path) # recover environment - environment_path = stg_path.joinpath("environment") - self.environment = Environment() - self.environment.deserialize(stg_path=environment_path) + environment = Environment.deserialize(stg_path=stg_path.joinpath("environment")) + team_info.update({"environment": environment}) + + team = Team(**team_info) + return team def hire(self, roles: list[Role]): """Hire roles to cooperate""" @@ -76,6 +83,7 @@ class Team(BaseModel): def _save(self): logger.info(self.json()) + @serialize_decorator async def run(self, n_round=3): """Run company until target round or no money""" while n_round > 0: @@ -85,4 +93,3 @@ class Team(BaseModel): self._check_balance() await self.environment.run() return self.environment.history - \ No newline at end of file diff --git a/metagpt/utils/serialize.py b/metagpt/utils/serialize.py index 56a866f2e..9a7049214 100644 --- a/metagpt/utils/serialize.py +++ b/metagpt/utils/serialize.py @@ -5,9 +5,7 @@ import copy import pickle -from metagpt.actions.action_output import ActionOutput -from metagpt.schema import Message -from metagpt.actions.action import Action +from metagpt.utils.utils import import_class def actionoutout_schema_to_mapping(schema: dict) -> dict: @@ -59,7 +57,7 @@ def actionoutput_str_to_mapping(mapping: dict) -> dict: return new_mapping -def serialize_general_message(message: Message) -> dict: +def serialize_general_message(message: "Message") -> dict: """ serialize Message, not to save""" message_cp = copy.deepcopy(message) ic = message_cp.instruct_content @@ -76,7 +74,7 @@ def serialize_general_message(message: Message) -> dict: return message_cp.dict() -def serialize_message(message: Message): +def serialize_message(message: "Message"): message_cp = copy.deepcopy(message) # avoid `instruct_content` value update by reference ic = message_cp.instruct_content if ic: @@ -90,29 +88,35 @@ def serialize_message(message: Message): return msg_ser -def deserialize_general_message(message_dict: dict) -> Message: +def deserialize_general_message(message_dict: dict) -> "Message": """ deserialize Message, not to load""" instruct_content = message_dict.pop("instruct_content") cause_by = message_dict.pop("cause_by") - message = Message(**message_dict) + message_cls = import_class("Message", "metagpt.schema") + message = message_cls(**message_dict) if instruct_content: ic = instruct_content mapping = actionoutput_str_to_mapping(ic["mapping"]) - ic_obj = ActionOutput.create_model_class(class_name=ic["class"], mapping=mapping) + + actionoutput_class = import_class("ActionOutput", "metagpt.actions.action_output") + ic_obj = actionoutput_class.create_model_class(class_name=ic["class"], mapping=mapping) ic_new = ic_obj(**ic["value"]) message.instruct_content = ic_new if cause_by: - message.cause_by = Action.deser_class(cause_by) + action_class = import_class("Action", "metagpt.actions.action") + message.cause_by = action_class.deser_class(cause_by) return message -def deserialize_message(message_ser: str) -> Message: +def deserialize_message(message_ser: str) -> "Message": message = pickle.loads(message_ser) if message.instruct_content: ic = message.instruct_content - ic_obj = ActionOutput.create_model_class(class_name=ic["class"], mapping=ic["mapping"]) + + actionoutput_class = import_class("ActionOutput", "metagpt.actions.action_output") + ic_obj = actionoutput_class.create_model_class(class_name=ic["class"], mapping=ic["mapping"]) ic_new = ic_obj(**ic["value"]) message.instruct_content = ic_new diff --git a/metagpt/utils/utils.py b/metagpt/utils/utils.py index 81ceea884..1cf618ba0 100644 --- a/metagpt/utils/utils.py +++ b/metagpt/utils/utils.py @@ -6,6 +6,9 @@ from typing import Any import json from pathlib import Path import importlib +import traceback + +from metagpt.logs import logger def read_json_file(json_file: str, encoding=None) -> list[Any]: @@ -39,3 +42,43 @@ def import_class_inst(class_name: str, module_name: str, *args, **kwargs) -> obj a_class = import_class(class_name, module_name) class_inst = a_class(*args, **kwargs) return class_inst + + +def format_trackback_info(limit: int = 2): + return traceback.format_exc(limit=limit) + + +def serialize_decorator(func): + async def wrapper(self, *args, **kwargs): + try: + return await func(self, *args, **kwargs) + except KeyboardInterrupt as kbi: + logger.error(f"KeyboardInterrupt occurs, start to serialize the project, exp:\n{format_trackback_info()}") + self.serialize() # Team.serialize + except Exception as exp: + logger.error(f"Exception occurs, start to serialize the project, exp:\n{format_trackback_info()}") + self.serialize() # Team.serialize + + return wrapper + + +def role_raise_decorator(func): + async def wrapper(self, *args, **kwargs): + try: + return await func(self, *args, **kwargs) + except KeyboardInterrupt as kbi: + logger.error(f"KeyboardInterrupt: {kbi} occurs, start to serialize the project") + if self._rc.env: + newest_msgs = self._rc.env.memory.get(1) + if len(newest_msgs) > 0: + self._rc.memory.delete(newest_msgs[0]) + except Exception as exp: + if self._rc.env: + newest_msgs = self._rc.env.memory.get(1) + if len(newest_msgs) > 0: + logger.warning("There is a exception in role's execution, in order to resume, " + "we delete the newest role communication message in the role's memory.") + self._rc.memory.delete(newest_msgs[0]) # remove newest msg of the role to make it observed again + raise Exception(format_trackback_info(limit=None)) # raise again to make it captured outside + + return wrapper diff --git a/startup.py b/startup.py index 9f753d553..c4928a1b5 100644 --- a/startup.py +++ b/startup.py @@ -1,10 +1,11 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- + +from typing import Optional import asyncio - import fire +from pathlib import Path -from metagpt.const import SERDES_PATH from metagpt.roles import ( Architect, Engineer, @@ -22,11 +23,11 @@ async def startup( code_review: bool = False, run_tests: bool = False, implement: bool = True, - recover_path: bool = False, + recover_path: Optional[str] = None, ): """Run a startup. Be a boss.""" - company = Team() if not recover_path: + company = Team() company.hire( [ ProductManager(), @@ -45,8 +46,12 @@ async def startup( # (bug fixing capability comes soon!) company.hire([QaEngineer()]) else: - stg_path = SERDES_PATH.joinpath("team") - company.deserialize(stg_path=stg_path) + # # stg_path = SERDESER_PATH.joinpath("team") + stg_path = Path(recover_path) + if not stg_path.exists() or not str(stg_path).endswith("team"): + raise FileNotFoundError(f"{recover_path} not exists or not endswith `team`") + + company = Team.recover(stg_path=stg_path) idea = company.idea # use original idea company.invest(investment) From caacfcff7a541b7e69928cb0ed078fd98f89b55b Mon Sep 17 00:00:00 2001 From: better629 Date: Thu, 30 Nov 2023 19:30:02 +0800 Subject: [PATCH 07/14] fix ut of serialize_deserialize --- .../serialize_deserialize/test_action.py | 3 +-- .../test_product_manager.py | 1 - .../serialize_deserialize/test_role.py | 10 ++++++++- .../test_serdeser_base.py | 21 +++++++++++++------ .../serialize_deserialize/test_team.py | 2 +- .../serialize_deserialize/test_wrire_prd.py | 4 ++-- .../serialize_deserialize/test_write_code.py | 2 -- .../test_write_design.py | 3 +-- 8 files changed, 29 insertions(+), 17 deletions(-) diff --git a/tests/metagpt/serialize_deserialize/test_action.py b/tests/metagpt/serialize_deserialize/test_action.py index b624dff5a..0138d41ce 100644 --- a/tests/metagpt/serialize_deserialize/test_action.py +++ b/tests/metagpt/serialize_deserialize/test_action.py @@ -13,14 +13,13 @@ def test_action_serialize(): action = Action() ser_action_dict = action.dict() assert "name" in ser_action_dict - assert "llm" in ser_action_dict + assert "llm" not in ser_action_dict @pytest.mark.asyncio async def test_action_deserialize(): action = Action() serialized_data = action.dict() - assert isinstance(serialized_data["llm"], OpenAIGPTAPI) new_action = Action(**serialized_data) diff --git a/tests/metagpt/serialize_deserialize/test_product_manager.py b/tests/metagpt/serialize_deserialize/test_product_manager.py index 54584cf96..25bc07a11 100644 --- a/tests/metagpt/serialize_deserialize/test_product_manager.py +++ b/tests/metagpt/serialize_deserialize/test_product_manager.py @@ -14,7 +14,6 @@ async def test_product_manager_deserialize(): role = ProductManager() ser_role_dict = role.dict(by_alias=True) new_role = ProductManager(**ser_role_dict) - # new_role = ProductManager().deserialize(ser_role_dict) assert new_role.name == "Alice" assert len(new_role._actions) == 1 diff --git a/tests/metagpt/serialize_deserialize/test_role.py b/tests/metagpt/serialize_deserialize/test_role.py index f260dea3a..c21b9cc2e 100644 --- a/tests/metagpt/serialize_deserialize/test_role.py +++ b/tests/metagpt/serialize_deserialize/test_role.py @@ -17,7 +17,15 @@ from metagpt.const import SERDESER_PATH from metagpt.roles.engineer import Engineer from metagpt.utils.utils import format_trackback_info -from tests.metagpt.serialize_deserialize.test_serdeser_base import RoleC, serdeser_path +from tests.metagpt.serialize_deserialize.test_serdeser_base import RoleA, RoleB, RoleC, serdeser_path + + +def test_roles(): + role_a = RoleA() + assert len(role_a._rc.watch) == 1 + role_b = RoleB() + assert len(role_a._rc.watch) == 1 + assert len(role_b._rc.watch) == 1 def test_role_serialize(): diff --git a/tests/metagpt/serialize_deserialize/test_serdeser_base.py b/tests/metagpt/serialize_deserialize/test_serdeser_base.py index 35bad6cd9..00d894b3d 100644 --- a/tests/metagpt/serialize_deserialize/test_serdeser_base.py +++ b/tests/metagpt/serialize_deserialize/test_serdeser_base.py @@ -8,6 +8,7 @@ from pathlib import Path from metagpt.actions.action import Action from metagpt.roles.role import Role, RoleReactMode from metagpt.actions.add_requirement import BossRequirement +from metagpt.actions.action_output import ActionOutput serdeser_path = Path(__file__).absolute().parent.joinpath("../../data/serdeser_storage") @@ -22,21 +23,27 @@ class MockMessage(BaseModel): class ActionPass(Action): name: str = "ActionPass" - async def run(self, messages: list["Message"]): - return "pass" + async def run(self, messages: list["Message"]) -> ActionOutput: + output_mapping = { + "result": (str, ...) + } + pass_class = ActionOutput.create_model_class("pass", output_mapping) + pass_output = ActionOutput("ActionPass run passed", pass_class(**{"result": "pass result"})) + + return pass_output class ActionOK(Action): name: str = "ActionOK" - async def run(self, messages: list["Message"]): + async def run(self, messages: list["Message"]) -> str: return "ok" class ActionRaise(Action): name: str = "ActionRaise" - async def run(self, messages: list["Message"]): + async def run(self, messages: list["Message"]) -> str: raise RuntimeError("parse error in ActionRaise") @@ -48,7 +55,8 @@ class RoleA(Role): constraints: str = "RoleA's constraints" def __init__(self, **kwargs): - super(RoleA, self).__init__(**kwargs) + # super(RoleA, self).__init__(**kwargs) + super().__init__(**kwargs) self._init_actions([ActionPass]) self._watch([BossRequirement]) @@ -63,7 +71,8 @@ class RoleB(Role): constraints: str = "RoleB's constraints" def __init__(self, **kwargs): - super(RoleB, self).__init__(**kwargs) + # super(RoleB, self).__init__(**kwargs) + super().__init__(**kwargs) self._init_actions([ActionOK, ActionRaise]) self._watch([ActionPass]) self._rc.react_mode = RoleReactMode.BY_ORDER diff --git a/tests/metagpt/serialize_deserialize/test_team.py b/tests/metagpt/serialize_deserialize/test_team.py index e9122ebc0..b8972135b 100644 --- a/tests/metagpt/serialize_deserialize/test_team.py +++ b/tests/metagpt/serialize_deserialize/test_team.py @@ -11,7 +11,7 @@ from metagpt.roles import ProjectManager, ProductManager, Architect from metagpt.team import Team from metagpt.const import SERDESER_PATH -from tests.metagpt.serialize_deserialize.test_serdeser_base import RoleA, RoleB, RoleC, serdeser_path +from tests.metagpt.serialize_deserialize.test_serdeser_base import RoleA, RoleB, RoleC, serdeser_path, ActionOK def test_team_deserialize(): diff --git a/tests/metagpt/serialize_deserialize/test_wrire_prd.py b/tests/metagpt/serialize_deserialize/test_wrire_prd.py index 96b4d19ad..05a86cb7f 100644 --- a/tests/metagpt/serialize_deserialize/test_wrire_prd.py +++ b/tests/metagpt/serialize_deserialize/test_wrire_prd.py @@ -21,7 +21,7 @@ async def test_action_deserialize(): action = WritePRD() serialized_data = action.dict() new_action = WritePRD(**serialized_data) - # new_action = WritePRD().deserialize(serialized_data) assert new_action.name == "" assert new_action.llm == LLM() - assert len(await new_action.run([Message(content="write a cli snake game")])) > 0 + action_output = await new_action.run([Message(content="write a cli snake game")]) + assert len(action_output.content) > 0 diff --git a/tests/metagpt/serialize_deserialize/test_write_code.py b/tests/metagpt/serialize_deserialize/test_write_code.py index 7f4799014..4e3b712c0 100644 --- a/tests/metagpt/serialize_deserialize/test_write_code.py +++ b/tests/metagpt/serialize_deserialize/test_write_code.py @@ -27,7 +27,6 @@ async def test_write_code_deserialize(): action = WriteCode() serialized_data = action.dict() new_action = WriteCode(**serialized_data) - # new_action = WriteCode().deserialize(serialized_data) assert new_action.name == "WriteCode" assert new_action.llm == LLM() await new_action.run(context="write a cli snake game", filename="test_code") @@ -38,7 +37,6 @@ async def test_write_code_review_deserialize(): action = WriteCodeReview() serialized_data = action.dict() new_action = WriteCodeReview(**serialized_data) - # new_action = WriteCodeReview().deserialize(serialized_data) code = await WriteCode().run(context="write a cli snake game", filename="test_code") assert new_action.name == "WriteCodeReview" diff --git a/tests/metagpt/serialize_deserialize/test_write_design.py b/tests/metagpt/serialize_deserialize/test_write_design.py index e6e236676..5b2a30ed3 100644 --- a/tests/metagpt/serialize_deserialize/test_write_design.py +++ b/tests/metagpt/serialize_deserialize/test_write_design.py @@ -26,7 +26,7 @@ def test_write_task_serialize(): async def test_write_design_deserialize(): action = WriteDesign() serialized_data = action.dict() - new_action = WriteDesign().deserialize(serialized_data) + new_action = WriteDesign(**serialized_data) assert new_action.name == "" assert new_action.llm == LLM() await new_action.run(context="write a cli snake game") @@ -37,7 +37,6 @@ async def test_write_task_deserialize(): action = WriteTasks() serialized_data = action.dict() new_action = WriteTasks(**serialized_data) - # new_action = WriteTasks().deserialize(serialized_data) assert new_action.name == "CreateTasks" assert new_action.llm == LLM() await new_action.run(context="write a cli snake game") From c70c8358d334d8297a0a33b95223d604c84096cd Mon Sep 17 00:00:00 2001 From: better629 Date: Thu, 30 Nov 2023 19:31:26 +0800 Subject: [PATCH 08/14] fix actions/roles ser&deser --- metagpt/actions/search_and_summarize.py | 16 +++++++--------- metagpt/actions/write_prd.py | 15 ++++++--------- metagpt/roles/role.py | 20 ++++++++++++++------ metagpt/utils/utils.py | 4 +++- 4 files changed, 30 insertions(+), 25 deletions(-) diff --git a/metagpt/actions/search_and_summarize.py b/metagpt/actions/search_and_summarize.py index 6b0c1f717..32444b302 100644 --- a/metagpt/actions/search_and_summarize.py +++ b/metagpt/actions/search_and_summarize.py @@ -117,23 +117,21 @@ class SearchAndSummarize(Action): @root_validator def validate_engine_and_run_func(cls, values): - engine = values.get('engine') - search_func = values.get('search_func') + engine = values.get("engine") + search_func = values.get("search_func") config = Config() if engine is None: engine = config.search_engine - config_data = { - 'engine': engine, - 'run_func': search_func - } - search_engine = SearchEngine(**config_data) + try: + search_engine = SearchEngine(engine=engine, run_func=search_func) + except pydantic.ValidationError: + search_engine = None - values['search_engine'] = search_engine + values["search_engine"] = search_engine return values async def run(self, context: list[Message], system_text=SEARCH_AND_SUMMARIZE_SYSTEM) -> str: - print(context) if self.search_engine is None: logger.warning("Configure one of SERPAPI_API_KEY, SERPER_API_KEY, GOOGLE_API_KEY to unlock full feature") return "" diff --git a/metagpt/actions/write_prd.py b/metagpt/actions/write_prd.py index 450bed7e7..86f0ad9a6 100644 --- a/metagpt/actions/write_prd.py +++ b/metagpt/actions/write_prd.py @@ -226,17 +226,14 @@ class WritePRD(Action): name: str = "" content: Optional[str] = None llm: BaseGPTAPI = Field(default_factory=LLM) - assistant_search_action: Action = None async def run(self, requirements, format=CONFIG.prompt_format, *args, **kwargs) -> ActionOutput: - # self.assistant_search_action = SearchAndSummarize() - if self.assistant_search_action is None: - self.assistant_search_action = SearchAndSummarize() - # self.assistant_search_action = SearchAndSummarize() - rsp = await self.assistant_search_action.run(context=requirements) - info = f"### Search Results\n{self.assistant_search_action.result}\n\n### Search Summary\n{rsp}" - if self.assistant_search_action.result: - logger.info(self.assistant_search_action.result) + sas = SearchAndSummarize() + # rsp = await sas.run(context=requirements, system_text=SEARCH_AND_SUMMARIZE_SYSTEM_EN_US) + rsp = "" + info = f"### Search Results\n{sas.result}\n\n### Search Summary\n{rsp}" + if sas.result: + logger.info(sas.result) logger.info(rsp) prompt_template, format_example = get_template(templates, format) diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index b6332aa4c..38f564caa 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -88,7 +88,7 @@ class RoleSetting(BaseModel): class RoleContext(BaseModel): """Role Runtime Context""" - env: "Environment" = Field(default=None) + env: "Environment" = Field(default=None, exclude=True) memory: Memory = Field(default_factory=Memory) long_term_memory: LongTermMemory = Field(default_factory=LongTermMemory) state: int = Field(default=-1) # -1 indicates initial or termination state where todo is None @@ -133,7 +133,7 @@ class Role(BaseModel): _role_id: str = "" _states: list[str] = Field(default=[]) _actions: list[Action] = Field(default=[]) - _rc: RoleContext = RoleContext() + _rc: RoleContext = Field(default=RoleContext, exclude=True) # builtin variables recovered: bool = False # to tag if a recovered role @@ -143,7 +143,8 @@ class Role(BaseModel): "_llm": LLM() if not is_human else HumanProvider(), "_role_id": _role_id, "_states": [], - "_actions": [] + "_actions": [], + "_rc": RoleContext() } class Config: @@ -169,6 +170,8 @@ class Role(BaseModel): self._private_attributes["_setting"] = RoleSetting(name=self.name, profile=self.profile, goal=self.goal, desc=self.desc, constraints=self.constraints, is_human=self.is_human) + self._private_attributes["_role_id"] = str(self._setting) + for key in self._private_attributes.keys(): if key in kwargs: object.__setattr__(self, key, kwargs[key]) @@ -176,10 +179,15 @@ class Role(BaseModel): setting = RoleSetting(**kwargs[key]) object.__setattr__(self, "_setting", setting) elif key == "_rc": - _rc = RoleContext + _rc = RoleContext() object.__setattr__(self, "_rc", _rc) else: - object.__setattr__(self, key, self._private_attributes[key]) + if key == "_rc": + # # Warning, if use self._private_attributes["_rc"], + # # self._rc will be a shared object between roles, so init one or reset it inside `_reset` + object.__setattr__(self, key, RoleContext()) + else: + object.__setattr__(self, key, self._private_attributes[key]) # deserialize child classes dynamically for inherited `role` object.__setattr__(self, "builtin_class_name", self.__class__.__name__) @@ -192,6 +200,7 @@ class Role(BaseModel): def _reset(self): object.__setattr__(self, "_states", []) object.__setattr__(self, "_actions", []) + # object.__setattr__(self, "_rc", RoleContext()) def serialize(self, stg_path: Path = None): stg_path = SERDESER_PATH.joinpath(f"team/environment/roles/{self.__class__.__name__}_{self.name}") \ @@ -289,7 +298,6 @@ class Role(BaseModel): for idx, action in enumerate(actions): if not isinstance(action, Action): ## 默认初始化 - # import pdb; pdb.set_trace() i = action(name="", llm=self._llm) else: if self._setting.is_human and not isinstance(action.llm, HumanProvider): diff --git a/metagpt/utils/utils.py b/metagpt/utils/utils.py index 1cf618ba0..b72dabf7e 100644 --- a/metagpt/utils/utils.py +++ b/metagpt/utils/utils.py @@ -51,7 +51,9 @@ def format_trackback_info(limit: int = 2): def serialize_decorator(func): async def wrapper(self, *args, **kwargs): try: - return await func(self, *args, **kwargs) + result = await func(self, *args, **kwargs) + self.serialize() # Team.serialize + return result except KeyboardInterrupt as kbi: logger.error(f"KeyboardInterrupt occurs, start to serialize the project, exp:\n{format_trackback_info()}") self.serialize() # Team.serialize From 6208400f71ee926ed422aed9ed2cc160d7a0de4e Mon Sep 17 00:00:00 2001 From: better629 Date: Thu, 30 Nov 2023 21:42:09 +0800 Subject: [PATCH 09/14] fix role._rc init --- metagpt/environment.py | 4 ++++ metagpt/roles/role.py | 11 ++++++----- .../serialize_deserialize/test_team.py | 19 ++++++++++++++++--- 3 files changed, 26 insertions(+), 8 deletions(-) diff --git a/metagpt/environment.py b/metagpt/environment.py index bade53f50..bff12210d 100644 --- a/metagpt/environment.py +++ b/metagpt/environment.py @@ -31,6 +31,7 @@ class Environment(BaseModel): arbitrary_types_allowed = True def __init__(self, **kwargs): + roles = [] for role_key, role in kwargs.get("roles", {}).items(): current_role = kwargs["roles"][role_key] if isinstance(current_role, dict): @@ -41,8 +42,11 @@ class Environment(BaseModel): current_role = subclass(**current_role) break kwargs["roles"][role_key] = current_role + roles.append(current_role) super().__init__(**kwargs) + self.add_roles(roles) # add_roles again to init the Role.set_env + def serialize(self, stg_path: Path): roles_path = stg_path.joinpath("roles.json") roles_info = [] diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 38f564caa..b78597d01 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -88,13 +88,14 @@ class RoleSetting(BaseModel): class RoleContext(BaseModel): """Role Runtime Context""" + # # env exclude=True to avoid `RecursionError: maximum recursion depth exceeded in comparison` env: "Environment" = Field(default=None, exclude=True) memory: Memory = Field(default_factory=Memory) - long_term_memory: LongTermMemory = Field(default_factory=LongTermMemory) + long_term_memory: LongTermMemory = Field(default_factory=LongTermMemory, exclude=True) # TODO not used now state: int = Field(default=-1) # -1 indicates initial or termination state where todo is None todo: Action = Field(default=None) watch: set[Type[Action]] = Field(default_factory=set) - news: list[Type[Message]] = Field(default=[]) + news: list[Type[Message]] = Field(default=[], exclude=True) # TODO not used react_mode: RoleReactMode = RoleReactMode.REACT # see `Role._set_react_mode` for definitions of the following two attributes max_react_loop: int = 1 @@ -128,12 +129,12 @@ class Role(BaseModel): desc: str = "" is_human: bool = False - _llm: BaseGPTAPI = Field(default_factory=LLM, exclude=True) + _llm: BaseGPTAPI = Field(default_factory=LLM) _setting: RoleSetting = Field(default_factory=RoleSetting, alias=True) _role_id: str = "" _states: list[str] = Field(default=[]) _actions: list[Action] = Field(default=[]) - _rc: RoleContext = Field(default=RoleContext, exclude=True) + _rc: RoleContext = Field(default=RoleContext) # builtin variables recovered: bool = False # to tag if a recovered role @@ -179,7 +180,7 @@ class Role(BaseModel): setting = RoleSetting(**kwargs[key]) object.__setattr__(self, "_setting", setting) elif key == "_rc": - _rc = RoleContext() + _rc = RoleContext(**kwargs["_rc"]) object.__setattr__(self, "_rc", _rc) else: if key == "_rc": diff --git a/tests/metagpt/serialize_deserialize/test_team.py b/tests/metagpt/serialize_deserialize/test_team.py index b8972135b..e5ec20f2e 100644 --- a/tests/metagpt/serialize_deserialize/test_team.py +++ b/tests/metagpt/serialize_deserialize/test_team.py @@ -39,7 +39,7 @@ def test_team_deserialize(): assert new_company.environment.get_role(arch.profile) is not None -def test_team_serdeser(): +def test_team_serdeser_save(): company = Team() company.hire([RoleC()]) @@ -60,12 +60,19 @@ async def test_team_recover(): shutil.rmtree(stg_path, ignore_errors=True) company = Team() - company.hire([RoleC()]) + role_c = RoleC() + company.hire([role_c]) company.start_project(idea) await company.run(n_round=4) ser_data = company.dict() new_company = Team(**ser_data) + + new_role_c = new_company.environment.get_role(role_c.profile) + assert new_role_c._rc.memory == role_c._rc.memory + assert new_role_c._rc.env != role_c._rc.env # due to Action raise, role's memory has been changed. + assert new_role_c._rc.env.memory == role_c._rc.env.memory + assert new_company.environment.memory.count() == 1 assert type(list(new_company.environment.roles.values())[0]._actions[0]) == ActionOK @@ -80,11 +87,17 @@ async def test_team_recover_save(): shutil.rmtree(stg_path, ignore_errors=True) company = Team() - company.hire([RoleC()]) + role_c = RoleC() + company.hire([role_c]) company.start_project(idea) await company.run(n_round=4) new_company = Team.recover(stg_path) + new_role_c = new_company.environment.get_role(role_c.profile) + assert new_role_c._rc.memory == role_c._rc.memory + assert new_role_c._rc.env != role_c._rc.env # due to Action raise, role's memory has been changed. + assert new_role_c._rc.env.memory == role_c._rc.env.memory + new_company.start_project(idea) await new_company.run(n_round=4) From f563b2c60809d45db87387956586acd18ddc9201 Mon Sep 17 00:00:00 2001 From: better629 Date: Fri, 1 Dec 2023 14:43:45 +0800 Subject: [PATCH 10/14] simplify some ser&desr code --- metagpt/actions/action.py | 20 ++----- metagpt/environment.py | 6 +- metagpt/memory/memory.py | 18 +----- metagpt/roles/role.py | 114 ++++++++++++++------------------------ metagpt/schema.py | 42 +------------- 5 files changed, 54 insertions(+), 146 deletions(-) diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index 7a7f194f4..692a2a6e5 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -52,6 +52,12 @@ class Action(BaseModel): super().__init_subclass__(**kwargs) action_subclass_registry[cls.__name__] = cls + def dict(self, *args, **kwargs) -> "DictStrAny": + obj_dict = super(Action, self).dict(*args, **kwargs) + if "llm" in obj_dict: + obj_dict.pop("llm") + return obj_dict + def set_prefix(self, prefix, profile): """Set prefix for later usage""" self.prefix = prefix @@ -63,20 +69,6 @@ class Action(BaseModel): def __repr__(self): return self.__str__() - def serialize(self): - return { - "action_class": self.__class__.__name__, - "module_name": self.__module__, - "name": self.name - } - - @classmethod - def deserialize(cls, action_dict: dict) -> "Action": - action_class_str = action_dict.pop("action_class") - module_name = action_dict.pop("module_name") - action_class = import_class(action_class_str, module_name) - return action_class(**action_dict) - @classmethod def ser_class(cls) -> dict: """ serialize class type""" diff --git a/metagpt/environment.py b/metagpt/environment.py index bff12210d..3174cfc10 100644 --- a/metagpt/environment.py +++ b/metagpt/environment.py @@ -70,10 +70,8 @@ class Environment(BaseModel): roles_info = read_json_file(roles_path) roles = [] for role_info in roles_info: - role_class = role_info.get("role_class") - role_name = role_info.get("role_name") - - role_path = stg_path.joinpath(f"roles/{role_class}_{role_name}") + # role stored in ./environment/roles/{role_class}_{role_name} + role_path = stg_path.joinpath(f'roles/{role_info.get("role_class")}_{role_info.get("role_name")}') role = Role.deserialize(role_path) roles.append(role) diff --git a/metagpt/memory/memory.py b/metagpt/memory/memory.py index c88cc750e..ed30cde18 100644 --- a/metagpt/memory/memory.py +++ b/metagpt/memory/memory.py @@ -36,23 +36,9 @@ class Memory(BaseModel): super(Memory, self).__init__(**kwargs) self.index = new_index - def dict(self, - *, - include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, - exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, - by_alias: bool = False, - skip_defaults: Optional[bool] = None, - exclude_unset: bool = False, - exclude_defaults: bool = False, - exclude_none: bool = False) -> "DictStrAny": + def dict(self, *args, **kwargs) -> "DictStrAny": """ overwrite the `dict` to dump dynamic pydantic model""" - obj_dict = super(Memory, self).dict(include=include, - exclude=exclude, - by_alias=by_alias, - skip_defaults=skip_defaults, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none) + obj_dict = super(Memory, self).dict(*args, **kwargs) new_obj_dict = copy.deepcopy(obj_dict) new_obj_dict["index"] = {} for action, value in obj_dict["index"].items(): diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index b78597d01..4e669772e 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -93,7 +93,7 @@ class RoleContext(BaseModel): memory: Memory = Field(default_factory=Memory) long_term_memory: LongTermMemory = Field(default_factory=LongTermMemory, exclude=True) # TODO not used now state: int = Field(default=-1) # -1 indicates initial or termination state where todo is None - todo: Action = Field(default=None) + todo: Action = Field(default=None, exclude=True) watch: set[Type[Action]] = Field(default_factory=set) news: list[Type[Message]] = Field(default=[], exclude=True) # TODO not used react_mode: RoleReactMode = RoleReactMode.REACT # see `Role._set_react_mode` for definitions of the following two attributes @@ -101,7 +101,25 @@ class RoleContext(BaseModel): class Config: arbitrary_types_allowed = True - + + def __init__(self, **kwargs): + watch_info = kwargs.get("watch", set()) + watch = set() + for item in watch_info: + action = Action.deser_class(item) + watch.update([action]) + kwargs["watch"] = watch + super(RoleContext, self).__init__(**kwargs) + + def dict(self, *args, **kwargs) -> "DictStrAny": + obj_dict = super(RoleContext, self).dict(*args, **kwargs) + watch = obj_dict.get("watch", set()) + watch_info = [] + for item in watch: + watch_info.append(item.ser_class()) + obj_dict["watch"] = watch_info + return obj_dict + def check(self, role_id: str): if hasattr(CONFIG, "long_term_memory") and CONFIG.long_term_memory: self.long_term_memory.recover_memory(role_id, self) @@ -130,7 +148,6 @@ class Role(BaseModel): is_human: bool = False _llm: BaseGPTAPI = Field(default_factory=LLM) - _setting: RoleSetting = Field(default_factory=RoleSetting, alias=True) _role_id: str = "" _states: list[str] = Field(default=[]) _actions: list[Action] = Field(default=[]) @@ -168,18 +185,12 @@ class Role(BaseModel): # 关于私有变量的初始化 https://github.com/pydantic/pydantic/issues/655 self._private_attributes["_llm"] = LLM() if not self.is_human else HumanProvider() - self._private_attributes["_setting"] = RoleSetting(name=self.name, profile=self.profile, goal=self.goal, - desc=self.desc, constraints=self.constraints, - is_human=self.is_human) self._private_attributes["_role_id"] = str(self._setting) for key in self._private_attributes.keys(): if key in kwargs: object.__setattr__(self, key, kwargs[key]) - if key == "_setting": - setting = RoleSetting(**kwargs[key]) - object.__setattr__(self, "_setting", setting) - elif key == "_rc": + if key == "_rc": _rc = RoleContext(**kwargs["_rc"]) object.__setattr__(self, "_rc", _rc) else: @@ -203,41 +214,23 @@ class Role(BaseModel): object.__setattr__(self, "_actions", []) # object.__setattr__(self, "_rc", RoleContext()) + @property + def _setting(self): + return f"{self.name}({self.profile})" + def serialize(self, stg_path: Path = None): stg_path = SERDESER_PATH.joinpath(f"team/environment/roles/{self.__class__.__name__}_{self.name}") \ if stg_path is None else stg_path - role_info_path = stg_path.joinpath("role_info.json") - role_info = { + + role_info = self.dict(exclude={"_rc": {"memory": True}, "_llm": True}) + role_info.update({ "role_class": self.__class__.__name__, "module_name": self.__module__ - } - setting = self._setting.dict() - setting.pop("desc") - setting.pop("is_human") # not all inherited roles have this atrr - role_info.update(setting) + }) + role_info_path = stg_path.joinpath("role_info.json") write_json_file(role_info_path, role_info) - actions_info_path = stg_path.joinpath("actions/actions_info.json") - actions_info = [] - for action in self._actions: - actions_info.append(action.serialize()) - write_json_file(actions_info_path, actions_info) - - watches_info_path = stg_path.joinpath("watches/watches_info.json") - watches_info = [] - for watch in self._rc.watch: - watches_info.append(watch.ser_class()) - write_json_file(watches_info_path, watches_info) - - actions_todo_path = stg_path.joinpath("actions/todo.json") - actions_todo = { - "cur_state": self._rc.state, - "react_mode": self._rc.react_mode.value, - "max_react_loop": self._rc.max_react_loop - } - write_json_file(actions_todo_path, actions_todo) - - self._rc.memory.serialize(stg_path) + self._rc.memory.serialize(stg_path) # serialize role's memory alone @classmethod def deserialize(cls, stg_path: Path) -> "Role": @@ -250,35 +243,7 @@ class Role(BaseModel): role_class = import_class(class_name=role_class_str, module_name=module_name) role = role_class(**role_info) # initiate particular Role - actions_info_path = stg_path.joinpath("actions/actions_info.json") - actions = [] - actions_info = read_json_file(actions_info_path) - for action_info in actions_info: - action = Action.deser_class(action_info) - actions.append(action) - - watches_info_path = stg_path.joinpath("watches/watches_info.json") - watches = [] - watches_info = read_json_file(watches_info_path) - for watch_info in watches_info: - action = Action.deser_class(watch_info) - watches.append(action) - - role.init_actions(actions) - role.watch(watches) - - actions_todo_path = stg_path.joinpath("actions/todo.json") - # recover self._rc.state - actions_todo = read_json_file(actions_todo_path) - max_react_loop = actions_todo.get("max_react_loop", 1) - cur_state = actions_todo.get("cur_state", -1) - role.set_state(cur_state) - role.set_recovered(True) - react_mode_str = actions_todo.get("react_mode", RoleReactMode.REACT.value) - if react_mode_str not in RoleReactMode.values(): - logger.warning(f"ReactMode: {react_mode_str} not in {RoleReactMode.values()}, use react as default") - react_mode_str = RoleReactMode.REACT.value - role.set_react_mode(RoleReactMode(react_mode_str), max_react_loop) + role.set_recovered(True) # set True to make a tag role_memory = Memory.deserialize(stg_path) role.set_memory(role_memory) @@ -299,9 +264,9 @@ class Role(BaseModel): for idx, action in enumerate(actions): if not isinstance(action, Action): ## 默认初始化 - i = action(name="", llm=self._llm) + i = action(llm=self._llm) else: - if self._setting.is_human and not isinstance(action.llm, HumanProvider): + if self.is_human and not isinstance(action.llm, HumanProvider): logger.warning(f"is_human attribute does not take effect," f"as Role's {str(action)} was initialized using LLM, try passing in Action classes instead of initialized instances") i = action @@ -357,9 +322,14 @@ class Role(BaseModel): def _get_prefix(self): """Get the role prefix""" - if self._setting.desc: - return self._setting.desc - return PREFIX_TEMPLATE.format(**self._setting.dict()) + if self.desc: + return self.desc + return PREFIX_TEMPLATE.format(**{ + "profile": self.profile, + "name": self.name, + "goal": self.goal, + "constraints": self.constraints + }) async def _think(self) -> None: """Think about what to do and decide on the next action""" diff --git a/metagpt/schema.py b/metagpt/schema.py index 60aa819b0..3a5bea7e9 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -48,23 +48,9 @@ class Message(BaseModel): kwargs["cause_by"] = action_class.deser_class(cause_by) super(Message, self).__init__(**kwargs) - def dict(self, - *, - include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, - exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, - by_alias: bool = False, - skip_defaults: Optional[bool] = None, - exclude_unset: bool = False, - exclude_defaults: bool = False, - exclude_none: bool = False) -> "DictStrAny": + def dict(self, *args, **kwargs) -> "DictStrAny": """ overwrite the `dict` to dump dynamic pydantic model""" - obj_dict = super(Message, self).dict(include=include, - exclude=exclude, - by_alias=by_alias, - skip_defaults=skip_defaults, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none) + obj_dict = super(Message, self).dict(*args, **kwargs) ic = self.instruct_content # deal custom-defined action if ic: schema = ic.schema() @@ -77,19 +63,6 @@ class Message(BaseModel): obj_dict["cause_by"] = cb.ser_class() return obj_dict -# -# -# @dataclass -# class Message: -# """list[: ]""" -# content: str -# instruct_content: BaseModel = field(default=None) -# role: str = field(default='user') # system / user / assistant -# cause_by: Type["Action"] = field(default="") -# sent_from: str = field(default="") -# send_to: str = field(default="") -# restricted_to: str = field(default="") - def __str__(self): # prefix = '-'.join([self.role, str(self.cause_by)]) return f"{self.role}: {self.content}" @@ -97,17 +70,6 @@ class Message(BaseModel): def __repr__(self): return self.__str__() - # def dict(self): - # return { - # "content": self.content, - # "instruct_content": self.instruct_content, - # "role": self.role, - # "cause_by": self.cause_by, - # "sent_from": self.sent_from, - # "send_to": self.send_to, - # "restricted_to": self.restricted_to - # } - def to_dict(self) -> dict: return { "role": self.role, From 0e8eda683e991f8ea7f80ccb09da2fc9a208a265 Mon Sep 17 00:00:00 2001 From: better629 Date: Fri, 1 Dec 2023 14:45:06 +0800 Subject: [PATCH 11/14] update ut after simplification --- tests/metagpt/serialize_deserialize/test_action.py | 14 +------------- tests/metagpt/serialize_deserialize/test_role.py | 3 --- .../serialize_deserialize/test_serdeser_base.py | 6 +++--- tests/metagpt/serialize_deserialize/test_team.py | 2 +- .../serialize_deserialize/test_wrire_prd.py | 2 +- .../serialize_deserialize/test_write_code.py | 4 ++-- .../serialize_deserialize/test_write_design.py | 4 ++-- 7 files changed, 10 insertions(+), 25 deletions(-) diff --git a/tests/metagpt/serialize_deserialize/test_action.py b/tests/metagpt/serialize_deserialize/test_action.py index 0138d41ce..16369bb61 100644 --- a/tests/metagpt/serialize_deserialize/test_action.py +++ b/tests/metagpt/serialize_deserialize/test_action.py @@ -13,7 +13,7 @@ def test_action_serialize(): action = Action() ser_action_dict = action.dict() assert "name" in ser_action_dict - assert "llm" not in ser_action_dict + # assert "llm" not in ser_action_dict # not export @pytest.mark.asyncio @@ -34,15 +34,3 @@ def test_action_serdeser(): action_class = Action.deser_class(action_info) assert action_class == WriteTest - - -def test_action_class_serdeser(): - name = "write test" - action_info = WriteTest(name=name).serialize() - assert action_info["name"] == name - - action_info = WriteTest(name=name, llm=LLM()).serialize() - assert action_info["name"] == name - - action = Action.deserialize(action_info) - assert action.name == name diff --git a/tests/metagpt/serialize_deserialize/test_role.py b/tests/metagpt/serialize_deserialize/test_role.py index c21b9cc2e..61684ba9d 100644 --- a/tests/metagpt/serialize_deserialize/test_role.py +++ b/tests/metagpt/serialize_deserialize/test_role.py @@ -66,7 +66,6 @@ def test_role_serdeser_save(): role_tag = f"{pm.__class__.__name__}_{pm.name}" stg_path = stg_path_prefix.joinpath(role_tag) pm.serialize(stg_path) - assert stg_path.joinpath("actions/actions_info.json").exists() new_pm = Role.deserialize(stg_path) assert new_pm.name == pm.name @@ -89,8 +88,6 @@ async def test_role_serdeser_interrupt(): assert role_c._rc.memory.count() == 2 - assert stg_path.joinpath("actions/todo.json").exists() - new_role_a: Role = Role.deserialize(stg_path) assert new_role_a._rc.state == 1 diff --git a/tests/metagpt/serialize_deserialize/test_serdeser_base.py b/tests/metagpt/serialize_deserialize/test_serdeser_base.py index 00d894b3d..74f9fea87 100644 --- a/tests/metagpt/serialize_deserialize/test_serdeser_base.py +++ b/tests/metagpt/serialize_deserialize/test_serdeser_base.py @@ -21,7 +21,7 @@ class MockMessage(BaseModel): class ActionPass(Action): - name: str = "ActionPass" + name: str = Field(default="ActionPass") async def run(self, messages: list["Message"]) -> ActionOutput: output_mapping = { @@ -34,14 +34,14 @@ class ActionPass(Action): class ActionOK(Action): - name: str = "ActionOK" + name: str = Field(default="ActionOK") async def run(self, messages: list["Message"]) -> str: return "ok" class ActionRaise(Action): - name: str = "ActionRaise" + name: str = Field(default="ActionRaise") async def run(self, messages: list["Message"]) -> str: raise RuntimeError("parse error in ActionRaise") diff --git a/tests/metagpt/serialize_deserialize/test_team.py b/tests/metagpt/serialize_deserialize/test_team.py index e5ec20f2e..28728e1b5 100644 --- a/tests/metagpt/serialize_deserialize/test_team.py +++ b/tests/metagpt/serialize_deserialize/test_team.py @@ -70,7 +70,7 @@ async def test_team_recover(): new_role_c = new_company.environment.get_role(role_c.profile) assert new_role_c._rc.memory == role_c._rc.memory - assert new_role_c._rc.env != role_c._rc.env # due to Action raise, role's memory has been changed. + assert new_role_c._rc.env == role_c._rc.env # TODO check again assert new_role_c._rc.env.memory == role_c._rc.env.memory assert new_company.environment.memory.count() == 1 diff --git a/tests/metagpt/serialize_deserialize/test_wrire_prd.py b/tests/metagpt/serialize_deserialize/test_wrire_prd.py index 05a86cb7f..0b9dfa9d8 100644 --- a/tests/metagpt/serialize_deserialize/test_wrire_prd.py +++ b/tests/metagpt/serialize_deserialize/test_wrire_prd.py @@ -13,7 +13,7 @@ def test_action_serialize(): action = WritePRD() ser_action_dict = action.dict() assert "name" in ser_action_dict - assert "llm" in ser_action_dict + # assert "llm" in ser_action_dict # not export @pytest.mark.asyncio diff --git a/tests/metagpt/serialize_deserialize/test_write_code.py b/tests/metagpt/serialize_deserialize/test_write_code.py index 4e3b712c0..5552ffd7f 100644 --- a/tests/metagpt/serialize_deserialize/test_write_code.py +++ b/tests/metagpt/serialize_deserialize/test_write_code.py @@ -12,14 +12,14 @@ def test_write_design_serialize(): action = WriteCode() ser_action_dict = action.dict() assert ser_action_dict["name"] == "WriteCode" - assert "llm" in ser_action_dict + # assert "llm" in ser_action_dict # not export def test_write_task_serialize(): action = WriteCodeReview() ser_action_dict = action.dict() assert ser_action_dict["name"] == "WriteCodeReview" - assert "llm" in ser_action_dict + # assert "llm" in ser_action_dict # not export @pytest.mark.asyncio diff --git a/tests/metagpt/serialize_deserialize/test_write_design.py b/tests/metagpt/serialize_deserialize/test_write_design.py index 5b2a30ed3..080896c98 100644 --- a/tests/metagpt/serialize_deserialize/test_write_design.py +++ b/tests/metagpt/serialize_deserialize/test_write_design.py @@ -12,14 +12,14 @@ def test_write_design_serialize(): action = WriteDesign() ser_action_dict = action.dict() assert "name" in ser_action_dict - assert "llm" in ser_action_dict + # assert "llm" in ser_action_dict # not export def test_write_task_serialize(): action = WriteTasks() ser_action_dict = action.dict() assert "name" in ser_action_dict - assert "llm" in ser_action_dict + # assert "llm" in ser_action_dict # not export @pytest.mark.asyncio From c7a5bea2b157d2fca2641369a14a415fd935f83f Mon Sep 17 00:00:00 2001 From: better629 Date: Fri, 1 Dec 2023 15:30:28 +0800 Subject: [PATCH 12/14] update --- tests/metagpt/serialize_deserialize/test_team.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/metagpt/serialize_deserialize/test_team.py b/tests/metagpt/serialize_deserialize/test_team.py index 28728e1b5..9c4eb8170 100644 --- a/tests/metagpt/serialize_deserialize/test_team.py +++ b/tests/metagpt/serialize_deserialize/test_team.py @@ -70,7 +70,7 @@ async def test_team_recover(): new_role_c = new_company.environment.get_role(role_c.profile) assert new_role_c._rc.memory == role_c._rc.memory - assert new_role_c._rc.env == role_c._rc.env # TODO check again + assert new_role_c._rc.env == role_c._rc.env assert new_role_c._rc.env.memory == role_c._rc.env.memory assert new_company.environment.memory.count() == 1 @@ -95,7 +95,10 @@ async def test_team_recover_save(): new_company = Team.recover(stg_path) new_role_c = new_company.environment.get_role(role_c.profile) assert new_role_c._rc.memory == role_c._rc.memory - assert new_role_c._rc.env != role_c._rc.env # due to Action raise, role's memory has been changed. + assert new_role_c._rc.env != role_c._rc.env + assert new_role_c.recovered != role_c.recovered # here cause previous ut is `!=` + assert new_role_c._rc.todo != role_c._rc.todo # serialize exclude `_rc.todo` + assert new_role_c._rc.news != role_c._rc.news # serialize exclude `_rc.news` assert new_role_c._rc.env.memory == role_c._rc.env.memory new_company.start_project(idea) From bcba1393b4e1e3445031cd4779fcb441f1fad8d7 Mon Sep 17 00:00:00 2001 From: better629 Date: Fri, 1 Dec 2023 20:35:48 +0800 Subject: [PATCH 13/14] update asyncio.sleep to make it async --- .../test_serdeser_base.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/tests/metagpt/serialize_deserialize/test_serdeser_base.py b/tests/metagpt/serialize_deserialize/test_serdeser_base.py index 74f9fea87..298c13823 100644 --- a/tests/metagpt/serialize_deserialize/test_serdeser_base.py +++ b/tests/metagpt/serialize_deserialize/test_serdeser_base.py @@ -4,6 +4,7 @@ from pydantic import BaseModel, Field from pathlib import Path +import asyncio from metagpt.actions.action import Action from metagpt.roles.role import Role, RoleReactMode @@ -24,6 +25,7 @@ class ActionPass(Action): name: str = Field(default="ActionPass") async def run(self, messages: list["Message"]) -> ActionOutput: + await asyncio.sleep(5) # sleep to make other roles can watch the executed Message output_mapping = { "result": (str, ...) } @@ -37,6 +39,7 @@ class ActionOK(Action): name: str = Field(default="ActionOK") async def run(self, messages: list["Message"]) -> str: + await asyncio.sleep(5) return "ok" @@ -55,14 +58,10 @@ class RoleA(Role): constraints: str = "RoleA's constraints" def __init__(self, **kwargs): - # super(RoleA, self).__init__(**kwargs) - super().__init__(**kwargs) + super(RoleA, self).__init__(**kwargs) self._init_actions([ActionPass]) self._watch([BossRequirement]) - async def run(self, message: "Message" = None): - await super(RoleA, self).run(message) - class RoleB(Role): name: str = Field(default="RoleB") @@ -71,15 +70,11 @@ class RoleB(Role): constraints: str = "RoleB's constraints" def __init__(self, **kwargs): - # super(RoleB, self).__init__(**kwargs) - super().__init__(**kwargs) + super(RoleB, self).__init__(**kwargs) self._init_actions([ActionOK, ActionRaise]) self._watch([ActionPass]) self._rc.react_mode = RoleReactMode.BY_ORDER - async def run(self, message: "Message" = None): - await super(RoleB, self).run(message) - class RoleC(Role): name: str = Field(default="RoleC") @@ -92,6 +87,3 @@ class RoleC(Role): self._init_actions([ActionOK, ActionRaise]) self._watch([BossRequirement]) self._rc.react_mode = RoleReactMode.BY_ORDER - - async def run(self, message: "Message" = None): - await super(RoleC, self).run(message) From cb81561b69749596b16cdee7e6e3ed4128cd6685 Mon Sep 17 00:00:00 2001 From: better629 Date: Fri, 1 Dec 2023 21:07:47 +0800 Subject: [PATCH 14/14] fix when RoleReactMode=REACT --- metagpt/roles/role.py | 4 ++-- metagpt/utils/utils.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 4e669772e..5b998bf9a 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -337,9 +337,9 @@ class Role(BaseModel): # If there is only one action, then only this one can be performed self._set_state(0) return - if self._recovered and self._rc.state >= 0: + if self.recovered and self._rc.state >= 0: self._set_state(self._rc.state) # action to run from recovered state - self._recovered = False # avoid max_react_loop out of work + self.recovered = False # avoid max_react_loop out of work return prompt = self._get_prefix() diff --git a/metagpt/utils/utils.py b/metagpt/utils/utils.py index b72dabf7e..c1416c352 100644 --- a/metagpt/utils/utils.py +++ b/metagpt/utils/utils.py @@ -74,6 +74,7 @@ def role_raise_decorator(func): newest_msgs = self._rc.env.memory.get(1) if len(newest_msgs) > 0: self._rc.memory.delete(newest_msgs[0]) + raise Exception(format_trackback_info(limit=None)) # raise again to make it captured outside except Exception as exp: if self._rc.env: newest_msgs = self._rc.env.memory.get(1)