diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index a3a9c0195..535c25cb9 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -8,34 +8,47 @@ from __future__ import annotations -from abc import ABC -from typing import Optional +from typing import Optional, Any + +from pydantic import BaseModel, Field -from metagpt.actions.action_node import ActionNode from metagpt.llm import LLM +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.schema import CodingContext, CodeSummarizeContext, TestingContext, RunCodeContext +action_subclass_registry = {} -class Action(ABC): - """Action abstract class, requiring all inheritors to provide a series of standard capabilities""" - name: str - llm: LLM - # FIXME: simplify context - context: dict | CodingContext | CodeSummarizeContext | TestingContext | RunCodeContext | str | None - prefix: str - desc: str - node: ActionNode | None +class Action(BaseModel): + name: str = "" + llm: BaseGPTAPI = Field(default_factory=LLM, exclude=True) + context: dict | CodingContext | CodeSummarizeContext | TestingContext | RunCodeContext | str | None = "" + prefix = "" # aask*时会加上prefix,作为system_message + desc = "" # for skill manager + # node: ActionNode = Field(default_factory=ActionNode, exclude=True) - def __init__(self, name: str = "", context=None, llm: LLM = None): - self.name: str = name - if llm is None: - llm = LLM() - self.llm = llm - self.context = context - self.prefix = "" # aask*时会加上prefix,作为system_message - self.desc = "" # for skill manager - self.node = None + # builtin variables + builtin_class_name: str = "" + + class Config: + arbitrary_types_allowed = True + + def __init__(self, **kwargs: Any): + super().__init__(**kwargs) + + # deserialize child classes dynamically for inherited `action` + object.__setattr__(self, "builtin_class_name", self.__class__.__name__) + self.__fields__["builtin_class_name"].default = self.__class__.__name__ + + def __init_subclass__(cls, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + action_subclass_registry[cls.__name__] = cls + + def dict(self, *args, **kwargs) -> "DictStrAny": + obj_dict = super(Action, self).dict(*args, **kwargs) + if "llm" in obj_dict: + obj_dict.pop("llm") + return obj_dict def set_prefix(self, prefix): """Set prefix for later usage""" diff --git a/metagpt/actions/debug_error.py b/metagpt/actions/debug_error.py index 39f3bc1bc..839acdc2e 100644 --- a/metagpt/actions/debug_error.py +++ b/metagpt/actions/debug_error.py @@ -10,11 +10,14 @@ """ import re +from pydantic import Field + from metagpt.actions.action import Action from metagpt.config import CONFIG from metagpt.const import TEST_CODES_FILE_REPO, TEST_OUTPUTS_FILE_REPO +from metagpt.llm import LLM, BaseGPTAPI from metagpt.logs import logger -from metagpt.schema import RunCodeResult +from metagpt.schema import RunCodeResult, RunCodeContext from metagpt.utils.common import CodeParser from metagpt.utils.file_repository import FileRepository @@ -47,8 +50,9 @@ Now you should start rewriting the code: class DebugError(Action): - def __init__(self, name="DebugError", context=None, llm=None): - super().__init__(name, context, llm) + name: str = "DebugError" + context: RunCodeContext = Field(default_factory=RunCodeContext) + llm: BaseGPTAPI = Field(default_factory=LLM) async def run(self, *args, **kwargs) -> str: output_doc = await FileRepository.get_file( diff --git a/metagpt/actions/design_api.py b/metagpt/actions/design_api.py index 548725fde..f5e122356 100644 --- a/metagpt/actions/design_api.py +++ b/metagpt/actions/design_api.py @@ -11,6 +11,9 @@ """ import json from pathlib import Path +from typing import Optional + +from pydantic import Field from metagpt.actions import Action, ActionOutput from metagpt.actions.design_api_an import DESIGN_API_NODE @@ -22,16 +25,13 @@ from metagpt.const import ( SYSTEM_DESIGN_FILE_REPO, SYSTEM_DESIGN_PDF_FILE_REPO, ) +from metagpt.llm import LLM from metagpt.logs import logger -from metagpt.schema import Document, Documents +from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.schema import Document, Documents, Message from metagpt.utils.file_repository import FileRepository - -# from metagpt.utils.get_template import get_template from metagpt.utils.mermaid import mermaid_to_file -# from typing import List - - NEW_REQ_TEMPLATE = """ ### Legacy Content {old_design} @@ -42,15 +42,14 @@ NEW_REQ_TEMPLATE = """ class WriteDesign(Action): - def __init__(self, name, context=None, llm=None): - super().__init__(name, context, llm) - self.desc = ( - "Based on the PRD, think about the system design, and design the corresponding APIs, " - "data structures, library tables, processes, and paths. Please provide your design, feedback " - "clearly and in detail." - ) + name: str = "" + context: Optional[str] = None + llm: BaseGPTAPI = Field(default_factory=LLM) + desc: str = "Based on the PRD, think about the system design, and design the corresponding APIs, " \ + "data structures, library tables, processes, and paths. Please provide your design, feedback " \ + "clearly and in detail." - async def run(self, with_messages, schema=CONFIG.prompt_schema): + async def run(self, with_messages: Message, schema: str = CONFIG.prompt_schema): # Use `git diff` to identify which PRD documents have been modified in the `docs/prds` directory. prds_file_repo = CONFIG.git_repo.new_file_repository(PRDS_FILE_REPO) changed_prds = prds_file_repo.changed_files diff --git a/metagpt/actions/fix_bug.py b/metagpt/actions/fix_bug.py index 6bd550d3d..eea40c91a 100644 --- a/metagpt/actions/fix_bug.py +++ b/metagpt/actions/fix_bug.py @@ -9,6 +9,7 @@ from metagpt.actions import Action class FixBug(Action): """Fix bug action without any implementation details""" + name: str = "FixBug" async def run(self, *args, **kwargs): raise NotImplementedError diff --git a/metagpt/actions/prepare_documents.py b/metagpt/actions/prepare_documents.py index 3c0885954..9b5128cbd 100644 --- a/metagpt/actions/prepare_documents.py +++ b/metagpt/actions/prepare_documents.py @@ -9,10 +9,15 @@ """ import shutil from pathlib import Path +from typing import Optional + +from pydantic import Field from metagpt.actions import Action, ActionOutput from metagpt.config import CONFIG from metagpt.const import DOCS_FILE_REPO, REQUIREMENT_FILENAME +from metagpt.llm import LLM +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.schema import Document from metagpt.utils.file_repository import FileRepository from metagpt.utils.git_repository import GitRepository @@ -20,6 +25,9 @@ from metagpt.utils.git_repository import GitRepository class PrepareDocuments(Action): """PrepareDocuments Action: initialize project folder and add new requirements to docs/requirements.txt.""" + name: str = "PrepareDocuments" + context: Optional[str] = None + llm: BaseGPTAPI = Field(default_factory=LLM) def _init_repo(self): """Initialize the Git environment.""" diff --git a/metagpt/actions/project_management.py b/metagpt/actions/project_management.py index fe2c8d537..095881e60 100644 --- a/metagpt/actions/project_management.py +++ b/metagpt/actions/project_management.py @@ -9,7 +9,11 @@ 2. Move the document storage operations related to WritePRD from the save operation of WriteDesign. 3. According to the design in Section 2.2.3.5.4 of RFC 135, add incremental iteration functionality. """ + import json +from typing import Optional + +from pydantic import Field from metagpt.actions import ActionOutput from metagpt.actions.action import Action @@ -21,14 +25,12 @@ from metagpt.const import ( TASK_FILE_REPO, TASK_PDF_FILE_REPO, ) +from metagpt.llm import LLM from metagpt.logs import logger +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.schema import Document, Documents from metagpt.utils.file_repository import FileRepository -# from typing import List - -# from metagpt.utils.get_template import get_template - NEW_REQ_TEMPLATE = """ ### Legacy Content {old_tasks} @@ -39,8 +41,9 @@ NEW_REQ_TEMPLATE = """ class WriteTasks(Action): - def __init__(self, name="CreateTasks", context=None, llm=None): - super().__init__(name, context, llm) + name: str = "CreateTasks" + context: Optional[str] = None + llm: BaseGPTAPI = Field(default_factory=LLM) async def run(self, with_messages, schema=CONFIG.prompt_schema): system_design_file_repo = CONFIG.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO) diff --git a/metagpt/actions/run_code.py b/metagpt/actions/run_code.py index 1b9fd252f..ea16c8891 100644 --- a/metagpt/actions/run_code.py +++ b/metagpt/actions/run_code.py @@ -18,10 +18,13 @@ import subprocess from typing import Tuple +from pydantic import Field + from metagpt.actions.action import Action from metagpt.config import CONFIG +from metagpt.llm import LLM, BaseGPTAPI from metagpt.logs import logger -from metagpt.schema import RunCodeResult +from metagpt.schema import RunCodeResult, RunCodeContext from metagpt.utils.exceptions import handle_exception PROMPT_TEMPLATE = """ @@ -74,8 +77,9 @@ standard errors: class RunCode(Action): - def __init__(self, name="RunCode", context=None, llm=None): - super().__init__(name, context, llm) + name: str = "RunCode" + context: RunCodeContext = Field(default_factory=RunCodeContext) + llm: BaseGPTAPI = Field(default_factory=LLM) @classmethod @handle_exception diff --git a/metagpt/actions/search_and_summarize.py b/metagpt/actions/search_and_summarize.py index a1d81bc65..3f110c370 100644 --- a/metagpt/actions/search_and_summarize.py +++ b/metagpt/actions/search_and_summarize.py @@ -6,12 +6,17 @@ @File : search_google.py """ import pydantic +from typing import Optional, Any +from pydantic import BaseModel, Field from metagpt.actions import Action -from metagpt.config import Config +from metagpt.llm import LLM +from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.config import Config, CONFIG from metagpt.logs import logger from metagpt.schema import Message from metagpt.tools.search_engine import SearchEngine +from pydantic import root_validator SEARCH_AND_SUMMARIZE_SYSTEM = """### Requirements 1. Please summarize the latest dialogue based on the reference information (secondary) and dialogue history (primary). Do not include text that is irrelevant to the conversation. @@ -54,7 +59,6 @@ SEARCH_AND_SUMMARIZE_PROMPT = """ """ - SEARCH_AND_SUMMARIZE_SALES_SYSTEM = """## Requirements 1. Please summarize the latest dialogue based on the reference information (secondary) and dialogue history (primary). Do not include text that is irrelevant to the conversation. - The context is for reference only. If it is irrelevant to the user's search request history, please reduce its reference and usage. @@ -101,23 +105,37 @@ You are a member of a professional butler team and will provide helpful suggesti class SearchAndSummarize(Action): - def __init__(self, name="", context=None, llm=None, engine=None, search_func=None): - self.config = Config() - self.engine = engine or self.config.search_engine + name: str = "" + content: Optional[str] = None + llm: BaseGPTAPI = Field(default_factory=LLM) + config: None = Field(default_factory=Config) + engine: Optional[str] = CONFIG.search_engine + search_func: Optional[str] = None + search_engine: SearchEngine = None + result = "" + + @root_validator + def validate_engine_and_run_func(cls, values): + engine = values.get("engine") + search_func = values.get("search_func") + config = Config() + + if engine is None: + engine = config.search_engine try: - self.search_engine = SearchEngine(self.engine, run_func=search_func) + search_engine = SearchEngine(engine=engine, run_func=search_func) except pydantic.ValidationError: - self.search_engine = None + search_engine = None - self.result = "" - super().__init__(name, context, llm) + values["search_engine"] = search_engine + return values async def run(self, context: list[Message], system_text=SEARCH_AND_SUMMARIZE_SYSTEM) -> str: if self.search_engine is None: logger.warning("Configure one of SERPAPI_API_KEY, SERPER_API_KEY, GOOGLE_API_KEY to unlock full feature") return "" - + query = context[-1].content # logger.debug(query) rsp = await self.search_engine.run(query) @@ -126,9 +144,9 @@ class SearchAndSummarize(Action): logger.error("empty rsp...") return "" # logger.info(rsp) - + system_prompt = [system_text] - + prompt = SEARCH_AND_SUMMARIZE_PROMPT.format( ROLE=self.prefix, CONTEXT=rsp, diff --git a/metagpt/actions/summarize_code.py b/metagpt/actions/summarize_code.py index f8d8d2b47..0aec15937 100644 --- a/metagpt/actions/summarize_code.py +++ b/metagpt/actions/summarize_code.py @@ -7,12 +7,15 @@ """ from pathlib import Path +from pydantic import Field from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.actions.action import Action from metagpt.config import CONFIG from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO +from metagpt.llm import LLM, BaseGPTAPI from metagpt.logs import logger +from metagpt.schema import CodeSummarizeContext from metagpt.utils.file_repository import FileRepository PROMPT_TEMPLATE = """ @@ -89,8 +92,9 @@ flowchart TB class SummarizeCode(Action): - def __init__(self, name="SummarizeCode", context=None, llm=None): - super().__init__(name, context, llm) + name: str = "SummarizeCode" + context: CodeSummarizeContext = Field(default_factory=CodeSummarizeContext) + llm: BaseGPTAPI = Field(default_factory=LLM) @retry(stop=stop_after_attempt(2), wait=wait_random_exponential(min=1, max=60)) async def summarize_code(self, prompt): diff --git a/metagpt/actions/write_code.py b/metagpt/actions/write_code.py index 5960e2621..4d0690e0f 100644 --- a/metagpt/actions/write_code.py +++ b/metagpt/actions/write_code.py @@ -14,8 +14,10 @@ 3. Encapsulate the input of RunCode into RunCodeContext and encapsulate the output of RunCode into RunCodeResult to standardize and unify parameter passing between WriteCode, RunCode, and DebugError. """ + import json +from pydantic import Field from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.actions.action import Action @@ -27,7 +29,9 @@ from metagpt.const import ( TASK_FILE_REPO, TEST_OUTPUTS_FILE_REPO, ) +from metagpt.llm import LLM from metagpt.logs import logger +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.schema import CodingContext, Document, RunCodeResult from metagpt.utils.common import CodeParser from metagpt.utils.file_repository import FileRepository @@ -84,8 +88,9 @@ ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenc class WriteCode(Action): - def __init__(self, name="WriteCode", context=None, llm=None): - super().__init__(name, context, llm) + name: str = "WriteCode" + context: Document = Field(default_factory=Document) + llm: BaseGPTAPI = Field(default_factory=LLM) @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) async def write_code(self, prompt) -> str: @@ -126,7 +131,9 @@ class WriteCode(Action): logger.info(f"Writing {coding_context.filename}..") code = await self.write_code(prompt) if not coding_context.code_doc: - coding_context.code_doc = Document(filename=coding_context.filename, root_path=CONFIG.src_workspace) + # avoid root_path pydantic ValidationError if use WriteCode alone + root_path = CONFIG.src_workspace if CONFIG.src_workspace else "" + coding_context.code_doc = Document(filename=coding_context.filename, root_path=root_path) coding_context.code_doc.content = code return coding_context diff --git a/metagpt/actions/write_code_review.py b/metagpt/actions/write_code_review.py index 365c87063..580069b74 100644 --- a/metagpt/actions/write_code_review.py +++ b/metagpt/actions/write_code_review.py @@ -8,12 +8,15 @@ WriteCode object, rather than passing them in when calling the run function. """ +from pydantic import Field from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.actions import WriteCode from metagpt.actions.action import Action from metagpt.config import CONFIG +from metagpt.llm import LLM from metagpt.logs import logger +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.schema import CodingContext from metagpt.utils.common import CodeParser @@ -32,7 +35,6 @@ ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenc ``` """ - EXAMPLE_AND_INSTRUCTION = """ {format_example} @@ -119,8 +121,9 @@ REWRITE_CODE_TEMPLATE = """ class WriteCodeReview(Action): - def __init__(self, name="WriteCodeReview", context=None, llm=None): - super().__init__(name, context, llm) + name: str = "WriteCodeReview" + context: CodingContext = Field(default_factory=CodingContext) + llm: BaseGPTAPI = Field(default_factory=LLM) @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) async def write_code_review_and_rewrite(self, context_prompt, cr_prompt, filename): @@ -139,9 +142,15 @@ class WriteCodeReview(Action): iterative_code = self.context.code_doc.content k = CONFIG.code_review_k_times or 1 for i in range(k): - format_example = FORMAT_EXAMPLE.format(filename=self.context.code_doc.filename) - task_content = self.context.task_doc.content if self.context.task_doc else "" - code_context = await WriteCode.get_codes(self.context.task_doc, exclude=self.context.filename) + format_example = FORMAT_EXAMPLE.format( + filename=self.context.code_doc.filename + ) + task_content = ( + self.context.task_doc.content if self.context.task_doc else "" + ) + code_context = await WriteCode.get_codes( + self.context.task_doc, exclude=self.context.filename + ) context = "\n".join( [ "## System Design\n" + str(self.context.design_doc) + "\n", @@ -158,7 +167,8 @@ class WriteCodeReview(Action): format_example=format_example, ) logger.info( - f"Code review and rewrite {self.context.code_doc.filename}: {i+1}/{k} | {len(iterative_code)=}, {len(self.context.code_doc.content)=}" + f"Code review and rewrite {self.context.code_doc.filename}: {i + 1}/{k} | {len(iterative_code)=}, " + f"{len(self.context.code_doc.content)=}" ) result, rewrited_code = await self.write_code_review_and_rewrite( context_prompt, cr_prompt, self.context.code_doc.filename diff --git a/metagpt/actions/write_prd.py b/metagpt/actions/write_prd.py index 7c160fa89..df66e6442 100644 --- a/metagpt/actions/write_prd.py +++ b/metagpt/actions/write_prd.py @@ -10,10 +10,14 @@ 3. Move the document storage operations related to WritePRD from the save operation of WriteDesign. @Modified By: mashenquan, 2023/12/5. Move the generation logic of the project name to WritePRD. """ + from __future__ import annotations import json from pathlib import Path +from typing import Optional + +from pydantic import Field from metagpt.actions import Action, ActionOutput from metagpt.actions.action_node import ActionNode @@ -32,17 +36,14 @@ from metagpt.const import ( PRDS_FILE_REPO, REQUIREMENT_FILENAME, ) +from metagpt.llm import LLM from metagpt.logs import logger +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.schema import BugFixContext, Document, Documents, Message from metagpt.utils.common import CodeParser from metagpt.utils.file_repository import FileRepository - -# from metagpt.utils.get_template import get_template from metagpt.utils.mermaid import mermaid_to_file -# from typing import List - - CONTEXT_TEMPLATE = """ ### Project Name {project_name} @@ -64,15 +65,16 @@ NEW_REQ_TEMPLATE = """ class WritePRD(Action): - def __init__(self, name="", context=None, llm=None): - super().__init__(name, context, llm) + name: str = "" + content: Optional[str] = None + llm: BaseGPTAPI = Field(default_factory=LLM) async def run(self, with_messages, schema=CONFIG.prompt_schema, *args, **kwargs) -> ActionOutput | Message: # Determine which requirement documents need to be rewritten: Use LLM to assess whether new requirements are # related to the PRD. If they are related, rewrite the PRD. docs_file_repo = CONFIG.git_repo.new_file_repository(relative_path=DOCS_FILE_REPO) requirement_doc = await docs_file_repo.get(filename=REQUIREMENT_FILENAME) - if await self._is_bugfix(requirement_doc.content): + if requirement_doc and await self._is_bugfix(requirement_doc.content): await docs_file_repo.save(filename=BUGFIX_FILENAME, content=requirement_doc.content) await docs_file_repo.save(filename=REQUIREMENT_FILENAME, content="") bug_fix = BugFixContext(filename=BUGFIX_FILENAME) @@ -141,7 +143,8 @@ class WritePRD(Action): async def _update_prd(self, requirement_doc, prd_doc, prds_file_repo, *args, **kwargs) -> Document | None: if not prd_doc: - prd = await self._run_new_requirement(requirements=[requirement_doc.content], *args, **kwargs) + prd = await self._run_new_requirement(requirements=[requirement_doc.content if requirement_doc else ""], + *args, **kwargs) new_prd_doc = Document( root_path=PRDS_FILE_REPO, filename=FileRepository.new_filename() + ".json", diff --git a/metagpt/actions/write_prd_review.py b/metagpt/actions/write_prd_review.py index 5ff9624c5..6ed73b6a2 100644 --- a/metagpt/actions/write_prd_review.py +++ b/metagpt/actions/write_prd_review.py @@ -5,20 +5,28 @@ @Author : alexanderwu @File : write_prd_review.py """ + +from typing import Optional + +from pydantic import Field + from metagpt.actions.action import Action +from metagpt.llm import LLM +from metagpt.provider.base_gpt_api import BaseGPTAPI class WritePRDReview(Action): - def __init__(self, name, context=None, llm=None): - super().__init__(name, context, llm) - self.prd = None - self.desc = "Based on the PRD, conduct a PRD Review, providing clear and detailed feedback" - self.prd_review_prompt_template = """ - Given the following Product Requirement Document (PRD): - {prd} + name: str = "" + context: Optional[str] = None + llm: BaseGPTAPI = Field(default_factory=LLM) + prd: Optional[str] = None + desc: str = "Based on the PRD, conduct a PRD Review, providing clear and detailed feedback" + prd_review_prompt_template: str = """ +Given the following Product Requirement Document (PRD): +{prd} - As a project manager, please review it and provide your feedback and suggestions. - """ +As a project manager, please review it and provide your feedback and suggestions. +""" async def run(self, prd): self.prd = prd diff --git a/metagpt/actions/write_test.py b/metagpt/actions/write_test.py index 9dd967788..fa3931ba6 100644 --- a/metagpt/actions/write_test.py +++ b/metagpt/actions/write_test.py @@ -7,6 +7,12 @@ @Modified By: mashenquan, 2023-11-27. Following the think-act principle, solidify the task parameters when creating the WriteTest object, rather than passing them in when calling the run function. """ + +from typing import Optional +from pydantic import Field + +from metagpt.llm import LLM +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.actions.action import Action from metagpt.config import CONFIG from metagpt.const import TEST_CODES_FILE_REPO @@ -36,8 +42,9 @@ you should correctly import the necessary classes based on these file locations! class WriteTest(Action): - def __init__(self, name="WriteTest", context=None, llm=None): - super().__init__(name, context, llm) + name: str = "WriteTest" + context: Optional[str] = None + llm: BaseGPTAPI = Field(default_factory=LLM) async def write_code(self, prompt): code_rsp = await self._aask(prompt) diff --git a/metagpt/const.py b/metagpt/const.py index 10de0ff66..3b4f2ae4b 100644 --- a/metagpt/const.py +++ b/metagpt/const.py @@ -55,11 +55,14 @@ DATA_PATH = METAGPT_ROOT / "data" RESEARCH_PATH = DATA_PATH / "research" TUTORIAL_PATH = DATA_PATH / "tutorial_docx" INVOICE_OCR_TABLE_PATH = DATA_PATH / "invoice_table" + UT_PATH = DATA_PATH / "ut" SWAGGER_PATH = UT_PATH / "files/api/" UT_PY_PATH = UT_PATH / "files/ut/" API_QUESTIONS_PATH = UT_PATH / "files/question/" +SERDESER_PATH = DEFAULT_WORKSPACE_ROOT / "storage" # TODO to store `storage` under the individual generated project + TMP = METAGPT_ROOT / "tmp" SOURCE_ROOT = METAGPT_ROOT / "metagpt" diff --git a/metagpt/environment.py b/metagpt/environment.py index 89b6f9d46..ab296557f 100644 --- a/metagpt/environment.py +++ b/metagpt/environment.py @@ -12,29 +12,84 @@ functionality is to be consolidated into the `Environment` class. """ import asyncio +from pathlib import Path from typing import Iterable, Set from pydantic import BaseModel, Field from metagpt.logs import logger -from metagpt.roles import Role +from metagpt.roles.role import Role, role_subclass_registry from metagpt.schema import Message -from metagpt.utils.common import is_subscribed +from metagpt.utils.common import is_subscribed, read_json_file, write_json_file class Environment(BaseModel): """环境,承载一批角色,角色可以向环境发布消息,可以被其他角色观察到 Environment, hosting a batch of roles, roles can publish messages to the environment, and can be observed by other roles - """ roles: dict[str, Role] = Field(default_factory=dict) members: dict[Role, Set] = Field(default_factory=dict) - history: str = Field(default="") # For debug + history: str = "" # For debug class Config: arbitrary_types_allowed = True + def __init__(self, **kwargs): + roles = [] + for role_key, role in kwargs.get("roles", {}).items(): + current_role = kwargs["roles"][role_key] + if isinstance(current_role, dict): + item_class_name = current_role.get("builtin_class_name", None) + for name, subclass in role_subclass_registry.items(): + registery_class_name = subclass.__fields__["builtin_class_name"].default + if item_class_name == registery_class_name: + current_role = subclass(**current_role) + break + kwargs["roles"][role_key] = current_role + roles.append(current_role) + super().__init__(**kwargs) + + self.add_roles(roles) # add_roles again to init the Role.set_env + + def serialize(self, stg_path: Path): + roles_path = stg_path.joinpath("roles.json") + roles_info = [] + for role_key, role in self.roles.items(): + roles_info.append({ + "role_class": role.__class__.__name__, + "module_name": role.__module__, + "role_name": role.name, + "role_sub_tags": list(self.members.get(role)) + }) + role.serialize(stg_path=stg_path.joinpath(f"roles/{role.__class__.__name__}_{role.name}")) + write_json_file(roles_path, roles_info) + + history_path = stg_path.joinpath("history.json") + write_json_file(history_path, {"content": self.history}) + + @classmethod + def deserialize(cls, stg_path: Path) -> "Environment": + """ stg_path: ./storage/team/environment/ """ + roles_path = stg_path.joinpath("roles.json") + roles_info = read_json_file(roles_path) + roles = [] + for role_info in roles_info: + # role stored in ./environment/roles/{role_class}_{role_name} + role_path = stg_path.joinpath(f"roles/{role_info.get('role_class')}_{role_info.get('role_name')}") + role = Role.deserialize(role_path) + roles.append(role) + + history = read_json_file(stg_path.joinpath("history.json")) + history = history.get("content") + + environment = Environment(**{ + "history": history + }) + environment.add_roles(roles) + + return environment + def add_role(self, role: Role): """增加一个在当前环境的角色 Add a role in the current environment diff --git a/metagpt/memory/longterm_memory.py b/metagpt/memory/longterm_memory.py index ab2214261..76a8deabb 100644 --- a/metagpt/memory/longterm_memory.py +++ b/metagpt/memory/longterm_memory.py @@ -4,6 +4,12 @@ @Desc : the implement of Long-term memory """ +from typing import Optional +from pydantic import Field + +from typing import Optional +from pydantic import Field + from metagpt.logs import logger from metagpt.memory import Memory from metagpt.memory.memory_storage import MemoryStorage @@ -16,12 +22,12 @@ class LongTermMemory(Memory): - recover memory when it staruped - update memory when it changed """ + memory_storage: MemoryStorage = Field(default_factory=MemoryStorage) + rc: Optional["RoleContext"] = None + msg_from_recover: bool = False - def __init__(self): - self.memory_storage: MemoryStorage = MemoryStorage() - super().__init__() - self.rc = None # RoleContext - self.msg_from_recover = False + class Config: + arbitrary_types_allowed = True def recover_memory(self, role_id: str, rc: "RoleContext"): messages = self.memory_storage.recover_memory(role_id) diff --git a/metagpt/memory/memory.py b/metagpt/memory/memory.py index 53b65fcf7..076db832a 100644 --- a/metagpt/memory/memory.py +++ b/metagpt/memory/memory.py @@ -6,20 +6,46 @@ @File : memory.py @Modified By: mashenquan, 2023-11-1. According to RFC 116: Updated the type of index key. """ +import copy from collections import defaultdict +from pathlib import Path from typing import Iterable, Set +from pydantic import BaseModel, Field + from metagpt.schema import Message -from metagpt.utils.common import any_to_str, any_to_str_set +from metagpt.utils.common import any_to_str, any_to_str_set, read_json_file, write_json_file -class Memory: +class Memory(BaseModel): """The most basic memory: super-memory""" + storage: list[Message] = [] + index: dict[str, list[Message]] = Field(default_factory=defaultdict(list)) - def __init__(self): - """Initialize an empty storage list and an empty index dictionary""" - self.storage: list[Message] = [] - self.index: dict[str, list[Message]] = defaultdict(list) + def __init__(self, **kwargs): + index = kwargs.get("index", {}) + new_index = defaultdict(list) + for action_str, value in index.items(): + new_index[action_str] = [Message(**item_dict) for item_dict in value] + kwargs["index"] = new_index + super(Memory, self).__init__(**kwargs) + self.index = new_index + + def serialize(self, stg_path: Path): + """ stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/ """ + memory_path = stg_path.joinpath("memory.json") + storage = self.dict() + write_json_file(memory_path, storage) + + @classmethod + def deserialize(cls, stg_path: Path) -> "Memory": + """ stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/""" + memory_path = stg_path.joinpath("memory.json") + + memory_dict = read_json_file(memory_path) + memory = Memory(**memory_dict) + + return memory def add(self, message: Message): """Add a new message to storage, while updating the index""" @@ -41,6 +67,16 @@ class Memory: """Return all messages containing a specified content""" return [message for message in self.storage if content in message.content] + def delete_newest(self) -> "Message": + """ delete the newest message from the storage""" + if len(self.storage) > 0: + newest_msg = self.storage.pop() + if newest_msg.cause_by and newest_msg in self.index[newest_msg.cause_by]: + self.index[newest_msg.cause_by].remove(newest_msg) + else: + newest_msg = None + return newest_msg + def delete(self, message: Message): """Delete the specified message from storage, while updating the index""" self.storage.remove(message) diff --git a/metagpt/provider/postprecess/base_postprecess_plugin.py b/metagpt/provider/postprecess/base_postprecess_plugin.py index 721476507..afcef2531 100644 --- a/metagpt/provider/postprecess/base_postprecess_plugin.py +++ b/metagpt/provider/postprecess/base_postprecess_plugin.py @@ -44,7 +44,7 @@ class BasePostPrecessPlugin(object): def run_retry_parse_json_text(self, content: str) -> Union[dict, list]: """inherited class can re-implement the function""" - logger.debug(f"extracted json CONTENT from output:\n{content}") + # logger.info(f"extracted json CONTENT from output:\n{content}") parsed_data = retry_parse_json_text(output=content) # should use output=content return parsed_data diff --git a/metagpt/roles/architect.py b/metagpt/roles/architect.py index fce6c3425..bd6cd110b 100644 --- a/metagpt/roles/architect.py +++ b/metagpt/roles/architect.py @@ -5,10 +5,11 @@ @Author : alexanderwu @File : architect.py """ +from pydantic import Field from metagpt.actions import WritePRD from metagpt.actions.design_api import WriteDesign -from metagpt.roles import Role +from metagpt.roles.role import Role class Architect(Role): @@ -21,18 +22,14 @@ class Architect(Role): goal (str): Primary goal or responsibility of the architect. constraints (str): Constraints or guidelines for the architect. """ + name: str = "Bob" + profile: str = "Architect" + goal: str = "design a concise, usable, complete software system" + constraints: str = "make sure the architecture is simple enough and use appropriate open source " \ + "libraries. Use same language as user requirement" - def __init__( - self, - name: str = "Bob", - profile: str = "Architect", - goal: str = "design a concise, usable, complete software system", - constraints: str = "make sure the architecture is simple enough and use appropriate open source libraries." - "Use same language as user requirement", - ) -> None: - """Initializes the Architect with given attributes.""" - super().__init__(name, profile, goal, constraints) - + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) # Initialize actions specific to the Architect role self._init_actions([WriteDesign]) diff --git a/metagpt/roles/customer_service.py b/metagpt/roles/customer_service.py index 188182d47..b2033ac0b 100644 --- a/metagpt/roles/customer_service.py +++ b/metagpt/roles/customer_service.py @@ -5,6 +5,9 @@ @Author : alexanderwu @File : sales.py """ +from typing import Optional +from pydantic import Field + from metagpt.roles import Sales # from metagpt.actions import SearchAndSummarize @@ -24,5 +27,14 @@ DESC = """ class CustomerService(Sales): - def __init__(self, name="Xiaomei", profile="Human customer service", desc=DESC, store=None): - super().__init__(name, profile, desc=desc, store=store) + + name: str = "Xiaomei" + profile: str = "Human customer service" + desc: str = DESC + + store: Optional[str] = None + + def __init__( + self, + **kwargs): + super().__init__(**kwargs) diff --git a/metagpt/roles/engineer.py b/metagpt/roles/engineer.py index 2620fe4e5..337184068 100644 --- a/metagpt/roles/engineer.py +++ b/metagpt/roles/engineer.py @@ -16,6 +16,7 @@ @Modified By: mashenquan, 2023-12-5. Enhance the workflow to navigate to WriteCode or QaEngineer based on the results of SummarizeCode. """ + from __future__ import annotations import json @@ -23,6 +24,8 @@ from collections import defaultdict from pathlib import Path from typing import Set +from pydantic import Field + from metagpt.actions import Action, WriteCode, WriteCodeReview, WriteTasks from metagpt.actions.fix_bug import FixBug from metagpt.actions.summarize_code import SummarizeCode @@ -66,25 +69,21 @@ class Engineer(Role): n_borg (int): Number of borgs. use_code_review (bool): Whether to use code review. """ + name: str = "Alex" + profile: str = "Engineer" + goal: str = "write elegant, readable, extensible, efficient code" + constraints: str = "the code should conform to standards like google-style and be modular and maintainable. " \ + "Use same language as user requirement" + n_borg: int = 1 + use_code_review: bool = False + code_todos: list = [] + summarize_todos = [] + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) - def __init__( - self, - name: str = "Alex", - profile: str = "Engineer", - goal: str = "write elegant, readable, extensible, efficient code", - constraints: str = "the code should conform to standards like google-style and be modular and maintainable. " - "Use same language as user requirement", - n_borg: int = 1, - use_code_review: bool = False, - ) -> None: - """Initializes the Engineer role with given attributes.""" - super().__init__(name, profile, goal, constraints) - self.use_code_review = use_code_review self._init_actions([WriteCode]) self._watch([WriteTasks, SummarizeCode, WriteCode, WriteCodeReview, FixBug]) - self.code_todos = [] - self.summarize_todos = [] - self.n_borg = n_borg @staticmethod def _parse_tasks(task_msg: Document) -> list[str]: @@ -213,7 +212,7 @@ class Engineer(Role): @staticmethod async def _new_coding_context( - filename, src_file_repo, task_file_repo, design_file_repo, dependency + filename, src_file_repo, task_file_repo, design_file_repo, dependency ) -> CodingContext: old_code_doc = await src_file_repo.get(filename) if not old_code_doc: diff --git a/metagpt/roles/product_manager.py b/metagpt/roles/product_manager.py index 61263cb50..6369688a5 100644 --- a/metagpt/roles/product_manager.py +++ b/metagpt/roles/product_manager.py @@ -7,10 +7,12 @@ @Modified By: mashenquan, 2023/11/27. Add `PrepareDocuments` action according to Section 2.2.3.5.1 of RFC 135. """ +from pydantic import Field + from metagpt.actions import UserRequirement, WritePRD from metagpt.actions.prepare_documents import PrepareDocuments from metagpt.config import CONFIG -from metagpt.roles import Role +from metagpt.roles.role import Role class ProductManager(Role): @@ -23,24 +25,13 @@ class ProductManager(Role): goal (str): Goal of the product manager. constraints (str): Constraints or limitations for the product manager. """ + name: str = "Alice" + profile: str = "Product Manager" + goal: str = "efficiently create a successful product that meets market demands and user expectations" + constraints: str = "utilize the same language as the user requirements for seamless communication" - def __init__( - self, - name: str = "Alice", - profile: str = "Product Manager", - goal: str = "efficiently create a successful product that meets market demands and user expectations", - constraints: str = "utilize the same language as the user requirements for seamless communication", - ) -> None: - """ - Initializes the ProductManager role with given attributes. - - Args: - name (str): Name of the product manager. - profile (str): Role profile. - goal (str): Goal of the product manager. - constraints (str): Constraints or limitations for the product manager. - """ - super().__init__(name, profile, goal, constraints) + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) self._init_actions([PrepareDocuments, WritePRD]) self._watch([UserRequirement, PrepareDocuments]) diff --git a/metagpt/roles/project_manager.py b/metagpt/roles/project_manager.py index 657737513..bf572d1f8 100644 --- a/metagpt/roles/project_manager.py +++ b/metagpt/roles/project_manager.py @@ -5,9 +5,11 @@ @Author : alexanderwu @File : project_manager.py """ +from pydantic import Field + from metagpt.actions import WriteTasks from metagpt.actions.design_api import WriteDesign -from metagpt.roles import Role +from metagpt.roles.role import Role class ProjectManager(Role): @@ -20,24 +22,14 @@ class ProjectManager(Role): goal (str): Goal of the project manager. constraints (str): Constraints or limitations for the project manager. """ + name: str = "Eve" + profile: str = "Project Manager" + goal: str = "break down tasks according to PRD/technical design, generate a task list, and analyze task " \ + "dependencies to start with the prerequisite modules" + constraints: str = "use same language as user requirement" - def __init__( - self, - name: str = "Eve", - profile: str = "Project Manager", - goal: str = "break down tasks according to PRD/technical design, generate a task list, and analyze task " - "dependencies to start with the prerequisite modules", - constraints: str = "use same language as user requirement", - ) -> None: - """ - Initializes the ProjectManager role with given attributes. + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) - Args: - name (str): Name of the project manager. - profile (str): Role profile. - goal (str): Goal of the project manager. - constraints (str): Constraints or limitations for the project manager. - """ - super().__init__(name, profile, goal, constraints) self._init_actions([WriteTasks]) self._watch([WriteDesign]) diff --git a/metagpt/roles/qa_engineer.py b/metagpt/roles/qa_engineer.py index 71b474a3b..369e3dc63 100644 --- a/metagpt/roles/qa_engineer.py +++ b/metagpt/roles/qa_engineer.py @@ -14,7 +14,14 @@ @Modified By: mashenquan, 2023-12-5. Enhance the workflow to navigate to WriteCode or QaEngineer based on the results of SummarizeCode. """ -from metagpt.actions import DebugError, RunCode, WriteTest + +from pydantic import Field + +from metagpt.actions import ( + DebugError, + RunCode, + WriteTest, +) from metagpt.actions.summarize_code import SummarizeCode from metagpt.config import CONFIG from metagpt.const import ( @@ -30,21 +37,21 @@ from metagpt.utils.file_repository import FileRepository class QaEngineer(Role): - def __init__( - self, - name="Edward", - profile="QaEngineer", - goal="Write comprehensive and robust tests to ensure codes will work as expected without bugs", - constraints="The test code you write should conform to code standard like PEP8, be modular, easy to read and maintain", - test_round_allowed=5, - ): - super().__init__(name, profile, goal, constraints) - self._init_actions( - [WriteTest] - ) # FIXME: a bit hack here, only init one action to circumvent _think() logic, will overwrite _think() in future updates + name: str = "Edward" + profile: str = "QaEngineer" + goal: str = "Write comprehensive and robust tests to ensure codes will work as expected without bugs" + constraints: str = "The test code you write should conform to code standard like PEP8, be modular, " \ + "easy to read and maintain" + test_round_allowed: int = 5 + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + # FIXME: a bit hack here, only init one action to circumvent _think() logic, + # will overwrite _think() in future updates + self._init_actions([WriteTest]) self._watch([SummarizeCode, WriteTest, RunCode, DebugError]) self.test_round = 0 - self.test_round_allowed = test_round_allowed async def _write_test(self, message: Message) -> None: src_file_repo = CONFIG.git_repo.new_file_repository(CONFIG.src_workspace) diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index bf37a6637..f87c4e250 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -18,20 +18,26 @@ @Modified By: mashenquan, 2023-11-4. According to the routing feature plan in Chapter 2.2.3.2 of RFC 113, the routing functionality is to be consolidated into the `Environment` class. """ + from __future__ import annotations from enum import Enum -from typing import Iterable, Set, Type +from pathlib import Path +from typing import Iterable, Set, Type, Any from pydantic import BaseModel, Field -from metagpt.actions import Action, ActionOutput, UserRequirement +from metagpt.actions import Action, ActionOutput +from metagpt.actions.action import action_subclass_registry from metagpt.actions.action_node import ActionNode +from metagpt.actions.add_requirement import UserRequirement +from metagpt.const import SERDESER_PATH from metagpt.llm import LLM, HumanProvider from metagpt.logs import logger from metagpt.memory import Memory +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.schema import Message, MessageQueue -from metagpt.utils.common import any_to_str +from metagpt.utils.common import any_to_str, read_json_file, write_json_file, import_class, role_raise_decorator from metagpt.utils.repair_llm_raw_output import extract_state_value_from_output PREFIX_TEMPLATE = """You are a {profile}, named {name}, your goal is {goal}, and the constraint is {constraints}. """ @@ -74,37 +80,20 @@ class RoleReactMode(str, Enum): return [item.value for item in cls] -class RoleSetting(BaseModel): - """Role Settings""" - - name: str - profile: str - goal: str - constraints: str - desc: str - is_human: bool - - def __str__(self): - return f"{self.name}({self.profile})" - - def __repr__(self): - return self.__str__() - - class RoleContext(BaseModel): """Role Runtime Context""" - - env: "Environment" = Field(default=None) - msg_buffer: MessageQueue = Field(default_factory=MessageQueue) # Message Buffer with Asynchronous Updates + # # env exclude=True to avoid `RecursionError: maximum recursion depth exceeded in comparison` + env: "Environment" = Field(default=None, exclude=True) + # TODO judge if ser&deser + msg_buffer: MessageQueue = Field(default_factory=MessageQueue, + exclude=True) # Message Buffer with Asynchronous Updates memory: Memory = Field(default_factory=Memory) # long_term_memory: LongTermMemory = Field(default_factory=LongTermMemory) state: int = Field(default=-1) # -1 indicates initial or termination state where todo is None - todo: Action = Field(default=None) + todo: Action = Field(default=None, exclude=True) watch: set[str] = Field(default_factory=set) - news: list[Type[Message]] = Field(default=[]) - react_mode: RoleReactMode = ( - RoleReactMode.REACT - ) # see `Role._set_react_mode` for definitions of the following two attributes + news: list[Type[Message]] = Field(default=[], exclude=True) # TODO not used + react_mode: RoleReactMode = RoleReactMode.REACT # see `Role._set_react_mode` for definitions of the following two attributes max_react_loop: int = 1 class Config: @@ -126,33 +115,146 @@ class RoleContext(BaseModel): return self.memory.get() -class Role: - """Role/Agent""" +role_subclass_registry = {} + + +class Role(BaseModel): + """Role/Agent""" + name: str = "" + profile: str = "" + goal: str = "" + constraints: str = "" + desc: str = "" + is_human: bool = False + + _llm: BaseGPTAPI = Field(default_factory=LLM) + _role_id: str = "" + _states: list[str] = [] + _actions: list[Action] = [] + _rc: RoleContext = Field(default_factory=RoleContext) + _subscription: tuple[str] = set() + + # builtin variables + recovered: bool = False # to tag if a recovered role + latest_observed_msg: Message = None # record the latest observed message when interrupted + builtin_class_name: str = "" + + _private_attributes = { + "_llm": LLM() if not is_human else HumanProvider(), + "_role_id": _role_id, + "_states": [], + "_actions": [], + "_rc": RoleContext(), + "_subscription": set() + } + + __hash__ = object.__hash__ # support Role as hashable type in `Environment.members` + + class Config: + arbitrary_types_allowed = True + exclude = ["_llm"] + + def __init__(self, **kwargs: Any): + for index in range(len(kwargs.get("_actions", []))): + current_action = kwargs["_actions"][index] + if isinstance(current_action, dict): + item_class_name = current_action.get("builtin_class_name", None) + for name, subclass in action_subclass_registry.items(): + registery_class_name = subclass.__fields__["builtin_class_name"].default + if item_class_name == registery_class_name: + current_action = subclass(**current_action) + break + kwargs["_actions"][index] = current_action + + super().__init__(**kwargs) + + # 关于私有变量的初始化 https://github.com/pydantic/pydantic/issues/655 + self._private_attributes["_llm"] = LLM() if not self.is_human else HumanProvider() + self._private_attributes["_role_id"] = str(self._setting) + self._private_attributes["_subscription"] = {any_to_str(self), self.name} if self.name else {any_to_str(self)} + + for key in self._private_attributes.keys(): + if key in kwargs: + object.__setattr__(self, key, kwargs[key]) + if key == "_rc": + _rc = RoleContext(**kwargs["_rc"]) + object.__setattr__(self, "_rc", _rc) + else: + if key == "_rc": + # # Warning, if use self._private_attributes["_rc"], + # # self._rc will be a shared object between roles, so init one or reset it inside `_reset` + object.__setattr__(self, key, RoleContext()) + else: + object.__setattr__(self, key, self._private_attributes[key]) - def __init__(self, name="", profile="", goal="", constraints="", desc="", is_human=False): - self._llm = LLM() if not is_human else HumanProvider() - self._setting = RoleSetting( - name=name, profile=profile, goal=goal, constraints=constraints, desc=desc, is_human=is_human - ) self._llm.system_prompt = self._get_prefix() - self._states = [] - self._actions = [] - self._role_id = str(self._setting) - self._rc = RoleContext(watch={any_to_str(UserRequirement)}) - self._subscription = {any_to_str(self), name} if name else {any_to_str(self)} + + # deserialize child classes dynamically for inherited `role` + object.__setattr__(self, "builtin_class_name", self.__class__.__name__) + self.__fields__["builtin_class_name"].default = self.__class__.__name__ + + def __init_subclass__(cls, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + role_subclass_registry[cls.__name__] = cls def _reset(self): - self._states = [] - self._actions = [] + object.__setattr__(self, "_states", []) + object.__setattr__(self, "_actions", []) + + @property + def _setting(self): + return f"{self.name}({self.profile})" + + def serialize(self, stg_path: Path = None): + stg_path = SERDESER_PATH.joinpath(f"team/environment/roles/{self.__class__.__name__}_{self.name}") \ + if stg_path is None else stg_path + + role_info = self.dict(exclude={"_rc": {"memory": True, "msg_buffer": True}, "_llm": True}) + role_info.update({ + "role_class": self.__class__.__name__, + "module_name": self.__module__ + }) + role_info_path = stg_path.joinpath("role_info.json") + write_json_file(role_info_path, role_info) + + self._rc.memory.serialize(stg_path) # serialize role's memory alone + + @classmethod + def deserialize(cls, stg_path: Path) -> "Role": + """ stg_path = ./storage/team/environment/roles/{role_class}_{role_name}""" + role_info_path = stg_path.joinpath("role_info.json") + role_info = read_json_file(role_info_path) + + role_class_str = role_info.pop("role_class") + module_name = role_info.pop("module_name") + role_class = import_class(class_name=role_class_str, module_name=module_name) + + role = role_class(**role_info) # initiate particular Role + role.set_recovered(True) # set True to make a tag + + role_memory = Memory.deserialize(stg_path) + role.set_memory(role_memory) + + return role def _init_action_system_message(self, action: Action): action.set_prefix(self._get_prefix()) + def set_recovered(self, recovered: bool = False): + self.recovered = recovered + + def set_memory(self, memory: Memory): + self._rc.memory = memory + + def init_actions(self, actions): + self._init_actions(actions) + def _init_actions(self, actions): self._reset() for idx, action in enumerate(actions): if not isinstance(action, Action): - i = action("", llm=self._llm) + ## 默认初始化 + i = action(name="", llm=self._llm) else: if self._setting.is_human and not isinstance(action.llm, HumanProvider): logger.warning( @@ -207,7 +309,7 @@ class Role: def _set_state(self, state: int): """Update the current state.""" self._rc.state = state - logger.debug(self._actions) + logger.debug(f"actions={self._actions}, state={state}") self._rc.todo = self._actions[self._rc.state] if state >= 0 else None def set_env(self, env: "Environment"): @@ -217,16 +319,6 @@ class Role: if env: env.set_subscription(self, self._subscription) - @property - def profile(self): - """Get the role description (position)""" - return self._setting.profile - - @property - def name(self): - """Get virtual user name""" - return self._setting.name - @property def subscription(self) -> Set: """The labels for messages to be consumed by the Role object.""" @@ -234,9 +326,14 @@ class Role: def _get_prefix(self): """Get the role prefix""" - if self._setting.desc: - return self._setting.desc - return PREFIX_TEMPLATE.format(**self._setting.dict()) + if self.desc: + return self.desc + return PREFIX_TEMPLATE.format(**{ + "profile": self.profile, + "name": self.name, + "goal": self.goal, + "constraints": self.constraints + }) async def _think(self) -> None: """Think about what to do and decide on the next action""" @@ -244,6 +341,11 @@ class Role: # If there is only one action, then only this one can be performed self._set_state(0) return + if self.recovered and self._rc.state >= 0: + self._set_state(self._rc.state) # action to run from recovered state + self.recovered = False # avoid max_react_loop out of work + return + prompt = self._get_prefix() prompt += STATE_TEMPLATE.format( history=self._rc.history, @@ -251,10 +353,11 @@ class Role: n_states=len(self._states) - 1, previous_state=self._rc.state, ) - # print(prompt) + next_state = await self._llm.aask(prompt) next_state = extract_state_value_from_output(next_state) logger.debug(f"{prompt=}") + if (not next_state.isdigit() and next_state != "-1") or int(next_state) not in range(-1, len(self._states)): logger.warning(f"Invalid answer of state, {next_state=}, will be set to -1") next_state = -1 @@ -283,15 +386,30 @@ class Role: return msg + def _find_news(self, observed: list[Message], existed: list[Message]) -> list[Message]: + news = [] + # Warning, remove `id` here to make it work for recover + observed_pure = [msg.dict(exclude={"id": True}) for msg in observed] + existed_pure = [msg.dict(exclude={"id": True}) for msg in existed] + for idx, new in enumerate(observed_pure): + if new["cause_by"] in self._rc.watch and new not in existed_pure: + news.append(observed[idx]) + return news + async def _observe(self, ignore_memory=False) -> int: """Prepare new messages for processing from the message buffer and other sources.""" # Read unprocessed messages from the msg buffer. news = self._rc.msg_buffer.pop_all() + if self.recovered: + news = [self.latest_observed_msg] if self.latest_observed_msg else [] + else: + self.latest_observed_msg = news[-1] if len(news) > 0 else None # record the latest observed msg + # Store the read messages in your own memory to prevent duplicate processing. old_messages = [] if ignore_memory else self._rc.memory.get() self._rc.memory.add_batch(news) # Filter out messages of interest. - self._rc.news = [n for n in news if n.cause_by in self._rc.watch and n not in old_messages] + self._rc.news = self._find_news(news, old_messages) # Design Rules: # If you need to further categorize Message objects, you can do so using the Message.set_meta function. @@ -322,7 +440,7 @@ class Role: Use llm to select actions in _think dynamically """ actions_taken = 0 - rsp = Message("No actions taken yet") # will be overwritten after Role _act + rsp = Message(content="No actions taken yet") # will be overwritten after Role _act while actions_taken < self._rc.max_react_loop: # think await self._think() @@ -336,7 +454,8 @@ class Role: async def _act_by_order(self) -> Message: """switch action each time by order defined in _init_actions, i.e. _act (Action1) -> _act (Action2) -> ...""" - for i in range(len(self._states)): + start_idx = self._rc.state if self._rc.state >= 0 else 0 # action to run from recovered state + for i in range(start_idx, len(self._states)): self._set_state(i) rsp = await self._act() return rsp # return output from the last action @@ -344,7 +463,7 @@ class Role: async def _plan_and_act(self) -> Message: """first plan, then execute an action sequence, i.e. _think (of a plan) -> _act -> _act -> ... Use llm to come up with the plan dynamically.""" # TODO: to be implemented - return Message("") + return Message(content="") async def react(self) -> Message: """Entry to one of three strategies by which Role reacts to the observed Message""" @@ -378,16 +497,17 @@ class Role: """A wrapper to return the most recent k memories of this role, return all when k=0""" return self._rc.memory.get(k=k) + @role_raise_decorator async def run(self, with_message=None): """Observe, and think and act based on the results of the observation""" if with_message: msg = None if isinstance(with_message, str): - msg = Message(with_message) + msg = Message(content=with_message) elif isinstance(with_message, Message): msg = with_message elif isinstance(with_message, list): - msg = Message("\n".join(with_message)) + msg = Message(content="\n".join(with_message)) if not msg.cause_by: msg.cause_by = UserRequirement self.put_message(msg) diff --git a/metagpt/roles/sales.py b/metagpt/roles/sales.py index d5aac1824..fd5a42915 100644 --- a/metagpt/roles/sales.py +++ b/metagpt/roles/sales.py @@ -5,26 +5,31 @@ @Author : alexanderwu @File : sales.py """ + +from typing import Optional +from pydantic import Field + from metagpt.actions import SearchAndSummarize from metagpt.roles import Role from metagpt.tools import SearchEngineType class Sales(Role): - def __init__( - self, - name="Xiaomei", - profile="Retail sales guide", - desc="I am a sales guide in retail. My name is Xiaomei. I will answer some customer questions next, and I " - "will answer questions only based on the information in the knowledge base." - "If I feel that you can't get the answer from the reference material, then I will directly reply that" - " I don't know, and I won't tell you that this is from the knowledge base," - "but pretend to be what I know. Note that each of my replies will be replied in the tone of a " - "professional guide", - store=None, - ): - super().__init__(name, profile, desc=desc) - self._set_store(store) + + name: str = "Xiaomei" + profile: str = "Retail sales guide" + desc: str = "I am a sales guide in retail. My name is Xiaomei. I will answer some customer questions next, and I " + "will answer questions only based on the information in the knowledge base." + "If I feel that you can't get the answer from the reference material, then I will directly reply that" + " I don't know, and I won't tell you that this is from the knowledge base," + "but pretend to be what I know. Note that each of my replies will be replied in the tone of a " + "professional guide" + + store: Optional[str] = None + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._set_store(self.store) def _set_store(self, store): if store: diff --git a/metagpt/roles/searcher.py b/metagpt/roles/searcher.py index 31de8e896..a5c399f47 100644 --- a/metagpt/roles/searcher.py +++ b/metagpt/roles/searcher.py @@ -7,6 +7,9 @@ @Modified By: mashenquan, 2023-11-1. According to Chapter 2.2.1 and 2.2.2 of RFC 116, change the data type of the `cause_by` value in the `Message` to a string to support the new message distribution feature. """ + +from pydantic import Field + from metagpt.actions import ActionOutput, SearchAndSummarize from metagpt.actions.action_node import ActionNode from metagpt.logs import logger @@ -27,15 +30,13 @@ class Searcher(Role): engine (SearchEngineType): The type of search engine to use. """ - def __init__( - self, - name: str = "Alice", - profile: str = "Smart Assistant", - goal: str = "Provide search services for users", - constraints: str = "Answer is rich and complete", - engine=SearchEngineType.SERPAPI_GOOGLE, - **kwargs, - ) -> None: + name: str = Field(default="Alice") + profile: str = Field(default="Smart Assistant") + goal: str = "Provide search services for users" + constraints: str = "Answer is rich and complete" + engine: SearchEngineType = SearchEngineType.SERPAPI_GOOGLE + + def __init__(self, **kwargs) -> None: """ Initializes the Searcher role with given attributes. @@ -46,8 +47,8 @@ class Searcher(Role): constraints (str): Constraints or limitations for the searcher. engine (SearchEngineType): The type of search engine to use. """ - super().__init__(name, profile, goal, constraints, **kwargs) - self._init_actions([SearchAndSummarize(engine=engine)]) + super().__init__(**kwargs) + self._init_actions([SearchAndSummarize(engine=self.engine)]) def set_search_func(self, search_func): """Sets a custom search function for the searcher.""" diff --git a/metagpt/schema.py b/metagpt/schema.py index d2f8d33e6..5103a4f20 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -12,6 +12,7 @@ between actions. 3. Add `id` to `Message` according to Section 2.2.3.1.1 of RFC 135. """ + from __future__ import annotations import asyncio @@ -22,7 +23,7 @@ from abc import ABC from asyncio import Queue, QueueEmpty, wait_for from json import JSONDecodeError from pathlib import Path -from typing import Dict, List, Optional, Set, Type, TypedDict, TypeVar +from typing import Dict, List, Optional, Set, Type, TypedDict, TypeVar, Any from pydantic import BaseModel, Field @@ -36,7 +37,9 @@ from metagpt.const import ( TASK_FILE_REPO, ) from metagpt.logs import logger -from metagpt.utils.common import any_to_str, any_to_str_set +from metagpt.utils.common import any_to_str, any_to_str_set, import_class +from metagpt.utils.serialize import actionoutout_schema_to_mapping, actionoutput_mapping_to_str, \ + actionoutput_str_to_mapping from metagpt.utils.exceptions import handle_exception @@ -98,41 +101,29 @@ class Message(BaseModel): id: str # According to Section 2.2.3.1.1 of RFC 135 content: str - instruct_content: BaseModel = Field(default=None) + instruct_content: BaseModel = None role: str = "user" # system / user / assistant cause_by: str = "" sent_from: str = "" send_to: Set = Field(default_factory={MESSAGE_ROUTE_TO_ALL}) - def __init__( - self, - content, - instruct_content=None, - role="user", - cause_by="", - sent_from="", - send_to=MESSAGE_ROUTE_TO_ALL, - **kwargs, - ): - """ - Parameters not listed below will be stored as meta info, including custom parameters. - :param content: Message content. - :param instruct_content: Message content struct. - :param cause_by: Message producer - :param sent_from: Message route info tells who sent this message. - :param send_to: Specifies the target recipient or consumer for message delivery in the environment. - :param role: Message meta info tells who sent this message. - """ - super().__init__( - id=uuid.uuid4().hex, - content=content, - instruct_content=instruct_content, - role=role, - cause_by=any_to_str(cause_by), - sent_from=any_to_str(sent_from), - send_to=any_to_str_set(send_to), - **kwargs, - ) + def __init__(self, **kwargs): + ic = kwargs.get("instruct_content", None) + if ic and not isinstance(ic, BaseModel) and "class" in ic: + # compatible with custom-defined ActionOutput + mapping = actionoutput_str_to_mapping(ic["mapping"]) + + actionnode_class = import_class("ActionNode", "metagpt.actions.action_node") # avoid circular import + ic_obj = actionnode_class.create_model_class(class_name=ic["class"], mapping=mapping) + ic_new = ic_obj(**ic["value"]) + kwargs["instruct_content"] = ic_new + + kwargs["id"] = kwargs.get("id", uuid.uuid4().hex) + kwargs["cause_by"] = any_to_str(kwargs.get("cause_by", + import_class("UserRequirement", "metagpt.actions.add_requirement"))) + kwargs["sent_from"] = any_to_str(kwargs.get("sent_from", "")) + kwargs["send_to"] = any_to_str_set(kwargs.get("send_to", {MESSAGE_ROUTE_TO_ALL})) + super(Message, self).__init__(**kwargs) def __setattr__(self, key, val): """Override `@property.setter`, convert non-string parameters into string parameters.""" @@ -146,6 +137,22 @@ class Message(BaseModel): new_val = val super().__setattr__(key, new_val) + def dict(self, *args, **kwargs) -> "DictStrAny": + """ overwrite the `dict` to dump dynamic pydantic model""" + obj_dict = super(Message, self).dict(*args, **kwargs) + ic = self.instruct_content + if ic: + # compatible with custom-defined ActionOutput + schema = ic.schema() + # `Documents` contain definitions + if "definitions" not in schema: + # TODO refine with nested BaseModel + mapping = actionoutout_schema_to_mapping(schema) + mapping = actionoutput_mapping_to_str(mapping) + + obj_dict["instruct_content"] = {"class": schema["title"], "mapping": mapping, "value": ic.dict()} + return obj_dict + def __str__(self): # prefix = '-'.join([self.role, str(self.cause_by)]) return f"{self.role}: {self.content}" @@ -196,11 +203,24 @@ class AIMessage(Message): super().__init__(content=content, role="assistant") -class MessageQueue: +class MessageQueue(BaseModel): """Message queue which supports asynchronous updates.""" - def __init__(self): - self._queue = Queue() + _queue: Queue = Field(default_factory=Queue) + + _private_attributes = { + "_queue": Queue() + } + + class Config: + arbitrary_types_allowed = True + + def __init__(self, **kwargs: Any): + for key in self._private_attributes.keys(): + if key in kwargs: + object.__setattr__(self, key, kwargs[key]) + else: + object.__setattr__(self, key, Queue()) def pop(self) -> Message | None: """Pop one message from the queue.""" diff --git a/metagpt/startup.py b/metagpt/startup.py index a1af90ffc..59e0cb199 100644 --- a/metagpt/startup.py +++ b/metagpt/startup.py @@ -3,6 +3,7 @@ import asyncio import typer +from pathlib import Path from metagpt.config import CONFIG @@ -31,6 +32,7 @@ def startup( help="The maximum number of times the 'SummarizeCode' action is automatically invoked, with -1 indicating " "unlimited. This parameter is used for debugging the workflow.", ), + recover_path: str = typer.Option(default=None, help="recover the project from existing serialized storage") ): """Run a startup. Be a boss.""" from metagpt.roles import ( @@ -44,20 +46,29 @@ def startup( CONFIG.update_via_cli(project_path, project_name, inc, reqa_file, max_auto_summarize_code) - company = Team() - company.hire( - [ - ProductManager(), - Architect(), - ProjectManager(), - ] - ) + if not recover_path: + company = Team() + company.hire( + [ + ProductManager(), + Architect(), + ProjectManager(), + ] + ) - if implement or code_review: - company.hire([Engineer(n_borg=5, use_code_review=code_review)]) + if implement or code_review: + company.hire([Engineer(n_borg=5, use_code_review=code_review)]) - if run_tests: - company.hire([QaEngineer()]) + if run_tests: + company.hire([QaEngineer()]) + else: + # # stg_path = SERDESER_PATH.joinpath("team") + stg_path = Path(recover_path) + if not stg_path.exists() or not str(stg_path).endswith("team"): + raise FileNotFoundError(f"{recover_path} not exists or not endswith `team`") + + company = Team.deserialize(stg_path=stg_path) + idea = company.idea # use original idea company.invest(investment) company.run_project(idea) diff --git a/metagpt/team.py b/metagpt/team.py index 570be4e6b..1df3c4052 100644 --- a/metagpt/team.py +++ b/metagpt/team.py @@ -7,17 +7,20 @@ @Modified By: mashenquan, 2023/11/27. Add an archiving operation after completing the project, as specified in Section 2.2.3.3 of RFC 135. """ + +from pathlib import Path import warnings from pydantic import BaseModel, Field from metagpt.actions import UserRequirement from metagpt.config import CONFIG from metagpt.const import MESSAGE_ROUTE_TO_ALL +from metagpt.const import SERDESER_PATH from metagpt.environment import Environment from metagpt.logs import logger from metagpt.roles import Role from metagpt.schema import Message -from metagpt.utils.common import NoMoneyException +from metagpt.utils.common import NoMoneyException, read_json_file, write_json_file, serialize_decorator class Team(BaseModel): @@ -33,6 +36,36 @@ class Team(BaseModel): class Config: arbitrary_types_allowed = True + def serialize(self, stg_path: Path = None): + stg_path = SERDESER_PATH.joinpath("team") if stg_path is None else stg_path + + team_info_path = stg_path.joinpath("team_info.json") + write_json_file(team_info_path, self.dict(exclude={"env": True})) + + self.env.serialize(stg_path.joinpath("environment")) # save environment alone + + @classmethod + def recover(cls, stg_path: Path) -> "Team": + return cls.deserialize(stg_path) + + @classmethod + def deserialize(cls, stg_path: Path) -> "Team": + """ stg_path = ./storage/team """ + # recover team_info + team_info_path = stg_path.joinpath("team_info.json") + if not team_info_path.exists(): + raise FileNotFoundError("recover storage meta file `team_info.json` not exist, " + "not to recover and please start a new project.") + + team_info: dict = read_json_file(team_info_path) + + # recover environment + environment = Environment.deserialize(stg_path=stg_path.joinpath("environment")) + team_info.update({"env": environment}) + + team = Team(**team_info) + return team + def hire(self, roles: list[Role]): """Hire roles to cooperate""" self.env.add_roles(roles) @@ -69,6 +102,7 @@ class Team(BaseModel): def _save(self): logger.info(self.json(ensure_ascii=False)) + @serialize_decorator async def run(self, n_round=3): """Run company until target round or no money""" while n_round > 0: @@ -76,6 +110,7 @@ class Team(BaseModel): n_round -= 1 logger.debug(f"max {n_round=} left.") self._check_balance() + await self.env.run() if CONFIG.git_repo: CONFIG.git_repo.archive() diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index a290c7db7..e123e8fd9 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -13,15 +13,21 @@ from __future__ import annotations import ast import contextlib +import importlib import inspect +import json import os import platform import re +import traceback import typing +from pathlib import Path +from typing import Any from typing import List, Tuple, Union, get_args, get_origin import aiofiles import loguru +from pydantic.json import pydantic_encoder from tenacity import RetryCallState, _utils from metagpt.const import MESSAGE_ROUTE_TO_ALL @@ -213,7 +219,7 @@ class OutputParser: if start_index != -1 and end_index != -1: # Extract the structure part - structure_text = text[start_index : end_index + 1] + structure_text = text[start_index: end_index + 1] try: # Attempt to convert the text to a Python data type using ast.literal_eval @@ -426,6 +432,77 @@ def general_after_log(i: "loguru.Logger", sec_format: str = "%0.3f") -> typing.C return log_it +def read_json_file(json_file: str, encoding=None) -> list[Any]: + if not Path(json_file).exists(): + raise FileNotFoundError(f"json_file: {json_file} not exist, return []") + + with open(json_file, "r", encoding=encoding) as fin: + try: + data = json.load(fin) + except Exception as exp: + raise ValueError(f"read json file: {json_file} failed") + return data + + +def write_json_file(json_file: str, data: list, encoding=None): + folder_path = Path(json_file).parent + if not folder_path.exists(): + folder_path.mkdir(parents=True, exist_ok=True) + + with open(json_file, "w", encoding=encoding) as fout: + json.dump(data, fout, ensure_ascii=False, indent=4, default=pydantic_encoder) + + +def import_class(class_name: str, module_name: str) -> type: + module = importlib.import_module(module_name) + a_class = getattr(module, class_name) + return a_class + + +def import_class_inst(class_name: str, module_name: str, *args, **kwargs) -> object: + a_class = import_class(class_name, module_name) + class_inst = a_class(*args, **kwargs) + return class_inst + + +def format_trackback_info(limit: int = 2): + return traceback.format_exc(limit=limit) + + +def serialize_decorator(func): + async def wrapper(self, *args, **kwargs): + try: + result = await func(self, *args, **kwargs) + return result + except KeyboardInterrupt as kbi: + logger.error(f"KeyboardInterrupt occurs, start to serialize the project, exp:\n{format_trackback_info()}") + except Exception as exp: + logger.error(f"Exception occurs, start to serialize the project, exp:\n{format_trackback_info()}") + self.serialize() # Team.serialize + + return wrapper + + +def role_raise_decorator(func): + async def wrapper(self, *args, **kwargs): + try: + return await func(self, *args, **kwargs) + except KeyboardInterrupt as kbi: + logger.error(f"KeyboardInterrupt: {kbi} occurs, start to serialize the project") + if self.latest_observed_msg: + self._rc.memory.delete(self.latest_observed_msg) + raise Exception(format_trackback_info(limit=None)) # raise again to make it captured outside + except Exception as exp: + if self.latest_observed_msg: + logger.warning("There is a exception in role's execution, in order to resume, " + "we delete the newest role communication message in the role's memory.") + # remove role newest observed msg to make it observed again + self._rc.memory.delete(self.latest_observed_msg) + raise Exception(format_trackback_info(limit=None)) # raise again to make it captured outside + + return wrapper + + @handle_exception async def aread(file_path: str) -> str: """Read file asynchronously.""" diff --git a/metagpt/utils/repair_llm_raw_output.py b/metagpt/utils/repair_llm_raw_output.py index 4aafd8e66..67ad4e963 100644 --- a/metagpt/utils/repair_llm_raw_output.py +++ b/metagpt/utils/repair_llm_raw_output.py @@ -253,7 +253,7 @@ def retry_parse_json_text(output: str) -> Union[list, dict]: if CONFIG.repair_llm_output is True, the _aask_v1 and the retry_parse_json_text will loop for {x=3*3} times. it's a two-layer retry cycle """ - logger.debug(f"output to json decode:\n{output}") + # logger.debug(f"output to json decode:\n{output}") # if CONFIG.repair_llm_output is True, it will try to fix output until the retry break parsed_data = CustomDecoder(strict=False).decode(output) diff --git a/metagpt/utils/serialize.py b/metagpt/utils/serialize.py index 5e52846e1..3939b1306 100644 --- a/metagpt/utils/serialize.py +++ b/metagpt/utils/serialize.py @@ -4,13 +4,11 @@ import copy import pickle -from typing import Dict, List -from metagpt.actions.action_node import ActionNode -from metagpt.schema import Message +from metagpt.utils.common import import_class -def actionoutout_schema_to_mapping(schema: Dict) -> Dict: +def actionoutout_schema_to_mapping(schema: dict) -> dict: """ directly traverse the `properties` in the first level. schema structure likes @@ -35,14 +33,31 @@ def actionoutout_schema_to_mapping(schema: Dict) -> Dict: if property["type"] == "string": mapping[field] = (str, ...) elif property["type"] == "array" and property["items"]["type"] == "string": - mapping[field] = (List[str], ...) + mapping[field] = (list[str], ...) elif property["type"] == "array" and property["items"]["type"] == "array": - # here only consider the `List[List[str]]` situation - mapping[field] = (List[List[str]], ...) + # here only consider the `list[list[str]]` situation + mapping[field] = (list[list[str]], ...) return mapping -def serialize_message(message: Message): +def actionoutput_mapping_to_str(mapping: dict) -> dict: + new_mapping = {} + for key, value in mapping.items(): + new_mapping[key] = str(value) + return new_mapping + + +def actionoutput_str_to_mapping(mapping: dict) -> dict: + new_mapping = {} + for key, value in mapping.items(): + if value == "(, Ellipsis)": + new_mapping[key] = (str, ...) + else: + new_mapping[key] = eval(value) # `"'(list[str], Ellipsis)"` to `(list[str], ...)` + return new_mapping + + +def serialize_message(message: "Message"): message_cp = copy.deepcopy(message) # avoid `instruct_content` value update by reference ic = message_cp.instruct_content if ic: @@ -56,11 +71,12 @@ def serialize_message(message: Message): return msg_ser -def deserialize_message(message_ser: str) -> Message: +def deserialize_message(message_ser: str) -> "Message": message = pickle.loads(message_ser) if message.instruct_content: ic = message.instruct_content - ic_obj = ActionNode.create_model_class(class_name=ic["class"], mapping=ic["mapping"]) + actionnode_class = import_class("ActionNode", "metagpt.actions.action_node") # avoid circular import + ic_obj = actionnode_class.create_model_class(class_name=ic["class"], mapping=ic["mapping"]) ic_new = ic_obj(**ic["value"]) message.instruct_content = ic_new diff --git a/tests/metagpt/roles/test_role.py b/tests/metagpt/roles/test_role.py new file mode 100644 index 000000000..72cd84a9a --- /dev/null +++ b/tests/metagpt/roles/test_role.py @@ -0,0 +1,11 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : unittest of Role + +from metagpt.roles.role import Role + + +def test_role_desc(): + role = Role(profile="Sales", desc="Best Seller") + assert role.profile == "Sales" + assert role._setting.desc == "Best Seller" diff --git a/tests/metagpt/serialize_deserialize/__init__.py b/tests/metagpt/serialize_deserialize/__init__.py new file mode 100644 index 000000000..78f454fb5 --- /dev/null +++ b/tests/metagpt/serialize_deserialize/__init__.py @@ -0,0 +1,4 @@ +# -*- coding: utf-8 -*- +# @Date : 11/22/2023 11:48 AM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : diff --git a/tests/metagpt/serialize_deserialize/test_action.py b/tests/metagpt/serialize_deserialize/test_action.py new file mode 100644 index 000000000..14d558c13 --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_action.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- +# @Date : 11/22/2023 11:48 AM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +import pytest + +from metagpt.actions import Action +from metagpt.llm import LLM + + +def test_action_serialize(): + action = Action() + ser_action_dict = action.dict() + assert "name" in ser_action_dict + # assert "llm" not in ser_action_dict # not export + + +@pytest.mark.asyncio +async def test_action_deserialize(): + action = Action() + serialized_data = action.dict() + + new_action = Action(**serialized_data) + + assert new_action.name == "" + assert new_action.llm == LLM() + assert len(await new_action._aask("who are you")) > 0 diff --git a/tests/metagpt/serialize_deserialize/test_architect_deserialize.py b/tests/metagpt/serialize_deserialize/test_architect_deserialize.py new file mode 100644 index 000000000..b92eba8a1 --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_architect_deserialize.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- +# @Date : 11/26/2023 2:04 PM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +import pytest + +from metagpt.actions.action import Action +from metagpt.roles.architect import Architect + + +def test_architect_serialize(): + role = Architect() + ser_role_dict = role.dict(by_alias=True) + assert "name" in ser_role_dict + assert "_states" in ser_role_dict + assert "_actions" in ser_role_dict + + +@pytest.mark.asyncio +async def test_architect_deserialize(): + role = Architect() + ser_role_dict = role.dict(by_alias=True) + new_role = Architect(**ser_role_dict) + # new_role = Architect.deserialize(ser_role_dict) + assert new_role.name == "Bob" + assert len(new_role._actions) == 1 + assert isinstance(new_role._actions[0], Action) + await new_role._actions[0].run(with_messages="write a cli snake game") diff --git a/tests/metagpt/serialize_deserialize/test_environment.py b/tests/metagpt/serialize_deserialize/test_environment.py new file mode 100644 index 000000000..b741b9c4b --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_environment.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +import shutil + +from metagpt.actions.action_node import ActionNode +from metagpt.actions.add_requirement import UserRequirement +from metagpt.actions.project_management import WriteTasks +from metagpt.environment import Environment +from metagpt.roles.project_manager import ProjectManager +from metagpt.schema import Message +from metagpt.utils.common import any_to_str +from tests.metagpt.serialize_deserialize.test_serdeser_base import RoleC, ActionOK, serdeser_path + + +def test_env_serialize(): + env = Environment() + ser_env_dict = env.dict() + assert "roles" in ser_env_dict + + +def test_env_deserialize(): + env = Environment() + env.publish_message(message=Message(content="test env serialize")) + ser_env_dict = env.dict() + new_env = Environment(**ser_env_dict) + assert len(new_env.roles) == 0 + assert len(new_env.history) == 25 + + +def test_environment_serdeser(): + out_mapping = {"field1": (list[str], ...)} + out_data = {"field1": ["field1 value1", "field1 value2"]} + ic_obj = ActionNode.create_model_class("prd", out_mapping) + + message = Message( + content="prd", + instruct_content=ic_obj(**out_data), + role="product manager", + cause_by=any_to_str(UserRequirement) + ) + + environment = Environment() + role_c = RoleC() + environment.add_role(role_c) + environment.publish_message(message) + + ser_data = environment.dict() + assert ser_data["roles"]["Role C"]["name"] == "RoleC" + + new_env: Environment = Environment(**ser_data) + assert len(new_env.roles) == 1 + + assert list(new_env.roles.values())[0]._states == list(environment.roles.values())[0]._states + assert list(new_env.roles.values())[0]._actions == list(environment.roles.values())[0]._actions + assert isinstance(list(environment.roles.values())[0]._actions[0], ActionOK) + assert type(list(new_env.roles.values())[0]._actions[0]) == ActionOK + + +def test_environment_serdeser_v2(): + environment = Environment() + pm = ProjectManager() + environment.add_role(pm) + + ser_data = environment.dict() + + new_env: Environment = Environment(**ser_data) + role = new_env.get_role(pm.profile) + assert isinstance(role, ProjectManager) + assert isinstance(role._actions[0], WriteTasks) + assert isinstance(list(new_env.roles.values())[0]._actions[0], WriteTasks) + + +def test_environment_serdeser_save(): + environment = Environment() + role_c = RoleC() + + shutil.rmtree(serdeser_path.joinpath("team"), ignore_errors=True) + + stg_path = serdeser_path.joinpath("team", "environment") + environment.add_role(role_c) + environment.serialize(stg_path) + + new_env: Environment = Environment.deserialize(stg_path) + assert len(new_env.roles) == 1 + assert type(list(new_env.roles.values())[0]._actions[0]) == ActionOK diff --git a/tests/metagpt/serialize_deserialize/test_memory.py b/tests/metagpt/serialize_deserialize/test_memory.py new file mode 100644 index 000000000..0d756518b --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_memory.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : unittest of memory + +from pydantic import BaseModel + +from metagpt.actions.action_node import ActionNode +from metagpt.actions.add_requirement import UserRequirement +from metagpt.actions.design_api import WriteDesign +from metagpt.memory.memory import Memory +from metagpt.schema import Message +from metagpt.utils.common import any_to_str +from tests.metagpt.serialize_deserialize.test_serdeser_base import serdeser_path + + +def test_memory_serdeser(): + msg1 = Message(role="Boss", + content="write a snake game", + cause_by=UserRequirement) + + out_mapping = {"field2": (list[str], ...)} + out_data = {"field2": ["field2 value1", "field2 value2"]} + ic_obj = ActionNode.create_model_class("system_design", out_mapping) + msg2 = Message(role="Architect", + instruct_content=ic_obj(**out_data), + content="system design content", + cause_by=WriteDesign) + + memory = Memory() + memory.add_batch([msg1, msg2]) + ser_data = memory.dict() + + new_memory = Memory(**ser_data) + assert new_memory.count() == 2 + new_msg2 = new_memory.get(2)[0] + assert isinstance(new_msg2, BaseModel) + assert isinstance(new_memory.storage[-1], BaseModel) + assert new_memory.storage[-1].cause_by == any_to_str(WriteDesign) + assert new_msg2.role == "Boss" + + +def test_memory_serdeser_save(): + msg1 = Message(role="User", + content="write a 2048 game", + cause_by=UserRequirement) + + out_mapping = {"field1": (list[str], ...)} + out_data = {"field1": ["field1 value1", "field1 value2"]} + ic_obj = ActionNode.create_model_class("system_design", out_mapping) + msg2 = Message(role="Architect", + instruct_content=ic_obj(**out_data), + content="system design content", + cause_by=WriteDesign) + + memory = Memory() + memory.add_batch([msg1, msg2]) + + stg_path = serdeser_path.joinpath("team", "environment") + memory.serialize(stg_path) + assert stg_path.joinpath("memory.json").exists() + + new_memory = Memory.deserialize(stg_path) + assert new_memory.count() == 2 + new_msg2 = new_memory.get(1)[0] + assert new_msg2.instruct_content.field1 == ["field1 value1", "field1 value2"] + assert new_msg2.cause_by == any_to_str(WriteDesign) + assert len(new_memory.index) == 2 + + stg_path.joinpath("memory.json").unlink() diff --git a/tests/metagpt/serialize_deserialize/test_product_manager.py b/tests/metagpt/serialize_deserialize/test_product_manager.py new file mode 100644 index 000000000..b65e329d1 --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_product_manager.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- +# @Date : 11/26/2023 2:07 PM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +import pytest + +from metagpt.actions.action import Action +from metagpt.roles.product_manager import ProductManager +from metagpt.schema import Message + + +@pytest.mark.asyncio +async def test_product_manager_deserialize(): + role = ProductManager() + ser_role_dict = role.dict(by_alias=True) + new_role = ProductManager(**ser_role_dict) + + assert new_role.name == "Alice" + assert len(new_role._actions) == 2 + assert isinstance(new_role._actions[0], Action) + await new_role._actions[0].run([Message(content="write a cli snake game")]) diff --git a/tests/metagpt/serialize_deserialize/test_project_manager.py b/tests/metagpt/serialize_deserialize/test_project_manager.py new file mode 100644 index 000000000..e52e3f247 --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_project_manager.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +# @Date : 11/26/2023 2:06 PM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +import pytest + +from metagpt.actions.action import Action +from metagpt.actions.project_management import WriteTasks +from metagpt.roles.project_manager import ProjectManager + + +def test_project_manager_serialize(): + role = ProjectManager() + ser_role_dict = role.dict(by_alias=True) + assert "name" in ser_role_dict + assert "_states" in ser_role_dict + assert "_actions" in ser_role_dict + + +@pytest.mark.asyncio +async def test_project_manager_deserialize(): + role = ProjectManager() + ser_role_dict = role.dict(by_alias=True) + + new_role = ProjectManager(**ser_role_dict) + assert new_role.name == "Eve" + assert len(new_role._actions) == 1 + assert isinstance(new_role._actions[0], Action) + assert isinstance(new_role._actions[0], WriteTasks) + # await new_role._actions[0].run(context="write a cli snake game") diff --git a/tests/metagpt/serialize_deserialize/test_role.py b/tests/metagpt/serialize_deserialize/test_role.py new file mode 100644 index 000000000..88c7f7d8b --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_role.py @@ -0,0 +1,95 @@ +# -*- coding: utf-8 -*- +# @Date : 11/23/2023 4:49 PM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : + +import shutil + +import pytest + +from metagpt.actions import WriteCode +from metagpt.actions.add_requirement import UserRequirement +from metagpt.const import SERDESER_PATH +from metagpt.logs import logger +from metagpt.roles.engineer import Engineer +from metagpt.roles.product_manager import ProductManager +from metagpt.roles.role import Role +from metagpt.schema import Message +from metagpt.utils.common import format_trackback_info +from tests.metagpt.serialize_deserialize.test_serdeser_base import RoleA, RoleB, RoleC, serdeser_path + + +def test_roles(): + role_a = RoleA() + assert len(role_a._rc.watch) == 1 + role_b = RoleB() + assert len(role_a._rc.watch) == 1 + assert len(role_b._rc.watch) == 1 + + +def test_role_serialize(): + role = Role() + ser_role_dict = role.dict(by_alias=True) + assert "name" in ser_role_dict + assert "_states" in ser_role_dict + assert "_actions" in ser_role_dict + + +def test_engineer_serialize(): + role = Engineer() + ser_role_dict = role.dict(by_alias=True) + assert "name" in ser_role_dict + assert "_states" in ser_role_dict + assert "_actions" in ser_role_dict + + +@pytest.mark.asyncio +async def test_engineer_deserialize(): + role = Engineer(use_code_review=True) + ser_role_dict = role.dict(by_alias=True) + + new_role = Engineer(**ser_role_dict) + assert new_role.name == "Alex" + assert new_role.use_code_review is True + assert len(new_role._actions) == 1 + assert isinstance(new_role._actions[0], WriteCode) + # await new_role._actions[0].run(context="write a cli snake game", filename="test_code") + + +def test_role_serdeser_save(): + stg_path_prefix = serdeser_path.joinpath("team", "environment", "roles") + shutil.rmtree(serdeser_path.joinpath("team"), ignore_errors=True) + + pm = ProductManager() + role_tag = f"{pm.__class__.__name__}_{pm.name}" + stg_path = stg_path_prefix.joinpath(role_tag) + pm.serialize(stg_path) + + new_pm = Role.deserialize(stg_path) + assert new_pm.name == pm.name + assert len(new_pm.get_memories(1)) == 0 + + +@pytest.mark.asyncio +async def test_role_serdeser_interrupt(): + role_c = RoleC() + shutil.rmtree(SERDESER_PATH.joinpath("team"), ignore_errors=True) + + stg_path = SERDESER_PATH.joinpath(f"team", "environment", "roles", "{role_c.__class__.__name__}_{role_c.name}") + try: + await role_c.run( + with_message=Message(content="demo", cause_by=UserRequirement) + ) + except Exception as exp: + logger.error(f"Exception in `role_a.run`, detail: {format_trackback_info()}") + role_c.serialize(stg_path) + + assert role_c._rc.memory.count() == 1 + + new_role_a: Role = Role.deserialize(stg_path) + assert new_role_a._rc.state == 1 + + with pytest.raises(Exception): + await role_c.run( + with_message=Message(content="demo", cause_by=UserRequirement) + ) diff --git a/tests/metagpt/serialize_deserialize/test_schema.py b/tests/metagpt/serialize_deserialize/test_schema.py new file mode 100644 index 000000000..72b7153a7 --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_schema.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : unittest of schema ser&deser + +from metagpt.actions.action_node import ActionNode +from metagpt.actions.write_code import WriteCode +from metagpt.schema import Message +from metagpt.utils.common import any_to_str +from tests.metagpt.serialize_deserialize.test_serdeser_base import MockMessage + + +def test_message_serdeser(): + out_mapping = {"field3": (str, ...), "field4": (list[str], ...)} + out_data = {"field3": "field3 value3", "field4": ["field4 value1", "field4 value2"]} + ic_obj = ActionNode.create_model_class("code", out_mapping) + + message = Message( + content="code", + instruct_content=ic_obj(**out_data), + role="engineer", + cause_by=WriteCode + ) + ser_data = message.dict() + assert ser_data["cause_by"] == "metagpt.actions.write_code.WriteCode" + assert ser_data["instruct_content"]["class"] == "code" + + new_message = Message(**ser_data) + assert new_message.cause_by == any_to_str(WriteCode) + assert new_message.cause_by in [any_to_str(WriteCode)] + assert new_message.instruct_content == ic_obj(**out_data) + + +def test_message_without_postprocess(): + """ to explain `instruct_content` should be postprocessed """ + out_mapping = {"field1": (list[str], ...)} + out_data = {"field1": ["field1 value1", "field1 value2"]} + ic_obj = ActionNode.create_model_class("code", out_mapping) + message = MockMessage( + content="code", + instruct_content=ic_obj(**out_data) + ) + ser_data = message.dict() + assert ser_data["instruct_content"] == {"field1": ["field1 value1", "field1 value2"]} + + new_message = MockMessage(**ser_data) + assert new_message.instruct_content != ic_obj(**out_data) diff --git a/tests/metagpt/serialize_deserialize/test_serdeser_base.py b/tests/metagpt/serialize_deserialize/test_serdeser_base.py new file mode 100644 index 000000000..eac083cf9 --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_serdeser_base.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : base test actions / roles used in unittest + +import asyncio +from pathlib import Path + +from pydantic import BaseModel, Field + +from metagpt.actions import Action, ActionOutput +from metagpt.actions.action_node import ActionNode +from metagpt.actions.add_requirement import UserRequirement +from metagpt.roles.role import Role, RoleReactMode + +serdeser_path = Path(__file__).absolute().parent.joinpath("..", "..", "data", "serdeser_storage") + + +class MockMessage(BaseModel): + """ to test normal dict without postprocess """ + content: str = "" + instruct_content: BaseModel = Field(default=None) + + +class ActionPass(Action): + name: str = Field(default="ActionPass") + + async def run(self, messages: list["Message"]) -> ActionOutput: + await asyncio.sleep(5) # sleep to make other roles can watch the executed Message + output_mapping = { + "result": (str, ...) + } + pass_class = ActionNode.create_model_class("pass", output_mapping) + pass_output = ActionOutput("ActionPass run passed", pass_class(**{"result": "pass result"})) + + return pass_output + + +class ActionOK(Action): + name: str = Field(default="ActionOK") + + async def run(self, messages: list["Message"]) -> str: + await asyncio.sleep(5) + return "ok" + + +class ActionRaise(Action): + name: str = Field(default="ActionRaise") + + async def run(self, messages: list["Message"]) -> str: + raise RuntimeError("parse error in ActionRaise") + + +class RoleA(Role): + name: str = Field(default="RoleA") + profile: str = Field(default="Role A") + goal: str = "RoleA's goal" + constraints: str = "RoleA's constraints" + + def __init__(self, **kwargs): + super(RoleA, self).__init__(**kwargs) + self._init_actions([ActionPass]) + self._watch([UserRequirement]) + + +class RoleB(Role): + name: str = Field(default="RoleB") + profile: str = Field(default="Role B") + goal: str = "RoleB's goal" + constraints: str = "RoleB's constraints" + + def __init__(self, **kwargs): + super(RoleB, self).__init__(**kwargs) + self._init_actions([ActionOK, ActionRaise]) + self._watch([ActionPass]) + self._rc.react_mode = RoleReactMode.BY_ORDER + + +class RoleC(Role): + name: str = Field(default="RoleC") + profile: str = Field(default="Role C") + goal: str = "RoleC's goal" + constraints: str = "RoleC's constraints" + + def __init__(self, **kwargs): + super(RoleC, self).__init__(**kwargs) + self._init_actions([ActionOK, ActionRaise]) + self._watch([UserRequirement]) + self._rc.react_mode = RoleReactMode.BY_ORDER diff --git a/tests/metagpt/serialize_deserialize/test_team.py b/tests/metagpt/serialize_deserialize/test_team.py new file mode 100644 index 000000000..db6001325 --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_team.py @@ -0,0 +1,131 @@ +# -*- coding: utf-8 -*- +# @Date : 11/27/2023 10:07 AM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : + +import shutil + +import pytest + +from metagpt.const import SERDESER_PATH +from metagpt.roles import ProjectManager, ProductManager, Architect +from metagpt.team import Team +from metagpt.logs import logger +from tests.metagpt.serialize_deserialize.test_serdeser_base import RoleA, RoleB, RoleC, serdeser_path, ActionOK + + +def test_team_deserialize(): + company = Team() + + pm = ProductManager() + arch = Architect() + company.hire( + [ + pm, + arch, + ProjectManager(), + ] + ) + assert len(company.env.get_roles()) == 3 + ser_company = company.dict() + new_company = Team(**ser_company) + + assert len(new_company.env.get_roles()) == 3 + assert new_company.env.get_role(pm.profile) is not None + + new_pm = new_company.env.get_role(pm.profile) + assert type(new_pm) == ProductManager + assert new_company.env.get_role(pm.profile) is not None + assert new_company.env.get_role(arch.profile) is not None + + +def test_team_serdeser_save(): + company = Team() + company.hire([RoleC()]) + + stg_path = serdeser_path.joinpath("team") + shutil.rmtree(stg_path, ignore_errors=True) + + company.serialize(stg_path=stg_path) + + new_company = Team.deserialize(stg_path) + + assert len(new_company.env.roles) == 1 + + +@pytest.mark.asyncio +async def test_team_recover(): + idea = "write a snake game" + stg_path = SERDESER_PATH.joinpath("team") + shutil.rmtree(stg_path, ignore_errors=True) + + company = Team() + role_c = RoleC() + company.hire([role_c]) + company.run_project(idea) + await company.run(n_round=4) + + ser_data = company.dict() + new_company = Team(**ser_data) + + new_role_c = new_company.env.get_role(role_c.profile) + # assert new_role_c._rc.memory == role_c._rc.memory # TODO + assert new_role_c._rc.env != role_c._rc.env # TODO + assert type(list(new_company.env.roles.values())[0]._actions[0]) == ActionOK + + new_company.run_project(idea) + await new_company.run(n_round=4) + + +@pytest.mark.asyncio +async def test_team_recover_save(): + idea = "write a 2048 web game" + stg_path = SERDESER_PATH.joinpath("team") + shutil.rmtree(stg_path, ignore_errors=True) + + company = Team() + role_c = RoleC() + company.hire([role_c]) + company.run_project(idea) + await company.run(n_round=4) + + new_company = Team.deserialize(stg_path) + new_role_c = new_company.env.get_role(role_c.profile) + # assert new_role_c._rc.memory == role_c._rc.memory + assert new_role_c._rc.env != role_c._rc.env + assert new_role_c.recovered != role_c.recovered # here cause previous ut is `!=` + assert new_role_c._rc.todo != role_c._rc.todo # serialize exclude `_rc.todo` + assert new_role_c._rc.news != role_c._rc.news # serialize exclude `_rc.news` + + new_company.run_project(idea) + await new_company.run(n_round=4) + + +@pytest.mark.asyncio +async def test_team_recover_multi_roles_save(): + idea = "write a snake game" + stg_path = SERDESER_PATH.joinpath("team") + shutil.rmtree(stg_path, ignore_errors=True) + + role_a = RoleA() + role_b = RoleB() + + assert role_a.subscription == {"tests.metagpt.serialize_deserialize.test_serdeser_base.RoleA", + "RoleA"} + assert role_b.subscription == {"tests.metagpt.serialize_deserialize.test_serdeser_base.RoleB", + "RoleB"} + assert role_b._rc.watch == {"tests.metagpt.serialize_deserialize.test_serdeser_base.ActionPass"} + + company = Team() + company.hire([role_a, role_b]) + company.run_project(idea) + await company.run(n_round=4) + + logger.info("Team recovered") + + new_company = Team.deserialize(stg_path) + new_company.run_project(idea) + + assert new_company.env.get_role(role_b.profile)._rc.state == 1 + + await new_company.run(n_round=4) diff --git a/tests/metagpt/serialize_deserialize/test_write_code.py b/tests/metagpt/serialize_deserialize/test_write_code.py new file mode 100644 index 000000000..0114c48da --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_write_code.py @@ -0,0 +1,31 @@ +# -*- coding: utf-8 -*- +# @Date : 11/23/2023 10:56 AM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : + +import pytest + +from metagpt.actions import WriteCode +from metagpt.llm import LLM +from metagpt.schema import CodingContext, Document + + +def test_write_design_serialize(): + action = WriteCode() + ser_action_dict = action.dict() + assert ser_action_dict["name"] == "WriteCode" + # assert "llm" in ser_action_dict # not export + + +@pytest.mark.asyncio +async def test_write_code_deserialize(): + context = CodingContext(filename="test_code.py", + design_doc=Document(content="write add function to calculate two numbers")) + doc = Document(content=context.json()) + action = WriteCode(context=doc) + serialized_data = action.dict() + new_action = WriteCode(**serialized_data) + + assert new_action.name == "WriteCode" + assert new_action.llm == LLM() + await action.run() diff --git a/tests/metagpt/serialize_deserialize/test_write_code_review.py b/tests/metagpt/serialize_deserialize/test_write_code_review.py new file mode 100644 index 000000000..a15b744db --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_write_code_review.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : unittest of WriteCodeReview SerDeser + +import pytest + +from metagpt.actions import WriteCodeReview +from metagpt.llm import LLM +from metagpt.schema import CodingContext, Document + + +@pytest.mark.asyncio +async def test_write_code_review_deserialize(): + code_content = """ +def div(a: int, b: int = 0): + return a / b +""" + context = CodingContext( + filename="test_op.py", + design_doc=Document(content="divide two numbers"), + code_doc=Document(content=code_content) + ) + + action = WriteCodeReview(context=context) + serialized_data = action.dict() + assert serialized_data["name"] == "WriteCodeReview" + + new_action = WriteCodeReview(**serialized_data) + + assert new_action.name == "WriteCodeReview" + assert new_action.llm == LLM() + await new_action.run() diff --git a/tests/metagpt/serialize_deserialize/test_write_design.py b/tests/metagpt/serialize_deserialize/test_write_design.py new file mode 100644 index 000000000..4e768ddd7 --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_write_design.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +# @Date : 11/22/2023 8:19 PM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +import pytest + +from metagpt.actions import WriteDesign, WriteTasks +from metagpt.llm import LLM + + +def test_write_design_serialize(): + action = WriteDesign() + ser_action_dict = action.dict() + assert "name" in ser_action_dict + # assert "llm" in ser_action_dict # not export + + +def test_write_task_serialize(): + action = WriteTasks() + ser_action_dict = action.dict() + assert "name" in ser_action_dict + # assert "llm" in ser_action_dict # not export + + +@pytest.mark.asyncio +async def test_write_design_deserialize(): + action = WriteDesign() + serialized_data = action.dict() + new_action = WriteDesign(**serialized_data) + assert new_action.name == "" + assert new_action.llm == LLM() + await new_action.run(with_messages="write a cli snake game") + + +@pytest.mark.asyncio +async def test_write_task_deserialize(): + action = WriteTasks() + serialized_data = action.dict() + new_action = WriteTasks(**serialized_data) + assert new_action.name == "CreateTasks" + assert new_action.llm == LLM() + await new_action.run(with_messages="write a cli snake game") diff --git a/tests/metagpt/serialize_deserialize/test_write_prd.py b/tests/metagpt/serialize_deserialize/test_write_prd.py new file mode 100644 index 000000000..d6d14f99a --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_write_prd.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- +# @Date : 11/22/2023 1:47 PM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : + +import pytest + +from metagpt.actions import WritePRD +from metagpt.llm import LLM +from metagpt.schema import Message + + +def test_action_serialize(): + action = WritePRD() + ser_action_dict = action.dict() + assert "name" in ser_action_dict + # assert "llm" in ser_action_dict # not export + + +@pytest.mark.asyncio +async def test_action_deserialize(): + action = WritePRD() + serialized_data = action.dict() + new_action = WritePRD(**serialized_data) + assert new_action.name == "" + assert new_action.llm == LLM() + action_output = await new_action.run(with_messages=Message(content="write a cli snake game")) + assert len(action_output.content) > 0 diff --git a/tests/metagpt/test_environment.py b/tests/metagpt/test_environment.py index b27bc3da7..ee322368e 100644 --- a/tests/metagpt/test_environment.py +++ b/tests/metagpt/test_environment.py @@ -7,6 +7,7 @@ """ import pytest +from pathlib import Path from metagpt.actions import UserRequirement from metagpt.environment import Environment @@ -16,38 +17,51 @@ from metagpt.roles import Architect, ProductManager, Role from metagpt.schema import Message +serdeser_path = Path(__file__).absolute().parent.joinpath("../data/serdeser_storage") + + @pytest.fixture def env(): return Environment() def test_add_role(env: Environment): - role = ProductManager("Alice", "product manager", "create a new product", "limited resources") + role = ProductManager(name="Alice", + profile="product manager", + goal="create a new product", + constraints="limited resources") env.add_role(role) assert env.get_role(role.profile) == role def test_get_roles(env: Environment): - role1 = Role("Alice", "product manager", "create a new product", "limited resources") - role2 = Role("Bob", "engineer", "develop the new product", "short deadline") + role1 = Role(name="Alice", + profile="product manager", + goal="create a new product", + constraints="limited resources") + role2 = Role(name="Bob", + profile="engineer", + goal="develop the new product", + constraints="short deadline") env.add_role(role1) env.add_role(role2) roles = env.get_roles() assert roles == {role1.profile: role1, role2.profile: role2} -def test_set_manager(env: Environment): - manager = Manager() - env.set_manager(manager) - assert env.manager == manager - - @pytest.mark.asyncio async def test_publish_and_process_message(env: Environment): - product_manager = ProductManager("Alice", "Product Manager", "做AI Native产品", "资源有限") - architect = Architect("Bob", "Architect", "设计一个可用、高效、较低成本的系统,包括数据结构与接口", "资源有限,需要节省成本") + product_manager = ProductManager(name="Alice", + profile="Product Manager", + goal="做AI Native产品", + constraints="资源有限") + architect = Architect(name="Bob", + profile="Architect", + goal="设计一个可用、高效、较低成本的系统,包括数据结构与接口", + constraints="资源有限,需要节省成本") env.add_roles([product_manager, architect]) + env.set_manager(Manager()) env.publish_message(Message(role="User", content="需要一个基于LLM做总结的搜索引擎", cause_by=UserRequirement)) diff --git a/tests/metagpt/test_schema.py b/tests/metagpt/test_schema.py index 79421fab2..ef706abfa 100644 --- a/tests/metagpt/test_schema.py +++ b/tests/metagpt/test_schema.py @@ -7,12 +7,14 @@ @Modified By: mashenquan, 2023-11-1. In line with Chapter 2.2.1 and 2.2.2 of RFC 116, introduce unit tests for the utilization of the new feature of `Message` class. """ -import json +import json import pytest from metagpt.actions import Action from metagpt.schema import AIMessage, Message, SystemMessage, UserMessage +from metagpt.actions.action_node import ActionNode +from metagpt.actions.write_code import WriteCode from metagpt.utils.common import any_to_str @@ -20,10 +22,10 @@ from metagpt.utils.common import any_to_str def test_messages(): test_content = "test_message" msgs = [ - UserMessage(test_content), - SystemMessage(test_content), - AIMessage(test_content), - Message(test_content, role="QA"), + UserMessage(content=test_content), + SystemMessage(content=test_content), + AIMessage(content=test_content), + Message(content=test_content, role="QA"), ] text = str(msgs) roles = ["user", "system", "assistant", "QA"] @@ -32,7 +34,7 @@ def test_messages(): @pytest.mark.asyncio def test_message(): - m = Message("a", role="v1") + m = Message(content="a", role="v1") v = m.dump() d = json.loads(v) assert d @@ -45,7 +47,7 @@ def test_message(): assert m.content == "a" assert m.role == "v2" - m = Message("a", role="b", cause_by="c", x="d", send_to="c") + m = Message(content="a", role="b", cause_by="c", x="d", send_to="c") assert m.content == "a" assert m.role == "b" assert m.send_to == {"c"} @@ -63,12 +65,46 @@ def test_message(): @pytest.mark.asyncio def test_routes(): - m = Message("a", role="b", cause_by="c", x="d", send_to="c") + m = Message(content="a", role="b", cause_by="c", x="d", send_to="c") m.send_to = "b" assert m.send_to == {"b"} m.send_to = {"e", Action} assert m.send_to == {"e", any_to_str(Action)} -if __name__ == "__main__": - pytest.main([__file__, "-s"]) +def test_message_serdeser(): + out_mapping = {"field3": (str, ...), "field4": (list[str], ...)} + out_data = {"field3": "field3 value3", "field4": ["field4 value1", "field4 value2"]} + ic_obj = ActionNode.create_model_class("code", out_mapping) + + message = Message( + content="code", + instruct_content=ic_obj(**out_data), + role="engineer", + cause_by=WriteCode + ) + message_dict = message.dict() + assert message_dict["cause_by"] == "metagpt.actions.write_code.WriteCode" + assert message_dict["instruct_content"] == { + "class": "code", + "mapping": { + "field3": "(, Ellipsis)", + "field4": "(list[str], Ellipsis)" + }, + "value": { + "field3": "field3 value3", + "field4": ["field4 value1", "field4 value2"] + } + } + + new_message = Message(**message_dict) + assert new_message.content == message.content + assert new_message.instruct_content == message.instruct_content + assert new_message.cause_by == message.cause_by + assert new_message.instruct_content.field3 == out_data["field3"] + + message = Message(content="code") + message_dict = message.dict() + new_message = Message(**message_dict) + assert new_message.instruct_content is None + assert new_message.cause_by == "metagpt.actions.add_requirement.UserRequirement" diff --git a/tests/metagpt/test_team.py b/tests/metagpt/test_team.py new file mode 100644 index 000000000..efd035bb2 --- /dev/null +++ b/tests/metagpt/test_team.py @@ -0,0 +1,13 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : unittest of team + +from metagpt.team import Team +from metagpt.roles.project_manager import ProjectManager + + +def test_team(): + company = Team() + company.hire([ProjectManager()]) + + assert len(company.environment.roles) == 1