diff --git a/metagpt/actions/__init__.py b/metagpt/actions/__init__.py index c56f25e31..b004bd58e 100644 --- a/metagpt/actions/__init__.py +++ b/metagpt/actions/__init__.py @@ -15,6 +15,7 @@ from metagpt.actions.design_api import WriteDesign from metagpt.actions.design_api_review import DesignReview from metagpt.actions.design_filenames import DesignFilenames from metagpt.actions.project_management import AssignTasks, WriteTasks +from metagpt.actions.research import CollectLinks, WebBrowseAndSummarize, ConductResearch from metagpt.actions.run_code import RunCode from metagpt.actions.search_and_summarize import SearchAndSummarize from metagpt.actions.write_code import WriteCode @@ -22,11 +23,11 @@ from metagpt.actions.write_code_review import WriteCodeReview from metagpt.actions.write_prd import WritePRD from metagpt.actions.write_prd_review import WritePRDReview from metagpt.actions.write_test import WriteTest -from metagpt.actions.research import CollectLinks, WebBrowseAndSummarize, ConductResearch class ActionType(Enum): """All types of Actions, used for indexing.""" + ADD_REQUIREMENT = BossRequirement WRITE_PRD = WritePRD WRITE_PRD_REVIEW = WritePRDReview @@ -44,3 +45,10 @@ class ActionType(Enum): COLLECT_LINKS = CollectLinks WEB_BROWSE_AND_SUMMARIZE = WebBrowseAndSummarize CONDUCT_RESEARCH = ConductResearch + + +__all__ = [ + "ActionType", + "Action", + "ActionOutput", +] diff --git a/metagpt/actions/run_code.py b/metagpt/actions/run_code.py index f14a6a8e7..f69d2cd1a 100644 --- a/metagpt/actions/run_code.py +++ b/metagpt/actions/run_code.py @@ -5,13 +5,13 @@ @Author : alexanderwu @File : run_code.py """ -import traceback import os import subprocess -from typing import List, Tuple +import traceback +from typing import Tuple -from metagpt.logs import logger from metagpt.actions.action import Action +from metagpt.logs import logger PROMPT_TEMPLATE = """ Role: You are a senior development and qa engineer, your role is summarize the code running result. @@ -55,6 +55,7 @@ standard output: {outs}; standard errors: {errs}; """ + class RunCode(Action): def __init__(self, name="RunCode", context=None, llm=None): super().__init__(name, context, llm) @@ -65,7 +66,7 @@ class RunCode(Action): # We will document_store the result in this dictionary namespace = {} exec(code, namespace) - return namespace.get('result', ""), "" + return namespace.get("result", ""), "" except Exception: # If there is an error in the code, return the error message return "", traceback.format_exc() @@ -81,10 +82,12 @@ class RunCode(Action): # Modify the PYTHONPATH environment variable additional_python_paths = [working_directory] + additional_python_paths additional_python_paths = ":".join(additional_python_paths) - env['PYTHONPATH'] = additional_python_paths + ':' + env.get('PYTHONPATH', '') + env["PYTHONPATH"] = additional_python_paths + ":" + env.get("PYTHONPATH", "") # Start the subprocess - process = subprocess.Popen(command, cwd=working_directory, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env) + process = subprocess.Popen( + command, cwd=working_directory, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env + ) try: # Wait for the process to complete, with a timeout @@ -93,7 +96,7 @@ class RunCode(Action): logger.info("The command did not complete within the given timeout.") process.kill() # Kill the process if it times out stdout, stderr = process.communicate() - return stdout.decode('utf-8'), stderr.decode('utf-8') + return stdout.decode("utf-8"), stderr.decode("utf-8") async def run( self, code, mode="script", code_file_name="", test_code="", test_file_name="", command=[], **kwargs @@ -108,11 +111,13 @@ class RunCode(Action): logger.info(f"{errs=}") context = CONTEXT.format( - code=code, code_file_name=code_file_name, - test_code=test_code, test_file_name=test_file_name, + code=code, + code_file_name=code_file_name, + test_code=test_code, + test_file_name=test_file_name, command=" ".join(command), - outs=outs[:500], # outs might be long but they are not important, truncate them to avoid token overflow - errs=errs[:10000] # truncate errors to avoid token overflow + outs=outs[:500], # outs might be long but they are not important, truncate them to avoid token overflow + errs=errs[:10000], # truncate errors to avoid token overflow ) prompt = PROMPT_TEMPLATE.format(context=context) diff --git a/metagpt/actions/write_test.py b/metagpt/actions/write_test.py index e1c1571c3..5e50fdb55 100644 --- a/metagpt/actions/write_test.py +++ b/metagpt/actions/write_test.py @@ -5,7 +5,6 @@ @Author : alexanderwu @File : write_test.py """ -from metagpt.logs import logger from metagpt.actions.action import Action from metagpt.utils.common import CodeParser @@ -29,6 +28,7 @@ you should correctly import the necessary classes based on these file locations! ## {test_file_name}: Write test code with triple quoto. Do your best to implement THIS ONLY ONE FILE. """ + class WriteTest(Action): def __init__(self, name="WriteTest", context=None, llm=None): super().__init__(name, context, llm) @@ -43,7 +43,7 @@ class WriteTest(Action): code_to_test=code_to_test, test_file_name=test_file_name, source_file_path=source_file_path, - workspace=workspace + workspace=workspace, ) code = await self.write_code(prompt) return code diff --git a/metagpt/document_store/__init__.py b/metagpt/document_store/__init__.py index 7d7c6e5e9..766e141a5 100644 --- a/metagpt/document_store/__init__.py +++ b/metagpt/document_store/__init__.py @@ -7,3 +7,5 @@ """ from metagpt.document_store.faiss_store import FaissStore + +__all__ = ["FaissStore"] diff --git a/metagpt/memory/__init__.py b/metagpt/memory/__init__.py index 2eff0d890..710930626 100644 --- a/metagpt/memory/__init__.py +++ b/metagpt/memory/__init__.py @@ -9,3 +9,8 @@ from metagpt.memory.memory import Memory from metagpt.memory.longterm_memory import LongTermMemory + +__all__ = [ + "Memory", + "LongTermMemory", +] diff --git a/metagpt/memory/longterm_memory.py b/metagpt/memory/longterm_memory.py index 154fcfbda..3c2963613 100644 --- a/metagpt/memory/longterm_memory.py +++ b/metagpt/memory/longterm_memory.py @@ -2,12 +2,10 @@ # -*- coding: utf-8 -*- # @Desc : the implement of Long-term memory -from typing import Iterable, Type - from metagpt.logs import logger -from metagpt.schema import Message from metagpt.memory import Memory from metagpt.memory.memory_storage import MemoryStorage +from metagpt.schema import Message class LongTermMemory(Memory): @@ -27,10 +25,11 @@ class LongTermMemory(Memory): messages = self.memory_storage.recover_memory(role_id) self.rc = rc if not self.memory_storage.is_initialized: - logger.warning(f'It may the first time to run Agent {role_id}, the long-term memory is empty') + logger.warning(f"It may the first time to run Agent {role_id}, the long-term memory is empty") else: - logger.warning(f'Agent {role_id} has existed memory storage with {len(messages)} messages ' - f'and has recovered them.') + logger.warning( + f"Agent {role_id} has existed memory storage with {len(messages)} messages " f"and has recovered them." + ) self.msg_from_recover = True self.add_batch(messages) self.msg_from_recover = False diff --git a/metagpt/provider/__init__.py b/metagpt/provider/__init__.py index 785dbdd66..56dc19b4b 100644 --- a/metagpt/provider/__init__.py +++ b/metagpt/provider/__init__.py @@ -7,3 +7,6 @@ """ from metagpt.provider.openai_api import OpenAIGPTAPI + + +__all__ = ["OpenAIGPTAPI"] diff --git a/metagpt/roles/__init__.py b/metagpt/roles/__init__.py index b1911df06..318a61090 100644 --- a/metagpt/roles/__init__.py +++ b/metagpt/roles/__init__.py @@ -8,10 +8,21 @@ from metagpt.roles.role import Role from metagpt.roles.architect import Architect -from metagpt.roles.product_manager import ProductManager from metagpt.roles.project_manager import ProjectManager from metagpt.roles.engineer import Engineer from metagpt.roles.qa_engineer import QaEngineer from metagpt.roles.seacher import Searcher from metagpt.roles.sales import Sales from metagpt.roles.customer_service import CustomerService + + +__all__ = [ + "Role", + "Architect", + "ProjectManager", + "Engineer", + "QaEngineer", + "Searcher", + "Sales", + "CustomerService", +] diff --git a/metagpt/roles/qa_engineer.py b/metagpt/roles/qa_engineer.py index 5e12a1abd..65bf2cc5b 100644 --- a/metagpt/roles/qa_engineer.py +++ b/metagpt/roles/qa_engineer.py @@ -6,40 +6,44 @@ @File : qa_engineer.py """ import os -import re from pathlib import Path -from typing import Type -from metagpt.actions import WriteTest, WriteCode, WriteDesign, RunCode, DebugError +from metagpt.actions import DebugError, RunCode, WriteCode, WriteDesign, WriteTest from metagpt.const import WORKSPACE_ROOT from metagpt.logs import logger from metagpt.roles import Role from metagpt.schema import Message -from metagpt.roles.engineer import Engineer from metagpt.utils.common import CodeParser, parse_recipient -from metagpt.utils.special_tokens import MSG_SEP, FILENAME_CODE_SEP +from metagpt.utils.special_tokens import FILENAME_CODE_SEP, MSG_SEP + class QaEngineer(Role): - def __init__(self, name="Edward", profile="QaEngineer", - goal="Write comprehensive and robust tests to ensure codes will work as expected without bugs", - constraints="The test code you write should conform to code standard like PEP8, be modular, easy to read and maintain", - test_round_allowed=5): + def __init__( + self, + name="Edward", + profile="QaEngineer", + goal="Write comprehensive and robust tests to ensure codes will work as expected without bugs", + constraints="The test code you write should conform to code standard like PEP8, be modular, easy to read and maintain", + test_round_allowed=5, + ): super().__init__(name, profile, goal, constraints) - self._init_actions([WriteTest]) # FIXME: a bit hack here, only init one action to circumvent _think() logic, will overwrite _think() in future updates + self._init_actions( + [WriteTest] + ) # FIXME: a bit hack here, only init one action to circumvent _think() logic, will overwrite _think() in future updates self._watch([WriteCode, WriteTest, RunCode, DebugError]) self.test_round = 0 self.test_round_allowed = test_round_allowed - + @classmethod def parse_workspace(cls, system_design_msg: Message) -> str: if not system_design_msg.instruct_content: return system_design_msg.instruct_content.dict().get("Python package name") return CodeParser.parse_str(block="Python package name", text=system_design_msg.content) - + def get_workspace(self, return_proj_dir=True) -> Path: msg = self._rc.memory.get_by_action(WriteDesign)[-1] if not msg: - return WORKSPACE_ROOT / 'src' + return WORKSPACE_ROOT / "src" workspace = self.parse_workspace(msg) # project directory: workspace/{package_name}, which contains package source code folder, tests folder, resources folder, etc. if return_proj_dir: @@ -48,49 +52,52 @@ class QaEngineer(Role): return WORKSPACE_ROOT / workspace / workspace def write_file(self, filename: str, code: str): - workspace = self.get_workspace() / 'tests' + workspace = self.get_workspace() / "tests" file = workspace / filename file.parent.mkdir(parents=True, exist_ok=True) file.write_text(code) async def _write_test(self, message: Message) -> None: - code_msgs = message.content.split(MSG_SEP) - result_msg_all = [] + # result_msg_all = [] for code_msg in code_msgs: - # write tests file_name, file_path = code_msg.split(FILENAME_CODE_SEP) code_to_test = open(file_path, "r").read() if "test" in file_name: - continue # Engineer might write some test files, skip testing a test file + continue # Engineer might write some test files, skip testing a test file test_file_name = "test_" + file_name test_file_path = self.get_workspace() / "tests" / test_file_name - logger.info(f'Writing {test_file_name}..') + logger.info(f"Writing {test_file_name}..") test_code = await WriteTest().run( code_to_test=code_to_test, test_file_name=test_file_name, # source_file_name=file_name, source_file_path=file_path, - workspace=self.get_workspace() + workspace=self.get_workspace(), ) self.write_file(test_file_name, test_code) # prepare context for run tests in next round - command = ['python', f'tests/{test_file_name}'] + command = ["python", f"tests/{test_file_name}"] file_info = { - "file_name": file_name, "file_path": str(file_path), - "test_file_name": test_file_name, "test_file_path": str(test_file_path), - "command": command + "file_name": file_name, + "file_path": str(file_path), + "test_file_name": test_file_name, + "test_file_path": str(test_file_path), + "command": command, } msg = Message( - content=str(file_info), role=self.profile, cause_by=WriteTest, - sent_from=self.profile, send_to=self.profile + content=str(file_info), + role=self.profile, + cause_by=WriteTest, + sent_from=self.profile, + send_to=self.profile, ) self._publish_message(msg) - - logger.info(f'Done {self.get_workspace()}/tests generating.') - + + logger.info(f"Done {self.get_workspace()}/tests generating.") + async def _run_code(self, msg): file_info = eval(msg.content) development_file_path = file_info["file_path"] @@ -110,17 +117,14 @@ class QaEngineer(Role): test_code=test_code, test_file_name=file_info["test_file_name"], command=file_info["command"], - working_directory=proj_dir, # workspace/package_name, will run tests/test_xxx.py here - additional_python_paths=[development_code_dir], # workspace/package_name/package_name, - # import statement inside package code needs this + working_directory=proj_dir, # workspace/package_name, will run tests/test_xxx.py here + additional_python_paths=[development_code_dir], # workspace/package_name/package_name, + # import statement inside package code needs this ) - recipient = parse_recipient(result_msg) # the recipient might be Engineer or myself + recipient = parse_recipient(result_msg) # the recipient might be Engineer or myself content = str(file_info) + FILENAME_CODE_SEP + result_msg - msg = Message( - content=content, role=self.profile, cause_by=RunCode, - sent_from=self.profile, send_to=recipient - ) + msg = Message(content=content, role=self.profile, cause_by=RunCode, sent_from=self.profile, send_to=recipient) self._publish_message(msg) async def _debug_error(self, msg): @@ -128,21 +132,27 @@ class QaEngineer(Role): file_name, code = await DebugError().run(context) if file_name: self.write_file(file_name, code) - recipient = msg.sent_from # send back to the one who ran the code for another run, might be one's self - msg = Message(content=file_info, role=self.profile, cause_by=DebugError, sent_from=self.profile, send_to=recipient) + recipient = msg.sent_from # send back to the one who ran the code for another run, might be one's self + msg = Message( + content=file_info, role=self.profile, cause_by=DebugError, sent_from=self.profile, send_to=recipient + ) self._publish_message(msg) - + async def _observe(self) -> int: await super()._observe() - self._rc.news = [msg for msg in self._rc.news \ - if msg.send_to == self.profile] # only relevant msgs count as observed news + self._rc.news = [ + msg for msg in self._rc.news if msg.send_to == self.profile + ] # only relevant msgs count as observed news return len(self._rc.news) async def _act(self) -> Message: if self.test_round > self.test_round_allowed: result_msg = Message( content=f"Exceeding {self.test_round_allowed} rounds of tests, skip (writing code counts as a round, too)", - role=self.profile, cause_by=WriteTest, sent_from=self.profile, send_to="" + role=self.profile, + cause_by=WriteTest, + sent_from=self.profile, + send_to="", ) return result_msg @@ -161,6 +171,9 @@ class QaEngineer(Role): self.test_round += 1 result_msg = Message( content=f"Round {self.test_round} of tests done", - role=self.profile, cause_by=WriteTest, sent_from=self.profile, send_to="" + role=self.profile, + cause_by=WriteTest, + sent_from=self.profile, + send_to="", ) return result_msg diff --git a/metagpt/tools/sd_engine.py b/metagpt/tools/sd_engine.py index e462f1bda..a63dbe5ac 100644 --- a/metagpt/tools/sd_engine.py +++ b/metagpt/tools/sd_engine.py @@ -2,29 +2,27 @@ # @Date : 2023/7/19 16:28 # @Author : stellahong (stellahong@fuzhi.ai) # @Desc : -import os import asyncio +import base64 +import io +import json +import os from os.path import join from typing import List -import json -import io -import base64 from aiohttp import ClientSession from PIL import Image, PngImagePlugin -from metagpt.logs import logger from metagpt.config import Config from metagpt.const import WORKSPACE_ROOT +from metagpt.logs import logger config = Config() payload = { "prompt": "", "negative_prompt": "(easynegative:0.8),black, dark,Low resolution", - "override_settings": { - "sd_model_checkpoint": "galaxytimemachinesGTM_photoV20" - }, + "override_settings": {"sd_model_checkpoint": "galaxytimemachinesGTM_photoV20"}, "seed": -1, "batch_size": 1, "n_iter": 1, @@ -36,21 +34,20 @@ payload = { "tiling": False, "do_not_save_samples": False, "do_not_save_grid": False, - 'enable_hr': False, - 'hr_scale': 2, - 'hr_upscaler': 'Latent', - 'hr_second_pass_steps': 0, - 'hr_resize_x': 0, - 'hr_resize_y': 0, - 'hr_upscale_to_x': 0, - 'hr_upscale_to_y': 0, - 'truncate_x': 0, - 'truncate_y': 0, - 'applied_old_hires_behavior_to': None, + "enable_hr": False, + "hr_scale": 2, + "hr_upscaler": "Latent", + "hr_second_pass_steps": 0, + "hr_resize_x": 0, + "hr_resize_y": 0, + "hr_upscale_to_x": 0, + "hr_upscale_to_y": 0, + "truncate_x": 0, + "truncate_y": 0, + "applied_old_hires_behavior_to": None, "eta": None, - "sampler_index": "DPM++ SDE Karras", - "alwayson_scripts": {} + "alwayson_scripts": {}, } default_negative_prompt = "(easynegative:0.8),black, dark,Low resolution" @@ -60,14 +57,20 @@ class SDEngine: def __init__(self): # Initialize the SDEngine with configuration self.config = Config() - self.sd_url = self.config.get('SD_URL') + self.sd_url = self.config.get("SD_URL") self.sd_t2i_url = f"{self.sd_url}{self.config.get('SD_T2I_API')}" # Define default payload settings for SD API self.payload = payload logger.info(self.sd_t2i_url) - - def construct_payload(self, prompt, negtive_prompt=default_negative_prompt, width=512, height=512, - sd_model="galaxytimemachinesGTM_photoV20"): + + def construct_payload( + self, + prompt, + negtive_prompt=default_negative_prompt, + width=512, + height=512, + sd_model="galaxytimemachinesGTM_photoV20", + ): # Configure the payload with provided inputs self.payload["prompt"] = prompt self.payload["negtive_prompt"] = negtive_prompt @@ -76,13 +79,13 @@ class SDEngine: self.payload["override_settings"]["sd_model_checkpoint"] = sd_model logger.info(f"call sd payload is {self.payload}") return self.payload - + def _save(self, imgs, save_name=""): - save_dir = WORKSPACE_ROOT / "resources"/"SD_Output" + save_dir = WORKSPACE_ROOT / "resources" / "SD_Output" if not os.path.exists(save_dir): os.makedirs(save_dir, exist_ok=True) batch_decode_base64_to_image(imgs, save_dir, save_name=save_name) - + async def run_t2i(self, prompts: List): # Asynchronously run the SD API for multiple prompts session = ClientSession() @@ -90,25 +93,26 @@ class SDEngine: results = await self.run(url=self.sd_t2i_url, payload=payload, session=session) self._save(results, save_name=f"output_{payload_idx}") await session.close() - + async def run(self, url, payload, session): # Perform the HTTP POST request to the SD API async with session.post(url, json=payload, timeout=600) as rsp: data = await rsp.read() - + rsp_json = json.loads(data) - imgs = rsp_json['images'] + imgs = rsp_json["images"] logger.info(f"callback rsp json is {rsp_json.keys()}") return imgs - + async def run_i2i(self): # todo: 添加图生图接口调用 raise NotImplementedError - + async def run_sam(self): # todo:添加SAM接口调用 raise NotImplementedError + def decode_base64_to_image(img, save_name): image = Image.open(io.BytesIO(base64.b64decode(img.split(",", 1)[0]))) pnginfo = PngImagePlugin.PngInfo() @@ -124,12 +128,10 @@ def batch_decode_base64_to_image(imgs, save_dir="", save_name=""): if __name__ == "__main__": - import asyncio - engine = SDEngine() prompt = "pixel style, game design, a game interface should be minimalistic and intuitive with the score and high score displayed at the top. The snake and its food should be easily distinguishable. The game should have a simple color scheme, with a contrasting color for the snake and its food. Complete interface boundary" - + engine.construct_payload(prompt) - + event_loop = asyncio.get_event_loop() event_loop.run_until_complete(engine.run_t2i(prompt)) diff --git a/metagpt/utils/__init__.py b/metagpt/utils/__init__.py index 579308a3b..f13175cf8 100644 --- a/metagpt/utils/__init__.py +++ b/metagpt/utils/__init__.py @@ -13,3 +13,12 @@ from metagpt.utils.token_counter import ( count_message_tokens, count_string_tokens, ) + + +__all__ = [ + "read_docx", + "Singleton", + "TOKEN_COSTS", + "count_message_tokens", + "count_string_tokens", +] diff --git a/metagpt/utils/mermaid.py b/metagpt/utils/mermaid.py index 3788b4743..24aabe8ae 100644 --- a/metagpt/utils/mermaid.py +++ b/metagpt/utils/mermaid.py @@ -5,9 +5,9 @@ @Author : alexanderwu @File : mermaid.py """ -import os import subprocess from pathlib import Path + from metagpt.config import CONFIG from metagpt.const import PROJECT_ROOT from metagpt.logs import logger @@ -24,25 +24,36 @@ def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, height :return: 0 if succed, -1 if failed """ # Write the Mermaid code to a temporary file - tmp = Path(f'{output_file_without_suffix}.mmd') - tmp.write_text(mermaid_code, encoding='utf-8') + tmp = Path(f"{output_file_without_suffix}.mmd") + tmp.write_text(mermaid_code, encoding="utf-8") - if check_cmd_exists('mmdc') != 0: - logger.warning( - "RUN `npm install -g @mermaid-js/mermaid-cli` to install mmdc") + if check_cmd_exists("mmdc") != 0: + logger.warning("RUN `npm install -g @mermaid-js/mermaid-cli` to install mmdc") return -1 - for suffix in ['pdf', 'svg', 'png']: - output_file = f'{output_file_without_suffix}.{suffix}' + for suffix in ["pdf", "svg", "png"]: + output_file = f"{output_file_without_suffix}.{suffix}" # Call the `mmdc` command to convert the Mermaid code to a PNG logger.info(f"Generating {output_file}..") if CONFIG.puppeteer_config: - subprocess.run([CONFIG.mmdc, '-p', CONFIG.puppeteer_config, '-i', str(tmp), '-o', - output_file, '-w', str(width), '-H', str(height)]) + subprocess.run( + [ + CONFIG.mmdc, + "-p", + CONFIG.puppeteer_config, + "-i", + str(tmp), + "-o", + output_file, + "-w", + str(width), + "-H", + str(height), + ] + ) else: - subprocess.run([CONFIG.mmdc, '-i', str(tmp), '-o', - output_file, '-w', str(width), '-H', str(height)]) + subprocess.run([CONFIG.mmdc, "-i", str(tmp), "-o", output_file, "-w", str(width), "-H", str(height)]) return 0 @@ -97,7 +108,7 @@ MMC2 = """sequenceDiagram SE-->>M: return summary""" -if __name__ == '__main__': +if __name__ == "__main__": # logger.info(print_members(print_members)) - mermaid_to_file(MMC1, PROJECT_ROOT / 'tmp/1.png') - mermaid_to_file(MMC2, PROJECT_ROOT / 'tmp/2.png') + mermaid_to_file(MMC1, PROJECT_ROOT / "tmp/1.png") + mermaid_to_file(MMC2, PROJECT_ROOT / "tmp/2.png") diff --git a/metagpt/utils/serialize.py b/metagpt/utils/serialize.py index 34dee7098..ffafca8cd 100644 --- a/metagpt/utils/serialize.py +++ b/metagpt/utils/serialize.py @@ -3,14 +3,11 @@ # @Desc : the implement of serialization and deserialization import copy -from typing import Tuple, List, Type, Union, Dict import pickle -from collections import defaultdict -from pydantic import create_model +from typing import Dict, List, Tuple -from metagpt.schema import Message -from metagpt.actions.action import Action from metagpt.actions.action_output import ActionOutput +from metagpt.schema import Message def actionoutout_schema_to_mapping(schema: Dict) -> Dict: @@ -34,12 +31,12 @@ def actionoutout_schema_to_mapping(schema: Dict) -> Dict: ``` """ mapping = dict() - for field, property in schema['properties'].items(): - if property['type'] == 'string': + for field, property in schema["properties"].items(): + if property["type"] == "string": mapping[field] = (str, ...) - elif property['type'] == 'array' and property['items']['type'] == 'string': + elif property["type"] == "array" and property["items"]["type"] == "string": mapping[field] = (List[str], ...) - elif property['type'] == 'array' and property['items']['type'] == 'array': + elif property["type"] == "array" and property["items"]["type"] == "array": # here only consider the `Tuple[str, str]` situation mapping[field] = (List[Tuple[str, str]], ...) return mapping @@ -53,11 +50,7 @@ def serialize_message(message: Message): schema = ic.schema() mapping = actionoutout_schema_to_mapping(schema) - message_cp.instruct_content = { - 'class': schema['title'], - 'mapping': mapping, - 'value': ic.dict() - } + message_cp.instruct_content = {"class": schema["title"], "mapping": mapping, "value": ic.dict()} msg_ser = pickle.dumps(message_cp) return msg_ser @@ -67,9 +60,8 @@ def deserialize_message(message_ser: str) -> Message: message = pickle.loads(message_ser) if message.instruct_content: ic = message.instruct_content - ic_obj = ActionOutput.create_model_class(class_name=ic['class'], - mapping=ic['mapping']) - ic_new = ic_obj(**ic['value']) + ic_obj = ActionOutput.create_model_class(class_name=ic["class"], mapping=ic["mapping"]) + ic_new = ic_obj(**ic["value"]) message.instruct_content = ic_new return message diff --git a/pyproject.toml b/pyproject.toml index 72bc26543..ed7c2769e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ target-version = ['py39'] [tool.ruff] # Enable pycodestyle (`E`) and Pyflakes (`F`) codes by default. select = ["E", "F"] -ignore = ["E501", "E712", "E722", "F821"] +ignore = ["E501", "E712", "E722", "F821", "E731"] # Allow autofix for all enabled rules (when `--fix`) is provided. fixable = ["A", "B", "C", "D", "E", "F", "G", "I", "N", "Q", "S", "T", "W", "ANN", "ARG", "BLE", "COM", "DJ", "DTZ", "EM", "ERA", "EXE", "FBT", "ICN", "INP", "ISC", "NPY", "PD", "PGH", "PIE", "PL", "PT", "PTH", "PYI", "RET", "RSE", "RUF", "SIM", "SLF", "TCH", "TID", "TRY", "UP", "YTT"] diff --git a/tests/metagpt/actions/test_run_code.py b/tests/metagpt/actions/test_run_code.py index 489da28c6..1e451cb14 100644 --- a/tests/metagpt/actions/test_run_code.py +++ b/tests/metagpt/actions/test_run_code.py @@ -6,24 +6,23 @@ @File : test_run_code.py """ import pytest -import asyncio + from metagpt.actions.run_code import RunCode + @pytest.mark.asyncio async def test_run_text(): - action = RunCode() - result, errs = await RunCode.run_text('result = 1 + 1') + result, errs = await RunCode.run_text("result = 1 + 1") assert result == 2 assert errs == "" - result, errs = await RunCode.run_text('result = 1 / 0') + result, errs = await RunCode.run_text("result = 1 / 0") assert result == "" assert "ZeroDivisionError" in errs + @pytest.mark.asyncio async def test_run_script(): - action = RunCode() - # Successful command out, err = await RunCode.run_script(".", command=["echo", "Hello World"]) assert out.strip() == "Hello World" @@ -33,6 +32,7 @@ async def test_run_script(): out, err = await RunCode.run_script(".", command=["python", "-c", "print(1/0)"]) assert "ZeroDivisionError" in err + @pytest.mark.asyncio async def test_run(): action = RunCode() @@ -47,10 +47,11 @@ async def test_run(): test_file_name="", command=["echo", "Hello World"], working_directory=".", - additional_python_paths=[] + additional_python_paths=[], ) assert "PASS" in result + @pytest.mark.asyncio async def test_run_failure(): action = RunCode() @@ -65,6 +66,6 @@ async def test_run_failure(): test_file_name="", command=["python", "-c", "print(1/0)"], working_directory=".", - additional_python_paths=[] + additional_python_paths=[], ) - assert "FAIL" in result \ No newline at end of file + assert "FAIL" in result diff --git a/tests/metagpt/actions/test_write_code_review.py b/tests/metagpt/actions/test_write_code_review.py index cee7eb941..21bc563ec 100644 --- a/tests/metagpt/actions/test_write_code_review.py +++ b/tests/metagpt/actions/test_write_code_review.py @@ -8,8 +8,6 @@ import pytest from metagpt.actions.write_code_review import WriteCodeReview -from metagpt.logs import logger -from tests.metagpt.actions.mock import SEARCH_CODE_SAMPLE @pytest.mark.asyncio @@ -20,11 +18,7 @@ def add(a, b): """ # write_code_review = WriteCodeReview("write_code_review") - code = await WriteCodeReview().run( - context="编写一个从a加b的函数,返回a+b", - code=code, - filename="math.py" - ) + code = await WriteCodeReview().run(context="编写一个从a加b的函数,返回a+b", code=code, filename="math.py") # 我们不能精确地预测生成的代码评审,但我们可以检查返回的是否为字符串 assert isinstance(code, str) @@ -33,6 +27,7 @@ def add(a, b): captured = capfd.readouterr() print(f"输出内容: {captured.out}") + # @pytest.mark.asyncio # async def test_write_code_review_directly(): # code = SEARCH_CODE_SAMPLE diff --git a/tests/metagpt/roles/ui_role.py b/tests/metagpt/roles/ui_role.py index 101be9c69..a45a89cde 100644 --- a/tests/metagpt/roles/ui_role.py +++ b/tests/metagpt/roles/ui_role.py @@ -2,22 +2,19 @@ # @Date : 2023/7/15 16:40 # @Author : stellahong (stellahong@fuzhi.ai) # @Desc : -import re import os -from importlib import import_module +import re from functools import wraps +from importlib import import_module -from metagpt.logs import logger -from metagpt.actions import Action, ActionOutput -from metagpt.roles import ProductManager, Role -from metagpt.schema import Message +from metagpt.actions import Action, ActionOutput, WritePRD from metagpt.const import WORKSPACE_ROOT - -from metagpt.actions import WritePRD -from metagpt.software_company import SoftwareCompany +from metagpt.logs import logger +from metagpt.roles import Role +from metagpt.schema import Message from metagpt.tools.sd_engine import SDEngine -PROMPT_TEMPLATE = ''' +PROMPT_TEMPLATE = """ # Context {context} @@ -34,9 +31,9 @@ Attention: Use '##' to split sections, not '#', and '## ' SHOULD W ## CSS Styles (styles.css):Provide as Plain text,use standard css code ## Anything UNCLEAR:Provide as Plain text. Make clear here. -''' +""" -FORMAT_EXAMPLE = ''' +FORMAT_EXAMPLE = """ ## UI Design Description ```Snake games are classic and addictive games with simple yet engaging elements. Here are the main elements commonly found in snake games ``` @@ -126,7 +123,7 @@ body { ## Anything UNCLEAR There are no unclear points. -''' +""" OUTPUT_MAPPING = { "UI Design Description": (str, ...), @@ -139,25 +136,25 @@ OUTPUT_MAPPING = { def load_engine(func): """Decorator to load an engine by file name and engine name.""" - + @wraps(func) def wrapper(*args, **kwargs): file_name, engine_name = func(*args, **kwargs) - engine_file = import_module(file_name, package='metagpt') + engine_file = import_module(file_name, package="metagpt") ip_module_cls = getattr(engine_file, engine_name) try: engine = ip_module_cls() except: engine = None - + return engine - + return wrapper def parse(func): """Decorator to parse information using regex pattern.""" - + @wraps(func) def wrapper(*args, **kwargs): context, pattern = func(*args, **kwargs) @@ -168,30 +165,30 @@ def parse(func): else: text_info = context logger.info("未找到匹配的内容") - + return text_info - + return wrapper class UIDesign(Action): """Class representing the UI Design action.""" - + def __init__(self, name, context=None, llm=None): super().__init__(name, context, llm) # 需要调用LLM进一步丰富UI设计的prompt - + @parse def parse_requirement(self, context: str): """Parse UI Design draft from the context using regex.""" pattern = r"## UI Design draft.*?\n(.*?)## Anything UNCLEAR" return context, pattern - + @parse def parse_ui_elements(self, context: str): """Parse Selected Elements from the context using regex.""" pattern = r"## Selected Elements.*?\n(.*?)## HTML Layout" return context, pattern - + @parse def parse_css_code(self, context: str): pattern = r"```css.*?\n(.*?)## Anything UNCLEAR" @@ -201,7 +198,7 @@ class UIDesign(Action): def parse_html_code(self, context: str): pattern = r"```html.*?\n(.*?)```" return context, pattern - + async def draw_icons(self, context, *args, **kwargs): """Draw icons using SDEngine.""" engine = SDEngine() @@ -215,20 +212,20 @@ class UIDesign(Action): prompts_batch.append(prompt) await engine.run_t2i(prompts_batch) logger.info("Finish icon design using StableDiffusion API") - + async def _save(self, css_content, html_content): - save_dir = WORKSPACE_ROOT / "resources" / 'codes' + save_dir = WORKSPACE_ROOT / "resources" / "codes" if not os.path.exists(save_dir): os.makedirs(save_dir, exist_ok=True) # Save CSS and HTML content to files - css_file_path = save_dir / f"ui_design.css" - html_file_path = save_dir / f"ui_design.html" - - with open(css_file_path, 'w') as css_file: + css_file_path = save_dir / "ui_design.css" + html_file_path = save_dir / "ui_design.html" + + with open(css_file_path, "w") as css_file: css_file.write(css_content) - with open(html_file_path, 'w') as html_file: + with open(html_file_path, "w") as html_file: html_file.write(html_content) - + async def run(self, requirements: list[Message], *args, **kwargs) -> ActionOutput: """Run the UI Design action.""" # fixme: update prompt (根据需求细化prompt) @@ -249,23 +246,27 @@ class UIDesign(Action): class UI(Role): """Class representing the UI Role.""" - - def __init__(self, name="Catherine", profile="UI Design", - goal="Finish a workable and good User Interface design based on a product design", - constraints="Give clear layout description and use standard icons to finish the design", - skills=["SD"]): + + def __init__( + self, + name="Catherine", + profile="UI Design", + goal="Finish a workable and good User Interface design based on a product design", + constraints="Give clear layout description and use standard icons to finish the design", + skills=["SD"], + ): super().__init__(name, profile, goal, constraints) self.load_skills(skills) self._init_actions([UIDesign]) self._watch([WritePRD]) - + @load_engine def load_sd_engine(self): """Load the SDEngine.""" file_name = ".tools.sd_engine" engine_name = "SDEngine" return file_name, engine_name - + def load_skills(self, skills): """Load skills for the UI Role.""" # todo: 添加其他出图engine @@ -273,4 +274,3 @@ class UI(Role): if skill == "SD": self.sd_engine = self.load_sd_engine() logger.info(f"load skill engine {self.sd_engine}") - diff --git a/tests/metagpt/tools/test_web_browser_engine.py b/tests/metagpt/tools/test_web_browser_engine.py index 57335de9c..b08d0ca10 100644 --- a/tests/metagpt/tools/test_web_browser_engine.py +++ b/tests/metagpt/tools/test_web_browser_engine.py @@ -1,6 +1,6 @@ import pytest -from metagpt.config import Config -from metagpt.tools import web_browser_engine, WebBrowserEngineType + +from metagpt.tools import WebBrowserEngineType, web_browser_engine @pytest.mark.asyncio diff --git a/tests/metagpt/utils/test_serialize.py b/tests/metagpt/utils/test_serialize.py index de8ccba4c..69f317f79 100644 --- a/tests/metagpt/utils/test_serialize.py +++ b/tests/metagpt/utils/test_serialize.py @@ -3,94 +3,64 @@ # @Desc : the unittest of serialize from typing import List, Tuple -import pytest -from pydantic import create_model - -from metagpt.actions.action_output import ActionOutput from metagpt.actions import WritePRD +from metagpt.actions.action_output import ActionOutput from metagpt.schema import Message -from metagpt.utils.serialize import actionoutout_schema_to_mapping, serialize_message, deserialize_message +from metagpt.utils.serialize import ( + actionoutout_schema_to_mapping, + deserialize_message, + serialize_message, +) def test_actionoutout_schema_to_mapping(): - schema = { - 'title': 'test', - 'type': 'object', - 'properties': { - 'field': { - 'title': 'field', - 'type': 'string' - } - } - } + schema = {"title": "test", "type": "object", "properties": {"field": {"title": "field", "type": "string"}}} mapping = actionoutout_schema_to_mapping(schema) - assert mapping['field'] == (str, ...) + assert mapping["field"] == (str, ...) schema = { - 'title': 'test', - 'type': 'object', - 'properties': { - 'field': { - 'title': 'field', - 'type': 'array', - 'items': { - 'type': 'string' - } - } - } + "title": "test", + "type": "object", + "properties": {"field": {"title": "field", "type": "array", "items": {"type": "string"}}}, } mapping = actionoutout_schema_to_mapping(schema) - assert mapping['field'] == (List[str], ...) + assert mapping["field"] == (List[str], ...) schema = { - 'title': 'test', - 'type': 'object', - 'properties': { - 'field': { - 'title': 'field', - 'type': 'array', - 'items': { - 'type': 'array', - 'minItems': 2, - 'maxItems': 2, - 'items': [ - { - 'type': 'string' - }, - { - 'type': 'string' - } - ] - } + "title": "test", + "type": "object", + "properties": { + "field": { + "title": "field", + "type": "array", + "items": { + "type": "array", + "minItems": 2, + "maxItems": 2, + "items": [{"type": "string"}, {"type": "string"}], + }, } - } + }, } mapping = actionoutout_schema_to_mapping(schema) - assert mapping['field'] == (List[Tuple[str, str]], ...) + assert mapping["field"] == (List[Tuple[str, str]], ...) assert True, True def test_serialize_and_deserialize_message(): - out_mapping = { - 'field1': (str, ...), - 'field2': (List[str], ...) - } - out_data = { - 'field1': 'field1 value', - 'field2': ['field2 value1', 'field2 value2'] - } - ic_obj = ActionOutput.create_model_class('prd', out_mapping) + out_mapping = {"field1": (str, ...), "field2": (List[str], ...)} + out_data = {"field1": "field1 value", "field2": ["field2 value1", "field2 value2"]} + ic_obj = ActionOutput.create_model_class("prd", out_mapping) - message = Message(content='prd demand', - instruct_content=ic_obj(**out_data), - role='user', - cause_by=WritePRD) # WritePRD as test action + message = Message( + content="prd demand", instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD + ) # WritePRD as test action message_ser = serialize_message(message) new_message = deserialize_message(message_ser) assert new_message.content == message.content assert new_message.cause_by == message.cause_by - assert new_message.instruct_content.field1 == out_data['field1'] + assert new_message.instruct_content.field1 == out_data["field1"]