From 4702059caf3c76b05d2a6c7c119a56fbd03a8db9 Mon Sep 17 00:00:00 2001 From: stellahsr Date: Mon, 27 Nov 2023 21:12:50 +0800 Subject: [PATCH 001/167] update basic code for serialize --- metagpt/actions/action.py | 61 +++--- metagpt/actions/design_api.py | 30 ++- metagpt/actions/project_management.py | 27 ++- metagpt/actions/search_and_summarize.py | 52 +++-- metagpt/actions/write_code.py | 15 +- metagpt/actions/write_code_review.py | 13 +- metagpt/actions/write_prd.py | 32 ++- metagpt/environment.py | 5 +- metagpt/roles/architect.py | 18 +- metagpt/roles/engineer.py | 76 +++---- metagpt/roles/product_manager.py | 36 ++-- metagpt/roles/project_manager.py | 27 +-- metagpt/roles/role.py | 271 +++++++++++------------- 13 files changed, 342 insertions(+), 321 deletions(-) diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index 790295d55..7bb5a151b 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -7,8 +7,9 @@ """ import re from abc import ABC -from typing import Optional +from typing import Optional, Any +from pydantic import BaseModel, Field from tenacity import retry, stop_after_attempt, wait_fixed from metagpt.actions.action_output import ActionOutput @@ -18,45 +19,45 @@ from metagpt.utils.common import OutputParser from metagpt.utils.custom_decoder import CustomDecoder -class Action(ABC): - def __init__(self, name: str = "", context=None, llm: LLM = None): - self.name: str = name - if llm is None: - llm = LLM() - self.llm = llm - self.context = context - self.prefix = "" - self.profile = "" - self.desc = "" - self.content = "" - self.instruct_content = None - +class Action(BaseModel): + name: str = "" + llm: LLM = Field(default_factory=LLM) + context = "" + prefix = "" + profile = "" + desc = "" + content: Optional[str] = None + instruct_content: Optional[str] = None + + def __init__(self, **kwargs: Any): + super().__init__(**kwargs) + def set_prefix(self, prefix, profile): """Set prefix for later usage""" self.prefix = prefix self.profile = profile - + def __str__(self): return self.__class__.__name__ - + def __repr__(self): return self.__str__() - + async def _aask(self, prompt: str, system_msgs: Optional[list[str]] = None) -> str: """Append default prefix""" if not system_msgs: system_msgs = [] system_msgs.append(self.prefix) return await self.llm.aask(prompt, system_msgs) - + @retry(stop=stop_after_attempt(3), wait=wait_fixed(1)) async def _aask_v1( - self, - prompt: str, - output_class_name: str, - output_data_mapping: dict, - system_msgs: Optional[list[str]] = None, - format="markdown", # compatible to original format + self, + prompt: str, + output_class_name: str, + output_data_mapping: dict, + system_msgs: Optional[list[str]] = None, + format="markdown", # compatible to original format ) -> ActionOutput: """Append default prefix""" if not system_msgs: @@ -65,25 +66,25 @@ class Action(ABC): content = await self.llm.aask(prompt, system_msgs) logger.debug(content) output_class = ActionOutput.create_model_class(output_class_name, output_data_mapping) - + if format == "json": pattern = r"\[CONTENT\](\s*\{.*?\}\s*)\[/CONTENT\]" matches = re.findall(pattern, content, re.DOTALL) - + for match in matches: if match: content = match break - + parsed_data = CustomDecoder(strict=False).decode(content) - + else: # using markdown parser parsed_data = OutputParser.parse_data_with_mapping(content, output_data_mapping) - + logger.debug(parsed_data) instruct_content = output_class(**parsed_data) return ActionOutput(content, instruct_content) - + async def run(self, *args, **kwargs): """Run action""" raise NotImplementedError("The run method should be implemented in a subclass.") diff --git a/metagpt/actions/design_api.py b/metagpt/actions/design_api.py index 75df8b909..30df70ce7 100644 --- a/metagpt/actions/design_api.py +++ b/metagpt/actions/design_api.py @@ -7,9 +7,12 @@ """ import shutil from pathlib import Path -from typing import List +from typing import List, Optional, Any + +from pydantic import Field from metagpt.actions import Action, ActionOutput +from metagpt.llm import LLM from metagpt.config import CONFIG from metagpt.const import WORKSPACE_ROOT from metagpt.logs import logger @@ -150,13 +153,13 @@ OUTPUT_MAPPING = { class WriteDesign(Action): - def __init__(self, name, context=None, llm=None): - super().__init__(name, context, llm) - self.desc = ( - "Based on the PRD, think about the system design, and design the corresponding APIs, " - "data structures, library tables, processes, and paths. Please provide your design, feedback " - "clearly and in detail." - ) + name: str = "" + context: Optional[str] = None + llm: LLM = Field(default_factory=LLM) + desc: str = "Based on the PRD, think about the system design, and design the corresponding APIs, " + "data structures, library tables, processes, and paths. Please provide your design, feedback " + "clearly and in detail." + def recreate_workspace(self, workspace: Path): try: @@ -165,16 +168,18 @@ class WriteDesign(Action): pass # Folder does not exist, but we don't care workspace.mkdir(parents=True, exist_ok=True) + async def _save_prd(self, docs_path, resources_path, context): prd_file = docs_path / "prd.md" if context[-1].instruct_content and context[-1].instruct_content.dict()["Competitive Quadrant Chart"]: quadrant_chart = context[-1].instruct_content.dict()["Competitive Quadrant Chart"] await mermaid_to_file(quadrant_chart, resources_path / "competitive_analysis") - + if context[-1].instruct_content: logger.info(f"Saving PRD to {prd_file}") prd_file.write_text(json_to_markdown(context[-1].instruct_content.dict())) + async def _save_system_design(self, docs_path, resources_path, system_design): data_api_design = system_design.instruct_content.dict()[ "Data structures and interface definitions" @@ -188,6 +193,7 @@ class WriteDesign(Action): logger.info(f"Saving System Designs to {system_design_file}") system_design_file.write_text((json_to_markdown(system_design.instruct_content.dict()))) + async def _save(self, context, system_design): if isinstance(system_design, ActionOutput): ws_name = system_design.instruct_content.dict()["Python package name"] @@ -199,9 +205,13 @@ class WriteDesign(Action): resources_path = workspace / "resources" docs_path.mkdir(parents=True, exist_ok=True) resources_path.mkdir(parents=True, exist_ok=True) - await self._save_prd(docs_path, resources_path, context) + try: + await self._save_prd(docs_path, resources_path, context) + except Exception as e: + logger.error(f"Failed to save PRD {e}") await self._save_system_design(docs_path, resources_path, system_design) + async def run(self, context, format=CONFIG.prompt_format): prompt_template, format_example = get_template(templates, format) prompt = prompt_template.format(context=context, format_example=format_example) diff --git a/metagpt/actions/project_management.py b/metagpt/actions/project_management.py index b395fa64e..b72507ee3 100644 --- a/metagpt/actions/project_management.py +++ b/metagpt/actions/project_management.py @@ -5,9 +5,12 @@ @Author : alexanderwu @File : project_management.py """ -from typing import List +from typing import List, Optional, Any + +from pydantic import Field from metagpt.actions.action import Action +from metagpt.llm import LLM from metagpt.config import CONFIG from metagpt.const import WORKSPACE_ROOT from metagpt.utils.common import CodeParser @@ -163,21 +166,25 @@ OUTPUT_MAPPING = { class WriteTasks(Action): - def __init__(self, name="CreateTasks", context=None, llm=None): - super().__init__(name, context, llm) - + name: str = "CreateTasks" + context: Optional[str] = None + llm: LLM = Field(default_factory=LLM) + def _save(self, context, rsp): - if context[-1].instruct_content: - ws_name = context[-1].instruct_content.dict()["Python package name"] - else: - ws_name = CodeParser.parse_str(block="Python package name", text=context[-1].content) + try: + if context[-1].instruct_content: + ws_name = context[-1].instruct_content.dict()["Python package name"] + else: + ws_name = CodeParser.parse_str(block="Python package name", text=context[-1].content) + except: + ws_name = "cli_snake_game" # fixme: 应该透传 file_path = WORKSPACE_ROOT / ws_name / "docs/api_spec_and_tasks.md" file_path.write_text(json_to_markdown(rsp.instruct_content.dict())) - + # Write requirements.txt requirements_path = WORKSPACE_ROOT / ws_name / "requirements.txt" requirements_path.write_text("\n".join(rsp.instruct_content.dict().get("Required Python third-party packages"))) - + async def run(self, context, format=CONFIG.prompt_format): prompt_template, format_example = get_template(templates, format) prompt = prompt_template.format(context=context, format_example=format_example) diff --git a/metagpt/actions/search_and_summarize.py b/metagpt/actions/search_and_summarize.py index 069f2a977..0580303e6 100644 --- a/metagpt/actions/search_and_summarize.py +++ b/metagpt/actions/search_and_summarize.py @@ -6,12 +6,16 @@ @File : search_google.py """ import pydantic +from typing import Optional, Any +from pydantic import BaseModel, Field from metagpt.actions import Action +from metagpt.llm import LLM from metagpt.config import Config from metagpt.logs import logger from metagpt.schema import Message from metagpt.tools.search_engine import SearchEngine +from pydantic import root_validator SEARCH_AND_SUMMARIZE_SYSTEM = """### Requirements 1. Please summarize the latest dialogue based on the reference information (secondary) and dialogue history (primary). Do not include text that is irrelevant to the conversation. @@ -54,7 +58,6 @@ SEARCH_AND_SUMMARIZE_PROMPT = """ """ - SEARCH_AND_SUMMARIZE_SALES_SYSTEM = """## Requirements 1. Please summarize the latest dialogue based on the reference information (secondary) and dialogue history (primary). Do not include text that is irrelevant to the conversation. - The context is for reference only. If it is irrelevant to the user's search request history, please reduce its reference and usage. @@ -101,23 +104,41 @@ You are a member of a professional butler team and will provide helpful suggesti class SearchAndSummarize(Action): - def __init__(self, name="", context=None, llm=None, engine=None, search_func=None): - self.config = Config() - self.engine = engine or self.config.search_engine + name: str = "" + content: Optional[str] = None + llm: None = Field(default_factory=LLM) + config: None = Field(default_factory=Config) + engine: Optional[str] = None + search_func: Optional[str] = None - try: - self.search_engine = SearchEngine(self.engine, run_func=search_func) - except pydantic.ValidationError: - self.search_engine = None + result = "" + - self.result = "" - super().__init__(name, context, llm) + @root_validator + def validate_engine_and_run_func(cls, values): + engine = values.get('engine') + search_func = values.get('search_func') + config = Config() + + if engine is None: + engine = config.search_engine + config_data = { + 'engine': engine, + 'run_func': search_func + } + search_engine = SearchEngine(**config_data) + values['search_engine'] = search_engine + return values + + + async def run(self, context: list[Message], system_text=SEARCH_AND_SUMMARIZE_SYSTEM) -> str: + print(context) if self.search_engine is None: logger.warning("Configure one of SERPAPI_API_KEY, SERPER_API_KEY, GOOGLE_API_KEY to unlock full feature") return "" - + query = context[-1].content # logger.debug(query) rsp = await self.search_engine.run(query) @@ -126,9 +147,9 @@ class SearchAndSummarize(Action): logger.error("empty rsp...") return "" # logger.info(rsp) - + system_prompt = [system_text] - + prompt = SEARCH_AND_SUMMARIZE_PROMPT.format( # PREFIX = self.prefix, ROLE=self.profile, @@ -140,4 +161,7 @@ class SearchAndSummarize(Action): logger.debug(prompt) logger.debug(result) return result - \ No newline at end of file + + +if __name__ == "__main__": + action = SearchAndSummarize() diff --git a/metagpt/actions/write_code.py b/metagpt/actions/write_code.py index c000805c5..2dc240591 100644 --- a/metagpt/actions/write_code.py +++ b/metagpt/actions/write_code.py @@ -5,13 +5,18 @@ @Author : alexanderwu @File : write_code.py """ +from typing import List, Optional, Any + +from pydantic import Field +from tenacity import retry, stop_after_attempt, wait_fixed + from metagpt.actions import WriteDesign from metagpt.actions.action import Action +from metagpt.llm import LLM from metagpt.const import WORKSPACE_ROOT from metagpt.logs import logger from metagpt.schema import Message from metagpt.utils.common import CodeParser -from tenacity import retry, stop_after_attempt, wait_fixed PROMPT_TEMPLATE = """ NOTICE @@ -43,9 +48,10 @@ ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenc class WriteCode(Action): - def __init__(self, name="WriteCode", context: list[Message] = None, llm=None): - super().__init__(name, context, llm) - + name: str = "WriteCode" + context: Optional[str] = None + llm: LLM = Field(default_factory=LLM) + def _is_invalid(self, filename): return any(i in filename for i in ["mp3", "wav"]) @@ -79,4 +85,3 @@ class WriteCode(Action): # code_rsp = await self._aask_v1(prompt, "code_rsp", OUTPUT_MAPPING) # self._save(context, filename, code) return code - \ No newline at end of file diff --git a/metagpt/actions/write_code_review.py b/metagpt/actions/write_code_review.py index 4ff4d6cf6..3d86d7c63 100644 --- a/metagpt/actions/write_code_review.py +++ b/metagpt/actions/write_code_review.py @@ -5,12 +5,15 @@ @Author : alexanderwu @File : write_code_review.py """ +from typing import List, Optional, Any +from pydantic import Field +from tenacity import retry, stop_after_attempt, wait_fixed +from metagpt.llm import LLM from metagpt.actions.action import Action from metagpt.logs import logger from metagpt.schema import Message from metagpt.utils.common import CodeParser -from tenacity import retry, stop_after_attempt, wait_fixed PROMPT_TEMPLATE = """ NOTICE @@ -62,9 +65,10 @@ FORMAT_EXAMPLE = """ class WriteCodeReview(Action): - def __init__(self, name="WriteCodeReview", context: list[Message] = None, llm=None): - super().__init__(name, context, llm) - + name: str = "WriteCodeReview" + context: Optional[str] = None + llm: LLM = Field(default_factory=LLM) + @retry(stop=stop_after_attempt(2), wait=wait_fixed(1)) async def write_code(self, prompt): code_rsp = await self._aask(prompt) @@ -79,4 +83,3 @@ class WriteCodeReview(Action): # code_rsp = await self._aask_v1(prompt, "code_rsp", OUTPUT_MAPPING) # self._save(context, filename, code) return code - \ No newline at end of file diff --git a/metagpt/actions/write_prd.py b/metagpt/actions/write_prd.py index bd04ca79e..660d7fb95 100644 --- a/metagpt/actions/write_prd.py +++ b/metagpt/actions/write_prd.py @@ -5,9 +5,12 @@ @Author : alexanderwu @File : write_prd.py """ -from typing import List +from typing import List, Optional, Any + +from pydantic import BaseModel, Field from metagpt.actions import Action, ActionOutput +from metagpt.llm import LLM from metagpt.actions.search_and_summarize import SearchAndSummarize from metagpt.config import CONFIG from metagpt.logs import logger @@ -219,18 +222,25 @@ OUTPUT_MAPPING = { class WritePRD(Action): - def __init__(self, name="", context=None, llm=None): - super().__init__(name, context, llm) - + name: str = "" + content: Optional[str] = None + llm: LLM = Field(default_factory=LLM) + assistant_search_action: Action = None + + def __init__(self, **kwargs): + super().__init__(**kwargs) + async def run(self, requirements, format=CONFIG.prompt_format, *args, **kwargs) -> ActionOutput: - sas = SearchAndSummarize() - # rsp = await sas.run(context=requirements, system_text=SEARCH_AND_SUMMARIZE_SYSTEM_EN_US) - rsp = "" - info = f"### Search Results\n{sas.result}\n\n### Search Summary\n{rsp}" - if sas.result: - logger.info(sas.result) + # self.assistant_search_action = SearchAndSummarize() + if self.assistant_search_action is None: + self.assistant_search_action = SearchAndSummarize() + # self.assistant_search_action = SearchAndSummarize() + rsp = await self.assistant_search_action.run(context=requirements) + info = f"### Search Results\n{self.assistant_search_action.result}\n\n### Search Summary\n{rsp}" + if self.assistant_search_action.result: + logger.info(self.assistant_search_action.result) logger.info(rsp) - + prompt_template, format_example = get_template(templates, format) prompt = prompt_template.format( requirements=requirements, search_information=info, format_example=format_example diff --git a/metagpt/environment.py b/metagpt/environment.py index 24e6ada2f..88ff145e0 100644 --- a/metagpt/environment.py +++ b/metagpt/environment.py @@ -29,11 +29,12 @@ class Environment(BaseModel): arbitrary_types_allowed = True def add_role(self, role: Role): - """增加一个在当前环境的角色 + """增加一个在当前环境的角色, 默认为profile/role_profile Add a role in the current environment """ role.set_env(self) - self.roles[role.profile] = role + # use alias + self.roles[role.role_profile] = role def add_roles(self, roles: Iterable[Role]): """增加一批在当前环境的角色 diff --git a/metagpt/roles/architect.py b/metagpt/roles/architect.py index 15d5fe5b1..face22a68 100644 --- a/metagpt/roles/architect.py +++ b/metagpt/roles/architect.py @@ -5,10 +5,11 @@ @Author : alexanderwu @File : architect.py """ +from pydantic import Field from metagpt.actions import WritePRD from metagpt.actions.design_api import WriteDesign -from metagpt.roles import Role +from metagpt.roles.role import Role class Architect(Role): @@ -21,17 +22,16 @@ class Architect(Role): goal (str): Primary goal or responsibility of the architect. constraints (str): Constraints or guidelines for the architect. """ + name: str = "Bob" + role_profile: str = Field(default="Architect" , alias='profile') + goal: str = "Design a concise, usable, complete python system" + constraints: str = "Try to specify good open source tools as much as possible" def __init__( - self, - name: str = "Bob", - profile: str = "Architect", - goal: str = "Design a concise, usable, complete python system", - constraints: str = "Try to specify good open source tools as much as possible", + self, + **kwargs ) -> None: - """Initializes the Architect with given attributes.""" - super().__init__(name, profile, goal, constraints) - + super().__init__(**kwargs) # Initialize actions specific to the Architect role self._init_actions([WriteDesign]) diff --git a/metagpt/roles/engineer.py b/metagpt/roles/engineer.py index 1f6685b38..129bedeb8 100644 --- a/metagpt/roles/engineer.py +++ b/metagpt/roles/engineer.py @@ -9,11 +9,12 @@ import asyncio import shutil from collections import OrderedDict from pathlib import Path +from pydantic import Field from metagpt.actions import WriteCode, WriteCodeReview, WriteDesign, WriteTasks from metagpt.const import WORKSPACE_ROOT from metagpt.logs import logger -from metagpt.roles import Role +from metagpt.roles.role import Role from metagpt.schema import Message from metagpt.utils.common import CodeParser from metagpt.utils.special_tokens import FILENAME_CODE_SEP, MSG_SEP @@ -23,7 +24,7 @@ async def gather_ordered_k(coros, k) -> list: tasks = OrderedDict() results = [None] * len(coros) done_queue = asyncio.Queue() - + for i, coro in enumerate(coros): if len(tasks) >= k: done, _ = await asyncio.wait(tasks.keys(), return_when=asyncio.FIRST_COMPLETED) @@ -32,17 +33,17 @@ async def gather_ordered_k(coros, k) -> list: await done_queue.put((index, task.result())) task = asyncio.create_task(coro) tasks[task] = i - + if tasks: done, _ = await asyncio.wait(tasks.keys()) for task in done: index = tasks[task] await done_queue.put((index, task.result())) - + while not done_queue.empty(): index, result = await done_queue.get() results[index] = result - + return results @@ -59,42 +60,42 @@ class Engineer(Role): use_code_review (bool): Whether to use code review. todos (list): List of tasks. """ - + name: str = "Alex" + role_profile: str = Field(default="Engineer", alias='profile') + goal: str = "Write elegant, readable, extensible, efficient code" + constraints: str = "The code should conform to standards like PEP8 and be modular and maintainable" + n_borg: int = 1 + use_code_review: bool = False + todos: list = [] + def __init__( - self, - name: str = "Alex", - profile: str = "Engineer", - goal: str = "Write elegant, readable, extensible, efficient code", - constraints: str = "The code should conform to standards like PEP8 and be modular and maintainable", - n_borg: int = 1, - use_code_review: bool = False, + self, + **kwargs ) -> None: - """Initializes the Engineer role with given attributes.""" - super().__init__(name, profile, goal, constraints) - self._init_actions([WriteCode]) - self.use_code_review = use_code_review + super().__init__(**kwargs) + + actions = [WriteCode] if self.use_code_review: - self._init_actions([WriteCode, WriteCodeReview]) + actions = [WriteCode, WriteCodeReview] + self._init_actions(actions) self._watch([WriteTasks]) - self.todos = [] - self.n_borg = n_borg - + @classmethod def parse_tasks(self, task_msg: Message) -> list[str]: if task_msg.instruct_content: return task_msg.instruct_content.dict().get("Task list") return CodeParser.parse_file_list(block="Task list", text=task_msg.content) - + @classmethod def parse_code(self, code_text: str) -> str: return CodeParser.parse_code(block="", text=code_text) - + @classmethod def parse_workspace(cls, system_design_msg: Message) -> str: if system_design_msg.instruct_content: return system_design_msg.instruct_content.dict().get("Python package name").strip().strip("'").strip('"') return CodeParser.parse_str(block="Python package name", text=system_design_msg.content) - + def get_workspace(self) -> Path: msg = self._rc.memory.get_by_action(WriteDesign)[-1] if not msg: @@ -102,7 +103,7 @@ class Engineer(Role): workspace = self.parse_workspace(msg) # Codes are written in workspace/{package_name}/{package_name} return WORKSPACE_ROOT / workspace / workspace - + def recreate_workspace(self): workspace = self.get_workspace() try: @@ -110,7 +111,7 @@ class Engineer(Role): except FileNotFoundError: pass # The folder does not exist, but we don't care workspace.mkdir(parents=True, exist_ok=True) - + def write_file(self, filename: str, code: str): workspace = self.get_workspace() filename = filename.replace('"', "").replace("\n", "") @@ -118,12 +119,12 @@ class Engineer(Role): file.parent.mkdir(parents=True, exist_ok=True) file.write_text(code) return file - + def recv(self, message: Message) -> None: self._rc.memory.add(message) if message in self._rc.important_memory: self.todos = self.parse_tasks(message) - + async def _act_mp(self) -> Message: # self.recreate_workspace() todo_coros = [] @@ -132,7 +133,7 @@ class Engineer(Role): context=self._rc.memory.get_by_actions([WriteTasks, WriteDesign]), filename=todo ) todo_coros.append(todo_coro) - + rsps = await gather_ordered_k(todo_coros, self.n_borg) for todo, code_rsp in zip(self.todos, rsps): _ = self.parse_code(code_rsp) @@ -142,11 +143,11 @@ class Engineer(Role): msg = Message(content=code_rsp, role=self.profile, cause_by=type(self._rc.todo)) self._rc.memory.add(msg) del self.todos[0] - + logger.info(f"Done {self.get_workspace()} generating.") msg = Message(content="all done.", role=self.profile, cause_by=type(self._rc.todo)) return msg - + async def _act_sp(self) -> Message: code_msg_all = [] # gather all code info, will pass to qa_engineer for tests later for todo in self.todos: @@ -157,16 +158,16 @@ class Engineer(Role): file_path = self.write_file(todo, code) msg = Message(content=code, role=self.profile, cause_by=type(self._rc.todo)) self._rc.memory.add(msg) - + code_msg = todo + FILENAME_CODE_SEP + str(file_path) code_msg_all.append(code_msg) - + logger.info(f"Done {self.get_workspace()} generating.") msg = Message( content=MSG_SEP.join(code_msg_all), role=self.profile, cause_by=type(self._rc.todo), send_to="QaEngineer" ) return msg - + async def _act_sp_precision(self) -> Message: code_msg_all = [] # gather all code info, will pass to qa_engineer for tests later for todo in self.todos: @@ -195,19 +196,18 @@ class Engineer(Role): file_path = self.write_file(todo, code) msg = Message(content=code, role=self.profile, cause_by=WriteCode) self._rc.memory.add(msg) - + code_msg = todo + FILENAME_CODE_SEP + str(file_path) code_msg_all.append(code_msg) - + logger.info(f"Done {self.get_workspace()} generating.") msg = Message( content=MSG_SEP.join(code_msg_all), role=self.profile, cause_by=type(self._rc.todo), send_to="QaEngineer" ) return msg - + async def _act(self) -> Message: """Determines the mode of action based on whether code review is used.""" - logger.info(f"{self._setting}: ready to WriteCode") if self.use_code_review: return await self._act_sp_precision() return await self._act_sp() diff --git a/metagpt/roles/product_manager.py b/metagpt/roles/product_manager.py index a58ea5385..b099fb4d9 100644 --- a/metagpt/roles/product_manager.py +++ b/metagpt/roles/product_manager.py @@ -5,37 +5,33 @@ @Author : alexanderwu @File : product_manager.py """ +from pydantic import Field + from metagpt.actions import BossRequirement, WritePRD -from metagpt.roles import Role +from metagpt.roles.role import Role class ProductManager(Role): """ - Represents a Product Manager role responsible for product development and management. + Initializes the ProductManager role with given attributes. - Attributes: + Args: name (str): Name of the product manager. - profile (str): Role profile, default is 'Product Manager'. + profile (str): Role profile. goal (str): Goal of the product manager. constraints (str): Constraints or limitations for the product manager. """ - + name: str = "Alice" + role_profile: str = Field(default="Product Manager", alias='profile') + goal: str = "Efficiently create a successful product" + constraints: str = "" + """ + Represents a Product Manager role responsible for product development and management. + """ def __init__( - self, - name: str = "Alice", - profile: str = "Product Manager", - goal: str = "Efficiently create a successful product", - constraints: str = "", + self, + **kwargs ) -> None: - """ - Initializes the ProductManager role with given attributes. - - Args: - name (str): Name of the product manager. - profile (str): Role profile. - goal (str): Goal of the product manager. - constraints (str): Constraints or limitations for the product manager. - """ - super().__init__(name, profile, goal, constraints) + super().__init__(**kwargs) self._init_actions([WritePRD]) self._watch([BossRequirement]) diff --git a/metagpt/roles/project_manager.py b/metagpt/roles/project_manager.py index 7e7c5699d..a2b227f22 100644 --- a/metagpt/roles/project_manager.py +++ b/metagpt/roles/project_manager.py @@ -5,9 +5,11 @@ @Author : alexanderwu @File : project_manager.py """ +from pydantic import Field + from metagpt.actions import WriteTasks from metagpt.actions.design_api import WriteDesign -from metagpt.roles import Role +from metagpt.roles.role import Role class ProjectManager(Role): @@ -20,23 +22,16 @@ class ProjectManager(Role): goal (str): Goal of the project manager. constraints (str): Constraints or limitations for the project manager. """ + name: str = "Eve" + role_profile: str = Field(default="Project Manager", alias='profile') + + goal: str = "Improve team efficiency and deliver with quality and quantity" + constraints: str = "" def __init__( - self, - name: str = "Eve", - profile: str = "Project Manager", - goal: str = "Improve team efficiency and deliver with quality and quantity", - constraints: str = "", + self, + **kwargs ) -> None: - """ - Initializes the ProjectManager role with given attributes. - - Args: - name (str): Name of the project manager. - profile (str): Role profile. - goal (str): Goal of the project manager. - constraints (str): Constraints or limitations for the project manager. - """ - super().__init__(name, profile, goal, constraints) + super().__init__(**kwargs) self._init_actions([WriteTasks]) self._watch([WriteDesign]) diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index b96c361c0..9aae64188 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -5,17 +5,26 @@ @Author : alexanderwu @File : role.py """ + from __future__ import annotations -from typing import Iterable, Type, Union -from enum import Enum - +import sys +from types import SimpleNamespace +from typing import ( + Dict, + Optional, + Union, + Iterable, + Type +) +import re from pydantic import BaseModel, Field +from importlib import import_module # from metagpt.environment import Environment from metagpt.config import CONFIG from metagpt.actions import Action, ActionOutput -from metagpt.llm import LLM, HumanProvider +from metagpt.llm import LLM from metagpt.logs import logger from metagpt.memory import Memory, LongTermMemory from metagpt.schema import Message @@ -28,14 +37,12 @@ Please note that only the text between the first and second "===" is information {history} === -Your previous stage: {previous_state} - -Now choose one of the following stages you need to go to in the next step: +You can now choose one of the following stages to decide the stage you need to go in the next step: {states} Just answer a number between 0-{n_states}, choose the most suitable stage according to the understanding of the conversation. Please note that the answer only needs a number, no need to add any other text. -If you think you have completed your goal and don't need to go to any of the stages, return -1. +If there is no conversation record, choose 0. Do not answer anything else, and do not add any other information in your answer. """ @@ -49,27 +56,18 @@ ROLE_TEMPLATE = """Your response should be based on the previous conversation hi {name}: {result} """ -class RoleReactMode(str, Enum): - REACT = "react" - BY_ORDER = "by_order" - PLAN_AND_ACT = "plan_and_act" - - @classmethod - def values(cls): - return [item.value for item in cls] class RoleSetting(BaseModel): """Role Settings""" - name: str - profile: str - goal: str - constraints: str - desc: str - is_human: bool - + name: str = "" + profile: str = "" + goal: str = "" + constraints: str = "" + desc: str = "" + def __str__(self): return f"{self.name}({self.profile})" - + def __repr__(self): return self.__str__() @@ -79,109 +77,128 @@ class RoleContext(BaseModel): env: 'Environment' = Field(default=None) memory: Memory = Field(default_factory=Memory) long_term_memory: LongTermMemory = Field(default_factory=LongTermMemory) - state: int = Field(default=-1) # -1 indicates initial or termination state where todo is None + state: int = Field(default=0) todo: Action = Field(default=None) watch: set[Type[Action]] = Field(default_factory=set) news: list[Type[Message]] = Field(default=[]) - react_mode: RoleReactMode = RoleReactMode.REACT # see `Role._set_react_mode` for definitions of the following two attributes - max_react_loop: int = 1 - + class Config: arbitrary_types_allowed = True - + def check(self, role_id: str): if hasattr(CONFIG, "long_term_memory") and CONFIG.long_term_memory: self.long_term_memory.recover_memory(role_id, self) self.memory = self.long_term_memory # use memory to act as long_term_memory for unify operation - + @property def important_memory(self) -> list[Message]: """Get the information corresponding to the watched actions""" return self.memory.get_by_actions(self.watch) - + @property def history(self) -> list[Message]: return self.memory.get() -class Role: +class Role(BaseModel): """Role/Agent""" - - def __init__(self, name="", profile="", goal="", constraints="", desc="", is_human=False): - self._llm = LLM() if not is_human else HumanProvider() - self._setting = RoleSetting(name=name, profile=profile, goal=goal, - constraints=constraints, desc=desc, is_human=is_human) - self._states = [] - self._actions = [] - self._role_id = str(self._setting) - self._rc = RoleContext() - + name: str = "" + profile: str = "" + goal: str = "" + constraints: str = "" + desc: str = "" + _setting: RoleSetting = Field(default_factory=RoleSetting, alias="_setting") + _setting = RoleSetting(name=name, profile=profile, goal=goal, constraints=constraints) + _role_id: str = "" + _states: list = Field(default=[]) + _actions: list = Field(default=[]) + _actions_type: list = Field(default=[]) + _rc: RoleContext = RoleContext() + + _private_attributes = { + '_setting': _setting, + '_role_id': _role_id, + '_states': [], + '_actions': [], + '_actions_type': [] # 用于记录和序列化 + } + + class Config: + arbitrary_types_allowed = True + + def __init__(self, **kwargs): + super().__init__(**kwargs) + # 关于私有变量的初始化 https://github.com/pydantic/pydantic/issues/655 + for key in self._private_attributes.keys(): + if key in kwargs: + object.__setattr__(self, key, kwargs[key]) + if key =="_setting": + _setting = RoleSetting(**kwargs[key]) + object.__setattr__(self, '_setting', _setting) + elif key == "_rc": + _rc = RoleContext + object.__setattr__(self, '_rc', _rc) + else: + object.__setattr__(self, key, self._private_attributes[key]) + def _reset(self): - self._states = [] - self._actions = [] + object.__setattr__(self, '_states', []) + object.__setattr__(self, '_actions', []) + + + @staticmethod + def _process_class(class_str, module_name): + cleaned_string = re.sub(r"[<>']", "", class_str).replace("class ", "") + package_name = "metagpt" + file_name = cleaned_string.replace(package_name, "").replace("." + module_name, "") + print(file_name) + # print("\n", sys.modules) + module_file = import_module(file_name, package=package_name) + module = getattr(module_file, module_name) + return module + def _init_actions(self, actions): self._reset() for idx, action in enumerate(actions): if not isinstance(action, Action): - i = action("", llm=self._llm) + ## 默认初始化 + i = action() else: - if self._setting.is_human and not isinstance(action.llm, HumanProvider): - logger.warning(f"is_human attribute does not take effect," - f"as Role's {str(action)} was initialized using LLM, try passing in Action classes instead of initialized instances") i = action i.set_prefix(self._get_prefix(), self.profile) self._actions.append(i) self._states.append(f"{idx}. {action}") - - def _set_react_mode(self, react_mode: str, max_react_loop: int = 1): - """Set strategy of the Role reacting to observed Message. Variation lies in how - this Role elects action to perform during the _think stage, especially if it is capable of multiple Actions. - - Args: - react_mode (str): Mode for choosing action during the _think stage, can be one of: - "react": standard think-act loop in the ReAct paper, alternating thinking and acting to solve the task, i.e. _think -> _act -> _think -> _act -> ... - Use llm to select actions in _think dynamically; - "by_order": switch action each time by order defined in _init_actions, i.e. _act (Action1) -> _act (Action2) -> ...; - "plan_and_act": first plan, then execute an action sequence, i.e. _think (of a plan) -> _act -> _act -> ... - Use llm to come up with the plan dynamically. - Defaults to "react". - max_react_loop (int): Maximum react cycles to execute, used to prevent the agent from reacting forever. - Take effect only when react_mode is react, in which we use llm to choose actions, including termination. - Defaults to 1, i.e. _think -> _act (-> return result and end) - """ - assert react_mode in RoleReactMode.values(), f"react_mode must be one of {RoleReactMode.values()}" - self._rc.react_mode = react_mode - if react_mode == RoleReactMode.REACT: - self._rc.max_react_loop = max_react_loop - + action_title = action.schema()["title"] + self._actions_type.append(action_title) + def _watch(self, actions: Iterable[Type[Action]]): """Listen to the corresponding behaviors""" self._rc.watch.update(actions) # check RoleContext after adding watch actions self._rc.check(self._role_id) - - def _set_state(self, state: int): + + def _set_state(self, state): """Update the current state.""" self._rc.state = state logger.debug(self._actions) - self._rc.todo = self._actions[self._rc.state] if state >= 0 else None - + self._rc.todo = self._actions[self._rc.state] + def set_env(self, env: 'Environment'): """Set the environment in which the role works. The role can talk to the environment and can also receive messages by observing.""" self._rc.env = env - + @property def profile(self): """Get the role description (position)""" return self._setting.profile - + def _get_prefix(self): """Get the role prefix""" if self._setting.desc: return self._setting.desc return PREFIX_TEMPLATE.format(**self._setting.dict()) - + async def _think(self) -> None: """Think about what to do and decide on the next action""" if len(self._actions) == 1: @@ -190,104 +207,60 @@ class Role: return prompt = self._get_prefix() prompt += STATE_TEMPLATE.format(history=self._rc.history, states="\n".join(self._states), - n_states=len(self._states) - 1, previous_state=self._rc.state) - # print(prompt) + n_states=len(self._states) - 1) next_state = await self._llm.aask(prompt) logger.debug(f"{prompt=}") - if (not next_state.isdigit() and next_state != "-1") \ - or int(next_state) not in range(-1, len(self._states)): - logger.warning(f'Invalid answer of state, {next_state=}, will be set to -1') - next_state = -1 - else: - next_state = int(next_state) - if next_state == -1: - logger.info(f"End actions with {next_state=}") - self._set_state(next_state) - + if not next_state.isdigit() or int(next_state) not in range(len(self._states)): + logger.warning(f'Invalid answer of state, {next_state=}') + next_state = "0" + self._set_state(int(next_state)) + async def _act(self) -> Message: - # prompt = self.get_prefix() - # prompt += ROLE_TEMPLATE.format(name=self.profile, state=self.states[self.state], result=response, - # history=self.history) - logger.info(f"{self._setting}: ready to {self._rc.todo}") response = await self._rc.todo.run(self._rc.important_memory) # logger.info(response) if isinstance(response, ActionOutput): msg = Message(content=response.content, instruct_content=response.instruct_content, - role=self.profile, cause_by=type(self._rc.todo)) + role=self.profile, cause_by=type(self._rc.todo)) else: msg = Message(content=response, role=self.profile, cause_by=type(self._rc.todo)) self._rc.memory.add(msg) # logger.debug(f"{response}") - + return msg - + async def _observe(self) -> int: """Observe from the environment, obtain important information, and add it to memory""" if not self._rc.env: return 0 env_msgs = self._rc.env.memory.get() - + observed = self._rc.env.memory.get_by_actions(self._rc.watch) - self._rc.news = self._rc.memory.find_news(observed) # find news (previously unseen messages) from observed messages - + self._rc.news = self._rc.memory.find_news( + observed) # find news (previously unseen messages) from observed messages + for i in env_msgs: self.recv(i) - + news_text = [f"{i.role}: {i.content[:20]}..." for i in self._rc.news] if news_text: logger.debug(f'{self._setting} observed: {news_text}') return len(self._rc.news) - + def _publish_message(self, msg): """If the role belongs to env, then the role's messages will be broadcast to env""" if not self._rc.env: # If env does not exist, do not publish the message return self._rc.env.publish_message(msg) - + async def _react(self) -> Message: - """Think first, then act, until the Role _think it is time to stop and requires no more todo. - This is the standard think-act loop in the ReAct paper, which alternates thinking and acting in task solving, i.e. _think -> _act -> _think -> _act -> ... - Use llm to select actions in _think dynamically - """ - actions_taken = 0 - rsp = Message("No actions taken yet") # will be overwritten after Role _act - while actions_taken < self._rc.max_react_loop: - # think - await self._think() - if self._rc.todo is None: - break - # act - logger.debug(f"{self._setting}: {self._rc.state=}, will do {self._rc.todo}") - rsp = await self._act() - actions_taken += 1 - return rsp # return output from the last action - - async def _act_by_order(self) -> Message: - """switch action each time by order defined in _init_actions, i.e. _act (Action1) -> _act (Action2) -> ...""" - for i in range(len(self._states)): - self._set_state(i) - rsp = await self._act() - return rsp # return output from the last action - - async def _plan_and_act(self) -> Message: - """first plan, then execute an action sequence, i.e. _think (of a plan) -> _act -> _act -> ... Use llm to come up with the plan dynamically.""" - # TODO: to be implemented - return Message("") - - async def react(self) -> Message: - """Entry to one of three strategies by which Role reacts to the observed Message""" - if self._rc.react_mode == RoleReactMode.REACT: - rsp = await self._react() - elif self._rc.react_mode == RoleReactMode.BY_ORDER: - rsp = await self._act_by_order() - elif self._rc.react_mode == RoleReactMode.PLAN_AND_ACT: - rsp = await self._plan_and_act() - self._set_state(state=-1) # current reaction is complete, reset state to -1 and todo back to None - return rsp - + """Think first, then act""" + await self._think() + logger.debug(f"{self._setting}: {self._rc.state=}, will do {self._rc.todo}") + return await self._act() + def recv(self, message: Message) -> None: """add message to history.""" # self._history += f"\n{message}" @@ -295,18 +268,14 @@ class Role: if message in self._rc.memory.get(): return self._rc.memory.add(message) - + async def handle(self, message: Message) -> Message: """Receive information and reply with actions""" # logger.debug(f"{self.name=}, {self.profile=}, {message.role=}") self.recv(message) - + return await self._react() - - def get_memories(self, k=0) -> list[Message]: - """A wrapper to return the most recent k memories of this role, return all when k=0""" - return self._rc.memory.get(k=k) - + async def run(self, message=None): """Observe, and think and act based on the results of the observation""" if message: @@ -320,8 +289,8 @@ class Role: # If there is no new information, suspend and wait logger.debug(f"{self._setting}: no news. waiting.") return - - rsp = await self.react() + + rsp = await self._react() # Publish the reply to the environment, waiting for the next subscriber to process self._publish_message(rsp) return rsp From 0dd63e4b2363d30d6c7e5db1705e749f00c9f82f Mon Sep 17 00:00:00 2001 From: stellahsr Date: Mon, 27 Nov 2023 21:13:19 +0800 Subject: [PATCH 002/167] update test cases for serialize_deserialize --- .../metagpt/serialize_deserialize/__init__.py | 4 ++ .../serialize_deserialize/test_actions.py | 24 ++++++++++ .../test_architect_deserialize.py | 26 ++++++++++ .../test_product_manager.py | 21 +++++++++ .../test_project_manager.py | 26 ++++++++++ .../serialize_deserialize/test_role.py | 41 ++++++++++++++++ .../serialize_deserialize/test_team.py | 47 +++++++++++++++++++ .../serialize_deserialize/test_wrire_prd.py | 28 +++++++++++ .../serialize_deserialize/test_write_code.py | 42 +++++++++++++++++ .../test_write_design.py | 39 +++++++++++++++ 10 files changed, 298 insertions(+) create mode 100644 tests/metagpt/serialize_deserialize/__init__.py create mode 100644 tests/metagpt/serialize_deserialize/test_actions.py create mode 100644 tests/metagpt/serialize_deserialize/test_architect_deserialize.py create mode 100644 tests/metagpt/serialize_deserialize/test_product_manager.py create mode 100644 tests/metagpt/serialize_deserialize/test_project_manager.py create mode 100644 tests/metagpt/serialize_deserialize/test_role.py create mode 100644 tests/metagpt/serialize_deserialize/test_team.py create mode 100644 tests/metagpt/serialize_deserialize/test_wrire_prd.py create mode 100644 tests/metagpt/serialize_deserialize/test_write_code.py create mode 100644 tests/metagpt/serialize_deserialize/test_write_design.py diff --git a/tests/metagpt/serialize_deserialize/__init__.py b/tests/metagpt/serialize_deserialize/__init__.py new file mode 100644 index 000000000..78f454fb5 --- /dev/null +++ b/tests/metagpt/serialize_deserialize/__init__.py @@ -0,0 +1,4 @@ +# -*- coding: utf-8 -*- +# @Date : 11/22/2023 11:48 AM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : diff --git a/tests/metagpt/serialize_deserialize/test_actions.py b/tests/metagpt/serialize_deserialize/test_actions.py new file mode 100644 index 000000000..e2efa982b --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_actions.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -*- +# @Date : 11/22/2023 11:48 AM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +import pytest + +from metagpt.actions import Action +from metagpt.llm import LLM + +def test_action_serialize(): + action = Action() + ser_action_dict = action.dict() + assert "name" in ser_action_dict + assert "llm" in ser_action_dict + +@pytest.mark.asyncio +async def test_action_deserialize(): + action = Action() + serialized_data = action.dict() + + new_action = Action(**serialized_data) + assert new_action.name == "" + assert new_action.llm == LLM() + assert len(await new_action._aask("who are you")) > 0 diff --git a/tests/metagpt/serialize_deserialize/test_architect_deserialize.py b/tests/metagpt/serialize_deserialize/test_architect_deserialize.py new file mode 100644 index 000000000..cff1bbadd --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_architect_deserialize.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- +# @Date : 11/26/2023 2:04 PM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +import pytest + +from metagpt.roles.architect import Architect +from metagpt.actions.action import Action + +def test_architect_serialize(): + role = Architect() + ser_role_dict = role.dict(by_alias=True) + assert "name" in ser_role_dict + assert "_states" in ser_role_dict + assert "_actions" in ser_role_dict + +@pytest.mark.asyncio +async def test_architect_deserialize(): + role = Architect() + ser_role_dict = role.dict(by_alias=True) + new_role = Architect(**ser_role_dict) + # new_role = Architect.deserialize(ser_role_dict) + assert new_role.name == "Bob" + assert len(new_role._actions) == 1 + assert isinstance(new_role._actions[0], Action) + await new_role._actions[0].run(context="write a cli snake game") \ No newline at end of file diff --git a/tests/metagpt/serialize_deserialize/test_product_manager.py b/tests/metagpt/serialize_deserialize/test_product_manager.py new file mode 100644 index 000000000..978c50e5e --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_product_manager.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- +# @Date : 11/26/2023 2:07 PM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +import pytest + +from metagpt.roles.product_manager import ProductManager +from metagpt.actions.action import Action +from metagpt.schema import Message + +@pytest.mark.asyncio +async def test_product_manager_deserialize(): + role = ProductManager() + ser_role_dict = role.dict(by_alias=True) + new_role = ProductManager(**ser_role_dict) + # new_role = ProductManager().deserialize(ser_role_dict) + + assert new_role.name == "Alice" + assert len(new_role._actions) == 1 + assert isinstance(new_role._actions[0], Action) + await new_role._actions[0].run([Message(content="write a cli snake game")]) \ No newline at end of file diff --git a/tests/metagpt/serialize_deserialize/test_project_manager.py b/tests/metagpt/serialize_deserialize/test_project_manager.py new file mode 100644 index 000000000..590bd8109 --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_project_manager.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- +# @Date : 11/26/2023 2:06 PM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +import pytest + +from metagpt.roles.project_manager import ProjectManager +from metagpt.actions.action import Action + +def test_project_manager_serialize(): + role = ProjectManager() + ser_role_dict = role.dict(by_alias=True) + assert "name" in ser_role_dict + assert "_states" in ser_role_dict + assert "_actions" in ser_role_dict + +@pytest.mark.asyncio +async def test_project_manager_deserialize(): + role = ProjectManager() + ser_role_dict = role.dict(by_alias=True) + new_role = ProjectManager(**ser_role_dict) + # new_role = ProjectManager().deserialize(ser_role_dict) + assert new_role.name == "Eve" + assert len(new_role._actions) == 1 + assert isinstance(new_role._actions[0], Action) + await new_role._actions[0].run(context="write a cli snake game") \ No newline at end of file diff --git a/tests/metagpt/serialize_deserialize/test_role.py b/tests/metagpt/serialize_deserialize/test_role.py new file mode 100644 index 000000000..432c9acb7 --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_role.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +# @Date : 11/23/2023 4:49 PM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +import pytest + +from metagpt.roles.role import Role +from metagpt.roles.engineer import Engineer + +from metagpt.actions.action import Action + + +def test_role_serialize(): + role = Role() + ser_role_dict = role.dict(by_alias=True) + assert "name" in ser_role_dict + assert "_states" in ser_role_dict + assert "_actions" in ser_role_dict + + +def test_engineer_serialize(): + role = Engineer() + ser_role_dict = role.dict(by_alias=True) + assert "name" in ser_role_dict + assert "_states" in ser_role_dict + assert "_actions" in ser_role_dict + + +@pytest.mark.asyncio +async def test_engineer_deserialize(): + role = Engineer(use_code_review=True) + ser_role_dict = role.dict(by_alias=True) + # new_role = Engineer().deserialize(ser_role_dict) + # also can be deserialized in this way: + new_role = Engineer(**ser_role_dict) + assert new_role.name == "Alex" + assert new_role.use_code_review == True + assert len(new_role._actions) == 2 + assert isinstance(new_role._actions[0], Action) + assert isinstance(new_role._actions[1], Action) + await new_role._actions[0].run(context="write a cli snake game", filename="test_code") diff --git a/tests/metagpt/serialize_deserialize/test_team.py b/tests/metagpt/serialize_deserialize/test_team.py new file mode 100644 index 000000000..44a75d262 --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_team.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +# @Date : 11/27/2023 10:07 AM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +import pytest + +from metagpt.environment import Environment +from metagpt.schema import Message +from metagpt.software_company import SoftwareCompany +from metagpt.roles import ProjectManager, ProductManager, Architect + + +def test_env_serialize(): + env = Environment() + ser_env_dict = env.dict() + assert "roles" in ser_env_dict + assert "memory" in ser_env_dict + assert "memory" in ser_env_dict + + +def test_env_deserialize(): + env = Environment() + env.publish_message(message=Message(content="test env serialize")) + ser_env_dict = env.dict() + new_env = Environment(**ser_env_dict) + assert len(new_env.roles) == 0 + assert new_env.memory.storage[0].content == "test env serialize" + assert len(new_env.history) == 25 + + +def test_softwarecompany_deserialize(): + team = SoftwareCompany() + team.hire( + [ + ProductManager(), + Architect(), + ProjectManager(), + ] + ) + assert len(team.environment.get_roles()) == 3 + ser_team_dict = team.dict() + new_team = SoftwareCompany(**ser_team_dict) + + assert len(new_team.environment.get_roles()) == 3 + assert new_team.environment.get_role('Product Manager') is not None + assert new_team.environment.get_role('Product Manager') is not None + assert new_team.environment.get_role('Architect') is not None diff --git a/tests/metagpt/serialize_deserialize/test_wrire_prd.py b/tests/metagpt/serialize_deserialize/test_wrire_prd.py new file mode 100644 index 000000000..9b2653820 --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_wrire_prd.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- +# @Date : 11/22/2023 1:47 PM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +import pytest + +from metagpt.actions import WritePRD +from metagpt.llm import LLM +from metagpt.schema import Message + + +def test_action_serialize(): + action = WritePRD() + ser_action_dict = action.dict() + assert "name" in ser_action_dict + assert "llm" in ser_action_dict + + +@pytest.mark.asyncio +async def test_action_deserialize(): + action = WritePRD() + serialized_data = action.dict() + new_action = WritePRD(**serialized_data) + # new_action = WritePRD().deserialize(serialized_data) + assert new_action.name == "" + assert new_action.llm == LLM() + assert len(await new_action.run([Message(content="write a cli snake game")]))>0 + diff --git a/tests/metagpt/serialize_deserialize/test_write_code.py b/tests/metagpt/serialize_deserialize/test_write_code.py new file mode 100644 index 000000000..0b1f1dc7c --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_write_code.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +# @Date : 11/23/2023 10:56 AM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +import pytest + +from metagpt.actions import WriteCode, WriteCodeReview +from metagpt.llm import LLM + +def test_write_design_serialize(): + action = WriteCode() + ser_action_dict = action.dict() + assert ser_action_dict["name"] == "WriteCode" + assert "llm" in ser_action_dict + +def test_write_task_serialize(): + action = WriteCodeReview() + ser_action_dict = action.dict() + assert ser_action_dict["name"] == "WriteCodeReview" + assert "llm" in ser_action_dict + +@pytest.mark.asyncio +async def test_write_code_deserialize(): + action = WriteCode() + serialized_data = action.dict() + new_action = WriteCode(**serialized_data) + # new_action = WriteCode().deserialize(serialized_data) + assert new_action.name == "WriteCode" + assert new_action.llm == LLM() + await new_action.run(context="write a cli snake game", filename="test_code") + +@pytest.mark.asyncio +async def test_write_code_review_deserialize(): + action = WriteCodeReview() + serialized_data = action.dict() + new_action = WriteCodeReview(**serialized_data) + # new_action = WriteCodeReview().deserialize(serialized_data) + code = await WriteCode().run(context="write a cli snake game", filename="test_code") + + assert new_action.name == "WriteCodeReview" + assert new_action.llm == LLM() + await new_action.run(context="write a cli snake game", code =code, filename="test_rewrite_code") \ No newline at end of file diff --git a/tests/metagpt/serialize_deserialize/test_write_design.py b/tests/metagpt/serialize_deserialize/test_write_design.py new file mode 100644 index 000000000..56bf78a63 --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_write_design.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +# @Date : 11/22/2023 8:19 PM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +import pytest + +from metagpt.actions import WriteDesign, WriteTasks +from metagpt.llm import LLM + +def test_write_design_serialize(): + action = WriteDesign() + ser_action_dict = action.dict() + assert "name" in ser_action_dict + assert "llm" in ser_action_dict + +def test_write_task_serialize(): + action = WriteTasks() + ser_action_dict = action.dict() + assert "name" in ser_action_dict + assert "llm" in ser_action_dict + +@pytest.mark.asyncio +async def test_write_design_deserialize(): + action = WriteDesign() + serialized_data = action.dict() + new_action = WriteDesign().deserialize(serialized_data) + assert new_action.name == "" + assert new_action.llm == LLM() + await new_action.run(context="write a cli snake game") + +@pytest.mark.asyncio +async def test_write_task_deserialize(): + action = WriteTasks() + serialized_data = action.dict() + new_action = WriteTasks(**serialized_data) + # new_action = WriteTasks().deserialize(serialized_data) + assert new_action.name == "CreateTasks" + assert new_action.llm == LLM() + await new_action.run(context="write a cli snake game") \ No newline at end of file From d99b4c62e33d1c37cb832c04030697d37a90be66 Mon Sep 17 00:00:00 2001 From: better629 Date: Tue, 28 Nov 2023 09:29:00 +0800 Subject: [PATCH 003/167] add mg ser&deser --- metagpt/actions/action.py | 32 ++++++++ metagpt/const.py | 1 + metagpt/environment.py | 38 +++++++++ metagpt/memory/memory.py | 30 +++++++ metagpt/roles/role.py | 115 ++++++++++++++++++++++++++- metagpt/schema.py | 44 ++++++++++ metagpt/team.py | 26 ++++++ metagpt/utils/serialize.py | 62 +++++++++++++-- metagpt/utils/utils.py | 41 ++++++++++ startup.py | 41 ++++++---- tests/metagpt/actions/test_action.py | 17 ++++ tests/metagpt/memory/test_memory.py | 42 ++++++++++ tests/metagpt/roles/test_role.py | 85 ++++++++++++++++++++ tests/metagpt/test_environment.py | 29 +++++-- tests/metagpt/test_schema.py | 42 ++++++++++ tests/metagpt/test_team.py | 27 +++++++ 16 files changed, 641 insertions(+), 31 deletions(-) create mode 100644 metagpt/utils/utils.py create mode 100644 tests/metagpt/memory/test_memory.py create mode 100644 tests/metagpt/roles/test_role.py create mode 100644 tests/metagpt/test_team.py diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index 790295d55..a538baa77 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -9,6 +9,7 @@ import re from abc import ABC from typing import Optional +import importlib from tenacity import retry, stop_after_attempt, wait_fixed from metagpt.actions.action_output import ActionOutput @@ -16,6 +17,7 @@ from metagpt.llm import LLM from metagpt.logs import logger from metagpt.utils.common import OutputParser from metagpt.utils.custom_decoder import CustomDecoder +from metagpt.utils.utils import import_class class Action(ABC): @@ -42,6 +44,36 @@ class Action(ABC): def __repr__(self): return self.__str__() + def serialize(self): + return { + "action_class": self.__class__.__name__, + "module_name": self.__module__, + "name": self.name + } + + @classmethod + def deserialize(cls, action_dict: dict): + action_class_str = action_dict.pop("action_class") + module_name = action_dict.pop("module_name") + action_class = import_class(action_class_str, module_name) + return action_class(**action_dict) + + @classmethod + def ser_class(cls): + """ serialize class type""" + return { + "action_class": cls.__name__, + "module_name": cls.__module__ + } + + @classmethod + def deser_class(cls, action_dict: dict): + """ deserialize class type """ + action_class_str = action_dict.pop("action_class") + module_name = action_dict.pop("module_name") + action_class = import_class(action_class_str, module_name) + return action_class + async def _aask(self, prompt: str, system_msgs: Optional[list[str]] = None) -> str: """Append default prefix""" if not system_msgs: diff --git a/metagpt/const.py b/metagpt/const.py index 407ce803a..711546d03 100644 --- a/metagpt/const.py +++ b/metagpt/const.py @@ -42,6 +42,7 @@ TMP = PROJECT_ROOT / "tmp" RESEARCH_PATH = DATA_PATH / "research" TUTORIAL_PATH = DATA_PATH / "tutorial_docx" INVOICE_OCR_TABLE_PATH = DATA_PATH / "invoice_table" +SERDES_PATH = WORKSPACE_ROOT / "storage" # TODO to store `storage` under the individual generated project SKILL_DIRECTORY = PROJECT_ROOT / "metagpt/skills" diff --git a/metagpt/environment.py b/metagpt/environment.py index 24e6ada2f..d1fa561f0 100644 --- a/metagpt/environment.py +++ b/metagpt/environment.py @@ -7,12 +7,14 @@ """ import asyncio from typing import Iterable +from pathlib import Path from pydantic import BaseModel, Field from metagpt.memory import Memory from metagpt.roles import Role from metagpt.schema import Message +from metagpt.utils.utils import read_json_file, write_json_file class Environment(BaseModel): @@ -28,6 +30,42 @@ class Environment(BaseModel): class Config: arbitrary_types_allowed = True + def serialize(self, stg_path: Path): + roles_path = stg_path.joinpath("roles.json") + roles_info = [] + for role_key, role in self.roles.items(): + roles_info.append({ + "role_class": role.__class__.__name__, + "module_name": role.__module__, + "role_name": role.name + }) + role.serialize(stg_path=stg_path.joinpath(f"roles/{role.__class__.__name__}_{role.name}")) + write_json_file(roles_path, roles_info) + + self.memory.serialize(stg_path) + history_path = stg_path.joinpath("history.json") + write_json_file(history_path, {"content": self.history}) + + def deserialize(self, stg_path: Path): + """ stg_path: ./storage/team/environment/ """ + roles_path = stg_path.joinpath("roles.json") + roles_info = read_json_file(roles_path) + for role_info in roles_info: + role_class = role_info.get("role_class") + role_name = role_info.get("role_name") + + role_path = stg_path.joinpath(f"roles/{role_class}_{role_name}") + role = Role.deserialize(role_path) + + self.add_role(role) + + memory = Memory.deserialize(stg_path) + self.memory = memory + + history_path = stg_path.joinpath("history.json") + history = read_json_file(history_path) + self.history = history.get("content") + def add_role(self, role: Role): """增加一个在当前环境的角色 Add a role in the current environment diff --git a/metagpt/memory/memory.py b/metagpt/memory/memory.py index c818fa707..a839bb038 100644 --- a/metagpt/memory/memory.py +++ b/metagpt/memory/memory.py @@ -7,9 +7,12 @@ """ from collections import defaultdict from typing import Iterable, Type +from pathlib import Path from metagpt.actions import Action from metagpt.schema import Message +from metagpt.utils.utils import read_json_file, write_json_file +from metagpt.utils.serialize import serialize_general_message, deserialize_general_message class Memory: @@ -20,6 +23,33 @@ class Memory: self.storage: list[Message] = [] self.index: dict[Type[Action], list[Message]] = defaultdict(list) + def serialize(self, stg_path: Path): + """ stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/ """ + memory_path = stg_path.joinpath("memory.json") + + storage = [] + for message in self.storage: + # msg_dict = message.serialize() + msg_dict = serialize_general_message(message) + storage.append(msg_dict) + + write_json_file(memory_path, storage) + + @classmethod + def deserialize(cls, stg_path: Path) -> "Memory": + """ stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/""" + memory_path = stg_path.joinpath("memory.json") + + memory = Memory() + memory_list = read_json_file(memory_path) + for message in memory_list: + # distinguish instruct_content type in message + # msg = Message.deserialize(message) + msg = deserialize_general_message(message) + memory.add(msg) + + return memory + def add(self, message: Message): """Add a new message to storage, while updating the index""" if message in self.storage: diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index b96c361c0..9b0613fd5 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -9,8 +9,9 @@ from __future__ import annotations from typing import Iterable, Type, Union from enum import Enum - +from pathlib import Path from pydantic import BaseModel, Field +import importlib # from metagpt.environment import Environment from metagpt.config import CONFIG @@ -19,6 +20,7 @@ from metagpt.llm import LLM, HumanProvider from metagpt.logs import logger from metagpt.memory import Memory, LongTermMemory from metagpt.schema import Message +from metagpt.utils.utils import read_json_file, write_json_file, import_class PREFIX_TEMPLATE = """You are a {profile}, named {name}, your goal is {goal}, and the constraint is {constraints}. """ @@ -115,11 +117,101 @@ class Role: self._actions = [] self._role_id = str(self._setting) self._rc = RoleContext() + self._recovered = False + + def serialize(self, stg_path: Path): + role_info_path = stg_path.joinpath("role_info.json") + role_info = { + "role_class": self.__class__.__name__, + "module_name": self.__module__ + } + setting = self._setting.dict() + setting.pop("desc") + setting.pop("is_human") # not all inherited roles have this atrr + role_info.update(setting) + write_json_file(role_info_path, role_info) + + actions_info_path = stg_path.joinpath("actions/actions_info.json") + actions_info = [] + for action in self._actions: + actions_info.append(action.serialize()) + write_json_file(actions_info_path, actions_info) + + watches_info_path = stg_path.joinpath("watches/watches_info.json") + watches_info = [] + for watch in self._rc.watch: + watches_info.append(watch.ser_class()) + write_json_file(watches_info_path, watches_info) + + actions_todo_path = stg_path.joinpath("actions/todo.json") + actions_todo = { + "cur_state": self._rc.state, + "react_mode": self._rc.react_mode.value, + "max_react_loop": self._rc.max_react_loop + } + write_json_file(actions_todo_path, actions_todo) + + self._rc.memory.serialize(stg_path) + + @classmethod + def deserialize(cls, stg_path: Path) -> "Role": + """ stg_path = ./storage/team/environment/roles/{role_class}_{role_name}""" + role_info_path = stg_path.joinpath("role_info.json") + role_info = read_json_file(role_info_path) + + role_class_str = role_info.pop("role_class") + module_name = role_info.pop("module_name") + role_class = import_class(class_name=role_class_str, module_name=module_name) + + role = role_class(**role_info) # initiate particular Role + actions_info_path = stg_path.joinpath("actions/actions_info.json") + actions = [] + actions_info = read_json_file(actions_info_path) + for action_info in actions_info: + action = Action.deserialize(action_info) + actions.append(action) + + watches_info_path = stg_path.joinpath("watches/watches_info.json") + watches = [] + watches_info = read_json_file(watches_info_path) + for watch_info in watches_info: + action = Action.deser_class(watch_info) + watches.append(action) + + role.init_actions(actions) + role.watch(watches) + + actions_todo_path = stg_path.joinpath("actions/todo.json") + # recover self._rc.state + actions_todo = read_json_file(actions_todo_path) + max_react_loop = actions_todo.get("max_react_loop", 1) + cur_state = actions_todo.get("cur_state", -1) + role.set_state(cur_state) + role.set_recovered(True) + react_mode_str = actions_todo.get("react_mode", RoleReactMode.REACT.value) + if react_mode_str not in RoleReactMode.values(): + logger.warning(f"ReactMode: {react_mode_str} not in {RoleReactMode.values()}, use react as default") + react_mode_str = RoleReactMode.REACT.value + role.set_react_mode(RoleReactMode(react_mode_str), max_react_loop) + + role_memory = Memory.deserialize(stg_path) + role.set_memory(role_memory) + + return role def _reset(self): self._states = [] self._actions = [] + def set_recovered(self, recovered: bool = False): + self._recovered = recovered + + def set_memory(self, memory: Memory): + self._rc.memory = memory + + def init_actions(self, actions): + self._init_actions(actions) + def _init_actions(self, actions): self._reset() for idx, action in enumerate(actions): @@ -134,6 +226,9 @@ class Role: self._actions.append(i) self._states.append(f"{idx}. {action}") + def set_react_mode(self, react_mode: RoleReactMode, max_react_loop: int = 1): + self._set_react_mode(react_mode, max_react_loop) + def _set_react_mode(self, react_mode: str, max_react_loop: int = 1): """Set strategy of the Role reacting to observed Message. Variation lies in how this Role elects action to perform during the _think stage, especially if it is capable of multiple Actions. @@ -155,12 +250,18 @@ class Role: if react_mode == RoleReactMode.REACT: self._rc.max_react_loop = max_react_loop + def watch(self, actions: Iterable[Type[Action]]): + self._watch(actions) + def _watch(self, actions: Iterable[Type[Action]]): """Listen to the corresponding behaviors""" self._rc.watch.update(actions) # check RoleContext after adding watch actions self._rc.check(self._role_id) + def set_state(self, state: int): + self._set_state(state) + def _set_state(self, state: int): """Update the current state.""" self._rc.state = state @@ -171,6 +272,10 @@ class Role: """Set the environment in which the role works. The role can talk to the environment and can also receive messages by observing.""" self._rc.env = env + @property + def name(self): + return self._setting.name + @property def profile(self): """Get the role description (position)""" @@ -188,6 +293,11 @@ class Role: # If there is only one action, then only this one can be performed self._set_state(0) return + if self._recovered and self._rc.state >= 0: + self._set_state(self._rc.state) # action to run from recovered state + self._recovered = False # avoid max_react_loop out of work + return + prompt = self._get_prefix() prompt += STATE_TEMPLATE.format(history=self._rc.history, states="\n".join(self._states), n_states=len(self._states) - 1, previous_state=self._rc.state) @@ -267,7 +377,8 @@ class Role: async def _act_by_order(self) -> Message: """switch action each time by order defined in _init_actions, i.e. _act (Action1) -> _act (Action2) -> ...""" - for i in range(len(self._states)): + start_idx = self._rc.state if self._rc.state >= 0 else 0 # action to run from recovered state + for i in range(start_idx, len(self._states)): self._set_state(i) rsp = await self._act() return rsp # return output from the last action diff --git a/metagpt/schema.py b/metagpt/schema.py index bdca093c2..3374a7241 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -9,10 +9,14 @@ from __future__ import annotations from dataclasses import dataclass, field from typing import Type, TypedDict +import copy from pydantic import BaseModel from metagpt.logs import logger +# from metagpt.utils.serialize import actionoutout_schema_to_mapping +# from metagpt.actions.action_output import ActionOutput +# from metagpt.actions.action import Action class RawMessage(TypedDict): @@ -38,6 +42,46 @@ class Message: def __repr__(self): return self.__str__() + # def serialize(self): + # message_cp: Message = copy.deepcopy(self) + # ic = message_cp.instruct_content + # if ic: + # # model create by pydantic create_model like `pydantic.main.prd`, can't pickle.dump directly + # schema = ic.schema() + # mapping = actionoutout_schema_to_mapping(schema) + # + # message_cp.instruct_content = {"class": schema["title"], "mapping": mapping, "value": ic.dict()} + # cb = message_cp.cause_by + # if cb: + # message_cp.cause_by = cb.serialize() + # + # return message_cp.dict() + # + # @classmethod + # def deserialize(cls, message_dict: dict): + # instruct_content = message_dict.get("instruct_content") + # if instruct_content: + # ic = instruct_content + # ic_obj = ActionOutput.create_model_class(class_name=ic["class"], mapping=ic["mapping"]) + # ic_new = ic_obj(**ic["value"]) + # message_dict.instruct_content = ic_new + # cause_by = message_dict.get("cause_by") + # if cause_by: + # message_dict.cause_by = Action.deserialize(cause_by) + # + # return Message(**message_dict) + + def dict(self): + return { + "content": self.content, + "instruct_content": self.instruct_content, + "role": self.role, + "cause_by": self.cause_by, + "sent_from": self.sent_from, + "send_to": self.send_to, + "restricted_to": self.restricted_to + } + def to_dict(self) -> dict: return { "role": self.role, diff --git a/metagpt/team.py b/metagpt/team.py index 67d3ecec8..3b76e5ff4 100644 --- a/metagpt/team.py +++ b/metagpt/team.py @@ -5,6 +5,7 @@ @Author : alexanderwu @File : software_company.py """ +from pathlib import Path from pydantic import BaseModel, Field from metagpt.actions import BossRequirement @@ -14,6 +15,7 @@ from metagpt.logs import logger from metagpt.roles import Role from metagpt.schema import Message from metagpt.utils.common import NoMoneyException +from metagpt.utils.utils import read_json_file, write_json_file class Team(BaseModel): @@ -28,6 +30,30 @@ class Team(BaseModel): class Config: arbitrary_types_allowed = True + def serialize(self, stg_path: Path): + team_info_path = stg_path.joinpath("team_info.json") + write_json_file(team_info_path, { + "idea": self.idea, + "investment": self.investment + }) + + self.environment.serialize(stg_path.joinpath("environment")) + + def deserialize(self, stg_path: Path): + """ stg_path = ./storage/team """ + # recover team_info + team_info_path = stg_path.joinpath("team_info.json") + if not team_info_path.exists(): + logger.error("recover storage not exist, not to recover and continue run the old project.") + team_info = read_json_file(team_info_path) + self.investment = team_info.get("investment", 10.0) + self.idea = team_info.get("idea", "") + + # recover environment + environment_path = stg_path.joinpath("environment") + self.environment = Environment() + self.environment.deserialize(stg_path=environment_path) + def hire(self, roles: list[Role]): """Hire roles to cooperate""" self.environment.add_roles(roles) diff --git a/metagpt/utils/serialize.py b/metagpt/utils/serialize.py index 124176fcb..56a866f2e 100644 --- a/metagpt/utils/serialize.py +++ b/metagpt/utils/serialize.py @@ -4,13 +4,13 @@ import copy import pickle -from typing import Dict, List from metagpt.actions.action_output import ActionOutput from metagpt.schema import Message +from metagpt.actions.action import Action -def actionoutout_schema_to_mapping(schema: Dict) -> Dict: +def actionoutout_schema_to_mapping(schema: dict) -> dict: """ directly traverse the `properties` in the first level. schema structure likes @@ -35,13 +35,47 @@ def actionoutout_schema_to_mapping(schema: Dict) -> Dict: if property["type"] == "string": mapping[field] = (str, ...) elif property["type"] == "array" and property["items"]["type"] == "string": - mapping[field] = (List[str], ...) + mapping[field] = (list[str], ...) elif property["type"] == "array" and property["items"]["type"] == "array": - # here only consider the `List[List[str]]` situation - mapping[field] = (List[List[str]], ...) + # here only consider the `list[list[str]]` situation + mapping[field] = (list[list[str]], ...) return mapping +def actionoutput_mapping_to_str(mapping: dict) -> dict: + new_mapping = {} + for key, value in mapping.items(): + new_mapping[key] = str(value) + return new_mapping + + +def actionoutput_str_to_mapping(mapping: dict) -> dict: + new_mapping = {} + for key, value in mapping.items(): + if value == "(, Ellipsis)": + new_mapping[key] = (str, ...) + else: + new_mapping[key] = eval(value) # `"'(list[str], Ellipsis)"` to `(list[str], ...)` + return new_mapping + + +def serialize_general_message(message: Message) -> dict: + """ serialize Message, not to save""" + message_cp = copy.deepcopy(message) + ic = message_cp.instruct_content + if ic: + # model create by pydantic create_model like `pydantic.main.prd`, can't pickle.dump directly + schema = ic.schema() + mapping = actionoutout_schema_to_mapping(schema) + mapping = actionoutput_mapping_to_str(mapping) + + message_cp.instruct_content = {"class": schema["title"], "mapping": mapping, "value": ic.dict()} + cb = message_cp.cause_by + if cb: + message_cp.cause_by = cb.ser_class() + return message_cp.dict() + + def serialize_message(message: Message): message_cp = copy.deepcopy(message) # avoid `instruct_content` value update by reference ic = message_cp.instruct_content @@ -56,6 +90,24 @@ def serialize_message(message: Message): return msg_ser +def deserialize_general_message(message_dict: dict) -> Message: + """ deserialize Message, not to load""" + instruct_content = message_dict.pop("instruct_content") + cause_by = message_dict.pop("cause_by") + + message = Message(**message_dict) + if instruct_content: + ic = instruct_content + mapping = actionoutput_str_to_mapping(ic["mapping"]) + ic_obj = ActionOutput.create_model_class(class_name=ic["class"], mapping=mapping) + ic_new = ic_obj(**ic["value"]) + message.instruct_content = ic_new + if cause_by: + message.cause_by = Action.deser_class(cause_by) + + return message + + def deserialize_message(message_ser: str) -> Message: message = pickle.loads(message_ser) if message.instruct_content: diff --git a/metagpt/utils/utils.py b/metagpt/utils/utils.py new file mode 100644 index 000000000..81ceea884 --- /dev/null +++ b/metagpt/utils/utils.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +from typing import Any +import json +from pathlib import Path +import importlib + + +def read_json_file(json_file: str, encoding=None) -> list[Any]: + if not Path(json_file).exists(): + raise FileNotFoundError(f"json_file: {json_file} not exist, return []") + + with open(json_file, "r", encoding=encoding) as fin: + try: + data = json.load(fin) + except Exception as exp: + raise ValueError(f"read json file: {json_file} failed") + return data + + +def write_json_file(json_file: str, data: list, encoding=None): + folder_path = Path(json_file).parent + if not folder_path.exists(): + folder_path.mkdir(parents=True, exist_ok=True) + + with open(json_file, "w", encoding=encoding) as fout: + json.dump(data, fout, ensure_ascii=False, indent=4) + + +def import_class(class_name: str, module_name: str) -> type: + module = importlib.import_module(module_name) + a_class = getattr(module, class_name) + return a_class + + +def import_class_inst(class_name: str, module_name: str, *args, **kwargs) -> object: + a_class = import_class(class_name, module_name) + class_inst = a_class(*args, **kwargs) + return class_inst diff --git a/startup.py b/startup.py index e9fbf94d3..9f753d553 100644 --- a/startup.py +++ b/startup.py @@ -4,6 +4,7 @@ import asyncio import fire +from metagpt.const import SERDES_PATH from metagpt.roles import ( Architect, Engineer, @@ -21,26 +22,32 @@ async def startup( code_review: bool = False, run_tests: bool = False, implement: bool = True, + recover_path: bool = False, ): """Run a startup. Be a boss.""" company = Team() - company.hire( - [ - ProductManager(), - Architect(), - ProjectManager(), - ] - ) + if not recover_path: + company.hire( + [ + ProductManager(), + Architect(), + ProjectManager(), + ] + ) - # if implement or code_review - if implement or code_review: - # developing features: implement the idea - company.hire([Engineer(n_borg=5, use_code_review=code_review)]) + # if implement or code_review + if implement or code_review: + # developing features: implement the idea + company.hire([Engineer(n_borg=5, use_code_review=code_review)]) - if run_tests: - # developing features: run tests on the spot and identify bugs - # (bug fixing capability comes soon!) - company.hire([QaEngineer()]) + if run_tests: + # developing features: run tests on the spot and identify bugs + # (bug fixing capability comes soon!) + company.hire([QaEngineer()]) + else: + stg_path = SERDES_PATH.joinpath("team") + company.deserialize(stg_path=stg_path) + idea = company.idea # use original idea company.invest(investment) company.start_project(idea) @@ -54,6 +61,7 @@ def main( code_review: bool = True, run_tests: bool = False, implement: bool = True, + recover_path: str = None, ): """ We are a software startup comprised of AI. By investing in us, @@ -63,9 +71,10 @@ def main( a certain dollar amount to this AI company. :param n_round: :param code_review: Whether to use code review. + :param recover_path: recover the project from existing serialized storage :return: """ - asyncio.run(startup(idea, investment, n_round, code_review, run_tests, implement)) + asyncio.run(startup(idea, investment, n_round, code_review, run_tests, implement, recover_path)) if __name__ == "__main__": diff --git a/tests/metagpt/actions/test_action.py b/tests/metagpt/actions/test_action.py index 9775630cc..4468a6f6f 100644 --- a/tests/metagpt/actions/test_action.py +++ b/tests/metagpt/actions/test_action.py @@ -11,3 +11,20 @@ from metagpt.actions import Action, WritePRD, WriteTest def test_action_repr(): actions = [Action(), WriteTest(), WritePRD()] assert "WriteTest" in str(actions) + + +def test_action_serdes(): + action_info = WriteTest.ser_class() + assert action_info["action_class"] == "WriteTest" + + action_class = Action.deser_class(action_info) + assert action_class == WriteTest + + +def test_action_class_serdes(): + name = "write test" + action_info = WriteTest(name=name).serialize() + assert action_info["name"] == name + + action = Action.deserialize(action_info) + assert action.name == name diff --git a/tests/metagpt/memory/test_memory.py b/tests/metagpt/memory/test_memory.py new file mode 100644 index 000000000..bda79ded1 --- /dev/null +++ b/tests/metagpt/memory/test_memory.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : unittest of memory + +from pathlib import Path + +from metagpt.schema import Message +from metagpt.memory.memory import Memory +from metagpt.actions.action_output import ActionOutput +from metagpt.actions.design_api import WriteDesign +from metagpt.actions.add_requirement import BossRequirement + +serdes_path = Path(__file__).absolute().parent.joinpath("../../data/serdes_storage") + + +def test_memory_serdes(): + msg1 = Message(role="User", + content="write a 2048 game", + cause_by=BossRequirement) + + out_mapping = {"field1": (list[str], ...)} + out_data = {"field1": ["field1 value1", "field1 value2"]} + ic_obj = ActionOutput.create_model_class("system_design", out_mapping) + msg2 = Message(role="Architect", + instruct_content=ic_obj(**out_data), + content="system design content", + cause_by=WriteDesign) + + memory = Memory() + memory.add_batch([msg1, msg2]) + + stg_path = serdes_path.joinpath("team/environment") + memory.serialize(stg_path) + assert stg_path.joinpath("memory.json").exists() + + new_memory = Memory.deserialize(stg_path) + assert new_memory.count() == 2 + new_msg2 = new_memory.get(1)[0] + assert new_msg2.instruct_content.field1 == ["field1 value1", "field1 value2"] + assert new_msg2.cause_by == WriteDesign + + stg_path.joinpath("memory.json").unlink() diff --git a/tests/metagpt/roles/test_role.py b/tests/metagpt/roles/test_role.py new file mode 100644 index 000000000..a19ad9cb5 --- /dev/null +++ b/tests/metagpt/roles/test_role.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : unittest of Role + +from pathlib import Path +import shutil +import pytest + +from metagpt.roles.role import Role, RoleReactMode +from metagpt.actions.action import Action +from metagpt.schema import Message +from metagpt.actions.add_requirement import BossRequirement +from metagpt.roles.product_manager import ProductManager + +serdes_path = Path(__file__).absolute().parent.joinpath("../../data/serdes_storage") + + +def test_role_serdes(): + stg_path_prefix = serdes_path.joinpath("team/environment/roles/") + shutil.rmtree(serdes_path.joinpath("team"), ignore_errors=True) + + pm = ProductManager() + role_tag = f"{pm.__class__.__name__}_{pm.name}" + stg_path = stg_path_prefix.joinpath(role_tag) + pm.serialize(stg_path) + assert stg_path.joinpath("actions/actions_info.json").exists() + + new_pm = Role.deserialize(stg_path) + assert new_pm.name == pm.name + assert len(new_pm.get_memories(1)) == 0 + + +class ActionOK(Action): + + async def run(self, messages: list["Message"]): + return "ok" + + +class ActionRaise(Action): + + async def run(self, messages: list["Message"]): + raise RuntimeError("parse error") + + +class RoleA(Role): + + def __init__(self, + name: str = "RoleA", + profile: str = "Role A", + goal: str = "", + constraints: str = ""): + super(RoleA, self).__init__(name=name, profile=profile, goal=goal, constraints=constraints) + self._init_actions([ActionOK, ActionRaise]) + self._watch([BossRequirement]) + self._rc.react_mode = RoleReactMode.BY_ORDER + + async def run(self, message: "Message" = None, stg_path: str = None): + try: + await super(RoleA, self).run(message) + except Exception as exp: + print("exp ", exp) + self.serialize(stg_path) + + +@pytest.mark.asyncio +async def test_role_serdes_interrupt(): + role_a = RoleA() + shutil.rmtree(serdes_path.joinpath("team"), ignore_errors=True) + + stg_path = serdes_path.joinpath(f"team/environment/roles/{role_a.__class__.__name__}_{role_a.name}") + await role_a.run( + message=Message(content="demo", cause_by=BossRequirement), + stg_path=stg_path + ) + assert role_a._rc.memory.count() == 2 + + assert stg_path.joinpath("actions/todo.json").exists() + + new_role_a: Role = Role.deserialize(stg_path) + assert new_role_a._rc.state == 1 + await role_a.run( + message=Message(content="demo", cause_by=BossRequirement), + stg_path=stg_path + ) + diff --git a/tests/metagpt/test_environment.py b/tests/metagpt/test_environment.py index a0f1f6257..3cc2d8a7a 100644 --- a/tests/metagpt/test_environment.py +++ b/tests/metagpt/test_environment.py @@ -7,13 +7,18 @@ """ import pytest +from pathlib import Path +import shutil from metagpt.actions import BossRequirement from metagpt.environment import Environment from metagpt.logs import logger -from metagpt.manager import Manager from metagpt.roles import Architect, ProductManager, Role from metagpt.schema import Message +from tests.metagpt.roles.test_role import RoleA + + +serdes_path = Path(__file__).absolute().parent.joinpath("../data/serdes_storage") @pytest.fixture @@ -36,21 +41,29 @@ def test_get_roles(env: Environment): assert roles == {role1.profile: role1, role2.profile: role2} -def test_set_manager(env: Environment): - manager = Manager() - env.set_manager(manager) - assert env.manager == manager - - @pytest.mark.asyncio async def test_publish_and_process_message(env: Environment): product_manager = ProductManager("Alice", "Product Manager", "做AI Native产品", "资源有限") architect = Architect("Bob", "Architect", "设计一个可用、高效、较低成本的系统,包括数据结构与接口", "资源有限,需要节省成本") env.add_roles([product_manager, architect]) - env.set_manager(Manager()) env.publish_message(Message(role="BOSS", content="需要一个基于LLM做总结的搜索引擎", cause_by=BossRequirement)) await env.run(k=2) logger.info(f"{env.history=}") assert len(env.history) > 10 + + +def test_environment_serdes(): + environment = Environment() + role_a = RoleA() + + shutil.rmtree(serdes_path.joinpath("team"), ignore_errors=True) + + stg_path = serdes_path.joinpath("team/environment") + environment.add_role(role_a) + environment.serialize(stg_path) + + new_env: Environment = Environment() + new_env.deserialize(stg_path) + assert len(new_env.roles) == 1 diff --git a/tests/metagpt/test_schema.py b/tests/metagpt/test_schema.py index 12666e0d3..f515326e8 100644 --- a/tests/metagpt/test_schema.py +++ b/tests/metagpt/test_schema.py @@ -5,7 +5,11 @@ @Author : alexanderwu @File : test_schema.py """ + from metagpt.schema import AIMessage, Message, SystemMessage, UserMessage +from metagpt.actions.action_output import ActionOutput +from metagpt.actions.write_code import WriteCode +from metagpt.utils.serialize import serialize_general_message, deserialize_general_message def test_messages(): @@ -19,3 +23,41 @@ def test_messages(): text = str(msgs) roles = ['user', 'system', 'assistant', 'QA'] assert all([i in text for i in roles]) + + +def test_message_serdes(): + out_mapping = {"field3": (str, ...), "field4": (list[str], ...)} + out_data = {"field3": "field3 value3", "field4": ["field4 value1", "field4 value2"]} + ic_obj = ActionOutput.create_model_class("code", out_mapping) + + message = Message( + content="code", + instruct_content=ic_obj(**out_data), + role="engineer", + cause_by=WriteCode + ) + message_dict = serialize_general_message(message) + assert message_dict["cause_by"] == {"action_class": "WriteCode"} + assert message_dict["instruct_content"] == { + "class": "code", + "mapping": { + "field3": "(, Ellipsis)", + "field4": "(list[str], Ellipsis)" + }, + "value": { + "field3": "field3 value3", + "field4": ["field4 value1", "field4 value2"] + } + } + + new_message = deserialize_general_message(message_dict) + assert new_message.content == message.content + assert new_message.instruct_content == message.instruct_content + assert new_message.cause_by == message.cause_by + assert new_message.instruct_content.field3 == out_data["field3"] + + message = Message(content="code") + message_dict = serialize_general_message(message) + new_message = deserialize_general_message(message_dict) + assert new_message.instruct_content is None + assert new_message.cause_by == "" diff --git a/tests/metagpt/test_team.py b/tests/metagpt/test_team.py new file mode 100644 index 000000000..ab201152c --- /dev/null +++ b/tests/metagpt/test_team.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : unittest of team + +from pathlib import Path +import shutil + +from metagpt.team import Team + +from tests.metagpt.roles.test_role import RoleA + +serdes_path = Path(__file__).absolute().parent.joinpath("../data/serdes_storage") + + +def test_team_serdes(): + company = Team() + company.hire([RoleA()]) + + stg_path = serdes_path.joinpath("team") + shutil.rmtree(stg_path, ignore_errors=True) + + company.serialize(stg_path=stg_path) + + new_company = Team() + new_company.deserialize(stg_path) + + assert len(new_company.environment.roles) == 1 From 39e4aa98ab6101ee1016cc8584c4f36977498077 Mon Sep 17 00:00:00 2001 From: better629 Date: Tue, 28 Nov 2023 10:47:19 +0800 Subject: [PATCH 004/167] fix role and format ut of serialize_deserialize --- metagpt/roles/role.py | 29 +++++-------------- .../serialize_deserialize/test_actions.py | 2 ++ .../test_architect_deserialize.py | 2 ++ .../test_product_manager.py | 1 + .../test_project_manager.py | 2 ++ .../serialize_deserialize/test_role.py | 2 +- .../serialize_deserialize/test_wrire_prd.py | 4 +-- .../serialize_deserialize/test_write_code.py | 6 +++- .../test_write_design.py | 6 +++- 9 files changed, 27 insertions(+), 27 deletions(-) diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index eb5539f43..e9371c2c0 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -6,16 +6,11 @@ @File : role.py """ -import sys from enum import Enum -import importlib +from pathlib import Path from __future__ import annotations -from types import SimpleNamespace from typing import ( - Dict, - Optional, - Union, Iterable, Type ) @@ -30,6 +25,7 @@ from metagpt.llm import LLM from metagpt.logs import logger from metagpt.memory import Memory, LongTermMemory from metagpt.schema import Message +from metagpt.provider.human_provider import HumanProvider from metagpt.utils.utils import read_json_file, write_json_file, import_class PREFIX_TEMPLATE = """You are a {profile}, named {name}, your goal is {goal}, and the constraint is {constraints}. """ @@ -133,11 +129,11 @@ class Role(BaseModel): _rc: RoleContext = RoleContext() _private_attributes = { - "_setting': _setting, - "_role_id': _role_id, - '_states': [], - '_actions': [], - '_actions_type': [] # 用于记录和序列化 + "_setting": _setting, + "_role_id": _role_id, + "_states": [], + "_actions": [], + "_actions_type": [] # 用于记录和序列化 } class Config: @@ -162,17 +158,6 @@ class Role(BaseModel): object.__setattr__(self, '_states', []) object.__setattr__(self, '_actions', []) - @staticmethod - def _process_class(class_str, module_name): - cleaned_string = re.sub(r"[<>']", "", class_str).replace("class ", "") - package_name = "metagpt" - file_name = cleaned_string.replace(package_name, "").replace("." + module_name, "") - print(file_name) - # print("\n", sys.modules) - module_file = import_module(file_name, package=package_name) - module = getattr(module_file, module_name) - return module - def serialize(self, stg_path: Path): role_info_path = stg_path.joinpath("role_info.json") role_info = { diff --git a/tests/metagpt/serialize_deserialize/test_actions.py b/tests/metagpt/serialize_deserialize/test_actions.py index e2efa982b..2fec2121a 100644 --- a/tests/metagpt/serialize_deserialize/test_actions.py +++ b/tests/metagpt/serialize_deserialize/test_actions.py @@ -7,12 +7,14 @@ import pytest from metagpt.actions import Action from metagpt.llm import LLM + def test_action_serialize(): action = Action() ser_action_dict = action.dict() assert "name" in ser_action_dict assert "llm" in ser_action_dict + @pytest.mark.asyncio async def test_action_deserialize(): action = Action() diff --git a/tests/metagpt/serialize_deserialize/test_architect_deserialize.py b/tests/metagpt/serialize_deserialize/test_architect_deserialize.py index cff1bbadd..d0ee3bc99 100644 --- a/tests/metagpt/serialize_deserialize/test_architect_deserialize.py +++ b/tests/metagpt/serialize_deserialize/test_architect_deserialize.py @@ -7,6 +7,7 @@ import pytest from metagpt.roles.architect import Architect from metagpt.actions.action import Action + def test_architect_serialize(): role = Architect() ser_role_dict = role.dict(by_alias=True) @@ -14,6 +15,7 @@ def test_architect_serialize(): assert "_states" in ser_role_dict assert "_actions" in ser_role_dict + @pytest.mark.asyncio async def test_architect_deserialize(): role = Architect() diff --git a/tests/metagpt/serialize_deserialize/test_product_manager.py b/tests/metagpt/serialize_deserialize/test_product_manager.py index 978c50e5e..2aed87a28 100644 --- a/tests/metagpt/serialize_deserialize/test_product_manager.py +++ b/tests/metagpt/serialize_deserialize/test_product_manager.py @@ -8,6 +8,7 @@ from metagpt.roles.product_manager import ProductManager from metagpt.actions.action import Action from metagpt.schema import Message + @pytest.mark.asyncio async def test_product_manager_deserialize(): role = ProductManager() diff --git a/tests/metagpt/serialize_deserialize/test_project_manager.py b/tests/metagpt/serialize_deserialize/test_project_manager.py index 590bd8109..fbc0dcc08 100644 --- a/tests/metagpt/serialize_deserialize/test_project_manager.py +++ b/tests/metagpt/serialize_deserialize/test_project_manager.py @@ -7,6 +7,7 @@ import pytest from metagpt.roles.project_manager import ProjectManager from metagpt.actions.action import Action + def test_project_manager_serialize(): role = ProjectManager() ser_role_dict = role.dict(by_alias=True) @@ -14,6 +15,7 @@ def test_project_manager_serialize(): assert "_states" in ser_role_dict assert "_actions" in ser_role_dict + @pytest.mark.asyncio async def test_project_manager_deserialize(): role = ProjectManager() diff --git a/tests/metagpt/serialize_deserialize/test_role.py b/tests/metagpt/serialize_deserialize/test_role.py index 432c9acb7..0e438d1a2 100644 --- a/tests/metagpt/serialize_deserialize/test_role.py +++ b/tests/metagpt/serialize_deserialize/test_role.py @@ -34,7 +34,7 @@ async def test_engineer_deserialize(): # also can be deserialized in this way: new_role = Engineer(**ser_role_dict) assert new_role.name == "Alex" - assert new_role.use_code_review == True + assert new_role.use_code_review is True assert len(new_role._actions) == 2 assert isinstance(new_role._actions[0], Action) assert isinstance(new_role._actions[1], Action) diff --git a/tests/metagpt/serialize_deserialize/test_wrire_prd.py b/tests/metagpt/serialize_deserialize/test_wrire_prd.py index 9b2653820..baa08ed76 100644 --- a/tests/metagpt/serialize_deserialize/test_wrire_prd.py +++ b/tests/metagpt/serialize_deserialize/test_wrire_prd.py @@ -24,5 +24,5 @@ async def test_action_deserialize(): # new_action = WritePRD().deserialize(serialized_data) assert new_action.name == "" assert new_action.llm == LLM() - assert len(await new_action.run([Message(content="write a cli snake game")]))>0 - + assert len(await new_action.run([Message(content="write a cli snake game")])) > 0 + diff --git a/tests/metagpt/serialize_deserialize/test_write_code.py b/tests/metagpt/serialize_deserialize/test_write_code.py index 0b1f1dc7c..9d659caaf 100644 --- a/tests/metagpt/serialize_deserialize/test_write_code.py +++ b/tests/metagpt/serialize_deserialize/test_write_code.py @@ -7,18 +7,21 @@ import pytest from metagpt.actions import WriteCode, WriteCodeReview from metagpt.llm import LLM + def test_write_design_serialize(): action = WriteCode() ser_action_dict = action.dict() assert ser_action_dict["name"] == "WriteCode" assert "llm" in ser_action_dict + def test_write_task_serialize(): action = WriteCodeReview() ser_action_dict = action.dict() assert ser_action_dict["name"] == "WriteCodeReview" assert "llm" in ser_action_dict - + + @pytest.mark.asyncio async def test_write_code_deserialize(): action = WriteCode() @@ -29,6 +32,7 @@ async def test_write_code_deserialize(): assert new_action.llm == LLM() await new_action.run(context="write a cli snake game", filename="test_code") + @pytest.mark.asyncio async def test_write_code_review_deserialize(): action = WriteCodeReview() diff --git a/tests/metagpt/serialize_deserialize/test_write_design.py b/tests/metagpt/serialize_deserialize/test_write_design.py index 56bf78a63..e6e236676 100644 --- a/tests/metagpt/serialize_deserialize/test_write_design.py +++ b/tests/metagpt/serialize_deserialize/test_write_design.py @@ -7,18 +7,21 @@ import pytest from metagpt.actions import WriteDesign, WriteTasks from metagpt.llm import LLM + def test_write_design_serialize(): action = WriteDesign() ser_action_dict = action.dict() assert "name" in ser_action_dict assert "llm" in ser_action_dict + def test_write_task_serialize(): action = WriteTasks() ser_action_dict = action.dict() assert "name" in ser_action_dict assert "llm" in ser_action_dict + @pytest.mark.asyncio async def test_write_design_deserialize(): action = WriteDesign() @@ -28,6 +31,7 @@ async def test_write_design_deserialize(): assert new_action.llm == LLM() await new_action.run(context="write a cli snake game") + @pytest.mark.asyncio async def test_write_task_deserialize(): action = WriteTasks() @@ -36,4 +40,4 @@ async def test_write_task_deserialize(): # new_action = WriteTasks().deserialize(serialized_data) assert new_action.name == "CreateTasks" assert new_action.llm == LLM() - await new_action.run(context="write a cli snake game") \ No newline at end of file + await new_action.run(context="write a cli snake game") From 5f69878a08ead0f3a9c4e743c8a226c902aec076 Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Wed, 29 Nov 2023 20:23:15 +0800 Subject: [PATCH 005/167] openai.api_base -> openai.base_url --- config/config.yaml | 10 +++++----- docs/FAQ-EN.md | 8 ++++---- docs/README_JA.md | 2 +- docs/tutorial/usage.md | 2 +- docs/tutorial/usage_cn.md | 2 +- metagpt/config.py | 10 ++++++---- 6 files changed, 18 insertions(+), 16 deletions(-) diff --git a/config/config.yaml b/config/config.yaml index bed67083c..249552693 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -2,10 +2,10 @@ # The configuration of key.yaml has a higher priority and will not enter git #### if OpenAI -## The official OPENAI_API_BASE is https://api.openai.com/v1 -## If the official OPENAI_API_BASE is not available, we recommend using the [openai-forward](https://github.com/beidongjiedeguang/openai-forward). -## Or, you can configure OPENAI_PROXY to access official OPENAI_API_BASE. -OPENAI_API_BASE: "https://api.openai.com/v1" +## The official OPENAI_BASE_URL is https://api.openai.com/v1/ +## If the official OPENAI_BASE_URL is not available, we recommend using the [openai-forward](https://github.com/beidongjiedeguang/openai-forward). +## Or, you can configure OPENAI_PROXY to access official OPENAI_BASE_URL. +OPENAI_BASE_URL: "https://api.openai.com/v1/" #OPENAI_PROXY: "http://127.0.0.1:8118" #OPENAI_API_KEY: "YOUR_API_KEY" # set the value to sk-xxx if you host the openai interface for open llm model OPENAI_API_MODEL: "gpt-4" @@ -25,7 +25,7 @@ RPM: 10 #### if AZURE, check https://github.com/openai/openai-cookbook/blob/main/examples/azure/chat.ipynb #### You can use ENGINE or DEPLOYMENT mode #OPENAI_API_TYPE: "azure" -#OPENAI_API_BASE: "YOUR_AZURE_ENDPOINT" +#OPENAI_BASE_URL: "YOUR_AZURE_ENDPOINT" #OPENAI_API_KEY: "YOUR_AZURE_API_KEY" #OPENAI_API_VERSION: "YOUR_AZURE_API_VERSION" #DEPLOYMENT_NAME: "YOUR_DEPLOYMENT_NAME" diff --git a/docs/FAQ-EN.md b/docs/FAQ-EN.md index f9df50caf..1c5b4a86a 100644 --- a/docs/FAQ-EN.md +++ b/docs/FAQ-EN.md @@ -83,10 +83,10 @@ 1. PRD stuck / unable to access/ connection interrupted - 1. The official OPENAI_API_BASE address is `https://api.openai.com/v1` - 1. If the official OPENAI_API_BASE address is inaccessible in your environment (this can be verified with curl), it's recommended to configure using the reverse proxy OPENAI_API_BASE provided by libraries such as openai-forward. For instance, `OPENAI_API_BASE: "``https://api.openai-forward.com/v1``"` - 1. If the official OPENAI_API_BASE address is inaccessible in your environment (again, verifiable via curl), another option is to configure the OPENAI_PROXY parameter. This way, you can access the official OPENAI_API_BASE via a local proxy. If you don't need to access via a proxy, please do not enable this configuration; if accessing through a proxy is required, modify it to the correct proxy address. Note that when OPENAI_PROXY is enabled, don't set OPENAI_API_BASE. - 1. Note: OpenAI's default API design ends with a v1. An example of the correct configuration is: `OPENAI_API_BASE: "``https://api.openai.com/v1``"` + 1. The official OPENAI_BASE_URL address is `https://api.openai.com/v1/` + 1. If the official OPENAI_BASE_URL address is inaccessible in your environment (this can be verified with curl), it's recommended to configure using the reverse proxy OPENAI_BASE_URL provided by libraries such as openai-forward. For instance, `OPENAI_BASE_URL: "``https://api.openai-forward.com/v1/``"` + 1. If the official OPENAI_BASE_URL address is inaccessible in your environment (again, verifiable via curl), another option is to configure the OPENAI_PROXY parameter. This way, you can access the official OPENAI_BASE_URL via a local proxy. If you don't need to access via a proxy, please do not enable this configuration; if accessing through a proxy is required, modify it to the correct proxy address. Note that when OPENAI_PROXY is enabled, don't set OPENAI_BASE_URL. + 1. Note: OpenAI's default API design ends with a v1. An example of the correct configuration is: `OPENAI_BASE_URL: "``https://api.openai.com/v1/``"` 1. Absolutely! How can I assist you today? diff --git a/docs/README_JA.md b/docs/README_JA.md index 411d190b4..33b08b770 100644 --- a/docs/README_JA.md +++ b/docs/README_JA.md @@ -219,7 +219,7 @@ # 設定ファイルをコピーし、必要な修正を加える。 | 変数名 | config/key.yaml | env | | --------------------------------------- | ----------------------------------------- | ----------------------------------------------- | | OPENAI_API_KEY # 自分のキーに置き換える | OPENAI_API_KEY: "sk-..." | export OPENAI_API_KEY="sk-..." | -| OPENAI_API_BASE # オプション | OPENAI_API_BASE: "https:///v1" | export OPENAI_API_BASE="https:///v1" | +| OPENAI_BASE_URL # オプション | OPENAI_BASE_URL: "https:///v1/" | export OPENAI_BASE_URL="https:///v1/" | ## チュートリアル: スタートアップの開始 diff --git a/docs/tutorial/usage.md b/docs/tutorial/usage.md index ee87b65c9..f8a25c84f 100644 --- a/docs/tutorial/usage.md +++ b/docs/tutorial/usage.md @@ -13,7 +13,7 @@ # Copy the configuration file and make the necessary modifications. | Variable Name | config/key.yaml | env | | ------------------------------------------ | ----------------------------------------- | ----------------------------------------------- | | OPENAI_API_KEY # Replace with your own key | OPENAI_API_KEY: "sk-..." | export OPENAI_API_KEY="sk-..." | -| OPENAI_API_BASE # Optional | OPENAI_API_BASE: "https:///v1" | export OPENAI_API_BASE="https:///v1" | +| OPENAI_BASE_URL # Optional | OPENAI_BASE_URL: "https:///v1/" | export OPENAI_BASE_URL="https:///v1/" | ### Initiating a startup diff --git a/docs/tutorial/usage_cn.md b/docs/tutorial/usage_cn.md index 4b3bdd2c3..ddd1c2267 100644 --- a/docs/tutorial/usage_cn.md +++ b/docs/tutorial/usage_cn.md @@ -13,7 +13,7 @@ # 复制配置文件并进行必要的修改 | 变量名 | config/key.yaml | env | | ----------------------------------- | ----------------------------------------- | ----------------------------------------------- | | OPENAI_API_KEY # 用您自己的密钥替换 | OPENAI_API_KEY: "sk-..." | export OPENAI_API_KEY="sk-..." | -| OPENAI_API_BASE # 可选 | OPENAI_API_BASE: "https:///v1" | export OPENAI_API_BASE="https:///v1" | +| OPENAI_BASE_URL # 可选 | OPENAI_BASE_URL: "https:///v1/" | export OPENAI_BASE_URL="https:///v1/" | ### 示例:启动一个创业公司 diff --git a/metagpt/config.py b/metagpt/config.py index 3f9e742bd..a6ecab5ff 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -46,11 +46,13 @@ class Config(metaclass=Singleton): self.openai_api_key = self._get("OPENAI_API_KEY") self.anthropic_api_key = self._get("Anthropic_API_KEY") self.zhipuai_api_key = self._get("ZHIPUAI_API_KEY") - if (not self.openai_api_key or "YOUR_API_KEY" == self.openai_api_key) and \ - (not self.anthropic_api_key or "YOUR_API_KEY" == self.anthropic_api_key) and \ - (not self.zhipuai_api_key or "YOUR_API_KEY" == self.zhipuai_api_key): + if ( + (not self.openai_api_key or "YOUR_API_KEY" == self.openai_api_key) + and (not self.anthropic_api_key or "YOUR_API_KEY" == self.anthropic_api_key) + and (not self.zhipuai_api_key or "YOUR_API_KEY" == self.zhipuai_api_key) + ): raise NotConfiguredException("Set OPENAI_API_KEY or Anthropic_API_KEY or ZHIPUAI_API_KEY first") - self.openai_api_base = self._get("OPENAI_API_BASE") + self.openai_api_base = self._get("OPENAI_BASE_URL") openai_proxy = self._get("OPENAI_PROXY") or self.global_proxy if openai_proxy: openai.proxy = openai_proxy From 9e5c873d77754f24a7b36be0e697975d30efed04 Mon Sep 17 00:00:00 2001 From: better629 Date: Thu, 30 Nov 2023 15:10:38 +0800 Subject: [PATCH 006/167] update unittest of ser&deser --- tests/metagpt/actions/test_action.py | 17 --- tests/metagpt/roles/test_role.py | 84 +----------- .../serialize_deserialize/test_action.py | 49 +++++++ .../serialize_deserialize/test_actions.py | 26 ---- .../test_architect_deserialize.py | 2 +- .../serialize_deserialize/test_environment.py | 91 +++++++++++++ .../test_memory.py | 34 ++++- .../test_product_manager.py | 4 +- .../test_project_manager.py | 6 +- .../serialize_deserialize/test_role.py | 63 ++++++++- .../serialize_deserialize/test_schema.py | 49 +++++++ .../test_serdeser_base.py | 88 +++++++++++++ .../serialize_deserialize/test_team.py | 124 +++++++++++++----- .../serialize_deserialize/test_wrire_prd.py | 1 - .../serialize_deserialize/test_write_code.py | 2 +- tests/metagpt/test_environment.py | 44 +++---- tests/metagpt/test_role.py | 14 -- tests/metagpt/test_schema.py | 4 +- tests/metagpt/test_team.py | 22 +--- 19 files changed, 496 insertions(+), 228 deletions(-) create mode 100644 tests/metagpt/serialize_deserialize/test_action.py delete mode 100644 tests/metagpt/serialize_deserialize/test_actions.py create mode 100644 tests/metagpt/serialize_deserialize/test_environment.py rename tests/metagpt/{memory => serialize_deserialize}/test_memory.py (52%) create mode 100644 tests/metagpt/serialize_deserialize/test_schema.py create mode 100644 tests/metagpt/serialize_deserialize/test_serdeser_base.py delete mode 100644 tests/metagpt/test_role.py diff --git a/tests/metagpt/actions/test_action.py b/tests/metagpt/actions/test_action.py index 4468a6f6f..9775630cc 100644 --- a/tests/metagpt/actions/test_action.py +++ b/tests/metagpt/actions/test_action.py @@ -11,20 +11,3 @@ from metagpt.actions import Action, WritePRD, WriteTest def test_action_repr(): actions = [Action(), WriteTest(), WritePRD()] assert "WriteTest" in str(actions) - - -def test_action_serdes(): - action_info = WriteTest.ser_class() - assert action_info["action_class"] == "WriteTest" - - action_class = Action.deser_class(action_info) - assert action_class == WriteTest - - -def test_action_class_serdes(): - name = "write test" - action_info = WriteTest(name=name).serialize() - assert action_info["name"] == name - - action = Action.deserialize(action_info) - assert action.name == name diff --git a/tests/metagpt/roles/test_role.py b/tests/metagpt/roles/test_role.py index a19ad9cb5..72cd84a9a 100644 --- a/tests/metagpt/roles/test_role.py +++ b/tests/metagpt/roles/test_role.py @@ -2,84 +2,10 @@ # -*- coding: utf-8 -*- # @Desc : unittest of Role -from pathlib import Path -import shutil -import pytest - -from metagpt.roles.role import Role, RoleReactMode -from metagpt.actions.action import Action -from metagpt.schema import Message -from metagpt.actions.add_requirement import BossRequirement -from metagpt.roles.product_manager import ProductManager - -serdes_path = Path(__file__).absolute().parent.joinpath("../../data/serdes_storage") +from metagpt.roles.role import Role -def test_role_serdes(): - stg_path_prefix = serdes_path.joinpath("team/environment/roles/") - shutil.rmtree(serdes_path.joinpath("team"), ignore_errors=True) - - pm = ProductManager() - role_tag = f"{pm.__class__.__name__}_{pm.name}" - stg_path = stg_path_prefix.joinpath(role_tag) - pm.serialize(stg_path) - assert stg_path.joinpath("actions/actions_info.json").exists() - - new_pm = Role.deserialize(stg_path) - assert new_pm.name == pm.name - assert len(new_pm.get_memories(1)) == 0 - - -class ActionOK(Action): - - async def run(self, messages: list["Message"]): - return "ok" - - -class ActionRaise(Action): - - async def run(self, messages: list["Message"]): - raise RuntimeError("parse error") - - -class RoleA(Role): - - def __init__(self, - name: str = "RoleA", - profile: str = "Role A", - goal: str = "", - constraints: str = ""): - super(RoleA, self).__init__(name=name, profile=profile, goal=goal, constraints=constraints) - self._init_actions([ActionOK, ActionRaise]) - self._watch([BossRequirement]) - self._rc.react_mode = RoleReactMode.BY_ORDER - - async def run(self, message: "Message" = None, stg_path: str = None): - try: - await super(RoleA, self).run(message) - except Exception as exp: - print("exp ", exp) - self.serialize(stg_path) - - -@pytest.mark.asyncio -async def test_role_serdes_interrupt(): - role_a = RoleA() - shutil.rmtree(serdes_path.joinpath("team"), ignore_errors=True) - - stg_path = serdes_path.joinpath(f"team/environment/roles/{role_a.__class__.__name__}_{role_a.name}") - await role_a.run( - message=Message(content="demo", cause_by=BossRequirement), - stg_path=stg_path - ) - assert role_a._rc.memory.count() == 2 - - assert stg_path.joinpath("actions/todo.json").exists() - - new_role_a: Role = Role.deserialize(stg_path) - assert new_role_a._rc.state == 1 - await role_a.run( - message=Message(content="demo", cause_by=BossRequirement), - stg_path=stg_path - ) - +def test_role_desc(): + role = Role(profile="Sales", desc="Best Seller") + assert role.profile == "Sales" + assert role._setting.desc == "Best Seller" diff --git a/tests/metagpt/serialize_deserialize/test_action.py b/tests/metagpt/serialize_deserialize/test_action.py new file mode 100644 index 000000000..b624dff5a --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_action.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +# @Date : 11/22/2023 11:48 AM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +import pytest + +from metagpt.actions import Action, WritePRD, WriteTest +from metagpt.llm import LLM +from metagpt.provider.openai_api import OpenAIGPTAPI + + +def test_action_serialize(): + action = Action() + ser_action_dict = action.dict() + assert "name" in ser_action_dict + assert "llm" in ser_action_dict + + +@pytest.mark.asyncio +async def test_action_deserialize(): + action = Action() + serialized_data = action.dict() + assert isinstance(serialized_data["llm"], OpenAIGPTAPI) + + new_action = Action(**serialized_data) + + assert new_action.name == "" + assert new_action.llm == LLM() + assert len(await new_action._aask("who are you")) > 0 + + +def test_action_serdeser(): + action_info = WriteTest.ser_class() + assert action_info["action_class"] == "WriteTest" + + action_class = Action.deser_class(action_info) + assert action_class == WriteTest + + +def test_action_class_serdeser(): + name = "write test" + action_info = WriteTest(name=name).serialize() + assert action_info["name"] == name + + action_info = WriteTest(name=name, llm=LLM()).serialize() + assert action_info["name"] == name + + action = Action.deserialize(action_info) + assert action.name == name diff --git a/tests/metagpt/serialize_deserialize/test_actions.py b/tests/metagpt/serialize_deserialize/test_actions.py deleted file mode 100644 index 2fec2121a..000000000 --- a/tests/metagpt/serialize_deserialize/test_actions.py +++ /dev/null @@ -1,26 +0,0 @@ -# -*- coding: utf-8 -*- -# @Date : 11/22/2023 11:48 AM -# @Author : stellahong (stellahong@fuzhi.ai) -# @Desc : -import pytest - -from metagpt.actions import Action -from metagpt.llm import LLM - - -def test_action_serialize(): - action = Action() - ser_action_dict = action.dict() - assert "name" in ser_action_dict - assert "llm" in ser_action_dict - - -@pytest.mark.asyncio -async def test_action_deserialize(): - action = Action() - serialized_data = action.dict() - - new_action = Action(**serialized_data) - assert new_action.name == "" - assert new_action.llm == LLM() - assert len(await new_action._aask("who are you")) > 0 diff --git a/tests/metagpt/serialize_deserialize/test_architect_deserialize.py b/tests/metagpt/serialize_deserialize/test_architect_deserialize.py index d0ee3bc99..fb58f0a3a 100644 --- a/tests/metagpt/serialize_deserialize/test_architect_deserialize.py +++ b/tests/metagpt/serialize_deserialize/test_architect_deserialize.py @@ -25,4 +25,4 @@ async def test_architect_deserialize(): assert new_role.name == "Bob" assert len(new_role._actions) == 1 assert isinstance(new_role._actions[0], Action) - await new_role._actions[0].run(context="write a cli snake game") \ No newline at end of file + await new_role._actions[0].run(context="write a cli snake game") diff --git a/tests/metagpt/serialize_deserialize/test_environment.py b/tests/metagpt/serialize_deserialize/test_environment.py new file mode 100644 index 000000000..15336eb6a --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_environment.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +from pathlib import Path +import shutil + +from metagpt.schema import Message +from metagpt.actions.action_output import ActionOutput +from metagpt.roles.project_manager import ProjectManager +from metagpt.actions.add_requirement import BossRequirement +from metagpt.actions.project_management import WriteTasks +from metagpt.environment import Environment +from tests.metagpt.serialize_deserialize.test_serdeser_base import RoleC, ActionOK, serdeser_path + + +def test_env_serialize(): + env = Environment() + ser_env_dict = env.dict() + assert "roles" in ser_env_dict + assert "memory" in ser_env_dict + + +def test_env_deserialize(): + env = Environment() + env.publish_message(message=Message(content="test env serialize")) + ser_env_dict = env.dict() + new_env = Environment(**ser_env_dict) + assert len(new_env.roles) == 0 + assert new_env.memory.storage[0].content == "test env serialize" + assert len(new_env.history) == 25 + + +def test_environment_serdeser(): + out_mapping = {"field1": (list[str], ...)} + out_data = {"field1": ["field1 value1", "field1 value2"]} + ic_obj = ActionOutput.create_model_class("prd", out_mapping) + + message = Message( + content="prd", + instruct_content=ic_obj(**out_data), + role="product manager", + cause_by=BossRequirement + ) + + environment = Environment() + role_c = RoleC() + environment.add_role(role_c) + environment.publish_message(message) + + ser_data = environment.dict() + assert ser_data["roles"]["Role C"]["name"] == "RoleC" + + new_env: Environment = Environment(**ser_data) + assert len(new_env.roles) == 1 + + assert new_env.memory.count() == 1 + assert new_env.memory.storage[0].instruct_content == ic_obj(**out_data) + assert list(new_env.roles.values())[0]._states == list(environment.roles.values())[0]._states + assert list(new_env.roles.values())[0]._actions == list(environment.roles.values())[0]._actions + assert isinstance(list(environment.roles.values())[0]._actions[0], ActionOK) + assert type(list(new_env.roles.values())[0]._actions[0]) == ActionOK + + +def test_environment_serdeser_v2(): + environment = Environment() + pm = ProjectManager() + environment.add_role(pm) + + ser_data = environment.dict() + + new_env: Environment = Environment(**ser_data) + role = new_env.get_role(pm.profile) + assert isinstance(role, ProjectManager) + assert isinstance(role._actions[0], WriteTasks) + assert isinstance(list(new_env.roles.values())[0]._actions[0], WriteTasks) + + +def test_environment_serdeser_save(): + environment = Environment() + role_c = RoleC() + + shutil.rmtree(serdeser_path.joinpath("team"), ignore_errors=True) + + stg_path = serdeser_path.joinpath("team/environment") + environment.add_role(role_c) + environment.serialize(stg_path) + + new_env: Environment = Environment.deserialize(stg_path) + assert len(new_env.roles) == 1 + assert type(list(new_env.roles.values())[0]._actions[0]) == ActionOK diff --git a/tests/metagpt/memory/test_memory.py b/tests/metagpt/serialize_deserialize/test_memory.py similarity index 52% rename from tests/metagpt/memory/test_memory.py rename to tests/metagpt/serialize_deserialize/test_memory.py index bda79ded1..e24f31af3 100644 --- a/tests/metagpt/memory/test_memory.py +++ b/tests/metagpt/serialize_deserialize/test_memory.py @@ -3,6 +3,7 @@ # @Desc : unittest of memory from pathlib import Path +from pydantic import BaseModel from metagpt.schema import Message from metagpt.memory.memory import Memory @@ -10,10 +11,36 @@ from metagpt.actions.action_output import ActionOutput from metagpt.actions.design_api import WriteDesign from metagpt.actions.add_requirement import BossRequirement -serdes_path = Path(__file__).absolute().parent.joinpath("../../data/serdes_storage") +from tests.metagpt.serialize_deserialize.test_serdeser_base import serdeser_path -def test_memory_serdes(): +def test_memory_serdeser(): + msg1 = Message(role="Boss", + content="write a snake game", + cause_by=BossRequirement) + + out_mapping = {"field2": (list[str], ...)} + out_data = {"field2": ["field2 value1", "field2 value2"]} + ic_obj = ActionOutput.create_model_class("system_design", out_mapping) + msg2 = Message(role="Architect", + instruct_content=ic_obj(**out_data), + content="system design content", + cause_by=WriteDesign) + + memory = Memory() + memory.add_batch([msg1, msg2]) + ser_data = memory.dict() + + new_memory = Memory(**ser_data) + assert new_memory.count() == 2 + new_msg2 = new_memory.get(2)[0] + assert isinstance(new_msg2, BaseModel) + assert isinstance(new_memory.storage[-1], BaseModel) + assert new_memory.storage[-1].cause_by == WriteDesign + assert new_msg2.role == "Boss" + + +def test_memory_serdeser_save(): msg1 = Message(role="User", content="write a 2048 game", cause_by=BossRequirement) @@ -29,7 +56,7 @@ def test_memory_serdes(): memory = Memory() memory.add_batch([msg1, msg2]) - stg_path = serdes_path.joinpath("team/environment") + stg_path = serdeser_path.joinpath("team/environment") memory.serialize(stg_path) assert stg_path.joinpath("memory.json").exists() @@ -38,5 +65,6 @@ def test_memory_serdes(): new_msg2 = new_memory.get(1)[0] assert new_msg2.instruct_content.field1 == ["field1 value1", "field1 value2"] assert new_msg2.cause_by == WriteDesign + assert len(new_memory.index) == 2 stg_path.joinpath("memory.json").unlink() diff --git a/tests/metagpt/serialize_deserialize/test_product_manager.py b/tests/metagpt/serialize_deserialize/test_product_manager.py index 2aed87a28..54584cf96 100644 --- a/tests/metagpt/serialize_deserialize/test_product_manager.py +++ b/tests/metagpt/serialize_deserialize/test_product_manager.py @@ -15,8 +15,8 @@ async def test_product_manager_deserialize(): ser_role_dict = role.dict(by_alias=True) new_role = ProductManager(**ser_role_dict) # new_role = ProductManager().deserialize(ser_role_dict) - + assert new_role.name == "Alice" assert len(new_role._actions) == 1 assert isinstance(new_role._actions[0], Action) - await new_role._actions[0].run([Message(content="write a cli snake game")]) \ No newline at end of file + await new_role._actions[0].run([Message(content="write a cli snake game")]) diff --git a/tests/metagpt/serialize_deserialize/test_project_manager.py b/tests/metagpt/serialize_deserialize/test_project_manager.py index fbc0dcc08..21fafa72e 100644 --- a/tests/metagpt/serialize_deserialize/test_project_manager.py +++ b/tests/metagpt/serialize_deserialize/test_project_manager.py @@ -6,6 +6,7 @@ import pytest from metagpt.roles.project_manager import ProjectManager from metagpt.actions.action import Action +from metagpt.actions.project_management import WriteTasks def test_project_manager_serialize(): @@ -20,9 +21,10 @@ def test_project_manager_serialize(): async def test_project_manager_deserialize(): role = ProjectManager() ser_role_dict = role.dict(by_alias=True) + new_role = ProjectManager(**ser_role_dict) - # new_role = ProjectManager().deserialize(ser_role_dict) assert new_role.name == "Eve" assert len(new_role._actions) == 1 assert isinstance(new_role._actions[0], Action) - await new_role._actions[0].run(context="write a cli snake game") \ No newline at end of file + assert isinstance(new_role._actions[0], WriteTasks) + # await new_role._actions[0].run(context="write a cli snake game") diff --git a/tests/metagpt/serialize_deserialize/test_role.py b/tests/metagpt/serialize_deserialize/test_role.py index 0e438d1a2..f260dea3a 100644 --- a/tests/metagpt/serialize_deserialize/test_role.py +++ b/tests/metagpt/serialize_deserialize/test_role.py @@ -2,12 +2,22 @@ # @Date : 11/23/2023 4:49 PM # @Author : stellahong (stellahong@fuzhi.ai) # @Desc : + +from pathlib import Path +import shutil import pytest +from metagpt.logs import logger from metagpt.roles.role import Role +from metagpt.actions import WriteCode, WriteCodeReview +from metagpt.schema import Message +from metagpt.actions.add_requirement import BossRequirement +from metagpt.roles.product_manager import ProductManager +from metagpt.const import SERDESER_PATH from metagpt.roles.engineer import Engineer +from metagpt.utils.utils import format_trackback_info -from metagpt.actions.action import Action +from tests.metagpt.serialize_deserialize.test_serdeser_base import RoleC, serdeser_path def test_role_serialize(): @@ -30,12 +40,53 @@ def test_engineer_serialize(): async def test_engineer_deserialize(): role = Engineer(use_code_review=True) ser_role_dict = role.dict(by_alias=True) - # new_role = Engineer().deserialize(ser_role_dict) - # also can be deserialized in this way: + new_role = Engineer(**ser_role_dict) assert new_role.name == "Alex" assert new_role.use_code_review is True assert len(new_role._actions) == 2 - assert isinstance(new_role._actions[0], Action) - assert isinstance(new_role._actions[1], Action) - await new_role._actions[0].run(context="write a cli snake game", filename="test_code") + assert isinstance(new_role._actions[0], WriteCode) + assert isinstance(new_role._actions[1], WriteCodeReview) + # await new_role._actions[0].run(context="write a cli snake game", filename="test_code") + + +def test_role_serdeser_save(): + stg_path_prefix = serdeser_path.joinpath("team/environment/roles/") + shutil.rmtree(serdeser_path.joinpath("team"), ignore_errors=True) + + pm = ProductManager() + role_tag = f"{pm.__class__.__name__}_{pm.name}" + stg_path = stg_path_prefix.joinpath(role_tag) + pm.serialize(stg_path) + assert stg_path.joinpath("actions/actions_info.json").exists() + + new_pm = Role.deserialize(stg_path) + assert new_pm.name == pm.name + assert len(new_pm.get_memories(1)) == 0 + + +@pytest.mark.asyncio +async def test_role_serdeser_interrupt(): + role_c = RoleC() + shutil.rmtree(SERDESER_PATH.joinpath("team"), ignore_errors=True) + + stg_path = SERDESER_PATH.joinpath(f"team/environment/roles/{role_c.__class__.__name__}_{role_c.name}") + try: + await role_c.run( + message=Message(content="demo", cause_by=BossRequirement) + ) + except Exception as exp: + logger.error(f"Exception in `role_a.run`, detail: {format_trackback_info()}") + role_c.serialize(stg_path) + + assert role_c._rc.memory.count() == 2 + + assert stg_path.joinpath("actions/todo.json").exists() + + new_role_a: Role = Role.deserialize(stg_path) + assert new_role_a._rc.state == 1 + + with pytest.raises(Exception): + await role_c.run( + message=Message(content="demo", cause_by=BossRequirement) + ) diff --git a/tests/metagpt/serialize_deserialize/test_schema.py b/tests/metagpt/serialize_deserialize/test_schema.py new file mode 100644 index 000000000..74b134cad --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_schema.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : unittest of schema ser&deser + +from metagpt.schema import Message +from metagpt.actions.action_output import ActionOutput +from metagpt.actions.write_code import WriteCode + +from tests.metagpt.serialize_deserialize.test_serdeser_base import MockMessage + + +def test_message_serdeser(): + out_mapping = {"field3": (str, ...), "field4": (list[str], ...)} + out_data = {"field3": "field3 value3", "field4": ["field4 value1", "field4 value2"]} + ic_obj = ActionOutput.create_model_class("code", out_mapping) + + message = Message( + content="code", + instruct_content=ic_obj(**out_data), + role="engineer", + cause_by=WriteCode + ) + ser_data = message.dict() + assert ser_data["cause_by"] == { + "action_class": "WriteCode", + "module_name": "metagpt.actions.write_code" + } + assert ser_data["instruct_content"]["class"] == "code" + + new_message = Message(**ser_data) + assert new_message.cause_by == WriteCode + assert new_message.cause_by in [WriteCode] + assert new_message.instruct_content == ic_obj(**out_data) + + +def test_message_without_postprocess(): + """ to explain `instruct_content` should be postprocessed """ + out_mapping = {"field1": (list[str], ...)} + out_data = {"field1": ["field1 value1", "field1 value2"]} + ic_obj = ActionOutput.create_model_class("code", out_mapping) + message = MockMessage( + content="code", + instruct_content=ic_obj(**out_data) + ) + ser_data = message.dict() + assert ser_data["instruct_content"] == {"field1": ["field1 value1", "field1 value2"]} + + new_message = MockMessage(**ser_data) + assert new_message.instruct_content != ic_obj(**out_data) diff --git a/tests/metagpt/serialize_deserialize/test_serdeser_base.py b/tests/metagpt/serialize_deserialize/test_serdeser_base.py new file mode 100644 index 000000000..35bad6cd9 --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_serdeser_base.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : base test actions / roles used in unittest + +from pydantic import BaseModel, Field +from pathlib import Path + +from metagpt.actions.action import Action +from metagpt.roles.role import Role, RoleReactMode +from metagpt.actions.add_requirement import BossRequirement + + +serdeser_path = Path(__file__).absolute().parent.joinpath("../../data/serdeser_storage") + + +class MockMessage(BaseModel): + """ to test normal dict without postprocess """ + content: str = "" + instruct_content: BaseModel = Field(default=None) + + +class ActionPass(Action): + name: str = "ActionPass" + + async def run(self, messages: list["Message"]): + return "pass" + + +class ActionOK(Action): + name: str = "ActionOK" + + async def run(self, messages: list["Message"]): + return "ok" + + +class ActionRaise(Action): + name: str = "ActionRaise" + + async def run(self, messages: list["Message"]): + raise RuntimeError("parse error in ActionRaise") + + +class RoleA(Role): + + name: str = Field(default="RoleA") + profile: str = Field(default="Role A") + goal: str = "RoleA's goal" + constraints: str = "RoleA's constraints" + + def __init__(self, **kwargs): + super(RoleA, self).__init__(**kwargs) + self._init_actions([ActionPass]) + self._watch([BossRequirement]) + + async def run(self, message: "Message" = None): + await super(RoleA, self).run(message) + + +class RoleB(Role): + name: str = Field(default="RoleB") + profile: str = Field(default="Role B") + goal: str = "RoleB's goal" + constraints: str = "RoleB's constraints" + + def __init__(self, **kwargs): + super(RoleB, self).__init__(**kwargs) + self._init_actions([ActionOK, ActionRaise]) + self._watch([ActionPass]) + self._rc.react_mode = RoleReactMode.BY_ORDER + + async def run(self, message: "Message" = None): + await super(RoleB, self).run(message) + + +class RoleC(Role): + name: str = Field(default="RoleC") + profile: str = Field(default="Role C") + goal: str = "RoleC's goal" + constraints: str = "RoleC's constraints" + + def __init__(self, **kwargs): + super(RoleC, self).__init__(**kwargs) + self._init_actions([ActionOK, ActionRaise]) + self._watch([BossRequirement]) + self._rc.react_mode = RoleReactMode.BY_ORDER + + async def run(self, message: "Message" = None): + await super(RoleC, self).run(message) diff --git a/tests/metagpt/serialize_deserialize/test_team.py b/tests/metagpt/serialize_deserialize/test_team.py index 44a75d262..e9122ebc0 100644 --- a/tests/metagpt/serialize_deserialize/test_team.py +++ b/tests/metagpt/serialize_deserialize/test_team.py @@ -2,46 +2,104 @@ # @Date : 11/27/2023 10:07 AM # @Author : stellahong (stellahong@fuzhi.ai) # @Desc : + +from pathlib import Path +import shutil import pytest -from metagpt.environment import Environment -from metagpt.schema import Message -from metagpt.software_company import SoftwareCompany from metagpt.roles import ProjectManager, ProductManager, Architect +from metagpt.team import Team +from metagpt.const import SERDESER_PATH + +from tests.metagpt.serialize_deserialize.test_serdeser_base import RoleA, RoleB, RoleC, serdeser_path -def test_env_serialize(): - env = Environment() - ser_env_dict = env.dict() - assert "roles" in ser_env_dict - assert "memory" in ser_env_dict - assert "memory" in ser_env_dict +def test_team_deserialize(): + company = Team() - -def test_env_deserialize(): - env = Environment() - env.publish_message(message=Message(content="test env serialize")) - ser_env_dict = env.dict() - new_env = Environment(**ser_env_dict) - assert len(new_env.roles) == 0 - assert new_env.memory.storage[0].content == "test env serialize" - assert len(new_env.history) == 25 - - -def test_softwarecompany_deserialize(): - team = SoftwareCompany() - team.hire( + pm = ProductManager() + arch = Architect() + company.hire( [ - ProductManager(), - Architect(), + pm, + arch, ProjectManager(), ] ) - assert len(team.environment.get_roles()) == 3 - ser_team_dict = team.dict() - new_team = SoftwareCompany(**ser_team_dict) - - assert len(new_team.environment.get_roles()) == 3 - assert new_team.environment.get_role('Product Manager') is not None - assert new_team.environment.get_role('Product Manager') is not None - assert new_team.environment.get_role('Architect') is not None + assert len(company.environment.get_roles()) == 3 + ser_company = company.dict() + new_company = Team(**ser_company) + + assert len(new_company.environment.get_roles()) == 3 + assert new_company.environment.get_role(pm.profile) is not None + + new_pm = new_company.environment.get_role(pm.profile) + assert type(new_pm) == ProductManager + assert new_company.environment.get_role(pm.profile) is not None + assert new_company.environment.get_role(arch.profile) is not None + + +def test_team_serdeser(): + company = Team() + company.hire([RoleC()]) + + stg_path = serdeser_path.joinpath("team") + shutil.rmtree(stg_path, ignore_errors=True) + + company.serialize(stg_path=stg_path) + + new_company = Team.deserialize(stg_path) + + assert len(new_company.environment.roles) == 1 + + +@pytest.mark.asyncio +async def test_team_recover(): + idea = "write a snake game" + stg_path = SERDESER_PATH.joinpath("team") + shutil.rmtree(stg_path, ignore_errors=True) + + company = Team() + company.hire([RoleC()]) + company.start_project(idea) + await company.run(n_round=4) + + ser_data = company.dict() + new_company = Team(**ser_data) + assert new_company.environment.memory.count() == 1 + assert type(list(new_company.environment.roles.values())[0]._actions[0]) == ActionOK + + new_company.start_project(idea) + await new_company.run(n_round=4) + + +@pytest.mark.asyncio +async def test_team_recover_save(): + idea = "write a 2048 web game" + stg_path = SERDESER_PATH.joinpath("team") + shutil.rmtree(stg_path, ignore_errors=True) + + company = Team() + company.hire([RoleC()]) + company.start_project(idea) + await company.run(n_round=4) + + new_company = Team.recover(stg_path) + new_company.start_project(idea) + await new_company.run(n_round=4) + + +@pytest.mark.asyncio +async def test_team_recover_multi_roles_save(): + idea = "write a snake game" + stg_path = SERDESER_PATH.joinpath("team") + shutil.rmtree(stg_path, ignore_errors=True) + + company = Team() + company.hire([RoleA(), RoleB()]) + company.start_project(idea) + await company.run(n_round=4) + + new_company = Team.recover(stg_path) + new_company.start_project(idea) + await new_company.run(n_round=4) diff --git a/tests/metagpt/serialize_deserialize/test_wrire_prd.py b/tests/metagpt/serialize_deserialize/test_wrire_prd.py index baa08ed76..96b4d19ad 100644 --- a/tests/metagpt/serialize_deserialize/test_wrire_prd.py +++ b/tests/metagpt/serialize_deserialize/test_wrire_prd.py @@ -25,4 +25,3 @@ async def test_action_deserialize(): assert new_action.name == "" assert new_action.llm == LLM() assert len(await new_action.run([Message(content="write a cli snake game")])) > 0 - diff --git a/tests/metagpt/serialize_deserialize/test_write_code.py b/tests/metagpt/serialize_deserialize/test_write_code.py index 9d659caaf..7f4799014 100644 --- a/tests/metagpt/serialize_deserialize/test_write_code.py +++ b/tests/metagpt/serialize_deserialize/test_write_code.py @@ -43,4 +43,4 @@ async def test_write_code_review_deserialize(): assert new_action.name == "WriteCodeReview" assert new_action.llm == LLM() - await new_action.run(context="write a cli snake game", code =code, filename="test_rewrite_code") \ No newline at end of file + await new_action.run(context="write a cli snake game", code=code, filename="test_rewrite_code") diff --git a/tests/metagpt/test_environment.py b/tests/metagpt/test_environment.py index 3cc2d8a7a..9f69e6189 100644 --- a/tests/metagpt/test_environment.py +++ b/tests/metagpt/test_environment.py @@ -8,17 +8,15 @@ import pytest from pathlib import Path -import shutil from metagpt.actions import BossRequirement from metagpt.environment import Environment from metagpt.logs import logger from metagpt.roles import Architect, ProductManager, Role from metagpt.schema import Message -from tests.metagpt.roles.test_role import RoleA -serdes_path = Path(__file__).absolute().parent.joinpath("../data/serdes_storage") +serdeser_path = Path(__file__).absolute().parent.joinpath("../data/serdeser_storage") @pytest.fixture @@ -27,14 +25,23 @@ def env(): def test_add_role(env: Environment): - role = ProductManager("Alice", "product manager", "create a new product", "limited resources") + role = ProductManager(name="Alice", + profile="product manager", + goal="create a new product", + constraints="limited resources") env.add_role(role) assert env.get_role(role.profile) == role def test_get_roles(env: Environment): - role1 = Role("Alice", "product manager", "create a new product", "limited resources") - role2 = Role("Bob", "engineer", "develop the new product", "short deadline") + role1 = Role(name="Alice", + profile="product manager", + goal="create a new product", + constraints="limited resources") + role2 = Role(name="Bob", + profile="engineer", + goal="develop the new product", + constraints="short deadline") env.add_role(role1) env.add_role(role2) roles = env.get_roles() @@ -43,8 +50,14 @@ def test_get_roles(env: Environment): @pytest.mark.asyncio async def test_publish_and_process_message(env: Environment): - product_manager = ProductManager("Alice", "Product Manager", "做AI Native产品", "资源有限") - architect = Architect("Bob", "Architect", "设计一个可用、高效、较低成本的系统,包括数据结构与接口", "资源有限,需要节省成本") + product_manager = ProductManager(name="Alice", + profile="Product Manager", + goal="做AI Native产品", + constraints="资源有限") + architect = Architect(name="Bob", + profile="Architect", + goal="设计一个可用、高效、较低成本的系统,包括数据结构与接口", + constraints="资源有限,需要节省成本") env.add_roles([product_manager, architect]) env.publish_message(Message(role="BOSS", content="需要一个基于LLM做总结的搜索引擎", cause_by=BossRequirement)) @@ -52,18 +65,3 @@ async def test_publish_and_process_message(env: Environment): await env.run(k=2) logger.info(f"{env.history=}") assert len(env.history) > 10 - - -def test_environment_serdes(): - environment = Environment() - role_a = RoleA() - - shutil.rmtree(serdes_path.joinpath("team"), ignore_errors=True) - - stg_path = serdes_path.joinpath("team/environment") - environment.add_role(role_a) - environment.serialize(stg_path) - - new_env: Environment = Environment() - new_env.deserialize(stg_path) - assert len(new_env.roles) == 1 diff --git a/tests/metagpt/test_role.py b/tests/metagpt/test_role.py deleted file mode 100644 index 11fd804ec..000000000 --- a/tests/metagpt/test_role.py +++ /dev/null @@ -1,14 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -@Time : 2023/5/11 14:44 -@Author : alexanderwu -@File : test_role.py -""" -from metagpt.roles import Role - - -def test_role_desc(): - i = Role(profile='Sales', desc='Best Seller') - assert i.profile == 'Sales' - assert i._setting.desc == 'Best Seller' diff --git a/tests/metagpt/test_schema.py b/tests/metagpt/test_schema.py index f515326e8..c70c93cfc 100644 --- a/tests/metagpt/test_schema.py +++ b/tests/metagpt/test_schema.py @@ -25,7 +25,7 @@ def test_messages(): assert all([i in text for i in roles]) -def test_message_serdes(): +def test_message_serdeser(): out_mapping = {"field3": (str, ...), "field4": (list[str], ...)} out_data = {"field3": "field3 value3", "field4": ["field4 value1", "field4 value2"]} ic_obj = ActionOutput.create_model_class("code", out_mapping) @@ -37,7 +37,7 @@ def test_message_serdes(): cause_by=WriteCode ) message_dict = serialize_general_message(message) - assert message_dict["cause_by"] == {"action_class": "WriteCode"} + assert message_dict["cause_by"] == {"action_class": "WriteCode", "module_name": "metagpt.actions.write_code"} assert message_dict["instruct_content"] == { "class": "code", "mapping": { diff --git a/tests/metagpt/test_team.py b/tests/metagpt/test_team.py index ab201152c..efd035bb2 100644 --- a/tests/metagpt/test_team.py +++ b/tests/metagpt/test_team.py @@ -2,26 +2,12 @@ # -*- coding: utf-8 -*- # @Desc : unittest of team -from pathlib import Path -import shutil - from metagpt.team import Team - -from tests.metagpt.roles.test_role import RoleA - -serdes_path = Path(__file__).absolute().parent.joinpath("../data/serdes_storage") +from metagpt.roles.project_manager import ProjectManager -def test_team_serdes(): +def test_team(): company = Team() - company.hire([RoleA()]) + company.hire([ProjectManager()]) - stg_path = serdes_path.joinpath("team") - shutil.rmtree(stg_path, ignore_errors=True) - - company.serialize(stg_path=stg_path) - - new_company = Team() - new_company.deserialize(stg_path) - - assert len(new_company.environment.roles) == 1 + assert len(company.environment.roles) == 1 From 5e3607f85bc4fec0ff97c57ff7d866f108e3c9c3 Mon Sep 17 00:00:00 2001 From: better629 Date: Thu, 30 Nov 2023 15:18:24 +0800 Subject: [PATCH 007/167] update environment/message to BaseModel, update the ser&deser of roles/actions --- metagpt/actions/action.py | 29 ++++- metagpt/actions/design_api.py | 8 +- metagpt/actions/project_management.py | 3 +- metagpt/actions/search_and_summarize.py | 15 ++- metagpt/actions/write_code.py | 3 +- metagpt/actions/write_code_review.py | 4 +- metagpt/actions/write_prd.py | 6 +- metagpt/actions/write_test.py | 11 +- metagpt/const.py | 2 +- metagpt/environment.py | 39 +++++-- metagpt/memory/longterm_memory.py | 14 ++- metagpt/memory/memory.py | 79 ++++++++++---- metagpt/roles/architect.py | 4 +- metagpt/roles/customer_service.py | 19 ++-- metagpt/roles/engineer.py | 4 +- metagpt/roles/product_manager.py | 5 +- metagpt/roles/project_manager.py | 4 +- metagpt/roles/qa_engineer.py | 16 +-- metagpt/roles/role.py | 130 +++++++++++++--------- metagpt/roles/sales.py | 31 +++--- metagpt/roles/seacher.py | 21 ++-- metagpt/schema.py | 138 ++++++++++++++---------- metagpt/team.py | 39 ++++--- metagpt/utils/serialize.py | 26 +++-- metagpt/utils/utils.py | 43 ++++++++ startup.py | 17 +-- 26 files changed, 458 insertions(+), 252 deletions(-) diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index aefe6d39d..7a7f194f4 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -5,8 +5,9 @@ @Author : alexanderwu @File : action.py """ + +from __future__ import annotations import re -from abc import ABC from typing import Optional, Any from pydantic import BaseModel, Field @@ -14,25 +15,43 @@ from tenacity import retry, stop_after_attempt, wait_fixed from metagpt.actions.action_output import ActionOutput from metagpt.llm import LLM +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.logs import logger from metagpt.utils.common import OutputParser from metagpt.utils.custom_decoder import CustomDecoder from metagpt.utils.utils import import_class +action_subclass_registry = {} + + class Action(BaseModel): name: str = "" - llm: LLM = Field(default_factory=LLM) + llm: BaseGPTAPI = Field(default_factory=LLM, exclude=True) context = "" prefix = "" profile = "" desc = "" content: Optional[str] = None instruct_content: Optional[str] = None + + # builtin variables + builtin_class_name: str = "" + + class Config: + arbitrary_types_allowed = True def __init__(self, **kwargs: Any): super().__init__(**kwargs) - + + # deserialize child classes dynamically for inherited `action` + object.__setattr__(self, "builtin_class_name", self.__class__.__name__) + self.__fields__["builtin_class_name"].default = self.__class__.__name__ + + def __init_subclass__(cls, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + action_subclass_registry[cls.__name__] = cls + def set_prefix(self, prefix, profile): """Set prefix for later usage""" self.prefix = prefix @@ -52,14 +71,14 @@ class Action(BaseModel): } @classmethod - def deserialize(cls, action_dict: dict): + def deserialize(cls, action_dict: dict) -> "Action": action_class_str = action_dict.pop("action_class") module_name = action_dict.pop("module_name") action_class = import_class(action_class_str, module_name) return action_class(**action_dict) @classmethod - def ser_class(cls): + def ser_class(cls) -> dict: """ serialize class type""" return { "action_class": cls.__name__, diff --git a/metagpt/actions/design_api.py b/metagpt/actions/design_api.py index 30df70ce7..015678baa 100644 --- a/metagpt/actions/design_api.py +++ b/metagpt/actions/design_api.py @@ -13,6 +13,7 @@ from pydantic import Field from metagpt.actions import Action, ActionOutput from metagpt.llm import LLM +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.config import CONFIG from metagpt.const import WORKSPACE_ROOT from metagpt.logs import logger @@ -155,12 +156,11 @@ OUTPUT_MAPPING = { class WriteDesign(Action): name: str = "" context: Optional[str] = None - llm: LLM = Field(default_factory=LLM) + llm: BaseGPTAPI = Field(default_factory=LLM) desc: str = "Based on the PRD, think about the system design, and design the corresponding APIs, " "data structures, library tables, processes, and paths. Please provide your design, feedback " "clearly and in detail." - def recreate_workspace(self, workspace: Path): try: shutil.rmtree(workspace) @@ -168,7 +168,6 @@ class WriteDesign(Action): pass # Folder does not exist, but we don't care workspace.mkdir(parents=True, exist_ok=True) - async def _save_prd(self, docs_path, resources_path, context): prd_file = docs_path / "prd.md" if context[-1].instruct_content and context[-1].instruct_content.dict()["Competitive Quadrant Chart"]: @@ -179,7 +178,6 @@ class WriteDesign(Action): logger.info(f"Saving PRD to {prd_file}") prd_file.write_text(json_to_markdown(context[-1].instruct_content.dict())) - async def _save_system_design(self, docs_path, resources_path, system_design): data_api_design = system_design.instruct_content.dict()[ "Data structures and interface definitions" @@ -193,7 +191,6 @@ class WriteDesign(Action): logger.info(f"Saving System Designs to {system_design_file}") system_design_file.write_text((json_to_markdown(system_design.instruct_content.dict()))) - async def _save(self, context, system_design): if isinstance(system_design, ActionOutput): ws_name = system_design.instruct_content.dict()["Python package name"] @@ -211,7 +208,6 @@ class WriteDesign(Action): logger.error(f"Failed to save PRD {e}") await self._save_system_design(docs_path, resources_path, system_design) - async def run(self, context, format=CONFIG.prompt_format): prompt_template, format_example = get_template(templates, format) prompt = prompt_template.format(context=context, format_example=format_example) diff --git a/metagpt/actions/project_management.py b/metagpt/actions/project_management.py index b72507ee3..cf44906cd 100644 --- a/metagpt/actions/project_management.py +++ b/metagpt/actions/project_management.py @@ -11,6 +11,7 @@ from pydantic import Field from metagpt.actions.action import Action from metagpt.llm import LLM +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.config import CONFIG from metagpt.const import WORKSPACE_ROOT from metagpt.utils.common import CodeParser @@ -168,7 +169,7 @@ OUTPUT_MAPPING = { class WriteTasks(Action): name: str = "CreateTasks" context: Optional[str] = None - llm: LLM = Field(default_factory=LLM) + llm: BaseGPTAPI = Field(default_factory=LLM) def _save(self, context, rsp): try: diff --git a/metagpt/actions/search_and_summarize.py b/metagpt/actions/search_and_summarize.py index 0580303e6..6b0c1f717 100644 --- a/metagpt/actions/search_and_summarize.py +++ b/metagpt/actions/search_and_summarize.py @@ -8,14 +8,15 @@ import pydantic from typing import Optional, Any from pydantic import BaseModel, Field +from pydantic import root_validator from metagpt.actions import Action from metagpt.llm import LLM -from metagpt.config import Config +from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.config import Config, CONFIG from metagpt.logs import logger from metagpt.schema import Message from metagpt.tools.search_engine import SearchEngine -from pydantic import root_validator SEARCH_AND_SUMMARIZE_SYSTEM = """### Requirements 1. Please summarize the latest dialogue based on the reference information (secondary) and dialogue history (primary). Do not include text that is irrelevant to the conversation. @@ -106,13 +107,13 @@ You are a member of a professional butler team and will provide helpful suggesti class SearchAndSummarize(Action): name: str = "" content: Optional[str] = None - llm: None = Field(default_factory=LLM) + llm: BaseGPTAPI = Field(default_factory=LLM) config: None = Field(default_factory=Config) - engine: Optional[str] = None + engine: Optional[str] = CONFIG.search_engine search_func: Optional[str] = None + search_engine: SearchEngine = None result = "" - @root_validator def validate_engine_and_run_func(cls, values): @@ -130,9 +131,7 @@ class SearchAndSummarize(Action): values['search_engine'] = search_engine return values - - - + async def run(self, context: list[Message], system_text=SEARCH_AND_SUMMARIZE_SYSTEM) -> str: print(context) if self.search_engine is None: diff --git a/metagpt/actions/write_code.py b/metagpt/actions/write_code.py index 2dc240591..10487e53a 100644 --- a/metagpt/actions/write_code.py +++ b/metagpt/actions/write_code.py @@ -13,6 +13,7 @@ from tenacity import retry, stop_after_attempt, wait_fixed from metagpt.actions import WriteDesign from metagpt.actions.action import Action from metagpt.llm import LLM +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.const import WORKSPACE_ROOT from metagpt.logs import logger from metagpt.schema import Message @@ -50,7 +51,7 @@ ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenc class WriteCode(Action): name: str = "WriteCode" context: Optional[str] = None - llm: LLM = Field(default_factory=LLM) + llm: BaseGPTAPI = Field(default_factory=LLM) def _is_invalid(self, filename): return any(i in filename for i in ["mp3", "wav"]) diff --git a/metagpt/actions/write_code_review.py b/metagpt/actions/write_code_review.py index 3d86d7c63..79e462f76 100644 --- a/metagpt/actions/write_code_review.py +++ b/metagpt/actions/write_code_review.py @@ -12,7 +12,7 @@ from tenacity import retry, stop_after_attempt, wait_fixed from metagpt.llm import LLM from metagpt.actions.action import Action from metagpt.logs import logger -from metagpt.schema import Message +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.utils.common import CodeParser PROMPT_TEMPLATE = """ @@ -67,7 +67,7 @@ FORMAT_EXAMPLE = """ class WriteCodeReview(Action): name: str = "WriteCodeReview" context: Optional[str] = None - llm: LLM = Field(default_factory=LLM) + llm: BaseGPTAPI = Field(default_factory=LLM) @retry(stop=stop_after_attempt(2), wait=wait_fixed(1)) async def write_code(self, prompt): diff --git a/metagpt/actions/write_prd.py b/metagpt/actions/write_prd.py index 660d7fb95..450bed7e7 100644 --- a/metagpt/actions/write_prd.py +++ b/metagpt/actions/write_prd.py @@ -11,6 +11,7 @@ from pydantic import BaseModel, Field from metagpt.actions import Action, ActionOutput from metagpt.llm import LLM +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.actions.search_and_summarize import SearchAndSummarize from metagpt.config import CONFIG from metagpt.logs import logger @@ -224,12 +225,9 @@ OUTPUT_MAPPING = { class WritePRD(Action): name: str = "" content: Optional[str] = None - llm: LLM = Field(default_factory=LLM) + llm: BaseGPTAPI = Field(default_factory=LLM) assistant_search_action: Action = None - def __init__(self, **kwargs): - super().__init__(**kwargs) - async def run(self, requirements, format=CONFIG.prompt_format, *args, **kwargs) -> ActionOutput: # self.assistant_search_action = SearchAndSummarize() if self.assistant_search_action is None: diff --git a/metagpt/actions/write_test.py b/metagpt/actions/write_test.py index 35ff36dc2..6c902444a 100644 --- a/metagpt/actions/write_test.py +++ b/metagpt/actions/write_test.py @@ -5,6 +5,12 @@ @Author : alexanderwu @File : environment.py """ + +from typing import Optional +from pydantic import Field + +from metagpt.llm import LLM +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.actions.action import Action from metagpt.logs import logger from metagpt.utils.common import CodeParser @@ -31,8 +37,9 @@ you should correctly import the necessary classes based on these file locations! class WriteTest(Action): - def __init__(self, name="WriteTest", context=None, llm=None): - super().__init__(name, context, llm) + name: str = "WriteTest" + context: Optional[str] = None + llm: BaseGPTAPI = Field(default_factory=LLM) async def write_code(self, prompt): code_rsp = await self._aask(prompt) diff --git a/metagpt/const.py b/metagpt/const.py index 711546d03..4b063a3dd 100644 --- a/metagpt/const.py +++ b/metagpt/const.py @@ -42,7 +42,7 @@ TMP = PROJECT_ROOT / "tmp" RESEARCH_PATH = DATA_PATH / "research" TUTORIAL_PATH = DATA_PATH / "tutorial_docx" INVOICE_OCR_TABLE_PATH = DATA_PATH / "invoice_table" -SERDES_PATH = WORKSPACE_ROOT / "storage" # TODO to store `storage` under the individual generated project +SERDESER_PATH = WORKSPACE_ROOT / "storage" # TODO to store `storage` under the individual generated project SKILL_DIRECTORY = PROJECT_ROOT / "metagpt/skills" diff --git a/metagpt/environment.py b/metagpt/environment.py index e867ad6fc..bade53f50 100644 --- a/metagpt/environment.py +++ b/metagpt/environment.py @@ -12,7 +12,7 @@ from pathlib import Path from pydantic import BaseModel, Field from metagpt.memory import Memory -from metagpt.roles import Role +from metagpt.roles.role import Role, role_subclass_registry from metagpt.schema import Message from metagpt.utils.utils import read_json_file, write_json_file @@ -30,6 +30,19 @@ class Environment(BaseModel): class Config: arbitrary_types_allowed = True + def __init__(self, **kwargs): + for role_key, role in kwargs.get("roles", {}).items(): + current_role = kwargs["roles"][role_key] + if isinstance(current_role, dict): + item_class_name = current_role.get("builtin_class_name", None) + for name, subclass in role_subclass_registry.items(): + registery_class_name = subclass.__fields__["builtin_class_name"].default + if item_class_name == registery_class_name: + current_role = subclass(**current_role) + break + kwargs["roles"][role_key] = current_role + super().__init__(**kwargs) + def serialize(self, stg_path: Path): roles_path = stg_path.joinpath("roles.json") roles_info = [] @@ -46,33 +59,39 @@ class Environment(BaseModel): history_path = stg_path.joinpath("history.json") write_json_file(history_path, {"content": self.history}) - def deserialize(self, stg_path: Path): + @classmethod + def deserialize(cls, stg_path: Path) -> "Environment": """ stg_path: ./storage/team/environment/ """ roles_path = stg_path.joinpath("roles.json") roles_info = read_json_file(roles_path) + roles = [] for role_info in roles_info: role_class = role_info.get("role_class") role_name = role_info.get("role_name") role_path = stg_path.joinpath(f"roles/{role_class}_{role_name}") role = Role.deserialize(role_path) - - self.add_role(role) + roles.append(role) memory = Memory.deserialize(stg_path) - self.memory = memory - history_path = stg_path.joinpath("history.json") - history = read_json_file(history_path) - self.history = history.get("content") + history = read_json_file(stg_path.joinpath("history.json")) + history = history.get("content") + + environment = Environment(**{ + "memory": memory, + "history": history + }) + environment.add_roles(roles) + return environment def add_role(self, role: Role): - """增加一个在当前环境的角色, 默认为profile/role_profile + """增加一个在当前环境的角色, 默认为profile Add a role in the current environment """ role.set_env(self) # use alias - self.roles[role.role_profile] = role + self.roles[role.profile] = role def add_roles(self, roles: Iterable[Role]): """增加一批在当前环境的角色 diff --git a/metagpt/memory/longterm_memory.py b/metagpt/memory/longterm_memory.py index f8abea5f3..5d149ee7a 100644 --- a/metagpt/memory/longterm_memory.py +++ b/metagpt/memory/longterm_memory.py @@ -2,6 +2,9 @@ # -*- coding: utf-8 -*- # @Desc : the implement of Long-term memory +from typing import Optional +from pydantic import Field + from metagpt.logs import logger from metagpt.memory import Memory from metagpt.memory.memory_storage import MemoryStorage @@ -15,11 +18,12 @@ class LongTermMemory(Memory): - update memory when it changed """ - def __init__(self): - self.memory_storage: MemoryStorage = MemoryStorage() - super(LongTermMemory, self).__init__() - self.rc = None # RoleContext - self.msg_from_recover = False + memory_storage: MemoryStorage = Field(default_factory=MemoryStorage) + rc: Optional["RoleContext"] = None + msg_from_recover: bool = False + + class Config: + arbitrary_types_allowed = True def recover_memory(self, role_id: str, rc: "RoleContext"): messages = self.memory_storage.recover_memory(role_id) diff --git a/metagpt/memory/memory.py b/metagpt/memory/memory.py index a839bb038..c88cc750e 100644 --- a/metagpt/memory/memory.py +++ b/metagpt/memory/memory.py @@ -5,34 +5,65 @@ @Author : alexanderwu @File : memory.py """ +import copy from collections import defaultdict -from typing import Iterable, Type +from typing import Iterable, Type, Union, Optional from pathlib import Path +from pydantic import BaseModel, Field +import json from metagpt.actions import Action from metagpt.schema import Message from metagpt.utils.utils import read_json_file, write_json_file -from metagpt.utils.serialize import serialize_general_message, deserialize_general_message +from metagpt.utils.utils import import_class -class Memory: +class Memory(BaseModel): """The most basic memory: super-memory""" - def __init__(self): - """Initialize an empty storage list and an empty index dictionary""" - self.storage: list[Message] = [] - self.index: dict[Type[Action], list[Message]] = defaultdict(list) + storage: list[Message] = Field(default=[]) + index: dict[Type[Action], list[Message]] = Field(default_factory=defaultdict(list)) + + def __init__(self, **kwargs): + index = kwargs.get("index", {}) + new_index = defaultdict(list) + for action_str, value in index.items(): + action_dict = json.loads(action_str) + action_class = import_class("Action", "metagpt.actions.action") + action_obj = action_class.deser_class(action_dict) + new_index[action_obj] = [Message(**item_dict) for item_dict in value] + kwargs["index"] = new_index + super(Memory, self).__init__(**kwargs) + self.index = new_index + + def dict(self, + *, + include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + by_alias: bool = False, + skip_defaults: Optional[bool] = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False) -> "DictStrAny": + """ overwrite the `dict` to dump dynamic pydantic model""" + obj_dict = super(Memory, self).dict(include=include, + exclude=exclude, + by_alias=by_alias, + skip_defaults=skip_defaults, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none) + new_obj_dict = copy.deepcopy(obj_dict) + new_obj_dict["index"] = {} + for action, value in obj_dict["index"].items(): + action_ser = json.dumps(action.ser_class()) + new_obj_dict["index"][action_ser] = value + return new_obj_dict def serialize(self, stg_path: Path): """ stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/ """ memory_path = stg_path.joinpath("memory.json") - - storage = [] - for message in self.storage: - # msg_dict = message.serialize() - msg_dict = serialize_general_message(message) - storage.append(msg_dict) - + storage = self.dict() write_json_file(memory_path, storage) @classmethod @@ -40,13 +71,8 @@ class Memory: """ stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/""" memory_path = stg_path.joinpath("memory.json") - memory = Memory() - memory_list = read_json_file(memory_path) - for message in memory_list: - # distinguish instruct_content type in message - # msg = Message.deserialize(message) - msg = deserialize_general_message(message) - memory.add(msg) + memory_dict = read_json_file(memory_path) + memory = Memory(**memory_dict) return memory @@ -70,6 +96,16 @@ class Memory: """Return all messages containing a specified content""" return [message for message in self.storage if content in message.content] + def delete_newest(self) -> "Message": + """ delete the newest message from the storage""" + if len(self.storage) > 0: + newest_msg = self.storage.pop() + if newest_msg.cause_by and newest_msg in self.index[newest_msg.cause_by]: + self.index[newest_msg.cause_by].remove(newest_msg) + else: + newest_msg = None + return newest_msg + def delete(self, message: Message): """Delete the specified message from storage, while updating the index""" self.storage.remove(message) @@ -115,4 +151,3 @@ class Memory: continue rsp += self.index[action] return rsp - \ No newline at end of file diff --git a/metagpt/roles/architect.py b/metagpt/roles/architect.py index face22a68..09d52edbe 100644 --- a/metagpt/roles/architect.py +++ b/metagpt/roles/architect.py @@ -22,8 +22,8 @@ class Architect(Role): goal (str): Primary goal or responsibility of the architect. constraints (str): Constraints or guidelines for the architect. """ - name: str = "Bob" - role_profile: str = Field(default="Architect" , alias='profile') + name: str = Field(default="Bob") + profile: str = Field(default="Architect") goal: str = "Design a concise, usable, complete python system" constraints: str = "Try to specify good open source tools as much as possible" diff --git a/metagpt/roles/customer_service.py b/metagpt/roles/customer_service.py index 4547f8190..62792696f 100644 --- a/metagpt/roles/customer_service.py +++ b/metagpt/roles/customer_service.py @@ -5,6 +5,9 @@ @Author : alexanderwu @File : sales.py """ +from typing import Optional +from pydantic import Field + from metagpt.roles import Sales # from metagpt.actions import SearchAndSummarize @@ -24,12 +27,14 @@ DESC = """ class CustomerService(Sales): + + name: str = Field(default="Xiaomei") + profile: str = Field(default="Human customer service") + desc: str = DESC, + + store: Optional[str] = None + def __init__( self, - name="Xiaomei", - profile="Human customer service", - desc=DESC, - store=None - ): - super().__init__(name, profile, desc=desc, store=store) - \ No newline at end of file + **kwargs): + super().__init__(**kwargs) diff --git a/metagpt/roles/engineer.py b/metagpt/roles/engineer.py index 129bedeb8..e90f586f0 100644 --- a/metagpt/roles/engineer.py +++ b/metagpt/roles/engineer.py @@ -60,8 +60,8 @@ class Engineer(Role): use_code_review (bool): Whether to use code review. todos (list): List of tasks. """ - name: str = "Alex" - role_profile: str = Field(default="Engineer", alias='profile') + name: str = Field(default="Alex") + profile: str = Field(default="Engineer") goal: str = "Write elegant, readable, extensible, efficient code" constraints: str = "The code should conform to standards like PEP8 and be modular and maintainable" n_borg: int = 1 diff --git a/metagpt/roles/product_manager.py b/metagpt/roles/product_manager.py index b099fb4d9..6f68fe5ba 100644 --- a/metagpt/roles/product_manager.py +++ b/metagpt/roles/product_manager.py @@ -21,10 +21,11 @@ class ProductManager(Role): goal (str): Goal of the product manager. constraints (str): Constraints or limitations for the product manager. """ - name: str = "Alice" - role_profile: str = Field(default="Product Manager", alias='profile') + name: str = Field(default="Alice") + profile: str = Field(default="Product Manager") goal: str = "Efficiently create a successful product" constraints: str = "" + """ Represents a Product Manager role responsible for product development and management. """ diff --git a/metagpt/roles/project_manager.py b/metagpt/roles/project_manager.py index a2b227f22..c8e785d85 100644 --- a/metagpt/roles/project_manager.py +++ b/metagpt/roles/project_manager.py @@ -22,8 +22,8 @@ class ProjectManager(Role): goal (str): Goal of the project manager. constraints (str): Constraints or limitations for the project manager. """ - name: str = "Eve" - role_profile: str = Field(default="Project Manager", alias='profile') + name: str = Field(default="Eve") + profile: str = Field(default="Project Manager") goal: str = "Improve team efficiency and deliver with quality and quantity" constraints: str = "" diff --git a/metagpt/roles/qa_engineer.py b/metagpt/roles/qa_engineer.py index a763c2ce8..bad3f2409 100644 --- a/metagpt/roles/qa_engineer.py +++ b/metagpt/roles/qa_engineer.py @@ -7,6 +7,7 @@ """ import os from pathlib import Path +from pydantic import Field from metagpt.actions import ( DebugError, @@ -25,21 +26,22 @@ from metagpt.utils.special_tokens import FILENAME_CODE_SEP, MSG_SEP class QaEngineer(Role): + name: str = Field(default="Edward") + profile: str = Field(default="QaEngineer") + goal: str = "Write comprehensive and robust tests to ensure codes will work as expected without bugs" + constraints: str = "The test code you write should conform to code standard like PEP8, be modular, easy to read and maintain" + test_round_allowed: int = 5 + def __init__( self, - name="Edward", - profile="QaEngineer", - goal="Write comprehensive and robust tests to ensure codes will work as expected without bugs", - constraints="The test code you write should conform to code standard like PEP8, be modular, easy to read and maintain", - test_round_allowed=5, + **kwargs ): - super().__init__(name, profile, goal, constraints) + super().__init__(**kwargs) self._init_actions( [WriteTest] ) # FIXME: a bit hack here, only init one action to circumvent _think() logic, will overwrite _think() in future updates self._watch([WriteCode, WriteCodeReview, WriteTest, RunCode, DebugError]) self.test_round = 0 - self.test_round_allowed = test_round_allowed @classmethod def parse_workspace(cls, system_design_msg: Message) -> str: diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index e9371c2c0..b6332aa4c 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -6,27 +6,29 @@ @File : role.py """ +from __future__ import annotations from enum import Enum from pathlib import Path -from __future__ import annotations from typing import ( Iterable, - Type + Type, + Any ) -import re -from pydantic import BaseModel, Field -from importlib import import_module +from pydantic import BaseModel, Field, validator # from metagpt.environment import Environment from metagpt.config import CONFIG -from metagpt.actions import Action, ActionOutput +from metagpt.actions.action import Action, ActionOutput, action_subclass_registry from metagpt.llm import LLM +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.logs import logger from metagpt.memory import Memory, LongTermMemory from metagpt.schema import Message from metagpt.provider.human_provider import HumanProvider -from metagpt.utils.utils import read_json_file, write_json_file, import_class +from metagpt.utils.utils import read_json_file, write_json_file, import_class, role_raise_decorator +from metagpt.const import SERDESER_PATH + PREFIX_TEMPLATE = """You are a {profile}, named {name}, your goal is {goal}, and the constraint is {constraints}. """ @@ -57,6 +59,7 @@ ROLE_TEMPLATE = """Your response should be based on the previous conversation hi {name}: {result} """ + class RoleReactMode(str, Enum): REACT = "react" BY_ORDER = "by_order" @@ -74,6 +77,7 @@ class RoleSetting(BaseModel): goal: str = "" constraints: str = "" desc: str = "" + is_human: bool = False def __str__(self): return f"{self.name}({self.profile})" @@ -84,10 +88,10 @@ class RoleSetting(BaseModel): class RoleContext(BaseModel): """Role Runtime Context""" - env: 'Environment' = Field(default=None) + env: "Environment" = Field(default=None) memory: Memory = Field(default_factory=Memory) long_term_memory: LongTermMemory = Field(default_factory=LongTermMemory) - state: int = Field(default=0) + state: int = Field(default=-1) # -1 indicates initial or termination state where todo is None todo: Action = Field(default=None) watch: set[Type[Action]] = Field(default_factory=set) news: list[Type[Message]] = Field(default=[]) @@ -112,53 +116,86 @@ class RoleContext(BaseModel): return self.memory.get() +role_subclass_registry = {} + + class Role(BaseModel): """Role/Agent""" - name: str = "" profile: str = "" goal: str = "" constraints: str = "" desc: str = "" - _setting: RoleSetting = Field(default_factory=RoleSetting, alias="_setting") - _setting = RoleSetting(name=name, profile=profile, goal=goal, constraints=constraints) + is_human: bool = False + + _llm: BaseGPTAPI = Field(default_factory=LLM, exclude=True) + _setting: RoleSetting = Field(default_factory=RoleSetting, alias=True) _role_id: str = "" - _states: list = Field(default=[]) - _actions: list = Field(default=[]) - _actions_type: list = Field(default=[]) + _states: list[str] = Field(default=[]) + _actions: list[Action] = Field(default=[]) _rc: RoleContext = RoleContext() - + + # builtin variables + recovered: bool = False # to tag if a recovered role + builtin_class_name: str = "" + _private_attributes = { - "_setting": _setting, + "_llm": LLM() if not is_human else HumanProvider(), "_role_id": _role_id, "_states": [], - "_actions": [], - "_actions_type": [] # 用于记录和序列化 + "_actions": [] } - + class Config: arbitrary_types_allowed = True - - def __init__(self, **kwargs): + exclude = ["_llm"] + + def __init__(self, **kwargs: Any): + for index in range(len(kwargs.get("_actions", []))): + current_action = kwargs["_actions"][index] + if isinstance(current_action, dict): + item_class_name = current_action.get("builtin_class_name", None) + for name, subclass in action_subclass_registry.items(): + registery_class_name = subclass.__fields__["builtin_class_name"].default + if item_class_name == registery_class_name: + current_action = subclass(**current_action) + break + kwargs["_actions"][index] = current_action + super().__init__(**kwargs) + # 关于私有变量的初始化 https://github.com/pydantic/pydantic/issues/655 + self._private_attributes["_llm"] = LLM() if not self.is_human else HumanProvider() + self._private_attributes["_setting"] = RoleSetting(name=self.name, profile=self.profile, goal=self.goal, + desc=self.desc, constraints=self.constraints, + is_human=self.is_human) for key in self._private_attributes.keys(): if key in kwargs: object.__setattr__(self, key, kwargs[key]) - if key =="_setting": - _setting = RoleSetting(**kwargs[key]) - object.__setattr__(self, '_setting', _setting) + if key == "_setting": + setting = RoleSetting(**kwargs[key]) + object.__setattr__(self, "_setting", setting) elif key == "_rc": _rc = RoleContext - object.__setattr__(self, '_rc', _rc) + object.__setattr__(self, "_rc", _rc) else: object.__setattr__(self, key, self._private_attributes[key]) + + # deserialize child classes dynamically for inherited `role` + object.__setattr__(self, "builtin_class_name", self.__class__.__name__) + self.__fields__["builtin_class_name"].default = self.__class__.__name__ + + def __init_subclass__(cls, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + role_subclass_registry[cls.__name__] = cls def _reset(self): - object.__setattr__(self, '_states', []) - object.__setattr__(self, '_actions', []) + object.__setattr__(self, "_states", []) + object.__setattr__(self, "_actions", []) - def serialize(self, stg_path: Path): + def serialize(self, stg_path: Path = None): + stg_path = SERDESER_PATH.joinpath(f"team/environment/roles/{self.__class__.__name__}_{self.name}") \ + if stg_path is None else stg_path role_info_path = stg_path.joinpath("role_info.json") role_info = { "role_class": self.__class__.__name__, @@ -207,7 +244,7 @@ class Role(BaseModel): actions = [] actions_info = read_json_file(actions_info_path) for action_info in actions_info: - action = Action.deserialize(action_info) + action = Action.deser_class(action_info) actions.append(action) watches_info_path = stg_path.joinpath("watches/watches_info.json") @@ -238,12 +275,8 @@ class Role(BaseModel): return role - def _reset(self): - self._states = [] - self._actions = [] - def set_recovered(self, recovered: bool = False): - self._recovered = recovered + self.recovered = recovered def set_memory(self, memory: Memory): self._rc.memory = memory @@ -256,7 +289,8 @@ class Role(BaseModel): for idx, action in enumerate(actions): if not isinstance(action, Action): ## 默认初始化 - i = action("", llm=self._llm) + # import pdb; pdb.set_trace() + i = action(name="", llm=self._llm) else: if self._setting.is_human and not isinstance(action.llm, HumanProvider): logger.warning(f"is_human attribute does not take effect," @@ -265,8 +299,6 @@ class Role(BaseModel): i.set_prefix(self._get_prefix(), self.profile) self._actions.append(i) self._states.append(f"{idx}. {action}") - action_title = action.schema()["title"] - self._actions_type.append(action_title) def set_react_mode(self, react_mode: RoleReactMode, max_react_loop: int = 1): self._set_react_mode(react_mode, max_react_loop) @@ -310,19 +342,10 @@ class Role(BaseModel): logger.debug(self._actions) self._rc.todo = self._actions[self._rc.state] if state >= 0 else None - def set_env(self, env: 'Environment'): + def set_env(self, env: "Environment"): """Set the environment in which the role works. The role can talk to the environment and can also receive messages by observing.""" self._rc.env = env - @property - def name(self): - return self._setting.name - - @property - def profile(self): - """Get the role description (position)""" - return self._setting.profile - def _get_prefix(self): """Get the role prefix""" if self._setting.desc: @@ -347,7 +370,7 @@ class Role(BaseModel): logger.debug(f"{prompt=}") if (not next_state.isdigit() and next_state != "-1") \ or int(next_state) not in range(-1, len(self._states)): - logger.warning(f'Invalid answer of state, {next_state=}, will be set to -1') + logger.warning(f"Invalid answer of state, {next_state=}, will be set to -1") next_state = -1 else: next_state = int(next_state) @@ -384,7 +407,7 @@ class Role(BaseModel): news_text = [f"{i.role}: {i.content[:20]}..." for i in self._rc.news] if news_text: - logger.debug(f'{self._setting} observed: {news_text}') + logger.debug(f"{self._setting} observed: {news_text}") return len(self._rc.news) def _publish_message(self, msg): @@ -400,7 +423,7 @@ class Role(BaseModel): Use llm to select actions in _think dynamically """ actions_taken = 0 - rsp = Message("No actions taken yet") # will be overwritten after Role _act + rsp = Message(content="No actions taken yet") # will be overwritten after Role _act while actions_taken < self._rc.max_react_loop: # think await self._think() @@ -410,7 +433,7 @@ class Role(BaseModel): logger.debug(f"{self._setting}: {self._rc.state=}, will do {self._rc.todo}") rsp = await self._act() actions_taken += 1 - return rsp # return output from the last action + return rsp # return output from the last action async def _act_by_order(self) -> Message: """switch action each time by order defined in _init_actions, i.e. _act (Action1) -> _act (Action2) -> ...""" @@ -454,7 +477,8 @@ class Role(BaseModel): def get_memories(self, k=0) -> list[Message]: """A wrapper to return the most recent k memories of this role, return all when k=0""" return self._rc.memory.get(k=k) - + + @role_raise_decorator async def run(self, message=None): """Observe, and think and act based on the results of the observation""" if message: diff --git a/metagpt/roles/sales.py b/metagpt/roles/sales.py index a45ad6f1b..dd360d82a 100644 --- a/metagpt/roles/sales.py +++ b/metagpt/roles/sales.py @@ -5,26 +5,34 @@ @Author : alexanderwu @File : sales.py """ + +from typing import Optional +from pydantic import Field + from metagpt.actions import SearchAndSummarize from metagpt.roles import Role from metagpt.tools import SearchEngineType class Sales(Role): + + name: str = Field(default="Xiaomei") + profile: str = Field(default="Retail sales guide") + desc: str = "I am a sales guide in retail. My name is Xiaomei. I will answer some customer questions next, and I " + "will answer questions only based on the information in the knowledge base." + "If I feel that you can't get the answer from the reference material, then I will directly reply that" + " I don't know, and I won't tell you that this is from the knowledge base," + "but pretend to be what I know. Note that each of my replies will be replied in the tone of a " + "professional guide", + + store: Optional[str] = None + def __init__( self, - name="Xiaomei", - profile="Retail sales guide", - desc="I am a sales guide in retail. My name is Xiaomei. I will answer some customer questions next, and I " - "will answer questions only based on the information in the knowledge base." - "If I feel that you can't get the answer from the reference material, then I will directly reply that" - " I don't know, and I won't tell you that this is from the knowledge base," - "but pretend to be what I know. Note that each of my replies will be replied in the tone of a " - "professional guide", - store=None + **kwargs ): - super().__init__(name, profile, desc=desc) - self._set_store(store) + super().__init__(**kwargs) + self._set_store(self.store) def _set_store(self, store): if store: @@ -32,4 +40,3 @@ class Sales(Role): else: action = SearchAndSummarize() self._init_actions([action]) - \ No newline at end of file diff --git a/metagpt/roles/seacher.py b/metagpt/roles/seacher.py index 0b6e089da..e8f291d0d 100644 --- a/metagpt/roles/seacher.py +++ b/metagpt/roles/seacher.py @@ -5,6 +5,9 @@ @Author : alexanderwu @File : seacher.py """ + +from pydantic import Field + from metagpt.actions import ActionOutput, SearchAndSummarize from metagpt.logs import logger from metagpt.roles import Role @@ -23,14 +26,14 @@ class Searcher(Role): constraints (str): Constraints or limitations for the searcher. engine (SearchEngineType): The type of search engine to use. """ + + name: str = Field(default="Alice") + profile: str = Field(default="Smart Assistant") + goal: str = "Provide search services for users" + constraints: str = "Answer is rich and complete" + engine: SearchEngineType = SearchEngineType.SERPAPI_GOOGLE - def __init__(self, - name: str = 'Alice', - profile: str = 'Smart Assistant', - goal: str = 'Provide search services for users', - constraints: str = 'Answer is rich and complete', - engine=SearchEngineType.SERPAPI_GOOGLE, - **kwargs) -> None: + def __init__(self, **kwargs) -> None: """ Initializes the Searcher role with given attributes. @@ -41,8 +44,8 @@ class Searcher(Role): constraints (str): Constraints or limitations for the searcher. engine (SearchEngineType): The type of search engine to use. """ - super().__init__(name, profile, goal, constraints, **kwargs) - self._init_actions([SearchAndSummarize(engine=engine)]) + super().__init__(**kwargs) + self._init_actions([SearchAndSummarize(engine=self.engine)]) def set_search_func(self, search_func): """Sets a custom search function for the searcher.""" diff --git a/metagpt/schema.py b/metagpt/schema.py index 3374a7241..60aa819b0 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -5,18 +5,17 @@ @Author : alexanderwu @File : schema.py """ -from __future__ import annotations from dataclasses import dataclass, field -from typing import Type, TypedDict -import copy +from typing import Type, TypedDict, Union, Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field +from pydantic.main import ModelMetaclass from metagpt.logs import logger -# from metagpt.utils.serialize import actionoutout_schema_to_mapping -# from metagpt.actions.action_output import ActionOutput -# from metagpt.actions.action import Action +from metagpt.utils.serialize import actionoutout_schema_to_mapping, actionoutput_mapping_to_str, \ + actionoutput_str_to_mapping +from metagpt.utils.utils import import_class class RawMessage(TypedDict): @@ -24,16 +23,72 @@ class RawMessage(TypedDict): role: str -@dataclass -class Message: - """list[: ]""" - content: str - instruct_content: BaseModel = field(default=None) - role: str = field(default='user') # system / user / assistant - cause_by: Type["Action"] = field(default="") - sent_from: str = field(default="") - send_to: str = field(default="") - restricted_to: str = field(default="") +class Message(BaseModel): + content: str = "" + instruct_content: BaseModel = Field(default=None) + role: str = "user" # system / user / assistant + cause_by: Type["Action"] = Field(default=None) + sent_from: str = "" + send_to: str = "" + restricted_to: str = "" + + def __init__(self, **kwargs): + instruct_content = kwargs.get("instruct_content", None) + cause_by = kwargs.get("cause_by", None) + if instruct_content and not isinstance(instruct_content, BaseModel): + ic = instruct_content + mapping = actionoutput_str_to_mapping(ic["mapping"]) + + actionoutput_class = import_class("ActionOutput", "metagpt.actions.action_output") + ic_obj = actionoutput_class.create_model_class(class_name=ic["class"], mapping=mapping) + ic_new = ic_obj(**ic["value"]) + kwargs["instruct_content"] = ic_new + if cause_by and not isinstance(cause_by, ModelMetaclass): + action_class = import_class("Action", "metagpt.actions.action") + kwargs["cause_by"] = action_class.deser_class(cause_by) + super(Message, self).__init__(**kwargs) + + def dict(self, + *, + include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + by_alias: bool = False, + skip_defaults: Optional[bool] = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False) -> "DictStrAny": + """ overwrite the `dict` to dump dynamic pydantic model""" + obj_dict = super(Message, self).dict(include=include, + exclude=exclude, + by_alias=by_alias, + skip_defaults=skip_defaults, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none) + ic = self.instruct_content # deal custom-defined action + if ic: + schema = ic.schema() + mapping = actionoutout_schema_to_mapping(schema) + mapping = actionoutput_mapping_to_str(mapping) + + obj_dict["instruct_content"] = {"class": schema["title"], "mapping": mapping, "value": ic.dict()} + cb = self.cause_by + if cb: + obj_dict["cause_by"] = cb.ser_class() + return obj_dict + +# +# +# @dataclass +# class Message: +# """list[: ]""" +# content: str +# instruct_content: BaseModel = field(default=None) +# role: str = field(default='user') # system / user / assistant +# cause_by: Type["Action"] = field(default="") +# sent_from: str = field(default="") +# send_to: str = field(default="") +# restricted_to: str = field(default="") def __str__(self): # prefix = '-'.join([self.role, str(self.cause_by)]) @@ -42,45 +97,16 @@ class Message: def __repr__(self): return self.__str__() - # def serialize(self): - # message_cp: Message = copy.deepcopy(self) - # ic = message_cp.instruct_content - # if ic: - # # model create by pydantic create_model like `pydantic.main.prd`, can't pickle.dump directly - # schema = ic.schema() - # mapping = actionoutout_schema_to_mapping(schema) - # - # message_cp.instruct_content = {"class": schema["title"], "mapping": mapping, "value": ic.dict()} - # cb = message_cp.cause_by - # if cb: - # message_cp.cause_by = cb.serialize() - # - # return message_cp.dict() - # - # @classmethod - # def deserialize(cls, message_dict: dict): - # instruct_content = message_dict.get("instruct_content") - # if instruct_content: - # ic = instruct_content - # ic_obj = ActionOutput.create_model_class(class_name=ic["class"], mapping=ic["mapping"]) - # ic_new = ic_obj(**ic["value"]) - # message_dict.instruct_content = ic_new - # cause_by = message_dict.get("cause_by") - # if cause_by: - # message_dict.cause_by = Action.deserialize(cause_by) - # - # return Message(**message_dict) - - def dict(self): - return { - "content": self.content, - "instruct_content": self.instruct_content, - "role": self.role, - "cause_by": self.cause_by, - "sent_from": self.sent_from, - "send_to": self.send_to, - "restricted_to": self.restricted_to - } + # def dict(self): + # return { + # "content": self.content, + # "instruct_content": self.instruct_content, + # "role": self.role, + # "cause_by": self.cause_by, + # "sent_from": self.sent_from, + # "send_to": self.send_to, + # "restricted_to": self.restricted_to + # } def to_dict(self) -> dict: return { diff --git a/metagpt/team.py b/metagpt/team.py index 3b76e5ff4..795019b92 100644 --- a/metagpt/team.py +++ b/metagpt/team.py @@ -15,7 +15,8 @@ from metagpt.logs import logger from metagpt.roles import Role from metagpt.schema import Message from metagpt.utils.common import NoMoneyException -from metagpt.utils.utils import read_json_file, write_json_file +from metagpt.utils.utils import read_json_file, write_json_file, serialize_decorator +from metagpt.const import SERDESER_PATH class Team(BaseModel): @@ -30,29 +31,35 @@ class Team(BaseModel): class Config: arbitrary_types_allowed = True - def serialize(self, stg_path: Path): + def serialize(self, stg_path: Path = None): + stg_path = SERDESER_PATH.joinpath("team") if stg_path is None else stg_path + team_info_path = stg_path.joinpath("team_info.json") - write_json_file(team_info_path, { - "idea": self.idea, - "investment": self.investment - }) + write_json_file(team_info_path, self.dict(exclude={"environment": True})) - self.environment.serialize(stg_path.joinpath("environment")) + self.environment.serialize(stg_path.joinpath("environment")) # save environment alone - def deserialize(self, stg_path: Path): + @classmethod + def recover(cls, stg_path: Path) -> "Team": + return cls.deserialize(stg_path) + + @classmethod + def deserialize(cls, stg_path: Path) -> "Team": """ stg_path = ./storage/team """ # recover team_info team_info_path = stg_path.joinpath("team_info.json") if not team_info_path.exists(): - logger.error("recover storage not exist, not to recover and continue run the old project.") - team_info = read_json_file(team_info_path) - self.investment = team_info.get("investment", 10.0) - self.idea = team_info.get("idea", "") + raise FileNotFoundError("recover storage meta file `team_info.json` not exist, " + "not to recover and please start a new project.") + + team_info: dict = read_json_file(team_info_path) # recover environment - environment_path = stg_path.joinpath("environment") - self.environment = Environment() - self.environment.deserialize(stg_path=environment_path) + environment = Environment.deserialize(stg_path=stg_path.joinpath("environment")) + team_info.update({"environment": environment}) + + team = Team(**team_info) + return team def hire(self, roles: list[Role]): """Hire roles to cooperate""" @@ -76,6 +83,7 @@ class Team(BaseModel): def _save(self): logger.info(self.json()) + @serialize_decorator async def run(self, n_round=3): """Run company until target round or no money""" while n_round > 0: @@ -85,4 +93,3 @@ class Team(BaseModel): self._check_balance() await self.environment.run() return self.environment.history - \ No newline at end of file diff --git a/metagpt/utils/serialize.py b/metagpt/utils/serialize.py index 56a866f2e..9a7049214 100644 --- a/metagpt/utils/serialize.py +++ b/metagpt/utils/serialize.py @@ -5,9 +5,7 @@ import copy import pickle -from metagpt.actions.action_output import ActionOutput -from metagpt.schema import Message -from metagpt.actions.action import Action +from metagpt.utils.utils import import_class def actionoutout_schema_to_mapping(schema: dict) -> dict: @@ -59,7 +57,7 @@ def actionoutput_str_to_mapping(mapping: dict) -> dict: return new_mapping -def serialize_general_message(message: Message) -> dict: +def serialize_general_message(message: "Message") -> dict: """ serialize Message, not to save""" message_cp = copy.deepcopy(message) ic = message_cp.instruct_content @@ -76,7 +74,7 @@ def serialize_general_message(message: Message) -> dict: return message_cp.dict() -def serialize_message(message: Message): +def serialize_message(message: "Message"): message_cp = copy.deepcopy(message) # avoid `instruct_content` value update by reference ic = message_cp.instruct_content if ic: @@ -90,29 +88,35 @@ def serialize_message(message: Message): return msg_ser -def deserialize_general_message(message_dict: dict) -> Message: +def deserialize_general_message(message_dict: dict) -> "Message": """ deserialize Message, not to load""" instruct_content = message_dict.pop("instruct_content") cause_by = message_dict.pop("cause_by") - message = Message(**message_dict) + message_cls = import_class("Message", "metagpt.schema") + message = message_cls(**message_dict) if instruct_content: ic = instruct_content mapping = actionoutput_str_to_mapping(ic["mapping"]) - ic_obj = ActionOutput.create_model_class(class_name=ic["class"], mapping=mapping) + + actionoutput_class = import_class("ActionOutput", "metagpt.actions.action_output") + ic_obj = actionoutput_class.create_model_class(class_name=ic["class"], mapping=mapping) ic_new = ic_obj(**ic["value"]) message.instruct_content = ic_new if cause_by: - message.cause_by = Action.deser_class(cause_by) + action_class = import_class("Action", "metagpt.actions.action") + message.cause_by = action_class.deser_class(cause_by) return message -def deserialize_message(message_ser: str) -> Message: +def deserialize_message(message_ser: str) -> "Message": message = pickle.loads(message_ser) if message.instruct_content: ic = message.instruct_content - ic_obj = ActionOutput.create_model_class(class_name=ic["class"], mapping=ic["mapping"]) + + actionoutput_class = import_class("ActionOutput", "metagpt.actions.action_output") + ic_obj = actionoutput_class.create_model_class(class_name=ic["class"], mapping=ic["mapping"]) ic_new = ic_obj(**ic["value"]) message.instruct_content = ic_new diff --git a/metagpt/utils/utils.py b/metagpt/utils/utils.py index 81ceea884..1cf618ba0 100644 --- a/metagpt/utils/utils.py +++ b/metagpt/utils/utils.py @@ -6,6 +6,9 @@ from typing import Any import json from pathlib import Path import importlib +import traceback + +from metagpt.logs import logger def read_json_file(json_file: str, encoding=None) -> list[Any]: @@ -39,3 +42,43 @@ def import_class_inst(class_name: str, module_name: str, *args, **kwargs) -> obj a_class = import_class(class_name, module_name) class_inst = a_class(*args, **kwargs) return class_inst + + +def format_trackback_info(limit: int = 2): + return traceback.format_exc(limit=limit) + + +def serialize_decorator(func): + async def wrapper(self, *args, **kwargs): + try: + return await func(self, *args, **kwargs) + except KeyboardInterrupt as kbi: + logger.error(f"KeyboardInterrupt occurs, start to serialize the project, exp:\n{format_trackback_info()}") + self.serialize() # Team.serialize + except Exception as exp: + logger.error(f"Exception occurs, start to serialize the project, exp:\n{format_trackback_info()}") + self.serialize() # Team.serialize + + return wrapper + + +def role_raise_decorator(func): + async def wrapper(self, *args, **kwargs): + try: + return await func(self, *args, **kwargs) + except KeyboardInterrupt as kbi: + logger.error(f"KeyboardInterrupt: {kbi} occurs, start to serialize the project") + if self._rc.env: + newest_msgs = self._rc.env.memory.get(1) + if len(newest_msgs) > 0: + self._rc.memory.delete(newest_msgs[0]) + except Exception as exp: + if self._rc.env: + newest_msgs = self._rc.env.memory.get(1) + if len(newest_msgs) > 0: + logger.warning("There is a exception in role's execution, in order to resume, " + "we delete the newest role communication message in the role's memory.") + self._rc.memory.delete(newest_msgs[0]) # remove newest msg of the role to make it observed again + raise Exception(format_trackback_info(limit=None)) # raise again to make it captured outside + + return wrapper diff --git a/startup.py b/startup.py index 9f753d553..c4928a1b5 100644 --- a/startup.py +++ b/startup.py @@ -1,10 +1,11 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- + +from typing import Optional import asyncio - import fire +from pathlib import Path -from metagpt.const import SERDES_PATH from metagpt.roles import ( Architect, Engineer, @@ -22,11 +23,11 @@ async def startup( code_review: bool = False, run_tests: bool = False, implement: bool = True, - recover_path: bool = False, + recover_path: Optional[str] = None, ): """Run a startup. Be a boss.""" - company = Team() if not recover_path: + company = Team() company.hire( [ ProductManager(), @@ -45,8 +46,12 @@ async def startup( # (bug fixing capability comes soon!) company.hire([QaEngineer()]) else: - stg_path = SERDES_PATH.joinpath("team") - company.deserialize(stg_path=stg_path) + # # stg_path = SERDESER_PATH.joinpath("team") + stg_path = Path(recover_path) + if not stg_path.exists() or not str(stg_path).endswith("team"): + raise FileNotFoundError(f"{recover_path} not exists or not endswith `team`") + + company = Team.recover(stg_path=stg_path) idea = company.idea # use original idea company.invest(investment) From caacfcff7a541b7e69928cb0ed078fd98f89b55b Mon Sep 17 00:00:00 2001 From: better629 Date: Thu, 30 Nov 2023 19:30:02 +0800 Subject: [PATCH 008/167] fix ut of serialize_deserialize --- .../serialize_deserialize/test_action.py | 3 +-- .../test_product_manager.py | 1 - .../serialize_deserialize/test_role.py | 10 ++++++++- .../test_serdeser_base.py | 21 +++++++++++++------ .../serialize_deserialize/test_team.py | 2 +- .../serialize_deserialize/test_wrire_prd.py | 4 ++-- .../serialize_deserialize/test_write_code.py | 2 -- .../test_write_design.py | 3 +-- 8 files changed, 29 insertions(+), 17 deletions(-) diff --git a/tests/metagpt/serialize_deserialize/test_action.py b/tests/metagpt/serialize_deserialize/test_action.py index b624dff5a..0138d41ce 100644 --- a/tests/metagpt/serialize_deserialize/test_action.py +++ b/tests/metagpt/serialize_deserialize/test_action.py @@ -13,14 +13,13 @@ def test_action_serialize(): action = Action() ser_action_dict = action.dict() assert "name" in ser_action_dict - assert "llm" in ser_action_dict + assert "llm" not in ser_action_dict @pytest.mark.asyncio async def test_action_deserialize(): action = Action() serialized_data = action.dict() - assert isinstance(serialized_data["llm"], OpenAIGPTAPI) new_action = Action(**serialized_data) diff --git a/tests/metagpt/serialize_deserialize/test_product_manager.py b/tests/metagpt/serialize_deserialize/test_product_manager.py index 54584cf96..25bc07a11 100644 --- a/tests/metagpt/serialize_deserialize/test_product_manager.py +++ b/tests/metagpt/serialize_deserialize/test_product_manager.py @@ -14,7 +14,6 @@ async def test_product_manager_deserialize(): role = ProductManager() ser_role_dict = role.dict(by_alias=True) new_role = ProductManager(**ser_role_dict) - # new_role = ProductManager().deserialize(ser_role_dict) assert new_role.name == "Alice" assert len(new_role._actions) == 1 diff --git a/tests/metagpt/serialize_deserialize/test_role.py b/tests/metagpt/serialize_deserialize/test_role.py index f260dea3a..c21b9cc2e 100644 --- a/tests/metagpt/serialize_deserialize/test_role.py +++ b/tests/metagpt/serialize_deserialize/test_role.py @@ -17,7 +17,15 @@ from metagpt.const import SERDESER_PATH from metagpt.roles.engineer import Engineer from metagpt.utils.utils import format_trackback_info -from tests.metagpt.serialize_deserialize.test_serdeser_base import RoleC, serdeser_path +from tests.metagpt.serialize_deserialize.test_serdeser_base import RoleA, RoleB, RoleC, serdeser_path + + +def test_roles(): + role_a = RoleA() + assert len(role_a._rc.watch) == 1 + role_b = RoleB() + assert len(role_a._rc.watch) == 1 + assert len(role_b._rc.watch) == 1 def test_role_serialize(): diff --git a/tests/metagpt/serialize_deserialize/test_serdeser_base.py b/tests/metagpt/serialize_deserialize/test_serdeser_base.py index 35bad6cd9..00d894b3d 100644 --- a/tests/metagpt/serialize_deserialize/test_serdeser_base.py +++ b/tests/metagpt/serialize_deserialize/test_serdeser_base.py @@ -8,6 +8,7 @@ from pathlib import Path from metagpt.actions.action import Action from metagpt.roles.role import Role, RoleReactMode from metagpt.actions.add_requirement import BossRequirement +from metagpt.actions.action_output import ActionOutput serdeser_path = Path(__file__).absolute().parent.joinpath("../../data/serdeser_storage") @@ -22,21 +23,27 @@ class MockMessage(BaseModel): class ActionPass(Action): name: str = "ActionPass" - async def run(self, messages: list["Message"]): - return "pass" + async def run(self, messages: list["Message"]) -> ActionOutput: + output_mapping = { + "result": (str, ...) + } + pass_class = ActionOutput.create_model_class("pass", output_mapping) + pass_output = ActionOutput("ActionPass run passed", pass_class(**{"result": "pass result"})) + + return pass_output class ActionOK(Action): name: str = "ActionOK" - async def run(self, messages: list["Message"]): + async def run(self, messages: list["Message"]) -> str: return "ok" class ActionRaise(Action): name: str = "ActionRaise" - async def run(self, messages: list["Message"]): + async def run(self, messages: list["Message"]) -> str: raise RuntimeError("parse error in ActionRaise") @@ -48,7 +55,8 @@ class RoleA(Role): constraints: str = "RoleA's constraints" def __init__(self, **kwargs): - super(RoleA, self).__init__(**kwargs) + # super(RoleA, self).__init__(**kwargs) + super().__init__(**kwargs) self._init_actions([ActionPass]) self._watch([BossRequirement]) @@ -63,7 +71,8 @@ class RoleB(Role): constraints: str = "RoleB's constraints" def __init__(self, **kwargs): - super(RoleB, self).__init__(**kwargs) + # super(RoleB, self).__init__(**kwargs) + super().__init__(**kwargs) self._init_actions([ActionOK, ActionRaise]) self._watch([ActionPass]) self._rc.react_mode = RoleReactMode.BY_ORDER diff --git a/tests/metagpt/serialize_deserialize/test_team.py b/tests/metagpt/serialize_deserialize/test_team.py index e9122ebc0..b8972135b 100644 --- a/tests/metagpt/serialize_deserialize/test_team.py +++ b/tests/metagpt/serialize_deserialize/test_team.py @@ -11,7 +11,7 @@ from metagpt.roles import ProjectManager, ProductManager, Architect from metagpt.team import Team from metagpt.const import SERDESER_PATH -from tests.metagpt.serialize_deserialize.test_serdeser_base import RoleA, RoleB, RoleC, serdeser_path +from tests.metagpt.serialize_deserialize.test_serdeser_base import RoleA, RoleB, RoleC, serdeser_path, ActionOK def test_team_deserialize(): diff --git a/tests/metagpt/serialize_deserialize/test_wrire_prd.py b/tests/metagpt/serialize_deserialize/test_wrire_prd.py index 96b4d19ad..05a86cb7f 100644 --- a/tests/metagpt/serialize_deserialize/test_wrire_prd.py +++ b/tests/metagpt/serialize_deserialize/test_wrire_prd.py @@ -21,7 +21,7 @@ async def test_action_deserialize(): action = WritePRD() serialized_data = action.dict() new_action = WritePRD(**serialized_data) - # new_action = WritePRD().deserialize(serialized_data) assert new_action.name == "" assert new_action.llm == LLM() - assert len(await new_action.run([Message(content="write a cli snake game")])) > 0 + action_output = await new_action.run([Message(content="write a cli snake game")]) + assert len(action_output.content) > 0 diff --git a/tests/metagpt/serialize_deserialize/test_write_code.py b/tests/metagpt/serialize_deserialize/test_write_code.py index 7f4799014..4e3b712c0 100644 --- a/tests/metagpt/serialize_deserialize/test_write_code.py +++ b/tests/metagpt/serialize_deserialize/test_write_code.py @@ -27,7 +27,6 @@ async def test_write_code_deserialize(): action = WriteCode() serialized_data = action.dict() new_action = WriteCode(**serialized_data) - # new_action = WriteCode().deserialize(serialized_data) assert new_action.name == "WriteCode" assert new_action.llm == LLM() await new_action.run(context="write a cli snake game", filename="test_code") @@ -38,7 +37,6 @@ async def test_write_code_review_deserialize(): action = WriteCodeReview() serialized_data = action.dict() new_action = WriteCodeReview(**serialized_data) - # new_action = WriteCodeReview().deserialize(serialized_data) code = await WriteCode().run(context="write a cli snake game", filename="test_code") assert new_action.name == "WriteCodeReview" diff --git a/tests/metagpt/serialize_deserialize/test_write_design.py b/tests/metagpt/serialize_deserialize/test_write_design.py index e6e236676..5b2a30ed3 100644 --- a/tests/metagpt/serialize_deserialize/test_write_design.py +++ b/tests/metagpt/serialize_deserialize/test_write_design.py @@ -26,7 +26,7 @@ def test_write_task_serialize(): async def test_write_design_deserialize(): action = WriteDesign() serialized_data = action.dict() - new_action = WriteDesign().deserialize(serialized_data) + new_action = WriteDesign(**serialized_data) assert new_action.name == "" assert new_action.llm == LLM() await new_action.run(context="write a cli snake game") @@ -37,7 +37,6 @@ async def test_write_task_deserialize(): action = WriteTasks() serialized_data = action.dict() new_action = WriteTasks(**serialized_data) - # new_action = WriteTasks().deserialize(serialized_data) assert new_action.name == "CreateTasks" assert new_action.llm == LLM() await new_action.run(context="write a cli snake game") From c70c8358d334d8297a0a33b95223d604c84096cd Mon Sep 17 00:00:00 2001 From: better629 Date: Thu, 30 Nov 2023 19:31:26 +0800 Subject: [PATCH 009/167] fix actions/roles ser&deser --- metagpt/actions/search_and_summarize.py | 16 +++++++--------- metagpt/actions/write_prd.py | 15 ++++++--------- metagpt/roles/role.py | 20 ++++++++++++++------ metagpt/utils/utils.py | 4 +++- 4 files changed, 30 insertions(+), 25 deletions(-) diff --git a/metagpt/actions/search_and_summarize.py b/metagpt/actions/search_and_summarize.py index 6b0c1f717..32444b302 100644 --- a/metagpt/actions/search_and_summarize.py +++ b/metagpt/actions/search_and_summarize.py @@ -117,23 +117,21 @@ class SearchAndSummarize(Action): @root_validator def validate_engine_and_run_func(cls, values): - engine = values.get('engine') - search_func = values.get('search_func') + engine = values.get("engine") + search_func = values.get("search_func") config = Config() if engine is None: engine = config.search_engine - config_data = { - 'engine': engine, - 'run_func': search_func - } - search_engine = SearchEngine(**config_data) + try: + search_engine = SearchEngine(engine=engine, run_func=search_func) + except pydantic.ValidationError: + search_engine = None - values['search_engine'] = search_engine + values["search_engine"] = search_engine return values async def run(self, context: list[Message], system_text=SEARCH_AND_SUMMARIZE_SYSTEM) -> str: - print(context) if self.search_engine is None: logger.warning("Configure one of SERPAPI_API_KEY, SERPER_API_KEY, GOOGLE_API_KEY to unlock full feature") return "" diff --git a/metagpt/actions/write_prd.py b/metagpt/actions/write_prd.py index 450bed7e7..86f0ad9a6 100644 --- a/metagpt/actions/write_prd.py +++ b/metagpt/actions/write_prd.py @@ -226,17 +226,14 @@ class WritePRD(Action): name: str = "" content: Optional[str] = None llm: BaseGPTAPI = Field(default_factory=LLM) - assistant_search_action: Action = None async def run(self, requirements, format=CONFIG.prompt_format, *args, **kwargs) -> ActionOutput: - # self.assistant_search_action = SearchAndSummarize() - if self.assistant_search_action is None: - self.assistant_search_action = SearchAndSummarize() - # self.assistant_search_action = SearchAndSummarize() - rsp = await self.assistant_search_action.run(context=requirements) - info = f"### Search Results\n{self.assistant_search_action.result}\n\n### Search Summary\n{rsp}" - if self.assistant_search_action.result: - logger.info(self.assistant_search_action.result) + sas = SearchAndSummarize() + # rsp = await sas.run(context=requirements, system_text=SEARCH_AND_SUMMARIZE_SYSTEM_EN_US) + rsp = "" + info = f"### Search Results\n{sas.result}\n\n### Search Summary\n{rsp}" + if sas.result: + logger.info(sas.result) logger.info(rsp) prompt_template, format_example = get_template(templates, format) diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index b6332aa4c..38f564caa 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -88,7 +88,7 @@ class RoleSetting(BaseModel): class RoleContext(BaseModel): """Role Runtime Context""" - env: "Environment" = Field(default=None) + env: "Environment" = Field(default=None, exclude=True) memory: Memory = Field(default_factory=Memory) long_term_memory: LongTermMemory = Field(default_factory=LongTermMemory) state: int = Field(default=-1) # -1 indicates initial or termination state where todo is None @@ -133,7 +133,7 @@ class Role(BaseModel): _role_id: str = "" _states: list[str] = Field(default=[]) _actions: list[Action] = Field(default=[]) - _rc: RoleContext = RoleContext() + _rc: RoleContext = Field(default=RoleContext, exclude=True) # builtin variables recovered: bool = False # to tag if a recovered role @@ -143,7 +143,8 @@ class Role(BaseModel): "_llm": LLM() if not is_human else HumanProvider(), "_role_id": _role_id, "_states": [], - "_actions": [] + "_actions": [], + "_rc": RoleContext() } class Config: @@ -169,6 +170,8 @@ class Role(BaseModel): self._private_attributes["_setting"] = RoleSetting(name=self.name, profile=self.profile, goal=self.goal, desc=self.desc, constraints=self.constraints, is_human=self.is_human) + self._private_attributes["_role_id"] = str(self._setting) + for key in self._private_attributes.keys(): if key in kwargs: object.__setattr__(self, key, kwargs[key]) @@ -176,10 +179,15 @@ class Role(BaseModel): setting = RoleSetting(**kwargs[key]) object.__setattr__(self, "_setting", setting) elif key == "_rc": - _rc = RoleContext + _rc = RoleContext() object.__setattr__(self, "_rc", _rc) else: - object.__setattr__(self, key, self._private_attributes[key]) + if key == "_rc": + # # Warning, if use self._private_attributes["_rc"], + # # self._rc will be a shared object between roles, so init one or reset it inside `_reset` + object.__setattr__(self, key, RoleContext()) + else: + object.__setattr__(self, key, self._private_attributes[key]) # deserialize child classes dynamically for inherited `role` object.__setattr__(self, "builtin_class_name", self.__class__.__name__) @@ -192,6 +200,7 @@ class Role(BaseModel): def _reset(self): object.__setattr__(self, "_states", []) object.__setattr__(self, "_actions", []) + # object.__setattr__(self, "_rc", RoleContext()) def serialize(self, stg_path: Path = None): stg_path = SERDESER_PATH.joinpath(f"team/environment/roles/{self.__class__.__name__}_{self.name}") \ @@ -289,7 +298,6 @@ class Role(BaseModel): for idx, action in enumerate(actions): if not isinstance(action, Action): ## 默认初始化 - # import pdb; pdb.set_trace() i = action(name="", llm=self._llm) else: if self._setting.is_human and not isinstance(action.llm, HumanProvider): diff --git a/metagpt/utils/utils.py b/metagpt/utils/utils.py index 1cf618ba0..b72dabf7e 100644 --- a/metagpt/utils/utils.py +++ b/metagpt/utils/utils.py @@ -51,7 +51,9 @@ def format_trackback_info(limit: int = 2): def serialize_decorator(func): async def wrapper(self, *args, **kwargs): try: - return await func(self, *args, **kwargs) + result = await func(self, *args, **kwargs) + self.serialize() # Team.serialize + return result except KeyboardInterrupt as kbi: logger.error(f"KeyboardInterrupt occurs, start to serialize the project, exp:\n{format_trackback_info()}") self.serialize() # Team.serialize From 6208400f71ee926ed422aed9ed2cc160d7a0de4e Mon Sep 17 00:00:00 2001 From: better629 Date: Thu, 30 Nov 2023 21:42:09 +0800 Subject: [PATCH 010/167] fix role._rc init --- metagpt/environment.py | 4 ++++ metagpt/roles/role.py | 11 ++++++----- .../serialize_deserialize/test_team.py | 19 ++++++++++++++++--- 3 files changed, 26 insertions(+), 8 deletions(-) diff --git a/metagpt/environment.py b/metagpt/environment.py index bade53f50..bff12210d 100644 --- a/metagpt/environment.py +++ b/metagpt/environment.py @@ -31,6 +31,7 @@ class Environment(BaseModel): arbitrary_types_allowed = True def __init__(self, **kwargs): + roles = [] for role_key, role in kwargs.get("roles", {}).items(): current_role = kwargs["roles"][role_key] if isinstance(current_role, dict): @@ -41,8 +42,11 @@ class Environment(BaseModel): current_role = subclass(**current_role) break kwargs["roles"][role_key] = current_role + roles.append(current_role) super().__init__(**kwargs) + self.add_roles(roles) # add_roles again to init the Role.set_env + def serialize(self, stg_path: Path): roles_path = stg_path.joinpath("roles.json") roles_info = [] diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 38f564caa..b78597d01 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -88,13 +88,14 @@ class RoleSetting(BaseModel): class RoleContext(BaseModel): """Role Runtime Context""" + # # env exclude=True to avoid `RecursionError: maximum recursion depth exceeded in comparison` env: "Environment" = Field(default=None, exclude=True) memory: Memory = Field(default_factory=Memory) - long_term_memory: LongTermMemory = Field(default_factory=LongTermMemory) + long_term_memory: LongTermMemory = Field(default_factory=LongTermMemory, exclude=True) # TODO not used now state: int = Field(default=-1) # -1 indicates initial or termination state where todo is None todo: Action = Field(default=None) watch: set[Type[Action]] = Field(default_factory=set) - news: list[Type[Message]] = Field(default=[]) + news: list[Type[Message]] = Field(default=[], exclude=True) # TODO not used react_mode: RoleReactMode = RoleReactMode.REACT # see `Role._set_react_mode` for definitions of the following two attributes max_react_loop: int = 1 @@ -128,12 +129,12 @@ class Role(BaseModel): desc: str = "" is_human: bool = False - _llm: BaseGPTAPI = Field(default_factory=LLM, exclude=True) + _llm: BaseGPTAPI = Field(default_factory=LLM) _setting: RoleSetting = Field(default_factory=RoleSetting, alias=True) _role_id: str = "" _states: list[str] = Field(default=[]) _actions: list[Action] = Field(default=[]) - _rc: RoleContext = Field(default=RoleContext, exclude=True) + _rc: RoleContext = Field(default=RoleContext) # builtin variables recovered: bool = False # to tag if a recovered role @@ -179,7 +180,7 @@ class Role(BaseModel): setting = RoleSetting(**kwargs[key]) object.__setattr__(self, "_setting", setting) elif key == "_rc": - _rc = RoleContext() + _rc = RoleContext(**kwargs["_rc"]) object.__setattr__(self, "_rc", _rc) else: if key == "_rc": diff --git a/tests/metagpt/serialize_deserialize/test_team.py b/tests/metagpt/serialize_deserialize/test_team.py index b8972135b..e5ec20f2e 100644 --- a/tests/metagpt/serialize_deserialize/test_team.py +++ b/tests/metagpt/serialize_deserialize/test_team.py @@ -39,7 +39,7 @@ def test_team_deserialize(): assert new_company.environment.get_role(arch.profile) is not None -def test_team_serdeser(): +def test_team_serdeser_save(): company = Team() company.hire([RoleC()]) @@ -60,12 +60,19 @@ async def test_team_recover(): shutil.rmtree(stg_path, ignore_errors=True) company = Team() - company.hire([RoleC()]) + role_c = RoleC() + company.hire([role_c]) company.start_project(idea) await company.run(n_round=4) ser_data = company.dict() new_company = Team(**ser_data) + + new_role_c = new_company.environment.get_role(role_c.profile) + assert new_role_c._rc.memory == role_c._rc.memory + assert new_role_c._rc.env != role_c._rc.env # due to Action raise, role's memory has been changed. + assert new_role_c._rc.env.memory == role_c._rc.env.memory + assert new_company.environment.memory.count() == 1 assert type(list(new_company.environment.roles.values())[0]._actions[0]) == ActionOK @@ -80,11 +87,17 @@ async def test_team_recover_save(): shutil.rmtree(stg_path, ignore_errors=True) company = Team() - company.hire([RoleC()]) + role_c = RoleC() + company.hire([role_c]) company.start_project(idea) await company.run(n_round=4) new_company = Team.recover(stg_path) + new_role_c = new_company.environment.get_role(role_c.profile) + assert new_role_c._rc.memory == role_c._rc.memory + assert new_role_c._rc.env != role_c._rc.env # due to Action raise, role's memory has been changed. + assert new_role_c._rc.env.memory == role_c._rc.env.memory + new_company.start_project(idea) await new_company.run(n_round=4) From f563b2c60809d45db87387956586acd18ddc9201 Mon Sep 17 00:00:00 2001 From: better629 Date: Fri, 1 Dec 2023 14:43:45 +0800 Subject: [PATCH 011/167] simplify some ser&desr code --- metagpt/actions/action.py | 20 ++----- metagpt/environment.py | 6 +- metagpt/memory/memory.py | 18 +----- metagpt/roles/role.py | 114 ++++++++++++++------------------------ metagpt/schema.py | 42 +------------- 5 files changed, 54 insertions(+), 146 deletions(-) diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index 7a7f194f4..692a2a6e5 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -52,6 +52,12 @@ class Action(BaseModel): super().__init_subclass__(**kwargs) action_subclass_registry[cls.__name__] = cls + def dict(self, *args, **kwargs) -> "DictStrAny": + obj_dict = super(Action, self).dict(*args, **kwargs) + if "llm" in obj_dict: + obj_dict.pop("llm") + return obj_dict + def set_prefix(self, prefix, profile): """Set prefix for later usage""" self.prefix = prefix @@ -63,20 +69,6 @@ class Action(BaseModel): def __repr__(self): return self.__str__() - def serialize(self): - return { - "action_class": self.__class__.__name__, - "module_name": self.__module__, - "name": self.name - } - - @classmethod - def deserialize(cls, action_dict: dict) -> "Action": - action_class_str = action_dict.pop("action_class") - module_name = action_dict.pop("module_name") - action_class = import_class(action_class_str, module_name) - return action_class(**action_dict) - @classmethod def ser_class(cls) -> dict: """ serialize class type""" diff --git a/metagpt/environment.py b/metagpt/environment.py index bff12210d..3174cfc10 100644 --- a/metagpt/environment.py +++ b/metagpt/environment.py @@ -70,10 +70,8 @@ class Environment(BaseModel): roles_info = read_json_file(roles_path) roles = [] for role_info in roles_info: - role_class = role_info.get("role_class") - role_name = role_info.get("role_name") - - role_path = stg_path.joinpath(f"roles/{role_class}_{role_name}") + # role stored in ./environment/roles/{role_class}_{role_name} + role_path = stg_path.joinpath(f'roles/{role_info.get("role_class")}_{role_info.get("role_name")}') role = Role.deserialize(role_path) roles.append(role) diff --git a/metagpt/memory/memory.py b/metagpt/memory/memory.py index c88cc750e..ed30cde18 100644 --- a/metagpt/memory/memory.py +++ b/metagpt/memory/memory.py @@ -36,23 +36,9 @@ class Memory(BaseModel): super(Memory, self).__init__(**kwargs) self.index = new_index - def dict(self, - *, - include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, - exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, - by_alias: bool = False, - skip_defaults: Optional[bool] = None, - exclude_unset: bool = False, - exclude_defaults: bool = False, - exclude_none: bool = False) -> "DictStrAny": + def dict(self, *args, **kwargs) -> "DictStrAny": """ overwrite the `dict` to dump dynamic pydantic model""" - obj_dict = super(Memory, self).dict(include=include, - exclude=exclude, - by_alias=by_alias, - skip_defaults=skip_defaults, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none) + obj_dict = super(Memory, self).dict(*args, **kwargs) new_obj_dict = copy.deepcopy(obj_dict) new_obj_dict["index"] = {} for action, value in obj_dict["index"].items(): diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index b78597d01..4e669772e 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -93,7 +93,7 @@ class RoleContext(BaseModel): memory: Memory = Field(default_factory=Memory) long_term_memory: LongTermMemory = Field(default_factory=LongTermMemory, exclude=True) # TODO not used now state: int = Field(default=-1) # -1 indicates initial or termination state where todo is None - todo: Action = Field(default=None) + todo: Action = Field(default=None, exclude=True) watch: set[Type[Action]] = Field(default_factory=set) news: list[Type[Message]] = Field(default=[], exclude=True) # TODO not used react_mode: RoleReactMode = RoleReactMode.REACT # see `Role._set_react_mode` for definitions of the following two attributes @@ -101,7 +101,25 @@ class RoleContext(BaseModel): class Config: arbitrary_types_allowed = True - + + def __init__(self, **kwargs): + watch_info = kwargs.get("watch", set()) + watch = set() + for item in watch_info: + action = Action.deser_class(item) + watch.update([action]) + kwargs["watch"] = watch + super(RoleContext, self).__init__(**kwargs) + + def dict(self, *args, **kwargs) -> "DictStrAny": + obj_dict = super(RoleContext, self).dict(*args, **kwargs) + watch = obj_dict.get("watch", set()) + watch_info = [] + for item in watch: + watch_info.append(item.ser_class()) + obj_dict["watch"] = watch_info + return obj_dict + def check(self, role_id: str): if hasattr(CONFIG, "long_term_memory") and CONFIG.long_term_memory: self.long_term_memory.recover_memory(role_id, self) @@ -130,7 +148,6 @@ class Role(BaseModel): is_human: bool = False _llm: BaseGPTAPI = Field(default_factory=LLM) - _setting: RoleSetting = Field(default_factory=RoleSetting, alias=True) _role_id: str = "" _states: list[str] = Field(default=[]) _actions: list[Action] = Field(default=[]) @@ -168,18 +185,12 @@ class Role(BaseModel): # 关于私有变量的初始化 https://github.com/pydantic/pydantic/issues/655 self._private_attributes["_llm"] = LLM() if not self.is_human else HumanProvider() - self._private_attributes["_setting"] = RoleSetting(name=self.name, profile=self.profile, goal=self.goal, - desc=self.desc, constraints=self.constraints, - is_human=self.is_human) self._private_attributes["_role_id"] = str(self._setting) for key in self._private_attributes.keys(): if key in kwargs: object.__setattr__(self, key, kwargs[key]) - if key == "_setting": - setting = RoleSetting(**kwargs[key]) - object.__setattr__(self, "_setting", setting) - elif key == "_rc": + if key == "_rc": _rc = RoleContext(**kwargs["_rc"]) object.__setattr__(self, "_rc", _rc) else: @@ -203,41 +214,23 @@ class Role(BaseModel): object.__setattr__(self, "_actions", []) # object.__setattr__(self, "_rc", RoleContext()) + @property + def _setting(self): + return f"{self.name}({self.profile})" + def serialize(self, stg_path: Path = None): stg_path = SERDESER_PATH.joinpath(f"team/environment/roles/{self.__class__.__name__}_{self.name}") \ if stg_path is None else stg_path - role_info_path = stg_path.joinpath("role_info.json") - role_info = { + + role_info = self.dict(exclude={"_rc": {"memory": True}, "_llm": True}) + role_info.update({ "role_class": self.__class__.__name__, "module_name": self.__module__ - } - setting = self._setting.dict() - setting.pop("desc") - setting.pop("is_human") # not all inherited roles have this atrr - role_info.update(setting) + }) + role_info_path = stg_path.joinpath("role_info.json") write_json_file(role_info_path, role_info) - actions_info_path = stg_path.joinpath("actions/actions_info.json") - actions_info = [] - for action in self._actions: - actions_info.append(action.serialize()) - write_json_file(actions_info_path, actions_info) - - watches_info_path = stg_path.joinpath("watches/watches_info.json") - watches_info = [] - for watch in self._rc.watch: - watches_info.append(watch.ser_class()) - write_json_file(watches_info_path, watches_info) - - actions_todo_path = stg_path.joinpath("actions/todo.json") - actions_todo = { - "cur_state": self._rc.state, - "react_mode": self._rc.react_mode.value, - "max_react_loop": self._rc.max_react_loop - } - write_json_file(actions_todo_path, actions_todo) - - self._rc.memory.serialize(stg_path) + self._rc.memory.serialize(stg_path) # serialize role's memory alone @classmethod def deserialize(cls, stg_path: Path) -> "Role": @@ -250,35 +243,7 @@ class Role(BaseModel): role_class = import_class(class_name=role_class_str, module_name=module_name) role = role_class(**role_info) # initiate particular Role - actions_info_path = stg_path.joinpath("actions/actions_info.json") - actions = [] - actions_info = read_json_file(actions_info_path) - for action_info in actions_info: - action = Action.deser_class(action_info) - actions.append(action) - - watches_info_path = stg_path.joinpath("watches/watches_info.json") - watches = [] - watches_info = read_json_file(watches_info_path) - for watch_info in watches_info: - action = Action.deser_class(watch_info) - watches.append(action) - - role.init_actions(actions) - role.watch(watches) - - actions_todo_path = stg_path.joinpath("actions/todo.json") - # recover self._rc.state - actions_todo = read_json_file(actions_todo_path) - max_react_loop = actions_todo.get("max_react_loop", 1) - cur_state = actions_todo.get("cur_state", -1) - role.set_state(cur_state) - role.set_recovered(True) - react_mode_str = actions_todo.get("react_mode", RoleReactMode.REACT.value) - if react_mode_str not in RoleReactMode.values(): - logger.warning(f"ReactMode: {react_mode_str} not in {RoleReactMode.values()}, use react as default") - react_mode_str = RoleReactMode.REACT.value - role.set_react_mode(RoleReactMode(react_mode_str), max_react_loop) + role.set_recovered(True) # set True to make a tag role_memory = Memory.deserialize(stg_path) role.set_memory(role_memory) @@ -299,9 +264,9 @@ class Role(BaseModel): for idx, action in enumerate(actions): if not isinstance(action, Action): ## 默认初始化 - i = action(name="", llm=self._llm) + i = action(llm=self._llm) else: - if self._setting.is_human and not isinstance(action.llm, HumanProvider): + if self.is_human and not isinstance(action.llm, HumanProvider): logger.warning(f"is_human attribute does not take effect," f"as Role's {str(action)} was initialized using LLM, try passing in Action classes instead of initialized instances") i = action @@ -357,9 +322,14 @@ class Role(BaseModel): def _get_prefix(self): """Get the role prefix""" - if self._setting.desc: - return self._setting.desc - return PREFIX_TEMPLATE.format(**self._setting.dict()) + if self.desc: + return self.desc + return PREFIX_TEMPLATE.format(**{ + "profile": self.profile, + "name": self.name, + "goal": self.goal, + "constraints": self.constraints + }) async def _think(self) -> None: """Think about what to do and decide on the next action""" diff --git a/metagpt/schema.py b/metagpt/schema.py index 60aa819b0..3a5bea7e9 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -48,23 +48,9 @@ class Message(BaseModel): kwargs["cause_by"] = action_class.deser_class(cause_by) super(Message, self).__init__(**kwargs) - def dict(self, - *, - include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, - exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, - by_alias: bool = False, - skip_defaults: Optional[bool] = None, - exclude_unset: bool = False, - exclude_defaults: bool = False, - exclude_none: bool = False) -> "DictStrAny": + def dict(self, *args, **kwargs) -> "DictStrAny": """ overwrite the `dict` to dump dynamic pydantic model""" - obj_dict = super(Message, self).dict(include=include, - exclude=exclude, - by_alias=by_alias, - skip_defaults=skip_defaults, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none) + obj_dict = super(Message, self).dict(*args, **kwargs) ic = self.instruct_content # deal custom-defined action if ic: schema = ic.schema() @@ -77,19 +63,6 @@ class Message(BaseModel): obj_dict["cause_by"] = cb.ser_class() return obj_dict -# -# -# @dataclass -# class Message: -# """list[: ]""" -# content: str -# instruct_content: BaseModel = field(default=None) -# role: str = field(default='user') # system / user / assistant -# cause_by: Type["Action"] = field(default="") -# sent_from: str = field(default="") -# send_to: str = field(default="") -# restricted_to: str = field(default="") - def __str__(self): # prefix = '-'.join([self.role, str(self.cause_by)]) return f"{self.role}: {self.content}" @@ -97,17 +70,6 @@ class Message(BaseModel): def __repr__(self): return self.__str__() - # def dict(self): - # return { - # "content": self.content, - # "instruct_content": self.instruct_content, - # "role": self.role, - # "cause_by": self.cause_by, - # "sent_from": self.sent_from, - # "send_to": self.send_to, - # "restricted_to": self.restricted_to - # } - def to_dict(self) -> dict: return { "role": self.role, From 0e8eda683e991f8ea7f80ccb09da2fc9a208a265 Mon Sep 17 00:00:00 2001 From: better629 Date: Fri, 1 Dec 2023 14:45:06 +0800 Subject: [PATCH 012/167] update ut after simplification --- tests/metagpt/serialize_deserialize/test_action.py | 14 +------------- tests/metagpt/serialize_deserialize/test_role.py | 3 --- .../serialize_deserialize/test_serdeser_base.py | 6 +++--- tests/metagpt/serialize_deserialize/test_team.py | 2 +- .../serialize_deserialize/test_wrire_prd.py | 2 +- .../serialize_deserialize/test_write_code.py | 4 ++-- .../serialize_deserialize/test_write_design.py | 4 ++-- 7 files changed, 10 insertions(+), 25 deletions(-) diff --git a/tests/metagpt/serialize_deserialize/test_action.py b/tests/metagpt/serialize_deserialize/test_action.py index 0138d41ce..16369bb61 100644 --- a/tests/metagpt/serialize_deserialize/test_action.py +++ b/tests/metagpt/serialize_deserialize/test_action.py @@ -13,7 +13,7 @@ def test_action_serialize(): action = Action() ser_action_dict = action.dict() assert "name" in ser_action_dict - assert "llm" not in ser_action_dict + # assert "llm" not in ser_action_dict # not export @pytest.mark.asyncio @@ -34,15 +34,3 @@ def test_action_serdeser(): action_class = Action.deser_class(action_info) assert action_class == WriteTest - - -def test_action_class_serdeser(): - name = "write test" - action_info = WriteTest(name=name).serialize() - assert action_info["name"] == name - - action_info = WriteTest(name=name, llm=LLM()).serialize() - assert action_info["name"] == name - - action = Action.deserialize(action_info) - assert action.name == name diff --git a/tests/metagpt/serialize_deserialize/test_role.py b/tests/metagpt/serialize_deserialize/test_role.py index c21b9cc2e..61684ba9d 100644 --- a/tests/metagpt/serialize_deserialize/test_role.py +++ b/tests/metagpt/serialize_deserialize/test_role.py @@ -66,7 +66,6 @@ def test_role_serdeser_save(): role_tag = f"{pm.__class__.__name__}_{pm.name}" stg_path = stg_path_prefix.joinpath(role_tag) pm.serialize(stg_path) - assert stg_path.joinpath("actions/actions_info.json").exists() new_pm = Role.deserialize(stg_path) assert new_pm.name == pm.name @@ -89,8 +88,6 @@ async def test_role_serdeser_interrupt(): assert role_c._rc.memory.count() == 2 - assert stg_path.joinpath("actions/todo.json").exists() - new_role_a: Role = Role.deserialize(stg_path) assert new_role_a._rc.state == 1 diff --git a/tests/metagpt/serialize_deserialize/test_serdeser_base.py b/tests/metagpt/serialize_deserialize/test_serdeser_base.py index 00d894b3d..74f9fea87 100644 --- a/tests/metagpt/serialize_deserialize/test_serdeser_base.py +++ b/tests/metagpt/serialize_deserialize/test_serdeser_base.py @@ -21,7 +21,7 @@ class MockMessage(BaseModel): class ActionPass(Action): - name: str = "ActionPass" + name: str = Field(default="ActionPass") async def run(self, messages: list["Message"]) -> ActionOutput: output_mapping = { @@ -34,14 +34,14 @@ class ActionPass(Action): class ActionOK(Action): - name: str = "ActionOK" + name: str = Field(default="ActionOK") async def run(self, messages: list["Message"]) -> str: return "ok" class ActionRaise(Action): - name: str = "ActionRaise" + name: str = Field(default="ActionRaise") async def run(self, messages: list["Message"]) -> str: raise RuntimeError("parse error in ActionRaise") diff --git a/tests/metagpt/serialize_deserialize/test_team.py b/tests/metagpt/serialize_deserialize/test_team.py index e5ec20f2e..28728e1b5 100644 --- a/tests/metagpt/serialize_deserialize/test_team.py +++ b/tests/metagpt/serialize_deserialize/test_team.py @@ -70,7 +70,7 @@ async def test_team_recover(): new_role_c = new_company.environment.get_role(role_c.profile) assert new_role_c._rc.memory == role_c._rc.memory - assert new_role_c._rc.env != role_c._rc.env # due to Action raise, role's memory has been changed. + assert new_role_c._rc.env == role_c._rc.env # TODO check again assert new_role_c._rc.env.memory == role_c._rc.env.memory assert new_company.environment.memory.count() == 1 diff --git a/tests/metagpt/serialize_deserialize/test_wrire_prd.py b/tests/metagpt/serialize_deserialize/test_wrire_prd.py index 05a86cb7f..0b9dfa9d8 100644 --- a/tests/metagpt/serialize_deserialize/test_wrire_prd.py +++ b/tests/metagpt/serialize_deserialize/test_wrire_prd.py @@ -13,7 +13,7 @@ def test_action_serialize(): action = WritePRD() ser_action_dict = action.dict() assert "name" in ser_action_dict - assert "llm" in ser_action_dict + # assert "llm" in ser_action_dict # not export @pytest.mark.asyncio diff --git a/tests/metagpt/serialize_deserialize/test_write_code.py b/tests/metagpt/serialize_deserialize/test_write_code.py index 4e3b712c0..5552ffd7f 100644 --- a/tests/metagpt/serialize_deserialize/test_write_code.py +++ b/tests/metagpt/serialize_deserialize/test_write_code.py @@ -12,14 +12,14 @@ def test_write_design_serialize(): action = WriteCode() ser_action_dict = action.dict() assert ser_action_dict["name"] == "WriteCode" - assert "llm" in ser_action_dict + # assert "llm" in ser_action_dict # not export def test_write_task_serialize(): action = WriteCodeReview() ser_action_dict = action.dict() assert ser_action_dict["name"] == "WriteCodeReview" - assert "llm" in ser_action_dict + # assert "llm" in ser_action_dict # not export @pytest.mark.asyncio diff --git a/tests/metagpt/serialize_deserialize/test_write_design.py b/tests/metagpt/serialize_deserialize/test_write_design.py index 5b2a30ed3..080896c98 100644 --- a/tests/metagpt/serialize_deserialize/test_write_design.py +++ b/tests/metagpt/serialize_deserialize/test_write_design.py @@ -12,14 +12,14 @@ def test_write_design_serialize(): action = WriteDesign() ser_action_dict = action.dict() assert "name" in ser_action_dict - assert "llm" in ser_action_dict + # assert "llm" in ser_action_dict # not export def test_write_task_serialize(): action = WriteTasks() ser_action_dict = action.dict() assert "name" in ser_action_dict - assert "llm" in ser_action_dict + # assert "llm" in ser_action_dict # not export @pytest.mark.asyncio From c7a5bea2b157d2fca2641369a14a415fd935f83f Mon Sep 17 00:00:00 2001 From: better629 Date: Fri, 1 Dec 2023 15:30:28 +0800 Subject: [PATCH 013/167] update --- tests/metagpt/serialize_deserialize/test_team.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/metagpt/serialize_deserialize/test_team.py b/tests/metagpt/serialize_deserialize/test_team.py index 28728e1b5..9c4eb8170 100644 --- a/tests/metagpt/serialize_deserialize/test_team.py +++ b/tests/metagpt/serialize_deserialize/test_team.py @@ -70,7 +70,7 @@ async def test_team_recover(): new_role_c = new_company.environment.get_role(role_c.profile) assert new_role_c._rc.memory == role_c._rc.memory - assert new_role_c._rc.env == role_c._rc.env # TODO check again + assert new_role_c._rc.env == role_c._rc.env assert new_role_c._rc.env.memory == role_c._rc.env.memory assert new_company.environment.memory.count() == 1 @@ -95,7 +95,10 @@ async def test_team_recover_save(): new_company = Team.recover(stg_path) new_role_c = new_company.environment.get_role(role_c.profile) assert new_role_c._rc.memory == role_c._rc.memory - assert new_role_c._rc.env != role_c._rc.env # due to Action raise, role's memory has been changed. + assert new_role_c._rc.env != role_c._rc.env + assert new_role_c.recovered != role_c.recovered # here cause previous ut is `!=` + assert new_role_c._rc.todo != role_c._rc.todo # serialize exclude `_rc.todo` + assert new_role_c._rc.news != role_c._rc.news # serialize exclude `_rc.news` assert new_role_c._rc.env.memory == role_c._rc.env.memory new_company.start_project(idea) From bcba1393b4e1e3445031cd4779fcb441f1fad8d7 Mon Sep 17 00:00:00 2001 From: better629 Date: Fri, 1 Dec 2023 20:35:48 +0800 Subject: [PATCH 014/167] update asyncio.sleep to make it async --- .../test_serdeser_base.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/tests/metagpt/serialize_deserialize/test_serdeser_base.py b/tests/metagpt/serialize_deserialize/test_serdeser_base.py index 74f9fea87..298c13823 100644 --- a/tests/metagpt/serialize_deserialize/test_serdeser_base.py +++ b/tests/metagpt/serialize_deserialize/test_serdeser_base.py @@ -4,6 +4,7 @@ from pydantic import BaseModel, Field from pathlib import Path +import asyncio from metagpt.actions.action import Action from metagpt.roles.role import Role, RoleReactMode @@ -24,6 +25,7 @@ class ActionPass(Action): name: str = Field(default="ActionPass") async def run(self, messages: list["Message"]) -> ActionOutput: + await asyncio.sleep(5) # sleep to make other roles can watch the executed Message output_mapping = { "result": (str, ...) } @@ -37,6 +39,7 @@ class ActionOK(Action): name: str = Field(default="ActionOK") async def run(self, messages: list["Message"]) -> str: + await asyncio.sleep(5) return "ok" @@ -55,14 +58,10 @@ class RoleA(Role): constraints: str = "RoleA's constraints" def __init__(self, **kwargs): - # super(RoleA, self).__init__(**kwargs) - super().__init__(**kwargs) + super(RoleA, self).__init__(**kwargs) self._init_actions([ActionPass]) self._watch([BossRequirement]) - async def run(self, message: "Message" = None): - await super(RoleA, self).run(message) - class RoleB(Role): name: str = Field(default="RoleB") @@ -71,15 +70,11 @@ class RoleB(Role): constraints: str = "RoleB's constraints" def __init__(self, **kwargs): - # super(RoleB, self).__init__(**kwargs) - super().__init__(**kwargs) + super(RoleB, self).__init__(**kwargs) self._init_actions([ActionOK, ActionRaise]) self._watch([ActionPass]) self._rc.react_mode = RoleReactMode.BY_ORDER - async def run(self, message: "Message" = None): - await super(RoleB, self).run(message) - class RoleC(Role): name: str = Field(default="RoleC") @@ -92,6 +87,3 @@ class RoleC(Role): self._init_actions([ActionOK, ActionRaise]) self._watch([BossRequirement]) self._rc.react_mode = RoleReactMode.BY_ORDER - - async def run(self, message: "Message" = None): - await super(RoleC, self).run(message) From cb81561b69749596b16cdee7e6e3ed4128cd6685 Mon Sep 17 00:00:00 2001 From: better629 Date: Fri, 1 Dec 2023 21:07:47 +0800 Subject: [PATCH 015/167] fix when RoleReactMode=REACT --- metagpt/roles/role.py | 4 ++-- metagpt/utils/utils.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 4e669772e..5b998bf9a 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -337,9 +337,9 @@ class Role(BaseModel): # If there is only one action, then only this one can be performed self._set_state(0) return - if self._recovered and self._rc.state >= 0: + if self.recovered and self._rc.state >= 0: self._set_state(self._rc.state) # action to run from recovered state - self._recovered = False # avoid max_react_loop out of work + self.recovered = False # avoid max_react_loop out of work return prompt = self._get_prefix() diff --git a/metagpt/utils/utils.py b/metagpt/utils/utils.py index b72dabf7e..c1416c352 100644 --- a/metagpt/utils/utils.py +++ b/metagpt/utils/utils.py @@ -74,6 +74,7 @@ def role_raise_decorator(func): newest_msgs = self._rc.env.memory.get(1) if len(newest_msgs) > 0: self._rc.memory.delete(newest_msgs[0]) + raise Exception(format_trackback_info(limit=None)) # raise again to make it captured outside except Exception as exp: if self._rc.env: newest_msgs = self._rc.env.memory.get(1) From 9f9b7ebe17b09d7bd952173e407dca565e064bb4 Mon Sep 17 00:00:00 2001 From: Stitch-z <284618289@qq.com> Date: Sat, 2 Dec 2023 14:39:51 +0800 Subject: [PATCH 016/167] update: optimize the action code for writing tutorials. --- examples/write_tutorial.py | 2 ++ metagpt/roles/tutorial_assistant.py | 30 +++++------------------------ 2 files changed, 7 insertions(+), 25 deletions(-) diff --git a/examples/write_tutorial.py b/examples/write_tutorial.py index 71ece5527..8d2b25103 100644 --- a/examples/write_tutorial.py +++ b/examples/write_tutorial.py @@ -1,10 +1,12 @@ #!/usr/bin/env python3 # _*_ coding: utf-8 _*_ + """ @Time : 2023/9/4 21:40:57 @Author : Stitch-z @File : tutorial_assistant.py """ + import asyncio from metagpt.roles.tutorial_assistant import TutorialAssistant diff --git a/metagpt/roles/tutorial_assistant.py b/metagpt/roles/tutorial_assistant.py index 9a7df4f4d..7c9450997 100644 --- a/metagpt/roles/tutorial_assistant.py +++ b/metagpt/roles/tutorial_assistant.py @@ -42,17 +42,7 @@ class TutorialAssistant(Role): self.main_title = "" self.total_content = "" self.language = language - - async def _think(self) -> None: - """Determine the next action to be taken by the role.""" - if self._rc.todo is None: - self._set_state(0) - return - - if self._rc.state + 1 < len(self._states): - self._set_state(self._rc.state + 1) - else: - self._rc.todo = None + self._set_react_mode(react_mode="by_order") async def _handle_directory(self, titles: Dict) -> Message: """Handle the directories for the tutorial document. @@ -75,8 +65,6 @@ class TutorialAssistant(Role): for second_dir in first_dir[key]: directory += f" - {second_dir}\n" self._init_actions(actions) - self._rc.todo = None - return Message(content=directory) async def _act(self) -> Message: """Perform an action as determined by the role. @@ -90,7 +78,8 @@ class TutorialAssistant(Role): self.topic = msg.content resp = await todo.run(topic=self.topic) logger.info(resp) - return await self._handle_directory(resp) + await self._handle_directory(resp) + return await super().react() resp = await todo.run(topic=self.topic) logger.info(resp) if self.total_content != "": @@ -98,17 +87,8 @@ class TutorialAssistant(Role): self.total_content += resp return Message(content=resp, role=self.profile) - async def _react(self) -> Message: - """Execute the assistant's think and actions. - - Returns: - A message containing the final result of the assistant's actions. - """ - while True: - await self._think() - if self._rc.todo is None: - break - msg = await self._act() + async def react(self) -> Message: + msg = await super().react() root_path = TUTORIAL_PATH / datetime.now().strftime("%Y-%m-%d_%H-%M-%S") await File.write(root_path, f"{self.main_title}.md", self.total_content.encode('utf-8')) return msg From eaf531e0ac44edd4360f550b960a977725bb0edd Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Tue, 5 Dec 2023 11:26:54 +0800 Subject: [PATCH 017/167] support new openai package --- metagpt/config.py | 8 +- metagpt/provider/general_api_base.py | 718 ++++++++++++++++++++ metagpt/provider/general_api_requestor.py | 22 +- metagpt/provider/openai_api.py | 126 ++-- metagpt/provider/zhipuai/zhipu_model_api.py | 35 +- metagpt/tools/code_interpreter.py | 62 +- metagpt/utils/make_sk_kernel.py | 6 +- requirements.txt | 4 +- tests/metagpt/provider/test_zhipuai_api.py | 22 +- 9 files changed, 866 insertions(+), 137 deletions(-) create mode 100644 metagpt/provider/general_api_base.py diff --git a/metagpt/config.py b/metagpt/config.py index a6ecab5ff..4306445ef 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -5,7 +5,6 @@ Provide configuration, singleton """ import os -import openai import yaml from metagpt.const import PROJECT_ROOT @@ -52,11 +51,8 @@ class Config(metaclass=Singleton): and (not self.zhipuai_api_key or "YOUR_API_KEY" == self.zhipuai_api_key) ): raise NotConfiguredException("Set OPENAI_API_KEY or Anthropic_API_KEY or ZHIPUAI_API_KEY first") - self.openai_api_base = self._get("OPENAI_BASE_URL") - openai_proxy = self._get("OPENAI_PROXY") or self.global_proxy - if openai_proxy: - openai.proxy = openai_proxy - openai.api_base = self.openai_api_base + self.openai_base_url = self._get("OPENAI_BASE_URL") + self.openai_proxy = self._get("OPENAI_PROXY") or self.global_proxy self.openai_api_type = self._get("OPENAI_API_TYPE") self.openai_api_version = self._get("OPENAI_API_VERSION") self.openai_api_rpm = self._get("RPM", 3) diff --git a/metagpt/provider/general_api_base.py b/metagpt/provider/general_api_base.py new file mode 100644 index 000000000..da16e942d --- /dev/null +++ b/metagpt/provider/general_api_base.py @@ -0,0 +1,718 @@ +import asyncio +import json +import os +import platform +import re +import sys +import threading +import time +from contextlib import asynccontextmanager +from enum import Enum +from typing import ( + AsyncGenerator, + AsyncIterator, + Callable, + Dict, + Iterator, + Optional, + Tuple, + Union, + overload, +) +from urllib.parse import urlencode, urlsplit, urlunsplit + +import aiohttp +import requests + +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + +import logging + +import openai +from openai import version + +logger = logging.getLogger("openai") + +TIMEOUT_SECS = 600 +MAX_SESSION_LIFETIME_SECS = 180 +MAX_CONNECTION_RETRIES = 2 + +# Has one attribute per thread, 'session'. +_thread_context = threading.local() + +OPENAI_LOG = os.environ.get("OPENAI_LOG") +OPENAI_LOG = "debug" + + +class ApiType(Enum): + AZURE = 1 + OPEN_AI = 2 + AZURE_AD = 3 + + @staticmethod + def from_str(label): + if label.lower() == "azure": + return ApiType.AZURE + elif label.lower() in ("azure_ad", "azuread"): + return ApiType.AZURE_AD + elif label.lower() in ("open_ai", "openai"): + return ApiType.OPEN_AI + else: + raise openai.OpenAIError( + "The API type provided in invalid. Please select one of the supported API types: 'azure', 'azure_ad', 'open_ai'" + ) + + +api_key_to_header = ( + lambda api, key: {"Authorization": f"Bearer {key}"} + if api in (ApiType.OPEN_AI, ApiType.AZURE_AD) + else {"api-key": f"{key}"} +) + + +def _console_log_level(): + if OPENAI_LOG in ["debug", "info"]: + return OPENAI_LOG + else: + return None + + +def log_debug(message, **params): + msg = logfmt(dict(message=message, **params)) + if _console_log_level() == "debug": + print(msg, file=sys.stderr) + logger.debug(msg) + + +def log_info(message, **params): + msg = logfmt(dict(message=message, **params)) + if _console_log_level() in ["debug", "info"]: + print(msg, file=sys.stderr) + logger.info(msg) + + +def log_warn(message, **params): + msg = logfmt(dict(message=message, **params)) + print(msg, file=sys.stderr) + logger.warn(msg) + + +def logfmt(props): + def fmt(key, val): + # Handle case where val is a bytes or bytesarray + if hasattr(val, "decode"): + val = val.decode("utf-8") + # Check if val is already a string to avoid re-encoding into ascii. + if not isinstance(val, str): + val = str(val) + if re.search(r"\s", val): + val = repr(val) + # key should already be a string + if re.search(r"\s", key): + key = repr(key) + return "{key}={val}".format(key=key, val=val) + + return " ".join([fmt(key, val) for key, val in sorted(props.items())]) + + +class OpenAIResponse: + def __init__(self, data, headers): + self._headers = headers + self.data = data + + @property + def request_id(self) -> Optional[str]: + return self._headers.get("request-id") + + @property + def retry_after(self) -> Optional[int]: + try: + return int(self._headers.get("retry-after")) + except TypeError: + return None + + @property + def operation_location(self) -> Optional[str]: + return self._headers.get("operation-location") + + @property + def organization(self) -> Optional[str]: + return self._headers.get("OpenAI-Organization") + + @property + def response_ms(self) -> Optional[int]: + h = self._headers.get("Openai-Processing-Ms") + return None if h is None else round(float(h)) + + +def _build_api_url(url, query): + scheme, netloc, path, base_query, fragment = urlsplit(url) + + if base_query: + query = "%s&%s" % (base_query, query) + + return urlunsplit((scheme, netloc, path, query, fragment)) + + +def _requests_proxies_arg(proxy) -> Optional[Dict[str, str]]: + """Returns a value suitable for the 'proxies' argument to 'requests.request.""" + if proxy is None: + return None + elif isinstance(proxy, str): + return {"http": proxy, "https": proxy} + elif isinstance(proxy, dict): + return proxy.copy() + else: + raise ValueError( + "'openai.proxy' must be specified as either a string URL or a dict with string URL under the https and/or http keys." + ) + + +def _aiohttp_proxies_arg(proxy) -> Optional[str]: + """Returns a value suitable for the 'proxies' argument to 'aiohttp.ClientSession.request.""" + if proxy is None: + return None + elif isinstance(proxy, str): + return proxy + elif isinstance(proxy, dict): + return proxy["https"] if "https" in proxy else proxy["http"] + else: + raise ValueError( + "'openai.proxy' must be specified as either a string URL or a dict with string URL under the https and/or http keys." + ) + + +def _make_session() -> requests.Session: + s = requests.Session() + s.mount( + "https://", + requests.adapters.HTTPAdapter(max_retries=MAX_CONNECTION_RETRIES), + ) + return s + + +def parse_stream_helper(line: bytes) -> Optional[str]: + if line: + if line.strip() == b"data: [DONE]": + # return here will cause GeneratorExit exception in urllib3 + # and it will close http connection with TCP Reset + return None + if line.startswith(b"data: "): + line = line[len(b"data: ") :] + return line.decode("utf-8") + else: + return None + return None + + +def parse_stream(rbody: Iterator[bytes]) -> Iterator[str]: + for line in rbody: + _line = parse_stream_helper(line) + if _line is not None: + yield _line + + +async def parse_stream_async(rbody: aiohttp.StreamReader): + async for line in rbody: + _line = parse_stream_helper(line) + if _line is not None: + yield _line + + +class APIRequestor: + def __init__( + self, + key=None, + base_url=None, + api_type=None, + api_version=None, + organization=None, + ): + self.base_url = base_url or openai.base_url + self.api_key = key or openai.api_key + self.api_type = ApiType.from_str(api_type) if api_type else ApiType.from_str("openai") + self.api_version = api_version or openai.api_version + self.organization = organization or openai.organization + + def _check_polling_response(self, response: OpenAIResponse, predicate: Callable[[OpenAIResponse], bool]): + if not predicate(response): + return + error_data = response.data["error"] + message = error_data.get("message", "Operation failed") + code = error_data.get("code") + raise openai.APIError(message=message, body=dict(code=code)) + + def _poll( + self, method, url, until, failed, params=None, headers=None, interval=None, delay=None + ) -> Tuple[Iterator[OpenAIResponse], bool, str]: + if delay: + time.sleep(delay) + + response, b, api_key = self.request(method, url, params, headers) + self._check_polling_response(response, failed) + start_time = time.time() + while not until(response): + if time.time() - start_time > TIMEOUT_SECS: + raise openai.APITimeoutError("Operation polling timed out.") + + time.sleep(interval or response.retry_after or 10) + response, b, api_key = self.request(method, url, params, headers) + self._check_polling_response(response, failed) + + response.data = response.data["result"] + return response, b, api_key + + async def _apoll( + self, method, url, until, failed, params=None, headers=None, interval=None, delay=None + ) -> Tuple[Iterator[OpenAIResponse], bool, str]: + if delay: + await asyncio.sleep(delay) + + response, b, api_key = await self.arequest(method, url, params, headers) + self._check_polling_response(response, failed) + start_time = time.time() + while not until(response): + if time.time() - start_time > TIMEOUT_SECS: + raise openai.APITimeoutError("Operation polling timed out.") + + await asyncio.sleep(interval or response.retry_after or 10) + response, b, api_key = await self.arequest(method, url, params, headers) + self._check_polling_response(response, failed) + + response.data = response.data["result"] + return response, b, api_key + + @overload + def request( + self, + method, + url, + params, + headers, + files, + stream: Literal[True], + request_id: Optional[str] = ..., + request_timeout: Optional[Union[float, Tuple[float, float]]] = ..., + ) -> Tuple[Iterator[OpenAIResponse], bool, str]: + pass + + @overload + def request( + self, + method, + url, + params=..., + headers=..., + files=..., + *, + stream: Literal[True], + request_id: Optional[str] = ..., + request_timeout: Optional[Union[float, Tuple[float, float]]] = ..., + ) -> Tuple[Iterator[OpenAIResponse], bool, str]: + pass + + @overload + def request( + self, + method, + url, + params=..., + headers=..., + files=..., + stream: Literal[False] = ..., + request_id: Optional[str] = ..., + request_timeout: Optional[Union[float, Tuple[float, float]]] = ..., + ) -> Tuple[OpenAIResponse, bool, str]: + pass + + @overload + def request( + self, + method, + url, + params=..., + headers=..., + files=..., + stream: bool = ..., + request_id: Optional[str] = ..., + request_timeout: Optional[Union[float, Tuple[float, float]]] = ..., + ) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool, str]: + pass + + def request( + self, + method, + url, + params=None, + headers=None, + files=None, + stream: bool = False, + request_id: Optional[str] = None, + request_timeout: Optional[Union[float, Tuple[float, float]]] = None, + ) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool, str]: + result = self.request_raw( + method.lower(), + url, + params=params, + supplied_headers=headers, + files=files, + stream=stream, + request_id=request_id, + request_timeout=request_timeout, + ) + resp, got_stream = self._interpret_response(result, stream) + return resp, got_stream, self.api_key + + @overload + async def arequest( + self, + method, + url, + params, + headers, + files, + stream: Literal[True], + request_id: Optional[str] = ..., + request_timeout: Optional[Union[float, Tuple[float, float]]] = ..., + ) -> Tuple[AsyncGenerator[OpenAIResponse, None], bool, str]: + pass + + @overload + async def arequest( + self, + method, + url, + params=..., + headers=..., + files=..., + *, + stream: Literal[True], + request_id: Optional[str] = ..., + request_timeout: Optional[Union[float, Tuple[float, float]]] = ..., + ) -> Tuple[AsyncGenerator[OpenAIResponse, None], bool, str]: + pass + + @overload + async def arequest( + self, + method, + url, + params=..., + headers=..., + files=..., + stream: Literal[False] = ..., + request_id: Optional[str] = ..., + request_timeout: Optional[Union[float, Tuple[float, float]]] = ..., + ) -> Tuple[OpenAIResponse, bool, str]: + pass + + @overload + async def arequest( + self, + method, + url, + params=..., + headers=..., + files=..., + stream: bool = ..., + request_id: Optional[str] = ..., + request_timeout: Optional[Union[float, Tuple[float, float]]] = ..., + ) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool, str]: + pass + + async def arequest( + self, + method, + url, + params=None, + headers=None, + files=None, + stream: bool = False, + request_id: Optional[str] = None, + request_timeout: Optional[Union[float, Tuple[float, float]]] = None, + ) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool, str]: + ctx = aiohttp_session() + session = await ctx.__aenter__() + try: + result = await self.arequest_raw( + method.lower(), + url, + session, + params=params, + supplied_headers=headers, + files=files, + request_id=request_id, + request_timeout=request_timeout, + ) + resp, got_stream = await self._interpret_async_response(result, stream) + except Exception: + await ctx.__aexit__(None, None, None) + raise + if got_stream: + + async def wrap_resp(): + assert isinstance(resp, AsyncGenerator) + try: + async for r in resp: + yield r + finally: + await ctx.__aexit__(None, None, None) + + return wrap_resp(), got_stream, self.api_key + else: + await ctx.__aexit__(None, None, None) + return resp, got_stream, self.api_key + + def handle_error_response(self, rbody, rcode, resp, rheaders, stream_error=False): + try: + error_data = resp["error"] + except (KeyError, TypeError): + raise openai.APIError( + "Invalid response object from API: %r (HTTP response code " "was %d)" % (rbody, rcode) + ) + + if "internal_message" in error_data: + error_data["message"] += "\n\n" + error_data["internal_message"] + + log_info( + "OpenAI API error received", + error_code=error_data.get("code"), + error_type=error_data.get("type"), + error_message=error_data.get("message"), + error_param=error_data.get("param"), + stream_error=stream_error, + ) + + # Rate limits were previously coded as 400's with code 'rate_limit' + if rcode == 429: + return openai.RateLimitError(f"{error_data.get('message')} {rbody} {rcode} {resp} {rheaders}", body=rbody) + elif rcode in [400, 404, 415]: + return openai.BadRequestError( + message=f'{error_data.get("message")}, {error_data.get("param")}, {error_data.get("code")} {rbody} {rcode} {resp} {rheaders}', + body=rbody, + ) + elif rcode == 401: + return openai.AuthenticationError( + f"{error_data.get('message')} {rbody} {rcode} {resp} {rheaders}", body=rbody + ) + elif rcode == 403: + return openai.PermissionDeniedError( + f"{error_data.get('message')} {rbody} {rcode} {resp} {rheaders}", body=rbody + ) + elif rcode == 409: + return openai.ConflictError(f"{error_data.get('message')} {rbody} {rcode} {resp} {rheaders}", body=rbody) + elif stream_error: + # TODO: we will soon attach status codes to stream errors + parts = [error_data.get("message"), "(Error occurred while streaming.)"] + message = " ".join([p for p in parts if p is not None]) + return openai.APIError(f"{message} {rbody} {rcode} {resp} {rheaders}", body=rbody) + else: + return openai.APIError( + f"{error_data.get('message')} {rbody} {rcode} {resp} {rheaders}", + body=rbody, + ) + + def request_headers(self, method: str, extra, request_id: Optional[str]) -> Dict[str, str]: + user_agent = "OpenAI/v1 PythonBindings/%s" % (version.VERSION,) + + uname_without_node = " ".join(v for k, v in platform.uname()._asdict().items() if k != "node") + ua = { + "bindings_version": version.VERSION, + "httplib": "requests", + "lang": "python", + "lang_version": platform.python_version(), + "platform": platform.platform(), + "publisher": "openai", + "uname": uname_without_node, + } + + headers = { + "X-OpenAI-Client-User-Agent": json.dumps(ua), + "User-Agent": user_agent, + } + + headers.update(api_key_to_header(self.api_type, self.api_key)) + + if self.organization: + headers["OpenAI-Organization"] = self.organization + + if self.api_version is not None and self.api_type == ApiType.OPEN_AI: + headers["OpenAI-Version"] = self.api_version + if request_id is not None: + headers["X-Request-Id"] = request_id + headers.update(extra) + + return headers + + def _validate_headers(self, supplied_headers: Optional[Dict[str, str]]) -> Dict[str, str]: + headers: Dict[str, str] = {} + if supplied_headers is None: + return headers + + if not isinstance(supplied_headers, dict): + raise TypeError("Headers must be a dictionary") + + for k, v in supplied_headers.items(): + if not isinstance(k, str): + raise TypeError("Header keys must be strings") + if not isinstance(v, str): + raise TypeError("Header values must be strings") + headers[k] = v + + # NOTE: It is possible to do more validation of the headers, but a request could always + # be made to the API manually with invalid headers, so we need to handle them server side. + + return headers + + def _prepare_request_raw( + self, + url, + supplied_headers, + method, + params, + files, + request_id: Optional[str], + ) -> Tuple[str, Dict[str, str], Optional[bytes]]: + abs_url = "%s%s" % (self.base_url, url) + headers = self._validate_headers(supplied_headers) + + data = None + if method == "get" or method == "delete": + if params: + encoded_params = urlencode([(k, v) for k, v in params.items() if v is not None]) + abs_url = _build_api_url(abs_url, encoded_params) + elif method in {"post", "put"}: + if params and files: + data = params + if params and not files: + data = json.dumps(params).encode() + headers["Content-Type"] = "application/json" + else: + raise openai.APIConnectionError( + "Unrecognized HTTP method %r. This may indicate a bug in the " + "OpenAI bindings. Please contact us through our help center at help.openai.com for " + "assistance." % (method,) + ) + + headers = self.request_headers(method, headers, request_id) + + log_debug("Request to OpenAI API", method=method, path=abs_url) + log_debug("Post details", data=data, api_version=self.api_version) + + return abs_url, headers, data + + def request_raw( + self, + method, + url, + *, + params=None, + supplied_headers: Optional[Dict[str, str]] = None, + files=None, + stream: bool = False, + request_id: Optional[str] = None, + request_timeout: Optional[Union[float, Tuple[float, float]]] = None, + ) -> requests.Response: + abs_url, headers, data = self._prepare_request_raw(url, supplied_headers, method, params, files, request_id) + + if not hasattr(_thread_context, "session"): + _thread_context.session = _make_session() + _thread_context.session_create_time = time.time() + elif time.time() - getattr(_thread_context, "session_create_time", 0) >= MAX_SESSION_LIFETIME_SECS: + _thread_context.session.close() + _thread_context.session = _make_session() + _thread_context.session_create_time = time.time() + try: + result = _thread_context.session.request( + method, + abs_url, + headers=headers, + data=data, + files=files, + stream=stream, + timeout=request_timeout if request_timeout else TIMEOUT_SECS, + proxies=_thread_context.session.proxies, + ) + except requests.exceptions.Timeout as e: + raise openai.APITimeoutError("Request timed out: {}".format(e)) from e + except requests.exceptions.RequestException as e: + raise openai.APIConnectionError("Error communicating with OpenAI: {}".format(e)) from e + log_debug( + "OpenAI API response", + path=abs_url, + response_code=result.status_code, + processing_ms=result.headers.get("OpenAI-Processing-Ms"), + request_id=result.headers.get("X-Request-Id"), + ) + return result + + async def arequest_raw( + self, + method, + url, + session, + *, + params=None, + supplied_headers: Optional[Dict[str, str]] = None, + files=None, + request_id: Optional[str] = None, + request_timeout: Optional[Union[float, Tuple[float, float]]] = None, + ) -> aiohttp.ClientResponse: + abs_url, headers, data = self._prepare_request_raw(url, supplied_headers, method, params, files, request_id) + + if isinstance(request_timeout, tuple): + timeout = aiohttp.ClientTimeout( + connect=request_timeout[0], + total=request_timeout[1], + ) + else: + timeout = aiohttp.ClientTimeout(total=request_timeout if request_timeout else TIMEOUT_SECS) + + if files: + # TODO: Use `aiohttp.MultipartWriter` to create the multipart form data here. + # For now we use the private `requests` method that is known to have worked so far. + data, content_type = requests.models.RequestEncodingMixin._encode_files(files, data) # type: ignore + headers["Content-Type"] = content_type + request_kwargs = { + "method": method, + "url": abs_url, + "headers": headers, + "data": data, + "timeout": timeout, + } + try: + result = await session.request(**request_kwargs) + log_info( + "OpenAI API response", + path=abs_url, + response_code=result.status, + processing_ms=result.headers.get("OpenAI-Processing-Ms"), + request_id=result.headers.get("X-Request-Id"), + ) + return result + except (aiohttp.ServerTimeoutError, asyncio.TimeoutError) as e: + raise openai.APITimeoutError("Request timed out") from e + except aiohttp.ClientError as e: + raise openai.APIConnectionError("Error communicating with OpenAI") from e + + def _interpret_response( + self, result: requests.Response, stream: bool + ) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool]: + """Returns the response(s) and a bool indicating whether it is a stream.""" + + async def _interpret_async_response( + self, result: aiohttp.ClientResponse, stream: bool + ) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool]: + """Returns the response(s) and a bool indicating whether it is a stream.""" + + def _interpret_response_line(self, rbody: str, rcode: int, rheaders, stream: bool) -> OpenAIResponse: + ... + + +@asynccontextmanager +async def aiohttp_session() -> AsyncIterator[aiohttp.ClientSession]: + async with aiohttp.ClientSession() as session: + yield session diff --git a/metagpt/provider/general_api_requestor.py b/metagpt/provider/general_api_requestor.py index 150f2f1e0..f8321cc6b 100644 --- a/metagpt/provider/general_api_requestor.py +++ b/metagpt/provider/general_api_requestor.py @@ -2,20 +2,20 @@ # -*- coding: utf-8 -*- # @Desc : General Async API for http-based LLM model -from typing import AsyncGenerator, Tuple, Union, Optional, Literal -import aiohttp import asyncio +from typing import AsyncGenerator, Tuple, Union -from openai.api_requestor import APIRequestor +import aiohttp from metagpt.logs import logger +from metagpt.provider.general_api_base import APIRequestor class GeneralAPIRequestor(APIRequestor): """ usage - # full_url = "{api_base}{url}" - requester = GeneralAPIRequestor(api_base=api_base) + # full_url = "{base_url}{url}" + requester = GeneralAPIRequestor(base_url=base_url) result, _, api_key = await requester.arequest( method=method, url=url, @@ -26,9 +26,7 @@ class GeneralAPIRequestor(APIRequestor): ) """ - def _interpret_response_line( - self, rbody: str, rcode: int, rheaders, stream: bool - ) -> str: + def _interpret_response_line(self, rbody: str, rcode: int, rheaders, stream: bool) -> str: # just do nothing to meet the APIRequestor process and return the raw data # due to the openai sdk will convert the data into OpenAIResponse which we don't need in general cases. @@ -39,11 +37,9 @@ class GeneralAPIRequestor(APIRequestor): ) -> Tuple[Union[str, AsyncGenerator[str, None]], bool]: if stream and "text/event-stream" in result.headers.get("Content-Type", ""): return ( - self._interpret_response_line( - line, result.status, result.headers, stream=True - ) - async for line in result.content - ), True + self._interpret_response_line(line, result.status, result.headers, stream=True) + async for line in result.content + ), True else: try: await result.read() diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 34e5693f8..3853e0ea6 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -5,11 +5,14 @@ @File : openai.py """ import asyncio +import json import time from typing import NamedTuple, Union -import openai -from openai.error import APIConnectionError +import httpx +from openai import APIConnectionError, AsyncOpenAI, AsyncStream, OpenAI +from openai.types import CompletionUsage +from openai.types.chat import ChatCompletion, ChatCompletionChunk from tenacity import ( after_log, retry, @@ -18,7 +21,7 @@ from tenacity import ( wait_fixed, ) -from metagpt.config import CONFIG +from metagpt.config import CONFIG, Config from metagpt.logs import logger from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA, GENERAL_TOOL_CHOICE @@ -144,23 +147,40 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): def __init__(self): self.__init_openai(CONFIG) - self.llm = openai self.model = CONFIG.openai_api_model self.auto_max_tokens = False self._cost_manager = CostManager() RateLimiter.__init__(self, rpm=self.rpm) - def __init_openai(self, config): - openai.api_key = config.openai_api_key - if config.openai_api_base: - openai.api_base = config.openai_api_base - if config.openai_api_type: - openai.api_type = config.openai_api_type - openai.api_version = config.openai_api_version + def __init_openai(self, config: Config): + client_kwargs, async_client_kwargs = self.__make_client_args(config) + + self.client = OpenAI(**client_kwargs) + self.async_client = AsyncOpenAI(**async_client_kwargs) + self.rpm = int(config.get("RPM", 10)) + def __make_client_args(self, config: Config): + mapping = { + "api_key": "openai_api_key", + "base_url": "openai_base_url", + } + + kwargs = {key: getattr(config, mapping[key]) for key in mapping if getattr(config, mapping[key], None)} + async_kwargs = kwargs.copy() + + # need http_client to support proxy + if config.openai_proxy: + httpx_args = dict(base_url=kwargs["base_url"], proxies=config.openai_proxy) + kwargs["http_client"] = httpx.Client(**httpx_args) + async_kwargs["http_client"] = httpx.AsyncClient(**httpx_args) + + return kwargs, async_kwargs + async def _achat_completion_stream(self, messages: list[dict]) -> str: - response = await openai.ChatCompletion.acreate(**self._cons_kwargs(messages), stream=True) + response: AsyncStream[ChatCompletionChunk] = await self.async_client.chat.completions.create( + **self._cons_kwargs(messages), stream=True + ) # create variables to collect the stream of chunks collected_chunks = [] @@ -168,15 +188,14 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): # iterate through the stream of events async for chunk in response: collected_chunks.append(chunk) # save the event response - choices = chunk["choices"] - if len(choices) > 0: - chunk_message = chunk["choices"][0].get("delta", {}) # extract the message + if chunk.choices: + chunk_message = chunk.choices[0].delta # extract the message collected_messages.append(chunk_message) # save the message - if "content" in chunk_message: - print(chunk_message["content"], end="") + if chunk_message.content: + print(chunk_message.content, end="") print() - full_reply_content = "".join([m.get("content", "") for m in collected_messages]) + full_reply_content = "".join([m.content for m in collected_messages if m.content]) usage = self._calc_usage(messages, full_reply_content) self._update_costs(usage) return full_reply_content @@ -208,24 +227,20 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): kwargs.update(kwargs_mode) return kwargs - async def _achat_completion(self, messages: list[dict]) -> dict: - rsp = await self.llm.ChatCompletion.acreate(**self._cons_kwargs(messages)) - self._update_costs(rsp.get("usage")) + async def _achat_completion(self, messages: list[dict]) -> ChatCompletion: + rsp: ChatCompletion = await self.async_client.chat.completions.create(**self._cons_kwargs(messages)) + self._update_costs(rsp.usage) return rsp - def _chat_completion(self, messages: list[dict]) -> dict: - rsp = self.llm.ChatCompletion.create(**self._cons_kwargs(messages)) - self._update_costs(rsp) + def _chat_completion(self, messages: list[dict]) -> ChatCompletion: + rsp: ChatCompletion = self.client.chat.completions.create(**self._cons_kwargs(messages)) + self._update_costs(rsp.usage) return rsp - def completion(self, messages: list[dict]) -> dict: - # if isinstance(messages[0], Message): - # messages = self.messages_to_dict(messages) + def completion(self, messages: list[dict]) -> ChatCompletion: return self._chat_completion(messages) - async def acompletion(self, messages: list[dict]) -> dict: - # if isinstance(messages[0], Message): - # messages = self.messages_to_dict(messages) + async def acompletion(self, messages: list[dict]) -> ChatCompletion: return await self._achat_completion(messages) @retry( @@ -255,14 +270,16 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): return self._cons_kwargs(messages, **kwargs) - def _chat_completion_function(self, messages: list[dict], **kwargs) -> dict: - rsp = self.llm.ChatCompletion.create(**self._func_configs(messages, **kwargs)) - self._update_costs(rsp.get("usage")) + def _chat_completion_function(self, messages: list[dict], **kwargs) -> ChatCompletion: + rsp: ChatCompletion = self.client.chat.completions.create(**self._func_configs(messages, **kwargs)) + self._update_costs(rsp.usage) return rsp - async def _achat_completion_function(self, messages: list[dict], **chat_configs) -> dict: - rsp = await self.llm.ChatCompletion.acreate(**self._func_configs(messages, **chat_configs)) - self._update_costs(rsp.get("usage")) + async def _achat_completion_function(self, messages: list[dict], **chat_configs) -> ChatCompletion: + rsp: ChatCompletion = await self.async_client.chat.completions.create( + **self._func_configs(messages, **chat_configs) + ) + self._update_costs(rsp.usage) return rsp def _process_message(self, messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]: @@ -317,21 +334,34 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): rsp = await self._achat_completion_function(messages, **kwargs) return self.get_choice_function_arguments(rsp) - def _calc_usage(self, messages: list[dict], rsp: str) -> dict: - usage = {} + def get_choice_function_arguments(self, rsp: ChatCompletion) -> dict: + """Required to provide the first function arguments of choice. + + :return dict: return the first function arguments of choice, for example, + {'language': 'python', 'code': "print('Hello, World!')"} + """ + try: + return json.loads(rsp.choices[0].message.tool_calls[0].function.arguments) + except json.JSONDecodeError: + return {} + + def get_choice_text(self, rsp: ChatCompletion) -> str: + """Required to provide the first text of choice""" + return rsp.choices[0].message.content if rsp.choices else "" + + def _calc_usage(self, messages: list[dict], rsp: str) -> CompletionUsage: + usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0) if CONFIG.calc_usage: try: - prompt_tokens = count_message_tokens(messages, self.model) - completion_tokens = count_string_tokens(rsp, self.model) - usage["prompt_tokens"] = prompt_tokens - usage["completion_tokens"] = completion_tokens + usage.prompt_tokens = count_message_tokens(messages, self.model) + usage.completion_tokens = count_string_tokens(rsp, self.model) return usage except Exception as e: logger.error("usage calculation failed!", e) else: return usage - async def acompletion_batch(self, batch: list[list[dict]]) -> list[dict]: + async def acompletion_batch(self, batch: list[list[dict]]) -> list[ChatCompletion]: """Return full JSON""" split_batches = self.split_batches(batch) all_results = [] @@ -357,12 +387,10 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): logger.info(f"Result of task {idx}: {result}") return results - def _update_costs(self, usage: dict): + def _update_costs(self, usage: CompletionUsage): if CONFIG.calc_usage: try: - prompt_tokens = int(usage["prompt_tokens"]) - completion_tokens = int(usage["completion_tokens"]) - self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) + self._cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model) except Exception as e: logger.error("updating costs failed!", e) @@ -385,7 +413,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): logger.error(f"moderating failed:{e}") def _moderation(self, content: Union[str, list[str]]): - rsp = self.llm.Moderation.create(input=content) + rsp = self.client.moderations.create(input=content) return rsp async def amoderation(self, content: Union[str, list[str]]): @@ -399,5 +427,5 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): logger.error(f"moderating failed:{e}") async def _amoderation(self, content: Union[str, list[str]]): - rsp = await self.llm.Moderation.acreate(input=content) + rsp = await self.async_client.moderations.create(input=content) return rsp diff --git a/metagpt/provider/zhipuai/zhipu_model_api.py b/metagpt/provider/zhipuai/zhipu_model_api.py index 618b2e865..19eb52530 100644 --- a/metagpt/provider/zhipuai/zhipu_model_api.py +++ b/metagpt/provider/zhipuai/zhipu_model_api.py @@ -3,15 +3,14 @@ # @Desc : zhipu model api to support sync & async for invoke & sse_invoke import zhipuai -from zhipuai.model_api.api import ModelAPI, InvokeType +from zhipuai.model_api.api import InvokeType, ModelAPI from zhipuai.utils.http_client import headers as zhipuai_default_headers -from metagpt.provider.zhipuai.async_sse_client import AsyncSSEClient from metagpt.provider.general_api_requestor import GeneralAPIRequestor +from metagpt.provider.zhipuai.async_sse_client import AsyncSSEClient class ZhiPuModelAPI(ModelAPI): - @classmethod def get_header(cls) -> dict: token = cls._generate_token() @@ -21,9 +20,7 @@ class ZhiPuModelAPI(ModelAPI): @classmethod def get_sse_header(cls) -> dict: token = cls._generate_token() - headers = { - "Authorization": token - } + headers = {"Authorization": token} return headers @classmethod @@ -44,36 +41,32 @@ class ZhiPuModelAPI(ModelAPI): # TODO to make the async request to be more generic for models in http mode. assert method in ["post", "get"] - api_base, url = cls.split_zhipu_api_url(invoke_type, kwargs) - requester = GeneralAPIRequestor(api_base=api_base) + base_url, url = cls.split_zhipu_api_url(invoke_type, kwargs) + requester = GeneralAPIRequestor(base_url=base_url) result, _, api_key = await requester.arequest( method=method, url=url, headers=headers, stream=stream, params=kwargs, - request_timeout=zhipuai.api_timeout_seconds + request_timeout=zhipuai.api_timeout_seconds, ) return result @classmethod async def ainvoke(cls, **kwargs) -> dict: - """ async invoke different from raw method `async_invoke` which get the final result by task_id""" + """async invoke different from raw method `async_invoke` which get the final result by task_id""" headers = cls.get_header() - resp = await cls.arequest(invoke_type=InvokeType.SYNC, - stream=False, - method="post", - headers=headers, - kwargs=kwargs) + resp = await cls.arequest( + invoke_type=InvokeType.SYNC, stream=False, method="post", headers=headers, kwargs=kwargs + ) return resp @classmethod async def asse_invoke(cls, **kwargs) -> AsyncSSEClient: - """ async sse_invoke """ + """async sse_invoke""" headers = cls.get_sse_header() - return AsyncSSEClient(await cls.arequest(invoke_type=InvokeType.SSE, - stream=True, - method="post", - headers=headers, - kwargs=kwargs)) + return AsyncSSEClient( + await cls.arequest(invoke_type=InvokeType.SSE, stream=True, method="post", headers=headers, kwargs=kwargs) + ) diff --git a/metagpt/tools/code_interpreter.py b/metagpt/tools/code_interpreter.py index e41eaab72..9575d6c13 100644 --- a/metagpt/tools/code_interpreter.py +++ b/metagpt/tools/code_interpreter.py @@ -1,22 +1,26 @@ +import inspect import re -from typing import List, Callable, Dict +import textwrap from pathlib import Path +from typing import Callable, Dict, List import wrapt -import textwrap -import inspect from interpreter.core.core import Interpreter -from metagpt.logs import logger +from metagpt.actions.clone_function import ( + CloneFunction, + run_function_code, + run_function_script, +) from metagpt.config import CONFIG +from metagpt.logs import logger from metagpt.utils.highlight import highlight -from metagpt.actions.clone_function import CloneFunction, run_function_code, run_function_script def extract_python_code(code: str): """Extract code blocks: If the code comments are the same, only the last code block is kept.""" # Use regular expressions to match comment blocks and related code. - pattern = r'(#\s[^\n]*)\n(.*?)(?=\n\s*#|$)' + pattern = r"(#\s[^\n]*)\n(.*?)(?=\n\s*#|$)" matches = re.findall(pattern, code, re.DOTALL) # Extract the last code block when encountering the same comment. @@ -25,8 +29,8 @@ def extract_python_code(code: str): unique_comments[comment] = code_block # concatenate into functional form - result_code = '\n'.join([f"{comment}\n{code_block}" for comment, code_block in unique_comments.items()]) - header_code = code[:code.find("#")] + result_code = "\n".join([f"{comment}\n{code_block}" for comment, code_block in unique_comments.items()]) + header_code = code[: code.find("#")] code = header_code + result_code logger.info(f"Extract python code: \n {highlight(code)}") @@ -36,12 +40,12 @@ def extract_python_code(code: str): class OpenCodeInterpreter(object): """https://github.com/KillianLucas/open-interpreter""" + def __init__(self, auto_run: bool = True) -> None: interpreter = Interpreter() interpreter.auto_run = auto_run interpreter.model = CONFIG.openai_api_model or "gpt-3.5-turbo" interpreter.api_key = CONFIG.openai_api_key - # interpreter.api_base = CONFIG.openai_api_base self.interpreter = interpreter def chat(self, query: str, reset: bool = True): @@ -50,15 +54,16 @@ class OpenCodeInterpreter(object): return self.interpreter.chat(query) @staticmethod - def extract_function(query_respond: List, function_name: str, *, language: str = 'python', - function_format: str = None) -> str: + def extract_function( + query_respond: List, function_name: str, *, language: str = "python", function_format: str = None + ) -> str: """create a function from query_respond.""" - if language not in ('python'): + if language not in ("python"): raise NotImplementedError(f"Not support to parse language {language}!") # set function form if function_format is None: - assert language == 'python', f"Expect python language for default function_format, but got {language}." + assert language == "python", f"Expect python language for default function_format, but got {language}." function_format = """def {function_name}():\n{code}""" # Extract the code module in the open-interpreter respond message. # The query_respond of open-interpreter before v0.1.4 is: @@ -68,25 +73,29 @@ class OpenCodeInterpreter(object): # "parsed_arguments": {"language": "python", "code": code of first plan} # ...] if "function_call" in query_respond[1]: - code = [item['function_call']['parsed_arguments']['code'] for item in query_respond - if "function_call" in item - and "parsed_arguments" in item["function_call"] - and 'language' in item["function_call"]['parsed_arguments'] - and item["function_call"]['parsed_arguments']['language'] == language] + code = [ + item["function_call"]["parsed_arguments"]["code"] + for item in query_respond + if "function_call" in item + and "parsed_arguments" in item["function_call"] + and "language" in item["function_call"]["parsed_arguments"] + and item["function_call"]["parsed_arguments"]["language"] == language + ] # The query_respond of open-interpreter v0.1.7 is: # [{'role': 'user', 'message': your query string}, # {'role': 'assistant', 'message': plan from llm, 'language': 'python', # 'code': code of first plan, 'output': output of first plan code}, # ...] elif "code" in query_respond[1]: - code = [item['code'] for item in query_respond - if "code" in item - and 'language' in item - and item['language'] == language] + code = [ + item["code"] + for item in query_respond + if "code" in item and "language" in item and item["language"] == language + ] else: raise ValueError(f"Unexpect message format in query_respond: {query_respond[1].keys()}") # add indent. - indented_code_str = textwrap.indent("\n".join(code), ' ' * 4) + indented_code_str = textwrap.indent("\n".join(code), " " * 4) # Return the code after deduplication. if language == "python": return extract_python_code(function_format.format(function_name=function_name, code=indented_code_str)) @@ -115,13 +124,13 @@ class OpenInterpreterDecorator(object): def _have_code(self, rsp: List[Dict]): # Is there any code generated? - return 'code' in rsp[1] and rsp[1]['code'] not in ("", None) + return "code" in rsp[1] and rsp[1]["code"] not in ("", None) def _is_faild_plan(self, rsp: List[Dict]): # is faild plan? - func_code = OpenCodeInterpreter.extract_function(rsp, 'function') + func_code = OpenCodeInterpreter.extract_function(rsp, "function") # If there is no more than 1 '\n', the plan execution fails. - if isinstance(func_code, str) and func_code.count('\n') <= 1: + if isinstance(func_code, str) and func_code.count("\n") <= 1: return True return False @@ -184,4 +193,5 @@ class OpenInterpreterDecorator(object): logger.error(f"Could not evaluate Python code \n{logger_code}: \nError: {e}") raise Exception("Could not evaluate Python code", e) return res + return wrapper(wrapped) diff --git a/metagpt/utils/make_sk_kernel.py b/metagpt/utils/make_sk_kernel.py index 5e919abeb..83b4005ec 100644 --- a/metagpt/utils/make_sk_kernel.py +++ b/metagpt/utils/make_sk_kernel.py @@ -21,14 +21,12 @@ def make_sk_kernel(): if CONFIG.openai_api_type == "azure": kernel.add_chat_service( "chat_completion", - AzureChatCompletion(CONFIG.deployment_name, CONFIG.openai_api_base, CONFIG.openai_api_key), + AzureChatCompletion(CONFIG.deployment_name, CONFIG.openai_base_url, CONFIG.openai_api_key), ) else: kernel.add_chat_service( "chat_completion", - OpenAIChatCompletion( - CONFIG.openai_api_model, CONFIG.openai_api_key, org_id=None, endpoint=CONFIG.openai_api_base - ), + OpenAIChatCompletion(CONFIG.openai_api_model, CONFIG.openai_api_key), ) return kernel diff --git a/requirements.txt b/requirements.txt index f0169d7fa..94aedbec7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,7 +14,7 @@ langchain==0.0.231 loguru==0.6.0 meilisearch==0.21.0 numpy==1.24.3 -openai>=0.28.0 +openai>=1.0.0 openpyxl beautifulsoup4==4.12.2 pandas==2.0.3 @@ -41,7 +41,7 @@ qdrant-client==1.4.0 pytest-mock==3.11.1 open-interpreter==0.1.7; python_version>"3.9" ta==0.10.2 -semantic-kernel==0.3.13.dev0 +semantic-kernel==0.4.0.dev0 wrapt==1.15.0 websocket-client==0.58.0 zhipuai==1.0.7 diff --git a/tests/metagpt/provider/test_zhipuai_api.py b/tests/metagpt/provider/test_zhipuai_api.py index 6a0c70de5..08c95a337 100644 --- a/tests/metagpt/provider/test_zhipuai_api.py +++ b/tests/metagpt/provider/test_zhipuai_api.py @@ -6,27 +6,17 @@ import pytest from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI +default_resp = {"code": 200, "data": {"choices": [{"role": "assistant", "content": "I'm chatglm-turbo"}]}} -default_resp = { - "code": 200, - "data": { - "choices": [ - {"role": "assistant", "content": "I'm chatglm-turbo"} - ] - } -} - -messages = [ - {"role": "user", "content": "who are you"} -] +messages = [{"role": "user", "content": "who are you"}] def mock_llm_ask(self, messages: list[dict]) -> dict: return default_resp -def test_zhipuai_completion(mocker): - mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.completion", mock_llm_ask) +def test_zhipuai_completion(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(ZhiPuAIGPTAPI, "completion", mock_llm_ask) resp = ZhiPuAIGPTAPI().completion(messages) assert resp["code"] == 200 @@ -38,8 +28,8 @@ async def mock_llm_aask(self, messgaes: list[dict], stream: bool = False) -> dic @pytest.mark.asyncio -async def test_zhipuai_acompletion(mocker): - mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.acompletion_text", mock_llm_aask) +async def test_zhipuai_acompletion(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(ZhiPuAIGPTAPI, "acompletion_text", mock_llm_aask) resp = await ZhiPuAIGPTAPI().acompletion_text(messages, stream=False) From 09134c9c725c1289eec7152d16690c9a3a6aa3e2 Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Tue, 5 Dec 2023 15:27:57 +0800 Subject: [PATCH 018/167] support new openai package --- config/config.yaml | 4 +-- docs/FAQ-EN.md | 6 ++-- docs/README_JA.md | 2 +- docs/tutorial/usage.md | 2 +- docs/tutorial/usage_cn.md | 2 +- metagpt/provider/openai_api.py | 25 +++++++++++----- metagpt/utils/common.py | 6 ++++ tests/metagpt/provider/test_openai.py | 41 +++++++++++++++++++++++++++ 8 files changed, 73 insertions(+), 15 deletions(-) diff --git a/config/config.yaml b/config/config.yaml index 249552693..9ef923366 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -2,10 +2,10 @@ # The configuration of key.yaml has a higher priority and will not enter git #### if OpenAI -## The official OPENAI_BASE_URL is https://api.openai.com/v1/ +## The official OPENAI_BASE_URL is https://api.openai.com/v1 ## If the official OPENAI_BASE_URL is not available, we recommend using the [openai-forward](https://github.com/beidongjiedeguang/openai-forward). ## Or, you can configure OPENAI_PROXY to access official OPENAI_BASE_URL. -OPENAI_BASE_URL: "https://api.openai.com/v1/" +OPENAI_BASE_URL: "https://api.openai.com/v1" #OPENAI_PROXY: "http://127.0.0.1:8118" #OPENAI_API_KEY: "YOUR_API_KEY" # set the value to sk-xxx if you host the openai interface for open llm model OPENAI_API_MODEL: "gpt-4" diff --git a/docs/FAQ-EN.md b/docs/FAQ-EN.md index 1c5b4a86a..fe2def1e1 100644 --- a/docs/FAQ-EN.md +++ b/docs/FAQ-EN.md @@ -83,10 +83,10 @@ 1. PRD stuck / unable to access/ connection interrupted - 1. The official OPENAI_BASE_URL address is `https://api.openai.com/v1/` - 1. If the official OPENAI_BASE_URL address is inaccessible in your environment (this can be verified with curl), it's recommended to configure using the reverse proxy OPENAI_BASE_URL provided by libraries such as openai-forward. For instance, `OPENAI_BASE_URL: "``https://api.openai-forward.com/v1/``"` + 1. The official OPENAI_BASE_URL address is `https://api.openai.com/v1` + 1. If the official OPENAI_BASE_URL address is inaccessible in your environment (this can be verified with curl), it's recommended to configure using the reverse proxy OPENAI_BASE_URL provided by libraries such as openai-forward. For instance, `OPENAI_BASE_URL: "``https://api.openai-forward.com/v1``"` 1. If the official OPENAI_BASE_URL address is inaccessible in your environment (again, verifiable via curl), another option is to configure the OPENAI_PROXY parameter. This way, you can access the official OPENAI_BASE_URL via a local proxy. If you don't need to access via a proxy, please do not enable this configuration; if accessing through a proxy is required, modify it to the correct proxy address. Note that when OPENAI_PROXY is enabled, don't set OPENAI_BASE_URL. - 1. Note: OpenAI's default API design ends with a v1. An example of the correct configuration is: `OPENAI_BASE_URL: "``https://api.openai.com/v1/``"` + 1. Note: OpenAI's default API design ends with a v1. An example of the correct configuration is: `OPENAI_BASE_URL: "``https://api.openai.com/v1``"` 1. Absolutely! How can I assist you today? diff --git a/docs/README_JA.md b/docs/README_JA.md index 33b08b770..14e7c3111 100644 --- a/docs/README_JA.md +++ b/docs/README_JA.md @@ -219,7 +219,7 @@ # 設定ファイルをコピーし、必要な修正を加える。 | 変数名 | config/key.yaml | env | | --------------------------------------- | ----------------------------------------- | ----------------------------------------------- | | OPENAI_API_KEY # 自分のキーに置き換える | OPENAI_API_KEY: "sk-..." | export OPENAI_API_KEY="sk-..." | -| OPENAI_BASE_URL # オプション | OPENAI_BASE_URL: "https:///v1/" | export OPENAI_BASE_URL="https:///v1/" | +| OPENAI_BASE_URL # オプション | OPENAI_BASE_URL: "https:///v1" | export OPENAI_BASE_URL="https:///v1" | ## チュートリアル: スタートアップの開始 diff --git a/docs/tutorial/usage.md b/docs/tutorial/usage.md index f8a25c84f..e6b4a7cc5 100644 --- a/docs/tutorial/usage.md +++ b/docs/tutorial/usage.md @@ -13,7 +13,7 @@ # Copy the configuration file and make the necessary modifications. | Variable Name | config/key.yaml | env | | ------------------------------------------ | ----------------------------------------- | ----------------------------------------------- | | OPENAI_API_KEY # Replace with your own key | OPENAI_API_KEY: "sk-..." | export OPENAI_API_KEY="sk-..." | -| OPENAI_BASE_URL # Optional | OPENAI_BASE_URL: "https:///v1/" | export OPENAI_BASE_URL="https:///v1/" | +| OPENAI_BASE_URL # Optional | OPENAI_BASE_URL: "https:///v1" | export OPENAI_BASE_URL="https:///v1" | ### Initiating a startup diff --git a/docs/tutorial/usage_cn.md b/docs/tutorial/usage_cn.md index ddd1c2267..195eec674 100644 --- a/docs/tutorial/usage_cn.md +++ b/docs/tutorial/usage_cn.md @@ -13,7 +13,7 @@ # 复制配置文件并进行必要的修改 | 变量名 | config/key.yaml | env | | ----------------------------------- | ----------------------------------------- | ----------------------------------------------- | | OPENAI_API_KEY # 用您自己的密钥替换 | OPENAI_API_KEY: "sk-..." | export OPENAI_API_KEY="sk-..." | -| OPENAI_BASE_URL # 可选 | OPENAI_BASE_URL: "https:///v1/" | export OPENAI_BASE_URL="https:///v1/" | +| OPENAI_BASE_URL # 可选 | OPENAI_BASE_URL: "https:///v1" | export OPENAI_BASE_URL="https:///v1" | ### 示例:启动一个创业公司 diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 3853e0ea6..98551c370 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -26,6 +26,7 @@ from metagpt.logs import logger from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA, GENERAL_TOOL_CHOICE from metagpt.schema import Message +from metagpt.utils.common import ensure_trailing_slash from metagpt.utils.singleton import Singleton from metagpt.utils.token_counter import ( TOKEN_COSTS, @@ -153,27 +154,37 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): RateLimiter.__init__(self, rpm=self.rpm) def __init_openai(self, config: Config): - client_kwargs, async_client_kwargs = self.__make_client_args(config) + client_kwargs, async_client_kwargs = self._make_client_kwargs(config) self.client = OpenAI(**client_kwargs) self.async_client = AsyncOpenAI(**async_client_kwargs) self.rpm = int(config.get("RPM", 10)) - def __make_client_args(self, config: Config): + def _make_client_kwargs(self, config: Config) -> (dict, dict): mapping = { "api_key": "openai_api_key", "base_url": "openai_base_url", } + kwargs = {} + for key, attr in mapping.items(): + value = getattr(config, attr, None) + if value: + kwargs[key] = value + + if config.openai_base_url: + kwargs["base_url"] = ensure_trailing_slash(config.openai_base_url) - kwargs = {key: getattr(config, mapping[key]) for key in mapping if getattr(config, mapping[key], None)} async_kwargs = kwargs.copy() - # need http_client to support proxy + # Create http_client if proxy is specified if config.openai_proxy: - httpx_args = dict(base_url=kwargs["base_url"], proxies=config.openai_proxy) - kwargs["http_client"] = httpx.Client(**httpx_args) - async_kwargs["http_client"] = httpx.AsyncClient(**httpx_args) + params = {"proxies": config.openai_proxy} + if config.openai_base_url: + params["base_url"] = config.openai_base_url + + kwargs["http_client"] = httpx.Client(**params) + async_kwargs["http_client"] = httpx.AsyncClient(**params) return kwargs, async_kwargs diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index f09666beb..c69a0fe10 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -305,3 +305,9 @@ def parse_recipient(text): pattern = r"## Send To:\s*([A-Za-z]+)\s*?" # hard code for now recipient = re.search(pattern, text) return recipient.group(1) if recipient else "" + + +def ensure_trailing_slash(url): + if not url: + return url + return url if url.endswith("/") else url + "/" diff --git a/tests/metagpt/provider/test_openai.py b/tests/metagpt/provider/test_openai.py index 2b0af37b5..3e8dbf7e7 100644 --- a/tests/metagpt/provider/test_openai.py +++ b/tests/metagpt/provider/test_openai.py @@ -1,4 +1,5 @@ import pytest +from httpx import AsyncClient, Client from metagpt.provider.openai_api import OpenAIGPTAPI from metagpt.schema import UserMessage @@ -78,3 +79,43 @@ def test_ask_code_list_str(): assert "language" in rsp assert "code" in rsp assert len(rsp["code"]) > 0 + + +def test_make_client_kwargs(): + class Config: + openai_api_key = "test_key" + openai_base_url = "test_url" + openai_proxy = "http://test_proxy" + + config = Config() + obj = OpenAIGPTAPI() + kwargs, async_kwargs = obj._make_client_kwargs(config) + + assert kwargs["api_key"] == "test_key" + assert kwargs["base_url"] == "test_url/" + assert isinstance(kwargs["http_client"], Client) + assert kwargs["http_client"].base_url == "test_url/" + + assert async_kwargs["api_key"] == "test_key" + assert async_kwargs["base_url"] == "test_url/" + assert isinstance(async_kwargs["http_client"], AsyncClient) + assert async_kwargs["http_client"].base_url == "test_url/" + + +def test_make_client_kwargs_no_proxy(): + class Config: + openai_api_key = "test_key" + openai_base_url = "test_url" + openai_proxy = None + + config = Config() + obj = OpenAIGPTAPI() + kwargs, async_kwargs = obj._make_client_kwargs(config) + + assert kwargs["api_key"] == "test_key" + assert kwargs["base_url"] == "test_url/" + assert "http_client" not in kwargs + + assert async_kwargs["api_key"] == "test_key" + assert async_kwargs["base_url"] == "test_url/" + assert "http_client" not in async_kwargs From 0d8b9cdc89ebf17f7d282e8f35745a17451d68ee Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Tue, 5 Dec 2023 15:36:38 +0800 Subject: [PATCH 019/167] support new openai package --- metagpt/provider/openai_api.py | 1 + 1 file changed, 1 insertion(+) diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 98551c370..733048b67 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -172,6 +172,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): if value: kwargs[key] = value + # OpenAI v1 requires the base_url to end with / if config.openai_base_url: kwargs["base_url"] = ensure_trailing_slash(config.openai_base_url) From f03a6d802978f7a56279f9852af607a71357d3e3 Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Tue, 5 Dec 2023 16:21:34 +0800 Subject: [PATCH 020/167] support new openai package --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 94aedbec7..93b7319f9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,7 +14,7 @@ langchain==0.0.231 loguru==0.6.0 meilisearch==0.21.0 numpy==1.24.3 -openai>=1.0.0 +openai~=1.3 openpyxl beautifulsoup4==4.12.2 pandas==2.0.3 From a617aab65b506a35c3edd3586845d3307427fff1 Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Wed, 6 Dec 2023 11:58:13 +0800 Subject: [PATCH 021/167] azure client --- metagpt/provider/openai_api.py | 71 ++++++++++++--------- metagpt/utils/common.py | 6 -- tests/metagpt/provider/test_openai.py | 88 +++++++++++++++++---------- 3 files changed, 98 insertions(+), 67 deletions(-) diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 733048b67..7fdc6ece0 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -10,7 +10,14 @@ import time from typing import NamedTuple, Union import httpx -from openai import APIConnectionError, AsyncOpenAI, AsyncStream, OpenAI +from openai import ( + APIConnectionError, + AsyncAzureOpenAI, + AsyncOpenAI, + AsyncStream, + AzureOpenAI, + OpenAI, +) from openai.types import CompletionUsage from openai.types.chat import ChatCompletion, ChatCompletionChunk from tenacity import ( @@ -26,7 +33,6 @@ from metagpt.logs import logger from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA, GENERAL_TOOL_CHOICE from metagpt.schema import Message -from metagpt.utils.common import ensure_trailing_slash from metagpt.utils.singleton import Singleton from metagpt.utils.token_counter import ( TOKEN_COSTS, @@ -154,40 +160,49 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): RateLimiter.__init__(self, rpm=self.rpm) def __init_openai(self, config: Config): - client_kwargs, async_client_kwargs = self._make_client_kwargs(config) - - self.client = OpenAI(**client_kwargs) - self.async_client = AsyncOpenAI(**async_client_kwargs) - + self._make_client(config) self.rpm = int(config.get("RPM", 10)) - def _make_client_kwargs(self, config: Config) -> (dict, dict): - mapping = { - "api_key": "openai_api_key", - "base_url": "openai_base_url", - } - kwargs = {} - for key, attr in mapping.items(): - value = getattr(config, attr, None) - if value: - kwargs[key] = value + def _make_client(self, config: Config): + kwargs, async_kwargs = self._make_client_kwargs(config) - # OpenAI v1 requires the base_url to end with / - if config.openai_base_url: - kwargs["base_url"] = ensure_trailing_slash(config.openai_base_url) + if self._is_azure(config): + self.client = AzureOpenAI(**kwargs) + self.async_client = AsyncAzureOpenAI(**async_kwargs) + else: + self.client = OpenAI(**kwargs) + self.async_client = AsyncOpenAI(**async_kwargs) + + def _make_client_kwargs(self, config: Config) -> (dict, dict): + if self._is_azure(config): + kwargs = dict( + api_key=config.openai_api_key, + api_version=config.openai_api_version, + azure_endpoint=config.openai_base_url, + ) + else: + kwargs = dict(api_key=config.openai_api_key, base_url=config.openai_base_url) async_kwargs = kwargs.copy() - # Create http_client if proxy is specified + # to use proxy, openai v1 needs http_client + proxy_params = self._get_proxy_params(config) + if proxy_params: + kwargs["http_client"] = httpx.Client(**proxy_params) + async_kwargs["http_client"] = httpx.AsyncClient(**proxy_params) + + return kwargs, async_kwargs + + def _is_azure(self, config: Config) -> bool: + return config.openai_api_type == "azure" + + def _get_proxy_params(self, config: Config) -> dict: + params = {} if config.openai_proxy: params = {"proxies": config.openai_proxy} if config.openai_base_url: params["base_url"] = config.openai_base_url - - kwargs["http_client"] = httpx.Client(**params) - async_kwargs["http_client"] = httpx.AsyncClient(**params) - - return kwargs, async_kwargs + return params async def _achat_completion_stream(self, messages: list[dict]) -> str: response: AsyncStream[ChatCompletionChunk] = await self.async_client.chat.completions.create( @@ -230,9 +245,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): elif not CONFIG.deployment_name and not CONFIG.deployment_id: raise ValueError("You must specify `DEPLOYMENT_NAME` or `DEPLOYMENT_ID` parameter") kwargs_mode = ( - {"engine": CONFIG.deployment_name} - if CONFIG.deployment_name - else {"deployment_id": CONFIG.deployment_id} + {"model": CONFIG.deployment_name} if CONFIG.deployment_name else {"deployment_id": CONFIG.deployment_id} ) else: kwargs_mode = {"model": self.model} diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index c69a0fe10..f09666beb 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -305,9 +305,3 @@ def parse_recipient(text): pattern = r"## Send To:\s*([A-Za-z]+)\s*?" # hard code for now recipient = re.search(pattern, text) return recipient.group(1) if recipient else "" - - -def ensure_trailing_slash(url): - if not url: - return url - return url if url.endswith("/") else url + "/" diff --git a/tests/metagpt/provider/test_openai.py b/tests/metagpt/provider/test_openai.py index 3e8dbf7e7..8d853f11c 100644 --- a/tests/metagpt/provider/test_openai.py +++ b/tests/metagpt/provider/test_openai.py @@ -1,5 +1,6 @@ +from unittest.mock import Mock + import pytest -from httpx import AsyncClient, Client from metagpt.provider.openai_api import OpenAIGPTAPI from metagpt.schema import UserMessage @@ -81,41 +82,64 @@ def test_ask_code_list_str(): assert len(rsp["code"]) > 0 -def test_make_client_kwargs(): - class Config: - openai_api_key = "test_key" - openai_base_url = "test_url" - openai_proxy = "http://test_proxy" +class TestOpenAI: + @pytest.fixture + def config(self): + return Mock(openai_api_key="test_key", openai_base_url="test_url", openai_proxy=None, openai_api_type="other") - config = Config() - obj = OpenAIGPTAPI() - kwargs, async_kwargs = obj._make_client_kwargs(config) + @pytest.fixture + def config_azure(self): + return Mock( + openai_api_key="test_key", + openai_api_version="test_version", + openai_base_url="test_url", + openai_proxy=None, + openai_api_type="azure", + ) - assert kwargs["api_key"] == "test_key" - assert kwargs["base_url"] == "test_url/" - assert isinstance(kwargs["http_client"], Client) - assert kwargs["http_client"].base_url == "test_url/" + @pytest.fixture + def config_proxy(self): + return Mock( + openai_api_key="test_key", + openai_base_url="test_url", + openai_proxy="http://proxy.com", + openai_api_type="other", + ) - assert async_kwargs["api_key"] == "test_key" - assert async_kwargs["base_url"] == "test_url/" - assert isinstance(async_kwargs["http_client"], AsyncClient) - assert async_kwargs["http_client"].base_url == "test_url/" + @pytest.fixture + def config_azure_proxy(self): + return Mock( + openai_api_key="test_key", + openai_api_version="test_version", + openai_base_url="test_url", + openai_proxy="http://proxy.com", + openai_api_type="azure", + ) + def test_make_client_kwargs_without_proxy(self, config): + instance = OpenAIGPTAPI() + kwargs, async_kwargs = instance._make_client_kwargs(config) + assert kwargs == {"api_key": "test_key", "base_url": "test_url"} + assert async_kwargs == {"api_key": "test_key", "base_url": "test_url"} + assert "http_client" not in kwargs + assert "http_client" not in async_kwargs -def test_make_client_kwargs_no_proxy(): - class Config: - openai_api_key = "test_key" - openai_base_url = "test_url" - openai_proxy = None + def test_make_client_kwargs_without_proxy_azure(self, config_azure): + instance = OpenAIGPTAPI() + kwargs, async_kwargs = instance._make_client_kwargs(config_azure) + assert kwargs == {"api_key": "test_key", "api_version": "test_version", "azure_endpoint": "test_url"} + assert async_kwargs == {"api_key": "test_key", "api_version": "test_version", "azure_endpoint": "test_url"} + assert "http_client" not in kwargs + assert "http_client" not in async_kwargs - config = Config() - obj = OpenAIGPTAPI() - kwargs, async_kwargs = obj._make_client_kwargs(config) + def test_make_client_kwargs_with_proxy(self, config_proxy): + instance = OpenAIGPTAPI() + kwargs, async_kwargs = instance._make_client_kwargs(config_proxy) + assert "http_client" in kwargs + assert "http_client" in async_kwargs - assert kwargs["api_key"] == "test_key" - assert kwargs["base_url"] == "test_url/" - assert "http_client" not in kwargs - - assert async_kwargs["api_key"] == "test_key" - assert async_kwargs["base_url"] == "test_url/" - assert "http_client" not in async_kwargs + def test_make_client_kwargs_with_proxy_azure(self, config_azure_proxy): + instance = OpenAIGPTAPI() + kwargs, async_kwargs = instance._make_client_kwargs(config_azure_proxy) + assert "http_client" in kwargs + assert "http_client" in async_kwargs From ad347e0717c3783163249753c7c196e6eb199525 Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Wed, 6 Dec 2023 16:06:17 +0800 Subject: [PATCH 022/167] upgrade tiktoken to support azure --- config/config.yaml | 2 - metagpt/config.py | 3 +- metagpt/provider/openai_api.py | 66 +++++++++++++-------------- metagpt/utils/token_counter.py | 10 +++- requirements.txt | 2 +- tests/metagpt/provider/test_openai.py | 12 +++-- 6 files changed, 50 insertions(+), 45 deletions(-) diff --git a/config/config.yaml b/config/config.yaml index 9ef923366..2846467ed 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -23,13 +23,11 @@ RPM: 10 #Anthropic_API_KEY: "YOUR_API_KEY" #### if AZURE, check https://github.com/openai/openai-cookbook/blob/main/examples/azure/chat.ipynb -#### You can use ENGINE or DEPLOYMENT mode #OPENAI_API_TYPE: "azure" #OPENAI_BASE_URL: "YOUR_AZURE_ENDPOINT" #OPENAI_API_KEY: "YOUR_AZURE_API_KEY" #OPENAI_API_VERSION: "YOUR_AZURE_API_VERSION" #DEPLOYMENT_NAME: "YOUR_DEPLOYMENT_NAME" -#DEPLOYMENT_ID: "YOUR_DEPLOYMENT_ID" #### if zhipuai from `https://open.bigmodel.cn`. You can set here or export API_KEY="YOUR_API_KEY" # ZHIPUAI_API_KEY: "YOUR_API_KEY" diff --git a/metagpt/config.py b/metagpt/config.py index 4306445ef..4f53a0ff3 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -58,8 +58,7 @@ class Config(metaclass=Singleton): self.openai_api_rpm = self._get("RPM", 3) self.openai_api_model = self._get("OPENAI_API_MODEL", "gpt-4") self.max_tokens_rsp = self._get("MAX_TOKENS", 2048) - self.deployment_name = self._get("DEPLOYMENT_NAME") - self.deployment_id = self._get("DEPLOYMENT_ID") + self.deployment_name = self._get("DEPLOYMENT_NAME", "gpt-4") self.spark_appid = self._get("SPARK_APPID") self.spark_api_secret = self._get("SPARK_API_SECRET") diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 7fdc6ece0..6564dcde4 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -153,55 +153,63 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): """ def __init__(self): - self.__init_openai(CONFIG) - self.model = CONFIG.openai_api_model + self.config: Config = CONFIG + self.__init_openai() self.auto_max_tokens = False self._cost_manager = CostManager() RateLimiter.__init__(self, rpm=self.rpm) - def __init_openai(self, config: Config): - self._make_client(config) - self.rpm = int(config.get("RPM", 10)) + @property + def model(self): + if self._is_azure(): + return self.config.deployment_name - def _make_client(self, config: Config): - kwargs, async_kwargs = self._make_client_kwargs(config) + return self.config.openai_api_model - if self._is_azure(config): + def __init_openai(self): + self._make_client() + self.rpm = int(self.config.get("RPM", 10)) + + def _make_client(self): + kwargs, async_kwargs = self._make_client_kwargs() + + if self._is_azure(): self.client = AzureOpenAI(**kwargs) self.async_client = AsyncAzureOpenAI(**async_kwargs) else: self.client = OpenAI(**kwargs) self.async_client = AsyncOpenAI(**async_kwargs) - def _make_client_kwargs(self, config: Config) -> (dict, dict): - if self._is_azure(config): + def _make_client_kwargs(self) -> (dict, dict): + if self._is_azure(): kwargs = dict( - api_key=config.openai_api_key, - api_version=config.openai_api_version, - azure_endpoint=config.openai_base_url, + api_key=self.config.openai_api_key, + api_version=self.config.openai_api_version, + azure_endpoint=self.config.openai_base_url, ) else: - kwargs = dict(api_key=config.openai_api_key, base_url=config.openai_base_url) + kwargs = dict(api_key=self.config.openai_api_key, base_url=self.config.openai_base_url) async_kwargs = kwargs.copy() # to use proxy, openai v1 needs http_client - proxy_params = self._get_proxy_params(config) + proxy_params = self._get_proxy_params() if proxy_params: kwargs["http_client"] = httpx.Client(**proxy_params) async_kwargs["http_client"] = httpx.AsyncClient(**proxy_params) return kwargs, async_kwargs - def _is_azure(self, config: Config) -> bool: - return config.openai_api_type == "azure" + def _is_azure(self) -> bool: + return self.config.openai_api_type == "azure" - def _get_proxy_params(self, config: Config) -> dict: + def _get_proxy_params(self) -> dict: params = {} - if config.openai_proxy: - params = {"proxies": config.openai_proxy} - if config.openai_base_url: - params["base_url"] = config.openai_base_url + if self.config.openai_proxy: + params = {"proxies": self.config.openai_proxy} + if self.config.openai_base_url: + params["base_url"] = self.config.openai_base_url + return params async def _achat_completion_stream(self, messages: list[dict]) -> str: @@ -235,21 +243,11 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): "stop": None, "temperature": 0.3, "timeout": 3, + "model": self.model, } if configs: kwargs.update(configs) - if CONFIG.openai_api_type == "azure": - if CONFIG.deployment_name and CONFIG.deployment_id: - raise ValueError("You can only use one of the `deployment_id` or `deployment_name` model") - elif not CONFIG.deployment_name and not CONFIG.deployment_id: - raise ValueError("You must specify `DEPLOYMENT_NAME` or `DEPLOYMENT_ID` parameter") - kwargs_mode = ( - {"model": CONFIG.deployment_name} if CONFIG.deployment_name else {"deployment_id": CONFIG.deployment_id} - ) - else: - kwargs_mode = {"model": self.model} - kwargs.update(kwargs_mode) return kwargs async def _achat_completion(self, messages: list[dict]) -> ChatCompletion: @@ -382,7 +380,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): usage.completion_tokens = count_string_tokens(rsp, self.model) return usage except Exception as e: - logger.error("usage calculation failed!", e) + logger.error(f"usage calculation failed!: {e}") else: return usage diff --git a/metagpt/utils/token_counter.py b/metagpt/utils/token_counter.py index 1af96f272..21de43501 100644 --- a/metagpt/utils/token_counter.py +++ b/metagpt/utils/token_counter.py @@ -16,13 +16,15 @@ TOKEN_COSTS = { "gpt-3.5-turbo-0613": {"prompt": 0.0015, "completion": 0.002}, "gpt-3.5-turbo-16k": {"prompt": 0.003, "completion": 0.004}, "gpt-3.5-turbo-16k-0613": {"prompt": 0.003, "completion": 0.004}, + "gpt-35-turbo": {"prompt": 0.0015, "completion": 0.002}, + "gpt-35-turbo-16k": {"prompt": 0.003, "completion": 0.004}, "gpt-4-0314": {"prompt": 0.03, "completion": 0.06}, "gpt-4": {"prompt": 0.03, "completion": 0.06}, "gpt-4-32k": {"prompt": 0.06, "completion": 0.12}, "gpt-4-32k-0314": {"prompt": 0.06, "completion": 0.12}, "gpt-4-0613": {"prompt": 0.06, "completion": 0.12}, "text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0}, - "chatglm_turbo": {"prompt": 0.0, "completion": 0.00069} # 32k version, prompt + completion tokens=0.005¥/k-tokens + "chatglm_turbo": {"prompt": 0.0, "completion": 0.00069}, # 32k version, prompt + completion tokens=0.005¥/k-tokens } @@ -32,13 +34,15 @@ TOKEN_MAX = { "gpt-3.5-turbo-0613": 4096, "gpt-3.5-turbo-16k": 16384, "gpt-3.5-turbo-16k-0613": 16384, + "gpt-35-turbo": 4096, + "gpt-35-turbo-16k": 16384, "gpt-4-0314": 8192, "gpt-4": 8192, "gpt-4-32k": 32768, "gpt-4-32k-0314": 32768, "gpt-4-0613": 8192, "text-embedding-ada-002": 8192, - "chatglm_turbo": 32768 + "chatglm_turbo": 32768, } @@ -52,6 +56,8 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0613"): if model in { "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", + "gpt-35-turbo", + "gpt-35-turbo-16k", "gpt-4-0314", "gpt-4-32k-0314", "gpt-4-0613", diff --git a/requirements.txt b/requirements.txt index 93b7319f9..c57fb6c2c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -27,7 +27,7 @@ PyYAML==6.0.1 # sentence_transformers==2.2.2 setuptools==65.6.3 tenacity==8.2.2 -tiktoken==0.4.0 +tiktoken==0.5.2 tqdm==4.64.0 #unstructured[local-inference] # playwright diff --git a/tests/metagpt/provider/test_openai.py b/tests/metagpt/provider/test_openai.py index 8d853f11c..332d554cf 100644 --- a/tests/metagpt/provider/test_openai.py +++ b/tests/metagpt/provider/test_openai.py @@ -118,7 +118,8 @@ class TestOpenAI: def test_make_client_kwargs_without_proxy(self, config): instance = OpenAIGPTAPI() - kwargs, async_kwargs = instance._make_client_kwargs(config) + instance.config = config + kwargs, async_kwargs = instance._make_client_kwargs() assert kwargs == {"api_key": "test_key", "base_url": "test_url"} assert async_kwargs == {"api_key": "test_key", "base_url": "test_url"} assert "http_client" not in kwargs @@ -126,7 +127,8 @@ class TestOpenAI: def test_make_client_kwargs_without_proxy_azure(self, config_azure): instance = OpenAIGPTAPI() - kwargs, async_kwargs = instance._make_client_kwargs(config_azure) + instance.config = config_azure + kwargs, async_kwargs = instance._make_client_kwargs() assert kwargs == {"api_key": "test_key", "api_version": "test_version", "azure_endpoint": "test_url"} assert async_kwargs == {"api_key": "test_key", "api_version": "test_version", "azure_endpoint": "test_url"} assert "http_client" not in kwargs @@ -134,12 +136,14 @@ class TestOpenAI: def test_make_client_kwargs_with_proxy(self, config_proxy): instance = OpenAIGPTAPI() - kwargs, async_kwargs = instance._make_client_kwargs(config_proxy) + instance.config = config_proxy + kwargs, async_kwargs = instance._make_client_kwargs() assert "http_client" in kwargs assert "http_client" in async_kwargs def test_make_client_kwargs_with_proxy_azure(self, config_azure_proxy): instance = OpenAIGPTAPI() - kwargs, async_kwargs = instance._make_client_kwargs(config_azure_proxy) + instance.config = config_azure_proxy + kwargs, async_kwargs = instance._make_client_kwargs() assert "http_client" in kwargs assert "http_client" in async_kwargs From f4505d0e397a816648bb5476e2bc6b3ee505bb8d Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Wed, 6 Dec 2023 16:23:43 +0800 Subject: [PATCH 023/167] upgrade tiktoken to support azure --- metagpt/provider/openai_api.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 6564dcde4..9a328f386 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -159,21 +159,16 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): self._cost_manager = CostManager() RateLimiter.__init__(self, rpm=self.rpm) - @property - def model(self): - if self._is_azure(): - return self.config.deployment_name - - return self.config.openai_api_model - def __init_openai(self): - self._make_client() + self.is_azure = self.config.openai_api_type == "azure" + self.model = self.config.deployment_name if self.is_azure else self.config.openai_api_model self.rpm = int(self.config.get("RPM", 10)) + self._make_client() def _make_client(self): kwargs, async_kwargs = self._make_client_kwargs() - if self._is_azure(): + if self.is_azure: self.client = AzureOpenAI(**kwargs) self.async_client = AsyncAzureOpenAI(**async_kwargs) else: @@ -181,7 +176,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): self.async_client = AsyncOpenAI(**async_kwargs) def _make_client_kwargs(self) -> (dict, dict): - if self._is_azure(): + if self.is_azure: kwargs = dict( api_key=self.config.openai_api_key, api_version=self.config.openai_api_version, @@ -200,9 +195,6 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): return kwargs, async_kwargs - def _is_azure(self) -> bool: - return self.config.openai_api_type == "azure" - def _get_proxy_params(self) -> dict: params = {} if self.config.openai_proxy: From 97f156b10d5fdb54dedc8076069d59bfae5713a7 Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Thu, 7 Dec 2023 10:23:08 +0800 Subject: [PATCH 024/167] revert pytest.MonkeyPatch --- tests/metagpt/provider/test_zhipuai_api.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/metagpt/provider/test_zhipuai_api.py b/tests/metagpt/provider/test_zhipuai_api.py index 08c95a337..4684e8887 100644 --- a/tests/metagpt/provider/test_zhipuai_api.py +++ b/tests/metagpt/provider/test_zhipuai_api.py @@ -15,8 +15,8 @@ def mock_llm_ask(self, messages: list[dict]) -> dict: return default_resp -def test_zhipuai_completion(monkeypatch: pytest.MonkeyPatch): - monkeypatch.setattr(ZhiPuAIGPTAPI, "completion", mock_llm_ask) +def test_zhipuai_completion(mocker): + mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.completion", mock_llm_ask) resp = ZhiPuAIGPTAPI().completion(messages) assert resp["code"] == 200 @@ -28,8 +28,8 @@ async def mock_llm_aask(self, messgaes: list[dict], stream: bool = False) -> dic @pytest.mark.asyncio -async def test_zhipuai_acompletion(monkeypatch: pytest.MonkeyPatch): - monkeypatch.setattr(ZhiPuAIGPTAPI, "acompletion_text", mock_llm_aask) +async def test_zhipuai_acompletion(mocker): + mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.acompletion_text", mock_llm_aask) resp = await ZhiPuAIGPTAPI().acompletion_text(messages, stream=False) From d196bd0cc947aaf47520bfc3157df064a95d8ab5 Mon Sep 17 00:00:00 2001 From: paulaan Date: Sun, 10 Dec 2023 00:15:39 +0700 Subject: [PATCH 025/167] selenium config better performance --- metagpt/tools/web_browser_engine_selenium.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/metagpt/tools/web_browser_engine_selenium.py b/metagpt/tools/web_browser_engine_selenium.py index d727709b8..80b60a93c 100644 --- a/metagpt/tools/web_browser_engine_selenium.py +++ b/metagpt/tools/web_browser_engine_selenium.py @@ -104,6 +104,9 @@ def _gen_get_driver_func(browser_type, *args, executable_path=None): def _get_driver(): options = Options() options.add_argument("--headless") + options.add_argument("--no-sandbox") # This flag is important for running in a Docker container + options.add_argument("--disable-gpu") # This flag can help avoid renderer issue + options.add_argument("--disable-dev-shm-usage") # Overcome limited resource problems options.add_argument("--enable-javascript") if browser_type == "chrome": options.add_argument("--no-sandbox") From c92793c27ceae28cdc0fba67c39648b5cb42cabd Mon Sep 17 00:00:00 2001 From: paulaan Date: Sat, 9 Dec 2023 12:57:54 +0700 Subject: [PATCH 026/167] researcher allow override system prompt --- metagpt/roles/researcher.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/metagpt/roles/researcher.py b/metagpt/roles/researcher.py index c5512121a..f954c60bb 100644 --- a/metagpt/roles/researcher.py +++ b/metagpt/roles/researcher.py @@ -46,7 +46,7 @@ class Researcher(Role): else: topic = msg.content - research_system_text = get_research_system_text(topic, self.language) + research_system_text = self.research_system_text(topic) if isinstance(todo, CollectLinks): links = await todo.run(topic, 4, 4) ret = Message("", Report(topic=topic, links=links), role=self.profile, cause_by=type(todo)) @@ -64,6 +64,17 @@ class Researcher(Role): self._rc.memory.add(ret) return ret + def research_system_text(self, topic) -> str: + """ BACKWARD compatible + This allows sub-class able to define its own system prompt based on topic. + return the previous implementation to have backward compatible + Args: + topic: + language: + + Returns: str + """ + return get_research_system_text(topic, self.language) async def react(self) -> Message: msg = await super().react() report = msg.instruct_content From 6b2fb95e665064a53c5098f28c4771cd5d69d70b Mon Sep 17 00:00:00 2001 From: paulaan Date: Sat, 9 Dec 2023 12:58:51 +0700 Subject: [PATCH 027/167] reformat for code convention --- metagpt/roles/researcher.py | 1 + 1 file changed, 1 insertion(+) diff --git a/metagpt/roles/researcher.py b/metagpt/roles/researcher.py index f954c60bb..c60d54486 100644 --- a/metagpt/roles/researcher.py +++ b/metagpt/roles/researcher.py @@ -75,6 +75,7 @@ class Researcher(Role): Returns: str """ return get_research_system_text(topic, self.language) + async def react(self) -> Message: msg = await super().react() report = msg.instruct_content From 9d0f19aeee7a713530217e19eac414a9354d5355 Mon Sep 17 00:00:00 2001 From: paulaan Date: Sat, 9 Dec 2023 22:01:47 +0700 Subject: [PATCH 028/167] current task might swith different sys prompt --- metagpt/roles/researcher.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/metagpt/roles/researcher.py b/metagpt/roles/researcher.py index c60d54486..387999cff 100644 --- a/metagpt/roles/researcher.py +++ b/metagpt/roles/researcher.py @@ -4,7 +4,7 @@ import asyncio from pydantic import BaseModel -from metagpt.actions import CollectLinks, ConductResearch, WebBrowseAndSummarize +from metagpt.actions import Action, CollectLinks, ConductResearch, WebBrowseAndSummarize from metagpt.actions.research import get_research_system_text from metagpt.const import RESEARCH_PATH from metagpt.logs import logger @@ -46,7 +46,7 @@ class Researcher(Role): else: topic = msg.content - research_system_text = self.research_system_text(topic) + research_system_text = self.research_system_text(topic, todo) if isinstance(todo, CollectLinks): links = await todo.run(topic, 4, 4) ret = Message("", Report(topic=topic, links=links), role=self.profile, cause_by=type(todo)) @@ -64,7 +64,7 @@ class Researcher(Role): self._rc.memory.add(ret) return ret - def research_system_text(self, topic) -> str: + def research_system_text(self, topic, current_task: Action) -> str: """ BACKWARD compatible This allows sub-class able to define its own system prompt based on topic. return the previous implementation to have backward compatible From 00f8b47d3946c63d9e2da0045404509f1f440692 Mon Sep 17 00:00:00 2001 From: paulaan Date: Sun, 10 Dec 2023 00:42:38 +0700 Subject: [PATCH 029/167] move to chrome --- metagpt/tools/web_browser_engine_selenium.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/metagpt/tools/web_browser_engine_selenium.py b/metagpt/tools/web_browser_engine_selenium.py index 80b60a93c..074943892 100644 --- a/metagpt/tools/web_browser_engine_selenium.py +++ b/metagpt/tools/web_browser_engine_selenium.py @@ -104,11 +104,10 @@ def _gen_get_driver_func(browser_type, *args, executable_path=None): def _get_driver(): options = Options() options.add_argument("--headless") - options.add_argument("--no-sandbox") # This flag is important for running in a Docker container - options.add_argument("--disable-gpu") # This flag can help avoid renderer issue - options.add_argument("--disable-dev-shm-usage") # Overcome limited resource problems options.add_argument("--enable-javascript") if browser_type == "chrome": + options.add_argument("--disable-gpu") # This flag can help avoid renderer issue + options.add_argument("--disable-dev-shm-usage") # Overcome limited resource problems options.add_argument("--no-sandbox") for i in args: options.add_argument(i) From bef8d64193c4ce783432e6f958fd5c0858ea7e00 Mon Sep 17 00:00:00 2001 From: better629 Date: Thu, 14 Dec 2023 16:45:40 +0800 Subject: [PATCH 030/167] add google gemini --- config/config.yaml | 4 + metagpt/config.py | 8 +- metagpt/llm.py | 3 + metagpt/provider/google_gemini_api.py | 130 ++++++++++++++++++ metagpt/utils/token_counter.py | 7 +- requirements.txt | 1 + .../provider/test_google_gemini_api.py | 43 ++++++ 7 files changed, 192 insertions(+), 4 deletions(-) create mode 100644 metagpt/provider/google_gemini_api.py create mode 100644 tests/metagpt/provider/test_google_gemini_api.py diff --git a/config/config.yaml b/config/config.yaml index 080de4000..596a31341 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -34,6 +34,10 @@ RPM: 10 #### if zhipuai from `https://open.bigmodel.cn`. You can set here or export API_KEY="YOUR_API_KEY" # ZHIPUAI_API_KEY: "YOUR_API_KEY" +#### if Google Gemini from `https://ai.google.dev/` and API_KEY from `https://makersuite.google.com/app/apikey`. +#### You can set here or export GOOGLE_API_KEY="YOUR_API_KEY" +# GEMINI_API_KEY: "YOUR_API_KEY" + #### if use self-host open llm model with openai-compatible interface #OPEN_LLM_API_BASE: "http://127.0.0.1:8000/v1" #OPEN_LLM_API_MODEL: "llama2-13b" diff --git a/metagpt/config.py b/metagpt/config.py index 2ce75b013..3b46c8504 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -51,13 +51,17 @@ class Config(metaclass=Singleton): self.open_llm_api_model = self._get("OPEN_LLM_API_MODEL") self.fireworks_api_key = self._get("FIREWORKS_API_KEY") + + self.gemini_api_key = self._get("GEMINI_API_KEY") + if (not self.openai_api_key or "YOUR_API_KEY" == self.openai_api_key) and \ (not self.anthropic_api_key or "YOUR_API_KEY" == self.anthropic_api_key) and \ (not self.zhipuai_api_key or "YOUR_API_KEY" == self.zhipuai_api_key) and \ (not self.open_llm_api_base) and \ - (not self.fireworks_api_key or "YOUR_API_KEY" == self.fireworks_api_key): + (not self.fireworks_api_key or "YOUR_API_KEY" == self.fireworks_api_key) and \ + (not self.gemini_api_key or "YOUR_API_KEY" in self.gemini_api_key): raise NotConfiguredException("Set OPENAI_API_KEY or Anthropic_API_KEY or ZHIPUAI_API_KEY first " - "or FIREWORKS_API_KEY or OPEN_LLM_API_BASE") + "or FIREWORKS_API_KEY or OPEN_LLM_API_BASE or GEMINI_API_KEY") self.openai_api_base = self._get("OPENAI_API_BASE") openai_proxy = self._get("OPENAI_PROXY") or self.global_proxy if openai_proxy: diff --git a/metagpt/llm.py b/metagpt/llm.py index 7b490ec4a..b13fc723a 100644 --- a/metagpt/llm.py +++ b/metagpt/llm.py @@ -14,6 +14,7 @@ from metagpt.provider.spark_api import SparkAPI from metagpt.provider.open_llm_api import OpenLLMGPTAPI from metagpt.provider.fireworks_api import FireWorksGPTAPI from metagpt.provider.human_provider import HumanProvider +from metagpt.provider.google_gemini_api import GeminiGPTAPI def LLM() -> "BaseGPTAPI": @@ -29,6 +30,8 @@ def LLM() -> "BaseGPTAPI": llm = OpenLLMGPTAPI() elif CONFIG.fireworks_api_key: llm = FireWorksGPTAPI() + elif CONFIG.gemini_api_key: + llm = GeminiGPTAPI() else: raise RuntimeError("You should config a LLM configuration first") diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py new file mode 100644 index 000000000..1c866ebad --- /dev/null +++ b/metagpt/provider/google_gemini_api.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : Google Gemini LLM from https://ai.google.dev/tutorials/python_quickstart + +from tenacity import ( + after_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_fixed, +) +import google.generativeai as genai +from google.generativeai import client +from google.generativeai.types.generation_types import GenerateContentResponse, AsyncGenerateContentResponse +from google.generativeai.types.generation_types import GenerationConfig + +from metagpt.config import CONFIG +from metagpt.logs import logger +from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.openai_api import log_and_reraise + + +class GeminiGPTAPI(BaseGPTAPI): + """ + Refs to `https://ai.google.dev/tutorials/python_quickstart` + """ + + use_system_prompt: bool = False # google gemini has no system prompt when use api + + def __init__(self): + self.__init_gemini(CONFIG) + self.model = "gemini-pro" # so far only one model + self.llm = genai.GenerativeModel(model_name=self.model) + + def __init_gemini(self, config: CONFIG): + genai.configure(api_key=config.gemini_api_key) + + def _user_msg(self, msg: str) -> dict[str, str]: + return {"role": "user", "parts": [msg]} + + def _assistant_msg(self, msg: str) -> dict[str, str]: + return {"role": "model", "parts": [msg]} + + def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict: + kwargs = { + "contents": messages, + "generation_config": GenerationConfig( + temperature=0.3 + ), + "stream": stream + } + return kwargs + + def _update_costs(self, usage: dict): + """ update each request's token cost """ + if CONFIG.calc_usage: + try: + prompt_tokens = int(usage.get("prompt_tokens", 0)) + completion_tokens = int(usage.get("completion_tokens", 0)) + self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) + except Exception as e: + logger.error("google gemini updats costs failed!", e) + + def get_choice_text(self, resp: GenerateContentResponse) -> str: + return resp.text + + def get_usage(self, messages: list[dict], resp_text: str) -> dict: + prompt_resp = self.llm.count_tokens(contents=messages) + completion_resp = self.llm.count_tokens(contents={"parts": [resp_text]}) + usage = { + "prompt_tokens": prompt_resp.total_tokens, + "completion_tokens": completion_resp.total_tokens + } + return usage + + async def aget_usage(self, messages: list[dict], resp_text: str) -> dict: + # fix google-generativeai sdk + if self.llm._client is None: + self.llm._client = client.get_default_generative_client() + # TODO exception to fix + prompt_resp = await self.llm.count_tokens_async(contents=messages) + completion_resp = await self.llm.count_tokens_async(contents={"parts": [resp_text]}) + usage = { + "prompt_tokens": prompt_resp.total_tokens, + "completion_tokens": completion_resp.total_tokens + } + return usage + + def completion(self, messages: list[dict]) -> "GenerateContentResponse": + resp: GenerateContentResponse = self.llm.generate_content(**self._const_kwargs(messages)) + # usage = self.get_usage(messages, resp.text) + # self._update_costs(usage) + return resp + + async def _achat_completion(self, messages: list[dict]) -> "AsyncGenerateContentResponse": + resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(**self._const_kwargs(messages)) + # usage = await self.aget_usage(messages, resp.text) + # self._update_costs(usage) + return resp + + async def acompletion(self, messages: list[dict]) -> dict: + return await self._achat_completion(messages) + + async def _achat_completion_stream(self, messages: list[dict]) -> str: + resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(**self._const_kwargs(messages, + stream=True)) + collected_content = [] + async for chunk in resp: + content = chunk.text + print(content, end="") + collected_content.append(content) + + full_content = "".join(collected_content) + # usage = await self.aget_usage(messages, full_content) + # self._update_costs(usage) + return full_content + + @retry( + stop=stop_after_attempt(3), + wait=wait_fixed(1), + after=after_log(logger, logger.level("WARNING").name), + retry=retry_if_exception_type(ConnectionError), + retry_error_callback=log_and_reraise + ) + async def acompletion_text(self, messages: list[dict], stream=False) -> str: + """ response in async with stream or non-stream mode """ + if stream: + return await self._achat_completion_stream(messages) + resp = await self._achat_completion(messages) + return self.get_choice_text(resp) diff --git a/metagpt/utils/token_counter.py b/metagpt/utils/token_counter.py index ba63e90a9..6d9cbd137 100644 --- a/metagpt/utils/token_counter.py +++ b/metagpt/utils/token_counter.py @@ -7,6 +7,7 @@ ref1: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb ref2: https://github.com/Significant-Gravitas/Auto-GPT/blob/master/autogpt/llm/token_counter.py ref3: https://github.com/hwchase17/langchain/blob/master/langchain/chat_models/openai.py +ref4: https://ai.google.dev/models/gemini """ import tiktoken @@ -24,7 +25,8 @@ TOKEN_COSTS = { "gpt-4-0613": {"prompt": 0.06, "completion": 0.12}, "gpt-4-1106-preview": {"prompt": 0.01, "completion": 0.03}, "text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0}, - "chatglm_turbo": {"prompt": 0.0, "completion": 0.00069} # 32k version, prompt + completion tokens=0.005¥/k-tokens + "chatglm_turbo": {"prompt": 0.0, "completion": 0.00069}, # 32k version, prompt + completion tokens=0.005¥/k-tokens + "gemini-pro": {"prompt": 0.00025, "completion": 0.0005} } @@ -42,7 +44,8 @@ TOKEN_MAX = { "gpt-4-0613": 8192, "gpt-4-1106-preview": 128000, "text-embedding-ada-002": 8192, - "chatglm_turbo": 32768 + "chatglm_turbo": 32768, + "gemini-pro": 32768 } diff --git a/requirements.txt b/requirements.txt index 14a9f485d..a2aaff48b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -45,3 +45,4 @@ semantic-kernel==0.3.13.dev0 wrapt==1.15.0 websocket-client==0.58.0 zhipuai==1.0.7 +google-generativeai==0.3.1 \ No newline at end of file diff --git a/tests/metagpt/provider/test_google_gemini_api.py b/tests/metagpt/provider/test_google_gemini_api.py new file mode 100644 index 000000000..32ed11ba5 --- /dev/null +++ b/tests/metagpt/provider/test_google_gemini_api.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : the unittest of google gemini api + +import pytest +from abc import ABC +from dataclasses import dataclass + +from metagpt.provider.google_gemini_api import GeminiGPTAPI + + +messages = [ + {"role": "user", "content": "who are you"} +] + + +@dataclass +class MockGeminiResponse(ABC): + text: str + + +default_resp = MockGeminiResponse(text="I'm gemini from google") + + +def mock_llm_ask(self, messages: list[dict]) -> MockGeminiResponse: + return default_resp + + +def test_gemini_completion(mocker): + mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.completion", mock_llm_ask) + resp = GeminiGPTAPI().completion(messages) + assert resp.text == default_resp.text + + +async def mock_llm_aask(self, messgaes: list[dict]) -> MockGeminiResponse: + return default_resp + + +@pytest.mark.asyncio +async def test_gemini_acompletion(mocker): + mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.acompletion", mock_llm_aask) + resp = await GeminiGPTAPI().acompletion(messages) + assert resp.text == default_resp.text From 9fb6e7c459a24489028ebe55a4ed2032d689eac1 Mon Sep 17 00:00:00 2001 From: better629 Date: Thu, 14 Dec 2023 16:54:56 +0800 Subject: [PATCH 031/167] update gemini user_msg doc --- metagpt/provider/google_gemini_api.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index 1c866ebad..a69ffdc28 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -36,6 +36,8 @@ class GeminiGPTAPI(BaseGPTAPI): genai.configure(api_key=config.gemini_api_key) def _user_msg(self, msg: str) -> dict[str, str]: + # Not to change BaseGPTAPI default functions but update with Gemini's conversation format. + # You should follow the format. return {"role": "user", "parts": [msg]} def _assistant_msg(self, msg: str) -> dict[str, str]: From 4127ef85704a7771b484c8c73912e1919ef0be09 Mon Sep 17 00:00:00 2001 From: better629 Date: Fri, 15 Dec 2023 17:06:59 +0800 Subject: [PATCH 032/167] update gemini count_tokens --- metagpt/provider/google_gemini_api.py | 56 ++++++++++++++++++--------- metagpt/provider/zhipuai_api.py | 2 +- 2 files changed, 39 insertions(+), 19 deletions(-) diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index a69ffdc28..0ba1e86c1 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -10,14 +10,35 @@ from tenacity import ( wait_fixed, ) import google.generativeai as genai -from google.generativeai import client +from google.ai import generativelanguage as glm +from google.generativeai.types import content_types +from google.generativeai.generative_models import GenerativeModel from google.generativeai.types.generation_types import GenerateContentResponse, AsyncGenerateContentResponse from google.generativeai.types.generation_types import GenerationConfig from metagpt.config import CONFIG from metagpt.logs import logger from metagpt.provider.base_gpt_api import BaseGPTAPI -from metagpt.provider.openai_api import log_and_reraise +from metagpt.provider.openai_api import CostManager, log_and_reraise + + +class GeminiGenerativeModel(GenerativeModel): + """ + Due to `https://github.com/google/generative-ai-python/pull/123`, inherit a new class. + Will use default GenerativeModel if it fixed. + """ + + def count_tokens( + self, contents: content_types.ContentsType + ) -> glm.CountTokensResponse: + contents = content_types.to_contents(contents) + return self._client.count_tokens(model=self.model_name, contents=contents) + + async def count_tokens_async( + self, contents: content_types.ContentsType + ) -> glm.CountTokensResponse: + contents = content_types.to_contents(contents) + return await self._async_client.count_tokens(model=self.model_name, contents=contents) class GeminiGPTAPI(BaseGPTAPI): @@ -30,7 +51,8 @@ class GeminiGPTAPI(BaseGPTAPI): def __init__(self): self.__init_gemini(CONFIG) self.model = "gemini-pro" # so far only one model - self.llm = genai.GenerativeModel(model_name=self.model) + self.llm = GeminiGenerativeModel(model_name=self.model) + self._cost_manager = CostManager() def __init_gemini(self, config: CONFIG): genai.configure(api_key=config.gemini_api_key) @@ -61,14 +83,15 @@ class GeminiGPTAPI(BaseGPTAPI): completion_tokens = int(usage.get("completion_tokens", 0)) self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) except Exception as e: - logger.error("google gemini updats costs failed!", e) + logger.error(f"google gemini updats costs failed! exp: {e}") def get_choice_text(self, resp: GenerateContentResponse) -> str: return resp.text def get_usage(self, messages: list[dict], resp_text: str) -> dict: - prompt_resp = self.llm.count_tokens(contents=messages) - completion_resp = self.llm.count_tokens(contents={"parts": [resp_text]}) + req_text = messages[-1]["parts"][0] if messages else "" + prompt_resp = self.llm.count_tokens(contents={"role": "user", "parts": [{"text": req_text}]}) + completion_resp = self.llm.count_tokens(contents={"role": "model", "parts": [{"text": resp_text}]}) usage = { "prompt_tokens": prompt_resp.total_tokens, "completion_tokens": completion_resp.total_tokens @@ -76,12 +99,9 @@ class GeminiGPTAPI(BaseGPTAPI): return usage async def aget_usage(self, messages: list[dict], resp_text: str) -> dict: - # fix google-generativeai sdk - if self.llm._client is None: - self.llm._client = client.get_default_generative_client() - # TODO exception to fix - prompt_resp = await self.llm.count_tokens_async(contents=messages) - completion_resp = await self.llm.count_tokens_async(contents={"parts": [resp_text]}) + req_text = messages[-1]["parts"][0] if messages else "" + prompt_resp = await self.llm.count_tokens_async(contents={"role": "user", "parts": [{"text": req_text}]}) + completion_resp = await self.llm.count_tokens_async(contents={"role": "model", "parts": [{"text": resp_text}]}) usage = { "prompt_tokens": prompt_resp.total_tokens, "completion_tokens": completion_resp.total_tokens @@ -90,14 +110,14 @@ class GeminiGPTAPI(BaseGPTAPI): def completion(self, messages: list[dict]) -> "GenerateContentResponse": resp: GenerateContentResponse = self.llm.generate_content(**self._const_kwargs(messages)) - # usage = self.get_usage(messages, resp.text) - # self._update_costs(usage) + usage = self.get_usage(messages, resp.text) + self._update_costs(usage) return resp async def _achat_completion(self, messages: list[dict]) -> "AsyncGenerateContentResponse": resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(**self._const_kwargs(messages)) - # usage = await self.aget_usage(messages, resp.text) - # self._update_costs(usage) + usage = await self.aget_usage(messages, resp.text) + self._update_costs(usage) return resp async def acompletion(self, messages: list[dict]) -> dict: @@ -113,8 +133,8 @@ class GeminiGPTAPI(BaseGPTAPI): collected_content.append(content) full_content = "".join(collected_content) - # usage = await self.aget_usage(messages, full_content) - # self._update_costs(usage) + usage = await self.aget_usage(messages, full_content) + self._update_costs(usage) return full_content @retry( diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py index 3161c0e88..3b24ca98f 100644 --- a/metagpt/provider/zhipuai_api.py +++ b/metagpt/provider/zhipuai_api.py @@ -65,7 +65,7 @@ class ZhiPuAIGPTAPI(BaseGPTAPI): completion_tokens = int(usage.get("completion_tokens", 0)) self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) except Exception as e: - logger.error("zhipuai updats costs failed!", e) + logger.error(f"zhipuai updats costs failed! exp: {e}") def get_choice_text(self, resp: dict) -> str: """ get the first text of choice from llm response """ From 70cbfb1e480367bb9586a62b7b723c80c57aa4f0 Mon Sep 17 00:00:00 2001 From: better629 Date: Fri, 15 Dec 2023 17:30:25 +0800 Subject: [PATCH 033/167] retry use wait_random_exponential --- metagpt/provider/google_gemini_api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index 0ba1e86c1..b68e013a0 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -7,7 +7,7 @@ from tenacity import ( retry, retry_if_exception_type, stop_after_attempt, - wait_fixed, + wait_random_exponential, ) import google.generativeai as genai from google.ai import generativelanguage as glm @@ -139,7 +139,7 @@ class GeminiGPTAPI(BaseGPTAPI): @retry( stop=stop_after_attempt(3), - wait=wait_fixed(1), + wait=wait_random_exponential(min=1, max=60), after=after_log(logger, logger.level("WARNING").name), retry=retry_if_exception_type(ConnectionError), retry_error_callback=log_and_reraise From f9111e009ee132b0e30ba2070bfe0f6cac986f1c Mon Sep 17 00:00:00 2001 From: shenchucheng Date: Sun, 17 Dec 2023 15:01:54 +0800 Subject: [PATCH 034/167] update the docs link --- README.md | 22 +++++++++++----------- docs/README_CN.md | 20 ++++++++++---------- docs/ROADMAP.md | 2 +- 3 files changed, 22 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index b0faf85c7..7538824c5 100644 --- a/README.md +++ b/README.md @@ -62,7 +62,7 @@ # If executing, ensure that NPM is installed on your system. Then install mermai sudo npm install -g @mermaid-js/mermaid-cli ``` -detail installation please refer to [cli_install](https://docs.deepwisdom.ai/guide/get_started/installation.html#install-stable-version) +detail installation please refer to [cli_install](https://docs.deepwisdom.ai/main/en/guide/get_started/installation.html#install-stable-version) ### Docker installation > Note: In the Windows, you need to replace "/opt/metagpt" with a directory that Docker has permission to create, such as "D:\Users\x\metagpt" @@ -83,7 +83,7 @@ # Step 2: Run metagpt demo with container metagpt "Write a cli snake game" ``` -detail installation please refer to [docker_install](https://docs.deepwisdom.ai/guide/get_started/installation.html#install-with-docker) +detail installation please refer to [docker_install](https://docs.deepwisdom.ai/main/en/guide/get_started/installation.html#install-with-docker) ### QuickStart & Demo Video - Try it on [MetaGPT Huggingface Space](https://huggingface.co/spaces/deepwisdom/MetaGPT) @@ -94,19 +94,19 @@ ### QuickStart & Demo Video ## Tutorial -- 🗒 [Online Document](https://docs.deepwisdom.ai/) -- 💻 [Usage](https://docs.deepwisdom.ai/guide/get_started/quickstart.html) -- 🔎 [What can MetaGPT do?](https://docs.deepwisdom.ai/guide/get_started/introduction.html) +- 🗒 [Online Document](https://docs.deepwisdom.ai/main/en/) +- 💻 [Usage](https://docs.deepwisdom.ai/main/en/guide/get_started/quickstart.html) +- 🔎 [What can MetaGPT do?](https://docs.deepwisdom.ai/main/en/guide/get_started/introduction.html) - 🛠 How to build your own agents? - - [MetaGPT Usage & Development Guide | Agent 101](https://docs.deepwisdom.ai/guide/tutorials/agent_101.html) - - [MetaGPT Usage & Development Guide | MultiAgent 101](https://docs.deepwisdom.ai/guide/tutorials/multi_agent_101.html) + - [MetaGPT Usage & Development Guide | Agent 101](https://docs.deepwisdom.ai/main/en/guide/tutorials/agent_101.html) + - [MetaGPT Usage & Development Guide | MultiAgent 101](https://docs.deepwisdom.ai/main/en/guide/tutorials/multi_agent_101.html) - 🧑‍💻 Contribution - [Develop Roadmap](docs/ROADMAP.md) - 🔖 Use Cases - - [Debate](https://docs.deepwisdom.ai/guide/use_cases/multi_agent/debate.html) - - [Researcher](https://docs.deepwisdom.ai/guide/use_cases/agent/researcher.html) - - [Recepit Assistant](https://docs.deepwisdom.ai/guide/use_cases/agent/receipt_assistant.html) -- ❓ [FAQs](https://docs.deepwisdom.ai/guide/faq.html) + - [Debate](https://docs.deepwisdom.ai/main/en/guide/use_cases/multi_agent/debate.html) + - [Researcher](https://docs.deepwisdom.ai/main/en/guide/use_cases/agent/researcher.html) + - [Recepit Assistant](https://docs.deepwisdom.ai/main/en/guide/use_cases/agent/receipt_assistant.html) +- ❓ [FAQs](https://docs.deepwisdom.ai/main/en/guide/faq.html) ## Support diff --git a/docs/README_CN.md b/docs/README_CN.md index dd65c2a25..2855b5500 100644 --- a/docs/README_CN.md +++ b/docs/README_CN.md @@ -78,7 +78,7 @@ # 步骤2: 使用容器运行metagpt演示 metagpt "Write a cli snake game" ``` -详细的安装请安装 [docker_install](https://docs.deepwisdom.ai/zhcn/guide/get_started/installation.html#%E4%BD%BF%E7%94%A8docker%E5%AE%89%E8%A3%85) +详细的安装请安装 [docker_install](https://docs.deepwisdom.ai/main/zh/guide/get_started/installation.html#%E4%BD%BF%E7%94%A8docker%E5%AE%89%E8%A3%85) ### 快速开始的演示视频 - 在 [MetaGPT Huggingface Space](https://huggingface.co/spaces/deepwisdom/MetaGPT) 上进行体验 @@ -88,19 +88,19 @@ ### 快速开始的演示视频 https://github.com/geekan/MetaGPT/assets/34952977/34345016-5d13-489d-b9f9-b82ace413419 ## 教程 -- 🗒 [在线文档](https://docs.deepwisdom.ai/zhcn/) -- 💻 [如何使用](https://docs.deepwisdom.ai/zhcn/guide/get_started/quickstart.html) -- 🔎 [MetaGPT的能力及应用场景](https://docs.deepwisdom.ai/zhcn/guide/get_started/introduction.html) +- 🗒 [在线文档](https://docs.deepwisdom.ai/main/zh/) +- 💻 [如何使用](https://docs.deepwisdom.ai/main/zh/guide/get_started/quickstart.html) +- 🔎 [MetaGPT的能力及应用场景](https://docs.deepwisdom.ai/main/zh/guide/get_started/introduction.html) - 🛠 如何构建你自己的智能体? - - [MetaGPT的使用和开发教程 | 智能体入门](https://docs.deepwisdom.ai/zhcn/guide/tutorials/agent_101.html) - - [MetaGPT的使用和开发教程 | 多智能体入门](https://docs.deepwisdom.ai/zhcn/guide/tutorials/multi_agent_101.html) + - [MetaGPT的使用和开发教程 | 智能体入门](https://docs.deepwisdom.ai/main/zh/guide/tutorials/agent_101.html) + - [MetaGPT的使用和开发教程 | 多智能体入门](https://docs.deepwisdom.ai/main/zh/guide/tutorials/multi_agent_101.html) - 🧑‍💻 贡献 - [开发路线图](ROADMAP.md) - 🔖 示例 - - [辩论](https://docs.deepwisdom.ai/zhcn/guide/use_cases/multi_agent/debate.html) - - [调研员](https://docs.deepwisdom.ai/zhcn/guide/use_cases/agent/researcher.html) - - [票据助手](https://docs.deepwisdom.ai/zhcn/guide/use_cases/agent/receipt_assistant.html) -- ❓ [常见问题解答](https://docs.deepwisdom.ai/zhcn/guide/faq.html) + - [辩论](https://docs.deepwisdom.ai/main/zh/guide/use_cases/multi_agent/debate.html) + - [调研员](https://docs.deepwisdom.ai/main/zh/guide/use_cases/agent/researcher.html) + - [票据助手](https://docs.deepwisdom.ai/main/zh/guide/use_cases/agent/receipt_assistant.html) +- ❓ [常见问题解答](https://docs.deepwisdom.ai/main/zh/guide/faq.html) ## 支持 diff --git a/docs/ROADMAP.md b/docs/ROADMAP.md index afc9ff445..25eb4e3a1 100644 --- a/docs/ROADMAP.md +++ b/docs/ROADMAP.md @@ -21,7 +21,7 @@ ### Tasks 3. ~~Support human confirmation and modification during the process~~ (v0.3.0) New: Support human confirmation and modification with fewer constrainsts and a more user-friendly interface 4. Support process caching: Consider carefully whether to add server caching mechanism 5. ~~Resolve occasional failure to follow instruction under current prompts, causing code parsing errors, through stricter system prompts~~ (v0.4.0, with function call) - 6. Write documentation, describing the current features and usage at all levels (ongoing, continuously adding contents to [documentation site](https://docs.deepwisdom.ai/guide/get_started/introduction.html)) + 6. Write documentation, describing the current features and usage at all levels (ongoing, continuously adding contents to [documentation site](https://docs.deepwisdom.ai/main/en/guide/get_started/introduction.html)) 7. ~~Support Docker~~ 2. Features 1. Support a more standard and stable parser (need to analyze the format that the current LLM is better at) From 949bc747f92c368f47bd73966e0eba205d4f7a40 Mon Sep 17 00:00:00 2001 From: better629 Date: Tue, 28 Nov 2023 09:29:00 +0800 Subject: [PATCH 035/167] add mg ser&deser --- metagpt/actions/action.py | 31 +++++++ metagpt/const.py | 2 + metagpt/environment.py | 38 +++++++++ metagpt/memory/memory.py | 30 +++++++ metagpt/roles/role.py | 117 ++++++++++++++++++++++++++- metagpt/schema.py | 44 +++++++++- metagpt/team.py | 26 ++++++ metagpt/utils/serialize.py | 62 ++++++++++++-- metagpt/utils/utils.py | 38 ++++++++- startup.py | 81 +++++++++++++++++++ tests/metagpt/actions/test_action.py | 17 ++++ tests/metagpt/memory/test_memory.py | 42 ++++++++++ tests/metagpt/roles/test_role.py | 85 +++++++++++++++++++ tests/metagpt/test_environment.py | 27 +++++-- tests/metagpt/test_schema.py | 42 ++++++++++ tests/metagpt/test_team.py | 27 +++++++ 16 files changed, 693 insertions(+), 16 deletions(-) create mode 100644 startup.py create mode 100644 tests/metagpt/memory/test_memory.py create mode 100644 tests/metagpt/roles/test_role.py create mode 100644 tests/metagpt/test_team.py diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index 1534b1f4d..3bfb69de4 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -17,6 +17,7 @@ from metagpt.logs import logger from metagpt.provider.postprecess.llm_output_postprecess import llm_output_postprecess from metagpt.utils.common import OutputParser from metagpt.utils.utils import general_after_log +from metagpt.utils.utils import import_class class Action(ABC): @@ -51,6 +52,36 @@ class Action(ABC): def __repr__(self): return self.__str__() + def serialize(self): + return { + "action_class": self.__class__.__name__, + "module_name": self.__module__, + "name": self.name + } + + @classmethod + def deserialize(cls, action_dict: dict): + action_class_str = action_dict.pop("action_class") + module_name = action_dict.pop("module_name") + action_class = import_class(action_class_str, module_name) + return action_class(**action_dict) + + @classmethod + def ser_class(cls): + """ serialize class type""" + return { + "action_class": cls.__name__, + "module_name": cls.__module__ + } + + @classmethod + def deser_class(cls, action_dict: dict): + """ deserialize class type """ + action_class_str = action_dict.pop("action_class") + module_name = action_dict.pop("module_name") + action_class = import_class(action_class_str, module_name) + return action_class + async def _aask(self, prompt: str, system_msgs: Optional[list[str]] = None) -> str: """Append default prefix""" if not system_msgs: diff --git a/metagpt/const.py b/metagpt/const.py index 10de0ff66..b46bc15a4 100644 --- a/metagpt/const.py +++ b/metagpt/const.py @@ -60,6 +60,8 @@ SWAGGER_PATH = UT_PATH / "files/api/" UT_PY_PATH = UT_PATH / "files/ut/" API_QUESTIONS_PATH = UT_PATH / "files/question/" +SERDES_PATH = DEFAULT_WORKSPACE_ROOT / "storage" # TODO to store `storage` under the individual generated project + TMP = METAGPT_ROOT / "tmp" SOURCE_ROOT = METAGPT_ROOT / "metagpt" diff --git a/metagpt/environment.py b/metagpt/environment.py index 89b6f9d46..14da6cd95 100644 --- a/metagpt/environment.py +++ b/metagpt/environment.py @@ -13,6 +13,7 @@ """ import asyncio from typing import Iterable, Set +from pathlib import Path from pydantic import BaseModel, Field @@ -20,6 +21,7 @@ from metagpt.logs import logger from metagpt.roles import Role from metagpt.schema import Message from metagpt.utils.common import is_subscribed +from metagpt.utils.utils import read_json_file, write_json_file class Environment(BaseModel): @@ -35,6 +37,42 @@ class Environment(BaseModel): class Config: arbitrary_types_allowed = True + def serialize(self, stg_path: Path): + roles_path = stg_path.joinpath("roles.json") + roles_info = [] + for role_key, role in self.roles.items(): + roles_info.append({ + "role_class": role.__class__.__name__, + "module_name": role.__module__, + "role_name": role.name + }) + role.serialize(stg_path=stg_path.joinpath(f"roles/{role.__class__.__name__}_{role.name}")) + write_json_file(roles_path, roles_info) + + self.memory.serialize(stg_path) + history_path = stg_path.joinpath("history.json") + write_json_file(history_path, {"content": self.history}) + + def deserialize(self, stg_path: Path): + """ stg_path: ./storage/team/environment/ """ + roles_path = stg_path.joinpath("roles.json") + roles_info = read_json_file(roles_path) + for role_info in roles_info: + role_class = role_info.get("role_class") + role_name = role_info.get("role_name") + + role_path = stg_path.joinpath(f"roles/{role_class}_{role_name}") + role = Role.deserialize(role_path) + + self.add_role(role) + + memory = Memory.deserialize(stg_path) + self.memory = memory + + history_path = stg_path.joinpath("history.json") + history = read_json_file(history_path) + self.history = history.get("content") + def add_role(self, role: Role): """增加一个在当前环境的角色 Add a role in the current environment diff --git a/metagpt/memory/memory.py b/metagpt/memory/memory.py index 53b65fcf7..43bd33e59 100644 --- a/metagpt/memory/memory.py +++ b/metagpt/memory/memory.py @@ -8,9 +8,12 @@ """ from collections import defaultdict from typing import Iterable, Set +from pathlib import Path from metagpt.schema import Message from metagpt.utils.common import any_to_str, any_to_str_set +from metagpt.utils.utils import read_json_file, write_json_file +from metagpt.utils.serialize import serialize_general_message, deserialize_general_message class Memory: @@ -21,6 +24,33 @@ class Memory: self.storage: list[Message] = [] self.index: dict[str, list[Message]] = defaultdict(list) + def serialize(self, stg_path: Path): + """ stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/ """ + memory_path = stg_path.joinpath("memory.json") + + storage = [] + for message in self.storage: + # msg_dict = message.serialize() + msg_dict = serialize_general_message(message) + storage.append(msg_dict) + + write_json_file(memory_path, storage) + + @classmethod + def deserialize(cls, stg_path: Path) -> "Memory": + """ stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/""" + memory_path = stg_path.joinpath("memory.json") + + memory = Memory() + memory_list = read_json_file(memory_path) + for message in memory_list: + # distinguish instruct_content type in message + # msg = Message.deserialize(message) + msg = deserialize_general_message(message) + memory.add(msg) + + return memory + def add(self, message: Message): """Add a new message to storage, while updating the index""" if message in self.storage: diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 1e7ebf711..bb3b2acfe 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -22,7 +22,7 @@ from __future__ import annotations from enum import Enum from typing import Iterable, Set, Type - +from pathlib import Path from pydantic import BaseModel, Field from metagpt.actions import Action, ActionOutput @@ -30,10 +30,12 @@ from metagpt.actions.action_node import ActionNode from metagpt.actions.add_requirement import UserRequirement from metagpt.llm import LLM, HumanProvider from metagpt.logs import logger -from metagpt.memory import Memory from metagpt.schema import Message, MessageQueue from metagpt.utils.common import any_to_str from metagpt.utils.repair_llm_raw_output import extract_state_value_from_output +from metagpt.memory import Memory +from metagpt.utils.utils import read_json_file, write_json_file, import_class + PREFIX_TEMPLATE = """You are a {profile}, named {name}, your goal is {goal}, and the constraint is {constraints}. """ @@ -152,6 +154,87 @@ class Role(metaclass=_RoleInjector): self._rc = RoleContext() self._subscription = {any_to_str(self), name} if name else {any_to_str(self)} + self._recovered = False + + def serialize(self, stg_path: Path): + role_info_path = stg_path.joinpath("role_info.json") + role_info = { + "role_class": self.__class__.__name__, + "module_name": self.__module__ + } + setting = self._setting.dict() + setting.pop("desc") + setting.pop("is_human") # not all inherited roles have this atrr + role_info.update(setting) + write_json_file(role_info_path, role_info) + + actions_info_path = stg_path.joinpath("actions/actions_info.json") + actions_info = [] + for action in self._actions: + actions_info.append(action.serialize()) + write_json_file(actions_info_path, actions_info) + + watches_info_path = stg_path.joinpath("watches/watches_info.json") + watches_info = [] + for watch in self._rc.watch: + watches_info.append(watch.ser_class()) + write_json_file(watches_info_path, watches_info) + + actions_todo_path = stg_path.joinpath("actions/todo.json") + actions_todo = { + "cur_state": self._rc.state, + "react_mode": self._rc.react_mode.value, + "max_react_loop": self._rc.max_react_loop + } + write_json_file(actions_todo_path, actions_todo) + + self._rc.memory.serialize(stg_path) + + @classmethod + def deserialize(cls, stg_path: Path) -> "Role": + """ stg_path = ./storage/team/environment/roles/{role_class}_{role_name}""" + role_info_path = stg_path.joinpath("role_info.json") + role_info = read_json_file(role_info_path) + + role_class_str = role_info.pop("role_class") + module_name = role_info.pop("module_name") + role_class = import_class(class_name=role_class_str, module_name=module_name) + + role = role_class(**role_info) # initiate particular Role + actions_info_path = stg_path.joinpath("actions/actions_info.json") + actions = [] + actions_info = read_json_file(actions_info_path) + for action_info in actions_info: + action = Action.deserialize(action_info) + actions.append(action) + + watches_info_path = stg_path.joinpath("watches/watches_info.json") + watches = [] + watches_info = read_json_file(watches_info_path) + for watch_info in watches_info: + action = Action.deser_class(watch_info) + watches.append(action) + + role.init_actions(actions) + role.watch(watches) + + actions_todo_path = stg_path.joinpath("actions/todo.json") + # recover self._rc.state + actions_todo = read_json_file(actions_todo_path) + max_react_loop = actions_todo.get("max_react_loop", 1) + cur_state = actions_todo.get("cur_state", -1) + role.set_state(cur_state) + role.set_recovered(True) + react_mode_str = actions_todo.get("react_mode", RoleReactMode.REACT.value) + if react_mode_str not in RoleReactMode.values(): + logger.warning(f"ReactMode: {react_mode_str} not in {RoleReactMode.values()}, use react as default") + react_mode_str = RoleReactMode.REACT.value + role.set_react_mode(RoleReactMode(react_mode_str), max_react_loop) + + role_memory = Memory.deserialize(stg_path) + role.set_memory(role_memory) + + return role def _reset(self): self._states = [] @@ -160,6 +243,15 @@ class Role(metaclass=_RoleInjector): def _init_action_system_message(self, action: Action): action.set_prefix(self._get_prefix(), self.profile) + def set_recovered(self, recovered: bool = False): + self._recovered = recovered + + def set_memory(self, memory: Memory): + self._rc.memory = memory + + def init_actions(self, actions): + self._init_actions(actions) + def _init_actions(self, actions): self._reset() for idx, action in enumerate(actions): @@ -178,6 +270,9 @@ class Role(metaclass=_RoleInjector): self._actions.append(i) self._states.append(f"{idx}. {action}") + def set_react_mode(self, react_mode: RoleReactMode, max_react_loop: int = 1): + self._set_react_mode(react_mode, max_react_loop) + def _set_react_mode(self, react_mode: str, max_react_loop: int = 1): """Set strategy of the Role reacting to observed Message. Variation lies in how this Role elects action to perform during the _think stage, especially if it is capable of multiple Actions. @@ -199,6 +294,9 @@ class Role(metaclass=_RoleInjector): if react_mode == RoleReactMode.REACT: self._rc.max_react_loop = max_react_loop + def watch(self, actions: Iterable[Type[Action]]): + self._watch(actions) + def _watch(self, actions: Iterable[Type[Action]]): """Watch Actions of interest. Role will select Messages caused by these Actions from its personal message buffer during _observe. @@ -217,6 +315,9 @@ class Role(metaclass=_RoleInjector): if self._rc.env: # According to the routing feature plan in Chapter 2.2.3.2 of RFC 113 self._rc.env.set_subscription(self, self._subscription) + def set_state(self, state: int): + self._set_state(state) + def _set_state(self, state: int): """Update the current state.""" self._rc.state = state @@ -230,6 +331,10 @@ class Role(metaclass=_RoleInjector): if env: env.set_subscription(self, self._subscription) + @property + def name(self): + return self._setting.name + @property def profile(self): """Get the role description (position)""" @@ -257,6 +362,11 @@ class Role(metaclass=_RoleInjector): # If there is only one action, then only this one can be performed self._set_state(0) return + if self._recovered and self._rc.state >= 0: + self._set_state(self._rc.state) # action to run from recovered state + self._recovered = False # avoid max_react_loop out of work + return + prompt = self._get_prefix() prompt += STATE_TEMPLATE.format( history=self._rc.history, @@ -349,7 +459,8 @@ class Role(metaclass=_RoleInjector): async def _act_by_order(self) -> Message: """switch action each time by order defined in _init_actions, i.e. _act (Action1) -> _act (Action2) -> ...""" - for i in range(len(self._states)): + start_idx = self._rc.state if self._rc.state >= 0 else 0 # action to run from recovered state + for i in range(start_idx, len(self._states)): self._set_state(i) rsp = await self._act() return rsp # return output from the last action diff --git a/metagpt/schema.py b/metagpt/schema.py index 5aec378e4..78e4a6031 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -22,7 +22,6 @@ from asyncio import Queue, QueueEmpty, wait_for from json import JSONDecodeError from pathlib import Path from typing import Dict, List, Optional, Set, TypedDict - from pydantic import BaseModel, Field from metagpt.config import CONFIG @@ -36,6 +35,9 @@ from metagpt.const import ( ) from metagpt.logs import logger from metagpt.utils.common import any_to_str, any_to_str_set +# from metagpt.utils.serialize import actionoutout_schema_to_mapping +# from metagpt.actions.action_output import ActionOutput +# from metagpt.actions.action import Action class RawMessage(TypedDict): @@ -155,6 +157,46 @@ class Message(BaseModel): def __repr__(self): return self.__str__() + # def serialize(self): + # message_cp: Message = copy.deepcopy(self) + # ic = message_cp.instruct_content + # if ic: + # # model create by pydantic create_model like `pydantic.main.prd`, can't pickle.dump directly + # schema = ic.schema() + # mapping = actionoutout_schema_to_mapping(schema) + # + # message_cp.instruct_content = {"class": schema["title"], "mapping": mapping, "value": ic.dict()} + # cb = message_cp.cause_by + # if cb: + # message_cp.cause_by = cb.serialize() + # + # return message_cp.dict() + # + # @classmethod + # def deserialize(cls, message_dict: dict): + # instruct_content = message_dict.get("instruct_content") + # if instruct_content: + # ic = instruct_content + # ic_obj = ActionOutput.create_model_class(class_name=ic["class"], mapping=ic["mapping"]) + # ic_new = ic_obj(**ic["value"]) + # message_dict.instruct_content = ic_new + # cause_by = message_dict.get("cause_by") + # if cause_by: + # message_dict.cause_by = Action.deserialize(cause_by) + # + # return Message(**message_dict) + + def dict(self): + return { + "content": self.content, + "instruct_content": self.instruct_content, + "role": self.role, + "cause_by": self.cause_by, + "sent_from": self.sent_from, + "send_to": self.send_to, + "restricted_to": self.restricted_to + } + def to_dict(self) -> dict: """Return a dict containing `role` and `content` for the LLM call.l""" return {"role": self.role, "content": self.content} diff --git a/metagpt/team.py b/metagpt/team.py index a5c405f80..02c48a138 100644 --- a/metagpt/team.py +++ b/metagpt/team.py @@ -7,6 +7,7 @@ @Modified By: mashenquan, 2023/11/27. Add an archiving operation after completing the project, as specified in Section 2.2.3.3 of RFC 135. """ +from pathlib import Path from pydantic import BaseModel, Field from metagpt.actions import UserRequirement @@ -17,6 +18,7 @@ from metagpt.logs import logger from metagpt.roles import Role from metagpt.schema import Message from metagpt.utils.common import NoMoneyException +from metagpt.utils.utils import read_json_file, write_json_file class Team(BaseModel): @@ -32,6 +34,30 @@ class Team(BaseModel): class Config: arbitrary_types_allowed = True + def serialize(self, stg_path: Path): + team_info_path = stg_path.joinpath("team_info.json") + write_json_file(team_info_path, { + "idea": self.idea, + "investment": self.investment + }) + + self.environment.serialize(stg_path.joinpath("environment")) + + def deserialize(self, stg_path: Path): + """ stg_path = ./storage/team """ + # recover team_info + team_info_path = stg_path.joinpath("team_info.json") + if not team_info_path.exists(): + logger.error("recover storage not exist, not to recover and continue run the old project.") + team_info = read_json_file(team_info_path) + self.investment = team_info.get("investment", 10.0) + self.idea = team_info.get("idea", "") + + # recover environment + environment_path = stg_path.joinpath("environment") + self.environment = Environment() + self.environment.deserialize(stg_path=environment_path) + def hire(self, roles: list[Role]): """Hire roles to cooperate""" self.env.add_roles(roles) diff --git a/metagpt/utils/serialize.py b/metagpt/utils/serialize.py index 124176fcb..56a866f2e 100644 --- a/metagpt/utils/serialize.py +++ b/metagpt/utils/serialize.py @@ -4,13 +4,13 @@ import copy import pickle -from typing import Dict, List from metagpt.actions.action_output import ActionOutput from metagpt.schema import Message +from metagpt.actions.action import Action -def actionoutout_schema_to_mapping(schema: Dict) -> Dict: +def actionoutout_schema_to_mapping(schema: dict) -> dict: """ directly traverse the `properties` in the first level. schema structure likes @@ -35,13 +35,47 @@ def actionoutout_schema_to_mapping(schema: Dict) -> Dict: if property["type"] == "string": mapping[field] = (str, ...) elif property["type"] == "array" and property["items"]["type"] == "string": - mapping[field] = (List[str], ...) + mapping[field] = (list[str], ...) elif property["type"] == "array" and property["items"]["type"] == "array": - # here only consider the `List[List[str]]` situation - mapping[field] = (List[List[str]], ...) + # here only consider the `list[list[str]]` situation + mapping[field] = (list[list[str]], ...) return mapping +def actionoutput_mapping_to_str(mapping: dict) -> dict: + new_mapping = {} + for key, value in mapping.items(): + new_mapping[key] = str(value) + return new_mapping + + +def actionoutput_str_to_mapping(mapping: dict) -> dict: + new_mapping = {} + for key, value in mapping.items(): + if value == "(, Ellipsis)": + new_mapping[key] = (str, ...) + else: + new_mapping[key] = eval(value) # `"'(list[str], Ellipsis)"` to `(list[str], ...)` + return new_mapping + + +def serialize_general_message(message: Message) -> dict: + """ serialize Message, not to save""" + message_cp = copy.deepcopy(message) + ic = message_cp.instruct_content + if ic: + # model create by pydantic create_model like `pydantic.main.prd`, can't pickle.dump directly + schema = ic.schema() + mapping = actionoutout_schema_to_mapping(schema) + mapping = actionoutput_mapping_to_str(mapping) + + message_cp.instruct_content = {"class": schema["title"], "mapping": mapping, "value": ic.dict()} + cb = message_cp.cause_by + if cb: + message_cp.cause_by = cb.ser_class() + return message_cp.dict() + + def serialize_message(message: Message): message_cp = copy.deepcopy(message) # avoid `instruct_content` value update by reference ic = message_cp.instruct_content @@ -56,6 +90,24 @@ def serialize_message(message: Message): return msg_ser +def deserialize_general_message(message_dict: dict) -> Message: + """ deserialize Message, not to load""" + instruct_content = message_dict.pop("instruct_content") + cause_by = message_dict.pop("cause_by") + + message = Message(**message_dict) + if instruct_content: + ic = instruct_content + mapping = actionoutput_str_to_mapping(ic["mapping"]) + ic_obj = ActionOutput.create_model_class(class_name=ic["class"], mapping=mapping) + ic_new = ic_obj(**ic["value"]) + message.instruct_content = ic_new + if cause_by: + message.cause_by = Action.deser_class(cause_by) + + return message + + def deserialize_message(message_ser: str) -> Message: message = pickle.loads(message_ser) if message.instruct_content: diff --git a/metagpt/utils/utils.py b/metagpt/utils/utils.py index 5ceed65d9..220e228c3 100644 --- a/metagpt/utils/utils.py +++ b/metagpt/utils/utils.py @@ -3,7 +3,10 @@ # @Desc : import typing - +from typing import Any +import json +from pathlib import Path +import importlib from tenacity import _utils @@ -20,3 +23,36 @@ def general_after_log(logger: "loguru.Logger", sec_format: str = "%0.3f") -> typ ) return log_it + + +def read_json_file(json_file: str, encoding=None) -> list[Any]: + if not Path(json_file).exists(): + raise FileNotFoundError(f"json_file: {json_file} not exist, return []") + + with open(json_file, "r", encoding=encoding) as fin: + try: + data = json.load(fin) + except Exception as exp: + raise ValueError(f"read json file: {json_file} failed") + return data + + +def write_json_file(json_file: str, data: list, encoding=None): + folder_path = Path(json_file).parent + if not folder_path.exists(): + folder_path.mkdir(parents=True, exist_ok=True) + + with open(json_file, "w", encoding=encoding) as fout: + json.dump(data, fout, ensure_ascii=False, indent=4) + + +def import_class(class_name: str, module_name: str) -> type: + module = importlib.import_module(module_name) + a_class = getattr(module, class_name) + return a_class + + +def import_class_inst(class_name: str, module_name: str, *args, **kwargs) -> object: + a_class = import_class(class_name, module_name) + class_inst = a_class(*args, **kwargs) + return class_inst diff --git a/startup.py b/startup.py new file mode 100644 index 000000000..9f753d553 --- /dev/null +++ b/startup.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +import asyncio + +import fire + +from metagpt.const import SERDES_PATH +from metagpt.roles import ( + Architect, + Engineer, + ProductManager, + ProjectManager, + QaEngineer, +) +from metagpt.team import Team + + +async def startup( + idea: str, + investment: float = 3.0, + n_round: int = 5, + code_review: bool = False, + run_tests: bool = False, + implement: bool = True, + recover_path: bool = False, +): + """Run a startup. Be a boss.""" + company = Team() + if not recover_path: + company.hire( + [ + ProductManager(), + Architect(), + ProjectManager(), + ] + ) + + # if implement or code_review + if implement or code_review: + # developing features: implement the idea + company.hire([Engineer(n_borg=5, use_code_review=code_review)]) + + if run_tests: + # developing features: run tests on the spot and identify bugs + # (bug fixing capability comes soon!) + company.hire([QaEngineer()]) + else: + stg_path = SERDES_PATH.joinpath("team") + company.deserialize(stg_path=stg_path) + idea = company.idea # use original idea + + company.invest(investment) + company.start_project(idea) + await company.run(n_round=n_round) + + +def main( + idea: str, + investment: float = 3.0, + n_round: int = 5, + code_review: bool = True, + run_tests: bool = False, + implement: bool = True, + recover_path: str = None, +): + """ + We are a software startup comprised of AI. By investing in us, + you are empowering a future filled with limitless possibilities. + :param idea: Your innovative idea, such as "Creating a snake game." + :param investment: As an investor, you have the opportunity to contribute + a certain dollar amount to this AI company. + :param n_round: + :param code_review: Whether to use code review. + :param recover_path: recover the project from existing serialized storage + :return: + """ + asyncio.run(startup(idea, investment, n_round, code_review, run_tests, implement, recover_path)) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/tests/metagpt/actions/test_action.py b/tests/metagpt/actions/test_action.py index 9775630cc..4468a6f6f 100644 --- a/tests/metagpt/actions/test_action.py +++ b/tests/metagpt/actions/test_action.py @@ -11,3 +11,20 @@ from metagpt.actions import Action, WritePRD, WriteTest def test_action_repr(): actions = [Action(), WriteTest(), WritePRD()] assert "WriteTest" in str(actions) + + +def test_action_serdes(): + action_info = WriteTest.ser_class() + assert action_info["action_class"] == "WriteTest" + + action_class = Action.deser_class(action_info) + assert action_class == WriteTest + + +def test_action_class_serdes(): + name = "write test" + action_info = WriteTest(name=name).serialize() + assert action_info["name"] == name + + action = Action.deserialize(action_info) + assert action.name == name diff --git a/tests/metagpt/memory/test_memory.py b/tests/metagpt/memory/test_memory.py new file mode 100644 index 000000000..bda79ded1 --- /dev/null +++ b/tests/metagpt/memory/test_memory.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : unittest of memory + +from pathlib import Path + +from metagpt.schema import Message +from metagpt.memory.memory import Memory +from metagpt.actions.action_output import ActionOutput +from metagpt.actions.design_api import WriteDesign +from metagpt.actions.add_requirement import BossRequirement + +serdes_path = Path(__file__).absolute().parent.joinpath("../../data/serdes_storage") + + +def test_memory_serdes(): + msg1 = Message(role="User", + content="write a 2048 game", + cause_by=BossRequirement) + + out_mapping = {"field1": (list[str], ...)} + out_data = {"field1": ["field1 value1", "field1 value2"]} + ic_obj = ActionOutput.create_model_class("system_design", out_mapping) + msg2 = Message(role="Architect", + instruct_content=ic_obj(**out_data), + content="system design content", + cause_by=WriteDesign) + + memory = Memory() + memory.add_batch([msg1, msg2]) + + stg_path = serdes_path.joinpath("team/environment") + memory.serialize(stg_path) + assert stg_path.joinpath("memory.json").exists() + + new_memory = Memory.deserialize(stg_path) + assert new_memory.count() == 2 + new_msg2 = new_memory.get(1)[0] + assert new_msg2.instruct_content.field1 == ["field1 value1", "field1 value2"] + assert new_msg2.cause_by == WriteDesign + + stg_path.joinpath("memory.json").unlink() diff --git a/tests/metagpt/roles/test_role.py b/tests/metagpt/roles/test_role.py new file mode 100644 index 000000000..a19ad9cb5 --- /dev/null +++ b/tests/metagpt/roles/test_role.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : unittest of Role + +from pathlib import Path +import shutil +import pytest + +from metagpt.roles.role import Role, RoleReactMode +from metagpt.actions.action import Action +from metagpt.schema import Message +from metagpt.actions.add_requirement import BossRequirement +from metagpt.roles.product_manager import ProductManager + +serdes_path = Path(__file__).absolute().parent.joinpath("../../data/serdes_storage") + + +def test_role_serdes(): + stg_path_prefix = serdes_path.joinpath("team/environment/roles/") + shutil.rmtree(serdes_path.joinpath("team"), ignore_errors=True) + + pm = ProductManager() + role_tag = f"{pm.__class__.__name__}_{pm.name}" + stg_path = stg_path_prefix.joinpath(role_tag) + pm.serialize(stg_path) + assert stg_path.joinpath("actions/actions_info.json").exists() + + new_pm = Role.deserialize(stg_path) + assert new_pm.name == pm.name + assert len(new_pm.get_memories(1)) == 0 + + +class ActionOK(Action): + + async def run(self, messages: list["Message"]): + return "ok" + + +class ActionRaise(Action): + + async def run(self, messages: list["Message"]): + raise RuntimeError("parse error") + + +class RoleA(Role): + + def __init__(self, + name: str = "RoleA", + profile: str = "Role A", + goal: str = "", + constraints: str = ""): + super(RoleA, self).__init__(name=name, profile=profile, goal=goal, constraints=constraints) + self._init_actions([ActionOK, ActionRaise]) + self._watch([BossRequirement]) + self._rc.react_mode = RoleReactMode.BY_ORDER + + async def run(self, message: "Message" = None, stg_path: str = None): + try: + await super(RoleA, self).run(message) + except Exception as exp: + print("exp ", exp) + self.serialize(stg_path) + + +@pytest.mark.asyncio +async def test_role_serdes_interrupt(): + role_a = RoleA() + shutil.rmtree(serdes_path.joinpath("team"), ignore_errors=True) + + stg_path = serdes_path.joinpath(f"team/environment/roles/{role_a.__class__.__name__}_{role_a.name}") + await role_a.run( + message=Message(content="demo", cause_by=BossRequirement), + stg_path=stg_path + ) + assert role_a._rc.memory.count() == 2 + + assert stg_path.joinpath("actions/todo.json").exists() + + new_role_a: Role = Role.deserialize(stg_path) + assert new_role_a._rc.state == 1 + await role_a.run( + message=Message(content="demo", cause_by=BossRequirement), + stg_path=stg_path + ) + diff --git a/tests/metagpt/test_environment.py b/tests/metagpt/test_environment.py index b27bc3da7..03236a08b 100644 --- a/tests/metagpt/test_environment.py +++ b/tests/metagpt/test_environment.py @@ -7,6 +7,8 @@ """ import pytest +from pathlib import Path +import shutil from metagpt.actions import UserRequirement from metagpt.environment import Environment @@ -14,6 +16,10 @@ from metagpt.logs import logger from metagpt.manager import Manager from metagpt.roles import Architect, ProductManager, Role from metagpt.schema import Message +from tests.metagpt.roles.test_role import RoleA + + +serdes_path = Path(__file__).absolute().parent.joinpath("../data/serdes_storage") @pytest.fixture @@ -36,12 +42,6 @@ def test_get_roles(env: Environment): assert roles == {role1.profile: role1, role2.profile: role2} -def test_set_manager(env: Environment): - manager = Manager() - env.set_manager(manager) - assert env.manager == manager - - @pytest.mark.asyncio async def test_publish_and_process_message(env: Environment): product_manager = ProductManager("Alice", "Product Manager", "做AI Native产品", "资源有限") @@ -54,3 +54,18 @@ async def test_publish_and_process_message(env: Environment): await env.run(k=2) logger.info(f"{env.history=}") assert len(env.history) > 10 + + +def test_environment_serdes(): + environment = Environment() + role_a = RoleA() + + shutil.rmtree(serdes_path.joinpath("team"), ignore_errors=True) + + stg_path = serdes_path.joinpath("team/environment") + environment.add_role(role_a) + environment.serialize(stg_path) + + new_env: Environment = Environment() + new_env.deserialize(stg_path) + assert len(new_env.roles) == 1 diff --git a/tests/metagpt/test_schema.py b/tests/metagpt/test_schema.py index 51ebd5baa..4a6f518b1 100644 --- a/tests/metagpt/test_schema.py +++ b/tests/metagpt/test_schema.py @@ -7,12 +7,16 @@ @Modified By: mashenquan, 2023-11-1. In line with Chapter 2.2.1 and 2.2.2 of RFC 116, introduce unit tests for the utilization of the new feature of `Message` class. """ + import json import pytest from metagpt.actions import Action from metagpt.schema import AIMessage, Message, SystemMessage, UserMessage +from metagpt.actions.action_output import ActionOutput +from metagpt.actions.write_code import WriteCode +from metagpt.utils.serialize import serialize_general_message, deserialize_general_message from metagpt.utils.common import get_class_name @@ -70,5 +74,43 @@ def test_routes(): assert m.send_to == {"e", get_class_name(Action)} +def test_message_serdes(): + out_mapping = {"field3": (str, ...), "field4": (list[str], ...)} + out_data = {"field3": "field3 value3", "field4": ["field4 value1", "field4 value2"]} + ic_obj = ActionOutput.create_model_class("code", out_mapping) + + message = Message( + content="code", + instruct_content=ic_obj(**out_data), + role="engineer", + cause_by=WriteCode + ) + message_dict = serialize_general_message(message) + assert message_dict["cause_by"] == {"action_class": "WriteCode"} + assert message_dict["instruct_content"] == { + "class": "code", + "mapping": { + "field3": "(, Ellipsis)", + "field4": "(list[str], Ellipsis)" + }, + "value": { + "field3": "field3 value3", + "field4": ["field4 value1", "field4 value2"] + } + } + + new_message = deserialize_general_message(message_dict) + assert new_message.content == message.content + assert new_message.instruct_content == message.instruct_content + assert new_message.cause_by == message.cause_by + assert new_message.instruct_content.field3 == out_data["field3"] + + message = Message(content="code") + message_dict = serialize_general_message(message) + new_message = deserialize_general_message(message_dict) + assert new_message.instruct_content is None + assert new_message.cause_by == "" + + if __name__ == "__main__": pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/test_team.py b/tests/metagpt/test_team.py new file mode 100644 index 000000000..ab201152c --- /dev/null +++ b/tests/metagpt/test_team.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : unittest of team + +from pathlib import Path +import shutil + +from metagpt.team import Team + +from tests.metagpt.roles.test_role import RoleA + +serdes_path = Path(__file__).absolute().parent.joinpath("../data/serdes_storage") + + +def test_team_serdes(): + company = Team() + company.hire([RoleA()]) + + stg_path = serdes_path.joinpath("team") + shutil.rmtree(stg_path, ignore_errors=True) + + company.serialize(stg_path=stg_path) + + new_company = Team() + new_company.deserialize(stg_path) + + assert len(new_company.environment.roles) == 1 From c8570036fc92be30d2513a95c72ed9d0dc73bc55 Mon Sep 17 00:00:00 2001 From: stellahsr Date: Mon, 27 Nov 2023 21:12:50 +0800 Subject: [PATCH 036/167] update basic code for serialize --- metagpt/actions/action.py | 57 ++++--- metagpt/actions/design_api.py | 20 ++- metagpt/actions/project_management.py | 13 +- metagpt/actions/search_and_summarize.py | 44 ++++-- metagpt/actions/write_code.py | 14 +- metagpt/actions/write_code_review.py | 8 +- metagpt/actions/write_prd.py | 14 +- metagpt/const.py | 2 +- metagpt/environment.py | 26 ++-- metagpt/roles/architect.py | 21 ++- metagpt/roles/engineer.py | 33 ++-- metagpt/roles/product_manager.py | 41 +++-- metagpt/roles/project_manager.py | 30 ++-- metagpt/roles/role.py | 193 +++++++++++++----------- 14 files changed, 270 insertions(+), 246 deletions(-) diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index 3bfb69de4..e890ef76a 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -6,10 +6,9 @@ @File : action.py """ -from abc import ABC -from typing import Optional - +from typing import Optional, Any from tenacity import retry, stop_after_attempt, wait_random_exponential +from pydantic import BaseModel, Field from metagpt.actions.action_output import ActionOutput from metagpt.llm import LLM @@ -20,25 +19,22 @@ from metagpt.utils.utils import general_after_log from metagpt.utils.utils import import_class -class Action(ABC): - def __init__(self, name: str = "", context=None, llm: LLM = None): - self.name: str = name - if llm is None: - llm = LLM() - self.llm = llm - self.context = context - self.prefix = "" # aask*时会加上prefix,作为system_message - self.profile = "" # FIXME: USELESS - self.desc = "" # for skill manager - self.nodes = ... +action_subclass_registry = {} - # Output, useless - # self.content = "" - # self.instruct_content = None - # self.env = None - # def set_env(self, env): - # self.env = env +class Action(BaseModel): + name: str = "" + llm: LLM = Field(default_factory=LLM) + context = None + prefix = "" # aask*时会加上prefix,作为system_message + profile = "" # FIXME: USELESS + desc = "" # for skill manager + nodes = None + # content: Optional[str] = None + # instruct_content: Optional[str] = None + + def __init__(self, **kwargs: Any): + super().__init__(**kwargs) def set_prefix(self, prefix, profile): """Set prefix for later usage""" @@ -95,27 +91,26 @@ class Action(ABC): after=general_after_log(logger), ) async def _aask_v1( - self, - prompt: str, - output_class_name: str, - output_data_mapping: dict, - system_msgs: Optional[list[str]] = None, - format="markdown", # compatible to original format + self, + prompt: str, + output_class_name: str, + output_data_mapping: dict, + system_msgs: Optional[list[str]] = None, + format="markdown", # compatible to original format ) -> ActionOutput: content = await self.llm.aask(prompt, system_msgs) logger.debug(f"llm raw output:\n{content}") output_class = ActionOutput.create_model_class(output_class_name, output_data_mapping) - + if format == "json": parsed_data = llm_output_postprecess(output=content, schema=output_class.schema(), req_key="[/CONTENT]") - else: # using markdown parser parsed_data = OutputParser.parse_data_with_mapping(content, output_data_mapping) - - logger.debug(f"parsed_data:\n{parsed_data}") + + logger.debug(parsed_data) instruct_content = output_class(**parsed_data) return ActionOutput(content, instruct_content) - + async def run(self, *args, **kwargs): """Run action""" raise NotImplementedError("The run method should be implemented in a subclass.") diff --git a/metagpt/actions/design_api.py b/metagpt/actions/design_api.py index 5a5f52de7..a10ff1c9a 100644 --- a/metagpt/actions/design_api.py +++ b/metagpt/actions/design_api.py @@ -11,9 +11,12 @@ """ import json from pathlib import Path +from typing import Optional +from pydantic import Field from metagpt.actions import Action, ActionOutput from metagpt.actions.design_api_an import DESIGN_API_NODE +from metagpt.llm import LLM from metagpt.config import CONFIG from metagpt.const import ( DATA_API_DESIGN_FILE_REPO, @@ -25,12 +28,8 @@ from metagpt.const import ( from metagpt.logs import logger from metagpt.schema import Document, Documents from metagpt.utils.file_repository import FileRepository - -# from metagpt.utils.get_template import get_template from metagpt.utils.mermaid import mermaid_to_file -# from typing import List - NEW_REQ_TEMPLATE = """ ### Legacy Content @@ -42,13 +41,12 @@ NEW_REQ_TEMPLATE = """ class WriteDesign(Action): - def __init__(self, name, context=None, llm=None): - super().__init__(name, context, llm) - self.desc = ( - "Based on the PRD, think about the system design, and design the corresponding APIs, " - "data structures, library tables, processes, and paths. Please provide your design, feedback " - "clearly and in detail." - ) + name: str = "" + context: Optional[str] = None + llm: LLM = Field(default_factory=LLM) + desc: str = "Based on the PRD, think about the system design, and design the corresponding APIs, " + "data structures, library tables, processes, and paths. Please provide your design, feedback " + "clearly and in detail." async def run(self, with_messages, format=CONFIG.prompt_format): # Use `git diff` to identify which PRD documents have been modified in the `docs/prds` directory. diff --git a/metagpt/actions/project_management.py b/metagpt/actions/project_management.py index 1f14e7944..d830a4c15 100644 --- a/metagpt/actions/project_management.py +++ b/metagpt/actions/project_management.py @@ -9,11 +9,15 @@ 2. Move the document storage operations related to WritePRD from the save operation of WriteDesign. 3. According to the design in Section 2.2.3.5.4 of RFC 135, add incremental iteration functionality. """ + import json +from typing import List, Optional, Any +from pydantic import Field from metagpt.actions import ActionOutput from metagpt.actions.action import Action from metagpt.actions.project_management_an import PM_NODE +from metagpt.llm import LLM from metagpt.config import CONFIG from metagpt.const import ( PACKAGE_REQUIREMENTS_FILENAME, @@ -24,10 +28,8 @@ from metagpt.const import ( from metagpt.logs import logger from metagpt.schema import Document, Documents from metagpt.utils.file_repository import FileRepository +from metagpt.provider.base_gpt_api import BaseGPTAPI -# from typing import List - -# from metagpt.utils.get_template import get_template NEW_REQ_TEMPLATE = """ ### Legacy Content @@ -39,8 +41,9 @@ NEW_REQ_TEMPLATE = """ class WriteTasks(Action): - def __init__(self, name="CreateTasks", context=None, llm=None): - super().__init__(name, context, llm) + name: str = "CreateTasks" + context: Optional[str] = None + llm: BaseGPTAPI = Field(default_factory=LLM) async def run(self, with_messages, format=CONFIG.prompt_format): system_design_file_repo = CONFIG.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO) diff --git a/metagpt/actions/search_and_summarize.py b/metagpt/actions/search_and_summarize.py index 5e4cdaea0..7b549518e 100644 --- a/metagpt/actions/search_and_summarize.py +++ b/metagpt/actions/search_and_summarize.py @@ -6,12 +6,16 @@ @File : search_google.py """ import pydantic +from typing import Optional, Any +from pydantic import BaseModel, Field from metagpt.actions import Action +from metagpt.llm import LLM from metagpt.config import Config from metagpt.logs import logger from metagpt.schema import Message from metagpt.tools.search_engine import SearchEngine +from pydantic import root_validator SEARCH_AND_SUMMARIZE_SYSTEM = """### Requirements 1. Please summarize the latest dialogue based on the reference information (secondary) and dialogue history (primary). Do not include text that is irrelevant to the conversation. @@ -54,7 +58,6 @@ SEARCH_AND_SUMMARIZE_PROMPT = """ """ - SEARCH_AND_SUMMARIZE_SALES_SYSTEM = """## Requirements 1. Please summarize the latest dialogue based on the reference information (secondary) and dialogue history (primary). Do not include text that is irrelevant to the conversation. - The context is for reference only. If it is irrelevant to the user's search request history, please reduce its reference and usage. @@ -101,23 +104,38 @@ You are a member of a professional butler team and will provide helpful suggesti class SearchAndSummarize(Action): - def __init__(self, name="", context=None, llm=None, engine=None, search_func=None): - self.config = Config() - self.engine = engine or self.config.search_engine + name: str = "" + content: Optional[str] = None + llm: None = Field(default_factory=LLM) + config: None = Field(default_factory=Config) + engine: Optional[str] = None + search_func: Optional[str] = None + search_engine: SearchEngine = None - try: - self.search_engine = SearchEngine(self.engine, run_func=search_func) - except pydantic.ValidationError: - self.search_engine = None + result = "" - self.result = "" - super().__init__(name, context, llm) + @root_validator + def validate_engine_and_run_func(cls, values): + engine = values.get('engine') + search_func = values.get('search_func') + config = Config() + + if engine is None: + engine = config.search_engine + config_data = { + 'engine': engine, + 'run_func': search_func + } + search_engine = SearchEngine(**config_data) + + values['search_engine'] = search_engine + return values async def run(self, context: list[Message], system_text=SEARCH_AND_SUMMARIZE_SYSTEM) -> str: if self.search_engine is None: logger.warning("Configure one of SERPAPI_API_KEY, SERPER_API_KEY, GOOGLE_API_KEY to unlock full feature") return "" - + query = context[-1].content # logger.debug(query) rsp = await self.search_engine.run(query) @@ -126,9 +144,9 @@ class SearchAndSummarize(Action): logger.error("empty rsp...") return "" # logger.info(rsp) - + system_prompt = [system_text] - + prompt = SEARCH_AND_SUMMARIZE_PROMPT.format( # PREFIX = self.prefix, ROLE=self.profile, diff --git a/metagpt/actions/write_code.py b/metagpt/actions/write_code.py index 5960e2621..2d155e6bf 100644 --- a/metagpt/actions/write_code.py +++ b/metagpt/actions/write_code.py @@ -14,10 +14,17 @@ 3. Encapsulate the input of RunCode into RunCodeContext and encapsulate the output of RunCode into RunCodeResult to standardize and unify parameter passing between WriteCode, RunCode, and DebugError. """ + import json from tenacity import retry, stop_after_attempt, wait_random_exponential + + +from typing import List, Optional, Any +from pydantic import Field +from tenacity import retry, stop_after_attempt, wait_fixed + from metagpt.actions.action import Action from metagpt.config import CONFIG from metagpt.const import ( @@ -27,6 +34,8 @@ from metagpt.const import ( TASK_FILE_REPO, TEST_OUTPUTS_FILE_REPO, ) +from metagpt.actions import WriteDesign +from metagpt.llm import LLM from metagpt.logs import logger from metagpt.schema import CodingContext, Document, RunCodeResult from metagpt.utils.common import CodeParser @@ -84,8 +93,9 @@ ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenc class WriteCode(Action): - def __init__(self, name="WriteCode", context=None, llm=None): - super().__init__(name, context, llm) + name: str = "WriteCode" + context: Optional[str] = None + llm: LLM = Field(default_factory=LLM) @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) async def write_code(self, prompt) -> str: diff --git a/metagpt/actions/write_code_review.py b/metagpt/actions/write_code_review.py index 4b3e9aece..bf07d0a93 100644 --- a/metagpt/actions/write_code_review.py +++ b/metagpt/actions/write_code_review.py @@ -8,9 +8,12 @@ WriteCode object, rather than passing them in when calling the run function. """ +from typing import List, Optional, Any +from pydantic import Field from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.actions import WriteCode +from metagpt.llm import LLM from metagpt.actions.action import Action from metagpt.config import CONFIG from metagpt.logs import logger @@ -119,8 +122,9 @@ REWRITE_CODE_TEMPLATE = """ class WriteCodeReview(Action): - def __init__(self, name="WriteCodeReview", context=None, llm=None): - super().__init__(name, context, llm) + name: str = "WriteCodeReview" + context: Optional[str] = None + llm: LLM = Field(default_factory=LLM) @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) async def write_code_review_and_rewrite(self, context_prompt, cr_prompt, filename): diff --git a/metagpt/actions/write_prd.py b/metagpt/actions/write_prd.py index bb0cf8fb9..7f9089763 100644 --- a/metagpt/actions/write_prd.py +++ b/metagpt/actions/write_prd.py @@ -10,10 +10,13 @@ 3. Move the document storage operations related to WritePRD from the save operation of WriteDesign. @Modified By: mashenquan, 2023/12/5. Move the generation logic of the project name to WritePRD. """ + from __future__ import annotations import json from pathlib import Path +from typing import List, Optional, Any +from pydantic import BaseModel, Field from metagpt.actions import Action, ActionOutput from metagpt.actions.action_node import ActionNode @@ -23,6 +26,8 @@ from metagpt.actions.write_prd_an import ( WP_ISSUE_TYPE_NODE, WRITE_PRD_NODE, ) +from metagpt.llm import LLM +from metagpt.actions.search_and_summarize import SearchAndSummarize from metagpt.config import CONFIG from metagpt.const import ( BUGFIX_FILENAME, @@ -36,12 +41,8 @@ from metagpt.logs import logger from metagpt.schema import BugFixContext, Document, Documents, Message from metagpt.utils.common import CodeParser from metagpt.utils.file_repository import FileRepository - -# from metagpt.utils.get_template import get_template from metagpt.utils.mermaid import mermaid_to_file -# from typing import List - CONTEXT_TEMPLATE = """ ### Project Name @@ -64,8 +65,9 @@ NEW_REQ_TEMPLATE = """ class WritePRD(Action): - def __init__(self, name="", context=None, llm=None): - super().__init__(name, context, llm) + name: str = "" + content: Optional[str] = None + llm: LLM = Field(default_factory=LLM) async def run(self, with_messages, format=CONFIG.prompt_format, *args, **kwargs) -> ActionOutput | Message: # Determine which requirement documents need to be rewritten: Use LLM to assess whether new requirements are diff --git a/metagpt/const.py b/metagpt/const.py index b46bc15a4..9cf9726fc 100644 --- a/metagpt/const.py +++ b/metagpt/const.py @@ -60,7 +60,7 @@ SWAGGER_PATH = UT_PATH / "files/api/" UT_PY_PATH = UT_PATH / "files/ut/" API_QUESTIONS_PATH = UT_PATH / "files/question/" -SERDES_PATH = DEFAULT_WORKSPACE_ROOT / "storage" # TODO to store `storage` under the individual generated project +SERDESER_PATH = DEFAULT_WORKSPACE_ROOT / "storage" # TODO to store `storage` under the individual generated project TMP = METAGPT_ROOT / "tmp" diff --git a/metagpt/environment.py b/metagpt/environment.py index 14da6cd95..19197bd10 100644 --- a/metagpt/environment.py +++ b/metagpt/environment.py @@ -54,31 +54,33 @@ class Environment(BaseModel): write_json_file(history_path, {"content": self.history}) def deserialize(self, stg_path: Path): + """ stg_path: ./storage/team/environment/ """ """ stg_path: ./storage/team/environment/ """ roles_path = stg_path.joinpath("roles.json") roles_info = read_json_file(roles_path) + roles = [] for role_info in roles_info: - role_class = role_info.get("role_class") - role_name = role_info.get("role_name") - - role_path = stg_path.joinpath(f"roles/{role_class}_{role_name}") + # role stored in ./environment/roles/{role_class}_{role_name} + role_path = stg_path.joinpath(f'roles/{role_info.get("role_class")}_{role_info.get("role_name")}') role = Role.deserialize(role_path) + roles.append(role) - self.add_role(role) + history = read_json_file(stg_path.joinpath("history.json")) + history = history.get("content") - memory = Memory.deserialize(stg_path) - self.memory = memory - - history_path = stg_path.joinpath("history.json") - history = read_json_file(history_path) - self.history = history.get("content") + environment = Environment(**{ + "history": history + }) + environment.add_roles(roles) + return environment def add_role(self, role: Role): """增加一个在当前环境的角色 Add a role in the current environment """ role.set_env(self) - self.roles[role.profile] = role + # use alias + self.roles[role.role_profile] = role def add_roles(self, roles: Iterable[Role]): """增加一批在当前环境的角色 diff --git a/metagpt/roles/architect.py b/metagpt/roles/architect.py index fa91d393d..377531c8d 100644 --- a/metagpt/roles/architect.py +++ b/metagpt/roles/architect.py @@ -5,10 +5,11 @@ @Author : alexanderwu @File : architect.py """ +from pydantic import Field from metagpt.actions import WritePRD from metagpt.actions.design_api import WriteDesign -from metagpt.roles import Role +from metagpt.roles.role import Role class Architect(Role): @@ -21,18 +22,14 @@ class Architect(Role): goal (str): Primary goal or responsibility of the architect. constraints (str): Constraints or guidelines for the architect. """ + name: str = "Bob" + profile: str = Field(default="Architect", alias='profile') + goal: str = "design a concise, usable, complete software system" + constraints: str = "make sure the architecture is simple enough and use appropriate open source libraries." \ + "Use same language as user requirement" - def __init__( - self, - name: str = "Bob", - profile: str = "Architect", - goal: str = "design a concise, usable, complete software system", - constraints: str = "make sure the architecture is simple enough and use appropriate open source libraries." - "Use same language as user requirement" - ) -> None: - """Initializes the Architect with given attributes.""" - super().__init__(name, profile, goal, constraints) - + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) # Initialize actions specific to the Architect role self._init_actions([WriteDesign]) diff --git a/metagpt/roles/engineer.py b/metagpt/roles/engineer.py index f1e65b177..59ca18a17 100644 --- a/metagpt/roles/engineer.py +++ b/metagpt/roles/engineer.py @@ -16,8 +16,9 @@ @Modified By: mashenquan, 2023-12-5. Enhance the workflow to navigate to WriteCode or QaEngineer based on the results of SummarizeCode. """ -from __future__ import annotations +from __future__ import annotations +from pydantic import Field import json from collections import defaultdict from pathlib import Path @@ -44,9 +45,11 @@ from metagpt.schema import ( ) from metagpt.utils.common import any_to_str, any_to_str_set + IS_PASS_PROMPT = """ {context} +<<<<<<< HEAD ---- Does the above log indicate anything that needs to be done? If there are any tasks to be completed, please answer 'NO' along with the to-do list in JSON format; @@ -66,25 +69,21 @@ class Engineer(Role): n_borg (int): Number of borgs. use_code_review (bool): Whether to use code review. """ + name: str = "Alex" + role_profile: str = Field(default="Engineer", alias='profile') + goal: str = "write elegant, readable, extensible, efficient code" + constraints: str = "the code should conform to standards like google-style and be modular and maintainable. " \ + "Use same language as user requirement", + n_borg: int = 1 + use_code_review: bool = False + code_todos: list = [] + summarize_todos = [] + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) - def __init__( - self, - name: str = "Alex", - profile: str = "Engineer", - goal: str = "write elegant, readable, extensible, efficient code", - constraints: str = "the code should conform to standards like google-style and be modular and maintainable. " - "Use same language as user requirement", - n_borg: int = 1, - use_code_review: bool = False, - ) -> None: - """Initializes the Engineer role with given attributes.""" - super().__init__(name, profile, goal, constraints) - self.use_code_review = use_code_review self._init_actions([WriteCode]) self._watch([WriteTasks, SummarizeCode, WriteCode, WriteCodeReview, FixBug]) - self.code_todos = [] - self.summarize_todos = [] - self.n_borg = n_borg @staticmethod def _parse_tasks(task_msg: Document) -> list[str]: diff --git a/metagpt/roles/product_manager.py b/metagpt/roles/product_manager.py index e5e9f2b5e..a49459fca 100644 --- a/metagpt/roles/product_manager.py +++ b/metagpt/roles/product_manager.py @@ -7,40 +7,33 @@ @Modified By: mashenquan, 2023/11/27. Add `PrepareDocuments` action according to Section 2.2.3.5.1 of RFC 135. """ +from pydantic import Field + from metagpt.actions import UserRequirement, WritePRD from metagpt.actions.prepare_documents import PrepareDocuments from metagpt.config import CONFIG -from metagpt.roles import Role +from metagpt.roles.role import Role class ProductManager(Role): """ - Represents a Product Manager role responsible for product development and management. + Represents a Project Manager role responsible for overseeing project execution and team efficiency. Attributes: - name (str): Name of the product manager. - profile (str): Role profile, default is 'Product Manager'. - goal (str): Goal of the product manager. - constraints (str): Constraints or limitations for the product manager. + name (str): Name of the project manager. + profile (str): Role profile, default is 'Project Manager'. + goal (str): Goal of the project manager. + constraints (str): Constraints or limitations for the project manager. """ - - def __init__( - self, - name: str = "Alice", - profile: str = "Product Manager", - goal: str = "efficiently create a successful product", - constraints: str = "use same language as user requirement", - ) -> None: - """ - Initializes the ProductManager role with given attributes. - - Args: - name (str): Name of the product manager. - profile (str): Role profile. - goal (str): Goal of the product manager. - constraints (str): Constraints or limitations for the product manager. - """ - super().__init__(name, profile, goal, constraints) + name: str = "Alice" + role_profile: str = Field(default="Product Manager", alias='profile') + goal: str = "efficiently create a successful product" + constraints: str = "use same language as user requiremen" + """ + Represents a Product Manager role responsible for product development and management. + """ + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) self._init_actions([PrepareDocuments, WritePRD]) self._watch([UserRequirement, PrepareDocuments]) diff --git a/metagpt/roles/project_manager.py b/metagpt/roles/project_manager.py index 5a2b9be50..211e41d3b 100644 --- a/metagpt/roles/project_manager.py +++ b/metagpt/roles/project_manager.py @@ -5,9 +5,11 @@ @Author : alexanderwu @File : project_manager.py """ +from pydantic import Field + from metagpt.actions import WriteTasks from metagpt.actions.design_api import WriteDesign -from metagpt.roles import Role +from metagpt.roles.role import Role class ProjectManager(Role): @@ -20,24 +22,14 @@ class ProjectManager(Role): goal (str): Goal of the project manager. constraints (str): Constraints or limitations for the project manager. """ + name: str = "Eve" + profile: str = Field(default="Project Manager") + + goal: str = "reak down tasks according to PRD/technical design, generate a task list, and analyze task " \ + "dependencies to start with the prerequisite modules" + constraints: str = "use same language as user requirement" - def __init__( - self, - name: str = "Eve", - profile: str = "Project Manager", - goal: str = "break down tasks according to PRD/technical design, generate a task list, and analyze task " - "dependencies to start with the prerequisite modules", - constraints: str = "use same language as user requirement", - ) -> None: - """ - Initializes the ProjectManager role with given attributes. - - Args: - name (str): Name of the project manager. - profile (str): Role profile. - goal (str): Goal of the project manager. - constraints (str): Constraints or limitations for the project manager. - """ - super().__init__(name, profile, goal, constraints) + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) self._init_actions([WriteTasks]) self._watch([WriteDesign]) diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index bb3b2acfe..07a78e4bb 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -18,14 +18,16 @@ @Modified By: mashenquan, 2023-11-4. According to the routing feature plan in Chapter 2.2.3.2 of RFC 113, the routing functionality is to be consolidated into the `Environment` class. """ + from __future__ import annotations + from enum import Enum from typing import Iterable, Set, Type from pathlib import Path from pydantic import BaseModel, Field -from metagpt.actions import Action, ActionOutput +from metagpt.actions.action import Action, ActionOutput, action_subclass_registry from metagpt.actions.action_node import ActionNode from metagpt.actions.add_requirement import UserRequirement from metagpt.llm import LLM, HumanProvider @@ -35,6 +37,8 @@ from metagpt.utils.common import any_to_str from metagpt.utils.repair_llm_raw_output import extract_state_value_from_output from metagpt.memory import Memory from metagpt.utils.utils import read_json_file, write_json_file, import_class +from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.const import SERDESER_PATH PREFIX_TEMPLATE = """You are a {profile}, named {name}, your goal is {goal}, and the constraint is {constraints}. """ @@ -45,14 +49,12 @@ Please note that only the text between the first and second "===" is information {history} === -Your previous stage: {previous_state} - -Now choose one of the following stages you need to go to in the next step: +You can now choose one of the following stages to decide the stage you need to go in the next step: {states} Just answer a number between 0-{n_states}, choose the most suitable stage according to the understanding of the conversation. Please note that the answer only needs a number, no need to add any other text. -If you think you have completed your goal and don't need to go to any of the stages, return -1. +If there is no conversation record, choose 0. Do not answer anything else, and do not add any other information in your answer. """ @@ -89,7 +91,7 @@ class RoleSetting(BaseModel): def __str__(self): return f"{self.name}({self.profile})" - + def __repr__(self): return self.__str__() @@ -112,7 +114,7 @@ class RoleContext(BaseModel): class Config: arbitrary_types_allowed = True - + def check(self, role_id: str): # if hasattr(CONFIG, "long_term_memory") and CONFIG.long_term_memory: # self.long_term_memory.recover_memory(role_id, self) @@ -123,7 +125,7 @@ class RoleContext(BaseModel): def important_memory(self) -> list[Message]: """Get the information corresponding to the watched actions""" return self.memory.get_by_actions(self.watch) - + @property def history(self) -> list[Message]: return self.memory.get() @@ -139,56 +141,99 @@ class _RoleInjector(type): return instance -class Role(metaclass=_RoleInjector): - """Role/Agent""" +role_subclass_registry = {} - def __init__(self, name="", profile="", goal="", constraints="", desc="", is_human=False): - self._llm = LLM() if not is_human else HumanProvider() - self._setting = RoleSetting( - name=name, profile=profile, goal=goal, constraints=constraints, desc=desc, is_human=is_human - ) - self._llm.system_prompt = self._get_prefix() - self._states = [] - self._actions = [] - self._role_id = str(self._setting) - self._rc = RoleContext() + +class Role(BaseModel): + """Role/Agent""" + name: str = "" + profile: str = "" + goal: str = "" + constraints: str = "" + desc: str = "" + is_human: bool = False + + _llm: BaseGPTAPI = Field(default_factory=LLM) + _role_id: str = "" + _states: list[str] = Field(default=[]) + _actions: list[Action] = Field(default=[]) + _rc: RoleContext = Field(default=RoleContext) + _subscription: tuple = set() + + # builtin variables + recovered: bool = False # to tag if a recovered role + builtin_class_name: str = "" + + _private_attributes = { + "_llm": LLM() if not is_human else HumanProvider(), + "_role_id": _role_id, + "_states": [], + "_actions": [], + "_rc": RoleContext() + } + + class Config: + arbitrary_types_allowed = True + exclude = ["_llm"] + + def __init__(self, **kwargs): + for index in range(len(kwargs.get("_actions", []))): + current_action = kwargs["_actions"][index] + if isinstance(current_action, dict): + item_class_name = current_action.get("builtin_class_name", None) + for name, subclass in action_subclass_registry.items(): + registery_class_name = subclass.__fields__["builtin_class_name"].default + if item_class_name == registery_class_name: + current_action = subclass(**current_action) + break + kwargs["_actions"][index] = current_action + + super().__init__(**kwargs) + + # 关于私有变量的初始化 https://github.com/pydantic/pydantic/issues/655 + self._private_attributes["_llm"] = LLM() if not self.is_human else HumanProvider() + self._private_attributes["_role_id"] = str(self._setting) self._subscription = {any_to_str(self), name} if name else {any_to_str(self)} - self._recovered = False + for key in self._private_attributes.keys(): + if key in kwargs: + object.__setattr__(self, key, kwargs[key]) + if key == "_rc": + _rc = RoleContext(**kwargs["_rc"]) + object.__setattr__(self, "_rc", _rc) + else: + if key == "_rc": + # # Warning, if use self._private_attributes["_rc"], + # # self._rc will be a shared object between roles, so init one or reset it inside `_reset` + object.__setattr__(self, key, RoleContext()) + else: + object.__setattr__(self, key, self._private_attributes[key]) + + # deserialize child classes dynamically for inherited `role` + object.__setattr__(self, "builtin_class_name", self.__class__.__name__) + self.__fields__["builtin_class_name"].default = self.__class__.__name__ + + def _reset(self): + object.__setattr__(self, '_states', []) + object.__setattr__(self, '_actions', []) + + @property + def _setting(self): + return f"{self.name}({self.profile})" def serialize(self, stg_path: Path): - role_info_path = stg_path.joinpath("role_info.json") - role_info = { + stg_path = SERDESER_PATH.joinpath(f"team/environment/roles/{self.__class__.__name__}_{self.name}") \ + if stg_path is None else stg_path + + role_info = self.dict(exclude={"_rc": {"memory": True}, "_llm": True}) + role_info.update({ "role_class": self.__class__.__name__, "module_name": self.__module__ - } - setting = self._setting.dict() - setting.pop("desc") - setting.pop("is_human") # not all inherited roles have this atrr - role_info.update(setting) + }) + role_info_path = stg_path.joinpath("role_info.json") write_json_file(role_info_path, role_info) - actions_info_path = stg_path.joinpath("actions/actions_info.json") - actions_info = [] - for action in self._actions: - actions_info.append(action.serialize()) - write_json_file(actions_info_path, actions_info) - - watches_info_path = stg_path.joinpath("watches/watches_info.json") - watches_info = [] - for watch in self._rc.watch: - watches_info.append(watch.ser_class()) - write_json_file(watches_info_path, watches_info) - - actions_todo_path = stg_path.joinpath("actions/todo.json") - actions_todo = { - "cur_state": self._rc.state, - "react_mode": self._rc.react_mode.value, - "max_react_loop": self._rc.max_react_loop - } - write_json_file(actions_todo_path, actions_todo) - - self._rc.memory.serialize(stg_path) + self._rc.memory.serialize(stg_path) # serialize role's memory alone @classmethod def deserialize(cls, stg_path: Path) -> "Role": @@ -201,45 +246,13 @@ class Role(metaclass=_RoleInjector): role_class = import_class(class_name=role_class_str, module_name=module_name) role = role_class(**role_info) # initiate particular Role - actions_info_path = stg_path.joinpath("actions/actions_info.json") - actions = [] - actions_info = read_json_file(actions_info_path) - for action_info in actions_info: - action = Action.deserialize(action_info) - actions.append(action) - - watches_info_path = stg_path.joinpath("watches/watches_info.json") - watches = [] - watches_info = read_json_file(watches_info_path) - for watch_info in watches_info: - action = Action.deser_class(watch_info) - watches.append(action) - - role.init_actions(actions) - role.watch(watches) - - actions_todo_path = stg_path.joinpath("actions/todo.json") - # recover self._rc.state - actions_todo = read_json_file(actions_todo_path) - max_react_loop = actions_todo.get("max_react_loop", 1) - cur_state = actions_todo.get("cur_state", -1) - role.set_state(cur_state) - role.set_recovered(True) - react_mode_str = actions_todo.get("react_mode", RoleReactMode.REACT.value) - if react_mode_str not in RoleReactMode.values(): - logger.warning(f"ReactMode: {react_mode_str} not in {RoleReactMode.values()}, use react as default") - react_mode_str = RoleReactMode.REACT.value - role.set_react_mode(RoleReactMode(react_mode_str), max_react_loop) + role.set_recovered(True) # set True to make a tag role_memory = Memory.deserialize(stg_path) role.set_memory(role_memory) return role - def _reset(self): - self._states = [] - self._actions = [] - def _init_action_system_message(self, action: Action): action.set_prefix(self._get_prefix(), self.profile) @@ -256,7 +269,8 @@ class Role(metaclass=_RoleInjector): self._reset() for idx, action in enumerate(actions): if not isinstance(action, Action): - i = action("", llm=self._llm) + ## 默认初始化 + i = action() else: if self._setting.is_human and not isinstance(action.llm, HumanProvider): logger.warning( @@ -331,10 +345,6 @@ class Role(metaclass=_RoleInjector): if env: env.set_subscription(self, self._subscription) - @property - def name(self): - return self._setting.name - @property def profile(self): """Get the role description (position)""" @@ -355,7 +365,7 @@ class Role(metaclass=_RoleInjector): if self._setting.desc: return self._setting.desc return PREFIX_TEMPLATE.format(**self._setting.dict()) - + async def _think(self) -> None: """Think about what to do and decide on the next action""" if len(self._actions) == 1: @@ -378,6 +388,7 @@ class Role(metaclass=_RoleInjector): next_state = await self._llm.aask(prompt) next_state = extract_state_value_from_output(next_state) logger.debug(f"{prompt=}") + if (not next_state.isdigit() and next_state != "-1") or int(next_state) not in range(-1, len(self._states)): logger.warning(f"Invalid answer of state, {next_state=}, will be set to -1") next_state = -1 @@ -423,8 +434,8 @@ class Role(metaclass=_RoleInjector): if news_text: logger.debug(f"{self._setting} observed: {news_text}") return len(self._rc.news) - - def publish_message(self, msg): + + def _publish_message(self, msg): """If the role belongs to env, then the role's messages will be broadcast to env""" if not msg: return @@ -501,7 +512,7 @@ class Role(metaclass=_RoleInjector): def get_memories(self, k=0) -> list[Message]: """A wrapper to return the most recent k memories of this role, return all when k=0""" return self._rc.memory.get(k=k) - + async def run(self, with_message=None): """Observe, and think and act based on the results of the observation""" if with_message: From 9608a20c7127f3034e58293343249401d61a59ac Mon Sep 17 00:00:00 2001 From: stellahsr Date: Mon, 27 Nov 2023 21:13:19 +0800 Subject: [PATCH 037/167] update test cases for serialize_deserialize --- .../metagpt/serialize_deserialize/__init__.py | 4 ++ .../serialize_deserialize/test_actions.py | 24 ++++++++++ .../test_architect_deserialize.py | 26 ++++++++++ .../test_product_manager.py | 21 +++++++++ .../test_project_manager.py | 26 ++++++++++ .../serialize_deserialize/test_role.py | 41 ++++++++++++++++ .../serialize_deserialize/test_team.py | 47 +++++++++++++++++++ .../serialize_deserialize/test_wrire_prd.py | 28 +++++++++++ .../serialize_deserialize/test_write_code.py | 42 +++++++++++++++++ .../test_write_design.py | 39 +++++++++++++++ 10 files changed, 298 insertions(+) create mode 100644 tests/metagpt/serialize_deserialize/__init__.py create mode 100644 tests/metagpt/serialize_deserialize/test_actions.py create mode 100644 tests/metagpt/serialize_deserialize/test_architect_deserialize.py create mode 100644 tests/metagpt/serialize_deserialize/test_product_manager.py create mode 100644 tests/metagpt/serialize_deserialize/test_project_manager.py create mode 100644 tests/metagpt/serialize_deserialize/test_role.py create mode 100644 tests/metagpt/serialize_deserialize/test_team.py create mode 100644 tests/metagpt/serialize_deserialize/test_wrire_prd.py create mode 100644 tests/metagpt/serialize_deserialize/test_write_code.py create mode 100644 tests/metagpt/serialize_deserialize/test_write_design.py diff --git a/tests/metagpt/serialize_deserialize/__init__.py b/tests/metagpt/serialize_deserialize/__init__.py new file mode 100644 index 000000000..78f454fb5 --- /dev/null +++ b/tests/metagpt/serialize_deserialize/__init__.py @@ -0,0 +1,4 @@ +# -*- coding: utf-8 -*- +# @Date : 11/22/2023 11:48 AM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : diff --git a/tests/metagpt/serialize_deserialize/test_actions.py b/tests/metagpt/serialize_deserialize/test_actions.py new file mode 100644 index 000000000..e2efa982b --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_actions.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -*- +# @Date : 11/22/2023 11:48 AM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +import pytest + +from metagpt.actions import Action +from metagpt.llm import LLM + +def test_action_serialize(): + action = Action() + ser_action_dict = action.dict() + assert "name" in ser_action_dict + assert "llm" in ser_action_dict + +@pytest.mark.asyncio +async def test_action_deserialize(): + action = Action() + serialized_data = action.dict() + + new_action = Action(**serialized_data) + assert new_action.name == "" + assert new_action.llm == LLM() + assert len(await new_action._aask("who are you")) > 0 diff --git a/tests/metagpt/serialize_deserialize/test_architect_deserialize.py b/tests/metagpt/serialize_deserialize/test_architect_deserialize.py new file mode 100644 index 000000000..cff1bbadd --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_architect_deserialize.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- +# @Date : 11/26/2023 2:04 PM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +import pytest + +from metagpt.roles.architect import Architect +from metagpt.actions.action import Action + +def test_architect_serialize(): + role = Architect() + ser_role_dict = role.dict(by_alias=True) + assert "name" in ser_role_dict + assert "_states" in ser_role_dict + assert "_actions" in ser_role_dict + +@pytest.mark.asyncio +async def test_architect_deserialize(): + role = Architect() + ser_role_dict = role.dict(by_alias=True) + new_role = Architect(**ser_role_dict) + # new_role = Architect.deserialize(ser_role_dict) + assert new_role.name == "Bob" + assert len(new_role._actions) == 1 + assert isinstance(new_role._actions[0], Action) + await new_role._actions[0].run(context="write a cli snake game") \ No newline at end of file diff --git a/tests/metagpt/serialize_deserialize/test_product_manager.py b/tests/metagpt/serialize_deserialize/test_product_manager.py new file mode 100644 index 000000000..978c50e5e --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_product_manager.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- +# @Date : 11/26/2023 2:07 PM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +import pytest + +from metagpt.roles.product_manager import ProductManager +from metagpt.actions.action import Action +from metagpt.schema import Message + +@pytest.mark.asyncio +async def test_product_manager_deserialize(): + role = ProductManager() + ser_role_dict = role.dict(by_alias=True) + new_role = ProductManager(**ser_role_dict) + # new_role = ProductManager().deserialize(ser_role_dict) + + assert new_role.name == "Alice" + assert len(new_role._actions) == 1 + assert isinstance(new_role._actions[0], Action) + await new_role._actions[0].run([Message(content="write a cli snake game")]) \ No newline at end of file diff --git a/tests/metagpt/serialize_deserialize/test_project_manager.py b/tests/metagpt/serialize_deserialize/test_project_manager.py new file mode 100644 index 000000000..590bd8109 --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_project_manager.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- +# @Date : 11/26/2023 2:06 PM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +import pytest + +from metagpt.roles.project_manager import ProjectManager +from metagpt.actions.action import Action + +def test_project_manager_serialize(): + role = ProjectManager() + ser_role_dict = role.dict(by_alias=True) + assert "name" in ser_role_dict + assert "_states" in ser_role_dict + assert "_actions" in ser_role_dict + +@pytest.mark.asyncio +async def test_project_manager_deserialize(): + role = ProjectManager() + ser_role_dict = role.dict(by_alias=True) + new_role = ProjectManager(**ser_role_dict) + # new_role = ProjectManager().deserialize(ser_role_dict) + assert new_role.name == "Eve" + assert len(new_role._actions) == 1 + assert isinstance(new_role._actions[0], Action) + await new_role._actions[0].run(context="write a cli snake game") \ No newline at end of file diff --git a/tests/metagpt/serialize_deserialize/test_role.py b/tests/metagpt/serialize_deserialize/test_role.py new file mode 100644 index 000000000..432c9acb7 --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_role.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +# @Date : 11/23/2023 4:49 PM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +import pytest + +from metagpt.roles.role import Role +from metagpt.roles.engineer import Engineer + +from metagpt.actions.action import Action + + +def test_role_serialize(): + role = Role() + ser_role_dict = role.dict(by_alias=True) + assert "name" in ser_role_dict + assert "_states" in ser_role_dict + assert "_actions" in ser_role_dict + + +def test_engineer_serialize(): + role = Engineer() + ser_role_dict = role.dict(by_alias=True) + assert "name" in ser_role_dict + assert "_states" in ser_role_dict + assert "_actions" in ser_role_dict + + +@pytest.mark.asyncio +async def test_engineer_deserialize(): + role = Engineer(use_code_review=True) + ser_role_dict = role.dict(by_alias=True) + # new_role = Engineer().deserialize(ser_role_dict) + # also can be deserialized in this way: + new_role = Engineer(**ser_role_dict) + assert new_role.name == "Alex" + assert new_role.use_code_review == True + assert len(new_role._actions) == 2 + assert isinstance(new_role._actions[0], Action) + assert isinstance(new_role._actions[1], Action) + await new_role._actions[0].run(context="write a cli snake game", filename="test_code") diff --git a/tests/metagpt/serialize_deserialize/test_team.py b/tests/metagpt/serialize_deserialize/test_team.py new file mode 100644 index 000000000..44a75d262 --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_team.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +# @Date : 11/27/2023 10:07 AM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +import pytest + +from metagpt.environment import Environment +from metagpt.schema import Message +from metagpt.software_company import SoftwareCompany +from metagpt.roles import ProjectManager, ProductManager, Architect + + +def test_env_serialize(): + env = Environment() + ser_env_dict = env.dict() + assert "roles" in ser_env_dict + assert "memory" in ser_env_dict + assert "memory" in ser_env_dict + + +def test_env_deserialize(): + env = Environment() + env.publish_message(message=Message(content="test env serialize")) + ser_env_dict = env.dict() + new_env = Environment(**ser_env_dict) + assert len(new_env.roles) == 0 + assert new_env.memory.storage[0].content == "test env serialize" + assert len(new_env.history) == 25 + + +def test_softwarecompany_deserialize(): + team = SoftwareCompany() + team.hire( + [ + ProductManager(), + Architect(), + ProjectManager(), + ] + ) + assert len(team.environment.get_roles()) == 3 + ser_team_dict = team.dict() + new_team = SoftwareCompany(**ser_team_dict) + + assert len(new_team.environment.get_roles()) == 3 + assert new_team.environment.get_role('Product Manager') is not None + assert new_team.environment.get_role('Product Manager') is not None + assert new_team.environment.get_role('Architect') is not None diff --git a/tests/metagpt/serialize_deserialize/test_wrire_prd.py b/tests/metagpt/serialize_deserialize/test_wrire_prd.py new file mode 100644 index 000000000..9b2653820 --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_wrire_prd.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- +# @Date : 11/22/2023 1:47 PM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +import pytest + +from metagpt.actions import WritePRD +from metagpt.llm import LLM +from metagpt.schema import Message + + +def test_action_serialize(): + action = WritePRD() + ser_action_dict = action.dict() + assert "name" in ser_action_dict + assert "llm" in ser_action_dict + + +@pytest.mark.asyncio +async def test_action_deserialize(): + action = WritePRD() + serialized_data = action.dict() + new_action = WritePRD(**serialized_data) + # new_action = WritePRD().deserialize(serialized_data) + assert new_action.name == "" + assert new_action.llm == LLM() + assert len(await new_action.run([Message(content="write a cli snake game")]))>0 + diff --git a/tests/metagpt/serialize_deserialize/test_write_code.py b/tests/metagpt/serialize_deserialize/test_write_code.py new file mode 100644 index 000000000..0b1f1dc7c --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_write_code.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +# @Date : 11/23/2023 10:56 AM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +import pytest + +from metagpt.actions import WriteCode, WriteCodeReview +from metagpt.llm import LLM + +def test_write_design_serialize(): + action = WriteCode() + ser_action_dict = action.dict() + assert ser_action_dict["name"] == "WriteCode" + assert "llm" in ser_action_dict + +def test_write_task_serialize(): + action = WriteCodeReview() + ser_action_dict = action.dict() + assert ser_action_dict["name"] == "WriteCodeReview" + assert "llm" in ser_action_dict + +@pytest.mark.asyncio +async def test_write_code_deserialize(): + action = WriteCode() + serialized_data = action.dict() + new_action = WriteCode(**serialized_data) + # new_action = WriteCode().deserialize(serialized_data) + assert new_action.name == "WriteCode" + assert new_action.llm == LLM() + await new_action.run(context="write a cli snake game", filename="test_code") + +@pytest.mark.asyncio +async def test_write_code_review_deserialize(): + action = WriteCodeReview() + serialized_data = action.dict() + new_action = WriteCodeReview(**serialized_data) + # new_action = WriteCodeReview().deserialize(serialized_data) + code = await WriteCode().run(context="write a cli snake game", filename="test_code") + + assert new_action.name == "WriteCodeReview" + assert new_action.llm == LLM() + await new_action.run(context="write a cli snake game", code =code, filename="test_rewrite_code") \ No newline at end of file diff --git a/tests/metagpt/serialize_deserialize/test_write_design.py b/tests/metagpt/serialize_deserialize/test_write_design.py new file mode 100644 index 000000000..56bf78a63 --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_write_design.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +# @Date : 11/22/2023 8:19 PM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +import pytest + +from metagpt.actions import WriteDesign, WriteTasks +from metagpt.llm import LLM + +def test_write_design_serialize(): + action = WriteDesign() + ser_action_dict = action.dict() + assert "name" in ser_action_dict + assert "llm" in ser_action_dict + +def test_write_task_serialize(): + action = WriteTasks() + ser_action_dict = action.dict() + assert "name" in ser_action_dict + assert "llm" in ser_action_dict + +@pytest.mark.asyncio +async def test_write_design_deserialize(): + action = WriteDesign() + serialized_data = action.dict() + new_action = WriteDesign().deserialize(serialized_data) + assert new_action.name == "" + assert new_action.llm == LLM() + await new_action.run(context="write a cli snake game") + +@pytest.mark.asyncio +async def test_write_task_deserialize(): + action = WriteTasks() + serialized_data = action.dict() + new_action = WriteTasks(**serialized_data) + # new_action = WriteTasks().deserialize(serialized_data) + assert new_action.name == "CreateTasks" + assert new_action.llm == LLM() + await new_action.run(context="write a cli snake game") \ No newline at end of file From c08f6d83d792bc66eafea7d0d1dca61db41b1916 Mon Sep 17 00:00:00 2001 From: better629 Date: Tue, 28 Nov 2023 10:47:19 +0800 Subject: [PATCH 038/167] fix role and format ut of serialize_deserialize --- metagpt/roles/role.py | 5 ++--- tests/metagpt/serialize_deserialize/test_actions.py | 2 ++ .../serialize_deserialize/test_architect_deserialize.py | 2 ++ tests/metagpt/serialize_deserialize/test_product_manager.py | 1 + tests/metagpt/serialize_deserialize/test_project_manager.py | 2 ++ tests/metagpt/serialize_deserialize/test_role.py | 2 +- tests/metagpt/serialize_deserialize/test_wrire_prd.py | 4 ++-- tests/metagpt/serialize_deserialize/test_write_code.py | 6 +++++- tests/metagpt/serialize_deserialize/test_write_design.py | 6 +++++- 9 files changed, 22 insertions(+), 8 deletions(-) diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 07a78e4bb..f1d7df5e7 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -20,8 +20,6 @@ """ from __future__ import annotations - - from enum import Enum from typing import Iterable, Set, Type from pathlib import Path @@ -30,12 +28,13 @@ from pydantic import BaseModel, Field from metagpt.actions.action import Action, ActionOutput, action_subclass_registry from metagpt.actions.action_node import ActionNode from metagpt.actions.add_requirement import UserRequirement -from metagpt.llm import LLM, HumanProvider +from metagpt.llm import LLM from metagpt.logs import logger from metagpt.schema import Message, MessageQueue from metagpt.utils.common import any_to_str from metagpt.utils.repair_llm_raw_output import extract_state_value_from_output from metagpt.memory import Memory +from metagpt.provider.human_provider import HumanProvider from metagpt.utils.utils import read_json_file, write_json_file, import_class from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.const import SERDESER_PATH diff --git a/tests/metagpt/serialize_deserialize/test_actions.py b/tests/metagpt/serialize_deserialize/test_actions.py index e2efa982b..2fec2121a 100644 --- a/tests/metagpt/serialize_deserialize/test_actions.py +++ b/tests/metagpt/serialize_deserialize/test_actions.py @@ -7,12 +7,14 @@ import pytest from metagpt.actions import Action from metagpt.llm import LLM + def test_action_serialize(): action = Action() ser_action_dict = action.dict() assert "name" in ser_action_dict assert "llm" in ser_action_dict + @pytest.mark.asyncio async def test_action_deserialize(): action = Action() diff --git a/tests/metagpt/serialize_deserialize/test_architect_deserialize.py b/tests/metagpt/serialize_deserialize/test_architect_deserialize.py index cff1bbadd..d0ee3bc99 100644 --- a/tests/metagpt/serialize_deserialize/test_architect_deserialize.py +++ b/tests/metagpt/serialize_deserialize/test_architect_deserialize.py @@ -7,6 +7,7 @@ import pytest from metagpt.roles.architect import Architect from metagpt.actions.action import Action + def test_architect_serialize(): role = Architect() ser_role_dict = role.dict(by_alias=True) @@ -14,6 +15,7 @@ def test_architect_serialize(): assert "_states" in ser_role_dict assert "_actions" in ser_role_dict + @pytest.mark.asyncio async def test_architect_deserialize(): role = Architect() diff --git a/tests/metagpt/serialize_deserialize/test_product_manager.py b/tests/metagpt/serialize_deserialize/test_product_manager.py index 978c50e5e..2aed87a28 100644 --- a/tests/metagpt/serialize_deserialize/test_product_manager.py +++ b/tests/metagpt/serialize_deserialize/test_product_manager.py @@ -8,6 +8,7 @@ from metagpt.roles.product_manager import ProductManager from metagpt.actions.action import Action from metagpt.schema import Message + @pytest.mark.asyncio async def test_product_manager_deserialize(): role = ProductManager() diff --git a/tests/metagpt/serialize_deserialize/test_project_manager.py b/tests/metagpt/serialize_deserialize/test_project_manager.py index 590bd8109..fbc0dcc08 100644 --- a/tests/metagpt/serialize_deserialize/test_project_manager.py +++ b/tests/metagpt/serialize_deserialize/test_project_manager.py @@ -7,6 +7,7 @@ import pytest from metagpt.roles.project_manager import ProjectManager from metagpt.actions.action import Action + def test_project_manager_serialize(): role = ProjectManager() ser_role_dict = role.dict(by_alias=True) @@ -14,6 +15,7 @@ def test_project_manager_serialize(): assert "_states" in ser_role_dict assert "_actions" in ser_role_dict + @pytest.mark.asyncio async def test_project_manager_deserialize(): role = ProjectManager() diff --git a/tests/metagpt/serialize_deserialize/test_role.py b/tests/metagpt/serialize_deserialize/test_role.py index 432c9acb7..0e438d1a2 100644 --- a/tests/metagpt/serialize_deserialize/test_role.py +++ b/tests/metagpt/serialize_deserialize/test_role.py @@ -34,7 +34,7 @@ async def test_engineer_deserialize(): # also can be deserialized in this way: new_role = Engineer(**ser_role_dict) assert new_role.name == "Alex" - assert new_role.use_code_review == True + assert new_role.use_code_review is True assert len(new_role._actions) == 2 assert isinstance(new_role._actions[0], Action) assert isinstance(new_role._actions[1], Action) diff --git a/tests/metagpt/serialize_deserialize/test_wrire_prd.py b/tests/metagpt/serialize_deserialize/test_wrire_prd.py index 9b2653820..baa08ed76 100644 --- a/tests/metagpt/serialize_deserialize/test_wrire_prd.py +++ b/tests/metagpt/serialize_deserialize/test_wrire_prd.py @@ -24,5 +24,5 @@ async def test_action_deserialize(): # new_action = WritePRD().deserialize(serialized_data) assert new_action.name == "" assert new_action.llm == LLM() - assert len(await new_action.run([Message(content="write a cli snake game")]))>0 - + assert len(await new_action.run([Message(content="write a cli snake game")])) > 0 + diff --git a/tests/metagpt/serialize_deserialize/test_write_code.py b/tests/metagpt/serialize_deserialize/test_write_code.py index 0b1f1dc7c..9d659caaf 100644 --- a/tests/metagpt/serialize_deserialize/test_write_code.py +++ b/tests/metagpt/serialize_deserialize/test_write_code.py @@ -7,18 +7,21 @@ import pytest from metagpt.actions import WriteCode, WriteCodeReview from metagpt.llm import LLM + def test_write_design_serialize(): action = WriteCode() ser_action_dict = action.dict() assert ser_action_dict["name"] == "WriteCode" assert "llm" in ser_action_dict + def test_write_task_serialize(): action = WriteCodeReview() ser_action_dict = action.dict() assert ser_action_dict["name"] == "WriteCodeReview" assert "llm" in ser_action_dict - + + @pytest.mark.asyncio async def test_write_code_deserialize(): action = WriteCode() @@ -29,6 +32,7 @@ async def test_write_code_deserialize(): assert new_action.llm == LLM() await new_action.run(context="write a cli snake game", filename="test_code") + @pytest.mark.asyncio async def test_write_code_review_deserialize(): action = WriteCodeReview() diff --git a/tests/metagpt/serialize_deserialize/test_write_design.py b/tests/metagpt/serialize_deserialize/test_write_design.py index 56bf78a63..e6e236676 100644 --- a/tests/metagpt/serialize_deserialize/test_write_design.py +++ b/tests/metagpt/serialize_deserialize/test_write_design.py @@ -7,18 +7,21 @@ import pytest from metagpt.actions import WriteDesign, WriteTasks from metagpt.llm import LLM + def test_write_design_serialize(): action = WriteDesign() ser_action_dict = action.dict() assert "name" in ser_action_dict assert "llm" in ser_action_dict + def test_write_task_serialize(): action = WriteTasks() ser_action_dict = action.dict() assert "name" in ser_action_dict assert "llm" in ser_action_dict + @pytest.mark.asyncio async def test_write_design_deserialize(): action = WriteDesign() @@ -28,6 +31,7 @@ async def test_write_design_deserialize(): assert new_action.llm == LLM() await new_action.run(context="write a cli snake game") + @pytest.mark.asyncio async def test_write_task_deserialize(): action = WriteTasks() @@ -36,4 +40,4 @@ async def test_write_task_deserialize(): # new_action = WriteTasks().deserialize(serialized_data) assert new_action.name == "CreateTasks" assert new_action.llm == LLM() - await new_action.run(context="write a cli snake game") \ No newline at end of file + await new_action.run(context="write a cli snake game") From f7d5102fa62b06ad728f86b32e68023f7c4baa3c Mon Sep 17 00:00:00 2001 From: better629 Date: Thu, 30 Nov 2023 15:10:38 +0800 Subject: [PATCH 039/167] update unittest of ser&deser --- tests/metagpt/actions/test_action.py | 17 --- tests/metagpt/roles/test_role.py | 84 +----------- .../serialize_deserialize/test_action.py | 49 +++++++ .../serialize_deserialize/test_actions.py | 26 ---- .../test_architect_deserialize.py | 2 +- .../serialize_deserialize/test_environment.py | 91 +++++++++++++ .../test_memory.py | 34 ++++- .../test_product_manager.py | 4 +- .../test_project_manager.py | 6 +- .../serialize_deserialize/test_role.py | 63 ++++++++- .../serialize_deserialize/test_schema.py | 49 +++++++ .../test_serdeser_base.py | 88 +++++++++++++ .../serialize_deserialize/test_team.py | 124 +++++++++++++----- .../serialize_deserialize/test_wrire_prd.py | 1 - .../serialize_deserialize/test_write_code.py | 2 +- tests/metagpt/test_environment.py | 44 +++---- tests/metagpt/test_schema.py | 4 +- tests/metagpt/test_team.py | 22 +--- 18 files changed, 496 insertions(+), 214 deletions(-) create mode 100644 tests/metagpt/serialize_deserialize/test_action.py delete mode 100644 tests/metagpt/serialize_deserialize/test_actions.py create mode 100644 tests/metagpt/serialize_deserialize/test_environment.py rename tests/metagpt/{memory => serialize_deserialize}/test_memory.py (52%) create mode 100644 tests/metagpt/serialize_deserialize/test_schema.py create mode 100644 tests/metagpt/serialize_deserialize/test_serdeser_base.py diff --git a/tests/metagpt/actions/test_action.py b/tests/metagpt/actions/test_action.py index 4468a6f6f..9775630cc 100644 --- a/tests/metagpt/actions/test_action.py +++ b/tests/metagpt/actions/test_action.py @@ -11,20 +11,3 @@ from metagpt.actions import Action, WritePRD, WriteTest def test_action_repr(): actions = [Action(), WriteTest(), WritePRD()] assert "WriteTest" in str(actions) - - -def test_action_serdes(): - action_info = WriteTest.ser_class() - assert action_info["action_class"] == "WriteTest" - - action_class = Action.deser_class(action_info) - assert action_class == WriteTest - - -def test_action_class_serdes(): - name = "write test" - action_info = WriteTest(name=name).serialize() - assert action_info["name"] == name - - action = Action.deserialize(action_info) - assert action.name == name diff --git a/tests/metagpt/roles/test_role.py b/tests/metagpt/roles/test_role.py index a19ad9cb5..72cd84a9a 100644 --- a/tests/metagpt/roles/test_role.py +++ b/tests/metagpt/roles/test_role.py @@ -2,84 +2,10 @@ # -*- coding: utf-8 -*- # @Desc : unittest of Role -from pathlib import Path -import shutil -import pytest - -from metagpt.roles.role import Role, RoleReactMode -from metagpt.actions.action import Action -from metagpt.schema import Message -from metagpt.actions.add_requirement import BossRequirement -from metagpt.roles.product_manager import ProductManager - -serdes_path = Path(__file__).absolute().parent.joinpath("../../data/serdes_storage") +from metagpt.roles.role import Role -def test_role_serdes(): - stg_path_prefix = serdes_path.joinpath("team/environment/roles/") - shutil.rmtree(serdes_path.joinpath("team"), ignore_errors=True) - - pm = ProductManager() - role_tag = f"{pm.__class__.__name__}_{pm.name}" - stg_path = stg_path_prefix.joinpath(role_tag) - pm.serialize(stg_path) - assert stg_path.joinpath("actions/actions_info.json").exists() - - new_pm = Role.deserialize(stg_path) - assert new_pm.name == pm.name - assert len(new_pm.get_memories(1)) == 0 - - -class ActionOK(Action): - - async def run(self, messages: list["Message"]): - return "ok" - - -class ActionRaise(Action): - - async def run(self, messages: list["Message"]): - raise RuntimeError("parse error") - - -class RoleA(Role): - - def __init__(self, - name: str = "RoleA", - profile: str = "Role A", - goal: str = "", - constraints: str = ""): - super(RoleA, self).__init__(name=name, profile=profile, goal=goal, constraints=constraints) - self._init_actions([ActionOK, ActionRaise]) - self._watch([BossRequirement]) - self._rc.react_mode = RoleReactMode.BY_ORDER - - async def run(self, message: "Message" = None, stg_path: str = None): - try: - await super(RoleA, self).run(message) - except Exception as exp: - print("exp ", exp) - self.serialize(stg_path) - - -@pytest.mark.asyncio -async def test_role_serdes_interrupt(): - role_a = RoleA() - shutil.rmtree(serdes_path.joinpath("team"), ignore_errors=True) - - stg_path = serdes_path.joinpath(f"team/environment/roles/{role_a.__class__.__name__}_{role_a.name}") - await role_a.run( - message=Message(content="demo", cause_by=BossRequirement), - stg_path=stg_path - ) - assert role_a._rc.memory.count() == 2 - - assert stg_path.joinpath("actions/todo.json").exists() - - new_role_a: Role = Role.deserialize(stg_path) - assert new_role_a._rc.state == 1 - await role_a.run( - message=Message(content="demo", cause_by=BossRequirement), - stg_path=stg_path - ) - +def test_role_desc(): + role = Role(profile="Sales", desc="Best Seller") + assert role.profile == "Sales" + assert role._setting.desc == "Best Seller" diff --git a/tests/metagpt/serialize_deserialize/test_action.py b/tests/metagpt/serialize_deserialize/test_action.py new file mode 100644 index 000000000..b624dff5a --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_action.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +# @Date : 11/22/2023 11:48 AM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +import pytest + +from metagpt.actions import Action, WritePRD, WriteTest +from metagpt.llm import LLM +from metagpt.provider.openai_api import OpenAIGPTAPI + + +def test_action_serialize(): + action = Action() + ser_action_dict = action.dict() + assert "name" in ser_action_dict + assert "llm" in ser_action_dict + + +@pytest.mark.asyncio +async def test_action_deserialize(): + action = Action() + serialized_data = action.dict() + assert isinstance(serialized_data["llm"], OpenAIGPTAPI) + + new_action = Action(**serialized_data) + + assert new_action.name == "" + assert new_action.llm == LLM() + assert len(await new_action._aask("who are you")) > 0 + + +def test_action_serdeser(): + action_info = WriteTest.ser_class() + assert action_info["action_class"] == "WriteTest" + + action_class = Action.deser_class(action_info) + assert action_class == WriteTest + + +def test_action_class_serdeser(): + name = "write test" + action_info = WriteTest(name=name).serialize() + assert action_info["name"] == name + + action_info = WriteTest(name=name, llm=LLM()).serialize() + assert action_info["name"] == name + + action = Action.deserialize(action_info) + assert action.name == name diff --git a/tests/metagpt/serialize_deserialize/test_actions.py b/tests/metagpt/serialize_deserialize/test_actions.py deleted file mode 100644 index 2fec2121a..000000000 --- a/tests/metagpt/serialize_deserialize/test_actions.py +++ /dev/null @@ -1,26 +0,0 @@ -# -*- coding: utf-8 -*- -# @Date : 11/22/2023 11:48 AM -# @Author : stellahong (stellahong@fuzhi.ai) -# @Desc : -import pytest - -from metagpt.actions import Action -from metagpt.llm import LLM - - -def test_action_serialize(): - action = Action() - ser_action_dict = action.dict() - assert "name" in ser_action_dict - assert "llm" in ser_action_dict - - -@pytest.mark.asyncio -async def test_action_deserialize(): - action = Action() - serialized_data = action.dict() - - new_action = Action(**serialized_data) - assert new_action.name == "" - assert new_action.llm == LLM() - assert len(await new_action._aask("who are you")) > 0 diff --git a/tests/metagpt/serialize_deserialize/test_architect_deserialize.py b/tests/metagpt/serialize_deserialize/test_architect_deserialize.py index d0ee3bc99..fb58f0a3a 100644 --- a/tests/metagpt/serialize_deserialize/test_architect_deserialize.py +++ b/tests/metagpt/serialize_deserialize/test_architect_deserialize.py @@ -25,4 +25,4 @@ async def test_architect_deserialize(): assert new_role.name == "Bob" assert len(new_role._actions) == 1 assert isinstance(new_role._actions[0], Action) - await new_role._actions[0].run(context="write a cli snake game") \ No newline at end of file + await new_role._actions[0].run(context="write a cli snake game") diff --git a/tests/metagpt/serialize_deserialize/test_environment.py b/tests/metagpt/serialize_deserialize/test_environment.py new file mode 100644 index 000000000..15336eb6a --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_environment.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +from pathlib import Path +import shutil + +from metagpt.schema import Message +from metagpt.actions.action_output import ActionOutput +from metagpt.roles.project_manager import ProjectManager +from metagpt.actions.add_requirement import BossRequirement +from metagpt.actions.project_management import WriteTasks +from metagpt.environment import Environment +from tests.metagpt.serialize_deserialize.test_serdeser_base import RoleC, ActionOK, serdeser_path + + +def test_env_serialize(): + env = Environment() + ser_env_dict = env.dict() + assert "roles" in ser_env_dict + assert "memory" in ser_env_dict + + +def test_env_deserialize(): + env = Environment() + env.publish_message(message=Message(content="test env serialize")) + ser_env_dict = env.dict() + new_env = Environment(**ser_env_dict) + assert len(new_env.roles) == 0 + assert new_env.memory.storage[0].content == "test env serialize" + assert len(new_env.history) == 25 + + +def test_environment_serdeser(): + out_mapping = {"field1": (list[str], ...)} + out_data = {"field1": ["field1 value1", "field1 value2"]} + ic_obj = ActionOutput.create_model_class("prd", out_mapping) + + message = Message( + content="prd", + instruct_content=ic_obj(**out_data), + role="product manager", + cause_by=BossRequirement + ) + + environment = Environment() + role_c = RoleC() + environment.add_role(role_c) + environment.publish_message(message) + + ser_data = environment.dict() + assert ser_data["roles"]["Role C"]["name"] == "RoleC" + + new_env: Environment = Environment(**ser_data) + assert len(new_env.roles) == 1 + + assert new_env.memory.count() == 1 + assert new_env.memory.storage[0].instruct_content == ic_obj(**out_data) + assert list(new_env.roles.values())[0]._states == list(environment.roles.values())[0]._states + assert list(new_env.roles.values())[0]._actions == list(environment.roles.values())[0]._actions + assert isinstance(list(environment.roles.values())[0]._actions[0], ActionOK) + assert type(list(new_env.roles.values())[0]._actions[0]) == ActionOK + + +def test_environment_serdeser_v2(): + environment = Environment() + pm = ProjectManager() + environment.add_role(pm) + + ser_data = environment.dict() + + new_env: Environment = Environment(**ser_data) + role = new_env.get_role(pm.profile) + assert isinstance(role, ProjectManager) + assert isinstance(role._actions[0], WriteTasks) + assert isinstance(list(new_env.roles.values())[0]._actions[0], WriteTasks) + + +def test_environment_serdeser_save(): + environment = Environment() + role_c = RoleC() + + shutil.rmtree(serdeser_path.joinpath("team"), ignore_errors=True) + + stg_path = serdeser_path.joinpath("team/environment") + environment.add_role(role_c) + environment.serialize(stg_path) + + new_env: Environment = Environment.deserialize(stg_path) + assert len(new_env.roles) == 1 + assert type(list(new_env.roles.values())[0]._actions[0]) == ActionOK diff --git a/tests/metagpt/memory/test_memory.py b/tests/metagpt/serialize_deserialize/test_memory.py similarity index 52% rename from tests/metagpt/memory/test_memory.py rename to tests/metagpt/serialize_deserialize/test_memory.py index bda79ded1..e24f31af3 100644 --- a/tests/metagpt/memory/test_memory.py +++ b/tests/metagpt/serialize_deserialize/test_memory.py @@ -3,6 +3,7 @@ # @Desc : unittest of memory from pathlib import Path +from pydantic import BaseModel from metagpt.schema import Message from metagpt.memory.memory import Memory @@ -10,10 +11,36 @@ from metagpt.actions.action_output import ActionOutput from metagpt.actions.design_api import WriteDesign from metagpt.actions.add_requirement import BossRequirement -serdes_path = Path(__file__).absolute().parent.joinpath("../../data/serdes_storage") +from tests.metagpt.serialize_deserialize.test_serdeser_base import serdeser_path -def test_memory_serdes(): +def test_memory_serdeser(): + msg1 = Message(role="Boss", + content="write a snake game", + cause_by=BossRequirement) + + out_mapping = {"field2": (list[str], ...)} + out_data = {"field2": ["field2 value1", "field2 value2"]} + ic_obj = ActionOutput.create_model_class("system_design", out_mapping) + msg2 = Message(role="Architect", + instruct_content=ic_obj(**out_data), + content="system design content", + cause_by=WriteDesign) + + memory = Memory() + memory.add_batch([msg1, msg2]) + ser_data = memory.dict() + + new_memory = Memory(**ser_data) + assert new_memory.count() == 2 + new_msg2 = new_memory.get(2)[0] + assert isinstance(new_msg2, BaseModel) + assert isinstance(new_memory.storage[-1], BaseModel) + assert new_memory.storage[-1].cause_by == WriteDesign + assert new_msg2.role == "Boss" + + +def test_memory_serdeser_save(): msg1 = Message(role="User", content="write a 2048 game", cause_by=BossRequirement) @@ -29,7 +56,7 @@ def test_memory_serdes(): memory = Memory() memory.add_batch([msg1, msg2]) - stg_path = serdes_path.joinpath("team/environment") + stg_path = serdeser_path.joinpath("team/environment") memory.serialize(stg_path) assert stg_path.joinpath("memory.json").exists() @@ -38,5 +65,6 @@ def test_memory_serdes(): new_msg2 = new_memory.get(1)[0] assert new_msg2.instruct_content.field1 == ["field1 value1", "field1 value2"] assert new_msg2.cause_by == WriteDesign + assert len(new_memory.index) == 2 stg_path.joinpath("memory.json").unlink() diff --git a/tests/metagpt/serialize_deserialize/test_product_manager.py b/tests/metagpt/serialize_deserialize/test_product_manager.py index 2aed87a28..54584cf96 100644 --- a/tests/metagpt/serialize_deserialize/test_product_manager.py +++ b/tests/metagpt/serialize_deserialize/test_product_manager.py @@ -15,8 +15,8 @@ async def test_product_manager_deserialize(): ser_role_dict = role.dict(by_alias=True) new_role = ProductManager(**ser_role_dict) # new_role = ProductManager().deserialize(ser_role_dict) - + assert new_role.name == "Alice" assert len(new_role._actions) == 1 assert isinstance(new_role._actions[0], Action) - await new_role._actions[0].run([Message(content="write a cli snake game")]) \ No newline at end of file + await new_role._actions[0].run([Message(content="write a cli snake game")]) diff --git a/tests/metagpt/serialize_deserialize/test_project_manager.py b/tests/metagpt/serialize_deserialize/test_project_manager.py index fbc0dcc08..21fafa72e 100644 --- a/tests/metagpt/serialize_deserialize/test_project_manager.py +++ b/tests/metagpt/serialize_deserialize/test_project_manager.py @@ -6,6 +6,7 @@ import pytest from metagpt.roles.project_manager import ProjectManager from metagpt.actions.action import Action +from metagpt.actions.project_management import WriteTasks def test_project_manager_serialize(): @@ -20,9 +21,10 @@ def test_project_manager_serialize(): async def test_project_manager_deserialize(): role = ProjectManager() ser_role_dict = role.dict(by_alias=True) + new_role = ProjectManager(**ser_role_dict) - # new_role = ProjectManager().deserialize(ser_role_dict) assert new_role.name == "Eve" assert len(new_role._actions) == 1 assert isinstance(new_role._actions[0], Action) - await new_role._actions[0].run(context="write a cli snake game") \ No newline at end of file + assert isinstance(new_role._actions[0], WriteTasks) + # await new_role._actions[0].run(context="write a cli snake game") diff --git a/tests/metagpt/serialize_deserialize/test_role.py b/tests/metagpt/serialize_deserialize/test_role.py index 0e438d1a2..f260dea3a 100644 --- a/tests/metagpt/serialize_deserialize/test_role.py +++ b/tests/metagpt/serialize_deserialize/test_role.py @@ -2,12 +2,22 @@ # @Date : 11/23/2023 4:49 PM # @Author : stellahong (stellahong@fuzhi.ai) # @Desc : + +from pathlib import Path +import shutil import pytest +from metagpt.logs import logger from metagpt.roles.role import Role +from metagpt.actions import WriteCode, WriteCodeReview +from metagpt.schema import Message +from metagpt.actions.add_requirement import BossRequirement +from metagpt.roles.product_manager import ProductManager +from metagpt.const import SERDESER_PATH from metagpt.roles.engineer import Engineer +from metagpt.utils.utils import format_trackback_info -from metagpt.actions.action import Action +from tests.metagpt.serialize_deserialize.test_serdeser_base import RoleC, serdeser_path def test_role_serialize(): @@ -30,12 +40,53 @@ def test_engineer_serialize(): async def test_engineer_deserialize(): role = Engineer(use_code_review=True) ser_role_dict = role.dict(by_alias=True) - # new_role = Engineer().deserialize(ser_role_dict) - # also can be deserialized in this way: + new_role = Engineer(**ser_role_dict) assert new_role.name == "Alex" assert new_role.use_code_review is True assert len(new_role._actions) == 2 - assert isinstance(new_role._actions[0], Action) - assert isinstance(new_role._actions[1], Action) - await new_role._actions[0].run(context="write a cli snake game", filename="test_code") + assert isinstance(new_role._actions[0], WriteCode) + assert isinstance(new_role._actions[1], WriteCodeReview) + # await new_role._actions[0].run(context="write a cli snake game", filename="test_code") + + +def test_role_serdeser_save(): + stg_path_prefix = serdeser_path.joinpath("team/environment/roles/") + shutil.rmtree(serdeser_path.joinpath("team"), ignore_errors=True) + + pm = ProductManager() + role_tag = f"{pm.__class__.__name__}_{pm.name}" + stg_path = stg_path_prefix.joinpath(role_tag) + pm.serialize(stg_path) + assert stg_path.joinpath("actions/actions_info.json").exists() + + new_pm = Role.deserialize(stg_path) + assert new_pm.name == pm.name + assert len(new_pm.get_memories(1)) == 0 + + +@pytest.mark.asyncio +async def test_role_serdeser_interrupt(): + role_c = RoleC() + shutil.rmtree(SERDESER_PATH.joinpath("team"), ignore_errors=True) + + stg_path = SERDESER_PATH.joinpath(f"team/environment/roles/{role_c.__class__.__name__}_{role_c.name}") + try: + await role_c.run( + message=Message(content="demo", cause_by=BossRequirement) + ) + except Exception as exp: + logger.error(f"Exception in `role_a.run`, detail: {format_trackback_info()}") + role_c.serialize(stg_path) + + assert role_c._rc.memory.count() == 2 + + assert stg_path.joinpath("actions/todo.json").exists() + + new_role_a: Role = Role.deserialize(stg_path) + assert new_role_a._rc.state == 1 + + with pytest.raises(Exception): + await role_c.run( + message=Message(content="demo", cause_by=BossRequirement) + ) diff --git a/tests/metagpt/serialize_deserialize/test_schema.py b/tests/metagpt/serialize_deserialize/test_schema.py new file mode 100644 index 000000000..74b134cad --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_schema.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : unittest of schema ser&deser + +from metagpt.schema import Message +from metagpt.actions.action_output import ActionOutput +from metagpt.actions.write_code import WriteCode + +from tests.metagpt.serialize_deserialize.test_serdeser_base import MockMessage + + +def test_message_serdeser(): + out_mapping = {"field3": (str, ...), "field4": (list[str], ...)} + out_data = {"field3": "field3 value3", "field4": ["field4 value1", "field4 value2"]} + ic_obj = ActionOutput.create_model_class("code", out_mapping) + + message = Message( + content="code", + instruct_content=ic_obj(**out_data), + role="engineer", + cause_by=WriteCode + ) + ser_data = message.dict() + assert ser_data["cause_by"] == { + "action_class": "WriteCode", + "module_name": "metagpt.actions.write_code" + } + assert ser_data["instruct_content"]["class"] == "code" + + new_message = Message(**ser_data) + assert new_message.cause_by == WriteCode + assert new_message.cause_by in [WriteCode] + assert new_message.instruct_content == ic_obj(**out_data) + + +def test_message_without_postprocess(): + """ to explain `instruct_content` should be postprocessed """ + out_mapping = {"field1": (list[str], ...)} + out_data = {"field1": ["field1 value1", "field1 value2"]} + ic_obj = ActionOutput.create_model_class("code", out_mapping) + message = MockMessage( + content="code", + instruct_content=ic_obj(**out_data) + ) + ser_data = message.dict() + assert ser_data["instruct_content"] == {"field1": ["field1 value1", "field1 value2"]} + + new_message = MockMessage(**ser_data) + assert new_message.instruct_content != ic_obj(**out_data) diff --git a/tests/metagpt/serialize_deserialize/test_serdeser_base.py b/tests/metagpt/serialize_deserialize/test_serdeser_base.py new file mode 100644 index 000000000..35bad6cd9 --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_serdeser_base.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : base test actions / roles used in unittest + +from pydantic import BaseModel, Field +from pathlib import Path + +from metagpt.actions.action import Action +from metagpt.roles.role import Role, RoleReactMode +from metagpt.actions.add_requirement import BossRequirement + + +serdeser_path = Path(__file__).absolute().parent.joinpath("../../data/serdeser_storage") + + +class MockMessage(BaseModel): + """ to test normal dict without postprocess """ + content: str = "" + instruct_content: BaseModel = Field(default=None) + + +class ActionPass(Action): + name: str = "ActionPass" + + async def run(self, messages: list["Message"]): + return "pass" + + +class ActionOK(Action): + name: str = "ActionOK" + + async def run(self, messages: list["Message"]): + return "ok" + + +class ActionRaise(Action): + name: str = "ActionRaise" + + async def run(self, messages: list["Message"]): + raise RuntimeError("parse error in ActionRaise") + + +class RoleA(Role): + + name: str = Field(default="RoleA") + profile: str = Field(default="Role A") + goal: str = "RoleA's goal" + constraints: str = "RoleA's constraints" + + def __init__(self, **kwargs): + super(RoleA, self).__init__(**kwargs) + self._init_actions([ActionPass]) + self._watch([BossRequirement]) + + async def run(self, message: "Message" = None): + await super(RoleA, self).run(message) + + +class RoleB(Role): + name: str = Field(default="RoleB") + profile: str = Field(default="Role B") + goal: str = "RoleB's goal" + constraints: str = "RoleB's constraints" + + def __init__(self, **kwargs): + super(RoleB, self).__init__(**kwargs) + self._init_actions([ActionOK, ActionRaise]) + self._watch([ActionPass]) + self._rc.react_mode = RoleReactMode.BY_ORDER + + async def run(self, message: "Message" = None): + await super(RoleB, self).run(message) + + +class RoleC(Role): + name: str = Field(default="RoleC") + profile: str = Field(default="Role C") + goal: str = "RoleC's goal" + constraints: str = "RoleC's constraints" + + def __init__(self, **kwargs): + super(RoleC, self).__init__(**kwargs) + self._init_actions([ActionOK, ActionRaise]) + self._watch([BossRequirement]) + self._rc.react_mode = RoleReactMode.BY_ORDER + + async def run(self, message: "Message" = None): + await super(RoleC, self).run(message) diff --git a/tests/metagpt/serialize_deserialize/test_team.py b/tests/metagpt/serialize_deserialize/test_team.py index 44a75d262..e9122ebc0 100644 --- a/tests/metagpt/serialize_deserialize/test_team.py +++ b/tests/metagpt/serialize_deserialize/test_team.py @@ -2,46 +2,104 @@ # @Date : 11/27/2023 10:07 AM # @Author : stellahong (stellahong@fuzhi.ai) # @Desc : + +from pathlib import Path +import shutil import pytest -from metagpt.environment import Environment -from metagpt.schema import Message -from metagpt.software_company import SoftwareCompany from metagpt.roles import ProjectManager, ProductManager, Architect +from metagpt.team import Team +from metagpt.const import SERDESER_PATH + +from tests.metagpt.serialize_deserialize.test_serdeser_base import RoleA, RoleB, RoleC, serdeser_path -def test_env_serialize(): - env = Environment() - ser_env_dict = env.dict() - assert "roles" in ser_env_dict - assert "memory" in ser_env_dict - assert "memory" in ser_env_dict +def test_team_deserialize(): + company = Team() - -def test_env_deserialize(): - env = Environment() - env.publish_message(message=Message(content="test env serialize")) - ser_env_dict = env.dict() - new_env = Environment(**ser_env_dict) - assert len(new_env.roles) == 0 - assert new_env.memory.storage[0].content == "test env serialize" - assert len(new_env.history) == 25 - - -def test_softwarecompany_deserialize(): - team = SoftwareCompany() - team.hire( + pm = ProductManager() + arch = Architect() + company.hire( [ - ProductManager(), - Architect(), + pm, + arch, ProjectManager(), ] ) - assert len(team.environment.get_roles()) == 3 - ser_team_dict = team.dict() - new_team = SoftwareCompany(**ser_team_dict) - - assert len(new_team.environment.get_roles()) == 3 - assert new_team.environment.get_role('Product Manager') is not None - assert new_team.environment.get_role('Product Manager') is not None - assert new_team.environment.get_role('Architect') is not None + assert len(company.environment.get_roles()) == 3 + ser_company = company.dict() + new_company = Team(**ser_company) + + assert len(new_company.environment.get_roles()) == 3 + assert new_company.environment.get_role(pm.profile) is not None + + new_pm = new_company.environment.get_role(pm.profile) + assert type(new_pm) == ProductManager + assert new_company.environment.get_role(pm.profile) is not None + assert new_company.environment.get_role(arch.profile) is not None + + +def test_team_serdeser(): + company = Team() + company.hire([RoleC()]) + + stg_path = serdeser_path.joinpath("team") + shutil.rmtree(stg_path, ignore_errors=True) + + company.serialize(stg_path=stg_path) + + new_company = Team.deserialize(stg_path) + + assert len(new_company.environment.roles) == 1 + + +@pytest.mark.asyncio +async def test_team_recover(): + idea = "write a snake game" + stg_path = SERDESER_PATH.joinpath("team") + shutil.rmtree(stg_path, ignore_errors=True) + + company = Team() + company.hire([RoleC()]) + company.start_project(idea) + await company.run(n_round=4) + + ser_data = company.dict() + new_company = Team(**ser_data) + assert new_company.environment.memory.count() == 1 + assert type(list(new_company.environment.roles.values())[0]._actions[0]) == ActionOK + + new_company.start_project(idea) + await new_company.run(n_round=4) + + +@pytest.mark.asyncio +async def test_team_recover_save(): + idea = "write a 2048 web game" + stg_path = SERDESER_PATH.joinpath("team") + shutil.rmtree(stg_path, ignore_errors=True) + + company = Team() + company.hire([RoleC()]) + company.start_project(idea) + await company.run(n_round=4) + + new_company = Team.recover(stg_path) + new_company.start_project(idea) + await new_company.run(n_round=4) + + +@pytest.mark.asyncio +async def test_team_recover_multi_roles_save(): + idea = "write a snake game" + stg_path = SERDESER_PATH.joinpath("team") + shutil.rmtree(stg_path, ignore_errors=True) + + company = Team() + company.hire([RoleA(), RoleB()]) + company.start_project(idea) + await company.run(n_round=4) + + new_company = Team.recover(stg_path) + new_company.start_project(idea) + await new_company.run(n_round=4) diff --git a/tests/metagpt/serialize_deserialize/test_wrire_prd.py b/tests/metagpt/serialize_deserialize/test_wrire_prd.py index baa08ed76..96b4d19ad 100644 --- a/tests/metagpt/serialize_deserialize/test_wrire_prd.py +++ b/tests/metagpt/serialize_deserialize/test_wrire_prd.py @@ -25,4 +25,3 @@ async def test_action_deserialize(): assert new_action.name == "" assert new_action.llm == LLM() assert len(await new_action.run([Message(content="write a cli snake game")])) > 0 - diff --git a/tests/metagpt/serialize_deserialize/test_write_code.py b/tests/metagpt/serialize_deserialize/test_write_code.py index 9d659caaf..7f4799014 100644 --- a/tests/metagpt/serialize_deserialize/test_write_code.py +++ b/tests/metagpt/serialize_deserialize/test_write_code.py @@ -43,4 +43,4 @@ async def test_write_code_review_deserialize(): assert new_action.name == "WriteCodeReview" assert new_action.llm == LLM() - await new_action.run(context="write a cli snake game", code =code, filename="test_rewrite_code") \ No newline at end of file + await new_action.run(context="write a cli snake game", code=code, filename="test_rewrite_code") diff --git a/tests/metagpt/test_environment.py b/tests/metagpt/test_environment.py index 03236a08b..8aacdd77b 100644 --- a/tests/metagpt/test_environment.py +++ b/tests/metagpt/test_environment.py @@ -8,7 +8,6 @@ import pytest from pathlib import Path -import shutil from metagpt.actions import UserRequirement from metagpt.environment import Environment @@ -16,10 +15,9 @@ from metagpt.logs import logger from metagpt.manager import Manager from metagpt.roles import Architect, ProductManager, Role from metagpt.schema import Message -from tests.metagpt.roles.test_role import RoleA -serdes_path = Path(__file__).absolute().parent.joinpath("../data/serdes_storage") +serdeser_path = Path(__file__).absolute().parent.joinpath("../data/serdeser_storage") @pytest.fixture @@ -28,14 +26,23 @@ def env(): def test_add_role(env: Environment): - role = ProductManager("Alice", "product manager", "create a new product", "limited resources") + role = ProductManager(name="Alice", + profile="product manager", + goal="create a new product", + constraints="limited resources") env.add_role(role) assert env.get_role(role.profile) == role def test_get_roles(env: Environment): - role1 = Role("Alice", "product manager", "create a new product", "limited resources") - role2 = Role("Bob", "engineer", "develop the new product", "short deadline") + role1 = Role(name="Alice", + profile="product manager", + goal="create a new product", + constraints="limited resources") + role2 = Role(name="Bob", + profile="engineer", + goal="develop the new product", + constraints="short deadline") env.add_role(role1) env.add_role(role2) roles = env.get_roles() @@ -44,8 +51,14 @@ def test_get_roles(env: Environment): @pytest.mark.asyncio async def test_publish_and_process_message(env: Environment): - product_manager = ProductManager("Alice", "Product Manager", "做AI Native产品", "资源有限") - architect = Architect("Bob", "Architect", "设计一个可用、高效、较低成本的系统,包括数据结构与接口", "资源有限,需要节省成本") + product_manager = ProductManager(name="Alice", + profile="Product Manager", + goal="做AI Native产品", + constraints="资源有限") + architect = Architect(name="Bob", + profile="Architect", + goal="设计一个可用、高效、较低成本的系统,包括数据结构与接口", + constraints="资源有限,需要节省成本") env.add_roles([product_manager, architect]) env.set_manager(Manager()) @@ -54,18 +67,3 @@ async def test_publish_and_process_message(env: Environment): await env.run(k=2) logger.info(f"{env.history=}") assert len(env.history) > 10 - - -def test_environment_serdes(): - environment = Environment() - role_a = RoleA() - - shutil.rmtree(serdes_path.joinpath("team"), ignore_errors=True) - - stg_path = serdes_path.joinpath("team/environment") - environment.add_role(role_a) - environment.serialize(stg_path) - - new_env: Environment = Environment() - new_env.deserialize(stg_path) - assert len(new_env.roles) == 1 diff --git a/tests/metagpt/test_schema.py b/tests/metagpt/test_schema.py index 4a6f518b1..5eea789ea 100644 --- a/tests/metagpt/test_schema.py +++ b/tests/metagpt/test_schema.py @@ -74,7 +74,7 @@ def test_routes(): assert m.send_to == {"e", get_class_name(Action)} -def test_message_serdes(): +def test_message_serdeser(): out_mapping = {"field3": (str, ...), "field4": (list[str], ...)} out_data = {"field3": "field3 value3", "field4": ["field4 value1", "field4 value2"]} ic_obj = ActionOutput.create_model_class("code", out_mapping) @@ -86,7 +86,7 @@ def test_message_serdes(): cause_by=WriteCode ) message_dict = serialize_general_message(message) - assert message_dict["cause_by"] == {"action_class": "WriteCode"} + assert message_dict["cause_by"] == {"action_class": "WriteCode", "module_name": "metagpt.actions.write_code"} assert message_dict["instruct_content"] == { "class": "code", "mapping": { diff --git a/tests/metagpt/test_team.py b/tests/metagpt/test_team.py index ab201152c..efd035bb2 100644 --- a/tests/metagpt/test_team.py +++ b/tests/metagpt/test_team.py @@ -2,26 +2,12 @@ # -*- coding: utf-8 -*- # @Desc : unittest of team -from pathlib import Path -import shutil - from metagpt.team import Team - -from tests.metagpt.roles.test_role import RoleA - -serdes_path = Path(__file__).absolute().parent.joinpath("../data/serdes_storage") +from metagpt.roles.project_manager import ProjectManager -def test_team_serdes(): +def test_team(): company = Team() - company.hire([RoleA()]) + company.hire([ProjectManager()]) - stg_path = serdes_path.joinpath("team") - shutil.rmtree(stg_path, ignore_errors=True) - - company.serialize(stg_path=stg_path) - - new_company = Team() - new_company.deserialize(stg_path) - - assert len(new_company.environment.roles) == 1 + assert len(company.environment.roles) == 1 From 2abe99cf45ec07bf69c44ec4c374704a798fd4c6 Mon Sep 17 00:00:00 2001 From: better629 Date: Thu, 30 Nov 2023 15:18:24 +0800 Subject: [PATCH 040/167] update environment/message to BaseModel, update the ser&deser of roles/actions --- metagpt/actions/action.py | 28 ++++- metagpt/actions/design_api.py | 3 +- metagpt/actions/project_management.py | 1 + metagpt/actions/search_and_summarize.py | 7 +- metagpt/actions/write_code.py | 9 +- metagpt/actions/write_code_review.py | 3 +- metagpt/actions/write_prd.py | 3 +- metagpt/actions/write_test.py | 11 +- metagpt/environment.py | 20 +++- metagpt/memory/longterm_memory.py | 14 ++- metagpt/memory/memory.py | 64 +++++++---- metagpt/roles/customer_service.py | 16 ++- metagpt/roles/product_manager.py | 1 + metagpt/roles/project_manager.py | 2 +- metagpt/roles/qa_engineer.py | 24 +++-- metagpt/roles/role.py | 52 ++++++--- metagpt/roles/sales.py | 33 +++--- metagpt/roles/searcher.py | 23 ++-- metagpt/schema.py | 134 ++++++++++-------------- metagpt/team.py | 38 ++++--- metagpt/utils/serialize.py | 26 +++-- metagpt/utils/utils.py | 40 +++++++ startup.py | 17 +-- 23 files changed, 361 insertions(+), 208 deletions(-) diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index e890ef76a..499b5e794 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -6,12 +6,17 @@ @File : action.py """ +from __future__ import annotations +import re +from typing import Optional, Any + from typing import Optional, Any from tenacity import retry, stop_after_attempt, wait_random_exponential from pydantic import BaseModel, Field from metagpt.actions.action_output import ActionOutput from metagpt.llm import LLM +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.logs import logger from metagpt.provider.postprecess.llm_output_postprecess import llm_output_postprecess from metagpt.utils.common import OutputParser @@ -24,18 +29,31 @@ action_subclass_registry = {} class Action(BaseModel): name: str = "" - llm: LLM = Field(default_factory=LLM) - context = None + llm: BaseGPTAPI = Field(default_factory=LLM, exclude=True) + context = "" prefix = "" # aask*时会加上prefix,作为system_message profile = "" # FIXME: USELESS desc = "" # for skill manager - nodes = None # content: Optional[str] = None # instruct_content: Optional[str] = None + + # builtin variables + builtin_class_name: str = "" + + class Config: + arbitrary_types_allowed = True def __init__(self, **kwargs: Any): super().__init__(**kwargs) + # deserialize child classes dynamically for inherited `action` + object.__setattr__(self, "builtin_class_name", self.__class__.__name__) + self.__fields__["builtin_class_name"].default = self.__class__.__name__ + + def __init_subclass__(cls, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + action_subclass_registry[cls.__name__] = cls + def set_prefix(self, prefix, profile): """Set prefix for later usage""" self.prefix = prefix @@ -56,14 +74,14 @@ class Action(BaseModel): } @classmethod - def deserialize(cls, action_dict: dict): + def deserialize(cls, action_dict: dict) -> "Action": action_class_str = action_dict.pop("action_class") module_name = action_dict.pop("module_name") action_class = import_class(action_class_str, module_name) return action_class(**action_dict) @classmethod - def ser_class(cls): + def ser_class(cls) -> dict: """ serialize class type""" return { "action_class": cls.__name__, diff --git a/metagpt/actions/design_api.py b/metagpt/actions/design_api.py index a10ff1c9a..504328582 100644 --- a/metagpt/actions/design_api.py +++ b/metagpt/actions/design_api.py @@ -17,6 +17,7 @@ from pydantic import Field from metagpt.actions import Action, ActionOutput from metagpt.actions.design_api_an import DESIGN_API_NODE from metagpt.llm import LLM +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.config import CONFIG from metagpt.const import ( DATA_API_DESIGN_FILE_REPO, @@ -43,7 +44,7 @@ NEW_REQ_TEMPLATE = """ class WriteDesign(Action): name: str = "" context: Optional[str] = None - llm: LLM = Field(default_factory=LLM) + llm: BaseGPTAPI = Field(default_factory=LLM) desc: str = "Based on the PRD, think about the system design, and design the corresponding APIs, " "data structures, library tables, processes, and paths. Please provide your design, feedback " "clearly and in detail." diff --git a/metagpt/actions/project_management.py b/metagpt/actions/project_management.py index d830a4c15..98a948b64 100644 --- a/metagpt/actions/project_management.py +++ b/metagpt/actions/project_management.py @@ -18,6 +18,7 @@ from metagpt.actions import ActionOutput from metagpt.actions.action import Action from metagpt.actions.project_management_an import PM_NODE from metagpt.llm import LLM +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.config import CONFIG from metagpt.const import ( PACKAGE_REQUIREMENTS_FILENAME, diff --git a/metagpt/actions/search_and_summarize.py b/metagpt/actions/search_and_summarize.py index 7b549518e..7bff1c113 100644 --- a/metagpt/actions/search_and_summarize.py +++ b/metagpt/actions/search_and_summarize.py @@ -11,7 +11,8 @@ from pydantic import BaseModel, Field from metagpt.actions import Action from metagpt.llm import LLM -from metagpt.config import Config +from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.config import Config, CONFIG from metagpt.logs import logger from metagpt.schema import Message from metagpt.tools.search_engine import SearchEngine @@ -106,9 +107,9 @@ You are a member of a professional butler team and will provide helpful suggesti class SearchAndSummarize(Action): name: str = "" content: Optional[str] = None - llm: None = Field(default_factory=LLM) + llm: BaseGPTAPI = Field(default_factory=LLM) config: None = Field(default_factory=Config) - engine: Optional[str] = None + engine: Optional[str] = CONFIG.search_engine search_func: Optional[str] = None search_engine: SearchEngine = None diff --git a/metagpt/actions/write_code.py b/metagpt/actions/write_code.py index 2d155e6bf..bad9a0890 100644 --- a/metagpt/actions/write_code.py +++ b/metagpt/actions/write_code.py @@ -16,14 +16,9 @@ """ import json - from tenacity import retry, stop_after_attempt, wait_random_exponential - - - from typing import List, Optional, Any from pydantic import Field -from tenacity import retry, stop_after_attempt, wait_fixed from metagpt.actions.action import Action from metagpt.config import CONFIG @@ -34,8 +29,8 @@ from metagpt.const import ( TASK_FILE_REPO, TEST_OUTPUTS_FILE_REPO, ) -from metagpt.actions import WriteDesign from metagpt.llm import LLM +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.logs import logger from metagpt.schema import CodingContext, Document, RunCodeResult from metagpt.utils.common import CodeParser @@ -95,7 +90,7 @@ ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenc class WriteCode(Action): name: str = "WriteCode" context: Optional[str] = None - llm: LLM = Field(default_factory=LLM) + llm: BaseGPTAPI = Field(default_factory=LLM) @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) async def write_code(self, prompt) -> str: diff --git a/metagpt/actions/write_code_review.py b/metagpt/actions/write_code_review.py index bf07d0a93..83225060a 100644 --- a/metagpt/actions/write_code_review.py +++ b/metagpt/actions/write_code_review.py @@ -18,6 +18,7 @@ from metagpt.actions.action import Action from metagpt.config import CONFIG from metagpt.logs import logger from metagpt.schema import CodingContext +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.utils.common import CodeParser PROMPT_TEMPLATE = """ @@ -124,7 +125,7 @@ REWRITE_CODE_TEMPLATE = """ class WriteCodeReview(Action): name: str = "WriteCodeReview" context: Optional[str] = None - llm: LLM = Field(default_factory=LLM) + llm: BaseGPTAPI = Field(default_factory=LLM) @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) async def write_code_review_and_rewrite(self, context_prompt, cr_prompt, filename): diff --git a/metagpt/actions/write_prd.py b/metagpt/actions/write_prd.py index 7f9089763..8510733ac 100644 --- a/metagpt/actions/write_prd.py +++ b/metagpt/actions/write_prd.py @@ -27,6 +27,7 @@ from metagpt.actions.write_prd_an import ( WRITE_PRD_NODE, ) from metagpt.llm import LLM +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.actions.search_and_summarize import SearchAndSummarize from metagpt.config import CONFIG from metagpt.const import ( @@ -67,7 +68,7 @@ NEW_REQ_TEMPLATE = """ class WritePRD(Action): name: str = "" content: Optional[str] = None - llm: LLM = Field(default_factory=LLM) + llm: BaseGPTAPI = Field(default_factory=LLM) async def run(self, with_messages, format=CONFIG.prompt_format, *args, **kwargs) -> ActionOutput | Message: # Determine which requirement documents need to be rewritten: Use LLM to assess whether new requirements are diff --git a/metagpt/actions/write_test.py b/metagpt/actions/write_test.py index 9dd967788..fa3931ba6 100644 --- a/metagpt/actions/write_test.py +++ b/metagpt/actions/write_test.py @@ -7,6 +7,12 @@ @Modified By: mashenquan, 2023-11-27. Following the think-act principle, solidify the task parameters when creating the WriteTest object, rather than passing them in when calling the run function. """ + +from typing import Optional +from pydantic import Field + +from metagpt.llm import LLM +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.actions.action import Action from metagpt.config import CONFIG from metagpt.const import TEST_CODES_FILE_REPO @@ -36,8 +42,9 @@ you should correctly import the necessary classes based on these file locations! class WriteTest(Action): - def __init__(self, name="WriteTest", context=None, llm=None): - super().__init__(name, context, llm) + name: str = "WriteTest" + context: Optional[str] = None + llm: BaseGPTAPI = Field(default_factory=LLM) async def write_code(self, prompt): code_rsp = await self._aask(prompt) diff --git a/metagpt/environment.py b/metagpt/environment.py index 19197bd10..242581e17 100644 --- a/metagpt/environment.py +++ b/metagpt/environment.py @@ -19,6 +19,8 @@ from pydantic import BaseModel, Field from metagpt.logs import logger from metagpt.roles import Role +from metagpt.memory import Memory +from metagpt.roles.role import Role, role_subclass_registry from metagpt.schema import Message from metagpt.utils.common import is_subscribed from metagpt.utils.utils import read_json_file, write_json_file @@ -37,6 +39,19 @@ class Environment(BaseModel): class Config: arbitrary_types_allowed = True + def __init__(self, **kwargs): + for role_key, role in kwargs.get("roles", {}).items(): + current_role = kwargs["roles"][role_key] + if isinstance(current_role, dict): + item_class_name = current_role.get("builtin_class_name", None) + for name, subclass in role_subclass_registry.items(): + registery_class_name = subclass.__fields__["builtin_class_name"].default + if item_class_name == registery_class_name: + current_role = subclass(**current_role) + break + kwargs["roles"][role_key] = current_role + super().__init__(**kwargs) + def serialize(self, stg_path: Path): roles_path = stg_path.joinpath("roles.json") roles_info = [] @@ -53,7 +68,8 @@ class Environment(BaseModel): history_path = stg_path.joinpath("history.json") write_json_file(history_path, {"content": self.history}) - def deserialize(self, stg_path: Path): + @classmethod + def deserialize(cls, stg_path: Path) -> "Environment": """ stg_path: ./storage/team/environment/ """ """ stg_path: ./storage/team/environment/ """ roles_path = stg_path.joinpath("roles.json") @@ -80,7 +96,7 @@ class Environment(BaseModel): """ role.set_env(self) # use alias - self.roles[role.role_profile] = role + self.roles[role.profile] = role def add_roles(self, roles: Iterable[Role]): """增加一批在当前环境的角色 diff --git a/metagpt/memory/longterm_memory.py b/metagpt/memory/longterm_memory.py index 22032a86e..e8a5be395 100644 --- a/metagpt/memory/longterm_memory.py +++ b/metagpt/memory/longterm_memory.py @@ -4,6 +4,9 @@ @Desc : the implement of Long-term memory """ +from typing import Optional +from pydantic import Field + from metagpt.logs import logger from metagpt.memory import Memory from metagpt.memory.memory_storage import MemoryStorage @@ -17,11 +20,12 @@ class LongTermMemory(Memory): - update memory when it changed """ - def __init__(self): - self.memory_storage: MemoryStorage = MemoryStorage() - super(LongTermMemory, self).__init__() - self.rc = None # RoleContext - self.msg_from_recover = False + memory_storage: MemoryStorage = Field(default_factory=MemoryStorage) + rc: Optional["RoleContext"] = None + msg_from_recover: bool = False + + class Config: + arbitrary_types_allowed = True def recover_memory(self, role_id: str, rc: "RoleContext"): messages = self.memory_storage.recover_memory(role_id) diff --git a/metagpt/memory/memory.py b/metagpt/memory/memory.py index 43bd33e59..adef0d283 100644 --- a/metagpt/memory/memory.py +++ b/metagpt/memory/memory.py @@ -6,34 +6,51 @@ @File : memory.py @Modified By: mashenquan, 2023-11-1. According to RFC 116: Updated the type of index key. """ +import copy from collections import defaultdict -from typing import Iterable, Set +from typing import Iterable, Type, Union, Optional, Set from pathlib import Path +from pydantic import BaseModel, Field +import json from metagpt.schema import Message from metagpt.utils.common import any_to_str, any_to_str_set from metagpt.utils.utils import read_json_file, write_json_file -from metagpt.utils.serialize import serialize_general_message, deserialize_general_message +from metagpt.utils.utils import import_class -class Memory: +class Memory(BaseModel): """The most basic memory: super-memory""" - def __init__(self): - """Initialize an empty storage list and an empty index dictionary""" - self.storage: list[Message] = [] - self.index: dict[str, list[Message]] = defaultdict(list) + storage: list[Message] = Field(default=[]) + index: dict[str, list[Message]] = Field(default_factory=defaultdict(list)) + + def __init__(self, **kwargs): + index = kwargs.get("index", {}) + new_index = defaultdict(list) + for action_str, value in index.items(): + action_dict = json.loads(action_str) + action_class = import_class("Action", "metagpt.actions.action") + action_obj = action_class.deser_class(action_dict) + new_index[action_obj] = [Message(**item_dict) for item_dict in value] + kwargs["index"] = new_index + super(Memory, self).__init__(**kwargs) + self.index = new_index + + def dict(self, *args, **kwargs) -> "DictStrAny": + """ overwrite the `dict` to dump dynamic pydantic model""" + obj_dict = super(Memory, self).dict(*args, **kwargs) + new_obj_dict = copy.deepcopy(obj_dict) + new_obj_dict["index"] = {} + for action, value in obj_dict["index"].items(): + action_ser = json.dumps(action.ser_class()) + new_obj_dict["index"][action_ser] = value + return new_obj_dict def serialize(self, stg_path: Path): """ stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/ """ memory_path = stg_path.joinpath("memory.json") - - storage = [] - for message in self.storage: - # msg_dict = message.serialize() - msg_dict = serialize_general_message(message) - storage.append(msg_dict) - + storage = self.dict() write_json_file(memory_path, storage) @classmethod @@ -41,13 +58,8 @@ class Memory: """ stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/""" memory_path = stg_path.joinpath("memory.json") - memory = Memory() - memory_list = read_json_file(memory_path) - for message in memory_list: - # distinguish instruct_content type in message - # msg = Message.deserialize(message) - msg = deserialize_general_message(message) - memory.add(msg) + memory_dict = read_json_file(memory_path) + memory = Memory(**memory_dict) return memory @@ -71,6 +83,16 @@ class Memory: """Return all messages containing a specified content""" return [message for message in self.storage if content in message.content] + def delete_newest(self) -> "Message": + """ delete the newest message from the storage""" + if len(self.storage) > 0: + newest_msg = self.storage.pop() + if newest_msg.cause_by and newest_msg in self.index[newest_msg.cause_by]: + self.index[newest_msg.cause_by].remove(newest_msg) + else: + newest_msg = None + return newest_msg + def delete(self, message: Message): """Delete the specified message from storage, while updating the index""" self.storage.remove(message) diff --git a/metagpt/roles/customer_service.py b/metagpt/roles/customer_service.py index 188182d47..62792696f 100644 --- a/metagpt/roles/customer_service.py +++ b/metagpt/roles/customer_service.py @@ -5,6 +5,9 @@ @Author : alexanderwu @File : sales.py """ +from typing import Optional +from pydantic import Field + from metagpt.roles import Sales # from metagpt.actions import SearchAndSummarize @@ -24,5 +27,14 @@ DESC = """ class CustomerService(Sales): - def __init__(self, name="Xiaomei", profile="Human customer service", desc=DESC, store=None): - super().__init__(name, profile, desc=desc, store=store) + + name: str = Field(default="Xiaomei") + profile: str = Field(default="Human customer service") + desc: str = DESC, + + store: Optional[str] = None + + def __init__( + self, + **kwargs): + super().__init__(**kwargs) diff --git a/metagpt/roles/product_manager.py b/metagpt/roles/product_manager.py index a49459fca..30017b60d 100644 --- a/metagpt/roles/product_manager.py +++ b/metagpt/roles/product_manager.py @@ -29,6 +29,7 @@ class ProductManager(Role): role_profile: str = Field(default="Product Manager", alias='profile') goal: str = "efficiently create a successful product" constraints: str = "use same language as user requiremen" + """ Represents a Product Manager role responsible for product development and management. """ diff --git a/metagpt/roles/project_manager.py b/metagpt/roles/project_manager.py index 211e41d3b..b7ee1ed53 100644 --- a/metagpt/roles/project_manager.py +++ b/metagpt/roles/project_manager.py @@ -22,7 +22,7 @@ class ProjectManager(Role): goal (str): Goal of the project manager. constraints (str): Constraints or limitations for the project manager. """ - name: str = "Eve" + name: str = Field(default="Eve") profile: str = Field(default="Project Manager") goal: str = "reak down tasks according to PRD/technical design, generate a task list, and analyze task " \ diff --git a/metagpt/roles/qa_engineer.py b/metagpt/roles/qa_engineer.py index 4439b9b19..ec404570c 100644 --- a/metagpt/roles/qa_engineer.py +++ b/metagpt/roles/qa_engineer.py @@ -14,7 +14,9 @@ @Modified By: mashenquan, 2023-12-5. Enhance the workflow to navigate to WriteCode or QaEngineer based on the results of SummarizeCode. """ -from metagpt.actions import DebugError, RunCode, WriteTest + +from pydantic import Field + from metagpt.actions.summarize_code import SummarizeCode from metagpt.config import CONFIG from metagpt.const import ( @@ -22,6 +24,11 @@ from metagpt.const import ( TEST_CODES_FILE_REPO, TEST_OUTPUTS_FILE_REPO, ) +from metagpt.actions import ( + DebugError, + RunCode, + WriteTest, +) from metagpt.logs import logger from metagpt.roles import Role from metagpt.schema import Document, Message, RunCodeContext, TestingContext @@ -30,21 +37,22 @@ from metagpt.utils.file_repository import FileRepository class QaEngineer(Role): + name: str = Field(default="Edward") + profile: str = Field(default="QaEngineer") + goal: str = "Write comprehensive and robust tests to ensure codes will work as expected without bugs" + constraints: str = "The test code you write should conform to code standard like PEP8, be modular, easy to read and maintain" + test_round_allowed: int = 5 + def __init__( self, - name="Edward", - profile="QaEngineer", - goal="Write comprehensive and robust tests to ensure codes will work as expected without bugs", - constraints="The test code you write should conform to code standard like PEP8, be modular, easy to read and maintain", - test_round_allowed=5, + **kwargs ): - super().__init__(name, profile, goal, constraints) + super().__init__(**kwargs) self._init_actions( [WriteTest] ) # FIXME: a bit hack here, only init one action to circumvent _think() logic, will overwrite _think() in future updates self._watch([SummarizeCode, WriteTest, RunCode, DebugError]) self.test_round = 0 - self.test_round_allowed = test_round_allowed async def _write_test(self, message: Message) -> None: src_file_repo = CONFIG.git_repo.new_file_repository(CONFIG.src_workspace) diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index f1d7df5e7..114e9e599 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -28,15 +28,32 @@ from pydantic import BaseModel, Field from metagpt.actions.action import Action, ActionOutput, action_subclass_registry from metagpt.actions.action_node import ActionNode from metagpt.actions.add_requirement import UserRequirement + +from pathlib import Path + +from typing import ( + Iterable, + Type, + Any +) +from pydantic import BaseModel, Field, validator + +# from metagpt.environment import Environment +from metagpt.config import CONFIG +from metagpt.actions.action import Action, ActionOutput, action_subclass_registry from metagpt.llm import LLM +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.logs import logger from metagpt.schema import Message, MessageQueue from metagpt.utils.common import any_to_str from metagpt.utils.repair_llm_raw_output import extract_state_value_from_output from metagpt.memory import Memory from metagpt.provider.human_provider import HumanProvider + from metagpt.utils.utils import read_json_file, write_json_file, import_class from metagpt.provider.base_gpt_api import BaseGPTAPI + +from metagpt.utils.utils import read_json_file, write_json_file, import_class, role_raise_decorator from metagpt.const import SERDESER_PATH @@ -80,13 +97,12 @@ class RoleReactMode(str, Enum): class RoleSetting(BaseModel): """Role Settings""" - - name: str - profile: str - goal: str - constraints: str - desc: str - is_human: bool + name: str = "" + profile: str = "" + goal: str = "" + constraints: str = "" + desc: str = "" + is_human: bool = False def __str__(self): return f"{self.name}({self.profile})" @@ -174,8 +190,8 @@ class Role(BaseModel): class Config: arbitrary_types_allowed = True exclude = ["_llm"] - - def __init__(self, **kwargs): + + def __init__(self, **kwargs: Any): for index in range(len(kwargs.get("_actions", []))): current_action = kwargs["_actions"][index] if isinstance(current_action, dict): @@ -212,15 +228,19 @@ class Role(BaseModel): object.__setattr__(self, "builtin_class_name", self.__class__.__name__) self.__fields__["builtin_class_name"].default = self.__class__.__name__ + def __init_subclass__(cls, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + role_subclass_registry[cls.__name__] = cls + def _reset(self): - object.__setattr__(self, '_states', []) - object.__setattr__(self, '_actions', []) + object.__setattr__(self, "_states", []) + object.__setattr__(self, "_actions", []) @property def _setting(self): return f"{self.name}({self.profile})" - def serialize(self, stg_path: Path): + def serialize(self, stg_path: Path = None): stg_path = SERDESER_PATH.joinpath(f"team/environment/roles/{self.__class__.__name__}_{self.name}") \ if stg_path is None else stg_path @@ -256,7 +276,7 @@ class Role(BaseModel): action.set_prefix(self._get_prefix(), self.profile) def set_recovered(self, recovered: bool = False): - self._recovered = recovered + self.recovered = recovered def set_memory(self, memory: Memory): self._rc.memory = memory @@ -269,7 +289,7 @@ class Role(BaseModel): for idx, action in enumerate(actions): if not isinstance(action, Action): ## 默认初始化 - i = action() + i = action(name="", llm=self._llm) else: if self._setting.is_human and not isinstance(action.llm, HumanProvider): logger.warning( @@ -358,6 +378,10 @@ class Role(BaseModel): def subscription(self) -> Set: """The labels for messages to be consumed by the Role object.""" return self._subscription + + def set_env(self, env: "Environment"): + """Set the environment in which the role works. The role can talk to the environment and can also receive messages by observing.""" + self._rc.env = env def _get_prefix(self): """Get the role prefix""" diff --git a/metagpt/roles/sales.py b/metagpt/roles/sales.py index d5aac1824..826413dc8 100644 --- a/metagpt/roles/sales.py +++ b/metagpt/roles/sales.py @@ -5,26 +5,31 @@ @Author : alexanderwu @File : sales.py """ + +from typing import Optional +from pydantic import Field + from metagpt.actions import SearchAndSummarize from metagpt.roles import Role from metagpt.tools import SearchEngineType class Sales(Role): - def __init__( - self, - name="Xiaomei", - profile="Retail sales guide", - desc="I am a sales guide in retail. My name is Xiaomei. I will answer some customer questions next, and I " - "will answer questions only based on the information in the knowledge base." - "If I feel that you can't get the answer from the reference material, then I will directly reply that" - " I don't know, and I won't tell you that this is from the knowledge base," - "but pretend to be what I know. Note that each of my replies will be replied in the tone of a " - "professional guide", - store=None, - ): - super().__init__(name, profile, desc=desc) - self._set_store(store) + + name: str = Field(default="Xiaomei") + profile: str = Field(default="Retail sales guide") + desc: str = "I am a sales guide in retail. My name is Xiaomei. I will answer some customer questions next, and I " + "will answer questions only based on the information in the knowledge base." + "If I feel that you can't get the answer from the reference material, then I will directly reply that" + " I don't know, and I won't tell you that this is from the knowledge base," + "but pretend to be what I know. Note that each of my replies will be replied in the tone of a " + "professional guide", + + store: Optional[str] = None + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._set_store(self.store) def _set_store(self, store): if store: diff --git a/metagpt/roles/searcher.py b/metagpt/roles/searcher.py index 5760202ff..7d58ad922 100644 --- a/metagpt/roles/searcher.py +++ b/metagpt/roles/searcher.py @@ -7,6 +7,9 @@ @Modified By: mashenquan, 2023-11-1. According to Chapter 2.2.1 and 2.2.2 of RFC 116, change the data type of the `cause_by` value in the `Message` to a string to support the new message distribution feature. """ + +from pydantic import Field + from metagpt.actions import ActionOutput, SearchAndSummarize from metagpt.actions.action_node import ActionNode from metagpt.logs import logger @@ -27,15 +30,13 @@ class Searcher(Role): engine (SearchEngineType): The type of search engine to use. """ - def __init__( - self, - name: str = "Alice", - profile: str = "Smart Assistant", - goal: str = "Provide search services for users", - constraints: str = "Answer is rich and complete", - engine=SearchEngineType.SERPAPI_GOOGLE, - **kwargs, - ) -> None: + name: str = Field(default="Alice") + profile: str = Field(default="Smart Assistant") + goal: str = "Provide search services for users" + constraints: str = "Answer is rich and complete" + engine: SearchEngineType = SearchEngineType.SERPAPI_GOOGLE + + def __init__(self, **kwargs) -> None: """ Initializes the Searcher role with given attributes. @@ -46,8 +47,8 @@ class Searcher(Role): constraints (str): Constraints or limitations for the searcher. engine (SearchEngineType): The type of search engine to use. """ - super().__init__(name, profile, goal, constraints, **kwargs) - self._init_actions([SearchAndSummarize(engine=engine)]) + super().__init__(**kwargs) + self._init_actions([SearchAndSummarize(engine=self.engine)]) def set_search_func(self, search_func): """Sets a custom search function for the searcher.""" diff --git a/metagpt/schema.py b/metagpt/schema.py index 78e4a6031..a872481bb 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -12,7 +12,6 @@ between actions. 3. Add `id` to `Message` according to Section 2.2.3.1.1 of RFC 135. """ -from __future__ import annotations import asyncio import json @@ -24,6 +23,12 @@ from pathlib import Path from typing import Dict, List, Optional, Set, TypedDict from pydantic import BaseModel, Field +from dataclasses import dataclass, field +from typing import Type, TypedDict, Union, Optional + +from pydantic import BaseModel, Field +from pydantic.main import ModelMetaclass + from metagpt.config import CONFIG from metagpt.const import ( MESSAGE_ROUTE_CAUSE_BY, @@ -34,11 +39,16 @@ from metagpt.const import ( TASK_FILE_REPO, ) from metagpt.logs import logger + from metagpt.utils.common import any_to_str, any_to_str_set # from metagpt.utils.serialize import actionoutout_schema_to_mapping # from metagpt.actions.action_output import ActionOutput # from metagpt.actions.action import Action +from metagpt.utils.serialize import actionoutout_schema_to_mapping, actionoutput_mapping_to_str, \ + actionoutput_str_to_mapping +from metagpt.utils.utils import import_class + class RawMessage(TypedDict): content: str @@ -54,7 +64,7 @@ class Document(BaseModel): filename: str = "" content: str = "" - def get_meta(self) -> Document: + def get_meta(self) -> "Document"": """Get metadata of the document. :return: A new Document instance with the same root path and filename. @@ -104,39 +114,21 @@ class Message(BaseModel): sent_from: str = "" send_to: Set = Field(default_factory={MESSAGE_ROUTE_TO_ALL}) - def __init__( - self, - content, - instruct_content=None, - role="user", - cause_by="", - sent_from="", - send_to=MESSAGE_ROUTE_TO_ALL, - **kwargs, - ): - """ - Parameters not listed below will be stored as meta info, including custom parameters. - :param content: Message content. - :param instruct_content: Message content struct. - :param cause_by: Message producer - :param sent_from: Message route info tells who sent this message. - :param send_to: Specifies the target recipient or consumer for message delivery in the environment. - :param role: Message meta info tells who sent this message. - """ - if not cause_by: - from metagpt.actions import UserRequirement - cause_by = UserRequirement + def __init__(self, **kwargs): + instruct_content = kwargs.get("instruct_content", None) + cause_by = kwargs.get("cause_by", None) + if instruct_content and not isinstance(instruct_content, BaseModel): + ic = instruct_content + mapping = actionoutput_str_to_mapping(ic["mapping"]) - super().__init__( - id=uuid.uuid4().hex, - content=content, - instruct_content=instruct_content, - role=role, - cause_by=any_to_str(cause_by), - sent_from=any_to_str(sent_from), - send_to=any_to_str_set(send_to), - **kwargs, - ) + actionoutput_class = import_class("ActionOutput", "metagpt.actions.action_output") + ic_obj = actionoutput_class.create_model_class(class_name=ic["class"], mapping=mapping) + ic_new = ic_obj(**ic["value"]) + kwargs["instruct_content"] = ic_new + if cause_by and not isinstance(cause_by, ModelMetaclass): + action_class = import_class("Action", "metagpt.actions.action") + kwargs["cause_by"] = action_class.deser_class(cause_by) + super(Message, self).__init__(**kwargs) def __setattr__(self, key, val): """Override `@property.setter`, convert non-string parameters into string parameters.""" @@ -150,6 +142,21 @@ class Message(BaseModel): new_val = val super().__setattr__(key, new_val) + def dict(self, *args, **kwargs) -> "DictStrAny": + """ overwrite the `dict` to dump dynamic pydantic model""" + obj_dict = super(Message, self).dict(*args, **kwargs) + ic = self.instruct_content # deal custom-defined action + if ic: + schema = ic.schema() + mapping = actionoutout_schema_to_mapping(schema) + mapping = actionoutput_mapping_to_str(mapping) + + obj_dict["instruct_content"] = {"class": schema["title"], "mapping": mapping, "value": ic.dict()} + cb = self.cause_by + if cb: + obj_dict["cause_by"] = cb.ser_class() + return obj_dict + def __str__(self): # prefix = '-'.join([self.role, str(self.cause_by)]) return f"{self.role}: {self.content}" @@ -157,45 +164,16 @@ class Message(BaseModel): def __repr__(self): return self.__str__() - # def serialize(self): - # message_cp: Message = copy.deepcopy(self) - # ic = message_cp.instruct_content - # if ic: - # # model create by pydantic create_model like `pydantic.main.prd`, can't pickle.dump directly - # schema = ic.schema() - # mapping = actionoutout_schema_to_mapping(schema) - # - # message_cp.instruct_content = {"class": schema["title"], "mapping": mapping, "value": ic.dict()} - # cb = message_cp.cause_by - # if cb: - # message_cp.cause_by = cb.serialize() - # - # return message_cp.dict() - # - # @classmethod - # def deserialize(cls, message_dict: dict): - # instruct_content = message_dict.get("instruct_content") - # if instruct_content: - # ic = instruct_content - # ic_obj = ActionOutput.create_model_class(class_name=ic["class"], mapping=ic["mapping"]) - # ic_new = ic_obj(**ic["value"]) - # message_dict.instruct_content = ic_new - # cause_by = message_dict.get("cause_by") - # if cause_by: - # message_dict.cause_by = Action.deserialize(cause_by) - # - # return Message(**message_dict) - - def dict(self): - return { - "content": self.content, - "instruct_content": self.instruct_content, - "role": self.role, - "cause_by": self.cause_by, - "sent_from": self.sent_from, - "send_to": self.send_to, - "restricted_to": self.restricted_to - } + # def dict(self): + # return { + # "content": self.content, + # "instruct_content": self.instruct_content, + # "role": self.role, + # "cause_by": self.cause_by, + # "sent_from": self.sent_from, + # "send_to": self.send_to, + # "restricted_to": self.restricted_to + # } def to_dict(self) -> dict: """Return a dict containing `role` and `content` for the LLM call.l""" @@ -316,7 +294,7 @@ class CodingContext(BaseModel): code_doc: Optional[Document] @staticmethod - def loads(val: str) -> CodingContext | None: + def loads(val: str) -> "CodingContext" | None: try: m = json.loads(val) return CodingContext(**m) @@ -330,7 +308,7 @@ class TestingContext(BaseModel): test_doc: Optional[Document] @staticmethod - def loads(val: str) -> TestingContext | None: + def loads(val: str) -> "TestingContext" | None: try: m = json.loads(val) return TestingContext(**m) @@ -351,7 +329,7 @@ class RunCodeContext(BaseModel): output: Optional[str] @staticmethod - def loads(val: str) -> RunCodeContext | None: + def loads(val: str) -> "RunCodeContext" | None: try: m = json.loads(val) return RunCodeContext(**m) @@ -365,7 +343,7 @@ class RunCodeResult(BaseModel): stderr: str @staticmethod - def loads(val: str) -> RunCodeResult | None: + def loads(val: str) -> "RunCodeResult" | None: try: m = json.loads(val) return RunCodeResult(**m) @@ -380,7 +358,7 @@ class CodeSummarizeContext(BaseModel): reason: str = "" @staticmethod - def loads(filenames: List) -> CodeSummarizeContext: + def loads(filenames: List) -> "CodeSummarizeContext": ctx = CodeSummarizeContext() for filename in filenames: if Path(filename).is_relative_to(SYSTEM_DESIGN_FILE_REPO): diff --git a/metagpt/team.py b/metagpt/team.py index 02c48a138..87a6766f6 100644 --- a/metagpt/team.py +++ b/metagpt/team.py @@ -18,7 +18,8 @@ from metagpt.logs import logger from metagpt.roles import Role from metagpt.schema import Message from metagpt.utils.common import NoMoneyException -from metagpt.utils.utils import read_json_file, write_json_file +from metagpt.utils.utils import read_json_file, write_json_file, serialize_decorator +from metagpt.const import SERDESER_PATH class Team(BaseModel): @@ -34,29 +35,35 @@ class Team(BaseModel): class Config: arbitrary_types_allowed = True - def serialize(self, stg_path: Path): + def serialize(self, stg_path: Path = None): + stg_path = SERDESER_PATH.joinpath("team") if stg_path is None else stg_path + team_info_path = stg_path.joinpath("team_info.json") - write_json_file(team_info_path, { - "idea": self.idea, - "investment": self.investment - }) + write_json_file(team_info_path, self.dict(exclude={"environment": True})) - self.environment.serialize(stg_path.joinpath("environment")) + self.environment.serialize(stg_path.joinpath("environment")) # save environment alone - def deserialize(self, stg_path: Path): + @classmethod + def recover(cls, stg_path: Path) -> "Team": + return cls.deserialize(stg_path) + + @classmethod + def deserialize(cls, stg_path: Path) -> "Team": """ stg_path = ./storage/team """ # recover team_info team_info_path = stg_path.joinpath("team_info.json") if not team_info_path.exists(): - logger.error("recover storage not exist, not to recover and continue run the old project.") - team_info = read_json_file(team_info_path) - self.investment = team_info.get("investment", 10.0) - self.idea = team_info.get("idea", "") + raise FileNotFoundError("recover storage meta file `team_info.json` not exist, " + "not to recover and please start a new project.") + + team_info: dict = read_json_file(team_info_path) # recover environment - environment_path = stg_path.joinpath("environment") - self.environment = Environment() - self.environment.deserialize(stg_path=environment_path) + environment = Environment.deserialize(stg_path=stg_path.joinpath("environment")) + team_info.update({"environment": environment}) + + team = Team(**team_info) + return team def hire(self, roles: list[Role]): """Hire roles to cooperate""" @@ -84,6 +91,7 @@ class Team(BaseModel): def _save(self): logger.info(self.json(ensure_ascii=False)) + @serialize_decorator async def run(self, n_round=3): """Run company until target round or no money""" while n_round > 0: diff --git a/metagpt/utils/serialize.py b/metagpt/utils/serialize.py index 56a866f2e..9a7049214 100644 --- a/metagpt/utils/serialize.py +++ b/metagpt/utils/serialize.py @@ -5,9 +5,7 @@ import copy import pickle -from metagpt.actions.action_output import ActionOutput -from metagpt.schema import Message -from metagpt.actions.action import Action +from metagpt.utils.utils import import_class def actionoutout_schema_to_mapping(schema: dict) -> dict: @@ -59,7 +57,7 @@ def actionoutput_str_to_mapping(mapping: dict) -> dict: return new_mapping -def serialize_general_message(message: Message) -> dict: +def serialize_general_message(message: "Message") -> dict: """ serialize Message, not to save""" message_cp = copy.deepcopy(message) ic = message_cp.instruct_content @@ -76,7 +74,7 @@ def serialize_general_message(message: Message) -> dict: return message_cp.dict() -def serialize_message(message: Message): +def serialize_message(message: "Message"): message_cp = copy.deepcopy(message) # avoid `instruct_content` value update by reference ic = message_cp.instruct_content if ic: @@ -90,29 +88,35 @@ def serialize_message(message: Message): return msg_ser -def deserialize_general_message(message_dict: dict) -> Message: +def deserialize_general_message(message_dict: dict) -> "Message": """ deserialize Message, not to load""" instruct_content = message_dict.pop("instruct_content") cause_by = message_dict.pop("cause_by") - message = Message(**message_dict) + message_cls = import_class("Message", "metagpt.schema") + message = message_cls(**message_dict) if instruct_content: ic = instruct_content mapping = actionoutput_str_to_mapping(ic["mapping"]) - ic_obj = ActionOutput.create_model_class(class_name=ic["class"], mapping=mapping) + + actionoutput_class = import_class("ActionOutput", "metagpt.actions.action_output") + ic_obj = actionoutput_class.create_model_class(class_name=ic["class"], mapping=mapping) ic_new = ic_obj(**ic["value"]) message.instruct_content = ic_new if cause_by: - message.cause_by = Action.deser_class(cause_by) + action_class = import_class("Action", "metagpt.actions.action") + message.cause_by = action_class.deser_class(cause_by) return message -def deserialize_message(message_ser: str) -> Message: +def deserialize_message(message_ser: str) -> "Message": message = pickle.loads(message_ser) if message.instruct_content: ic = message.instruct_content - ic_obj = ActionOutput.create_model_class(class_name=ic["class"], mapping=ic["mapping"]) + + actionoutput_class = import_class("ActionOutput", "metagpt.actions.action_output") + ic_obj = actionoutput_class.create_model_class(class_name=ic["class"], mapping=ic["mapping"]) ic_new = ic_obj(**ic["value"]) message.instruct_content = ic_new diff --git a/metagpt/utils/utils.py b/metagpt/utils/utils.py index 220e228c3..ad5c7626a 100644 --- a/metagpt/utils/utils.py +++ b/metagpt/utils/utils.py @@ -56,3 +56,43 @@ def import_class_inst(class_name: str, module_name: str, *args, **kwargs) -> obj a_class = import_class(class_name, module_name) class_inst = a_class(*args, **kwargs) return class_inst + + +def format_trackback_info(limit: int = 2): + return traceback.format_exc(limit=limit) + + +def serialize_decorator(func): + async def wrapper(self, *args, **kwargs): + try: + return await func(self, *args, **kwargs) + except KeyboardInterrupt as kbi: + logger.error(f"KeyboardInterrupt occurs, start to serialize the project, exp:\n{format_trackback_info()}") + self.serialize() # Team.serialize + except Exception as exp: + logger.error(f"Exception occurs, start to serialize the project, exp:\n{format_trackback_info()}") + self.serialize() # Team.serialize + + return wrapper + + +def role_raise_decorator(func): + async def wrapper(self, *args, **kwargs): + try: + return await func(self, *args, **kwargs) + except KeyboardInterrupt as kbi: + logger.error(f"KeyboardInterrupt: {kbi} occurs, start to serialize the project") + if self._rc.env: + newest_msgs = self._rc.env.memory.get(1) + if len(newest_msgs) > 0: + self._rc.memory.delete(newest_msgs[0]) + except Exception as exp: + if self._rc.env: + newest_msgs = self._rc.env.memory.get(1) + if len(newest_msgs) > 0: + logger.warning("There is a exception in role's execution, in order to resume, " + "we delete the newest role communication message in the role's memory.") + self._rc.memory.delete(newest_msgs[0]) # remove newest msg of the role to make it observed again + raise Exception(format_trackback_info(limit=None)) # raise again to make it captured outside + + return wrapper diff --git a/startup.py b/startup.py index 9f753d553..c4928a1b5 100644 --- a/startup.py +++ b/startup.py @@ -1,10 +1,11 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- + +from typing import Optional import asyncio - import fire +from pathlib import Path -from metagpt.const import SERDES_PATH from metagpt.roles import ( Architect, Engineer, @@ -22,11 +23,11 @@ async def startup( code_review: bool = False, run_tests: bool = False, implement: bool = True, - recover_path: bool = False, + recover_path: Optional[str] = None, ): """Run a startup. Be a boss.""" - company = Team() if not recover_path: + company = Team() company.hire( [ ProductManager(), @@ -45,8 +46,12 @@ async def startup( # (bug fixing capability comes soon!) company.hire([QaEngineer()]) else: - stg_path = SERDES_PATH.joinpath("team") - company.deserialize(stg_path=stg_path) + # # stg_path = SERDESER_PATH.joinpath("team") + stg_path = Path(recover_path) + if not stg_path.exists() or not str(stg_path).endswith("team"): + raise FileNotFoundError(f"{recover_path} not exists or not endswith `team`") + + company = Team.recover(stg_path=stg_path) idea = company.idea # use original idea company.invest(investment) From a01766ae72d9d2ac7a113f51afbfd6e2d30e85e1 Mon Sep 17 00:00:00 2001 From: better629 Date: Thu, 30 Nov 2023 19:30:02 +0800 Subject: [PATCH 041/167] fix ut of serialize_deserialize --- .../serialize_deserialize/test_action.py | 3 +-- .../test_product_manager.py | 1 - .../serialize_deserialize/test_role.py | 10 ++++++++- .../test_serdeser_base.py | 21 +++++++++++++------ .../serialize_deserialize/test_team.py | 2 +- .../serialize_deserialize/test_wrire_prd.py | 4 ++-- .../serialize_deserialize/test_write_code.py | 2 -- .../test_write_design.py | 3 +-- 8 files changed, 29 insertions(+), 17 deletions(-) diff --git a/tests/metagpt/serialize_deserialize/test_action.py b/tests/metagpt/serialize_deserialize/test_action.py index b624dff5a..0138d41ce 100644 --- a/tests/metagpt/serialize_deserialize/test_action.py +++ b/tests/metagpt/serialize_deserialize/test_action.py @@ -13,14 +13,13 @@ def test_action_serialize(): action = Action() ser_action_dict = action.dict() assert "name" in ser_action_dict - assert "llm" in ser_action_dict + assert "llm" not in ser_action_dict @pytest.mark.asyncio async def test_action_deserialize(): action = Action() serialized_data = action.dict() - assert isinstance(serialized_data["llm"], OpenAIGPTAPI) new_action = Action(**serialized_data) diff --git a/tests/metagpt/serialize_deserialize/test_product_manager.py b/tests/metagpt/serialize_deserialize/test_product_manager.py index 54584cf96..25bc07a11 100644 --- a/tests/metagpt/serialize_deserialize/test_product_manager.py +++ b/tests/metagpt/serialize_deserialize/test_product_manager.py @@ -14,7 +14,6 @@ async def test_product_manager_deserialize(): role = ProductManager() ser_role_dict = role.dict(by_alias=True) new_role = ProductManager(**ser_role_dict) - # new_role = ProductManager().deserialize(ser_role_dict) assert new_role.name == "Alice" assert len(new_role._actions) == 1 diff --git a/tests/metagpt/serialize_deserialize/test_role.py b/tests/metagpt/serialize_deserialize/test_role.py index f260dea3a..c21b9cc2e 100644 --- a/tests/metagpt/serialize_deserialize/test_role.py +++ b/tests/metagpt/serialize_deserialize/test_role.py @@ -17,7 +17,15 @@ from metagpt.const import SERDESER_PATH from metagpt.roles.engineer import Engineer from metagpt.utils.utils import format_trackback_info -from tests.metagpt.serialize_deserialize.test_serdeser_base import RoleC, serdeser_path +from tests.metagpt.serialize_deserialize.test_serdeser_base import RoleA, RoleB, RoleC, serdeser_path + + +def test_roles(): + role_a = RoleA() + assert len(role_a._rc.watch) == 1 + role_b = RoleB() + assert len(role_a._rc.watch) == 1 + assert len(role_b._rc.watch) == 1 def test_role_serialize(): diff --git a/tests/metagpt/serialize_deserialize/test_serdeser_base.py b/tests/metagpt/serialize_deserialize/test_serdeser_base.py index 35bad6cd9..00d894b3d 100644 --- a/tests/metagpt/serialize_deserialize/test_serdeser_base.py +++ b/tests/metagpt/serialize_deserialize/test_serdeser_base.py @@ -8,6 +8,7 @@ from pathlib import Path from metagpt.actions.action import Action from metagpt.roles.role import Role, RoleReactMode from metagpt.actions.add_requirement import BossRequirement +from metagpt.actions.action_output import ActionOutput serdeser_path = Path(__file__).absolute().parent.joinpath("../../data/serdeser_storage") @@ -22,21 +23,27 @@ class MockMessage(BaseModel): class ActionPass(Action): name: str = "ActionPass" - async def run(self, messages: list["Message"]): - return "pass" + async def run(self, messages: list["Message"]) -> ActionOutput: + output_mapping = { + "result": (str, ...) + } + pass_class = ActionOutput.create_model_class("pass", output_mapping) + pass_output = ActionOutput("ActionPass run passed", pass_class(**{"result": "pass result"})) + + return pass_output class ActionOK(Action): name: str = "ActionOK" - async def run(self, messages: list["Message"]): + async def run(self, messages: list["Message"]) -> str: return "ok" class ActionRaise(Action): name: str = "ActionRaise" - async def run(self, messages: list["Message"]): + async def run(self, messages: list["Message"]) -> str: raise RuntimeError("parse error in ActionRaise") @@ -48,7 +55,8 @@ class RoleA(Role): constraints: str = "RoleA's constraints" def __init__(self, **kwargs): - super(RoleA, self).__init__(**kwargs) + # super(RoleA, self).__init__(**kwargs) + super().__init__(**kwargs) self._init_actions([ActionPass]) self._watch([BossRequirement]) @@ -63,7 +71,8 @@ class RoleB(Role): constraints: str = "RoleB's constraints" def __init__(self, **kwargs): - super(RoleB, self).__init__(**kwargs) + # super(RoleB, self).__init__(**kwargs) + super().__init__(**kwargs) self._init_actions([ActionOK, ActionRaise]) self._watch([ActionPass]) self._rc.react_mode = RoleReactMode.BY_ORDER diff --git a/tests/metagpt/serialize_deserialize/test_team.py b/tests/metagpt/serialize_deserialize/test_team.py index e9122ebc0..b8972135b 100644 --- a/tests/metagpt/serialize_deserialize/test_team.py +++ b/tests/metagpt/serialize_deserialize/test_team.py @@ -11,7 +11,7 @@ from metagpt.roles import ProjectManager, ProductManager, Architect from metagpt.team import Team from metagpt.const import SERDESER_PATH -from tests.metagpt.serialize_deserialize.test_serdeser_base import RoleA, RoleB, RoleC, serdeser_path +from tests.metagpt.serialize_deserialize.test_serdeser_base import RoleA, RoleB, RoleC, serdeser_path, ActionOK def test_team_deserialize(): diff --git a/tests/metagpt/serialize_deserialize/test_wrire_prd.py b/tests/metagpt/serialize_deserialize/test_wrire_prd.py index 96b4d19ad..05a86cb7f 100644 --- a/tests/metagpt/serialize_deserialize/test_wrire_prd.py +++ b/tests/metagpt/serialize_deserialize/test_wrire_prd.py @@ -21,7 +21,7 @@ async def test_action_deserialize(): action = WritePRD() serialized_data = action.dict() new_action = WritePRD(**serialized_data) - # new_action = WritePRD().deserialize(serialized_data) assert new_action.name == "" assert new_action.llm == LLM() - assert len(await new_action.run([Message(content="write a cli snake game")])) > 0 + action_output = await new_action.run([Message(content="write a cli snake game")]) + assert len(action_output.content) > 0 diff --git a/tests/metagpt/serialize_deserialize/test_write_code.py b/tests/metagpt/serialize_deserialize/test_write_code.py index 7f4799014..4e3b712c0 100644 --- a/tests/metagpt/serialize_deserialize/test_write_code.py +++ b/tests/metagpt/serialize_deserialize/test_write_code.py @@ -27,7 +27,6 @@ async def test_write_code_deserialize(): action = WriteCode() serialized_data = action.dict() new_action = WriteCode(**serialized_data) - # new_action = WriteCode().deserialize(serialized_data) assert new_action.name == "WriteCode" assert new_action.llm == LLM() await new_action.run(context="write a cli snake game", filename="test_code") @@ -38,7 +37,6 @@ async def test_write_code_review_deserialize(): action = WriteCodeReview() serialized_data = action.dict() new_action = WriteCodeReview(**serialized_data) - # new_action = WriteCodeReview().deserialize(serialized_data) code = await WriteCode().run(context="write a cli snake game", filename="test_code") assert new_action.name == "WriteCodeReview" diff --git a/tests/metagpt/serialize_deserialize/test_write_design.py b/tests/metagpt/serialize_deserialize/test_write_design.py index e6e236676..5b2a30ed3 100644 --- a/tests/metagpt/serialize_deserialize/test_write_design.py +++ b/tests/metagpt/serialize_deserialize/test_write_design.py @@ -26,7 +26,7 @@ def test_write_task_serialize(): async def test_write_design_deserialize(): action = WriteDesign() serialized_data = action.dict() - new_action = WriteDesign().deserialize(serialized_data) + new_action = WriteDesign(**serialized_data) assert new_action.name == "" assert new_action.llm == LLM() await new_action.run(context="write a cli snake game") @@ -37,7 +37,6 @@ async def test_write_task_deserialize(): action = WriteTasks() serialized_data = action.dict() new_action = WriteTasks(**serialized_data) - # new_action = WriteTasks().deserialize(serialized_data) assert new_action.name == "CreateTasks" assert new_action.llm == LLM() await new_action.run(context="write a cli snake game") From a6510c44fcb14eaecb42224d3398acdacbc13d30 Mon Sep 17 00:00:00 2001 From: better629 Date: Thu, 30 Nov 2023 19:31:26 +0800 Subject: [PATCH 042/167] fix actions/roles ser&deser --- metagpt/actions/search_and_summarize.py | 15 +++++++-------- metagpt/roles/role.py | 4 ++-- metagpt/utils/utils.py | 4 +++- 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/metagpt/actions/search_and_summarize.py b/metagpt/actions/search_and_summarize.py index 7bff1c113..aa4d0f654 100644 --- a/metagpt/actions/search_and_summarize.py +++ b/metagpt/actions/search_and_summarize.py @@ -117,19 +117,18 @@ class SearchAndSummarize(Action): @root_validator def validate_engine_and_run_func(cls, values): - engine = values.get('engine') - search_func = values.get('search_func') + engine = values.get("engine") + search_func = values.get("search_func") config = Config() if engine is None: engine = config.search_engine - config_data = { - 'engine': engine, - 'run_func': search_func - } - search_engine = SearchEngine(**config_data) + try: + search_engine = SearchEngine(engine=engine, run_func=search_func) + except pydantic.ValidationError: + search_engine = None - values['search_engine'] = search_engine + values["search_engine"] = search_engine return values async def run(self, context: list[Message], system_text=SEARCH_AND_SUMMARIZE_SYSTEM) -> str: diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 114e9e599..e407003f5 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -113,8 +113,7 @@ class RoleSetting(BaseModel): class RoleContext(BaseModel): """Role Runtime Context""" - - env: "Environment" = Field(default=None) + env: "Environment" = Field(default=None, exclude=True) msg_buffer: MessageQueue = Field(default_factory=MessageQueue) # Message Buffer with Asynchronous Updates memory: Memory = Field(default_factory=Memory) # long_term_memory: LongTermMemory = Field(default_factory=LongTermMemory) @@ -235,6 +234,7 @@ class Role(BaseModel): def _reset(self): object.__setattr__(self, "_states", []) object.__setattr__(self, "_actions", []) + # object.__setattr__(self, "_rc", RoleContext()) @property def _setting(self): diff --git a/metagpt/utils/utils.py b/metagpt/utils/utils.py index ad5c7626a..b9a8dcb53 100644 --- a/metagpt/utils/utils.py +++ b/metagpt/utils/utils.py @@ -65,7 +65,9 @@ def format_trackback_info(limit: int = 2): def serialize_decorator(func): async def wrapper(self, *args, **kwargs): try: - return await func(self, *args, **kwargs) + result = await func(self, *args, **kwargs) + self.serialize() # Team.serialize + return result except KeyboardInterrupt as kbi: logger.error(f"KeyboardInterrupt occurs, start to serialize the project, exp:\n{format_trackback_info()}") self.serialize() # Team.serialize From 0a80752908deae92906f4b0337972790ada79756 Mon Sep 17 00:00:00 2001 From: better629 Date: Thu, 30 Nov 2023 21:42:09 +0800 Subject: [PATCH 043/167] fix role._rc init --- metagpt/environment.py | 4 ++++ metagpt/roles/role.py | 1 + .../serialize_deserialize/test_team.py | 19 ++++++++++++++++--- 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/metagpt/environment.py b/metagpt/environment.py index 242581e17..19c77a03d 100644 --- a/metagpt/environment.py +++ b/metagpt/environment.py @@ -40,6 +40,7 @@ class Environment(BaseModel): arbitrary_types_allowed = True def __init__(self, **kwargs): + roles = [] for role_key, role in kwargs.get("roles", {}).items(): current_role = kwargs["roles"][role_key] if isinstance(current_role, dict): @@ -50,8 +51,11 @@ class Environment(BaseModel): current_role = subclass(**current_role) break kwargs["roles"][role_key] = current_role + roles.append(current_role) super().__init__(**kwargs) + self.add_roles(roles) # add_roles again to init the Role.set_env + def serialize(self, stg_path: Path): roles_path = stg_path.joinpath("roles.json") roles_info = [] diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index e407003f5..6be800789 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -113,6 +113,7 @@ class RoleSetting(BaseModel): class RoleContext(BaseModel): """Role Runtime Context""" + # # env exclude=True to avoid `RecursionError: maximum recursion depth exceeded in comparison` env: "Environment" = Field(default=None, exclude=True) msg_buffer: MessageQueue = Field(default_factory=MessageQueue) # Message Buffer with Asynchronous Updates memory: Memory = Field(default_factory=Memory) diff --git a/tests/metagpt/serialize_deserialize/test_team.py b/tests/metagpt/serialize_deserialize/test_team.py index b8972135b..e5ec20f2e 100644 --- a/tests/metagpt/serialize_deserialize/test_team.py +++ b/tests/metagpt/serialize_deserialize/test_team.py @@ -39,7 +39,7 @@ def test_team_deserialize(): assert new_company.environment.get_role(arch.profile) is not None -def test_team_serdeser(): +def test_team_serdeser_save(): company = Team() company.hire([RoleC()]) @@ -60,12 +60,19 @@ async def test_team_recover(): shutil.rmtree(stg_path, ignore_errors=True) company = Team() - company.hire([RoleC()]) + role_c = RoleC() + company.hire([role_c]) company.start_project(idea) await company.run(n_round=4) ser_data = company.dict() new_company = Team(**ser_data) + + new_role_c = new_company.environment.get_role(role_c.profile) + assert new_role_c._rc.memory == role_c._rc.memory + assert new_role_c._rc.env != role_c._rc.env # due to Action raise, role's memory has been changed. + assert new_role_c._rc.env.memory == role_c._rc.env.memory + assert new_company.environment.memory.count() == 1 assert type(list(new_company.environment.roles.values())[0]._actions[0]) == ActionOK @@ -80,11 +87,17 @@ async def test_team_recover_save(): shutil.rmtree(stg_path, ignore_errors=True) company = Team() - company.hire([RoleC()]) + role_c = RoleC() + company.hire([role_c]) company.start_project(idea) await company.run(n_round=4) new_company = Team.recover(stg_path) + new_role_c = new_company.environment.get_role(role_c.profile) + assert new_role_c._rc.memory == role_c._rc.memory + assert new_role_c._rc.env != role_c._rc.env # due to Action raise, role's memory has been changed. + assert new_role_c._rc.env.memory == role_c._rc.env.memory + new_company.start_project(idea) await new_company.run(n_round=4) From 26ddddaadd8dada086d8bc6199320863ca7d3f51 Mon Sep 17 00:00:00 2001 From: better629 Date: Fri, 1 Dec 2023 14:43:45 +0800 Subject: [PATCH 044/167] simplify some ser&desr code --- metagpt/actions/action.py | 20 ++++++------------ metagpt/roles/role.py | 43 +++++++++++++++++++++++++++++---------- metagpt/schema.py | 13 +----------- 3 files changed, 39 insertions(+), 37 deletions(-) diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index 499b5e794..8b28ffd8e 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -54,6 +54,12 @@ class Action(BaseModel): super().__init_subclass__(**kwargs) action_subclass_registry[cls.__name__] = cls + def dict(self, *args, **kwargs) -> "DictStrAny": + obj_dict = super(Action, self).dict(*args, **kwargs) + if "llm" in obj_dict: + obj_dict.pop("llm") + return obj_dict + def set_prefix(self, prefix, profile): """Set prefix for later usage""" self.prefix = prefix @@ -66,20 +72,6 @@ class Action(BaseModel): def __repr__(self): return self.__str__() - def serialize(self): - return { - "action_class": self.__class__.__name__, - "module_name": self.__module__, - "name": self.name - } - - @classmethod - def deserialize(cls, action_dict: dict) -> "Action": - action_class_str = action_dict.pop("action_class") - module_name = action_dict.pop("module_name") - action_class = import_class(action_class_str, module_name) - return action_class(**action_dict) - @classmethod def ser_class(cls) -> dict: """ serialize class type""" diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 6be800789..59b0f9cd6 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -119,17 +119,33 @@ class RoleContext(BaseModel): memory: Memory = Field(default_factory=Memory) # long_term_memory: LongTermMemory = Field(default_factory=LongTermMemory) state: int = Field(default=-1) # -1 indicates initial or termination state where todo is None - todo: Action = Field(default=None) - watch: set[str] = Field(default_factory=set) - news: list[Type[Message]] = Field(default=[]) - react_mode: RoleReactMode = ( - RoleReactMode.REACT - ) # see `Role._set_react_mode` for definitions of the following two attributes + todo: Action = Field(default=None, exclude=True) + watch: set[Type[Action]] = Field(default_factory=set) + news: list[Type[Message]] = Field(default=[], exclude=True) # TODO not used + react_mode: RoleReactMode = RoleReactMode.REACT # see `Role._set_react_mode` for definitions of the following two attributes max_react_loop: int = 1 class Config: arbitrary_types_allowed = True - + + def __init__(self, **kwargs): + watch_info = kwargs.get("watch", set()) + watch = set() + for item in watch_info: + action = Action.deser_class(item) + watch.update([action]) + kwargs["watch"] = watch + super(RoleContext, self).__init__(**kwargs) + + def dict(self, *args, **kwargs) -> "DictStrAny": + obj_dict = super(RoleContext, self).dict(*args, **kwargs) + watch = obj_dict.get("watch", set()) + watch_info = [] + for item in watch: + watch_info.append(item.ser_class()) + obj_dict["watch"] = watch_info + return obj_dict + def check(self, role_id: str): # if hasattr(CONFIG, "long_term_memory") and CONFIG.long_term_memory: # self.long_term_memory.recover_memory(role_id, self) @@ -290,7 +306,7 @@ class Role(BaseModel): for idx, action in enumerate(actions): if not isinstance(action, Action): ## 默认初始化 - i = action(name="", llm=self._llm) + i = action(llm=self._llm) else: if self._setting.is_human and not isinstance(action.llm, HumanProvider): logger.warning( @@ -386,9 +402,14 @@ class Role(BaseModel): def _get_prefix(self): """Get the role prefix""" - if self._setting.desc: - return self._setting.desc - return PREFIX_TEMPLATE.format(**self._setting.dict()) + if self.desc: + return self.desc + return PREFIX_TEMPLATE.format(**{ + "profile": self.profile, + "name": self.name, + "goal": self.goal, + "constraints": self.constraints + }) async def _think(self) -> None: """Think about what to do and decide on the next action""" diff --git a/metagpt/schema.py b/metagpt/schema.py index a872481bb..15dfb579c 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -64,7 +64,7 @@ class Document(BaseModel): filename: str = "" content: str = "" - def get_meta(self) -> "Document"": + def get_meta(self) -> "Document": """Get metadata of the document. :return: A new Document instance with the same root path and filename. @@ -164,17 +164,6 @@ class Message(BaseModel): def __repr__(self): return self.__str__() - # def dict(self): - # return { - # "content": self.content, - # "instruct_content": self.instruct_content, - # "role": self.role, - # "cause_by": self.cause_by, - # "sent_from": self.sent_from, - # "send_to": self.send_to, - # "restricted_to": self.restricted_to - # } - def to_dict(self) -> dict: """Return a dict containing `role` and `content` for the LLM call.l""" return {"role": self.role, "content": self.content} From 1514942d1d0058f85569fffdca10db2e9281613c Mon Sep 17 00:00:00 2001 From: better629 Date: Fri, 1 Dec 2023 14:45:06 +0800 Subject: [PATCH 045/167] update ut after simplification --- tests/metagpt/serialize_deserialize/test_action.py | 14 +------------- tests/metagpt/serialize_deserialize/test_role.py | 3 --- .../serialize_deserialize/test_serdeser_base.py | 6 +++--- tests/metagpt/serialize_deserialize/test_team.py | 2 +- .../serialize_deserialize/test_wrire_prd.py | 2 +- .../serialize_deserialize/test_write_code.py | 4 ++-- .../serialize_deserialize/test_write_design.py | 4 ++-- 7 files changed, 10 insertions(+), 25 deletions(-) diff --git a/tests/metagpt/serialize_deserialize/test_action.py b/tests/metagpt/serialize_deserialize/test_action.py index 0138d41ce..16369bb61 100644 --- a/tests/metagpt/serialize_deserialize/test_action.py +++ b/tests/metagpt/serialize_deserialize/test_action.py @@ -13,7 +13,7 @@ def test_action_serialize(): action = Action() ser_action_dict = action.dict() assert "name" in ser_action_dict - assert "llm" not in ser_action_dict + # assert "llm" not in ser_action_dict # not export @pytest.mark.asyncio @@ -34,15 +34,3 @@ def test_action_serdeser(): action_class = Action.deser_class(action_info) assert action_class == WriteTest - - -def test_action_class_serdeser(): - name = "write test" - action_info = WriteTest(name=name).serialize() - assert action_info["name"] == name - - action_info = WriteTest(name=name, llm=LLM()).serialize() - assert action_info["name"] == name - - action = Action.deserialize(action_info) - assert action.name == name diff --git a/tests/metagpt/serialize_deserialize/test_role.py b/tests/metagpt/serialize_deserialize/test_role.py index c21b9cc2e..61684ba9d 100644 --- a/tests/metagpt/serialize_deserialize/test_role.py +++ b/tests/metagpt/serialize_deserialize/test_role.py @@ -66,7 +66,6 @@ def test_role_serdeser_save(): role_tag = f"{pm.__class__.__name__}_{pm.name}" stg_path = stg_path_prefix.joinpath(role_tag) pm.serialize(stg_path) - assert stg_path.joinpath("actions/actions_info.json").exists() new_pm = Role.deserialize(stg_path) assert new_pm.name == pm.name @@ -89,8 +88,6 @@ async def test_role_serdeser_interrupt(): assert role_c._rc.memory.count() == 2 - assert stg_path.joinpath("actions/todo.json").exists() - new_role_a: Role = Role.deserialize(stg_path) assert new_role_a._rc.state == 1 diff --git a/tests/metagpt/serialize_deserialize/test_serdeser_base.py b/tests/metagpt/serialize_deserialize/test_serdeser_base.py index 00d894b3d..74f9fea87 100644 --- a/tests/metagpt/serialize_deserialize/test_serdeser_base.py +++ b/tests/metagpt/serialize_deserialize/test_serdeser_base.py @@ -21,7 +21,7 @@ class MockMessage(BaseModel): class ActionPass(Action): - name: str = "ActionPass" + name: str = Field(default="ActionPass") async def run(self, messages: list["Message"]) -> ActionOutput: output_mapping = { @@ -34,14 +34,14 @@ class ActionPass(Action): class ActionOK(Action): - name: str = "ActionOK" + name: str = Field(default="ActionOK") async def run(self, messages: list["Message"]) -> str: return "ok" class ActionRaise(Action): - name: str = "ActionRaise" + name: str = Field(default="ActionRaise") async def run(self, messages: list["Message"]) -> str: raise RuntimeError("parse error in ActionRaise") diff --git a/tests/metagpt/serialize_deserialize/test_team.py b/tests/metagpt/serialize_deserialize/test_team.py index e5ec20f2e..28728e1b5 100644 --- a/tests/metagpt/serialize_deserialize/test_team.py +++ b/tests/metagpt/serialize_deserialize/test_team.py @@ -70,7 +70,7 @@ async def test_team_recover(): new_role_c = new_company.environment.get_role(role_c.profile) assert new_role_c._rc.memory == role_c._rc.memory - assert new_role_c._rc.env != role_c._rc.env # due to Action raise, role's memory has been changed. + assert new_role_c._rc.env == role_c._rc.env # TODO check again assert new_role_c._rc.env.memory == role_c._rc.env.memory assert new_company.environment.memory.count() == 1 diff --git a/tests/metagpt/serialize_deserialize/test_wrire_prd.py b/tests/metagpt/serialize_deserialize/test_wrire_prd.py index 05a86cb7f..0b9dfa9d8 100644 --- a/tests/metagpt/serialize_deserialize/test_wrire_prd.py +++ b/tests/metagpt/serialize_deserialize/test_wrire_prd.py @@ -13,7 +13,7 @@ def test_action_serialize(): action = WritePRD() ser_action_dict = action.dict() assert "name" in ser_action_dict - assert "llm" in ser_action_dict + # assert "llm" in ser_action_dict # not export @pytest.mark.asyncio diff --git a/tests/metagpt/serialize_deserialize/test_write_code.py b/tests/metagpt/serialize_deserialize/test_write_code.py index 4e3b712c0..5552ffd7f 100644 --- a/tests/metagpt/serialize_deserialize/test_write_code.py +++ b/tests/metagpt/serialize_deserialize/test_write_code.py @@ -12,14 +12,14 @@ def test_write_design_serialize(): action = WriteCode() ser_action_dict = action.dict() assert ser_action_dict["name"] == "WriteCode" - assert "llm" in ser_action_dict + # assert "llm" in ser_action_dict # not export def test_write_task_serialize(): action = WriteCodeReview() ser_action_dict = action.dict() assert ser_action_dict["name"] == "WriteCodeReview" - assert "llm" in ser_action_dict + # assert "llm" in ser_action_dict # not export @pytest.mark.asyncio diff --git a/tests/metagpt/serialize_deserialize/test_write_design.py b/tests/metagpt/serialize_deserialize/test_write_design.py index 5b2a30ed3..080896c98 100644 --- a/tests/metagpt/serialize_deserialize/test_write_design.py +++ b/tests/metagpt/serialize_deserialize/test_write_design.py @@ -12,14 +12,14 @@ def test_write_design_serialize(): action = WriteDesign() ser_action_dict = action.dict() assert "name" in ser_action_dict - assert "llm" in ser_action_dict + # assert "llm" in ser_action_dict # not export def test_write_task_serialize(): action = WriteTasks() ser_action_dict = action.dict() assert "name" in ser_action_dict - assert "llm" in ser_action_dict + # assert "llm" in ser_action_dict # not export @pytest.mark.asyncio From a11096ef02efb43f056f77d21707dff97f8d72a3 Mon Sep 17 00:00:00 2001 From: better629 Date: Fri, 1 Dec 2023 15:30:28 +0800 Subject: [PATCH 046/167] update --- tests/metagpt/serialize_deserialize/test_team.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/metagpt/serialize_deserialize/test_team.py b/tests/metagpt/serialize_deserialize/test_team.py index 28728e1b5..9c4eb8170 100644 --- a/tests/metagpt/serialize_deserialize/test_team.py +++ b/tests/metagpt/serialize_deserialize/test_team.py @@ -70,7 +70,7 @@ async def test_team_recover(): new_role_c = new_company.environment.get_role(role_c.profile) assert new_role_c._rc.memory == role_c._rc.memory - assert new_role_c._rc.env == role_c._rc.env # TODO check again + assert new_role_c._rc.env == role_c._rc.env assert new_role_c._rc.env.memory == role_c._rc.env.memory assert new_company.environment.memory.count() == 1 @@ -95,7 +95,10 @@ async def test_team_recover_save(): new_company = Team.recover(stg_path) new_role_c = new_company.environment.get_role(role_c.profile) assert new_role_c._rc.memory == role_c._rc.memory - assert new_role_c._rc.env != role_c._rc.env # due to Action raise, role's memory has been changed. + assert new_role_c._rc.env != role_c._rc.env + assert new_role_c.recovered != role_c.recovered # here cause previous ut is `!=` + assert new_role_c._rc.todo != role_c._rc.todo # serialize exclude `_rc.todo` + assert new_role_c._rc.news != role_c._rc.news # serialize exclude `_rc.news` assert new_role_c._rc.env.memory == role_c._rc.env.memory new_company.start_project(idea) From 0f2d96a7e2ad1028fe1f8baa3495be8c2e1fd5c7 Mon Sep 17 00:00:00 2001 From: better629 Date: Fri, 1 Dec 2023 20:35:48 +0800 Subject: [PATCH 047/167] update asyncio.sleep to make it async --- .../test_serdeser_base.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/tests/metagpt/serialize_deserialize/test_serdeser_base.py b/tests/metagpt/serialize_deserialize/test_serdeser_base.py index 74f9fea87..298c13823 100644 --- a/tests/metagpt/serialize_deserialize/test_serdeser_base.py +++ b/tests/metagpt/serialize_deserialize/test_serdeser_base.py @@ -4,6 +4,7 @@ from pydantic import BaseModel, Field from pathlib import Path +import asyncio from metagpt.actions.action import Action from metagpt.roles.role import Role, RoleReactMode @@ -24,6 +25,7 @@ class ActionPass(Action): name: str = Field(default="ActionPass") async def run(self, messages: list["Message"]) -> ActionOutput: + await asyncio.sleep(5) # sleep to make other roles can watch the executed Message output_mapping = { "result": (str, ...) } @@ -37,6 +39,7 @@ class ActionOK(Action): name: str = Field(default="ActionOK") async def run(self, messages: list["Message"]) -> str: + await asyncio.sleep(5) return "ok" @@ -55,14 +58,10 @@ class RoleA(Role): constraints: str = "RoleA's constraints" def __init__(self, **kwargs): - # super(RoleA, self).__init__(**kwargs) - super().__init__(**kwargs) + super(RoleA, self).__init__(**kwargs) self._init_actions([ActionPass]) self._watch([BossRequirement]) - async def run(self, message: "Message" = None): - await super(RoleA, self).run(message) - class RoleB(Role): name: str = Field(default="RoleB") @@ -71,15 +70,11 @@ class RoleB(Role): constraints: str = "RoleB's constraints" def __init__(self, **kwargs): - # super(RoleB, self).__init__(**kwargs) - super().__init__(**kwargs) + super(RoleB, self).__init__(**kwargs) self._init_actions([ActionOK, ActionRaise]) self._watch([ActionPass]) self._rc.react_mode = RoleReactMode.BY_ORDER - async def run(self, message: "Message" = None): - await super(RoleB, self).run(message) - class RoleC(Role): name: str = Field(default="RoleC") @@ -92,6 +87,3 @@ class RoleC(Role): self._init_actions([ActionOK, ActionRaise]) self._watch([BossRequirement]) self._rc.react_mode = RoleReactMode.BY_ORDER - - async def run(self, message: "Message" = None): - await super(RoleC, self).run(message) From 3679d77f0df68eeb7bd9d325eb671a20430a81c7 Mon Sep 17 00:00:00 2001 From: better629 Date: Fri, 1 Dec 2023 21:07:47 +0800 Subject: [PATCH 048/167] fix when RoleReactMode=REACT --- metagpt/roles/role.py | 4 ++-- metagpt/utils/utils.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 59b0f9cd6..e63404939 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -417,9 +417,9 @@ class Role(BaseModel): # If there is only one action, then only this one can be performed self._set_state(0) return - if self._recovered and self._rc.state >= 0: + if self.recovered and self._rc.state >= 0: self._set_state(self._rc.state) # action to run from recovered state - self._recovered = False # avoid max_react_loop out of work + self.recovered = False # avoid max_react_loop out of work return prompt = self._get_prefix() diff --git a/metagpt/utils/utils.py b/metagpt/utils/utils.py index b9a8dcb53..33ca16944 100644 --- a/metagpt/utils/utils.py +++ b/metagpt/utils/utils.py @@ -88,6 +88,7 @@ def role_raise_decorator(func): newest_msgs = self._rc.env.memory.get(1) if len(newest_msgs) > 0: self._rc.memory.delete(newest_msgs[0]) + raise Exception(format_trackback_info(limit=None)) # raise again to make it captured outside except Exception as exp: if self._rc.env: newest_msgs = self._rc.env.memory.get(1) From b4322bca54b62bf32498af28f607f0413ae0fe0e Mon Sep 17 00:00:00 2001 From: better629 Date: Tue, 19 Dec 2023 13:55:45 +0800 Subject: [PATCH 049/167] update tests of serialize_deserialize --- .../serialize_deserialize/test_action.py | 3 +- .../test_architect_deserialize.py | 2 +- .../serialize_deserialize/test_environment.py | 12 +++--- .../serialize_deserialize/test_memory.py | 13 +++--- .../test_product_manager.py | 2 +- .../serialize_deserialize/test_role.py | 16 ++++--- .../serialize_deserialize/test_schema.py | 10 ++--- .../test_serdeser_base.py | 8 ++-- .../serialize_deserialize/test_team.py | 42 +++++++++---------- .../serialize_deserialize/test_write_code.py | 31 ++++---------- .../test_write_code_review.py | 37 ++++++++++++++++ .../test_write_design.py | 4 +- .../{test_wrire_prd.py => test_write_prd.py} | 3 +- tests/metagpt/test_schema.py | 8 +--- 14 files changed, 100 insertions(+), 91 deletions(-) create mode 100644 tests/metagpt/serialize_deserialize/test_write_code_review.py rename tests/metagpt/serialize_deserialize/{test_wrire_prd.py => test_write_prd.py} (87%) diff --git a/tests/metagpt/serialize_deserialize/test_action.py b/tests/metagpt/serialize_deserialize/test_action.py index 16369bb61..2db5d223c 100644 --- a/tests/metagpt/serialize_deserialize/test_action.py +++ b/tests/metagpt/serialize_deserialize/test_action.py @@ -4,9 +4,8 @@ # @Desc : import pytest -from metagpt.actions import Action, WritePRD, WriteTest +from metagpt.actions import Action, WriteTest from metagpt.llm import LLM -from metagpt.provider.openai_api import OpenAIGPTAPI def test_action_serialize(): diff --git a/tests/metagpt/serialize_deserialize/test_architect_deserialize.py b/tests/metagpt/serialize_deserialize/test_architect_deserialize.py index fb58f0a3a..66fba6167 100644 --- a/tests/metagpt/serialize_deserialize/test_architect_deserialize.py +++ b/tests/metagpt/serialize_deserialize/test_architect_deserialize.py @@ -25,4 +25,4 @@ async def test_architect_deserialize(): assert new_role.name == "Bob" assert len(new_role._actions) == 1 assert isinstance(new_role._actions[0], Action) - await new_role._actions[0].run(context="write a cli snake game") + await new_role._actions[0].run(with_messages="write a cli snake game") diff --git a/tests/metagpt/serialize_deserialize/test_environment.py b/tests/metagpt/serialize_deserialize/test_environment.py index 15336eb6a..4e3445047 100644 --- a/tests/metagpt/serialize_deserialize/test_environment.py +++ b/tests/metagpt/serialize_deserialize/test_environment.py @@ -8,9 +8,11 @@ import shutil from metagpt.schema import Message from metagpt.actions.action_output import ActionOutput from metagpt.roles.project_manager import ProjectManager -from metagpt.actions.add_requirement import BossRequirement +from metagpt.actions.add_requirement import UserRequirement from metagpt.actions.project_management import WriteTasks from metagpt.environment import Environment +from metagpt.utils.common import any_to_str + from tests.metagpt.serialize_deserialize.test_serdeser_base import RoleC, ActionOK, serdeser_path @@ -18,7 +20,6 @@ def test_env_serialize(): env = Environment() ser_env_dict = env.dict() assert "roles" in ser_env_dict - assert "memory" in ser_env_dict def test_env_deserialize(): @@ -27,7 +28,6 @@ def test_env_deserialize(): ser_env_dict = env.dict() new_env = Environment(**ser_env_dict) assert len(new_env.roles) == 0 - assert new_env.memory.storage[0].content == "test env serialize" assert len(new_env.history) == 25 @@ -40,7 +40,7 @@ def test_environment_serdeser(): content="prd", instruct_content=ic_obj(**out_data), role="product manager", - cause_by=BossRequirement + cause_by=any_to_str(UserRequirement) ) environment = Environment() @@ -54,8 +54,6 @@ def test_environment_serdeser(): new_env: Environment = Environment(**ser_data) assert len(new_env.roles) == 1 - assert new_env.memory.count() == 1 - assert new_env.memory.storage[0].instruct_content == ic_obj(**out_data) assert list(new_env.roles.values())[0]._states == list(environment.roles.values())[0]._states assert list(new_env.roles.values())[0]._actions == list(environment.roles.values())[0]._actions assert isinstance(list(environment.roles.values())[0]._actions[0], ActionOK) @@ -82,7 +80,7 @@ def test_environment_serdeser_save(): shutil.rmtree(serdeser_path.joinpath("team"), ignore_errors=True) - stg_path = serdeser_path.joinpath("team/environment") + stg_path = serdeser_path.joinpath("team", "environment") environment.add_role(role_c) environment.serialize(stg_path) diff --git a/tests/metagpt/serialize_deserialize/test_memory.py b/tests/metagpt/serialize_deserialize/test_memory.py index e24f31af3..50d30a94d 100644 --- a/tests/metagpt/serialize_deserialize/test_memory.py +++ b/tests/metagpt/serialize_deserialize/test_memory.py @@ -9,7 +9,8 @@ from metagpt.schema import Message from metagpt.memory.memory import Memory from metagpt.actions.action_output import ActionOutput from metagpt.actions.design_api import WriteDesign -from metagpt.actions.add_requirement import BossRequirement +from metagpt.actions.add_requirement import UserRequirement +from metagpt.utils.common import any_to_str from tests.metagpt.serialize_deserialize.test_serdeser_base import serdeser_path @@ -17,7 +18,7 @@ from tests.metagpt.serialize_deserialize.test_serdeser_base import serdeser_path def test_memory_serdeser(): msg1 = Message(role="Boss", content="write a snake game", - cause_by=BossRequirement) + cause_by=UserRequirement) out_mapping = {"field2": (list[str], ...)} out_data = {"field2": ["field2 value1", "field2 value2"]} @@ -36,14 +37,14 @@ def test_memory_serdeser(): new_msg2 = new_memory.get(2)[0] assert isinstance(new_msg2, BaseModel) assert isinstance(new_memory.storage[-1], BaseModel) - assert new_memory.storage[-1].cause_by == WriteDesign + assert new_memory.storage[-1].cause_by == any_to_str(WriteDesign) assert new_msg2.role == "Boss" def test_memory_serdeser_save(): msg1 = Message(role="User", content="write a 2048 game", - cause_by=BossRequirement) + cause_by=UserRequirement) out_mapping = {"field1": (list[str], ...)} out_data = {"field1": ["field1 value1", "field1 value2"]} @@ -56,7 +57,7 @@ def test_memory_serdeser_save(): memory = Memory() memory.add_batch([msg1, msg2]) - stg_path = serdeser_path.joinpath("team/environment") + stg_path = serdeser_path.joinpath("team", "environment") memory.serialize(stg_path) assert stg_path.joinpath("memory.json").exists() @@ -64,7 +65,7 @@ def test_memory_serdeser_save(): assert new_memory.count() == 2 new_msg2 = new_memory.get(1)[0] assert new_msg2.instruct_content.field1 == ["field1 value1", "field1 value2"] - assert new_msg2.cause_by == WriteDesign + assert new_msg2.cause_by == any_to_str(WriteDesign) assert len(new_memory.index) == 2 stg_path.joinpath("memory.json").unlink() diff --git a/tests/metagpt/serialize_deserialize/test_product_manager.py b/tests/metagpt/serialize_deserialize/test_product_manager.py index 25bc07a11..1d721282f 100644 --- a/tests/metagpt/serialize_deserialize/test_product_manager.py +++ b/tests/metagpt/serialize_deserialize/test_product_manager.py @@ -16,6 +16,6 @@ async def test_product_manager_deserialize(): new_role = ProductManager(**ser_role_dict) assert new_role.name == "Alice" - assert len(new_role._actions) == 1 + assert len(new_role._actions) == 2 assert isinstance(new_role._actions[0], Action) await new_role._actions[0].run([Message(content="write a cli snake game")]) diff --git a/tests/metagpt/serialize_deserialize/test_role.py b/tests/metagpt/serialize_deserialize/test_role.py index 61684ba9d..fe7b63ef3 100644 --- a/tests/metagpt/serialize_deserialize/test_role.py +++ b/tests/metagpt/serialize_deserialize/test_role.py @@ -3,15 +3,14 @@ # @Author : stellahong (stellahong@fuzhi.ai) # @Desc : -from pathlib import Path import shutil import pytest from metagpt.logs import logger from metagpt.roles.role import Role -from metagpt.actions import WriteCode, WriteCodeReview +from metagpt.actions import WriteCode from metagpt.schema import Message -from metagpt.actions.add_requirement import BossRequirement +from metagpt.actions.add_requirement import UserRequirement from metagpt.roles.product_manager import ProductManager from metagpt.const import SERDESER_PATH from metagpt.roles.engineer import Engineer @@ -52,14 +51,13 @@ async def test_engineer_deserialize(): new_role = Engineer(**ser_role_dict) assert new_role.name == "Alex" assert new_role.use_code_review is True - assert len(new_role._actions) == 2 + assert len(new_role._actions) == 1 assert isinstance(new_role._actions[0], WriteCode) - assert isinstance(new_role._actions[1], WriteCodeReview) # await new_role._actions[0].run(context="write a cli snake game", filename="test_code") def test_role_serdeser_save(): - stg_path_prefix = serdeser_path.joinpath("team/environment/roles/") + stg_path_prefix = serdeser_path.joinpath("team", "environment", "roles") shutil.rmtree(serdeser_path.joinpath("team"), ignore_errors=True) pm = ProductManager() @@ -77,10 +75,10 @@ async def test_role_serdeser_interrupt(): role_c = RoleC() shutil.rmtree(SERDESER_PATH.joinpath("team"), ignore_errors=True) - stg_path = SERDESER_PATH.joinpath(f"team/environment/roles/{role_c.__class__.__name__}_{role_c.name}") + stg_path = SERDESER_PATH.joinpath(f"team", "environment", "roles", "{role_c.__class__.__name__}_{role_c.name}") try: await role_c.run( - message=Message(content="demo", cause_by=BossRequirement) + with_message=Message(content="demo", cause_by=UserRequirement) ) except Exception as exp: logger.error(f"Exception in `role_a.run`, detail: {format_trackback_info()}") @@ -93,5 +91,5 @@ async def test_role_serdeser_interrupt(): with pytest.raises(Exception): await role_c.run( - message=Message(content="demo", cause_by=BossRequirement) + with_message=Message(content="demo", cause_by=UserRequirement) ) diff --git a/tests/metagpt/serialize_deserialize/test_schema.py b/tests/metagpt/serialize_deserialize/test_schema.py index 74b134cad..97ca4ea0c 100644 --- a/tests/metagpt/serialize_deserialize/test_schema.py +++ b/tests/metagpt/serialize_deserialize/test_schema.py @@ -5,6 +5,7 @@ from metagpt.schema import Message from metagpt.actions.action_output import ActionOutput from metagpt.actions.write_code import WriteCode +from metagpt.utils.common import any_to_str from tests.metagpt.serialize_deserialize.test_serdeser_base import MockMessage @@ -21,15 +22,12 @@ def test_message_serdeser(): cause_by=WriteCode ) ser_data = message.dict() - assert ser_data["cause_by"] == { - "action_class": "WriteCode", - "module_name": "metagpt.actions.write_code" - } + assert ser_data["cause_by"] == "metagpt.actions.write_code.WriteCode" assert ser_data["instruct_content"]["class"] == "code" new_message = Message(**ser_data) - assert new_message.cause_by == WriteCode - assert new_message.cause_by in [WriteCode] + assert new_message.cause_by == any_to_str(WriteCode) + assert new_message.cause_by in [any_to_str(WriteCode)] assert new_message.instruct_content == ic_obj(**out_data) diff --git a/tests/metagpt/serialize_deserialize/test_serdeser_base.py b/tests/metagpt/serialize_deserialize/test_serdeser_base.py index 298c13823..0363c519b 100644 --- a/tests/metagpt/serialize_deserialize/test_serdeser_base.py +++ b/tests/metagpt/serialize_deserialize/test_serdeser_base.py @@ -8,11 +8,11 @@ import asyncio from metagpt.actions.action import Action from metagpt.roles.role import Role, RoleReactMode -from metagpt.actions.add_requirement import BossRequirement +from metagpt.actions.add_requirement import UserRequirement from metagpt.actions.action_output import ActionOutput -serdeser_path = Path(__file__).absolute().parent.joinpath("../../data/serdeser_storage") +serdeser_path = Path(__file__).absolute().parent.joinpath("..", "..", "data", "serdeser_storage") class MockMessage(BaseModel): @@ -60,7 +60,7 @@ class RoleA(Role): def __init__(self, **kwargs): super(RoleA, self).__init__(**kwargs) self._init_actions([ActionPass]) - self._watch([BossRequirement]) + self._watch([UserRequirement]) class RoleB(Role): @@ -85,5 +85,5 @@ class RoleC(Role): def __init__(self, **kwargs): super(RoleC, self).__init__(**kwargs) self._init_actions([ActionOK, ActionRaise]) - self._watch([BossRequirement]) + self._watch([UserRequirement]) self._rc.react_mode = RoleReactMode.BY_ORDER diff --git a/tests/metagpt/serialize_deserialize/test_team.py b/tests/metagpt/serialize_deserialize/test_team.py index 9c4eb8170..777f0f381 100644 --- a/tests/metagpt/serialize_deserialize/test_team.py +++ b/tests/metagpt/serialize_deserialize/test_team.py @@ -26,17 +26,17 @@ def test_team_deserialize(): ProjectManager(), ] ) - assert len(company.environment.get_roles()) == 3 + assert len(company.env.get_roles()) == 3 ser_company = company.dict() new_company = Team(**ser_company) - assert len(new_company.environment.get_roles()) == 3 - assert new_company.environment.get_role(pm.profile) is not None + assert len(new_company.env.get_roles()) == 3 + assert new_company.env.get_role(pm.profile) is not None - new_pm = new_company.environment.get_role(pm.profile) + new_pm = new_company.env.get_role(pm.profile) assert type(new_pm) == ProductManager - assert new_company.environment.get_role(pm.profile) is not None - assert new_company.environment.get_role(arch.profile) is not None + assert new_company.env.get_role(pm.profile) is not None + assert new_company.env.get_role(arch.profile) is not None def test_team_serdeser_save(): @@ -50,7 +50,7 @@ def test_team_serdeser_save(): new_company = Team.deserialize(stg_path) - assert len(new_company.environment.roles) == 1 + assert len(new_company.env.roles) == 1 @pytest.mark.asyncio @@ -62,21 +62,18 @@ async def test_team_recover(): company = Team() role_c = RoleC() company.hire([role_c]) - company.start_project(idea) + company.run_project(idea) await company.run(n_round=4) ser_data = company.dict() new_company = Team(**ser_data) - new_role_c = new_company.environment.get_role(role_c.profile) - assert new_role_c._rc.memory == role_c._rc.memory - assert new_role_c._rc.env == role_c._rc.env - assert new_role_c._rc.env.memory == role_c._rc.env.memory + new_role_c = new_company.env.get_role(role_c.profile) + # assert new_role_c._rc.memory == role_c._rc.memory # TODO + assert new_role_c._rc.env != role_c._rc.env # TODO + assert type(list(new_company.env.roles.values())[0]._actions[0]) == ActionOK - assert new_company.environment.memory.count() == 1 - assert type(list(new_company.environment.roles.values())[0]._actions[0]) == ActionOK - - new_company.start_project(idea) + new_company.run_project(idea) await new_company.run(n_round=4) @@ -89,19 +86,18 @@ async def test_team_recover_save(): company = Team() role_c = RoleC() company.hire([role_c]) - company.start_project(idea) + company.run_project(idea) await company.run(n_round=4) new_company = Team.recover(stg_path) - new_role_c = new_company.environment.get_role(role_c.profile) - assert new_role_c._rc.memory == role_c._rc.memory + new_role_c = new_company.env.get_role(role_c.profile) + # assert new_role_c._rc.memory == role_c._rc.memory assert new_role_c._rc.env != role_c._rc.env assert new_role_c.recovered != role_c.recovered # here cause previous ut is `!=` assert new_role_c._rc.todo != role_c._rc.todo # serialize exclude `_rc.todo` assert new_role_c._rc.news != role_c._rc.news # serialize exclude `_rc.news` - assert new_role_c._rc.env.memory == role_c._rc.env.memory - new_company.start_project(idea) + new_company.run_project(idea) await new_company.run(n_round=4) @@ -113,9 +109,9 @@ async def test_team_recover_multi_roles_save(): company = Team() company.hire([RoleA(), RoleB()]) - company.start_project(idea) + company.run_project(idea) await company.run(n_round=4) new_company = Team.recover(stg_path) - new_company.start_project(idea) + new_company.run_project(idea) await new_company.run(n_round=4) diff --git a/tests/metagpt/serialize_deserialize/test_write_code.py b/tests/metagpt/serialize_deserialize/test_write_code.py index 5552ffd7f..0114c48da 100644 --- a/tests/metagpt/serialize_deserialize/test_write_code.py +++ b/tests/metagpt/serialize_deserialize/test_write_code.py @@ -2,10 +2,12 @@ # @Date : 11/23/2023 10:56 AM # @Author : stellahong (stellahong@fuzhi.ai) # @Desc : + import pytest -from metagpt.actions import WriteCode, WriteCodeReview +from metagpt.actions import WriteCode from metagpt.llm import LLM +from metagpt.schema import CodingContext, Document def test_write_design_serialize(): @@ -15,30 +17,15 @@ def test_write_design_serialize(): # assert "llm" in ser_action_dict # not export -def test_write_task_serialize(): - action = WriteCodeReview() - ser_action_dict = action.dict() - assert ser_action_dict["name"] == "WriteCodeReview" - # assert "llm" in ser_action_dict # not export - - @pytest.mark.asyncio async def test_write_code_deserialize(): - action = WriteCode() + context = CodingContext(filename="test_code.py", + design_doc=Document(content="write add function to calculate two numbers")) + doc = Document(content=context.json()) + action = WriteCode(context=doc) serialized_data = action.dict() new_action = WriteCode(**serialized_data) + assert new_action.name == "WriteCode" assert new_action.llm == LLM() - await new_action.run(context="write a cli snake game", filename="test_code") - - -@pytest.mark.asyncio -async def test_write_code_review_deserialize(): - action = WriteCodeReview() - serialized_data = action.dict() - new_action = WriteCodeReview(**serialized_data) - code = await WriteCode().run(context="write a cli snake game", filename="test_code") - - assert new_action.name == "WriteCodeReview" - assert new_action.llm == LLM() - await new_action.run(context="write a cli snake game", code=code, filename="test_rewrite_code") + await action.run() diff --git a/tests/metagpt/serialize_deserialize/test_write_code_review.py b/tests/metagpt/serialize_deserialize/test_write_code_review.py new file mode 100644 index 000000000..6ca4c6027 --- /dev/null +++ b/tests/metagpt/serialize_deserialize/test_write_code_review.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : unittest of WriteCodeReview SerDeser + +import pytest + +from metagpt.actions import WriteCodeReview +from metagpt.llm import LLM +from metagpt.schema import CodingContext, Document + + +def test_write_task_serialize(): + action = WriteCodeReview() + ser_action_dict = action.dict() + assert ser_action_dict["name"] == "WriteCodeReview" + # assert "llm" in ser_action_dict # not export + + +@pytest.mark.asyncio +async def test_write_code_review_deserialize(): + code_content = """ +def div(a: int, b: int = 0): + return a / b +""" + context = CodingContext( + filename="test_op.py", + design_doc=Document(content="divide two numbers"), + code_doc=Document(content=code_content) + ) + + action = WriteCodeReview(context=context) + serialized_data = action.dict() + new_action = WriteCodeReview(**serialized_data) + + assert new_action.name == "WriteCodeReview" + assert new_action.llm == LLM() + await new_action.run() diff --git a/tests/metagpt/serialize_deserialize/test_write_design.py b/tests/metagpt/serialize_deserialize/test_write_design.py index 080896c98..4e768ddd7 100644 --- a/tests/metagpt/serialize_deserialize/test_write_design.py +++ b/tests/metagpt/serialize_deserialize/test_write_design.py @@ -29,7 +29,7 @@ async def test_write_design_deserialize(): new_action = WriteDesign(**serialized_data) assert new_action.name == "" assert new_action.llm == LLM() - await new_action.run(context="write a cli snake game") + await new_action.run(with_messages="write a cli snake game") @pytest.mark.asyncio @@ -39,4 +39,4 @@ async def test_write_task_deserialize(): new_action = WriteTasks(**serialized_data) assert new_action.name == "CreateTasks" assert new_action.llm == LLM() - await new_action.run(context="write a cli snake game") + await new_action.run(with_messages="write a cli snake game") diff --git a/tests/metagpt/serialize_deserialize/test_wrire_prd.py b/tests/metagpt/serialize_deserialize/test_write_prd.py similarity index 87% rename from tests/metagpt/serialize_deserialize/test_wrire_prd.py rename to tests/metagpt/serialize_deserialize/test_write_prd.py index 0b9dfa9d8..d6d14f99a 100644 --- a/tests/metagpt/serialize_deserialize/test_wrire_prd.py +++ b/tests/metagpt/serialize_deserialize/test_write_prd.py @@ -2,6 +2,7 @@ # @Date : 11/22/2023 1:47 PM # @Author : stellahong (stellahong@fuzhi.ai) # @Desc : + import pytest from metagpt.actions import WritePRD @@ -23,5 +24,5 @@ async def test_action_deserialize(): new_action = WritePRD(**serialized_data) assert new_action.name == "" assert new_action.llm == LLM() - action_output = await new_action.run([Message(content="write a cli snake game")]) + action_output = await new_action.run(with_messages=Message(content="write a cli snake game")) assert len(action_output.content) > 0 diff --git a/tests/metagpt/test_schema.py b/tests/metagpt/test_schema.py index ca8b9043f..10343c192 100644 --- a/tests/metagpt/test_schema.py +++ b/tests/metagpt/test_schema.py @@ -8,22 +8,16 @@ the utilization of the new feature of `Message` class. """ -<<<<<<< HEAD import json - import pytest from metagpt.actions import Action -======= ->>>>>>> a69be36abf7beef1a989a707d1aa027948c07fee from metagpt.schema import AIMessage, Message, SystemMessage, UserMessage from metagpt.actions.action_output import ActionOutput from metagpt.actions.write_code import WriteCode from metagpt.utils.serialize import serialize_general_message, deserialize_general_message -<<<<<<< HEAD + from metagpt.utils.common import get_class_name -======= ->>>>>>> a69be36abf7beef1a989a707d1aa027948c07fee @pytest.mark.asyncio From 35ac28c30eae3ef9728bfd10c84bb3ae212c653e Mon Sep 17 00:00:00 2001 From: better629 Date: Tue, 19 Dec 2023 14:04:09 +0800 Subject: [PATCH 050/167] format serialize_deserialize tests code --- .../test_architect_deserialize.py | 2 +- .../metagpt/serialize_deserialize/test_environment.py | 6 ++---- tests/metagpt/serialize_deserialize/test_memory.py | 8 +++----- .../serialize_deserialize/test_product_manager.py | 2 +- .../serialize_deserialize/test_project_manager.py | 2 +- tests/metagpt/serialize_deserialize/test_role.py | 10 +++++----- tests/metagpt/serialize_deserialize/test_schema.py | 3 +-- .../serialize_deserialize/test_serdeser_base.py | 11 +++++------ tests/metagpt/serialize_deserialize/test_team.py | 5 ++--- 9 files changed, 21 insertions(+), 28 deletions(-) diff --git a/tests/metagpt/serialize_deserialize/test_architect_deserialize.py b/tests/metagpt/serialize_deserialize/test_architect_deserialize.py index 66fba6167..b92eba8a1 100644 --- a/tests/metagpt/serialize_deserialize/test_architect_deserialize.py +++ b/tests/metagpt/serialize_deserialize/test_architect_deserialize.py @@ -4,8 +4,8 @@ # @Desc : import pytest -from metagpt.roles.architect import Architect from metagpt.actions.action import Action +from metagpt.roles.architect import Architect def test_architect_serialize(): diff --git a/tests/metagpt/serialize_deserialize/test_environment.py b/tests/metagpt/serialize_deserialize/test_environment.py index 4e3445047..3a374460c 100644 --- a/tests/metagpt/serialize_deserialize/test_environment.py +++ b/tests/metagpt/serialize_deserialize/test_environment.py @@ -2,17 +2,15 @@ # -*- coding: utf-8 -*- # @Desc : -from pathlib import Path import shutil -from metagpt.schema import Message from metagpt.actions.action_output import ActionOutput -from metagpt.roles.project_manager import ProjectManager from metagpt.actions.add_requirement import UserRequirement from metagpt.actions.project_management import WriteTasks from metagpt.environment import Environment +from metagpt.roles.project_manager import ProjectManager +from metagpt.schema import Message from metagpt.utils.common import any_to_str - from tests.metagpt.serialize_deserialize.test_serdeser_base import RoleC, ActionOK, serdeser_path diff --git a/tests/metagpt/serialize_deserialize/test_memory.py b/tests/metagpt/serialize_deserialize/test_memory.py index 50d30a94d..47410c615 100644 --- a/tests/metagpt/serialize_deserialize/test_memory.py +++ b/tests/metagpt/serialize_deserialize/test_memory.py @@ -2,16 +2,14 @@ # -*- coding: utf-8 -*- # @Desc : unittest of memory -from pathlib import Path from pydantic import BaseModel -from metagpt.schema import Message -from metagpt.memory.memory import Memory from metagpt.actions.action_output import ActionOutput -from metagpt.actions.design_api import WriteDesign from metagpt.actions.add_requirement import UserRequirement +from metagpt.actions.design_api import WriteDesign +from metagpt.memory.memory import Memory +from metagpt.schema import Message from metagpt.utils.common import any_to_str - from tests.metagpt.serialize_deserialize.test_serdeser_base import serdeser_path diff --git a/tests/metagpt/serialize_deserialize/test_product_manager.py b/tests/metagpt/serialize_deserialize/test_product_manager.py index 1d721282f..b65e329d1 100644 --- a/tests/metagpt/serialize_deserialize/test_product_manager.py +++ b/tests/metagpt/serialize_deserialize/test_product_manager.py @@ -4,8 +4,8 @@ # @Desc : import pytest -from metagpt.roles.product_manager import ProductManager from metagpt.actions.action import Action +from metagpt.roles.product_manager import ProductManager from metagpt.schema import Message diff --git a/tests/metagpt/serialize_deserialize/test_project_manager.py b/tests/metagpt/serialize_deserialize/test_project_manager.py index 21fafa72e..e52e3f247 100644 --- a/tests/metagpt/serialize_deserialize/test_project_manager.py +++ b/tests/metagpt/serialize_deserialize/test_project_manager.py @@ -4,9 +4,9 @@ # @Desc : import pytest -from metagpt.roles.project_manager import ProjectManager from metagpt.actions.action import Action from metagpt.actions.project_management import WriteTasks +from metagpt.roles.project_manager import ProjectManager def test_project_manager_serialize(): diff --git a/tests/metagpt/serialize_deserialize/test_role.py b/tests/metagpt/serialize_deserialize/test_role.py index fe7b63ef3..f25403dc0 100644 --- a/tests/metagpt/serialize_deserialize/test_role.py +++ b/tests/metagpt/serialize_deserialize/test_role.py @@ -4,18 +4,18 @@ # @Desc : import shutil + import pytest -from metagpt.logs import logger -from metagpt.roles.role import Role from metagpt.actions import WriteCode -from metagpt.schema import Message from metagpt.actions.add_requirement import UserRequirement -from metagpt.roles.product_manager import ProductManager from metagpt.const import SERDESER_PATH +from metagpt.logs import logger from metagpt.roles.engineer import Engineer +from metagpt.roles.product_manager import ProductManager +from metagpt.roles.role import Role +from metagpt.schema import Message from metagpt.utils.utils import format_trackback_info - from tests.metagpt.serialize_deserialize.test_serdeser_base import RoleA, RoleB, RoleC, serdeser_path diff --git a/tests/metagpt/serialize_deserialize/test_schema.py b/tests/metagpt/serialize_deserialize/test_schema.py index 97ca4ea0c..02afa762d 100644 --- a/tests/metagpt/serialize_deserialize/test_schema.py +++ b/tests/metagpt/serialize_deserialize/test_schema.py @@ -2,11 +2,10 @@ # -*- coding: utf-8 -*- # @Desc : unittest of schema ser&deser -from metagpt.schema import Message from metagpt.actions.action_output import ActionOutput from metagpt.actions.write_code import WriteCode +from metagpt.schema import Message from metagpt.utils.common import any_to_str - from tests.metagpt.serialize_deserialize.test_serdeser_base import MockMessage diff --git a/tests/metagpt/serialize_deserialize/test_serdeser_base.py b/tests/metagpt/serialize_deserialize/test_serdeser_base.py index 0363c519b..20f708e30 100644 --- a/tests/metagpt/serialize_deserialize/test_serdeser_base.py +++ b/tests/metagpt/serialize_deserialize/test_serdeser_base.py @@ -2,15 +2,15 @@ # -*- coding: utf-8 -*- # @Desc : base test actions / roles used in unittest -from pydantic import BaseModel, Field -from pathlib import Path import asyncio +from pathlib import Path + +from pydantic import BaseModel, Field from metagpt.actions.action import Action -from metagpt.roles.role import Role, RoleReactMode -from metagpt.actions.add_requirement import UserRequirement from metagpt.actions.action_output import ActionOutput - +from metagpt.actions.add_requirement import UserRequirement +from metagpt.roles.role import Role, RoleReactMode serdeser_path = Path(__file__).absolute().parent.joinpath("..", "..", "data", "serdeser_storage") @@ -51,7 +51,6 @@ class ActionRaise(Action): class RoleA(Role): - name: str = Field(default="RoleA") profile: str = Field(default="Role A") goal: str = "RoleA's goal" diff --git a/tests/metagpt/serialize_deserialize/test_team.py b/tests/metagpt/serialize_deserialize/test_team.py index 777f0f381..01e0a6c70 100644 --- a/tests/metagpt/serialize_deserialize/test_team.py +++ b/tests/metagpt/serialize_deserialize/test_team.py @@ -3,14 +3,13 @@ # @Author : stellahong (stellahong@fuzhi.ai) # @Desc : -from pathlib import Path import shutil + import pytest +from metagpt.const import SERDESER_PATH from metagpt.roles import ProjectManager, ProductManager, Architect from metagpt.team import Team -from metagpt.const import SERDESER_PATH - from tests.metagpt.serialize_deserialize.test_serdeser_base import RoleA, RoleB, RoleC, serdeser_path, ActionOK From ebc4fe4b179acfe8c373afb8e2ee922e15fb06c6 Mon Sep 17 00:00:00 2001 From: better629 Date: Tue, 19 Dec 2023 14:22:52 +0800 Subject: [PATCH 051/167] update ser&deser after env_refactor --- metagpt/actions/action.py | 24 ++--- metagpt/actions/prepare_documents.py | 2 - metagpt/actions/write_code.py | 13 +-- metagpt/actions/write_code_review.py | 38 ++++---- metagpt/actions/write_prd.py | 18 ++-- metagpt/environment.py | 9 +- metagpt/memory/memory.py | 16 ++-- metagpt/roles/architect.py | 10 +-- metagpt/roles/engineer.py | 13 +-- metagpt/roles/product_manager.py | 3 +- metagpt/roles/project_manager.py | 2 - metagpt/roles/role.py | 129 +++++++++------------------ metagpt/schema.py | 63 +++++++------ metagpt/team.py | 9 +- metagpt/utils/utils.py | 3 +- 15 files changed, 152 insertions(+), 200 deletions(-) diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index c941d44b6..a21f575ea 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -7,23 +7,21 @@ """ from __future__ import annotations -import re -from typing import Optional, Any from typing import Optional, Any -from tenacity import retry, stop_after_attempt, wait_random_exponential + from pydantic import BaseModel, Field +from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.actions.action_output import ActionOutput from metagpt.llm import LLM -from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.logs import logger +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.provider.postprecess.llm_output_postprecess import llm_output_postprecess from metagpt.utils.common import OutputParser from metagpt.utils.utils import general_after_log from metagpt.utils.utils import import_class - action_subclass_registry = {} @@ -31,9 +29,10 @@ class Action(BaseModel): name: str = "" llm: BaseGPTAPI = Field(default_factory=LLM, exclude=True) context = "" - prefix = "" # aask*时会加上prefix,作为system_message + prefix = "" # aask*时会加上prefix,作为system_message profile = "" # FIXME: USELESS - desc = "" # for skill manager + desc = "" # for skill manager + nodes = [] # content: Optional[str] = None # instruct_content: Optional[str] = None @@ -42,7 +41,7 @@ class Action(BaseModel): class Config: arbitrary_types_allowed = True - + def __init__(self, **kwargs: Any): super().__init__(**kwargs) @@ -64,10 +63,11 @@ class Action(BaseModel): """Set prefix for later usage""" self.prefix = prefix self.profile = profile + return self def __str__(self): return self.__class__.__name__ - + def __repr__(self): return self.__str__() @@ -110,16 +110,16 @@ class Action(BaseModel): content = await self.llm.aask(prompt, system_msgs) logger.debug(f"llm raw output:\n{content}") output_class = ActionOutput.create_model_class(output_class_name, output_data_mapping) - + if format == "json": parsed_data = llm_output_postprecess(output=content, schema=output_class.schema(), req_key="[/CONTENT]") else: # using markdown parser parsed_data = OutputParser.parse_data_with_mapping(content, output_data_mapping) - + logger.debug(parsed_data) instruct_content = output_class(**parsed_data) return ActionOutput(content, instruct_content) - + async def run(self, *args, **kwargs): """Run action""" raise NotImplementedError("The run method should be implemented in a subclass.") diff --git a/metagpt/actions/prepare_documents.py b/metagpt/actions/prepare_documents.py index 8d3445ae4..af38b7eae 100644 --- a/metagpt/actions/prepare_documents.py +++ b/metagpt/actions/prepare_documents.py @@ -19,8 +19,6 @@ from metagpt.utils.git_repository import GitRepository class PrepareDocuments(Action): - def __init__(self, name="", context=None, llm=None): - super().__init__(name, context, llm) async def run(self, with_messages, **kwargs): if not CONFIG.git_repo: diff --git a/metagpt/actions/write_code.py b/metagpt/actions/write_code.py index bad9a0890..046f9f456 100644 --- a/metagpt/actions/write_code.py +++ b/metagpt/actions/write_code.py @@ -16,9 +16,10 @@ """ import json -from tenacity import retry, stop_after_attempt, wait_random_exponential -from typing import List, Optional, Any +from typing import Optional + from pydantic import Field +from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.actions.action import Action from metagpt.config import CONFIG @@ -30,8 +31,8 @@ from metagpt.const import ( TEST_OUTPUTS_FILE_REPO, ) from metagpt.llm import LLM -from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.logs import logger +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.schema import CodingContext, Document, RunCodeResult from metagpt.utils.common import CodeParser from metagpt.utils.file_repository import FileRepository @@ -89,7 +90,7 @@ ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenc class WriteCode(Action): name: str = "WriteCode" - context: Optional[str] = None + context: Optional[Document] = None llm: BaseGPTAPI = Field(default_factory=LLM) @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) @@ -131,7 +132,9 @@ class WriteCode(Action): logger.info(f"Writing {coding_context.filename}..") code = await self.write_code(prompt) if not coding_context.code_doc: - coding_context.code_doc = Document(filename=coding_context.filename, root_path=CONFIG.src_workspace) + # avoid root_path pydantic ValidationError if use WriteCode alone + root_path = CONFIG.src_workspace if CONFIG.src_workspace else "" + coding_context.code_doc = Document(filename=coding_context.filename, root_path=root_path) coding_context.code_doc.content = code return coding_context diff --git a/metagpt/actions/write_code_review.py b/metagpt/actions/write_code_review.py index 636f3f12a..f4ab0adfe 100644 --- a/metagpt/actions/write_code_review.py +++ b/metagpt/actions/write_code_review.py @@ -7,21 +7,19 @@ @Modified By: mashenquan, 2023/11/27. Following the think-act principle, solidify the task parameters when creating the WriteCode object, rather than passing them in when calling the run function. """ -from typing import List, Optional, Any -from pydantic import Field -from tenacity import retry, stop_after_attempt, wait_fixed -from typing import List, Optional, Any +from typing import Optional + from pydantic import Field from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.actions import WriteCode -from metagpt.llm import LLM from metagpt.actions.action import Action from metagpt.config import CONFIG +from metagpt.llm import LLM from metagpt.logs import logger -from metagpt.schema import CodingContext from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.schema import CodingContext from metagpt.utils.common import CodeParser PROMPT_TEMPLATE = """ @@ -39,7 +37,6 @@ ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenc ``` """ - EXAMPLE_AND_INSTRUCTION = """ {format_example} @@ -127,7 +124,7 @@ REWRITE_CODE_TEMPLATE = """ class WriteCodeReview(Action): name: str = "WriteCodeReview" - context: Optional[str] = None + context: Optional[CodingContext] = None llm: BaseGPTAPI = Field(default_factory=LLM) @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) @@ -147,9 +144,15 @@ class WriteCodeReview(Action): iterative_code = self.context.code_doc.content k = CONFIG.code_review_k_times or 1 for i in range(k): - format_example = FORMAT_EXAMPLE.format(filename=self.context.code_doc.filename) - task_content = self.context.task_doc.content if self.context.task_doc else "" - code_context = await WriteCode.get_codes(self.context.task_doc, exclude=self.context.filename) + format_example = FORMAT_EXAMPLE.format( + filename=self.context.code_doc.filename + ) + task_content = ( + self.context.task_doc.content if self.context.task_doc else "" + ) + code_context = await WriteCode.get_codes( + self.context.task_doc, exclude=self.context.filename + ) context = "\n".join( [ "## System Design\n" + str(self.context.design_doc) + "\n", @@ -162,11 +165,16 @@ class WriteCodeReview(Action): code=iterative_code, filename=self.context.code_doc.filename, ) - cr_prompt = EXAMPLE_AND_INSTRUCTION.format(format_example=format_example, ) - logger.info( - f"Code review and rewrite {self.context.code_doc.filename}: {i+1}/{k} | {len(iterative_code)=}, {len(self.context.code_doc.content)=}" + cr_prompt = EXAMPLE_AND_INSTRUCTION.format( + format_example=format_example, + ) + logger.info( + f"Code review and rewrite {self.context.code_doc.filename}: {i + 1}/{k} | {len(iterative_code)=}, " + f"{len(self.context.code_doc.content)=}" + ) + result, rewrited_code = await self.write_code_review_and_rewrite( + context_prompt, cr_prompt, self.context.code_doc.filename ) - result, rewrited_code = await self.write_code_review_and_rewrite(context_prompt, cr_prompt, self.context.code_doc.filename) if "LBTM" in result: iterative_code = rewrited_code elif "LGTM" in result: diff --git a/metagpt/actions/write_prd.py b/metagpt/actions/write_prd.py index 8510733ac..e76e91272 100644 --- a/metagpt/actions/write_prd.py +++ b/metagpt/actions/write_prd.py @@ -15,8 +15,9 @@ from __future__ import annotations import json from pathlib import Path -from typing import List, Optional, Any -from pydantic import BaseModel, Field +from typing import Optional + +from pydantic import Field from metagpt.actions import Action, ActionOutput from metagpt.actions.action_node import ActionNode @@ -26,9 +27,6 @@ from metagpt.actions.write_prd_an import ( WP_ISSUE_TYPE_NODE, WRITE_PRD_NODE, ) -from metagpt.llm import LLM -from metagpt.provider.base_gpt_api import BaseGPTAPI -from metagpt.actions.search_and_summarize import SearchAndSummarize from metagpt.config import CONFIG from metagpt.const import ( BUGFIX_FILENAME, @@ -38,13 +36,14 @@ from metagpt.const import ( PRDS_FILE_REPO, REQUIREMENT_FILENAME, ) +from metagpt.llm import LLM from metagpt.logs import logger +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.schema import BugFixContext, Document, Documents, Message from metagpt.utils.common import CodeParser from metagpt.utils.file_repository import FileRepository from metagpt.utils.mermaid import mermaid_to_file - CONTEXT_TEMPLATE = """ ### Project Name {project_name} @@ -75,7 +74,7 @@ class WritePRD(Action): # related to the PRD. If they are related, rewrite the PRD. docs_file_repo = CONFIG.git_repo.new_file_repository(relative_path=DOCS_FILE_REPO) requirement_doc = await docs_file_repo.get(filename=REQUIREMENT_FILENAME) - if await self._is_bugfix(requirement_doc.content): + if requirement_doc and await self._is_bugfix(requirement_doc.content): await docs_file_repo.save(filename=BUGFIX_FILENAME, content=requirement_doc.content) await docs_file_repo.save(filename=REQUIREMENT_FILENAME, content="") bug_fix = BugFixContext(filename=BUGFIX_FILENAME) @@ -144,7 +143,8 @@ class WritePRD(Action): async def _update_prd(self, requirement_doc, prd_doc, prds_file_repo, *args, **kwargs) -> Document | None: if not prd_doc: - prd = await self._run_new_requirement(requirements=[requirement_doc.content], *args, **kwargs) + prd = await self._run_new_requirement(requirements=[requirement_doc.content if requirement_doc else ""], + *args, **kwargs) new_prd_doc = Document( root_path=PRDS_FILE_REPO, filename=FileRepository.new_filename() + ".json", @@ -166,7 +166,7 @@ class WritePRD(Action): if not quadrant_chart: return pathname = ( - CONFIG.git_repo.workdir / Path(COMPETITIVE_ANALYSIS_FILE_REPO) / Path(prd_doc.filename).with_suffix("") + CONFIG.git_repo.workdir / Path(COMPETITIVE_ANALYSIS_FILE_REPO) / Path(prd_doc.filename).with_suffix("") ) if not pathname.parent.exists(): pathname.parent.mkdir(parents=True, exist_ok=True) diff --git a/metagpt/environment.py b/metagpt/environment.py index 19c77a03d..4c8d7d5e5 100644 --- a/metagpt/environment.py +++ b/metagpt/environment.py @@ -12,14 +12,12 @@ functionality is to be consolidated into the `Environment` class. """ import asyncio -from typing import Iterable, Set from pathlib import Path +from typing import Iterable, Set from pydantic import BaseModel, Field from metagpt.logs import logger -from metagpt.roles import Role -from metagpt.memory import Memory from metagpt.roles.role import Role, role_subclass_registry from metagpt.schema import Message from metagpt.utils.common import is_subscribed @@ -29,7 +27,6 @@ from metagpt.utils.utils import read_json_file, write_json_file class Environment(BaseModel): """环境,承载一批角色,角色可以向环境发布消息,可以被其他角色观察到 Environment, hosting a batch of roles, roles can publish messages to the environment, and can be observed by other roles - """ roles: dict[str, Role] = Field(default_factory=dict) @@ -63,12 +60,11 @@ class Environment(BaseModel): roles_info.append({ "role_class": role.__class__.__name__, "module_name": role.__module__, - "role_name": role.name + "role_name": role.name, }) role.serialize(stg_path=stg_path.joinpath(f"roles/{role.__class__.__name__}_{role.name}")) write_json_file(roles_path, roles_info) - self.memory.serialize(stg_path) history_path = stg_path.joinpath("history.json") write_json_file(history_path, {"content": self.history}) @@ -92,6 +88,7 @@ class Environment(BaseModel): "history": history }) environment.add_roles(roles) + return environment def add_role(self, role: Role): diff --git a/metagpt/memory/memory.py b/metagpt/memory/memory.py index b647198e3..fe70358c9 100644 --- a/metagpt/memory/memory.py +++ b/metagpt/memory/memory.py @@ -8,16 +8,14 @@ """ import copy from collections import defaultdict - -from typing import Iterable, Type, Union, Optional, Set from pathlib import Path +from typing import Iterable, Set + from pydantic import BaseModel, Field -import json from metagpt.schema import Message from metagpt.utils.common import any_to_str, any_to_str_set from metagpt.utils.utils import read_json_file, write_json_file -from metagpt.utils.utils import import_class class Memory(BaseModel): @@ -30,10 +28,7 @@ class Memory(BaseModel): index = kwargs.get("index", {}) new_index = defaultdict(list) for action_str, value in index.items(): - action_dict = json.loads(action_str) - action_class = import_class("Action", "metagpt.actions.action") - action_obj = action_class.deser_class(action_dict) - new_index[action_obj] = [Message(**item_dict) for item_dict in value] + new_index[action_str] = [Message(**item_dict) for item_dict in value] kwargs["index"] = new_index super(Memory, self).__init__(**kwargs) self.index = new_index @@ -43,9 +38,8 @@ class Memory(BaseModel): obj_dict = super(Memory, self).dict(*args, **kwargs) new_obj_dict = copy.deepcopy(obj_dict) new_obj_dict["index"] = {} - for action, value in obj_dict["index"].items(): - action_ser = json.dumps(action.ser_class()) - new_obj_dict["index"][action_ser] = value + for action_str, value in obj_dict["index"].items(): + new_obj_dict["index"][action_str] = value return new_obj_dict def serialize(self, stg_path: Path): diff --git a/metagpt/roles/architect.py b/metagpt/roles/architect.py index 266ffc256..9edfe33d9 100644 --- a/metagpt/roles/architect.py +++ b/metagpt/roles/architect.py @@ -23,11 +23,11 @@ class Architect(Role): constraints (str): Constraints or guidelines for the architect. """ - name: str = "Bob" - profile: str = Field(default="Architect", alias='profile') - goal: str = "design a concise, usable, complete software system" - constraints: str = "make sure the architecture is simple enough and use appropriate open source libraries." \ - "Use same language as user requirement" + name: str = Field(default="Bob") + profile: str = Field(default="Architect") + goal: str = Field(default="design a concise, usable, complete software system") + constraints: str = Field(default="make sure the architecture is simple enough and use appropriate open source " + "libraries. Use same language as user requirement") def __init__(self, **kwargs) -> None: super().__init__(**kwargs) diff --git a/metagpt/roles/engineer.py b/metagpt/roles/engineer.py index ad3d0f66a..206afb38c 100644 --- a/metagpt/roles/engineer.py +++ b/metagpt/roles/engineer.py @@ -18,12 +18,14 @@ """ from __future__ import annotations -from pydantic import Field + import json from collections import defaultdict from pathlib import Path from typing import Set +from pydantic import Field + from metagpt.actions import Action, WriteCode, WriteCodeReview, WriteTasks from metagpt.actions.fix_bug import FixBug from metagpt.actions.summarize_code import SummarizeCode @@ -45,7 +47,6 @@ from metagpt.schema import ( ) from metagpt.utils.common import any_to_str, any_to_str_set - IS_PASS_PROMPT = """ {context} @@ -69,15 +70,15 @@ class Engineer(Role): use_code_review (bool): Whether to use code review. """ name: str = "Alex" - role_profile: str = Field(default="Engineer", alias='profile') + profile: str = Field(default="Engineer") goal: str = "write elegant, readable, extensible, efficient code" constraints: str = "the code should conform to standards like google-style and be modular and maintainable. " \ - "Use same language as user requirement", + "Use same language as user requirement" n_borg: int = 1 use_code_review: bool = False code_todos: list = [] summarize_todos = [] - + def __init__(self, **kwargs) -> None: super().__init__(**kwargs) @@ -211,7 +212,7 @@ class Engineer(Role): @staticmethod async def _new_coding_context( - filename, src_file_repo, task_file_repo, design_file_repo, dependency + filename, src_file_repo, task_file_repo, design_file_repo, dependency ) -> CodingContext: old_code_doc = await src_file_repo.get(filename) if not old_code_doc: diff --git a/metagpt/roles/product_manager.py b/metagpt/roles/product_manager.py index 30017b60d..d054b94f5 100644 --- a/metagpt/roles/product_manager.py +++ b/metagpt/roles/product_manager.py @@ -26,13 +26,14 @@ class ProductManager(Role): constraints (str): Constraints or limitations for the project manager. """ name: str = "Alice" - role_profile: str = Field(default="Product Manager", alias='profile') + profile: str = Field(default="Product Manager") goal: str = "efficiently create a successful product" constraints: str = "use same language as user requiremen" """ Represents a Product Manager role responsible for product development and management. """ + def __init__(self, **kwargs) -> None: super().__init__(**kwargs) diff --git a/metagpt/roles/project_manager.py b/metagpt/roles/project_manager.py index d885f2ee6..ec93e609b 100644 --- a/metagpt/roles/project_manager.py +++ b/metagpt/roles/project_manager.py @@ -24,8 +24,6 @@ class ProjectManager(Role): """ name: str = Field(default="Eve") profile: str = Field(default="Project Manager") - - goal: str = "reak down tasks according to PRD/technical design, generate a task list, and analyze task " \ "dependencies to start with the prerequisite modules" constraints: str = "use same language as user requirement" diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index bed5a38e7..dbbaf8713 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -20,42 +20,26 @@ """ from __future__ import annotations + from enum import Enum -from typing import Iterable, Set, Type from pathlib import Path +from typing import Iterable, Set, Type, Any + from pydantic import BaseModel, Field from metagpt.actions.action import Action, ActionOutput, action_subclass_registry from metagpt.actions.action_node import ActionNode from metagpt.actions.add_requirement import UserRequirement - -from pathlib import Path - -from typing import ( - Iterable, - Type, - Any -) -from pydantic import BaseModel, Field, validator - -# from metagpt.environment import Environment -from metagpt.config import CONFIG -from metagpt.actions.action import Action, ActionOutput, action_subclass_registry +from metagpt.const import SERDESER_PATH from metagpt.llm import LLM -from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.logs import logger +from metagpt.memory import Memory +from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.human_provider import HumanProvider from metagpt.schema import Message, MessageQueue from metagpt.utils.common import any_to_str from metagpt.utils.repair_llm_raw_output import extract_state_value_from_output -from metagpt.memory import Memory -from metagpt.provider.human_provider import HumanProvider - from metagpt.utils.utils import read_json_file, write_json_file, import_class -from metagpt.provider.base_gpt_api import BaseGPTAPI - -from metagpt.utils.utils import read_json_file, write_json_file, import_class, role_raise_decorator -from metagpt.const import SERDESER_PATH - PREFIX_TEMPLATE = """You are a {profile}, named {name}, your goal is {goal}, and the constraint is {constraints}. """ @@ -65,12 +49,14 @@ Please note that only the text between the first and second "===" is information {history} === -You can now choose one of the following stages to decide the stage you need to go in the next step: +Your previous stage: {previous_state} + +Now choose one of the following stages you need to go to in the next step: {states} Just answer a number between 0-{n_states}, choose the most suitable stage according to the understanding of the conversation. Please note that the answer only needs a number, no need to add any other text. -If there is no conversation record, choose 0. +If you think you have completed your goal and don't need to go to any of the stages, return -1. Do not answer anything else, and do not add any other information in your answer. """ @@ -106,7 +92,7 @@ class RoleSetting(BaseModel): def __str__(self): return f"{self.name}({self.profile})" - + def __repr__(self): return self.__str__() @@ -115,37 +101,21 @@ class RoleContext(BaseModel): """Role Runtime Context""" # # env exclude=True to avoid `RecursionError: maximum recursion depth exceeded in comparison` env: "Environment" = Field(default=None, exclude=True) - msg_buffer: MessageQueue = Field(default_factory=MessageQueue) # Message Buffer with Asynchronous Updates + # TODO judge if ser&deser + msg_buffer: MessageQueue = Field(default_factory=MessageQueue, + exclude=True) # Message Buffer with Asynchronous Updates memory: Memory = Field(default_factory=Memory) # long_term_memory: LongTermMemory = Field(default_factory=LongTermMemory) state: int = Field(default=-1) # -1 indicates initial or termination state where todo is None todo: Action = Field(default=None, exclude=True) - watch: set[Type[Action]] = Field(default_factory=set) + watch: set[str] = Field(default_factory=set) news: list[Type[Message]] = Field(default=[], exclude=True) # TODO not used - react_mode: RoleReactMode = RoleReactMode.REACT # see `Role._set_react_mode` for definitions of the following two attributes + react_mode: RoleReactMode = RoleReactMode.REACT # see `Role._set_react_mode` for definitions of the following two attributes max_react_loop: int = 1 - + class Config: arbitrary_types_allowed = True - def __init__(self, **kwargs): - watch_info = kwargs.get("watch", set()) - watch = set() - for item in watch_info: - action = Action.deser_class(item) - watch.update([action]) - kwargs["watch"] = watch - super(RoleContext, self).__init__(**kwargs) - - def dict(self, *args, **kwargs) -> "DictStrAny": - obj_dict = super(RoleContext, self).dict(*args, **kwargs) - watch = obj_dict.get("watch", set()) - watch_info = [] - for item in watch: - watch_info.append(item.ser_class()) - obj_dict["watch"] = watch_info - return obj_dict - def check(self, role_id: str): # if hasattr(CONFIG, "long_term_memory") and CONFIG.long_term_memory: # self.long_term_memory.recover_memory(role_id, self) @@ -156,26 +126,16 @@ class RoleContext(BaseModel): def important_memory(self) -> list[Message]: """Get the information corresponding to the watched actions""" return self.memory.get_by_actions(self.watch) - + @property def history(self) -> list[Message]: return self.memory.get() -class _RoleInjector(type): - def __call__(cls, *args, **kwargs): - instance = super().__call__(*args, **kwargs) - - if not instance._rc.watch: - instance._watch([UserRequirement]) - - return instance - - role_subclass_registry = {} -class Role(BaseModel, metaclass=_RoleInjector): +class Role(BaseModel): """Role/Agent""" name: str = "" profile: str = "" @@ -189,7 +149,7 @@ class Role(BaseModel, metaclass=_RoleInjector): _states: list[str] = Field(default=[]) _actions: list[Action] = Field(default=[]) _rc: RoleContext = Field(default=RoleContext) - _subscription: tuple = set() + _subscription: tuple[str] = set() # builtin variables recovered: bool = False # to tag if a recovered role @@ -203,6 +163,8 @@ class Role(BaseModel, metaclass=_RoleInjector): "_rc": RoleContext() } + __hash__ = object.__hash__ # support Role as hashable type in `Environment.members` + class Config: arbitrary_types_allowed = True exclude = ["_llm"] @@ -240,6 +202,9 @@ class Role(BaseModel, metaclass=_RoleInjector): else: object.__setattr__(self, key, self._private_attributes[key]) + if not self._rc.watch: + self._watch([UserRequirement]) + # deserialize child classes dynamically for inherited `role` object.__setattr__(self, "builtin_class_name", self.__class__.__name__) self.__fields__["builtin_class_name"].default = self.__class__.__name__ @@ -303,7 +268,7 @@ class Role(BaseModel, metaclass=_RoleInjector): def __init_subclass__(cls, **kwargs: Any) -> None: super().__init_subclass__(**kwargs) role_subclass_registry[cls.__name__] = cls - + def _reset(self): object.__setattr__(self, "_states", []) object.__setattr__(self, "_actions", []) @@ -338,7 +303,7 @@ class Role(BaseModel, metaclass=_RoleInjector): role_class = import_class(class_name=role_class_str, module_name=module_name) role = role_class(**role_info) # initiate particular Role - role.set_recovered(True) # set True to make a tag + role.set_recovered(True) # set True to make a tag role_memory = Memory.deserialize(stg_path) role.set_memory(role_memory) @@ -362,7 +327,7 @@ class Role(BaseModel, metaclass=_RoleInjector): for idx, action in enumerate(actions): if not isinstance(action, Action): ## 默认初始化 - i = action(llm=self._llm) + i = action(name="", llm=self._llm) else: if self._setting.is_human and not isinstance(action.llm, HumanProvider): logger.warning( @@ -437,24 +402,10 @@ class Role(BaseModel, metaclass=_RoleInjector): if env: env.set_subscription(self, self._subscription) - @property - def profile(self): - """Get the role description (position)""" - return self._setting.profile - - @property - def name(self): - """Get virtual user name""" - return self._setting.name - @property def subscription(self) -> Set: """The labels for messages to be consumed by the Role object.""" return self._subscription - - def set_env(self, env: "Environment"): - """Set the environment in which the role works. The role can talk to the environment and can also receive messages by observing.""" - self._rc.env = env def _get_prefix(self): """Get the role prefix""" @@ -466,7 +417,7 @@ class Role(BaseModel, metaclass=_RoleInjector): "goal": self.goal, "constraints": self.constraints }) - + async def _think(self) -> None: """Think about what to do and decide on the next action""" if len(self._actions) == 1: @@ -475,7 +426,7 @@ class Role(BaseModel, metaclass=_RoleInjector): return if self.recovered and self._rc.state >= 0: self._set_state(self._rc.state) # action to run from recovered state - self.recovered = False # avoid max_react_loop out of work + self.recovered = False # avoid max_react_loop out of work return prompt = self._get_prefix() @@ -498,7 +449,7 @@ class Role(BaseModel, metaclass=_RoleInjector): if next_state == -1: logger.info(f"End actions with {next_state=}") self._set_state(next_state) - + async def _act(self) -> Message: logger.info(f"{self._setting}: ready to {self._rc.todo}") response = await self._rc.todo.run(self._rc.important_memory) @@ -535,8 +486,8 @@ class Role(BaseModel, metaclass=_RoleInjector): if news_text: logger.debug(f"{self._setting} observed: {news_text}") return len(self._rc.news) - - def _publish_message(self, msg): + + def publish_message(self, msg): """If the role belongs to env, then the role's messages will be broadcast to env""" if not msg: return @@ -557,7 +508,7 @@ class Role(BaseModel, metaclass=_RoleInjector): Use llm to select actions in _think dynamically """ actions_taken = 0 - rsp = Message("No actions taken yet") # will be overwritten after Role _act + rsp = Message(content="No actions taken yet") # will be overwritten after Role _act while actions_taken < self._rc.max_react_loop: # think await self._think() @@ -580,7 +531,7 @@ class Role(BaseModel, metaclass=_RoleInjector): async def _plan_and_act(self) -> Message: """first plan, then execute an action sequence, i.e. _think (of a plan) -> _act -> _act -> ... Use llm to come up with the plan dynamically.""" # TODO: to be implemented - return Message("") + return Message(content="") async def react(self) -> Message: """Entry to one of three strategies by which Role reacts to the observed Message""" @@ -613,24 +564,24 @@ class Role(BaseModel, metaclass=_RoleInjector): def get_memories(self, k=0) -> list[Message]: """A wrapper to return the most recent k memories of this role, return all when k=0""" return self._rc.memory.get(k=k) - + async def run(self, with_message=None): """Observe, and think and act based on the results of the observation""" if with_message: msg = None if isinstance(with_message, str): - msg = Message(with_message) + msg = Message(content=with_message) elif isinstance(with_message, Message): msg = with_message elif isinstance(with_message, list): - msg = Message("\n".join(with_message)) + msg = Message(content="\n".join(with_message)) self.put_message(msg) if not await self._observe(): # If there is no new information, suspend and wait logger.debug(f"{self._setting}: no news. waiting.") return - + rsp = await self.react() # Reset the next action to be taken. diff --git a/metagpt/schema.py b/metagpt/schema.py index 962850547..690f64128 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -13,6 +13,8 @@ 3. Add `id` to `Message` according to Section 2.2.3.1.1 of RFC 135. """ +from __future__ import annotations + import asyncio import json import os.path @@ -20,14 +22,9 @@ import uuid from asyncio import Queue, QueueEmpty, wait_for from json import JSONDecodeError from pathlib import Path -from typing import Dict, List, Optional, Set, TypedDict -from pydantic import BaseModel, Field - -from dataclasses import dataclass, field -from typing import Type, TypedDict, Union, Optional +from typing import Dict, List, Set, TypedDict, Optional, Any from pydantic import BaseModel, Field -from pydantic.main import ModelMetaclass from metagpt.config import CONFIG from metagpt.const import ( @@ -39,15 +36,7 @@ from metagpt.const import ( TASK_FILE_REPO, ) from metagpt.logs import logger -from metagpt.utils.serialize import actionoutout_schema_to_mapping, actionoutput_mapping_to_str, \ - actionoutput_str_to_mapping -from metagpt.utils.utils import import_class - from metagpt.utils.common import any_to_str, any_to_str_set -# from metagpt.utils.serialize import actionoutout_schema_to_mapping -# from metagpt.actions.action_output import ActionOutput -# from metagpt.actions.action import Action - from metagpt.utils.serialize import actionoutout_schema_to_mapping, actionoutput_mapping_to_str, \ actionoutput_str_to_mapping from metagpt.utils.utils import import_class @@ -58,7 +47,6 @@ class RawMessage(TypedDict): role: str - class Document(BaseModel): """ Represents a document. @@ -68,7 +56,7 @@ class Document(BaseModel): filename: str = "" content: str = "" - def get_meta(self) -> "Document": + def get_meta(self) -> Document: """Get metadata of the document. :return: A new Document instance with the same root path and filename. @@ -120,7 +108,6 @@ class Message(BaseModel): def __init__(self, **kwargs): instruct_content = kwargs.get("instruct_content", None) - cause_by = kwargs.get("cause_by", None) if instruct_content and not isinstance(instruct_content, BaseModel): ic = instruct_content mapping = actionoutput_str_to_mapping(ic["mapping"]) @@ -129,9 +116,11 @@ class Message(BaseModel): ic_obj = actionoutput_class.create_model_class(class_name=ic["class"], mapping=mapping) ic_new = ic_obj(**ic["value"]) kwargs["instruct_content"] = ic_new - if cause_by and not isinstance(cause_by, ModelMetaclass): - action_class = import_class("Action", "metagpt.actions.action") - kwargs["cause_by"] = action_class.deser_class(cause_by) + + kwargs["id"] = uuid.uuid4().hex + kwargs["cause_by"] = any_to_str(kwargs.get("cause_by", "")) + kwargs["sent_from"] = any_to_str(kwargs.get("sent_from", "")) + kwargs["send_to"] = any_to_str_set(kwargs.get("send_to", {MESSAGE_ROUTE_TO_ALL})) super(Message, self).__init__(**kwargs) def __setattr__(self, key, val): @@ -156,9 +145,6 @@ class Message(BaseModel): mapping = actionoutput_mapping_to_str(mapping) obj_dict["instruct_content"] = {"class": schema["title"], "mapping": mapping, "value": ic.dict()} - cb = self.cause_by - if cb: - obj_dict["cause_by"] = cb.ser_class() return obj_dict def __str__(self): @@ -214,11 +200,24 @@ class AIMessage(Message): super().__init__(content=content, role="assistant") -class MessageQueue: +class MessageQueue(BaseModel): """Message queue which supports asynchronous updates.""" - def __init__(self): - self._queue = Queue() + _queue: Queue = Field(default_factory=Queue) + + _private_attributes = { + "_queue": Queue() + } + + class Config: + arbitrary_types_allowed = True + + def __init__(self, **kwargs: Any): + for key in self._private_attributes.keys(): + if key in kwargs: + object.__setattr__(self, key, kwargs[key]) + else: + object.__setattr__(self, key, self._private_attributes[key]) def pop(self) -> Message | None: """Pop one message from the queue.""" @@ -266,7 +265,7 @@ class MessageQueue: return json.dumps(lst) @staticmethod - def load(self, v) -> "MessageQueue": + def load(self, v) -> MessageQueue: """Convert the json string to the `MessageQueue` object.""" q = MessageQueue() try: @@ -287,7 +286,7 @@ class CodingContext(BaseModel): code_doc: Optional[Document] @staticmethod - def loads(val: str) -> "CodingContext" | None: + def loads(val: str) -> CodingContext | None: try: m = json.loads(val) return CodingContext(**m) @@ -301,7 +300,7 @@ class TestingContext(BaseModel): test_doc: Optional[Document] @staticmethod - def loads(val: str) -> "TestingContext" | None: + def loads(val: str) -> TestingContext | None: try: m = json.loads(val) return TestingContext(**m) @@ -322,7 +321,7 @@ class RunCodeContext(BaseModel): output: Optional[str] @staticmethod - def loads(val: str) -> "RunCodeContext" | None: + def loads(val: str) -> RunCodeContext | None: try: m = json.loads(val) return RunCodeContext(**m) @@ -336,7 +335,7 @@ class RunCodeResult(BaseModel): stderr: str @staticmethod - def loads(val: str) -> "RunCodeResult" | None: + def loads(val: str) -> RunCodeResult | None: try: m = json.loads(val) return RunCodeResult(**m) @@ -351,7 +350,7 @@ class CodeSummarizeContext(BaseModel): reason: str = "" @staticmethod - def loads(filenames: List) -> "CodeSummarizeContext": + def loads(filenames: List) -> CodeSummarizeContext: ctx = CodeSummarizeContext() for filename in filenames: if Path(filename).is_relative_to(SYSTEM_DESIGN_FILE_REPO): diff --git a/metagpt/team.py b/metagpt/team.py index bd02508c4..30e3dc618 100644 --- a/metagpt/team.py +++ b/metagpt/team.py @@ -8,18 +8,19 @@ Section 2.2.3.3 of RFC 135. """ from pathlib import Path + from pydantic import BaseModel, Field from metagpt.actions import UserRequirement from metagpt.config import CONFIG from metagpt.const import MESSAGE_ROUTE_TO_ALL +from metagpt.const import SERDESER_PATH from metagpt.environment import Environment from metagpt.logs import logger from metagpt.roles import Role from metagpt.schema import Message from metagpt.utils.common import NoMoneyException from metagpt.utils.utils import read_json_file, write_json_file, serialize_decorator -from metagpt.const import SERDESER_PATH class Team(BaseModel): @@ -39,9 +40,9 @@ class Team(BaseModel): stg_path = SERDESER_PATH.joinpath("team") if stg_path is None else stg_path team_info_path = stg_path.joinpath("team_info.json") - write_json_file(team_info_path, self.dict(exclude={"environment": True})) + write_json_file(team_info_path, self.dict(exclude={"env": True})) - self.environment.serialize(stg_path.joinpath("environment")) # save environment alone + self.env.serialize(stg_path.joinpath("environment")) # save environment alone @classmethod def recover(cls, stg_path: Path) -> "Team": @@ -60,7 +61,7 @@ class Team(BaseModel): # recover environment environment = Environment.deserialize(stg_path=stg_path.joinpath("environment")) - team_info.update({"environment": environment}) + team_info.update({"env": environment}) team = Team(**team_info) return team diff --git a/metagpt/utils/utils.py b/metagpt/utils/utils.py index 35df654d7..57da57b00 100644 --- a/metagpt/utils/utils.py +++ b/metagpt/utils/utils.py @@ -9,6 +9,7 @@ from pathlib import Path import importlib from tenacity import _utils import traceback +from pydantic.json import pydantic_encoder from metagpt.logs import logger @@ -46,7 +47,7 @@ def write_json_file(json_file: str, data: list, encoding=None): folder_path.mkdir(parents=True, exist_ok=True) with open(json_file, "w", encoding=encoding) as fout: - json.dump(data, fout, ensure_ascii=False, indent=4) + json.dump(data, fout, ensure_ascii=False, indent=4, default=pydantic_encoder) def import_class(class_name: str, module_name: str) -> type: From 57121ef395c2659f8b67be025e7e7fbcd621434e Mon Sep 17 00:00:00 2001 From: better629 Date: Tue, 19 Dec 2023 15:53:14 +0800 Subject: [PATCH 052/167] remove useless code and format code --- metagpt/actions/action.py | 16 ---- metagpt/actions/design_api.py | 21 ++--- metagpt/actions/prepare_documents.py | 8 ++ metagpt/actions/project_management.py | 9 +- metagpt/actions/write_prd.py | 2 +- metagpt/actions/write_prd_review.py | 26 ++++-- metagpt/environment.py | 5 +- metagpt/memory/memory.py | 10 -- metagpt/roles/product_manager.py | 8 +- metagpt/roles/project_manager.py | 2 +- metagpt/roles/role.py | 91 +------------------ metagpt/schema.py | 3 +- metagpt/utils/serialize.py | 6 -- .../serialize_deserialize/test_action.py | 8 -- 14 files changed, 50 insertions(+), 165 deletions(-) diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index a21f575ea..570863388 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -71,22 +71,6 @@ class Action(BaseModel): def __repr__(self): return self.__str__() - @classmethod - def ser_class(cls) -> dict: - """ serialize class type""" - return { - "action_class": cls.__name__, - "module_name": cls.__module__ - } - - @classmethod - def deser_class(cls, action_dict: dict): - """ deserialize class type """ - action_class_str = action_dict.pop("action_class") - module_name = action_dict.pop("module_name") - action_class = import_class(action_class_str, module_name) - return action_class - async def _aask(self, prompt: str, system_msgs: Optional[list[str]] = None) -> str: """Append default prefix""" if not system_msgs: diff --git a/metagpt/actions/design_api.py b/metagpt/actions/design_api.py index a13c5873a..c1778d53f 100644 --- a/metagpt/actions/design_api.py +++ b/metagpt/actions/design_api.py @@ -12,17 +12,11 @@ import json from pathlib import Path from typing import Optional + from pydantic import Field from metagpt.actions import Action, ActionOutput from metagpt.actions.design_api_an import DESIGN_API_NODE -from typing import List, Optional, Any - -from pydantic import Field - -from metagpt.actions import Action, ActionOutput -from metagpt.llm import LLM -from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.config import CONFIG from metagpt.const import ( DATA_API_DESIGN_FILE_REPO, @@ -31,12 +25,13 @@ from metagpt.const import ( SYSTEM_DESIGN_FILE_REPO, SYSTEM_DESIGN_PDF_FILE_REPO, ) +from metagpt.llm import LLM from metagpt.logs import logger -from metagpt.schema import Document, Documents +from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.schema import Document, Documents, Message from metagpt.utils.file_repository import FileRepository from metagpt.utils.mermaid import mermaid_to_file - NEW_REQ_TEMPLATE = """ ### Legacy Content {old_design} @@ -50,11 +45,11 @@ class WriteDesign(Action): name: str = "" context: Optional[str] = None llm: BaseGPTAPI = Field(default_factory=LLM) - desc: str = "Based on the PRD, think about the system design, and design the corresponding APIs, " - "data structures, library tables, processes, and paths. Please provide your design, feedback " - "clearly and in detail." + desc: str = "Based on the PRD, think about the system design, and design the corresponding APIs, " \ + "data structures, library tables, processes, and paths. Please provide your design, feedback " \ + "clearly and in detail." - async def run(self, with_messages, format=CONFIG.prompt_format): + async def run(self, with_messages: Message, format: str = CONFIG.prompt_format): # Use `git diff` to identify which PRD documents have been modified in the `docs/prds` directory. prds_file_repo = CONFIG.git_repo.new_file_repository(PRDS_FILE_REPO) changed_prds = prds_file_repo.changed_files diff --git a/metagpt/actions/prepare_documents.py b/metagpt/actions/prepare_documents.py index af38b7eae..6bb18be7b 100644 --- a/metagpt/actions/prepare_documents.py +++ b/metagpt/actions/prepare_documents.py @@ -9,16 +9,24 @@ """ import shutil from pathlib import Path +from typing import Optional + +from pydantic import Field from metagpt.actions import Action, ActionOutput from metagpt.config import CONFIG from metagpt.const import DEFAULT_WORKSPACE_ROOT, DOCS_FILE_REPO, REQUIREMENT_FILENAME +from metagpt.llm import LLM +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.schema import Document from metagpt.utils.file_repository import FileRepository from metagpt.utils.git_repository import GitRepository class PrepareDocuments(Action): + name: str = "PrepareDocuments" + context: Optional[str] = None + llm: BaseGPTAPI = Field(default_factory=LLM) async def run(self, with_messages, **kwargs): if not CONFIG.git_repo: diff --git a/metagpt/actions/project_management.py b/metagpt/actions/project_management.py index 98a948b64..2727f7e7f 100644 --- a/metagpt/actions/project_management.py +++ b/metagpt/actions/project_management.py @@ -11,14 +11,13 @@ """ import json -from typing import List, Optional, Any +from typing import Optional + from pydantic import Field from metagpt.actions import ActionOutput from metagpt.actions.action import Action from metagpt.actions.project_management_an import PM_NODE -from metagpt.llm import LLM -from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.config import CONFIG from metagpt.const import ( PACKAGE_REQUIREMENTS_FILENAME, @@ -26,11 +25,11 @@ from metagpt.const import ( TASK_FILE_REPO, TASK_PDF_FILE_REPO, ) +from metagpt.llm import LLM from metagpt.logs import logger +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.schema import Document, Documents from metagpt.utils.file_repository import FileRepository -from metagpt.provider.base_gpt_api import BaseGPTAPI - NEW_REQ_TEMPLATE = """ ### Legacy Content diff --git a/metagpt/actions/write_prd.py b/metagpt/actions/write_prd.py index e76e91272..f087d8650 100644 --- a/metagpt/actions/write_prd.py +++ b/metagpt/actions/write_prd.py @@ -166,7 +166,7 @@ class WritePRD(Action): if not quadrant_chart: return pathname = ( - CONFIG.git_repo.workdir / Path(COMPETITIVE_ANALYSIS_FILE_REPO) / Path(prd_doc.filename).with_suffix("") + CONFIG.git_repo.workdir / Path(COMPETITIVE_ANALYSIS_FILE_REPO) / Path(prd_doc.filename).with_suffix("") ) if not pathname.parent.exists(): pathname.parent.mkdir(parents=True, exist_ok=True) diff --git a/metagpt/actions/write_prd_review.py b/metagpt/actions/write_prd_review.py index 5ff9624c5..6ed73b6a2 100644 --- a/metagpt/actions/write_prd_review.py +++ b/metagpt/actions/write_prd_review.py @@ -5,20 +5,28 @@ @Author : alexanderwu @File : write_prd_review.py """ + +from typing import Optional + +from pydantic import Field + from metagpt.actions.action import Action +from metagpt.llm import LLM +from metagpt.provider.base_gpt_api import BaseGPTAPI class WritePRDReview(Action): - def __init__(self, name, context=None, llm=None): - super().__init__(name, context, llm) - self.prd = None - self.desc = "Based on the PRD, conduct a PRD Review, providing clear and detailed feedback" - self.prd_review_prompt_template = """ - Given the following Product Requirement Document (PRD): - {prd} + name: str = "" + context: Optional[str] = None + llm: BaseGPTAPI = Field(default_factory=LLM) + prd: Optional[str] = None + desc: str = "Based on the PRD, conduct a PRD Review, providing clear and detailed feedback" + prd_review_prompt_template: str = """ +Given the following Product Requirement Document (PRD): +{prd} - As a project manager, please review it and provide your feedback and suggestions. - """ +As a project manager, please review it and provide your feedback and suggestions. +""" async def run(self, prd): self.prd = prd diff --git a/metagpt/environment.py b/metagpt/environment.py index 4c8d7d5e5..9108cdf06 100644 --- a/metagpt/environment.py +++ b/metagpt/environment.py @@ -61,6 +61,7 @@ class Environment(BaseModel): "role_class": role.__class__.__name__, "module_name": role.__module__, "role_name": role.name, + "role_sub_tags": list(self.members.get(role)) }) role.serialize(stg_path=stg_path.joinpath(f"roles/{role.__class__.__name__}_{role.name}")) write_json_file(roles_path, roles_info) @@ -70,14 +71,13 @@ class Environment(BaseModel): @classmethod def deserialize(cls, stg_path: Path) -> "Environment": - """ stg_path: ./storage/team/environment/ """ """ stg_path: ./storage/team/environment/ """ roles_path = stg_path.joinpath("roles.json") roles_info = read_json_file(roles_path) roles = [] for role_info in roles_info: # role stored in ./environment/roles/{role_class}_{role_name} - role_path = stg_path.joinpath(f'roles/{role_info.get("role_class")}_{role_info.get("role_name")}') + role_path = stg_path.joinpath(f"roles/{role_info.get('role_class')}_{role_info.get('role_name')}") role = Role.deserialize(role_path) roles.append(role) @@ -96,7 +96,6 @@ class Environment(BaseModel): Add a role in the current environment """ role.set_env(self) - # use alias self.roles[role.profile] = role def add_roles(self, roles: Iterable[Role]): diff --git a/metagpt/memory/memory.py b/metagpt/memory/memory.py index fe70358c9..198c0970d 100644 --- a/metagpt/memory/memory.py +++ b/metagpt/memory/memory.py @@ -20,7 +20,6 @@ from metagpt.utils.utils import read_json_file, write_json_file class Memory(BaseModel): """The most basic memory: super-memory""" - storage: list[Message] = Field(default=[]) index: dict[str, list[Message]] = Field(default_factory=defaultdict(list)) @@ -33,15 +32,6 @@ class Memory(BaseModel): super(Memory, self).__init__(**kwargs) self.index = new_index - def dict(self, *args, **kwargs) -> "DictStrAny": - """ overwrite the `dict` to dump dynamic pydantic model""" - obj_dict = super(Memory, self).dict(*args, **kwargs) - new_obj_dict = copy.deepcopy(obj_dict) - new_obj_dict["index"] = {} - for action_str, value in obj_dict["index"].items(): - new_obj_dict["index"][action_str] = value - return new_obj_dict - def serialize(self, stg_path: Path): """ stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/ """ memory_path = stg_path.joinpath("memory.json") diff --git a/metagpt/roles/product_manager.py b/metagpt/roles/product_manager.py index d054b94f5..11bda2127 100644 --- a/metagpt/roles/product_manager.py +++ b/metagpt/roles/product_manager.py @@ -17,7 +17,7 @@ from metagpt.roles.role import Role class ProductManager(Role): """ - Represents a Project Manager role responsible for overseeing project execution and team efficiency. + Represents a Product Manager role responsible for product development and management. Attributes: name (str): Name of the project manager. @@ -28,11 +28,7 @@ class ProductManager(Role): name: str = "Alice" profile: str = Field(default="Product Manager") goal: str = "efficiently create a successful product" - constraints: str = "use same language as user requiremen" - - """ - Represents a Product Manager role responsible for product development and management. - """ + constraints: str = "use same language as user requirement" def __init__(self, **kwargs) -> None: super().__init__(**kwargs) diff --git a/metagpt/roles/project_manager.py b/metagpt/roles/project_manager.py index ec93e609b..f98d28cb7 100644 --- a/metagpt/roles/project_manager.py +++ b/metagpt/roles/project_manager.py @@ -24,7 +24,7 @@ class ProjectManager(Role): """ name: str = Field(default="Eve") profile: str = Field(default="Project Manager") - goal: str = "reak down tasks according to PRD/technical design, generate a task list, and analyze task " \ + goal: str = "break down tasks according to PRD/technical design, generate a task list, and analyze task " \ "dependencies to start with the prerequisite modules" constraints: str = "use same language as user requirement" diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index dbbaf8713..9b1e0bf94 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -81,22 +81,6 @@ class RoleReactMode(str, Enum): return [item.value for item in cls] -class RoleSetting(BaseModel): - """Role Settings""" - name: str = "" - profile: str = "" - goal: str = "" - constraints: str = "" - desc: str = "" - is_human: bool = False - - def __str__(self): - return f"{self.name}({self.profile})" - - def __repr__(self): - return self.__str__() - - class RoleContext(BaseModel): """Role Runtime Context""" # # env exclude=True to avoid `RecursionError: maximum recursion depth exceeded in comparison` @@ -160,7 +144,8 @@ class Role(BaseModel): "_role_id": _role_id, "_states": [], "_actions": [], - "_rc": RoleContext() + "_rc": RoleContext(), + "_subscription": set() } __hash__ = object.__hash__ # support Role as hashable type in `Environment.members` @@ -186,7 +171,7 @@ class Role(BaseModel): # 关于私有变量的初始化 https://github.com/pydantic/pydantic/issues/655 self._private_attributes["_llm"] = LLM() if not self.is_human else HumanProvider() self._private_attributes["_role_id"] = str(self._setting) - self._subscription = {any_to_str(self), name} if name else {any_to_str(self)} + self._private_attributes["_subscription"] = {any_to_str(self), self.name} if self.name else {any_to_str(self)} for key in self._private_attributes.keys(): if key in kwargs: @@ -202,64 +187,7 @@ class Role(BaseModel): else: object.__setattr__(self, key, self._private_attributes[key]) - if not self._rc.watch: - self._watch([UserRequirement]) - - # deserialize child classes dynamically for inherited `role` - object.__setattr__(self, "builtin_class_name", self.__class__.__name__) - self.__fields__["builtin_class_name"].default = self.__class__.__name__ - - def __init_subclass__(cls, **kwargs: Any) -> None: - super().__init_subclass__(**kwargs) - role_subclass_registry[cls.__name__] = cls - - # builtin variables - recovered: bool = False # to tag if a recovered role - builtin_class_name: str = "" - - _private_attributes = { - "_llm": LLM() if not is_human else HumanProvider(), - "_role_id": _role_id, - "_states": [], - "_actions": [], - "_rc": RoleContext() - } - - class Config: - arbitrary_types_allowed = True - exclude = ["_llm"] - - def __init__(self, **kwargs: Any): - for index in range(len(kwargs.get("_actions", []))): - current_action = kwargs["_actions"][index] - if isinstance(current_action, dict): - item_class_name = current_action.get("builtin_class_name", None) - for name, subclass in action_subclass_registry.items(): - registery_class_name = subclass.__fields__["builtin_class_name"].default - if item_class_name == registery_class_name: - current_action = subclass(**current_action) - break - kwargs["_actions"][index] = current_action - - super().__init__(**kwargs) - - # 关于私有变量的初始化 https://github.com/pydantic/pydantic/issues/655 - self._private_attributes["_llm"] = LLM() if not self.is_human else HumanProvider() - self._private_attributes["_role_id"] = str(self._setting) - - for key in self._private_attributes.keys(): - if key in kwargs: - object.__setattr__(self, key, kwargs[key]) - if key == "_rc": - _rc = RoleContext(**kwargs["_rc"]) - object.__setattr__(self, "_rc", _rc) - else: - if key == "_rc": - # # Warning, if use self._private_attributes["_rc"], - # # self._rc will be a shared object between roles, so init one or reset it inside `_reset` - object.__setattr__(self, key, RoleContext()) - else: - object.__setattr__(self, key, self._private_attributes[key]) + self._llm.system_prompt = self._get_prefix() # deserialize child classes dynamically for inherited `role` object.__setattr__(self, "builtin_class_name", self.__class__.__name__) @@ -341,9 +269,6 @@ class Role(BaseModel): self._actions.append(i) self._states.append(f"{idx}. {action}") - def set_react_mode(self, react_mode: RoleReactMode, max_react_loop: int = 1): - self._set_react_mode(react_mode, max_react_loop) - def _set_react_mode(self, react_mode: str, max_react_loop: int = 1): """Set strategy of the Role reacting to observed Message. Variation lies in how this Role elects action to perform during the _think stage, especially if it is capable of multiple Actions. @@ -365,9 +290,6 @@ class Role(BaseModel): if react_mode == RoleReactMode.REACT: self._rc.max_react_loop = max_react_loop - def watch(self, actions: Iterable[Type[Action]]): - self._watch(actions) - def _watch(self, actions: Iterable[Type[Action]]): """Watch Actions of interest. Role will select Messages caused by these Actions from its personal message buffer during _observe. @@ -386,9 +308,6 @@ class Role(BaseModel): if self._rc.env: # According to the routing feature plan in Chapter 2.2.3.2 of RFC 113 self._rc.env.set_subscription(self, self._subscription) - def set_state(self, state: int): - self._set_state(state) - def _set_state(self, state: int): """Update the current state.""" self._rc.state = state @@ -436,7 +355,7 @@ class Role(BaseModel): n_states=len(self._states) - 1, previous_state=self._rc.state, ) - # print(prompt) + next_state = await self._llm.aask(prompt) next_state = extract_state_value_from_output(next_state) logger.debug(f"{prompt=}") diff --git a/metagpt/schema.py b/metagpt/schema.py index 690f64128..0ec9b5c60 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -26,6 +26,7 @@ from typing import Dict, List, Set, TypedDict, Optional, Any from pydantic import BaseModel, Field +from metagpt.actions import UserRequirement from metagpt.config import CONFIG from metagpt.const import ( MESSAGE_ROUTE_CAUSE_BY, @@ -118,7 +119,7 @@ class Message(BaseModel): kwargs["instruct_content"] = ic_new kwargs["id"] = uuid.uuid4().hex - kwargs["cause_by"] = any_to_str(kwargs.get("cause_by", "")) + kwargs["cause_by"] = any_to_str(kwargs.get("cause_by", UserRequirement)) kwargs["sent_from"] = any_to_str(kwargs.get("sent_from", "")) kwargs["send_to"] = any_to_str_set(kwargs.get("send_to", {MESSAGE_ROUTE_TO_ALL})) super(Message, self).__init__(**kwargs) diff --git a/metagpt/utils/serialize.py b/metagpt/utils/serialize.py index 9a7049214..93f584057 100644 --- a/metagpt/utils/serialize.py +++ b/metagpt/utils/serialize.py @@ -68,9 +68,6 @@ def serialize_general_message(message: "Message") -> dict: mapping = actionoutput_mapping_to_str(mapping) message_cp.instruct_content = {"class": schema["title"], "mapping": mapping, "value": ic.dict()} - cb = message_cp.cause_by - if cb: - message_cp.cause_by = cb.ser_class() return message_cp.dict() @@ -103,9 +100,6 @@ def deserialize_general_message(message_dict: dict) -> "Message": ic_obj = actionoutput_class.create_model_class(class_name=ic["class"], mapping=mapping) ic_new = ic_obj(**ic["value"]) message.instruct_content = ic_new - if cause_by: - action_class = import_class("Action", "metagpt.actions.action") - message.cause_by = action_class.deser_class(cause_by) return message diff --git a/tests/metagpt/serialize_deserialize/test_action.py b/tests/metagpt/serialize_deserialize/test_action.py index 2db5d223c..63d8e7b7c 100644 --- a/tests/metagpt/serialize_deserialize/test_action.py +++ b/tests/metagpt/serialize_deserialize/test_action.py @@ -25,11 +25,3 @@ async def test_action_deserialize(): assert new_action.name == "" assert new_action.llm == LLM() assert len(await new_action._aask("who are you")) > 0 - - -def test_action_serdeser(): - action_info = WriteTest.ser_class() - assert action_info["action_class"] == "WriteTest" - - action_class = Action.deser_class(action_info) - assert action_class == WriteTest From 93745b85ccfbe7b953c17a36867dc823ff2699c5 Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 16:54:06 +0800 Subject: [PATCH 053/167] refine config --- config/config.yaml | 2 +- metagpt/config.py | 51 +++++++++++++++++++------------ metagpt/provider/anthropic_api.py | 4 +-- 3 files changed, 34 insertions(+), 23 deletions(-) diff --git a/config/config.yaml b/config/config.yaml index 8fd208c59..9a7207c1a 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -20,7 +20,7 @@ RPM: 10 #SPARK_URL : "ws://spark-api.xf-yun.com/v2.1/chat" #### if Anthropic -#Anthropic_API_KEY: "YOUR_API_KEY" +#ANTHROPIC_API_KEY: "YOUR_API_KEY" #### if AZURE, check https://github.com/openai/openai-cookbook/blob/main/examples/azure/chat.ipynb #### You can use ENGINE or DEPLOYMENT mode diff --git a/metagpt/config.py b/metagpt/config.py index 629a5b797..702a2ddc9 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -46,30 +46,41 @@ class Config(metaclass=Singleton): def __init__(self, yaml_file=default_yaml_file): self._init_with_config_files_and_env(yaml_file) - logger.debug("Config loading done.") self._update() + logger.debug("Config loading done.") logger.info(f"OpenAI API Model: {self.openai_api_model}") + @staticmethod + def _is_valid_llm_key(k) -> bool: + return k and k != "YOUR_API_KEY" + + def _check_llm_exists(self): + if not any( + [ + self._is_valid_llm_key(self.openai_api_key), + self._is_valid_llm_key(self.anthropic_api_key), + self._is_valid_llm_key(self.zhipuai_api_key), + self._is_valid_llm_key(self.fireworks_api_key), + self.open_llm_api_base, + ] + ): + raise NotConfiguredException( + "Set OPENAI_API_KEY or Anthropic_API_KEY or ZHIPUAI_API_KEY " + "or FIREWORKS_API_KEY or OPEN_LLM_API_BASE" + ) + def _update(self): # logger.info("Config loading done.") self.global_proxy = self._get("GLOBAL_PROXY") + self.openai_api_key = self._get("OPENAI_API_KEY") - self.anthropic_api_key = self._get("Anthropic_API_KEY") + self.anthropic_api_key = self._get("ANTHROPIC_API_KEY") self.zhipuai_api_key = self._get("ZHIPUAI_API_KEY") self.open_llm_api_base = self._get("OPEN_LLM_API_BASE") self.open_llm_api_model = self._get("OPEN_LLM_API_MODEL") self.fireworks_api_key = self._get("FIREWORKS_API_KEY") - if ( - (not self.openai_api_key or "YOUR_API_KEY" == self.openai_api_key) - and (not self.anthropic_api_key or "YOUR_API_KEY" == self.anthropic_api_key) - and (not self.zhipuai_api_key or "YOUR_API_KEY" == self.zhipuai_api_key) - and (not self.open_llm_api_base) - and (not self.fireworks_api_key or "YOUR_API_KEY" == self.fireworks_api_key) - ): - raise NotConfiguredException( - "Set OPENAI_API_KEY or Anthropic_API_KEY or ZHIPUAI_API_KEY first " - "or FIREWORKS_API_KEY or OPEN_LLM_API_BASE" - ) + self._check_llm_exists() + self.openai_api_base = self._get("OPENAI_API_BASE") self.openai_proxy = self._get("OPENAI_PROXY") or self.global_proxy self.openai_api_type = self._get("OPENAI_API_TYPE") @@ -89,7 +100,7 @@ class Config(metaclass=Singleton): self.fireworks_api_base = self._get("FIREWORKS_API_BASE") self.fireworks_api_model = self._get("FIREWORKS_API_MODEL") - self.claude_api_key = self._get("Anthropic_API_KEY") + self.claude_api_key = self._get("ANTHROPIC_API_KEY") self.serpapi_api_key = self._get("SERPAPI_API_KEY") self.serper_api_key = self._get("SERPER_API_KEY") self.google_api_key = self._get("GOOGLE_API_KEY") @@ -141,8 +152,8 @@ class Config(metaclass=Singleton): @staticmethod def _get(*args, **kwargs): - m = OPTIONS.get() - return m.get(*args, **kwargs) + i = OPTIONS.get() + return i.get(*args, **kwargs) def get(self, key, *args, **kwargs): """Search for a value in config/key.yaml, config/config.yaml, and env; raise an error if not found""" @@ -155,8 +166,8 @@ class Config(metaclass=Singleton): OPTIONS.get()[name] = value def __getattr__(self, name: str) -> Any: - m = OPTIONS.get() - return m.get(name) + i = OPTIONS.get() + return i.get(name) def set_context(self, options: dict): """Update current config""" @@ -175,8 +186,8 @@ class Config(metaclass=Singleton): def new_environ(self): """Return a new os.environ object""" env = os.environ.copy() - m = self.options - env.update({k: v for k, v in m.items() if isinstance(v, str)}) + i = self.options + env.update({k: v for k, v in i.items() if isinstance(v, str)}) return env diff --git a/metagpt/provider/anthropic_api.py b/metagpt/provider/anthropic_api.py index 03802a716..f5b06c855 100644 --- a/metagpt/provider/anthropic_api.py +++ b/metagpt/provider/anthropic_api.py @@ -14,7 +14,7 @@ from metagpt.config import CONFIG class Claude2: def ask(self, prompt): - client = Anthropic(api_key=CONFIG.claude_api_key) + client = Anthropic(api_key=CONFIG.anthropic_api_key) res = client.completions.create( model="claude-2", @@ -24,7 +24,7 @@ class Claude2: return res.completion async def aask(self, prompt): - client = Anthropic(api_key=CONFIG.claude_api_key) + client = Anthropic(api_key=CONFIG.anthropic_api_key) res = client.completions.create( model="claude-2", From 7f04ec2060da2ccdc3ca72a4d5e7e60377958b7d Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 17:06:07 +0800 Subject: [PATCH 054/167] refine code --- metagpt/config.py | 8 ++++++++ metagpt/repo_parser.py | 2 +- metagpt/startup.py | 9 +++------ 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/metagpt/config.py b/metagpt/config.py index 702a2ddc9..48ac82a3a 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -130,6 +130,14 @@ class Config(metaclass=Singleton): self.workspace_path = Path(self._get("WORKSPACE_PATH", DEFAULT_WORKSPACE_ROOT)) self._ensure_workspace_exists() + def update_via_cli(self, project_path, project_name, inc, reqa_file, max_auto_summarize_code): + """update config via cli""" + self.project_path = project_path + self.project_name = project_name + self.inc = inc + self.reqa_file = reqa_file + self.max_auto_summarize_code = max_auto_summarize_code + def _ensure_workspace_exists(self): self.workspace_path.mkdir(parents=True, exist_ok=True) logger.debug(f"WORKSPACE_PATH set to {self.workspace_path}") diff --git a/metagpt/repo_parser.py b/metagpt/repo_parser.py index 9a1218ef1..3524a5bce 100644 --- a/metagpt/repo_parser.py +++ b/metagpt/repo_parser.py @@ -96,4 +96,4 @@ def error(): if __name__ == "__main__": - error() + main() diff --git a/metagpt/startup.py b/metagpt/startup.py index f930c386b..047f35cf6 100644 --- a/metagpt/startup.py +++ b/metagpt/startup.py @@ -27,7 +27,8 @@ def startup( reqa_file: str = typer.Option(default="", help="Specify the source file name for rewriting the quality test code."), max_auto_summarize_code: int = typer.Option( default=-1, - help="The maximum number of times the 'SummarizeCode' action is automatically invoked, with -1 indicating unlimited. This parameter is used for debugging the workflow.", + help="The maximum number of times the 'SummarizeCode' action is automatically invoked, with -1 indicating " + "unlimited. This parameter is used for debugging the workflow.", ), ): """Run a startup. Be a boss.""" @@ -41,14 +42,10 @@ def startup( from metagpt.team import Team # Use in the PrepareDocuments action according to Section 2.2.3.5.1 of RFC 135. - CONFIG.project_path = project_path if project_path: inc = True project_name = project_name or Path(project_path).name - CONFIG.project_name = project_name - CONFIG.inc = inc - CONFIG.reqa_file = reqa_file - CONFIG.max_auto_summarize_code = max_auto_summarize_code + CONFIG.update_via_cli(project_path, project_name, inc, reqa_file, max_auto_summarize_code) company = Team() company.hire( From 2bae7f2bfb116d9deeab3e6d6237da0a12bdd2be Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 17:11:02 +0800 Subject: [PATCH 055/167] refine code --- metagpt/config.py | 13 +++++++++++++ metagpt/startup.py | 5 ----- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/metagpt/config.py b/metagpt/config.py index 48ac82a3a..bdf580a1f 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -45,6 +45,7 @@ class Config(metaclass=Singleton): default_yaml_file = METAGPT_ROOT / "config/config.yaml" def __init__(self, yaml_file=default_yaml_file): + self._init_cli_paras() self._init_with_config_files_and_env(yaml_file) self._update() logger.debug("Config loading done.") @@ -130,8 +131,20 @@ class Config(metaclass=Singleton): self.workspace_path = Path(self._get("WORKSPACE_PATH", DEFAULT_WORKSPACE_ROOT)) self._ensure_workspace_exists() + def _init_cli_paras(self): + self.project_path = None + self.project_name = None + self.inc = None + self.reqa_file = None + self.max_auto_summarize_code = None + def update_via_cli(self, project_path, project_name, inc, reqa_file, max_auto_summarize_code): """update config via cli""" + + # Use in the PrepareDocuments action according to Section 2.2.3.5.1 of RFC 135. + if project_path: + inc = True + project_name = project_name or Path(project_path).name self.project_path = project_path self.project_name = project_name self.inc = inc diff --git a/metagpt/startup.py b/metagpt/startup.py index 047f35cf6..37526dbcc 100644 --- a/metagpt/startup.py +++ b/metagpt/startup.py @@ -1,7 +1,6 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- import asyncio -from pathlib import Path import typer @@ -41,10 +40,6 @@ def startup( ) from metagpt.team import Team - # Use in the PrepareDocuments action according to Section 2.2.3.5.1 of RFC 135. - if project_path: - inc = True - project_name = project_name or Path(project_path).name CONFIG.update_via_cli(project_path, project_name, inc, reqa_file, max_auto_summarize_code) company = Team() From 1213c5f88fe2ab257681d7f383e311c6bcbff925 Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 17:14:50 +0800 Subject: [PATCH 056/167] fix comment --- metagpt/team.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/metagpt/team.py b/metagpt/team.py index a5c405f80..ddd145269 100644 --- a/metagpt/team.py +++ b/metagpt/team.py @@ -21,8 +21,8 @@ from metagpt.utils.common import NoMoneyException class Team(BaseModel): """ - Team: Possesses one or more roles (agents), SOP (Standard Operating Procedures), and a platform for instant messaging, - dedicated to perform any multi-agent activity, such as collaboratively writing executable code. + Team: Possesses one or more roles (agents), SOP (Standard Operating Procedures), and a env for instant messaging, + dedicated to env any multi-agent activity, such as collaboratively writing executable code. """ env: Environment = Field(default_factory=Environment) From f27461f7582ec1143f43718ae79373187e0c7684 Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 17:55:34 +0800 Subject: [PATCH 057/167] add llm provider registry --- metagpt/config.py | 57 +++++++++++++---------- metagpt/llm.py | 21 +-------- metagpt/provider/fireworks_api.py | 4 +- metagpt/provider/llm_provider_registry.py | 34 ++++++++++++++ metagpt/provider/open_llm_api.py | 4 +- metagpt/provider/openai_api.py | 4 +- metagpt/provider/spark_api.py | 4 +- metagpt/provider/zhipuai_api.py | 4 +- metagpt/schema.py | 10 ++-- 9 files changed, 89 insertions(+), 53 deletions(-) create mode 100644 metagpt/provider/llm_provider_registry.py diff --git a/metagpt/config.py b/metagpt/config.py index bdf580a1f..a0d61b39f 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -8,6 +8,7 @@ Provide configuration, singleton """ import os from copy import deepcopy +from enum import Enum from pathlib import Path from typing import Any @@ -31,6 +32,15 @@ class NotConfiguredException(Exception): super().__init__(self.message) +class LLMProviderEnum(Enum): + OPENAI = "openai" + ANTHROPIC = "anthropic" + SPARK = "spark" + ZHIPUAI = "zhipuai" + FIREWORKS = "fireworks" + OPEN_LLM = "open_llm" + + class Config(metaclass=Singleton): """ Regular usage method: @@ -45,31 +55,37 @@ class Config(metaclass=Singleton): default_yaml_file = METAGPT_ROOT / "config/config.yaml" def __init__(self, yaml_file=default_yaml_file): - self._init_cli_paras() + # cli paras + self.project_path = "" + self.project_name = "" + self.inc = False + self.reqa_file = "" + self.max_auto_summarize_code = 0 + self._init_with_config_files_and_env(yaml_file) self._update() logger.debug("Config loading done.") logger.info(f"OpenAI API Model: {self.openai_api_model}") + def get_default_llm_provider_enum(self): + if self._is_valid_llm_key(self.openai_api_key): + llm = LLMProviderEnum.OPENAI + elif self._is_valid_llm_key(self.anthropic_api_key): + llm = LLMProviderEnum.ANTHROPIC + elif self._is_valid_llm_key(self.zhipuai_api_key): + llm = LLMProviderEnum.ZHIPUAI + elif self._is_valid_llm_key(self.fireworks_api_key): + llm = LLMProviderEnum.FIREWORKS + elif self.open_llm_api_base: + llm = LLMProviderEnum.OPEN_LLM + else: + raise NotConfiguredException("You should config a LLM configuration first") + return llm + @staticmethod def _is_valid_llm_key(k) -> bool: return k and k != "YOUR_API_KEY" - def _check_llm_exists(self): - if not any( - [ - self._is_valid_llm_key(self.openai_api_key), - self._is_valid_llm_key(self.anthropic_api_key), - self._is_valid_llm_key(self.zhipuai_api_key), - self._is_valid_llm_key(self.fireworks_api_key), - self.open_llm_api_base, - ] - ): - raise NotConfiguredException( - "Set OPENAI_API_KEY or Anthropic_API_KEY or ZHIPUAI_API_KEY " - "or FIREWORKS_API_KEY or OPEN_LLM_API_BASE" - ) - def _update(self): # logger.info("Config loading done.") self.global_proxy = self._get("GLOBAL_PROXY") @@ -80,7 +96,7 @@ class Config(metaclass=Singleton): self.open_llm_api_base = self._get("OPEN_LLM_API_BASE") self.open_llm_api_model = self._get("OPEN_LLM_API_MODEL") self.fireworks_api_key = self._get("FIREWORKS_API_KEY") - self._check_llm_exists() + _ = self.get_default_llm_provider_enum() self.openai_api_base = self._get("OPENAI_API_BASE") self.openai_proxy = self._get("OPENAI_PROXY") or self.global_proxy @@ -131,13 +147,6 @@ class Config(metaclass=Singleton): self.workspace_path = Path(self._get("WORKSPACE_PATH", DEFAULT_WORKSPACE_ROOT)) self._ensure_workspace_exists() - def _init_cli_paras(self): - self.project_path = None - self.project_name = None - self.inc = None - self.reqa_file = None - self.max_auto_summarize_code = None - def update_via_cli(self, project_path, project_name, inc, reqa_file, max_auto_summarize_code): """update config via cli""" diff --git a/metagpt/llm.py b/metagpt/llm.py index 7c0ad7975..e0c0716de 100644 --- a/metagpt/llm.py +++ b/metagpt/llm.py @@ -8,12 +8,8 @@ from metagpt.config import CONFIG from metagpt.provider.base_gpt_api import BaseGPTAPI -from metagpt.provider.fireworks_api import FireWorksGPTAPI from metagpt.provider.human_provider import HumanProvider -from metagpt.provider.open_llm_api import OpenLLMGPTAPI -from metagpt.provider.openai_api import OpenAIGPTAPI -from metagpt.provider.spark_api import SparkAPI -from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI +from metagpt.provider.llm_provider_registry import LLMProviderRegistry _ = HumanProvider() # Avoid pre-commit error @@ -21,17 +17,4 @@ _ = HumanProvider() # Avoid pre-commit error def LLM() -> BaseGPTAPI: """initialize different LLM instance according to the key field existence""" # TODO a little trick, can use registry to initialize LLM instance further - if CONFIG.openai_api_key: - llm = OpenAIGPTAPI() - elif CONFIG.spark_api_key: - llm = SparkAPI() - elif CONFIG.zhipuai_api_key: - llm = ZhiPuAIGPTAPI() - elif CONFIG.open_llm_api_base: - llm = OpenLLMGPTAPI() - elif CONFIG.fireworks_api_key: - llm = FireWorksGPTAPI() - else: - raise RuntimeError("You should config a LLM configuration first") - - return llm + return LLMProviderRegistry.get_provider(CONFIG.get_default_llm_provider_enum()) diff --git a/metagpt/provider/fireworks_api.py b/metagpt/provider/fireworks_api.py index 47ac9cf61..a76151666 100644 --- a/metagpt/provider/fireworks_api.py +++ b/metagpt/provider/fireworks_api.py @@ -4,10 +4,12 @@ import openai -from metagpt.config import CONFIG +from metagpt.config import CONFIG, LLMProviderEnum +from metagpt.provider.llm_provider_registry import register_provider from metagpt.provider.openai_api import CostManager, OpenAIGPTAPI, RateLimiter +@register_provider(LLMProviderEnum.FIREWORKS) class FireWorksGPTAPI(OpenAIGPTAPI): def __init__(self): self.__init_fireworks(CONFIG) diff --git a/metagpt/provider/llm_provider_registry.py b/metagpt/provider/llm_provider_registry.py new file mode 100644 index 000000000..2b3ef93a3 --- /dev/null +++ b/metagpt/provider/llm_provider_registry.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/19 17:26 +@Author : alexanderwu +@File : llm_provider_registry.py +""" +from metagpt.config import LLMProviderEnum + + +class LLMProviderRegistry: + def __init__(self): + self.providers = {} + + def register(self, key, provider_cls): + self.providers[key] = provider_cls + + def get_provider(self, enum: LLMProviderEnum): + """get provider instance according to the enum""" + return self.providers[enum]() + + +# Registry instance +LLM_REGISTRY = LLMProviderRegistry() + + +def register_provider(key): + """register provider to registry""" + + def decorator(cls): + LLM_REGISTRY.register(key, cls) + return cls + + return decorator diff --git a/metagpt/provider/open_llm_api.py b/metagpt/provider/open_llm_api.py index f421e30c8..bada0e294 100644 --- a/metagpt/provider/open_llm_api.py +++ b/metagpt/provider/open_llm_api.py @@ -4,8 +4,9 @@ import openai -from metagpt.config import CONFIG +from metagpt.config import CONFIG, LLMProviderEnum from metagpt.logs import logger +from metagpt.provider.llm_provider_registry import register_provider from metagpt.provider.openai_api import CostManager, OpenAIGPTAPI, RateLimiter @@ -31,6 +32,7 @@ class OpenLLMCostManager(CostManager): CONFIG.total_cost = self.total_cost +@register_provider(LLMProviderEnum.OPEN_LLM) class OpenLLMGPTAPI(OpenAIGPTAPI): def __init__(self): self.__init_openllm(CONFIG) diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 86054881e..0be70b3ca 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -18,10 +18,11 @@ from tenacity import ( wait_random_exponential, ) -from metagpt.config import CONFIG +from metagpt.config import CONFIG, LLMProviderEnum from metagpt.logs import logger from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA, GENERAL_TOOL_CHOICE +from metagpt.provider.llm_provider_registry import register_provider from metagpt.schema import Message from metagpt.utils.singleton import Singleton from metagpt.utils.token_counter import ( @@ -137,6 +138,7 @@ See FAQ 5.8 raise retry_state.outcome.exception() +@register_provider(LLMProviderEnum.OPENAI) class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): """ Check https://platform.openai.com/examples for examples diff --git a/metagpt/provider/spark_api.py b/metagpt/provider/spark_api.py index 60c86f4dc..484fa7956 100644 --- a/metagpt/provider/spark_api.py +++ b/metagpt/provider/spark_api.py @@ -19,11 +19,13 @@ from wsgiref.handlers import format_date_time import websocket # 使用websocket_client -from metagpt.config import CONFIG +from metagpt.config import CONFIG, LLMProviderEnum from metagpt.logs import logger from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.llm_provider_registry import register_provider +@register_provider(LLMProviderEnum.SPARK) class SparkAPI(BaseGPTAPI): def __init__(self): logger.warning("当前方法无法支持异步运行。当你使用acompletion时,并不能并行访问。") diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py index 92119b764..eef0e51e1 100644 --- a/metagpt/provider/zhipuai_api.py +++ b/metagpt/provider/zhipuai_api.py @@ -16,9 +16,10 @@ from tenacity import ( wait_random_exponential, ) -from metagpt.config import CONFIG +from metagpt.config import CONFIG, LLMProviderEnum from metagpt.logs import logger from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.llm_provider_registry import register_provider from metagpt.provider.openai_api import CostManager, log_and_reraise from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI @@ -30,6 +31,7 @@ class ZhiPuEvent(Enum): FINISH = "finish" +@register_provider(LLMProviderEnum.ZHIPUAI) class ZhiPuAIGPTAPI(BaseGPTAPI): """ Refs to `https://open.bigmodel.cn/dev/api#chatglm_turbo` diff --git a/metagpt/schema.py b/metagpt/schema.py index b24f114b0..aacc2cebb 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -164,8 +164,8 @@ class Message(BaseModel): @handle_exception(exception_type=JSONDecodeError, default_return=None) def load(val): """Convert the json string to object.""" - d = json.loads(val) - return Message(**d) + i = json.loads(val) + return Message(**i) class UserMessage(Message): @@ -247,16 +247,16 @@ class MessageQueue: return json.dumps(lst) @staticmethod - def load(i) -> "MessageQueue": + def load(data) -> "MessageQueue": """Convert the json string to the `MessageQueue` object.""" queue = MessageQueue() try: - lst = json.loads(i) + lst = json.loads(data) for i in lst: msg = Message(**i) queue.push(msg) except JSONDecodeError as e: - logger.warning(f"JSON load failed: {i}, error:{e}") + logger.warning(f"JSON load failed: {data}, error:{e}") return queue From 25b8a6dcef768ed1e45489e2dd3a5462f37fd593 Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 18:02:51 +0800 Subject: [PATCH 058/167] make registry work --- metagpt/llm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/metagpt/llm.py b/metagpt/llm.py index e0c0716de..60f110a00 100644 --- a/metagpt/llm.py +++ b/metagpt/llm.py @@ -9,7 +9,7 @@ from metagpt.config import CONFIG from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.provider.human_provider import HumanProvider -from metagpt.provider.llm_provider_registry import LLMProviderRegistry +from metagpt.provider.llm_provider_registry import LLM_REGISTRY _ = HumanProvider() # Avoid pre-commit error @@ -17,4 +17,4 @@ _ = HumanProvider() # Avoid pre-commit error def LLM() -> BaseGPTAPI: """initialize different LLM instance according to the key field existence""" # TODO a little trick, can use registry to initialize LLM instance further - return LLMProviderRegistry.get_provider(CONFIG.get_default_llm_provider_enum()) + return LLM_REGISTRY.get_provider(CONFIG.get_default_llm_provider_enum()) From 77735d6e612422911dedd86c40aebb2b7c69dcb3 Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 18:04:12 +0800 Subject: [PATCH 059/167] make registry work --- metagpt/llm.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/metagpt/llm.py b/metagpt/llm.py index 60f110a00..8763642f0 100644 --- a/metagpt/llm.py +++ b/metagpt/llm.py @@ -6,7 +6,7 @@ @File : llm.py """ -from metagpt.config import CONFIG +from metagpt.config import CONFIG, LLMProviderEnum from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.provider.human_provider import HumanProvider from metagpt.provider.llm_provider_registry import LLM_REGISTRY @@ -14,7 +14,6 @@ from metagpt.provider.llm_provider_registry import LLM_REGISTRY _ = HumanProvider() # Avoid pre-commit error -def LLM() -> BaseGPTAPI: - """initialize different LLM instance according to the key field existence""" - # TODO a little trick, can use registry to initialize LLM instance further - return LLM_REGISTRY.get_provider(CONFIG.get_default_llm_provider_enum()) +def LLM(provider: LLMProviderEnum = CONFIG.get_default_llm_provider_enum()) -> BaseGPTAPI: + """get the default llm provider""" + return LLM_REGISTRY.get_provider(provider) From 3baf47a3d64ebf9278ec5bee5e6ec524fdf9f666 Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 18:50:55 +0800 Subject: [PATCH 060/167] refine code for isinstance --- metagpt/actions/write_prd.py | 2 +- metagpt/roles/role.py | 2 +- metagpt/roles/searcher.py | 2 +- metagpt/utils/common.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/metagpt/actions/write_prd.py b/metagpt/actions/write_prd.py index bb0cf8fb9..adba7decb 100644 --- a/metagpt/actions/write_prd.py +++ b/metagpt/actions/write_prd.py @@ -182,7 +182,7 @@ class WritePRD(Action): return if not CONFIG.project_name: - if isinstance(prd, ActionOutput) or isinstance(prd, ActionNode): + if isinstance(prd, (ActionOutput, ActionNode)): ws_name = prd.instruct_content.dict()["Project Name"] else: ws_name = CodeParser.parse_str(block="Project Name", text=prd) diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 48688ad5f..e13bf454b 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -267,7 +267,7 @@ class Role: async def _act(self) -> Message: logger.info(f"{self._setting}: ready to {self._rc.todo}") response = await self._rc.todo.run(self._rc.important_memory) - if isinstance(response, ActionOutput) or isinstance(response, ActionNode): + if isinstance(response, (ActionOutput, ActionNode)): msg = Message( content=response.content, instruct_content=response.instruct_content, diff --git a/metagpt/roles/searcher.py b/metagpt/roles/searcher.py index 5760202ff..31de8e896 100644 --- a/metagpt/roles/searcher.py +++ b/metagpt/roles/searcher.py @@ -59,7 +59,7 @@ class Searcher(Role): logger.info(f"{self._setting}: ready to {self._rc.todo}") response = await self._rc.todo.run(self._rc.memory.get(k=0)) - if isinstance(response, ActionOutput) or isinstance(response, ActionNode): + if isinstance(response, (ActionOutput, ActionNode)): msg = Message( content=response.content, instruct_content=response.instruct_content, diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index bf435b74f..fa18694e3 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -197,7 +197,7 @@ class OutputParser: result = ast.literal_eval(structure_text) # Ensure the result matches the specified data type - if isinstance(result, list) or isinstance(result, dict): + if isinstance(result, (list, dict)): return result raise ValueError(f"The extracted structure is not a {data_type}.") From 5aa4ef5d836771b3335ded771626e44dfce74c2c Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 18:54:04 +0800 Subject: [PATCH 061/167] fix typo --- metagpt/config.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/metagpt/config.py b/metagpt/config.py index d4e85ca7b..766024222 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -55,8 +55,7 @@ class Config(metaclass=Singleton): default_yaml_file = METAGPT_ROOT / "config/config.yaml" def __init__(self, yaml_file=default_yaml_file): - - golbal_options = OPTIONS.get() + global_options = OPTIONS.get() # cli paras self.project_path = "" self.project_name = "" @@ -66,7 +65,7 @@ class Config(metaclass=Singleton): self._init_with_config_files_and_env(yaml_file) self._update() - golbal_options.update(OPTIONS.get()) + global_options.update(OPTIONS.get()) logger.debug("Config loading done.") logger.info(f"OpenAI API Model: {self.openai_api_model}") From 9d1b628bce1de85b401bbb781c75707f7774dfba Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 19:00:20 +0800 Subject: [PATCH 062/167] refine cli --- metagpt/startup.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/metagpt/startup.py b/metagpt/startup.py index a89b9c5e9..d6f3397bc 100644 --- a/metagpt/startup.py +++ b/metagpt/startup.py @@ -6,7 +6,7 @@ import typer from metagpt.config import CONFIG -app = typer.Typer() +app = typer.Typer(add_completion=False) @app.command() @@ -23,7 +23,9 @@ def startup( default="", help="Specify the directory path of the old version project to fulfill the " "incremental requirements.", ), - reqa_file: str = typer.Option(default="", help="Specify the source file name for rewriting the quality test code."), + reqa_file: str = typer.Option( + default="", help="Specify the source file name for rewriting the quality assurance " "code." + ), max_auto_summarize_code: int = typer.Option( default=0, help="The maximum number of times the 'SummarizeCode' action is automatically invoked, with -1 indicating " From 505133cacc587c5894f10bed149d774c41b857e2 Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 19:00:39 +0800 Subject: [PATCH 063/167] refine cli --- metagpt/startup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/metagpt/startup.py b/metagpt/startup.py index d6f3397bc..a1af90ffc 100644 --- a/metagpt/startup.py +++ b/metagpt/startup.py @@ -21,10 +21,10 @@ def startup( inc: bool = typer.Option(default=False, help="Incremental mode. Use it to coop with existing repo."), project_path: str = typer.Option( default="", - help="Specify the directory path of the old version project to fulfill the " "incremental requirements.", + help="Specify the directory path of the old version project to fulfill the incremental requirements.", ), reqa_file: str = typer.Option( - default="", help="Specify the source file name for rewriting the quality assurance " "code." + default="", help="Specify the source file name for rewriting the quality assurance code." ), max_auto_summarize_code: int = typer.Option( default=0, From 6dfa4e2c9e44d8db8e8e1c67646ae88d4547c968 Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 19:15:30 +0800 Subject: [PATCH 064/167] fix pylint --- examples/agent_creator.py | 9 ++++----- metagpt/memory/longterm_memory.py | 10 +++++----- metagpt/memory/memory_storage.py | 2 +- metagpt/roles/product_manager.py | 2 +- metagpt/roles/qa_engineer.py | 2 +- 5 files changed, 12 insertions(+), 13 deletions(-) diff --git a/examples/agent_creator.py b/examples/agent_creator.py index 05417d24a..26af8a287 100644 --- a/examples/agent_creator.py +++ b/examples/agent_creator.py @@ -12,9 +12,8 @@ from metagpt.logs import logger from metagpt.roles import Role from metagpt.schema import Message -with open(METAGPT_ROOT / "examples/build_customized_agent.py", "r") as f: - # use official example script to guide AgentCreator - MULTI_ACTION_AGENT_CODE_EXAMPLE = f.read() +EXAMPLE_CODE_FILE = METAGPT_ROOT / "examples/build_customized_agent.py" +MULTI_ACTION_AGENT_CODE_EXAMPLE = EXAMPLE_CODE_FILE.read_text() class CreateAgent(Action): @@ -50,8 +49,8 @@ class CreateAgent(Action): match = re.search(pattern, rsp, re.DOTALL) code_text = match.group(1) if match else "" CONFIG.workspace_path.mkdir(parents=True, exist_ok=True) - with open(CONFIG.workspace_path / "agent_created_agent.py", "w") as f: - f.write(code_text) + new_file = CONFIG.workspace_path / "agent_created_agent.py" + new_file.write_text(code_text) return code_text diff --git a/metagpt/memory/longterm_memory.py b/metagpt/memory/longterm_memory.py index 22032a86e..ab2214261 100644 --- a/metagpt/memory/longterm_memory.py +++ b/metagpt/memory/longterm_memory.py @@ -19,7 +19,7 @@ class LongTermMemory(Memory): def __init__(self): self.memory_storage: MemoryStorage = MemoryStorage() - super(LongTermMemory, self).__init__() + super().__init__() self.rc = None # RoleContext self.msg_from_recover = False @@ -37,7 +37,7 @@ class LongTermMemory(Memory): self.msg_from_recover = False def add(self, message: Message): - super(LongTermMemory, self).add(message) + super().add(message) for action in self.rc.watch: if message.cause_by == action and not self.msg_from_recover: # currently, only add role's watching messages to its memory_storage @@ -50,7 +50,7 @@ class LongTermMemory(Memory): 1. find the short-term memory(stm) news 2. furthermore, filter out similar messages based on ltm(long-term memory), get the final news """ - stm_news = super(LongTermMemory, self).find_news(observed, k=k) # shot-term memory news + stm_news = super().find_news(observed, k=k) # shot-term memory news if not self.memory_storage.is_initialized: # memory_storage hasn't initialized, use default `find_news` to get stm_news return stm_news @@ -64,9 +64,9 @@ class LongTermMemory(Memory): return ltm_news[-k:] def delete(self, message: Message): - super(LongTermMemory, self).delete(message) + super().delete(message) # TODO delete message in memory_storage def clear(self): - super(LongTermMemory, self).clear() + super().clear() self.memory_storage.clean() diff --git a/metagpt/memory/memory_storage.py b/metagpt/memory/memory_storage.py index a213f6d7a..fafb33568 100644 --- a/metagpt/memory/memory_storage.py +++ b/metagpt/memory/memory_storage.py @@ -58,7 +58,7 @@ class MemoryStorage(FaissStore): return index_fpath, storage_fpath def persist(self): - super(MemoryStorage, self).persist() + super().persist() logger.debug(f"Agent {self.role_id} persist memory into local") def add(self, message: Message) -> bool: diff --git a/metagpt/roles/product_manager.py b/metagpt/roles/product_manager.py index e5e9f2b5e..7858d2caa 100644 --- a/metagpt/roles/product_manager.py +++ b/metagpt/roles/product_manager.py @@ -54,4 +54,4 @@ class ProductManager(Role): return self._rc.todo async def _observe(self, ignore_memory=False) -> int: - return await super(ProductManager, self)._observe(ignore_memory=True) + return await super()._observe(ignore_memory=True) diff --git a/metagpt/roles/qa_engineer.py b/metagpt/roles/qa_engineer.py index 4439b9b19..71b474a3b 100644 --- a/metagpt/roles/qa_engineer.py +++ b/metagpt/roles/qa_engineer.py @@ -178,4 +178,4 @@ class QaEngineer(Role): async def _observe(self, ignore_memory=False) -> int: # This role has events that trigger and execute themselves based on conditions, and cannot rely on the # content of memory to activate. - return await super(QaEngineer, self)._observe(ignore_memory=True) + return await super()._observe(ignore_memory=True) From c12cd7b9c6bd2d900fbd70072cd9731b86486e1b Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 19:25:01 +0800 Subject: [PATCH 065/167] refine code --- metagpt/config.py | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/metagpt/config.py b/metagpt/config.py index 766024222..80a3a28f4 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -67,25 +67,23 @@ class Config(metaclass=Singleton): self._update() global_options.update(OPTIONS.get()) logger.debug("Config loading done.") - logger.info(f"OpenAI API Model: {self.openai_api_model}") - def get_default_llm_provider_enum(self): - if self._is_valid_llm_key(self.openai_api_key): - llm = LLMProviderEnum.OPENAI - elif self._is_valid_llm_key(self.anthropic_api_key): - llm = LLMProviderEnum.ANTHROPIC - elif self._is_valid_llm_key(self.zhipuai_api_key): - llm = LLMProviderEnum.ZHIPUAI - elif self._is_valid_llm_key(self.fireworks_api_key): - llm = LLMProviderEnum.FIREWORKS - elif self.open_llm_api_base: - llm = LLMProviderEnum.OPEN_LLM - else: - raise NotConfiguredException("You should config a LLM configuration first") - return llm + def get_default_llm_provider_enum(self) -> LLMProviderEnum: + for k, v in [ + (self.openai_api_key, LLMProviderEnum.OPENAI), + (self.anthropic_api_key, LLMProviderEnum.ANTHROPIC), + (self.zhipuai_api_key, LLMProviderEnum.ZHIPUAI), + (self.fireworks_api_key, LLMProviderEnum.FIREWORKS), + (self.open_llm_api_base, LLMProviderEnum.OPEN_LLM), # reuse logic. but not a key + ]: + if self._is_valid_llm_key(k): + if self.openai_api_model: + logger.info(f"OpenAI API Model: {self.openai_api_model}") + return v + raise NotConfiguredException("You should config a LLM configuration first") @staticmethod - def _is_valid_llm_key(k) -> bool: + def _is_valid_llm_key(k: str) -> bool: return k and k != "YOUR_API_KEY" def _update(self): From edb90690263b5b0aa91ecdf61e94476e6ff613c4 Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 19:26:01 +0800 Subject: [PATCH 066/167] delete manager.py --- metagpt/manager.py | 66 ---------------------------------------------- 1 file changed, 66 deletions(-) delete mode 100644 metagpt/manager.py diff --git a/metagpt/manager.py b/metagpt/manager.py deleted file mode 100644 index a063608be..000000000 --- a/metagpt/manager.py +++ /dev/null @@ -1,66 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -@Time : 2023/5/11 14:42 -@Author : alexanderwu -@File : manager.py -""" -from metagpt.llm import LLM -from metagpt.logs import logger -from metagpt.schema import Message - - -class Manager: - def __init__(self, llm: LLM = LLM()): - self.llm = llm # Large Language Model - self.role_directions = { - "User": "Product Manager", - "Product Manager": "Architect", - "Architect": "Engineer", - "Engineer": "QA Engineer", - "QA Engineer": "Product Manager", - } - self.prompt_template = """ - Given the following message: - {message} - - And the current status of roles: - {roles} - - Which role should handle this message? - """ - - async def handle(self, message: Message, environment): - """ - 管理员处理信息,现在简单的将信息递交给下一个人 - The administrator processes the information, now simply passes the information on to the next person - :param message: - :param environment: - :return: - """ - # Get all roles from the environment - roles = environment.get_roles() - # logger.debug(f"{roles=}, {message=}") - - # Build a context for the LLM to understand the situation - # context = { - # "message": str(message), - # "roles": {role.name: role.get_info() for role in roles}, - # } - # Ask the LLM to decide which role should handle the message - # chosen_role_name = self.llm.ask(self.prompt_template.format(context)) - - # FIXME: 现在通过简单的字典决定流向,但之后还是应该有思考过程 - # The direction of flow is now determined by a simple dictionary, but there should still be a thought process afterwards - next_role_profile = self.role_directions[message.role] - # logger.debug(f"{next_role_profile}") - for _, role in roles.items(): - if next_role_profile == role.profile: - next_role = role - break - else: - logger.error(f"No available role can handle message: {message}.") - return - - # Find the chosen role and handle the message - return await next_role.handle(message) From 8a1237460eb1afd77be3d8db6d61adbcdcf271a2 Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 19:27:11 +0800 Subject: [PATCH 067/167] remove useless fields --- metagpt/actions/action.py | 12 +----------- metagpt/actions/search_and_summarize.py | 3 +-- metagpt/roles/role.py | 2 +- 3 files changed, 3 insertions(+), 14 deletions(-) diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index 7bb26ea91..1292b6684 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -26,22 +26,12 @@ class Action(ABC): self.llm = llm self.context = context self.prefix = "" # aask*时会加上prefix,作为system_message - self.profile = "" # FIXME: USELESS self.desc = "" # for skill manager self.nodes = ... - # Output, useless - # self.content = "" - # self.instruct_content = None - # self.env = None - - # def set_env(self, env): - # self.env = env - - def set_prefix(self, prefix, profile): + def set_prefix(self, prefix): """Set prefix for later usage""" self.prefix = prefix - self.profile = profile return self def __str__(self): diff --git a/metagpt/actions/search_and_summarize.py b/metagpt/actions/search_and_summarize.py index 5e4cdaea0..a1d81bc65 100644 --- a/metagpt/actions/search_and_summarize.py +++ b/metagpt/actions/search_and_summarize.py @@ -130,8 +130,7 @@ class SearchAndSummarize(Action): system_prompt = [system_text] prompt = SEARCH_AND_SUMMARIZE_PROMPT.format( - # PREFIX = self.prefix, - ROLE=self.profile, + ROLE=self.prefix, CONTEXT=rsp, QUERY_HISTORY="\n".join([str(i) for i in context[:-1]]), QUERY=str(context[-1]), diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index e13bf454b..bf37a6637 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -146,7 +146,7 @@ class Role: self._actions = [] def _init_action_system_message(self, action: Action): - action.set_prefix(self._get_prefix(), self.profile) + action.set_prefix(self._get_prefix()) def _init_actions(self, actions): self._reset() From f0fd5ac59bd8be8e0083aa89a5d38d7cf3c3d639 Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 21:17:02 +0800 Subject: [PATCH 068/167] refine a lot of code, fix pylint, use actionnode include ui, action _aask_v1, detail_mining, prepare_interview, etc. --- metagpt/actions/action.py | 48 +++----- metagpt/actions/action_node.py | 81 +++++--------- metagpt/actions/design_api.py | 10 +- metagpt/actions/detail_mining.py | 50 +++------ metagpt/actions/prepare_interview.py | 35 ++---- metagpt/actions/project_management.py | 10 +- metagpt/actions/write_prd.py | 8 +- metagpt/config.py | 2 +- metagpt/utils/get_template.py | 6 +- tests/metagpt/actions/test_detail_mining.py | 4 +- .../metagpt/actions/test_prepare_interview.py | 21 ++++ tests/metagpt/roles/ui_role.py | 104 +++++++++--------- 12 files changed, 163 insertions(+), 216 deletions(-) create mode 100644 tests/metagpt/actions/test_prepare_interview.py diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index 1292b6684..5c5884e8b 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -6,19 +6,26 @@ @File : action.py """ +from __future__ import annotations + from abc import ABC from typing import Optional -from tenacity import retry, stop_after_attempt, wait_random_exponential - -from metagpt.actions.action_output import ActionOutput +from metagpt.actions.action_node import ActionNode from metagpt.llm import LLM -from metagpt.logs import logger -from metagpt.provider.postprecess.llm_output_postprecess import llm_output_postprecess -from metagpt.utils.common import OutputParser, general_after_log +from metagpt.schema import BaseContext class Action(ABC): + """Action abstract class, requiring all inheritors to provide a series of standard capabilities""" + + name: str + llm: LLM + context: dict | BaseContext | str | None + prefix: str + desc: str + node: ActionNode | None + def __init__(self, name: str = "", context=None, llm: LLM = None): self.name: str = name if llm is None: @@ -27,7 +34,7 @@ class Action(ABC): self.context = context self.prefix = "" # aask*时会加上prefix,作为system_message self.desc = "" # for skill manager - self.nodes = ... + self.node = None def set_prefix(self, prefix): """Set prefix for later usage""" @@ -47,33 +54,6 @@ class Action(ABC): system_msgs.append(self.prefix) return await self.llm.aask(prompt, system_msgs) - @retry( - wait=wait_random_exponential(min=1, max=60), - stop=stop_after_attempt(6), - after=general_after_log(logger), - ) - async def _aask_v1( - self, - prompt: str, - output_class_name: str, - output_data_mapping: dict, - system_msgs: Optional[list[str]] = None, - format="markdown", # compatible to original format - ) -> ActionOutput: - content = await self.llm.aask(prompt, system_msgs) - logger.debug(f"llm raw output:\n{content}") - output_class = ActionOutput.create_model_class(output_class_name, output_data_mapping) - - if format == "json": - parsed_data = llm_output_postprecess(output=content, schema=output_class.schema(), req_key="[/CONTENT]") - - else: # using markdown parser - parsed_data = OutputParser.parse_data_with_mapping(content, output_data_mapping) - - logger.debug(f"parsed_data:\n{parsed_data}") - instruct_content = output_class(**parsed_data) - return ActionOutput(content, instruct_content) - async def run(self, *args, **kwargs): """Run action""" raise NotImplementedError("The run method should be implemented in a subclass.") diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index 6f1215920..0368d2df1 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -6,17 +6,15 @@ @File : action_node.py """ import json -import re -from typing import Any, Dict, List, Optional, Type +from typing import Dict, Generic, List, Optional, Type, TypeVar from pydantic import BaseModel, create_model, root_validator, validator from tenacity import retry, stop_after_attempt, wait_random_exponential -from metagpt.actions import ActionOutput from metagpt.llm import BaseGPTAPI from metagpt.logs import logger -from metagpt.utils.common import OutputParser -from metagpt.utils.custom_decoder import CustomDecoder +from metagpt.provider.postprecess.llm_output_postprecess import llm_output_postprecess +from metagpt.utils.common import OutputParser, general_after_log CONSTRAINT = """ - Language: Please use the same language as the user input. @@ -43,14 +41,17 @@ Fill in the above nodes based on the format example. """ -def dict_to_markdown(d, prefix="###", postfix="\n"): +def dict_to_markdown(d, prefix="-", postfix="\n"): markdown_str = "" for key, value in d.items(): markdown_str += f"{prefix} {key}: {value}{postfix}" return markdown_str -class ActionNode: +T = TypeVar("T") + + +class ActionNode(Generic[T]): """ActionNode is a tree of nodes.""" mode: str @@ -65,7 +66,7 @@ class ActionNode: expected_type: Type # such as str / int / float etc. # context: str # everything in the history. instruction: str # the instructions should be followed. - example: Any # example for In Context-Learning. + example: T # example for In Context-Learning. # Action Output content: str @@ -76,7 +77,7 @@ class ActionNode: key: str, expected_type: Type, instruction: str, - example: str, + example: T, content: str = "", children: dict[str, "ActionNode"] = None, ): @@ -148,29 +149,6 @@ class ActionNode: new_class.__root_validator_check_missing_fields = classmethod(check_missing_fields) return new_class - @classmethod - def create_model_class_v2(cls, class_name: str, mapping: Dict[str, Type]): - """基于pydantic v2的模型动态生成,用来检验结果类型正确性,待验证""" - new_class = create_model(class_name, **mapping) - - @model_validator(mode="before") - def check_missing_fields(data): - required_fields = set(mapping.keys()) - missing_fields = required_fields - set(data.keys()) - if missing_fields: - raise ValueError(f"Missing fields: {missing_fields}") - return data - - @field_validator("*") - def check_name(v: Any, field: str) -> Any: - if field not in mapping.keys(): - raise ValueError(f"Unrecognized block: {field}") - return v - - new_class.__model_validator_check_missing_fields = classmethod(check_missing_fields) - new_class.__field_validator_check_name = classmethod(check_name) - return new_class - def create_children_class(self): """使用object内有的字段直接生成model_class""" class_name = f"{self.key}_AN" @@ -245,6 +223,7 @@ class ActionNode: """ # FIXME: json instruction会带来格式问题,如:"Project name": "web_2048 # 项目名称使用下划线", + # compile example暂时不支持markdown self.instruction = self.compile_instruction(to="markdown", mode=mode) self.example = self.compile_example(to=to, tag="CONTENT", mode=mode) prompt = template.format( @@ -252,36 +231,32 @@ class ActionNode: ) return prompt - @retry(wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(6)) + @retry( + wait=wait_random_exponential(min=1, max=60), + stop=stop_after_attempt(6), + after=general_after_log(logger), + ) async def _aask_v1( self, prompt: str, output_class_name: str, output_data_mapping: dict, system_msgs: Optional[list[str]] = None, - format="markdown", # compatible to original format - ) -> ActionOutput: + schema="markdown", # compatible to original format + ) -> (str, BaseModel): + """Use ActionOutput to wrap the output of aask""" content = await self.llm.aask(prompt, system_msgs) - logger.debug(content) - output_class = ActionOutput.create_model_class(output_class_name, output_data_mapping) - - if format == "json": - pattern = r"\[CONTENT\](\s*\{.*?\}\s*)\[/CONTENT\]" - matches = re.findall(pattern, content, re.DOTALL) - - for match in matches: - if match: - content = match - break - - parsed_data = CustomDecoder(strict=False).decode(content) + logger.debug(f"llm raw output:\n{content}") + output_class = self.create_model_class(output_class_name, output_data_mapping) + if schema == "json": + parsed_data = llm_output_postprecess(output=content, schema=output_class.schema(), req_key="[/CONTENT]") else: # using markdown parser parsed_data = OutputParser.parse_data_with_mapping(content, output_data_mapping) - logger.debug(parsed_data) + logger.debug(f"parsed_data:\n{parsed_data}") instruct_content = output_class(**parsed_data) - return ActionOutput(content, instruct_content) + return content, instruct_content def get(self, key): return self.instruct_content.dict()[key] @@ -302,9 +277,9 @@ class ActionNode: mapping = self.get_mapping(mode) class_name = f"{self.key}_AN" - output = await self._aask_v1(prompt, class_name, mapping, format=to) - self.content = output.content - self.instruct_content = output.instruct_content + content, scontent = await self._aask_v1(prompt, class_name, mapping, schema=to) + self.content = content + self.instruct_content = scontent return self async def fill(self, context, llm, to="json", mode="auto", strgy="simple"): diff --git a/metagpt/actions/design_api.py b/metagpt/actions/design_api.py index 5a5f52de7..f757ca856 100644 --- a/metagpt/actions/design_api.py +++ b/metagpt/actions/design_api.py @@ -50,7 +50,7 @@ class WriteDesign(Action): "clearly and in detail." ) - async def run(self, with_messages, format=CONFIG.prompt_format): + async def run(self, with_messages, schema=CONFIG.prompt_schema): # Use `git diff` to identify which PRD documents have been modified in the `docs/prds` directory. prds_file_repo = CONFIG.git_repo.new_file_repository(PRDS_FILE_REPO) changed_prds = prds_file_repo.changed_files @@ -80,13 +80,13 @@ class WriteDesign(Action): # leaving room for global optimization in subsequent steps. return ActionOutput(content=changed_files.json(), instruct_content=changed_files) - async def _new_system_design(self, context, format=CONFIG.prompt_format): - node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, to=format) + async def _new_system_design(self, context, schema=CONFIG.prompt_schema): + node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, to=schema) return node - async def _merge(self, prd_doc, system_design_doc, format=CONFIG.prompt_format): + async def _merge(self, prd_doc, system_design_doc, schema=CONFIG.prompt_schema): context = NEW_REQ_TEMPLATE.format(old_design=system_design_doc.content, context=prd_doc.content) - node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, to=format) + node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, to=schema) system_design_doc.content = node.instruct_content.json(ensure_ascii=False) return system_design_doc diff --git a/metagpt/actions/detail_mining.py b/metagpt/actions/detail_mining.py index 5afcf52c6..0314d30dd 100644 --- a/metagpt/actions/detail_mining.py +++ b/metagpt/actions/detail_mining.py @@ -5,47 +5,31 @@ @Author : fisherdeng @File : detail_mining.py """ -from metagpt.actions import Action, ActionOutput +from metagpt.actions import Action +from metagpt.actions.action_node import ActionNode -PROMPT_TEMPLATE = """ -##TOPIC +CONTEXT_TEMPLATE = """ +## TOPIC {topic} -##RECORD +## RECORD {record} - -##Format example -{format_example} ------ - -Task: Refer to the "##TOPIC" (discussion objectives) and "##RECORD" (discussion records) to further inquire about the details that interest you, within a word limit of 150 words. -Special Note 1: Your intention is solely to ask questions without endorsing or negating any individual's viewpoints. -Special Note 2: This output should only include the topic "##OUTPUT". Do not add, remove, or modify the topic. Begin the output with '##OUTPUT', followed by an immediate line break, and then proceed to provide the content in the specified format as outlined in the "##Format example" section. -Special Note 3: The output should be in the same language as the input. """ -FORMAT_EXAMPLE = """ -## - -##OUTPUT -...(Please provide the specific details you would like to inquire about here.) - -## - -## -""" -OUTPUT_MAPPING = { - "OUTPUT": (str, ...), -} +QUESTIONS = ActionNode( + key="Questions", + expected_type=list[str], + instruction="Task: Refer to the context to further inquire about the details that interest you, within a word limit" + " of 150 words. Please provide the specific details you would like to inquire about here", + example=["1. What ...", "2. How ...", "3. ..."], +) class DetailMining(Action): - """This class allows LLM to further mine noteworthy details based on specific "##TOPIC"(discussion topic) and "##RECORD" (discussion records), thereby deepening the discussion.""" + """This class allows LLM to further mine noteworthy details based on specific "##TOPIC"(discussion topic) and + "##RECORD" (discussion records), thereby deepening the discussion.""" - def __init__(self, name="", context=None, llm=None): - super().__init__(name, context, llm) - - async def run(self, topic, record) -> ActionOutput: - prompt = PROMPT_TEMPLATE.format(topic=topic, record=record, format_example=FORMAT_EXAMPLE) - rsp = await self._aask_v1(prompt, "detail_mining", OUTPUT_MAPPING) + async def run(self, topic, record): + context = CONTEXT_TEMPLATE.format(topic=topic, record=record) + rsp = await QUESTIONS.fill(context=context, llm=self.llm) return rsp diff --git a/metagpt/actions/prepare_interview.py b/metagpt/actions/prepare_interview.py index b2704616e..7ed42d590 100644 --- a/metagpt/actions/prepare_interview.py +++ b/metagpt/actions/prepare_interview.py @@ -6,35 +6,18 @@ @File : prepare_interview.py """ from metagpt.actions import Action +from metagpt.actions.action_node import ActionNode -PROMPT_TEMPLATE = """ -# Context -{context} - -## Format example ---- -Q1: question 1 here -References: - - point 1 - - point 2 - -Q2: question 2 here... ---- - ------ -Role: You are an interviewer of our company who is well-knonwn in frontend or backend develop; +QUESTIONS = ActionNode( + key="Questions", + expected_type=list[str], + instruction="""Role: You are an interviewer of our company who is well-knonwn in frontend or backend develop; Requirement: Provide a list of questions for the interviewer to ask the interviewee, by reading the resume of the interviewee in the context. -Attention: Provide as markdown block as the format above, at least 10 questions. -""" - -# prepare for a interview +Attention: Provide as markdown block as the format above, at least 10 questions.""", + example=["1. What ...", "2. How ..."], +) class PrepareInterview(Action): - def __init__(self, name, context=None, llm=None): - super().__init__(name, context, llm) - async def run(self, context): - prompt = PROMPT_TEMPLATE.format(context=context) - question_list = await self._aask_v1(prompt) - return question_list + return await QUESTIONS.fill(context=context, llm=self.llm) diff --git a/metagpt/actions/project_management.py b/metagpt/actions/project_management.py index 1f14e7944..fe2c8d537 100644 --- a/metagpt/actions/project_management.py +++ b/metagpt/actions/project_management.py @@ -42,7 +42,7 @@ class WriteTasks(Action): def __init__(self, name="CreateTasks", context=None, llm=None): super().__init__(name, context, llm) - async def run(self, with_messages, format=CONFIG.prompt_format): + async def run(self, with_messages, schema=CONFIG.prompt_schema): system_design_file_repo = CONFIG.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO) changed_system_designs = system_design_file_repo.changed_files @@ -89,16 +89,16 @@ class WriteTasks(Action): await self._save_pdf(task_doc=task_doc) return task_doc - async def _run_new_tasks(self, context, format=CONFIG.prompt_format): - node = await PM_NODE.fill(context, self.llm, format) + async def _run_new_tasks(self, context, schema=CONFIG.prompt_schema): + node = await PM_NODE.fill(context, self.llm, schema) # prompt_template, format_example = get_template(templates, format) # prompt = prompt_template.format(context=context, format_example=format_example) # rsp = await self._aask_v1(prompt, "task", OUTPUT_MAPPING, format=format) return node - async def _merge(self, system_design_doc, task_doc, format=CONFIG.prompt_format) -> Document: + async def _merge(self, system_design_doc, task_doc, schema=CONFIG.prompt_schema) -> Document: context = NEW_REQ_TEMPLATE.format(context=system_design_doc.content, old_tasks=task_doc.content) - node = await PM_NODE.fill(context, self.llm, format) + node = await PM_NODE.fill(context, self.llm, schema) task_doc.content = node.instruct_content.json(ensure_ascii=False) return task_doc diff --git a/metagpt/actions/write_prd.py b/metagpt/actions/write_prd.py index adba7decb..1cf21dbb7 100644 --- a/metagpt/actions/write_prd.py +++ b/metagpt/actions/write_prd.py @@ -111,7 +111,7 @@ class WritePRD(Action): # optimization in subsequent steps. return ActionOutput(content=change_files.json(), instruct_content=change_files) - async def _run_new_requirement(self, requirements, format=CONFIG.prompt_format) -> ActionOutput: + async def _run_new_requirement(self, requirements, schema=CONFIG.prompt_schema) -> ActionOutput: # sas = SearchAndSummarize() # # rsp = await sas.run(context=requirements, system_text=SEARCH_AND_SUMMARIZE_SYSTEM_EN_US) # rsp = "" @@ -121,7 +121,7 @@ class WritePRD(Action): # logger.info(rsp) project_name = CONFIG.project_name if CONFIG.project_name else "" context = CONTEXT_TEMPLATE.format(requirements=requirements, project_name=project_name) - node = await WRITE_PRD_NODE.fill(context=context, llm=self.llm, to=format) + node = await WRITE_PRD_NODE.fill(context=context, llm=self.llm, to=schema) await self._rename_workspace(node) return node @@ -130,11 +130,11 @@ class WritePRD(Action): node = await WP_IS_RELATIVE_NODE.fill(context, self.llm) return node.get("is_relative") == "YES" - async def _merge(self, new_requirement_doc, prd_doc, format=CONFIG.prompt_format) -> Document: + async def _merge(self, new_requirement_doc, prd_doc, schema=CONFIG.prompt_schema) -> Document: if not CONFIG.project_name: CONFIG.project_name = Path(CONFIG.project_path).name prompt = NEW_REQ_TEMPLATE.format(requirements=new_requirement_doc.content, old_prd=prd_doc.content) - node = await WRITE_PRD_NODE.fill(context=prompt, llm=self.llm, to=format) + node = await WRITE_PRD_NODE.fill(context=prompt, llm=self.llm, to=schema) prd_doc.content = node.instruct_content.json(ensure_ascii=False) await self._rename_workspace(node) return prd_doc diff --git a/metagpt/config.py b/metagpt/config.py index 80a3a28f4..131854a56 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -143,7 +143,7 @@ class Config(metaclass=Singleton): self.pyppeteer_executable_path = self._get("PYPPETEER_EXECUTABLE_PATH", "") self.repair_llm_output = self._get("REPAIR_LLM_OUTPUT", False) - self.prompt_format = self._get("PROMPT_FORMAT", "json") + self.prompt_schema = self._get("PROMPT_FORMAT", "json") self.workspace_path = Path(self._get("WORKSPACE_PATH", DEFAULT_WORKSPACE_ROOT)) self._ensure_workspace_exists() diff --git a/metagpt/utils/get_template.py b/metagpt/utils/get_template.py index 86c1915f7..7e05e5d5e 100644 --- a/metagpt/utils/get_template.py +++ b/metagpt/utils/get_template.py @@ -8,10 +8,10 @@ from metagpt.config import CONFIG -def get_template(templates, format=CONFIG.prompt_format): - selected_templates = templates.get(format) +def get_template(templates, schema=CONFIG.prompt_schema): + selected_templates = templates.get(schema) if selected_templates is None: - raise ValueError(f"Can't find {format} in passed in templates") + raise ValueError(f"Can't find {schema} in passed in templates") # Extract the selected templates prompt_template = selected_templates["PROMPT_TEMPLATE"] diff --git a/tests/metagpt/actions/test_detail_mining.py b/tests/metagpt/actions/test_detail_mining.py index 891dca6ca..30bcf9dfb 100644 --- a/tests/metagpt/actions/test_detail_mining.py +++ b/tests/metagpt/actions/test_detail_mining.py @@ -19,5 +19,5 @@ async def test_detail_mining(): rsp = await detail_mining.run(topic=topic, record=record) logger.info(f"{rsp.content=}") - assert "##OUTPUT" in rsp.content - assert "蛋糕" in rsp.content + assert "Questions" in rsp.content + assert "1." in rsp.content diff --git a/tests/metagpt/actions/test_prepare_interview.py b/tests/metagpt/actions/test_prepare_interview.py new file mode 100644 index 000000000..7c32882e0 --- /dev/null +++ b/tests/metagpt/actions/test_prepare_interview.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/9/13 00:26 +@Author : fisherdeng +@File : test_detail_mining.py +""" +import pytest + +from metagpt.actions.prepare_interview import PrepareInterview +from metagpt.logs import logger + + +@pytest.mark.asyncio +async def test_prepare_interview(): + action = PrepareInterview() + rsp = await action.run("I just graduated and hope to find a job as a Python engineer") + logger.info(f"{rsp.content=}") + + assert "Questions" in rsp.content + assert "1." in rsp.content diff --git a/tests/metagpt/roles/ui_role.py b/tests/metagpt/roles/ui_role.py index 8ac799bf3..0932efa1f 100644 --- a/tests/metagpt/roles/ui_role.py +++ b/tests/metagpt/roles/ui_role.py @@ -10,6 +10,7 @@ from importlib import import_module from metagpt.actions import Action, ActionOutput, WritePRD # from metagpt.const import WORKSPACE_ROOT +from metagpt.actions.action_node import ActionNode from metagpt.config import CONFIG from metagpt.logs import logger from metagpt.roles import Role @@ -17,44 +18,38 @@ from metagpt.schema import Message from metagpt.tools.sd_engine import SDEngine PROMPT_TEMPLATE = """ -# Context {context} -## Format example -{format_example} ------ -Role: You are a UserInterface Designer; the goal is to finish a UI design according to PRD, give a design description, and select specified elements and UI style. -Requirements: Based on the context, fill in the following missing information, provide detailed HTML and CSS code -Attention: Use '##' to split sections, not '#', and '## ' SHOULD WRITE BEFORE the code and triple quote. - -## UI Design Description:Provide as Plain text, place the design objective here -## Selected Elements:Provide as Plain text, up to 5 specified elements, clear and simple -## HTML Layout:Provide as Plain text, use standard HTML code -## CSS Styles (styles.css):Provide as Plain text,use standard css code -## Anything UNCLEAR:Provide as Plain text. Try to clarify it. - +## Role +You are a UserInterface Designer; the goal is to finish a UI design according to PRD, give a design description, and select specified elements and UI style. """ -FORMAT_EXAMPLE = """ +UI_DESIGN_DESC = ActionNode( + key="UI Design Desc", + expected_type=str, + instruction="place the design objective here", + example="Snake games are classic and addictive games with simple yet engaging elements. Here are the main elements" + " commonly found in snake games", +) -## UI Design Description -```Snake games are classic and addictive games with simple yet engaging elements. Here are the main elements commonly found in snake games ``` +SELECTED_ELEMENTS = ActionNode( + key="Selected Elements", + expected_type=list[str], + instruction="up to 5 specified elements, clear and simple", + example=[ + "Game Grid: The game grid is a rectangular...", + "Snake: The player controls a snake that moves across the grid...", + "Food: Food items (often represented as small objects or differently colored blocks)", + "Score: The player's score increases each time the snake eats a piece of food. The longer the snake becomes, the higher the score.", + "Game Over: The game ends when the snake collides with itself or an obstacle. At this point, the player's final score is displayed, and they are given the option to restart the game.", + ], +) -## Selected Elements - -Game Grid: The game grid is a rectangular... - -Snake: The player controls a snake that moves across the grid... - -Food: Food items (often represented as small objects or differently colored blocks) - -Score: The player's score increases each time the snake eats a piece of food. The longer the snake becomes, the higher the score. - -Game Over: The game ends when the snake collides with itself or an obstacle. At this point, the player's final score is displayed, and they are given the option to restart the game. - - -## HTML Layout - +HTML_LAYOUT = ActionNode( + key="HTML Layout", + expected_type=str, + instruction="use standard HTML code", + example=""" @@ -71,9 +66,14 @@ Game Over: The game ends when the snake collides with itself or an obstacle. At +""", +) -## CSS Styles (styles.css) -body { +CSS_STYLES = ActionNode( + key="CSS Styles", + expected_type=str, + instruction="use standard css code", + example="""body { display: flex; justify-content: center; align-items: center; @@ -121,19 +121,25 @@ body { color: #ff0000; display: none; } +""", +) -## Anything UNCLEAR -There are no unclear points. +ANYTHING_UNCLEAR = ActionNode( + key="Anything UNCLEAR", + expected_type=str, + instruction="Mention any aspects of the project that are unclear and try to clarify them.", + example="...", +) -""" +NODES = [ + UI_DESIGN_DESC, + SELECTED_ELEMENTS, + HTML_LAYOUT, + CSS_STYLES, + ANYTHING_UNCLEAR, +] -OUTPUT_MAPPING = { - "UI Design Description": (str, ...), - "Selected Elements": (str, ...), - "HTML Layout": (str, ...), - "CSS Styles (styles.css)": (str, ...), - "Anything UNCLEAR": (str, ...), -} +UI_DESIGN_NODE = ActionNode.from_children("UI_DESIGN", NODES) def load_engine(func): @@ -223,10 +229,8 @@ class UIDesign(Action): css_file_path = save_dir / "ui_design.css" html_file_path = save_dir / "ui_design.html" - with open(css_file_path, "w") as css_file: - css_file.write(css_content) - with open(html_file_path, "w") as html_file: - html_file.write(html_content) + css_file_path.write_text(css_content) + html_file_path.write_text(html_content) async def run(self, requirements: list[Message], *args, **kwargs) -> ActionOutput: """Run the UI Design action.""" @@ -234,9 +238,9 @@ class UIDesign(Action): context = requirements[-1].content ui_design_draft = self.parse_requirement(context=context) # todo: parse requirements str - prompt = PROMPT_TEMPLATE.format(context=ui_design_draft, format_example=FORMAT_EXAMPLE) + prompt = PROMPT_TEMPLATE.format(context=ui_design_draft) logger.info(prompt) - ui_describe = await self._aask_v1(prompt, "ui_design", OUTPUT_MAPPING) + ui_describe = await UI_DESIGN_NODE.fill(prompt) logger.info(ui_describe.content) logger.info(ui_describe.instruct_content) css = self.parse_css_code(context=ui_describe.content) From 09e2f05a6a553c32cfdcdb53ec680d73acda1af2 Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 21:24:08 +0800 Subject: [PATCH 069/167] refactor action_output and action_node --- metagpt/actions/action_node.py | 4 ++-- metagpt/actions/action_output.py | 26 +-------------------- metagpt/actions/write_prd.py | 2 +- metagpt/utils/serialize.py | 4 ++-- tests/metagpt/actions/test_action_output.py | 6 ++--- tests/metagpt/memory/test_memory_storage.py | 4 ++-- tests/metagpt/utils/test_serialize.py | 4 ++-- 7 files changed, 13 insertions(+), 37 deletions(-) diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index 0368d2df1..865cb2d32 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -6,7 +6,7 @@ @File : action_node.py """ import json -from typing import Dict, Generic, List, Optional, Type, TypeVar +from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar from pydantic import BaseModel, create_model, root_validator, validator from tenacity import retry, stop_after_attempt, wait_random_exponential @@ -127,7 +127,7 @@ class ActionNode(Generic[T]): return self.get_self_mapping() @classmethod - def create_model_class(cls, class_name: str, mapping: Dict[str, Type]): + def create_model_class(cls, class_name: str, mapping: Dict[str, Tuple[Type, Any]]): """基于pydantic v1的模型动态生成,用来检验结果类型正确性""" new_class = create_model(class_name, **mapping) diff --git a/metagpt/actions/action_output.py b/metagpt/actions/action_output.py index 25326d43b..6be8dac50 100644 --- a/metagpt/actions/action_output.py +++ b/metagpt/actions/action_output.py @@ -6,9 +6,7 @@ @File : action_output """ -from typing import Dict, Type - -from pydantic import BaseModel, create_model, root_validator, validator +from pydantic import BaseModel class ActionOutput: @@ -18,25 +16,3 @@ class ActionOutput: def __init__(self, content: str, instruct_content: BaseModel): self.content = content self.instruct_content = instruct_content - - @classmethod - def create_model_class(cls, class_name: str, mapping: Dict[str, Type]): - new_class = create_model(class_name, **mapping) - - @validator("*", allow_reuse=True) - def check_name(v, field): - if field.name not in mapping.keys(): - raise ValueError(f"Unrecognized block: {field.name}") - return v - - @root_validator(pre=True, allow_reuse=True) - def check_missing_fields(values): - required_fields = set(mapping.keys()) - missing_fields = required_fields - set(values.keys()) - if missing_fields: - raise ValueError(f"Missing fields: {missing_fields}") - return values - - new_class.__validator_check_name = classmethod(check_name) - new_class.__root_validator_check_missing_fields = classmethod(check_missing_fields) - return new_class diff --git a/metagpt/actions/write_prd.py b/metagpt/actions/write_prd.py index 1cf21dbb7..23925ff10 100644 --- a/metagpt/actions/write_prd.py +++ b/metagpt/actions/write_prd.py @@ -67,7 +67,7 @@ class WritePRD(Action): def __init__(self, name="", context=None, llm=None): super().__init__(name, context, llm) - async def run(self, with_messages, format=CONFIG.prompt_format, *args, **kwargs) -> ActionOutput | Message: + async def run(self, with_messages, schema=CONFIG.prompt_schema, *args, **kwargs) -> ActionOutput | Message: # Determine which requirement documents need to be rewritten: Use LLM to assess whether new requirements are # related to the PRD. If they are related, rewrite the PRD. docs_file_repo = CONFIG.git_repo.new_file_repository(relative_path=DOCS_FILE_REPO) diff --git a/metagpt/utils/serialize.py b/metagpt/utils/serialize.py index 124176fcb..5e52846e1 100644 --- a/metagpt/utils/serialize.py +++ b/metagpt/utils/serialize.py @@ -6,7 +6,7 @@ import copy import pickle from typing import Dict, List -from metagpt.actions.action_output import ActionOutput +from metagpt.actions.action_node import ActionNode from metagpt.schema import Message @@ -60,7 +60,7 @@ def deserialize_message(message_ser: str) -> Message: message = pickle.loads(message_ser) if message.instruct_content: ic = message.instruct_content - ic_obj = ActionOutput.create_model_class(class_name=ic["class"], mapping=ic["mapping"]) + ic_obj = ActionNode.create_model_class(class_name=ic["class"], mapping=ic["mapping"]) ic_new = ic_obj(**ic["value"]) message.instruct_content = ic_new diff --git a/tests/metagpt/actions/test_action_output.py b/tests/metagpt/actions/test_action_output.py index ef8e239bd..f1765cb03 100644 --- a/tests/metagpt/actions/test_action_output.py +++ b/tests/metagpt/actions/test_action_output.py @@ -7,7 +7,7 @@ """ from typing import List, Tuple -from metagpt.actions import ActionOutput +from metagpt.actions.action_node import ActionNode t_dict = { "Required Python third-party packages": '"""\nflask==1.1.2\npygame==2.0.1\n"""\n', @@ -37,12 +37,12 @@ WRITE_TASKS_OUTPUT_MAPPING = { def test_create_model_class(): - test_class = ActionOutput.create_model_class("test_class", WRITE_TASKS_OUTPUT_MAPPING) + test_class = ActionNode.create_model_class("test_class", WRITE_TASKS_OUTPUT_MAPPING) assert test_class.__name__ == "test_class" def test_create_model_class_with_mapping(): - t = ActionOutput.create_model_class("test_class_1", WRITE_TASKS_OUTPUT_MAPPING) + t = ActionNode.create_model_class("test_class_1", WRITE_TASKS_OUTPUT_MAPPING) t1 = t(**t_dict) value = t1.dict()["Task list"] assert value == ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"] diff --git a/tests/metagpt/memory/test_memory_storage.py b/tests/metagpt/memory/test_memory_storage.py index c67ca689f..7b74eb512 100644 --- a/tests/metagpt/memory/test_memory_storage.py +++ b/tests/metagpt/memory/test_memory_storage.py @@ -8,7 +8,7 @@ from typing import List from metagpt.actions import UserRequirement, WritePRD -from metagpt.actions.action_output import ActionOutput +from metagpt.actions.action_node import ActionNode from metagpt.memory.memory_storage import MemoryStorage from metagpt.schema import Message @@ -42,7 +42,7 @@ def test_idea_message(): def test_actionout_message(): out_mapping = {"field1": (str, ...), "field2": (List[str], ...)} out_data = {"field1": "field1 value", "field2": ["field2 value1", "field2 value2"]} - ic_obj = ActionOutput.create_model_class("prd", out_mapping) + ic_obj = ActionNode.create_model_class("prd", out_mapping) role_id = "UTUser2(Architect)" content = "The user has requested the creation of a command-line interface (CLI) snake game" diff --git a/tests/metagpt/utils/test_serialize.py b/tests/metagpt/utils/test_serialize.py index ffa34866c..f027d53f8 100644 --- a/tests/metagpt/utils/test_serialize.py +++ b/tests/metagpt/utils/test_serialize.py @@ -7,7 +7,7 @@ from typing import List, Tuple from metagpt.actions import WritePRD -from metagpt.actions.action_output import ActionOutput +from metagpt.actions.action_node import ActionNode from metagpt.schema import Message from metagpt.utils.serialize import ( actionoutout_schema_to_mapping, @@ -54,7 +54,7 @@ def test_actionoutout_schema_to_mapping(): def test_serialize_and_deserialize_message(): out_mapping = {"field1": (str, ...), "field2": (List[str], ...)} out_data = {"field1": "field1 value", "field2": ["field2 value1", "field2 value2"]} - ic_obj = ActionOutput.create_model_class("prd", out_mapping) + ic_obj = ActionNode.create_model_class("prd", out_mapping) message = Message( content="prd demand", instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD From 33c58d97fef317afba757ba04ece00fd1830130d Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 21:32:52 +0800 Subject: [PATCH 070/167] refine code --- metagpt/actions/action_node.py | 2 +- metagpt/actions/write_prd_an.py | 8 ++++---- metagpt/provider/postprecess/base_postprecess_plugin.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index 865cb2d32..790069369 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -232,7 +232,7 @@ class ActionNode(Generic[T]): return prompt @retry( - wait=wait_random_exponential(min=1, max=60), + wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6), after=general_after_log(logger), ) diff --git a/metagpt/actions/write_prd_an.py b/metagpt/actions/write_prd_an.py index d96c0aeac..edd94a463 100644 --- a/metagpt/actions/write_prd_an.py +++ b/metagpt/actions/write_prd_an.py @@ -47,7 +47,7 @@ PRODUCT_GOALS = ActionNode( USER_STORIES = ActionNode( key="User Stories", expected_type=list[str], - instruction="Provide up to five scenario-based user stories.", + instruction="Provide up to 3 to 5 scenario-based user stories.", example=[ "As a user, I want to be able to choose difficulty levels", "As a player, I want to see my score after each game", @@ -57,7 +57,7 @@ USER_STORIES = ActionNode( COMPETITIVE_ANALYSIS = ActionNode( key="Competitive Analysis", expected_type=list[str], - instruction="Provide analyses for up to seven competitive products.", + instruction="Provide 5 to 7 competitive products.", example=["Python Snake Game: Simple interface, lacks advanced features"], ) @@ -92,8 +92,8 @@ REQUIREMENT_ANALYSIS = ActionNode( REQUIREMENT_POOL = ActionNode( key="Requirement Pool", expected_type=list[list[str]], - instruction="List down the requirements with their priority (P0, P1, P2).", - example=[["P0", "..."], ["P1", "..."]], + instruction="List down the top-5 requirements with their priority (P0, P1, P2).", + example=[["P0", "The main code ..."], ["P0", "The game algorithm ..."]], ) UI_DESIGN_DRAFT = ActionNode( diff --git a/metagpt/provider/postprecess/base_postprecess_plugin.py b/metagpt/provider/postprecess/base_postprecess_plugin.py index 0d1cfbb11..721476507 100644 --- a/metagpt/provider/postprecess/base_postprecess_plugin.py +++ b/metagpt/provider/postprecess/base_postprecess_plugin.py @@ -44,7 +44,7 @@ class BasePostPrecessPlugin(object): def run_retry_parse_json_text(self, content: str) -> Union[dict, list]: """inherited class can re-implement the function""" - logger.info(f"extracted json CONTENT from output:\n{content}") + logger.debug(f"extracted json CONTENT from output:\n{content}") parsed_data = retry_parse_json_text(output=content) # should use output=content return parsed_data From 62f34db137dcd73b965e613497ca1dd2df1ddcd9 Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 23:53:04 +0800 Subject: [PATCH 071/167] refine code. move azure tts to tool, refactor actions --- metagpt/actions/__init__.py | 2 - metagpt/actions/action.py | 5 ++- metagpt/actions/analyze_dep_libs.py | 37 ------------------- metagpt/actions/design_filenames.py | 30 --------------- ...detail_mining.py => generate_questions.py} | 18 ++------- metagpt/schema.py | 3 +- metagpt/{actions => tools}/azure_tts.py | 19 ++++------ tests/metagpt/actions/test_azure_tts.py | 4 +- tests/metagpt/actions/test_detail_mining.py | 20 ++++++---- 9 files changed, 32 insertions(+), 106 deletions(-) delete mode 100644 metagpt/actions/analyze_dep_libs.py delete mode 100644 metagpt/actions/design_filenames.py rename metagpt/actions/{detail_mining.py => generate_questions.py} (69%) rename metagpt/{actions => tools}/azure_tts.py (65%) diff --git a/metagpt/actions/__init__.py b/metagpt/actions/__init__.py index 79ff94b3e..c34c72ed2 100644 --- a/metagpt/actions/__init__.py +++ b/metagpt/actions/__init__.py @@ -13,7 +13,6 @@ from metagpt.actions.add_requirement import UserRequirement from metagpt.actions.debug_error import DebugError from metagpt.actions.design_api import WriteDesign from metagpt.actions.design_api_review import DesignReview -from metagpt.actions.design_filenames import DesignFilenames from metagpt.actions.project_management import AssignTasks, WriteTasks from metagpt.actions.research import CollectLinks, WebBrowseAndSummarize, ConductResearch from metagpt.actions.run_code import RunCode @@ -33,7 +32,6 @@ class ActionType(Enum): WRITE_PRD_REVIEW = WritePRDReview WRITE_DESIGN = WriteDesign DESIGN_REVIEW = DesignReview - DESIGN_FILENAMES = DesignFilenames WRTIE_CODE = WriteCode WRITE_CODE_REVIEW = WriteCodeReview WRITE_TEST = WriteTest diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index 5c5884e8b..a3a9c0195 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -13,7 +13,7 @@ from typing import Optional from metagpt.actions.action_node import ActionNode from metagpt.llm import LLM -from metagpt.schema import BaseContext +from metagpt.schema import CodingContext, CodeSummarizeContext, TestingContext, RunCodeContext class Action(ABC): @@ -21,7 +21,8 @@ class Action(ABC): name: str llm: LLM - context: dict | BaseContext | str | None + # FIXME: simplify context + context: dict | CodingContext | CodeSummarizeContext | TestingContext | RunCodeContext | str | None prefix: str desc: str node: ActionNode | None diff --git a/metagpt/actions/analyze_dep_libs.py b/metagpt/actions/analyze_dep_libs.py deleted file mode 100644 index 53d40200a..000000000 --- a/metagpt/actions/analyze_dep_libs.py +++ /dev/null @@ -1,37 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -@Time : 2023/5/19 12:01 -@Author : alexanderwu -@File : analyze_dep_libs.py -""" - -from metagpt.actions import Action - -PROMPT = """You are an AI developer, trying to write a program that generates code for users based on their intentions. - -For the user's prompt: - ---- -The API is: {prompt} ---- - -We decide the generated files are: {filepaths_string} - -Now that we have a file list, we need to understand the shared dependencies they have. -Please list and briefly describe the shared contents between the files we are generating, including exported variables, -data patterns, id names of all DOM elements that javascript functions will use, message names and function names. -Focus only on the names of shared dependencies, do not add any other explanations. -""" - - -class AnalyzeDepLibs(Action): - def __init__(self, name, context=None, llm=None): - super().__init__(name, context, llm) - self.desc = "Analyze the runtime dependencies of the program based on the context" - - async def run(self, requirement, filepaths_string): - # prompt = f"Below is the product requirement document (PRD):\n\n{prd}\n\n{PROMPT}" - prompt = PROMPT.format(prompt=requirement, filepaths_string=filepaths_string) - design_filenames = await self._aask(prompt) - return design_filenames diff --git a/metagpt/actions/design_filenames.py b/metagpt/actions/design_filenames.py deleted file mode 100644 index ffa171d7b..000000000 --- a/metagpt/actions/design_filenames.py +++ /dev/null @@ -1,30 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -@Time : 2023/5/19 11:50 -@Author : alexanderwu -@File : design_filenames.py -""" -from metagpt.actions import Action -from metagpt.logs import logger - -PROMPT = """You are an AI developer, trying to write a program that generates code for users based on their intentions. -When given their intentions, provide a complete and exhaustive list of file paths needed to write the program for the user. -Only list the file paths you will write and return them as a Python string list. -Do not add any other explanations, just return a Python string list.""" - - -class DesignFilenames(Action): - def __init__(self, name, context=None, llm=None): - super().__init__(name, context, llm) - self.desc = ( - "Based on the PRD, consider system design, and carry out the basic design of the corresponding " - "APIs, data structures, and database tables. Please give your design, feedback clearly and in detail." - ) - - async def run(self, prd): - prompt = f"The following is the Product Requirement Document (PRD):\n\n{prd}\n\n{PROMPT}" - design_filenames = await self._aask(prompt) - logger.debug(prompt) - logger.debug(design_filenames) - return design_filenames diff --git a/metagpt/actions/detail_mining.py b/metagpt/actions/generate_questions.py similarity index 69% rename from metagpt/actions/detail_mining.py rename to metagpt/actions/generate_questions.py index 0314d30dd..c38c463bc 100644 --- a/metagpt/actions/detail_mining.py +++ b/metagpt/actions/generate_questions.py @@ -3,19 +3,11 @@ """ @Time : 2023/9/12 17:45 @Author : fisherdeng -@File : detail_mining.py +@File : generate_questions.py """ from metagpt.actions import Action from metagpt.actions.action_node import ActionNode -CONTEXT_TEMPLATE = """ -## TOPIC -{topic} - -## RECORD -{record} -""" - QUESTIONS = ActionNode( key="Questions", expected_type=list[str], @@ -25,11 +17,9 @@ QUESTIONS = ActionNode( ) -class DetailMining(Action): +class GenerateQuestions(Action): """This class allows LLM to further mine noteworthy details based on specific "##TOPIC"(discussion topic) and "##RECORD" (discussion records), thereby deepening the discussion.""" - async def run(self, topic, record): - context = CONTEXT_TEMPLATE.format(topic=topic, record=record) - rsp = await QUESTIONS.fill(context=context, llm=self.llm) - return rsp + async def run(self, context): + return await QUESTIONS.fill(context=context, llm=self.llm) diff --git a/metagpt/schema.py b/metagpt/schema.py index aacc2cebb..d2f8d33e6 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -18,6 +18,7 @@ import asyncio import json import os.path import uuid +from abc import ABC from asyncio import Queue, QueueEmpty, wait_for from json import JSONDecodeError from pathlib import Path @@ -265,7 +266,7 @@ class MessageQueue: T = TypeVar("T", bound="BaseModel") -class BaseContext(BaseModel): +class BaseContext(BaseModel, ABC): @classmethod @handle_exception def loads(cls: Type[T], val: str) -> Optional[T]: diff --git a/metagpt/actions/azure_tts.py b/metagpt/tools/azure_tts.py similarity index 65% rename from metagpt/actions/azure_tts.py rename to metagpt/tools/azure_tts.py index daa3f6892..e59d98016 100644 --- a/metagpt/actions/azure_tts.py +++ b/metagpt/tools/azure_tts.py @@ -7,19 +7,16 @@ """ from azure.cognitiveservices.speech import AudioConfig, SpeechConfig, SpeechSynthesizer -from metagpt.actions.action import Action -from metagpt.config import Config +from metagpt.config import CONFIG -class AzureTTS(Action): - def __init__(self, name, context=None, llm=None): - super().__init__(name, context, llm) - self.config = Config() +class AzureTTS: + """https://learn.microsoft.com/zh-cn/azure/cognitive-services/speech-service/language-support?tabs=tts#voice-styles-and-roles""" - # Parameters reference: https://learn.microsoft.com/zh-cn/azure/cognitive-services/speech-service/language-support?tabs=tts#voice-styles-and-roles - def synthesize_speech(self, lang, voice, role, text, output_file): - subscription_key = self.config.get("AZURE_TTS_SUBSCRIPTION_KEY") - region = self.config.get("AZURE_TTS_REGION") + @classmethod + def synthesize_speech(cls, lang, voice, role, text, output_file): + subscription_key = CONFIG.get("AZURE_TTS_SUBSCRIPTION_KEY") + region = CONFIG.get("AZURE_TTS_REGION") speech_config = SpeechConfig(subscription=subscription_key, region=region) speech_config.speech_synthesis_voice_name = voice @@ -41,5 +38,5 @@ class AzureTTS(Action): if __name__ == "__main__": - azure_tts = AzureTTS("azure_tts") + azure_tts = AzureTTS() azure_tts.synthesize_speech("zh-CN", "zh-CN-YunxiNeural", "Boy", "Hello, I am Kaka", "output.wav") diff --git a/tests/metagpt/actions/test_azure_tts.py b/tests/metagpt/actions/test_azure_tts.py index bcafe10f5..9995e9691 100644 --- a/tests/metagpt/actions/test_azure_tts.py +++ b/tests/metagpt/actions/test_azure_tts.py @@ -5,11 +5,11 @@ @Author : alexanderwu @File : test_azure_tts.py """ -from metagpt.actions.azure_tts import AzureTTS +from metagpt.tools.azure_tts import AzureTTS def test_azure_tts(): - azure_tts = AzureTTS("azure_tts") + azure_tts = AzureTTS() azure_tts.synthesize_speech("zh-CN", "zh-CN-YunxiNeural", "Boy", "你好,我是卡卡", "output.wav") # 运行需要先配置 SUBSCRIPTION_KEY diff --git a/tests/metagpt/actions/test_detail_mining.py b/tests/metagpt/actions/test_detail_mining.py index 30bcf9dfb..a178ec840 100644 --- a/tests/metagpt/actions/test_detail_mining.py +++ b/tests/metagpt/actions/test_detail_mining.py @@ -3,20 +3,26 @@ """ @Time : 2023/9/13 00:26 @Author : fisherdeng -@File : test_detail_mining.py +@File : test_generate_questions.py """ import pytest -from metagpt.actions.detail_mining import DetailMining +from metagpt.actions.generate_questions import GenerateQuestions from metagpt.logs import logger +context = """ +## topic +如何做一个生日蛋糕 + +## record +我认为应该先准备好材料,然后再开始做蛋糕。 +""" + @pytest.mark.asyncio -async def test_detail_mining(): - topic = "如何做一个生日蛋糕" - record = "我认为应该先准备好材料,然后再开始做蛋糕。" - detail_mining = DetailMining("detail_mining") - rsp = await detail_mining.run(topic=topic, record=record) +async def test_generate_questions(): + detail_mining = GenerateQuestions() + rsp = await detail_mining.run(context) logger.info(f"{rsp.content=}") assert "Questions" in rsp.content From 0f78d4ea51d6e7d579dc7340e9b7e2039d0f5aa2 Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 23:58:18 +0800 Subject: [PATCH 072/167] refine code --- metagpt/actions/action_node.py | 52 +++++++++++++++++----------------- metagpt/actions/design_api.py | 4 +-- metagpt/actions/write_prd.py | 4 +-- 3 files changed, 30 insertions(+), 30 deletions(-) diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index 790069369..092dd5755 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -112,15 +112,15 @@ class ActionNode(Generic[T]): obj.add_children(nodes) return obj - def get_children_mapping(self) -> Dict[str, Type]: + def get_children_mapping(self) -> Dict[str, Tuple[Type, Any]]: """获得子ActionNode的字典,以key索引""" return {k: (v.expected_type, ...) for k, v in self.children.items()} - def get_self_mapping(self) -> Dict[str, Type]: + def get_self_mapping(self) -> Dict[str, Tuple[Type, Any]]: """get self key: type mapping""" return {self.key: (self.expected_type, ...)} - def get_mapping(self, mode="children") -> Dict[str, Type]: + def get_mapping(self, mode="children") -> Dict[str, Tuple[Type, Any]]: """get key: type mapping under mode""" if mode == "children" or (mode == "auto" and self.children): return self.get_children_mapping() @@ -175,46 +175,46 @@ class ActionNode(Generic[T]): return node_dict # 遍历子节点并递归调用 to_dict 方法 - for child_key, child_node in self.children.items(): + for _, child_node in self.children.items(): node_dict.update(child_node.to_dict(format_func)) return node_dict - def compile_to(self, i: Dict, to) -> str: - if to == "json": + def compile_to(self, i: Dict, schema) -> str: + if schema == "json": return json.dumps(i, indent=4) - elif to == "markdown": + elif schema == "markdown": return dict_to_markdown(i) else: return str(i) - def tagging(self, text, to, tag="") -> str: + def tagging(self, text, schema, tag="") -> str: if not tag: return text - if to == "json": + if schema == "json": return f"[{tag}]\n" + text + f"\n[/{tag}]" else: return f"[{tag}]\n" + text + f"\n[/{tag}]" - def _compile_f(self, to, mode, tag, format_func) -> str: + def _compile_f(self, schema, mode, tag, format_func) -> str: nodes = self.to_dict(format_func=format_func, mode=mode) - text = self.compile_to(nodes, to) - return self.tagging(text, to, tag) + text = self.compile_to(nodes, schema) + return self.tagging(text, schema, tag) - def compile_instruction(self, to="raw", mode="children", tag="") -> str: + def compile_instruction(self, schema="raw", mode="children", tag="") -> str: """compile to raw/json/markdown template with all/root/children nodes""" format_func = lambda i: f"{i.expected_type} # {i.instruction}" - return self._compile_f(to, mode, tag, format_func) + return self._compile_f(schema, mode, tag, format_func) - def compile_example(self, to="raw", mode="children", tag="") -> str: + def compile_example(self, schema="raw", mode="children", tag="") -> str: """compile to raw/json/markdown examples with all/root/children nodes""" # 这里不能使用f-string,因为转译为str后再json.dumps会额外加上引号,无法作为有效的example # 错误示例:"File list": "['main.py', 'const.py', 'game.py']", 注意这里值不是list,而是str format_func = lambda i: i.example - return self._compile_f(to, mode, tag, format_func) + return self._compile_f(schema, mode, tag, format_func) - def compile(self, context, to="json", mode="children", template=SIMPLE_TEMPLATE) -> str: + def compile(self, context, schema="json", mode="children", template=SIMPLE_TEMPLATE) -> str: """ mode: all/root/children mode="children": 编译所有子节点为一个统一模板,包括instruction与example @@ -224,8 +224,8 @@ class ActionNode(Generic[T]): # FIXME: json instruction会带来格式问题,如:"Project name": "web_2048 # 项目名称使用下划线", # compile example暂时不支持markdown - self.instruction = self.compile_instruction(to="markdown", mode=mode) - self.example = self.compile_example(to=to, tag="CONTENT", mode=mode) + self.instruction = self.compile_instruction(schema="markdown", mode=mode) + self.example = self.compile_example(schema=schema, tag="CONTENT", mode=mode) prompt = template.format( context=context, example=self.example, instruction=self.instruction, constraint=CONSTRAINT ) @@ -272,22 +272,22 @@ class ActionNode(Generic[T]): def set_context(self, context): self.set_recursive("context", context) - async def simple_fill(self, to, mode): - prompt = self.compile(context=self.context, to=to, mode=mode) + async def simple_fill(self, schema, mode): + prompt = self.compile(context=self.context, schema=schema, mode=mode) mapping = self.get_mapping(mode) class_name = f"{self.key}_AN" - content, scontent = await self._aask_v1(prompt, class_name, mapping, schema=to) + content, scontent = await self._aask_v1(prompt, class_name, mapping, schema=schema) self.content = content self.instruct_content = scontent return self - async def fill(self, context, llm, to="json", mode="auto", strgy="simple"): + async def fill(self, context, llm, schema="json", mode="auto", strgy="simple"): """Fill the node(s) with mode. :param context: Everything we should know when filling node. :param llm: Large Language Model with pre-defined system message. - :param to: json/markdown, determine example and output format. + :param schema: json/markdown, determine example and output format. - json: it's easy to open source LLM with json format - markdown: when generating code, markdown is always better :param mode: auto/children/root @@ -303,12 +303,12 @@ class ActionNode(Generic[T]): self.set_context(context) if strgy == "simple": - return await self.simple_fill(to, mode) + return await self.simple_fill(schema, mode) elif strgy == "complex": # 这里隐式假设了拥有children tmp = {} for _, i in self.children.items(): - child = await i.simple_fill(to, mode) + child = await i.simple_fill(schema, mode) tmp.update(child.instruct_content.dict()) cls = self.create_children_class() self.instruct_content = cls(**tmp) diff --git a/metagpt/actions/design_api.py b/metagpt/actions/design_api.py index f757ca856..548725fde 100644 --- a/metagpt/actions/design_api.py +++ b/metagpt/actions/design_api.py @@ -81,12 +81,12 @@ class WriteDesign(Action): return ActionOutput(content=changed_files.json(), instruct_content=changed_files) async def _new_system_design(self, context, schema=CONFIG.prompt_schema): - node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, to=schema) + node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, schema=schema) return node async def _merge(self, prd_doc, system_design_doc, schema=CONFIG.prompt_schema): context = NEW_REQ_TEMPLATE.format(old_design=system_design_doc.content, context=prd_doc.content) - node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, to=schema) + node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, schema=schema) system_design_doc.content = node.instruct_content.json(ensure_ascii=False) return system_design_doc diff --git a/metagpt/actions/write_prd.py b/metagpt/actions/write_prd.py index 23925ff10..7c160fa89 100644 --- a/metagpt/actions/write_prd.py +++ b/metagpt/actions/write_prd.py @@ -121,7 +121,7 @@ class WritePRD(Action): # logger.info(rsp) project_name = CONFIG.project_name if CONFIG.project_name else "" context = CONTEXT_TEMPLATE.format(requirements=requirements, project_name=project_name) - node = await WRITE_PRD_NODE.fill(context=context, llm=self.llm, to=schema) + node = await WRITE_PRD_NODE.fill(context=context, llm=self.llm, schema=schema) await self._rename_workspace(node) return node @@ -134,7 +134,7 @@ class WritePRD(Action): if not CONFIG.project_name: CONFIG.project_name = Path(CONFIG.project_path).name prompt = NEW_REQ_TEMPLATE.format(requirements=new_requirement_doc.content, old_prd=prd_doc.content) - node = await WRITE_PRD_NODE.fill(context=prompt, llm=self.llm, to=schema) + node = await WRITE_PRD_NODE.fill(context=prompt, llm=self.llm, schema=schema) prd_doc.content = node.instruct_content.json(ensure_ascii=False) await self._rename_workspace(node) return prd_doc From d0382b0ba7dfa69c7aafb7f6619c81531637d728 Mon Sep 17 00:00:00 2001 From: geekan Date: Wed, 20 Dec 2023 00:34:57 +0800 Subject: [PATCH 073/167] refine devcontainer README --- .devcontainer/README.md | 41 ++++++++++++++++++----------------------- 1 file changed, 18 insertions(+), 23 deletions(-) diff --git a/.devcontainer/README.md b/.devcontainer/README.md index dd088aab1..be692c14d 100644 --- a/.devcontainer/README.md +++ b/.devcontainer/README.md @@ -1,39 +1,34 @@ -# Dev container +# Dev Container -This project includes a [dev container](https://containers.dev/), which lets you use a container as a full-featured dev environment. +This project includes a [Dev Container](https://containers.dev/), offering you a comprehensive and fully-featured development environment within a container. By leveraging the Dev Container configuration in this folder, you can seamlessly build and initiate MetaGPT locally. For detailed information, please refer to the main README in the home directory. -You can use the dev container configuration in this folder to build and start running MetaGPT locally! For more, refer to the main README under the home directory. -You can use it in [GitHub Codespaces](https://github.com/features/codespaces) or the [VS Code Dev Containers extension](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-containers). +You can utilize this Dev Container in [GitHub Codespaces](https://github.com/features/codespaces) or with the [VS Code Dev Containers extension](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-containers). ## GitHub Codespaces -Open in GitHub Codespaces +[![Open in GitHub Codespaces](https://github.com/codespaces/badge.svg)](https://codespaces.new/geekan/MetaGPT) -You may use the button above to open this repo in a Codespace +Click the button above to open this repository in a Codespace. For additional information, refer to the [GitHub documentation on creating a Codespace](https://docs.github.com/en/free-pro-team@latest/github/developing-online-with-codespaces/creating-a-codespace#creating-a-codespace). -For more info, check out the [GitHub documentation](https://docs.github.com/en/free-pro-team@latest/github/developing-online-with-codespaces/creating-a-codespace#creating-a-codespace). - ## VS Code Dev Containers -Open in Dev Containers +[![Open in Dev Containers](https://img.shields.io/static/v1?label=Dev%20Containers&message=Open&color=blue&logo=visualstudiocode)](https://vscode.dev/redirect?url=vscode://ms-vscode-remote.remote-containers/cloneInVolume?url=https://github.com/geekan/MetaGPT) -Note: If you click this link you will open the main repo and not your local cloned repo, you can use this link and replace with your username and cloned repo name: -https://vscode.dev/redirect?url=vscode://ms-vscode-remote.remote-containers/cloneInVolume?url=https://github.com/geekan/MetaGPT +Note: Clicking the link above opens the main repository. To open your local cloned repository, replace the URL with your username and cloned repository's name: `https://vscode.dev/redirect?url=vscode://ms-vscode-remote.remote-containers/cloneInVolume?url=https://github.com//` +If you have VS Code and Docker installed, use the button above to get started. This will prompt VS Code to install the Dev Containers extension if it's not already installed, clone the source code into a container volume, and set up a dev container for you. -If you already have VS Code and Docker installed, you can use the button above to get started. This will cause VS Code to automatically install the Dev Containers extension if needed, clone the source code into a container volume, and spin up a dev container for use. +Alternatively, follow these steps to open this repository in a container using the VS Code Dev Containers extension: -You can also follow these steps to open this repo in a container using the VS Code Dev Containers extension: +1. For first-time users of a development container, ensure your system meets the prerequisites (e.g., Docker installation) as outlined in the [getting started steps](https://aka.ms/vscode-remote/containers/getting-started). -1. If this is your first time using a development container, please ensure your system meets the pre-reqs (i.e. have Docker installed) in the [getting started steps](https://aka.ms/vscode-remote/containers/getting-started). - -2. Open a locally cloned copy of the code: - - - Fork and Clone this repository to your local filesystem. +2. To open a locally cloned copy of the code: + - Fork and clone this repository to your local file system. - Press F1 and select the **Dev Containers: Open Folder in Container...** command. - - Select the cloned copy of this folder, wait for the container to start, and try things out! + - Choose the cloned folder, wait for the container to initialize, and start exploring! -You can learn more in the [Dev Containers documentation](https://code.visualstudio.com/docs/devcontainers/containers). +Learn more in the [VS Code Dev Containers documentation](https://code.visualstudio.com/docs/devcontainers/containers). -## Tips and tricks +## Tips and Tricks -* If you are working with the same repository folder in a container and Windows, you'll want consistent line endings (otherwise you may see hundreds of changes in the SCM view). The `.gitattributes` file in the root of this repo will disable line ending conversion and should prevent this. See [tips and tricks](https://code.visualstudio.com/docs/devcontainers/tips-and-tricks#_resolving-git-line-ending-issues-in-containers-resulting-in-many-modified-files) for more info. -* If you'd like to review the contents of the image used in this dev container, you can check it out in the [devcontainers/images](https://github.com/devcontainers/images/tree/main/src/python) repo. +* When working with the same repository folder in both a container and on Windows, it's crucial to have consistent line endings to avoid numerous changes in the SCM view. The `.gitattributes` file in the root of this repository disables line ending conversion, helping to prevent this issue. For more information, see [resolving git line ending issues in containers](https://code.visualstudio.com/docs/devcontainers/tips-and-tricks#_resolving-git-line-ending-issues-in-containers-resulting-in-many-modified-files). + +* If you're curious about the contents of the image used in this Dev Container, you can review it in the [devcontainers/images](https://github.com/devcontainers/images/tree/main/src/python) repository. From 1a62148dc6ea684a9dc0da372dc5c1ba3ac785a9 Mon Sep 17 00:00:00 2001 From: geekan Date: Wed, 20 Dec 2023 00:35:15 +0800 Subject: [PATCH 074/167] add proper space --- .devcontainer/postCreateCommand.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.devcontainer/postCreateCommand.sh b/.devcontainer/postCreateCommand.sh index 46788e306..3901193cd 100644 --- a/.devcontainer/postCreateCommand.sh +++ b/.devcontainer/postCreateCommand.sh @@ -4,4 +4,4 @@ sudo npm install -g @mermaid-js/mermaid-cli # Step 2: Ensure that Python 3.9+ is installed on your system. You can check this by using: python --version -pip install -e. \ No newline at end of file +pip install -e . \ No newline at end of file From 6b235e536e6d5b2590db97cdcd4aece779227c13 Mon Sep 17 00:00:00 2001 From: geekan Date: Wed, 20 Dec 2023 00:39:35 +0800 Subject: [PATCH 075/167] .gitattributes: ensure lf --- .gitattributes | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/.gitattributes b/.gitattributes index 32555a806..7f1424434 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,2 +1,29 @@ +# HTML code is incorrectly calculated into statistics, so ignore them *.html linguist-detectable=false +# Auto detect text files and perform LF normalization +* text=auto eol=lf + +# Ensure shell scripts use LF (Linux style) line endings on Windows +*.sh text eol=lf + +# Treat specific binary files as binary and prevent line ending conversion +*.png binary +*.jpg binary +*.gif binary +*.ico binary + +# Preserve original line endings for specific document files +*.doc text eol=crlf +*.docx text eol=crlf +*.pdf binary + +# Ensure source code and script files use LF line endings +*.py text eol=lf +*.js text eol=lf +*.html text eol=lf +*.css text eol=lf + +# Specify custom diff driver for specific file types +*.md diff=markdown +*.json diff=json From efebc07e54374accd65c7a82c2c10fb4b1dfdb0a Mon Sep 17 00:00:00 2001 From: geekan Date: Wed, 20 Dec 2023 00:47:28 +0800 Subject: [PATCH 076/167] refine .gitignore and .pre-commit-config.yaml --- .gitignore | 8 +------- .pre-commit-config.yaml | 2 +- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/.gitignore b/.gitignore index 0ac318ff5..c12506b0e 100644 --- a/.gitignore +++ b/.gitignore @@ -144,24 +144,18 @@ cython_debug/ allure-report allure-results -# idea +# idea / vscode / macos .idea .DS_Store .vscode -log.txt -docs/scripts/set_env.sh key.yaml -output.json data -data/output_add.json data.ms examples/nb/ .chroma *~$* workspace/* -*.mmd tmp -output.wav metagpt/roles/idea_agent.py .aider* diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b1892a709..338f832ac 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,7 +2,7 @@ default_stages: [ commit ] # Install # 1. pip install pre-commit -# 2. pre-commit install(the first time you download the repo, it will be cached for future use) +# 2. pre-commit install repos: - repo: https://github.com/pycqa/isort rev: 5.11.5 From 3b7c2e48599b9837894de766eb7f6bb275752667 Mon Sep 17 00:00:00 2001 From: geekan Date: Wed, 20 Dec 2023 00:49:08 +0800 Subject: [PATCH 077/167] updating time of license --- LICENSE | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/LICENSE b/LICENSE index 5b0c000cd..67460e101 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ The MIT License -Copyright (c) Chenglin Wu +Copyright (c) 2023 Chenglin Wu Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal From 394055d7e6380b05f28ffaebf53b7ae50c9d79a6 Mon Sep 17 00:00:00 2001 From: geekan Date: Wed, 20 Dec 2023 00:53:36 +0800 Subject: [PATCH 078/167] align ruff.toml with black --- ruff.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ruff.toml b/ruff.toml index 7835865e0..21de5ee14 100644 --- a/ruff.toml +++ b/ruff.toml @@ -31,7 +31,7 @@ exclude = [ ] # Same as Black. -line-length = 119 +line-length = 120 # Allow unused variables when underscore-prefixed. dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" From 5c7c522c623e56efbc89e47adfa5b59ebf775754 Mon Sep 17 00:00:00 2001 From: geekan Date: Wed, 20 Dec 2023 00:54:29 +0800 Subject: [PATCH 079/167] uncomment fire in requirements.txt due to usage in the example --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 515a4d88b..f5ef63c58 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ channels==4.0.0 # docx==0.2.4 #faiss==1.5.3 faiss_cpu==1.7.4 -# fire==0.4.0 +fire==0.4.0 typer # godot==0.1.1 # google_api_python_client==2.93.0 From 66c0bce60bfffb3727f27554ee0cbb5d0fac8817 Mon Sep 17 00:00:00 2001 From: geekan Date: Wed, 20 Dec 2023 00:58:56 +0800 Subject: [PATCH 080/167] add proper space --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index c6e22989b..9eeacbccb 100644 --- a/Dockerfile +++ b/Dockerfile @@ -18,7 +18,7 @@ COPY . /app/metagpt WORKDIR /app/metagpt RUN mkdir workspace &&\ pip install --no-cache-dir -r requirements.txt &&\ - pip install -e. + pip install -e . # Running with an infinite loop using the tail command CMD ["sh", "-c", "tail -f /dev/null"] From 77ec9b823f985fc0f30bccb5a71b2eec18b77f1d Mon Sep 17 00:00:00 2001 From: geekan Date: Wed, 20 Dec 2023 00:59:23 +0800 Subject: [PATCH 081/167] remove duplicate string --- .dockerignore | 1 - 1 file changed, 1 deletion(-) diff --git a/.dockerignore b/.dockerignore index 2968dd34d..8c09eaf73 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,7 +1,6 @@ workspace tmp build -workspace dist data geckodriver.log From 68c8ef107347f713ee6f3433735374d175b98017 Mon Sep 17 00:00:00 2001 From: better629 Date: Wed, 20 Dec 2023 10:44:30 +0800 Subject: [PATCH 082/167] update ser&deser code --- metagpt/actions/action.py | 1 - metagpt/roles/role.py | 26 ++++-- metagpt/schema.py | 8 +- metagpt/startup.py | 37 +++++--- metagpt/utils/utils.py | 17 ++-- startup.py | 86 ------------------- .../serialize_deserialize/test_role.py | 2 +- .../serialize_deserialize/test_team.py | 14 ++- 8 files changed, 70 insertions(+), 121 deletions(-) delete mode 100644 startup.py diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index 570863388..8cba18945 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -20,7 +20,6 @@ from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.provider.postprecess.llm_output_postprecess import llm_output_postprecess from metagpt.utils.common import OutputParser from metagpt.utils.utils import general_after_log -from metagpt.utils.utils import import_class action_subclass_registry = {} diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 9b1e0bf94..09371ae08 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -39,7 +39,7 @@ from metagpt.provider.human_provider import HumanProvider from metagpt.schema import Message, MessageQueue from metagpt.utils.common import any_to_str from metagpt.utils.repair_llm_raw_output import extract_state_value_from_output -from metagpt.utils.utils import read_json_file, write_json_file, import_class +from metagpt.utils.utils import read_json_file, write_json_file, import_class, role_raise_decorator PREFIX_TEMPLATE = """You are a {profile}, named {name}, your goal is {goal}, and the constraint is {constraints}. """ @@ -137,6 +137,7 @@ class Role(BaseModel): # builtin variables recovered: bool = False # to tag if a recovered role + latest_observed_msg: Message = None # record the latest observed message when interrupted builtin_class_name: str = "" _private_attributes = { @@ -200,7 +201,6 @@ class Role(BaseModel): def _reset(self): object.__setattr__(self, "_states", []) object.__setattr__(self, "_actions", []) - # object.__setattr__(self, "_rc", RoleContext()) @property def _setting(self): @@ -210,7 +210,7 @@ class Role(BaseModel): stg_path = SERDESER_PATH.joinpath(f"team/environment/roles/{self.__class__.__name__}_{self.name}") \ if stg_path is None else stg_path - role_info = self.dict(exclude={"_rc": {"memory": True}, "_llm": True}) + role_info = self.dict(exclude={"_rc": {"memory": True, "msg_buffer": True}, "_llm": True}) role_info.update({ "role_class": self.__class__.__name__, "module_name": self.__module__ @@ -311,7 +311,7 @@ class Role(BaseModel): def _set_state(self, state: int): """Update the current state.""" self._rc.state = state - logger.debug(self._actions) + logger.debug(f"actions={self._actions}, state={state}") self._rc.todo = self._actions[self._rc.state] if state >= 0 else None def set_env(self, env: "Environment"): @@ -388,15 +388,30 @@ class Role(BaseModel): return msg + def _find_news(self, observed: list[Message], existed: list[Message]) -> list[Message]: + news = [] + # Warning, remove `id` here to make it work for recover + observed_pure = [msg.dict(exclude={"id": True}) for msg in observed] + existed_pure = [msg.dict(exclude={"id": True}) for msg in existed] + for idx, new in enumerate(observed_pure): + if new["cause_by"] in self._rc.watch and new not in existed_pure: + news.append(observed[idx]) + return news + async def _observe(self, ignore_memory=False) -> int: """Prepare new messages for processing from the message buffer and other sources.""" # Read unprocessed messages from the msg buffer. news = self._rc.msg_buffer.pop_all() + if self.recovered: + news = [self.latest_observed_msg] if self.latest_observed_msg else [] + else: + self.latest_observed_msg = news[-1] if len(news) > 0 else None # record the latest observed msg + # Store the read messages in your own memory to prevent duplicate processing. old_messages = [] if ignore_memory else self._rc.memory.get() self._rc.memory.add_batch(news) # Filter out messages of interest. - self._rc.news = [n for n in news if n.cause_by in self._rc.watch and n not in old_messages] + self._rc.news = self._find_news(news, old_messages) # Design Rules: # If you need to further categorize Message objects, you can do so using the Message.set_meta function. @@ -484,6 +499,7 @@ class Role(BaseModel): """A wrapper to return the most recent k memories of this role, return all when k=0""" return self._rc.memory.get(k=k) + @role_raise_decorator async def run(self, with_message=None): """Observe, and think and act based on the results of the observation""" if with_message: diff --git a/metagpt/schema.py b/metagpt/schema.py index 0ec9b5c60..0fdc24e02 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -26,7 +26,6 @@ from typing import Dict, List, Set, TypedDict, Optional, Any from pydantic import BaseModel, Field -from metagpt.actions import UserRequirement from metagpt.config import CONFIG from metagpt.const import ( MESSAGE_ROUTE_CAUSE_BY, @@ -118,8 +117,9 @@ class Message(BaseModel): ic_new = ic_obj(**ic["value"]) kwargs["instruct_content"] = ic_new - kwargs["id"] = uuid.uuid4().hex - kwargs["cause_by"] = any_to_str(kwargs.get("cause_by", UserRequirement)) + kwargs["id"] = kwargs.get("id", uuid.uuid4().hex) + kwargs["cause_by"] = any_to_str(kwargs.get("cause_by", + import_class("UserRequirement", "metagpt.actions.add_requirement"))) kwargs["sent_from"] = any_to_str(kwargs.get("sent_from", "")) kwargs["send_to"] = any_to_str_set(kwargs.get("send_to", {MESSAGE_ROUTE_TO_ALL})) super(Message, self).__init__(**kwargs) @@ -218,7 +218,7 @@ class MessageQueue(BaseModel): if key in kwargs: object.__setattr__(self, key, kwargs[key]) else: - object.__setattr__(self, key, self._private_attributes[key]) + object.__setattr__(self, key, Queue()) def pop(self) -> Message | None: """Pop one message from the queue.""" diff --git a/metagpt/startup.py b/metagpt/startup.py index f930c386b..17eb26665 100644 --- a/metagpt/startup.py +++ b/metagpt/startup.py @@ -27,8 +27,10 @@ def startup( reqa_file: str = typer.Option(default="", help="Specify the source file name for rewriting the quality test code."), max_auto_summarize_code: int = typer.Option( default=-1, - help="The maximum number of times the 'SummarizeCode' action is automatically invoked, with -1 indicating unlimited. This parameter is used for debugging the workflow.", + help="The maximum number of times the 'SummarizeCode' action is automatically invoked, " + "with -1 indicating unlimited. This parameter is used for debugging the workflow.", ), + recover_path: str = typer.Option(default=None, help="recover the project from existing serialized storage") ): """Run a startup. Be a boss.""" from metagpt.roles import ( @@ -50,20 +52,29 @@ def startup( CONFIG.reqa_file = reqa_file CONFIG.max_auto_summarize_code = max_auto_summarize_code - company = Team() - company.hire( - [ - ProductManager(), - Architect(), - ProjectManager(), - ] - ) + if not recover_path: + company = Team() + company.hire( + [ + ProductManager(), + Architect(), + ProjectManager(), + ] + ) - if implement or code_review: - company.hire([Engineer(n_borg=5, use_code_review=code_review)]) + if implement or code_review: + company.hire([Engineer(n_borg=5, use_code_review=code_review)]) - if run_tests: - company.hire([QaEngineer()]) + if run_tests: + company.hire([QaEngineer()]) + else: + # # stg_path = SERDESER_PATH.joinpath("team") + stg_path = Path(recover_path) + if not stg_path.exists() or not str(stg_path).endswith("team"): + raise FileNotFoundError(f"{recover_path} not exists or not endswith `team`") + + company = Team.recover(stg_path=stg_path) + idea = company.idea # use original idea company.invest(investment) company.run_project(idea) diff --git a/metagpt/utils/utils.py b/metagpt/utils/utils.py index 57da57b00..aa7c039c4 100644 --- a/metagpt/utils/utils.py +++ b/metagpt/utils/utils.py @@ -88,18 +88,15 @@ def role_raise_decorator(func): return await func(self, *args, **kwargs) except KeyboardInterrupt as kbi: logger.error(f"KeyboardInterrupt: {kbi} occurs, start to serialize the project") - if self._rc.env: - newest_msgs = self._rc.env.memory.get(1) - if len(newest_msgs) > 0: - self._rc.memory.delete(newest_msgs[0]) + if self.latest_observed_msg: + self._rc.memory.delete(self.latest_observed_msg) raise Exception(format_trackback_info(limit=None)) # raise again to make it captured outside except Exception as exp: - if self._rc.env: - newest_msgs = self._rc.env.memory.get(1) - if len(newest_msgs) > 0: - logger.warning("There is a exception in role's execution, in order to resume, " - "we delete the newest role communication message in the role's memory.") - self._rc.memory.delete(newest_msgs[0]) # remove newest msg of the role to make it observed again + if self.latest_observed_msg: + logger.warning("There is a exception in role's execution, in order to resume, " + "we delete the newest role communication message in the role's memory.") + # remove role newest observed msg to make it observed again + self._rc.memory.delete(self.latest_observed_msg) raise Exception(format_trackback_info(limit=None)) # raise again to make it captured outside return wrapper diff --git a/startup.py b/startup.py deleted file mode 100644 index c4928a1b5..000000000 --- a/startup.py +++ /dev/null @@ -1,86 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -from typing import Optional -import asyncio -import fire -from pathlib import Path - -from metagpt.roles import ( - Architect, - Engineer, - ProductManager, - ProjectManager, - QaEngineer, -) -from metagpt.team import Team - - -async def startup( - idea: str, - investment: float = 3.0, - n_round: int = 5, - code_review: bool = False, - run_tests: bool = False, - implement: bool = True, - recover_path: Optional[str] = None, -): - """Run a startup. Be a boss.""" - if not recover_path: - company = Team() - company.hire( - [ - ProductManager(), - Architect(), - ProjectManager(), - ] - ) - - # if implement or code_review - if implement or code_review: - # developing features: implement the idea - company.hire([Engineer(n_borg=5, use_code_review=code_review)]) - - if run_tests: - # developing features: run tests on the spot and identify bugs - # (bug fixing capability comes soon!) - company.hire([QaEngineer()]) - else: - # # stg_path = SERDESER_PATH.joinpath("team") - stg_path = Path(recover_path) - if not stg_path.exists() or not str(stg_path).endswith("team"): - raise FileNotFoundError(f"{recover_path} not exists or not endswith `team`") - - company = Team.recover(stg_path=stg_path) - idea = company.idea # use original idea - - company.invest(investment) - company.start_project(idea) - await company.run(n_round=n_round) - - -def main( - idea: str, - investment: float = 3.0, - n_round: int = 5, - code_review: bool = True, - run_tests: bool = False, - implement: bool = True, - recover_path: str = None, -): - """ - We are a software startup comprised of AI. By investing in us, - you are empowering a future filled with limitless possibilities. - :param idea: Your innovative idea, such as "Creating a snake game." - :param investment: As an investor, you have the opportunity to contribute - a certain dollar amount to this AI company. - :param n_round: - :param code_review: Whether to use code review. - :param recover_path: recover the project from existing serialized storage - :return: - """ - asyncio.run(startup(idea, investment, n_round, code_review, run_tests, implement, recover_path)) - - -if __name__ == "__main__": - fire.Fire(main) diff --git a/tests/metagpt/serialize_deserialize/test_role.py b/tests/metagpt/serialize_deserialize/test_role.py index f25403dc0..87cf75caa 100644 --- a/tests/metagpt/serialize_deserialize/test_role.py +++ b/tests/metagpt/serialize_deserialize/test_role.py @@ -84,7 +84,7 @@ async def test_role_serdeser_interrupt(): logger.error(f"Exception in `role_a.run`, detail: {format_trackback_info()}") role_c.serialize(stg_path) - assert role_c._rc.memory.count() == 2 + assert role_c._rc.memory.count() == 1 new_role_a: Role = Role.deserialize(stg_path) assert new_role_a._rc.state == 1 diff --git a/tests/metagpt/serialize_deserialize/test_team.py b/tests/metagpt/serialize_deserialize/test_team.py index 01e0a6c70..e87df9b52 100644 --- a/tests/metagpt/serialize_deserialize/test_team.py +++ b/tests/metagpt/serialize_deserialize/test_team.py @@ -106,11 +106,23 @@ async def test_team_recover_multi_roles_save(): stg_path = SERDESER_PATH.joinpath("team") shutil.rmtree(stg_path, ignore_errors=True) + role_a = RoleA() + role_b = RoleB() + + assert role_a.subscription == {"tests.metagpt.serialize_deserialize.test_serdeser_base.RoleA", + "RoleA"} + assert role_b.subscription == {"tests.metagpt.serialize_deserialize.test_serdeser_base.RoleB", + "RoleB"} + assert role_b._rc.watch == {"tests.metagpt.serialize_deserialize.test_serdeser_base.ActionPass"} + company = Team() - company.hire([RoleA(), RoleB()]) + company.hire([role_a, role_b]) company.run_project(idea) await company.run(n_round=4) new_company = Team.recover(stg_path) new_company.run_project(idea) + + assert new_company.env.get_role(role_b.profile)._rc.state == 1 + await new_company.run(n_round=4) From 32af743b36a8e31cf3c4a063a2869ea7da40a6f8 Mon Sep 17 00:00:00 2001 From: better629 Date: Wed, 20 Dec 2023 10:54:49 +0800 Subject: [PATCH 083/167] rm metagpt/utils/utils.py --- metagpt/actions/action.py | 4 +- metagpt/environment.py | 3 +- metagpt/memory/memory.py | 3 +- .../postprecess/base_postprecess_plugin.py | 2 +- metagpt/roles/role.py | 3 +- metagpt/schema.py | 3 +- metagpt/team.py | 3 +- metagpt/utils/common.py | 99 ++++++++++++++++- metagpt/utils/repair_llm_raw_output.py | 2 +- metagpt/utils/serialize.py | 2 +- metagpt/utils/utils.py | 102 ------------------ .../serialize_deserialize/test_role.py | 2 +- 12 files changed, 109 insertions(+), 119 deletions(-) delete mode 100644 metagpt/utils/utils.py diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index 8cba18945..9c7fb06e1 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -18,8 +18,8 @@ from metagpt.llm import LLM from metagpt.logs import logger from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.provider.postprecess.llm_output_postprecess import llm_output_postprecess -from metagpt.utils.common import OutputParser -from metagpt.utils.utils import general_after_log +from metagpt.utils.common import OutputParser, general_after_log + action_subclass_registry = {} diff --git a/metagpt/environment.py b/metagpt/environment.py index 9108cdf06..a3cbe6978 100644 --- a/metagpt/environment.py +++ b/metagpt/environment.py @@ -20,8 +20,7 @@ from pydantic import BaseModel, Field from metagpt.logs import logger from metagpt.roles.role import Role, role_subclass_registry from metagpt.schema import Message -from metagpt.utils.common import is_subscribed -from metagpt.utils.utils import read_json_file, write_json_file +from metagpt.utils.common import is_subscribed, read_json_file, write_json_file class Environment(BaseModel): diff --git a/metagpt/memory/memory.py b/metagpt/memory/memory.py index 198c0970d..66ab5d4e9 100644 --- a/metagpt/memory/memory.py +++ b/metagpt/memory/memory.py @@ -14,8 +14,7 @@ from typing import Iterable, Set from pydantic import BaseModel, Field from metagpt.schema import Message -from metagpt.utils.common import any_to_str, any_to_str_set -from metagpt.utils.utils import read_json_file, write_json_file +from metagpt.utils.common import any_to_str, any_to_str_set, read_json_file, write_json_file class Memory(BaseModel): diff --git a/metagpt/provider/postprecess/base_postprecess_plugin.py b/metagpt/provider/postprecess/base_postprecess_plugin.py index 0d1cfbb11..afcef2531 100644 --- a/metagpt/provider/postprecess/base_postprecess_plugin.py +++ b/metagpt/provider/postprecess/base_postprecess_plugin.py @@ -44,7 +44,7 @@ class BasePostPrecessPlugin(object): def run_retry_parse_json_text(self, content: str) -> Union[dict, list]: """inherited class can re-implement the function""" - logger.info(f"extracted json CONTENT from output:\n{content}") + # logger.info(f"extracted json CONTENT from output:\n{content}") parsed_data = retry_parse_json_text(output=content) # should use output=content return parsed_data diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 09371ae08..efe3bcbd4 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -37,9 +37,8 @@ from metagpt.memory import Memory from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.provider.human_provider import HumanProvider from metagpt.schema import Message, MessageQueue -from metagpt.utils.common import any_to_str +from metagpt.utils.common import any_to_str, read_json_file, write_json_file, import_class, role_raise_decorator from metagpt.utils.repair_llm_raw_output import extract_state_value_from_output -from metagpt.utils.utils import read_json_file, write_json_file, import_class, role_raise_decorator PREFIX_TEMPLATE = """You are a {profile}, named {name}, your goal is {goal}, and the constraint is {constraints}. """ diff --git a/metagpt/schema.py b/metagpt/schema.py index 0fdc24e02..1c1fdd94d 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -36,10 +36,9 @@ from metagpt.const import ( TASK_FILE_REPO, ) from metagpt.logs import logger -from metagpt.utils.common import any_to_str, any_to_str_set +from metagpt.utils.common import any_to_str, any_to_str_set, import_class from metagpt.utils.serialize import actionoutout_schema_to_mapping, actionoutput_mapping_to_str, \ actionoutput_str_to_mapping -from metagpt.utils.utils import import_class class RawMessage(TypedDict): diff --git a/metagpt/team.py b/metagpt/team.py index 30e3dc618..383f2da36 100644 --- a/metagpt/team.py +++ b/metagpt/team.py @@ -19,8 +19,7 @@ from metagpt.environment import Environment from metagpt.logs import logger from metagpt.roles import Role from metagpt.schema import Message -from metagpt.utils.common import NoMoneyException -from metagpt.utils.utils import read_json_file, write_json_file, serialize_decorator +from metagpt.utils.common import NoMoneyException, read_json_file, write_json_file, serialize_decorator class Team(BaseModel): diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index a9bdd6e2d..c909180cc 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -13,12 +13,21 @@ from __future__ import annotations import ast import contextlib +import importlib import inspect +import json import os import platform import re +import traceback +import typing +from pathlib import Path +from typing import Any from typing import List, Tuple, Union +from pydantic.json import pydantic_encoder +from tenacity import _utils + from metagpt.const import MESSAGE_ROUTE_TO_ALL from metagpt.logs import logger @@ -184,7 +193,7 @@ class OutputParser: if start_index != -1 and end_index != -1: # Extract the structure part - structure_text = text[start_index : end_index + 1] + structure_text = text[start_index: end_index + 1] try: # Attempt to convert the text to a Python data type using ast.literal_eval @@ -363,3 +372,91 @@ def is_subscribed(message, tags): if t in message.send_to: return True return False + + +def general_after_log(logger: "loguru.Logger", sec_format: str = "%0.3f") -> typing.Callable[["RetryCallState"], None]: + def log_it(retry_state: "RetryCallState") -> None: + if retry_state.fn is None: + fn_name = "" + else: + fn_name = _utils.get_callback_name(retry_state.fn) + logger.error( + f"Finished call to '{fn_name}' after {sec_format % retry_state.seconds_since_start}(s), " + f"this was the {_utils.to_ordinal(retry_state.attempt_number)} time calling it. " + f"exp: {retry_state.outcome.exception()}" + ) + + return log_it + + +def read_json_file(json_file: str, encoding=None) -> list[Any]: + if not Path(json_file).exists(): + raise FileNotFoundError(f"json_file: {json_file} not exist, return []") + + with open(json_file, "r", encoding=encoding) as fin: + try: + data = json.load(fin) + except Exception as exp: + raise ValueError(f"read json file: {json_file} failed") + return data + + +def write_json_file(json_file: str, data: list, encoding=None): + folder_path = Path(json_file).parent + if not folder_path.exists(): + folder_path.mkdir(parents=True, exist_ok=True) + + with open(json_file, "w", encoding=encoding) as fout: + json.dump(data, fout, ensure_ascii=False, indent=4, default=pydantic_encoder) + + +def import_class(class_name: str, module_name: str) -> type: + module = importlib.import_module(module_name) + a_class = getattr(module, class_name) + return a_class + + +def import_class_inst(class_name: str, module_name: str, *args, **kwargs) -> object: + a_class = import_class(class_name, module_name) + class_inst = a_class(*args, **kwargs) + return class_inst + + +def format_trackback_info(limit: int = 2): + return traceback.format_exc(limit=limit) + + +def serialize_decorator(func): + async def wrapper(self, *args, **kwargs): + try: + result = await func(self, *args, **kwargs) + self.serialize() # Team.serialize + return result + except KeyboardInterrupt as kbi: + logger.error(f"KeyboardInterrupt occurs, start to serialize the project, exp:\n{format_trackback_info()}") + self.serialize() # Team.serialize + except Exception as exp: + logger.error(f"Exception occurs, start to serialize the project, exp:\n{format_trackback_info()}") + self.serialize() # Team.serialize + + return wrapper + + +def role_raise_decorator(func): + async def wrapper(self, *args, **kwargs): + try: + return await func(self, *args, **kwargs) + except KeyboardInterrupt as kbi: + logger.error(f"KeyboardInterrupt: {kbi} occurs, start to serialize the project") + if self.latest_observed_msg: + self._rc.memory.delete(self.latest_observed_msg) + raise Exception(format_trackback_info(limit=None)) # raise again to make it captured outside + except Exception as exp: + if self.latest_observed_msg: + logger.warning("There is a exception in role's execution, in order to resume, " + "we delete the newest role communication message in the role's memory.") + # remove role newest observed msg to make it observed again + self._rc.memory.delete(self.latest_observed_msg) + raise Exception(format_trackback_info(limit=None)) # raise again to make it captured outside + + return wrapper diff --git a/metagpt/utils/repair_llm_raw_output.py b/metagpt/utils/repair_llm_raw_output.py index 4aafd8e66..67ad4e963 100644 --- a/metagpt/utils/repair_llm_raw_output.py +++ b/metagpt/utils/repair_llm_raw_output.py @@ -253,7 +253,7 @@ def retry_parse_json_text(output: str) -> Union[list, dict]: if CONFIG.repair_llm_output is True, the _aask_v1 and the retry_parse_json_text will loop for {x=3*3} times. it's a two-layer retry cycle """ - logger.debug(f"output to json decode:\n{output}") + # logger.debug(f"output to json decode:\n{output}") # if CONFIG.repair_llm_output is True, it will try to fix output until the retry break parsed_data = CustomDecoder(strict=False).decode(output) diff --git a/metagpt/utils/serialize.py b/metagpt/utils/serialize.py index 93f584057..9a758da34 100644 --- a/metagpt/utils/serialize.py +++ b/metagpt/utils/serialize.py @@ -5,7 +5,7 @@ import copy import pickle -from metagpt.utils.utils import import_class +from metagpt.utils.common import import_class def actionoutout_schema_to_mapping(schema: dict) -> dict: diff --git a/metagpt/utils/utils.py b/metagpt/utils/utils.py deleted file mode 100644 index aa7c039c4..000000000 --- a/metagpt/utils/utils.py +++ /dev/null @@ -1,102 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# @Desc : - -import typing -from typing import Any -import json -from pathlib import Path -import importlib -from tenacity import _utils -import traceback -from pydantic.json import pydantic_encoder - -from metagpt.logs import logger - - -def general_after_log(logger: "loguru.Logger", sec_format: str = "%0.3f") -> typing.Callable[["RetryCallState"], None]: - def log_it(retry_state: "RetryCallState") -> None: - if retry_state.fn is None: - fn_name = "" - else: - fn_name = _utils.get_callback_name(retry_state.fn) - logger.error( - f"Finished call to '{fn_name}' after {sec_format % retry_state.seconds_since_start}(s), " - f"this was the {_utils.to_ordinal(retry_state.attempt_number)} time calling it. " - f"exp: {retry_state.outcome.exception()}" - ) - - return log_it - - -def read_json_file(json_file: str, encoding=None) -> list[Any]: - if not Path(json_file).exists(): - raise FileNotFoundError(f"json_file: {json_file} not exist, return []") - - with open(json_file, "r", encoding=encoding) as fin: - try: - data = json.load(fin) - except Exception as exp: - raise ValueError(f"read json file: {json_file} failed") - return data - - -def write_json_file(json_file: str, data: list, encoding=None): - folder_path = Path(json_file).parent - if not folder_path.exists(): - folder_path.mkdir(parents=True, exist_ok=True) - - with open(json_file, "w", encoding=encoding) as fout: - json.dump(data, fout, ensure_ascii=False, indent=4, default=pydantic_encoder) - - -def import_class(class_name: str, module_name: str) -> type: - module = importlib.import_module(module_name) - a_class = getattr(module, class_name) - return a_class - - -def import_class_inst(class_name: str, module_name: str, *args, **kwargs) -> object: - a_class = import_class(class_name, module_name) - class_inst = a_class(*args, **kwargs) - return class_inst - - -def format_trackback_info(limit: int = 2): - return traceback.format_exc(limit=limit) - - -def serialize_decorator(func): - async def wrapper(self, *args, **kwargs): - try: - result = await func(self, *args, **kwargs) - self.serialize() # Team.serialize - return result - except KeyboardInterrupt as kbi: - logger.error(f"KeyboardInterrupt occurs, start to serialize the project, exp:\n{format_trackback_info()}") - self.serialize() # Team.serialize - except Exception as exp: - logger.error(f"Exception occurs, start to serialize the project, exp:\n{format_trackback_info()}") - self.serialize() # Team.serialize - - return wrapper - - -def role_raise_decorator(func): - async def wrapper(self, *args, **kwargs): - try: - return await func(self, *args, **kwargs) - except KeyboardInterrupt as kbi: - logger.error(f"KeyboardInterrupt: {kbi} occurs, start to serialize the project") - if self.latest_observed_msg: - self._rc.memory.delete(self.latest_observed_msg) - raise Exception(format_trackback_info(limit=None)) # raise again to make it captured outside - except Exception as exp: - if self.latest_observed_msg: - logger.warning("There is a exception in role's execution, in order to resume, " - "we delete the newest role communication message in the role's memory.") - # remove role newest observed msg to make it observed again - self._rc.memory.delete(self.latest_observed_msg) - raise Exception(format_trackback_info(limit=None)) # raise again to make it captured outside - - return wrapper diff --git a/tests/metagpt/serialize_deserialize/test_role.py b/tests/metagpt/serialize_deserialize/test_role.py index 87cf75caa..88c7f7d8b 100644 --- a/tests/metagpt/serialize_deserialize/test_role.py +++ b/tests/metagpt/serialize_deserialize/test_role.py @@ -15,7 +15,7 @@ from metagpt.roles.engineer import Engineer from metagpt.roles.product_manager import ProductManager from metagpt.roles.role import Role from metagpt.schema import Message -from metagpt.utils.utils import format_trackback_info +from metagpt.utils.common import format_trackback_info from tests.metagpt.serialize_deserialize.test_serdeser_base import RoleA, RoleB, RoleC, serdeser_path From b3750d5947894779fbaff392b242e722e57a05d6 Mon Sep 17 00:00:00 2001 From: geekan Date: Wed, 20 Dec 2023 11:52:11 +0800 Subject: [PATCH 084/167] refine code for prepare document. remove useless logic --- metagpt/actions/prepare_documents.py | 29 ++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/metagpt/actions/prepare_documents.py b/metagpt/actions/prepare_documents.py index 8d3445ae4..3c0885954 100644 --- a/metagpt/actions/prepare_documents.py +++ b/metagpt/actions/prepare_documents.py @@ -12,28 +12,29 @@ from pathlib import Path from metagpt.actions import Action, ActionOutput from metagpt.config import CONFIG -from metagpt.const import DEFAULT_WORKSPACE_ROOT, DOCS_FILE_REPO, REQUIREMENT_FILENAME +from metagpt.const import DOCS_FILE_REPO, REQUIREMENT_FILENAME from metagpt.schema import Document from metagpt.utils.file_repository import FileRepository from metagpt.utils.git_repository import GitRepository class PrepareDocuments(Action): - def __init__(self, name="", context=None, llm=None): - super().__init__(name, context, llm) + """PrepareDocuments Action: initialize project folder and add new requirements to docs/requirements.txt.""" + + def _init_repo(self): + """Initialize the Git environment.""" + path = CONFIG.project_path + if not path: + name = CONFIG.project_name or FileRepository.new_filename() + path = Path(CONFIG.workspace_path) / name + + if path.exists() and not CONFIG.inc: + shutil.rmtree(path) + CONFIG.git_repo = GitRepository(local_path=path, auto_init=True) async def run(self, with_messages, **kwargs): - if not CONFIG.git_repo: - # Create and initialize the workspace folder, initialize the Git environment. - project_name = CONFIG.project_name or FileRepository.new_filename() - workdir = CONFIG.project_path - if not workdir and CONFIG.workspace_path: - workdir = Path(CONFIG.workspace_path) / project_name - workdir = Path(workdir or DEFAULT_WORKSPACE_ROOT / project_name) - if not CONFIG.inc and workdir.exists(): - shutil.rmtree(workdir) - CONFIG.git_repo = GitRepository() - CONFIG.git_repo.open(local_path=workdir, auto_init=True) + """Create and initialize the workspace folder, initialize the Git environment.""" + self._init_repo() # Write the newly added requirements from the main parameter idea to `docs/requirement.txt`. doc = Document(root_path=DOCS_FILE_REPO, filename=REQUIREMENT_FILENAME, content=with_messages[0].content) From f365348f49815c85fe4ca163647e66ad56ccd73f Mon Sep 17 00:00:00 2001 From: geekan Date: Wed, 20 Dec 2023 11:59:59 +0800 Subject: [PATCH 085/167] add .pylintrc --- docs/.pylintrc | 639 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 639 insertions(+) create mode 100644 docs/.pylintrc diff --git a/docs/.pylintrc b/docs/.pylintrc new file mode 100644 index 000000000..9e8488bc7 --- /dev/null +++ b/docs/.pylintrc @@ -0,0 +1,639 @@ +[MAIN] + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + +# Clear in-memory caches upon conclusion of linting. Useful if running pylint +# in a server-like mode. +clear-cache-post-run=no + +# Load and enable all available extensions. Use --list-extensions to see a list +# all available extensions. +#enable-all-extensions= + +# In error mode, messages with a category besides ERROR or FATAL are +# suppressed, and no reports are done by default. Error mode is compatible with +# disabling specific errors. +#errors-only= + +# Always return a 0 (non-error) status code, even if lint errors are found. +# This is primarily useful in continuous integration scripts. +#exit-zero= + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. +extension-pkg-allow-list= + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. (This is an alternative name to extension-pkg-allow-list +# for backward compatibility.) +extension-pkg-whitelist=pydantic + +# Return non-zero exit code if any of these messages/categories are detected, +# even if score is above --fail-under value. Syntax same as enable. Messages +# specified are enabled, while categories only check already-enabled messages. +fail-on= + +# Specify a score threshold under which the program will exit with error. +fail-under=10 + +# Interpret the stdin as a python script, whose filename needs to be passed as +# the module_or_package argument. +#from-stdin= + +# Files or directories to be skipped. They should be base names, not paths. +ignore=CVS + +# Add files or directories matching the regular expressions patterns to the +# ignore-list. The regex matches against paths and can be in Posix or Windows +# format. Because '\\' represents the directory delimiter on Windows systems, +# it can't be used as an escape character. +ignore-paths= + +# Files or directories matching the regular expression patterns are skipped. +# The regex matches against base names, not paths. The default value ignores +# Emacs file locks +#ignore-patterns=^\.# +ignore-patterns=(.)*_test\.py,test_(.)*\.py + + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis). It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules= + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +#init-hook= + +# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the +# number of processors available to use, and will cap the count on Windows to +# avoid hangs. +jobs=1 + +# Control the amount of potential inferred values when inferring a single +# object. This can help the performance when dealing with large functions or +# complex, nested conditions. +limit-inference-results=120 + +# List of plugins (as comma separated values of python module names) to load, +# usually to register additional checkers. +load-plugins= + +# Pickle collected data for later comparisons. +persistent=yes + +# Minimum Python version to use for version dependent checks. Will default to +# the version used to run pylint. +py-version=3.9 + +# Discover python modules and packages in the file system subtree. +recursive=no + +# Add paths to the list of the source roots. Supports globbing patterns. The +# source root is an absolute path or a path relative to the current working +# directory used to determine a package namespace for modules located under the +# source root. +source-roots= + +# When enabled, pylint would attempt to guess common misconfiguration and emit +# user-friendly hints instead of false-positive error messages. +suggestion-mode=yes + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + +# In verbose mode, extra non-checker-related info will be displayed. +#verbose= + + +[BASIC] + +# Naming style matching correct argument names. +argument-naming-style=snake_case + +# Regular expression matching correct argument names. Overrides argument- +# naming-style. If left empty, argument names will be checked with the set +# naming style. +#argument-rgx= + +# Naming style matching correct attribute names. +attr-naming-style=snake_case + +# Regular expression matching correct attribute names. Overrides attr-naming- +# style. If left empty, attribute names will be checked with the set naming +# style. +#attr-rgx= + +# Bad variable names which should always be refused, separated by a comma. +bad-names=foo, + bar, + baz, + toto, + tutu, + tata + +# Bad variable names regexes, separated by a comma. If names match any regex, +# they will always be refused +bad-names-rgxs= + +# Naming style matching correct class attribute names. +class-attribute-naming-style=any + +# Regular expression matching correct class attribute names. Overrides class- +# attribute-naming-style. If left empty, class attribute names will be checked +# with the set naming style. +#class-attribute-rgx= + +# Naming style matching correct class constant names. +class-const-naming-style=UPPER_CASE + +# Regular expression matching correct class constant names. Overrides class- +# const-naming-style. If left empty, class constant names will be checked with +# the set naming style. +#class-const-rgx= + +# Naming style matching correct class names. +class-naming-style=PascalCase + +# Regular expression matching correct class names. Overrides class-naming- +# style. If left empty, class names will be checked with the set naming style. +#class-rgx= + +# Naming style matching correct constant names. +const-naming-style=UPPER_CASE + +# Regular expression matching correct constant names. Overrides const-naming- +# style. If left empty, constant names will be checked with the set naming +# style. +#const-rgx= + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=-1 + +# Naming style matching correct function names. +function-naming-style=snake_case + +# Regular expression matching correct function names. Overrides function- +# naming-style. If left empty, function names will be checked with the set +# naming style. +#function-rgx= + +# Good variable names which should always be accepted, separated by a comma. +good-names=i, + j, + k, + v, + e, + d, + m, + df, + ex, + Run, + _ + +# Good variable names regexes, separated by a comma. If names match any regex, +# they will always be accepted +good-names-rgxs= + +# Include a hint for the correct naming format with invalid-name. +include-naming-hint=no + +# Naming style matching correct inline iteration names. +inlinevar-naming-style=any + +# Regular expression matching correct inline iteration names. Overrides +# inlinevar-naming-style. If left empty, inline iteration names will be checked +# with the set naming style. +#inlinevar-rgx= + +# Naming style matching correct method names. +method-naming-style=snake_case + +# Regular expression matching correct method names. Overrides method-naming- +# style. If left empty, method names will be checked with the set naming style. +#method-rgx= + +# Naming style matching correct module names. +module-naming-style=snake_case + +# Regular expression matching correct module names. Overrides module-naming- +# style. If left empty, module names will be checked with the set naming style. +#module-rgx= + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=^_ + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +# These decorators are taken in consideration only for invalid-name. +property-classes=abc.abstractproperty + +# Regular expression matching correct type alias names. If left empty, type +# alias names will be checked with the set naming style. +#typealias-rgx= + +# Regular expression matching correct type variable names. If left empty, type +# variable names will be checked with the set naming style. +#typevar-rgx= + +# Naming style matching correct variable names. +variable-naming-style=snake_case + +# Regular expression matching correct variable names. Overrides variable- +# naming-style. If left empty, variable names will be checked with the set +# naming style. +#variable-rgx= + + +[CLASSES] + +# Warn about protected attribute access inside special methods +check-protected-access-in-special-methods=no + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp, + __post_init__ + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict,_fields,_replace,_source,_make,os._exit + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=mcs + + +[DESIGN] + +# List of regular expressions of class ancestor names to ignore when counting +# public methods (see R0903) +exclude-too-few-public-methods= + +# List of qualified class names to ignore when counting class parents (see +# R0901) +ignored-parents= + +# Maximum number of arguments for function / method. +max-args=5 + +# Maximum number of attributes for a class (see R0902). +max-attributes=7 + +# Maximum number of boolean expressions in an if statement (see R0916). +max-bool-expr=5 + +# Maximum number of branch for function / method body. +max-branches=12 + +# Maximum number of locals for function / method body. +max-locals=15 + +# Maximum number of parents for a class (see R0901). +max-parents=7 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + +# Maximum number of return / yield for function / method body. +max-returns=6 + +# Maximum number of statements in function / method body. +max-statements=50 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=2 + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when caught. +overgeneral-exceptions=builtins.BaseException,builtins.Exception + + +[FORMAT] + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=^\s*(# )??$ + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + +# Maximum number of characters on a single line. +max-line-length=120 + +# Maximum number of lines in a module. +max-module-lines=1000 + +# Allow the body of a class to be on the same line as the declaration if body +# contains single statement. +single-line-class-stmt=no + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=no + + +[IMPORTS] + +# List of modules that can be imported at any level, not just the top level +# one. +allow-any-import-level= + +# Allow explicit reexports by alias from a package __init__. +allow-reexport-from-package=no + +# Allow wildcard imports from modules that define __all__. +allow-wildcard-with-all=no + +# Deprecated modules which should not be used, separated by a comma. +deprecated-modules= + +# Output a graph (.gv or any supported image format) of external dependencies +# to the given file (report RP0402 must not be disabled). +ext-import-graph= + +# Output a graph (.gv or any supported image format) of all (i.e. internal and +# external) dependencies to the given file (report RP0402 must not be +# disabled). +import-graph= + +# Output a graph (.gv or any supported image format) of internal dependencies +# to the given file (report RP0402 must not be disabled). +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant + +# Couples of modules and preferred modules, separated by a comma. +preferred-modules= + + +[LOGGING] + +# The type of string formatting that logging methods do. `old` means using % +# formatting, `new` is for `{}` formatting. +logging-format-style=old + +# Logging modules to check that the string format arguments are in logging +# function parameter format. +logging-modules=logging + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE, +# UNDEFINED. +confidence=HIGH, + CONTROL_FLOW, + INFERENCE, + INFERENCE_FAILURE, + UNDEFINED + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once). You can also use "--disable=all" to +# disable everything first and then re-enable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use "--disable=all --enable=classes +# --disable=W". +disable=raw-checker-failed, + bad-inline-option, + locally-disabled, + file-ignored, + suppressed-message, + useless-suppression, + deprecated-pragma, + use-symbolic-message-instead, + expression-not-assigned, + pointless-statement + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +enable=c-extension-no-member + + +[METHOD_ARGS] + +# List of qualified names (i.e., library.method) which require a timeout +# parameter e.g. 'requests.api.get,requests.api.post' +timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME, + XXX, + TODO + +# Regular expression of note tags to take in consideration. +notes-rgx= + + +[REFACTORING] + +# Maximum number of nested blocks for function / method body +max-nested-blocks=5 + +# Complete name of functions that never returns. When checking for +# inconsistent-return-statements if a never returning function is called then +# it will be considered as an explicit return statement and no message will be +# printed. +never-returning-functions=sys.exit,argparse.parse_error + + +[REPORTS] + +# Python expression which should return a score less than or equal to 10. You +# have access to the variables 'fatal', 'error', 'warning', 'refactor', +# 'convention', and 'info' which contain the number of messages in each +# category, as well as 'statement' which is the total number of statements +# analyzed. This score is used by the global evaluation report (RP0004). +evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details. +msg-template= + +# Set the output format. Available formats are text, parseable, colorized, json +# and msvs (visual studio). You can also give a reporter class, e.g. +# mypackage.mymodule.MyReporterClass. +#output-format= + +# Tells whether to display a full report or only the messages. +reports=no + +# Activate the evaluation score. +score=yes + + +[SIMILARITIES] + +# Comments are removed from the similarity computation +ignore-comments=yes + +# Docstrings are removed from the similarity computation +ignore-docstrings=yes + +# Imports are removed from the similarity computation +ignore-imports=yes + +# Signatures are removed from the similarity computation +ignore-signatures=yes + +# Minimum lines number of a similarity. +min-similarity-lines=4 + + +[SPELLING] + +# Limits count of emitted suggestions for spelling mistakes. +max-spelling-suggestions=4 + +# Spelling dictionary name. No available dictionaries : You need to install +# both the python package and the system dependency for enchant to work.. +spelling-dict= + +# List of comma separated words that should be considered directives if they +# appear at the beginning of a comment and should not be checked. +spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy: + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains the private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to the private dictionary (see the +# --spelling-private-dict-file option) instead of raising a message. +spelling-store-unknown-words=no + + +[STRING] + +# This flag controls whether inconsistent-quotes generates a warning when the +# character used as a quote delimiter is used inconsistently within a module. +check-quote-consistency=no + +# This flag controls whether the implicit-str-concat should generate a warning +# on implicit string concatenation in sequences defined over several lines. +check-str-concat-over-line-jumps=no + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members= + +# Tells whether to warn about missing members when the owner of the attribute +# is inferred to be None. +ignore-none=yes + +# This flag controls whether pylint should warn about no-member and similar +# checks whenever an opaque object is returned when inferring. The inference +# can return multiple potential results while evaluating a Python object, but +# some branches might not be evaluated, which results in partial inference. In +# that case, it might be useful to still emit no-member and other checks for +# the rest of the inferred objects. +ignore-on-opaque-inference=yes + +# List of symbolic message names to ignore for Mixin members. +ignored-checks-for-mixins=no-member, + not-async-context-manager, + not-context-manager, + attribute-defined-outside-init + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace + +# Show a hint with possible names when a member name was not found. The aspect +# of finding the hint is based on edit distance. +missing-member-hint=yes + +# The minimum edit distance a name should have in order to be considered a +# similar match for a missing member name. +missing-member-hint-distance=1 + +# The total number of similar names that should be taken in consideration when +# showing a hint for a missing member. +missing-member-max-choices=1 + +# Regex pattern to define which classes are considered mixins. +mixin-class-rgx=.*[Mm]ixin + +# List of decorators that change the signature of a decorated function. +signature-mutators= + + +[VARIABLES] + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid defining new builtins when possible. +additional-builtins= + +# Tells whether unused global variables should be treated as a violation. +allow-global-unused-variables=yes + +# List of names allowed to shadow builtins +allowed-redefined-builtins= + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_, + _cb + +# A regular expression matching the name of dummy variables (i.e. expected to +# not be used). +dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ + +# Argument names that match this expression will be ignored. +ignored-argument-names=_.*|^ignored_|^unused_ + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io From 1ab0ae99a90c54b6c8d104684a5127f91710e04c Mon Sep 17 00:00:00 2001 From: geekan Date: Wed, 20 Dec 2023 12:48:57 +0800 Subject: [PATCH 086/167] refine sop --- metagpt/actions/write_prd_an.py | 21 ++++++++++++++------- metagpt/roles/product_manager.py | 4 ++-- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/metagpt/actions/write_prd_an.py b/metagpt/actions/write_prd_an.py index edd94a463..8698c739f 100644 --- a/metagpt/actions/write_prd_an.py +++ b/metagpt/actions/write_prd_an.py @@ -26,8 +26,8 @@ PROGRAMMING_LANGUAGE = ActionNode( ORIGINAL_REQUIREMENTS = ActionNode( key="Original Requirements", expected_type=str, - instruction="Place the polished, complete original requirements here.", - example="The game should have a leaderboard and multiple difficulty levels.", + instruction="Place the original user's requirements here.", + example="Create a 2048 game", ) PROJECT_NAME = ActionNode( @@ -41,7 +41,7 @@ PRODUCT_GOALS = ActionNode( key="Product Goals", expected_type=list[str], instruction="Provide up to three clear, orthogonal product goals.", - example=["Create an engaging user experience", "Ensure high performance", "Provide customizable features"], + example=["Create an engaging user experience", "Improve accessibility, be responsive", "More beautiful UI"], ) USER_STORIES = ActionNode( @@ -49,8 +49,11 @@ USER_STORIES = ActionNode( expected_type=list[str], instruction="Provide up to 3 to 5 scenario-based user stories.", example=[ - "As a user, I want to be able to choose difficulty levels", + "As a player, I want to be able to choose difficulty levels", "As a player, I want to see my score after each game", + "As a player, I want to get restart button when I lose", + "As a player, I want to see beautiful UI that make me feel good", + "As a player, I want to play game via mobile phone", ], ) @@ -58,7 +61,11 @@ COMPETITIVE_ANALYSIS = ActionNode( key="Competitive Analysis", expected_type=list[str], instruction="Provide 5 to 7 competitive products.", - example=["Python Snake Game: Simple interface, lacks advanced features"], + example=[ + "2048 Game A: Simple interface, lacks responsive features", + "play2048.co: Beautiful and responsive UI with my best score shown", + "2048game.com: Responsive UI with my best score shown, but many ads", + ], ) COMPETITIVE_QUADRANT_CHART = ActionNode( @@ -86,7 +93,7 @@ REQUIREMENT_ANALYSIS = ActionNode( key="Requirement Analysis", expected_type=str, instruction="Provide a detailed analysis of the requirements.", - example="The product should be user-friendly.", + example="", ) REQUIREMENT_POOL = ActionNode( @@ -107,7 +114,7 @@ ANYTHING_UNCLEAR = ActionNode( key="Anything UNCLEAR", expected_type=str, instruction="Mention any aspects of the project that are unclear and try to clarify them.", - example="...", + example="", ) ISSUE_TYPE = ActionNode( diff --git a/metagpt/roles/product_manager.py b/metagpt/roles/product_manager.py index 7858d2caa..61263cb50 100644 --- a/metagpt/roles/product_manager.py +++ b/metagpt/roles/product_manager.py @@ -28,8 +28,8 @@ class ProductManager(Role): self, name: str = "Alice", profile: str = "Product Manager", - goal: str = "efficiently create a successful product", - constraints: str = "use same language as user requirement", + goal: str = "efficiently create a successful product that meets market demands and user expectations", + constraints: str = "utilize the same language as the user requirements for seamless communication", ) -> None: """ Initializes the ProductManager role with given attributes. From de02894578a4adc5b4de404549d46d2291181899 Mon Sep 17 00:00:00 2001 From: garylin2099 Date: Sun, 17 Dec 2023 13:52:37 +0800 Subject: [PATCH 087/167] patch release v0.5.1 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 730fffd35..73a05eeae 100644 --- a/setup.py +++ b/setup.py @@ -30,7 +30,7 @@ with open(path.join(here, "requirements.txt"), encoding="utf-8") as f: setup( name="metagpt", - version="0.5.0", + version="0.5.1", description="The Multi-Role Meta Programming Framework", long_description=long_description, long_description_content_type="text/markdown", From e8a848a6145166ef39a7be1e2dd5f8cb4e05a733 Mon Sep 17 00:00:00 2001 From: shenchucheng Date: Sun, 17 Dec 2023 14:41:59 +0800 Subject: [PATCH 088/167] add deprecated warnings for the start_project method --- metagpt/team.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/metagpt/team.py b/metagpt/team.py index 383f2da36..9aa89ee2b 100644 --- a/metagpt/team.py +++ b/metagpt/team.py @@ -3,12 +3,13 @@ """ @Time : 2023/5/12 00:30 @Author : alexanderwu -@File : software_company.py +@File : team.py @Modified By: mashenquan, 2023/11/27. Add an archiving operation after completing the project, as specified in Section 2.2.3.3 of RFC 135. """ -from pathlib import Path +from pathlib import Path +import warnings from pydantic import BaseModel, Field from metagpt.actions import UserRequirement @@ -80,7 +81,7 @@ class Team(BaseModel): raise NoMoneyException(CONFIG.total_cost, f"Insufficient funds: {CONFIG.max_budget}") def run_project(self, idea, send_to: str = ""): - """Start a project from publishing user requirement.""" + """Run a project from publishing user requirement.""" self.idea = idea # Human requirement. @@ -88,6 +89,16 @@ class Team(BaseModel): Message(role="Human", content=idea, cause_by=UserRequirement, send_to=send_to or MESSAGE_ROUTE_TO_ALL) ) + def start_project(self, idea, send_to: str = ""): + """ + Deprecated: This method will be removed in the future. + Please use the `run_project` method instead. + """ + warnings.warn("The 'start_project' method is deprecated and will be removed in the future. " + "Please use the 'run_project' method instead.", + DeprecationWarning, stacklevel=2) + return self.run_project(idea=idea, send_to=send_to) + def _save(self): logger.info(self.json(ensure_ascii=False)) From 31f1be98a0aa95a94ae307186143a6258d901a2c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Mon, 18 Dec 2023 16:13:21 +0800 Subject: [PATCH 089/167] fixbug: recursive user requirement dead loop --- metagpt/roles/role.py | 9 +++++---- tests/metagpt/test_role.py | 6 +++--- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index efe3bcbd4..3a8721004 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -27,15 +27,15 @@ from typing import Iterable, Set, Type, Any from pydantic import BaseModel, Field + from metagpt.actions.action import Action, ActionOutput, action_subclass_registry from metagpt.actions.action_node import ActionNode from metagpt.actions.add_requirement import UserRequirement from metagpt.const import SERDESER_PATH -from metagpt.llm import LLM +from metagpt.llm import LLM, HumanProvider from metagpt.logs import logger from metagpt.memory import Memory from metagpt.provider.base_gpt_api import BaseGPTAPI -from metagpt.provider.human_provider import HumanProvider from metagpt.schema import Message, MessageQueue from metagpt.utils.common import any_to_str, read_json_file, write_json_file, import_class, role_raise_decorator from metagpt.utils.repair_llm_raw_output import extract_state_value_from_output @@ -293,8 +293,7 @@ class Role(BaseModel): """Watch Actions of interest. Role will select Messages caused by these Actions from its personal message buffer during _observe. """ - tags = {any_to_str(t) for t in actions} - self._rc.watch.update(tags) + self._rc.watch = {any_to_str(t) for t in actions} # check RoleContext after adding watch actions self._rc.check(self._role_id) @@ -509,6 +508,8 @@ class Role(BaseModel): msg = with_message elif isinstance(with_message, list): msg = Message(content="\n".join(with_message)) + if not msg.cause_by: + msg.cause_by = UserRequirement self.put_message(msg) if not await self._observe(): diff --git a/tests/metagpt/test_role.py b/tests/metagpt/test_role.py index 8fac2503c..611d321fc 100644 --- a/tests/metagpt/test_role.py +++ b/tests/metagpt/test_role.py @@ -14,11 +14,11 @@ import uuid import pytest from pydantic import BaseModel -from metagpt.actions import Action, ActionOutput +from metagpt.actions import Action, ActionOutput, UserRequirement from metagpt.environment import Environment from metagpt.roles import Role from metagpt.schema import Message -from metagpt.utils.common import get_class_name +from metagpt.utils.common import any_to_str, get_class_name class MockAction(Action): @@ -60,7 +60,7 @@ async def test_react(): name=seed.name, profile=seed.profile, goal=seed.goal, constraints=seed.constraints, desc=seed.desc ) role.subscribe({seed.subscription}) - assert role._rc.watch == set({}) + assert role._rc.watch == {any_to_str(UserRequirement)} assert role.name == seed.name assert role.profile == seed.profile assert role._setting.goal == seed.goal From f2e1053b489c2bedca3f05e2487c6913d31fb8f8 Mon Sep 17 00:00:00 2001 From: garylin2099 Date: Mon, 18 Dec 2023 19:26:38 +0800 Subject: [PATCH 090/167] update version and roadmap --- docs/ROADMAP.md | 8 ++++---- setup.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/ROADMAP.md b/docs/ROADMAP.md index afc9ff445..3cb03f374 100644 --- a/docs/ROADMAP.md +++ b/docs/ROADMAP.md @@ -30,10 +30,10 @@ ### Tasks 4. Complete the design and implementation of module breakdown 5. Support various modes of memory: clearly distinguish between long-term and short-term memory 6. Perfect the test role, and carry out necessary interactions with humans - 7. Allowing natural communication between roles (expected v0.5.0) + 7. ~~Allowing natural communication between roles~~ (v0.5.0) 8. Implement SkillManager and the process of incremental Skill learning (experimentation done with game agents) 9. Automatically get RPM and configure it by calling the corresponding openai page, so that each key does not need to be manually configured - 10. IMPORTANT: Support incremental development (expected v0.5.0) + 10. ~~IMPORTANT: Support incremental development~~ (v0.5.0) 3. Strategies 1. Support ReAct strategy (experimentation done with game agents) 2. Support CoT strategy (experimentation done with game agents) @@ -45,8 +45,8 @@ ### Tasks 2. Implementation: Knowledge search, supporting 10+ data formats 3. Implementation: Data EDA (expected v0.6.0) 4. Implementation: Review - 5. Implementation: Add Document (expected v0.5.0) - 6. Implementation: Delete Document (expected v0.5.0) + 5. ~~Implementation~~: Add Document (v0.5.0) + 6. ~~Implementation~~: Delete Document (v0.5.0) 7. Implementation: Self-training 8. ~~Implementation: DebugError~~ (v0.2.1) 9. Implementation: Generate reliable unit tests based on YAPI diff --git a/setup.py b/setup.py index 73a05eeae..57290f4cd 100644 --- a/setup.py +++ b/setup.py @@ -30,7 +30,7 @@ with open(path.join(here, "requirements.txt"), encoding="utf-8") as f: setup( name="metagpt", - version="0.5.1", + version="0.5.2", description="The Multi-Role Meta Programming Framework", long_description=long_description, long_description_content_type="text/markdown", From 548e6d5f25d6263f471b3f6a76ffd1749a2213f7 Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 10:52:16 +0800 Subject: [PATCH 091/167] remove requirements-ocr.txt and place the optional setup to setup.py --- requirements-ocr.txt | 4 ---- setup.py | 1 + 2 files changed, 1 insertion(+), 4 deletions(-) delete mode 100644 requirements-ocr.txt diff --git a/requirements-ocr.txt b/requirements-ocr.txt deleted file mode 100644 index cf6103afc..000000000 --- a/requirements-ocr.txt +++ /dev/null @@ -1,4 +0,0 @@ -paddlepaddle==2.4.2 -paddleocr>=2.0.1 -tabulate==0.9.0 --r requirements.txt diff --git a/setup.py b/setup.py index 57290f4cd..a06530015 100644 --- a/setup.py +++ b/setup.py @@ -48,6 +48,7 @@ setup( "search-google": ["google-api-python-client==2.94.0"], "search-ddg": ["duckduckgo-search==3.8.5"], "pyppeteer": ["pyppeteer>=1.0.2"], + "ocr": ["paddlepaddle==2.4.2", "paddleocr>=2.0.1", "tabulate==0.9.0"], }, cmdclass={ "install_mermaid": InstallMermaidCLI, From 4e6d1a00f87378a04465d43e81d248c7219447cf Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 11:01:20 +0800 Subject: [PATCH 092/167] use pre-commit --- metagpt/actions/action_node.py | 12 ++++++++++-- metagpt/actions/project_management_an.py | 2 +- metagpt/roles/project_manager.py | 1 + 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index fb7d621d8..9bb12fc84 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -52,6 +52,7 @@ def dict_to_markdown(d, prefix="-", postfix="\n"): class ActionNode: """ActionNode is a tree of nodes.""" + mode: str # Action Context @@ -70,8 +71,15 @@ class ActionNode: content: str instruct_content: BaseModel - def __init__(self, key: str, expected_type: Type, instruction: str, example: str, content: str = "", - children: dict[str, "ActionNode"] = None): + def __init__( + self, + key: str, + expected_type: Type, + instruction: str, + example: str, + content: str = "", + children: dict[str, "ActionNode"] = None, + ): self.key = key self.expected_type = expected_type self.instruction = instruction diff --git a/metagpt/actions/project_management_an.py b/metagpt/actions/project_management_an.py index 970cb0594..6208c1051 100644 --- a/metagpt/actions/project_management_an.py +++ b/metagpt/actions/project_management_an.py @@ -44,7 +44,7 @@ FULL_API_SPEC = ActionNode( key="Full API spec", expected_type=str, instruction="Describe all APIs using OpenAPI 3.0 spec that may be used by both frontend and backend. If front-end " - "and back-end communication is not required, leave it blank.", + "and back-end communication is not required, leave it blank.", example="openapi: 3.0.0 ...", ) diff --git a/metagpt/roles/project_manager.py b/metagpt/roles/project_manager.py index f98d28cb7..42564cd70 100644 --- a/metagpt/roles/project_manager.py +++ b/metagpt/roles/project_manager.py @@ -30,5 +30,6 @@ class ProjectManager(Role): def __init__(self, **kwargs) -> None: super().__init__(**kwargs) + self._init_actions([WriteTasks]) self._watch([WriteDesign]) From b14b3f4dd9e4a3d4fd2ffef85871e483c61677ca Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 11:10:17 +0800 Subject: [PATCH 093/167] setup.py: update --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index a06530015..8ef2a6946 100644 --- a/setup.py +++ b/setup.py @@ -31,14 +31,14 @@ with open(path.join(here, "requirements.txt"), encoding="utf-8") as f: setup( name="metagpt", version="0.5.2", - description="The Multi-Role Meta Programming Framework", + description="The Multi-Agent Framework", long_description=long_description, long_description_content_type="text/markdown", url="https://github.com/geekan/MetaGPT", author="Alexander Wu", author_email="alexanderwu@deepwisdom.ai", license="MIT", - keywords="metagpt multi-role multi-agent programming gpt llm metaprogramming", + keywords="metagpt multi-agent multi-role programming gpt llm metaprogramming", packages=find_packages(exclude=["contrib", "docs", "examples", "tests*"]), python_requires=">=3.9", install_requires=requirements, From 2296aea055be706a3d80c2441410aec2f6cd97c9 Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 11:22:21 +0800 Subject: [PATCH 094/167] delete inspect_module.py because we have ast tree parser --- metagpt/inspect_module.py | 28 ---------------------------- 1 file changed, 28 deletions(-) delete mode 100644 metagpt/inspect_module.py diff --git a/metagpt/inspect_module.py b/metagpt/inspect_module.py deleted file mode 100644 index 48ceffc57..000000000 --- a/metagpt/inspect_module.py +++ /dev/null @@ -1,28 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -@Time : 2023/5/28 14:54 -@Author : alexanderwu -@File : inspect_module.py -""" - -import inspect - -import metagpt # replace with your module - - -def print_classes_and_functions(module): - """FIXME: NOT WORK..""" - for name, obj in inspect.getmembers(module): - if inspect.isclass(obj): - print(f"Class: {name}") - elif inspect.isfunction(obj): - print(f"Function: {name}") - else: - print(name) - - print(dir(module)) - - -if __name__ == "__main__": - print_classes_and_functions(metagpt) From f371e3a49979e87be1ce64b23b5d094b102cd271 Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 11:49:06 +0800 Subject: [PATCH 095/167] token_counter: add gpt-3.5-turbo-16k in list and add comment for them --- metagpt/utils/token_counter.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/metagpt/utils/token_counter.py b/metagpt/utils/token_counter.py index 266a53268..ebfb85de7 100644 --- a/metagpt/utils/token_counter.py +++ b/metagpt/utils/token_counter.py @@ -56,6 +56,7 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0613"): if model in { "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", + "gpt-3.5-turbo-16k", "gpt-3.5-turbo-1106", "gpt-4-0314", "gpt-4-32k-0314", @@ -63,7 +64,7 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0613"): "gpt-4-32k-0613", "gpt-4-1106-preview", }: - tokens_per_message = 3 + tokens_per_message = 3 # # every reply is primed with <|start|>assistant<|message|> tokens_per_name = 1 elif model == "gpt-3.5-turbo-0301": tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n From e8cb7991c447ff9e24303111b435ef0c1ebe7051 Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 11:52:23 +0800 Subject: [PATCH 096/167] openai_api: refine logic --- metagpt/provider/openai_api.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index a73bb0aa0..86054881e 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -329,7 +329,8 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): usage["completion_tokens"] = completion_tokens return usage except Exception as e: - logger.error("usage calculation failed!", e) + logger.error(f"{self.model} usage calculation failed!", e) + return {} else: return usage @@ -360,7 +361,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): return results def _update_costs(self, usage: dict): - if CONFIG.calc_usage: + if CONFIG.calc_usage and usage: try: prompt_tokens = int(usage["prompt_tokens"]) completion_tokens = int(usage["completion_tokens"]) From f71753ba0dc7fcfacc3456755a0fa6a19d7b8374 Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 13:51:51 +0800 Subject: [PATCH 097/167] add function import, avoid "import" --- metagpt/utils/common.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index c909180cc..6301cd6a3 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -25,8 +25,9 @@ from pathlib import Path from typing import Any from typing import List, Tuple, Union +import loguru from pydantic.json import pydantic_encoder -from tenacity import _utils +from tenacity import RetryCallState, _utils from metagpt.const import MESSAGE_ROUTE_TO_ALL from metagpt.logs import logger From 8f649252900a8f1e7977cdf2eea8da9a8d4518dc Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 14:17:54 +0800 Subject: [PATCH 098/167] refine utils code --- metagpt/utils/common.py | 51 ++++++++++++++++++++++++------------ tests/metagpt/test_role.py | 8 +++--- tests/metagpt/test_schema.py | 9 +++---- 3 files changed, 42 insertions(+), 26 deletions(-) diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index 6301cd6a3..08df480ee 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -301,9 +301,6 @@ class NoMoneyException(Exception): def print_members(module, indent=0): """ https://stackoverflow.com/questions/1796180/how-can-i-get-a-list-of-all-classes-within-current-module-in-python - :param module: - :param indent: - :return: """ prefix = " " * indent for name, obj in inspect.getmembers(module): @@ -321,6 +318,7 @@ def print_members(module, indent=0): def parse_recipient(text): + # FIXME: use ActionNode instead. pattern = r"## Send To:\s*([A-Za-z]+)\s*?" # hard code for now recipient = re.search(pattern, text) if recipient: @@ -337,18 +335,12 @@ def get_class_name(cls) -> str: return f"{cls.__module__}.{cls.__name__}" -def get_object_name(obj) -> str: - """Return class name of the object""" - cls = type(obj) - return f"{cls.__module__}.{cls.__name__}" - - -def any_to_str(val) -> str: +def any_to_str(val: str | typing.Callable) -> str: """Return the class name or the class name of the object, or 'val' if it's a string type.""" if isinstance(val, str): return val if not callable(val): - return get_object_name(val) + return get_class_name(type(val)) return get_class_name(val) @@ -356,32 +348,57 @@ def any_to_str(val) -> str: def any_to_str_set(val) -> set: """Convert any type to string set.""" res = set() - if isinstance(val, dict) or isinstance(val, list) or isinstance(val, set) or isinstance(val, tuple): + + # Check if the value is iterable, but not a string (since strings are technically iterable) + if isinstance(val, (dict, list, set, tuple)): + # Special handling for dictionaries to iterate over values + if isinstance(val, dict): + val = val.values() + for i in val: res.add(any_to_str(i)) else: res.add(any_to_str(val)) + return res -def is_subscribed(message, tags): +def is_subscribed(message: "Message", tags: set): """Return whether it's consumer""" if MESSAGE_ROUTE_TO_ALL in message.send_to: return True - for t in tags: - if t in message.send_to: + for i in tags: + if i in message.send_to: return True return False -def general_after_log(logger: "loguru.Logger", sec_format: str = "%0.3f") -> typing.Callable[["RetryCallState"], None]: +def general_after_log(i: "loguru.Logger", sec_format: str = "%0.3f") -> typing.Callable[["RetryCallState"], None]: + """ + Generates a logging function to be used after a call is retried. + + This generated function logs an error message with the outcome of the retried function call. It includes + the name of the function, the time taken for the call in seconds (formatted according to `sec_format`), + the number of attempts made, and the exception raised, if any. + + :param i: A Logger instance from the loguru library used to log the error message. + :param sec_format: A string format specifier for how to format the number of seconds since the start of the call. + Defaults to three decimal places. + :return: A callable that accepts a RetryCallState object and returns None. This callable logs the details + of the retried call. + """ + def log_it(retry_state: "RetryCallState") -> None: + # If the function name is not known, default to "" if retry_state.fn is None: fn_name = "" else: + # Retrieve the callable's name using a utility function fn_name = _utils.get_callback_name(retry_state.fn) - logger.error( + + # Log an error message with the function name, time since start, attempt number, and the exception + i.error( f"Finished call to '{fn_name}' after {sec_format % retry_state.seconds_since_start}(s), " f"this was the {_utils.to_ordinal(retry_state.attempt_number)} time calling it. " f"exp: {retry_state.outcome.exception()}" diff --git a/tests/metagpt/test_role.py b/tests/metagpt/test_role.py index 611d321fc..dbe45130d 100644 --- a/tests/metagpt/test_role.py +++ b/tests/metagpt/test_role.py @@ -18,7 +18,7 @@ from metagpt.actions import Action, ActionOutput, UserRequirement from metagpt.environment import Environment from metagpt.roles import Role from metagpt.schema import Message -from metagpt.utils.common import any_to_str, get_class_name +from metagpt.utils.common import any_to_str class MockAction(Action): @@ -88,13 +88,13 @@ async def test_react(): @pytest.mark.asyncio async def test_msg_to(): m = Message(content="a", send_to=["a", MockRole, Message]) - assert m.send_to == set({"a", get_class_name(MockRole), get_class_name(Message)}) + assert m.send_to == {"a", any_to_str(MockRole), any_to_str(Message)} m = Message(content="a", cause_by=MockAction, send_to={"a", MockRole, Message}) - assert m.send_to == set({"a", get_class_name(MockRole), get_class_name(Message)}) + assert m.send_to == {"a", any_to_str(MockRole), any_to_str(Message)} m = Message(content="a", send_to=("a", MockRole, Message)) - assert m.send_to == set({"a", get_class_name(MockRole), get_class_name(Message)}) + assert m.send_to == {"a", any_to_str(MockRole), any_to_str(Message)} if __name__ == "__main__": diff --git a/tests/metagpt/test_schema.py b/tests/metagpt/test_schema.py index 10343c192..c8602d953 100644 --- a/tests/metagpt/test_schema.py +++ b/tests/metagpt/test_schema.py @@ -16,8 +16,7 @@ from metagpt.schema import AIMessage, Message, SystemMessage, UserMessage from metagpt.actions.action_output import ActionOutput from metagpt.actions.write_code import WriteCode from metagpt.utils.serialize import serialize_general_message, deserialize_general_message - -from metagpt.utils.common import get_class_name +from metagpt.utils.common import any_to_str @pytest.mark.asyncio @@ -58,9 +57,9 @@ def test_message(): m.cause_by = "Message" assert m.cause_by == "Message" m.cause_by = Action - assert m.cause_by == get_class_name(Action) + assert m.cause_by == any_to_str(Action) m.cause_by = Action() - assert m.cause_by == get_class_name(Action) + assert m.cause_by == any_to_str(Action) m.content = "b" assert m.content == "b" @@ -71,7 +70,7 @@ def test_routes(): m.send_to = "b" assert m.send_to == {"b"} m.send_to = {"e", Action} - assert m.send_to == {"e", get_class_name(Action)} + assert m.send_to == {"e", any_to_str(Action)} def test_message_serdeser(): From 5c341cb05383685c3d4403d22495850af33c8b3f Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 16:16:52 +0800 Subject: [PATCH 099/167] refine code: use handle_exception function instead of in-function duplicate code frags --- metagpt/actions/action_node.py | 2 +- metagpt/actions/run_code.py | 30 ++++----- metagpt/config.py | 1 + metagpt/repo_parser.py | 19 ++++-- metagpt/schema.py | 78 ++++++++-------------- metagpt/tools/search_engine_meilisearch.py | 12 ++-- metagpt/utils/common.py | 10 +++ metagpt/utils/custom_decoder.py | 2 +- metagpt/utils/dependency_file.py | 20 ++---- metagpt/utils/exceptions.py | 59 ++++++++++++++++ metagpt/utils/file.py | 45 ++++++------- metagpt/utils/file_repository.py | 11 +-- 12 files changed, 159 insertions(+), 130 deletions(-) create mode 100644 metagpt/utils/exceptions.py diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index 9bb12fc84..6f1215920 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -43,7 +43,7 @@ Fill in the above nodes based on the format example. """ -def dict_to_markdown(d, prefix="-", postfix="\n"): +def dict_to_markdown(d, prefix="###", postfix="\n"): markdown_str = "" for key, value in d.items(): markdown_str += f"{prefix} {key}: {value}{postfix}" diff --git a/metagpt/actions/run_code.py b/metagpt/actions/run_code.py index fa13a0980..1b9fd252f 100644 --- a/metagpt/actions/run_code.py +++ b/metagpt/actions/run_code.py @@ -16,13 +16,13 @@ class. """ import subprocess -import traceback from typing import Tuple from metagpt.actions.action import Action from metagpt.config import CONFIG from metagpt.logs import logger from metagpt.schema import RunCodeResult +from metagpt.utils.exceptions import handle_exception PROMPT_TEMPLATE = """ Role: You are a senior development and qa engineer, your role is summarize the code running result. @@ -78,15 +78,12 @@ class RunCode(Action): super().__init__(name, context, llm) @classmethod + @handle_exception async def run_text(cls, code) -> Tuple[str, str]: - try: - # We will document_store the result in this dictionary - namespace = {} - exec(code, namespace) - return namespace.get("result", ""), "" - except Exception: - # If there is an error in the code, return the error message - return "", traceback.format_exc() + # We will document_store the result in this dictionary + namespace = {} + exec(code, namespace) + return namespace.get("result", ""), "" @classmethod async def run_script(cls, working_directory, additional_python_paths=[], command=[]) -> Tuple[str, str]: @@ -145,18 +142,17 @@ class RunCode(Action): rsp = await self._aask(prompt) return RunCodeResult(summary=rsp, stdout=outs, stderr=errs) + @staticmethod + @handle_exception(exception_type=subprocess.CalledProcessError) + def _install_via_subprocess(cmd, check, cwd, env): + return subprocess.run(cmd, check=check, cwd=cwd, env=env) + @staticmethod def _install_dependencies(working_directory, env): install_command = ["python", "-m", "pip", "install", "-r", "requirements.txt"] logger.info(" ".join(install_command)) - try: - subprocess.run(install_command, check=True, cwd=working_directory, env=env) - except subprocess.CalledProcessError as e: - logger.warning(f"{e}") + RunCode._install_via_subprocess(install_command, check=True, cwd=working_directory, env=env) install_pytest_command = ["python", "-m", "pip", "install", "pytest"] logger.info(" ".join(install_pytest_command)) - try: - subprocess.run(install_pytest_command, check=True, cwd=working_directory, env=env) - except subprocess.CalledProcessError as e: - logger.warning(f"{e}") + RunCode._install_via_subprocess(install_pytest_command, check=True, cwd=working_directory, env=env) diff --git a/metagpt/config.py b/metagpt/config.py index d7f5c1249..d6e6d8b88 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -139,6 +139,7 @@ class Config(metaclass=Singleton): continue configs.update(yaml_data) OPTIONS.set(configs) + logger.info(f"Default OpenAI API Model: {self.openai_api_model}") @staticmethod def _get(*args, **kwargs): diff --git a/metagpt/repo_parser.py b/metagpt/repo_parser.py index b84dbab9a..9a1218ef1 100644 --- a/metagpt/repo_parser.py +++ b/metagpt/repo_parser.py @@ -15,17 +15,17 @@ from pydantic import BaseModel, Field from metagpt.config import CONFIG from metagpt.logs import logger +from metagpt.utils.exceptions import handle_exception class RepoParser(BaseModel): base_directory: Path = Field(default=None) - def parse_file(self, file_path): + @classmethod + @handle_exception(exception_type=Exception, default_return=[]) + def _parse_file(cls, file_path: Path) -> list: """Parse a Python file in the repository.""" - try: - return ast.parse(file_path.read_text()).body - except: - return [] + return ast.parse(file_path.read_text()).body def extract_class_and_function_info(self, tree, file_path): """Extract class, function, and global variable information from the AST.""" @@ -52,7 +52,7 @@ class RepoParser(BaseModel): files_classes = [] directory = self.base_directory for path in directory.rglob("*.py"): - tree = self.parse_file(path) + tree = self._parse_file(path) file_info = self.extract_class_and_function_info(tree, path) files_classes.append(file_info) @@ -90,5 +90,10 @@ def main(): logger.info(pformat(symbols)) +def error(): + """raise Exception and logs it""" + RepoParser._parse_file(Path("test.py")) + + if __name__ == "__main__": - main() + error() diff --git a/metagpt/schema.py b/metagpt/schema.py index 1c1fdd94d..c026ea1d9 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -22,7 +22,7 @@ import uuid from asyncio import Queue, QueueEmpty, wait_for from json import JSONDecodeError from pathlib import Path -from typing import Dict, List, Set, TypedDict, Optional, Any +from typing import Dict, List, Optional, Set, Type, TypedDict, TypeVar, Any from pydantic import BaseModel, Field @@ -39,6 +39,7 @@ from metagpt.logs import logger from metagpt.utils.common import any_to_str, any_to_str_set, import_class from metagpt.utils.serialize import actionoutout_schema_to_mapping, actionoutput_mapping_to_str, \ actionoutput_str_to_mapping +from metagpt.utils.exceptions import handle_exception class RawMessage(TypedDict): @@ -163,14 +164,11 @@ class Message(BaseModel): return self.json(exclude_none=True) @staticmethod + @handle_exception(exception_type=JSONDecodeError, default_return=None) def load(val): """Convert the json string to object.""" - try: - d = json.loads(val) - return Message(**d) - except JSONDecodeError as err: - logger.error(f"parse json failed: {val}, error:{err}") - return None + d = json.loads(val) + return Message(**d) class UserMessage(Message): @@ -265,50 +263,46 @@ class MessageQueue(BaseModel): return json.dumps(lst) @staticmethod - def load(self, v) -> MessageQueue: + def load(i) -> "MessageQueue": """Convert the json string to the `MessageQueue` object.""" - q = MessageQueue() + queue = MessageQueue() try: - lst = json.loads(v) + lst = json.loads(i) for i in lst: msg = Message(**i) - q.push(msg) + queue.push(msg) except JSONDecodeError as e: - logger.warning(f"JSON load failed: {v}, error:{e}") + logger.warning(f"JSON load failed: {i}, error:{e}") - return q + return queue -class CodingContext(BaseModel): +# 定义一个泛型类型变量 +T = TypeVar("T", bound="BaseModel") + + +class BaseContext(BaseModel): + @staticmethod + @handle_exception + def loads(val: str, cls: Type[T]) -> Optional[T]: + m = json.loads(val) + return cls(**m) + + +class CodingContext(BaseContext): filename: str design_doc: Optional[Document] task_doc: Optional[Document] code_doc: Optional[Document] - @staticmethod - def loads(val: str) -> CodingContext | None: - try: - m = json.loads(val) - return CodingContext(**m) - except Exception: - return None - -class TestingContext(BaseModel): +class TestingContext(BaseContext): filename: str code_doc: Document test_doc: Optional[Document] - @staticmethod - def loads(val: str) -> TestingContext | None: - try: - m = json.loads(val) - return TestingContext(**m) - except Exception: - return None - -class RunCodeContext(BaseModel): +class RunCodeContext(BaseContext): mode: str = "script" code: Optional[str] code_filename: str = "" @@ -320,28 +314,12 @@ class RunCodeContext(BaseModel): output_filename: Optional[str] output: Optional[str] - @staticmethod - def loads(val: str) -> RunCodeContext | None: - try: - m = json.loads(val) - return RunCodeContext(**m) - except Exception: - return None - -class RunCodeResult(BaseModel): +class RunCodeResult(BaseContext): summary: str stdout: str stderr: str - @staticmethod - def loads(val: str) -> RunCodeResult | None: - try: - m = json.loads(val) - return RunCodeResult(**m) - except Exception: - return None - class CodeSummarizeContext(BaseModel): design_filename: str = "" @@ -365,5 +343,5 @@ class CodeSummarizeContext(BaseModel): return hash((self.design_filename, self.task_filename)) -class BugFixContext(BaseModel): +class BugFixContext(BaseContext): filename: str = "" diff --git a/metagpt/tools/search_engine_meilisearch.py b/metagpt/tools/search_engine_meilisearch.py index f7c1c685a..ea6db4dbd 100644 --- a/metagpt/tools/search_engine_meilisearch.py +++ b/metagpt/tools/search_engine_meilisearch.py @@ -11,6 +11,8 @@ from typing import List import meilisearch from meilisearch.index import Index +from metagpt.utils.exceptions import handle_exception + class DataSource: def __init__(self, name: str, url: str): @@ -34,11 +36,7 @@ class MeilisearchEngine: index.add_documents(documents) self.set_index(index) + @handle_exception(exception_type=Exception, default_return=[]) def search(self, query): - try: - search_results = self._index.search(query) - return search_results["hits"] - except Exception as e: - # Handle MeiliSearch API errors - print(f"MeiliSearch API error: {e}") - return [] + search_results = self._index.search(query) + return search_results["hits"] diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index 08df480ee..0060950dc 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -25,12 +25,14 @@ from pathlib import Path from typing import Any from typing import List, Tuple, Union +import aiofiles import loguru from pydantic.json import pydantic_encoder from tenacity import RetryCallState, _utils from metagpt.const import MESSAGE_ROUTE_TO_ALL from metagpt.logs import logger +from metagpt.utils.exceptions import handle_exception def check_cmd_exists(command) -> int: @@ -478,3 +480,11 @@ def role_raise_decorator(func): raise Exception(format_trackback_info(limit=None)) # raise again to make it captured outside return wrapper + + +@handle_exception +async def aread(file_path: str) -> str: + """Read file asynchronously.""" + async with aiofiles.open(str(file_path), mode="r") as reader: + content = await reader.read() + return content diff --git a/metagpt/utils/custom_decoder.py b/metagpt/utils/custom_decoder.py index 373d16356..eb01a1115 100644 --- a/metagpt/utils/custom_decoder.py +++ b/metagpt/utils/custom_decoder.py @@ -25,7 +25,7 @@ def py_make_scanner(context): except IndexError: raise StopIteration(idx) from None - if nextchar == '"' or nextchar == "'": + if nextchar in ("'", '"'): if idx + 2 < len(string) and string[idx + 1] == nextchar and string[idx + 2] == nextchar: # Handle the case where the next two characters are the same as nextchar return parse_string(string, idx + 3, strict, delimiter=nextchar * 3) # triple quote diff --git a/metagpt/utils/dependency_file.py b/metagpt/utils/dependency_file.py index e8347d567..d03444f0e 100644 --- a/metagpt/utils/dependency_file.py +++ b/metagpt/utils/dependency_file.py @@ -15,7 +15,8 @@ from typing import Set import aiofiles from metagpt.config import CONFIG -from metagpt.logs import logger +from metagpt.utils.common import aread +from metagpt.utils.exceptions import handle_exception class DependencyFile: @@ -36,21 +37,14 @@ class DependencyFile: """Load dependencies from the file asynchronously.""" if not self._filename.exists(): return - try: - async with aiofiles.open(str(self._filename), mode="r") as reader: - data = await reader.read() - self._dependencies = json.loads(data) - except Exception as e: - logger.error(f"Failed to load {str(self._filename)}, error:{e}") + self._dependencies = await aread(self._filename) + @handle_exception async def save(self): """Save dependencies to the file asynchronously.""" - try: - data = json.dumps(self._dependencies) - async with aiofiles.open(str(self._filename), mode="w") as writer: - await writer.write(data) - except Exception as e: - logger.error(f"Failed to save {str(self._filename)}, error:{e}") + data = json.dumps(self._dependencies) + async with aiofiles.open(str(self._filename), mode="w") as writer: + await writer.write(data) async def update(self, filename: Path | str, dependencies: Set[Path | str], persist=True): """Update dependencies for a file asynchronously. diff --git a/metagpt/utils/exceptions.py b/metagpt/utils/exceptions.py new file mode 100644 index 000000000..b4b5aa590 --- /dev/null +++ b/metagpt/utils/exceptions.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/19 14:46 +@Author : alexanderwu +@File : exceptions.py +""" + + +import asyncio +import functools +import traceback +from typing import Any, Callable, Tuple, Type, TypeVar, Union + +from metagpt.logs import logger + +ReturnType = TypeVar("ReturnType") + + +def handle_exception( + _func: Callable[..., ReturnType] = None, + *, + exception_type: Union[Type[Exception], Tuple[Type[Exception], ...]] = Exception, + default_return: Any = None, +) -> Callable[..., ReturnType]: + """handle exception, return default value""" + + def decorator(func: Callable[..., ReturnType]) -> Callable[..., ReturnType]: + @functools.wraps(func) + async def async_wrapper(*args: Any, **kwargs: Any) -> ReturnType: + try: + return await func(*args, **kwargs) + except exception_type as e: + logger.opt(depth=1).error( + f"Calling {func.__name__} with args: {args}, kwargs: {kwargs} failed: {e}, " + f"stack: {traceback.format_exc()}" + ) + return default_return + + @functools.wraps(func) + def sync_wrapper(*args: Any, **kwargs: Any) -> ReturnType: + try: + return func(*args, **kwargs) + except exception_type as e: + logger.opt(depth=1).error( + f"Calling {func.__name__} with args: {args}, kwargs: {kwargs} failed: {e}, " + f"stack: {traceback.format_exc()}" + ) + return default_return + + if asyncio.iscoroutinefunction(func): + return async_wrapper + else: + return sync_wrapper + + if _func is None: + return decorator + else: + return decorator(_func) diff --git a/metagpt/utils/file.py b/metagpt/utils/file.py index 6bb9a1a97..f62b44eb8 100644 --- a/metagpt/utils/file.py +++ b/metagpt/utils/file.py @@ -11,6 +11,7 @@ from pathlib import Path import aiofiles from metagpt.logs import logger +from metagpt.utils.exceptions import handle_exception class File: @@ -19,6 +20,7 @@ class File: CHUNK_SIZE = 64 * 1024 @classmethod + @handle_exception async def write(cls, root_path: Path, filename: str, content: bytes) -> Path: """Write the file content to the local specified path. @@ -33,18 +35,15 @@ class File: Raises: Exception: If an unexpected error occurs during the file writing process. """ - try: - root_path.mkdir(parents=True, exist_ok=True) - full_path = root_path / filename - async with aiofiles.open(full_path, mode="wb") as writer: - await writer.write(content) - logger.debug(f"Successfully write file: {full_path}") - return full_path - except Exception as e: - logger.error(f"Error writing file: {e}") - raise e + root_path.mkdir(parents=True, exist_ok=True) + full_path = root_path / filename + async with aiofiles.open(full_path, mode="wb") as writer: + await writer.write(content) + logger.debug(f"Successfully write file: {full_path}") + return full_path @classmethod + @handle_exception async def read(cls, file_path: Path, chunk_size: int = None) -> bytes: """Partitioning read the file content from the local specified path. @@ -58,18 +57,14 @@ class File: Raises: Exception: If an unexpected error occurs during the file reading process. """ - try: - chunk_size = chunk_size or cls.CHUNK_SIZE - async with aiofiles.open(file_path, mode="rb") as reader: - chunks = list() - while True: - chunk = await reader.read(chunk_size) - if not chunk: - break - chunks.append(chunk) - content = b"".join(chunks) - logger.debug(f"Successfully read file, the path of file: {file_path}") - return content - except Exception as e: - logger.error(f"Error reading file: {e}") - raise e + chunk_size = chunk_size or cls.CHUNK_SIZE + async with aiofiles.open(file_path, mode="rb") as reader: + chunks = list() + while True: + chunk = await reader.read(chunk_size) + if not chunk: + break + chunks.append(chunk) + content = b"".join(chunks) + logger.debug(f"Successfully read file, the path of file: {file_path}") + return content diff --git a/metagpt/utils/file_repository.py b/metagpt/utils/file_repository.py index 2eca799a8..099556a6b 100644 --- a/metagpt/utils/file_repository.py +++ b/metagpt/utils/file_repository.py @@ -19,6 +19,7 @@ import aiofiles from metagpt.config import CONFIG from metagpt.logs import logger from metagpt.schema import Document +from metagpt.utils.common import aread from metagpt.utils.json_to_markdown import json_to_markdown @@ -97,15 +98,7 @@ class FileRepository: path_name = self.workdir / filename if not path_name.exists(): return None - try: - async with aiofiles.open(str(path_name), mode="r") as reader: - doc.content = await reader.read() - except FileNotFoundError as e: - logger.info(f"open {str(path_name)} failed:{e}") - return None - except Exception as e: - logger.info(f"open {str(path_name)} failed:{e}") - return None + doc.content = await aread(path_name) return doc async def get_all(self) -> List[Document]: From 437abd1754603a6037fce7f2d1c8cbaa46c56116 Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 16:22:29 +0800 Subject: [PATCH 100/167] bug fix and proper log --- metagpt/config.py | 3 +-- metagpt/utils/dependency_file.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/metagpt/config.py b/metagpt/config.py index d6e6d8b88..5f2be971a 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -76,7 +76,7 @@ class Config(metaclass=Singleton): self.openai_api_type = self._get("OPENAI_API_TYPE") self.openai_api_version = self._get("OPENAI_API_VERSION") self.openai_api_rpm = self._get("RPM", 3) - self.openai_api_model = self._get("OPENAI_API_MODEL", "gpt-4") + self.openai_api_model = self._get("OPENAI_API_MODEL", "gpt-4-1106-preview") self.max_tokens_rsp = self._get("MAX_TOKENS", 2048) self.deployment_name = self._get("DEPLOYMENT_NAME") self.deployment_id = self._get("DEPLOYMENT_ID") @@ -139,7 +139,6 @@ class Config(metaclass=Singleton): continue configs.update(yaml_data) OPTIONS.set(configs) - logger.info(f"Default OpenAI API Model: {self.openai_api_model}") @staticmethod def _get(*args, **kwargs): diff --git a/metagpt/utils/dependency_file.py b/metagpt/utils/dependency_file.py index d03444f0e..8a6575e9e 100644 --- a/metagpt/utils/dependency_file.py +++ b/metagpt/utils/dependency_file.py @@ -37,7 +37,7 @@ class DependencyFile: """Load dependencies from the file asynchronously.""" if not self._filename.exists(): return - self._dependencies = await aread(self._filename) + self._dependencies = json.loads(await aread(self._filename)) @handle_exception async def save(self): From 9ca0d57a91bea18a19cf9b80b8854d00b310b67a Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 16:31:38 +0800 Subject: [PATCH 101/167] bug fix and proper log --- metagpt/schema.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/metagpt/schema.py b/metagpt/schema.py index c026ea1d9..991ceaae0 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -282,11 +282,11 @@ T = TypeVar("T", bound="BaseModel") class BaseContext(BaseModel): - @staticmethod + @classmethod @handle_exception - def loads(val: str, cls: Type[T]) -> Optional[T]: - m = json.loads(val) - return cls(**m) + def loads(cls: Type[T], val: str) -> Optional[T]: + i = json.loads(val) + return cls(**i) class CodingContext(BaseContext): From b43d8462deb4c35d997b8c2ae3d797a0cb1853f6 Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 16:54:06 +0800 Subject: [PATCH 102/167] refine config --- config/config.yaml | 2 +- metagpt/config.py | 51 +++++++++++++++++++------------ metagpt/provider/anthropic_api.py | 4 +-- 3 files changed, 34 insertions(+), 23 deletions(-) diff --git a/config/config.yaml b/config/config.yaml index dc4c4ea5a..f547462ba 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -23,7 +23,7 @@ RPM: 10 #SPARK_URL : "ws://spark-api.xf-yun.com/v2.1/chat" #### if Anthropic -#Anthropic_API_KEY: "YOUR_API_KEY" +#ANTHROPIC_API_KEY: "YOUR_API_KEY" #### if AZURE, check https://github.com/openai/openai-cookbook/blob/main/examples/azure/chat.ipynb #### You can use ENGINE or DEPLOYMENT mode diff --git a/metagpt/config.py b/metagpt/config.py index 5f2be971a..386c4784e 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -47,30 +47,41 @@ class Config(metaclass=Singleton): def __init__(self, yaml_file=default_yaml_file): golbal_options = OPTIONS.get() self._init_with_config_files_and_env(yaml_file) - logger.debug("Config loading done.") self._update() golbal_options.update(OPTIONS.get()) + logger.debug("Config loading done.") + + @staticmethod + def _is_valid_llm_key(k) -> bool: + return k and k != "YOUR_API_KEY" + + def _check_llm_exists(self): + if not any( + [ + self._is_valid_llm_key(self.openai_api_key), + self._is_valid_llm_key(self.anthropic_api_key), + self._is_valid_llm_key(self.zhipuai_api_key), + self._is_valid_llm_key(self.fireworks_api_key), + self.open_llm_api_base, + ] + ): + raise NotConfiguredException( + "Set OPENAI_API_KEY or Anthropic_API_KEY or ZHIPUAI_API_KEY " + "or FIREWORKS_API_KEY or OPEN_LLM_API_BASE" + ) def _update(self): # logger.info("Config loading done.") self.global_proxy = self._get("GLOBAL_PROXY") + self.openai_api_key = self._get("OPENAI_API_KEY") - self.anthropic_api_key = self._get("Anthropic_API_KEY") + self.anthropic_api_key = self._get("ANTHROPIC_API_KEY") self.zhipuai_api_key = self._get("ZHIPUAI_API_KEY") self.open_llm_api_base = self._get("OPEN_LLM_API_BASE") self.open_llm_api_model = self._get("OPEN_LLM_API_MODEL") self.fireworks_api_key = self._get("FIREWORKS_API_KEY") - if ( - (not self.openai_api_key or "YOUR_API_KEY" == self.openai_api_key) - and (not self.anthropic_api_key or "YOUR_API_KEY" == self.anthropic_api_key) - and (not self.zhipuai_api_key or "YOUR_API_KEY" == self.zhipuai_api_key) - and (not self.open_llm_api_base) - and (not self.fireworks_api_key or "YOUR_API_KEY" == self.fireworks_api_key) - ): - raise NotConfiguredException( - "Set OPENAI_API_KEY or Anthropic_API_KEY or ZHIPUAI_API_KEY first " - "or FIREWORKS_API_KEY or OPEN_LLM_API_BASE" - ) + self._check_llm_exists() + self.openai_api_base = self._get("OPENAI_API_BASE") self.openai_proxy = self._get("OPENAI_PROXY") or self.global_proxy self.openai_api_type = self._get("OPENAI_API_TYPE") @@ -90,7 +101,7 @@ class Config(metaclass=Singleton): self.fireworks_api_base = self._get("FIREWORKS_API_BASE") self.fireworks_api_model = self._get("FIREWORKS_API_MODEL") - self.claude_api_key = self._get("Anthropic_API_KEY") + self.claude_api_key = self._get("ANTHROPIC_API_KEY") self.serpapi_api_key = self._get("SERPAPI_API_KEY") self.serper_api_key = self._get("SERPER_API_KEY") self.google_api_key = self._get("GOOGLE_API_KEY") @@ -142,8 +153,8 @@ class Config(metaclass=Singleton): @staticmethod def _get(*args, **kwargs): - m = OPTIONS.get() - return m.get(*args, **kwargs) + i = OPTIONS.get() + return i.get(*args, **kwargs) def get(self, key, *args, **kwargs): """Search for a value in config/key.yaml, config/config.yaml, and env; raise an error if not found""" @@ -156,8 +167,8 @@ class Config(metaclass=Singleton): OPTIONS.get()[name] = value def __getattr__(self, name: str) -> Any: - m = OPTIONS.get() - return m.get(name) + i = OPTIONS.get() + return i.get(name) def set_context(self, options: dict): """Update current config""" @@ -176,8 +187,8 @@ class Config(metaclass=Singleton): def new_environ(self): """Return a new os.environ object""" env = os.environ.copy() - m = self.options - env.update({k: v for k, v in m.items() if isinstance(v, str)}) + i = self.options + env.update({k: v for k, v in i.items() if isinstance(v, str)}) return env diff --git a/metagpt/provider/anthropic_api.py b/metagpt/provider/anthropic_api.py index 03802a716..f5b06c855 100644 --- a/metagpt/provider/anthropic_api.py +++ b/metagpt/provider/anthropic_api.py @@ -14,7 +14,7 @@ from metagpt.config import CONFIG class Claude2: def ask(self, prompt): - client = Anthropic(api_key=CONFIG.claude_api_key) + client = Anthropic(api_key=CONFIG.anthropic_api_key) res = client.completions.create( model="claude-2", @@ -24,7 +24,7 @@ class Claude2: return res.completion async def aask(self, prompt): - client = Anthropic(api_key=CONFIG.claude_api_key) + client = Anthropic(api_key=CONFIG.anthropic_api_key) res = client.completions.create( model="claude-2", From 67de3132483409c9ed3b85809bcb5cfc7276d347 Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 17:06:07 +0800 Subject: [PATCH 103/167] refine code --- metagpt/config.py | 8 ++++++++ metagpt/repo_parser.py | 2 +- metagpt/startup.py | 10 +++------- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/metagpt/config.py b/metagpt/config.py index 386c4784e..50ad6a3b2 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -131,6 +131,14 @@ class Config(metaclass=Singleton): self.workspace_path = Path(self._get("WORKSPACE_PATH", DEFAULT_WORKSPACE_ROOT)) self._ensure_workspace_exists() + def update_via_cli(self, project_path, project_name, inc, reqa_file, max_auto_summarize_code): + """update config via cli""" + self.project_path = project_path + self.project_name = project_name + self.inc = inc + self.reqa_file = reqa_file + self.max_auto_summarize_code = max_auto_summarize_code + def _ensure_workspace_exists(self): self.workspace_path.mkdir(parents=True, exist_ok=True) logger.debug(f"WORKSPACE_PATH set to {self.workspace_path}") diff --git a/metagpt/repo_parser.py b/metagpt/repo_parser.py index 9a1218ef1..3524a5bce 100644 --- a/metagpt/repo_parser.py +++ b/metagpt/repo_parser.py @@ -96,4 +96,4 @@ def error(): if __name__ == "__main__": - error() + main() diff --git a/metagpt/startup.py b/metagpt/startup.py index 17eb26665..6ae47213e 100644 --- a/metagpt/startup.py +++ b/metagpt/startup.py @@ -27,8 +27,8 @@ def startup( reqa_file: str = typer.Option(default="", help="Specify the source file name for rewriting the quality test code."), max_auto_summarize_code: int = typer.Option( default=-1, - help="The maximum number of times the 'SummarizeCode' action is automatically invoked, " - "with -1 indicating unlimited. This parameter is used for debugging the workflow.", + help="The maximum number of times the 'SummarizeCode' action is automatically invoked, with -1 indicating " + "unlimited. This parameter is used for debugging the workflow.", ), recover_path: str = typer.Option(default=None, help="recover the project from existing serialized storage") ): @@ -43,14 +43,10 @@ def startup( from metagpt.team import Team # Use in the PrepareDocuments action according to Section 2.2.3.5.1 of RFC 135. - CONFIG.project_path = project_path if project_path: inc = True project_name = project_name or Path(project_path).name - CONFIG.project_name = project_name - CONFIG.inc = inc - CONFIG.reqa_file = reqa_file - CONFIG.max_auto_summarize_code = max_auto_summarize_code + CONFIG.update_via_cli(project_path, project_name, inc, reqa_file, max_auto_summarize_code) if not recover_path: company = Team() From 1162f21b6ceef6e09c84b55927cc72c4930d03d1 Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 17:11:02 +0800 Subject: [PATCH 104/167] refine code --- metagpt/config.py | 12 ++++++++++++ metagpt/startup.py | 5 ----- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/metagpt/config.py b/metagpt/config.py index 50ad6a3b2..68b7a2a96 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -131,8 +131,20 @@ class Config(metaclass=Singleton): self.workspace_path = Path(self._get("WORKSPACE_PATH", DEFAULT_WORKSPACE_ROOT)) self._ensure_workspace_exists() + def _init_cli_paras(self): + self.project_path = None + self.project_name = None + self.inc = None + self.reqa_file = None + self.max_auto_summarize_code = None + def update_via_cli(self, project_path, project_name, inc, reqa_file, max_auto_summarize_code): """update config via cli""" + + # Use in the PrepareDocuments action according to Section 2.2.3.5.1 of RFC 135. + if project_path: + inc = True + project_name = project_name or Path(project_path).name self.project_path = project_path self.project_name = project_name self.inc = inc diff --git a/metagpt/startup.py b/metagpt/startup.py index 6ae47213e..a25b71cd0 100644 --- a/metagpt/startup.py +++ b/metagpt/startup.py @@ -1,7 +1,6 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- import asyncio -from pathlib import Path import typer @@ -42,10 +41,6 @@ def startup( ) from metagpt.team import Team - # Use in the PrepareDocuments action according to Section 2.2.3.5.1 of RFC 135. - if project_path: - inc = True - project_name = project_name or Path(project_path).name CONFIG.update_via_cli(project_path, project_name, inc, reqa_file, max_auto_summarize_code) if not recover_path: From bd12087be4dd16f4b460d3bd0b4a7b6fb41eaa9a Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 17:14:50 +0800 Subject: [PATCH 105/167] fix comment --- metagpt/team.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/metagpt/team.py b/metagpt/team.py index 9aa89ee2b..1df3c4052 100644 --- a/metagpt/team.py +++ b/metagpt/team.py @@ -25,8 +25,8 @@ from metagpt.utils.common import NoMoneyException, read_json_file, write_json_fi class Team(BaseModel): """ - Team: Possesses one or more roles (agents), SOP (Standard Operating Procedures), and a platform for instant messaging, - dedicated to perform any multi-agent activity, such as collaboratively writing executable code. + Team: Possesses one or more roles (agents), SOP (Standard Operating Procedures), and a env for instant messaging, + dedicated to env any multi-agent activity, such as collaboratively writing executable code. """ env: Environment = Field(default_factory=Environment) From f32f9c82e54581241de42bcf21a5d2efcd12c9e1 Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 17:55:34 +0800 Subject: [PATCH 106/167] add llm provider registry --- metagpt/config.py | 56 +++++++++++++---------- metagpt/llm.py | 21 +-------- metagpt/provider/fireworks_api.py | 4 +- metagpt/provider/llm_provider_registry.py | 34 ++++++++++++++ metagpt/provider/open_llm_api.py | 4 +- metagpt/provider/openai_api.py | 4 +- metagpt/provider/spark_api.py | 4 +- metagpt/provider/zhipuai_api.py | 4 +- metagpt/schema.py | 10 ++-- 9 files changed, 89 insertions(+), 52 deletions(-) create mode 100644 metagpt/provider/llm_provider_registry.py diff --git a/metagpt/config.py b/metagpt/config.py index 68b7a2a96..c8346ccdc 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -8,6 +8,7 @@ Provide configuration, singleton """ import os from copy import deepcopy +from enum import Enum from pathlib import Path from typing import Any @@ -31,6 +32,15 @@ class NotConfiguredException(Exception): super().__init__(self.message) +class LLMProviderEnum(Enum): + OPENAI = "openai" + ANTHROPIC = "anthropic" + SPARK = "spark" + ZHIPUAI = "zhipuai" + FIREWORKS = "fireworks" + OPEN_LLM = "open_llm" + + class Config(metaclass=Singleton): """ Regular usage method: @@ -46,30 +56,37 @@ class Config(metaclass=Singleton): def __init__(self, yaml_file=default_yaml_file): golbal_options = OPTIONS.get() + # cli paras + self.project_path = "" + self.project_name = "" + self.inc = False + self.reqa_file = "" + self.max_auto_summarize_code = 0 + self._init_with_config_files_and_env(yaml_file) self._update() golbal_options.update(OPTIONS.get()) logger.debug("Config loading done.") + def get_default_llm_provider_enum(self): + if self._is_valid_llm_key(self.openai_api_key): + llm = LLMProviderEnum.OPENAI + elif self._is_valid_llm_key(self.anthropic_api_key): + llm = LLMProviderEnum.ANTHROPIC + elif self._is_valid_llm_key(self.zhipuai_api_key): + llm = LLMProviderEnum.ZHIPUAI + elif self._is_valid_llm_key(self.fireworks_api_key): + llm = LLMProviderEnum.FIREWORKS + elif self.open_llm_api_base: + llm = LLMProviderEnum.OPEN_LLM + else: + raise NotConfiguredException("You should config a LLM configuration first") + return llm + @staticmethod def _is_valid_llm_key(k) -> bool: return k and k != "YOUR_API_KEY" - def _check_llm_exists(self): - if not any( - [ - self._is_valid_llm_key(self.openai_api_key), - self._is_valid_llm_key(self.anthropic_api_key), - self._is_valid_llm_key(self.zhipuai_api_key), - self._is_valid_llm_key(self.fireworks_api_key), - self.open_llm_api_base, - ] - ): - raise NotConfiguredException( - "Set OPENAI_API_KEY or Anthropic_API_KEY or ZHIPUAI_API_KEY " - "or FIREWORKS_API_KEY or OPEN_LLM_API_BASE" - ) - def _update(self): # logger.info("Config loading done.") self.global_proxy = self._get("GLOBAL_PROXY") @@ -80,7 +97,7 @@ class Config(metaclass=Singleton): self.open_llm_api_base = self._get("OPEN_LLM_API_BASE") self.open_llm_api_model = self._get("OPEN_LLM_API_MODEL") self.fireworks_api_key = self._get("FIREWORKS_API_KEY") - self._check_llm_exists() + _ = self.get_default_llm_provider_enum() self.openai_api_base = self._get("OPENAI_API_BASE") self.openai_proxy = self._get("OPENAI_PROXY") or self.global_proxy @@ -131,13 +148,6 @@ class Config(metaclass=Singleton): self.workspace_path = Path(self._get("WORKSPACE_PATH", DEFAULT_WORKSPACE_ROOT)) self._ensure_workspace_exists() - def _init_cli_paras(self): - self.project_path = None - self.project_name = None - self.inc = None - self.reqa_file = None - self.max_auto_summarize_code = None - def update_via_cli(self, project_path, project_name, inc, reqa_file, max_auto_summarize_code): """update config via cli""" diff --git a/metagpt/llm.py b/metagpt/llm.py index 7c0ad7975..e0c0716de 100644 --- a/metagpt/llm.py +++ b/metagpt/llm.py @@ -8,12 +8,8 @@ from metagpt.config import CONFIG from metagpt.provider.base_gpt_api import BaseGPTAPI -from metagpt.provider.fireworks_api import FireWorksGPTAPI from metagpt.provider.human_provider import HumanProvider -from metagpt.provider.open_llm_api import OpenLLMGPTAPI -from metagpt.provider.openai_api import OpenAIGPTAPI -from metagpt.provider.spark_api import SparkAPI -from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI +from metagpt.provider.llm_provider_registry import LLMProviderRegistry _ = HumanProvider() # Avoid pre-commit error @@ -21,17 +17,4 @@ _ = HumanProvider() # Avoid pre-commit error def LLM() -> BaseGPTAPI: """initialize different LLM instance according to the key field existence""" # TODO a little trick, can use registry to initialize LLM instance further - if CONFIG.openai_api_key: - llm = OpenAIGPTAPI() - elif CONFIG.spark_api_key: - llm = SparkAPI() - elif CONFIG.zhipuai_api_key: - llm = ZhiPuAIGPTAPI() - elif CONFIG.open_llm_api_base: - llm = OpenLLMGPTAPI() - elif CONFIG.fireworks_api_key: - llm = FireWorksGPTAPI() - else: - raise RuntimeError("You should config a LLM configuration first") - - return llm + return LLMProviderRegistry.get_provider(CONFIG.get_default_llm_provider_enum()) diff --git a/metagpt/provider/fireworks_api.py b/metagpt/provider/fireworks_api.py index 47ac9cf61..a76151666 100644 --- a/metagpt/provider/fireworks_api.py +++ b/metagpt/provider/fireworks_api.py @@ -4,10 +4,12 @@ import openai -from metagpt.config import CONFIG +from metagpt.config import CONFIG, LLMProviderEnum +from metagpt.provider.llm_provider_registry import register_provider from metagpt.provider.openai_api import CostManager, OpenAIGPTAPI, RateLimiter +@register_provider(LLMProviderEnum.FIREWORKS) class FireWorksGPTAPI(OpenAIGPTAPI): def __init__(self): self.__init_fireworks(CONFIG) diff --git a/metagpt/provider/llm_provider_registry.py b/metagpt/provider/llm_provider_registry.py new file mode 100644 index 000000000..2b3ef93a3 --- /dev/null +++ b/metagpt/provider/llm_provider_registry.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/19 17:26 +@Author : alexanderwu +@File : llm_provider_registry.py +""" +from metagpt.config import LLMProviderEnum + + +class LLMProviderRegistry: + def __init__(self): + self.providers = {} + + def register(self, key, provider_cls): + self.providers[key] = provider_cls + + def get_provider(self, enum: LLMProviderEnum): + """get provider instance according to the enum""" + return self.providers[enum]() + + +# Registry instance +LLM_REGISTRY = LLMProviderRegistry() + + +def register_provider(key): + """register provider to registry""" + + def decorator(cls): + LLM_REGISTRY.register(key, cls) + return cls + + return decorator diff --git a/metagpt/provider/open_llm_api.py b/metagpt/provider/open_llm_api.py index f421e30c8..bada0e294 100644 --- a/metagpt/provider/open_llm_api.py +++ b/metagpt/provider/open_llm_api.py @@ -4,8 +4,9 @@ import openai -from metagpt.config import CONFIG +from metagpt.config import CONFIG, LLMProviderEnum from metagpt.logs import logger +from metagpt.provider.llm_provider_registry import register_provider from metagpt.provider.openai_api import CostManager, OpenAIGPTAPI, RateLimiter @@ -31,6 +32,7 @@ class OpenLLMCostManager(CostManager): CONFIG.total_cost = self.total_cost +@register_provider(LLMProviderEnum.OPEN_LLM) class OpenLLMGPTAPI(OpenAIGPTAPI): def __init__(self): self.__init_openllm(CONFIG) diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 86054881e..0be70b3ca 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -18,10 +18,11 @@ from tenacity import ( wait_random_exponential, ) -from metagpt.config import CONFIG +from metagpt.config import CONFIG, LLMProviderEnum from metagpt.logs import logger from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA, GENERAL_TOOL_CHOICE +from metagpt.provider.llm_provider_registry import register_provider from metagpt.schema import Message from metagpt.utils.singleton import Singleton from metagpt.utils.token_counter import ( @@ -137,6 +138,7 @@ See FAQ 5.8 raise retry_state.outcome.exception() +@register_provider(LLMProviderEnum.OPENAI) class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): """ Check https://platform.openai.com/examples for examples diff --git a/metagpt/provider/spark_api.py b/metagpt/provider/spark_api.py index 60c86f4dc..484fa7956 100644 --- a/metagpt/provider/spark_api.py +++ b/metagpt/provider/spark_api.py @@ -19,11 +19,13 @@ from wsgiref.handlers import format_date_time import websocket # 使用websocket_client -from metagpt.config import CONFIG +from metagpt.config import CONFIG, LLMProviderEnum from metagpt.logs import logger from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.llm_provider_registry import register_provider +@register_provider(LLMProviderEnum.SPARK) class SparkAPI(BaseGPTAPI): def __init__(self): logger.warning("当前方法无法支持异步运行。当你使用acompletion时,并不能并行访问。") diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py index 92119b764..eef0e51e1 100644 --- a/metagpt/provider/zhipuai_api.py +++ b/metagpt/provider/zhipuai_api.py @@ -16,9 +16,10 @@ from tenacity import ( wait_random_exponential, ) -from metagpt.config import CONFIG +from metagpt.config import CONFIG, LLMProviderEnum from metagpt.logs import logger from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.llm_provider_registry import register_provider from metagpt.provider.openai_api import CostManager, log_and_reraise from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI @@ -30,6 +31,7 @@ class ZhiPuEvent(Enum): FINISH = "finish" +@register_provider(LLMProviderEnum.ZHIPUAI) class ZhiPuAIGPTAPI(BaseGPTAPI): """ Refs to `https://open.bigmodel.cn/dev/api#chatglm_turbo` diff --git a/metagpt/schema.py b/metagpt/schema.py index 991ceaae0..59203c404 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -167,8 +167,8 @@ class Message(BaseModel): @handle_exception(exception_type=JSONDecodeError, default_return=None) def load(val): """Convert the json string to object.""" - d = json.loads(val) - return Message(**d) + i = json.loads(val) + return Message(**i) class UserMessage(Message): @@ -263,16 +263,16 @@ class MessageQueue(BaseModel): return json.dumps(lst) @staticmethod - def load(i) -> "MessageQueue": + def load(data) -> "MessageQueue": """Convert the json string to the `MessageQueue` object.""" queue = MessageQueue() try: - lst = json.loads(i) + lst = json.loads(data) for i in lst: msg = Message(**i) queue.push(msg) except JSONDecodeError as e: - logger.warning(f"JSON load failed: {i}, error:{e}") + logger.warning(f"JSON load failed: {data}, error:{e}") return queue From af59323a692c90a2cfe58eede0dd3189ac6568b7 Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 18:02:51 +0800 Subject: [PATCH 107/167] make registry work --- metagpt/llm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/metagpt/llm.py b/metagpt/llm.py index e0c0716de..60f110a00 100644 --- a/metagpt/llm.py +++ b/metagpt/llm.py @@ -9,7 +9,7 @@ from metagpt.config import CONFIG from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.provider.human_provider import HumanProvider -from metagpt.provider.llm_provider_registry import LLMProviderRegistry +from metagpt.provider.llm_provider_registry import LLM_REGISTRY _ = HumanProvider() # Avoid pre-commit error @@ -17,4 +17,4 @@ _ = HumanProvider() # Avoid pre-commit error def LLM() -> BaseGPTAPI: """initialize different LLM instance according to the key field existence""" # TODO a little trick, can use registry to initialize LLM instance further - return LLMProviderRegistry.get_provider(CONFIG.get_default_llm_provider_enum()) + return LLM_REGISTRY.get_provider(CONFIG.get_default_llm_provider_enum()) From fc829edc45571f9ca3b5d3212a4f49e46d77a4eb Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 18:04:12 +0800 Subject: [PATCH 108/167] make registry work --- metagpt/llm.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/metagpt/llm.py b/metagpt/llm.py index 60f110a00..8763642f0 100644 --- a/metagpt/llm.py +++ b/metagpt/llm.py @@ -6,7 +6,7 @@ @File : llm.py """ -from metagpt.config import CONFIG +from metagpt.config import CONFIG, LLMProviderEnum from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.provider.human_provider import HumanProvider from metagpt.provider.llm_provider_registry import LLM_REGISTRY @@ -14,7 +14,6 @@ from metagpt.provider.llm_provider_registry import LLM_REGISTRY _ = HumanProvider() # Avoid pre-commit error -def LLM() -> BaseGPTAPI: - """initialize different LLM instance according to the key field existence""" - # TODO a little trick, can use registry to initialize LLM instance further - return LLM_REGISTRY.get_provider(CONFIG.get_default_llm_provider_enum()) +def LLM(provider: LLMProviderEnum = CONFIG.get_default_llm_provider_enum()) -> BaseGPTAPI: + """get the default llm provider""" + return LLM_REGISTRY.get_provider(provider) From 06d8dccc16cd6b1694b97960708d1e73c130b7c7 Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 18:50:55 +0800 Subject: [PATCH 109/167] refine code for isinstance --- metagpt/actions/write_prd.py | 2 +- metagpt/roles/role.py | 2 +- metagpt/roles/searcher.py | 2 +- metagpt/utils/common.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/metagpt/actions/write_prd.py b/metagpt/actions/write_prd.py index f087d8650..0febb2656 100644 --- a/metagpt/actions/write_prd.py +++ b/metagpt/actions/write_prd.py @@ -185,7 +185,7 @@ class WritePRD(Action): return if not CONFIG.project_name: - if isinstance(prd, ActionOutput) or isinstance(prd, ActionNode): + if isinstance(prd, (ActionOutput, ActionNode)): ws_name = prd.instruct_content.dict()["Project Name"] else: ws_name = CodeParser.parse_str(block="Project Name", text=prd) diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 3a8721004..fa09999e5 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -370,7 +370,7 @@ class Role(BaseModel): async def _act(self) -> Message: logger.info(f"{self._setting}: ready to {self._rc.todo}") response = await self._rc.todo.run(self._rc.important_memory) - if isinstance(response, ActionOutput) or isinstance(response, ActionNode): + if isinstance(response, (ActionOutput, ActionNode)): msg = Message( content=response.content, instruct_content=response.instruct_content, diff --git a/metagpt/roles/searcher.py b/metagpt/roles/searcher.py index 7d58ad922..a5c399f47 100644 --- a/metagpt/roles/searcher.py +++ b/metagpt/roles/searcher.py @@ -60,7 +60,7 @@ class Searcher(Role): logger.info(f"{self._setting}: ready to {self._rc.todo}") response = await self._rc.todo.run(self._rc.memory.get(k=0)) - if isinstance(response, ActionOutput) or isinstance(response, ActionNode): + if isinstance(response, (ActionOutput, ActionNode)): msg = Message( content=response.content, instruct_content=response.instruct_content, diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index 0060950dc..a445c9f31 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -203,7 +203,7 @@ class OutputParser: result = ast.literal_eval(structure_text) # Ensure the result matches the specified data type - if isinstance(result, list) or isinstance(result, dict): + if isinstance(result, (list, dict)): return result raise ValueError(f"The extracted structure is not a {data_type}.") From da91fb18c0d135cac509df35d8f53098a3c1f00d Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 18:54:04 +0800 Subject: [PATCH 110/167] fix typo --- metagpt/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/metagpt/config.py b/metagpt/config.py index c8346ccdc..8ed957808 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -55,7 +55,7 @@ class Config(metaclass=Singleton): default_yaml_file = METAGPT_ROOT / "config/config.yaml" def __init__(self, yaml_file=default_yaml_file): - golbal_options = OPTIONS.get() + global_options = OPTIONS.get() # cli paras self.project_path = "" self.project_name = "" @@ -65,7 +65,7 @@ class Config(metaclass=Singleton): self._init_with_config_files_and_env(yaml_file) self._update() - golbal_options.update(OPTIONS.get()) + global_options.update(OPTIONS.get()) logger.debug("Config loading done.") def get_default_llm_provider_enum(self): From acb968663f6ba10e4a621e53ab6d2255163a6519 Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 19:00:20 +0800 Subject: [PATCH 111/167] refine cli --- metagpt/startup.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/metagpt/startup.py b/metagpt/startup.py index a25b71cd0..9c17edc1c 100644 --- a/metagpt/startup.py +++ b/metagpt/startup.py @@ -6,7 +6,7 @@ import typer from metagpt.config import CONFIG -app = typer.Typer() +app = typer.Typer(add_completion=False) @app.command() @@ -23,7 +23,9 @@ def startup( default="", help="Specify the directory path of the old version project to fulfill the " "incremental requirements.", ), - reqa_file: str = typer.Option(default="", help="Specify the source file name for rewriting the quality test code."), + reqa_file: str = typer.Option( + default="", help="Specify the source file name for rewriting the quality assurance " "code." + ), max_auto_summarize_code: int = typer.Option( default=-1, help="The maximum number of times the 'SummarizeCode' action is automatically invoked, with -1 indicating " From b5b1ef7ead978303e27364a4d52cf090322a9743 Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 19:00:39 +0800 Subject: [PATCH 112/167] refine cli --- metagpt/startup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/metagpt/startup.py b/metagpt/startup.py index 9c17edc1c..b66f9e305 100644 --- a/metagpt/startup.py +++ b/metagpt/startup.py @@ -21,10 +21,10 @@ def startup( inc: bool = typer.Option(default=False, help="Incremental mode. Use it to coop with existing repo."), project_path: str = typer.Option( default="", - help="Specify the directory path of the old version project to fulfill the " "incremental requirements.", + help="Specify the directory path of the old version project to fulfill the incremental requirements.", ), reqa_file: str = typer.Option( - default="", help="Specify the source file name for rewriting the quality assurance " "code." + default="", help="Specify the source file name for rewriting the quality assurance code." ), max_auto_summarize_code: int = typer.Option( default=-1, From 79bb44b0b7978c590143f1a8b1775747d5490a66 Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 19:15:30 +0800 Subject: [PATCH 113/167] fix pylint --- examples/agent_creator.py | 9 ++++----- metagpt/memory/longterm_memory.py | 8 ++++---- metagpt/memory/memory_storage.py | 2 +- metagpt/roles/product_manager.py | 2 +- metagpt/roles/qa_engineer.py | 2 +- 5 files changed, 11 insertions(+), 12 deletions(-) diff --git a/examples/agent_creator.py b/examples/agent_creator.py index 05417d24a..26af8a287 100644 --- a/examples/agent_creator.py +++ b/examples/agent_creator.py @@ -12,9 +12,8 @@ from metagpt.logs import logger from metagpt.roles import Role from metagpt.schema import Message -with open(METAGPT_ROOT / "examples/build_customized_agent.py", "r") as f: - # use official example script to guide AgentCreator - MULTI_ACTION_AGENT_CODE_EXAMPLE = f.read() +EXAMPLE_CODE_FILE = METAGPT_ROOT / "examples/build_customized_agent.py" +MULTI_ACTION_AGENT_CODE_EXAMPLE = EXAMPLE_CODE_FILE.read_text() class CreateAgent(Action): @@ -50,8 +49,8 @@ class CreateAgent(Action): match = re.search(pattern, rsp, re.DOTALL) code_text = match.group(1) if match else "" CONFIG.workspace_path.mkdir(parents=True, exist_ok=True) - with open(CONFIG.workspace_path / "agent_created_agent.py", "w") as f: - f.write(code_text) + new_file = CONFIG.workspace_path / "agent_created_agent.py" + new_file.write_text(code_text) return code_text diff --git a/metagpt/memory/longterm_memory.py b/metagpt/memory/longterm_memory.py index d36188f0c..069740054 100644 --- a/metagpt/memory/longterm_memory.py +++ b/metagpt/memory/longterm_memory.py @@ -44,7 +44,7 @@ class LongTermMemory(Memory): self.msg_from_recover = False def add(self, message: Message): - super(LongTermMemory, self).add(message) + super().add(message) for action in self.rc.watch: if message.cause_by == action and not self.msg_from_recover: # currently, only add role's watching messages to its memory_storage @@ -57,7 +57,7 @@ class LongTermMemory(Memory): 1. find the short-term memory(stm) news 2. furthermore, filter out similar messages based on ltm(long-term memory), get the final news """ - stm_news = super(LongTermMemory, self).find_news(observed, k=k) # shot-term memory news + stm_news = super().find_news(observed, k=k) # shot-term memory news if not self.memory_storage.is_initialized: # memory_storage hasn't initialized, use default `find_news` to get stm_news return stm_news @@ -71,9 +71,9 @@ class LongTermMemory(Memory): return ltm_news[-k:] def delete(self, message: Message): - super(LongTermMemory, self).delete(message) + super().delete(message) # TODO delete message in memory_storage def clear(self): - super(LongTermMemory, self).clear() + super().clear() self.memory_storage.clean() diff --git a/metagpt/memory/memory_storage.py b/metagpt/memory/memory_storage.py index a213f6d7a..fafb33568 100644 --- a/metagpt/memory/memory_storage.py +++ b/metagpt/memory/memory_storage.py @@ -58,7 +58,7 @@ class MemoryStorage(FaissStore): return index_fpath, storage_fpath def persist(self): - super(MemoryStorage, self).persist() + super().persist() logger.debug(f"Agent {self.role_id} persist memory into local") def add(self, message: Message) -> bool: diff --git a/metagpt/roles/product_manager.py b/metagpt/roles/product_manager.py index 11bda2127..6dba21fe1 100644 --- a/metagpt/roles/product_manager.py +++ b/metagpt/roles/product_manager.py @@ -45,4 +45,4 @@ class ProductManager(Role): return self._rc.todo async def _observe(self, ignore_memory=False) -> int: - return await super(ProductManager, self)._observe(ignore_memory=True) + return await super()._observe(ignore_memory=True) diff --git a/metagpt/roles/qa_engineer.py b/metagpt/roles/qa_engineer.py index ec404570c..acb79ab80 100644 --- a/metagpt/roles/qa_engineer.py +++ b/metagpt/roles/qa_engineer.py @@ -186,4 +186,4 @@ class QaEngineer(Role): async def _observe(self, ignore_memory=False) -> int: # This role has events that trigger and execute themselves based on conditions, and cannot rely on the # content of memory to activate. - return await super(QaEngineer, self)._observe(ignore_memory=True) + return await super()._observe(ignore_memory=True) From 3920982786bdfee81443639f7f1c060da474ca24 Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 19:25:01 +0800 Subject: [PATCH 114/167] refine code --- metagpt/config.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/metagpt/config.py b/metagpt/config.py index 8ed957808..80a3a28f4 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -68,23 +68,22 @@ class Config(metaclass=Singleton): global_options.update(OPTIONS.get()) logger.debug("Config loading done.") - def get_default_llm_provider_enum(self): - if self._is_valid_llm_key(self.openai_api_key): - llm = LLMProviderEnum.OPENAI - elif self._is_valid_llm_key(self.anthropic_api_key): - llm = LLMProviderEnum.ANTHROPIC - elif self._is_valid_llm_key(self.zhipuai_api_key): - llm = LLMProviderEnum.ZHIPUAI - elif self._is_valid_llm_key(self.fireworks_api_key): - llm = LLMProviderEnum.FIREWORKS - elif self.open_llm_api_base: - llm = LLMProviderEnum.OPEN_LLM - else: - raise NotConfiguredException("You should config a LLM configuration first") - return llm + def get_default_llm_provider_enum(self) -> LLMProviderEnum: + for k, v in [ + (self.openai_api_key, LLMProviderEnum.OPENAI), + (self.anthropic_api_key, LLMProviderEnum.ANTHROPIC), + (self.zhipuai_api_key, LLMProviderEnum.ZHIPUAI), + (self.fireworks_api_key, LLMProviderEnum.FIREWORKS), + (self.open_llm_api_base, LLMProviderEnum.OPEN_LLM), # reuse logic. but not a key + ]: + if self._is_valid_llm_key(k): + if self.openai_api_model: + logger.info(f"OpenAI API Model: {self.openai_api_model}") + return v + raise NotConfiguredException("You should config a LLM configuration first") @staticmethod - def _is_valid_llm_key(k) -> bool: + def _is_valid_llm_key(k: str) -> bool: return k and k != "YOUR_API_KEY" def _update(self): From 029eed1792d555b4b373264a2ba8f12d0b81c7aa Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 19:26:01 +0800 Subject: [PATCH 115/167] delete manager.py --- metagpt/manager.py | 66 ---------------------------------------------- 1 file changed, 66 deletions(-) delete mode 100644 metagpt/manager.py diff --git a/metagpt/manager.py b/metagpt/manager.py deleted file mode 100644 index a063608be..000000000 --- a/metagpt/manager.py +++ /dev/null @@ -1,66 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -@Time : 2023/5/11 14:42 -@Author : alexanderwu -@File : manager.py -""" -from metagpt.llm import LLM -from metagpt.logs import logger -from metagpt.schema import Message - - -class Manager: - def __init__(self, llm: LLM = LLM()): - self.llm = llm # Large Language Model - self.role_directions = { - "User": "Product Manager", - "Product Manager": "Architect", - "Architect": "Engineer", - "Engineer": "QA Engineer", - "QA Engineer": "Product Manager", - } - self.prompt_template = """ - Given the following message: - {message} - - And the current status of roles: - {roles} - - Which role should handle this message? - """ - - async def handle(self, message: Message, environment): - """ - 管理员处理信息,现在简单的将信息递交给下一个人 - The administrator processes the information, now simply passes the information on to the next person - :param message: - :param environment: - :return: - """ - # Get all roles from the environment - roles = environment.get_roles() - # logger.debug(f"{roles=}, {message=}") - - # Build a context for the LLM to understand the situation - # context = { - # "message": str(message), - # "roles": {role.name: role.get_info() for role in roles}, - # } - # Ask the LLM to decide which role should handle the message - # chosen_role_name = self.llm.ask(self.prompt_template.format(context)) - - # FIXME: 现在通过简单的字典决定流向,但之后还是应该有思考过程 - # The direction of flow is now determined by a simple dictionary, but there should still be a thought process afterwards - next_role_profile = self.role_directions[message.role] - # logger.debug(f"{next_role_profile}") - for _, role in roles.items(): - if next_role_profile == role.profile: - next_role = role - break - else: - logger.error(f"No available role can handle message: {message}.") - return - - # Find the chosen role and handle the message - return await next_role.handle(message) From 25ea21321fbf5f1212289e77de46875037ecaa85 Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 19:27:11 +0800 Subject: [PATCH 116/167] remove useless fields --- metagpt/actions/action.py | 9 +++------ metagpt/actions/search_and_summarize.py | 3 +-- metagpt/roles/role.py | 2 +- 3 files changed, 5 insertions(+), 9 deletions(-) diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index 9c7fb06e1..ba1bb48de 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -14,6 +14,7 @@ from pydantic import BaseModel, Field from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.actions.action_output import ActionOutput +from metagpt.actions.action_node import ActionNode from metagpt.llm import LLM from metagpt.logs import logger from metagpt.provider.base_gpt_api import BaseGPTAPI @@ -29,11 +30,8 @@ class Action(BaseModel): llm: BaseGPTAPI = Field(default_factory=LLM, exclude=True) context = "" prefix = "" # aask*时会加上prefix,作为system_message - profile = "" # FIXME: USELESS desc = "" # for skill manager - nodes = [] - # content: Optional[str] = None - # instruct_content: Optional[str] = None + node: ActionNode = Field(default_factory=ActionNode) # builtin variables builtin_class_name: str = "" @@ -58,10 +56,9 @@ class Action(BaseModel): obj_dict.pop("llm") return obj_dict - def set_prefix(self, prefix, profile): + def set_prefix(self, prefix): """Set prefix for later usage""" self.prefix = prefix - self.profile = profile return self def __str__(self): diff --git a/metagpt/actions/search_and_summarize.py b/metagpt/actions/search_and_summarize.py index aa4d0f654..3f110c370 100644 --- a/metagpt/actions/search_and_summarize.py +++ b/metagpt/actions/search_and_summarize.py @@ -148,8 +148,7 @@ class SearchAndSummarize(Action): system_prompt = [system_text] prompt = SEARCH_AND_SUMMARIZE_PROMPT.format( - # PREFIX = self.prefix, - ROLE=self.profile, + ROLE=self.prefix, CONTEXT=rsp, QUERY_HISTORY="\n".join([str(i) for i in context[:-1]]), QUERY=str(context[-1]), diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index fa09999e5..e57f21ec3 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -238,7 +238,7 @@ class Role(BaseModel): return role def _init_action_system_message(self, action: Action): - action.set_prefix(self._get_prefix(), self.profile) + action.set_prefix(self._get_prefix()) def set_recovered(self, recovered: bool = False): self.recovered = recovered From a75ab7971fad845f1d07c0fe455cc1a398ec54b4 Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 21:17:02 +0800 Subject: [PATCH 117/167] refine a lot of code, fix pylint, use actionnode include ui, action _aask_v1, detail_mining, prepare_interview, etc. --- metagpt/actions/action.py | 34 +----- metagpt/actions/action_node.py | 81 +++++--------- metagpt/actions/design_api.py | 10 +- metagpt/actions/detail_mining.py | 50 +++------ metagpt/actions/prepare_interview.py | 35 ++---- metagpt/actions/project_management.py | 10 +- metagpt/actions/write_prd.py | 8 +- metagpt/config.py | 2 +- metagpt/utils/get_template.py | 6 +- tests/metagpt/actions/test_detail_mining.py | 4 +- .../metagpt/actions/test_prepare_interview.py | 21 ++++ tests/metagpt/roles/ui_role.py | 104 +++++++++--------- 12 files changed, 150 insertions(+), 215 deletions(-) create mode 100644 tests/metagpt/actions/test_prepare_interview.py diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index ba1bb48de..1fcc8fc80 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -11,15 +11,9 @@ from __future__ import annotations from typing import Optional, Any from pydantic import BaseModel, Field -from tenacity import retry, stop_after_attempt, wait_random_exponential - -from metagpt.actions.action_output import ActionOutput from metagpt.actions.action_node import ActionNode from metagpt.llm import LLM -from metagpt.logs import logger from metagpt.provider.base_gpt_api import BaseGPTAPI -from metagpt.provider.postprecess.llm_output_postprecess import llm_output_postprecess -from metagpt.utils.common import OutputParser, general_after_log action_subclass_registry = {} @@ -31,7 +25,7 @@ class Action(BaseModel): context = "" prefix = "" # aask*时会加上prefix,作为system_message desc = "" # for skill manager - node: ActionNode = Field(default_factory=ActionNode) + node: ActionNode = Field(default_factory=ActionNode, exclude=True) # builtin variables builtin_class_name: str = "" @@ -74,32 +68,6 @@ class Action(BaseModel): system_msgs.append(self.prefix) return await self.llm.aask(prompt, system_msgs) - @retry( - wait=wait_random_exponential(min=1, max=60), - stop=stop_after_attempt(6), - after=general_after_log(logger), - ) - async def _aask_v1( - self, - prompt: str, - output_class_name: str, - output_data_mapping: dict, - system_msgs: Optional[list[str]] = None, - format="markdown", # compatible to original format - ) -> ActionOutput: - content = await self.llm.aask(prompt, system_msgs) - logger.debug(f"llm raw output:\n{content}") - output_class = ActionOutput.create_model_class(output_class_name, output_data_mapping) - - if format == "json": - parsed_data = llm_output_postprecess(output=content, schema=output_class.schema(), req_key="[/CONTENT]") - else: # using markdown parser - parsed_data = OutputParser.parse_data_with_mapping(content, output_data_mapping) - - logger.debug(parsed_data) - instruct_content = output_class(**parsed_data) - return ActionOutput(content, instruct_content) - async def run(self, *args, **kwargs): """Run action""" raise NotImplementedError("The run method should be implemented in a subclass.") diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index 6f1215920..0368d2df1 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -6,17 +6,15 @@ @File : action_node.py """ import json -import re -from typing import Any, Dict, List, Optional, Type +from typing import Dict, Generic, List, Optional, Type, TypeVar from pydantic import BaseModel, create_model, root_validator, validator from tenacity import retry, stop_after_attempt, wait_random_exponential -from metagpt.actions import ActionOutput from metagpt.llm import BaseGPTAPI from metagpt.logs import logger -from metagpt.utils.common import OutputParser -from metagpt.utils.custom_decoder import CustomDecoder +from metagpt.provider.postprecess.llm_output_postprecess import llm_output_postprecess +from metagpt.utils.common import OutputParser, general_after_log CONSTRAINT = """ - Language: Please use the same language as the user input. @@ -43,14 +41,17 @@ Fill in the above nodes based on the format example. """ -def dict_to_markdown(d, prefix="###", postfix="\n"): +def dict_to_markdown(d, prefix="-", postfix="\n"): markdown_str = "" for key, value in d.items(): markdown_str += f"{prefix} {key}: {value}{postfix}" return markdown_str -class ActionNode: +T = TypeVar("T") + + +class ActionNode(Generic[T]): """ActionNode is a tree of nodes.""" mode: str @@ -65,7 +66,7 @@ class ActionNode: expected_type: Type # such as str / int / float etc. # context: str # everything in the history. instruction: str # the instructions should be followed. - example: Any # example for In Context-Learning. + example: T # example for In Context-Learning. # Action Output content: str @@ -76,7 +77,7 @@ class ActionNode: key: str, expected_type: Type, instruction: str, - example: str, + example: T, content: str = "", children: dict[str, "ActionNode"] = None, ): @@ -148,29 +149,6 @@ class ActionNode: new_class.__root_validator_check_missing_fields = classmethod(check_missing_fields) return new_class - @classmethod - def create_model_class_v2(cls, class_name: str, mapping: Dict[str, Type]): - """基于pydantic v2的模型动态生成,用来检验结果类型正确性,待验证""" - new_class = create_model(class_name, **mapping) - - @model_validator(mode="before") - def check_missing_fields(data): - required_fields = set(mapping.keys()) - missing_fields = required_fields - set(data.keys()) - if missing_fields: - raise ValueError(f"Missing fields: {missing_fields}") - return data - - @field_validator("*") - def check_name(v: Any, field: str) -> Any: - if field not in mapping.keys(): - raise ValueError(f"Unrecognized block: {field}") - return v - - new_class.__model_validator_check_missing_fields = classmethod(check_missing_fields) - new_class.__field_validator_check_name = classmethod(check_name) - return new_class - def create_children_class(self): """使用object内有的字段直接生成model_class""" class_name = f"{self.key}_AN" @@ -245,6 +223,7 @@ class ActionNode: """ # FIXME: json instruction会带来格式问题,如:"Project name": "web_2048 # 项目名称使用下划线", + # compile example暂时不支持markdown self.instruction = self.compile_instruction(to="markdown", mode=mode) self.example = self.compile_example(to=to, tag="CONTENT", mode=mode) prompt = template.format( @@ -252,36 +231,32 @@ class ActionNode: ) return prompt - @retry(wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(6)) + @retry( + wait=wait_random_exponential(min=1, max=60), + stop=stop_after_attempt(6), + after=general_after_log(logger), + ) async def _aask_v1( self, prompt: str, output_class_name: str, output_data_mapping: dict, system_msgs: Optional[list[str]] = None, - format="markdown", # compatible to original format - ) -> ActionOutput: + schema="markdown", # compatible to original format + ) -> (str, BaseModel): + """Use ActionOutput to wrap the output of aask""" content = await self.llm.aask(prompt, system_msgs) - logger.debug(content) - output_class = ActionOutput.create_model_class(output_class_name, output_data_mapping) - - if format == "json": - pattern = r"\[CONTENT\](\s*\{.*?\}\s*)\[/CONTENT\]" - matches = re.findall(pattern, content, re.DOTALL) - - for match in matches: - if match: - content = match - break - - parsed_data = CustomDecoder(strict=False).decode(content) + logger.debug(f"llm raw output:\n{content}") + output_class = self.create_model_class(output_class_name, output_data_mapping) + if schema == "json": + parsed_data = llm_output_postprecess(output=content, schema=output_class.schema(), req_key="[/CONTENT]") else: # using markdown parser parsed_data = OutputParser.parse_data_with_mapping(content, output_data_mapping) - logger.debug(parsed_data) + logger.debug(f"parsed_data:\n{parsed_data}") instruct_content = output_class(**parsed_data) - return ActionOutput(content, instruct_content) + return content, instruct_content def get(self, key): return self.instruct_content.dict()[key] @@ -302,9 +277,9 @@ class ActionNode: mapping = self.get_mapping(mode) class_name = f"{self.key}_AN" - output = await self._aask_v1(prompt, class_name, mapping, format=to) - self.content = output.content - self.instruct_content = output.instruct_content + content, scontent = await self._aask_v1(prompt, class_name, mapping, schema=to) + self.content = content + self.instruct_content = scontent return self async def fill(self, context, llm, to="json", mode="auto", strgy="simple"): diff --git a/metagpt/actions/design_api.py b/metagpt/actions/design_api.py index c1778d53f..49c5a019d 100644 --- a/metagpt/actions/design_api.py +++ b/metagpt/actions/design_api.py @@ -49,7 +49,7 @@ class WriteDesign(Action): "data structures, library tables, processes, and paths. Please provide your design, feedback " \ "clearly and in detail." - async def run(self, with_messages: Message, format: str = CONFIG.prompt_format): + async def run(self, with_messages: Message, schema: str = CONFIG.prompt_schema): # Use `git diff` to identify which PRD documents have been modified in the `docs/prds` directory. prds_file_repo = CONFIG.git_repo.new_file_repository(PRDS_FILE_REPO) changed_prds = prds_file_repo.changed_files @@ -79,13 +79,13 @@ class WriteDesign(Action): # leaving room for global optimization in subsequent steps. return ActionOutput(content=changed_files.json(), instruct_content=changed_files) - async def _new_system_design(self, context, format=CONFIG.prompt_format): - node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, to=format) + async def _new_system_design(self, context, schema=CONFIG.prompt_schema): + node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, to=schema) return node - async def _merge(self, prd_doc, system_design_doc, format=CONFIG.prompt_format): + async def _merge(self, prd_doc, system_design_doc, schema=CONFIG.prompt_schema): context = NEW_REQ_TEMPLATE.format(old_design=system_design_doc.content, context=prd_doc.content) - node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, to=format) + node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, to=schema) system_design_doc.content = node.instruct_content.json(ensure_ascii=False) return system_design_doc diff --git a/metagpt/actions/detail_mining.py b/metagpt/actions/detail_mining.py index 5afcf52c6..0314d30dd 100644 --- a/metagpt/actions/detail_mining.py +++ b/metagpt/actions/detail_mining.py @@ -5,47 +5,31 @@ @Author : fisherdeng @File : detail_mining.py """ -from metagpt.actions import Action, ActionOutput +from metagpt.actions import Action +from metagpt.actions.action_node import ActionNode -PROMPT_TEMPLATE = """ -##TOPIC +CONTEXT_TEMPLATE = """ +## TOPIC {topic} -##RECORD +## RECORD {record} - -##Format example -{format_example} ------ - -Task: Refer to the "##TOPIC" (discussion objectives) and "##RECORD" (discussion records) to further inquire about the details that interest you, within a word limit of 150 words. -Special Note 1: Your intention is solely to ask questions without endorsing or negating any individual's viewpoints. -Special Note 2: This output should only include the topic "##OUTPUT". Do not add, remove, or modify the topic. Begin the output with '##OUTPUT', followed by an immediate line break, and then proceed to provide the content in the specified format as outlined in the "##Format example" section. -Special Note 3: The output should be in the same language as the input. """ -FORMAT_EXAMPLE = """ -## - -##OUTPUT -...(Please provide the specific details you would like to inquire about here.) - -## - -## -""" -OUTPUT_MAPPING = { - "OUTPUT": (str, ...), -} +QUESTIONS = ActionNode( + key="Questions", + expected_type=list[str], + instruction="Task: Refer to the context to further inquire about the details that interest you, within a word limit" + " of 150 words. Please provide the specific details you would like to inquire about here", + example=["1. What ...", "2. How ...", "3. ..."], +) class DetailMining(Action): - """This class allows LLM to further mine noteworthy details based on specific "##TOPIC"(discussion topic) and "##RECORD" (discussion records), thereby deepening the discussion.""" + """This class allows LLM to further mine noteworthy details based on specific "##TOPIC"(discussion topic) and + "##RECORD" (discussion records), thereby deepening the discussion.""" - def __init__(self, name="", context=None, llm=None): - super().__init__(name, context, llm) - - async def run(self, topic, record) -> ActionOutput: - prompt = PROMPT_TEMPLATE.format(topic=topic, record=record, format_example=FORMAT_EXAMPLE) - rsp = await self._aask_v1(prompt, "detail_mining", OUTPUT_MAPPING) + async def run(self, topic, record): + context = CONTEXT_TEMPLATE.format(topic=topic, record=record) + rsp = await QUESTIONS.fill(context=context, llm=self.llm) return rsp diff --git a/metagpt/actions/prepare_interview.py b/metagpt/actions/prepare_interview.py index b2704616e..7ed42d590 100644 --- a/metagpt/actions/prepare_interview.py +++ b/metagpt/actions/prepare_interview.py @@ -6,35 +6,18 @@ @File : prepare_interview.py """ from metagpt.actions import Action +from metagpt.actions.action_node import ActionNode -PROMPT_TEMPLATE = """ -# Context -{context} - -## Format example ---- -Q1: question 1 here -References: - - point 1 - - point 2 - -Q2: question 2 here... ---- - ------ -Role: You are an interviewer of our company who is well-knonwn in frontend or backend develop; +QUESTIONS = ActionNode( + key="Questions", + expected_type=list[str], + instruction="""Role: You are an interviewer of our company who is well-knonwn in frontend or backend develop; Requirement: Provide a list of questions for the interviewer to ask the interviewee, by reading the resume of the interviewee in the context. -Attention: Provide as markdown block as the format above, at least 10 questions. -""" - -# prepare for a interview +Attention: Provide as markdown block as the format above, at least 10 questions.""", + example=["1. What ...", "2. How ..."], +) class PrepareInterview(Action): - def __init__(self, name, context=None, llm=None): - super().__init__(name, context, llm) - async def run(self, context): - prompt = PROMPT_TEMPLATE.format(context=context) - question_list = await self._aask_v1(prompt) - return question_list + return await QUESTIONS.fill(context=context, llm=self.llm) diff --git a/metagpt/actions/project_management.py b/metagpt/actions/project_management.py index 2727f7e7f..095881e60 100644 --- a/metagpt/actions/project_management.py +++ b/metagpt/actions/project_management.py @@ -45,7 +45,7 @@ class WriteTasks(Action): context: Optional[str] = None llm: BaseGPTAPI = Field(default_factory=LLM) - async def run(self, with_messages, format=CONFIG.prompt_format): + async def run(self, with_messages, schema=CONFIG.prompt_schema): system_design_file_repo = CONFIG.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO) changed_system_designs = system_design_file_repo.changed_files @@ -92,16 +92,16 @@ class WriteTasks(Action): await self._save_pdf(task_doc=task_doc) return task_doc - async def _run_new_tasks(self, context, format=CONFIG.prompt_format): - node = await PM_NODE.fill(context, self.llm, format) + async def _run_new_tasks(self, context, schema=CONFIG.prompt_schema): + node = await PM_NODE.fill(context, self.llm, schema) # prompt_template, format_example = get_template(templates, format) # prompt = prompt_template.format(context=context, format_example=format_example) # rsp = await self._aask_v1(prompt, "task", OUTPUT_MAPPING, format=format) return node - async def _merge(self, system_design_doc, task_doc, format=CONFIG.prompt_format) -> Document: + async def _merge(self, system_design_doc, task_doc, schema=CONFIG.prompt_schema) -> Document: context = NEW_REQ_TEMPLATE.format(context=system_design_doc.content, old_tasks=task_doc.content) - node = await PM_NODE.fill(context, self.llm, format) + node = await PM_NODE.fill(context, self.llm, schema) task_doc.content = node.instruct_content.json(ensure_ascii=False) return task_doc diff --git a/metagpt/actions/write_prd.py b/metagpt/actions/write_prd.py index 0febb2656..ae1e0379c 100644 --- a/metagpt/actions/write_prd.py +++ b/metagpt/actions/write_prd.py @@ -113,7 +113,7 @@ class WritePRD(Action): # optimization in subsequent steps. return ActionOutput(content=change_files.json(), instruct_content=change_files) - async def _run_new_requirement(self, requirements, format=CONFIG.prompt_format) -> ActionOutput: + async def _run_new_requirement(self, requirements, schema=CONFIG.prompt_schema) -> ActionOutput: # sas = SearchAndSummarize() # # rsp = await sas.run(context=requirements, system_text=SEARCH_AND_SUMMARIZE_SYSTEM_EN_US) # rsp = "" @@ -123,7 +123,7 @@ class WritePRD(Action): # logger.info(rsp) project_name = CONFIG.project_name if CONFIG.project_name else "" context = CONTEXT_TEMPLATE.format(requirements=requirements, project_name=project_name) - node = await WRITE_PRD_NODE.fill(context=context, llm=self.llm, to=format) + node = await WRITE_PRD_NODE.fill(context=context, llm=self.llm, to=schema) await self._rename_workspace(node) return node @@ -132,11 +132,11 @@ class WritePRD(Action): node = await WP_IS_RELATIVE_NODE.fill(context, self.llm) return node.get("is_relative") == "YES" - async def _merge(self, new_requirement_doc, prd_doc, format=CONFIG.prompt_format) -> Document: + async def _merge(self, new_requirement_doc, prd_doc, schema=CONFIG.prompt_schema) -> Document: if not CONFIG.project_name: CONFIG.project_name = Path(CONFIG.project_path).name prompt = NEW_REQ_TEMPLATE.format(requirements=new_requirement_doc.content, old_prd=prd_doc.content) - node = await WRITE_PRD_NODE.fill(context=prompt, llm=self.llm, to=format) + node = await WRITE_PRD_NODE.fill(context=prompt, llm=self.llm, to=schema) prd_doc.content = node.instruct_content.json(ensure_ascii=False) await self._rename_workspace(node) return prd_doc diff --git a/metagpt/config.py b/metagpt/config.py index 80a3a28f4..131854a56 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -143,7 +143,7 @@ class Config(metaclass=Singleton): self.pyppeteer_executable_path = self._get("PYPPETEER_EXECUTABLE_PATH", "") self.repair_llm_output = self._get("REPAIR_LLM_OUTPUT", False) - self.prompt_format = self._get("PROMPT_FORMAT", "json") + self.prompt_schema = self._get("PROMPT_FORMAT", "json") self.workspace_path = Path(self._get("WORKSPACE_PATH", DEFAULT_WORKSPACE_ROOT)) self._ensure_workspace_exists() diff --git a/metagpt/utils/get_template.py b/metagpt/utils/get_template.py index 86c1915f7..7e05e5d5e 100644 --- a/metagpt/utils/get_template.py +++ b/metagpt/utils/get_template.py @@ -8,10 +8,10 @@ from metagpt.config import CONFIG -def get_template(templates, format=CONFIG.prompt_format): - selected_templates = templates.get(format) +def get_template(templates, schema=CONFIG.prompt_schema): + selected_templates = templates.get(schema) if selected_templates is None: - raise ValueError(f"Can't find {format} in passed in templates") + raise ValueError(f"Can't find {schema} in passed in templates") # Extract the selected templates prompt_template = selected_templates["PROMPT_TEMPLATE"] diff --git a/tests/metagpt/actions/test_detail_mining.py b/tests/metagpt/actions/test_detail_mining.py index 891dca6ca..30bcf9dfb 100644 --- a/tests/metagpt/actions/test_detail_mining.py +++ b/tests/metagpt/actions/test_detail_mining.py @@ -19,5 +19,5 @@ async def test_detail_mining(): rsp = await detail_mining.run(topic=topic, record=record) logger.info(f"{rsp.content=}") - assert "##OUTPUT" in rsp.content - assert "蛋糕" in rsp.content + assert "Questions" in rsp.content + assert "1." in rsp.content diff --git a/tests/metagpt/actions/test_prepare_interview.py b/tests/metagpt/actions/test_prepare_interview.py new file mode 100644 index 000000000..7c32882e0 --- /dev/null +++ b/tests/metagpt/actions/test_prepare_interview.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/9/13 00:26 +@Author : fisherdeng +@File : test_detail_mining.py +""" +import pytest + +from metagpt.actions.prepare_interview import PrepareInterview +from metagpt.logs import logger + + +@pytest.mark.asyncio +async def test_prepare_interview(): + action = PrepareInterview() + rsp = await action.run("I just graduated and hope to find a job as a Python engineer") + logger.info(f"{rsp.content=}") + + assert "Questions" in rsp.content + assert "1." in rsp.content diff --git a/tests/metagpt/roles/ui_role.py b/tests/metagpt/roles/ui_role.py index 8ac799bf3..0932efa1f 100644 --- a/tests/metagpt/roles/ui_role.py +++ b/tests/metagpt/roles/ui_role.py @@ -10,6 +10,7 @@ from importlib import import_module from metagpt.actions import Action, ActionOutput, WritePRD # from metagpt.const import WORKSPACE_ROOT +from metagpt.actions.action_node import ActionNode from metagpt.config import CONFIG from metagpt.logs import logger from metagpt.roles import Role @@ -17,44 +18,38 @@ from metagpt.schema import Message from metagpt.tools.sd_engine import SDEngine PROMPT_TEMPLATE = """ -# Context {context} -## Format example -{format_example} ------ -Role: You are a UserInterface Designer; the goal is to finish a UI design according to PRD, give a design description, and select specified elements and UI style. -Requirements: Based on the context, fill in the following missing information, provide detailed HTML and CSS code -Attention: Use '##' to split sections, not '#', and '## ' SHOULD WRITE BEFORE the code and triple quote. - -## UI Design Description:Provide as Plain text, place the design objective here -## Selected Elements:Provide as Plain text, up to 5 specified elements, clear and simple -## HTML Layout:Provide as Plain text, use standard HTML code -## CSS Styles (styles.css):Provide as Plain text,use standard css code -## Anything UNCLEAR:Provide as Plain text. Try to clarify it. - +## Role +You are a UserInterface Designer; the goal is to finish a UI design according to PRD, give a design description, and select specified elements and UI style. """ -FORMAT_EXAMPLE = """ +UI_DESIGN_DESC = ActionNode( + key="UI Design Desc", + expected_type=str, + instruction="place the design objective here", + example="Snake games are classic and addictive games with simple yet engaging elements. Here are the main elements" + " commonly found in snake games", +) -## UI Design Description -```Snake games are classic and addictive games with simple yet engaging elements. Here are the main elements commonly found in snake games ``` +SELECTED_ELEMENTS = ActionNode( + key="Selected Elements", + expected_type=list[str], + instruction="up to 5 specified elements, clear and simple", + example=[ + "Game Grid: The game grid is a rectangular...", + "Snake: The player controls a snake that moves across the grid...", + "Food: Food items (often represented as small objects or differently colored blocks)", + "Score: The player's score increases each time the snake eats a piece of food. The longer the snake becomes, the higher the score.", + "Game Over: The game ends when the snake collides with itself or an obstacle. At this point, the player's final score is displayed, and they are given the option to restart the game.", + ], +) -## Selected Elements - -Game Grid: The game grid is a rectangular... - -Snake: The player controls a snake that moves across the grid... - -Food: Food items (often represented as small objects or differently colored blocks) - -Score: The player's score increases each time the snake eats a piece of food. The longer the snake becomes, the higher the score. - -Game Over: The game ends when the snake collides with itself or an obstacle. At this point, the player's final score is displayed, and they are given the option to restart the game. - - -## HTML Layout - +HTML_LAYOUT = ActionNode( + key="HTML Layout", + expected_type=str, + instruction="use standard HTML code", + example=""" @@ -71,9 +66,14 @@ Game Over: The game ends when the snake collides with itself or an obstacle. At +""", +) -## CSS Styles (styles.css) -body { +CSS_STYLES = ActionNode( + key="CSS Styles", + expected_type=str, + instruction="use standard css code", + example="""body { display: flex; justify-content: center; align-items: center; @@ -121,19 +121,25 @@ body { color: #ff0000; display: none; } +""", +) -## Anything UNCLEAR -There are no unclear points. +ANYTHING_UNCLEAR = ActionNode( + key="Anything UNCLEAR", + expected_type=str, + instruction="Mention any aspects of the project that are unclear and try to clarify them.", + example="...", +) -""" +NODES = [ + UI_DESIGN_DESC, + SELECTED_ELEMENTS, + HTML_LAYOUT, + CSS_STYLES, + ANYTHING_UNCLEAR, +] -OUTPUT_MAPPING = { - "UI Design Description": (str, ...), - "Selected Elements": (str, ...), - "HTML Layout": (str, ...), - "CSS Styles (styles.css)": (str, ...), - "Anything UNCLEAR": (str, ...), -} +UI_DESIGN_NODE = ActionNode.from_children("UI_DESIGN", NODES) def load_engine(func): @@ -223,10 +229,8 @@ class UIDesign(Action): css_file_path = save_dir / "ui_design.css" html_file_path = save_dir / "ui_design.html" - with open(css_file_path, "w") as css_file: - css_file.write(css_content) - with open(html_file_path, "w") as html_file: - html_file.write(html_content) + css_file_path.write_text(css_content) + html_file_path.write_text(html_content) async def run(self, requirements: list[Message], *args, **kwargs) -> ActionOutput: """Run the UI Design action.""" @@ -234,9 +238,9 @@ class UIDesign(Action): context = requirements[-1].content ui_design_draft = self.parse_requirement(context=context) # todo: parse requirements str - prompt = PROMPT_TEMPLATE.format(context=ui_design_draft, format_example=FORMAT_EXAMPLE) + prompt = PROMPT_TEMPLATE.format(context=ui_design_draft) logger.info(prompt) - ui_describe = await self._aask_v1(prompt, "ui_design", OUTPUT_MAPPING) + ui_describe = await UI_DESIGN_NODE.fill(prompt) logger.info(ui_describe.content) logger.info(ui_describe.instruct_content) css = self.parse_css_code(context=ui_describe.content) From d159bfc4e195a6a72ff5b54dcbea9f36c36373fd Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 21:24:08 +0800 Subject: [PATCH 118/167] refactor action_output and action_node --- metagpt/actions/action_node.py | 4 ++-- metagpt/actions/action_output.py | 26 +-------------------- metagpt/actions/write_prd.py | 2 +- metagpt/utils/serialize.py | 11 +++++---- tests/metagpt/actions/test_action_output.py | 6 ++--- tests/metagpt/memory/test_memory_storage.py | 4 ++-- tests/metagpt/utils/test_serialize.py | 4 ++-- 7 files changed, 18 insertions(+), 39 deletions(-) diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index 0368d2df1..865cb2d32 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -6,7 +6,7 @@ @File : action_node.py """ import json -from typing import Dict, Generic, List, Optional, Type, TypeVar +from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar from pydantic import BaseModel, create_model, root_validator, validator from tenacity import retry, stop_after_attempt, wait_random_exponential @@ -127,7 +127,7 @@ class ActionNode(Generic[T]): return self.get_self_mapping() @classmethod - def create_model_class(cls, class_name: str, mapping: Dict[str, Type]): + def create_model_class(cls, class_name: str, mapping: Dict[str, Tuple[Type, Any]]): """基于pydantic v1的模型动态生成,用来检验结果类型正确性""" new_class = create_model(class_name, **mapping) diff --git a/metagpt/actions/action_output.py b/metagpt/actions/action_output.py index 25326d43b..6be8dac50 100644 --- a/metagpt/actions/action_output.py +++ b/metagpt/actions/action_output.py @@ -6,9 +6,7 @@ @File : action_output """ -from typing import Dict, Type - -from pydantic import BaseModel, create_model, root_validator, validator +from pydantic import BaseModel class ActionOutput: @@ -18,25 +16,3 @@ class ActionOutput: def __init__(self, content: str, instruct_content: BaseModel): self.content = content self.instruct_content = instruct_content - - @classmethod - def create_model_class(cls, class_name: str, mapping: Dict[str, Type]): - new_class = create_model(class_name, **mapping) - - @validator("*", allow_reuse=True) - def check_name(v, field): - if field.name not in mapping.keys(): - raise ValueError(f"Unrecognized block: {field.name}") - return v - - @root_validator(pre=True, allow_reuse=True) - def check_missing_fields(values): - required_fields = set(mapping.keys()) - missing_fields = required_fields - set(values.keys()) - if missing_fields: - raise ValueError(f"Missing fields: {missing_fields}") - return values - - new_class.__validator_check_name = classmethod(check_name) - new_class.__root_validator_check_missing_fields = classmethod(check_missing_fields) - return new_class diff --git a/metagpt/actions/write_prd.py b/metagpt/actions/write_prd.py index ae1e0379c..411051199 100644 --- a/metagpt/actions/write_prd.py +++ b/metagpt/actions/write_prd.py @@ -69,7 +69,7 @@ class WritePRD(Action): content: Optional[str] = None llm: BaseGPTAPI = Field(default_factory=LLM) - async def run(self, with_messages, format=CONFIG.prompt_format, *args, **kwargs) -> ActionOutput | Message: + async def run(self, with_messages, schema=CONFIG.prompt_schema, *args, **kwargs) -> ActionOutput | Message: # Determine which requirement documents need to be rewritten: Use LLM to assess whether new requirements are # related to the PRD. If they are related, rewrite the PRD. docs_file_repo = CONFIG.git_repo.new_file_repository(relative_path=DOCS_FILE_REPO) diff --git a/metagpt/utils/serialize.py b/metagpt/utils/serialize.py index 9a758da34..d4db5985b 100644 --- a/metagpt/utils/serialize.py +++ b/metagpt/utils/serialize.py @@ -5,7 +5,12 @@ import copy import pickle +<<<<<<< HEAD from metagpt.utils.common import import_class +======= +from metagpt.actions.action_node import ActionNode +from metagpt.schema import Message +>>>>>>> 09e2f05 (refactor action_output and action_node) def actionoutout_schema_to_mapping(schema: dict) -> dict: @@ -104,13 +109,11 @@ def deserialize_general_message(message_dict: dict) -> "Message": return message -def deserialize_message(message_ser: str) -> "Message": +def deserialize_message(message_ser: str) -> Message: message = pickle.loads(message_ser) if message.instruct_content: ic = message.instruct_content - - actionoutput_class = import_class("ActionOutput", "metagpt.actions.action_output") - ic_obj = actionoutput_class.create_model_class(class_name=ic["class"], mapping=ic["mapping"]) + ic_obj = ActionNode.create_model_class(class_name=ic["class"], mapping=ic["mapping"]) ic_new = ic_obj(**ic["value"]) message.instruct_content = ic_new diff --git a/tests/metagpt/actions/test_action_output.py b/tests/metagpt/actions/test_action_output.py index ef8e239bd..f1765cb03 100644 --- a/tests/metagpt/actions/test_action_output.py +++ b/tests/metagpt/actions/test_action_output.py @@ -7,7 +7,7 @@ """ from typing import List, Tuple -from metagpt.actions import ActionOutput +from metagpt.actions.action_node import ActionNode t_dict = { "Required Python third-party packages": '"""\nflask==1.1.2\npygame==2.0.1\n"""\n', @@ -37,12 +37,12 @@ WRITE_TASKS_OUTPUT_MAPPING = { def test_create_model_class(): - test_class = ActionOutput.create_model_class("test_class", WRITE_TASKS_OUTPUT_MAPPING) + test_class = ActionNode.create_model_class("test_class", WRITE_TASKS_OUTPUT_MAPPING) assert test_class.__name__ == "test_class" def test_create_model_class_with_mapping(): - t = ActionOutput.create_model_class("test_class_1", WRITE_TASKS_OUTPUT_MAPPING) + t = ActionNode.create_model_class("test_class_1", WRITE_TASKS_OUTPUT_MAPPING) t1 = t(**t_dict) value = t1.dict()["Task list"] assert value == ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"] diff --git a/tests/metagpt/memory/test_memory_storage.py b/tests/metagpt/memory/test_memory_storage.py index c67ca689f..7b74eb512 100644 --- a/tests/metagpt/memory/test_memory_storage.py +++ b/tests/metagpt/memory/test_memory_storage.py @@ -8,7 +8,7 @@ from typing import List from metagpt.actions import UserRequirement, WritePRD -from metagpt.actions.action_output import ActionOutput +from metagpt.actions.action_node import ActionNode from metagpt.memory.memory_storage import MemoryStorage from metagpt.schema import Message @@ -42,7 +42,7 @@ def test_idea_message(): def test_actionout_message(): out_mapping = {"field1": (str, ...), "field2": (List[str], ...)} out_data = {"field1": "field1 value", "field2": ["field2 value1", "field2 value2"]} - ic_obj = ActionOutput.create_model_class("prd", out_mapping) + ic_obj = ActionNode.create_model_class("prd", out_mapping) role_id = "UTUser2(Architect)" content = "The user has requested the creation of a command-line interface (CLI) snake game" diff --git a/tests/metagpt/utils/test_serialize.py b/tests/metagpt/utils/test_serialize.py index ffa34866c..f027d53f8 100644 --- a/tests/metagpt/utils/test_serialize.py +++ b/tests/metagpt/utils/test_serialize.py @@ -7,7 +7,7 @@ from typing import List, Tuple from metagpt.actions import WritePRD -from metagpt.actions.action_output import ActionOutput +from metagpt.actions.action_node import ActionNode from metagpt.schema import Message from metagpt.utils.serialize import ( actionoutout_schema_to_mapping, @@ -54,7 +54,7 @@ def test_actionoutout_schema_to_mapping(): def test_serialize_and_deserialize_message(): out_mapping = {"field1": (str, ...), "field2": (List[str], ...)} out_data = {"field1": "field1 value", "field2": ["field2 value1", "field2 value2"]} - ic_obj = ActionOutput.create_model_class("prd", out_mapping) + ic_obj = ActionNode.create_model_class("prd", out_mapping) message = Message( content="prd demand", instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD From a06acbbbe8fc8ada928cfd82bdeab36ecba9e5c9 Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 21:32:52 +0800 Subject: [PATCH 119/167] refine code --- metagpt/actions/action_node.py | 2 +- metagpt/actions/write_prd_an.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index 865cb2d32..790069369 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -232,7 +232,7 @@ class ActionNode(Generic[T]): return prompt @retry( - wait=wait_random_exponential(min=1, max=60), + wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6), after=general_after_log(logger), ) diff --git a/metagpt/actions/write_prd_an.py b/metagpt/actions/write_prd_an.py index d96c0aeac..edd94a463 100644 --- a/metagpt/actions/write_prd_an.py +++ b/metagpt/actions/write_prd_an.py @@ -47,7 +47,7 @@ PRODUCT_GOALS = ActionNode( USER_STORIES = ActionNode( key="User Stories", expected_type=list[str], - instruction="Provide up to five scenario-based user stories.", + instruction="Provide up to 3 to 5 scenario-based user stories.", example=[ "As a user, I want to be able to choose difficulty levels", "As a player, I want to see my score after each game", @@ -57,7 +57,7 @@ USER_STORIES = ActionNode( COMPETITIVE_ANALYSIS = ActionNode( key="Competitive Analysis", expected_type=list[str], - instruction="Provide analyses for up to seven competitive products.", + instruction="Provide 5 to 7 competitive products.", example=["Python Snake Game: Simple interface, lacks advanced features"], ) @@ -92,8 +92,8 @@ REQUIREMENT_ANALYSIS = ActionNode( REQUIREMENT_POOL = ActionNode( key="Requirement Pool", expected_type=list[list[str]], - instruction="List down the requirements with their priority (P0, P1, P2).", - example=[["P0", "..."], ["P1", "..."]], + instruction="List down the top-5 requirements with their priority (P0, P1, P2).", + example=[["P0", "The main code ..."], ["P0", "The game algorithm ..."]], ) UI_DESIGN_DRAFT = ActionNode( From 4d78dbce406dc85e90eb865037b883de278390d5 Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 23:53:04 +0800 Subject: [PATCH 120/167] refine code. move azure tts to tool, refactor actions --- metagpt/actions/__init__.py | 2 - metagpt/actions/action.py | 3 +- metagpt/actions/analyze_dep_libs.py | 37 ------------------- metagpt/actions/design_filenames.py | 30 --------------- ...detail_mining.py => generate_questions.py} | 18 ++------- metagpt/schema.py | 3 +- metagpt/{actions => tools}/azure_tts.py | 19 ++++------ metagpt/utils/serialize.py | 4 +- tests/metagpt/actions/test_azure_tts.py | 4 +- tests/metagpt/actions/test_detail_mining.py | 20 ++++++---- 10 files changed, 32 insertions(+), 108 deletions(-) delete mode 100644 metagpt/actions/analyze_dep_libs.py delete mode 100644 metagpt/actions/design_filenames.py rename metagpt/actions/{detail_mining.py => generate_questions.py} (69%) rename metagpt/{actions => tools}/azure_tts.py (65%) diff --git a/metagpt/actions/__init__.py b/metagpt/actions/__init__.py index 79ff94b3e..c34c72ed2 100644 --- a/metagpt/actions/__init__.py +++ b/metagpt/actions/__init__.py @@ -13,7 +13,6 @@ from metagpt.actions.add_requirement import UserRequirement from metagpt.actions.debug_error import DebugError from metagpt.actions.design_api import WriteDesign from metagpt.actions.design_api_review import DesignReview -from metagpt.actions.design_filenames import DesignFilenames from metagpt.actions.project_management import AssignTasks, WriteTasks from metagpt.actions.research import CollectLinks, WebBrowseAndSummarize, ConductResearch from metagpt.actions.run_code import RunCode @@ -33,7 +32,6 @@ class ActionType(Enum): WRITE_PRD_REVIEW = WritePRDReview WRITE_DESIGN = WriteDesign DESIGN_REVIEW = DesignReview - DESIGN_FILENAMES = DesignFilenames WRTIE_CODE = WriteCode WRITE_CODE_REVIEW = WriteCodeReview WRITE_TEST = WriteTest diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index 1fcc8fc80..e18983d7d 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -14,6 +14,7 @@ from pydantic import BaseModel, Field from metagpt.actions.action_node import ActionNode from metagpt.llm import LLM from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.schema import CodingContext, CodeSummarizeContext, TestingContext, RunCodeContext action_subclass_registry = {} @@ -22,7 +23,7 @@ action_subclass_registry = {} class Action(BaseModel): name: str = "" llm: BaseGPTAPI = Field(default_factory=LLM, exclude=True) - context = "" + context: dict | CodingContext | CodeSummarizeContext | TestingContext | RunCodeContext | str | None = "" prefix = "" # aask*时会加上prefix,作为system_message desc = "" # for skill manager node: ActionNode = Field(default_factory=ActionNode, exclude=True) diff --git a/metagpt/actions/analyze_dep_libs.py b/metagpt/actions/analyze_dep_libs.py deleted file mode 100644 index 53d40200a..000000000 --- a/metagpt/actions/analyze_dep_libs.py +++ /dev/null @@ -1,37 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -@Time : 2023/5/19 12:01 -@Author : alexanderwu -@File : analyze_dep_libs.py -""" - -from metagpt.actions import Action - -PROMPT = """You are an AI developer, trying to write a program that generates code for users based on their intentions. - -For the user's prompt: - ---- -The API is: {prompt} ---- - -We decide the generated files are: {filepaths_string} - -Now that we have a file list, we need to understand the shared dependencies they have. -Please list and briefly describe the shared contents between the files we are generating, including exported variables, -data patterns, id names of all DOM elements that javascript functions will use, message names and function names. -Focus only on the names of shared dependencies, do not add any other explanations. -""" - - -class AnalyzeDepLibs(Action): - def __init__(self, name, context=None, llm=None): - super().__init__(name, context, llm) - self.desc = "Analyze the runtime dependencies of the program based on the context" - - async def run(self, requirement, filepaths_string): - # prompt = f"Below is the product requirement document (PRD):\n\n{prd}\n\n{PROMPT}" - prompt = PROMPT.format(prompt=requirement, filepaths_string=filepaths_string) - design_filenames = await self._aask(prompt) - return design_filenames diff --git a/metagpt/actions/design_filenames.py b/metagpt/actions/design_filenames.py deleted file mode 100644 index ffa171d7b..000000000 --- a/metagpt/actions/design_filenames.py +++ /dev/null @@ -1,30 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -@Time : 2023/5/19 11:50 -@Author : alexanderwu -@File : design_filenames.py -""" -from metagpt.actions import Action -from metagpt.logs import logger - -PROMPT = """You are an AI developer, trying to write a program that generates code for users based on their intentions. -When given their intentions, provide a complete and exhaustive list of file paths needed to write the program for the user. -Only list the file paths you will write and return them as a Python string list. -Do not add any other explanations, just return a Python string list.""" - - -class DesignFilenames(Action): - def __init__(self, name, context=None, llm=None): - super().__init__(name, context, llm) - self.desc = ( - "Based on the PRD, consider system design, and carry out the basic design of the corresponding " - "APIs, data structures, and database tables. Please give your design, feedback clearly and in detail." - ) - - async def run(self, prd): - prompt = f"The following is the Product Requirement Document (PRD):\n\n{prd}\n\n{PROMPT}" - design_filenames = await self._aask(prompt) - logger.debug(prompt) - logger.debug(design_filenames) - return design_filenames diff --git a/metagpt/actions/detail_mining.py b/metagpt/actions/generate_questions.py similarity index 69% rename from metagpt/actions/detail_mining.py rename to metagpt/actions/generate_questions.py index 0314d30dd..c38c463bc 100644 --- a/metagpt/actions/detail_mining.py +++ b/metagpt/actions/generate_questions.py @@ -3,19 +3,11 @@ """ @Time : 2023/9/12 17:45 @Author : fisherdeng -@File : detail_mining.py +@File : generate_questions.py """ from metagpt.actions import Action from metagpt.actions.action_node import ActionNode -CONTEXT_TEMPLATE = """ -## TOPIC -{topic} - -## RECORD -{record} -""" - QUESTIONS = ActionNode( key="Questions", expected_type=list[str], @@ -25,11 +17,9 @@ QUESTIONS = ActionNode( ) -class DetailMining(Action): +class GenerateQuestions(Action): """This class allows LLM to further mine noteworthy details based on specific "##TOPIC"(discussion topic) and "##RECORD" (discussion records), thereby deepening the discussion.""" - async def run(self, topic, record): - context = CONTEXT_TEMPLATE.format(topic=topic, record=record) - rsp = await QUESTIONS.fill(context=context, llm=self.llm) - return rsp + async def run(self, context): + return await QUESTIONS.fill(context=context, llm=self.llm) diff --git a/metagpt/schema.py b/metagpt/schema.py index 59203c404..327bfd2d1 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -19,6 +19,7 @@ import asyncio import json import os.path import uuid +from abc import ABC from asyncio import Queue, QueueEmpty, wait_for from json import JSONDecodeError from pathlib import Path @@ -281,7 +282,7 @@ class MessageQueue(BaseModel): T = TypeVar("T", bound="BaseModel") -class BaseContext(BaseModel): +class BaseContext(BaseModel, ABC): @classmethod @handle_exception def loads(cls: Type[T], val: str) -> Optional[T]: diff --git a/metagpt/actions/azure_tts.py b/metagpt/tools/azure_tts.py similarity index 65% rename from metagpt/actions/azure_tts.py rename to metagpt/tools/azure_tts.py index daa3f6892..e59d98016 100644 --- a/metagpt/actions/azure_tts.py +++ b/metagpt/tools/azure_tts.py @@ -7,19 +7,16 @@ """ from azure.cognitiveservices.speech import AudioConfig, SpeechConfig, SpeechSynthesizer -from metagpt.actions.action import Action -from metagpt.config import Config +from metagpt.config import CONFIG -class AzureTTS(Action): - def __init__(self, name, context=None, llm=None): - super().__init__(name, context, llm) - self.config = Config() +class AzureTTS: + """https://learn.microsoft.com/zh-cn/azure/cognitive-services/speech-service/language-support?tabs=tts#voice-styles-and-roles""" - # Parameters reference: https://learn.microsoft.com/zh-cn/azure/cognitive-services/speech-service/language-support?tabs=tts#voice-styles-and-roles - def synthesize_speech(self, lang, voice, role, text, output_file): - subscription_key = self.config.get("AZURE_TTS_SUBSCRIPTION_KEY") - region = self.config.get("AZURE_TTS_REGION") + @classmethod + def synthesize_speech(cls, lang, voice, role, text, output_file): + subscription_key = CONFIG.get("AZURE_TTS_SUBSCRIPTION_KEY") + region = CONFIG.get("AZURE_TTS_REGION") speech_config = SpeechConfig(subscription=subscription_key, region=region) speech_config.speech_synthesis_voice_name = voice @@ -41,5 +38,5 @@ class AzureTTS(Action): if __name__ == "__main__": - azure_tts = AzureTTS("azure_tts") + azure_tts = AzureTTS() azure_tts.synthesize_speech("zh-CN", "zh-CN-YunxiNeural", "Boy", "Hello, I am Kaka", "output.wav") diff --git a/metagpt/utils/serialize.py b/metagpt/utils/serialize.py index d4db5985b..8ad46a120 100644 --- a/metagpt/utils/serialize.py +++ b/metagpt/utils/serialize.py @@ -5,12 +5,10 @@ import copy import pickle -<<<<<<< HEAD + from metagpt.utils.common import import_class -======= from metagpt.actions.action_node import ActionNode from metagpt.schema import Message ->>>>>>> 09e2f05 (refactor action_output and action_node) def actionoutout_schema_to_mapping(schema: dict) -> dict: diff --git a/tests/metagpt/actions/test_azure_tts.py b/tests/metagpt/actions/test_azure_tts.py index bcafe10f5..9995e9691 100644 --- a/tests/metagpt/actions/test_azure_tts.py +++ b/tests/metagpt/actions/test_azure_tts.py @@ -5,11 +5,11 @@ @Author : alexanderwu @File : test_azure_tts.py """ -from metagpt.actions.azure_tts import AzureTTS +from metagpt.tools.azure_tts import AzureTTS def test_azure_tts(): - azure_tts = AzureTTS("azure_tts") + azure_tts = AzureTTS() azure_tts.synthesize_speech("zh-CN", "zh-CN-YunxiNeural", "Boy", "你好,我是卡卡", "output.wav") # 运行需要先配置 SUBSCRIPTION_KEY diff --git a/tests/metagpt/actions/test_detail_mining.py b/tests/metagpt/actions/test_detail_mining.py index 30bcf9dfb..a178ec840 100644 --- a/tests/metagpt/actions/test_detail_mining.py +++ b/tests/metagpt/actions/test_detail_mining.py @@ -3,20 +3,26 @@ """ @Time : 2023/9/13 00:26 @Author : fisherdeng -@File : test_detail_mining.py +@File : test_generate_questions.py """ import pytest -from metagpt.actions.detail_mining import DetailMining +from metagpt.actions.generate_questions import GenerateQuestions from metagpt.logs import logger +context = """ +## topic +如何做一个生日蛋糕 + +## record +我认为应该先准备好材料,然后再开始做蛋糕。 +""" + @pytest.mark.asyncio -async def test_detail_mining(): - topic = "如何做一个生日蛋糕" - record = "我认为应该先准备好材料,然后再开始做蛋糕。" - detail_mining = DetailMining("detail_mining") - rsp = await detail_mining.run(topic=topic, record=record) +async def test_generate_questions(): + detail_mining = GenerateQuestions() + rsp = await detail_mining.run(context) logger.info(f"{rsp.content=}") assert "Questions" in rsp.content From b4af3b6270fe1b4d6f57283964e29e6a0d8b1a19 Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 19 Dec 2023 23:58:18 +0800 Subject: [PATCH 121/167] refine code --- metagpt/actions/action_node.py | 52 +++++++++++++++++----------------- metagpt/actions/design_api.py | 4 +-- metagpt/actions/write_prd.py | 4 +-- 3 files changed, 30 insertions(+), 30 deletions(-) diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index 790069369..092dd5755 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -112,15 +112,15 @@ class ActionNode(Generic[T]): obj.add_children(nodes) return obj - def get_children_mapping(self) -> Dict[str, Type]: + def get_children_mapping(self) -> Dict[str, Tuple[Type, Any]]: """获得子ActionNode的字典,以key索引""" return {k: (v.expected_type, ...) for k, v in self.children.items()} - def get_self_mapping(self) -> Dict[str, Type]: + def get_self_mapping(self) -> Dict[str, Tuple[Type, Any]]: """get self key: type mapping""" return {self.key: (self.expected_type, ...)} - def get_mapping(self, mode="children") -> Dict[str, Type]: + def get_mapping(self, mode="children") -> Dict[str, Tuple[Type, Any]]: """get key: type mapping under mode""" if mode == "children" or (mode == "auto" and self.children): return self.get_children_mapping() @@ -175,46 +175,46 @@ class ActionNode(Generic[T]): return node_dict # 遍历子节点并递归调用 to_dict 方法 - for child_key, child_node in self.children.items(): + for _, child_node in self.children.items(): node_dict.update(child_node.to_dict(format_func)) return node_dict - def compile_to(self, i: Dict, to) -> str: - if to == "json": + def compile_to(self, i: Dict, schema) -> str: + if schema == "json": return json.dumps(i, indent=4) - elif to == "markdown": + elif schema == "markdown": return dict_to_markdown(i) else: return str(i) - def tagging(self, text, to, tag="") -> str: + def tagging(self, text, schema, tag="") -> str: if not tag: return text - if to == "json": + if schema == "json": return f"[{tag}]\n" + text + f"\n[/{tag}]" else: return f"[{tag}]\n" + text + f"\n[/{tag}]" - def _compile_f(self, to, mode, tag, format_func) -> str: + def _compile_f(self, schema, mode, tag, format_func) -> str: nodes = self.to_dict(format_func=format_func, mode=mode) - text = self.compile_to(nodes, to) - return self.tagging(text, to, tag) + text = self.compile_to(nodes, schema) + return self.tagging(text, schema, tag) - def compile_instruction(self, to="raw", mode="children", tag="") -> str: + def compile_instruction(self, schema="raw", mode="children", tag="") -> str: """compile to raw/json/markdown template with all/root/children nodes""" format_func = lambda i: f"{i.expected_type} # {i.instruction}" - return self._compile_f(to, mode, tag, format_func) + return self._compile_f(schema, mode, tag, format_func) - def compile_example(self, to="raw", mode="children", tag="") -> str: + def compile_example(self, schema="raw", mode="children", tag="") -> str: """compile to raw/json/markdown examples with all/root/children nodes""" # 这里不能使用f-string,因为转译为str后再json.dumps会额外加上引号,无法作为有效的example # 错误示例:"File list": "['main.py', 'const.py', 'game.py']", 注意这里值不是list,而是str format_func = lambda i: i.example - return self._compile_f(to, mode, tag, format_func) + return self._compile_f(schema, mode, tag, format_func) - def compile(self, context, to="json", mode="children", template=SIMPLE_TEMPLATE) -> str: + def compile(self, context, schema="json", mode="children", template=SIMPLE_TEMPLATE) -> str: """ mode: all/root/children mode="children": 编译所有子节点为一个统一模板,包括instruction与example @@ -224,8 +224,8 @@ class ActionNode(Generic[T]): # FIXME: json instruction会带来格式问题,如:"Project name": "web_2048 # 项目名称使用下划线", # compile example暂时不支持markdown - self.instruction = self.compile_instruction(to="markdown", mode=mode) - self.example = self.compile_example(to=to, tag="CONTENT", mode=mode) + self.instruction = self.compile_instruction(schema="markdown", mode=mode) + self.example = self.compile_example(schema=schema, tag="CONTENT", mode=mode) prompt = template.format( context=context, example=self.example, instruction=self.instruction, constraint=CONSTRAINT ) @@ -272,22 +272,22 @@ class ActionNode(Generic[T]): def set_context(self, context): self.set_recursive("context", context) - async def simple_fill(self, to, mode): - prompt = self.compile(context=self.context, to=to, mode=mode) + async def simple_fill(self, schema, mode): + prompt = self.compile(context=self.context, schema=schema, mode=mode) mapping = self.get_mapping(mode) class_name = f"{self.key}_AN" - content, scontent = await self._aask_v1(prompt, class_name, mapping, schema=to) + content, scontent = await self._aask_v1(prompt, class_name, mapping, schema=schema) self.content = content self.instruct_content = scontent return self - async def fill(self, context, llm, to="json", mode="auto", strgy="simple"): + async def fill(self, context, llm, schema="json", mode="auto", strgy="simple"): """Fill the node(s) with mode. :param context: Everything we should know when filling node. :param llm: Large Language Model with pre-defined system message. - :param to: json/markdown, determine example and output format. + :param schema: json/markdown, determine example and output format. - json: it's easy to open source LLM with json format - markdown: when generating code, markdown is always better :param mode: auto/children/root @@ -303,12 +303,12 @@ class ActionNode(Generic[T]): self.set_context(context) if strgy == "simple": - return await self.simple_fill(to, mode) + return await self.simple_fill(schema, mode) elif strgy == "complex": # 这里隐式假设了拥有children tmp = {} for _, i in self.children.items(): - child = await i.simple_fill(to, mode) + child = await i.simple_fill(schema, mode) tmp.update(child.instruct_content.dict()) cls = self.create_children_class() self.instruct_content = cls(**tmp) diff --git a/metagpt/actions/design_api.py b/metagpt/actions/design_api.py index 49c5a019d..f5e122356 100644 --- a/metagpt/actions/design_api.py +++ b/metagpt/actions/design_api.py @@ -80,12 +80,12 @@ class WriteDesign(Action): return ActionOutput(content=changed_files.json(), instruct_content=changed_files) async def _new_system_design(self, context, schema=CONFIG.prompt_schema): - node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, to=schema) + node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, schema=schema) return node async def _merge(self, prd_doc, system_design_doc, schema=CONFIG.prompt_schema): context = NEW_REQ_TEMPLATE.format(old_design=system_design_doc.content, context=prd_doc.content) - node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, to=schema) + node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, schema=schema) system_design_doc.content = node.instruct_content.json(ensure_ascii=False) return system_design_doc diff --git a/metagpt/actions/write_prd.py b/metagpt/actions/write_prd.py index 411051199..df66e6442 100644 --- a/metagpt/actions/write_prd.py +++ b/metagpt/actions/write_prd.py @@ -123,7 +123,7 @@ class WritePRD(Action): # logger.info(rsp) project_name = CONFIG.project_name if CONFIG.project_name else "" context = CONTEXT_TEMPLATE.format(requirements=requirements, project_name=project_name) - node = await WRITE_PRD_NODE.fill(context=context, llm=self.llm, to=schema) + node = await WRITE_PRD_NODE.fill(context=context, llm=self.llm, schema=schema) await self._rename_workspace(node) return node @@ -136,7 +136,7 @@ class WritePRD(Action): if not CONFIG.project_name: CONFIG.project_name = Path(CONFIG.project_path).name prompt = NEW_REQ_TEMPLATE.format(requirements=new_requirement_doc.content, old_prd=prd_doc.content) - node = await WRITE_PRD_NODE.fill(context=prompt, llm=self.llm, to=schema) + node = await WRITE_PRD_NODE.fill(context=prompt, llm=self.llm, schema=schema) prd_doc.content = node.instruct_content.json(ensure_ascii=False) await self._rename_workspace(node) return prd_doc From 8107861302f4a31a624372ce4a1f59ed64f0276f Mon Sep 17 00:00:00 2001 From: geekan Date: Wed, 20 Dec 2023 00:34:57 +0800 Subject: [PATCH 122/167] refine devcontainer README --- .devcontainer/README.md | 41 ++++++++++++++++++----------------------- 1 file changed, 18 insertions(+), 23 deletions(-) diff --git a/.devcontainer/README.md b/.devcontainer/README.md index dd088aab1..be692c14d 100644 --- a/.devcontainer/README.md +++ b/.devcontainer/README.md @@ -1,39 +1,34 @@ -# Dev container +# Dev Container -This project includes a [dev container](https://containers.dev/), which lets you use a container as a full-featured dev environment. +This project includes a [Dev Container](https://containers.dev/), offering you a comprehensive and fully-featured development environment within a container. By leveraging the Dev Container configuration in this folder, you can seamlessly build and initiate MetaGPT locally. For detailed information, please refer to the main README in the home directory. -You can use the dev container configuration in this folder to build and start running MetaGPT locally! For more, refer to the main README under the home directory. -You can use it in [GitHub Codespaces](https://github.com/features/codespaces) or the [VS Code Dev Containers extension](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-containers). +You can utilize this Dev Container in [GitHub Codespaces](https://github.com/features/codespaces) or with the [VS Code Dev Containers extension](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-containers). ## GitHub Codespaces -Open in GitHub Codespaces +[![Open in GitHub Codespaces](https://github.com/codespaces/badge.svg)](https://codespaces.new/geekan/MetaGPT) -You may use the button above to open this repo in a Codespace +Click the button above to open this repository in a Codespace. For additional information, refer to the [GitHub documentation on creating a Codespace](https://docs.github.com/en/free-pro-team@latest/github/developing-online-with-codespaces/creating-a-codespace#creating-a-codespace). -For more info, check out the [GitHub documentation](https://docs.github.com/en/free-pro-team@latest/github/developing-online-with-codespaces/creating-a-codespace#creating-a-codespace). - ## VS Code Dev Containers -Open in Dev Containers +[![Open in Dev Containers](https://img.shields.io/static/v1?label=Dev%20Containers&message=Open&color=blue&logo=visualstudiocode)](https://vscode.dev/redirect?url=vscode://ms-vscode-remote.remote-containers/cloneInVolume?url=https://github.com/geekan/MetaGPT) -Note: If you click this link you will open the main repo and not your local cloned repo, you can use this link and replace with your username and cloned repo name: -https://vscode.dev/redirect?url=vscode://ms-vscode-remote.remote-containers/cloneInVolume?url=https://github.com/geekan/MetaGPT +Note: Clicking the link above opens the main repository. To open your local cloned repository, replace the URL with your username and cloned repository's name: `https://vscode.dev/redirect?url=vscode://ms-vscode-remote.remote-containers/cloneInVolume?url=https://github.com//` +If you have VS Code and Docker installed, use the button above to get started. This will prompt VS Code to install the Dev Containers extension if it's not already installed, clone the source code into a container volume, and set up a dev container for you. -If you already have VS Code and Docker installed, you can use the button above to get started. This will cause VS Code to automatically install the Dev Containers extension if needed, clone the source code into a container volume, and spin up a dev container for use. +Alternatively, follow these steps to open this repository in a container using the VS Code Dev Containers extension: -You can also follow these steps to open this repo in a container using the VS Code Dev Containers extension: +1. For first-time users of a development container, ensure your system meets the prerequisites (e.g., Docker installation) as outlined in the [getting started steps](https://aka.ms/vscode-remote/containers/getting-started). -1. If this is your first time using a development container, please ensure your system meets the pre-reqs (i.e. have Docker installed) in the [getting started steps](https://aka.ms/vscode-remote/containers/getting-started). - -2. Open a locally cloned copy of the code: - - - Fork and Clone this repository to your local filesystem. +2. To open a locally cloned copy of the code: + - Fork and clone this repository to your local file system. - Press F1 and select the **Dev Containers: Open Folder in Container...** command. - - Select the cloned copy of this folder, wait for the container to start, and try things out! + - Choose the cloned folder, wait for the container to initialize, and start exploring! -You can learn more in the [Dev Containers documentation](https://code.visualstudio.com/docs/devcontainers/containers). +Learn more in the [VS Code Dev Containers documentation](https://code.visualstudio.com/docs/devcontainers/containers). -## Tips and tricks +## Tips and Tricks -* If you are working with the same repository folder in a container and Windows, you'll want consistent line endings (otherwise you may see hundreds of changes in the SCM view). The `.gitattributes` file in the root of this repo will disable line ending conversion and should prevent this. See [tips and tricks](https://code.visualstudio.com/docs/devcontainers/tips-and-tricks#_resolving-git-line-ending-issues-in-containers-resulting-in-many-modified-files) for more info. -* If you'd like to review the contents of the image used in this dev container, you can check it out in the [devcontainers/images](https://github.com/devcontainers/images/tree/main/src/python) repo. +* When working with the same repository folder in both a container and on Windows, it's crucial to have consistent line endings to avoid numerous changes in the SCM view. The `.gitattributes` file in the root of this repository disables line ending conversion, helping to prevent this issue. For more information, see [resolving git line ending issues in containers](https://code.visualstudio.com/docs/devcontainers/tips-and-tricks#_resolving-git-line-ending-issues-in-containers-resulting-in-many-modified-files). + +* If you're curious about the contents of the image used in this Dev Container, you can review it in the [devcontainers/images](https://github.com/devcontainers/images/tree/main/src/python) repository. From a7b909e6fe816cb1840a480614f7a77074275ca8 Mon Sep 17 00:00:00 2001 From: geekan Date: Wed, 20 Dec 2023 00:35:15 +0800 Subject: [PATCH 123/167] add proper space --- .devcontainer/postCreateCommand.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.devcontainer/postCreateCommand.sh b/.devcontainer/postCreateCommand.sh index 46788e306..3901193cd 100644 --- a/.devcontainer/postCreateCommand.sh +++ b/.devcontainer/postCreateCommand.sh @@ -4,4 +4,4 @@ sudo npm install -g @mermaid-js/mermaid-cli # Step 2: Ensure that Python 3.9+ is installed on your system. You can check this by using: python --version -pip install -e. \ No newline at end of file +pip install -e . \ No newline at end of file From 111e820722ed814d8321e34d3604e52ba96a5436 Mon Sep 17 00:00:00 2001 From: geekan Date: Wed, 20 Dec 2023 00:39:35 +0800 Subject: [PATCH 124/167] .gitattributes: ensure lf --- .gitattributes | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/.gitattributes b/.gitattributes index 32555a806..7f1424434 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,2 +1,29 @@ +# HTML code is incorrectly calculated into statistics, so ignore them *.html linguist-detectable=false +# Auto detect text files and perform LF normalization +* text=auto eol=lf + +# Ensure shell scripts use LF (Linux style) line endings on Windows +*.sh text eol=lf + +# Treat specific binary files as binary and prevent line ending conversion +*.png binary +*.jpg binary +*.gif binary +*.ico binary + +# Preserve original line endings for specific document files +*.doc text eol=crlf +*.docx text eol=crlf +*.pdf binary + +# Ensure source code and script files use LF line endings +*.py text eol=lf +*.js text eol=lf +*.html text eol=lf +*.css text eol=lf + +# Specify custom diff driver for specific file types +*.md diff=markdown +*.json diff=json From 250c5503de0374d37c9d153a75f7f84708bc2319 Mon Sep 17 00:00:00 2001 From: geekan Date: Wed, 20 Dec 2023 00:47:28 +0800 Subject: [PATCH 125/167] refine .gitignore and .pre-commit-config.yaml --- .gitignore | 8 +------- .pre-commit-config.yaml | 2 +- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/.gitignore b/.gitignore index 0ac318ff5..c12506b0e 100644 --- a/.gitignore +++ b/.gitignore @@ -144,24 +144,18 @@ cython_debug/ allure-report allure-results -# idea +# idea / vscode / macos .idea .DS_Store .vscode -log.txt -docs/scripts/set_env.sh key.yaml -output.json data -data/output_add.json data.ms examples/nb/ .chroma *~$* workspace/* -*.mmd tmp -output.wav metagpt/roles/idea_agent.py .aider* diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b1892a709..338f832ac 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,7 +2,7 @@ default_stages: [ commit ] # Install # 1. pip install pre-commit -# 2. pre-commit install(the first time you download the repo, it will be cached for future use) +# 2. pre-commit install repos: - repo: https://github.com/pycqa/isort rev: 5.11.5 From ec6493a748bce00b768a81caea2ff59cf729c40b Mon Sep 17 00:00:00 2001 From: geekan Date: Wed, 20 Dec 2023 00:49:08 +0800 Subject: [PATCH 126/167] updating time of license --- LICENSE | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/LICENSE b/LICENSE index 5b0c000cd..67460e101 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ The MIT License -Copyright (c) Chenglin Wu +Copyright (c) 2023 Chenglin Wu Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal From d85adbd6402d85425e9891aa10a060d77b9af489 Mon Sep 17 00:00:00 2001 From: geekan Date: Wed, 20 Dec 2023 00:53:36 +0800 Subject: [PATCH 127/167] align ruff.toml with black --- ruff.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ruff.toml b/ruff.toml index 7835865e0..21de5ee14 100644 --- a/ruff.toml +++ b/ruff.toml @@ -31,7 +31,7 @@ exclude = [ ] # Same as Black. -line-length = 119 +line-length = 120 # Allow unused variables when underscore-prefixed. dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" From 3a44b89ad882297d33c441745ec80e686ccc29a6 Mon Sep 17 00:00:00 2001 From: geekan Date: Wed, 20 Dec 2023 00:54:29 +0800 Subject: [PATCH 128/167] uncomment fire in requirements.txt due to usage in the example --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 515a4d88b..f5ef63c58 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ channels==4.0.0 # docx==0.2.4 #faiss==1.5.3 faiss_cpu==1.7.4 -# fire==0.4.0 +fire==0.4.0 typer # godot==0.1.1 # google_api_python_client==2.93.0 From de23c23839b29d04209fb2781cf702043e9c16c7 Mon Sep 17 00:00:00 2001 From: geekan Date: Wed, 20 Dec 2023 00:58:56 +0800 Subject: [PATCH 129/167] add proper space --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index c6e22989b..9eeacbccb 100644 --- a/Dockerfile +++ b/Dockerfile @@ -18,7 +18,7 @@ COPY . /app/metagpt WORKDIR /app/metagpt RUN mkdir workspace &&\ pip install --no-cache-dir -r requirements.txt &&\ - pip install -e. + pip install -e . # Running with an infinite loop using the tail command CMD ["sh", "-c", "tail -f /dev/null"] From 2abc211e0d10a9e92ca79c7bc717985e206bb61b Mon Sep 17 00:00:00 2001 From: geekan Date: Wed, 20 Dec 2023 00:59:23 +0800 Subject: [PATCH 130/167] remove duplicate string --- .dockerignore | 1 - 1 file changed, 1 deletion(-) diff --git a/.dockerignore b/.dockerignore index 2968dd34d..8c09eaf73 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,7 +1,6 @@ workspace tmp build -workspace dist data geckodriver.log From 9eaf08b7dd47398be1c4a4c1fd810a529129e7d5 Mon Sep 17 00:00:00 2001 From: geekan Date: Wed, 20 Dec 2023 11:52:11 +0800 Subject: [PATCH 131/167] refine code for prepare document. remove useless logic --- metagpt/actions/prepare_documents.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/metagpt/actions/prepare_documents.py b/metagpt/actions/prepare_documents.py index 6bb18be7b..696dc9a89 100644 --- a/metagpt/actions/prepare_documents.py +++ b/metagpt/actions/prepare_documents.py @@ -15,7 +15,7 @@ from pydantic import Field from metagpt.actions import Action, ActionOutput from metagpt.config import CONFIG -from metagpt.const import DEFAULT_WORKSPACE_ROOT, DOCS_FILE_REPO, REQUIREMENT_FILENAME +from metagpt.const import DOCS_FILE_REPO, REQUIREMENT_FILENAME from metagpt.llm import LLM from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.schema import Document @@ -24,22 +24,26 @@ from metagpt.utils.git_repository import GitRepository class PrepareDocuments(Action): + """PrepareDocuments Action: initialize project folder and add new requirements to docs/requirements.txt.""" + name: str = "PrepareDocuments" context: Optional[str] = None llm: BaseGPTAPI = Field(default_factory=LLM) + def _init_repo(self): + """Initialize the Git environment.""" + path = CONFIG.project_path + if not path: + name = CONFIG.project_name or FileRepository.new_filename() + path = Path(CONFIG.workspace_path) / name + + if path.exists() and not CONFIG.inc: + shutil.rmtree(path) + CONFIG.git_repo = GitRepository(local_path=path, auto_init=True) + async def run(self, with_messages, **kwargs): - if not CONFIG.git_repo: - # Create and initialize the workspace folder, initialize the Git environment. - project_name = CONFIG.project_name or FileRepository.new_filename() - workdir = CONFIG.project_path - if not workdir and CONFIG.workspace_path: - workdir = Path(CONFIG.workspace_path) / project_name - workdir = Path(workdir or DEFAULT_WORKSPACE_ROOT / project_name) - if not CONFIG.inc and workdir.exists(): - shutil.rmtree(workdir) - CONFIG.git_repo = GitRepository() - CONFIG.git_repo.open(local_path=workdir, auto_init=True) + """Create and initialize the workspace folder, initialize the Git environment.""" + self._init_repo() # Write the newly added requirements from the main parameter idea to `docs/requirement.txt`. doc = Document(root_path=DOCS_FILE_REPO, filename=REQUIREMENT_FILENAME, content=with_messages[0].content) From 608e0e9f16e1f1d2d081dd784621bdf23b684446 Mon Sep 17 00:00:00 2001 From: geekan Date: Wed, 20 Dec 2023 11:59:59 +0800 Subject: [PATCH 132/167] add .pylintrc --- docs/.pylintrc | 639 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 639 insertions(+) create mode 100644 docs/.pylintrc diff --git a/docs/.pylintrc b/docs/.pylintrc new file mode 100644 index 000000000..9e8488bc7 --- /dev/null +++ b/docs/.pylintrc @@ -0,0 +1,639 @@ +[MAIN] + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + +# Clear in-memory caches upon conclusion of linting. Useful if running pylint +# in a server-like mode. +clear-cache-post-run=no + +# Load and enable all available extensions. Use --list-extensions to see a list +# all available extensions. +#enable-all-extensions= + +# In error mode, messages with a category besides ERROR or FATAL are +# suppressed, and no reports are done by default. Error mode is compatible with +# disabling specific errors. +#errors-only= + +# Always return a 0 (non-error) status code, even if lint errors are found. +# This is primarily useful in continuous integration scripts. +#exit-zero= + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. +extension-pkg-allow-list= + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. (This is an alternative name to extension-pkg-allow-list +# for backward compatibility.) +extension-pkg-whitelist=pydantic + +# Return non-zero exit code if any of these messages/categories are detected, +# even if score is above --fail-under value. Syntax same as enable. Messages +# specified are enabled, while categories only check already-enabled messages. +fail-on= + +# Specify a score threshold under which the program will exit with error. +fail-under=10 + +# Interpret the stdin as a python script, whose filename needs to be passed as +# the module_or_package argument. +#from-stdin= + +# Files or directories to be skipped. They should be base names, not paths. +ignore=CVS + +# Add files or directories matching the regular expressions patterns to the +# ignore-list. The regex matches against paths and can be in Posix or Windows +# format. Because '\\' represents the directory delimiter on Windows systems, +# it can't be used as an escape character. +ignore-paths= + +# Files or directories matching the regular expression patterns are skipped. +# The regex matches against base names, not paths. The default value ignores +# Emacs file locks +#ignore-patterns=^\.# +ignore-patterns=(.)*_test\.py,test_(.)*\.py + + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis). It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules= + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +#init-hook= + +# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the +# number of processors available to use, and will cap the count on Windows to +# avoid hangs. +jobs=1 + +# Control the amount of potential inferred values when inferring a single +# object. This can help the performance when dealing with large functions or +# complex, nested conditions. +limit-inference-results=120 + +# List of plugins (as comma separated values of python module names) to load, +# usually to register additional checkers. +load-plugins= + +# Pickle collected data for later comparisons. +persistent=yes + +# Minimum Python version to use for version dependent checks. Will default to +# the version used to run pylint. +py-version=3.9 + +# Discover python modules and packages in the file system subtree. +recursive=no + +# Add paths to the list of the source roots. Supports globbing patterns. The +# source root is an absolute path or a path relative to the current working +# directory used to determine a package namespace for modules located under the +# source root. +source-roots= + +# When enabled, pylint would attempt to guess common misconfiguration and emit +# user-friendly hints instead of false-positive error messages. +suggestion-mode=yes + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + +# In verbose mode, extra non-checker-related info will be displayed. +#verbose= + + +[BASIC] + +# Naming style matching correct argument names. +argument-naming-style=snake_case + +# Regular expression matching correct argument names. Overrides argument- +# naming-style. If left empty, argument names will be checked with the set +# naming style. +#argument-rgx= + +# Naming style matching correct attribute names. +attr-naming-style=snake_case + +# Regular expression matching correct attribute names. Overrides attr-naming- +# style. If left empty, attribute names will be checked with the set naming +# style. +#attr-rgx= + +# Bad variable names which should always be refused, separated by a comma. +bad-names=foo, + bar, + baz, + toto, + tutu, + tata + +# Bad variable names regexes, separated by a comma. If names match any regex, +# they will always be refused +bad-names-rgxs= + +# Naming style matching correct class attribute names. +class-attribute-naming-style=any + +# Regular expression matching correct class attribute names. Overrides class- +# attribute-naming-style. If left empty, class attribute names will be checked +# with the set naming style. +#class-attribute-rgx= + +# Naming style matching correct class constant names. +class-const-naming-style=UPPER_CASE + +# Regular expression matching correct class constant names. Overrides class- +# const-naming-style. If left empty, class constant names will be checked with +# the set naming style. +#class-const-rgx= + +# Naming style matching correct class names. +class-naming-style=PascalCase + +# Regular expression matching correct class names. Overrides class-naming- +# style. If left empty, class names will be checked with the set naming style. +#class-rgx= + +# Naming style matching correct constant names. +const-naming-style=UPPER_CASE + +# Regular expression matching correct constant names. Overrides const-naming- +# style. If left empty, constant names will be checked with the set naming +# style. +#const-rgx= + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=-1 + +# Naming style matching correct function names. +function-naming-style=snake_case + +# Regular expression matching correct function names. Overrides function- +# naming-style. If left empty, function names will be checked with the set +# naming style. +#function-rgx= + +# Good variable names which should always be accepted, separated by a comma. +good-names=i, + j, + k, + v, + e, + d, + m, + df, + ex, + Run, + _ + +# Good variable names regexes, separated by a comma. If names match any regex, +# they will always be accepted +good-names-rgxs= + +# Include a hint for the correct naming format with invalid-name. +include-naming-hint=no + +# Naming style matching correct inline iteration names. +inlinevar-naming-style=any + +# Regular expression matching correct inline iteration names. Overrides +# inlinevar-naming-style. If left empty, inline iteration names will be checked +# with the set naming style. +#inlinevar-rgx= + +# Naming style matching correct method names. +method-naming-style=snake_case + +# Regular expression matching correct method names. Overrides method-naming- +# style. If left empty, method names will be checked with the set naming style. +#method-rgx= + +# Naming style matching correct module names. +module-naming-style=snake_case + +# Regular expression matching correct module names. Overrides module-naming- +# style. If left empty, module names will be checked with the set naming style. +#module-rgx= + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=^_ + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +# These decorators are taken in consideration only for invalid-name. +property-classes=abc.abstractproperty + +# Regular expression matching correct type alias names. If left empty, type +# alias names will be checked with the set naming style. +#typealias-rgx= + +# Regular expression matching correct type variable names. If left empty, type +# variable names will be checked with the set naming style. +#typevar-rgx= + +# Naming style matching correct variable names. +variable-naming-style=snake_case + +# Regular expression matching correct variable names. Overrides variable- +# naming-style. If left empty, variable names will be checked with the set +# naming style. +#variable-rgx= + + +[CLASSES] + +# Warn about protected attribute access inside special methods +check-protected-access-in-special-methods=no + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp, + __post_init__ + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict,_fields,_replace,_source,_make,os._exit + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=mcs + + +[DESIGN] + +# List of regular expressions of class ancestor names to ignore when counting +# public methods (see R0903) +exclude-too-few-public-methods= + +# List of qualified class names to ignore when counting class parents (see +# R0901) +ignored-parents= + +# Maximum number of arguments for function / method. +max-args=5 + +# Maximum number of attributes for a class (see R0902). +max-attributes=7 + +# Maximum number of boolean expressions in an if statement (see R0916). +max-bool-expr=5 + +# Maximum number of branch for function / method body. +max-branches=12 + +# Maximum number of locals for function / method body. +max-locals=15 + +# Maximum number of parents for a class (see R0901). +max-parents=7 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + +# Maximum number of return / yield for function / method body. +max-returns=6 + +# Maximum number of statements in function / method body. +max-statements=50 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=2 + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when caught. +overgeneral-exceptions=builtins.BaseException,builtins.Exception + + +[FORMAT] + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=^\s*(# )??$ + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + +# Maximum number of characters on a single line. +max-line-length=120 + +# Maximum number of lines in a module. +max-module-lines=1000 + +# Allow the body of a class to be on the same line as the declaration if body +# contains single statement. +single-line-class-stmt=no + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=no + + +[IMPORTS] + +# List of modules that can be imported at any level, not just the top level +# one. +allow-any-import-level= + +# Allow explicit reexports by alias from a package __init__. +allow-reexport-from-package=no + +# Allow wildcard imports from modules that define __all__. +allow-wildcard-with-all=no + +# Deprecated modules which should not be used, separated by a comma. +deprecated-modules= + +# Output a graph (.gv or any supported image format) of external dependencies +# to the given file (report RP0402 must not be disabled). +ext-import-graph= + +# Output a graph (.gv or any supported image format) of all (i.e. internal and +# external) dependencies to the given file (report RP0402 must not be +# disabled). +import-graph= + +# Output a graph (.gv or any supported image format) of internal dependencies +# to the given file (report RP0402 must not be disabled). +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant + +# Couples of modules and preferred modules, separated by a comma. +preferred-modules= + + +[LOGGING] + +# The type of string formatting that logging methods do. `old` means using % +# formatting, `new` is for `{}` formatting. +logging-format-style=old + +# Logging modules to check that the string format arguments are in logging +# function parameter format. +logging-modules=logging + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE, +# UNDEFINED. +confidence=HIGH, + CONTROL_FLOW, + INFERENCE, + INFERENCE_FAILURE, + UNDEFINED + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once). You can also use "--disable=all" to +# disable everything first and then re-enable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use "--disable=all --enable=classes +# --disable=W". +disable=raw-checker-failed, + bad-inline-option, + locally-disabled, + file-ignored, + suppressed-message, + useless-suppression, + deprecated-pragma, + use-symbolic-message-instead, + expression-not-assigned, + pointless-statement + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +enable=c-extension-no-member + + +[METHOD_ARGS] + +# List of qualified names (i.e., library.method) which require a timeout +# parameter e.g. 'requests.api.get,requests.api.post' +timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME, + XXX, + TODO + +# Regular expression of note tags to take in consideration. +notes-rgx= + + +[REFACTORING] + +# Maximum number of nested blocks for function / method body +max-nested-blocks=5 + +# Complete name of functions that never returns. When checking for +# inconsistent-return-statements if a never returning function is called then +# it will be considered as an explicit return statement and no message will be +# printed. +never-returning-functions=sys.exit,argparse.parse_error + + +[REPORTS] + +# Python expression which should return a score less than or equal to 10. You +# have access to the variables 'fatal', 'error', 'warning', 'refactor', +# 'convention', and 'info' which contain the number of messages in each +# category, as well as 'statement' which is the total number of statements +# analyzed. This score is used by the global evaluation report (RP0004). +evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details. +msg-template= + +# Set the output format. Available formats are text, parseable, colorized, json +# and msvs (visual studio). You can also give a reporter class, e.g. +# mypackage.mymodule.MyReporterClass. +#output-format= + +# Tells whether to display a full report or only the messages. +reports=no + +# Activate the evaluation score. +score=yes + + +[SIMILARITIES] + +# Comments are removed from the similarity computation +ignore-comments=yes + +# Docstrings are removed from the similarity computation +ignore-docstrings=yes + +# Imports are removed from the similarity computation +ignore-imports=yes + +# Signatures are removed from the similarity computation +ignore-signatures=yes + +# Minimum lines number of a similarity. +min-similarity-lines=4 + + +[SPELLING] + +# Limits count of emitted suggestions for spelling mistakes. +max-spelling-suggestions=4 + +# Spelling dictionary name. No available dictionaries : You need to install +# both the python package and the system dependency for enchant to work.. +spelling-dict= + +# List of comma separated words that should be considered directives if they +# appear at the beginning of a comment and should not be checked. +spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy: + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains the private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to the private dictionary (see the +# --spelling-private-dict-file option) instead of raising a message. +spelling-store-unknown-words=no + + +[STRING] + +# This flag controls whether inconsistent-quotes generates a warning when the +# character used as a quote delimiter is used inconsistently within a module. +check-quote-consistency=no + +# This flag controls whether the implicit-str-concat should generate a warning +# on implicit string concatenation in sequences defined over several lines. +check-str-concat-over-line-jumps=no + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members= + +# Tells whether to warn about missing members when the owner of the attribute +# is inferred to be None. +ignore-none=yes + +# This flag controls whether pylint should warn about no-member and similar +# checks whenever an opaque object is returned when inferring. The inference +# can return multiple potential results while evaluating a Python object, but +# some branches might not be evaluated, which results in partial inference. In +# that case, it might be useful to still emit no-member and other checks for +# the rest of the inferred objects. +ignore-on-opaque-inference=yes + +# List of symbolic message names to ignore for Mixin members. +ignored-checks-for-mixins=no-member, + not-async-context-manager, + not-context-manager, + attribute-defined-outside-init + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace + +# Show a hint with possible names when a member name was not found. The aspect +# of finding the hint is based on edit distance. +missing-member-hint=yes + +# The minimum edit distance a name should have in order to be considered a +# similar match for a missing member name. +missing-member-hint-distance=1 + +# The total number of similar names that should be taken in consideration when +# showing a hint for a missing member. +missing-member-max-choices=1 + +# Regex pattern to define which classes are considered mixins. +mixin-class-rgx=.*[Mm]ixin + +# List of decorators that change the signature of a decorated function. +signature-mutators= + + +[VARIABLES] + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid defining new builtins when possible. +additional-builtins= + +# Tells whether unused global variables should be treated as a violation. +allow-global-unused-variables=yes + +# List of names allowed to shadow builtins +allowed-redefined-builtins= + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_, + _cb + +# A regular expression matching the name of dummy variables (i.e. expected to +# not be used). +dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ + +# Argument names that match this expression will be ignored. +ignored-argument-names=_.*|^ignored_|^unused_ + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io From d5913970d545d08218285c25fddc9c5e0d625ec7 Mon Sep 17 00:00:00 2001 From: geekan Date: Wed, 20 Dec 2023 12:48:57 +0800 Subject: [PATCH 133/167] refine sop --- metagpt/actions/write_prd_an.py | 21 ++++++++++++++------- metagpt/roles/product_manager.py | 4 ++-- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/metagpt/actions/write_prd_an.py b/metagpt/actions/write_prd_an.py index edd94a463..8698c739f 100644 --- a/metagpt/actions/write_prd_an.py +++ b/metagpt/actions/write_prd_an.py @@ -26,8 +26,8 @@ PROGRAMMING_LANGUAGE = ActionNode( ORIGINAL_REQUIREMENTS = ActionNode( key="Original Requirements", expected_type=str, - instruction="Place the polished, complete original requirements here.", - example="The game should have a leaderboard and multiple difficulty levels.", + instruction="Place the original user's requirements here.", + example="Create a 2048 game", ) PROJECT_NAME = ActionNode( @@ -41,7 +41,7 @@ PRODUCT_GOALS = ActionNode( key="Product Goals", expected_type=list[str], instruction="Provide up to three clear, orthogonal product goals.", - example=["Create an engaging user experience", "Ensure high performance", "Provide customizable features"], + example=["Create an engaging user experience", "Improve accessibility, be responsive", "More beautiful UI"], ) USER_STORIES = ActionNode( @@ -49,8 +49,11 @@ USER_STORIES = ActionNode( expected_type=list[str], instruction="Provide up to 3 to 5 scenario-based user stories.", example=[ - "As a user, I want to be able to choose difficulty levels", + "As a player, I want to be able to choose difficulty levels", "As a player, I want to see my score after each game", + "As a player, I want to get restart button when I lose", + "As a player, I want to see beautiful UI that make me feel good", + "As a player, I want to play game via mobile phone", ], ) @@ -58,7 +61,11 @@ COMPETITIVE_ANALYSIS = ActionNode( key="Competitive Analysis", expected_type=list[str], instruction="Provide 5 to 7 competitive products.", - example=["Python Snake Game: Simple interface, lacks advanced features"], + example=[ + "2048 Game A: Simple interface, lacks responsive features", + "play2048.co: Beautiful and responsive UI with my best score shown", + "2048game.com: Responsive UI with my best score shown, but many ads", + ], ) COMPETITIVE_QUADRANT_CHART = ActionNode( @@ -86,7 +93,7 @@ REQUIREMENT_ANALYSIS = ActionNode( key="Requirement Analysis", expected_type=str, instruction="Provide a detailed analysis of the requirements.", - example="The product should be user-friendly.", + example="", ) REQUIREMENT_POOL = ActionNode( @@ -107,7 +114,7 @@ ANYTHING_UNCLEAR = ActionNode( key="Anything UNCLEAR", expected_type=str, instruction="Mention any aspects of the project that are unclear and try to clarify them.", - example="...", + example="", ) ISSUE_TYPE = ActionNode( diff --git a/metagpt/roles/product_manager.py b/metagpt/roles/product_manager.py index 6dba21fe1..72e5a9be5 100644 --- a/metagpt/roles/product_manager.py +++ b/metagpt/roles/product_manager.py @@ -27,8 +27,8 @@ class ProductManager(Role): """ name: str = "Alice" profile: str = Field(default="Product Manager") - goal: str = "efficiently create a successful product" - constraints: str = "use same language as user requirement" + goal: str = "efficiently create a successful product that meets market demands and user expectations" + constraints: str = "utilize the same language as the user requirements for seamless communication" def __init__(self, **kwargs) -> None: super().__init__(**kwargs) From 6959d40e6d265c0de99ebf057bbc4434febf2a22 Mon Sep 17 00:00:00 2001 From: geekan Date: Wed, 20 Dec 2023 15:04:25 +0800 Subject: [PATCH 134/167] add write_review action and its test --- metagpt/actions/action_node.py | 4 +- metagpt/actions/write_review.py | 40 ++++++++++++++++ metagpt/utils/common.py | 25 +++++++++- tests/metagpt/actions/test_write_review.py | 53 ++++++++++++++++++++++ 4 files changed, 119 insertions(+), 3 deletions(-) create mode 100644 metagpt/actions/write_review.py create mode 100644 tests/metagpt/actions/test_write_review.py diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index 092dd5755..58688aefa 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -41,10 +41,10 @@ Fill in the above nodes based on the format example. """ -def dict_to_markdown(d, prefix="-", postfix="\n"): +def dict_to_markdown(d, prefix="##", kv_sep="\n", postfix="\n"): markdown_str = "" for key, value in d.items(): - markdown_str += f"{prefix} {key}: {value}{postfix}" + markdown_str += f"{prefix}{key}{kv_sep}{value}{postfix}" return markdown_str diff --git a/metagpt/actions/write_review.py b/metagpt/actions/write_review.py new file mode 100644 index 000000000..94dd9951b --- /dev/null +++ b/metagpt/actions/write_review.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Author : alexanderwu +@File : write_review.py +""" +from typing import List + +from metagpt.actions import Action +from metagpt.actions.action_node import ActionNode + +# from metagpt.llm import LLM + +REVIEW = ActionNode( + key="Review", + expected_type=List[str], + instruction="Act as an experienced Reviewer and review the given output. Ask a series of critical questions, " + "concisely and clearly, to help the writer improve their work.", + example=[ + "This is a good PRD, but I think it can be improved by adding more details.", + ], +) + +LGTM = ActionNode( + key="LGTM", + expected_type=str, + instruction="If the output is good enough, give a LGTM (Looks Good To Me) to the writer, " + "else LBTM (Looks Bad To Me).", + example="LGTM", +) + +WRITE_REVIEW_NODE = ActionNode.from_children("WRITE_REVIEW_NODE", [REVIEW, LGTM]) + + +class WriteReview(Action): + """This class allows LLM to further mine noteworthy details based on specific "##TOPIC"(discussion topic) and + "##RECORD" (discussion records), thereby deepening the discussion.""" + + async def run(self, context): + return await WRITE_REVIEW_NODE.fill(context=context, llm=self.llm, schema="markdown") diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index fa18694e3..a290c7db7 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -18,7 +18,7 @@ import os import platform import re import typing -from typing import List, Tuple, Union +from typing import List, Tuple, Union, get_args, get_origin import aiofiles import loguru @@ -129,8 +129,31 @@ class OutputParser: parsed_data[block] = content return parsed_data + @staticmethod + def extract_content(text, tag="CONTENT"): + # Use regular expression to extract content between [CONTENT] and [/CONTENT] + extracted_content = re.search(rf"\[{tag}\](.*?)\[/{tag}\]", text, re.DOTALL) + + if extracted_content: + return extracted_content.group(1).strip() + else: + return "No content found between [CONTENT] and [/CONTENT] tags." + + @staticmethod + def is_supported_list_type(i): + origin = get_origin(i) + if origin is not List: + return False + + args = get_args(i) + if args == (str,) or args == (Tuple[str, str],) or args == (List[str],): + return True + + return False + @classmethod def parse_data_with_mapping(cls, data, mapping): + data = cls.extract_content(text=data) block_dict = cls.parse_blocks(data) parsed_data = {} for block, content in block_dict.items(): diff --git a/tests/metagpt/actions/test_write_review.py b/tests/metagpt/actions/test_write_review.py new file mode 100644 index 000000000..2d188b720 --- /dev/null +++ b/tests/metagpt/actions/test_write_review.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/20 15:01 +@Author : alexanderwu +@File : test_write_review.py +""" +import pytest + +from metagpt.actions.write_review import WriteReview + +CONTEXT = """ +{ + "Language": "zh_cn", + "Programming Language": "Python", + "Original Requirements": "写一个简单的2048", + "Project Name": "game_2048", + "Product Goals": [ + "创建一个引人入胜的用户体验", + "确保高性能", + "提供可定制的功能" + ], + "User Stories": [ + "作为用户,我希望能够选择不同的难度级别", + "作为玩家,我希望在每局游戏结束后能看到我的得分" + ], + "Competitive Analysis": [ + "Python Snake Game: 界面简单,缺乏高级功能" + ], + "Competitive Quadrant Chart": "quadrantChart\n title \"Reach and engagement of campaigns\"\n x-axis \"Low Reach\" --> \"High Reach\"\n y-axis \"Low Engagement\" --> \"High Engagement\"\n quadrant-1 \"我们应该扩展\"\n quadrant-2 \"需要推广\"\n quadrant-3 \"重新评估\"\n quadrant-4 \"可能需要改进\"\n \"Campaign A\": [0.3, 0.6]\n \"Campaign B\": [0.45, 0.23]\n \"Campaign C\": [0.57, 0.69]\n \"Campaign D\": [0.78, 0.34]\n \"Campaign E\": [0.40, 0.34]\n \"Campaign F\": [0.35, 0.78]\n \"Our Target Product\": [0.5, 0.6]", + "Requirement Analysis": "产品应该用户友好。", + "Requirement Pool": [ + [ + "P0", + "主要代码..." + ], + [ + "P0", + "游戏算法..." + ] + ], + "UI Design draft": "基本功能描述,简单的风格和布局。", + "Anything UNCLEAR": "..." +} +""" + + +@pytest.mark.asyncio +async def test_write_review(): + write_review = WriteReview() + review = await write_review.run(CONTEXT) + assert review.instruct_content + assert review.get("LGTM") in ["LGTM", "LBTM"] From 8bec6e98cc8ad5ef1e4d0bb5f0407d08adb682ac Mon Sep 17 00:00:00 2001 From: geekan Date: Wed, 20 Dec 2023 15:20:39 +0800 Subject: [PATCH 135/167] use typing.List instead of list --- metagpt/actions/action_node.py | 3 +++ metagpt/actions/design_api_an.py | 4 +++- metagpt/actions/project_management_an.py | 10 ++++++---- metagpt/actions/write_prd_an.py | 9 +++++---- metagpt/actions/write_review.py | 4 +--- 5 files changed, 18 insertions(+), 12 deletions(-) diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index 58688aefa..4376e09ed 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -4,6 +4,9 @@ @Time : 2023/12/11 18:45 @Author : alexanderwu @File : action_node.py + +NOTE: You should use typing.List instead of list to do type annotation. Because in the markdown extraction process, + we can use typing to extract the type of the node, but we cannot use built-in list to extract. """ import json from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar diff --git a/metagpt/actions/design_api_an.py b/metagpt/actions/design_api_an.py index 0a303cdd5..7d6802381 100644 --- a/metagpt/actions/design_api_an.py +++ b/metagpt/actions/design_api_an.py @@ -5,6 +5,8 @@ @Author : alexanderwu @File : design_api_an.py """ +from typing import List + from metagpt.actions.action_node import ActionNode from metagpt.logs import logger from metagpt.utils.mermaid import MMC1, MMC2 @@ -22,7 +24,7 @@ PROJECT_NAME = ActionNode( FILE_LIST = ActionNode( key="File list", - expected_type=list[str], + expected_type=List[str], instruction="Only need relative paths. ALWAYS write a main.py or app.py here", example=["main.py", "game.py"], ) diff --git a/metagpt/actions/project_management_an.py b/metagpt/actions/project_management_an.py index 6208c1051..215a67202 100644 --- a/metagpt/actions/project_management_an.py +++ b/metagpt/actions/project_management_an.py @@ -5,26 +5,28 @@ @Author : alexanderwu @File : project_management_an.py """ +from typing import List + from metagpt.actions.action_node import ActionNode from metagpt.logs import logger REQUIRED_PYTHON_PACKAGES = ActionNode( key="Required Python packages", - expected_type=list[str], + expected_type=List[str], instruction="Provide required Python packages in requirements.txt format.", example=["flask==1.1.2", "bcrypt==3.2.0"], ) REQUIRED_OTHER_LANGUAGE_PACKAGES = ActionNode( key="Required Other language third-party packages", - expected_type=list[str], + expected_type=List[str], instruction="List down the required packages for languages other than Python.", example=["No third-party dependencies required"], ) LOGIC_ANALYSIS = ActionNode( key="Logic Analysis", - expected_type=list[list[str]], + expected_type=List[List[str]], instruction="Provide a list of files with the classes/methods/functions to be implemented, " "including dependency analysis and imports.", example=[ @@ -35,7 +37,7 @@ LOGIC_ANALYSIS = ActionNode( TASK_LIST = ActionNode( key="Task list", - expected_type=list[str], + expected_type=List[str], instruction="Break down the tasks into a list of filenames, prioritized by dependency order.", example=["game.py", "main.py"], ) diff --git a/metagpt/actions/write_prd_an.py b/metagpt/actions/write_prd_an.py index 8698c739f..d58d72f64 100644 --- a/metagpt/actions/write_prd_an.py +++ b/metagpt/actions/write_prd_an.py @@ -5,6 +5,7 @@ @Author : alexanderwu @File : write_prd_an.py """ +from typing import List from metagpt.actions.action_node import ActionNode from metagpt.logs import logger @@ -39,14 +40,14 @@ PROJECT_NAME = ActionNode( PRODUCT_GOALS = ActionNode( key="Product Goals", - expected_type=list[str], + expected_type=List[str], instruction="Provide up to three clear, orthogonal product goals.", example=["Create an engaging user experience", "Improve accessibility, be responsive", "More beautiful UI"], ) USER_STORIES = ActionNode( key="User Stories", - expected_type=list[str], + expected_type=List[str], instruction="Provide up to 3 to 5 scenario-based user stories.", example=[ "As a player, I want to be able to choose difficulty levels", @@ -59,7 +60,7 @@ USER_STORIES = ActionNode( COMPETITIVE_ANALYSIS = ActionNode( key="Competitive Analysis", - expected_type=list[str], + expected_type=List[str], instruction="Provide 5 to 7 competitive products.", example=[ "2048 Game A: Simple interface, lacks responsive features", @@ -98,7 +99,7 @@ REQUIREMENT_ANALYSIS = ActionNode( REQUIREMENT_POOL = ActionNode( key="Requirement Pool", - expected_type=list[list[str]], + expected_type=List[List[str]], instruction="List down the top-5 requirements with their priority (P0, P1, P2).", example=[["P0", "The main code ..."], ["P0", "The game algorithm ..."]], ) diff --git a/metagpt/actions/write_review.py b/metagpt/actions/write_review.py index 94dd9951b..13690a1a5 100644 --- a/metagpt/actions/write_review.py +++ b/metagpt/actions/write_review.py @@ -9,8 +9,6 @@ from typing import List from metagpt.actions import Action from metagpt.actions.action_node import ActionNode -# from metagpt.llm import LLM - REVIEW = ActionNode( key="Review", expected_type=List[str], @@ -24,7 +22,7 @@ REVIEW = ActionNode( LGTM = ActionNode( key="LGTM", expected_type=str, - instruction="If the output is good enough, give a LGTM (Looks Good To Me) to the writer, " + instruction="LGTM/LBTM. If the output is good enough, give a LGTM (Looks Good To Me) to the writer, " "else LBTM (Looks Bad To Me).", example="LGTM", ) From 3f0f008690d1c19ab379cf2925603f55d6599c10 Mon Sep 17 00:00:00 2001 From: better629 Date: Wed, 20 Dec 2023 15:59:15 +0800 Subject: [PATCH 136/167] update ActionOutput.create_model_class to ActionNode.create_model_class --- tests/metagpt/serialize_deserialize/test_action.py | 2 +- tests/metagpt/serialize_deserialize/test_environment.py | 4 ++-- tests/metagpt/serialize_deserialize/test_memory.py | 6 +++--- tests/metagpt/serialize_deserialize/test_schema.py | 6 +++--- .../metagpt/serialize_deserialize/test_serdeser_base.py | 6 +++--- .../serialize_deserialize/test_write_code_review.py | 9 ++------- tests/metagpt/test_schema.py | 4 ++-- 7 files changed, 16 insertions(+), 21 deletions(-) diff --git a/tests/metagpt/serialize_deserialize/test_action.py b/tests/metagpt/serialize_deserialize/test_action.py index 63d8e7b7c..14d558c13 100644 --- a/tests/metagpt/serialize_deserialize/test_action.py +++ b/tests/metagpt/serialize_deserialize/test_action.py @@ -4,7 +4,7 @@ # @Desc : import pytest -from metagpt.actions import Action, WriteTest +from metagpt.actions import Action from metagpt.llm import LLM diff --git a/tests/metagpt/serialize_deserialize/test_environment.py b/tests/metagpt/serialize_deserialize/test_environment.py index 3a374460c..b741b9c4b 100644 --- a/tests/metagpt/serialize_deserialize/test_environment.py +++ b/tests/metagpt/serialize_deserialize/test_environment.py @@ -4,7 +4,7 @@ import shutil -from metagpt.actions.action_output import ActionOutput +from metagpt.actions.action_node import ActionNode from metagpt.actions.add_requirement import UserRequirement from metagpt.actions.project_management import WriteTasks from metagpt.environment import Environment @@ -32,7 +32,7 @@ def test_env_deserialize(): def test_environment_serdeser(): out_mapping = {"field1": (list[str], ...)} out_data = {"field1": ["field1 value1", "field1 value2"]} - ic_obj = ActionOutput.create_model_class("prd", out_mapping) + ic_obj = ActionNode.create_model_class("prd", out_mapping) message = Message( content="prd", diff --git a/tests/metagpt/serialize_deserialize/test_memory.py b/tests/metagpt/serialize_deserialize/test_memory.py index 47410c615..0d756518b 100644 --- a/tests/metagpt/serialize_deserialize/test_memory.py +++ b/tests/metagpt/serialize_deserialize/test_memory.py @@ -4,7 +4,7 @@ from pydantic import BaseModel -from metagpt.actions.action_output import ActionOutput +from metagpt.actions.action_node import ActionNode from metagpt.actions.add_requirement import UserRequirement from metagpt.actions.design_api import WriteDesign from metagpt.memory.memory import Memory @@ -20,7 +20,7 @@ def test_memory_serdeser(): out_mapping = {"field2": (list[str], ...)} out_data = {"field2": ["field2 value1", "field2 value2"]} - ic_obj = ActionOutput.create_model_class("system_design", out_mapping) + ic_obj = ActionNode.create_model_class("system_design", out_mapping) msg2 = Message(role="Architect", instruct_content=ic_obj(**out_data), content="system design content", @@ -46,7 +46,7 @@ def test_memory_serdeser_save(): out_mapping = {"field1": (list[str], ...)} out_data = {"field1": ["field1 value1", "field1 value2"]} - ic_obj = ActionOutput.create_model_class("system_design", out_mapping) + ic_obj = ActionNode.create_model_class("system_design", out_mapping) msg2 = Message(role="Architect", instruct_content=ic_obj(**out_data), content="system design content", diff --git a/tests/metagpt/serialize_deserialize/test_schema.py b/tests/metagpt/serialize_deserialize/test_schema.py index 02afa762d..72b7153a7 100644 --- a/tests/metagpt/serialize_deserialize/test_schema.py +++ b/tests/metagpt/serialize_deserialize/test_schema.py @@ -2,7 +2,7 @@ # -*- coding: utf-8 -*- # @Desc : unittest of schema ser&deser -from metagpt.actions.action_output import ActionOutput +from metagpt.actions.action_node import ActionNode from metagpt.actions.write_code import WriteCode from metagpt.schema import Message from metagpt.utils.common import any_to_str @@ -12,7 +12,7 @@ from tests.metagpt.serialize_deserialize.test_serdeser_base import MockMessage def test_message_serdeser(): out_mapping = {"field3": (str, ...), "field4": (list[str], ...)} out_data = {"field3": "field3 value3", "field4": ["field4 value1", "field4 value2"]} - ic_obj = ActionOutput.create_model_class("code", out_mapping) + ic_obj = ActionNode.create_model_class("code", out_mapping) message = Message( content="code", @@ -34,7 +34,7 @@ def test_message_without_postprocess(): """ to explain `instruct_content` should be postprocessed """ out_mapping = {"field1": (list[str], ...)} out_data = {"field1": ["field1 value1", "field1 value2"]} - ic_obj = ActionOutput.create_model_class("code", out_mapping) + ic_obj = ActionNode.create_model_class("code", out_mapping) message = MockMessage( content="code", instruct_content=ic_obj(**out_data) diff --git a/tests/metagpt/serialize_deserialize/test_serdeser_base.py b/tests/metagpt/serialize_deserialize/test_serdeser_base.py index 20f708e30..eac083cf9 100644 --- a/tests/metagpt/serialize_deserialize/test_serdeser_base.py +++ b/tests/metagpt/serialize_deserialize/test_serdeser_base.py @@ -7,8 +7,8 @@ from pathlib import Path from pydantic import BaseModel, Field -from metagpt.actions.action import Action -from metagpt.actions.action_output import ActionOutput +from metagpt.actions import Action, ActionOutput +from metagpt.actions.action_node import ActionNode from metagpt.actions.add_requirement import UserRequirement from metagpt.roles.role import Role, RoleReactMode @@ -29,7 +29,7 @@ class ActionPass(Action): output_mapping = { "result": (str, ...) } - pass_class = ActionOutput.create_model_class("pass", output_mapping) + pass_class = ActionNode.create_model_class("pass", output_mapping) pass_output = ActionOutput("ActionPass run passed", pass_class(**{"result": "pass result"})) return pass_output diff --git a/tests/metagpt/serialize_deserialize/test_write_code_review.py b/tests/metagpt/serialize_deserialize/test_write_code_review.py index 6ca4c6027..a15b744db 100644 --- a/tests/metagpt/serialize_deserialize/test_write_code_review.py +++ b/tests/metagpt/serialize_deserialize/test_write_code_review.py @@ -9,13 +9,6 @@ from metagpt.llm import LLM from metagpt.schema import CodingContext, Document -def test_write_task_serialize(): - action = WriteCodeReview() - ser_action_dict = action.dict() - assert ser_action_dict["name"] == "WriteCodeReview" - # assert "llm" in ser_action_dict # not export - - @pytest.mark.asyncio async def test_write_code_review_deserialize(): code_content = """ @@ -30,6 +23,8 @@ def div(a: int, b: int = 0): action = WriteCodeReview(context=context) serialized_data = action.dict() + assert serialized_data["name"] == "WriteCodeReview" + new_action = WriteCodeReview(**serialized_data) assert new_action.name == "WriteCodeReview" diff --git a/tests/metagpt/test_schema.py b/tests/metagpt/test_schema.py index c8602d953..054a92de1 100644 --- a/tests/metagpt/test_schema.py +++ b/tests/metagpt/test_schema.py @@ -13,7 +13,7 @@ import pytest from metagpt.actions import Action from metagpt.schema import AIMessage, Message, SystemMessage, UserMessage -from metagpt.actions.action_output import ActionOutput +from metagpt.actions.action_node import ActionNode from metagpt.actions.write_code import WriteCode from metagpt.utils.serialize import serialize_general_message, deserialize_general_message from metagpt.utils.common import any_to_str @@ -76,7 +76,7 @@ def test_routes(): def test_message_serdeser(): out_mapping = {"field3": (str, ...), "field4": (list[str], ...)} out_data = {"field3": "field3 value3", "field4": ["field4 value1", "field4 value2"]} - ic_obj = ActionOutput.create_model_class("code", out_mapping) + ic_obj = ActionNode.create_model_class("code", out_mapping) message = Message( content="code", From 15279376d40ec59405295af2c80b9c7c96ddd294 Mon Sep 17 00:00:00 2001 From: better629 Date: Wed, 20 Dec 2023 16:01:17 +0800 Subject: [PATCH 137/167] rebase update after #589 --- metagpt/actions/action.py | 5 ++--- metagpt/actions/debug_error.py | 10 +++++++--- metagpt/actions/fix_bug.py | 1 + metagpt/actions/run_code.py | 10 +++++++--- metagpt/actions/summarize_code.py | 8 ++++++-- metagpt/actions/write_code.py | 3 +-- metagpt/actions/write_code_review.py | 4 +--- metagpt/roles/qa_engineer.py | 25 ++++++++++++------------- metagpt/roles/role.py | 3 ++- metagpt/schema.py | 4 ++-- metagpt/utils/serialize.py | 13 +++++-------- 11 files changed, 46 insertions(+), 40 deletions(-) diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index e18983d7d..535c25cb9 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -11,12 +11,11 @@ from __future__ import annotations from typing import Optional, Any from pydantic import BaseModel, Field -from metagpt.actions.action_node import ActionNode + from metagpt.llm import LLM from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.schema import CodingContext, CodeSummarizeContext, TestingContext, RunCodeContext - action_subclass_registry = {} @@ -26,7 +25,7 @@ class Action(BaseModel): context: dict | CodingContext | CodeSummarizeContext | TestingContext | RunCodeContext | str | None = "" prefix = "" # aask*时会加上prefix,作为system_message desc = "" # for skill manager - node: ActionNode = Field(default_factory=ActionNode, exclude=True) + # node: ActionNode = Field(default_factory=ActionNode, exclude=True) # builtin variables builtin_class_name: str = "" diff --git a/metagpt/actions/debug_error.py b/metagpt/actions/debug_error.py index 39f3bc1bc..839acdc2e 100644 --- a/metagpt/actions/debug_error.py +++ b/metagpt/actions/debug_error.py @@ -10,11 +10,14 @@ """ import re +from pydantic import Field + from metagpt.actions.action import Action from metagpt.config import CONFIG from metagpt.const import TEST_CODES_FILE_REPO, TEST_OUTPUTS_FILE_REPO +from metagpt.llm import LLM, BaseGPTAPI from metagpt.logs import logger -from metagpt.schema import RunCodeResult +from metagpt.schema import RunCodeResult, RunCodeContext from metagpt.utils.common import CodeParser from metagpt.utils.file_repository import FileRepository @@ -47,8 +50,9 @@ Now you should start rewriting the code: class DebugError(Action): - def __init__(self, name="DebugError", context=None, llm=None): - super().__init__(name, context, llm) + name: str = "DebugError" + context: RunCodeContext = Field(default_factory=RunCodeContext) + llm: BaseGPTAPI = Field(default_factory=LLM) async def run(self, *args, **kwargs) -> str: output_doc = await FileRepository.get_file( diff --git a/metagpt/actions/fix_bug.py b/metagpt/actions/fix_bug.py index 6bd550d3d..eea40c91a 100644 --- a/metagpt/actions/fix_bug.py +++ b/metagpt/actions/fix_bug.py @@ -9,6 +9,7 @@ from metagpt.actions import Action class FixBug(Action): """Fix bug action without any implementation details""" + name: str = "FixBug" async def run(self, *args, **kwargs): raise NotImplementedError diff --git a/metagpt/actions/run_code.py b/metagpt/actions/run_code.py index 1b9fd252f..ea16c8891 100644 --- a/metagpt/actions/run_code.py +++ b/metagpt/actions/run_code.py @@ -18,10 +18,13 @@ import subprocess from typing import Tuple +from pydantic import Field + from metagpt.actions.action import Action from metagpt.config import CONFIG +from metagpt.llm import LLM, BaseGPTAPI from metagpt.logs import logger -from metagpt.schema import RunCodeResult +from metagpt.schema import RunCodeResult, RunCodeContext from metagpt.utils.exceptions import handle_exception PROMPT_TEMPLATE = """ @@ -74,8 +77,9 @@ standard errors: class RunCode(Action): - def __init__(self, name="RunCode", context=None, llm=None): - super().__init__(name, context, llm) + name: str = "RunCode" + context: RunCodeContext = Field(default_factory=RunCodeContext) + llm: BaseGPTAPI = Field(default_factory=LLM) @classmethod @handle_exception diff --git a/metagpt/actions/summarize_code.py b/metagpt/actions/summarize_code.py index f8d8d2b47..0aec15937 100644 --- a/metagpt/actions/summarize_code.py +++ b/metagpt/actions/summarize_code.py @@ -7,12 +7,15 @@ """ from pathlib import Path +from pydantic import Field from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.actions.action import Action from metagpt.config import CONFIG from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO +from metagpt.llm import LLM, BaseGPTAPI from metagpt.logs import logger +from metagpt.schema import CodeSummarizeContext from metagpt.utils.file_repository import FileRepository PROMPT_TEMPLATE = """ @@ -89,8 +92,9 @@ flowchart TB class SummarizeCode(Action): - def __init__(self, name="SummarizeCode", context=None, llm=None): - super().__init__(name, context, llm) + name: str = "SummarizeCode" + context: CodeSummarizeContext = Field(default_factory=CodeSummarizeContext) + llm: BaseGPTAPI = Field(default_factory=LLM) @retry(stop=stop_after_attempt(2), wait=wait_random_exponential(min=1, max=60)) async def summarize_code(self, prompt): diff --git a/metagpt/actions/write_code.py b/metagpt/actions/write_code.py index 046f9f456..4d0690e0f 100644 --- a/metagpt/actions/write_code.py +++ b/metagpt/actions/write_code.py @@ -16,7 +16,6 @@ """ import json -from typing import Optional from pydantic import Field from tenacity import retry, stop_after_attempt, wait_random_exponential @@ -90,7 +89,7 @@ ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenc class WriteCode(Action): name: str = "WriteCode" - context: Optional[Document] = None + context: Document = Field(default_factory=Document) llm: BaseGPTAPI = Field(default_factory=LLM) @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) diff --git a/metagpt/actions/write_code_review.py b/metagpt/actions/write_code_review.py index f4ab0adfe..580069b74 100644 --- a/metagpt/actions/write_code_review.py +++ b/metagpt/actions/write_code_review.py @@ -8,8 +8,6 @@ WriteCode object, rather than passing them in when calling the run function. """ -from typing import Optional - from pydantic import Field from tenacity import retry, stop_after_attempt, wait_random_exponential @@ -124,7 +122,7 @@ REWRITE_CODE_TEMPLATE = """ class WriteCodeReview(Action): name: str = "WriteCodeReview" - context: Optional[CodingContext] = None + context: CodingContext = Field(default_factory=CodingContext) llm: BaseGPTAPI = Field(default_factory=LLM) @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) diff --git a/metagpt/roles/qa_engineer.py b/metagpt/roles/qa_engineer.py index acb79ab80..893faa9dd 100644 --- a/metagpt/roles/qa_engineer.py +++ b/metagpt/roles/qa_engineer.py @@ -17,6 +17,11 @@ from pydantic import Field +from metagpt.actions import ( + DebugError, + RunCode, + WriteTest, +) from metagpt.actions.summarize_code import SummarizeCode from metagpt.config import CONFIG from metagpt.const import ( @@ -24,11 +29,6 @@ from metagpt.const import ( TEST_CODES_FILE_REPO, TEST_OUTPUTS_FILE_REPO, ) -from metagpt.actions import ( - DebugError, - RunCode, - WriteTest, -) from metagpt.logs import logger from metagpt.roles import Role from metagpt.schema import Document, Message, RunCodeContext, TestingContext @@ -40,17 +40,16 @@ class QaEngineer(Role): name: str = Field(default="Edward") profile: str = Field(default="QaEngineer") goal: str = "Write comprehensive and robust tests to ensure codes will work as expected without bugs" - constraints: str = "The test code you write should conform to code standard like PEP8, be modular, easy to read and maintain" + constraints: str = "The test code you write should conform to code standard like PEP8, be modular, " \ + "easy to read and maintain" test_round_allowed: int = 5 - def __init__( - self, - **kwargs - ): + def __init__(self, **kwargs): super().__init__(**kwargs) - self._init_actions( - [WriteTest] - ) # FIXME: a bit hack here, only init one action to circumvent _think() logic, will overwrite _think() in future updates + + # FIXME: a bit hack here, only init one action to circumvent _think() logic, + # will overwrite _think() in future updates + self._init_actions([WriteTest]) self._watch([SummarizeCode, WriteTest, RunCode, DebugError]) self.test_round = 0 diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 0bc129174..4bce64245 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -27,7 +27,8 @@ from typing import Iterable, Set, Type, Any from pydantic import BaseModel, Field -from metagpt.actions.action import Action, ActionOutput, action_subclass_registry +from metagpt.actions import Action, ActionOutput +from metagpt.actions.action import action_subclass_registry from metagpt.actions.action_node import ActionNode from metagpt.actions.add_requirement import UserRequirement from metagpt.const import SERDESER_PATH diff --git a/metagpt/schema.py b/metagpt/schema.py index 327bfd2d1..e5df6fb10 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -113,8 +113,8 @@ class Message(BaseModel): ic = instruct_content mapping = actionoutput_str_to_mapping(ic["mapping"]) - actionoutput_class = import_class("ActionOutput", "metagpt.actions.action_output") - ic_obj = actionoutput_class.create_model_class(class_name=ic["class"], mapping=mapping) + actionnode_class = import_class("ActionNode", "metagpt.actions.action_node") # avoid circular import + ic_obj = actionnode_class.create_model_class(class_name=ic["class"], mapping=mapping) ic_new = ic_obj(**ic["value"]) kwargs["instruct_content"] = ic_new diff --git a/metagpt/utils/serialize.py b/metagpt/utils/serialize.py index 7bfd55008..1d90e8de8 100644 --- a/metagpt/utils/serialize.py +++ b/metagpt/utils/serialize.py @@ -6,8 +6,6 @@ import copy import pickle from metagpt.utils.common import import_class -from metagpt.actions.action_node import ActionNode -from metagpt.schema import Message def actionoutout_schema_to_mapping(schema: dict) -> dict: @@ -90,27 +88,26 @@ def serialize_message(message: "Message"): def deserialize_general_message(message_dict: dict) -> "Message": """ deserialize Message, not to load""" instruct_content = message_dict.pop("instruct_content") - cause_by = message_dict.pop("cause_by") message_cls = import_class("Message", "metagpt.schema") message = message_cls(**message_dict) if instruct_content: ic = instruct_content mapping = actionoutput_str_to_mapping(ic["mapping"]) - - actionoutput_class = import_class("ActionOutput", "metagpt.actions.action_output") - ic_obj = actionoutput_class.create_model_class(class_name=ic["class"], mapping=mapping) + actionnode_class = import_class("ActionNode", "metagpt.actions.action_node") # avoid circular import + ic_obj = actionnode_class.create_model_class(class_name=ic["class"], mapping=mapping) ic_new = ic_obj(**ic["value"]) message.instruct_content = ic_new return message -def deserialize_message(message_ser: str) -> Message: +def deserialize_message(message_ser: str) -> "Message": message = pickle.loads(message_ser) if message.instruct_content: ic = message.instruct_content - ic_obj = ActionNode.create_model_class(class_name=ic["class"], mapping=ic["mapping"]) + actionnode_class = import_class("ActionNode", "metagpt.actions.action_node") # avoid circular import + ic_obj = actionnode_class.create_model_class(class_name=ic["class"], mapping=ic["mapping"]) ic_new = ic_obj(**ic["value"]) message.instruct_content = ic_new From 6877fa444feee9b3e00ede2d426e65c8a0b20446 Mon Sep 17 00:00:00 2001 From: better629 Date: Wed, 20 Dec 2023 18:55:29 +0800 Subject: [PATCH 138/167] deal with nested BaseModel --- metagpt/schema.py | 18 +++++++++++------- metagpt/utils/common.py | 4 +--- metagpt/utils/serialize.py | 2 +- .../metagpt/serialize_deserialize/test_team.py | 3 +++ 4 files changed, 16 insertions(+), 11 deletions(-) diff --git a/metagpt/schema.py b/metagpt/schema.py index e5df6fb10..1bb07aa95 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -108,9 +108,9 @@ class Message(BaseModel): send_to: Set = Field(default_factory={MESSAGE_ROUTE_TO_ALL}) def __init__(self, **kwargs): - instruct_content = kwargs.get("instruct_content", None) - if instruct_content and not isinstance(instruct_content, BaseModel): - ic = instruct_content + ic = kwargs.get("instruct_content", None) + if ic and not isinstance(ic, BaseModel) and "class" in ic: + # compatible with custom-defined ActionOutput mapping = actionoutput_str_to_mapping(ic["mapping"]) actionnode_class = import_class("ActionNode", "metagpt.actions.action_node") # avoid circular import @@ -140,13 +140,17 @@ class Message(BaseModel): def dict(self, *args, **kwargs) -> "DictStrAny": """ overwrite the `dict` to dump dynamic pydantic model""" obj_dict = super(Message, self).dict(*args, **kwargs) - ic = self.instruct_content # deal custom-defined action + ic = self.instruct_content if ic: + # compatible with custom-defined ActionOutput schema = ic.schema() - mapping = actionoutout_schema_to_mapping(schema) - mapping = actionoutput_mapping_to_str(mapping) + # `Documents` contain definitions + if "definitions" not in schema: + # TODO refine with nested BaseModel + mapping = actionoutout_schema_to_mapping(schema) + mapping = actionoutput_mapping_to_str(mapping) - obj_dict["instruct_content"] = {"class": schema["title"], "mapping": mapping, "value": ic.dict()} + obj_dict["instruct_content"] = {"class": schema["title"], "mapping": mapping, "value": ic.dict()} return obj_dict def __str__(self): diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index a445c9f31..ab7a3d99e 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -450,14 +450,12 @@ def serialize_decorator(func): async def wrapper(self, *args, **kwargs): try: result = await func(self, *args, **kwargs) - self.serialize() # Team.serialize return result except KeyboardInterrupt as kbi: logger.error(f"KeyboardInterrupt occurs, start to serialize the project, exp:\n{format_trackback_info()}") - self.serialize() # Team.serialize except Exception as exp: logger.error(f"Exception occurs, start to serialize the project, exp:\n{format_trackback_info()}") - self.serialize() # Team.serialize + self.serialize() # Team.serialize return wrapper diff --git a/metagpt/utils/serialize.py b/metagpt/utils/serialize.py index 1d90e8de8..a52dc8f45 100644 --- a/metagpt/utils/serialize.py +++ b/metagpt/utils/serialize.py @@ -62,7 +62,7 @@ def serialize_general_message(message: "Message") -> dict: message_cp = copy.deepcopy(message) ic = message_cp.instruct_content if ic: - # model create by pydantic create_model like `pydantic.main.prd`, can't pickle.dump directly + # model create by pydantic create_model like `pydantic.main.prd`, can't load directly schema = ic.schema() mapping = actionoutout_schema_to_mapping(schema) mapping = actionoutput_mapping_to_str(mapping) diff --git a/tests/metagpt/serialize_deserialize/test_team.py b/tests/metagpt/serialize_deserialize/test_team.py index e87df9b52..d6a477b0e 100644 --- a/tests/metagpt/serialize_deserialize/test_team.py +++ b/tests/metagpt/serialize_deserialize/test_team.py @@ -10,6 +10,7 @@ import pytest from metagpt.const import SERDESER_PATH from metagpt.roles import ProjectManager, ProductManager, Architect from metagpt.team import Team +from metagpt.logs import logger from tests.metagpt.serialize_deserialize.test_serdeser_base import RoleA, RoleB, RoleC, serdeser_path, ActionOK @@ -120,6 +121,8 @@ async def test_team_recover_multi_roles_save(): company.run_project(idea) await company.run(n_round=4) + logger.info("Team recovered") + new_company = Team.recover(stg_path) new_company.run_project(idea) From 0543c0f76b18680031a59ce5cccd5e1a1899cb58 Mon Sep 17 00:00:00 2001 From: better629 Date: Thu, 21 Dec 2023 00:16:28 +0800 Subject: [PATCH 139/167] just use deserialize instead of recover --- metagpt/startup.py | 2 +- tests/metagpt/serialize_deserialize/test_team.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/metagpt/startup.py b/metagpt/startup.py index 5a3e482a4..59e0cb199 100644 --- a/metagpt/startup.py +++ b/metagpt/startup.py @@ -67,7 +67,7 @@ def startup( if not stg_path.exists() or not str(stg_path).endswith("team"): raise FileNotFoundError(f"{recover_path} not exists or not endswith `team`") - company = Team.recover(stg_path=stg_path) + company = Team.deserialize(stg_path=stg_path) idea = company.idea # use original idea company.invest(investment) diff --git a/tests/metagpt/serialize_deserialize/test_team.py b/tests/metagpt/serialize_deserialize/test_team.py index d6a477b0e..db6001325 100644 --- a/tests/metagpt/serialize_deserialize/test_team.py +++ b/tests/metagpt/serialize_deserialize/test_team.py @@ -89,7 +89,7 @@ async def test_team_recover_save(): company.run_project(idea) await company.run(n_round=4) - new_company = Team.recover(stg_path) + new_company = Team.deserialize(stg_path) new_role_c = new_company.env.get_role(role_c.profile) # assert new_role_c._rc.memory == role_c._rc.memory assert new_role_c._rc.env != role_c._rc.env @@ -123,7 +123,7 @@ async def test_team_recover_multi_roles_save(): logger.info("Team recovered") - new_company = Team.recover(stg_path) + new_company = Team.deserialize(stg_path) new_company.run_project(idea) assert new_company.env.get_role(role_b.profile)._rc.state == 1 From 24060ea8a65d45e32d816b4ad596e74f3f4a78fe Mon Sep 17 00:00:00 2001 From: better629 Date: Thu, 21 Dec 2023 00:18:09 +0800 Subject: [PATCH 140/167] update use Field with uniform rule: define default_factory or exclude, use Field --- metagpt/environment.py | 2 +- metagpt/memory/memory.py | 2 +- metagpt/roles/architect.py | 10 +++++----- metagpt/roles/customer_service.py | 6 +++--- metagpt/roles/engineer.py | 2 +- metagpt/roles/product_manager.py | 10 +++++----- metagpt/roles/project_manager.py | 4 ++-- metagpt/roles/qa_engineer.py | 4 ++-- metagpt/roles/role.py | 6 +++--- metagpt/roles/sales.py | 6 +++--- metagpt/schema.py | 2 +- 11 files changed, 27 insertions(+), 27 deletions(-) diff --git a/metagpt/environment.py b/metagpt/environment.py index a3cbe6978..ab296557f 100644 --- a/metagpt/environment.py +++ b/metagpt/environment.py @@ -30,7 +30,7 @@ class Environment(BaseModel): roles: dict[str, Role] = Field(default_factory=dict) members: dict[Role, Set] = Field(default_factory=dict) - history: str = Field(default="") # For debug + history: str = "" # For debug class Config: arbitrary_types_allowed = True diff --git a/metagpt/memory/memory.py b/metagpt/memory/memory.py index 66ab5d4e9..076db832a 100644 --- a/metagpt/memory/memory.py +++ b/metagpt/memory/memory.py @@ -19,7 +19,7 @@ from metagpt.utils.common import any_to_str, any_to_str_set, read_json_file, wri class Memory(BaseModel): """The most basic memory: super-memory""" - storage: list[Message] = Field(default=[]) + storage: list[Message] = [] index: dict[str, list[Message]] = Field(default_factory=defaultdict(list)) def __init__(self, **kwargs): diff --git a/metagpt/roles/architect.py b/metagpt/roles/architect.py index a36cd6e93..bd6cd110b 100644 --- a/metagpt/roles/architect.py +++ b/metagpt/roles/architect.py @@ -22,11 +22,11 @@ class Architect(Role): goal (str): Primary goal or responsibility of the architect. constraints (str): Constraints or guidelines for the architect. """ - name: str = Field(default="Bob") - profile: str = Field(default="Architect") - goal: str = Field(default="design a concise, usable, complete software system") - constraints: str = Field(default="make sure the architecture is simple enough and use appropriate open source " - "libraries. Use same language as user requirement") + name: str = "Bob" + profile: str = "Architect" + goal: str = "design a concise, usable, complete software system" + constraints: str = "make sure the architecture is simple enough and use appropriate open source " \ + "libraries. Use same language as user requirement" def __init__(self, **kwargs) -> None: super().__init__(**kwargs) diff --git a/metagpt/roles/customer_service.py b/metagpt/roles/customer_service.py index 62792696f..b2033ac0b 100644 --- a/metagpt/roles/customer_service.py +++ b/metagpt/roles/customer_service.py @@ -28,9 +28,9 @@ DESC = """ class CustomerService(Sales): - name: str = Field(default="Xiaomei") - profile: str = Field(default="Human customer service") - desc: str = DESC, + name: str = "Xiaomei" + profile: str = "Human customer service" + desc: str = DESC store: Optional[str] = None diff --git a/metagpt/roles/engineer.py b/metagpt/roles/engineer.py index 206afb38c..337184068 100644 --- a/metagpt/roles/engineer.py +++ b/metagpt/roles/engineer.py @@ -70,7 +70,7 @@ class Engineer(Role): use_code_review (bool): Whether to use code review. """ name: str = "Alex" - profile: str = Field(default="Engineer") + profile: str = "Engineer" goal: str = "write elegant, readable, extensible, efficient code" constraints: str = "the code should conform to standards like google-style and be modular and maintainable. " \ "Use same language as user requirement" diff --git a/metagpt/roles/product_manager.py b/metagpt/roles/product_manager.py index 72e5a9be5..6369688a5 100644 --- a/metagpt/roles/product_manager.py +++ b/metagpt/roles/product_manager.py @@ -20,13 +20,13 @@ class ProductManager(Role): Represents a Product Manager role responsible for product development and management. Attributes: - name (str): Name of the project manager. - profile (str): Role profile, default is 'Project Manager'. - goal (str): Goal of the project manager. - constraints (str): Constraints or limitations for the project manager. + name (str): Name of the product manager. + profile (str): Role profile, default is 'Product Manager'. + goal (str): Goal of the product manager. + constraints (str): Constraints or limitations for the product manager. """ name: str = "Alice" - profile: str = Field(default="Product Manager") + profile: str = "Product Manager" goal: str = "efficiently create a successful product that meets market demands and user expectations" constraints: str = "utilize the same language as the user requirements for seamless communication" diff --git a/metagpt/roles/project_manager.py b/metagpt/roles/project_manager.py index 42564cd70..bf572d1f8 100644 --- a/metagpt/roles/project_manager.py +++ b/metagpt/roles/project_manager.py @@ -22,8 +22,8 @@ class ProjectManager(Role): goal (str): Goal of the project manager. constraints (str): Constraints or limitations for the project manager. """ - name: str = Field(default="Eve") - profile: str = Field(default="Project Manager") + name: str = "Eve" + profile: str = "Project Manager" goal: str = "break down tasks according to PRD/technical design, generate a task list, and analyze task " \ "dependencies to start with the prerequisite modules" constraints: str = "use same language as user requirement" diff --git a/metagpt/roles/qa_engineer.py b/metagpt/roles/qa_engineer.py index 893faa9dd..369e3dc63 100644 --- a/metagpt/roles/qa_engineer.py +++ b/metagpt/roles/qa_engineer.py @@ -37,8 +37,8 @@ from metagpt.utils.file_repository import FileRepository class QaEngineer(Role): - name: str = Field(default="Edward") - profile: str = Field(default="QaEngineer") + name: str = "Edward" + profile: str = "QaEngineer" goal: str = "Write comprehensive and robust tests to ensure codes will work as expected without bugs" constraints: str = "The test code you write should conform to code standard like PEP8, be modular, " \ "easy to read and maintain" diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 4bce64245..f87c4e250 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -129,9 +129,9 @@ class Role(BaseModel): _llm: BaseGPTAPI = Field(default_factory=LLM) _role_id: str = "" - _states: list[str] = Field(default=[]) - _actions: list[Action] = Field(default=[]) - _rc: RoleContext = Field(default=RoleContext) + _states: list[str] = [] + _actions: list[Action] = [] + _rc: RoleContext = Field(default_factory=RoleContext) _subscription: tuple[str] = set() # builtin variables diff --git a/metagpt/roles/sales.py b/metagpt/roles/sales.py index 826413dc8..fd5a42915 100644 --- a/metagpt/roles/sales.py +++ b/metagpt/roles/sales.py @@ -16,14 +16,14 @@ from metagpt.tools import SearchEngineType class Sales(Role): - name: str = Field(default="Xiaomei") - profile: str = Field(default="Retail sales guide") + name: str = "Xiaomei" + profile: str = "Retail sales guide" desc: str = "I am a sales guide in retail. My name is Xiaomei. I will answer some customer questions next, and I " "will answer questions only based on the information in the knowledge base." "If I feel that you can't get the answer from the reference material, then I will directly reply that" " I don't know, and I won't tell you that this is from the knowledge base," "but pretend to be what I know. Note that each of my replies will be replied in the tone of a " - "professional guide", + "professional guide" store: Optional[str] = None diff --git a/metagpt/schema.py b/metagpt/schema.py index 1bb07aa95..5103a4f20 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -101,7 +101,7 @@ class Message(BaseModel): id: str # According to Section 2.2.3.1.1 of RFC 135 content: str - instruct_content: BaseModel = Field(default=None) + instruct_content: BaseModel = None role: str = "user" # system / user / assistant cause_by: str = "" sent_from: str = "" From 2178cecd25916a53c77695eb25c46d2f472ff1b1 Mon Sep 17 00:00:00 2001 From: better629 Date: Thu, 21 Dec 2023 00:34:53 +0800 Subject: [PATCH 141/167] rm useless functions in serialize.py --- metagpt/utils/serialize.py | 31 ------------------------------- tests/metagpt/test_schema.py | 27 +++++++++++++-------------- 2 files changed, 13 insertions(+), 45 deletions(-) diff --git a/metagpt/utils/serialize.py b/metagpt/utils/serialize.py index a52dc8f45..3939b1306 100644 --- a/metagpt/utils/serialize.py +++ b/metagpt/utils/serialize.py @@ -57,20 +57,6 @@ def actionoutput_str_to_mapping(mapping: dict) -> dict: return new_mapping -def serialize_general_message(message: "Message") -> dict: - """ serialize Message, not to save""" - message_cp = copy.deepcopy(message) - ic = message_cp.instruct_content - if ic: - # model create by pydantic create_model like `pydantic.main.prd`, can't load directly - schema = ic.schema() - mapping = actionoutout_schema_to_mapping(schema) - mapping = actionoutput_mapping_to_str(mapping) - - message_cp.instruct_content = {"class": schema["title"], "mapping": mapping, "value": ic.dict()} - return message_cp.dict() - - def serialize_message(message: "Message"): message_cp = copy.deepcopy(message) # avoid `instruct_content` value update by reference ic = message_cp.instruct_content @@ -85,23 +71,6 @@ def serialize_message(message: "Message"): return msg_ser -def deserialize_general_message(message_dict: dict) -> "Message": - """ deserialize Message, not to load""" - instruct_content = message_dict.pop("instruct_content") - - message_cls = import_class("Message", "metagpt.schema") - message = message_cls(**message_dict) - if instruct_content: - ic = instruct_content - mapping = actionoutput_str_to_mapping(ic["mapping"]) - actionnode_class = import_class("ActionNode", "metagpt.actions.action_node") # avoid circular import - ic_obj = actionnode_class.create_model_class(class_name=ic["class"], mapping=mapping) - ic_new = ic_obj(**ic["value"]) - message.instruct_content = ic_new - - return message - - def deserialize_message(message_ser: str) -> "Message": message = pickle.loads(message_ser) if message.instruct_content: diff --git a/tests/metagpt/test_schema.py b/tests/metagpt/test_schema.py index 054a92de1..ef706abfa 100644 --- a/tests/metagpt/test_schema.py +++ b/tests/metagpt/test_schema.py @@ -15,7 +15,6 @@ from metagpt.actions import Action from metagpt.schema import AIMessage, Message, SystemMessage, UserMessage from metagpt.actions.action_node import ActionNode from metagpt.actions.write_code import WriteCode -from metagpt.utils.serialize import serialize_general_message, deserialize_general_message from metagpt.utils.common import any_to_str @@ -23,10 +22,10 @@ from metagpt.utils.common import any_to_str def test_messages(): test_content = "test_message" msgs = [ - UserMessage(test_content), - SystemMessage(test_content), - AIMessage(test_content), - Message(test_content, role="QA"), + UserMessage(content=test_content), + SystemMessage(content=test_content), + AIMessage(content=test_content), + Message(content=test_content, role="QA"), ] text = str(msgs) roles = ["user", "system", "assistant", "QA"] @@ -35,7 +34,7 @@ def test_messages(): @pytest.mark.asyncio def test_message(): - m = Message("a", role="v1") + m = Message(content="a", role="v1") v = m.dump() d = json.loads(v) assert d @@ -48,7 +47,7 @@ def test_message(): assert m.content == "a" assert m.role == "v2" - m = Message("a", role="b", cause_by="c", x="d", send_to="c") + m = Message(content="a", role="b", cause_by="c", x="d", send_to="c") assert m.content == "a" assert m.role == "b" assert m.send_to == {"c"} @@ -66,7 +65,7 @@ def test_message(): @pytest.mark.asyncio def test_routes(): - m = Message("a", role="b", cause_by="c", x="d", send_to="c") + m = Message(content="a", role="b", cause_by="c", x="d", send_to="c") m.send_to = "b" assert m.send_to == {"b"} m.send_to = {"e", Action} @@ -84,8 +83,8 @@ def test_message_serdeser(): role="engineer", cause_by=WriteCode ) - message_dict = serialize_general_message(message) - assert message_dict["cause_by"] == {"action_class": "WriteCode", "module_name": "metagpt.actions.write_code"} + message_dict = message.dict() + assert message_dict["cause_by"] == "metagpt.actions.write_code.WriteCode" assert message_dict["instruct_content"] == { "class": "code", "mapping": { @@ -98,14 +97,14 @@ def test_message_serdeser(): } } - new_message = deserialize_general_message(message_dict) + new_message = Message(**message_dict) assert new_message.content == message.content assert new_message.instruct_content == message.instruct_content assert new_message.cause_by == message.cause_by assert new_message.instruct_content.field3 == out_data["field3"] message = Message(content="code") - message_dict = serialize_general_message(message) - new_message = deserialize_general_message(message_dict) + message_dict = message.dict() + new_message = Message(**message_dict) assert new_message.instruct_content is None - assert new_message.cause_by == "" + assert new_message.cause_by == "metagpt.actions.add_requirement.UserRequirement" From fa1af925376b12a11f8c5e585bdb0a101f027792 Mon Sep 17 00:00:00 2001 From: voidking Date: Tue, 19 Dec 2023 20:34:53 +0800 Subject: [PATCH 142/167] =?UTF-8?q?feature:=20=E6=94=AF=E6=8C=81pre-commit?= =?UTF-8?q?=E6=A3=80=E6=9F=A5=E4=BB=A3=E7=A0=81=E8=A7=84=E8=8C=83?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .github/workflows/pre-commit.yaml | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 .github/workflows/pre-commit.yaml diff --git a/.github/workflows/pre-commit.yaml b/.github/workflows/pre-commit.yaml new file mode 100644 index 000000000..ed4bbb144 --- /dev/null +++ b/.github/workflows/pre-commit.yaml @@ -0,0 +1,30 @@ +name: Pre-commit checks + +on: + pull_request: + branches: + - '**' + push: + branches: + - '**' + +jobs: + pre-commit-check: + runs-on: ubuntu-latest + steps: + - name: Checkout Source Code + uses: actions/checkout@v2 + + - name: Setup Python + uses: actions/setup-python@v2 + with: + python-version: '3.9.17' + + - name: Install pre-commit + run: pip install pre-commit + + - name: Initialize pre-commit + run: pre-commit install + + - name: Run pre-commit hooks + run: pre-commit run --all-files \ No newline at end of file From bf9fa4476549d2b57fbe62f5a6df9d2825d46a21 Mon Sep 17 00:00:00 2001 From: voidking Date: Wed, 20 Dec 2023 14:07:52 +0800 Subject: [PATCH 143/167] =?UTF-8?q?=E4=BF=AE=E6=94=B9=20metagpt/team.py=20?= =?UTF-8?q?=E7=AC=A6=E5=90=88=E4=BB=A3=E7=A0=81=E8=A7=84=E8=8C=83?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- metagpt/team.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/metagpt/team.py b/metagpt/team.py index 1df3c4052..0c1efb812 100644 --- a/metagpt/team.py +++ b/metagpt/team.py @@ -10,6 +10,7 @@ from pathlib import Path import warnings + from pydantic import BaseModel, Field from metagpt.actions import UserRequirement @@ -94,9 +95,12 @@ class Team(BaseModel): Deprecated: This method will be removed in the future. Please use the `run_project` method instead. """ - warnings.warn("The 'start_project' method is deprecated and will be removed in the future. " - "Please use the 'run_project' method instead.", - DeprecationWarning, stacklevel=2) + warnings.warn( + "The 'start_project' method is deprecated and will be removed in the future. " + "Please use the 'run_project' method instead.", + DeprecationWarning, + stacklevel=2, + ) return self.run_project(idea=idea, send_to=send_to) def _save(self): From 4929e41f18cb047bf583fd43d25a16bacb886d93 Mon Sep 17 00:00:00 2001 From: voidking Date: Thu, 21 Dec 2023 10:48:46 +0800 Subject: [PATCH 144/167] run pre-commit to find potential issues and fix them --- metagpt/actions/action.py | 9 +++- metagpt/actions/debug_error.py | 2 +- metagpt/actions/design_api.py | 8 ++-- metagpt/actions/fix_bug.py | 1 + metagpt/actions/prepare_documents.py | 1 + metagpt/actions/run_code.py | 2 +- metagpt/actions/search_and_summarize.py | 18 ++++---- metagpt/actions/write_code_review.py | 12 ++--- metagpt/actions/write_prd.py | 5 +- metagpt/actions/write_test.py | 5 +- metagpt/environment.py | 20 ++++---- metagpt/memory/longterm_memory.py | 3 +- metagpt/memory/memory.py | 15 ++++-- .../postprecess/base_postprecess_plugin.py | 1 - metagpt/roles/architect.py | 8 ++-- metagpt/roles/customer_service.py | 6 +-- metagpt/roles/engineer.py | 11 +++-- metagpt/roles/product_manager.py | 2 +- metagpt/roles/project_manager.py | 8 ++-- metagpt/roles/qa_engineer.py | 15 +++--- metagpt/roles/role.py | 46 +++++++++++-------- metagpt/roles/sales.py | 2 - metagpt/roles/searcher.py | 2 +- metagpt/schema.py | 20 ++++---- metagpt/startup.py | 4 +- metagpt/team.py | 20 +++++--- metagpt/utils/common.py | 25 +++++----- .../serialize_deserialize/test_environment.py | 11 +++-- .../serialize_deserialize/test_memory.py | 22 ++++----- .../serialize_deserialize/test_role.py | 19 ++++---- .../serialize_deserialize/test_schema.py | 14 ++---- .../test_serdeser_base.py | 7 ++- .../serialize_deserialize/test_team.py | 18 +++++--- .../serialize_deserialize/test_write_code.py | 5 +- .../test_write_code_review.py | 2 +- tests/metagpt/test_environment.py | 33 +++++-------- tests/metagpt/test_schema.py | 20 ++------ tests/metagpt/test_team.py | 2 +- 38 files changed, 209 insertions(+), 215 deletions(-) diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index 535c25cb9..62434e7f8 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -8,13 +8,18 @@ from __future__ import annotations -from typing import Optional, Any +from typing import Any, Optional from pydantic import BaseModel, Field from metagpt.llm import LLM from metagpt.provider.base_gpt_api import BaseGPTAPI -from metagpt.schema import CodingContext, CodeSummarizeContext, TestingContext, RunCodeContext +from metagpt.schema import ( + CodeSummarizeContext, + CodingContext, + RunCodeContext, + TestingContext, +) action_subclass_registry = {} diff --git a/metagpt/actions/debug_error.py b/metagpt/actions/debug_error.py index 839acdc2e..9dc6862f9 100644 --- a/metagpt/actions/debug_error.py +++ b/metagpt/actions/debug_error.py @@ -17,7 +17,7 @@ from metagpt.config import CONFIG from metagpt.const import TEST_CODES_FILE_REPO, TEST_OUTPUTS_FILE_REPO from metagpt.llm import LLM, BaseGPTAPI from metagpt.logs import logger -from metagpt.schema import RunCodeResult, RunCodeContext +from metagpt.schema import RunCodeContext, RunCodeResult from metagpt.utils.common import CodeParser from metagpt.utils.file_repository import FileRepository diff --git a/metagpt/actions/design_api.py b/metagpt/actions/design_api.py index f5e122356..055365421 100644 --- a/metagpt/actions/design_api.py +++ b/metagpt/actions/design_api.py @@ -45,9 +45,11 @@ class WriteDesign(Action): name: str = "" context: Optional[str] = None llm: BaseGPTAPI = Field(default_factory=LLM) - desc: str = "Based on the PRD, think about the system design, and design the corresponding APIs, " \ - "data structures, library tables, processes, and paths. Please provide your design, feedback " \ - "clearly and in detail." + desc: str = ( + "Based on the PRD, think about the system design, and design the corresponding APIs, " + "data structures, library tables, processes, and paths. Please provide your design, feedback " + "clearly and in detail." + ) async def run(self, with_messages: Message, schema: str = CONFIG.prompt_schema): # Use `git diff` to identify which PRD documents have been modified in the `docs/prds` directory. diff --git a/metagpt/actions/fix_bug.py b/metagpt/actions/fix_bug.py index eea40c91a..56b488218 100644 --- a/metagpt/actions/fix_bug.py +++ b/metagpt/actions/fix_bug.py @@ -9,6 +9,7 @@ from metagpt.actions import Action class FixBug(Action): """Fix bug action without any implementation details""" + name: str = "FixBug" async def run(self, *args, **kwargs): diff --git a/metagpt/actions/prepare_documents.py b/metagpt/actions/prepare_documents.py index 9b5128cbd..696dc9a89 100644 --- a/metagpt/actions/prepare_documents.py +++ b/metagpt/actions/prepare_documents.py @@ -25,6 +25,7 @@ from metagpt.utils.git_repository import GitRepository class PrepareDocuments(Action): """PrepareDocuments Action: initialize project folder and add new requirements to docs/requirements.txt.""" + name: str = "PrepareDocuments" context: Optional[str] = None llm: BaseGPTAPI = Field(default_factory=LLM) diff --git a/metagpt/actions/run_code.py b/metagpt/actions/run_code.py index ea16c8891..bca9b337d 100644 --- a/metagpt/actions/run_code.py +++ b/metagpt/actions/run_code.py @@ -24,7 +24,7 @@ from metagpt.actions.action import Action from metagpt.config import CONFIG from metagpt.llm import LLM, BaseGPTAPI from metagpt.logs import logger -from metagpt.schema import RunCodeResult, RunCodeContext +from metagpt.schema import RunCodeContext, RunCodeResult from metagpt.utils.exceptions import handle_exception PROMPT_TEMPLATE = """ diff --git a/metagpt/actions/search_and_summarize.py b/metagpt/actions/search_and_summarize.py index 3f110c370..6ab7becb6 100644 --- a/metagpt/actions/search_and_summarize.py +++ b/metagpt/actions/search_and_summarize.py @@ -5,18 +5,18 @@ @Author : alexanderwu @File : search_google.py """ +from typing import Optional + import pydantic -from typing import Optional, Any -from pydantic import BaseModel, Field +from pydantic import Field, root_validator from metagpt.actions import Action +from metagpt.config import CONFIG, Config from metagpt.llm import LLM -from metagpt.provider.base_gpt_api import BaseGPTAPI -from metagpt.config import Config, CONFIG from metagpt.logs import logger +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.schema import Message from metagpt.tools.search_engine import SearchEngine -from pydantic import root_validator SEARCH_AND_SUMMARIZE_SYSTEM = """### Requirements 1. Please summarize the latest dialogue based on the reference information (secondary) and dialogue history (primary). Do not include text that is irrelevant to the conversation. @@ -120,7 +120,7 @@ class SearchAndSummarize(Action): engine = values.get("engine") search_func = values.get("search_func") config = Config() - + if engine is None: engine = config.search_engine try: @@ -135,7 +135,7 @@ class SearchAndSummarize(Action): if self.search_engine is None: logger.warning("Configure one of SERPAPI_API_KEY, SERPER_API_KEY, GOOGLE_API_KEY to unlock full feature") return "" - + query = context[-1].content # logger.debug(query) rsp = await self.search_engine.run(query) @@ -144,9 +144,9 @@ class SearchAndSummarize(Action): logger.error("empty rsp...") return "" # logger.info(rsp) - + system_prompt = [system_text] - + prompt = SEARCH_AND_SUMMARIZE_PROMPT.format( ROLE=self.prefix, CONTEXT=rsp, diff --git a/metagpt/actions/write_code_review.py b/metagpt/actions/write_code_review.py index 580069b74..1eba672a5 100644 --- a/metagpt/actions/write_code_review.py +++ b/metagpt/actions/write_code_review.py @@ -142,15 +142,9 @@ class WriteCodeReview(Action): iterative_code = self.context.code_doc.content k = CONFIG.code_review_k_times or 1 for i in range(k): - format_example = FORMAT_EXAMPLE.format( - filename=self.context.code_doc.filename - ) - task_content = ( - self.context.task_doc.content if self.context.task_doc else "" - ) - code_context = await WriteCode.get_codes( - self.context.task_doc, exclude=self.context.filename - ) + format_example = FORMAT_EXAMPLE.format(filename=self.context.code_doc.filename) + task_content = self.context.task_doc.content if self.context.task_doc else "" + code_context = await WriteCode.get_codes(self.context.task_doc, exclude=self.context.filename) context = "\n".join( [ "## System Design\n" + str(self.context.design_doc) + "\n", diff --git a/metagpt/actions/write_prd.py b/metagpt/actions/write_prd.py index df66e6442..1223e5486 100644 --- a/metagpt/actions/write_prd.py +++ b/metagpt/actions/write_prd.py @@ -143,8 +143,9 @@ class WritePRD(Action): async def _update_prd(self, requirement_doc, prd_doc, prds_file_repo, *args, **kwargs) -> Document | None: if not prd_doc: - prd = await self._run_new_requirement(requirements=[requirement_doc.content if requirement_doc else ""], - *args, **kwargs) + prd = await self._run_new_requirement( + requirements=[requirement_doc.content if requirement_doc else ""], *args, **kwargs + ) new_prd_doc = Document( root_path=PRDS_FILE_REPO, filename=FileRepository.new_filename() + ".json", diff --git a/metagpt/actions/write_test.py b/metagpt/actions/write_test.py index fa3931ba6..9eb0bdbb6 100644 --- a/metagpt/actions/write_test.py +++ b/metagpt/actions/write_test.py @@ -9,14 +9,15 @@ """ from typing import Optional + from pydantic import Field -from metagpt.llm import LLM -from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.actions.action import Action from metagpt.config import CONFIG from metagpt.const import TEST_CODES_FILE_REPO +from metagpt.llm import LLM from metagpt.logs import logger +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.schema import Document, TestingContext from metagpt.utils.common import CodeParser diff --git a/metagpt/environment.py b/metagpt/environment.py index ab296557f..58569ec08 100644 --- a/metagpt/environment.py +++ b/metagpt/environment.py @@ -56,12 +56,14 @@ class Environment(BaseModel): roles_path = stg_path.joinpath("roles.json") roles_info = [] for role_key, role in self.roles.items(): - roles_info.append({ - "role_class": role.__class__.__name__, - "module_name": role.__module__, - "role_name": role.name, - "role_sub_tags": list(self.members.get(role)) - }) + roles_info.append( + { + "role_class": role.__class__.__name__, + "module_name": role.__module__, + "role_name": role.name, + "role_sub_tags": list(self.members.get(role)), + } + ) role.serialize(stg_path=stg_path.joinpath(f"roles/{role.__class__.__name__}_{role.name}")) write_json_file(roles_path, roles_info) @@ -70,7 +72,7 @@ class Environment(BaseModel): @classmethod def deserialize(cls, stg_path: Path) -> "Environment": - """ stg_path: ./storage/team/environment/ """ + """stg_path: ./storage/team/environment/""" roles_path = stg_path.joinpath("roles.json") roles_info = read_json_file(roles_path) roles = [] @@ -83,9 +85,7 @@ class Environment(BaseModel): history = read_json_file(stg_path.joinpath("history.json")) history = history.get("content") - environment = Environment(**{ - "history": history - }) + environment = Environment(**{"history": history}) environment.add_roles(roles) return environment diff --git a/metagpt/memory/longterm_memory.py b/metagpt/memory/longterm_memory.py index 76a8deabb..710074f81 100644 --- a/metagpt/memory/longterm_memory.py +++ b/metagpt/memory/longterm_memory.py @@ -5,9 +5,7 @@ """ from typing import Optional -from pydantic import Field -from typing import Optional from pydantic import Field from metagpt.logs import logger @@ -22,6 +20,7 @@ class LongTermMemory(Memory): - recover memory when it staruped - update memory when it changed """ + memory_storage: MemoryStorage = Field(default_factory=MemoryStorage) rc: Optional["RoleContext"] = None msg_from_recover: bool = False diff --git a/metagpt/memory/memory.py b/metagpt/memory/memory.py index 076db832a..e9891ed00 100644 --- a/metagpt/memory/memory.py +++ b/metagpt/memory/memory.py @@ -6,7 +6,6 @@ @File : memory.py @Modified By: mashenquan, 2023-11-1. According to RFC 116: Updated the type of index key. """ -import copy from collections import defaultdict from pathlib import Path from typing import Iterable, Set @@ -14,11 +13,17 @@ from typing import Iterable, Set from pydantic import BaseModel, Field from metagpt.schema import Message -from metagpt.utils.common import any_to_str, any_to_str_set, read_json_file, write_json_file +from metagpt.utils.common import ( + any_to_str, + any_to_str_set, + read_json_file, + write_json_file, +) class Memory(BaseModel): """The most basic memory: super-memory""" + storage: list[Message] = [] index: dict[str, list[Message]] = Field(default_factory=defaultdict(list)) @@ -32,14 +37,14 @@ class Memory(BaseModel): self.index = new_index def serialize(self, stg_path: Path): - """ stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/ """ + """stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/""" memory_path = stg_path.joinpath("memory.json") storage = self.dict() write_json_file(memory_path, storage) @classmethod def deserialize(cls, stg_path: Path) -> "Memory": - """ stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/""" + """stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/""" memory_path = stg_path.joinpath("memory.json") memory_dict = read_json_file(memory_path) @@ -68,7 +73,7 @@ class Memory(BaseModel): return [message for message in self.storage if content in message.content] def delete_newest(self) -> "Message": - """ delete the newest message from the storage""" + """delete the newest message from the storage""" if len(self.storage) > 0: newest_msg = self.storage.pop() if newest_msg.cause_by and newest_msg in self.index[newest_msg.cause_by]: diff --git a/metagpt/provider/postprecess/base_postprecess_plugin.py b/metagpt/provider/postprecess/base_postprecess_plugin.py index afcef2531..46646be91 100644 --- a/metagpt/provider/postprecess/base_postprecess_plugin.py +++ b/metagpt/provider/postprecess/base_postprecess_plugin.py @@ -4,7 +4,6 @@ from typing import Union -from metagpt.logs import logger from metagpt.utils.repair_llm_raw_output import ( RepairType, extract_content_from_output, diff --git a/metagpt/roles/architect.py b/metagpt/roles/architect.py index bd6cd110b..c6ceaccb7 100644 --- a/metagpt/roles/architect.py +++ b/metagpt/roles/architect.py @@ -5,7 +5,6 @@ @Author : alexanderwu @File : architect.py """ -from pydantic import Field from metagpt.actions import WritePRD from metagpt.actions.design_api import WriteDesign @@ -22,11 +21,14 @@ class Architect(Role): goal (str): Primary goal or responsibility of the architect. constraints (str): Constraints or guidelines for the architect. """ + name: str = "Bob" profile: str = "Architect" goal: str = "design a concise, usable, complete software system" - constraints: str = "make sure the architecture is simple enough and use appropriate open source " \ - "libraries. Use same language as user requirement" + constraints: str = ( + "make sure the architecture is simple enough and use appropriate open source " + "libraries. Use same language as user requirement" + ) def __init__(self, **kwargs) -> None: super().__init__(**kwargs) diff --git a/metagpt/roles/customer_service.py b/metagpt/roles/customer_service.py index b2033ac0b..777f62731 100644 --- a/metagpt/roles/customer_service.py +++ b/metagpt/roles/customer_service.py @@ -6,7 +6,6 @@ @File : sales.py """ from typing import Optional -from pydantic import Field from metagpt.roles import Sales @@ -27,14 +26,11 @@ DESC = """ class CustomerService(Sales): - name: str = "Xiaomei" profile: str = "Human customer service" desc: str = DESC store: Optional[str] = None - def __init__( - self, - **kwargs): + def __init__(self, **kwargs): super().__init__(**kwargs) diff --git a/metagpt/roles/engineer.py b/metagpt/roles/engineer.py index 337184068..e0234f378 100644 --- a/metagpt/roles/engineer.py +++ b/metagpt/roles/engineer.py @@ -24,8 +24,6 @@ from collections import defaultdict from pathlib import Path from typing import Set -from pydantic import Field - from metagpt.actions import Action, WriteCode, WriteCodeReview, WriteTasks from metagpt.actions.fix_bug import FixBug from metagpt.actions.summarize_code import SummarizeCode @@ -69,11 +67,14 @@ class Engineer(Role): n_borg (int): Number of borgs. use_code_review (bool): Whether to use code review. """ + name: str = "Alex" profile: str = "Engineer" goal: str = "write elegant, readable, extensible, efficient code" - constraints: str = "the code should conform to standards like google-style and be modular and maintainable. " \ - "Use same language as user requirement" + constraints: str = ( + "the code should conform to standards like google-style and be modular and maintainable. " + "Use same language as user requirement" + ) n_borg: int = 1 use_code_review: bool = False code_todos: list = [] @@ -212,7 +213,7 @@ class Engineer(Role): @staticmethod async def _new_coding_context( - filename, src_file_repo, task_file_repo, design_file_repo, dependency + filename, src_file_repo, task_file_repo, design_file_repo, dependency ) -> CodingContext: old_code_doc = await src_file_repo.get(filename) if not old_code_doc: diff --git a/metagpt/roles/product_manager.py b/metagpt/roles/product_manager.py index 6369688a5..c794ad2eb 100644 --- a/metagpt/roles/product_manager.py +++ b/metagpt/roles/product_manager.py @@ -7,7 +7,6 @@ @Modified By: mashenquan, 2023/11/27. Add `PrepareDocuments` action according to Section 2.2.3.5.1 of RFC 135. """ -from pydantic import Field from metagpt.actions import UserRequirement, WritePRD from metagpt.actions.prepare_documents import PrepareDocuments @@ -25,6 +24,7 @@ class ProductManager(Role): goal (str): Goal of the product manager. constraints (str): Constraints or limitations for the product manager. """ + name: str = "Alice" profile: str = "Product Manager" goal: str = "efficiently create a successful product that meets market demands and user expectations" diff --git a/metagpt/roles/project_manager.py b/metagpt/roles/project_manager.py index bf572d1f8..1fad4afc2 100644 --- a/metagpt/roles/project_manager.py +++ b/metagpt/roles/project_manager.py @@ -5,7 +5,6 @@ @Author : alexanderwu @File : project_manager.py """ -from pydantic import Field from metagpt.actions import WriteTasks from metagpt.actions.design_api import WriteDesign @@ -22,10 +21,13 @@ class ProjectManager(Role): goal (str): Goal of the project manager. constraints (str): Constraints or limitations for the project manager. """ + name: str = "Eve" profile: str = "Project Manager" - goal: str = "break down tasks according to PRD/technical design, generate a task list, and analyze task " \ - "dependencies to start with the prerequisite modules" + goal: str = ( + "break down tasks according to PRD/technical design, generate a task list, and analyze task " + "dependencies to start with the prerequisite modules" + ) constraints: str = "use same language as user requirement" def __init__(self, **kwargs) -> None: diff --git a/metagpt/roles/qa_engineer.py b/metagpt/roles/qa_engineer.py index 369e3dc63..5e509300b 100644 --- a/metagpt/roles/qa_engineer.py +++ b/metagpt/roles/qa_engineer.py @@ -15,13 +15,8 @@ of SummarizeCode. """ -from pydantic import Field -from metagpt.actions import ( - DebugError, - RunCode, - WriteTest, -) +from metagpt.actions import DebugError, RunCode, WriteTest from metagpt.actions.summarize_code import SummarizeCode from metagpt.config import CONFIG from metagpt.const import ( @@ -40,8 +35,9 @@ class QaEngineer(Role): name: str = "Edward" profile: str = "QaEngineer" goal: str = "Write comprehensive and robust tests to ensure codes will work as expected without bugs" - constraints: str = "The test code you write should conform to code standard like PEP8, be modular, " \ - "easy to read and maintain" + constraints: str = ( + "The test code you write should conform to code standard like PEP8, be modular, " "easy to read and maintain" + ) test_round_allowed: int = 5 def __init__(self, **kwargs): @@ -118,7 +114,8 @@ class QaEngineer(Role): ) run_code_context.code = None run_code_context.test_code = None - recipient = parse_recipient(result.summary) # the recipient might be Engineer or myself + # the recipient might be Engineer or myself + recipient = parse_recipient(result.summary) mappings = {"Engineer": "Alex", "QaEngineer": "Edward"} self.publish_message( Message( diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index f87c4e250..8c5743467 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -23,7 +23,7 @@ from __future__ import annotations from enum import Enum from pathlib import Path -from typing import Iterable, Set, Type, Any +from typing import Any, Iterable, Set, Type from pydantic import BaseModel, Field @@ -37,7 +37,13 @@ from metagpt.logs import logger from metagpt.memory import Memory from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.schema import Message, MessageQueue -from metagpt.utils.common import any_to_str, read_json_file, write_json_file, import_class, role_raise_decorator +from metagpt.utils.common import ( + any_to_str, + import_class, + read_json_file, + role_raise_decorator, + write_json_file, +) from metagpt.utils.repair_llm_raw_output import extract_state_value_from_output PREFIX_TEMPLATE = """You are a {profile}, named {name}, your goal is {goal}, and the constraint is {constraints}. """ @@ -82,18 +88,22 @@ class RoleReactMode(str, Enum): class RoleContext(BaseModel): """Role Runtime Context""" + # # env exclude=True to avoid `RecursionError: maximum recursion depth exceeded in comparison` env: "Environment" = Field(default=None, exclude=True) # TODO judge if ser&deser - msg_buffer: MessageQueue = Field(default_factory=MessageQueue, - exclude=True) # Message Buffer with Asynchronous Updates + msg_buffer: MessageQueue = Field( + default_factory=MessageQueue, exclude=True + ) # Message Buffer with Asynchronous Updates memory: Memory = Field(default_factory=Memory) # long_term_memory: LongTermMemory = Field(default_factory=LongTermMemory) state: int = Field(default=-1) # -1 indicates initial or termination state where todo is None todo: Action = Field(default=None, exclude=True) watch: set[str] = Field(default_factory=set) news: list[Type[Message]] = Field(default=[], exclude=True) # TODO not used - react_mode: RoleReactMode = RoleReactMode.REACT # see `Role._set_react_mode` for definitions of the following two attributes + react_mode: RoleReactMode = ( + RoleReactMode.REACT + ) # see `Role._set_react_mode` for definitions of the following two attributes max_react_loop: int = 1 class Config: @@ -120,6 +130,7 @@ role_subclass_registry = {} class Role(BaseModel): """Role/Agent""" + name: str = "" profile: str = "" goal: str = "" @@ -145,7 +156,7 @@ class Role(BaseModel): "_states": [], "_actions": [], "_rc": RoleContext(), - "_subscription": set() + "_subscription": set(), } __hash__ = object.__hash__ # support Role as hashable type in `Environment.members` @@ -206,14 +217,14 @@ class Role(BaseModel): return f"{self.name}({self.profile})" def serialize(self, stg_path: Path = None): - stg_path = SERDESER_PATH.joinpath(f"team/environment/roles/{self.__class__.__name__}_{self.name}") \ - if stg_path is None else stg_path + stg_path = ( + SERDESER_PATH.joinpath(f"team/environment/roles/{self.__class__.__name__}_{self.name}") + if stg_path is None + else stg_path + ) role_info = self.dict(exclude={"_rc": {"memory": True, "msg_buffer": True}, "_llm": True}) - role_info.update({ - "role_class": self.__class__.__name__, - "module_name": self.__module__ - }) + role_info.update({"role_class": self.__class__.__name__, "module_name": self.__module__}) role_info_path = stg_path.joinpath("role_info.json") write_json_file(role_info_path, role_info) @@ -221,7 +232,7 @@ class Role(BaseModel): @classmethod def deserialize(cls, stg_path: Path) -> "Role": - """ stg_path = ./storage/team/environment/roles/{role_class}_{role_name}""" + """stg_path = ./storage/team/environment/roles/{role_class}_{role_name}""" role_info_path = stg_path.joinpath("role_info.json") role_info = read_json_file(role_info_path) @@ -328,12 +339,9 @@ class Role(BaseModel): """Get the role prefix""" if self.desc: return self.desc - return PREFIX_TEMPLATE.format(**{ - "profile": self.profile, - "name": self.name, - "goal": self.goal, - "constraints": self.constraints - }) + return PREFIX_TEMPLATE.format( + **{"profile": self.profile, "name": self.name, "goal": self.goal, "constraints": self.constraints} + ) async def _think(self) -> None: """Think about what to do and decide on the next action""" diff --git a/metagpt/roles/sales.py b/metagpt/roles/sales.py index fd5a42915..ba0a6fc6b 100644 --- a/metagpt/roles/sales.py +++ b/metagpt/roles/sales.py @@ -7,7 +7,6 @@ """ from typing import Optional -from pydantic import Field from metagpt.actions import SearchAndSummarize from metagpt.roles import Role @@ -15,7 +14,6 @@ from metagpt.tools import SearchEngineType class Sales(Role): - name: str = "Xiaomei" profile: str = "Retail sales guide" desc: str = "I am a sales guide in retail. My name is Xiaomei. I will answer some customer questions next, and I " diff --git a/metagpt/roles/searcher.py b/metagpt/roles/searcher.py index a5c399f47..a2136064f 100644 --- a/metagpt/roles/searcher.py +++ b/metagpt/roles/searcher.py @@ -35,7 +35,7 @@ class Searcher(Role): goal: str = "Provide search services for users" constraints: str = "Answer is rich and complete" engine: SearchEngineType = SearchEngineType.SERPAPI_GOOGLE - + def __init__(self, **kwargs) -> None: """ Initializes the Searcher role with given attributes. diff --git a/metagpt/schema.py b/metagpt/schema.py index 5103a4f20..4a9df7fe2 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -23,7 +23,7 @@ from abc import ABC from asyncio import Queue, QueueEmpty, wait_for from json import JSONDecodeError from pathlib import Path -from typing import Dict, List, Optional, Set, Type, TypedDict, TypeVar, Any +from typing import Any, Dict, List, Optional, Set, Type, TypedDict, TypeVar from pydantic import BaseModel, Field @@ -38,9 +38,12 @@ from metagpt.const import ( ) from metagpt.logs import logger from metagpt.utils.common import any_to_str, any_to_str_set, import_class -from metagpt.utils.serialize import actionoutout_schema_to_mapping, actionoutput_mapping_to_str, \ - actionoutput_str_to_mapping from metagpt.utils.exceptions import handle_exception +from metagpt.utils.serialize import ( + actionoutout_schema_to_mapping, + actionoutput_mapping_to_str, + actionoutput_str_to_mapping, +) class RawMessage(TypedDict): @@ -119,8 +122,9 @@ class Message(BaseModel): kwargs["instruct_content"] = ic_new kwargs["id"] = kwargs.get("id", uuid.uuid4().hex) - kwargs["cause_by"] = any_to_str(kwargs.get("cause_by", - import_class("UserRequirement", "metagpt.actions.add_requirement"))) + kwargs["cause_by"] = any_to_str( + kwargs.get("cause_by", import_class("UserRequirement", "metagpt.actions.add_requirement")) + ) kwargs["sent_from"] = any_to_str(kwargs.get("sent_from", "")) kwargs["send_to"] = any_to_str_set(kwargs.get("send_to", {MESSAGE_ROUTE_TO_ALL})) super(Message, self).__init__(**kwargs) @@ -138,7 +142,7 @@ class Message(BaseModel): super().__setattr__(key, new_val) def dict(self, *args, **kwargs) -> "DictStrAny": - """ overwrite the `dict` to dump dynamic pydantic model""" + """overwrite the `dict` to dump dynamic pydantic model""" obj_dict = super(Message, self).dict(*args, **kwargs) ic = self.instruct_content if ic: @@ -208,9 +212,7 @@ class MessageQueue(BaseModel): _queue: Queue = Field(default_factory=Queue) - _private_attributes = { - "_queue": Queue() - } + _private_attributes = {"_queue": Queue()} class Config: arbitrary_types_allowed = True diff --git a/metagpt/startup.py b/metagpt/startup.py index 59e0cb199..767a19a9d 100644 --- a/metagpt/startup.py +++ b/metagpt/startup.py @@ -1,9 +1,9 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- import asyncio +from pathlib import Path import typer -from pathlib import Path from metagpt.config import CONFIG @@ -32,7 +32,7 @@ def startup( help="The maximum number of times the 'SummarizeCode' action is automatically invoked, with -1 indicating " "unlimited. This parameter is used for debugging the workflow.", ), - recover_path: str = typer.Option(default=None, help="recover the project from existing serialized storage") + recover_path: str = typer.Option(default=None, help="recover the project from existing serialized storage"), ): """Run a startup. Be a boss.""" from metagpt.roles import ( diff --git a/metagpt/team.py b/metagpt/team.py index 0c1efb812..8b92ed47a 100644 --- a/metagpt/team.py +++ b/metagpt/team.py @@ -8,20 +8,24 @@ Section 2.2.3.3 of RFC 135. """ -from pathlib import Path import warnings +from pathlib import Path from pydantic import BaseModel, Field from metagpt.actions import UserRequirement from metagpt.config import CONFIG -from metagpt.const import MESSAGE_ROUTE_TO_ALL -from metagpt.const import SERDESER_PATH +from metagpt.const import MESSAGE_ROUTE_TO_ALL, SERDESER_PATH from metagpt.environment import Environment from metagpt.logs import logger from metagpt.roles import Role from metagpt.schema import Message -from metagpt.utils.common import NoMoneyException, read_json_file, write_json_file, serialize_decorator +from metagpt.utils.common import ( + NoMoneyException, + read_json_file, + serialize_decorator, + write_json_file, +) class Team(BaseModel): @@ -51,12 +55,14 @@ class Team(BaseModel): @classmethod def deserialize(cls, stg_path: Path) -> "Team": - """ stg_path = ./storage/team """ + """stg_path = ./storage/team""" # recover team_info team_info_path = stg_path.joinpath("team_info.json") if not team_info_path.exists(): - raise FileNotFoundError("recover storage meta file `team_info.json` not exist, " - "not to recover and please start a new project.") + raise FileNotFoundError( + "recover storage meta file `team_info.json` not exist, " + "not to recover and please start a new project." + ) team_info: dict = read_json_file(team_info_path) diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index e123e8fd9..ea3316d66 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -22,8 +22,7 @@ import re import traceback import typing from pathlib import Path -from typing import Any -from typing import List, Tuple, Union, get_args, get_origin +from typing import Any, List, Tuple, Union, get_args, get_origin import aiofiles import loguru @@ -219,7 +218,7 @@ class OutputParser: if start_index != -1 and end_index != -1: # Extract the structure part - structure_text = text[start_index: end_index + 1] + structure_text = text[start_index : end_index + 1] try: # Attempt to convert the text to a Python data type using ast.literal_eval @@ -439,7 +438,7 @@ def read_json_file(json_file: str, encoding=None) -> list[Any]: with open(json_file, "r", encoding=encoding) as fin: try: data = json.load(fin) - except Exception as exp: + except Exception: raise ValueError(f"read json file: {json_file} failed") return data @@ -474,9 +473,9 @@ def serialize_decorator(func): try: result = await func(self, *args, **kwargs) return result - except KeyboardInterrupt as kbi: + except KeyboardInterrupt: logger.error(f"KeyboardInterrupt occurs, start to serialize the project, exp:\n{format_trackback_info()}") - except Exception as exp: + except Exception: logger.error(f"Exception occurs, start to serialize the project, exp:\n{format_trackback_info()}") self.serialize() # Team.serialize @@ -491,14 +490,18 @@ def role_raise_decorator(func): logger.error(f"KeyboardInterrupt: {kbi} occurs, start to serialize the project") if self.latest_observed_msg: self._rc.memory.delete(self.latest_observed_msg) - raise Exception(format_trackback_info(limit=None)) # raise again to make it captured outside - except Exception as exp: + # raise again to make it captured outside + raise Exception(format_trackback_info(limit=None)) + except Exception: if self.latest_observed_msg: - logger.warning("There is a exception in role's execution, in order to resume, " - "we delete the newest role communication message in the role's memory.") + logger.warning( + "There is a exception in role's execution, in order to resume, " + "we delete the newest role communication message in the role's memory." + ) # remove role newest observed msg to make it observed again self._rc.memory.delete(self.latest_observed_msg) - raise Exception(format_trackback_info(limit=None)) # raise again to make it captured outside + # raise again to make it captured outside + raise Exception(format_trackback_info(limit=None)) return wrapper diff --git a/tests/metagpt/serialize_deserialize/test_environment.py b/tests/metagpt/serialize_deserialize/test_environment.py index b741b9c4b..096c1dd68 100644 --- a/tests/metagpt/serialize_deserialize/test_environment.py +++ b/tests/metagpt/serialize_deserialize/test_environment.py @@ -11,7 +11,11 @@ from metagpt.environment import Environment from metagpt.roles.project_manager import ProjectManager from metagpt.schema import Message from metagpt.utils.common import any_to_str -from tests.metagpt.serialize_deserialize.test_serdeser_base import RoleC, ActionOK, serdeser_path +from tests.metagpt.serialize_deserialize.test_serdeser_base import ( + ActionOK, + RoleC, + serdeser_path, +) def test_env_serialize(): @@ -35,10 +39,7 @@ def test_environment_serdeser(): ic_obj = ActionNode.create_model_class("prd", out_mapping) message = Message( - content="prd", - instruct_content=ic_obj(**out_data), - role="product manager", - cause_by=any_to_str(UserRequirement) + content="prd", instruct_content=ic_obj(**out_data), role="product manager", cause_by=any_to_str(UserRequirement) ) environment = Environment() diff --git a/tests/metagpt/serialize_deserialize/test_memory.py b/tests/metagpt/serialize_deserialize/test_memory.py index 0d756518b..5a40f5c3b 100644 --- a/tests/metagpt/serialize_deserialize/test_memory.py +++ b/tests/metagpt/serialize_deserialize/test_memory.py @@ -14,17 +14,14 @@ from tests.metagpt.serialize_deserialize.test_serdeser_base import serdeser_path def test_memory_serdeser(): - msg1 = Message(role="Boss", - content="write a snake game", - cause_by=UserRequirement) + msg1 = Message(role="Boss", content="write a snake game", cause_by=UserRequirement) out_mapping = {"field2": (list[str], ...)} out_data = {"field2": ["field2 value1", "field2 value2"]} ic_obj = ActionNode.create_model_class("system_design", out_mapping) - msg2 = Message(role="Architect", - instruct_content=ic_obj(**out_data), - content="system design content", - cause_by=WriteDesign) + msg2 = Message( + role="Architect", instruct_content=ic_obj(**out_data), content="system design content", cause_by=WriteDesign + ) memory = Memory() memory.add_batch([msg1, msg2]) @@ -40,17 +37,14 @@ def test_memory_serdeser(): def test_memory_serdeser_save(): - msg1 = Message(role="User", - content="write a 2048 game", - cause_by=UserRequirement) + msg1 = Message(role="User", content="write a 2048 game", cause_by=UserRequirement) out_mapping = {"field1": (list[str], ...)} out_data = {"field1": ["field1 value1", "field1 value2"]} ic_obj = ActionNode.create_model_class("system_design", out_mapping) - msg2 = Message(role="Architect", - instruct_content=ic_obj(**out_data), - content="system design content", - cause_by=WriteDesign) + msg2 = Message( + role="Architect", instruct_content=ic_obj(**out_data), content="system design content", cause_by=WriteDesign + ) memory = Memory() memory.add_batch([msg1, msg2]) diff --git a/tests/metagpt/serialize_deserialize/test_role.py b/tests/metagpt/serialize_deserialize/test_role.py index 88c7f7d8b..72da8a6fc 100644 --- a/tests/metagpt/serialize_deserialize/test_role.py +++ b/tests/metagpt/serialize_deserialize/test_role.py @@ -16,7 +16,12 @@ from metagpt.roles.product_manager import ProductManager from metagpt.roles.role import Role from metagpt.schema import Message from metagpt.utils.common import format_trackback_info -from tests.metagpt.serialize_deserialize.test_serdeser_base import RoleA, RoleB, RoleC, serdeser_path +from tests.metagpt.serialize_deserialize.test_serdeser_base import ( + RoleA, + RoleB, + RoleC, + serdeser_path, +) def test_roles(): @@ -75,12 +80,10 @@ async def test_role_serdeser_interrupt(): role_c = RoleC() shutil.rmtree(SERDESER_PATH.joinpath("team"), ignore_errors=True) - stg_path = SERDESER_PATH.joinpath(f"team", "environment", "roles", "{role_c.__class__.__name__}_{role_c.name}") + stg_path = SERDESER_PATH.joinpath("team", "environment", "roles", f"{role_c.__class__.__name__}_{role_c.name}") try: - await role_c.run( - with_message=Message(content="demo", cause_by=UserRequirement) - ) - except Exception as exp: + await role_c.run(with_message=Message(content="demo", cause_by=UserRequirement)) + except Exception: logger.error(f"Exception in `role_a.run`, detail: {format_trackback_info()}") role_c.serialize(stg_path) @@ -90,6 +93,4 @@ async def test_role_serdeser_interrupt(): assert new_role_a._rc.state == 1 with pytest.raises(Exception): - await role_c.run( - with_message=Message(content="demo", cause_by=UserRequirement) - ) + await role_c.run(with_message=Message(content="demo", cause_by=UserRequirement)) diff --git a/tests/metagpt/serialize_deserialize/test_schema.py b/tests/metagpt/serialize_deserialize/test_schema.py index 72b7153a7..0358265a9 100644 --- a/tests/metagpt/serialize_deserialize/test_schema.py +++ b/tests/metagpt/serialize_deserialize/test_schema.py @@ -14,12 +14,7 @@ def test_message_serdeser(): out_data = {"field3": "field3 value3", "field4": ["field4 value1", "field4 value2"]} ic_obj = ActionNode.create_model_class("code", out_mapping) - message = Message( - content="code", - instruct_content=ic_obj(**out_data), - role="engineer", - cause_by=WriteCode - ) + message = Message(content="code", instruct_content=ic_obj(**out_data), role="engineer", cause_by=WriteCode) ser_data = message.dict() assert ser_data["cause_by"] == "metagpt.actions.write_code.WriteCode" assert ser_data["instruct_content"]["class"] == "code" @@ -31,14 +26,11 @@ def test_message_serdeser(): def test_message_without_postprocess(): - """ to explain `instruct_content` should be postprocessed """ + """to explain `instruct_content` should be postprocessed""" out_mapping = {"field1": (list[str], ...)} out_data = {"field1": ["field1 value1", "field1 value2"]} ic_obj = ActionNode.create_model_class("code", out_mapping) - message = MockMessage( - content="code", - instruct_content=ic_obj(**out_data) - ) + message = MockMessage(content="code", instruct_content=ic_obj(**out_data)) ser_data = message.dict() assert ser_data["instruct_content"] == {"field1": ["field1 value1", "field1 value2"]} diff --git a/tests/metagpt/serialize_deserialize/test_serdeser_base.py b/tests/metagpt/serialize_deserialize/test_serdeser_base.py index eac083cf9..a66813489 100644 --- a/tests/metagpt/serialize_deserialize/test_serdeser_base.py +++ b/tests/metagpt/serialize_deserialize/test_serdeser_base.py @@ -16,7 +16,8 @@ serdeser_path = Path(__file__).absolute().parent.joinpath("..", "..", "data", "s class MockMessage(BaseModel): - """ to test normal dict without postprocess """ + """to test normal dict without postprocess""" + content: str = "" instruct_content: BaseModel = Field(default=None) @@ -26,9 +27,7 @@ class ActionPass(Action): async def run(self, messages: list["Message"]) -> ActionOutput: await asyncio.sleep(5) # sleep to make other roles can watch the executed Message - output_mapping = { - "result": (str, ...) - } + output_mapping = {"result": (str, ...)} pass_class = ActionNode.create_model_class("pass", output_mapping) pass_output = ActionOutput("ActionPass run passed", pass_class(**{"result": "pass result"})) diff --git a/tests/metagpt/serialize_deserialize/test_team.py b/tests/metagpt/serialize_deserialize/test_team.py index db6001325..dc41fa4ed 100644 --- a/tests/metagpt/serialize_deserialize/test_team.py +++ b/tests/metagpt/serialize_deserialize/test_team.py @@ -8,10 +8,16 @@ import shutil import pytest from metagpt.const import SERDESER_PATH -from metagpt.roles import ProjectManager, ProductManager, Architect -from metagpt.team import Team from metagpt.logs import logger -from tests.metagpt.serialize_deserialize.test_serdeser_base import RoleA, RoleB, RoleC, serdeser_path, ActionOK +from metagpt.roles import Architect, ProductManager, ProjectManager +from metagpt.team import Team +from tests.metagpt.serialize_deserialize.test_serdeser_base import ( + ActionOK, + RoleA, + RoleB, + RoleC, + serdeser_path, +) def test_team_deserialize(): @@ -110,10 +116,8 @@ async def test_team_recover_multi_roles_save(): role_a = RoleA() role_b = RoleB() - assert role_a.subscription == {"tests.metagpt.serialize_deserialize.test_serdeser_base.RoleA", - "RoleA"} - assert role_b.subscription == {"tests.metagpt.serialize_deserialize.test_serdeser_base.RoleB", - "RoleB"} + assert role_a.subscription == {"tests.metagpt.serialize_deserialize.test_serdeser_base.RoleA", "RoleA"} + assert role_b.subscription == {"tests.metagpt.serialize_deserialize.test_serdeser_base.RoleB", "RoleB"} assert role_b._rc.watch == {"tests.metagpt.serialize_deserialize.test_serdeser_base.ActionPass"} company = Team() diff --git a/tests/metagpt/serialize_deserialize/test_write_code.py b/tests/metagpt/serialize_deserialize/test_write_code.py index 0114c48da..65b8f456a 100644 --- a/tests/metagpt/serialize_deserialize/test_write_code.py +++ b/tests/metagpt/serialize_deserialize/test_write_code.py @@ -19,8 +19,9 @@ def test_write_design_serialize(): @pytest.mark.asyncio async def test_write_code_deserialize(): - context = CodingContext(filename="test_code.py", - design_doc=Document(content="write add function to calculate two numbers")) + context = CodingContext( + filename="test_code.py", design_doc=Document(content="write add function to calculate two numbers") + ) doc = Document(content=context.json()) action = WriteCode(context=doc) serialized_data = action.dict() diff --git a/tests/metagpt/serialize_deserialize/test_write_code_review.py b/tests/metagpt/serialize_deserialize/test_write_code_review.py index a15b744db..01026590c 100644 --- a/tests/metagpt/serialize_deserialize/test_write_code_review.py +++ b/tests/metagpt/serialize_deserialize/test_write_code_review.py @@ -18,7 +18,7 @@ def div(a: int, b: int = 0): context = CodingContext( filename="test_op.py", design_doc=Document(content="divide two numbers"), - code_doc=Document(content=code_content) + code_doc=Document(content=code_content), ) action = WriteCodeReview(context=context) diff --git a/tests/metagpt/test_environment.py b/tests/metagpt/test_environment.py index ee322368e..56e2b4fc3 100644 --- a/tests/metagpt/test_environment.py +++ b/tests/metagpt/test_environment.py @@ -6,9 +6,10 @@ @File : test_environment.py """ -import pytest from pathlib import Path +import pytest + from metagpt.actions import UserRequirement from metagpt.environment import Environment from metagpt.logs import logger @@ -16,7 +17,6 @@ from metagpt.manager import Manager from metagpt.roles import Architect, ProductManager, Role from metagpt.schema import Message - serdeser_path = Path(__file__).absolute().parent.joinpath("../data/serdeser_storage") @@ -26,23 +26,16 @@ def env(): def test_add_role(env: Environment): - role = ProductManager(name="Alice", - profile="product manager", - goal="create a new product", - constraints="limited resources") + role = ProductManager( + name="Alice", profile="product manager", goal="create a new product", constraints="limited resources" + ) env.add_role(role) assert env.get_role(role.profile) == role def test_get_roles(env: Environment): - role1 = Role(name="Alice", - profile="product manager", - goal="create a new product", - constraints="limited resources") - role2 = Role(name="Bob", - profile="engineer", - goal="develop the new product", - constraints="short deadline") + role1 = Role(name="Alice", profile="product manager", goal="create a new product", constraints="limited resources") + role2 = Role(name="Bob", profile="engineer", goal="develop the new product", constraints="short deadline") env.add_role(role1) env.add_role(role2) roles = env.get_roles() @@ -51,14 +44,10 @@ def test_get_roles(env: Environment): @pytest.mark.asyncio async def test_publish_and_process_message(env: Environment): - product_manager = ProductManager(name="Alice", - profile="Product Manager", - goal="做AI Native产品", - constraints="资源有限") - architect = Architect(name="Bob", - profile="Architect", - goal="设计一个可用、高效、较低成本的系统,包括数据结构与接口", - constraints="资源有限,需要节省成本") + product_manager = ProductManager(name="Alice", profile="Product Manager", goal="做AI Native产品", constraints="资源有限") + architect = Architect( + name="Bob", profile="Architect", goal="设计一个可用、高效、较低成本的系统,包括数据结构与接口", constraints="资源有限,需要节省成本" + ) env.add_roles([product_manager, architect]) diff --git a/tests/metagpt/test_schema.py b/tests/metagpt/test_schema.py index ef706abfa..1742757e8 100644 --- a/tests/metagpt/test_schema.py +++ b/tests/metagpt/test_schema.py @@ -9,12 +9,13 @@ """ import json + import pytest from metagpt.actions import Action -from metagpt.schema import AIMessage, Message, SystemMessage, UserMessage from metagpt.actions.action_node import ActionNode from metagpt.actions.write_code import WriteCode +from metagpt.schema import AIMessage, Message, SystemMessage, UserMessage from metagpt.utils.common import any_to_str @@ -77,24 +78,13 @@ def test_message_serdeser(): out_data = {"field3": "field3 value3", "field4": ["field4 value1", "field4 value2"]} ic_obj = ActionNode.create_model_class("code", out_mapping) - message = Message( - content="code", - instruct_content=ic_obj(**out_data), - role="engineer", - cause_by=WriteCode - ) + message = Message(content="code", instruct_content=ic_obj(**out_data), role="engineer", cause_by=WriteCode) message_dict = message.dict() assert message_dict["cause_by"] == "metagpt.actions.write_code.WriteCode" assert message_dict["instruct_content"] == { "class": "code", - "mapping": { - "field3": "(, Ellipsis)", - "field4": "(list[str], Ellipsis)" - }, - "value": { - "field3": "field3 value3", - "field4": ["field4 value1", "field4 value2"] - } + "mapping": {"field3": "(, Ellipsis)", "field4": "(list[str], Ellipsis)"}, + "value": {"field3": "field3 value3", "field4": ["field4 value1", "field4 value2"]}, } new_message = Message(**message_dict) diff --git a/tests/metagpt/test_team.py b/tests/metagpt/test_team.py index efd035bb2..930306b5e 100644 --- a/tests/metagpt/test_team.py +++ b/tests/metagpt/test_team.py @@ -2,8 +2,8 @@ # -*- coding: utf-8 -*- # @Desc : unittest of team -from metagpt.team import Team from metagpt.roles.project_manager import ProjectManager +from metagpt.team import Team def test_team(): From f4198dc1116ff7ace820b56513556afe7e216354 Mon Sep 17 00:00:00 2001 From: geekan Date: Thu, 21 Dec 2023 11:03:13 +0800 Subject: [PATCH 145/167] refine action node and add some experiment --- metagpt/actions/action_node.py | 57 +-- metagpt/actions/write_code_an_draft.py | 591 +++++++++++++++++++++++++ metagpt/actions/write_review.py | 5 +- metagpt/utils/common.py | 3 +- tests/metagpt/test_prompt.py | 342 ++++++++++++++ 5 files changed, 968 insertions(+), 30 deletions(-) create mode 100644 metagpt/actions/write_code_an_draft.py create mode 100644 tests/metagpt/test_prompt.py diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index 4376e09ed..8a0aaf146 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -9,7 +9,7 @@ NOTE: You should use typing.List instead of list to do type annotation. Because we can use typing to extract the type of the node, but we cannot use built-in list to extract. """ import json -from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar +from typing import Any, Dict, List, Optional, Tuple, Type from pydantic import BaseModel, create_model, root_validator, validator from tenacity import retry, stop_after_attempt, wait_random_exponential @@ -19,10 +19,11 @@ from metagpt.logs import logger from metagpt.provider.postprecess.llm_output_postprecess import llm_output_postprecess from metagpt.utils.common import OutputParser, general_after_log -CONSTRAINT = """ -- Language: Please use the same language as the user input. -- Format: output wrapped inside [CONTENT][/CONTENT] as format example, nothing else. -""" +TAG = "CONTENT" + +LANGUAGE_CONSTRAINT = "Language: Please use the same language as the user input." +FORMAT_CONSTRAINT = f"Format: output wrapped inside [{TAG}][/{TAG}] like format example, nothing else." + SIMPLE_TEMPLATE = """ ## context @@ -33,28 +34,25 @@ SIMPLE_TEMPLATE = """ ## format example {example} -## nodes: ": # " +## nodes: ": # " {instruction} ## constraint {constraint} ## action -Fill in the above nodes based on the format example. +Follow instructions of nodes, generate output and make sure it follows the format example. """ -def dict_to_markdown(d, prefix="##", kv_sep="\n", postfix="\n"): +def dict_to_markdown(d, prefix="- ", kv_sep="\n", postfix="\n"): markdown_str = "" for key, value in d.items(): markdown_str += f"{prefix}{key}{kv_sep}{value}{postfix}" return markdown_str -T = TypeVar("T") - - -class ActionNode(Generic[T]): +class ActionNode: """ActionNode is a tree of nodes.""" mode: str @@ -69,7 +67,7 @@ class ActionNode(Generic[T]): expected_type: Type # such as str / int / float etc. # context: str # everything in the history. instruction: str # the instructions should be followed. - example: T # example for In Context-Learning. + example: Any # example for In Context-Learning. # Action Output content: str @@ -80,7 +78,7 @@ class ActionNode(Generic[T]): key: str, expected_type: Type, instruction: str, - example: T, + example: Any, content: str = "", children: dict[str, "ActionNode"] = None, ): @@ -183,11 +181,11 @@ class ActionNode(Generic[T]): return node_dict - def compile_to(self, i: Dict, schema) -> str: + def compile_to(self, i: Dict, schema, kv_sep) -> str: if schema == "json": return json.dumps(i, indent=4) elif schema == "markdown": - return dict_to_markdown(i) + return dict_to_markdown(i, kv_sep=kv_sep) else: return str(i) @@ -196,26 +194,26 @@ class ActionNode(Generic[T]): return text if schema == "json": return f"[{tag}]\n" + text + f"\n[/{tag}]" - else: + else: # markdown return f"[{tag}]\n" + text + f"\n[/{tag}]" - def _compile_f(self, schema, mode, tag, format_func) -> str: + def _compile_f(self, schema, mode, tag, format_func, kv_sep) -> str: nodes = self.to_dict(format_func=format_func, mode=mode) - text = self.compile_to(nodes, schema) + text = self.compile_to(nodes, schema, kv_sep) return self.tagging(text, schema, tag) - def compile_instruction(self, schema="raw", mode="children", tag="") -> str: + def compile_instruction(self, schema="markdown", mode="children", tag="") -> str: """compile to raw/json/markdown template with all/root/children nodes""" format_func = lambda i: f"{i.expected_type} # {i.instruction}" - return self._compile_f(schema, mode, tag, format_func) + return self._compile_f(schema, mode, tag, format_func, kv_sep=": ") - def compile_example(self, schema="raw", mode="children", tag="") -> str: + def compile_example(self, schema="json", mode="children", tag="") -> str: """compile to raw/json/markdown examples with all/root/children nodes""" # 这里不能使用f-string,因为转译为str后再json.dumps会额外加上引号,无法作为有效的example # 错误示例:"File list": "['main.py', 'const.py', 'game.py']", 注意这里值不是list,而是str format_func = lambda i: i.example - return self._compile_f(schema, mode, tag, format_func) + return self._compile_f(schema, mode, tag, format_func, kv_sep="\n") def compile(self, context, schema="json", mode="children", template=SIMPLE_TEMPLATE) -> str: """ @@ -228,9 +226,16 @@ class ActionNode(Generic[T]): # FIXME: json instruction会带来格式问题,如:"Project name": "web_2048 # 项目名称使用下划线", # compile example暂时不支持markdown self.instruction = self.compile_instruction(schema="markdown", mode=mode) - self.example = self.compile_example(schema=schema, tag="CONTENT", mode=mode) + self.example = self.compile_example(schema=schema, tag=TAG, mode=mode) + # nodes = ", ".join(self.to_dict(mode=mode).keys()) + constraints = [LANGUAGE_CONSTRAINT, FORMAT_CONSTRAINT] + constraint = "\n".join(constraints) + prompt = template.format( - context=context, example=self.example, instruction=self.instruction, constraint=CONSTRAINT + context=context, + example=self.example, + instruction=self.instruction, + constraint=constraint, ) return prompt @@ -253,7 +258,7 @@ class ActionNode(Generic[T]): output_class = self.create_model_class(output_class_name, output_data_mapping) if schema == "json": - parsed_data = llm_output_postprecess(output=content, schema=output_class.schema(), req_key="[/CONTENT]") + parsed_data = llm_output_postprecess(output=content, schema=output_class.schema(), req_key=f"[/{TAG}]") else: # using markdown parser parsed_data = OutputParser.parse_data_with_mapping(content, output_data_mapping) diff --git a/metagpt/actions/write_code_an_draft.py b/metagpt/actions/write_code_an_draft.py new file mode 100644 index 000000000..968c8924b --- /dev/null +++ b/metagpt/actions/write_code_an_draft.py @@ -0,0 +1,591 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Author : alexanderwu +@File : write_review.py +""" +import asyncio +from typing import List + +from metagpt.actions import Action +from metagpt.actions.action_node import ActionNode + +REVIEW = ActionNode( + key="Review", + expected_type=List[str], + instruction="Act as an experienced reviewer and critically assess the given output. Provide specific and" + " constructive feedback, highlighting areas for improvement and suggesting changes.", + example=[ + "The logic in the function `calculate_total` seems flawed. Shouldn't it consider the discount rate as well?", + "The TODO function is not implemented yet? Should we implement it before commit?", + ], +) + +LGTM = ActionNode( + key="LGTM", + expected_type=str, + instruction="LGTM/LBTM. If the code is fully implemented, " + "give a LGTM (Looks Good To Me), otherwise provide a LBTM (Looks Bad To Me).", + example="LBTM", +) + +ACTIONS = ActionNode( + key="Actions", + expected_type=str, + instruction="Based on the code review outcome, suggest actionable steps. This can include code changes, " + "refactoring suggestions, or any follow-up tasks.", + example="""1. Refactor the `process_data` method to improve readability and efficiency. +2. Cover edge cases in the `validate_user` function. +3. Implement a the TODO in the `calculate_total` function. +4. Fix the `handle_events` method to update the game state only if a move is successful. + ```python + def handle_events(self): + for event in pygame.event.get(): + if event.type == pygame.QUIT: + return False + if event.type == pygame.KEYDOWN: + moved = False + if event.key == pygame.K_UP: + moved = self.game.move('UP') + elif event.key == pygame.K_DOWN: + moved = self.game.move('DOWN') + elif event.key == pygame.K_LEFT: + moved = self.game.move('LEFT') + elif event.key == pygame.K_RIGHT: + moved = self.game.move('RIGHT') + if moved: + # Update the game state only if a move was successful + self.render() + return True + ``` +""", +) + +WRITE_DRAFT = ActionNode( + key="WriteDraft", + expected_type=str, + instruction="Could you write draft code for move function in order to implement it?", + example="Draft: ...", +) + + +WRITE_MOVE_FUNCTION = ActionNode( + key="WriteFunction", + expected_type=str, + instruction="write code for the function not implemented.", + example=""" +```Code +... +``` +""", +) + + +REWRITE_CODE = ActionNode( + key="RewriteCode", + expected_type=str, + instruction="""rewrite code based on the Review and Actions""", + example=""" +```python +## example.py +def calculate_total(price, quantity): + total = price * quantity +``` +""", +) + + +CODE_REVIEW_CONTEXT = """ +# System +Role: You are a professional software engineer, and your main task is to review and revise the code. You need to ensure that the code conforms to the google-style standards, is elegantly designed and modularized, easy to read and maintain. +Language: Please use the same language as the user requirement, but the title and code should be still in English. For example, if the user speaks Chinese, the specific text of your answer should also be in Chinese. + +# Context +## System Design +{"Implementation approach": "我们将使用HTML、CSS和JavaScript来实现这个单机的响应式2048游戏。为了确保游戏性能流畅和响应式设计,我们会选择使用Vue.js框架,因为它易于上手且适合构建交互式界面。我们还将使用localStorage来记录玩家的最高分。", "File list": ["index.html", "styles.css", "main.js", "game.js", "storage.js"], "Data structures and interfaces": "classDiagram\ + class Game {\ + -board Array\ + -score Number\ + -bestScore Number\ + +constructor()\ + +startGame()\ + +move(direction: String)\ + +getBoard() Array\ + +getScore() Number\ + +getBestScore() Number\ + +setBestScore(score: Number)\ + }\ + class Storage {\ + +getBestScore() Number\ + +setBestScore(score: Number)\ + }\ + class Main {\ + +init()\ + +bindEvents()\ + }\ + Game --> Storage : uses\ + Main --> Game : uses", "Program call flow": "sequenceDiagram\ + participant M as Main\ + participant G as Game\ + participant S as Storage\ + M->>G: init()\ + G->>S: getBestScore()\ + S-->>G: return bestScore\ + M->>G: bindEvents()\ + M->>G: startGame()\ + loop Game Loop\ + M->>G: move(direction)\ + G->>S: setBestScore(score)\ + S-->>G: return\ + end", "Anything UNCLEAR": "目前项目要求明确,没有不清楚的地方。"} + +## Tasks +{"Required Python packages": ["无需Python包"], "Required Other language third-party packages": ["vue.js"], "Logic Analysis": [["index.html", "作为游戏的入口文件和主要的HTML结构"], ["styles.css", "包含所有的CSS样式,确保游戏界面美观"], ["main.js", "包含Main类,负责初始化游戏和绑定事件"], ["game.js", "包含Game类,负责游戏逻辑,如开始游戏、移动方块等"], ["storage.js", "包含Storage类,用于获取和设置玩家的最高分"]], "Task list": ["index.html", "styles.css", "storage.js", "game.js", "main.js"], "Full API spec": "", "Shared Knowledge": "\'game.js\' 包含游戏逻辑相关的函数,被 \'main.js\' 调用。", "Anything UNCLEAR": "目前项目要求明确,没有不清楚的地方。"} + +## Code Files +----- index.html + + + + + + 2048游戏 + + + + +
+

2048

+
+
+
分数
+
{{ score }}
+
+
+
最高分
+
{{ bestScore }}
+
+
+
+
+
+ {{ cell !== 0 ? cell : \'\' }} +
+
+
+ +
+ + + + + + + + +----- styles.css +/* styles.css */ +body, html { + margin: 0; + padding: 0; + font-family: \'Arial\', sans-serif; +} + +#app { + text-align: center; + font-size: 18px; + color: #776e65; +} + +h1 { + color: #776e65; + font-size: 72px; + font-weight: bold; + margin: 20px 0; +} + +.scores-container { + display: flex; + justify-content: center; + margin-bottom: 20px; +} + +.score-container, .best-container { + background: #bbada0; + padding: 10px; + border-radius: 5px; + margin: 0 10px; + min-width: 100px; + text-align: center; +} + +.score-header, .best-header { + color: #eee4da; + font-size: 18px; + margin-bottom: 5px; +} + +.game-container { + max-width: 500px; + margin: 0 auto 20px; + background: #bbada0; + padding: 15px; + border-radius: 10px; + position: relative; +} + +.grid-row { + display: flex; +} + +.grid-cell { + background: #cdc1b4; + width: 100px; + height: 100px; + margin: 5px; + display: flex; + justify-content: center; + align-items: center; + font-size: 35px; + font-weight: bold; + color: #776e65; + border-radius: 3px; +} + +/* Dynamic classes for different number cells */ +.number-cell-2 { + background: #eee4da; +} + +.number-cell-4 { + background: #ede0c8; +} + +.number-cell-8 { + background: #f2b179; + color: #f9f6f2; +} + +.number-cell-16 { + background: #f59563; + color: #f9f6f2; +} + +.number-cell-32 { + background: #f67c5f; + color: #f9f6f2; +} + +.number-cell-64 { + background: #f65e3b; + color: #f9f6f2; +} + +.number-cell-128 { + background: #edcf72; + color: #f9f6f2; +} + +.number-cell-256 { + background: #edcc61; + color: #f9f6f2; +} + +.number-cell-512 { + background: #edc850; + color: #f9f6f2; +} + +.number-cell-1024 { + background: #edc53f; + color: #f9f6f2; +} + +.number-cell-2048 { + background: #edc22e; + color: #f9f6f2; +} + +/* Larger numbers need smaller font sizes */ +.number-cell-1024, .number-cell-2048 { + font-size: 30px; +} + +button { + background-color: #8f7a66; + color: #f9f6f2; + border: none; + border-radius: 3px; + padding: 10px 20px; + font-size: 18px; + cursor: pointer; + outline: none; +} + +button:hover { + background-color: #9f8b76; +} + +----- storage.js +## storage.js +class Storage { + // 获取最高分 + getBestScore() { + // 尝试从localStorage中获取最高分,如果不存在则默认为0 + const bestScore = localStorage.getItem(\'bestScore\'); + return bestScore ? Number(bestScore) : 0; + } + + // 设置最高分 + setBestScore(score) { + // 将最高分设置到localStorage中 + localStorage.setItem(\'bestScore\', score.toString()); + } +} + + + +## Code to be Reviewed: game.js +```Code +## game.js +class Game { + constructor() { + this.board = this.createEmptyBoard(); + this.score = 0; + this.bestScore = 0; + } + + createEmptyBoard() { + const board = []; + for (let i = 0; i < 4; i++) { + board[i] = [0, 0, 0, 0]; + } + return board; + } + + startGame() { + this.board = this.createEmptyBoard(); + this.score = 0; + this.addRandomTile(); + this.addRandomTile(); + } + + addRandomTile() { + let emptyCells = []; + for (let r = 0; r < 4; r++) { + for (let c = 0; c < 4; c++) { + if (this.board[r][c] === 0) { + emptyCells.push({ r, c }); + } + } + } + if (emptyCells.length > 0) { + let randomCell = emptyCells[Math.floor(Math.random() * emptyCells.length)]; + this.board[randomCell.r][randomCell.c] = Math.random() < 0.9 ? 2 : 4; + } + } + + move(direction) { + // This function will handle the logic for moving tiles + // in the specified direction and merging them + // It will also update the score and add a new random tile if the move is successful + // The actual implementation of this function is complex and would require + // a significant amount of code to handle all the cases for moving and merging tiles + // For the purposes of this example, we will not implement the full logic + // Instead, we will just call addRandomTile to simulate a move + this.addRandomTile(); + } + + getBoard() { + return this.board; + } + + getScore() { + return this.score; + } + + getBestScore() { + return this.bestScore; + } + + setBestScore(score) { + this.bestScore = score; + } +} + +``` +""" + + +CODE_REVIEW_SMALLEST_CONTEXT = """ +## Code to be Reviewed: game.js +```Code +// game.js +class Game { + constructor() { + this.board = this.createEmptyBoard(); + this.score = 0; + this.bestScore = 0; + } + + createEmptyBoard() { + const board = []; + for (let i = 0; i < 4; i++) { + board[i] = [0, 0, 0, 0]; + } + return board; + } + + startGame() { + this.board = this.createEmptyBoard(); + this.score = 0; + this.addRandomTile(); + this.addRandomTile(); + } + + addRandomTile() { + let emptyCells = []; + for (let r = 0; r < 4; r++) { + for (let c = 0; c < 4; c++) { + if (this.board[r][c] === 0) { + emptyCells.push({ r, c }); + } + } + } + if (emptyCells.length > 0) { + let randomCell = emptyCells[Math.floor(Math.random() * emptyCells.length)]; + this.board[randomCell.r][randomCell.c] = Math.random() < 0.9 ? 2 : 4; + } + } + + move(direction) { + // This function will handle the logic for moving tiles + // in the specified direction and merging them + // It will also update the score and add a new random tile if the move is successful + // The actual implementation of this function is complex and would require + // a significant amount of code to handle all the cases for moving and merging tiles + // For the purposes of this example, we will not implement the full logic + // Instead, we will just call addRandomTile to simulate a move + this.addRandomTile(); + } + + getBoard() { + return this.board; + } + + getScore() { + return this.score; + } + + getBestScore() { + return this.bestScore; + } + + setBestScore(score) { + this.bestScore = score; + } +} + +``` +""" + + +CODE_REVIEW_SAMPLE = """ +## Code Review: game.js +1. The code partially implements the requirements. The `Game` class is missing the full implementation of the `move` method, which is crucial for the game\'s functionality. +2. The code logic is not completely correct. The `move` method is not implemented, which means the game cannot process player moves. +3. The existing code follows the "Data structures and interfaces" in terms of class structure but lacks full method implementations. +4. Not all functions are implemented. The `move` method is incomplete and does not handle the logic for moving and merging tiles. +5. All necessary pre-dependencies seem to be imported since the code does not indicate the need for additional imports. +6. The methods from other files (such as `Storage`) are not being used in the provided code snippet, but the class structure suggests that they will be used correctly. + +## Actions +1. Implement the `move` method to handle tile movements and merging. This is a complex task that requires careful consideration of the game\'s rules and logic. Here is a simplified version of how one might begin to implement the `move` method: + ```javascript + move(direction) { + // Simplified logic for moving tiles up + if (direction === \'up\') { + for (let col = 0; col < 4; col++) { + let tiles = this.board.map(row => row[col]).filter(val => val !== 0); + let merged = []; + for (let i = 0; i < tiles.length; i++) { + if (tiles[i] === tiles[i + 1]) { + tiles[i] *= 2; + this.score += tiles[i]; + tiles[i + 1] = 0; + merged.push(i); + } + } + tiles = tiles.filter(val => val !== 0); + while (tiles.length < 4) { + tiles.push(0); + } + for (let row = 0; row < 4; row++) { + this.board[row][col] = tiles[row]; + } + } + } + // Additional logic needed for \'down\', \'left\', \'right\' + // ... + this.addRandomTile(); + } + ``` +2. Integrate the `Storage` class methods to handle the best score. This means updating the `startGame` and `setBestScore` methods to use `Storage` for retrieving and setting the best score: + ```javascript + startGame() { + this.board = this.createEmptyBoard(); + this.score = 0; + this.bestScore = new Storage().getBestScore(); // Retrieve the best score from storage + this.addRandomTile(); + this.addRandomTile(); + } + + setBestScore(score) { + if (score > this.bestScore) { + this.bestScore = score; + new Storage().setBestScore(score); // Set the new best score in storage + } + } + ``` + +## Code Review Result +LBTM + +``` +""" + + +WRITE_CODE_NODE = ActionNode.from_children("WRITE_REVIEW_NODE", [REVIEW, LGTM, ACTIONS]) +WRITE_MOVE_NODE = ActionNode.from_children("WRITE_MOVE_NODE", [WRITE_DRAFT, WRITE_MOVE_FUNCTION]) + + +CR_FOR_MOVE_FUNCTION_BY_3 = """ +The move function implementation provided appears to be well-structured and follows a clear logic for moving and merging tiles in the specified direction. However, there are a few potential improvements that could be made to enhance the code: + +1. Encapsulation: The logic for moving and merging tiles could be encapsulated into smaller, reusable functions to improve readability and maintainability. + +2. Magic Numbers: There are some magic numbers (e.g., 4, 3) used in the loops that could be replaced with named constants for improved readability and easier maintenance. + +3. Comments: Adding comments to explain the logic and purpose of each section of the code can improve understanding for future developers who may need to work on or maintain the code. + +4. Error Handling: It's important to consider error handling for unexpected input or edge cases to ensure the function behaves as expected in all scenarios. + +Overall, the code could benefit from refactoring to improve readability, maintainability, and extensibility. If you would like, I can provide a refactored version of the move function that addresses these considerations. +""" + + +class WriteCodeAN(Action): + """Write a code review for the context.""" + + async def run(self, context): + self.llm.system_prompt = "You are an outstanding engineer and can implement any code" + return await WRITE_MOVE_FUNCTION.fill(context=context, llm=self.llm, schema="json") + # return await WRITE_CODE_NODE.fill(context=context, llm=self.llm, schema="markdown") + + +async def main(): + await WriteCodeAN().run(CODE_REVIEW_SMALLEST_CONTEXT) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/metagpt/actions/write_review.py b/metagpt/actions/write_review.py index 13690a1a5..8a4856317 100644 --- a/metagpt/actions/write_review.py +++ b/metagpt/actions/write_review.py @@ -31,8 +31,7 @@ WRITE_REVIEW_NODE = ActionNode.from_children("WRITE_REVIEW_NODE", [REVIEW, LGTM] class WriteReview(Action): - """This class allows LLM to further mine noteworthy details based on specific "##TOPIC"(discussion topic) and - "##RECORD" (discussion records), thereby deepening the discussion.""" + """Write a review for the given context.""" async def run(self, context): - return await WRITE_REVIEW_NODE.fill(context=context, llm=self.llm, schema="markdown") + return await WRITE_REVIEW_NODE.fill(context=context, llm=self.llm, schema="json") diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index ea3316d66..e5d4573e8 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -158,7 +158,8 @@ class OutputParser: @classmethod def parse_data_with_mapping(cls, data, mapping): - data = cls.extract_content(text=data) + if "[CONTENT]" in data: + data = cls.extract_content(text=data) block_dict = cls.parse_blocks(data) parsed_data = {} for block, content in block_dict.items(): diff --git a/tests/metagpt/test_prompt.py b/tests/metagpt/test_prompt.py new file mode 100644 index 000000000..f7b1cc68e --- /dev/null +++ b/tests/metagpt/test_prompt.py @@ -0,0 +1,342 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/11 14:45 +@Author : alexanderwu +@File : test_llm.py +""" + +import pytest + +from metagpt.llm import LLM + +CODE_REVIEW_SMALLEST_CONTEXT = """ +## game.js +```Code +// game.js +class Game { + constructor() { + this.board = this.createEmptyBoard(); + this.score = 0; + this.bestScore = 0; + } + + createEmptyBoard() { + const board = []; + for (let i = 0; i < 4; i++) { + board[i] = [0, 0, 0, 0]; + } + return board; + } + + startGame() { + this.board = this.createEmptyBoard(); + this.score = 0; + this.addRandomTile(); + this.addRandomTile(); + } + + addRandomTile() { + let emptyCells = []; + for (let r = 0; r < 4; r++) { + for (let c = 0; c < 4; c++) { + if (this.board[r][c] === 0) { + emptyCells.push({ r, c }); + } + } + } + if (emptyCells.length > 0) { + let randomCell = emptyCells[Math.floor(Math.random() * emptyCells.length)]; + this.board[randomCell.r][randomCell.c] = Math.random() < 0.9 ? 2 : 4; + } + } + + move(direction) { + // This function will handle the logic for moving tiles + // in the specified direction and merging them + // It will also update the score and add a new random tile if the move is successful + // The actual implementation of this function is complex and would require + // a significant amount of code to handle all the cases for moving and merging tiles + // For the purposes of this example, we will not implement the full logic + // Instead, we will just call addRandomTile to simulate a move + this.addRandomTile(); + } + + getBoard() { + return this.board; + } + + getScore() { + return this.score; + } + + getBestScore() { + return this.bestScore; + } + + setBestScore(score) { + this.bestScore = score; + } +} + +``` +""" + +MOVE_DRAFT = """ +## move function draft + +```javascript +move(direction) { + let moved = false; + switch (direction) { + case 'up': + for (let c = 0; c < 4; c++) { + for (let r = 1; r < 4; r++) { + if (this.board[r][c] !== 0) { + let row = r; + while (row > 0 && this.board[row - 1][c] === 0) { + this.board[row - 1][c] = this.board[row][c]; + this.board[row][c] = 0; + row--; + moved = true; + } + if (row > 0 && this.board[row - 1][c] === this.board[row][c]) { + this.board[row - 1][c] *= 2; + this.board[row][c] = 0; + this.score += this.board[row - 1][c]; + moved = true; + } + } + } + } + break; + case 'down': + // Implement logic for moving tiles down + // Similar to the 'up' case but iterating in reverse order + // and checking for merging in the opposite direction + break; + case 'left': + // Implement logic for moving tiles left + // Similar to the 'up' case but iterating over columns first + // and checking for merging in the opposite direction + break; + case 'right': + // Implement logic for moving tiles right + // Similar to the 'up' case but iterating over columns in reverse order + // and checking for merging in the opposite direction + break; + } + + if (moved) { + this.addRandomTile(); + } +} +``` +""" + +FUNCTION_TO_MERMAID_CLASS = """ +## context +``` +class UIDesign(Action): + #Class representing the UI Design action. + def __init__(self, name, context=None, llm=None): + super().__init__(name, context, llm) # 需要调用LLM进一步丰富UI设计的prompt + @parse + def parse_requirement(self, context: str): + #Parse UI Design draft from the context using regex. + pattern = r"## UI Design draft.*?\n(.*?)## Anything UNCLEAR" + return context, pattern + @parse + def parse_ui_elements(self, context: str): + #Parse Selected Elements from the context using regex. + pattern = r"## Selected Elements.*?\n(.*?)## HTML Layout" + return context, pattern + @parse + def parse_css_code(self, context: str): + pattern = r"```css.*?\n(.*?)## Anything UNCLEAR" + return context, pattern + @parse + def parse_html_code(self, context: str): + pattern = r"```html.*?\n(.*?)```" + return context, pattern + async def draw_icons(self, context, *args, **kwargs): + #Draw icons using SDEngine. + engine = SDEngine() + icon_prompts = self.parse_ui_elements(context) + icons = icon_prompts.split("\n") + icons = [s for s in icons if len(s.strip()) > 0] + prompts_batch = [] + for icon_prompt in icons: + # fixme: 添加icon lora + prompt = engine.construct_payload(icon_prompt + ".") + prompts_batch.append(prompt) + await engine.run_t2i(prompts_batch) + logger.info("Finish icon design using StableDiffusion API") + async def _save(self, css_content, html_content): + save_dir = CONFIG.workspace_path / "resources" / "codes" + if not os.path.exists(save_dir): + os.makedirs(save_dir, exist_ok=True) + # Save CSS and HTML content to files + css_file_path = save_dir / "ui_design.css" + html_file_path = save_dir / "ui_design.html" + with open(css_file_path, "w") as css_file: + css_file.write(css_content) + with open(html_file_path, "w") as html_file: + html_file.write(html_content) + async def run(self, requirements: list[Message], *args, **kwargs) -> ActionOutput: + #Run the UI Design action. + # fixme: update prompt (根据需求细化prompt) + context = requirements[-1].content + ui_design_draft = self.parse_requirement(context=context) + # todo: parse requirements str + prompt = PROMPT_TEMPLATE.format(context=ui_design_draft, format_example=FORMAT_EXAMPLE) + logger.info(prompt) + ui_describe = await self._aask_v1(prompt, "ui_design", OUTPUT_MAPPING) + logger.info(ui_describe.content) + logger.info(ui_describe.instruct_content) + css = self.parse_css_code(context=ui_describe.content) + html = self.parse_html_code(context=ui_describe.content) + await self._save(css_content=css, html_content=html) + await self.draw_icons(ui_describe.content) + return ui_describe +``` +----- +## format example +[CONTENT] +{ + "ClassView": "classDiagram\n class A {\n -int x\n +int y\n -int speed\n -int direction\n +__init__(x: int, y: int, speed: int, direction: int)\n +change_direction(new_direction: int) None\n +move() None\n }\n " +} +[/CONTENT] +## nodes: ": # " +- ClassView: # Generate the mermaid class diagram corresponding to source code in "context." +## constraint +- Language: Please use the same language as the user input. +- Format: output wrapped inside [CONTENT][/CONTENT] as format example, nothing else. +## action +Fill in the above nodes(ClassView) based on the format example. +""" + +MOVE_FUNCTION = """ +## move function implementation + +```javascript +move(direction) { + let moved = false; + switch (direction) { + case 'up': + for (let c = 0; c < 4; c++) { + for (let r = 1; r < 4; r++) { + if (this.board[r][c] !== 0) { + let row = r; + while (row > 0 && this.board[row - 1][c] === 0) { + this.board[row - 1][c] = this.board[row][c]; + this.board[row][c] = 0; + row--; + moved = true; + } + if (row > 0 && this.board[row - 1][c] === this.board[row][c]) { + this.board[row - 1][c] *= 2; + this.board[row][c] = 0; + this.score += this.board[row - 1][c]; + moved = true; + } + } + } + } + break; + case 'down': + for (let c = 0; c < 4; c++) { + for (let r = 2; r >= 0; r--) { + if (this.board[r][c] !== 0) { + let row = r; + while (row < 3 && this.board[row + 1][c] === 0) { + this.board[row + 1][c] = this.board[row][c]; + this.board[row][c] = 0; + row++; + moved = true; + } + if (row < 3 && this.board[row + 1][c] === this.board[row][c]) { + this.board[row + 1][c] *= 2; + this.board[row][c] = 0; + this.score += this.board[row + 1][c]; + moved = true; + } + } + } + } + break; + case 'left': + for (let r = 0; r < 4; r++) { + for (let c = 1; c < 4; c++) { + if (this.board[r][c] !== 0) { + let col = c; + while (col > 0 && this.board[r][col - 1] === 0) { + this.board[r][col - 1] = this.board[r][col]; + this.board[r][col] = 0; + col--; + moved = true; + } + if (col > 0 && this.board[r][col - 1] === this.board[r][col]) { + this.board[r][col - 1] *= 2; + this.board[r][col] = 0; + this.score += this.board[r][col - 1]; + moved = true; + } + } + } + } + break; + case 'right': + for (let r = 0; r < 4; r++) { + for (let c = 2; c >= 0; c--) { + if (this.board[r][c] !== 0) { + let col = c; + while (col < 3 && this.board[r][col + 1] === 0) { + this.board[r][col + 1] = this.board[r][col]; + this.board[r][col] = 0; + col++; + moved = true; + } + if (col < 3 && this.board[r][col + 1] === this.board[r][col]) { + this.board[r][col + 1] *= 2; + this.board[r][col] = 0; + this.score += this.board[r][col + 1]; + moved = true; + } + } + } + } + break; + } + + if (moved) { + this.addRandomTile(); + } +} +``` +""" + + +@pytest.fixture() +def llm(): + return LLM() + + +@pytest.mark.asyncio +async def test_llm_code_review(llm): + choices = [ + "Please review the move function code above. Should it be refactor?", + "Please implement the move function", + "Please write a draft for the move function in order to implement it", + ] + # prompt = CODE_REVIEW_SMALLEST_CONTEXT+ "\n\n" + MOVE_DRAFT + "\n\n" + choices[1] + # rsp = await llm.aask(prompt) + + prompt = CODE_REVIEW_SMALLEST_CONTEXT + "\n\n" + MOVE_FUNCTION + "\n\n" + choices[0] + prompt = FUNCTION_TO_MERMAID_CLASS + + _ = await llm.aask(prompt) + + +# if __name__ == "__main__": +# pytest.main([__file__, "-s"]) From e772ffdc1e8836d45428f1050dc50dad0c1a843b Mon Sep 17 00:00:00 2001 From: geekan Date: Thu, 21 Dec 2023 11:05:24 +0800 Subject: [PATCH 146/167] fix pydantic not support future issue --- metagpt/actions/action.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index 62434e7f8..cd2b5148f 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -8,7 +8,7 @@ from __future__ import annotations -from typing import Any, Optional +from typing import Any, Optional, Union from pydantic import BaseModel, Field @@ -27,7 +27,7 @@ action_subclass_registry = {} class Action(BaseModel): name: str = "" llm: BaseGPTAPI = Field(default_factory=LLM, exclude=True) - context: dict | CodingContext | CodeSummarizeContext | TestingContext | RunCodeContext | str | None = "" + context: Union[dict, CodingContext, CodeSummarizeContext, TestingContext, RunCodeContext, str, None] = "" prefix = "" # aask*时会加上prefix,作为system_message desc = "" # for skill manager # node: ActionNode = Field(default_factory=ActionNode, exclude=True) From 1564b1bf14000b014c928c35b0e286718225b31e Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Thu, 21 Dec 2023 11:47:29 +0800 Subject: [PATCH 147/167] upgrade openai 1.3.5 to 1.6.0 --- metagpt/provider/openai_api.py | 6 +++--- requirements.txt | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 9a328f386..e8023b717 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -9,7 +9,6 @@ import json import time from typing import NamedTuple, Union -import httpx from openai import ( APIConnectionError, AsyncAzureOpenAI, @@ -18,6 +17,7 @@ from openai import ( AzureOpenAI, OpenAI, ) +from openai._base_client import AsyncHttpxClientWrapper, SyncHttpxClientWrapper from openai.types import CompletionUsage from openai.types.chat import ChatCompletion, ChatCompletionChunk from tenacity import ( @@ -190,8 +190,8 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): # to use proxy, openai v1 needs http_client proxy_params = self._get_proxy_params() if proxy_params: - kwargs["http_client"] = httpx.Client(**proxy_params) - async_kwargs["http_client"] = httpx.AsyncClient(**proxy_params) + kwargs["http_client"] = SyncHttpxClientWrapper(**proxy_params) + async_kwargs["http_client"] = AsyncHttpxClientWrapper(**proxy_params) return kwargs, async_kwargs diff --git a/requirements.txt b/requirements.txt index c57fb6c2c..fd7a31607 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,7 +14,7 @@ langchain==0.0.231 loguru==0.6.0 meilisearch==0.21.0 numpy==1.24.3 -openai~=1.3 +openai==1.6.0 openpyxl beautifulsoup4==4.12.2 pandas==2.0.3 From c4fbc478d22ee0a1794e619866164f37d322ee73 Mon Sep 17 00:00:00 2001 From: better629 Date: Thu, 14 Dec 2023 16:45:40 +0800 Subject: [PATCH 148/167] add google gemini --- config/config.yaml | 4 + metagpt/config.py | 6 +- metagpt/provider/google_gemini_api.py | 130 ++++++++++++++++++ metagpt/utils/token_counter.py | 3 + requirements.txt | 1 + .../provider/test_google_gemini_api.py | 43 ++++++ 6 files changed, 186 insertions(+), 1 deletion(-) create mode 100644 metagpt/provider/google_gemini_api.py create mode 100644 tests/metagpt/provider/test_google_gemini_api.py diff --git a/config/config.yaml b/config/config.yaml index f547462ba..fc113370d 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -37,6 +37,10 @@ RPM: 10 #### if zhipuai from `https://open.bigmodel.cn`. You can set here or export API_KEY="YOUR_API_KEY" # ZHIPUAI_API_KEY: "YOUR_API_KEY" +#### if Google Gemini from `https://ai.google.dev/` and API_KEY from `https://makersuite.google.com/app/apikey`. +#### You can set here or export GOOGLE_API_KEY="YOUR_API_KEY" +# GEMINI_API_KEY: "YOUR_API_KEY" + #### if use self-host open llm model with openai-compatible interface #OPEN_LLM_API_BASE: "http://127.0.0.1:8000/v1" #OPEN_LLM_API_MODEL: "llama2-13b" diff --git a/metagpt/config.py b/metagpt/config.py index 131854a56..6ab537296 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -39,6 +39,7 @@ class LLMProviderEnum(Enum): ZHIPUAI = "zhipuai" FIREWORKS = "fireworks" OPEN_LLM = "open_llm" + GEMINI = "gemini" class Config(metaclass=Singleton): @@ -74,7 +75,8 @@ class Config(metaclass=Singleton): (self.anthropic_api_key, LLMProviderEnum.ANTHROPIC), (self.zhipuai_api_key, LLMProviderEnum.ZHIPUAI), (self.fireworks_api_key, LLMProviderEnum.FIREWORKS), - (self.open_llm_api_base, LLMProviderEnum.OPEN_LLM), # reuse logic. but not a key + (self.open_llm_api_base, LLMProviderEnum.OPEN_LLM), + (self.gemini_api_key, LLMProviderEnum.GEMINI), # reuse logic. but not a key ]: if self._is_valid_llm_key(k): if self.openai_api_model: @@ -96,6 +98,8 @@ class Config(metaclass=Singleton): self.open_llm_api_base = self._get("OPEN_LLM_API_BASE") self.open_llm_api_model = self._get("OPEN_LLM_API_MODEL") self.fireworks_api_key = self._get("FIREWORKS_API_KEY") + self.gemini_api_key = self._get("GEMINI_API_KEY") + _ = self.get_default_llm_provider_enum() self.openai_api_base = self._get("OPENAI_API_BASE") diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py new file mode 100644 index 000000000..1c866ebad --- /dev/null +++ b/metagpt/provider/google_gemini_api.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : Google Gemini LLM from https://ai.google.dev/tutorials/python_quickstart + +from tenacity import ( + after_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_fixed, +) +import google.generativeai as genai +from google.generativeai import client +from google.generativeai.types.generation_types import GenerateContentResponse, AsyncGenerateContentResponse +from google.generativeai.types.generation_types import GenerationConfig + +from metagpt.config import CONFIG +from metagpt.logs import logger +from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.openai_api import log_and_reraise + + +class GeminiGPTAPI(BaseGPTAPI): + """ + Refs to `https://ai.google.dev/tutorials/python_quickstart` + """ + + use_system_prompt: bool = False # google gemini has no system prompt when use api + + def __init__(self): + self.__init_gemini(CONFIG) + self.model = "gemini-pro" # so far only one model + self.llm = genai.GenerativeModel(model_name=self.model) + + def __init_gemini(self, config: CONFIG): + genai.configure(api_key=config.gemini_api_key) + + def _user_msg(self, msg: str) -> dict[str, str]: + return {"role": "user", "parts": [msg]} + + def _assistant_msg(self, msg: str) -> dict[str, str]: + return {"role": "model", "parts": [msg]} + + def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict: + kwargs = { + "contents": messages, + "generation_config": GenerationConfig( + temperature=0.3 + ), + "stream": stream + } + return kwargs + + def _update_costs(self, usage: dict): + """ update each request's token cost """ + if CONFIG.calc_usage: + try: + prompt_tokens = int(usage.get("prompt_tokens", 0)) + completion_tokens = int(usage.get("completion_tokens", 0)) + self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) + except Exception as e: + logger.error("google gemini updats costs failed!", e) + + def get_choice_text(self, resp: GenerateContentResponse) -> str: + return resp.text + + def get_usage(self, messages: list[dict], resp_text: str) -> dict: + prompt_resp = self.llm.count_tokens(contents=messages) + completion_resp = self.llm.count_tokens(contents={"parts": [resp_text]}) + usage = { + "prompt_tokens": prompt_resp.total_tokens, + "completion_tokens": completion_resp.total_tokens + } + return usage + + async def aget_usage(self, messages: list[dict], resp_text: str) -> dict: + # fix google-generativeai sdk + if self.llm._client is None: + self.llm._client = client.get_default_generative_client() + # TODO exception to fix + prompt_resp = await self.llm.count_tokens_async(contents=messages) + completion_resp = await self.llm.count_tokens_async(contents={"parts": [resp_text]}) + usage = { + "prompt_tokens": prompt_resp.total_tokens, + "completion_tokens": completion_resp.total_tokens + } + return usage + + def completion(self, messages: list[dict]) -> "GenerateContentResponse": + resp: GenerateContentResponse = self.llm.generate_content(**self._const_kwargs(messages)) + # usage = self.get_usage(messages, resp.text) + # self._update_costs(usage) + return resp + + async def _achat_completion(self, messages: list[dict]) -> "AsyncGenerateContentResponse": + resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(**self._const_kwargs(messages)) + # usage = await self.aget_usage(messages, resp.text) + # self._update_costs(usage) + return resp + + async def acompletion(self, messages: list[dict]) -> dict: + return await self._achat_completion(messages) + + async def _achat_completion_stream(self, messages: list[dict]) -> str: + resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(**self._const_kwargs(messages, + stream=True)) + collected_content = [] + async for chunk in resp: + content = chunk.text + print(content, end="") + collected_content.append(content) + + full_content = "".join(collected_content) + # usage = await self.aget_usage(messages, full_content) + # self._update_costs(usage) + return full_content + + @retry( + stop=stop_after_attempt(3), + wait=wait_fixed(1), + after=after_log(logger, logger.level("WARNING").name), + retry=retry_if_exception_type(ConnectionError), + retry_error_callback=log_and_reraise + ) + async def acompletion_text(self, messages: list[dict], stream=False) -> str: + """ response in async with stream or non-stream mode """ + if stream: + return await self._achat_completion_stream(messages) + resp = await self._achat_completion(messages) + return self.get_choice_text(resp) diff --git a/metagpt/utils/token_counter.py b/metagpt/utils/token_counter.py index ebfb85de7..512ff784c 100644 --- a/metagpt/utils/token_counter.py +++ b/metagpt/utils/token_counter.py @@ -7,6 +7,7 @@ ref1: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb ref2: https://github.com/Significant-Gravitas/Auto-GPT/blob/master/autogpt/llm/token_counter.py ref3: https://github.com/hwchase17/langchain/blob/master/langchain/chat_models/openai.py +ref4: https://ai.google.dev/models/gemini """ import tiktoken @@ -25,6 +26,7 @@ TOKEN_COSTS = { "gpt-4-1106-preview": {"prompt": 0.01, "completion": 0.03}, "text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0}, "chatglm_turbo": {"prompt": 0.0, "completion": 0.00069}, # 32k version, prompt + completion tokens=0.005¥/k-tokens + "gemini-pro": {"prompt": 0.00025, "completion": 0.0005} } @@ -43,6 +45,7 @@ TOKEN_MAX = { "gpt-4-1106-preview": 128000, "text-embedding-ada-002": 8192, "chatglm_turbo": 32768, + "gemini-pro": 32768 } diff --git a/requirements.txt b/requirements.txt index f5ef63c58..2b4e064ef 100644 --- a/requirements.txt +++ b/requirements.txt @@ -49,3 +49,4 @@ aiofiles==23.2.1 gitpython==3.1.40 zhipuai==1.0.7 gitignore-parser==0.1.9 +google-generativeai==0.3.1 diff --git a/tests/metagpt/provider/test_google_gemini_api.py b/tests/metagpt/provider/test_google_gemini_api.py new file mode 100644 index 000000000..32ed11ba5 --- /dev/null +++ b/tests/metagpt/provider/test_google_gemini_api.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : the unittest of google gemini api + +import pytest +from abc import ABC +from dataclasses import dataclass + +from metagpt.provider.google_gemini_api import GeminiGPTAPI + + +messages = [ + {"role": "user", "content": "who are you"} +] + + +@dataclass +class MockGeminiResponse(ABC): + text: str + + +default_resp = MockGeminiResponse(text="I'm gemini from google") + + +def mock_llm_ask(self, messages: list[dict]) -> MockGeminiResponse: + return default_resp + + +def test_gemini_completion(mocker): + mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.completion", mock_llm_ask) + resp = GeminiGPTAPI().completion(messages) + assert resp.text == default_resp.text + + +async def mock_llm_aask(self, messgaes: list[dict]) -> MockGeminiResponse: + return default_resp + + +@pytest.mark.asyncio +async def test_gemini_acompletion(mocker): + mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.acompletion", mock_llm_aask) + resp = await GeminiGPTAPI().acompletion(messages) + assert resp.text == default_resp.text From 91d1ab20cc21eccf5966cf507b08087af4cadda6 Mon Sep 17 00:00:00 2001 From: better629 Date: Thu, 14 Dec 2023 16:54:56 +0800 Subject: [PATCH 149/167] update gemini user_msg doc --- metagpt/provider/google_gemini_api.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index 1c866ebad..a69ffdc28 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -36,6 +36,8 @@ class GeminiGPTAPI(BaseGPTAPI): genai.configure(api_key=config.gemini_api_key) def _user_msg(self, msg: str) -> dict[str, str]: + # Not to change BaseGPTAPI default functions but update with Gemini's conversation format. + # You should follow the format. return {"role": "user", "parts": [msg]} def _assistant_msg(self, msg: str) -> dict[str, str]: From 02090af7cb1a315b2b59ea843fa7aa8bb816cf4e Mon Sep 17 00:00:00 2001 From: better629 Date: Fri, 15 Dec 2023 17:06:59 +0800 Subject: [PATCH 150/167] update gemini count_tokens --- metagpt/provider/google_gemini_api.py | 56 ++++++++++++++++++--------- metagpt/provider/zhipuai_api.py | 2 +- 2 files changed, 39 insertions(+), 19 deletions(-) diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index a69ffdc28..0ba1e86c1 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -10,14 +10,35 @@ from tenacity import ( wait_fixed, ) import google.generativeai as genai -from google.generativeai import client +from google.ai import generativelanguage as glm +from google.generativeai.types import content_types +from google.generativeai.generative_models import GenerativeModel from google.generativeai.types.generation_types import GenerateContentResponse, AsyncGenerateContentResponse from google.generativeai.types.generation_types import GenerationConfig from metagpt.config import CONFIG from metagpt.logs import logger from metagpt.provider.base_gpt_api import BaseGPTAPI -from metagpt.provider.openai_api import log_and_reraise +from metagpt.provider.openai_api import CostManager, log_and_reraise + + +class GeminiGenerativeModel(GenerativeModel): + """ + Due to `https://github.com/google/generative-ai-python/pull/123`, inherit a new class. + Will use default GenerativeModel if it fixed. + """ + + def count_tokens( + self, contents: content_types.ContentsType + ) -> glm.CountTokensResponse: + contents = content_types.to_contents(contents) + return self._client.count_tokens(model=self.model_name, contents=contents) + + async def count_tokens_async( + self, contents: content_types.ContentsType + ) -> glm.CountTokensResponse: + contents = content_types.to_contents(contents) + return await self._async_client.count_tokens(model=self.model_name, contents=contents) class GeminiGPTAPI(BaseGPTAPI): @@ -30,7 +51,8 @@ class GeminiGPTAPI(BaseGPTAPI): def __init__(self): self.__init_gemini(CONFIG) self.model = "gemini-pro" # so far only one model - self.llm = genai.GenerativeModel(model_name=self.model) + self.llm = GeminiGenerativeModel(model_name=self.model) + self._cost_manager = CostManager() def __init_gemini(self, config: CONFIG): genai.configure(api_key=config.gemini_api_key) @@ -61,14 +83,15 @@ class GeminiGPTAPI(BaseGPTAPI): completion_tokens = int(usage.get("completion_tokens", 0)) self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) except Exception as e: - logger.error("google gemini updats costs failed!", e) + logger.error(f"google gemini updats costs failed! exp: {e}") def get_choice_text(self, resp: GenerateContentResponse) -> str: return resp.text def get_usage(self, messages: list[dict], resp_text: str) -> dict: - prompt_resp = self.llm.count_tokens(contents=messages) - completion_resp = self.llm.count_tokens(contents={"parts": [resp_text]}) + req_text = messages[-1]["parts"][0] if messages else "" + prompt_resp = self.llm.count_tokens(contents={"role": "user", "parts": [{"text": req_text}]}) + completion_resp = self.llm.count_tokens(contents={"role": "model", "parts": [{"text": resp_text}]}) usage = { "prompt_tokens": prompt_resp.total_tokens, "completion_tokens": completion_resp.total_tokens @@ -76,12 +99,9 @@ class GeminiGPTAPI(BaseGPTAPI): return usage async def aget_usage(self, messages: list[dict], resp_text: str) -> dict: - # fix google-generativeai sdk - if self.llm._client is None: - self.llm._client = client.get_default_generative_client() - # TODO exception to fix - prompt_resp = await self.llm.count_tokens_async(contents=messages) - completion_resp = await self.llm.count_tokens_async(contents={"parts": [resp_text]}) + req_text = messages[-1]["parts"][0] if messages else "" + prompt_resp = await self.llm.count_tokens_async(contents={"role": "user", "parts": [{"text": req_text}]}) + completion_resp = await self.llm.count_tokens_async(contents={"role": "model", "parts": [{"text": resp_text}]}) usage = { "prompt_tokens": prompt_resp.total_tokens, "completion_tokens": completion_resp.total_tokens @@ -90,14 +110,14 @@ class GeminiGPTAPI(BaseGPTAPI): def completion(self, messages: list[dict]) -> "GenerateContentResponse": resp: GenerateContentResponse = self.llm.generate_content(**self._const_kwargs(messages)) - # usage = self.get_usage(messages, resp.text) - # self._update_costs(usage) + usage = self.get_usage(messages, resp.text) + self._update_costs(usage) return resp async def _achat_completion(self, messages: list[dict]) -> "AsyncGenerateContentResponse": resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(**self._const_kwargs(messages)) - # usage = await self.aget_usage(messages, resp.text) - # self._update_costs(usage) + usage = await self.aget_usage(messages, resp.text) + self._update_costs(usage) return resp async def acompletion(self, messages: list[dict]) -> dict: @@ -113,8 +133,8 @@ class GeminiGPTAPI(BaseGPTAPI): collected_content.append(content) full_content = "".join(collected_content) - # usage = await self.aget_usage(messages, full_content) - # self._update_costs(usage) + usage = await self.aget_usage(messages, full_content) + self._update_costs(usage) return full_content @retry( diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py index eef0e51e1..60d9a0777 100644 --- a/metagpt/provider/zhipuai_api.py +++ b/metagpt/provider/zhipuai_api.py @@ -63,7 +63,7 @@ class ZhiPuAIGPTAPI(BaseGPTAPI): completion_tokens = int(usage.get("completion_tokens", 0)) self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) except Exception as e: - logger.error("zhipuai updats costs failed!", e) + logger.error(f"zhipuai updats costs failed! exp: {e}") def get_choice_text(self, resp: dict) -> str: """get the first text of choice from llm response""" From e5a7fdfe3b7168341a7b5b1903288fdbe99a7dd1 Mon Sep 17 00:00:00 2001 From: better629 Date: Fri, 15 Dec 2023 17:30:25 +0800 Subject: [PATCH 151/167] retry use wait_random_exponential --- metagpt/provider/google_gemini_api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index 0ba1e86c1..b68e013a0 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -7,7 +7,7 @@ from tenacity import ( retry, retry_if_exception_type, stop_after_attempt, - wait_fixed, + wait_random_exponential, ) import google.generativeai as genai from google.ai import generativelanguage as glm @@ -139,7 +139,7 @@ class GeminiGPTAPI(BaseGPTAPI): @retry( stop=stop_after_attempt(3), - wait=wait_fixed(1), + wait=wait_random_exponential(min=1, max=60), after=after_log(logger, logger.level("WARNING").name), retry=retry_if_exception_type(ConnectionError), retry_error_callback=log_and_reraise From 163da9a2e7dd19de9be4746a243fb45c1ba9afdd Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Thu, 21 Dec 2023 12:44:43 +0800 Subject: [PATCH 152/167] format code --- metagpt/provider/openai_api.py | 1 - metagpt/roles/researcher.py | 2 +- metagpt/tools/web_browser_engine_selenium.py | 4 ++-- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index ed1afd6e7..dbafa31b7 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -28,7 +28,6 @@ from tenacity import ( wait_random_exponential, ) - from metagpt.config import CONFIG, Config, LLMProviderEnum from metagpt.logs import logger from metagpt.provider.base_gpt_api import BaseGPTAPI diff --git a/metagpt/roles/researcher.py b/metagpt/roles/researcher.py index 52c55f0ca..e894d1a57 100644 --- a/metagpt/roles/researcher.py +++ b/metagpt/roles/researcher.py @@ -70,7 +70,7 @@ class Researcher(Role): return ret def research_system_text(self, topic, current_task: Action) -> str: - """ BACKWARD compatible + """BACKWARD compatible This allows sub-class able to define its own system prompt based on topic. return the previous implementation to have backward compatible Args: diff --git a/metagpt/tools/web_browser_engine_selenium.py b/metagpt/tools/web_browser_engine_selenium.py index 074943892..decab2b7d 100644 --- a/metagpt/tools/web_browser_engine_selenium.py +++ b/metagpt/tools/web_browser_engine_selenium.py @@ -106,8 +106,8 @@ def _gen_get_driver_func(browser_type, *args, executable_path=None): options.add_argument("--headless") options.add_argument("--enable-javascript") if browser_type == "chrome": - options.add_argument("--disable-gpu") # This flag can help avoid renderer issue - options.add_argument("--disable-dev-shm-usage") # Overcome limited resource problems + options.add_argument("--disable-gpu") # This flag can help avoid renderer issue + options.add_argument("--disable-dev-shm-usage") # Overcome limited resource problems options.add_argument("--no-sandbox") for i in args: options.add_argument(i) From f3eb9f638efd3bdd08023a996985098959116dfd Mon Sep 17 00:00:00 2001 From: better629 Date: Thu, 21 Dec 2023 12:55:45 +0800 Subject: [PATCH 153/167] add other llm for LLMProviderRegistry --- metagpt/config.py | 2 +- metagpt/provider/__init__.py | 13 +++++++++++-- metagpt/provider/google_gemini_api.py | 20 +++++++++++--------- 3 files changed, 23 insertions(+), 12 deletions(-) diff --git a/metagpt/config.py b/metagpt/config.py index 6ab537296..27d4488e0 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -79,7 +79,7 @@ class Config(metaclass=Singleton): (self.gemini_api_key, LLMProviderEnum.GEMINI), # reuse logic. but not a key ]: if self._is_valid_llm_key(k): - if self.openai_api_model: + if self.openai_api_key and self.openai_api_model: logger.info(f"OpenAI API Model: {self.openai_api_model}") return v raise NotConfiguredException("You should config a LLM configuration first") diff --git a/metagpt/provider/__init__.py b/metagpt/provider/__init__.py index 56dc19b4b..028c6f837 100644 --- a/metagpt/provider/__init__.py +++ b/metagpt/provider/__init__.py @@ -6,7 +6,16 @@ @File : __init__.py """ +from metagpt.provider.fireworks_api import FireWorksGPTAPI +from metagpt.provider.google_gemini_api import GeminiGPTAPI +from metagpt.provider.open_llm_api import OpenLLMGPTAPI from metagpt.provider.openai_api import OpenAIGPTAPI +from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI - -__all__ = ["OpenAIGPTAPI"] +__all__ = [ + "FireWorksGPTAPI", + "GeminiGPTAPI", + "OpenLLMGPTAPI", + "OpenAIGPTAPI", + "ZhiPuAIGPTAPI" +] diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index b68e013a0..213b53263 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -2,6 +2,12 @@ # -*- coding: utf-8 -*- # @Desc : Google Gemini LLM from https://ai.google.dev/tutorials/python_quickstart +import google.generativeai as genai +from google.ai import generativelanguage as glm +from google.generativeai.generative_models import GenerativeModel +from google.generativeai.types import content_types +from google.generativeai.types.generation_types import GenerateContentResponse, AsyncGenerateContentResponse +from google.generativeai.types.generation_types import GenerationConfig from tenacity import ( after_log, retry, @@ -9,16 +15,11 @@ from tenacity import ( stop_after_attempt, wait_random_exponential, ) -import google.generativeai as genai -from google.ai import generativelanguage as glm -from google.generativeai.types import content_types -from google.generativeai.generative_models import GenerativeModel -from google.generativeai.types.generation_types import GenerateContentResponse, AsyncGenerateContentResponse -from google.generativeai.types.generation_types import GenerationConfig -from metagpt.config import CONFIG +from metagpt.config import CONFIG, LLMProviderEnum from metagpt.logs import logger from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.llm_provider_registry import register_provider from metagpt.provider.openai_api import CostManager, log_and_reraise @@ -29,18 +30,19 @@ class GeminiGenerativeModel(GenerativeModel): """ def count_tokens( - self, contents: content_types.ContentsType + self, contents: content_types.ContentsType ) -> glm.CountTokensResponse: contents = content_types.to_contents(contents) return self._client.count_tokens(model=self.model_name, contents=contents) async def count_tokens_async( - self, contents: content_types.ContentsType + self, contents: content_types.ContentsType ) -> glm.CountTokensResponse: contents = content_types.to_contents(contents) return await self._async_client.count_tokens(model=self.model_name, contents=contents) +@register_provider(LLMProviderEnum.GEMINI) class GeminiGPTAPI(BaseGPTAPI): """ Refs to `https://ai.google.dev/tutorials/python_quickstart` From bd3d5fe1f3088b9233d7738b278462161397f2ec Mon Sep 17 00:00:00 2001 From: geekan Date: Thu, 21 Dec 2023 13:59:00 +0800 Subject: [PATCH 154/167] fix installation --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 0e8e3650b..be7c477bf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -36,7 +36,7 @@ tqdm==4.64.0 # webdriver_manager<3.9 anthropic==0.3.6 typing-inspect==0.8.0 -typing_extensions==4.5.0 +typing_extensions==4.7.0 libcst==1.0.1 qdrant-client==1.4.0 pytest-mock==3.11.1 From 44e648eabf7f56478a465e83726fc37396ad4641 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E6=A3=92=E6=A3=92?= Date: Thu, 21 Dec 2023 14:17:05 +0800 Subject: [PATCH 155/167] Message(msg) -> Message(content=msg) --- metagpt/provider/openai_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index dbafa31b7..f6661e79a 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -302,7 +302,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): def _process_message(self, messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]: """convert messages to list[dict].""" if isinstance(messages, list): - messages = [Message(msg) if isinstance(msg, str) else msg for msg in messages] + messages = [Message(content=msg) if isinstance(msg, str) else msg for msg in messages] return [msg if isinstance(msg, dict) else msg.to_dict() for msg in messages] if isinstance(messages, Message): From bdb427d5b785222701ef2e49c09bb0a2a1b40654 Mon Sep 17 00:00:00 2001 From: better629 Date: Thu, 21 Dec 2023 14:18:50 +0800 Subject: [PATCH 156/167] add gemini minimal python version warning --- metagpt/config.py | 5 +++++ metagpt/provider/google_gemini_api.py | 3 +-- metagpt/utils/common.py | 9 ++++++++- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/metagpt/config.py b/metagpt/config.py index 27d4488e0..727b37b9c 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -7,6 +7,7 @@ Provide configuration, singleton 2. Add the parameter `src_workspace` for the old version project path. """ import os +import warnings from copy import deepcopy from enum import Enum from pathlib import Path @@ -17,6 +18,7 @@ import yaml from metagpt.const import DEFAULT_WORKSPACE_ROOT, METAGPT_ROOT, OPTIONS from metagpt.logs import logger from metagpt.tools import SearchEngineType, WebBrowserEngineType +from metagpt.utils.common import require_python_version from metagpt.utils.singleton import Singleton @@ -79,6 +81,9 @@ class Config(metaclass=Singleton): (self.gemini_api_key, LLMProviderEnum.GEMINI), # reuse logic. but not a key ]: if self._is_valid_llm_key(k): + logger.info(f"Use LLMProvider: {v.value}") + if v == LLMProviderEnum.GEMINI and not require_python_version(req_version=(3, 10)): + warnings.warn("Use Gemini requires Python >= 3.10") if self.openai_api_key and self.openai_api_model: logger.info(f"OpenAI API Model: {self.openai_api_model}") return v diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index 213b53263..10215e2d9 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -48,9 +48,8 @@ class GeminiGPTAPI(BaseGPTAPI): Refs to `https://ai.google.dev/tutorials/python_quickstart` """ - use_system_prompt: bool = False # google gemini has no system prompt when use api - def __init__(self): + self.use_system_prompt = False # google gemini has no system prompt when use api self.__init_gemini(CONFIG) self.model = "gemini-pro" # so far only one model self.llm = GeminiGenerativeModel(model_name=self.model) diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index e5d4573e8..eec4176df 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -19,6 +19,7 @@ import json import os import platform import re +import sys import traceback import typing from pathlib import Path @@ -47,6 +48,12 @@ def check_cmd_exists(command) -> int: return result +def require_python_version(req_version: tuple[int]) -> bool: + if not (2 <= len(req_version) <= 3): + raise ValueError("req_version should be (3, 9) or (3, 10, 13)") + return True if sys.version_info > req_version else False + + class OutputParser: @classmethod def parse_blocks(cls, text: str): @@ -219,7 +226,7 @@ class OutputParser: if start_index != -1 and end_index != -1: # Extract the structure part - structure_text = text[start_index : end_index + 1] + structure_text = text[start_index: end_index + 1] try: # Attempt to convert the text to a Python data type using ast.literal_eval From d46b7c4018b107d693937a0228ec43d761d66ae0 Mon Sep 17 00:00:00 2001 From: geekan Date: Thu, 21 Dec 2023 14:45:53 +0800 Subject: [PATCH 157/167] fix moderation, remove claude from LLM, refine exceptions handler --- examples/llm_hello_world.py | 9 +++------ metagpt/provider/openai_api.py | 29 +++-------------------------- metagpt/tools/moderation.py | 22 ++++++++++------------ metagpt/utils/exceptions.py | 6 ++++-- 4 files changed, 20 insertions(+), 46 deletions(-) diff --git a/examples/llm_hello_world.py b/examples/llm_hello_world.py index 677098399..76be1cc90 100644 --- a/examples/llm_hello_world.py +++ b/examples/llm_hello_world.py @@ -7,23 +7,20 @@ """ import asyncio -from metagpt.llm import LLM, Claude +from metagpt.llm import LLM from metagpt.logs import logger async def main(): llm = LLM() - claude = Claude() - logger.info(await claude.aask("你好,请进行自我介绍")) logger.info(await llm.aask("hello world")) logger.info(await llm.aask_batch(["hi", "write python hello world."])) hello_msg = [{"role": "user", "content": "count from 1 to 10. split by newline."}] logger.info(await llm.acompletion(hello_msg)) - logger.info(await llm.acompletion_batch([hello_msg])) - logger.info(await llm.acompletion_batch_text([hello_msg])) - logger.info(await llm.acompletion_text(hello_msg)) + + # streaming mode, much slower await llm.acompletion_text(hello_msg, stream=True) diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index dbafa31b7..b6c1fbe55 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -34,6 +34,7 @@ from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA, GENERAL_TOOL_CHOICE from metagpt.provider.llm_provider_registry import register_provider from metagpt.schema import Message +from metagpt.utils.exceptions import handle_exception from metagpt.utils.singleton import Singleton from metagpt.utils.token_counter import ( TOKEN_COSTS, @@ -420,30 +421,6 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): return CONFIG.max_tokens_rsp return get_max_completion_tokens(messages, self.model, CONFIG.max_tokens_rsp) - def moderation(self, content: Union[str, list[str]]): - try: - if not content: - logger.error("content cannot be empty!") - else: - rsp = self._moderation(content=content) - return rsp - except Exception as e: - logger.error(f"moderating failed:{e}") - - def _moderation(self, content: Union[str, list[str]]): - rsp = self.client.moderations.create(input=content) - return rsp - + @handle_exception async def amoderation(self, content: Union[str, list[str]]): - try: - if not content: - logger.error("content cannot be empty!") - else: - rsp = await self._amoderation(content=content) - return rsp - except Exception as e: - logger.error(f"moderating failed:{e}") - - async def _amoderation(self, content: Union[str, list[str]]): - rsp = await self.async_client.moderations.create(input=content) - return rsp + return await self.async_client.moderations.create(input=content) diff --git a/metagpt/tools/moderation.py b/metagpt/tools/moderation.py index c56a6afc4..5532e4f66 100644 --- a/metagpt/tools/moderation.py +++ b/metagpt/tools/moderation.py @@ -5,6 +5,7 @@ @Author : zhanglei @File : moderation.py """ +import asyncio from typing import Union from metagpt.llm import LLM @@ -14,16 +15,6 @@ class Moderation: def __init__(self): self.llm = LLM() - def moderation(self, content: Union[str, list[str]]): - resp = [] - if content: - moderation_results = self.llm.moderation(content=content) - results = moderation_results.results - for item in results: - resp.append(item.flagged) - - return resp - async def amoderation(self, content: Union[str, list[str]]): resp = [] if content: @@ -35,6 +26,13 @@ class Moderation: return resp -if __name__ == "__main__": +async def main(): moderation = Moderation() - print(moderation.moderation(content=["I will kill you", "The weather is really nice today", "I want to hit you"])) + rsp = await moderation.amoderation( + content=["I will kill you", "The weather is really nice today", "I want to hit you"] + ) + print(rsp) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/metagpt/utils/exceptions.py b/metagpt/utils/exceptions.py index b4b5aa590..70ed45910 100644 --- a/metagpt/utils/exceptions.py +++ b/metagpt/utils/exceptions.py @@ -21,6 +21,7 @@ def handle_exception( _func: Callable[..., ReturnType] = None, *, exception_type: Union[Type[Exception], Tuple[Type[Exception], ...]] = Exception, + exception_msg: str = "", default_return: Any = None, ) -> Callable[..., ReturnType]: """handle exception, return default value""" @@ -32,8 +33,9 @@ def handle_exception( return await func(*args, **kwargs) except exception_type as e: logger.opt(depth=1).error( - f"Calling {func.__name__} with args: {args}, kwargs: {kwargs} failed: {e}, " - f"stack: {traceback.format_exc()}" + f"{e}: {exception_msg}, " + f"\nCalling {func.__name__} with args: {args}, kwargs: {kwargs} " + f"\nStack: {traceback.format_exc()}" ) return default_return From 18a195a3678dd5c23c9666a57742eeb5bdec943a Mon Sep 17 00:00:00 2001 From: better629 Date: Thu, 21 Dec 2023 14:46:33 +0800 Subject: [PATCH 158/167] update config --- metagpt/config.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/metagpt/config.py b/metagpt/config.py index be0d6ec41..963fe3b05 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -81,7 +81,7 @@ class Config(metaclass=Singleton): (self.gemini_api_key, LLMProviderEnum.GEMINI), # reuse logic. but not a key ]: if self._is_valid_llm_key(k): - logger.info(f"Use LLMProvider: {v.value}") + # logger.debug(f"Use LLMProvider: {v.value}") if v == LLMProviderEnum.GEMINI and not require_python_version(req_version=(3, 10)): warnings.warn("Use Gemini requires Python >= 3.10") if self.openai_api_key and self.openai_api_model: @@ -94,7 +94,6 @@ class Config(metaclass=Singleton): return k and k != "YOUR_API_KEY" def _update(self): - # logger.info("Config loading done.") self.global_proxy = self._get("GLOBAL_PROXY") self.openai_api_key = self._get("OPENAI_API_KEY") From 6af9fecf65cb80f35d8fb1d56d6a6a01fe3504a5 Mon Sep 17 00:00:00 2001 From: better629 Date: Thu, 21 Dec 2023 15:06:59 +0800 Subject: [PATCH 159/167] fix format --- metagpt/provider/__init__.py | 8 +--- metagpt/provider/google_gemini_api.py | 44 +++++++------------ metagpt/roles/researcher.py | 2 +- metagpt/tools/web_browser_engine_selenium.py | 4 +- metagpt/utils/common.py | 2 +- metagpt/utils/token_counter.py | 4 +- .../provider/test_google_gemini_api.py | 8 ++-- 7 files changed, 26 insertions(+), 46 deletions(-) diff --git a/metagpt/provider/__init__.py b/metagpt/provider/__init__.py index 028c6f837..a9f46eb03 100644 --- a/metagpt/provider/__init__.py +++ b/metagpt/provider/__init__.py @@ -12,10 +12,4 @@ from metagpt.provider.open_llm_api import OpenLLMGPTAPI from metagpt.provider.openai_api import OpenAIGPTAPI from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI -__all__ = [ - "FireWorksGPTAPI", - "GeminiGPTAPI", - "OpenLLMGPTAPI", - "OpenAIGPTAPI", - "ZhiPuAIGPTAPI" -] +__all__ = ["FireWorksGPTAPI", "GeminiGPTAPI", "OpenLLMGPTAPI", "OpenAIGPTAPI", "ZhiPuAIGPTAPI"] diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index 631da1052..682f7b507 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -6,8 +6,11 @@ import google.generativeai as genai from google.ai import generativelanguage as glm from google.generativeai.generative_models import GenerativeModel from google.generativeai.types import content_types -from google.generativeai.types.generation_types import GenerateContentResponse, AsyncGenerateContentResponse -from google.generativeai.types.generation_types import GenerationConfig +from google.generativeai.types.generation_types import ( + AsyncGenerateContentResponse, + GenerateContentResponse, + GenerationConfig, +) from tenacity import ( after_log, retry, @@ -29,15 +32,11 @@ class GeminiGenerativeModel(GenerativeModel): Will use default GenerativeModel if it fixed. """ - def count_tokens( - self, contents: content_types.ContentsType - ) -> glm.CountTokensResponse: + def count_tokens(self, contents: content_types.ContentsType) -> glm.CountTokensResponse: contents = content_types.to_contents(contents) return self._client.count_tokens(model=self.model_name, contents=contents) - async def count_tokens_async( - self, contents: content_types.ContentsType - ) -> glm.CountTokensResponse: + async def count_tokens_async(self, contents: content_types.ContentsType) -> glm.CountTokensResponse: contents = content_types.to_contents(contents) return await self._async_client.count_tokens(model=self.model_name, contents=contents) @@ -68,17 +67,11 @@ class GeminiGPTAPI(BaseGPTAPI): return {"role": "model", "parts": [msg]} def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict: - kwargs = { - "contents": messages, - "generation_config": GenerationConfig( - temperature=0.3 - ), - "stream": stream - } + kwargs = {"contents": messages, "generation_config": GenerationConfig(temperature=0.3), "stream": stream} return kwargs def _update_costs(self, usage: dict): - """ update each request's token cost """ + """update each request's token cost""" if CONFIG.calc_usage: try: prompt_tokens = int(usage.get("prompt_tokens", 0)) @@ -94,20 +87,14 @@ class GeminiGPTAPI(BaseGPTAPI): req_text = messages[-1]["parts"][0] if messages else "" prompt_resp = self.llm.count_tokens(contents={"role": "user", "parts": [{"text": req_text}]}) completion_resp = self.llm.count_tokens(contents={"role": "model", "parts": [{"text": resp_text}]}) - usage = { - "prompt_tokens": prompt_resp.total_tokens, - "completion_tokens": completion_resp.total_tokens - } + usage = {"prompt_tokens": prompt_resp.total_tokens, "completion_tokens": completion_resp.total_tokens} return usage async def aget_usage(self, messages: list[dict], resp_text: str) -> dict: req_text = messages[-1]["parts"][0] if messages else "" prompt_resp = await self.llm.count_tokens_async(contents={"role": "user", "parts": [{"text": req_text}]}) completion_resp = await self.llm.count_tokens_async(contents={"role": "model", "parts": [{"text": resp_text}]}) - usage = { - "prompt_tokens": prompt_resp.total_tokens, - "completion_tokens": completion_resp.total_tokens - } + usage = {"prompt_tokens": prompt_resp.total_tokens, "completion_tokens": completion_resp.total_tokens} return usage def completion(self, messages: list[dict]) -> "GenerateContentResponse": @@ -126,8 +113,9 @@ class GeminiGPTAPI(BaseGPTAPI): return await self._achat_completion(messages) async def _achat_completion_stream(self, messages: list[dict]) -> str: - resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(**self._const_kwargs(messages, - stream=True)) + resp: AsyncGenerateContentResponse = await self.llm.generate_content_async( + **self._const_kwargs(messages, stream=True) + ) collected_content = [] async for chunk in resp: content = chunk.text @@ -144,10 +132,10 @@ class GeminiGPTAPI(BaseGPTAPI): wait=wait_random_exponential(min=1, max=60), after=after_log(logger, logger.level("WARNING").name), retry=retry_if_exception_type(ConnectionError), - retry_error_callback=log_and_reraise + retry_error_callback=log_and_reraise, ) async def acompletion_text(self, messages: list[dict], stream=False) -> str: - """ response in async with stream or non-stream mode """ + """response in async with stream or non-stream mode""" if stream: return await self._achat_completion_stream(messages) resp = await self._achat_completion(messages) diff --git a/metagpt/roles/researcher.py b/metagpt/roles/researcher.py index 52c55f0ca..e894d1a57 100644 --- a/metagpt/roles/researcher.py +++ b/metagpt/roles/researcher.py @@ -70,7 +70,7 @@ class Researcher(Role): return ret def research_system_text(self, topic, current_task: Action) -> str: - """ BACKWARD compatible + """BACKWARD compatible This allows sub-class able to define its own system prompt based on topic. return the previous implementation to have backward compatible Args: diff --git a/metagpt/tools/web_browser_engine_selenium.py b/metagpt/tools/web_browser_engine_selenium.py index 074943892..decab2b7d 100644 --- a/metagpt/tools/web_browser_engine_selenium.py +++ b/metagpt/tools/web_browser_engine_selenium.py @@ -106,8 +106,8 @@ def _gen_get_driver_func(browser_type, *args, executable_path=None): options.add_argument("--headless") options.add_argument("--enable-javascript") if browser_type == "chrome": - options.add_argument("--disable-gpu") # This flag can help avoid renderer issue - options.add_argument("--disable-dev-shm-usage") # Overcome limited resource problems + options.add_argument("--disable-gpu") # This flag can help avoid renderer issue + options.add_argument("--disable-dev-shm-usage") # Overcome limited resource problems options.add_argument("--no-sandbox") for i in args: options.add_argument(i) diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index eec4176df..8db7a80a1 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -226,7 +226,7 @@ class OutputParser: if start_index != -1 and end_index != -1: # Extract the structure part - structure_text = text[start_index: end_index + 1] + structure_text = text[start_index : end_index + 1] try: # Attempt to convert the text to a Python data type using ast.literal_eval diff --git a/metagpt/utils/token_counter.py b/metagpt/utils/token_counter.py index 512ff784c..c29fa7d43 100644 --- a/metagpt/utils/token_counter.py +++ b/metagpt/utils/token_counter.py @@ -26,7 +26,7 @@ TOKEN_COSTS = { "gpt-4-1106-preview": {"prompt": 0.01, "completion": 0.03}, "text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0}, "chatglm_turbo": {"prompt": 0.0, "completion": 0.00069}, # 32k version, prompt + completion tokens=0.005¥/k-tokens - "gemini-pro": {"prompt": 0.00025, "completion": 0.0005} + "gemini-pro": {"prompt": 0.00025, "completion": 0.0005}, } @@ -45,7 +45,7 @@ TOKEN_MAX = { "gpt-4-1106-preview": 128000, "text-embedding-ada-002": 8192, "chatglm_turbo": 32768, - "gemini-pro": 32768 + "gemini-pro": 32768, } diff --git a/tests/metagpt/provider/test_google_gemini_api.py b/tests/metagpt/provider/test_google_gemini_api.py index 32ed11ba5..229d9b9a7 100644 --- a/tests/metagpt/provider/test_google_gemini_api.py +++ b/tests/metagpt/provider/test_google_gemini_api.py @@ -2,16 +2,14 @@ # -*- coding: utf-8 -*- # @Desc : the unittest of google gemini api -import pytest from abc import ABC from dataclasses import dataclass +import pytest + from metagpt.provider.google_gemini_api import GeminiGPTAPI - -messages = [ - {"role": "user", "content": "who are you"} -] +messages = [{"role": "user", "content": "who are you"}] @dataclass From 5eac57a379a33bab569e9ff443656bda37f07d30 Mon Sep 17 00:00:00 2001 From: better629 Date: Thu, 21 Dec 2023 16:30:53 +0800 Subject: [PATCH 160/167] add issue and pr template --- .github/ISSUE_TEMPLATE/config.yaml | 5 ++++ .../ISSUE_TEMPLATE/request_new_features.md | 14 +++++++++ .github/ISSUE_TEMPLATE/show_me_the_bug.md | 29 +++++++++++++++++++ .github/PULL_REQUEST_TEMPLATE.md | 15 ++++++++++ 4 files changed, 63 insertions(+) create mode 100644 .github/ISSUE_TEMPLATE/config.yaml create mode 100644 .github/ISSUE_TEMPLATE/request_new_features.md create mode 100644 .github/ISSUE_TEMPLATE/show_me_the_bug.md create mode 100644 .github/PULL_REQUEST_TEMPLATE.md diff --git a/.github/ISSUE_TEMPLATE/config.yaml b/.github/ISSUE_TEMPLATE/config.yaml new file mode 100644 index 000000000..622f76f1a --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yaml @@ -0,0 +1,5 @@ +blank_issues_enabled: false +contact_links: + - name: "📑 Read online docs" + url: https://docs.deepwisdom.ai/ + about: Find the tutorials, use cases and blogs from the doc site. \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/request_new_features.md b/.github/ISSUE_TEMPLATE/request_new_features.md new file mode 100644 index 000000000..c725cf6d2 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/request_new_features.md @@ -0,0 +1,14 @@ +--- +name: "🤔 Request new features" +about: There are some ideas or demands want to discuss with the official and hope to be implemented in the future. +title: '' +labels: kind/features +assignees: '' +--- + +**Feature description** + + +**Your Feature** + + diff --git a/.github/ISSUE_TEMPLATE/show_me_the_bug.md b/.github/ISSUE_TEMPLATE/show_me_the_bug.md new file mode 100644 index 000000000..504a2bd12 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/show_me_the_bug.md @@ -0,0 +1,29 @@ +--- +name: "🪲 Show me the Bug" +about: Something happened when I use MetaGPT, I want to report it and hope to get help from the official and community. +title: '' +labels: kind/bug +assignees: '' +--- + +**Bug description** + + +**Bug solved method** + + + +**Environment information** + + +- LLM type and model name: +- System version: +- Python version: + + + +- packages version: +- installation method: + +**Screenshots or logs** + diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 000000000..1def6935c --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,15 @@ +**Features** + + + +- xx +- yy + +**Feature Docs** + + +**Influence** + + +**Result** + \ No newline at end of file From ae2985d7e63973fa9646c768bca1dc52cd52d895 Mon Sep 17 00:00:00 2001 From: better629 Date: Thu, 21 Dec 2023 16:34:02 +0800 Subject: [PATCH 161/167] update --- .github/PULL_REQUEST_TEMPLATE.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 1def6935c..f5b280994 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,3 +1,4 @@ + **Features** @@ -12,4 +13,7 @@ **Result** - \ No newline at end of file + + +**Other** + \ No newline at end of file From 561263183ab796c4854f20e5d6cf09d5af83671e Mon Sep 17 00:00:00 2001 From: geekan Date: Thu, 21 Dec 2023 16:47:16 +0800 Subject: [PATCH 162/167] remove oi install --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index e14c6bd3e..eaff5c4b2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -40,7 +40,7 @@ typing_extensions==4.7.0 libcst==1.0.1 qdrant-client==1.4.0 pytest-mock==3.11.1 -open-interpreter==0.1.7; python_version>"3.9" +# open-interpreter==0.1.7; python_version>"3.9" ta==0.10.2 semantic-kernel==0.4.0.dev0 wrapt==1.15.0 From 64c5673d6a0a06d3b349e62644ef508ec870fa51 Mon Sep 17 00:00:00 2001 From: better629 Date: Thu, 21 Dec 2023 17:27:09 +0800 Subject: [PATCH 163/167] support Message() without content param --- examples/search_kb.py | 4 ++-- metagpt/roles/researcher.py | 15 ++++++++++++--- metagpt/schema.py | 3 ++- metagpt/subscription.py | 2 +- tests/metagpt/test_message.py | 2 +- tests/metagpt/test_subscription.py | 10 +++++----- tests/metagpt/utils/test_common.py | 2 +- 7 files changed, 24 insertions(+), 14 deletions(-) diff --git a/examples/search_kb.py b/examples/search_kb.py index 7a9911ca2..5d61bbe02 100644 --- a/examples/search_kb.py +++ b/examples/search_kb.py @@ -31,8 +31,8 @@ async def search(): role = Sales(profile="Sales", store=store) role._watch({Action}) queries = [ - Message("Which facial cleanser is good for oily skin?", cause_by=Action), - Message("Is L'Oreal good to use?", cause_by=Action), + Message(content="Which facial cleanser is good for oily skin?", cause_by=Action), + Message(content="Is L'Oreal good to use?", cause_by=Action), ] for query in queries: logger.info(f"User: {query}") diff --git a/metagpt/roles/researcher.py b/metagpt/roles/researcher.py index e894d1a57..162d72b9b 100644 --- a/metagpt/roles/researcher.py +++ b/metagpt/roles/researcher.py @@ -54,18 +54,27 @@ class Researcher(Role): research_system_text = self.research_system_text(topic, todo) if isinstance(todo, CollectLinks): links = await todo.run(topic, 4, 4) - ret = Message("", Report(topic=topic, links=links), role=self.profile, cause_by=todo) + ret = Message( + content="", instruct_content=Report(topic=topic, links=links), role=self.profile, cause_by=todo + ) elif isinstance(todo, WebBrowseAndSummarize): links = instruct_content.links todos = (todo.run(*url, query=query, system_text=research_system_text) for (query, url) in links.items()) summaries = await asyncio.gather(*todos) summaries = list((url, summary) for i in summaries for (url, summary) in i.items() if summary) - ret = Message("", Report(topic=topic, summaries=summaries), role=self.profile, cause_by=todo) + ret = Message( + content="", instruct_content=Report(topic=topic, summaries=summaries), role=self.profile, cause_by=todo + ) else: summaries = instruct_content.summaries summary_text = "\n---\n".join(f"url: {url}\nsummary: {summary}" for (url, summary) in summaries) content = await self._rc.todo.run(topic, summary_text, system_text=research_system_text) - ret = Message("", Report(topic=topic, content=content), role=self.profile, cause_by=self._rc.todo) + ret = Message( + content="", + instruct_content=Report(topic=topic, content=content), + role=self.profile, + cause_by=self._rc.todo, + ) self._rc.memory.add(ret) return ret diff --git a/metagpt/schema.py b/metagpt/schema.py index 4a9df7fe2..d3c836d8e 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -110,7 +110,7 @@ class Message(BaseModel): sent_from: str = "" send_to: Set = Field(default_factory={MESSAGE_ROUTE_TO_ALL}) - def __init__(self, **kwargs): + def __init__(self, content: str = "", **kwargs): ic = kwargs.get("instruct_content", None) if ic and not isinstance(ic, BaseModel) and "class" in ic: # compatible with custom-defined ActionOutput @@ -122,6 +122,7 @@ class Message(BaseModel): kwargs["instruct_content"] = ic_new kwargs["id"] = kwargs.get("id", uuid.uuid4().hex) + kwargs["content"] = kwargs.get("content", content) kwargs["cause_by"] = any_to_str( kwargs.get("cause_by", import_class("UserRequirement", "metagpt.actions.add_requirement")) ) diff --git a/metagpt/subscription.py b/metagpt/subscription.py index 0d2b30821..607cbdb8d 100644 --- a/metagpt/subscription.py +++ b/metagpt/subscription.py @@ -19,7 +19,7 @@ class SubscriptionRunner(BaseModel): >>> async def trigger(): ... while True: - ... yield Message("the latest news about OpenAI") + ... yield Message(content="the latest news about OpenAI") ... await asyncio.sleep(3600 * 24) >>> async def callback(msg: Message): diff --git a/tests/metagpt/test_message.py b/tests/metagpt/test_message.py index 04d85d9e4..8f267ba54 100644 --- a/tests/metagpt/test_message.py +++ b/tests/metagpt/test_message.py @@ -23,7 +23,7 @@ def test_all_messages(): UserMessage(test_content), SystemMessage(test_content), AIMessage(test_content), - Message(test_content, role="QA"), + Message(content=test_content, role="QA"), ] for msg in msgs: assert msg.content == test_content diff --git a/tests/metagpt/test_subscription.py b/tests/metagpt/test_subscription.py index 2e898424d..75e06411c 100644 --- a/tests/metagpt/test_subscription.py +++ b/tests/metagpt/test_subscription.py @@ -13,12 +13,12 @@ async def test_subscription_run(): async def trigger(): while True: - yield Message("the latest news about OpenAI") + yield Message(content="the latest news about OpenAI") await asyncio.sleep(3600 * 24) class MockRole(Role): async def run(self, message=None): - return Message("") + return Message(content="") async def callback(message): nonlocal callback_done @@ -61,11 +61,11 @@ async def test_subscription_run(): async def test_subscription_run_error(loguru_caplog): async def trigger1(): while True: - yield Message("the latest news about OpenAI") + yield Message(content="the latest news about OpenAI") await asyncio.sleep(3600 * 24) async def trigger2(): - yield Message("the latest news about OpenAI") + yield Message(content="the latest news about OpenAI") class MockRole1(Role): async def run(self, message=None): @@ -73,7 +73,7 @@ async def test_subscription_run_error(loguru_caplog): class MockRole2(Role): async def run(self, message=None): - return Message("") + return Message(content="") async def callback(msg: Message): print(msg) diff --git a/tests/metagpt/utils/test_common.py b/tests/metagpt/utils/test_common.py index 4bd38db63..0ab34437d 100644 --- a/tests/metagpt/utils/test_common.py +++ b/tests/metagpt/utils/test_common.py @@ -47,7 +47,7 @@ class TestGetProjectRoot: Input(x=RunCode, want="metagpt.actions.run_code.RunCode"), Input(x=RunCode(), want="metagpt.actions.run_code.RunCode"), Input(x=Message, want="metagpt.schema.Message"), - Input(x=Message(""), want="metagpt.schema.Message"), + Input(x=Message(content=""), want="metagpt.schema.Message"), Input(x="A", want="A"), ] for i in inputs: From 8d26af8466cd9a750d3dda0fa4084ea7fc574b7a Mon Sep 17 00:00:00 2001 From: better629 Date: Thu, 21 Dec 2023 18:05:34 +0800 Subject: [PATCH 164/167] fix bug of missing test_round --- metagpt/roles/qa_engineer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/metagpt/roles/qa_engineer.py b/metagpt/roles/qa_engineer.py index 5e509300b..39246364e 100644 --- a/metagpt/roles/qa_engineer.py +++ b/metagpt/roles/qa_engineer.py @@ -39,6 +39,7 @@ class QaEngineer(Role): "The test code you write should conform to code standard like PEP8, be modular, " "easy to read and maintain" ) test_round_allowed: int = 5 + test_round: int = 0 def __init__(self, **kwargs): super().__init__(**kwargs) From 7e0a2fabc71ebaa5e6fced59b047bf496eb16c63 Mon Sep 17 00:00:00 2001 From: geekan Date: Thu, 21 Dec 2023 18:21:08 +0800 Subject: [PATCH 165/167] update readme, add link to 0.5 version --- README.md | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 7538824c5..dc6e3dd69 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ # MetaGPT: The Multi-Agent Framework

Software Company Multi-Role Schematic (Gradually Implementing)

## News -- Dec 15: v0.5.0 is released! We introduce **incremental development**, facilitating agents to build up larger projects on top of their previous efforts or exisiting human codebase. We also launch a whole collection of important features, including multilingual support (experimental), multiple programming languages support (experimental), incremental development (experimental), CLI support, pip support, enhanced code review, documentation mechanism, and optimized messaging mechanism! +- Dec 15: [v0.5.0](https://github.com/geekan/MetaGPT/releases/tag/v0.5.0) is released! We introduce **incremental development**, facilitating agents to build up larger projects on top of their previous efforts or exisiting human codebase. We also launch a whole collection of important features, including multilingual support (experimental), multiple programming languages support (experimental), incremental development (experimental), CLI support, pip support, enhanced code review, documentation mechanism, and optimized messaging mechanism! ## Install @@ -50,13 +50,17 @@ # conda activate metagpt # Step 2: Clone the repository to your local machine for latest version, and install it. git clone https://github.com/geekan/MetaGPT.git cd MetaGPT -pip3 install -e. # or pip3 install metagpt # for stable version +pip3 install -e . # or pip3 install metagpt # for stable version -# Step 3: run metagpt cli -# setup your OPENAI_API_KEY in key.yaml copy from config.yaml -metagpt "Write a cli snake game" +# Step 3: setup your OPENAI_API_KEY, or make sure it existed in the env +mkdir ~/.metagpt/key.yaml +cp config/config.yaml ~/.metagpt/key.yaml +vim ~/.metagpt/key.yaml -# Step 4 [Optional]: If you want to save the artifacts like diagrams such as quadrant chart, system designs, sequence flow in the workspace, you can execute the step before Step 3. By default, the framework is compatible, and the entire process can be run completely without executing this step. +# Step 4: run metagpt cli +metagpt "Create a 2048 game in python" + +# Step 5 [Optional]: If you want to save the artifacts like diagrams such as quadrant chart, system designs, sequence flow in the workspace, you can execute the step before Step 3. By default, the framework is compatible, and the entire process can be run completely without executing this step. # If executing, ensure that NPM is installed on your system. Then install mermaid-js. (If you don't have npm in your computer, please go to the Node.js official website to install Node.js https://nodejs.org/ and then you will have npm tool in your computer.) npm --version sudo npm install -g @mermaid-js/mermaid-cli From 73411d21ebfcf723f5f1158acddb142f98846ccb Mon Sep 17 00:00:00 2001 From: geekan Date: Thu, 21 Dec 2023 19:40:27 +0800 Subject: [PATCH 166/167] refine README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index dc6e3dd69..19971acca 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ # MetaGPT: The Multi-Agent Framework

Software Company Multi-Role Schematic (Gradually Implementing)

## News -- Dec 15: [v0.5.0](https://github.com/geekan/MetaGPT/releases/tag/v0.5.0) is released! We introduce **incremental development**, facilitating agents to build up larger projects on top of their previous efforts or exisiting human codebase. We also launch a whole collection of important features, including multilingual support (experimental), multiple programming languages support (experimental), incremental development (experimental), CLI support, pip support, enhanced code review, documentation mechanism, and optimized messaging mechanism! +- Dec 15: [v0.5.0](https://github.com/geekan/MetaGPT/releases/tag/v0.5.0) is released! We introduce **incremental development**, facilitating agents to build up larger projects on top of their previous efforts or exisiting codebase. We also launch a whole collection of important features, including **multilingual support** (experimental), multiple **programming languages support** (experimental), **incremental development** (experimental), CLI support, pip support, enhanced code review, documentation mechanism, and optimized messaging mechanism! ## Install From 139c7c363f593b60c0f80cf43b78b81f8885a95b Mon Sep 17 00:00:00 2001 From: geekan Date: Thu, 21 Dec 2023 22:24:26 +0800 Subject: [PATCH 167/167] fix bugs --- README.md | 2 +- examples/search_with_specific_engine.py | 7 ++++--- metagpt/actions/search_and_summarize.py | 3 ++- metagpt/roles/role.py | 2 +- metagpt/roles/sales.py | 2 +- metagpt/roles/searcher.py | 2 +- 6 files changed, 10 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 19971acca..a03c1eabf 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,7 @@ # Step 2: Clone the repository to your local machine for latest version, and ins pip3 install -e . # or pip3 install metagpt # for stable version # Step 3: setup your OPENAI_API_KEY, or make sure it existed in the env -mkdir ~/.metagpt/key.yaml +mkdir ~/.metagpt cp config/config.yaml ~/.metagpt/key.yaml vim ~/.metagpt/key.yaml diff --git a/examples/search_with_specific_engine.py b/examples/search_with_specific_engine.py index 334a7821f..923f538ed 100644 --- a/examples/search_with_specific_engine.py +++ b/examples/search_with_specific_engine.py @@ -5,12 +5,13 @@ from metagpt.tools import SearchEngineType async def main(): + question = "What are the most interesting human facts?" # Serper API - # await Searcher(engine = SearchEngineType.SERPER_GOOGLE).run(["What are some good sun protection products?","What are some of the best beaches?"]) + # await Searcher(engine=SearchEngineType.SERPER_GOOGLE).run(question) # SerpAPI - # await Searcher(engine=SearchEngineType.SERPAPI_GOOGLE).run("What are the best ski brands for skiers?") + # await Searcher(engine=SearchEngineType.SERPAPI_GOOGLE).run(question) # Google API - await Searcher(engine=SearchEngineType.DIRECT_GOOGLE).run("What are the most interesting human facts?") + await Searcher(engine=SearchEngineType.DIRECT_GOOGLE).run(question) if __name__ == "__main__": diff --git a/metagpt/actions/search_and_summarize.py b/metagpt/actions/search_and_summarize.py index 6ab7becb6..bc1319291 100644 --- a/metagpt/actions/search_and_summarize.py +++ b/metagpt/actions/search_and_summarize.py @@ -16,6 +16,7 @@ from metagpt.llm import LLM from metagpt.logs import logger from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.schema import Message +from metagpt.tools import SearchEngineType from metagpt.tools.search_engine import SearchEngine SEARCH_AND_SUMMARIZE_SYSTEM = """### Requirements @@ -109,7 +110,7 @@ class SearchAndSummarize(Action): content: Optional[str] = None llm: BaseGPTAPI = Field(default_factory=LLM) config: None = Field(default_factory=Config) - engine: Optional[str] = CONFIG.search_engine + engine: Optional[SearchEngineType] = CONFIG.search_engine search_func: Optional[str] = None search_engine: SearchEngine = None diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 8c5743467..b9fde7d05 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -267,7 +267,7 @@ class Role(BaseModel): ## 默认初始化 i = action(name="", llm=self._llm) else: - if self._setting.is_human and not isinstance(action.llm, HumanProvider): + if self.is_human and not isinstance(action.llm, HumanProvider): logger.warning( f"is_human attribute does not take effect, " f"as Role's {str(action)} was initialized using LLM, " diff --git a/metagpt/roles/sales.py b/metagpt/roles/sales.py index ba0a6fc6b..76abf10f3 100644 --- a/metagpt/roles/sales.py +++ b/metagpt/roles/sales.py @@ -31,7 +31,7 @@ class Sales(Role): def _set_store(self, store): if store: - action = SearchAndSummarize("", engine=SearchEngineType.CUSTOM_ENGINE, search_func=store.asearch) + action = SearchAndSummarize(name="", engine=SearchEngineType.CUSTOM_ENGINE, search_func=store.asearch) else: action = SearchAndSummarize() self._init_actions([action]) diff --git a/metagpt/roles/searcher.py b/metagpt/roles/searcher.py index a2136064f..e4a672176 100644 --- a/metagpt/roles/searcher.py +++ b/metagpt/roles/searcher.py @@ -52,7 +52,7 @@ class Searcher(Role): def set_search_func(self, search_func): """Sets a custom search function for the searcher.""" - action = SearchAndSummarize("", engine=SearchEngineType.CUSTOM_ENGINE, search_func=search_func) + action = SearchAndSummarize(name="", engine=SearchEngineType.CUSTOM_ENGINE, search_func=search_func) self._init_actions([action]) async def _act_sp(self) -> Message: